chore: import deployable mai-bot source tree
This commit is contained in:
533
pytests/config_test/test_config_base.py
Normal file
533
pytests/config_test/test_config_base.py
Normal file
@@ -0,0 +1,533 @@
|
||||
import logging
|
||||
import sys
|
||||
from importlib import util
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 测试环境准备:补全 logger 和 AttrDocBase 依赖
|
||||
# -------------------------------------------------------------
|
||||
|
||||
TEST_ROOT = Path(__file__).parent.parent.absolute().resolve()
|
||||
logger_file = TEST_ROOT / "logger.py"
|
||||
spec = util.spec_from_file_location("src.common.logger", logger_file)
|
||||
module = util.module_from_spec(spec) # type: ignore
|
||||
assert spec is not None and spec.loader is not None
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
sys.modules["src.common.logger"] = module
|
||||
|
||||
PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve()
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
sys.path.insert(0, str(PROJECT_ROOT / "src" / "config"))
|
||||
|
||||
from src.config.config_base import ConfigBase # noqa: E402
|
||||
import src.config.config_base as config_base_module # noqa: E402
|
||||
|
||||
|
||||
class AttrDocBase:
|
||||
"""用于测试的轻量级 AttrDocBase 替身"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# 被 ConfigBase.model_post_init 调用
|
||||
self.__post_init_called__ = True
|
||||
|
||||
|
||||
# 打补丁,让 ConfigBase 使用测试替身
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_attrdoc_post_init():
|
||||
orig = config_base_module.AttrDocBase.__post_init__
|
||||
config_base_module.AttrDocBase.__post_init__ = AttrDocBase.__post_init__ # type: ignore
|
||||
yield
|
||||
config_base_module.AttrDocBase.__post_init__ = orig
|
||||
|
||||
|
||||
config_base_module.logger = logging.getLogger("config_base_test_logger")
|
||||
|
||||
|
||||
class SimpleClass(ConfigBase):
|
||||
a: int = 1
|
||||
b: str = "test"
|
||||
|
||||
|
||||
class TestConfigBase:
|
||||
# ---------------------------------------------------------
|
||||
# happy path:整体 model_post_init 测试
|
||||
# ---------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"model_cls, init_kwargs, expected_fields",
|
||||
[
|
||||
pytest.param(
|
||||
# 简单原子类型字段
|
||||
type(
|
||||
"SimpleAtomic",
|
||||
(ConfigBase,),
|
||||
{
|
||||
"__annotations__": {
|
||||
"a": int,
|
||||
"b": str,
|
||||
"c": bool,
|
||||
"d": float,
|
||||
},
|
||||
"a": Field(default=1),
|
||||
"b": Field(default="x"),
|
||||
"c": Field(default=True),
|
||||
"d": Field(default=1.5),
|
||||
},
|
||||
),
|
||||
{},
|
||||
{"a", "b", "c", "d"},
|
||||
id="happy-simple-atomic-fields",
|
||||
),
|
||||
pytest.param(
|
||||
# list/set/dict 泛型 + 原子内部类型
|
||||
type(
|
||||
"AtomicContainers",
|
||||
(ConfigBase,),
|
||||
{
|
||||
"__annotations__": {
|
||||
"ints": List[int],
|
||||
"names": Set[str],
|
||||
"mapping": Dict[str, int],
|
||||
},
|
||||
"ints": Field(default_factory=lambda: [1, 2]),
|
||||
"names": Field(default_factory=lambda: {"a", "b"}),
|
||||
"mapping": Field(default_factory=lambda: {"x": 1}),
|
||||
},
|
||||
),
|
||||
{},
|
||||
{"ints", "names", "mapping"},
|
||||
id="happy-atomic-containers",
|
||||
),
|
||||
pytest.param(
|
||||
# Optional 原子和 Optional 容器
|
||||
type(
|
||||
"OptionalFields",
|
||||
(ConfigBase,),
|
||||
{
|
||||
"__annotations__": {
|
||||
"maybe_int": Optional[int],
|
||||
"maybe_str_list": Optional[List[str]],
|
||||
},
|
||||
"maybe_int": Field(default=None),
|
||||
"maybe_str_list": Field(default=None),
|
||||
},
|
||||
),
|
||||
{},
|
||||
{"maybe_int", "maybe_str_list"},
|
||||
id="happy-optional-fields",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_model_post_init_happy_paths(self, model_cls, init_kwargs, expected_fields):
|
||||
# Act
|
||||
instance = model_cls(**init_kwargs)
|
||||
|
||||
# Assert
|
||||
for field_name in expected_fields:
|
||||
assert field_name in type(instance).model_fields
|
||||
_ = getattr(instance, field_name)
|
||||
assert getattr(instance, "__post_init_called__", False) is True
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# _get_real_type
|
||||
# ---------------------------------------------------------
|
||||
def test_get_real_type_non_generic_and_generic(self):
|
||||
class Sample(ConfigBase):
|
||||
x: int = 1
|
||||
y: List[int] = Field(default_factory=list)
|
||||
|
||||
instance = Sample()
|
||||
|
||||
# Act
|
||||
origin_x, args_x = instance._get_real_type(int)
|
||||
|
||||
# Assert
|
||||
assert origin_x is int
|
||||
assert args_x == ()
|
||||
|
||||
# Act
|
||||
origin_y, args_y = instance._get_real_type(List[int])
|
||||
|
||||
# Assert
|
||||
assert origin_y in (list, List)
|
||||
assert args_y == (int,)
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# _validate_union_type
|
||||
# ---------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"annotation, expect_error, error_fragment, expected_origin_type",
|
||||
[
|
||||
pytest.param(
|
||||
int,
|
||||
False,
|
||||
None,
|
||||
int,
|
||||
id="union-validation-atomic-non-union",
|
||||
),
|
||||
pytest.param(
|
||||
Optional[int],
|
||||
False,
|
||||
None,
|
||||
int,
|
||||
id="union-validation-optional-atomic",
|
||||
),
|
||||
pytest.param(
|
||||
Optional[List[int]],
|
||||
False,
|
||||
None,
|
||||
list,
|
||||
id="union-validation-optional-container",
|
||||
),
|
||||
pytest.param(
|
||||
Union[int, str],
|
||||
True,
|
||||
"不允许使用 Union 类型注解",
|
||||
None,
|
||||
id="union-validation-disallow-non-optional-union",
|
||||
),
|
||||
pytest.param(
|
||||
int | str,
|
||||
True,
|
||||
"不允许使用 Union 类型注解",
|
||||
None,
|
||||
id="union-validation-pep604-disallow-non-optional-union",
|
||||
),
|
||||
pytest.param(
|
||||
Union[int, None, str],
|
||||
True,
|
||||
"不允许使用 Union 类型注解",
|
||||
None,
|
||||
id="union-validation-disallow-union-more-than-two",
|
||||
),
|
||||
pytest.param(
|
||||
Optional[Union[int, str]],
|
||||
True,
|
||||
"不允许使用 Union 类型注解",
|
||||
None,
|
||||
id="union-validation-disallow-nested-optional-union",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_validate_union_type(self, annotation, expect_error, error_fragment, expected_origin_type):
|
||||
# 这里我们不实例化 Sample,以避免在 __init__/model_post_init 阶段触发验证。
|
||||
# 直接通过一个“哑实例”调用受测方法,仅测试类型注解逻辑。
|
||||
|
||||
class Dummy(ConfigBase):
|
||||
pass
|
||||
|
||||
dummy = Dummy() # 最小初始化,避免字段校验
|
||||
|
||||
field_name = "v"
|
||||
|
||||
if expect_error:
|
||||
# Act / Assert
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
dummy._validate_union_type(annotation, field_name)
|
||||
assert error_fragment in str(exc_info.value)
|
||||
else:
|
||||
# Act
|
||||
origin, args, other = dummy._validate_union_type(annotation, field_name)
|
||||
|
||||
# Assert
|
||||
assert origin is expected_origin_type
|
||||
assert other is not None
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# _validate_list_set_type
|
||||
# ---------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"annotation, expect_error, error_fragment",
|
||||
[
|
||||
pytest.param(
|
||||
List[int],
|
||||
False,
|
||||
None,
|
||||
id="listset-validation-list-happy",
|
||||
),
|
||||
pytest.param(
|
||||
Set[str],
|
||||
False,
|
||||
None,
|
||||
id="listset-validation-set-happy",
|
||||
),
|
||||
pytest.param(
|
||||
list,
|
||||
True,
|
||||
"必须指定且仅指定一个类型参数",
|
||||
id="listset-validation-missing-type-arg",
|
||||
),
|
||||
pytest.param(
|
||||
List[int | None],
|
||||
True,
|
||||
"不允许嵌套泛型类型",
|
||||
id="listset-validation-nested-generic-inner-union",
|
||||
),
|
||||
pytest.param(
|
||||
List[List[int]],
|
||||
True,
|
||||
"不允许嵌套泛型类型",
|
||||
id="listset-validation-nested-generic-inner-list",
|
||||
),
|
||||
pytest.param(
|
||||
List[SimpleClass],
|
||||
False,
|
||||
None,
|
||||
id="listset-validation-list-configbase-element_allow",
|
||||
),
|
||||
pytest.param(
|
||||
Set[SimpleClass],
|
||||
True,
|
||||
"ConfigBase is not Hashable",
|
||||
id="listset-validation-set-configbase-element_reject",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_validate_list_set_type(self, annotation, expect_error, error_fragment):
|
||||
# 不实例化带有这些字段的模型,避免在 __init__/model_post_init 阶段就失败,
|
||||
# 只测试 _validate_list_set_type 本身的逻辑。
|
||||
|
||||
class Dummy(ConfigBase):
|
||||
pass
|
||||
|
||||
dummy = Dummy()
|
||||
|
||||
field_name = "items"
|
||||
|
||||
if expect_error:
|
||||
# Act / Assert
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
dummy._validate_list_set_type(annotation, field_name)
|
||||
assert error_fragment in str(exc_info.value)
|
||||
else:
|
||||
# Act
|
||||
dummy._validate_list_set_type(annotation, field_name)
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# _validate_dict_type
|
||||
# ---------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"annotation, expect_error, error_fragment",
|
||||
[
|
||||
pytest.param(
|
||||
Dict[str, int],
|
||||
False,
|
||||
None,
|
||||
id="dict-validation-happy-atomic",
|
||||
),
|
||||
pytest.param(
|
||||
Dict[str, Any],
|
||||
True,
|
||||
"不允许使用 Any 类型注解",
|
||||
id="dict-validation-any-value-disallowed",
|
||||
),
|
||||
pytest.param(
|
||||
Dict[str, Dict[str, int]],
|
||||
True,
|
||||
"不允许嵌套泛型类型",
|
||||
id="dict-validation-optional-nested-list",
|
||||
),
|
||||
pytest.param(
|
||||
Dict,
|
||||
True,
|
||||
"必须指定键和值的类型参数",
|
||||
id="dict-validation-missing-args",
|
||||
),
|
||||
pytest.param(
|
||||
Dict[str, SimpleClass],
|
||||
False,
|
||||
None,
|
||||
id="dict-validation-happy-configbase-value",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_validate_dict_type(self, annotation, expect_error, error_fragment):
|
||||
# 同样不通过字段定义来触发 model_post_init,只测试 _validate_dict_type 本身。
|
||||
|
||||
class Dummy(ConfigBase):
|
||||
_validate_any: bool = True
|
||||
|
||||
dummy = Dummy()
|
||||
field_name = "mapping"
|
||||
|
||||
if expect_error:
|
||||
# Act / Assert
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
dummy._validate_dict_type(annotation, field_name)
|
||||
assert error_fragment in str(exc_info.value)
|
||||
else:
|
||||
# Act
|
||||
dummy._validate_dict_type(annotation, field_name)
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# _discourage_any_usage
|
||||
# ---------------------------------------------------------
|
||||
def test_discourage_any_usage_raises_when_validate_any_true(self, caplog):
|
||||
class Sample(ConfigBase):
|
||||
_validate_any: bool = True
|
||||
|
||||
instance = Sample()
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
instance._discourage_any_usage("field_x")
|
||||
assert "不允许使用 Any 类型注解" in str(exc_info.value)
|
||||
assert "建议避免使用" not in caplog.text
|
||||
|
||||
def test_discourage_any_usage_logs_when_validate_any_false(self, caplog):
|
||||
class Sample(ConfigBase):
|
||||
_validate_any: bool = False
|
||||
|
||||
instance = Sample()
|
||||
|
||||
# Arrange
|
||||
caplog.set_level(logging.WARNING, logger="config_base_test_logger")
|
||||
|
||||
# Act
|
||||
instance._discourage_any_usage("field_y")
|
||||
|
||||
# Assert
|
||||
assert "字段'field_y'中使用了 Any 类型注解" in caplog.text
|
||||
|
||||
def test_discourage_any_usage_suppressed_warning(self, caplog):
|
||||
class Sample(ConfigBase):
|
||||
_validate_any: bool = False
|
||||
suppress_any_warning: bool = True
|
||||
|
||||
instance = Sample()
|
||||
|
||||
# Arrange
|
||||
caplog.set_level(logging.WARNING, logger="config_base_test_logger")
|
||||
|
||||
# Act
|
||||
instance._discourage_any_usage("field_z")
|
||||
|
||||
# Assert
|
||||
assert "字段'field_z'中使用了 Any 类型注解" not in caplog.text
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# model_post_init 规则覆盖(错误与边界情况)
|
||||
# ---------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"field_annotation, expect_error, error_fragment, test_id",
|
||||
[
|
||||
(
|
||||
Tuple[int, int],
|
||||
True,
|
||||
"不允许使用 Tuple 类型注解",
|
||||
"model-post-init-disallow-tuple-typing-tuple",
|
||||
),
|
||||
(
|
||||
tuple[int, int],
|
||||
True,
|
||||
"不允许使用 Tuple 类型注解",
|
||||
"model-post-init-disallow-pep604-tuple",
|
||||
),
|
||||
(
|
||||
Union[int, str],
|
||||
True,
|
||||
"不允许使用 Union 类型注解",
|
||||
"model-post-init-disallow-union-field",
|
||||
),
|
||||
(
|
||||
list,
|
||||
True,
|
||||
"必须指定且仅指定一个类型参数",
|
||||
"model-post-init-list-missing-type-arg",
|
||||
),
|
||||
(
|
||||
List[List[int]],
|
||||
True,
|
||||
"不允许嵌套泛型类型",
|
||||
"model-post-init-list-nested-generic",
|
||||
),
|
||||
(
|
||||
Dict[str, Any],
|
||||
True,
|
||||
"不允许使用 Any 类型注解",
|
||||
"model-post-init-dict-value-any",
|
||||
),
|
||||
(
|
||||
Any,
|
||||
True,
|
||||
"不允许使用 Any 类型注解",
|
||||
"model-post-init-field-any-disallowed",
|
||||
),
|
||||
(
|
||||
Set[int],
|
||||
False,
|
||||
None,
|
||||
"model-post-init-allow-set-int",
|
||||
),
|
||||
(
|
||||
Dict[str, Optional[int]],
|
||||
False,
|
||||
None,
|
||||
"model-post-init-allow-dict-optional-int",
|
||||
),
|
||||
],
|
||||
ids=lambda v: v[3] if isinstance(v, tuple) else v,
|
||||
)
|
||||
def test_model_post_init_type_rules(self, field_annotation, expect_error, error_fragment, test_id):
|
||||
# Arrange
|
||||
attrs = {
|
||||
"__annotations__": {"f": field_annotation},
|
||||
"f": Field(default=None),
|
||||
}
|
||||
model_cls = type("DynamicModel" + test_id.replace("-", "_"), (ConfigBase,), attrs)
|
||||
|
||||
if expect_error:
|
||||
# Act / Assert
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
model_cls()
|
||||
assert error_fragment in str(exc_info.value)
|
||||
else:
|
||||
# Act
|
||||
instance = model_cls()
|
||||
|
||||
# Assert
|
||||
assert hasattr(instance, "f")
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 嵌套 ConfigBase & 非支持泛型 origin
|
||||
# ---------------------------------------------------------
|
||||
def test_model_post_init_allows_configbase_nested_class(self):
|
||||
class Child(ConfigBase):
|
||||
value: int = 1
|
||||
|
||||
class Parent(ConfigBase):
|
||||
child: Child = Field(default_factory=Child)
|
||||
|
||||
# Act
|
||||
parent = Parent()
|
||||
|
||||
# Assert
|
||||
assert isinstance(parent.child, Child)
|
||||
|
||||
def test_model_post_init_disallow_non_supported_generic_origin(self):
|
||||
class CustomGeneric(BaseModel):
|
||||
pass
|
||||
|
||||
class Sample(ConfigBase):
|
||||
f: CustomGeneric = Field(default_factory=CustomGeneric)
|
||||
|
||||
# Arrange / Act / Assert
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
Sample()
|
||||
assert "仅允许使用list, set, dict三种泛型类型注解" in str(exc_info.value)
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# super().model_post_init 和 AttrDocBase.__post_init__ 调用
|
||||
# ---------------------------------------------------------
|
||||
def test_super_model_post_init_and_attrdoc_post_init_called(self):
|
||||
class Sample(ConfigBase):
|
||||
value: int = 1
|
||||
|
||||
# Act
|
||||
instance = Sample()
|
||||
|
||||
# Assert
|
||||
assert getattr(instance, "__post_init_called__", False) is True
|
||||
104
pytests/config_test/test_config_manager_hot_reload.py
Normal file
104
pytests/config_test/test_config_manager_hot_reload.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from pathlib import Path
|
||||
|
||||
from watchfiles import Change
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
|
||||
from src.config.config import ConfigManager
|
||||
from src.config.file_watcher import FileChange, FileWatcherStats
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_file_changes_throttles_reload():
|
||||
manager = ConfigManager()
|
||||
manager._hot_reload_min_interval_s = 100.0
|
||||
|
||||
called = 0
|
||||
|
||||
async def reload_stub(changed_scopes=None) -> bool:
|
||||
nonlocal called
|
||||
called += 1
|
||||
return True
|
||||
|
||||
manager.reload_config = reload_stub # type: ignore[method-assign]
|
||||
changes = [FileChange(change_type=Change.modified, path=Path("/tmp/bot_config.toml"))]
|
||||
|
||||
await manager._handle_file_changes(changes)
|
||||
await manager._handle_file_changes(changes)
|
||||
|
||||
assert called == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_file_changes_timeout_logged(caplog):
|
||||
manager = ConfigManager()
|
||||
manager._hot_reload_min_interval_s = 0.0
|
||||
manager._hot_reload_timeout_s = 0.01
|
||||
|
||||
async def reload_stub(changed_scopes=None) -> bool:
|
||||
await asyncio.sleep(0.05)
|
||||
return True
|
||||
|
||||
manager.reload_config = reload_stub # type: ignore[method-assign]
|
||||
changes = [FileChange(change_type=Change.modified, path=Path("/tmp/model_config.toml"))]
|
||||
|
||||
with caplog.at_level("ERROR"):
|
||||
await manager._handle_file_changes(changes)
|
||||
|
||||
assert "配置热重载超时" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_file_changes_empty_skips_reload():
|
||||
manager = ConfigManager()
|
||||
|
||||
called = 0
|
||||
|
||||
async def reload_stub(changed_scopes=None) -> bool:
|
||||
nonlocal called
|
||||
called += 1
|
||||
return True
|
||||
|
||||
manager.reload_config = reload_stub # type: ignore[method-assign]
|
||||
|
||||
await manager._handle_file_changes([])
|
||||
|
||||
assert called == 0
|
||||
|
||||
|
||||
class _FakeWatcher:
|
||||
def __init__(self):
|
||||
self.unsubscribe_called_with: str | None = None
|
||||
self.stop_called = False
|
||||
self.stats = FileWatcherStats(
|
||||
batches_seen=1,
|
||||
changes_seen=2,
|
||||
callbacks_succeeded=3,
|
||||
callbacks_failed=4,
|
||||
callbacks_timed_out=5,
|
||||
callbacks_skipped_cooldown=6,
|
||||
restart_count=7,
|
||||
)
|
||||
|
||||
def unsubscribe(self, subscription_id: str) -> bool:
|
||||
self.unsubscribe_called_with = subscription_id
|
||||
return True
|
||||
|
||||
async def stop(self) -> None:
|
||||
self.stop_called = True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_file_watcher_cleans_state():
|
||||
manager = ConfigManager()
|
||||
fake_watcher = _FakeWatcher()
|
||||
manager._file_watcher = fake_watcher # type: ignore[assignment]
|
||||
manager._file_watcher_subscription_id = "sub-1"
|
||||
|
||||
await manager.stop_file_watcher()
|
||||
|
||||
assert fake_watcher.unsubscribe_called_with == "sub-1"
|
||||
assert fake_watcher.stop_called is True
|
||||
assert manager._file_watcher is None
|
||||
assert manager._file_watcher_subscription_id is None
|
||||
22
pytests/config_test/test_config_manager_startup_upgrade.py
Normal file
22
pytests/config_test/test_config_manager_startup_upgrade.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Any
|
||||
|
||||
from src.config import config as config_module
|
||||
from src.config.config import Config, ConfigManager, ModelConfig
|
||||
|
||||
|
||||
def test_initialize_upgrades_bot_and_model_config_without_exit(monkeypatch):
|
||||
manager = ConfigManager()
|
||||
loaded_config_classes: list[type[Any]] = []
|
||||
warnings: list[Any] = []
|
||||
|
||||
def fake_load_config_from_file(config_class, config_path, new_ver, override_repr=False):
|
||||
loaded_config_classes.append(config_class)
|
||||
return object(), True
|
||||
|
||||
monkeypatch.setattr(config_module, "load_config_from_file", fake_load_config_from_file)
|
||||
monkeypatch.setattr(ConfigManager, "_warn_if_vlm_not_configured", lambda self, model_config: warnings.append(model_config))
|
||||
|
||||
manager.initialize()
|
||||
|
||||
assert loaded_config_classes == [Config, ModelConfig]
|
||||
assert warnings == [manager.model_config]
|
||||
138
pytests/config_test/test_file_watcher.py
Normal file
138
pytests/config_test/test_file_watcher.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from pathlib import Path
|
||||
|
||||
from watchfiles import Change
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
|
||||
from src.config.file_watcher import FileChange, FileWatcher
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_changes_with_path_and_change_type_filters(tmp_path: Path):
|
||||
watcher = FileWatcher(paths=[tmp_path])
|
||||
target_file = (tmp_path / "bot_config.toml").resolve()
|
||||
|
||||
received: list[list[FileChange]] = []
|
||||
|
||||
async def callback(changes):
|
||||
received.append(list(changes))
|
||||
|
||||
watcher.subscribe(callback, paths=[target_file], change_types=[Change.modified])
|
||||
|
||||
await watcher._dispatch_changes(
|
||||
[
|
||||
FileChange(change_type=Change.added, path=target_file),
|
||||
FileChange(change_type=Change.modified, path=target_file),
|
||||
FileChange(change_type=Change.modified, path=(tmp_path / "other.toml").resolve()),
|
||||
]
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert len(received[0]) == 1
|
||||
assert received[0][0].change_type == Change.modified
|
||||
assert received[0][0].path == target_file
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_callback_supported(tmp_path: Path):
|
||||
watcher = FileWatcher(paths=[tmp_path])
|
||||
target_file = (tmp_path / "model_config.toml").resolve()
|
||||
|
||||
received_paths: list[Path] = []
|
||||
|
||||
def sync_callback(changes):
|
||||
received_paths.extend(change.path for change in changes)
|
||||
|
||||
watcher.subscribe(sync_callback, paths=[target_file])
|
||||
|
||||
await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)])
|
||||
|
||||
assert received_paths == [target_file]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_timeout_and_cooldown(tmp_path: Path):
|
||||
watcher = FileWatcher(
|
||||
paths=[tmp_path],
|
||||
callback_timeout_s=0.05,
|
||||
callback_failure_threshold=2,
|
||||
callback_cooldown_s=0.2,
|
||||
)
|
||||
target_file = (tmp_path / "bot_config.toml").resolve()
|
||||
|
||||
async def slow_callback(changes):
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
watcher.subscribe(slow_callback, paths=[target_file])
|
||||
|
||||
await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)])
|
||||
await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)])
|
||||
|
||||
stats_after_failures = watcher.stats
|
||||
assert stats_after_failures.callbacks_timed_out == 2
|
||||
assert stats_after_failures.callbacks_failed == 2
|
||||
|
||||
await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)])
|
||||
stats_after_cooldown_skip = watcher.stats
|
||||
assert stats_after_cooldown_skip.callbacks_skipped_cooldown >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_requires_subscription(tmp_path: Path):
|
||||
watcher = FileWatcher(paths=[tmp_path])
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await watcher.start()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribe_stops_dispatch(tmp_path: Path):
|
||||
watcher = FileWatcher(paths=[tmp_path])
|
||||
target_file = (tmp_path / "bot_config.toml").resolve()
|
||||
|
||||
calls = 0
|
||||
|
||||
async def callback(changes):
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
|
||||
subscription_id = watcher.subscribe(callback, paths=[target_file])
|
||||
assert watcher.unsubscribe(subscription_id) is True
|
||||
|
||||
await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)])
|
||||
|
||||
assert calls == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_callback_while_watcher_running(tmp_path: Path):
|
||||
dirs = (tmp_path / "a_dir").resolve()
|
||||
dirs.mkdir(exist_ok=True)
|
||||
file = (dirs / "a.toml").resolve()
|
||||
file.touch()
|
||||
watcher = FileWatcher(paths=[dirs], debounce_ms=200)
|
||||
|
||||
calls = 0
|
||||
|
||||
async def callback(changes: Sequence[FileChange]):
|
||||
nonlocal calls
|
||||
print(f"Callback called with changes: {[f'{change.change_type} {change.path}' for change in changes]}")
|
||||
calls += 1
|
||||
|
||||
uuid = watcher.subscribe(callback, paths=[file])
|
||||
await watcher.start()
|
||||
try:
|
||||
with file.open("w") as f:
|
||||
f.write("change")
|
||||
await asyncio.sleep(0.5)
|
||||
assert calls == 1
|
||||
watcher.unsubscribe(uuid)
|
||||
with file.open("w") as f:
|
||||
f.write("change2")
|
||||
await asyncio.sleep(0.5)
|
||||
assert calls == 1
|
||||
finally:
|
||||
await watcher.stop()
|
||||
76
pytests/config_test/test_llm_request_hot_reload.py
Normal file
76
pytests/config_test/test_llm_request_hot_reload.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from types import SimpleNamespace
|
||||
from importlib import util
|
||||
from pathlib import Path
|
||||
|
||||
from src.config.config import config_manager
|
||||
from src.config.model_configs import TaskConfig
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
|
||||
def _load_llm_api_module():
|
||||
file_path = Path(__file__).parent.parent.parent / "src" / "plugin_system" / "apis" / "llm_api.py"
|
||||
spec = util.spec_from_file_location("test_llm_api_module", file_path)
|
||||
assert spec is not None and spec.loader is not None
|
||||
module = util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _make_model_config(task_config: TaskConfig, attr_name: str = "utils"):
|
||||
model_task_config = SimpleNamespace(**{attr_name: task_config})
|
||||
return SimpleNamespace(model_task_config=model_task_config, models=[], api_providers=[])
|
||||
|
||||
|
||||
def test_llm_request_resolve_task_config_by_signature(monkeypatch):
|
||||
old_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0)
|
||||
current_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0)
|
||||
|
||||
monkeypatch.setattr(config_manager, "get_model_config", lambda: _make_model_config(current_task, "utils"))
|
||||
|
||||
req = LLMRequest(model_set=old_task, request_type="test")
|
||||
|
||||
assert req._task_config_name == "utils"
|
||||
|
||||
|
||||
def test_llm_request_refresh_task_config_updates_runtime_state(monkeypatch):
|
||||
old_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0)
|
||||
initial_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0)
|
||||
updated_task = TaskConfig(model_list=["gpt-b", "gpt-c"], max_tokens=1024, temperature=0.5, slow_threshold=20.0)
|
||||
|
||||
current = {"task": initial_task}
|
||||
|
||||
def get_model_config_stub():
|
||||
return _make_model_config(current["task"], "replyer")
|
||||
|
||||
monkeypatch.setattr(config_manager, "get_model_config", get_model_config_stub)
|
||||
|
||||
req = LLMRequest(model_set=old_task, request_type="test")
|
||||
assert req._task_config_name == "replyer"
|
||||
|
||||
current["task"] = updated_task
|
||||
req._refresh_task_config()
|
||||
|
||||
assert req.model_for_task.model_list == ["gpt-b", "gpt-c"]
|
||||
assert list(req.model_usage.keys()) == ["gpt-b", "gpt-c"]
|
||||
|
||||
|
||||
def test_llm_api_get_available_models_reads_latest_config(monkeypatch):
|
||||
llm_api = _load_llm_api_module()
|
||||
|
||||
first_utils = TaskConfig(model_list=["gpt-a"])
|
||||
second_utils = TaskConfig(model_list=["gpt-z"])
|
||||
|
||||
state = {"task": first_utils}
|
||||
|
||||
def get_model_config_stub():
|
||||
model_task_config = SimpleNamespace(utils=state["task"], planner=TaskConfig(model_list=["gpt-p"]))
|
||||
return SimpleNamespace(model_task_config=model_task_config)
|
||||
|
||||
monkeypatch.setattr(config_manager, "get_model_config", get_model_config_stub)
|
||||
|
||||
first = llm_api.get_available_models()
|
||||
assert first["utils"].model_list == ["gpt-a"]
|
||||
|
||||
state["task"] = second_utils
|
||||
second = llm_api.get_available_models()
|
||||
assert second["utils"].model_list == ["gpt-z"]
|
||||
11
pytests/config_test/test_model_info_normalization.py
Normal file
11
pytests/config_test/test_model_info_normalization.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from src.config.model_configs import ModelInfo
|
||||
|
||||
|
||||
def test_model_identifier_strips_surrounding_whitespace() -> None:
|
||||
model_info = ModelInfo(
|
||||
api_provider="test-provider",
|
||||
model_identifier=" glm-5.1 ",
|
||||
name="test-model",
|
||||
)
|
||||
|
||||
assert model_info.model_identifier == "glm-5.1"
|
||||
104
pytests/config_test/test_startup_bindings.py
Normal file
104
pytests/config_test/test_startup_bindings.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
import sys
|
||||
|
||||
from src.config.legacy_migration import migrate_legacy_bind_env_to_bot_config_dict
|
||||
from src.config.startup_bindings import (
|
||||
BindAddress,
|
||||
get_startup_main_bind_address,
|
||||
get_startup_webui_bind_address,
|
||||
resolve_main_bind_address,
|
||||
resolve_webui_bind_address,
|
||||
)
|
||||
|
||||
|
||||
def test_startup_bindings_use_defaults_when_config_file_missing(tmp_path: Path):
|
||||
missing_path = tmp_path / "missing_bot_config.toml"
|
||||
|
||||
assert get_startup_main_bind_address(missing_path) == BindAddress(host="127.0.0.1", port=8080)
|
||||
assert get_startup_webui_bind_address(missing_path) == BindAddress(host="127.0.0.1", port=8001)
|
||||
|
||||
|
||||
def test_startup_bindings_can_read_addresses_from_bot_config(tmp_path: Path):
|
||||
config_path = tmp_path / "bot_config.toml"
|
||||
config_path.write_text(
|
||||
"""
|
||||
[inner]
|
||||
version = "8.3.1"
|
||||
|
||||
[maim_message]
|
||||
ws_server_host = "0.0.0.0"
|
||||
ws_server_port = 22345
|
||||
|
||||
[webui]
|
||||
host = "192.168.1.9"
|
||||
port = 18001
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
assert get_startup_main_bind_address(config_path) == BindAddress(host="0.0.0.0", port=22345)
|
||||
assert get_startup_webui_bind_address(config_path) == BindAddress(host="192.168.1.9", port=18001)
|
||||
|
||||
|
||||
def test_resolve_bindings_prefer_initialized_global_config(monkeypatch):
|
||||
fake_config_module = SimpleNamespace(
|
||||
global_config=SimpleNamespace(
|
||||
maim_message=SimpleNamespace(ws_server_host="10.0.0.2", ws_server_port=32000),
|
||||
webui=SimpleNamespace(host="10.0.0.3", port=32001),
|
||||
)
|
||||
)
|
||||
|
||||
monkeypatch.setitem(sys.modules, "src.config.config", fake_config_module)
|
||||
|
||||
assert resolve_main_bind_address() == BindAddress(host="10.0.0.2", port=32000)
|
||||
assert resolve_webui_bind_address() == BindAddress(host="10.0.0.3", port=32001)
|
||||
|
||||
|
||||
def test_legacy_env_bindings_are_migrated_when_fields_missing_or_default(monkeypatch):
|
||||
monkeypatch.setenv("HOST", "0.0.0.0")
|
||||
monkeypatch.setenv("PORT", "22345")
|
||||
monkeypatch.setenv("WEBUI_HOST", "192.168.1.8")
|
||||
monkeypatch.setenv("WEBUI_PORT", "19001")
|
||||
|
||||
payload = {
|
||||
"maim_message": {
|
||||
"ws_server_host": "127.0.0.1",
|
||||
"ws_server_port": 8080,
|
||||
},
|
||||
"webui": {},
|
||||
}
|
||||
|
||||
result = migrate_legacy_bind_env_to_bot_config_dict(payload)
|
||||
|
||||
assert result.migrated is True
|
||||
assert payload["maim_message"]["ws_server_host"] == "0.0.0.0"
|
||||
assert payload["maim_message"]["ws_server_port"] == 22345
|
||||
assert payload["webui"]["host"] == "192.168.1.8"
|
||||
assert payload["webui"]["port"] == 19001
|
||||
|
||||
|
||||
def test_legacy_env_bindings_do_not_override_explicit_config(monkeypatch):
|
||||
monkeypatch.setenv("HOST", "0.0.0.0")
|
||||
monkeypatch.setenv("PORT", "22345")
|
||||
monkeypatch.setenv("WEBUI_HOST", "192.168.1.8")
|
||||
monkeypatch.setenv("WEBUI_PORT", "19001")
|
||||
|
||||
payload = {
|
||||
"maim_message": {
|
||||
"ws_server_host": "10.1.1.1",
|
||||
"ws_server_port": 30000,
|
||||
},
|
||||
"webui": {
|
||||
"host": "10.1.1.2",
|
||||
"port": 30001,
|
||||
},
|
||||
}
|
||||
|
||||
result = migrate_legacy_bind_env_to_bot_config_dict(payload)
|
||||
|
||||
assert result.migrated is False
|
||||
assert payload["maim_message"]["ws_server_host"] == "10.1.1.1"
|
||||
assert payload["maim_message"]["ws_server_port"] == 30000
|
||||
assert payload["webui"]["host"] == "10.1.1.2"
|
||||
assert payload["webui"]["port"] == 30001
|
||||
Reference in New Issue
Block a user