feat: 增强国际化验证功能,添加对共享翻译字符串的支持,优化提示模板加载逻辑
This commit is contained in:
@@ -3,18 +3,16 @@ from __future__ import annotations
|
||||
from collections.abc import Iterator
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from functools import lru_cache
|
||||
|
||||
from .loaders import DEFAULT_LOCALE
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_manager():
|
||||
from .manager import I18nManager
|
||||
|
||||
manager = getattr(_get_manager, "_manager", None)
|
||||
if manager is None:
|
||||
manager = I18nManager()
|
||||
_get_manager._manager = manager
|
||||
return manager
|
||||
return I18nManager()
|
||||
|
||||
|
||||
def set_locale(locale: str) -> str:
|
||||
|
||||
@@ -67,22 +67,12 @@ class I18nManager:
|
||||
self._catalog_cache.pop(normalize_locale(locale), None)
|
||||
|
||||
def t(self, key: str, locale: str | None = None, **kwargs: object) -> str:
|
||||
translation_value, _ = self._get_translation_value(key, locale)
|
||||
if translation_value is None:
|
||||
translation_value, translation_locale = self._get_translation_value(key, locale)
|
||||
template = self._get_standard_template(key, translation_value, translation_locale)
|
||||
if template is None:
|
||||
return key
|
||||
|
||||
if isinstance(translation_value, dict):
|
||||
template = translation_value.get("other")
|
||||
if template is None:
|
||||
self._log_once(
|
||||
("plural_missing_other", self.get_locale(), key),
|
||||
logging.WARNING,
|
||||
"翻译 key '%s' 缺少 other plural category,已回退到 key 本身",
|
||||
key,
|
||||
)
|
||||
return key
|
||||
return self._format_translation(key, template, kwargs)
|
||||
return self._format_translation(key, translation_value, kwargs)
|
||||
return self._format_translation(key, template, kwargs)
|
||||
|
||||
def tn(self, key: str, count: int | float, locale: str | None = None, **kwargs: object) -> str:
|
||||
translation_value, translation_locale = self._get_translation_value(key, locale)
|
||||
@@ -118,6 +108,27 @@ class I18nManager:
|
||||
formatting_kwargs["count"] = count
|
||||
return self._format_translation(key, template, formatting_kwargs)
|
||||
|
||||
def _get_standard_template(
|
||||
self,
|
||||
key: str,
|
||||
translation_value: TranslationValue | None,
|
||||
translation_locale: str,
|
||||
) -> str | None:
|
||||
if translation_value is None:
|
||||
return None
|
||||
if not isinstance(translation_value, dict):
|
||||
return translation_value
|
||||
|
||||
template = translation_value.get("other")
|
||||
if template is None:
|
||||
self._log_once(
|
||||
("plural_missing_other", translation_locale, key),
|
||||
logging.WARNING,
|
||||
"翻译 key '%s' 缺少 other plural category,已回退到 key 本身",
|
||||
key,
|
||||
)
|
||||
return template
|
||||
|
||||
def _format_translation(self, key: str, template: str, kwargs: dict[str, object]) -> str:
|
||||
try:
|
||||
return format_template(template, **kwargs)
|
||||
@@ -161,14 +172,15 @@ class I18nManager:
|
||||
try:
|
||||
return normalize_locale(locale)
|
||||
except InvalidLocaleError:
|
||||
current_locale = self.get_locale()
|
||||
self._log_once(
|
||||
("invalid_locale", "explicit", locale),
|
||||
logging.WARNING,
|
||||
"检测到非法 locale='%s',已回退到当前默认 locale %s",
|
||||
locale,
|
||||
self.get_locale(),
|
||||
current_locale,
|
||||
)
|
||||
return self.get_locale()
|
||||
return current_locale
|
||||
|
||||
def _get_catalog(self, locale: str) -> dict[str, TranslationValue]:
|
||||
normalized_locale = normalize_locale(locale)
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
|
||||
from .i18n import get_locale, t
|
||||
from .i18n.loaders import DEFAULT_LOCALE, extract_placeholders as extract_prompt_placeholders, normalize_locale
|
||||
from .i18n.loaders import DEFAULT_LOCALE, extract_placeholders, normalize_locale
|
||||
|
||||
logger = logging.getLogger("maibot.prompt_i18n")
|
||||
|
||||
@@ -17,9 +17,9 @@ PROMPTS_ROOT = (PROJECT_ROOT / "prompts").resolve()
|
||||
PROMPT_EXTENSIONS = (".prompt",)
|
||||
SAFE_SEGMENT_PATTERN = re.compile(r"^[A-Za-z0-9_.-]+$")
|
||||
STRICT_ENV_KEYS = ("MAIBOT_PROMPT_I18N_STRICT", "MAIBOT_I18N_STRICT")
|
||||
STRICT_ENV_VALUES = {"1", "true", "yes", "on"}
|
||||
|
||||
_prompt_cache: dict[Path, str] = {}
|
||||
_cache_lock = threading.RLock()
|
||||
extract_prompt_placeholders = extract_placeholders
|
||||
|
||||
|
||||
def get_prompts_root(prompts_root: Path | None = None) -> Path:
|
||||
@@ -56,91 +56,100 @@ def is_strict_prompt_i18n_mode() -> bool:
|
||||
if os.getenv("PYTEST_CURRENT_TEST"):
|
||||
return True
|
||||
|
||||
return any(os.getenv(env_key, "").strip().lower() in {"1", "true", "yes", "on"} for env_key in STRICT_ENV_KEYS)
|
||||
return any(os.getenv(env_key, "").strip().lower() in STRICT_ENV_VALUES for env_key in STRICT_ENV_KEYS)
|
||||
|
||||
|
||||
def _supported_prompt_files(directory: Path, recursive: bool = True) -> list[Path]:
|
||||
def discover_prompt_locales(prompts_root: Path | None = None) -> list[str]:
|
||||
resolved_prompts_root = get_prompts_root(prompts_root)
|
||||
if not resolved_prompts_root.exists():
|
||||
return []
|
||||
|
||||
locale_names = [path.name for path in resolved_prompts_root.iterdir() if path.is_dir()]
|
||||
return sorted(locale_names)
|
||||
|
||||
|
||||
def iter_prompt_files(directory: Path, recursive: bool = True) -> list[Path]:
|
||||
if not directory.exists():
|
||||
return []
|
||||
|
||||
search = directory.rglob if recursive else directory.glob
|
||||
matched_files: list[Path] = []
|
||||
prompt_files: list[Path] = []
|
||||
for suffix in PROMPT_EXTENSIONS:
|
||||
matched_files.extend(path for path in search(f"*{suffix}") if path.is_file())
|
||||
return sorted(set(matched_files))
|
||||
prompt_files.extend(path for path in search(f"*{suffix}") if path.is_file())
|
||||
return sorted(set(prompt_files))
|
||||
|
||||
|
||||
def _scan_prompt_directory(directory: Path, prompts_root: Path) -> dict[str, Path]:
|
||||
def _raise_duplicate_prompt_name(name: str, first_path: Path, second_path: Path, prompts_root: Path) -> None:
|
||||
raise ValueError(
|
||||
t(
|
||||
"prompt.duplicate_template_name",
|
||||
name=name,
|
||||
path_a=first_path.relative_to(prompts_root),
|
||||
path_b=second_path.relative_to(prompts_root),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _scan_prompt_directory(directory: Path, prompts_root: Path, recursive: bool = True) -> dict[str, Path]:
|
||||
prompt_paths: dict[str, Path] = {}
|
||||
if not directory.exists():
|
||||
return prompt_paths
|
||||
|
||||
for prompt_path in _supported_prompt_files(directory):
|
||||
for prompt_path in iter_prompt_files(directory, recursive=recursive):
|
||||
prompt_name = prompt_path.stem
|
||||
if prompt_name in prompt_paths:
|
||||
raise ValueError(
|
||||
t(
|
||||
"prompt.duplicate_template_name",
|
||||
name=prompt_name,
|
||||
path_a=prompt_paths[prompt_name].relative_to(prompts_root),
|
||||
path_b=prompt_path.relative_to(prompts_root),
|
||||
)
|
||||
)
|
||||
existing_path = prompt_paths.get(prompt_name)
|
||||
if existing_path is not None:
|
||||
_raise_duplicate_prompt_name(prompt_name, existing_path, prompt_path, prompts_root)
|
||||
prompt_paths[prompt_name] = prompt_path
|
||||
return prompt_paths
|
||||
|
||||
|
||||
def _scan_legacy_prompt_directory(directory: Path, prompts_root: Path) -> dict[str, Path]:
|
||||
prompt_paths: dict[str, Path] = {}
|
||||
if not directory.exists():
|
||||
return prompt_paths
|
||||
def _iter_prompt_template_layers(prompts_root: Path, requested_locale: str) -> list[tuple[Path, bool]]:
|
||||
prompt_layers: list[tuple[Path, bool]] = [
|
||||
(prompts_root, False),
|
||||
(prompts_root / DEFAULT_LOCALE, True),
|
||||
]
|
||||
if requested_locale != DEFAULT_LOCALE:
|
||||
prompt_layers.append((prompts_root / requested_locale, True))
|
||||
return prompt_layers
|
||||
|
||||
for prompt_path in _supported_prompt_files(directory, recursive=False):
|
||||
prompt_name = prompt_path.stem
|
||||
if prompt_name in prompt_paths:
|
||||
raise ValueError(
|
||||
t(
|
||||
"prompt.duplicate_template_name",
|
||||
name=prompt_name,
|
||||
path_a=prompt_paths[prompt_name].relative_to(prompts_root),
|
||||
path_b=prompt_path.relative_to(prompts_root),
|
||||
)
|
||||
)
|
||||
prompt_paths[prompt_name] = prompt_path
|
||||
return prompt_paths
|
||||
|
||||
def _iter_locale_candidates(requested_locale: str) -> list[str | None]:
|
||||
locale_candidates: list[str | None] = [requested_locale]
|
||||
if requested_locale != DEFAULT_LOCALE:
|
||||
locale_candidates.append(DEFAULT_LOCALE)
|
||||
locale_candidates.append(None)
|
||||
return locale_candidates
|
||||
|
||||
|
||||
def list_prompt_templates(locale: str | None = None, prompts_root: Path | None = None) -> dict[str, Path]:
|
||||
resolved_prompts_root = get_prompts_root(prompts_root)
|
||||
requested_locale = normalize_locale(locale or get_locale())
|
||||
|
||||
prompt_paths = _scan_legacy_prompt_directory(resolved_prompts_root, resolved_prompts_root)
|
||||
prompt_paths.update(_scan_prompt_directory(resolved_prompts_root / DEFAULT_LOCALE, resolved_prompts_root))
|
||||
|
||||
if requested_locale != DEFAULT_LOCALE:
|
||||
prompt_paths.update(_scan_prompt_directory(resolved_prompts_root / requested_locale, resolved_prompts_root))
|
||||
prompt_paths: dict[str, Path] = {}
|
||||
for directory, recursive in _iter_prompt_template_layers(resolved_prompts_root, requested_locale):
|
||||
prompt_paths.update(_scan_prompt_directory(directory, resolved_prompts_root, recursive=recursive))
|
||||
|
||||
return prompt_paths
|
||||
|
||||
|
||||
def resolve_prompt_path(name: str, locale: str | None = None, category: str | None = None, prompts_root: Path | None = None) -> Path:
|
||||
def resolve_prompt_path(
|
||||
name: str, locale: str | None = None, category: str | None = None, prompts_root: Path | None = None
|
||||
) -> Path:
|
||||
resolved_prompts_root = get_prompts_root(prompts_root)
|
||||
normalized_name = normalize_prompt_name(name)
|
||||
normalized_category = normalize_prompt_category(category)
|
||||
requested_locale = normalize_locale(locale or get_locale())
|
||||
|
||||
locale_candidates: list[str | None] = [requested_locale]
|
||||
if requested_locale != DEFAULT_LOCALE:
|
||||
locale_candidates.append(DEFAULT_LOCALE)
|
||||
locale_candidates.append(None)
|
||||
|
||||
if normalized_category is not None:
|
||||
for locale_candidate in locale_candidates:
|
||||
for locale_candidate in _iter_locale_candidates(requested_locale):
|
||||
base_dir = resolved_prompts_root if locale_candidate is None else resolved_prompts_root / locale_candidate
|
||||
for suffix in PROMPT_EXTENSIONS:
|
||||
candidate_paths = [(base_dir / normalized_category / f"{normalized_name}{suffix}").resolve()]
|
||||
candidate_path = (base_dir / normalized_category / f"{normalized_name}{suffix}").resolve()
|
||||
if candidate_path.is_file():
|
||||
return candidate_path
|
||||
|
||||
# 允许带 category 的调用在旧版平铺目录或未迁移完的 locale 目录中继续工作。
|
||||
candidate_paths.append((base_dir / f"{normalized_name}{suffix}").resolve())
|
||||
for candidate_path in candidate_paths:
|
||||
if candidate_path.is_file():
|
||||
return candidate_path
|
||||
fallback_path = (base_dir / f"{normalized_name}{suffix}").resolve()
|
||||
if fallback_path.is_file():
|
||||
return fallback_path
|
||||
else:
|
||||
prompt_paths = list_prompt_templates(locale=requested_locale, prompts_root=resolved_prompts_root)
|
||||
if normalized_name in prompt_paths:
|
||||
@@ -149,6 +158,31 @@ def resolve_prompt_path(name: str, locale: str | None = None, category: str | No
|
||||
raise FileNotFoundError(t("prompt.template_not_found", locale=requested_locale, name=normalized_name))
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _read_prompt_template(prompt_path: Path) -> str:
|
||||
return prompt_path.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def _format_prompt_template(name: str, template: str, **kwargs: object) -> str:
|
||||
if not kwargs:
|
||||
return template
|
||||
|
||||
try:
|
||||
return template.format(**kwargs)
|
||||
except KeyError as exc:
|
||||
missing_placeholder = exc.args[0]
|
||||
error = KeyError(t("prompt.missing_placeholder", name=name, placeholder=missing_placeholder))
|
||||
if is_strict_prompt_i18n_mode():
|
||||
raise error from exc
|
||||
logger.error("%s", error)
|
||||
return template
|
||||
except Exception as exc:
|
||||
logger.error(t("prompt.format_failed", name=name, error=exc))
|
||||
if is_strict_prompt_i18n_mode():
|
||||
raise
|
||||
return template
|
||||
|
||||
|
||||
def load_prompt(
|
||||
name: str,
|
||||
locale: str | None = None,
|
||||
@@ -156,40 +190,11 @@ def load_prompt(
|
||||
prompts_root: Path | None = None,
|
||||
**kwargs: object,
|
||||
) -> str:
|
||||
prompt_path = resolve_prompt_path(name=name, locale=locale, category=category, prompts_root=prompts_root)
|
||||
with _cache_lock:
|
||||
template = _prompt_cache.get(prompt_path)
|
||||
if template is None:
|
||||
template = prompt_path.read_text(encoding="utf-8")
|
||||
with _cache_lock:
|
||||
_prompt_cache.setdefault(prompt_path, template)
|
||||
template = _prompt_cache[prompt_path]
|
||||
|
||||
if not kwargs:
|
||||
return template
|
||||
|
||||
try:
|
||||
return template.format(**kwargs)
|
||||
except KeyError as exc:
|
||||
missing_placeholder = exc.args[0]
|
||||
error = KeyError(
|
||||
t(
|
||||
"prompt.missing_placeholder",
|
||||
name=normalize_prompt_name(name),
|
||||
placeholder=missing_placeholder,
|
||||
)
|
||||
)
|
||||
if is_strict_prompt_i18n_mode():
|
||||
raise error from exc
|
||||
logger.error("%s", error)
|
||||
return template
|
||||
except Exception as exc:
|
||||
logger.error(t("prompt.format_failed", name=normalize_prompt_name(name), error=exc))
|
||||
if is_strict_prompt_i18n_mode():
|
||||
raise
|
||||
return template
|
||||
normalized_name = normalize_prompt_name(name)
|
||||
prompt_path = resolve_prompt_path(name=normalized_name, locale=locale, category=category, prompts_root=prompts_root)
|
||||
template = _read_prompt_template(prompt_path)
|
||||
return _format_prompt_template(normalized_name, template, **kwargs)
|
||||
|
||||
|
||||
def clear_prompt_cache() -> None:
|
||||
with _cache_lock:
|
||||
_prompt_cache.clear()
|
||||
_read_prompt_template.cache_clear()
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from collections.abc import Callable, Coroutine
|
||||
from pathlib import Path
|
||||
from string import Formatter
|
||||
from typing import Any, Optional
|
||||
|
||||
from collections.abc import Callable, Coroutine
|
||||
import inspect
|
||||
|
||||
from src.common.prompt_i18n import list_prompt_templates, load_prompt
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import list_prompt_templates, load_prompt
|
||||
|
||||
|
||||
logger = get_logger("Prompt")
|
||||
@@ -14,7 +14,7 @@ logger = get_logger("Prompt")
|
||||
_LEFT_BRACE = "\ufde9"
|
||||
_RIGHT_BRACE = "\ufdea"
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute().resolve()
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
PROMPTS_DIR = PROJECT_ROOT / "prompts"
|
||||
DATA_DIR = PROJECT_ROOT / "data"
|
||||
CUSTOM_PROMPTS_DIR = DATA_DIR / "custom_prompts"
|
||||
@@ -240,18 +240,23 @@ class PromptManager:
|
||||
for prompt_file in CUSTOM_PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"):
|
||||
try:
|
||||
prompt_file.unlink()
|
||||
except Exception as e:
|
||||
logger.error(f"删除自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}")
|
||||
raise e
|
||||
except Exception as exc:
|
||||
logger.error(f"删除自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {exc}")
|
||||
raise
|
||||
for prompt_name in self._prompt_to_save:
|
||||
prompt = self.prompts[prompt_name]
|
||||
file_path = CUSTOM_PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}"
|
||||
try:
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(prompt.template)
|
||||
except Exception as e:
|
||||
logger.error(f"保存 Prompt '{prompt_name}' 时出错,文件路径: '{file_path}',错误信息: {e}")
|
||||
raise e
|
||||
file_path.write_text(prompt.template, encoding="utf-8")
|
||||
except Exception as exc:
|
||||
logger.error(f"保存 Prompt '{prompt_name}' 时出错,文件路径: '{file_path}',错误信息: {exc}")
|
||||
raise
|
||||
|
||||
def _load_prompt_template(self, prompt_name: str) -> tuple[str, bool]:
|
||||
custom_prompt_path = CUSTOM_PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}"
|
||||
if custom_prompt_path.exists():
|
||||
return custom_prompt_path.read_text(encoding="utf-8"), True
|
||||
return load_prompt(prompt_name, prompts_root=PROMPTS_DIR), False
|
||||
|
||||
def load_prompts(self) -> None:
|
||||
"""
|
||||
@@ -259,34 +264,23 @@ class PromptManager:
|
||||
Raises:
|
||||
Exception: 如果在加载过程中出现任何文件操作错误则引发该异常
|
||||
"""
|
||||
prompt_files = list_prompt_templates(prompts_root=PROMPTS_DIR)
|
||||
for prompt_name, prompt_file in prompt_files.items():
|
||||
prompt_templates = list_prompt_templates(prompts_root=PROMPTS_DIR)
|
||||
for prompt_name, prompt_file in prompt_templates.items():
|
||||
try:
|
||||
prompt_to_load = prompt_file
|
||||
need_save = False
|
||||
custom_prompt_path = CUSTOM_PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}"
|
||||
if custom_prompt_path.exists():
|
||||
# 优先加载自定义目录下的 Prompt 文件
|
||||
prompt_to_load = custom_prompt_path
|
||||
need_save = True
|
||||
with open(prompt_to_load, "r", encoding="utf-8") as f:
|
||||
template = f.read()
|
||||
else:
|
||||
template = load_prompt(prompt_name, prompts_root=PROMPTS_DIR)
|
||||
template, need_save = self._load_prompt_template(prompt_name)
|
||||
self.add_prompt(Prompt(prompt_name=prompt_name, template=template), need_save=need_save)
|
||||
except Exception as e:
|
||||
logger.error(f"加载 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}")
|
||||
raise e
|
||||
except Exception as exc:
|
||||
logger.error(f"加载 Prompt 文件 '{prompt_file}' 时出错,错误信息: {exc}")
|
||||
raise
|
||||
for prompt_file in CUSTOM_PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"):
|
||||
if prompt_file.stem in prompt_files:
|
||||
if prompt_file.stem in prompt_templates:
|
||||
continue # 已经加载过了,跳过
|
||||
try:
|
||||
with open(prompt_file, "r", encoding="utf-8") as f:
|
||||
template = f.read()
|
||||
template = prompt_file.read_text(encoding="utf-8")
|
||||
self.add_prompt(Prompt(prompt_name=prompt_file.stem, template=template), need_save=True)
|
||||
except Exception as e:
|
||||
logger.error(f"加载自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}")
|
||||
raise e
|
||||
except Exception as exc:
|
||||
logger.error(f"加载自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {exc}")
|
||||
raise
|
||||
|
||||
async def _get_function_result(
|
||||
self,
|
||||
@@ -301,12 +295,12 @@ class PromptManager:
|
||||
if isinstance(res, Coroutine):
|
||||
res = await res
|
||||
return res
|
||||
except Exception as e:
|
||||
except Exception as exc:
|
||||
if is_prompt_context:
|
||||
logger.error(f"调用 Prompt '{prompt_name}' 内部上下文构造函数 '{field_name}' 时出错,错误信息: {e}")
|
||||
logger.error(f"调用 Prompt '{prompt_name}' 内部上下文构造函数 '{field_name}' 时出错,错误信息: {exc}")
|
||||
else:
|
||||
logger.error(f"调用上下文构造函数 '{field_name}' 时出错,所属模块: '{module}',错误信息: {e}")
|
||||
raise e
|
||||
logger.error(f"调用上下文构造函数 '{field_name}' 时出错,所属模块: '{module}',错误信息: {exc}")
|
||||
raise
|
||||
|
||||
|
||||
prompt_manager = PromptManager()
|
||||
|
||||
Reference in New Issue
Block a user