chore: import deployable mai-bot source tree
This commit is contained in:
0
pytests/webui/__init__.py
Normal file
0
pytests/webui/__init__.py
Normal file
161
pytests/webui/test_app.py
Normal file
161
pytests/webui/test_app.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.webui import app as webui_app
|
||||
|
||||
|
||||
def test_ensure_static_path_ready_uses_existing_static_path(tmp_path) -> None:
|
||||
static_path = tmp_path / "dist"
|
||||
static_path.mkdir()
|
||||
(static_path / "index.html").write_text("<html></html>", encoding="utf-8")
|
||||
|
||||
with patch.object(webui_app, "_resolve_static_path", return_value=static_path):
|
||||
result = webui_app._ensure_static_path_ready()
|
||||
|
||||
assert result == static_path
|
||||
|
||||
|
||||
def test_ensure_static_path_ready_logs_install_hint_when_static_assets_are_missing() -> None:
|
||||
with (
|
||||
patch.object(webui_app, "_resolve_static_path", return_value=None),
|
||||
patch.object(webui_app.logger, "warning") as warning_mock,
|
||||
):
|
||||
result = webui_app._ensure_static_path_ready()
|
||||
|
||||
assert result is None
|
||||
warning_mock.assert_any_call(webui_app.t("startup.webui_static_assets_unavailable"))
|
||||
warning_mock.assert_any_call(
|
||||
webui_app.t("startup.webui_dashboard_package_hint", command=webui_app._MANUAL_INSTALL_COMMAND)
|
||||
)
|
||||
|
||||
|
||||
def test_ensure_static_path_ready_logs_index_error_when_static_path_is_invalid(tmp_path) -> None:
|
||||
static_path = tmp_path / "dist"
|
||||
static_path.mkdir()
|
||||
|
||||
with (
|
||||
patch.object(webui_app, "_resolve_static_path", return_value=static_path),
|
||||
patch.object(webui_app.logger, "warning") as warning_mock,
|
||||
):
|
||||
result = webui_app._ensure_static_path_ready()
|
||||
|
||||
assert result is None
|
||||
warning_mock.assert_any_call(
|
||||
webui_app.t("startup.webui_index_missing", index_path=static_path / "index.html")
|
||||
)
|
||||
warning_mock.assert_any_call(
|
||||
webui_app.t("startup.webui_dashboard_package_hint", command=webui_app._MANUAL_INSTALL_COMMAND)
|
||||
)
|
||||
|
||||
|
||||
def test_setup_static_files_does_not_duplicate_warning_when_static_path_is_unavailable() -> None:
|
||||
app = webui_app.FastAPI()
|
||||
|
||||
with (
|
||||
patch.object(webui_app, "_ensure_static_path_ready", return_value=None),
|
||||
patch.object(webui_app.logger, "warning") as warning_mock,
|
||||
):
|
||||
webui_app._setup_static_files(app)
|
||||
|
||||
warning_mock.assert_not_called()
|
||||
|
||||
|
||||
def test_resolve_static_path_prefers_installed_dashboard_package(monkeypatch, tmp_path) -> None:
|
||||
package_dist = tmp_path / "site-packages" / "maibot_dashboard" / "dist"
|
||||
package_dist.mkdir(parents=True)
|
||||
|
||||
class _DashboardModule:
|
||||
@staticmethod
|
||||
def get_dist_path() -> Path:
|
||||
return package_dist
|
||||
|
||||
monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path)
|
||||
|
||||
with patch.object(webui_app, "import_module", return_value=_DashboardModule()):
|
||||
resolved_path = webui_app._resolve_static_path()
|
||||
|
||||
assert resolved_path == package_dist
|
||||
|
||||
|
||||
def test_resolve_static_path_ignores_dashboard_dist_when_package_is_unavailable(monkeypatch, tmp_path) -> None:
|
||||
dashboard_dist = tmp_path / "dashboard" / "dist"
|
||||
dashboard_dist.mkdir(parents=True)
|
||||
(dashboard_dist / "index.html").write_text("<html></html>", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path)
|
||||
|
||||
with patch.object(webui_app, "import_module", side_effect=ImportError):
|
||||
resolved_path = webui_app._resolve_static_path()
|
||||
|
||||
assert resolved_path is None
|
||||
|
||||
|
||||
def test_resolve_static_path_uses_package_even_when_dashboard_dist_exists(monkeypatch, tmp_path) -> None:
|
||||
dashboard_dist = tmp_path / "dashboard" / "dist"
|
||||
dashboard_dist.mkdir(parents=True)
|
||||
|
||||
package_dist = tmp_path / "site-packages" / "maibot_dashboard" / "dist"
|
||||
package_dist.mkdir(parents=True)
|
||||
|
||||
class _DashboardModule:
|
||||
@staticmethod
|
||||
def get_dist_path() -> Path:
|
||||
return package_dist
|
||||
|
||||
monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path)
|
||||
|
||||
with patch.object(webui_app, "import_module", return_value=_DashboardModule()):
|
||||
resolved_path = webui_app._resolve_static_path()
|
||||
|
||||
assert resolved_path == package_dist
|
||||
|
||||
|
||||
def test_resolve_safe_static_file_path_allows_regular_static_file(tmp_path) -> None:
|
||||
static_path = tmp_path / "dist"
|
||||
asset_path = static_path / "assets" / "app.js"
|
||||
asset_path.parent.mkdir(parents=True)
|
||||
asset_path.write_text("console.log('ok')", encoding="utf-8")
|
||||
|
||||
resolved_path = webui_app._resolve_safe_static_file_path(static_path, "assets/app.js")
|
||||
|
||||
assert resolved_path == asset_path.resolve()
|
||||
|
||||
|
||||
def test_resolve_safe_static_file_path_rejects_relative_path_traversal(tmp_path) -> None:
|
||||
static_path = tmp_path / "dist"
|
||||
static_path.mkdir()
|
||||
|
||||
resolved_path = webui_app._resolve_safe_static_file_path(static_path, "../secret.txt")
|
||||
|
||||
assert resolved_path is None
|
||||
|
||||
|
||||
def test_resolve_safe_static_file_path_rejects_absolute_path_traversal(tmp_path) -> None:
|
||||
static_path = tmp_path / "dist"
|
||||
static_path.mkdir()
|
||||
|
||||
resolved_path = webui_app._resolve_safe_static_file_path(static_path, "/etc/passwd")
|
||||
|
||||
assert resolved_path is None
|
||||
|
||||
|
||||
def test_resolve_safe_static_file_path_rejects_symlink_escape(tmp_path) -> None:
|
||||
static_path = tmp_path / "dist"
|
||||
static_path.mkdir()
|
||||
|
||||
outside_dir = tmp_path / "outside"
|
||||
outside_dir.mkdir()
|
||||
outside_file = outside_dir / "secret.txt"
|
||||
outside_file.write_text("secret", encoding="utf-8")
|
||||
|
||||
link_path = static_path / "escape"
|
||||
try:
|
||||
link_path.symlink_to(outside_dir, target_is_directory=True)
|
||||
except OSError as exc:
|
||||
pytest.skip(f"symlink is not supported in this environment: {exc}")
|
||||
|
||||
resolved_path = webui_app._resolve_safe_static_file_path(static_path, "escape/secret.txt")
|
||||
|
||||
assert resolved_path is None
|
||||
147
pytests/webui/test_config_schema.py
Normal file
147
pytests/webui/test_config_schema.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from src.config.official_configs import ChatConfig, MessageReceiveConfig
|
||||
from src.config.config import Config
|
||||
from src.config.config_base import ConfigBase, Field
|
||||
from src.webui.config_schema import ConfigSchemaGenerator
|
||||
|
||||
|
||||
def test_field_docs_in_schema():
|
||||
"""Test that field descriptions are correctly extracted from field_docs (docstrings)."""
|
||||
schema = ConfigSchemaGenerator.generate_schema(ChatConfig)
|
||||
talk_value = next(f for f in schema["fields"] if f["name"] == "talk_value")
|
||||
|
||||
# Verify description field exists
|
||||
assert "description" in talk_value
|
||||
# Verify description contains expected Chinese text from the docstring
|
||||
assert "聊天频率" in talk_value["description"]
|
||||
|
||||
|
||||
def test_json_schema_extra_merged():
|
||||
"""Test that json_schema_extra fields are correctly merged into output."""
|
||||
schema = ConfigSchemaGenerator.generate_schema(ChatConfig)
|
||||
talk_value = next(f for f in schema["fields"] if f["name"] == "talk_value")
|
||||
|
||||
# Verify UI metadata fields from json_schema_extra exist
|
||||
assert talk_value.get("x-widget") == "slider"
|
||||
assert talk_value.get("x-icon") == "message-circle"
|
||||
assert talk_value.get("step") == 0.1
|
||||
|
||||
|
||||
def test_pydantic_constraints_mapped():
|
||||
"""Test that Pydantic constraints (ge/le) are correctly mapped to minValue/maxValue."""
|
||||
schema = ConfigSchemaGenerator.generate_schema(ChatConfig)
|
||||
talk_value = next(f for f in schema["fields"] if f["name"] == "talk_value")
|
||||
|
||||
# Verify constraints are mapped to frontend naming convention
|
||||
assert "minValue" in talk_value
|
||||
assert "maxValue" in talk_value
|
||||
assert talk_value["minValue"] == 0 # From ge=0
|
||||
assert talk_value["maxValue"] == 1 # From le=1
|
||||
|
||||
|
||||
def test_nested_model_schema():
|
||||
"""Test that nested models (ConfigBase fields) are correctly handled."""
|
||||
schema = ConfigSchemaGenerator.generate_schema(Config)
|
||||
|
||||
# Verify nested structure exists
|
||||
assert "nested" in schema
|
||||
assert "chat" in schema["nested"]
|
||||
|
||||
# Verify nested chat schema is complete
|
||||
chat_schema = schema["nested"]["chat"]
|
||||
assert chat_schema["className"] == "ChatConfig"
|
||||
assert "fields" in chat_schema
|
||||
|
||||
# Verify nested schema fields include description and metadata
|
||||
talk_value = next(f for f in chat_schema["fields"] if f["name"] == "talk_value")
|
||||
assert "description" in talk_value
|
||||
assert talk_value.get("x-widget") == "slider"
|
||||
assert talk_value.get("minValue") == 0
|
||||
|
||||
|
||||
def test_field_without_extra_metadata():
|
||||
"""Test that fields without json_schema_extra still generate valid schema."""
|
||||
schema = ConfigSchemaGenerator.generate_schema(ChatConfig)
|
||||
inevitable_at_reply = next(f for f in schema["fields"] if f["name"] == "inevitable_at_reply")
|
||||
|
||||
# Verify basic fields are generated
|
||||
assert "name" in inevitable_at_reply
|
||||
assert inevitable_at_reply["name"] == "inevitable_at_reply"
|
||||
assert "type" in inevitable_at_reply
|
||||
assert inevitable_at_reply["type"] == "boolean"
|
||||
assert "label" in inevitable_at_reply
|
||||
assert "required" in inevitable_at_reply
|
||||
|
||||
# Verify no x-widget or x-icon from json_schema_extra (since field has none)
|
||||
# These fields should only be present if explicitly defined in json_schema_extra
|
||||
assert not inevitable_at_reply.get("x-widget")
|
||||
assert not inevitable_at_reply.get("x-icon")
|
||||
|
||||
|
||||
def test_all_top_level_sections_have_ui_metadata():
|
||||
"""所有顶层配置节都必须声明 uiParent 或独立 Tab 的标签与图标。"""
|
||||
schema = ConfigSchemaGenerator.generate_schema(Config)
|
||||
|
||||
for section_name, section_schema in schema["nested"].items():
|
||||
has_parent = bool(section_schema.get("uiParent"))
|
||||
has_host_meta = bool(section_schema.get("uiLabel")) and bool(section_schema.get("uiIcon"))
|
||||
assert has_parent or has_host_meta, f"{section_name} 缺少 UI 元数据"
|
||||
|
||||
|
||||
def test_maisaka_is_host_tab_and_mcp_is_attached_to_it():
|
||||
"""MaiSaka 应作为独立 Tab,MCP 作为其子配置挂载。"""
|
||||
schema = ConfigSchemaGenerator.generate_schema(Config)
|
||||
|
||||
maisaka_schema = schema["nested"]["maisaka"]
|
||||
mcp_schema = schema["nested"]["mcp"]
|
||||
|
||||
assert maisaka_schema.get("uiParent") is None
|
||||
assert maisaka_schema.get("uiLabel") == "MaiSaka"
|
||||
assert maisaka_schema.get("uiIcon") == "message-circle"
|
||||
assert mcp_schema.get("uiParent") == "maisaka"
|
||||
|
||||
|
||||
def test_memory_query_config_fields_are_exposed():
|
||||
"""query_memory 开关和默认条数应出现在记忆配置 schema 中。"""
|
||||
schema = ConfigSchemaGenerator.generate_schema(Config)
|
||||
memory_schema = schema["nested"]["memory"]
|
||||
|
||||
assert memory_schema.get("uiParent") == "emoji"
|
||||
|
||||
enable_field = next(field for field in memory_schema["fields"] if field["name"] == "enable_memory_query_tool")
|
||||
limit_field = next(field for field in memory_schema["fields"] if field["name"] == "memory_query_default_limit")
|
||||
|
||||
assert enable_field["type"] == "boolean"
|
||||
assert enable_field.get("x-widget") == "switch"
|
||||
assert enable_field.get("x-icon") == "database"
|
||||
|
||||
assert limit_field["type"] == "integer"
|
||||
assert limit_field.get("x-widget") == "input"
|
||||
assert limit_field.get("x-icon") == "hash"
|
||||
assert limit_field.get("minValue") == 1
|
||||
assert limit_field.get("maxValue") == 20
|
||||
|
||||
|
||||
def test_set_field_is_mapped_as_array():
|
||||
"""set[str] 应映射为前端可识别的 array。"""
|
||||
schema = ConfigSchemaGenerator.generate_schema(MessageReceiveConfig)
|
||||
ban_words = next(field for field in schema["fields"] if field["name"] == "ban_words")
|
||||
|
||||
assert ban_words["type"] == "array"
|
||||
assert ban_words["items"]["type"] == "string"
|
||||
|
||||
|
||||
def test_advanced_fields_are_hidden_from_webui_schema():
|
||||
"""advanced=True 的字段不应出现在 WebUI 配置 schema 中,未声明时默认展示。"""
|
||||
|
||||
class AdvancedExampleConfig(ConfigBase):
|
||||
normal_field: str = Field(default="visible")
|
||||
"""普通字段"""
|
||||
|
||||
advanced_field: str = Field(default="hidden", json_schema_extra={"advanced": True})
|
||||
"""高级字段"""
|
||||
|
||||
schema = ConfigSchemaGenerator.generate_schema(AdvancedExampleConfig)
|
||||
field_names = {field["name"] for field in schema["fields"]}
|
||||
|
||||
assert "normal_field" in field_names
|
||||
assert "advanced_field" not in field_names
|
||||
461
pytests/webui/test_emoji_routes.py
Normal file
461
pytests/webui/test_emoji_routes.py
Normal file
@@ -0,0 +1,461 @@
|
||||
"""表情包路由 API 测试
|
||||
|
||||
测试 src/webui/routers/emoji.py 中的核心 emoji 路由端点
|
||||
使用内存 SQLite 数据库和 FastAPI TestClient
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
from src.webui.core import TokenManager
|
||||
from src.webui.routers.emoji import router
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def test_engine():
|
||||
"""创建内存 SQLite 引擎用于测试"""
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def test_session(test_engine) -> Generator[Session, None, None]:
|
||||
"""创建测试数据库会话"""
|
||||
with Session(test_engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def test_app(test_session):
|
||||
"""创建测试 FastAPI 应用并覆盖 get_db_session 依赖"""
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
# Create a context manager that yields the test session
|
||||
@contextmanager
|
||||
def override_get_db_session(auto_commit=True):
|
||||
"""Override get_db_session to use test session"""
|
||||
try:
|
||||
yield test_session
|
||||
if auto_commit:
|
||||
test_session.commit()
|
||||
except Exception:
|
||||
test_session.rollback()
|
||||
raise
|
||||
|
||||
with patch("src.webui.routers.emoji.get_db_session", override_get_db_session):
|
||||
yield app
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client(test_app):
|
||||
"""创建 TestClient"""
|
||||
return TestClient(test_app)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def auth_token():
|
||||
"""创建有效的认证 token"""
|
||||
token_manager = TokenManager(secret_key="test-secret-key", token_expire_hours=24)
|
||||
return token_manager.create_token()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def sample_emojis(test_session) -> list[Images]:
|
||||
"""插入测试用表情包数据"""
|
||||
import hashlib
|
||||
|
||||
emojis = [
|
||||
Images(
|
||||
image_type=ImageType.EMOJI,
|
||||
full_path="/data/emoji_registed/test1.png",
|
||||
image_hash=hashlib.sha256(b"test1").hexdigest(),
|
||||
description="测试表情包 1",
|
||||
emotion="开心,快乐",
|
||||
query_count=10,
|
||||
is_registered=True,
|
||||
is_banned=False,
|
||||
record_time=datetime(2026, 1, 1, 10, 0, 0),
|
||||
register_time=datetime(2026, 1, 1, 10, 0, 0),
|
||||
last_used_time=datetime(2026, 1, 2, 10, 0, 0),
|
||||
),
|
||||
Images(
|
||||
image_type=ImageType.EMOJI,
|
||||
full_path="/data/emoji_registed/test2.gif",
|
||||
image_hash=hashlib.sha256(b"test2").hexdigest(),
|
||||
description="测试表情包 2",
|
||||
emotion="难过",
|
||||
query_count=5,
|
||||
is_registered=False,
|
||||
is_banned=False,
|
||||
record_time=datetime(2026, 1, 3, 10, 0, 0),
|
||||
register_time=None,
|
||||
last_used_time=None,
|
||||
),
|
||||
Images(
|
||||
image_type=ImageType.EMOJI,
|
||||
full_path="/data/emoji_registed/test3.webp",
|
||||
image_hash=hashlib.sha256(b"test3").hexdigest(),
|
||||
description="测试表情包 3",
|
||||
emotion="生气",
|
||||
query_count=20,
|
||||
is_registered=True,
|
||||
is_banned=True,
|
||||
record_time=datetime(2026, 1, 4, 10, 0, 0),
|
||||
register_time=datetime(2026, 1, 4, 10, 0, 0),
|
||||
last_used_time=datetime(2026, 1, 5, 10, 0, 0),
|
||||
),
|
||||
]
|
||||
|
||||
for emoji in emojis:
|
||||
test_session.add(emoji)
|
||||
test_session.commit()
|
||||
|
||||
for emoji in emojis:
|
||||
test_session.refresh(emoji)
|
||||
|
||||
return emojis
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def mock_token_verify():
|
||||
"""Mock token verification to always succeed"""
|
||||
with patch("src.webui.routers.emoji.verify_auth_token", return_value=True):
|
||||
yield
|
||||
|
||||
|
||||
# ==================== 测试用例 ====================
|
||||
|
||||
|
||||
def test_list_emojis_basic(client, sample_emojis, mock_token_verify):
|
||||
"""测试获取表情包列表(基本分页)"""
|
||||
response = client.get("/emoji/list?page=1&page_size=10")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["success"] is True
|
||||
assert data["total"] == 3
|
||||
assert data["page"] == 1
|
||||
assert data["page_size"] == 10
|
||||
assert len(data["data"]) == 3
|
||||
|
||||
# 验证第一个表情包字段
|
||||
emoji = data["data"][0]
|
||||
assert "id" in emoji
|
||||
assert "full_path" in emoji
|
||||
assert "emoji_hash" in emoji
|
||||
assert "description" in emoji
|
||||
assert "query_count" in emoji
|
||||
assert "is_registered" in emoji
|
||||
assert "is_banned" in emoji
|
||||
assert "emotion" in emoji
|
||||
assert "record_time" in emoji
|
||||
assert "register_time" in emoji
|
||||
assert "last_used_time" in emoji
|
||||
|
||||
|
||||
def test_list_emojis_pagination(client, sample_emojis, mock_token_verify):
|
||||
"""测试分页功能"""
|
||||
response = client.get("/emoji/list?page=1&page_size=2")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["success"] is True
|
||||
assert data["total"] == 3
|
||||
assert len(data["data"]) == 2
|
||||
|
||||
# 第二页
|
||||
response = client.get("/emoji/list?page=2&page_size=2")
|
||||
data = response.json()
|
||||
assert len(data["data"]) == 1
|
||||
|
||||
|
||||
def test_list_emojis_search(client, sample_emojis, mock_token_verify):
|
||||
"""测试搜索过滤"""
|
||||
response = client.get("/emoji/list?search=表情包 2")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["success"] is True
|
||||
assert data["total"] == 1
|
||||
assert data["data"][0]["description"] == "测试表情包 2"
|
||||
|
||||
|
||||
def test_list_emojis_filter_registered(client, sample_emojis, mock_token_verify):
|
||||
"""测试 is_registered 过滤"""
|
||||
response = client.get("/emoji/list?is_registered=true")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["success"] is True
|
||||
assert data["total"] == 2
|
||||
assert all(emoji["is_registered"] is True for emoji in data["data"])
|
||||
|
||||
|
||||
def test_list_emojis_filter_banned(client, sample_emojis, mock_token_verify):
|
||||
"""测试 is_banned 过滤"""
|
||||
response = client.get("/emoji/list?is_banned=true")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["success"] is True
|
||||
assert data["total"] == 1
|
||||
assert data["data"][0]["is_banned"] is True
|
||||
|
||||
|
||||
def test_list_emojis_sort_by_query_count(client, sample_emojis, mock_token_verify):
|
||||
"""测试按 query_count 排序"""
|
||||
response = client.get("/emoji/list?sort_by=query_count&sort_order=desc")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["success"] is True
|
||||
# 验证降序排列 (20 > 10 > 5)
|
||||
assert data["data"][0]["query_count"] == 20
|
||||
assert data["data"][1]["query_count"] == 10
|
||||
assert data["data"][2]["query_count"] == 5
|
||||
|
||||
|
||||
def test_get_emoji_detail_success(client, sample_emojis, mock_token_verify):
|
||||
"""测试获取表情包详情(成功)"""
|
||||
emoji_id = sample_emojis[0].id
|
||||
response = client.get(f"/emoji/{emoji_id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["success"] is True
|
||||
assert data["data"]["id"] == emoji_id
|
||||
assert data["data"]["emoji_hash"] == sample_emojis[0].image_hash
|
||||
|
||||
|
||||
def test_get_emoji_detail_not_found(client, mock_token_verify):
|
||||
"""测试获取不存在的表情包(404)"""
|
||||
response = client.get("/emoji/99999")
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert "未找到" in data["detail"]
|
||||
|
||||
|
||||
def test_update_emoji_description(client, sample_emojis, mock_token_verify):
|
||||
"""测试更新表情包描述"""
|
||||
emoji_id = sample_emojis[0].id
|
||||
response = client.patch(
|
||||
f"/emoji/{emoji_id}",
|
||||
json={"description": "更新后的描述"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["success"] is True
|
||||
assert data["data"]["description"] == "更新后的描述"
|
||||
assert "成功更新" in data["message"]
|
||||
|
||||
|
||||
def test_update_emoji_register_status(client, sample_emojis, mock_token_verify, test_session):
|
||||
"""测试更新注册状态(False -> True 应设置 register_time)"""
|
||||
emoji_id = sample_emojis[1].id # 未注册的表情包
|
||||
response = client.patch(
|
||||
f"/emoji/{emoji_id}",
|
||||
json={"is_registered": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["success"] is True
|
||||
assert data["data"]["is_registered"] is True
|
||||
assert data["data"]["register_time"] is not None # 应该设置了注册时间
|
||||
|
||||
|
||||
def test_update_emoji_no_fields(client, sample_emojis, mock_token_verify):
|
||||
"""测试更新请求未提供任何字段(400)"""
|
||||
emoji_id = sample_emojis[0].id
|
||||
response = client.patch(f"/emoji/{emoji_id}", json={})
|
||||
|
||||
assert response.status_code == 400
|
||||
data = response.json()
|
||||
assert "未提供任何需要更新的字段" in data["detail"]
|
||||
|
||||
|
||||
def test_update_emoji_not_found(client, mock_token_verify):
|
||||
"""测试更新不存在的表情包(404)"""
|
||||
response = client.patch("/emoji/99999", json={"description": "test"})
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert "未找到" in data["detail"]
|
||||
|
||||
|
||||
def test_delete_emoji_success(client, sample_emojis, mock_token_verify, test_session):
|
||||
"""测试删除表情包(成功)"""
|
||||
emoji_id = sample_emojis[0].id
|
||||
response = client.delete(f"/emoji/{emoji_id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["success"] is True
|
||||
assert "成功删除" in data["message"]
|
||||
|
||||
# 验证数据库中已删除
|
||||
from sqlmodel import select
|
||||
|
||||
statement = select(Images).where(Images.id == emoji_id)
|
||||
result = test_session.exec(statement).first()
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_delete_emoji_not_found(client, mock_token_verify):
|
||||
"""测试删除不存在的表情包(404)"""
|
||||
response = client.delete("/emoji/99999")
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert "未找到" in data["detail"]
|
||||
|
||||
|
||||
def test_batch_delete_success(client, sample_emojis, mock_token_verify, test_session):
|
||||
"""测试批量删除表情包(全部成功)"""
|
||||
emoji_ids = [sample_emojis[0].id, sample_emojis[1].id]
|
||||
response = client.post("/emoji/batch/delete", json={"emoji_ids": emoji_ids})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["success"] is True
|
||||
assert data["deleted_count"] == 2
|
||||
assert data["failed_count"] == 0
|
||||
assert "成功删除 2 个表情包" in data["message"]
|
||||
|
||||
# 验证数据库中已删除
|
||||
from sqlmodel import select
|
||||
|
||||
for emoji_id in emoji_ids:
|
||||
statement = select(Images).where(Images.id == emoji_id)
|
||||
result = test_session.exec(statement).first()
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_batch_delete_partial_failure(client, sample_emojis, mock_token_verify):
|
||||
"""测试批量删除(部分失败)"""
|
||||
emoji_ids = [sample_emojis[0].id, 99999] # 第二个 ID 不存在
|
||||
response = client.post("/emoji/batch/delete", json={"emoji_ids": emoji_ids})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["success"] is True
|
||||
assert data["deleted_count"] == 1
|
||||
assert data["failed_count"] == 1
|
||||
assert 99999 in data["failed_ids"]
|
||||
|
||||
|
||||
def test_batch_delete_empty_list(client, mock_token_verify):
|
||||
"""测试批量删除空列表(400)"""
|
||||
response = client.post("/emoji/batch/delete", json={"emoji_ids": []})
|
||||
|
||||
assert response.status_code == 400
|
||||
data = response.json()
|
||||
assert "未提供要删除的表情包ID" in data["detail"]
|
||||
|
||||
|
||||
def test_auth_required_list(client):
|
||||
"""测试未认证访问列表端点(401)"""
|
||||
# Without mock_token_verify fixture
|
||||
with patch("src.webui.routers.emoji.verify_auth_token", return_value=False):
|
||||
client.get("/emoji/list")
|
||||
# verify_auth_token 返回 False 会触发 HTTPException
|
||||
# 但具体状态码取决于 verify_auth_token_from_cookie_or_header 的实现
|
||||
# 这里假设它抛出 401
|
||||
|
||||
|
||||
def test_auth_required_update(client, sample_emojis):
|
||||
"""测试未认证访问更新端点(401)"""
|
||||
with patch("src.webui.routers.emoji.verify_auth_token", return_value=False):
|
||||
emoji_id = sample_emojis[0].id
|
||||
client.patch(f"/emoji/{emoji_id}", json={"description": "test"})
|
||||
# Should be unauthorized
|
||||
|
||||
|
||||
def test_emoji_to_response_field_mapping(sample_emojis):
|
||||
"""测试 emoji_to_response 字段映射(image_hash -> emoji_hash)"""
|
||||
from src.webui.routers.emoji import emoji_to_response
|
||||
|
||||
emoji = sample_emojis[0]
|
||||
response = emoji_to_response(emoji)
|
||||
|
||||
# 验证 API 字段名称
|
||||
assert hasattr(response, "emoji_hash")
|
||||
assert response.emoji_hash == emoji.image_hash
|
||||
|
||||
# 验证时间戳转换
|
||||
assert isinstance(response.record_time, float)
|
||||
assert response.record_time == emoji.record_time.timestamp()
|
||||
|
||||
if emoji.register_time:
|
||||
assert isinstance(response.register_time, float)
|
||||
assert response.register_time == emoji.register_time.timestamp()
|
||||
|
||||
|
||||
def test_list_emojis_only_emoji_type(client, test_session, mock_token_verify):
|
||||
"""测试列表只返回 type=EMOJI 的记录(不包括其他类型)"""
|
||||
# 插入一个非 EMOJI 类型的图片
|
||||
non_emoji = Images(
|
||||
image_type=ImageType.IMAGE, # 不是 EMOJI
|
||||
full_path="/data/images/test.png",
|
||||
image_hash="hash_image",
|
||||
description="非表情包图片",
|
||||
query_count=0,
|
||||
is_registered=False,
|
||||
is_banned=False,
|
||||
record_time=datetime.now(),
|
||||
)
|
||||
test_session.add(non_emoji)
|
||||
test_session.commit()
|
||||
|
||||
# 插入一个 EMOJI 类型
|
||||
emoji = Images(
|
||||
image_type=ImageType.EMOJI,
|
||||
full_path="/data/emoji_registed/emoji.png",
|
||||
image_hash="hash_emoji",
|
||||
description="表情包",
|
||||
query_count=0,
|
||||
is_registered=True,
|
||||
is_banned=False,
|
||||
record_time=datetime.now(),
|
||||
)
|
||||
test_session.add(emoji)
|
||||
test_session.commit()
|
||||
|
||||
response = client.get("/emoji/list")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# 只应该返回 1 个 EMOJI 类型的记录
|
||||
assert data["total"] == 1
|
||||
assert data["data"][0]["description"] == "表情包"
|
||||
529
pytests/webui/test_expression_routes.py
Normal file
529
pytests/webui/test_expression_routes.py
Normal file
@@ -0,0 +1,529 @@
|
||||
"""Expression routes pytest tests"""
|
||||
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import Session, SQLModel, create_engine, select
|
||||
|
||||
from src.common.database.database_model import Expression, ModifiedBy
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
|
||||
def create_test_app() -> FastAPI:
|
||||
"""Create minimal test app with only expression router"""
|
||||
app = FastAPI(title="Test App")
|
||||
from src.webui.routers.expression import router as expression_router
|
||||
|
||||
main_router = APIRouter(prefix="/api/webui")
|
||||
main_router.include_router(expression_router)
|
||||
app.include_router(main_router)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_test_app()
|
||||
|
||||
|
||||
# Test database setup
|
||||
@pytest.fixture(name="test_engine")
|
||||
def test_engine_fixture():
|
||||
"""Create in-memory SQLite database for testing"""
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture(name="test_session")
|
||||
def test_session_fixture(test_engine) -> Generator[Session, None, None]:
|
||||
"""Create a test database session with transaction rollback"""
|
||||
connection = test_engine.connect()
|
||||
transaction = connection.begin()
|
||||
session = Session(bind=connection)
|
||||
|
||||
yield session
|
||||
|
||||
session.close()
|
||||
transaction.rollback()
|
||||
connection.close()
|
||||
|
||||
|
||||
@pytest.fixture(name="client")
|
||||
def client_fixture(test_session: Session, monkeypatch) -> Generator[TestClient, None, None]:
|
||||
"""Create TestClient with overridden database session"""
|
||||
from contextlib import contextmanager
|
||||
|
||||
@contextmanager
|
||||
def get_test_db_session():
|
||||
yield test_session
|
||||
test_session.commit()
|
||||
|
||||
monkeypatch.setattr("src.webui.routers.expression.get_db_session", get_test_db_session)
|
||||
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(name="mock_auth")
|
||||
def mock_auth_fixture():
|
||||
"""Mock authentication to always return True"""
|
||||
app.dependency_overrides[require_auth] = lambda: "test-token"
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest.fixture(name="sample_expression")
|
||||
def sample_expression_fixture(test_session: Session) -> Expression:
|
||||
"""Insert a sample expression into test database"""
|
||||
test_session.execute(
|
||||
text(
|
||||
"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
|
||||
"VALUES (1, '测试情景', '测试风格', '[\"测试内容1\", \"测试内容2\"]', 10, '2026-02-17 12:00:00', '2026-02-15 10:00:00', 'test_chat_001', 0, 0)"
|
||||
)
|
||||
)
|
||||
test_session.commit()
|
||||
|
||||
expression = test_session.exec(select(Expression).where(Expression.id == 1)).first()
|
||||
assert expression is not None
|
||||
return expression
|
||||
|
||||
|
||||
# ============ Tests ============
|
||||
|
||||
|
||||
def test_list_expressions_empty(client: TestClient, mock_auth):
|
||||
"""Test GET /expression/list with empty database"""
|
||||
response = client.get("/api/webui/expression/list")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["total"] == 0
|
||||
assert data["page"] == 1
|
||||
assert data["page_size"] == 20
|
||||
assert data["data"] == []
|
||||
|
||||
|
||||
def test_list_expressions_with_data(client: TestClient, mock_auth, sample_expression: Expression):
|
||||
"""Test GET /expression/list returns expression data"""
|
||||
response = client.get("/api/webui/expression/list")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["total"] == 1
|
||||
assert len(data["data"]) == 1
|
||||
|
||||
expr_data = data["data"][0]
|
||||
assert expr_data["id"] == sample_expression.id
|
||||
assert expr_data["situation"] == "测试情景"
|
||||
assert expr_data["style"] == "测试风格"
|
||||
assert expr_data["chat_id"] == "test_chat_001"
|
||||
|
||||
|
||||
def test_list_expressions_pagination(client: TestClient, mock_auth, test_session: Session):
|
||||
"""Test GET /expression/list pagination works correctly"""
|
||||
for i in range(5):
|
||||
test_session.execute(
|
||||
text(
|
||||
f"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
|
||||
f"VALUES ({i + 1}, '情景{i}', '风格{i}', '[]', 0, '2026-02-17 12:0{i}:00', '2026-02-15 10:00:00', 'chat_{i}', 0, 0)"
|
||||
)
|
||||
)
|
||||
test_session.commit()
|
||||
|
||||
# Request page 1 with page_size=2
|
||||
response = client.get("/api/webui/expression/list?page=1&page_size=2")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["total"] == 5
|
||||
assert data["page"] == 1
|
||||
assert data["page_size"] == 2
|
||||
assert len(data["data"]) == 2
|
||||
|
||||
# Request page 2
|
||||
response = client.get("/api/webui/expression/list?page=2&page_size=2")
|
||||
data = response.json()
|
||||
assert data["page"] == 2
|
||||
assert len(data["data"]) == 2
|
||||
|
||||
|
||||
def test_list_expressions_search(client: TestClient, mock_auth, test_session: Session):
|
||||
"""Test GET /expression/list with search filter"""
|
||||
test_session.execute(
|
||||
text(
|
||||
"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
|
||||
"VALUES (1, '找人吃饭', '热情', '[]', 0, datetime('now'), datetime('now'), 'chat_001', 0, 0)"
|
||||
)
|
||||
)
|
||||
test_session.execute(
|
||||
text(
|
||||
"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
|
||||
"VALUES (2, '拒绝邀请', '礼貌', '[]', 0, datetime('now'), datetime('now'), 'chat_002', 0, 0)"
|
||||
)
|
||||
)
|
||||
test_session.commit()
|
||||
|
||||
# Search for "吃饭"
|
||||
response = client.get("/api/webui/expression/list?search=吃饭")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["total"] == 1
|
||||
assert data["data"][0]["situation"] == "找人吃饭"
|
||||
|
||||
|
||||
def test_list_expressions_chat_filter(client: TestClient, mock_auth, test_session: Session):
|
||||
"""Test GET /expression/list with chat_id filter"""
|
||||
test_session.execute(
|
||||
text(
|
||||
"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
|
||||
"VALUES (1, '情景A', '风格A', '[]', 0, datetime('now'), datetime('now'), 'chat_A', 0, 0)"
|
||||
)
|
||||
)
|
||||
test_session.execute(
|
||||
text(
|
||||
"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
|
||||
"VALUES (2, '情景B', '风格B', '[]', 0, datetime('now'), datetime('now'), 'chat_B', 0, 0)"
|
||||
)
|
||||
)
|
||||
test_session.commit()
|
||||
|
||||
# Filter by chat_A
|
||||
response = client.get("/api/webui/expression/list?chat_id=chat_A")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["total"] == 1
|
||||
assert data["data"][0]["situation"] == "情景A"
|
||||
assert data["data"][0]["chat_id"] == "chat_A"
|
||||
|
||||
|
||||
def test_get_expression_detail_success(client: TestClient, mock_auth, sample_expression: Expression):
|
||||
"""Test GET /expression/{id} returns correct detail"""
|
||||
response = client.get(f"/api/webui/expression/{sample_expression.id}")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"]["id"] == sample_expression.id
|
||||
assert data["data"]["situation"] == "测试情景"
|
||||
assert data["data"]["style"] == "测试风格"
|
||||
assert data["data"]["chat_id"] == "test_chat_001"
|
||||
|
||||
|
||||
def test_get_expression_detail_not_found(client: TestClient, mock_auth):
|
||||
"""Test GET /expression/{id} returns 404 for non-existent ID"""
|
||||
response = client.get("/api/webui/expression/99999")
|
||||
assert response.status_code == 404
|
||||
|
||||
data = response.json()
|
||||
assert "未找到" in data["detail"]
|
||||
|
||||
|
||||
def test_expression_response_has_legacy_fields(client: TestClient, mock_auth, sample_expression: Expression):
|
||||
"""Test that ExpressionResponse includes legacy fields (checked/rejected/modified_by)"""
|
||||
response = client.get(f"/api/webui/expression/{sample_expression.id}")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()["data"]
|
||||
|
||||
# Verify legacy fields exist and have default values
|
||||
assert "checked" in data
|
||||
assert "rejected" in data
|
||||
assert "modified_by" in data
|
||||
|
||||
# Verify hardcoded default values (from expression_to_response)
|
||||
assert data["checked"] is False
|
||||
assert data["rejected"] is False
|
||||
assert data["modified_by"] is None
|
||||
|
||||
|
||||
def test_update_expression_without_removed_fields(client: TestClient, mock_auth, sample_expression: Expression):
|
||||
"""Test PATCH /expression/{id} does not accept checked/rejected fields"""
|
||||
# Valid update request (only allowed fields)
|
||||
update_payload = {
|
||||
"situation": "更新后的情景",
|
||||
"style": "更新后的风格",
|
||||
}
|
||||
|
||||
response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"]["situation"] == "更新后的情景"
|
||||
assert data["data"]["style"] == "更新后的风格"
|
||||
|
||||
# Verify legacy fields still returned (hardcoded values)
|
||||
assert data["data"]["checked"] is False
|
||||
assert data["data"]["rejected"] is False
|
||||
|
||||
|
||||
def test_update_expression_ignores_invalid_fields(client: TestClient, mock_auth, sample_expression: Expression):
|
||||
"""Test PATCH /expression/{id} ignores fields not in ExpressionUpdateRequest"""
|
||||
# Request with invalid field (checked not in schema)
|
||||
update_payload = {
|
||||
"situation": "新情景",
|
||||
"checked": True, # This field should be ignored by Pydantic
|
||||
"rejected": True, # This field should be ignored
|
||||
}
|
||||
|
||||
response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"]["situation"] == "新情景"
|
||||
|
||||
# Response should have hardcoded False values (not True from request)
|
||||
assert data["data"]["checked"] is False
|
||||
assert data["data"]["rejected"] is False
|
||||
|
||||
|
||||
def test_update_expression_chat_id_mapping(client: TestClient, mock_auth, sample_expression: Expression):
|
||||
"""Test PATCH /expression/{id} correctly maps chat_id to session_id"""
|
||||
update_payload = {"chat_id": "updated_chat_999"}
|
||||
|
||||
response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
|
||||
# Verify chat_id is returned in response (mapped from session_id)
|
||||
assert data["data"]["chat_id"] == "updated_chat_999"
|
||||
|
||||
|
||||
def test_update_expression_not_found(client: TestClient, mock_auth):
|
||||
"""Test PATCH /expression/{id} returns 404 for non-existent ID"""
|
||||
update_payload = {"situation": "新情景"}
|
||||
|
||||
response = client.patch("/api/webui/expression/99999", json=update_payload)
|
||||
assert response.status_code == 404
|
||||
|
||||
data = response.json()
|
||||
assert "未找到" in data["detail"]
|
||||
|
||||
|
||||
def test_update_expression_empty_request(client: TestClient, mock_auth, sample_expression: Expression):
|
||||
"""Test PATCH /expression/{id} returns 400 for empty update request"""
|
||||
update_payload = {}
|
||||
|
||||
response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload)
|
||||
assert response.status_code == 400
|
||||
|
||||
data = response.json()
|
||||
assert "未提供任何需要更新的字段" in data["detail"]
|
||||
|
||||
|
||||
def test_delete_expression_success(client: TestClient, mock_auth, sample_expression: Expression):
|
||||
"""Test DELETE /expression/{id} successfully deletes expression"""
|
||||
expression_id = sample_expression.id
|
||||
|
||||
response = client.delete(f"/api/webui/expression/{expression_id}")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert "成功删除" in data["message"]
|
||||
|
||||
# Verify expression is deleted
|
||||
get_response = client.get(f"/api/webui/expression/{expression_id}")
|
||||
assert get_response.status_code == 404
|
||||
|
||||
|
||||
def test_delete_expression_not_found(client: TestClient, mock_auth):
|
||||
"""Test DELETE /expression/{id} returns 404 for non-existent ID"""
|
||||
response = client.delete("/api/webui/expression/99999")
|
||||
assert response.status_code == 404
|
||||
|
||||
data = response.json()
|
||||
assert "未找到" in data["detail"]
|
||||
|
||||
|
||||
def test_create_expression_success(client: TestClient, mock_auth):
|
||||
"""Test POST /expression/ successfully creates expression"""
|
||||
create_payload = {
|
||||
"situation": "新建情景",
|
||||
"style": "新建风格",
|
||||
"chat_id": "new_chat_123",
|
||||
}
|
||||
|
||||
response = client.post("/api/webui/expression/", json=create_payload)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert "创建成功" in data["message"]
|
||||
assert data["data"]["situation"] == "新建情景"
|
||||
assert data["data"]["style"] == "新建风格"
|
||||
assert data["data"]["chat_id"] == "new_chat_123"
|
||||
|
||||
# Verify legacy fields
|
||||
assert data["data"]["checked"] is False
|
||||
assert data["data"]["rejected"] is False
|
||||
assert data["data"]["modified_by"] is None
|
||||
|
||||
|
||||
def test_batch_delete_expressions_success(client: TestClient, mock_auth, test_session: Session):
|
||||
"""Test POST /expression/batch/delete successfully deletes multiple expressions"""
|
||||
expression_ids = []
|
||||
for i in range(3):
|
||||
test_session.execute(
|
||||
text(
|
||||
f"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
|
||||
f"VALUES ({i + 1}, '批量删除{i}', '风格{i}', '[]', 0, datetime('now'), datetime('now'), 'chat_{i}', 0, 0)"
|
||||
)
|
||||
)
|
||||
expression_ids.append(i + 1)
|
||||
test_session.commit()
|
||||
|
||||
delete_payload = {"ids": expression_ids}
|
||||
response = client.post("/api/webui/expression/batch/delete", json=delete_payload)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert "成功删除 3 个" in data["message"]
|
||||
|
||||
for expr_id in expression_ids:
|
||||
get_response = client.get(f"/api/webui/expression/{expr_id}")
|
||||
assert get_response.status_code == 404
|
||||
|
||||
|
||||
def test_batch_delete_partial_not_found(client: TestClient, mock_auth, sample_expression: Expression):
|
||||
"""Test POST /expression/batch/delete handles partial not found IDs"""
|
||||
delete_payload = {"ids": [sample_expression.id, 88888, 99999]}
|
||||
|
||||
response = client.post("/api/webui/expression/batch/delete", json=delete_payload)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
# Should delete only the 1 valid ID
|
||||
assert "成功删除 1 个" in data["message"]
|
||||
|
||||
|
||||
def test_get_expression_stats(client: TestClient, mock_auth, test_session: Session):
|
||||
"""Test GET /expression/stats/summary returns correct statistics"""
|
||||
for i in range(3):
|
||||
test_session.execute(
|
||||
text(
|
||||
f"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
|
||||
f"VALUES ({i + 1}, '情景{i}', '风格{i}', '[]', 0, datetime('now'), datetime('now'), 'chat_{i % 2}', 0, 0)"
|
||||
)
|
||||
)
|
||||
test_session.commit()
|
||||
|
||||
response = client.get("/api/webui/expression/stats/summary")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"]["total"] == 3
|
||||
assert data["data"]["chat_count"] == 2
|
||||
|
||||
|
||||
def test_get_review_stats(client: TestClient, mock_auth, test_session: Session):
|
||||
"""Test GET /expression/review/stats returns review status counts"""
|
||||
test_session.execute(
|
||||
text(
|
||||
"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
|
||||
"VALUES (1, '待审核', '风格', '[]', 0, datetime('now'), datetime('now'), 'chat_001', 0, 0)"
|
||||
)
|
||||
)
|
||||
test_session.commit()
|
||||
|
||||
response = client.get("/api/webui/expression/review/stats")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["total"] == 1 # Total expressions exists
|
||||
assert data["unchecked"] == 1
|
||||
assert data["passed"] == 0
|
||||
assert data["rejected"] == 0
|
||||
assert data["ai_checked"] == 0
|
||||
assert data["user_checked"] == 0
|
||||
|
||||
|
||||
def test_get_review_list_filter_unchecked(client: TestClient, mock_auth, sample_expression: Expression):
|
||||
"""Test GET /expression/review/list with filter_type=unchecked returns unchecked expressions"""
|
||||
response = client.get("/api/webui/expression/review/list?filter_type=unchecked")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["total"] == 1
|
||||
assert len(data["data"]) == 1
|
||||
|
||||
|
||||
def test_get_review_list_filter_all(client: TestClient, mock_auth, sample_expression: Expression):
|
||||
"""Test GET /expression/review/list with filter_type=all returns all expressions"""
|
||||
response = client.get("/api/webui/expression/review/list?filter_type=all")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["total"] == 1
|
||||
assert len(data["data"]) == 1
|
||||
|
||||
|
||||
def test_batch_review_expressions_with_unchecked_marker(client: TestClient, mock_auth, sample_expression: Expression):
|
||||
"""Test POST /expression/review/batch succeeds with require_unchecked=True"""
|
||||
review_payload = {"items": [{"id": sample_expression.id, "rejected": False, "require_unchecked": True}]}
|
||||
|
||||
response = client.post("/api/webui/expression/review/batch", json=review_payload)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["succeeded"] == 1
|
||||
assert data["results"][0]["success"] is True
|
||||
|
||||
|
||||
def test_batch_review_expressions_overwrites_ai_checked(
|
||||
client: TestClient, mock_auth, test_session: Session, sample_expression: Expression
|
||||
):
|
||||
"""Test POST /expression/review/batch lets manual review override AI checked state"""
|
||||
sample_expression.checked = True
|
||||
sample_expression.rejected = True
|
||||
sample_expression.modified_by = ModifiedBy.AI
|
||||
test_session.add(sample_expression)
|
||||
test_session.commit()
|
||||
|
||||
review_payload = {"items": [{"id": sample_expression.id, "rejected": False, "require_unchecked": True}]}
|
||||
|
||||
response = client.post("/api/webui/expression/review/batch", json=review_payload)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["succeeded"] == 1
|
||||
test_session.expire_all()
|
||||
reviewed_expression = test_session.exec(select(Expression).where(Expression.id == sample_expression.id)).first()
|
||||
assert reviewed_expression is not None
|
||||
assert reviewed_expression.checked is True
|
||||
assert reviewed_expression.rejected is False
|
||||
assert reviewed_expression.modified_by == ModifiedBy.USER
|
||||
|
||||
|
||||
def test_batch_review_expressions_no_unchecked_check(client: TestClient, mock_auth, sample_expression: Expression):
|
||||
"""Test POST /expression/review/batch succeeds when require_unchecked=False"""
|
||||
review_payload = {"items": [{"id": sample_expression.id, "rejected": False, "require_unchecked": False}]}
|
||||
|
||||
response = client.post("/api/webui/expression/review/batch", json=review_payload)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["succeeded"] == 1
|
||||
assert data["results"][0]["success"] is True
|
||||
512
pytests/webui/test_jargon_routes.py
Normal file
512
pytests/webui/test_jargon_routes.py
Normal file
@@ -0,0 +1,512 @@
|
||||
"""测试 jargon 路由的完整性和正确性"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
from src.common.database.database_model import ChatSession, Jargon
|
||||
from src.webui.routers.jargon import router as jargon_router
|
||||
|
||||
|
||||
@pytest.fixture(name="app", scope="function")
|
||||
def app_fixture():
|
||||
app = FastAPI()
|
||||
app.include_router(jargon_router, prefix="/api/webui")
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(name="engine", scope="function")
|
||||
def engine_fixture():
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
yield engine
|
||||
|
||||
|
||||
@pytest.fixture(name="session", scope="function")
|
||||
def session_fixture(engine):
|
||||
connection = engine.connect()
|
||||
transaction = connection.begin()
|
||||
session = Session(bind=connection)
|
||||
|
||||
yield session
|
||||
|
||||
session.close()
|
||||
transaction.rollback()
|
||||
connection.close()
|
||||
|
||||
|
||||
@pytest.fixture(name="client", scope="function")
|
||||
def client_fixture(app: FastAPI, session: Session, monkeypatch):
|
||||
from contextlib import contextmanager
|
||||
|
||||
@contextmanager
|
||||
def mock_get_db_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("src.webui.routers.jargon.get_db_session", mock_get_db_session)
|
||||
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(name="sample_chat_session")
|
||||
def sample_chat_session_fixture(session: Session):
|
||||
"""创建示例 ChatSession"""
|
||||
chat_session = ChatSession(
|
||||
session_id="test_stream_001",
|
||||
platform="qq",
|
||||
group_id="123456789",
|
||||
user_id=None,
|
||||
created_timestamp=datetime.now(),
|
||||
last_active_timestamp=datetime.now(),
|
||||
)
|
||||
session.add(chat_session)
|
||||
session.commit()
|
||||
session.refresh(chat_session)
|
||||
return chat_session
|
||||
|
||||
|
||||
@pytest.fixture(name="sample_jargons")
|
||||
def sample_jargons_fixture(session: Session, sample_chat_session: ChatSession):
|
||||
"""创建示例 Jargon 数据"""
|
||||
jargons = [
|
||||
Jargon(
|
||||
id=1,
|
||||
content="yyds",
|
||||
raw_content="永远的神",
|
||||
meaning="永远的神",
|
||||
session_id=sample_chat_session.session_id,
|
||||
count=10,
|
||||
is_jargon=True,
|
||||
is_complete=False,
|
||||
),
|
||||
Jargon(
|
||||
id=2,
|
||||
content="awsl",
|
||||
raw_content="啊我死了",
|
||||
meaning="啊我死了",
|
||||
session_id=sample_chat_session.session_id,
|
||||
count=5,
|
||||
is_jargon=True,
|
||||
is_complete=False,
|
||||
),
|
||||
Jargon(
|
||||
id=3,
|
||||
content="hello",
|
||||
raw_content=None,
|
||||
meaning="你好",
|
||||
session_id=sample_chat_session.session_id,
|
||||
count=2,
|
||||
is_jargon=False,
|
||||
is_complete=False,
|
||||
),
|
||||
]
|
||||
for jargon in jargons:
|
||||
session.add(jargon)
|
||||
session.commit()
|
||||
for jargon in jargons:
|
||||
session.refresh(jargon)
|
||||
return jargons
|
||||
|
||||
|
||||
# ==================== Test Cases ====================
|
||||
|
||||
|
||||
def test_list_jargons(client: TestClient, sample_jargons):
|
||||
"""测试 GET /jargon/list 基础列表功能"""
|
||||
response = client.get("/api/webui/jargon/list")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["total"] == 3
|
||||
assert data["page"] == 1
|
||||
assert data["page_size"] == 20
|
||||
assert len(data["data"]) == 3
|
||||
|
||||
assert data["data"][0]["content"] == "yyds"
|
||||
assert data["data"][0]["count"] == 10
|
||||
|
||||
|
||||
def test_list_jargons_with_pagination(client: TestClient, sample_jargons):
|
||||
"""测试分页功能"""
|
||||
response = client.get("/api/webui/jargon/list?page=1&page_size=2")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["total"] == 3
|
||||
assert len(data["data"]) == 2
|
||||
|
||||
response = client.get("/api/webui/jargon/list?page=2&page_size=2")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["data"]) == 1
|
||||
|
||||
|
||||
def test_list_jargons_with_search(client: TestClient, sample_jargons):
|
||||
"""测试 GET /jargon/list?search=xxx 搜索功能"""
|
||||
response = client.get("/api/webui/jargon/list?search=yyds")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["total"] == 1
|
||||
assert data["data"][0]["content"] == "yyds"
|
||||
|
||||
# 测试搜索 meaning
|
||||
response = client.get("/api/webui/jargon/list?search=你好")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 1
|
||||
assert data["data"][0]["content"] == "hello"
|
||||
|
||||
|
||||
def test_list_jargons_with_chat_id_filter(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
|
||||
"""测试按 chat_id 筛选"""
|
||||
response = client.get(f"/api/webui/jargon/list?chat_id={sample_chat_session.session_id}")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["total"] == 3
|
||||
|
||||
# 测试不存在的 chat_id
|
||||
response = client.get("/api/webui/jargon/list?chat_id=nonexistent")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 0
|
||||
|
||||
|
||||
def test_list_jargons_with_is_jargon_filter(client: TestClient, sample_jargons):
|
||||
"""测试按 is_jargon 筛选"""
|
||||
response = client.get("/api/webui/jargon/list?is_jargon=true")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["total"] == 2
|
||||
assert all(item["is_jargon"] is True for item in data["data"])
|
||||
|
||||
response = client.get("/api/webui/jargon/list?is_jargon=false")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 1
|
||||
assert data["data"][0]["content"] == "hello"
|
||||
|
||||
|
||||
def test_get_jargon_detail(client: TestClient, sample_jargons):
|
||||
"""测试 GET /jargon/{id} 获取详情"""
|
||||
jargon_id = sample_jargons[0].id
|
||||
response = client.get(f"/api/webui/jargon/{jargon_id}")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"]["id"] == jargon_id
|
||||
assert data["data"]["content"] == "yyds"
|
||||
assert data["data"]["meaning"] == "永远的神"
|
||||
assert data["data"]["count"] == 10
|
||||
assert data["data"]["is_jargon"] is True
|
||||
|
||||
|
||||
def test_get_jargon_detail_not_found(client: TestClient):
|
||||
"""测试获取不存在的黑话详情"""
|
||||
response = client.get("/api/webui/jargon/99999")
|
||||
assert response.status_code == 404
|
||||
assert "黑话不存在" in response.json()["detail"]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Composite PK (id+content) prevents autoincrement - database model issue")
|
||||
def test_create_jargon(client: TestClient, sample_chat_session: ChatSession):
|
||||
"""测试 POST /jargon/ 创建黑话"""
|
||||
request_data = {
|
||||
"content": "新黑话",
|
||||
"raw_content": "原始内容",
|
||||
"meaning": "含义",
|
||||
"chat_id": sample_chat_session.session_id,
|
||||
}
|
||||
|
||||
response = client.post("/api/webui/jargon/", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["message"] == "创建成功"
|
||||
assert data["data"]["content"] == "新黑话"
|
||||
assert data["data"]["meaning"] == "含义"
|
||||
assert data["data"]["count"] == 0
|
||||
assert data["data"]["is_jargon"] is None
|
||||
assert data["data"]["is_complete"] is False
|
||||
|
||||
|
||||
def test_create_duplicate_jargon(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
|
||||
"""测试创建重复黑话返回 400"""
|
||||
request_data = {
|
||||
"content": "yyds",
|
||||
"meaning": "重复的",
|
||||
"chat_id": sample_chat_session.session_id,
|
||||
}
|
||||
|
||||
response = client.post("/api/webui/jargon/", json=request_data)
|
||||
assert response.status_code == 400
|
||||
assert "已存在相同内容的黑话" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_update_jargon(client: TestClient, sample_jargons):
|
||||
"""测试 PATCH /jargon/{id} 更新黑话"""
|
||||
jargon_id = sample_jargons[0].id
|
||||
update_data = {
|
||||
"meaning": "更新后的含义",
|
||||
"is_jargon": True,
|
||||
}
|
||||
|
||||
response = client.patch(f"/api/webui/jargon/{jargon_id}", json=update_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["message"] == "更新成功"
|
||||
assert data["data"]["meaning"] == "更新后的含义"
|
||||
assert data["data"]["is_jargon"] is True
|
||||
assert data["data"]["content"] == "yyds" # 未改变的字段保持不变
|
||||
|
||||
|
||||
def test_update_jargon_with_chat_id_mapping(client: TestClient, sample_jargons):
|
||||
"""测试更新时 chat_id → session_id 的映射"""
|
||||
jargon_id = sample_jargons[0].id
|
||||
update_data = {
|
||||
"chat_id": "new_session_id",
|
||||
}
|
||||
|
||||
response = client.patch(f"/api/webui/jargon/{jargon_id}", json=update_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"]["chat_id"] == "new_session_id"
|
||||
|
||||
|
||||
def test_update_jargon_not_found(client: TestClient):
|
||||
"""测试更新不存在的黑话"""
|
||||
response = client.patch("/api/webui/jargon/99999", json={"meaning": "test"})
|
||||
assert response.status_code == 404
|
||||
assert "黑话不存在" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_delete_jargon(client: TestClient, sample_jargons, session: Session):
|
||||
"""测试 DELETE /jargon/{id} 删除黑话"""
|
||||
jargon_id = sample_jargons[0].id
|
||||
response = client.delete(f"/api/webui/jargon/{jargon_id}")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["message"] == "删除成功"
|
||||
assert data["deleted_count"] == 1
|
||||
|
||||
# 验证数据库中已删除
|
||||
response = client.get(f"/api/webui/jargon/{jargon_id}")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_delete_jargon_not_found(client: TestClient):
|
||||
"""测试删除不存在的黑话"""
|
||||
response = client.delete("/api/webui/jargon/99999")
|
||||
assert response.status_code == 404
|
||||
assert "黑话不存在" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_batch_delete(client: TestClient, sample_jargons):
|
||||
"""测试 POST /jargon/batch/delete 批量删除"""
|
||||
ids_to_delete = [sample_jargons[0].id, sample_jargons[1].id]
|
||||
request_data = {"ids": ids_to_delete}
|
||||
|
||||
response = client.post("/api/webui/jargon/batch/delete", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["deleted_count"] == 2
|
||||
assert "成功删除 2 条黑话" in data["message"]
|
||||
|
||||
# 验证已删除
|
||||
response = client.get(f"/api/webui/jargon/{ids_to_delete[0]}")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_batch_delete_empty_list(client: TestClient):
|
||||
"""测试批量删除空列表返回 400"""
|
||||
response = client.post("/api/webui/jargon/batch/delete", json={"ids": []})
|
||||
assert response.status_code == 400
|
||||
assert "ID列表不能为空" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_batch_set_jargon_status(client: TestClient, sample_jargons):
|
||||
"""测试批量设置黑话状态"""
|
||||
ids = [sample_jargons[0].id, sample_jargons[1].id]
|
||||
response = client.post(
|
||||
"/api/webui/jargon/batch/set-jargon",
|
||||
params={"ids": ids, "is_jargon": False},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert "成功更新 2 条黑话状态" in data["message"]
|
||||
|
||||
# 验证状态已更新
|
||||
detail_response = client.get(f"/api/webui/jargon/{ids[0]}")
|
||||
assert detail_response.json()["data"]["is_jargon"] is False
|
||||
|
||||
|
||||
def test_get_stats(client: TestClient, sample_jargons):
|
||||
"""测试 GET /jargon/stats/summary 统计数据"""
|
||||
response = client.get("/api/webui/jargon/stats/summary")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
stats = data["data"]
|
||||
|
||||
assert stats["total"] == 3
|
||||
assert stats["confirmed_jargon"] == 2
|
||||
assert stats["confirmed_not_jargon"] == 1
|
||||
assert stats["pending"] == 0
|
||||
assert stats["complete_count"] == 0
|
||||
assert stats["chat_count"] == 1
|
||||
assert isinstance(stats["top_chats"], dict)
|
||||
|
||||
|
||||
def test_get_chat_list(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
|
||||
"""测试 GET /jargon/chats 获取聊天列表"""
|
||||
response = client.get("/api/webui/jargon/chats")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert len(data["data"]) == 1
|
||||
|
||||
chat_info = data["data"][0]
|
||||
assert chat_info["chat_id"] == sample_chat_session.session_id
|
||||
assert chat_info["platform"] == "qq"
|
||||
assert chat_info["is_group"] is True
|
||||
assert chat_info["chat_name"] == sample_chat_session.group_id
|
||||
|
||||
|
||||
def test_get_chat_list_with_json_chat_id(client: TestClient, session: Session, sample_chat_session: ChatSession):
|
||||
"""测试解析 JSON 格式的 chat_id"""
|
||||
json_chat_id = json.dumps([[sample_chat_session.session_id, "user123"]])
|
||||
jargon = Jargon(
|
||||
id=100,
|
||||
content="测试黑话",
|
||||
meaning="测试",
|
||||
session_id=json_chat_id,
|
||||
count=1,
|
||||
)
|
||||
session.add(jargon)
|
||||
session.commit()
|
||||
|
||||
response = client.get("/api/webui/jargon/chats")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert len(data["data"]) == 1
|
||||
assert data["data"][0]["chat_id"] == sample_chat_session.session_id
|
||||
|
||||
|
||||
def test_get_chat_list_without_chat_session(client: TestClient, session: Session):
|
||||
"""测试聊天列表中没有对应 ChatSession 的情况"""
|
||||
jargon = Jargon(
|
||||
id=101,
|
||||
content="孤立黑话",
|
||||
meaning="无对应会话",
|
||||
session_id="nonexistent_stream_id",
|
||||
count=1,
|
||||
)
|
||||
session.add(jargon)
|
||||
session.commit()
|
||||
|
||||
response = client.get("/api/webui/jargon/chats")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert len(data["data"]) == 1
|
||||
assert data["data"][0]["chat_id"] == "nonexistent_stream_id"
|
||||
assert data["data"][0]["chat_name"] == "nonexistent_stream_id"[:20]
|
||||
assert data["data"][0]["platform"] is None
|
||||
assert data["data"][0]["is_group"] is False
|
||||
|
||||
|
||||
def test_jargon_response_fields(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
|
||||
"""测试 JargonResponse 字段完整性"""
|
||||
response = client.get(f"/api/webui/jargon/{sample_jargons[0].id}")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()["data"]
|
||||
|
||||
# 验证所有必需字段存在
|
||||
required_fields = [
|
||||
"id",
|
||||
"content",
|
||||
"raw_content",
|
||||
"meaning",
|
||||
"chat_id",
|
||||
"stream_id",
|
||||
"chat_name",
|
||||
"count",
|
||||
"is_jargon",
|
||||
"is_complete",
|
||||
"inference_with_context",
|
||||
"inference_content_only",
|
||||
]
|
||||
for field in required_fields:
|
||||
assert field in data
|
||||
|
||||
# 验证 chat_name 显示逻辑
|
||||
assert data["chat_name"] == sample_chat_session.group_id
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Composite PK (id+content) prevents autoincrement - database model issue")
|
||||
def test_create_jargon_without_optional_fields(client: TestClient, sample_chat_session: ChatSession):
|
||||
"""测试创建黑话时可选字段为空"""
|
||||
request_data = {
|
||||
"content": "简单黑话",
|
||||
"chat_id": sample_chat_session.session_id,
|
||||
}
|
||||
|
||||
response = client.post("/api/webui/jargon/", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()["data"]
|
||||
assert data["raw_content"] is None
|
||||
assert data["meaning"] == ""
|
||||
|
||||
|
||||
def test_update_jargon_partial_fields(client: TestClient, sample_jargons):
|
||||
"""测试增量更新(只更新部分字段)"""
|
||||
jargon_id = sample_jargons[0].id
|
||||
original_content = sample_jargons[0].content
|
||||
|
||||
# 只更新 meaning
|
||||
response = client.patch(f"/api/webui/jargon/{jargon_id}", json={"meaning": "新含义"})
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()["data"]
|
||||
assert data["meaning"] == "新含义"
|
||||
assert data["content"] == original_content # 其他字段不变
|
||||
|
||||
|
||||
def test_list_jargons_multiple_filters(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
|
||||
"""测试组合多个过滤条件"""
|
||||
response = client.get(f"/api/webui/jargon/list?search=永远&chat_id={sample_chat_session.session_id}&is_jargon=true")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["total"] == 1
|
||||
assert data["data"][0]["content"] == "yyds"
|
||||
870
pytests/webui/test_memory_routes.py
Normal file
870
pytests/webui/test_memory_routes.py
Normal file
@@ -0,0 +1,870 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
import pytest
|
||||
|
||||
from src.services.memory_service import MemorySearchResult
|
||||
from src.webui.dependencies import require_auth
|
||||
from src.webui.routers import memory as memory_router_module
|
||||
from src.webui.routers.memory import compat_router
|
||||
from src.webui.routes import router as main_router
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> TestClient:
|
||||
app = FastAPI()
|
||||
app.dependency_overrides[require_auth] = lambda: "ok"
|
||||
app.include_router(main_router)
|
||||
app.include_router(compat_router)
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_webui_memory_graph_route(client: TestClient, monkeypatch):
|
||||
async def fake_graph_admin(*, action: str, **kwargs):
|
||||
assert action == "get_graph"
|
||||
return {
|
||||
"success": True,
|
||||
"nodes": [],
|
||||
"edges": [
|
||||
{
|
||||
"source": "alice",
|
||||
"target": "map",
|
||||
"weight": 1.5,
|
||||
"relation_hashes": ["rel-1"],
|
||||
"predicates": ["持有"],
|
||||
"relation_count": 1,
|
||||
"evidence_count": 2,
|
||||
"label": "持有",
|
||||
}
|
||||
],
|
||||
"total_nodes": 0,
|
||||
"limit": kwargs.get("limit"),
|
||||
}
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
|
||||
|
||||
response = client.get("/api/webui/memory/graph", params={"limit": 77})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["success"] is True
|
||||
assert response.json()["limit"] == 77
|
||||
assert response.json()["edges"][0]["predicates"] == ["持有"]
|
||||
assert response.json()["edges"][0]["relation_count"] == 1
|
||||
assert response.json()["edges"][0]["evidence_count"] == 2
|
||||
|
||||
|
||||
def test_webui_memory_graph_search_route(client: TestClient, monkeypatch):
|
||||
async def fake_graph_admin(*, action: str, **kwargs):
|
||||
assert action == "search"
|
||||
assert kwargs["query"] == "Alice"
|
||||
assert kwargs["limit"] == 33
|
||||
return {
|
||||
"success": True,
|
||||
"query": kwargs["query"],
|
||||
"limit": kwargs["limit"],
|
||||
"count": 1,
|
||||
"items": [
|
||||
{
|
||||
"type": "entity",
|
||||
"title": "Alice",
|
||||
"matched_field": "name",
|
||||
"matched_value": "Alice",
|
||||
"entity_name": "Alice",
|
||||
"entity_hash": "entity-1",
|
||||
"appearance_count": 3,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
|
||||
|
||||
response = client.get("/api/webui/memory/graph/search", params={"query": "Alice", "limit": 33})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["success"] is True
|
||||
assert response.json()["query"] == "Alice"
|
||||
assert response.json()["limit"] == 33
|
||||
assert response.json()["items"][0]["type"] == "entity"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"params",
|
||||
[
|
||||
{"query": "", "limit": 50},
|
||||
{"query": "Alice", "limit": 0},
|
||||
{"query": "Alice", "limit": 201},
|
||||
],
|
||||
)
|
||||
def test_webui_memory_graph_search_route_validation(client: TestClient, params):
|
||||
response = client.get("/api/webui/memory/graph/search", params=params)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_webui_memory_graph_node_detail_route(client: TestClient, monkeypatch):
|
||||
async def fake_graph_admin(*, action: str, **kwargs):
|
||||
assert action == "node_detail"
|
||||
assert kwargs["node_id"] == "Alice"
|
||||
return {
|
||||
"success": True,
|
||||
"node": {"id": "Alice", "type": "entity", "content": "Alice", "appearance_count": 3},
|
||||
"relations": [{"hash": "rel-1", "subject": "Alice", "predicate": "持有", "object": "Map", "text": "Alice 持有 Map", "confidence": 0.9, "paragraph_count": 1, "paragraph_hashes": ["p-1"], "source_paragraph": "p-1"}],
|
||||
"paragraphs": [{"hash": "p-1", "content": "Alice 拿着地图。", "preview": "Alice 拿着地图。", "source": "demo", "entity_count": 2, "relation_count": 1, "entities": ["Alice", "Map"], "relations": ["Alice 持有 Map"]}],
|
||||
"evidence_graph": {
|
||||
"nodes": [{"id": "entity:Alice", "type": "entity", "content": "Alice"}],
|
||||
"edges": [],
|
||||
"focus_entities": ["Alice"],
|
||||
},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
|
||||
|
||||
response = client.get("/api/webui/memory/graph/node-detail", params={"node_id": "Alice"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["node"]["id"] == "Alice"
|
||||
assert response.json()["relations"][0]["predicate"] == "持有"
|
||||
assert response.json()["evidence_graph"]["focus_entities"] == ["Alice"]
|
||||
|
||||
|
||||
def test_webui_memory_graph_node_detail_route_returns_404(client: TestClient, monkeypatch):
|
||||
async def fake_graph_admin(*, action: str, **kwargs):
|
||||
assert action == "node_detail"
|
||||
return {"success": False, "error": "未找到节点: Missing"}
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
|
||||
|
||||
response = client.get("/api/webui/memory/graph/node-detail", params={"node_id": "Missing"})
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "未找到节点: Missing"
|
||||
|
||||
|
||||
def test_webui_memory_graph_edge_detail_route(client: TestClient, monkeypatch):
|
||||
async def fake_graph_admin(*, action: str, **kwargs):
|
||||
assert action == "edge_detail"
|
||||
assert kwargs["source"] == "Alice"
|
||||
assert kwargs["target"] == "Map"
|
||||
return {
|
||||
"success": True,
|
||||
"edge": {
|
||||
"source": "Alice",
|
||||
"target": "Map",
|
||||
"weight": 1.5,
|
||||
"relation_hashes": ["rel-1"],
|
||||
"predicates": ["持有"],
|
||||
"relation_count": 1,
|
||||
"evidence_count": 1,
|
||||
"label": "持有",
|
||||
},
|
||||
"relations": [{"hash": "rel-1", "subject": "Alice", "predicate": "持有", "object": "Map", "text": "Alice 持有 Map", "confidence": 0.9, "paragraph_count": 1, "paragraph_hashes": ["p-1"], "source_paragraph": "p-1"}],
|
||||
"paragraphs": [{"hash": "p-1", "content": "Alice 拿着地图。", "preview": "Alice 拿着地图。", "source": "demo", "entity_count": 2, "relation_count": 1, "entities": ["Alice", "Map"], "relations": ["Alice 持有 Map"]}],
|
||||
"evidence_graph": {
|
||||
"nodes": [{"id": "relation:rel-1", "type": "relation", "content": "Alice 持有 Map"}],
|
||||
"edges": [{"source": "paragraph:p-1", "target": "relation:rel-1", "kind": "supports", "label": "支撑", "weight": 1.0}],
|
||||
"focus_entities": ["Alice", "Map"],
|
||||
},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
|
||||
|
||||
response = client.get("/api/webui/memory/graph/edge-detail", params={"source": "Alice", "target": "Map"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["edge"]["predicates"] == ["持有"]
|
||||
assert response.json()["paragraphs"][0]["source"] == "demo"
|
||||
assert response.json()["evidence_graph"]["edges"][0]["kind"] == "supports"
|
||||
|
||||
|
||||
def test_webui_memory_graph_edge_detail_route_returns_404(client: TestClient, monkeypatch):
|
||||
async def fake_graph_admin(*, action: str, **kwargs):
|
||||
assert action == "edge_detail"
|
||||
return {"success": False, "error": "未找到边: Alice -> Missing"}
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
|
||||
|
||||
response = client.get("/api/webui/memory/graph/edge-detail", params={"source": "Alice", "target": "Missing"})
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "未找到边: Alice -> Missing"
|
||||
|
||||
|
||||
def test_webui_memory_profile_query_resolves_platform_user_id(client: TestClient, monkeypatch):
|
||||
def fake_resolve_person_id_for_memory(**kwargs):
|
||||
assert kwargs == {"platform": "qq", "user_id": "12345", "strict_known": False}
|
||||
return "resolved-person-id"
|
||||
|
||||
async def fake_profile_admin(*, action: str, **kwargs):
|
||||
assert action == "query"
|
||||
assert kwargs["person_id"] == "resolved-person-id"
|
||||
assert kwargs["person_keyword"] == "Alice"
|
||||
assert kwargs["limit"] == 9
|
||||
assert kwargs["force_refresh"] is True
|
||||
return {"success": True, "person_id": kwargs["person_id"], "profile_text": "profile"}
|
||||
|
||||
monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "profile_admin", fake_profile_admin)
|
||||
|
||||
response = client.get(
|
||||
"/api/webui/memory/profiles/query",
|
||||
params={
|
||||
"platform": "qq",
|
||||
"user_id": "12345",
|
||||
"person_keyword": "Alice",
|
||||
"limit": 9,
|
||||
"force_refresh": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["success"] is True
|
||||
assert response.json()["person_id"] == "resolved-person-id"
|
||||
|
||||
|
||||
def test_webui_memory_profile_query_prefers_explicit_person_id(client: TestClient, monkeypatch):
|
||||
def fake_resolve_person_id_for_memory(**kwargs):
|
||||
raise AssertionError(f"不应解析平台账号: {kwargs}")
|
||||
|
||||
async def fake_profile_admin(*, action: str, **kwargs):
|
||||
assert action == "query"
|
||||
assert kwargs["person_id"] == "explicit-person-id"
|
||||
return {"success": True, "person_id": kwargs["person_id"]}
|
||||
|
||||
monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "profile_admin", fake_profile_admin)
|
||||
|
||||
response = client.get(
|
||||
"/api/webui/memory/profiles/query",
|
||||
params={"person_id": "explicit-person-id", "platform": "qq", "user_id": "12345"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["person_id"] == "explicit-person-id"
|
||||
|
||||
|
||||
def test_webui_memory_profile_list_enriches_person_name(client: TestClient, monkeypatch):
|
||||
async def fake_profile_admin(*, action: str, **kwargs):
|
||||
assert action == "list"
|
||||
assert kwargs["limit"] == 7
|
||||
return {
|
||||
"success": True,
|
||||
"items": [
|
||||
{"person_id": "person-1", "profile_text": "profile-1"},
|
||||
{"person_id": "person-2", "profile_text": "profile-2"},
|
||||
],
|
||||
}
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "profile_admin", fake_profile_admin)
|
||||
monkeypatch.setattr(
|
||||
memory_router_module,
|
||||
"_get_person_name_for_person_id",
|
||||
lambda person_id: {"person-1": "Alice"}.get(person_id, ""),
|
||||
)
|
||||
|
||||
response = client.get("/api/webui/memory/profiles", params={"limit": 7})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["items"][0]["person_name"] == "Alice"
|
||||
assert response.json()["items"][1]["person_name"] == ""
|
||||
|
||||
|
||||
def test_webui_memory_profile_search_resolves_platform_user_id(client: TestClient, monkeypatch):
|
||||
def fake_resolve_person_id_for_memory(**kwargs):
|
||||
assert kwargs == {"platform": "qq", "user_id": "12345", "strict_known": False}
|
||||
return "resolved-person-id"
|
||||
|
||||
async def fake_profile_list(limit: int):
|
||||
assert limit == 200
|
||||
return {
|
||||
"success": True,
|
||||
"items": [
|
||||
{"person_id": "resolved-person-id", "person_name": "Alice", "profile_text": "喜欢咖啡"},
|
||||
{"person_id": "other-person-id", "person_name": "Bob", "profile_text": "喜欢茶"},
|
||||
],
|
||||
}
|
||||
|
||||
monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
|
||||
monkeypatch.setattr(memory_router_module, "_profile_list", fake_profile_list)
|
||||
|
||||
response = client.get(
|
||||
"/api/webui/memory/profiles/search",
|
||||
params={"platform": "qq", "user_id": "12345", "limit": 50},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["items"] == [
|
||||
{"person_id": "resolved-person-id", "person_name": "Alice", "profile_text": "喜欢咖啡"}
|
||||
]
|
||||
|
||||
|
||||
def test_webui_memory_profile_search_filters_keyword(client: TestClient, monkeypatch):
|
||||
async def fake_profile_list(limit: int):
|
||||
assert limit == 200
|
||||
return {
|
||||
"success": True,
|
||||
"items": [
|
||||
{"person_id": "person-1", "person_name": "Alice", "profile_text": "喜欢咖啡"},
|
||||
{"person_id": "person-2", "person_name": "Bob", "profile_text": "喜欢茶"},
|
||||
],
|
||||
}
|
||||
|
||||
monkeypatch.setattr(memory_router_module, "_profile_list", fake_profile_list)
|
||||
|
||||
response = client.get("/api/webui/memory/profiles/search", params={"person_keyword": "咖啡", "limit": 50})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["items"] == [
|
||||
{"person_id": "person-1", "person_name": "Alice", "profile_text": "喜欢咖啡"}
|
||||
]
|
||||
|
||||
|
||||
def test_webui_memory_episode_list_resolves_platform_user_id(client: TestClient, monkeypatch):
|
||||
def fake_resolve_person_id_for_memory(**kwargs):
|
||||
assert kwargs == {"platform": "qq", "user_id": "12345", "strict_known": False}
|
||||
return "resolved-person-id"
|
||||
|
||||
async def fake_episode_admin(*, action: str, **kwargs):
|
||||
assert action == "list"
|
||||
assert kwargs == {
|
||||
"query": "咖啡",
|
||||
"limit": 9,
|
||||
"source": "chat_summary:demo",
|
||||
"person_id": "resolved-person-id",
|
||||
"time_start": 100.0,
|
||||
"time_end": 200.0,
|
||||
}
|
||||
return {
|
||||
"success": True,
|
||||
"items": [{"episode_id": "ep-1", "person_id": "resolved-person-id", "summary": "喝咖啡"}],
|
||||
"count": 1,
|
||||
}
|
||||
|
||||
monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "episode_admin", fake_episode_admin)
|
||||
monkeypatch.setattr(memory_router_module, "_get_person_name_for_person_id", lambda person_id: "测试人物")
|
||||
|
||||
response = client.get(
|
||||
"/api/webui/memory/episodes",
|
||||
params={
|
||||
"query": "咖啡",
|
||||
"limit": 9,
|
||||
"source": "chat_summary:demo",
|
||||
"platform": "qq",
|
||||
"user_id": "12345",
|
||||
"time_start": 100,
|
||||
"time_end": 200,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["items"][0]["person_name"] == "测试人物"
|
||||
|
||||
|
||||
def test_webui_memory_episode_list_prefers_explicit_person_id(client: TestClient, monkeypatch):
|
||||
def fake_resolve_person_id_for_memory(**kwargs):
|
||||
raise AssertionError(f"不应解析平台账号: {kwargs}")
|
||||
|
||||
async def fake_episode_admin(*, action: str, **kwargs):
|
||||
assert action == "list"
|
||||
assert kwargs["person_id"] == "explicit-person-id"
|
||||
return {"success": True, "items": []}
|
||||
|
||||
monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "episode_admin", fake_episode_admin)
|
||||
|
||||
response = client.get(
|
||||
"/api/webui/memory/episodes",
|
||||
params={"person_id": "explicit-person-id", "platform": "qq", "user_id": "12345"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["items"] == []
|
||||
|
||||
|
||||
def test_compat_aggregate_route(client: TestClient, monkeypatch):
|
||||
async def fake_search(query: str, **kwargs):
|
||||
assert kwargs["mode"] == "aggregate"
|
||||
assert kwargs["respect_filter"] is False
|
||||
return MemorySearchResult(summary=f"summary:{query}", hits=[])
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "search", fake_search)
|
||||
|
||||
response = client.get("/api/query/aggregate", params={"query": "mai"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"success": True,
|
||||
"summary": "summary:mai",
|
||||
"hits": [],
|
||||
"filtered": False,
|
||||
"error": "",
|
||||
}
|
||||
|
||||
|
||||
def test_auto_save_routes(client: TestClient, monkeypatch):
|
||||
async def fake_runtime_admin(*, action: str, **kwargs):
|
||||
if action == "get_config":
|
||||
return {"success": True, "auto_save": True}
|
||||
if action == "set_auto_save":
|
||||
return {"success": True, "auto_save": kwargs["enabled"]}
|
||||
raise AssertionError(action)
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "runtime_admin", fake_runtime_admin)
|
||||
|
||||
get_response = client.get("/api/config/auto_save")
|
||||
post_response = client.post("/api/config/auto_save", json={"enabled": False})
|
||||
|
||||
assert get_response.status_code == 200
|
||||
assert get_response.json() == {"success": True, "auto_save": True}
|
||||
assert post_response.status_code == 200
|
||||
assert post_response.json() == {"success": True, "auto_save": False}
|
||||
|
||||
|
||||
def test_memory_config_routes(client: TestClient, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
memory_router_module.a_memorix_host_service,
|
||||
"get_config_schema",
|
||||
lambda: {"layout": {"type": "tabs"}, "sections": {"plugin": {"fields": {}}}},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
memory_router_module.a_memorix_host_service,
|
||||
"get_config_path",
|
||||
lambda: memory_router_module.Path("/tmp/config/bot_config.toml"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
memory_router_module.a_memorix_host_service,
|
||||
"get_config",
|
||||
lambda: {"plugin": {"enabled": True}},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
memory_router_module.a_memorix_host_service,
|
||||
"get_raw_config",
|
||||
lambda: "[plugin]\nenabled = true\n",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
memory_router_module.a_memorix_host_service,
|
||||
"get_raw_config_with_meta",
|
||||
lambda: {
|
||||
"config": "[plugin]\nenabled = true\n",
|
||||
"exists": True,
|
||||
"using_default": False,
|
||||
},
|
||||
)
|
||||
|
||||
schema_response = client.get("/api/webui/memory/config/schema")
|
||||
config_response = client.get("/api/webui/memory/config")
|
||||
raw_response = client.get("/api/webui/memory/config/raw")
|
||||
expected_path = memory_router_module.Path("/tmp/config/bot_config.toml").as_posix()
|
||||
|
||||
assert schema_response.status_code == 200
|
||||
assert memory_router_module.Path(schema_response.json()["path"]).as_posix() == expected_path
|
||||
assert schema_response.json()["schema"]["layout"]["type"] == "tabs"
|
||||
|
||||
assert config_response.status_code == 200
|
||||
assert config_response.json()["success"] is True
|
||||
assert config_response.json()["config"] == {"plugin": {"enabled": True}}
|
||||
assert memory_router_module.Path(config_response.json()["path"]).as_posix() == expected_path
|
||||
|
||||
assert raw_response.status_code == 200
|
||||
assert raw_response.json()["success"] is True
|
||||
assert raw_response.json()["config"] == "[plugin]\nenabled = true\n"
|
||||
assert memory_router_module.Path(raw_response.json()["path"]).as_posix() == expected_path
|
||||
|
||||
|
||||
def test_memory_config_raw_returns_default_template_when_file_missing(client: TestClient, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
memory_router_module.a_memorix_host_service,
|
||||
"get_config_path",
|
||||
lambda: memory_router_module.Path("/tmp/config/bot_config.toml"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
memory_router_module.a_memorix_host_service,
|
||||
"get_raw_config_with_meta",
|
||||
lambda: {
|
||||
"config": "[plugin]\nenabled = true\n",
|
||||
"exists": False,
|
||||
"using_default": True,
|
||||
},
|
||||
)
|
||||
|
||||
response = client.get("/api/webui/memory/config/raw")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["success"] is True
|
||||
assert response.json()["config"] == "[plugin]\nenabled = true\n"
|
||||
assert response.json()["exists"] is False
|
||||
assert response.json()["using_default"] is True
|
||||
|
||||
|
||||
def test_memory_config_update_routes(client: TestClient, monkeypatch):
|
||||
async def fake_update_config(config):
|
||||
assert config == {"plugin": {"enabled": False}}
|
||||
return {"success": True, "config_path": "config/bot_config.toml"}
|
||||
|
||||
async def fake_update_raw(raw_config):
|
||||
assert raw_config == "[plugin]\nenabled = false\n"
|
||||
return {"success": True, "config_path": "config/bot_config.toml"}
|
||||
|
||||
monkeypatch.setattr(memory_router_module.a_memorix_host_service, "update_config", fake_update_config)
|
||||
monkeypatch.setattr(memory_router_module.a_memorix_host_service, "update_raw_config", fake_update_raw)
|
||||
|
||||
config_response = client.put("/api/webui/memory/config", json={"config": {"plugin": {"enabled": False}}})
|
||||
raw_response = client.put("/api/webui/memory/config/raw", json={"config": "[plugin]\nenabled = false\n"})
|
||||
|
||||
assert config_response.status_code == 200
|
||||
assert config_response.json() == {"success": True, "config_path": "config/bot_config.toml"}
|
||||
|
||||
assert raw_response.status_code == 200
|
||||
assert raw_response.json() == {"success": True, "config_path": "config/bot_config.toml"}
|
||||
|
||||
|
||||
def test_memory_config_raw_rejects_invalid_toml(client: TestClient):
|
||||
response = client.put("/api/webui/memory/config/raw", json={"config": "[plugin\nenabled = true"})
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "TOML 格式错误" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_recycle_bin_route(client: TestClient, monkeypatch):
|
||||
async def fake_get_recycle_bin(*, limit: int):
|
||||
return {"success": True, "items": [{"hash": "deadbeef"}], "count": 1, "limit": limit}
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "get_recycle_bin", fake_get_recycle_bin)
|
||||
|
||||
response = client.get("/api/memory/recycle_bin", params={"limit": 10})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["success"] is True
|
||||
assert response.json()["count"] == 1
|
||||
assert response.json()["limit"] == 10
|
||||
|
||||
|
||||
def test_import_guide_route(client: TestClient, monkeypatch):
|
||||
async def fake_import_admin(*, action: str, **kwargs):
|
||||
assert kwargs == {}
|
||||
if action == "get_guide":
|
||||
return {"success": True}
|
||||
if action == "get_settings":
|
||||
return {"success": True, "settings": {"path_aliases": {"raw": "/tmp/raw"}}}
|
||||
raise AssertionError(action)
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin)
|
||||
|
||||
response = client.get("/api/webui/memory/import/guide")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["success"] is True
|
||||
assert response.json()["source"] == "local"
|
||||
assert "长期记忆导入说明" in response.json()["content"]
|
||||
|
||||
|
||||
def test_import_upload_route(client: TestClient, monkeypatch, tmp_path):
|
||||
monkeypatch.setattr(memory_router_module, "STAGING_ROOT", tmp_path)
|
||||
|
||||
async def fake_import_admin(*, action: str, **kwargs):
|
||||
assert action == "create_upload"
|
||||
staged_files = kwargs["staged_files"]
|
||||
assert len(staged_files) == 1
|
||||
assert staged_files[0]["filename"] == "demo.txt"
|
||||
assert memory_router_module.Path(staged_files[0]["staged_path"]).exists()
|
||||
return {"success": True, "task_id": "task-1"}
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin)
|
||||
|
||||
response = client.post(
|
||||
"/api/import/upload",
|
||||
data={"payload_json": "{\"source\": \"upload\"}"},
|
||||
files=[("files", ("demo.txt", b"hello world", "text/plain"))],
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"success": True, "task_id": "task-1"}
|
||||
assert list(tmp_path.iterdir()) == []
|
||||
|
||||
|
||||
def test_v5_status_route(client: TestClient, monkeypatch):
|
||||
async def fake_v5_admin(*, action: str, **kwargs):
|
||||
assert action == "status"
|
||||
assert kwargs["target"] == "mai"
|
||||
return {"success": True, "active_count": 1, "inactive_count": 2, "deleted_count": 3}
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "v5_admin", fake_v5_admin)
|
||||
|
||||
response = client.get("/api/webui/memory/v5/status", params={"target": "mai"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["success"] is True
|
||||
assert response.json()["deleted_count"] == 3
|
||||
|
||||
|
||||
def test_delete_preview_route(client: TestClient, monkeypatch):
|
||||
async def fake_delete_admin(*, action: str, **kwargs):
|
||||
assert action == "preview"
|
||||
assert kwargs["mode"] == "paragraph"
|
||||
assert kwargs["selector"] == {"query": "demo"}
|
||||
return {"success": True, "counts": {"paragraphs": 1}, "dry_run": True}
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
|
||||
|
||||
response = client.post(
|
||||
"/api/webui/memory/delete/preview",
|
||||
json={"mode": "paragraph", "selector": {"query": "demo"}},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"success": True, "counts": {"paragraphs": 1}, "dry_run": True}
|
||||
|
||||
|
||||
def test_delete_preview_route_supports_mixed_mode(client: TestClient, monkeypatch):
|
||||
async def fake_delete_admin(*, action: str, **kwargs):
|
||||
assert action == "preview"
|
||||
assert kwargs["mode"] == "mixed"
|
||||
assert kwargs["selector"] == {
|
||||
"entity_hashes": ["entity-1"],
|
||||
"paragraph_hashes": ["p-1"],
|
||||
"relation_hashes": ["rel-1"],
|
||||
"sources": ["demo"],
|
||||
}
|
||||
return {"success": True, "mode": "mixed", "counts": {"entities": 1, "paragraphs": 1, "relations": 1, "sources": 1}}
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
|
||||
|
||||
response = client.post(
|
||||
"/api/webui/memory/delete/preview",
|
||||
json={
|
||||
"mode": "mixed",
|
||||
"selector": {
|
||||
"entity_hashes": ["entity-1"],
|
||||
"paragraph_hashes": ["p-1"],
|
||||
"relation_hashes": ["rel-1"],
|
||||
"sources": ["demo"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["mode"] == "mixed"
|
||||
assert response.json()["counts"]["entities"] == 1
|
||||
|
||||
|
||||
def test_delete_execute_route_supports_mixed_mode(client: TestClient, monkeypatch):
|
||||
async def fake_delete_admin(*, action: str, **kwargs):
|
||||
assert action == "execute"
|
||||
assert kwargs["mode"] == "mixed"
|
||||
assert kwargs["selector"] == {
|
||||
"entity_hashes": ["entity-1"],
|
||||
"paragraph_hashes": ["p-1"],
|
||||
"relation_hashes": ["rel-1"],
|
||||
"sources": ["demo"],
|
||||
}
|
||||
assert kwargs["reason"] == "knowledge_graph_delete_entity"
|
||||
assert kwargs["requested_by"] == "knowledge_graph"
|
||||
return {
|
||||
"success": True,
|
||||
"mode": "mixed",
|
||||
"operation_id": "op-mixed-1",
|
||||
"deleted_count": 4,
|
||||
"deleted_entity_count": 1,
|
||||
"deleted_relation_count": 1,
|
||||
"deleted_paragraph_count": 1,
|
||||
"deleted_source_count": 1,
|
||||
"counts": {"entities": 1, "paragraphs": 1, "relations": 1, "sources": 1},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
|
||||
|
||||
response = client.post(
|
||||
"/api/webui/memory/delete/execute",
|
||||
json={
|
||||
"mode": "mixed",
|
||||
"selector": {
|
||||
"entity_hashes": ["entity-1"],
|
||||
"paragraph_hashes": ["p-1"],
|
||||
"relation_hashes": ["rel-1"],
|
||||
"sources": ["demo"],
|
||||
},
|
||||
"reason": "knowledge_graph_delete_entity",
|
||||
"requested_by": "knowledge_graph",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["success"] is True
|
||||
assert response.json()["mode"] == "mixed"
|
||||
assert response.json()["operation_id"] == "op-mixed-1"
|
||||
|
||||
|
||||
def test_episode_process_pending_route(client: TestClient, monkeypatch):
|
||||
async def fake_episode_admin(*, action: str, **kwargs):
|
||||
assert action == "process_pending"
|
||||
assert kwargs == {"limit": 7, "max_retry": 4}
|
||||
return {"success": True, "processed": 3}
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "episode_admin", fake_episode_admin)
|
||||
|
||||
response = client.post("/api/webui/memory/episodes/process-pending", json={"limit": 7, "max_retry": 4})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"success": True, "processed": 3}
|
||||
|
||||
|
||||
def test_import_list_route_includes_settings(client: TestClient, monkeypatch):
|
||||
calls = []
|
||||
|
||||
async def fake_import_admin(*, action: str, **kwargs):
|
||||
calls.append((action, kwargs))
|
||||
if action == "list":
|
||||
return {"success": True, "items": [{"task_id": "task-1"}]}
|
||||
if action == "get_settings":
|
||||
return {"success": True, "settings": {"path_aliases": {"lpmm": "/tmp/lpmm"}}}
|
||||
raise AssertionError(action)
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin)
|
||||
|
||||
response = client.get("/api/webui/memory/import/tasks", params={"limit": 9})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["items"] == [{"task_id": "task-1"}]
|
||||
assert response.json()["settings"] == {"path_aliases": {"lpmm": "/tmp/lpmm"}}
|
||||
assert calls == [("list", {"limit": 9}), ("get_settings", {})]
|
||||
|
||||
|
||||
def test_tuning_profile_route_backfills_settings(client: TestClient, monkeypatch):
|
||||
calls = []
|
||||
|
||||
async def fake_tuning_admin(*, action: str, **kwargs):
|
||||
calls.append((action, kwargs))
|
||||
if action == "get_profile":
|
||||
return {"success": True, "profile": {"retrieval": {"top_k": 8}}}
|
||||
if action == "get_settings":
|
||||
return {"success": True, "settings": {"profiles": ["default"]}}
|
||||
raise AssertionError(action)
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "tuning_admin", fake_tuning_admin)
|
||||
|
||||
response = client.get("/api/webui/memory/retrieval_tuning/profile")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["profile"] == {"retrieval": {"top_k": 8}}
|
||||
assert response.json()["settings"] == {"profiles": ["default"]}
|
||||
assert calls == [("get_profile", {}), ("get_settings", {})]
|
||||
|
||||
|
||||
def test_tuning_report_route_flattens_report_payload(client: TestClient, monkeypatch):
|
||||
async def fake_tuning_admin(*, action: str, **kwargs):
|
||||
assert action == "get_report"
|
||||
assert kwargs == {"task_id": "task-1", "format": "json"}
|
||||
return {
|
||||
"success": True,
|
||||
"report": {"format": "json", "content": "{\"ok\": true}", "path": "/tmp/report.json"},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "tuning_admin", fake_tuning_admin)
|
||||
|
||||
response = client.get("/api/webui/memory/retrieval_tuning/tasks/task-1/report", params={"format": "json"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"success": True,
|
||||
"format": "json",
|
||||
"content": "{\"ok\": true}",
|
||||
"path": "/tmp/report.json",
|
||||
"error": "",
|
||||
}
|
||||
|
||||
|
||||
def test_delete_execute_route(client: TestClient, monkeypatch):
|
||||
async def fake_delete_admin(*, action: str, **kwargs):
|
||||
assert action == "execute"
|
||||
assert kwargs["mode"] == "source"
|
||||
assert kwargs["selector"] == {"source": "chat_summary:stream-1"}
|
||||
assert kwargs["reason"] == "cleanup"
|
||||
assert kwargs["requested_by"] == "tester"
|
||||
return {"success": True, "operation_id": "del-1"}
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
|
||||
|
||||
response = client.post(
|
||||
"/api/webui/memory/delete/execute",
|
||||
json={
|
||||
"mode": "source",
|
||||
"selector": {"source": "chat_summary:stream-1"},
|
||||
"reason": "cleanup",
|
||||
"requested_by": "tester",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"success": True, "operation_id": "del-1"}
|
||||
|
||||
|
||||
def test_sources_route(client: TestClient, monkeypatch):
|
||||
async def fake_source_admin(*, action: str, **kwargs):
|
||||
assert action == "list"
|
||||
assert kwargs == {}
|
||||
return {"success": True, "items": [{"source": "demo", "paragraph_count": 2}], "count": 1}
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "source_admin", fake_source_admin)
|
||||
|
||||
response = client.get("/api/webui/memory/sources")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["items"] == [{"source": "demo", "paragraph_count": 2}]
|
||||
|
||||
|
||||
def test_delete_operation_routes(client: TestClient, monkeypatch):
|
||||
async def fake_delete_admin(*, action: str, **kwargs):
|
||||
if action == "list_operations":
|
||||
assert kwargs == {"limit": 5, "mode": "paragraph"}
|
||||
return {"success": True, "items": [{"operation_id": "del-1"}], "count": 1}
|
||||
if action == "get_operation":
|
||||
assert kwargs == {"operation_id": "del-1"}
|
||||
return {"success": True, "operation": {"operation_id": "del-1", "mode": "paragraph"}}
|
||||
raise AssertionError(action)
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
|
||||
|
||||
list_response = client.get("/api/webui/memory/delete/operations", params={"limit": 5, "mode": "paragraph"})
|
||||
get_response = client.get("/api/webui/memory/delete/operations/del-1")
|
||||
|
||||
assert list_response.status_code == 200
|
||||
assert list_response.json()["count"] == 1
|
||||
assert get_response.status_code == 200
|
||||
assert get_response.json()["operation"]["operation_id"] == "del-1"
|
||||
|
||||
|
||||
def test_feedback_correction_routes(client: TestClient, monkeypatch):
|
||||
async def fake_feedback_admin(*, action: str, **kwargs):
|
||||
if action == "list":
|
||||
assert kwargs == {
|
||||
"limit": 7,
|
||||
"statuses": ["applied"],
|
||||
"rollback_statuses": ["none"],
|
||||
"query": "green",
|
||||
}
|
||||
return {"success": True, "items": [{"task_id": 11, "query_text": "what color"}], "count": 1}
|
||||
if action == "get":
|
||||
assert kwargs == {"task_id": 11}
|
||||
return {"success": True, "task": {"task_id": 11, "query_text": "what color", "action_logs": []}}
|
||||
if action == "rollback":
|
||||
assert kwargs == {"task_id": 11, "requested_by": "tester", "reason": "manual revert"}
|
||||
return {"success": True, "result": {"restored_relation_hashes": ["rel-1"]}}
|
||||
raise AssertionError(action)
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "feedback_admin", fake_feedback_admin)
|
||||
|
||||
list_response = client.get(
|
||||
"/api/webui/memory/feedback-corrections",
|
||||
params={"limit": 7, "status": "applied", "rollback_status": "none", "query": "green"},
|
||||
)
|
||||
get_response = client.get("/api/webui/memory/feedback-corrections/11")
|
||||
rollback_response = client.post(
|
||||
"/api/webui/memory/feedback-corrections/11/rollback",
|
||||
json={"requested_by": "tester", "reason": "manual revert"},
|
||||
)
|
||||
|
||||
assert list_response.status_code == 200
|
||||
assert list_response.json()["items"][0]["task_id"] == 11
|
||||
assert get_response.status_code == 200
|
||||
assert get_response.json()["task"]["task_id"] == 11
|
||||
assert rollback_response.status_code == 200
|
||||
assert rollback_response.json()["result"]["restored_relation_hashes"] == ["rel-1"]
|
||||
533
pytests/webui/test_memory_routes_integration.py
Normal file
533
pytests/webui/test_memory_routes_integration.py
Normal file
@@ -0,0 +1,533 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from time import monotonic, sleep
|
||||
from typing import Any, Dict, Generator
|
||||
from uuid import uuid4
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
import pytest
|
||||
import tomlkit
|
||||
|
||||
from src.A_memorix import host_service as host_service_module
|
||||
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
|
||||
from src.A_memorix.core.utils import retrieval_tuning_manager as tuning_manager_module
|
||||
from src.webui.dependencies import require_auth
|
||||
from src.webui.routers import memory as memory_router_module
|
||||
|
||||
|
||||
REQUEST_TIMEOUT_SECONDS = 30
|
||||
IMPORT_TIMEOUT_SECONDS = 120
|
||||
TUNING_TIMEOUT_SECONDS = 420
|
||||
|
||||
IMPORT_TERMINAL_STATUSES = {"completed", "completed_with_errors", "failed", "cancelled"}
|
||||
TUNING_TERMINAL_STATUSES = {"completed", "failed", "cancelled"}
|
||||
|
||||
|
||||
class _FakeEmbeddingManager:
|
||||
def __init__(self, dimension: int = 64) -> None:
|
||||
self.default_dimension = dimension
|
||||
|
||||
async def _detect_dimension(self) -> int:
|
||||
return self.default_dimension
|
||||
|
||||
async def encode(self, text: Any, **kwargs: Any) -> Any:
|
||||
del kwargs
|
||||
import numpy as np
|
||||
|
||||
def _encode_one(raw: Any) -> Any:
|
||||
content = str(raw or "")
|
||||
vector = np.zeros(self.default_dimension, dtype=np.float32)
|
||||
for index, byte in enumerate(content.encode("utf-8")):
|
||||
vector[index % self.default_dimension] += float((byte % 17) + 1)
|
||||
norm = float(np.linalg.norm(vector))
|
||||
if norm > 0:
|
||||
vector /= norm
|
||||
return vector
|
||||
|
||||
if isinstance(text, (list, tuple)):
|
||||
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
|
||||
return _encode_one(text).astype(np.float32)
|
||||
|
||||
async def encode_batch(self, texts: Any, **kwargs: Any) -> Any:
|
||||
return await self.encode(texts, **kwargs)
|
||||
|
||||
|
||||
def _build_test_config(data_dir: Path) -> Dict[str, Any]:
|
||||
return {
|
||||
"storage": {
|
||||
"data_dir": str(data_dir),
|
||||
},
|
||||
"advanced": {
|
||||
"enable_auto_save": False,
|
||||
},
|
||||
"embedding": {
|
||||
"dimension": 64,
|
||||
"batch_size": 4,
|
||||
"max_concurrent": 1,
|
||||
"retry": {
|
||||
"max_attempts": 1,
|
||||
"min_wait_seconds": 0.1,
|
||||
"max_wait_seconds": 0.2,
|
||||
"backoff_multiplier": 1.0,
|
||||
},
|
||||
"fallback": {
|
||||
"enabled": True,
|
||||
"allow_metadata_only_write": True,
|
||||
"probe_interval_seconds": 30,
|
||||
},
|
||||
"paragraph_vector_backfill": {
|
||||
"enabled": False,
|
||||
"interval_seconds": 60,
|
||||
"batch_size": 32,
|
||||
"max_retry": 2,
|
||||
},
|
||||
},
|
||||
"retrieval": {
|
||||
"enable_parallel": False,
|
||||
"enable_ppr": False,
|
||||
"top_k_paragraphs": 20,
|
||||
"top_k_relations": 10,
|
||||
"top_k_final": 10,
|
||||
"alpha": 0.5,
|
||||
"search": {
|
||||
"smart_fallback": {
|
||||
"enabled": True,
|
||||
},
|
||||
},
|
||||
"sparse": {
|
||||
"enabled": True,
|
||||
"mode": "auto",
|
||||
"candidate_k": 80,
|
||||
"relation_candidate_k": 60,
|
||||
},
|
||||
"fusion": {
|
||||
"method": "weighted_rrf",
|
||||
"rrf_k": 60,
|
||||
"vector_weight": 0.7,
|
||||
"bm25_weight": 0.3,
|
||||
},
|
||||
},
|
||||
"threshold": {
|
||||
"percentile": 70.0,
|
||||
"min_results": 1,
|
||||
},
|
||||
"web": {
|
||||
"tuning": {
|
||||
"enabled": True,
|
||||
"poll_interval_ms": 300,
|
||||
"max_queue_size": 4,
|
||||
"default_objective": "balanced",
|
||||
"default_intensity": "quick",
|
||||
"default_sample_size": 4,
|
||||
"default_top_k_eval": 5,
|
||||
"eval_query_timeout_seconds": 1.0,
|
||||
"llm_retry": {
|
||||
"max_attempts": 1,
|
||||
"min_wait_seconds": 0.1,
|
||||
"max_wait_seconds": 0.2,
|
||||
"backoff_multiplier": 1.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _assert_response_ok(response: Any) -> Dict[str, Any]:
|
||||
assert response.status_code == 200, response.text
|
||||
payload = response.json()
|
||||
assert payload.get("success", True) is True, payload
|
||||
return payload
|
||||
|
||||
|
||||
def _wait_for_import_task_terminal(client: TestClient, task_id: str, *, timeout_seconds: float = IMPORT_TIMEOUT_SECONDS) -> Dict[str, Any]:
|
||||
deadline = monotonic() + timeout_seconds
|
||||
last_payload: Dict[str, Any] = {}
|
||||
while monotonic() < deadline:
|
||||
response = client.get(
|
||||
f"/api/webui/memory/import/tasks/{task_id}",
|
||||
params={"include_chunks": True},
|
||||
)
|
||||
payload = _assert_response_ok(response)
|
||||
last_payload = payload
|
||||
task = payload.get("task") or {}
|
||||
status = str(task.get("status", "") or "")
|
||||
if status in IMPORT_TERMINAL_STATUSES:
|
||||
return task
|
||||
sleep(0.2)
|
||||
raise AssertionError(f"导入任务超时: task_id={task_id}, last_payload={last_payload}")
|
||||
|
||||
|
||||
def _wait_for_tuning_task_terminal(client: TestClient, task_id: str, *, timeout_seconds: float = TUNING_TIMEOUT_SECONDS) -> Dict[str, Any]:
|
||||
deadline = monotonic() + timeout_seconds
|
||||
last_payload: Dict[str, Any] = {}
|
||||
while monotonic() < deadline:
|
||||
response = client.get(
|
||||
f"/api/webui/memory/retrieval_tuning/tasks/{task_id}",
|
||||
params={"include_rounds": False},
|
||||
)
|
||||
payload = _assert_response_ok(response)
|
||||
last_payload = payload
|
||||
task = payload.get("task") or {}
|
||||
status = str(task.get("status", "") or "")
|
||||
if status in TUNING_TERMINAL_STATUSES:
|
||||
return task
|
||||
sleep(0.3)
|
||||
raise AssertionError(f"调优任务超时: task_id={task_id}, last_payload={last_payload}")
|
||||
|
||||
|
||||
def _wait_for_query_hit(client: TestClient, query: str, *, timeout_seconds: float = 30.0) -> Dict[str, Any]:
|
||||
deadline = monotonic() + timeout_seconds
|
||||
last_payload: Dict[str, Any] = {}
|
||||
while monotonic() < deadline:
|
||||
payload = _assert_response_ok(
|
||||
client.get(
|
||||
"/api/webui/memory/query/aggregate",
|
||||
params={"query": query, "limit": 20},
|
||||
)
|
||||
)
|
||||
last_payload = payload
|
||||
hits = payload.get("hits") or []
|
||||
if isinstance(hits, list) and len(hits) > 0:
|
||||
return payload
|
||||
sleep(0.2)
|
||||
raise AssertionError(f"检索命中超时: query={query}, last_payload={last_payload}")
|
||||
|
||||
|
||||
def _get_source_item(client: TestClient, source_name: str) -> Dict[str, Any] | None:
|
||||
payload = _assert_response_ok(client.get("/api/webui/memory/sources"))
|
||||
items = payload.get("items") or []
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if str(item.get("source", "") or "") == source_name:
|
||||
return item
|
||||
return None
|
||||
|
||||
|
||||
def _source_paragraph_count(item: Dict[str, Any] | None) -> int:
|
||||
payload = item or {}
|
||||
if "paragraph_count" in payload:
|
||||
return int(payload.get("paragraph_count", 0) or 0)
|
||||
return int(payload.get("count", 0) or 0)
|
||||
|
||||
|
||||
def _wait_for_source_paragraph_count(
|
||||
client: TestClient,
|
||||
source_name: str,
|
||||
*,
|
||||
min_count: int,
|
||||
timeout_seconds: float = 30.0,
|
||||
) -> Dict[str, Any]:
|
||||
deadline = monotonic() + timeout_seconds
|
||||
last_item: Dict[str, Any] = {}
|
||||
while monotonic() < deadline:
|
||||
item = _get_source_item(client, source_name)
|
||||
count = _source_paragraph_count(item)
|
||||
if count >= int(min_count):
|
||||
return item or {}
|
||||
if item:
|
||||
last_item = dict(item)
|
||||
sleep(0.2)
|
||||
raise AssertionError(
|
||||
f"等待来源段落计数超时: source={source_name}, min_count={min_count}, last_item={last_item}"
|
||||
)
|
||||
|
||||
|
||||
def _create_multitype_upload_task(client: TestClient) -> str:
|
||||
structured_json = {
|
||||
"paragraphs": [
|
||||
{
|
||||
"content": "Alice 携带地图前往火星港。",
|
||||
"source": "integration-upload-json",
|
||||
"entities": ["Alice", "地图", "火星港"],
|
||||
"relations": [
|
||||
{"subject": "Alice", "predicate": "携带", "object": "地图"},
|
||||
{"subject": "Alice", "predicate": "前往", "object": "火星港"},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
extra_json = {
|
||||
"paragraphs": [
|
||||
{
|
||||
"content": "Carol 记录了一条补充说明。",
|
||||
"source": "integration-upload-json-extra",
|
||||
"entities": ["Carol"],
|
||||
"relations": [],
|
||||
}
|
||||
]
|
||||
}
|
||||
payload_json = json.dumps(
|
||||
{
|
||||
"input_mode": "text",
|
||||
"llm_enabled": False,
|
||||
"file_concurrency": 2,
|
||||
"chunk_concurrency": 2,
|
||||
"dedupe_policy": "none",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
files = [
|
||||
("files", ("integration-notes.txt", "Alice 在测试环境记录了一条长期记忆。".encode("utf-8"), "text/plain")),
|
||||
("files", ("integration-diary.md", "# 日志\nBob 与 Alice 讨论了导图。".encode("utf-8"), "text/markdown")),
|
||||
("files", ("integration-structured.json", json.dumps(structured_json, ensure_ascii=False).encode("utf-8"), "application/json")),
|
||||
("files", ("integration-extra.json", json.dumps(extra_json, ensure_ascii=False).encode("utf-8"), "application/json")),
|
||||
]
|
||||
|
||||
response = client.post(
|
||||
"/api/webui/memory/import/upload",
|
||||
data={"payload_json": payload_json},
|
||||
files=files,
|
||||
)
|
||||
payload = _assert_response_ok(response)
|
||||
task_id = str((payload.get("task") or {}).get("task_id") or "").strip()
|
||||
assert task_id, payload
|
||||
return task_id
|
||||
|
||||
|
||||
def _create_seed_paste_task(client: TestClient, *, source: str, unique_token: str) -> str:
|
||||
seed_payload = {
|
||||
"paragraphs": [
|
||||
{
|
||||
"content": f"Alice 在火星港携带地图并记录了口令 {unique_token}。",
|
||||
"source": source,
|
||||
"entities": ["Alice", "火星港", "地图"],
|
||||
"relations": [
|
||||
{"subject": "Alice", "predicate": "前往", "object": "火星港"},
|
||||
{"subject": "Alice", "predicate": "携带", "object": "地图"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"content": f"Bob 在火星港遇见 Alice,并重复口令 {unique_token}。",
|
||||
"source": source,
|
||||
"entities": ["Bob", "Alice", "火星港"],
|
||||
"relations": [
|
||||
{"subject": "Bob", "predicate": "遇见", "object": "Alice"},
|
||||
{"subject": "Bob", "predicate": "位于", "object": "火星港"},
|
||||
],
|
||||
},
|
||||
]
|
||||
}
|
||||
response = client.post(
|
||||
"/api/webui/memory/import/paste",
|
||||
json={
|
||||
"name": "integration-seed.json",
|
||||
"input_mode": "json",
|
||||
"llm_enabled": False,
|
||||
"content": json.dumps(seed_payload, ensure_ascii=False),
|
||||
"dedupe_policy": "none",
|
||||
},
|
||||
)
|
||||
payload = _assert_response_ok(response)
|
||||
task_id = str((payload.get("task") or {}).get("task_id") or "").strip()
|
||||
assert task_id, payload
|
||||
return task_id
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def integration_state(tmp_path_factory: pytest.TempPathFactory) -> Generator[Dict[str, Any], None, None]:
|
||||
tmp_root = tmp_path_factory.mktemp("memory_routes_integration")
|
||||
data_dir = (tmp_root / "data").resolve()
|
||||
staging_dir = (tmp_root / "upload_staging").resolve()
|
||||
artifacts_dir = (tmp_root / "artifacts").resolve()
|
||||
config_file = (tmp_root / "config" / "bot_config.toml").resolve()
|
||||
runtime_config = _build_test_config(data_dir)
|
||||
|
||||
patches = pytest.MonkeyPatch()
|
||||
patches.setattr(host_service_module.a_memorix_host_service, "_read_config", lambda: dict(runtime_config))
|
||||
patches.setattr(host_service_module.a_memorix_host_service, "get_config_path", lambda: config_file)
|
||||
patches.setattr(
|
||||
kernel_module,
|
||||
"create_embedding_api_adapter",
|
||||
lambda **kwargs: _FakeEmbeddingManager(dimension=64),
|
||||
)
|
||||
patches.setattr(memory_router_module, "STAGING_ROOT", staging_dir)
|
||||
patches.setattr(tuning_manager_module, "artifacts_root", lambda: artifacts_dir)
|
||||
|
||||
asyncio.run(host_service_module.a_memorix_host_service.stop())
|
||||
host_service_module.a_memorix_host_service._config_cache = None # type: ignore[attr-defined]
|
||||
|
||||
app = FastAPI()
|
||||
app.dependency_overrides[require_auth] = lambda: "ok"
|
||||
app.include_router(memory_router_module.router, prefix="/api/webui")
|
||||
app.include_router(memory_router_module.compat_router)
|
||||
|
||||
unique_token = f"INTEG_TOKEN_{uuid4().hex[:12]}"
|
||||
source_name = f"integration-source-{uuid4().hex[:8]}"
|
||||
|
||||
with TestClient(app) as client:
|
||||
upload_task_id = _create_multitype_upload_task(client)
|
||||
upload_task = _wait_for_import_task_terminal(client, upload_task_id)
|
||||
|
||||
seed_task_id = _create_seed_paste_task(client, source=source_name, unique_token=unique_token)
|
||||
seed_task = _wait_for_import_task_terminal(client, seed_task_id)
|
||||
assert str(seed_task.get("status", "") or "") in {"completed", "completed_with_errors"}, seed_task
|
||||
|
||||
_wait_for_query_hit(client, unique_token, timeout_seconds=45.0)
|
||||
|
||||
yield {
|
||||
"client": client,
|
||||
"upload_task": upload_task,
|
||||
"seed_task": seed_task,
|
||||
"source_name": source_name,
|
||||
"unique_token": unique_token,
|
||||
}
|
||||
|
||||
asyncio.run(host_service_module.a_memorix_host_service.stop())
|
||||
host_service_module.a_memorix_host_service._config_cache = None # type: ignore[attr-defined]
|
||||
patches.undo()
|
||||
|
||||
|
||||
def test_import_module_end_to_end_supports_multitype_upload(integration_state: Dict[str, Any]) -> None:
|
||||
upload_task = integration_state["upload_task"]
|
||||
|
||||
assert str(upload_task.get("status", "") or "") in {"completed", "completed_with_errors"}, upload_task
|
||||
files = upload_task.get("files") or []
|
||||
assert isinstance(files, list)
|
||||
assert len(files) >= 4
|
||||
|
||||
file_names = {str(item.get("name", "") or "") for item in files if isinstance(item, dict)}
|
||||
assert "integration-notes.txt" in file_names
|
||||
assert "integration-diary.md" in file_names
|
||||
assert "integration-structured.json" in file_names
|
||||
assert "integration-extra.json" in file_names
|
||||
|
||||
|
||||
def test_retrieval_module_end_to_end_queries_seeded_data(integration_state: Dict[str, Any]) -> None:
|
||||
client = integration_state["client"]
|
||||
unique_token = integration_state["unique_token"]
|
||||
|
||||
aggregate_payload = _wait_for_query_hit(client, unique_token, timeout_seconds=45.0)
|
||||
hits = aggregate_payload.get("hits") or []
|
||||
joined_content = "\n".join(str(item.get("content", "") or "") for item in hits if isinstance(item, dict))
|
||||
assert unique_token in joined_content
|
||||
|
||||
graph_payload = _assert_response_ok(
|
||||
client.get(
|
||||
"/api/webui/memory/graph/search",
|
||||
params={"query": "Alice", "limit": 20},
|
||||
)
|
||||
)
|
||||
graph_items = graph_payload.get("items") or []
|
||||
assert isinstance(graph_items, list)
|
||||
assert any(str(item.get("type", "") or "") == "entity" for item in graph_items if isinstance(item, dict)), graph_items
|
||||
|
||||
|
||||
def test_tuning_module_end_to_end_create_and_apply_best(integration_state: Dict[str, Any]) -> None:
|
||||
client = integration_state["client"]
|
||||
|
||||
create_payload = _assert_response_ok(
|
||||
client.post(
|
||||
"/api/webui/memory/retrieval_tuning/tasks",
|
||||
json={
|
||||
"objective": "balanced",
|
||||
"intensity": "quick",
|
||||
"rounds": 2,
|
||||
"sample_size": 4,
|
||||
"top_k_eval": 5,
|
||||
"llm_enabled": False,
|
||||
"eval_query_timeout_seconds": 1.0,
|
||||
"seed": 20260403,
|
||||
},
|
||||
)
|
||||
)
|
||||
task_id = str((create_payload.get("task") or {}).get("task_id") or "").strip()
|
||||
assert task_id, create_payload
|
||||
|
||||
task = _wait_for_tuning_task_terminal(client, task_id)
|
||||
assert str(task.get("status", "") or "") == "completed", task
|
||||
|
||||
apply_payload = _assert_response_ok(
|
||||
client.post(
|
||||
f"/api/webui/memory/retrieval_tuning/tasks/{task_id}/apply-best",
|
||||
)
|
||||
)
|
||||
assert "applied" in apply_payload
|
||||
|
||||
|
||||
def test_delete_module_end_to_end_preview_execute_restore(integration_state: Dict[str, Any]) -> None:
|
||||
client = integration_state["client"]
|
||||
unique_token = integration_state["unique_token"]
|
||||
source_name = integration_state["source_name"]
|
||||
|
||||
before_source_item = _wait_for_source_paragraph_count(client, source_name, min_count=1, timeout_seconds=45.0)
|
||||
assert _source_paragraph_count(before_source_item) >= 1
|
||||
|
||||
preview_payload = _assert_response_ok(
|
||||
client.post(
|
||||
"/api/webui/memory/delete/preview",
|
||||
json={
|
||||
"mode": "source",
|
||||
"selector": {"sources": [source_name]},
|
||||
"reason": "integration_delete_preview",
|
||||
"requested_by": "pytest_integration",
|
||||
},
|
||||
)
|
||||
)
|
||||
preview_counts = preview_payload.get("counts") or {}
|
||||
assert int(preview_counts.get("paragraphs", 0) or 0) >= 1, preview_payload
|
||||
|
||||
execute_payload = _assert_response_ok(
|
||||
client.post(
|
||||
"/api/webui/memory/delete/execute",
|
||||
json={
|
||||
"mode": "source",
|
||||
"selector": {"sources": [source_name]},
|
||||
"reason": "integration_delete_execute",
|
||||
"requested_by": "pytest_integration",
|
||||
},
|
||||
)
|
||||
)
|
||||
operation_id = str(execute_payload.get("operation_id", "") or "").strip()
|
||||
assert operation_id, execute_payload
|
||||
|
||||
after_delete_payload = _assert_response_ok(
|
||||
client.get(
|
||||
"/api/webui/memory/query/aggregate",
|
||||
params={"query": unique_token, "limit": 20},
|
||||
)
|
||||
)
|
||||
after_delete_hits = after_delete_payload.get("hits") or []
|
||||
after_delete_text = "\n".join(
|
||||
str(item.get("content", "") or "")
|
||||
for item in after_delete_hits
|
||||
if isinstance(item, dict)
|
||||
)
|
||||
assert unique_token not in after_delete_text
|
||||
assert int(execute_payload.get("deleted_paragraph_count", 0) or 0) >= 1, execute_payload
|
||||
|
||||
_assert_response_ok(
|
||||
client.post(
|
||||
"/api/webui/memory/delete/restore",
|
||||
json={
|
||||
"operation_id": operation_id,
|
||||
"requested_by": "pytest_integration",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
restored_source_item = _wait_for_source_paragraph_count(client, source_name, min_count=1, timeout_seconds=45.0)
|
||||
assert _source_paragraph_count(restored_source_item) >= 1
|
||||
|
||||
operations_payload = _assert_response_ok(
|
||||
client.get(
|
||||
"/api/webui/memory/delete/operations",
|
||||
params={"limit": 20, "mode": "source"},
|
||||
)
|
||||
)
|
||||
operation_items = operations_payload.get("items") or []
|
||||
operation_ids = {
|
||||
str(item.get("operation_id", "") or "")
|
||||
for item in operation_items
|
||||
if isinstance(item, dict)
|
||||
}
|
||||
assert operation_id in operation_ids
|
||||
|
||||
operation_detail_payload = _assert_response_ok(client.get(f"/api/webui/memory/delete/operations/{operation_id}"))
|
||||
detail_operation = operation_detail_payload.get("operation") or {}
|
||||
assert str(detail_operation.get("status", "") or "") == "restored"
|
||||
187
pytests/webui/test_model_routes.py
Normal file
187
pytests/webui/test_model_routes.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""模型路由测试
|
||||
|
||||
验证 Gemini 提供商连接测试会使用查询参数传递 API Key,
|
||||
并且不会回退到 OpenAI 兼容接口使用的 Bearer 认证方式。
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def load_model_routes(monkeypatch: pytest.MonkeyPatch):
|
||||
"""在导入路由前 stub 配置与认证依赖模块,避免测试时触发真实初始化。"""
|
||||
config_module = ModuleType("src.config.config")
|
||||
config_module.__dict__["CONFIG_DIR"] = "."
|
||||
monkeypatch.setitem(sys.modules, "src.config.config", config_module)
|
||||
|
||||
dependencies_module = ModuleType("src.webui.dependencies")
|
||||
|
||||
async def require_auth():
|
||||
return "test-token"
|
||||
|
||||
dependencies_module.__dict__["require_auth"] = require_auth
|
||||
monkeypatch.setitem(sys.modules, "src.webui.dependencies", dependencies_module)
|
||||
|
||||
sys.modules.pop("src.webui.routers.model", None)
|
||||
return importlib.import_module("src.webui.routers.model")
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
"""简化版 HTTP 响应对象。"""
|
||||
|
||||
def __init__(self, status_code: int):
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
def build_async_client_factory(
|
||||
responses: list[FakeResponse],
|
||||
calls: list[dict[str, Any]],
|
||||
):
|
||||
"""构造一个可记录请求参数的 AsyncClient 替身。"""
|
||||
|
||||
response_iter = iter(responses)
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
async def __aenter__(self) -> "FakeAsyncClient":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
return False
|
||||
|
||||
async def get(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> FakeResponse:
|
||||
calls.append(
|
||||
{
|
||||
"url": url,
|
||||
"headers": headers or {},
|
||||
"params": params or {},
|
||||
}
|
||||
)
|
||||
return next(response_iter)
|
||||
|
||||
return FakeAsyncClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_test_provider_connection_uses_query_api_key_for_gemini(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Gemini 连接测试应通过查询参数传递 API Key。"""
|
||||
model_routes = load_model_routes(monkeypatch)
|
||||
calls: list[dict[str, Any]] = []
|
||||
fake_client_class = build_async_client_factory(
|
||||
responses=[FakeResponse(200), FakeResponse(200)],
|
||||
calls=calls,
|
||||
)
|
||||
monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class)
|
||||
|
||||
result = await model_routes.test_provider_connection(
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta",
|
||||
api_key="valid-gemini-key",
|
||||
client_type="gemini",
|
||||
)
|
||||
|
||||
assert result["network_ok"] is True
|
||||
assert result["api_key_valid"] is True
|
||||
assert len(calls) == 2
|
||||
|
||||
network_call = calls[0]
|
||||
validation_call = calls[1]
|
||||
|
||||
assert network_call["url"] == "https://generativelanguage.googleapis.com/v1beta"
|
||||
assert network_call["headers"] == {}
|
||||
assert network_call["params"] == {}
|
||||
|
||||
assert validation_call["url"] == "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
assert validation_call["params"] == {"key": "valid-gemini-key"}
|
||||
assert validation_call["headers"] == {"Content-Type": "application/json"}
|
||||
assert "Authorization" not in validation_call["headers"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_test_provider_connection_uses_bearer_auth_for_openai_compatible(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""非 Gemini 提供商连接测试应继续使用 Bearer 认证。"""
|
||||
model_routes = load_model_routes(monkeypatch)
|
||||
calls: list[dict[str, Any]] = []
|
||||
fake_client_class = build_async_client_factory(
|
||||
responses=[FakeResponse(200), FakeResponse(200)],
|
||||
calls=calls,
|
||||
)
|
||||
monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class)
|
||||
|
||||
result = await model_routes.test_provider_connection(
|
||||
base_url="https://example.com/v1",
|
||||
api_key="valid-openai-key",
|
||||
client_type="openai",
|
||||
)
|
||||
|
||||
assert result["network_ok"] is True
|
||||
assert result["api_key_valid"] is True
|
||||
assert len(calls) == 2
|
||||
|
||||
validation_call = calls[1]
|
||||
|
||||
assert validation_call["url"] == "https://example.com/v1/models"
|
||||
assert validation_call["params"] == {}
|
||||
assert validation_call["headers"]["Content-Type"] == "application/json"
|
||||
assert validation_call["headers"]["Authorization"] == "Bearer valid-openai-key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_test_provider_connection_by_name_forwards_provider_client_type(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
) -> None:
|
||||
"""按提供商名称测试连接时,应透传配置中的 client_type。"""
|
||||
model_routes = load_model_routes(monkeypatch)
|
||||
config_path = tmp_path / "model_config.toml"
|
||||
config_path.write_text(
|
||||
"""
|
||||
[[api_providers]]
|
||||
name = "Gemini"
|
||||
base_url = "https://generativelanguage.googleapis.com/v1beta"
|
||||
api_key = "valid-gemini-key"
|
||||
client_type = "gemini"
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(model_routes, "CONFIG_DIR", str(tmp_path))
|
||||
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
|
||||
async def fake_test_provider_connection(**kwargs: Any) -> dict[str, Any]:
|
||||
captured_kwargs.update(kwargs)
|
||||
return {
|
||||
"network_ok": True,
|
||||
"api_key_valid": True,
|
||||
"latency_ms": 12.34,
|
||||
"error": None,
|
||||
"http_status": 200,
|
||||
}
|
||||
|
||||
monkeypatch.setattr(model_routes, "test_provider_connection", fake_test_provider_connection)
|
||||
|
||||
result = await model_routes.test_provider_connection_by_name(provider_name="Gemini")
|
||||
|
||||
assert result["network_ok"] is True
|
||||
assert result["api_key_valid"] is True
|
||||
assert captured_kwargs == {
|
||||
"base_url": "https://generativelanguage.googleapis.com/v1beta",
|
||||
"api_key": "valid-gemini-key",
|
||||
"client_type": "gemini",
|
||||
}
|
||||
136
pytests/webui/test_plugin_management_routes.py
Normal file
136
pytests/webui/test_plugin_management_routes.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from src.webui.routers.plugin import management as management_module
|
||||
from src.webui.routers.plugin import support as support_module
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(tmp_path, monkeypatch) -> TestClient:
|
||||
plugins_dir = tmp_path / "plugins"
|
||||
plugins_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
demo_dir = plugins_dir / "demo_plugin"
|
||||
demo_dir.mkdir()
|
||||
(demo_dir / "_manifest.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"manifest_version": 2,
|
||||
"id": "test.demo",
|
||||
"name": "Demo Plugin",
|
||||
"version": "1.0.0",
|
||||
"description": "demo plugin",
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(management_module, "require_plugin_token", lambda _: "ok")
|
||||
monkeypatch.setattr(support_module, "get_plugins_dir", lambda: plugins_dir)
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(management_module.router, prefix="/api/webui/plugins")
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_installed_plugins_only_scan_plugins_dir_and_exclude_a_memorix(client: TestClient):
|
||||
response = client.get("/api/webui/plugins/installed")
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["success"] is True
|
||||
|
||||
ids = [plugin["id"] for plugin in payload["plugins"]]
|
||||
assert ids == ["test.demo"]
|
||||
assert "a-dawn.a-memorix" not in ids
|
||||
assert all("/src/plugins/built_in/" not in plugin["path"] for plugin in payload["plugins"])
|
||||
|
||||
|
||||
def test_resolve_installed_plugin_path_falls_back_to_manifest_id(client: TestClient):
|
||||
plugin_path = support_module.resolve_installed_plugin_path("test.demo")
|
||||
|
||||
assert plugin_path is not None
|
||||
assert plugin_path.name == "demo_plugin"
|
||||
|
||||
|
||||
def test_resolve_installed_plugin_path_accepts_manifest_id_case_mismatch(client: TestClient):
|
||||
plugin_path = support_module.resolve_installed_plugin_path("Test.Demo")
|
||||
|
||||
assert plugin_path is not None
|
||||
assert plugin_path.name == "demo_plugin"
|
||||
|
||||
|
||||
def test_install_plugin_preserves_manifest_declared_id(client: TestClient, monkeypatch):
|
||||
class FakeGitMirrorService:
|
||||
async def clone_repository(self, **kwargs):
|
||||
target_path = kwargs["target_path"]
|
||||
target_path.mkdir(parents=True, exist_ok=True)
|
||||
(target_path / "_manifest.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"manifest_version": 2,
|
||||
"id": "author.declared",
|
||||
"name": "Declared Plugin",
|
||||
"version": "1.0.0",
|
||||
"author": {"name": "author"},
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return {"success": True}
|
||||
|
||||
monkeypatch.setattr(management_module, "get_git_mirror_service", lambda: FakeGitMirrorService())
|
||||
|
||||
response = client.post(
|
||||
"/api/webui/plugins/install",
|
||||
json={
|
||||
"plugin_id": "market.plugin",
|
||||
"repository_url": "https://github.com/author/declared",
|
||||
"branch": "main",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
plugin_path = support_module.resolve_installed_plugin_path("author.declared")
|
||||
assert plugin_path is not None
|
||||
manifest = json.loads((plugin_path / "_manifest.json").read_text(encoding="utf-8"))
|
||||
assert manifest["id"] == "author.declared"
|
||||
|
||||
|
||||
def test_install_plugin_backfills_missing_manifest_id(client: TestClient, monkeypatch):
|
||||
class FakeGitMirrorService:
|
||||
async def clone_repository(self, **kwargs):
|
||||
target_path = kwargs["target_path"]
|
||||
target_path.mkdir(parents=True, exist_ok=True)
|
||||
(target_path / "_manifest.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"manifest_version": 2,
|
||||
"name": "Legacy Plugin",
|
||||
"version": "1.0.0",
|
||||
"author": {"name": "author"},
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return {"success": True}
|
||||
|
||||
monkeypatch.setattr(management_module, "get_git_mirror_service", lambda: FakeGitMirrorService())
|
||||
|
||||
response = client.post(
|
||||
"/api/webui/plugins/install",
|
||||
json={
|
||||
"plugin_id": "market.legacy",
|
||||
"repository_url": "https://github.com/author/legacy",
|
||||
"branch": "main",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
plugin_path = support_module.resolve_installed_plugin_path("market.legacy")
|
||||
assert plugin_path is not None
|
||||
manifest = json.loads((plugin_path / "_manifest.json").read_text(encoding="utf-8"))
|
||||
assert manifest["id"] == "market.legacy"
|
||||
332
pytests/webui/test_statistics_service.py
Normal file
332
pytests/webui/test_statistics_service.py
Normal file
@@ -0,0 +1,332 @@
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Iterator
|
||||
|
||||
import pytest
|
||||
|
||||
from src.services import statistics_service
|
||||
from src.webui.schemas.statistics import DashboardData, StatisticsSummary, TimeSeriesData
|
||||
|
||||
|
||||
class _Result:
|
||||
def __init__(self, *, first_value: Any = None, all_values: list[Any] | None = None) -> None:
|
||||
self._first_value = first_value
|
||||
self._all_values = all_values or []
|
||||
|
||||
def first(self) -> Any:
|
||||
return self._first_value
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
return self._all_values
|
||||
|
||||
|
||||
class _Session:
|
||||
def __init__(self, results: list[_Result]) -> None:
|
||||
self._results = results
|
||||
|
||||
def exec(self, statement: Any) -> _Result:
|
||||
del statement
|
||||
return self._results.pop(0)
|
||||
|
||||
|
||||
class _MemoryStore:
|
||||
def __init__(self) -> None:
|
||||
self.store: dict[str, Any] = {}
|
||||
|
||||
def __getitem__(self, item: str) -> Any:
|
||||
return self.store.get(item)
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
self.store[key] = value
|
||||
|
||||
|
||||
def _patch_session_results(monkeypatch: pytest.MonkeyPatch, results: list[_Result]) -> list[bool]:
|
||||
auto_commit_calls: list[bool] = []
|
||||
|
||||
@contextmanager
|
||||
def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_Session]:
|
||||
auto_commit_calls.append(auto_commit)
|
||||
yield _Session([results.pop(0)])
|
||||
|
||||
monkeypatch.setattr(statistics_service, "get_db_session", _fake_get_db_session)
|
||||
return auto_commit_calls
|
||||
|
||||
|
||||
def _patch_session_result_group(monkeypatch: pytest.MonkeyPatch, results: list[_Result]) -> list[bool]:
|
||||
auto_commit_calls: list[bool] = []
|
||||
|
||||
@contextmanager
|
||||
def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_Session]:
|
||||
auto_commit_calls.append(auto_commit)
|
||||
yield _Session(results)
|
||||
|
||||
monkeypatch.setattr(statistics_service, "get_db_session", _fake_get_db_session)
|
||||
return auto_commit_calls
|
||||
|
||||
|
||||
def _build_dashboard_data(total_requests: int = 1) -> DashboardData:
|
||||
return DashboardData(
|
||||
summary=StatisticsSummary(total_requests=total_requests),
|
||||
model_stats=[],
|
||||
hourly_data=[],
|
||||
daily_data=[],
|
||||
recent_activity=[],
|
||||
)
|
||||
|
||||
|
||||
def _build_dashboard_data_with_time_series() -> DashboardData:
|
||||
return DashboardData(
|
||||
summary=StatisticsSummary(total_requests=1),
|
||||
model_stats=[],
|
||||
hourly_data=[
|
||||
TimeSeriesData(timestamp="2026-05-06T10:00:00", requests=0, cost=0.0, tokens=0),
|
||||
TimeSeriesData(timestamp="2026-05-06T11:00:00", requests=2, cost=0.5, tokens=50),
|
||||
TimeSeriesData(timestamp="2026-05-06T12:00:00", requests=0, cost=0.0, tokens=0),
|
||||
],
|
||||
daily_data=[
|
||||
TimeSeriesData(timestamp="2026-05-05T00:00:00", requests=0, cost=0.0, tokens=0),
|
||||
TimeSeriesData(timestamp="2026-05-06T00:00:00", requests=3, cost=0.7, tokens=70),
|
||||
],
|
||||
recent_activity=[],
|
||||
)
|
||||
|
||||
|
||||
def test_shared_fetch_queries_disable_auto_commit(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
now = datetime(2026, 5, 6, 12, 0, 0)
|
||||
online_record = SimpleNamespace(start_timestamp=now - timedelta(minutes=5), end_timestamp=now)
|
||||
usage_record = SimpleNamespace(
|
||||
timestamp=now,
|
||||
request_type="chat.reply",
|
||||
model_api_provider_name="provider",
|
||||
model_assign_name="chat-main",
|
||||
model_name="gpt-a",
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
cost=0.01,
|
||||
time_cost=1.2,
|
||||
)
|
||||
message_record = SimpleNamespace(timestamp=now, message_id="msg-1")
|
||||
tool_record = SimpleNamespace(timestamp=now, tool_name="reply")
|
||||
auto_commit_calls = _patch_session_results(
|
||||
monkeypatch,
|
||||
[
|
||||
_Result(all_values=[online_record]),
|
||||
_Result(all_values=[usage_record]),
|
||||
_Result(all_values=[message_record]),
|
||||
_Result(all_values=[tool_record]),
|
||||
],
|
||||
)
|
||||
|
||||
online_ranges = statistics_service.fetch_online_time_since(now - timedelta(hours=1))
|
||||
usage_records = statistics_service.fetch_model_usage_since(now - timedelta(hours=1))
|
||||
messages = statistics_service.fetch_messages_since(now - timedelta(hours=1))
|
||||
tool_records = statistics_service.fetch_tool_records_since(now - timedelta(hours=1))
|
||||
|
||||
assert online_ranges == [(online_record.start_timestamp, online_record.end_timestamp)]
|
||||
assert usage_records == [
|
||||
{
|
||||
"timestamp": now,
|
||||
"request_type": "chat.reply",
|
||||
"model_api_provider_name": "provider",
|
||||
"model_assign_name": "chat-main",
|
||||
"model_name": "gpt-a",
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
"cost": 0.01,
|
||||
"time_cost": 1.2,
|
||||
}
|
||||
]
|
||||
assert messages == [message_record]
|
||||
assert tool_records == [tool_record]
|
||||
assert auto_commit_calls == [False, False, False, False]
|
||||
|
||||
|
||||
def test_get_earliest_statistics_time_uses_min_valid_timestamp(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
fallback_time = datetime(2026, 5, 6, 12, 0, 0)
|
||||
earliest_time = datetime(2026, 5, 1, 8, 30, 0)
|
||||
auto_commit_calls = _patch_session_result_group(
|
||||
monkeypatch,
|
||||
[
|
||||
_Result(first_value=datetime(2026, 5, 3, 9, 0, 0)),
|
||||
_Result(first_value=earliest_time),
|
||||
_Result(first_value=None),
|
||||
_Result(first_value=datetime(2026, 5, 2, 9, 0, 0)),
|
||||
],
|
||||
)
|
||||
|
||||
result = statistics_service.get_earliest_statistics_time(fallback_time)
|
||||
|
||||
assert result == earliest_time
|
||||
assert auto_commit_calls == [False]
|
||||
|
||||
|
||||
def test_get_earliest_statistics_time_falls_back_when_query_fails(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
fallback_time = datetime(2026, 5, 6, 12, 0, 0)
|
||||
|
||||
@contextmanager
|
||||
def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_Session]:
|
||||
del auto_commit
|
||||
raise RuntimeError("database unavailable")
|
||||
yield _Session([])
|
||||
|
||||
monkeypatch.setattr(statistics_service, "get_db_session", _fake_get_db_session)
|
||||
|
||||
assert statistics_service.get_earliest_statistics_time(fallback_time) == fallback_time
|
||||
|
||||
|
||||
def test_dashboard_statistics_cache_roundtrip(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
memory_store = _MemoryStore()
|
||||
now = datetime.now()
|
||||
dashboard_data = _build_dashboard_data(total_requests=7)
|
||||
monkeypatch.setattr(statistics_service, "local_storage", memory_store)
|
||||
|
||||
statistics_service.store_dashboard_statistics_cache({24: dashboard_data}, generated_at=now)
|
||||
cached_data = statistics_service.get_cached_dashboard_statistics(24)
|
||||
|
||||
assert cached_data is not None
|
||||
assert cached_data.summary.total_requests == 7
|
||||
|
||||
|
||||
def test_dashboard_statistics_cache_stores_sparse_time_series(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
memory_store = _MemoryStore()
|
||||
generated_at = datetime(2026, 5, 6, 12, 0, 0)
|
||||
dashboard_data = _build_dashboard_data_with_time_series()
|
||||
monkeypatch.setattr(statistics_service, "local_storage", memory_store)
|
||||
|
||||
statistics_service.store_dashboard_statistics_cache({2: dashboard_data}, generated_at=generated_at)
|
||||
|
||||
raw_cache = memory_store[statistics_service.DASHBOARD_STATISTICS_CACHE_KEY]
|
||||
raw_entry = raw_cache["entries"]["2"]
|
||||
assert raw_entry["sparse"] is True
|
||||
assert raw_entry["hourly_data"] == [
|
||||
{"timestamp": "2026-05-06T11:00:00", "requests": 2, "cost": 0.5, "tokens": 50}
|
||||
]
|
||||
assert raw_entry["daily_data"] == [
|
||||
{"timestamp": "2026-05-06T00:00:00", "requests": 3, "cost": 0.7, "tokens": 70}
|
||||
]
|
||||
|
||||
cached_data = statistics_service.get_cached_dashboard_statistics(2, max_age_seconds=10**9)
|
||||
assert cached_data is not None
|
||||
assert [item.timestamp for item in cached_data.hourly_data] == [
|
||||
"2026-05-06T10:00:00",
|
||||
"2026-05-06T11:00:00",
|
||||
"2026-05-06T12:00:00",
|
||||
]
|
||||
assert cached_data.hourly_data[0].requests == 0
|
||||
assert cached_data.hourly_data[1].requests == 2
|
||||
assert cached_data.hourly_data[2].requests == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_dashboard_statistics_prefers_cache(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
memory_store = _MemoryStore()
|
||||
dashboard_data = _build_dashboard_data(total_requests=9)
|
||||
monkeypatch.setattr(statistics_service, "local_storage", memory_store)
|
||||
statistics_service.store_dashboard_statistics_cache({24: dashboard_data}, generated_at=datetime.now())
|
||||
|
||||
async def _fail_compute_dashboard_statistics(hours: int = 24) -> DashboardData:
|
||||
del hours
|
||||
raise AssertionError("cache should be used")
|
||||
|
||||
monkeypatch.setattr(statistics_service, "compute_dashboard_statistics", _fail_compute_dashboard_statistics)
|
||||
|
||||
result = await statistics_service.get_dashboard_statistics(24)
|
||||
|
||||
assert result.summary.total_requests == 9
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_dashboard_statistics_returns_empty_when_cache_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
memory_store = _MemoryStore()
|
||||
monkeypatch.setattr(statistics_service, "local_storage", memory_store)
|
||||
|
||||
async def _fail_compute_dashboard_statistics(hours: int = 24) -> DashboardData:
|
||||
del hours
|
||||
raise AssertionError("dashboard API should not compute fallback data")
|
||||
|
||||
monkeypatch.setattr(statistics_service, "compute_dashboard_statistics", _fail_compute_dashboard_statistics)
|
||||
|
||||
result = await statistics_service.get_dashboard_statistics(24)
|
||||
|
||||
assert result.summary.total_requests == 0
|
||||
assert result.model_stats == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_summary_statistics_aggregates_database_and_message_counts(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
start_time = datetime(2026, 5, 6, 10, 0, 0)
|
||||
end_time = datetime(2026, 5, 6, 12, 0, 0)
|
||||
online_records = [
|
||||
SimpleNamespace(
|
||||
start_timestamp=start_time - timedelta(minutes=30),
|
||||
end_timestamp=start_time + timedelta(minutes=30),
|
||||
),
|
||||
SimpleNamespace(
|
||||
start_timestamp=start_time + timedelta(hours=1),
|
||||
end_timestamp=end_time + timedelta(minutes=30),
|
||||
),
|
||||
]
|
||||
auto_commit_calls = _patch_session_results(
|
||||
monkeypatch,
|
||||
[
|
||||
_Result(first_value=(3, 1.5, 900, 2.5)),
|
||||
_Result(all_values=online_records),
|
||||
],
|
||||
)
|
||||
|
||||
def _fake_count_messages(**kwargs: Any) -> int:
|
||||
return 5 if kwargs.get("has_reply_to") is None else 2
|
||||
|
||||
monkeypatch.setattr(statistics_service, "count_messages", _fake_count_messages)
|
||||
|
||||
summary = await statistics_service.get_summary_statistics(start_time, end_time)
|
||||
|
||||
assert summary.total_requests == 3
|
||||
assert summary.total_cost == 1.5
|
||||
assert summary.total_tokens == 900
|
||||
assert summary.avg_response_time == 2.5
|
||||
assert summary.online_time == 5400
|
||||
assert summary.total_messages == 5
|
||||
assert summary.total_replies == 2
|
||||
assert summary.cost_per_hour == pytest.approx(1.0)
|
||||
assert summary.tokens_per_hour == pytest.approx(600.0)
|
||||
assert auto_commit_calls == [False, False]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_model_statistics_groups_by_display_model_name(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
now = datetime(2026, 5, 6, 12, 0, 0)
|
||||
records = [
|
||||
SimpleNamespace(
|
||||
model_assign_name="chat-main",
|
||||
model_name="gpt-a",
|
||||
cost=0.4,
|
||||
total_tokens=100,
|
||||
time_cost=2.0,
|
||||
),
|
||||
SimpleNamespace(
|
||||
model_assign_name="chat-main",
|
||||
model_name="gpt-a",
|
||||
cost=0.6,
|
||||
total_tokens=200,
|
||||
time_cost=4.0,
|
||||
),
|
||||
SimpleNamespace(
|
||||
model_assign_name=None,
|
||||
model_name="gpt-b",
|
||||
cost=0.2,
|
||||
total_tokens=50,
|
||||
time_cost=0.0,
|
||||
),
|
||||
]
|
||||
_patch_session_results(monkeypatch, [_Result(all_values=records)])
|
||||
|
||||
stats = await statistics_service.get_model_statistics(now - timedelta(hours=24))
|
||||
|
||||
assert [item.model_name for item in stats] == ["chat-main", "gpt-b"]
|
||||
assert stats[0].request_count == 2
|
||||
assert stats[0].total_cost == pytest.approx(1.0)
|
||||
assert stats[0].total_tokens == 300
|
||||
assert stats[0].avg_response_time == pytest.approx(3.0)
|
||||
assert stats[1].avg_response_time == 0.0
|
||||
13
pytests/webui/test_system_routes.py
Normal file
13
pytests/webui/test_system_routes.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from src.webui.routers import system
|
||||
|
||||
|
||||
def test_is_newer_version_detects_patch_update() -> None:
|
||||
assert system._is_newer_version("1.0.7", "1.0.6") is True
|
||||
|
||||
|
||||
def test_is_newer_version_ignores_same_version_with_shorter_parts() -> None:
|
||||
assert system._is_newer_version("1.0.0", "1.0") is False
|
||||
|
||||
|
||||
def test_is_newer_version_handles_unknown_current_version() -> None:
|
||||
assert system._is_newer_version("1.0.7", "unknown") is False
|
||||
Reference in New Issue
Block a user