chore: import deployable mai-bot source tree

This commit is contained in:
2026-05-11 00:51:12 +00:00
parent 4813699b3e
commit 7a54015f94
1009 changed files with 312999 additions and 16 deletions

View File

161
pytests/webui/test_app.py Normal file
View File

@@ -0,0 +1,161 @@
from pathlib import Path
from unittest.mock import patch
import pytest
from src.webui import app as webui_app
def test_ensure_static_path_ready_uses_existing_static_path(tmp_path) -> None:
static_path = tmp_path / "dist"
static_path.mkdir()
(static_path / "index.html").write_text("<html></html>", encoding="utf-8")
with patch.object(webui_app, "_resolve_static_path", return_value=static_path):
result = webui_app._ensure_static_path_ready()
assert result == static_path
def test_ensure_static_path_ready_logs_install_hint_when_static_assets_are_missing() -> None:
with (
patch.object(webui_app, "_resolve_static_path", return_value=None),
patch.object(webui_app.logger, "warning") as warning_mock,
):
result = webui_app._ensure_static_path_ready()
assert result is None
warning_mock.assert_any_call(webui_app.t("startup.webui_static_assets_unavailable"))
warning_mock.assert_any_call(
webui_app.t("startup.webui_dashboard_package_hint", command=webui_app._MANUAL_INSTALL_COMMAND)
)
def test_ensure_static_path_ready_logs_index_error_when_static_path_is_invalid(tmp_path) -> None:
static_path = tmp_path / "dist"
static_path.mkdir()
with (
patch.object(webui_app, "_resolve_static_path", return_value=static_path),
patch.object(webui_app.logger, "warning") as warning_mock,
):
result = webui_app._ensure_static_path_ready()
assert result is None
warning_mock.assert_any_call(
webui_app.t("startup.webui_index_missing", index_path=static_path / "index.html")
)
warning_mock.assert_any_call(
webui_app.t("startup.webui_dashboard_package_hint", command=webui_app._MANUAL_INSTALL_COMMAND)
)
def test_setup_static_files_does_not_duplicate_warning_when_static_path_is_unavailable() -> None:
app = webui_app.FastAPI()
with (
patch.object(webui_app, "_ensure_static_path_ready", return_value=None),
patch.object(webui_app.logger, "warning") as warning_mock,
):
webui_app._setup_static_files(app)
warning_mock.assert_not_called()
def test_resolve_static_path_prefers_installed_dashboard_package(monkeypatch, tmp_path) -> None:
package_dist = tmp_path / "site-packages" / "maibot_dashboard" / "dist"
package_dist.mkdir(parents=True)
class _DashboardModule:
@staticmethod
def get_dist_path() -> Path:
return package_dist
monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path)
with patch.object(webui_app, "import_module", return_value=_DashboardModule()):
resolved_path = webui_app._resolve_static_path()
assert resolved_path == package_dist
def test_resolve_static_path_ignores_dashboard_dist_when_package_is_unavailable(monkeypatch, tmp_path) -> None:
dashboard_dist = tmp_path / "dashboard" / "dist"
dashboard_dist.mkdir(parents=True)
(dashboard_dist / "index.html").write_text("<html></html>", encoding="utf-8")
monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path)
with patch.object(webui_app, "import_module", side_effect=ImportError):
resolved_path = webui_app._resolve_static_path()
assert resolved_path is None
def test_resolve_static_path_uses_package_even_when_dashboard_dist_exists(monkeypatch, tmp_path) -> None:
dashboard_dist = tmp_path / "dashboard" / "dist"
dashboard_dist.mkdir(parents=True)
package_dist = tmp_path / "site-packages" / "maibot_dashboard" / "dist"
package_dist.mkdir(parents=True)
class _DashboardModule:
@staticmethod
def get_dist_path() -> Path:
return package_dist
monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path)
with patch.object(webui_app, "import_module", return_value=_DashboardModule()):
resolved_path = webui_app._resolve_static_path()
assert resolved_path == package_dist
def test_resolve_safe_static_file_path_allows_regular_static_file(tmp_path) -> None:
static_path = tmp_path / "dist"
asset_path = static_path / "assets" / "app.js"
asset_path.parent.mkdir(parents=True)
asset_path.write_text("console.log('ok')", encoding="utf-8")
resolved_path = webui_app._resolve_safe_static_file_path(static_path, "assets/app.js")
assert resolved_path == asset_path.resolve()
def test_resolve_safe_static_file_path_rejects_relative_path_traversal(tmp_path) -> None:
static_path = tmp_path / "dist"
static_path.mkdir()
resolved_path = webui_app._resolve_safe_static_file_path(static_path, "../secret.txt")
assert resolved_path is None
def test_resolve_safe_static_file_path_rejects_absolute_path_traversal(tmp_path) -> None:
static_path = tmp_path / "dist"
static_path.mkdir()
resolved_path = webui_app._resolve_safe_static_file_path(static_path, "/etc/passwd")
assert resolved_path is None
def test_resolve_safe_static_file_path_rejects_symlink_escape(tmp_path) -> None:
static_path = tmp_path / "dist"
static_path.mkdir()
outside_dir = tmp_path / "outside"
outside_dir.mkdir()
outside_file = outside_dir / "secret.txt"
outside_file.write_text("secret", encoding="utf-8")
link_path = static_path / "escape"
try:
link_path.symlink_to(outside_dir, target_is_directory=True)
except OSError as exc:
pytest.skip(f"symlink is not supported in this environment: {exc}")
resolved_path = webui_app._resolve_safe_static_file_path(static_path, "escape/secret.txt")
assert resolved_path is None

View File

@@ -0,0 +1,147 @@
from src.config.official_configs import ChatConfig, MessageReceiveConfig
from src.config.config import Config
from src.config.config_base import ConfigBase, Field
from src.webui.config_schema import ConfigSchemaGenerator
def test_field_docs_in_schema():
"""Test that field descriptions are correctly extracted from field_docs (docstrings)."""
schema = ConfigSchemaGenerator.generate_schema(ChatConfig)
talk_value = next(f for f in schema["fields"] if f["name"] == "talk_value")
# Verify description field exists
assert "description" in talk_value
# Verify description contains expected Chinese text from the docstring
assert "聊天频率" in talk_value["description"]
def test_json_schema_extra_merged():
"""Test that json_schema_extra fields are correctly merged into output."""
schema = ConfigSchemaGenerator.generate_schema(ChatConfig)
talk_value = next(f for f in schema["fields"] if f["name"] == "talk_value")
# Verify UI metadata fields from json_schema_extra exist
assert talk_value.get("x-widget") == "slider"
assert talk_value.get("x-icon") == "message-circle"
assert talk_value.get("step") == 0.1
def test_pydantic_constraints_mapped():
"""Test that Pydantic constraints (ge/le) are correctly mapped to minValue/maxValue."""
schema = ConfigSchemaGenerator.generate_schema(ChatConfig)
talk_value = next(f for f in schema["fields"] if f["name"] == "talk_value")
# Verify constraints are mapped to frontend naming convention
assert "minValue" in talk_value
assert "maxValue" in talk_value
assert talk_value["minValue"] == 0 # From ge=0
assert talk_value["maxValue"] == 1 # From le=1
def test_nested_model_schema():
"""Test that nested models (ConfigBase fields) are correctly handled."""
schema = ConfigSchemaGenerator.generate_schema(Config)
# Verify nested structure exists
assert "nested" in schema
assert "chat" in schema["nested"]
# Verify nested chat schema is complete
chat_schema = schema["nested"]["chat"]
assert chat_schema["className"] == "ChatConfig"
assert "fields" in chat_schema
# Verify nested schema fields include description and metadata
talk_value = next(f for f in chat_schema["fields"] if f["name"] == "talk_value")
assert "description" in talk_value
assert talk_value.get("x-widget") == "slider"
assert talk_value.get("minValue") == 0
def test_field_without_extra_metadata():
"""Test that fields without json_schema_extra still generate valid schema."""
schema = ConfigSchemaGenerator.generate_schema(ChatConfig)
inevitable_at_reply = next(f for f in schema["fields"] if f["name"] == "inevitable_at_reply")
# Verify basic fields are generated
assert "name" in inevitable_at_reply
assert inevitable_at_reply["name"] == "inevitable_at_reply"
assert "type" in inevitable_at_reply
assert inevitable_at_reply["type"] == "boolean"
assert "label" in inevitable_at_reply
assert "required" in inevitable_at_reply
# Verify no x-widget or x-icon from json_schema_extra (since field has none)
# These fields should only be present if explicitly defined in json_schema_extra
assert not inevitable_at_reply.get("x-widget")
assert not inevitable_at_reply.get("x-icon")
def test_all_top_level_sections_have_ui_metadata():
"""所有顶层配置节都必须声明 uiParent 或独立 Tab 的标签与图标。"""
schema = ConfigSchemaGenerator.generate_schema(Config)
for section_name, section_schema in schema["nested"].items():
has_parent = bool(section_schema.get("uiParent"))
has_host_meta = bool(section_schema.get("uiLabel")) and bool(section_schema.get("uiIcon"))
assert has_parent or has_host_meta, f"{section_name} 缺少 UI 元数据"
def test_maisaka_is_host_tab_and_mcp_is_attached_to_it():
"""MaiSaka 应作为独立 TabMCP 作为其子配置挂载。"""
schema = ConfigSchemaGenerator.generate_schema(Config)
maisaka_schema = schema["nested"]["maisaka"]
mcp_schema = schema["nested"]["mcp"]
assert maisaka_schema.get("uiParent") is None
assert maisaka_schema.get("uiLabel") == "MaiSaka"
assert maisaka_schema.get("uiIcon") == "message-circle"
assert mcp_schema.get("uiParent") == "maisaka"
def test_memory_query_config_fields_are_exposed():
"""query_memory 开关和默认条数应出现在记忆配置 schema 中。"""
schema = ConfigSchemaGenerator.generate_schema(Config)
memory_schema = schema["nested"]["memory"]
assert memory_schema.get("uiParent") == "emoji"
enable_field = next(field for field in memory_schema["fields"] if field["name"] == "enable_memory_query_tool")
limit_field = next(field for field in memory_schema["fields"] if field["name"] == "memory_query_default_limit")
assert enable_field["type"] == "boolean"
assert enable_field.get("x-widget") == "switch"
assert enable_field.get("x-icon") == "database"
assert limit_field["type"] == "integer"
assert limit_field.get("x-widget") == "input"
assert limit_field.get("x-icon") == "hash"
assert limit_field.get("minValue") == 1
assert limit_field.get("maxValue") == 20
def test_set_field_is_mapped_as_array():
"""set[str] 应映射为前端可识别的 array。"""
schema = ConfigSchemaGenerator.generate_schema(MessageReceiveConfig)
ban_words = next(field for field in schema["fields"] if field["name"] == "ban_words")
assert ban_words["type"] == "array"
assert ban_words["items"]["type"] == "string"
def test_advanced_fields_are_hidden_from_webui_schema():
"""advanced=True 的字段不应出现在 WebUI 配置 schema 中,未声明时默认展示。"""
class AdvancedExampleConfig(ConfigBase):
normal_field: str = Field(default="visible")
"""普通字段"""
advanced_field: str = Field(default="hidden", json_schema_extra={"advanced": True})
"""高级字段"""
schema = ConfigSchemaGenerator.generate_schema(AdvancedExampleConfig)
field_names = {field["name"] for field in schema["fields"]}
assert "normal_field" in field_names
assert "advanced_field" not in field_names

View File

@@ -0,0 +1,461 @@
"""表情包路由 API 测试
测试 src/webui/routers/emoji.py 中的核心 emoji 路由端点
使用内存 SQLite 数据库和 FastAPI TestClient
"""
from contextlib import contextmanager
from datetime import datetime
from typing import Generator
from unittest.mock import patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine
from src.common.database.database_model import Images, ImageType
from src.webui.core import TokenManager
from src.webui.routers.emoji import router
@pytest.fixture(scope="function")
def test_engine():
"""创建内存 SQLite 引擎用于测试"""
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
return engine
@pytest.fixture(scope="function")
def test_session(test_engine) -> Generator[Session, None, None]:
"""创建测试数据库会话"""
with Session(test_engine) as session:
yield session
@pytest.fixture(scope="function")
def test_app(test_session):
"""创建测试 FastAPI 应用并覆盖 get_db_session 依赖"""
app = FastAPI()
app.include_router(router)
# Create a context manager that yields the test session
@contextmanager
def override_get_db_session(auto_commit=True):
"""Override get_db_session to use test session"""
try:
yield test_session
if auto_commit:
test_session.commit()
except Exception:
test_session.rollback()
raise
with patch("src.webui.routers.emoji.get_db_session", override_get_db_session):
yield app
@pytest.fixture(scope="function")
def client(test_app):
"""创建 TestClient"""
return TestClient(test_app)
@pytest.fixture(scope="function")
def auth_token():
"""创建有效的认证 token"""
token_manager = TokenManager(secret_key="test-secret-key", token_expire_hours=24)
return token_manager.create_token()
@pytest.fixture(scope="function")
def sample_emojis(test_session) -> list[Images]:
"""插入测试用表情包数据"""
import hashlib
emojis = [
Images(
image_type=ImageType.EMOJI,
full_path="/data/emoji_registed/test1.png",
image_hash=hashlib.sha256(b"test1").hexdigest(),
description="测试表情包 1",
emotion="开心,快乐",
query_count=10,
is_registered=True,
is_banned=False,
record_time=datetime(2026, 1, 1, 10, 0, 0),
register_time=datetime(2026, 1, 1, 10, 0, 0),
last_used_time=datetime(2026, 1, 2, 10, 0, 0),
),
Images(
image_type=ImageType.EMOJI,
full_path="/data/emoji_registed/test2.gif",
image_hash=hashlib.sha256(b"test2").hexdigest(),
description="测试表情包 2",
emotion="难过",
query_count=5,
is_registered=False,
is_banned=False,
record_time=datetime(2026, 1, 3, 10, 0, 0),
register_time=None,
last_used_time=None,
),
Images(
image_type=ImageType.EMOJI,
full_path="/data/emoji_registed/test3.webp",
image_hash=hashlib.sha256(b"test3").hexdigest(),
description="测试表情包 3",
emotion="生气",
query_count=20,
is_registered=True,
is_banned=True,
record_time=datetime(2026, 1, 4, 10, 0, 0),
register_time=datetime(2026, 1, 4, 10, 0, 0),
last_used_time=datetime(2026, 1, 5, 10, 0, 0),
),
]
for emoji in emojis:
test_session.add(emoji)
test_session.commit()
for emoji in emojis:
test_session.refresh(emoji)
return emojis
@pytest.fixture(scope="function")
def mock_token_verify():
"""Mock token verification to always succeed"""
with patch("src.webui.routers.emoji.verify_auth_token", return_value=True):
yield
# ==================== 测试用例 ====================
def test_list_emojis_basic(client, sample_emojis, mock_token_verify):
"""测试获取表情包列表(基本分页)"""
response = client.get("/emoji/list?page=1&page_size=10")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 3
assert data["page"] == 1
assert data["page_size"] == 10
assert len(data["data"]) == 3
# 验证第一个表情包字段
emoji = data["data"][0]
assert "id" in emoji
assert "full_path" in emoji
assert "emoji_hash" in emoji
assert "description" in emoji
assert "query_count" in emoji
assert "is_registered" in emoji
assert "is_banned" in emoji
assert "emotion" in emoji
assert "record_time" in emoji
assert "register_time" in emoji
assert "last_used_time" in emoji
def test_list_emojis_pagination(client, sample_emojis, mock_token_verify):
"""测试分页功能"""
response = client.get("/emoji/list?page=1&page_size=2")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 3
assert len(data["data"]) == 2
# 第二页
response = client.get("/emoji/list?page=2&page_size=2")
data = response.json()
assert len(data["data"]) == 1
def test_list_emojis_search(client, sample_emojis, mock_token_verify):
"""测试搜索过滤"""
response = client.get("/emoji/list?search=表情包 2")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 1
assert data["data"][0]["description"] == "测试表情包 2"
def test_list_emojis_filter_registered(client, sample_emojis, mock_token_verify):
"""测试 is_registered 过滤"""
response = client.get("/emoji/list?is_registered=true")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 2
assert all(emoji["is_registered"] is True for emoji in data["data"])
def test_list_emojis_filter_banned(client, sample_emojis, mock_token_verify):
"""测试 is_banned 过滤"""
response = client.get("/emoji/list?is_banned=true")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 1
assert data["data"][0]["is_banned"] is True
def test_list_emojis_sort_by_query_count(client, sample_emojis, mock_token_verify):
"""测试按 query_count 排序"""
response = client.get("/emoji/list?sort_by=query_count&sort_order=desc")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
# 验证降序排列 (20 > 10 > 5)
assert data["data"][0]["query_count"] == 20
assert data["data"][1]["query_count"] == 10
assert data["data"][2]["query_count"] == 5
def test_get_emoji_detail_success(client, sample_emojis, mock_token_verify):
"""测试获取表情包详情(成功)"""
emoji_id = sample_emojis[0].id
response = client.get(f"/emoji/{emoji_id}")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["id"] == emoji_id
assert data["data"]["emoji_hash"] == sample_emojis[0].image_hash
def test_get_emoji_detail_not_found(client, mock_token_verify):
"""测试获取不存在的表情包404"""
response = client.get("/emoji/99999")
assert response.status_code == 404
data = response.json()
assert "未找到" in data["detail"]
def test_update_emoji_description(client, sample_emojis, mock_token_verify):
"""测试更新表情包描述"""
emoji_id = sample_emojis[0].id
response = client.patch(
f"/emoji/{emoji_id}",
json={"description": "更新后的描述"},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["description"] == "更新后的描述"
assert "成功更新" in data["message"]
def test_update_emoji_register_status(client, sample_emojis, mock_token_verify, test_session):
"""测试更新注册状态False -> True 应设置 register_time"""
emoji_id = sample_emojis[1].id # 未注册的表情包
response = client.patch(
f"/emoji/{emoji_id}",
json={"is_registered": True},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["is_registered"] is True
assert data["data"]["register_time"] is not None # 应该设置了注册时间
def test_update_emoji_no_fields(client, sample_emojis, mock_token_verify):
"""测试更新请求未提供任何字段400"""
emoji_id = sample_emojis[0].id
response = client.patch(f"/emoji/{emoji_id}", json={})
assert response.status_code == 400
data = response.json()
assert "未提供任何需要更新的字段" in data["detail"]
def test_update_emoji_not_found(client, mock_token_verify):
"""测试更新不存在的表情包404"""
response = client.patch("/emoji/99999", json={"description": "test"})
assert response.status_code == 404
data = response.json()
assert "未找到" in data["detail"]
def test_delete_emoji_success(client, sample_emojis, mock_token_verify, test_session):
"""测试删除表情包(成功)"""
emoji_id = sample_emojis[0].id
response = client.delete(f"/emoji/{emoji_id}")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "成功删除" in data["message"]
# 验证数据库中已删除
from sqlmodel import select
statement = select(Images).where(Images.id == emoji_id)
result = test_session.exec(statement).first()
assert result is None
def test_delete_emoji_not_found(client, mock_token_verify):
"""测试删除不存在的表情包404"""
response = client.delete("/emoji/99999")
assert response.status_code == 404
data = response.json()
assert "未找到" in data["detail"]
def test_batch_delete_success(client, sample_emojis, mock_token_verify, test_session):
"""测试批量删除表情包(全部成功)"""
emoji_ids = [sample_emojis[0].id, sample_emojis[1].id]
response = client.post("/emoji/batch/delete", json={"emoji_ids": emoji_ids})
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["deleted_count"] == 2
assert data["failed_count"] == 0
assert "成功删除 2 个表情包" in data["message"]
# 验证数据库中已删除
from sqlmodel import select
for emoji_id in emoji_ids:
statement = select(Images).where(Images.id == emoji_id)
result = test_session.exec(statement).first()
assert result is None
def test_batch_delete_partial_failure(client, sample_emojis, mock_token_verify):
"""测试批量删除(部分失败)"""
emoji_ids = [sample_emojis[0].id, 99999] # 第二个 ID 不存在
response = client.post("/emoji/batch/delete", json={"emoji_ids": emoji_ids})
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["deleted_count"] == 1
assert data["failed_count"] == 1
assert 99999 in data["failed_ids"]
def test_batch_delete_empty_list(client, mock_token_verify):
"""测试批量删除空列表400"""
response = client.post("/emoji/batch/delete", json={"emoji_ids": []})
assert response.status_code == 400
data = response.json()
assert "未提供要删除的表情包ID" in data["detail"]
def test_auth_required_list(client):
"""测试未认证访问列表端点401"""
# Without mock_token_verify fixture
with patch("src.webui.routers.emoji.verify_auth_token", return_value=False):
client.get("/emoji/list")
# verify_auth_token 返回 False 会触发 HTTPException
# 但具体状态码取决于 verify_auth_token_from_cookie_or_header 的实现
# 这里假设它抛出 401
def test_auth_required_update(client, sample_emojis):
"""测试未认证访问更新端点401"""
with patch("src.webui.routers.emoji.verify_auth_token", return_value=False):
emoji_id = sample_emojis[0].id
client.patch(f"/emoji/{emoji_id}", json={"description": "test"})
# Should be unauthorized
def test_emoji_to_response_field_mapping(sample_emojis):
"""测试 emoji_to_response 字段映射image_hash -> emoji_hash"""
from src.webui.routers.emoji import emoji_to_response
emoji = sample_emojis[0]
response = emoji_to_response(emoji)
# 验证 API 字段名称
assert hasattr(response, "emoji_hash")
assert response.emoji_hash == emoji.image_hash
# 验证时间戳转换
assert isinstance(response.record_time, float)
assert response.record_time == emoji.record_time.timestamp()
if emoji.register_time:
assert isinstance(response.register_time, float)
assert response.register_time == emoji.register_time.timestamp()
def test_list_emojis_only_emoji_type(client, test_session, mock_token_verify):
"""测试列表只返回 type=EMOJI 的记录(不包括其他类型)"""
# 插入一个非 EMOJI 类型的图片
non_emoji = Images(
image_type=ImageType.IMAGE, # 不是 EMOJI
full_path="/data/images/test.png",
image_hash="hash_image",
description="非表情包图片",
query_count=0,
is_registered=False,
is_banned=False,
record_time=datetime.now(),
)
test_session.add(non_emoji)
test_session.commit()
# 插入一个 EMOJI 类型
emoji = Images(
image_type=ImageType.EMOJI,
full_path="/data/emoji_registed/emoji.png",
image_hash="hash_emoji",
description="表情包",
query_count=0,
is_registered=True,
is_banned=False,
record_time=datetime.now(),
)
test_session.add(emoji)
test_session.commit()
response = client.get("/emoji/list")
assert response.status_code == 200
data = response.json()
# 只应该返回 1 个 EMOJI 类型的记录
assert data["total"] == 1
assert data["data"][0]["description"] == "表情包"

View File

@@ -0,0 +1,529 @@
"""Expression routes pytest tests"""
from typing import Generator
import pytest
from fastapi import APIRouter, FastAPI
from fastapi.testclient import TestClient
from sqlalchemy import text
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine, select
from src.common.database.database_model import Expression, ModifiedBy
from src.webui.dependencies import require_auth
def create_test_app() -> FastAPI:
"""Create minimal test app with only expression router"""
app = FastAPI(title="Test App")
from src.webui.routers.expression import router as expression_router
main_router = APIRouter(prefix="/api/webui")
main_router.include_router(expression_router)
app.include_router(main_router)
return app
app = create_test_app()
# Test database setup
@pytest.fixture(name="test_engine")
def test_engine_fixture():
"""Create in-memory SQLite database for testing"""
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
return engine
@pytest.fixture(name="test_session")
def test_session_fixture(test_engine) -> Generator[Session, None, None]:
"""Create a test database session with transaction rollback"""
connection = test_engine.connect()
transaction = connection.begin()
session = Session(bind=connection)
yield session
session.close()
transaction.rollback()
connection.close()
@pytest.fixture(name="client")
def client_fixture(test_session: Session, monkeypatch) -> Generator[TestClient, None, None]:
"""Create TestClient with overridden database session"""
from contextlib import contextmanager
@contextmanager
def get_test_db_session():
yield test_session
test_session.commit()
monkeypatch.setattr("src.webui.routers.expression.get_db_session", get_test_db_session)
with TestClient(app) as client:
yield client
@pytest.fixture(name="mock_auth")
def mock_auth_fixture():
"""Mock authentication to always return True"""
app.dependency_overrides[require_auth] = lambda: "test-token"
yield
app.dependency_overrides.clear()
@pytest.fixture(name="sample_expression")
def sample_expression_fixture(test_session: Session) -> Expression:
"""Insert a sample expression into test database"""
test_session.execute(
text(
"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
"VALUES (1, '测试情景', '测试风格', '[\"测试内容1\", \"测试内容2\"]', 10, '2026-02-17 12:00:00', '2026-02-15 10:00:00', 'test_chat_001', 0, 0)"
)
)
test_session.commit()
expression = test_session.exec(select(Expression).where(Expression.id == 1)).first()
assert expression is not None
return expression
# ============ Tests ============
def test_list_expressions_empty(client: TestClient, mock_auth):
"""Test GET /expression/list with empty database"""
response = client.get("/api/webui/expression/list")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 0
assert data["page"] == 1
assert data["page_size"] == 20
assert data["data"] == []
def test_list_expressions_with_data(client: TestClient, mock_auth, sample_expression: Expression):
"""Test GET /expression/list returns expression data"""
response = client.get("/api/webui/expression/list")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 1
assert len(data["data"]) == 1
expr_data = data["data"][0]
assert expr_data["id"] == sample_expression.id
assert expr_data["situation"] == "测试情景"
assert expr_data["style"] == "测试风格"
assert expr_data["chat_id"] == "test_chat_001"
def test_list_expressions_pagination(client: TestClient, mock_auth, test_session: Session):
"""Test GET /expression/list pagination works correctly"""
for i in range(5):
test_session.execute(
text(
f"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
f"VALUES ({i + 1}, '情景{i}', '风格{i}', '[]', 0, '2026-02-17 12:0{i}:00', '2026-02-15 10:00:00', 'chat_{i}', 0, 0)"
)
)
test_session.commit()
# Request page 1 with page_size=2
response = client.get("/api/webui/expression/list?page=1&page_size=2")
assert response.status_code == 200
data = response.json()
assert data["total"] == 5
assert data["page"] == 1
assert data["page_size"] == 2
assert len(data["data"]) == 2
# Request page 2
response = client.get("/api/webui/expression/list?page=2&page_size=2")
data = response.json()
assert data["page"] == 2
assert len(data["data"]) == 2
def test_list_expressions_search(client: TestClient, mock_auth, test_session: Session):
"""Test GET /expression/list with search filter"""
test_session.execute(
text(
"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
"VALUES (1, '找人吃饭', '热情', '[]', 0, datetime('now'), datetime('now'), 'chat_001', 0, 0)"
)
)
test_session.execute(
text(
"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
"VALUES (2, '拒绝邀请', '礼貌', '[]', 0, datetime('now'), datetime('now'), 'chat_002', 0, 0)"
)
)
test_session.commit()
# Search for "吃饭"
response = client.get("/api/webui/expression/list?search=吃饭")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["data"][0]["situation"] == "找人吃饭"
def test_list_expressions_chat_filter(client: TestClient, mock_auth, test_session: Session):
"""Test GET /expression/list with chat_id filter"""
test_session.execute(
text(
"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
"VALUES (1, '情景A', '风格A', '[]', 0, datetime('now'), datetime('now'), 'chat_A', 0, 0)"
)
)
test_session.execute(
text(
"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
"VALUES (2, '情景B', '风格B', '[]', 0, datetime('now'), datetime('now'), 'chat_B', 0, 0)"
)
)
test_session.commit()
# Filter by chat_A
response = client.get("/api/webui/expression/list?chat_id=chat_A")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["data"][0]["situation"] == "情景A"
assert data["data"][0]["chat_id"] == "chat_A"
def test_get_expression_detail_success(client: TestClient, mock_auth, sample_expression: Expression):
"""Test GET /expression/{id} returns correct detail"""
response = client.get(f"/api/webui/expression/{sample_expression.id}")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["id"] == sample_expression.id
assert data["data"]["situation"] == "测试情景"
assert data["data"]["style"] == "测试风格"
assert data["data"]["chat_id"] == "test_chat_001"
def test_get_expression_detail_not_found(client: TestClient, mock_auth):
"""Test GET /expression/{id} returns 404 for non-existent ID"""
response = client.get("/api/webui/expression/99999")
assert response.status_code == 404
data = response.json()
assert "未找到" in data["detail"]
def test_expression_response_has_legacy_fields(client: TestClient, mock_auth, sample_expression: Expression):
"""Test that ExpressionResponse includes legacy fields (checked/rejected/modified_by)"""
response = client.get(f"/api/webui/expression/{sample_expression.id}")
assert response.status_code == 200
data = response.json()["data"]
# Verify legacy fields exist and have default values
assert "checked" in data
assert "rejected" in data
assert "modified_by" in data
# Verify hardcoded default values (from expression_to_response)
assert data["checked"] is False
assert data["rejected"] is False
assert data["modified_by"] is None
def test_update_expression_without_removed_fields(client: TestClient, mock_auth, sample_expression: Expression):
"""Test PATCH /expression/{id} does not accept checked/rejected fields"""
# Valid update request (only allowed fields)
update_payload = {
"situation": "更新后的情景",
"style": "更新后的风格",
}
response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["situation"] == "更新后的情景"
assert data["data"]["style"] == "更新后的风格"
# Verify legacy fields still returned (hardcoded values)
assert data["data"]["checked"] is False
assert data["data"]["rejected"] is False
def test_update_expression_ignores_invalid_fields(client: TestClient, mock_auth, sample_expression: Expression):
"""Test PATCH /expression/{id} ignores fields not in ExpressionUpdateRequest"""
# Request with invalid field (checked not in schema)
update_payload = {
"situation": "新情景",
"checked": True, # This field should be ignored by Pydantic
"rejected": True, # This field should be ignored
}
response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["situation"] == "新情景"
# Response should have hardcoded False values (not True from request)
assert data["data"]["checked"] is False
assert data["data"]["rejected"] is False
def test_update_expression_chat_id_mapping(client: TestClient, mock_auth, sample_expression: Expression):
"""Test PATCH /expression/{id} correctly maps chat_id to session_id"""
update_payload = {"chat_id": "updated_chat_999"}
response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
# Verify chat_id is returned in response (mapped from session_id)
assert data["data"]["chat_id"] == "updated_chat_999"
def test_update_expression_not_found(client: TestClient, mock_auth):
"""Test PATCH /expression/{id} returns 404 for non-existent ID"""
update_payload = {"situation": "新情景"}
response = client.patch("/api/webui/expression/99999", json=update_payload)
assert response.status_code == 404
data = response.json()
assert "未找到" in data["detail"]
def test_update_expression_empty_request(client: TestClient, mock_auth, sample_expression: Expression):
"""Test PATCH /expression/{id} returns 400 for empty update request"""
update_payload = {}
response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload)
assert response.status_code == 400
data = response.json()
assert "未提供任何需要更新的字段" in data["detail"]
def test_delete_expression_success(client: TestClient, mock_auth, sample_expression: Expression):
"""Test DELETE /expression/{id} successfully deletes expression"""
expression_id = sample_expression.id
response = client.delete(f"/api/webui/expression/{expression_id}")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "成功删除" in data["message"]
# Verify expression is deleted
get_response = client.get(f"/api/webui/expression/{expression_id}")
assert get_response.status_code == 404
def test_delete_expression_not_found(client: TestClient, mock_auth):
"""Test DELETE /expression/{id} returns 404 for non-existent ID"""
response = client.delete("/api/webui/expression/99999")
assert response.status_code == 404
data = response.json()
assert "未找到" in data["detail"]
def test_create_expression_success(client: TestClient, mock_auth):
"""Test POST /expression/ successfully creates expression"""
create_payload = {
"situation": "新建情景",
"style": "新建风格",
"chat_id": "new_chat_123",
}
response = client.post("/api/webui/expression/", json=create_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "创建成功" in data["message"]
assert data["data"]["situation"] == "新建情景"
assert data["data"]["style"] == "新建风格"
assert data["data"]["chat_id"] == "new_chat_123"
# Verify legacy fields
assert data["data"]["checked"] is False
assert data["data"]["rejected"] is False
assert data["data"]["modified_by"] is None
def test_batch_delete_expressions_success(client: TestClient, mock_auth, test_session: Session):
"""Test POST /expression/batch/delete successfully deletes multiple expressions"""
expression_ids = []
for i in range(3):
test_session.execute(
text(
f"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
f"VALUES ({i + 1}, '批量删除{i}', '风格{i}', '[]', 0, datetime('now'), datetime('now'), 'chat_{i}', 0, 0)"
)
)
expression_ids.append(i + 1)
test_session.commit()
delete_payload = {"ids": expression_ids}
response = client.post("/api/webui/expression/batch/delete", json=delete_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "成功删除 3 个" in data["message"]
for expr_id in expression_ids:
get_response = client.get(f"/api/webui/expression/{expr_id}")
assert get_response.status_code == 404
def test_batch_delete_partial_not_found(client: TestClient, mock_auth, sample_expression: Expression):
"""Test POST /expression/batch/delete handles partial not found IDs"""
delete_payload = {"ids": [sample_expression.id, 88888, 99999]}
response = client.post("/api/webui/expression/batch/delete", json=delete_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
# Should delete only the 1 valid ID
assert "成功删除 1 个" in data["message"]
def test_get_expression_stats(client: TestClient, mock_auth, test_session: Session):
"""Test GET /expression/stats/summary returns correct statistics"""
for i in range(3):
test_session.execute(
text(
f"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
f"VALUES ({i + 1}, '情景{i}', '风格{i}', '[]', 0, datetime('now'), datetime('now'), 'chat_{i % 2}', 0, 0)"
)
)
test_session.commit()
response = client.get("/api/webui/expression/stats/summary")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["total"] == 3
assert data["data"]["chat_count"] == 2
def test_get_review_stats(client: TestClient, mock_auth, test_session: Session):
"""Test GET /expression/review/stats returns review status counts"""
test_session.execute(
text(
"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
"VALUES (1, '待审核', '风格', '[]', 0, datetime('now'), datetime('now'), 'chat_001', 0, 0)"
)
)
test_session.commit()
response = client.get("/api/webui/expression/review/stats")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1 # Total expressions exists
assert data["unchecked"] == 1
assert data["passed"] == 0
assert data["rejected"] == 0
assert data["ai_checked"] == 0
assert data["user_checked"] == 0
def test_get_review_list_filter_unchecked(client: TestClient, mock_auth, sample_expression: Expression):
"""Test GET /expression/review/list with filter_type=unchecked returns unchecked expressions"""
response = client.get("/api/webui/expression/review/list?filter_type=unchecked")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 1
assert len(data["data"]) == 1
def test_get_review_list_filter_all(client: TestClient, mock_auth, sample_expression: Expression):
"""Test GET /expression/review/list with filter_type=all returns all expressions"""
response = client.get("/api/webui/expression/review/list?filter_type=all")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 1
assert len(data["data"]) == 1
def test_batch_review_expressions_with_unchecked_marker(client: TestClient, mock_auth, sample_expression: Expression):
"""Test POST /expression/review/batch succeeds with require_unchecked=True"""
review_payload = {"items": [{"id": sample_expression.id, "rejected": False, "require_unchecked": True}]}
response = client.post("/api/webui/expression/review/batch", json=review_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["succeeded"] == 1
assert data["results"][0]["success"] is True
def test_batch_review_expressions_overwrites_ai_checked(
client: TestClient, mock_auth, test_session: Session, sample_expression: Expression
):
"""Test POST /expression/review/batch lets manual review override AI checked state"""
sample_expression.checked = True
sample_expression.rejected = True
sample_expression.modified_by = ModifiedBy.AI
test_session.add(sample_expression)
test_session.commit()
review_payload = {"items": [{"id": sample_expression.id, "rejected": False, "require_unchecked": True}]}
response = client.post("/api/webui/expression/review/batch", json=review_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["succeeded"] == 1
test_session.expire_all()
reviewed_expression = test_session.exec(select(Expression).where(Expression.id == sample_expression.id)).first()
assert reviewed_expression is not None
assert reviewed_expression.checked is True
assert reviewed_expression.rejected is False
assert reviewed_expression.modified_by == ModifiedBy.USER
def test_batch_review_expressions_no_unchecked_check(client: TestClient, mock_auth, sample_expression: Expression):
"""Test POST /expression/review/batch succeeds when require_unchecked=False"""
review_payload = {"items": [{"id": sample_expression.id, "rejected": False, "require_unchecked": False}]}
response = client.post("/api/webui/expression/review/batch", json=review_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["succeeded"] == 1
assert data["results"][0]["success"] is True

View File

@@ -0,0 +1,512 @@
"""测试 jargon 路由的完整性和正确性"""
import json
from datetime import datetime
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine
from src.common.database.database_model import ChatSession, Jargon
from src.webui.routers.jargon import router as jargon_router
@pytest.fixture(name="app", scope="function")
def app_fixture():
app = FastAPI()
app.include_router(jargon_router, prefix="/api/webui")
return app
@pytest.fixture(name="engine", scope="function")
def engine_fixture():
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
yield engine
@pytest.fixture(name="session", scope="function")
def session_fixture(engine):
connection = engine.connect()
transaction = connection.begin()
session = Session(bind=connection)
yield session
session.close()
transaction.rollback()
connection.close()
@pytest.fixture(name="client", scope="function")
def client_fixture(app: FastAPI, session: Session, monkeypatch):
from contextlib import contextmanager
@contextmanager
def mock_get_db_session():
yield session
monkeypatch.setattr("src.webui.routers.jargon.get_db_session", mock_get_db_session)
with TestClient(app) as client:
yield client
@pytest.fixture(name="sample_chat_session")
def sample_chat_session_fixture(session: Session):
"""创建示例 ChatSession"""
chat_session = ChatSession(
session_id="test_stream_001",
platform="qq",
group_id="123456789",
user_id=None,
created_timestamp=datetime.now(),
last_active_timestamp=datetime.now(),
)
session.add(chat_session)
session.commit()
session.refresh(chat_session)
return chat_session
@pytest.fixture(name="sample_jargons")
def sample_jargons_fixture(session: Session, sample_chat_session: ChatSession):
"""创建示例 Jargon 数据"""
jargons = [
Jargon(
id=1,
content="yyds",
raw_content="永远的神",
meaning="永远的神",
session_id=sample_chat_session.session_id,
count=10,
is_jargon=True,
is_complete=False,
),
Jargon(
id=2,
content="awsl",
raw_content="啊我死了",
meaning="啊我死了",
session_id=sample_chat_session.session_id,
count=5,
is_jargon=True,
is_complete=False,
),
Jargon(
id=3,
content="hello",
raw_content=None,
meaning="你好",
session_id=sample_chat_session.session_id,
count=2,
is_jargon=False,
is_complete=False,
),
]
for jargon in jargons:
session.add(jargon)
session.commit()
for jargon in jargons:
session.refresh(jargon)
return jargons
# ==================== Test Cases ====================
def test_list_jargons(client: TestClient, sample_jargons):
"""测试 GET /jargon/list 基础列表功能"""
response = client.get("/api/webui/jargon/list")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 3
assert data["page"] == 1
assert data["page_size"] == 20
assert len(data["data"]) == 3
assert data["data"][0]["content"] == "yyds"
assert data["data"][0]["count"] == 10
def test_list_jargons_with_pagination(client: TestClient, sample_jargons):
"""测试分页功能"""
response = client.get("/api/webui/jargon/list?page=1&page_size=2")
assert response.status_code == 200
data = response.json()
assert data["total"] == 3
assert len(data["data"]) == 2
response = client.get("/api/webui/jargon/list?page=2&page_size=2")
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 1
def test_list_jargons_with_search(client: TestClient, sample_jargons):
"""测试 GET /jargon/list?search=xxx 搜索功能"""
response = client.get("/api/webui/jargon/list?search=yyds")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["data"][0]["content"] == "yyds"
# 测试搜索 meaning
response = client.get("/api/webui/jargon/list?search=你好")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["data"][0]["content"] == "hello"
def test_list_jargons_with_chat_id_filter(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
"""测试按 chat_id 筛选"""
response = client.get(f"/api/webui/jargon/list?chat_id={sample_chat_session.session_id}")
assert response.status_code == 200
data = response.json()
assert data["total"] == 3
# 测试不存在的 chat_id
response = client.get("/api/webui/jargon/list?chat_id=nonexistent")
assert response.status_code == 200
data = response.json()
assert data["total"] == 0
def test_list_jargons_with_is_jargon_filter(client: TestClient, sample_jargons):
"""测试按 is_jargon 筛选"""
response = client.get("/api/webui/jargon/list?is_jargon=true")
assert response.status_code == 200
data = response.json()
assert data["total"] == 2
assert all(item["is_jargon"] is True for item in data["data"])
response = client.get("/api/webui/jargon/list?is_jargon=false")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["data"][0]["content"] == "hello"
def test_get_jargon_detail(client: TestClient, sample_jargons):
"""测试 GET /jargon/{id} 获取详情"""
jargon_id = sample_jargons[0].id
response = client.get(f"/api/webui/jargon/{jargon_id}")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["id"] == jargon_id
assert data["data"]["content"] == "yyds"
assert data["data"]["meaning"] == "永远的神"
assert data["data"]["count"] == 10
assert data["data"]["is_jargon"] is True
def test_get_jargon_detail_not_found(client: TestClient):
"""测试获取不存在的黑话详情"""
response = client.get("/api/webui/jargon/99999")
assert response.status_code == 404
assert "黑话不存在" in response.json()["detail"]
@pytest.mark.skip(reason="Composite PK (id+content) prevents autoincrement - database model issue")
def test_create_jargon(client: TestClient, sample_chat_session: ChatSession):
"""测试 POST /jargon/ 创建黑话"""
request_data = {
"content": "新黑话",
"raw_content": "原始内容",
"meaning": "含义",
"chat_id": sample_chat_session.session_id,
}
response = client.post("/api/webui/jargon/", json=request_data)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["message"] == "创建成功"
assert data["data"]["content"] == "新黑话"
assert data["data"]["meaning"] == "含义"
assert data["data"]["count"] == 0
assert data["data"]["is_jargon"] is None
assert data["data"]["is_complete"] is False
def test_create_duplicate_jargon(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
"""测试创建重复黑话返回 400"""
request_data = {
"content": "yyds",
"meaning": "重复的",
"chat_id": sample_chat_session.session_id,
}
response = client.post("/api/webui/jargon/", json=request_data)
assert response.status_code == 400
assert "已存在相同内容的黑话" in response.json()["detail"]
def test_update_jargon(client: TestClient, sample_jargons):
"""测试 PATCH /jargon/{id} 更新黑话"""
jargon_id = sample_jargons[0].id
update_data = {
"meaning": "更新后的含义",
"is_jargon": True,
}
response = client.patch(f"/api/webui/jargon/{jargon_id}", json=update_data)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["message"] == "更新成功"
assert data["data"]["meaning"] == "更新后的含义"
assert data["data"]["is_jargon"] is True
assert data["data"]["content"] == "yyds" # 未改变的字段保持不变
def test_update_jargon_with_chat_id_mapping(client: TestClient, sample_jargons):
"""测试更新时 chat_id → session_id 的映射"""
jargon_id = sample_jargons[0].id
update_data = {
"chat_id": "new_session_id",
}
response = client.patch(f"/api/webui/jargon/{jargon_id}", json=update_data)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["chat_id"] == "new_session_id"
def test_update_jargon_not_found(client: TestClient):
"""测试更新不存在的黑话"""
response = client.patch("/api/webui/jargon/99999", json={"meaning": "test"})
assert response.status_code == 404
assert "黑话不存在" in response.json()["detail"]
def test_delete_jargon(client: TestClient, sample_jargons, session: Session):
"""测试 DELETE /jargon/{id} 删除黑话"""
jargon_id = sample_jargons[0].id
response = client.delete(f"/api/webui/jargon/{jargon_id}")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["message"] == "删除成功"
assert data["deleted_count"] == 1
# 验证数据库中已删除
response = client.get(f"/api/webui/jargon/{jargon_id}")
assert response.status_code == 404
def test_delete_jargon_not_found(client: TestClient):
"""测试删除不存在的黑话"""
response = client.delete("/api/webui/jargon/99999")
assert response.status_code == 404
assert "黑话不存在" in response.json()["detail"]
def test_batch_delete(client: TestClient, sample_jargons):
"""测试 POST /jargon/batch/delete 批量删除"""
ids_to_delete = [sample_jargons[0].id, sample_jargons[1].id]
request_data = {"ids": ids_to_delete}
response = client.post("/api/webui/jargon/batch/delete", json=request_data)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["deleted_count"] == 2
assert "成功删除 2 条黑话" in data["message"]
# 验证已删除
response = client.get(f"/api/webui/jargon/{ids_to_delete[0]}")
assert response.status_code == 404
def test_batch_delete_empty_list(client: TestClient):
"""测试批量删除空列表返回 400"""
response = client.post("/api/webui/jargon/batch/delete", json={"ids": []})
assert response.status_code == 400
assert "ID列表不能为空" in response.json()["detail"]
def test_batch_set_jargon_status(client: TestClient, sample_jargons):
"""测试批量设置黑话状态"""
ids = [sample_jargons[0].id, sample_jargons[1].id]
response = client.post(
"/api/webui/jargon/batch/set-jargon",
params={"ids": ids, "is_jargon": False},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "成功更新 2 条黑话状态" in data["message"]
# 验证状态已更新
detail_response = client.get(f"/api/webui/jargon/{ids[0]}")
assert detail_response.json()["data"]["is_jargon"] is False
def test_get_stats(client: TestClient, sample_jargons):
"""测试 GET /jargon/stats/summary 统计数据"""
response = client.get("/api/webui/jargon/stats/summary")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
stats = data["data"]
assert stats["total"] == 3
assert stats["confirmed_jargon"] == 2
assert stats["confirmed_not_jargon"] == 1
assert stats["pending"] == 0
assert stats["complete_count"] == 0
assert stats["chat_count"] == 1
assert isinstance(stats["top_chats"], dict)
def test_get_chat_list(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
"""测试 GET /jargon/chats 获取聊天列表"""
response = client.get("/api/webui/jargon/chats")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert len(data["data"]) == 1
chat_info = data["data"][0]
assert chat_info["chat_id"] == sample_chat_session.session_id
assert chat_info["platform"] == "qq"
assert chat_info["is_group"] is True
assert chat_info["chat_name"] == sample_chat_session.group_id
def test_get_chat_list_with_json_chat_id(client: TestClient, session: Session, sample_chat_session: ChatSession):
"""测试解析 JSON 格式的 chat_id"""
json_chat_id = json.dumps([[sample_chat_session.session_id, "user123"]])
jargon = Jargon(
id=100,
content="测试黑话",
meaning="测试",
session_id=json_chat_id,
count=1,
)
session.add(jargon)
session.commit()
response = client.get("/api/webui/jargon/chats")
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 1
assert data["data"][0]["chat_id"] == sample_chat_session.session_id
def test_get_chat_list_without_chat_session(client: TestClient, session: Session):
"""测试聊天列表中没有对应 ChatSession 的情况"""
jargon = Jargon(
id=101,
content="孤立黑话",
meaning="无对应会话",
session_id="nonexistent_stream_id",
count=1,
)
session.add(jargon)
session.commit()
response = client.get("/api/webui/jargon/chats")
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 1
assert data["data"][0]["chat_id"] == "nonexistent_stream_id"
assert data["data"][0]["chat_name"] == "nonexistent_stream_id"[:20]
assert data["data"][0]["platform"] is None
assert data["data"][0]["is_group"] is False
def test_jargon_response_fields(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
"""测试 JargonResponse 字段完整性"""
response = client.get(f"/api/webui/jargon/{sample_jargons[0].id}")
assert response.status_code == 200
data = response.json()["data"]
# 验证所有必需字段存在
required_fields = [
"id",
"content",
"raw_content",
"meaning",
"chat_id",
"stream_id",
"chat_name",
"count",
"is_jargon",
"is_complete",
"inference_with_context",
"inference_content_only",
]
for field in required_fields:
assert field in data
# 验证 chat_name 显示逻辑
assert data["chat_name"] == sample_chat_session.group_id
@pytest.mark.skip(reason="Composite PK (id+content) prevents autoincrement - database model issue")
def test_create_jargon_without_optional_fields(client: TestClient, sample_chat_session: ChatSession):
"""测试创建黑话时可选字段为空"""
request_data = {
"content": "简单黑话",
"chat_id": sample_chat_session.session_id,
}
response = client.post("/api/webui/jargon/", json=request_data)
assert response.status_code == 200
data = response.json()["data"]
assert data["raw_content"] is None
assert data["meaning"] == ""
def test_update_jargon_partial_fields(client: TestClient, sample_jargons):
"""测试增量更新(只更新部分字段)"""
jargon_id = sample_jargons[0].id
original_content = sample_jargons[0].content
# 只更新 meaning
response = client.patch(f"/api/webui/jargon/{jargon_id}", json={"meaning": "新含义"})
assert response.status_code == 200
data = response.json()["data"]
assert data["meaning"] == "新含义"
assert data["content"] == original_content # 其他字段不变
def test_list_jargons_multiple_filters(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
"""测试组合多个过滤条件"""
response = client.get(f"/api/webui/jargon/list?search=永远&chat_id={sample_chat_session.session_id}&is_jargon=true")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["data"][0]["content"] == "yyds"

View File

@@ -0,0 +1,870 @@
from fastapi import FastAPI
from fastapi.testclient import TestClient
import pytest
from src.services.memory_service import MemorySearchResult
from src.webui.dependencies import require_auth
from src.webui.routers import memory as memory_router_module
from src.webui.routers.memory import compat_router
from src.webui.routes import router as main_router
@pytest.fixture
def client() -> TestClient:
app = FastAPI()
app.dependency_overrides[require_auth] = lambda: "ok"
app.include_router(main_router)
app.include_router(compat_router)
return TestClient(app)
def test_webui_memory_graph_route(client: TestClient, monkeypatch):
async def fake_graph_admin(*, action: str, **kwargs):
assert action == "get_graph"
return {
"success": True,
"nodes": [],
"edges": [
{
"source": "alice",
"target": "map",
"weight": 1.5,
"relation_hashes": ["rel-1"],
"predicates": ["持有"],
"relation_count": 1,
"evidence_count": 2,
"label": "持有",
}
],
"total_nodes": 0,
"limit": kwargs.get("limit"),
}
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
response = client.get("/api/webui/memory/graph", params={"limit": 77})
assert response.status_code == 200
assert response.json()["success"] is True
assert response.json()["limit"] == 77
assert response.json()["edges"][0]["predicates"] == ["持有"]
assert response.json()["edges"][0]["relation_count"] == 1
assert response.json()["edges"][0]["evidence_count"] == 2
def test_webui_memory_graph_search_route(client: TestClient, monkeypatch):
async def fake_graph_admin(*, action: str, **kwargs):
assert action == "search"
assert kwargs["query"] == "Alice"
assert kwargs["limit"] == 33
return {
"success": True,
"query": kwargs["query"],
"limit": kwargs["limit"],
"count": 1,
"items": [
{
"type": "entity",
"title": "Alice",
"matched_field": "name",
"matched_value": "Alice",
"entity_name": "Alice",
"entity_hash": "entity-1",
"appearance_count": 3,
}
],
}
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
response = client.get("/api/webui/memory/graph/search", params={"query": "Alice", "limit": 33})
assert response.status_code == 200
assert response.json()["success"] is True
assert response.json()["query"] == "Alice"
assert response.json()["limit"] == 33
assert response.json()["items"][0]["type"] == "entity"
@pytest.mark.parametrize(
"params",
[
{"query": "", "limit": 50},
{"query": "Alice", "limit": 0},
{"query": "Alice", "limit": 201},
],
)
def test_webui_memory_graph_search_route_validation(client: TestClient, params):
response = client.get("/api/webui/memory/graph/search", params=params)
assert response.status_code == 422
def test_webui_memory_graph_node_detail_route(client: TestClient, monkeypatch):
async def fake_graph_admin(*, action: str, **kwargs):
assert action == "node_detail"
assert kwargs["node_id"] == "Alice"
return {
"success": True,
"node": {"id": "Alice", "type": "entity", "content": "Alice", "appearance_count": 3},
"relations": [{"hash": "rel-1", "subject": "Alice", "predicate": "持有", "object": "Map", "text": "Alice 持有 Map", "confidence": 0.9, "paragraph_count": 1, "paragraph_hashes": ["p-1"], "source_paragraph": "p-1"}],
"paragraphs": [{"hash": "p-1", "content": "Alice 拿着地图。", "preview": "Alice 拿着地图。", "source": "demo", "entity_count": 2, "relation_count": 1, "entities": ["Alice", "Map"], "relations": ["Alice 持有 Map"]}],
"evidence_graph": {
"nodes": [{"id": "entity:Alice", "type": "entity", "content": "Alice"}],
"edges": [],
"focus_entities": ["Alice"],
},
}
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
response = client.get("/api/webui/memory/graph/node-detail", params={"node_id": "Alice"})
assert response.status_code == 200
assert response.json()["node"]["id"] == "Alice"
assert response.json()["relations"][0]["predicate"] == "持有"
assert response.json()["evidence_graph"]["focus_entities"] == ["Alice"]
def test_webui_memory_graph_node_detail_route_returns_404(client: TestClient, monkeypatch):
async def fake_graph_admin(*, action: str, **kwargs):
assert action == "node_detail"
return {"success": False, "error": "未找到节点: Missing"}
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
response = client.get("/api/webui/memory/graph/node-detail", params={"node_id": "Missing"})
assert response.status_code == 404
assert response.json()["detail"] == "未找到节点: Missing"
def test_webui_memory_graph_edge_detail_route(client: TestClient, monkeypatch):
async def fake_graph_admin(*, action: str, **kwargs):
assert action == "edge_detail"
assert kwargs["source"] == "Alice"
assert kwargs["target"] == "Map"
return {
"success": True,
"edge": {
"source": "Alice",
"target": "Map",
"weight": 1.5,
"relation_hashes": ["rel-1"],
"predicates": ["持有"],
"relation_count": 1,
"evidence_count": 1,
"label": "持有",
},
"relations": [{"hash": "rel-1", "subject": "Alice", "predicate": "持有", "object": "Map", "text": "Alice 持有 Map", "confidence": 0.9, "paragraph_count": 1, "paragraph_hashes": ["p-1"], "source_paragraph": "p-1"}],
"paragraphs": [{"hash": "p-1", "content": "Alice 拿着地图。", "preview": "Alice 拿着地图。", "source": "demo", "entity_count": 2, "relation_count": 1, "entities": ["Alice", "Map"], "relations": ["Alice 持有 Map"]}],
"evidence_graph": {
"nodes": [{"id": "relation:rel-1", "type": "relation", "content": "Alice 持有 Map"}],
"edges": [{"source": "paragraph:p-1", "target": "relation:rel-1", "kind": "supports", "label": "支撑", "weight": 1.0}],
"focus_entities": ["Alice", "Map"],
},
}
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
response = client.get("/api/webui/memory/graph/edge-detail", params={"source": "Alice", "target": "Map"})
assert response.status_code == 200
assert response.json()["edge"]["predicates"] == ["持有"]
assert response.json()["paragraphs"][0]["source"] == "demo"
assert response.json()["evidence_graph"]["edges"][0]["kind"] == "supports"
def test_webui_memory_graph_edge_detail_route_returns_404(client: TestClient, monkeypatch):
async def fake_graph_admin(*, action: str, **kwargs):
assert action == "edge_detail"
return {"success": False, "error": "未找到边: Alice -> Missing"}
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
response = client.get("/api/webui/memory/graph/edge-detail", params={"source": "Alice", "target": "Missing"})
assert response.status_code == 404
assert response.json()["detail"] == "未找到边: Alice -> Missing"
def test_webui_memory_profile_query_resolves_platform_user_id(client: TestClient, monkeypatch):
def fake_resolve_person_id_for_memory(**kwargs):
assert kwargs == {"platform": "qq", "user_id": "12345", "strict_known": False}
return "resolved-person-id"
async def fake_profile_admin(*, action: str, **kwargs):
assert action == "query"
assert kwargs["person_id"] == "resolved-person-id"
assert kwargs["person_keyword"] == "Alice"
assert kwargs["limit"] == 9
assert kwargs["force_refresh"] is True
return {"success": True, "person_id": kwargs["person_id"], "profile_text": "profile"}
monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
monkeypatch.setattr(memory_router_module.memory_service, "profile_admin", fake_profile_admin)
response = client.get(
"/api/webui/memory/profiles/query",
params={
"platform": "qq",
"user_id": "12345",
"person_keyword": "Alice",
"limit": 9,
"force_refresh": True,
},
)
assert response.status_code == 200
assert response.json()["success"] is True
assert response.json()["person_id"] == "resolved-person-id"
def test_webui_memory_profile_query_prefers_explicit_person_id(client: TestClient, monkeypatch):
def fake_resolve_person_id_for_memory(**kwargs):
raise AssertionError(f"不应解析平台账号: {kwargs}")
async def fake_profile_admin(*, action: str, **kwargs):
assert action == "query"
assert kwargs["person_id"] == "explicit-person-id"
return {"success": True, "person_id": kwargs["person_id"]}
monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
monkeypatch.setattr(memory_router_module.memory_service, "profile_admin", fake_profile_admin)
response = client.get(
"/api/webui/memory/profiles/query",
params={"person_id": "explicit-person-id", "platform": "qq", "user_id": "12345"},
)
assert response.status_code == 200
assert response.json()["person_id"] == "explicit-person-id"
def test_webui_memory_profile_list_enriches_person_name(client: TestClient, monkeypatch):
async def fake_profile_admin(*, action: str, **kwargs):
assert action == "list"
assert kwargs["limit"] == 7
return {
"success": True,
"items": [
{"person_id": "person-1", "profile_text": "profile-1"},
{"person_id": "person-2", "profile_text": "profile-2"},
],
}
monkeypatch.setattr(memory_router_module.memory_service, "profile_admin", fake_profile_admin)
monkeypatch.setattr(
memory_router_module,
"_get_person_name_for_person_id",
lambda person_id: {"person-1": "Alice"}.get(person_id, ""),
)
response = client.get("/api/webui/memory/profiles", params={"limit": 7})
assert response.status_code == 200
assert response.json()["items"][0]["person_name"] == "Alice"
assert response.json()["items"][1]["person_name"] == ""
def test_webui_memory_profile_search_resolves_platform_user_id(client: TestClient, monkeypatch):
def fake_resolve_person_id_for_memory(**kwargs):
assert kwargs == {"platform": "qq", "user_id": "12345", "strict_known": False}
return "resolved-person-id"
async def fake_profile_list(limit: int):
assert limit == 200
return {
"success": True,
"items": [
{"person_id": "resolved-person-id", "person_name": "Alice", "profile_text": "喜欢咖啡"},
{"person_id": "other-person-id", "person_name": "Bob", "profile_text": "喜欢茶"},
],
}
monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
monkeypatch.setattr(memory_router_module, "_profile_list", fake_profile_list)
response = client.get(
"/api/webui/memory/profiles/search",
params={"platform": "qq", "user_id": "12345", "limit": 50},
)
assert response.status_code == 200
assert response.json()["items"] == [
{"person_id": "resolved-person-id", "person_name": "Alice", "profile_text": "喜欢咖啡"}
]
def test_webui_memory_profile_search_filters_keyword(client: TestClient, monkeypatch):
async def fake_profile_list(limit: int):
assert limit == 200
return {
"success": True,
"items": [
{"person_id": "person-1", "person_name": "Alice", "profile_text": "喜欢咖啡"},
{"person_id": "person-2", "person_name": "Bob", "profile_text": "喜欢茶"},
],
}
monkeypatch.setattr(memory_router_module, "_profile_list", fake_profile_list)
response = client.get("/api/webui/memory/profiles/search", params={"person_keyword": "咖啡", "limit": 50})
assert response.status_code == 200
assert response.json()["items"] == [
{"person_id": "person-1", "person_name": "Alice", "profile_text": "喜欢咖啡"}
]
def test_webui_memory_episode_list_resolves_platform_user_id(client: TestClient, monkeypatch):
def fake_resolve_person_id_for_memory(**kwargs):
assert kwargs == {"platform": "qq", "user_id": "12345", "strict_known": False}
return "resolved-person-id"
async def fake_episode_admin(*, action: str, **kwargs):
assert action == "list"
assert kwargs == {
"query": "咖啡",
"limit": 9,
"source": "chat_summary:demo",
"person_id": "resolved-person-id",
"time_start": 100.0,
"time_end": 200.0,
}
return {
"success": True,
"items": [{"episode_id": "ep-1", "person_id": "resolved-person-id", "summary": "喝咖啡"}],
"count": 1,
}
monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
monkeypatch.setattr(memory_router_module.memory_service, "episode_admin", fake_episode_admin)
monkeypatch.setattr(memory_router_module, "_get_person_name_for_person_id", lambda person_id: "测试人物")
response = client.get(
"/api/webui/memory/episodes",
params={
"query": "咖啡",
"limit": 9,
"source": "chat_summary:demo",
"platform": "qq",
"user_id": "12345",
"time_start": 100,
"time_end": 200,
},
)
assert response.status_code == 200
assert response.json()["items"][0]["person_name"] == "测试人物"
def test_webui_memory_episode_list_prefers_explicit_person_id(client: TestClient, monkeypatch):
def fake_resolve_person_id_for_memory(**kwargs):
raise AssertionError(f"不应解析平台账号: {kwargs}")
async def fake_episode_admin(*, action: str, **kwargs):
assert action == "list"
assert kwargs["person_id"] == "explicit-person-id"
return {"success": True, "items": []}
monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
monkeypatch.setattr(memory_router_module.memory_service, "episode_admin", fake_episode_admin)
response = client.get(
"/api/webui/memory/episodes",
params={"person_id": "explicit-person-id", "platform": "qq", "user_id": "12345"},
)
assert response.status_code == 200
assert response.json()["items"] == []
def test_compat_aggregate_route(client: TestClient, monkeypatch):
async def fake_search(query: str, **kwargs):
assert kwargs["mode"] == "aggregate"
assert kwargs["respect_filter"] is False
return MemorySearchResult(summary=f"summary:{query}", hits=[])
monkeypatch.setattr(memory_router_module.memory_service, "search", fake_search)
response = client.get("/api/query/aggregate", params={"query": "mai"})
assert response.status_code == 200
assert response.json() == {
"success": True,
"summary": "summary:mai",
"hits": [],
"filtered": False,
"error": "",
}
def test_auto_save_routes(client: TestClient, monkeypatch):
async def fake_runtime_admin(*, action: str, **kwargs):
if action == "get_config":
return {"success": True, "auto_save": True}
if action == "set_auto_save":
return {"success": True, "auto_save": kwargs["enabled"]}
raise AssertionError(action)
monkeypatch.setattr(memory_router_module.memory_service, "runtime_admin", fake_runtime_admin)
get_response = client.get("/api/config/auto_save")
post_response = client.post("/api/config/auto_save", json={"enabled": False})
assert get_response.status_code == 200
assert get_response.json() == {"success": True, "auto_save": True}
assert post_response.status_code == 200
assert post_response.json() == {"success": True, "auto_save": False}
def test_memory_config_routes(client: TestClient, monkeypatch):
monkeypatch.setattr(
memory_router_module.a_memorix_host_service,
"get_config_schema",
lambda: {"layout": {"type": "tabs"}, "sections": {"plugin": {"fields": {}}}},
)
monkeypatch.setattr(
memory_router_module.a_memorix_host_service,
"get_config_path",
lambda: memory_router_module.Path("/tmp/config/bot_config.toml"),
)
monkeypatch.setattr(
memory_router_module.a_memorix_host_service,
"get_config",
lambda: {"plugin": {"enabled": True}},
)
monkeypatch.setattr(
memory_router_module.a_memorix_host_service,
"get_raw_config",
lambda: "[plugin]\nenabled = true\n",
)
monkeypatch.setattr(
memory_router_module.a_memorix_host_service,
"get_raw_config_with_meta",
lambda: {
"config": "[plugin]\nenabled = true\n",
"exists": True,
"using_default": False,
},
)
schema_response = client.get("/api/webui/memory/config/schema")
config_response = client.get("/api/webui/memory/config")
raw_response = client.get("/api/webui/memory/config/raw")
expected_path = memory_router_module.Path("/tmp/config/bot_config.toml").as_posix()
assert schema_response.status_code == 200
assert memory_router_module.Path(schema_response.json()["path"]).as_posix() == expected_path
assert schema_response.json()["schema"]["layout"]["type"] == "tabs"
assert config_response.status_code == 200
assert config_response.json()["success"] is True
assert config_response.json()["config"] == {"plugin": {"enabled": True}}
assert memory_router_module.Path(config_response.json()["path"]).as_posix() == expected_path
assert raw_response.status_code == 200
assert raw_response.json()["success"] is True
assert raw_response.json()["config"] == "[plugin]\nenabled = true\n"
assert memory_router_module.Path(raw_response.json()["path"]).as_posix() == expected_path
def test_memory_config_raw_returns_default_template_when_file_missing(client: TestClient, monkeypatch):
monkeypatch.setattr(
memory_router_module.a_memorix_host_service,
"get_config_path",
lambda: memory_router_module.Path("/tmp/config/bot_config.toml"),
)
monkeypatch.setattr(
memory_router_module.a_memorix_host_service,
"get_raw_config_with_meta",
lambda: {
"config": "[plugin]\nenabled = true\n",
"exists": False,
"using_default": True,
},
)
response = client.get("/api/webui/memory/config/raw")
assert response.status_code == 200
assert response.json()["success"] is True
assert response.json()["config"] == "[plugin]\nenabled = true\n"
assert response.json()["exists"] is False
assert response.json()["using_default"] is True
def test_memory_config_update_routes(client: TestClient, monkeypatch):
async def fake_update_config(config):
assert config == {"plugin": {"enabled": False}}
return {"success": True, "config_path": "config/bot_config.toml"}
async def fake_update_raw(raw_config):
assert raw_config == "[plugin]\nenabled = false\n"
return {"success": True, "config_path": "config/bot_config.toml"}
monkeypatch.setattr(memory_router_module.a_memorix_host_service, "update_config", fake_update_config)
monkeypatch.setattr(memory_router_module.a_memorix_host_service, "update_raw_config", fake_update_raw)
config_response = client.put("/api/webui/memory/config", json={"config": {"plugin": {"enabled": False}}})
raw_response = client.put("/api/webui/memory/config/raw", json={"config": "[plugin]\nenabled = false\n"})
assert config_response.status_code == 200
assert config_response.json() == {"success": True, "config_path": "config/bot_config.toml"}
assert raw_response.status_code == 200
assert raw_response.json() == {"success": True, "config_path": "config/bot_config.toml"}
def test_memory_config_raw_rejects_invalid_toml(client: TestClient):
response = client.put("/api/webui/memory/config/raw", json={"config": "[plugin\nenabled = true"})
assert response.status_code == 400
assert "TOML 格式错误" in response.json()["detail"]
def test_recycle_bin_route(client: TestClient, monkeypatch):
async def fake_get_recycle_bin(*, limit: int):
return {"success": True, "items": [{"hash": "deadbeef"}], "count": 1, "limit": limit}
monkeypatch.setattr(memory_router_module.memory_service, "get_recycle_bin", fake_get_recycle_bin)
response = client.get("/api/memory/recycle_bin", params={"limit": 10})
assert response.status_code == 200
assert response.json()["success"] is True
assert response.json()["count"] == 1
assert response.json()["limit"] == 10
def test_import_guide_route(client: TestClient, monkeypatch):
async def fake_import_admin(*, action: str, **kwargs):
assert kwargs == {}
if action == "get_guide":
return {"success": True}
if action == "get_settings":
return {"success": True, "settings": {"path_aliases": {"raw": "/tmp/raw"}}}
raise AssertionError(action)
monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin)
response = client.get("/api/webui/memory/import/guide")
assert response.status_code == 200
assert response.json()["success"] is True
assert response.json()["source"] == "local"
assert "长期记忆导入说明" in response.json()["content"]
def test_import_upload_route(client: TestClient, monkeypatch, tmp_path):
monkeypatch.setattr(memory_router_module, "STAGING_ROOT", tmp_path)
async def fake_import_admin(*, action: str, **kwargs):
assert action == "create_upload"
staged_files = kwargs["staged_files"]
assert len(staged_files) == 1
assert staged_files[0]["filename"] == "demo.txt"
assert memory_router_module.Path(staged_files[0]["staged_path"]).exists()
return {"success": True, "task_id": "task-1"}
monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin)
response = client.post(
"/api/import/upload",
data={"payload_json": "{\"source\": \"upload\"}"},
files=[("files", ("demo.txt", b"hello world", "text/plain"))],
)
assert response.status_code == 200
assert response.json() == {"success": True, "task_id": "task-1"}
assert list(tmp_path.iterdir()) == []
def test_v5_status_route(client: TestClient, monkeypatch):
async def fake_v5_admin(*, action: str, **kwargs):
assert action == "status"
assert kwargs["target"] == "mai"
return {"success": True, "active_count": 1, "inactive_count": 2, "deleted_count": 3}
monkeypatch.setattr(memory_router_module.memory_service, "v5_admin", fake_v5_admin)
response = client.get("/api/webui/memory/v5/status", params={"target": "mai"})
assert response.status_code == 200
assert response.json()["success"] is True
assert response.json()["deleted_count"] == 3
def test_delete_preview_route(client: TestClient, monkeypatch):
async def fake_delete_admin(*, action: str, **kwargs):
assert action == "preview"
assert kwargs["mode"] == "paragraph"
assert kwargs["selector"] == {"query": "demo"}
return {"success": True, "counts": {"paragraphs": 1}, "dry_run": True}
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
response = client.post(
"/api/webui/memory/delete/preview",
json={"mode": "paragraph", "selector": {"query": "demo"}},
)
assert response.status_code == 200
assert response.json() == {"success": True, "counts": {"paragraphs": 1}, "dry_run": True}
def test_delete_preview_route_supports_mixed_mode(client: TestClient, monkeypatch):
async def fake_delete_admin(*, action: str, **kwargs):
assert action == "preview"
assert kwargs["mode"] == "mixed"
assert kwargs["selector"] == {
"entity_hashes": ["entity-1"],
"paragraph_hashes": ["p-1"],
"relation_hashes": ["rel-1"],
"sources": ["demo"],
}
return {"success": True, "mode": "mixed", "counts": {"entities": 1, "paragraphs": 1, "relations": 1, "sources": 1}}
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
response = client.post(
"/api/webui/memory/delete/preview",
json={
"mode": "mixed",
"selector": {
"entity_hashes": ["entity-1"],
"paragraph_hashes": ["p-1"],
"relation_hashes": ["rel-1"],
"sources": ["demo"],
},
},
)
assert response.status_code == 200
assert response.json()["mode"] == "mixed"
assert response.json()["counts"]["entities"] == 1
def test_delete_execute_route_supports_mixed_mode(client: TestClient, monkeypatch):
async def fake_delete_admin(*, action: str, **kwargs):
assert action == "execute"
assert kwargs["mode"] == "mixed"
assert kwargs["selector"] == {
"entity_hashes": ["entity-1"],
"paragraph_hashes": ["p-1"],
"relation_hashes": ["rel-1"],
"sources": ["demo"],
}
assert kwargs["reason"] == "knowledge_graph_delete_entity"
assert kwargs["requested_by"] == "knowledge_graph"
return {
"success": True,
"mode": "mixed",
"operation_id": "op-mixed-1",
"deleted_count": 4,
"deleted_entity_count": 1,
"deleted_relation_count": 1,
"deleted_paragraph_count": 1,
"deleted_source_count": 1,
"counts": {"entities": 1, "paragraphs": 1, "relations": 1, "sources": 1},
}
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
response = client.post(
"/api/webui/memory/delete/execute",
json={
"mode": "mixed",
"selector": {
"entity_hashes": ["entity-1"],
"paragraph_hashes": ["p-1"],
"relation_hashes": ["rel-1"],
"sources": ["demo"],
},
"reason": "knowledge_graph_delete_entity",
"requested_by": "knowledge_graph",
},
)
assert response.status_code == 200
assert response.json()["success"] is True
assert response.json()["mode"] == "mixed"
assert response.json()["operation_id"] == "op-mixed-1"
def test_episode_process_pending_route(client: TestClient, monkeypatch):
async def fake_episode_admin(*, action: str, **kwargs):
assert action == "process_pending"
assert kwargs == {"limit": 7, "max_retry": 4}
return {"success": True, "processed": 3}
monkeypatch.setattr(memory_router_module.memory_service, "episode_admin", fake_episode_admin)
response = client.post("/api/webui/memory/episodes/process-pending", json={"limit": 7, "max_retry": 4})
assert response.status_code == 200
assert response.json() == {"success": True, "processed": 3}
def test_import_list_route_includes_settings(client: TestClient, monkeypatch):
calls = []
async def fake_import_admin(*, action: str, **kwargs):
calls.append((action, kwargs))
if action == "list":
return {"success": True, "items": [{"task_id": "task-1"}]}
if action == "get_settings":
return {"success": True, "settings": {"path_aliases": {"lpmm": "/tmp/lpmm"}}}
raise AssertionError(action)
monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin)
response = client.get("/api/webui/memory/import/tasks", params={"limit": 9})
assert response.status_code == 200
assert response.json()["items"] == [{"task_id": "task-1"}]
assert response.json()["settings"] == {"path_aliases": {"lpmm": "/tmp/lpmm"}}
assert calls == [("list", {"limit": 9}), ("get_settings", {})]
def test_tuning_profile_route_backfills_settings(client: TestClient, monkeypatch):
calls = []
async def fake_tuning_admin(*, action: str, **kwargs):
calls.append((action, kwargs))
if action == "get_profile":
return {"success": True, "profile": {"retrieval": {"top_k": 8}}}
if action == "get_settings":
return {"success": True, "settings": {"profiles": ["default"]}}
raise AssertionError(action)
monkeypatch.setattr(memory_router_module.memory_service, "tuning_admin", fake_tuning_admin)
response = client.get("/api/webui/memory/retrieval_tuning/profile")
assert response.status_code == 200
assert response.json()["profile"] == {"retrieval": {"top_k": 8}}
assert response.json()["settings"] == {"profiles": ["default"]}
assert calls == [("get_profile", {}), ("get_settings", {})]
def test_tuning_report_route_flattens_report_payload(client: TestClient, monkeypatch):
async def fake_tuning_admin(*, action: str, **kwargs):
assert action == "get_report"
assert kwargs == {"task_id": "task-1", "format": "json"}
return {
"success": True,
"report": {"format": "json", "content": "{\"ok\": true}", "path": "/tmp/report.json"},
}
monkeypatch.setattr(memory_router_module.memory_service, "tuning_admin", fake_tuning_admin)
response = client.get("/api/webui/memory/retrieval_tuning/tasks/task-1/report", params={"format": "json"})
assert response.status_code == 200
assert response.json() == {
"success": True,
"format": "json",
"content": "{\"ok\": true}",
"path": "/tmp/report.json",
"error": "",
}
def test_delete_execute_route(client: TestClient, monkeypatch):
async def fake_delete_admin(*, action: str, **kwargs):
assert action == "execute"
assert kwargs["mode"] == "source"
assert kwargs["selector"] == {"source": "chat_summary:stream-1"}
assert kwargs["reason"] == "cleanup"
assert kwargs["requested_by"] == "tester"
return {"success": True, "operation_id": "del-1"}
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
response = client.post(
"/api/webui/memory/delete/execute",
json={
"mode": "source",
"selector": {"source": "chat_summary:stream-1"},
"reason": "cleanup",
"requested_by": "tester",
},
)
assert response.status_code == 200
assert response.json() == {"success": True, "operation_id": "del-1"}
def test_sources_route(client: TestClient, monkeypatch):
async def fake_source_admin(*, action: str, **kwargs):
assert action == "list"
assert kwargs == {}
return {"success": True, "items": [{"source": "demo", "paragraph_count": 2}], "count": 1}
monkeypatch.setattr(memory_router_module.memory_service, "source_admin", fake_source_admin)
response = client.get("/api/webui/memory/sources")
assert response.status_code == 200
assert response.json()["items"] == [{"source": "demo", "paragraph_count": 2}]
def test_delete_operation_routes(client: TestClient, monkeypatch):
async def fake_delete_admin(*, action: str, **kwargs):
if action == "list_operations":
assert kwargs == {"limit": 5, "mode": "paragraph"}
return {"success": True, "items": [{"operation_id": "del-1"}], "count": 1}
if action == "get_operation":
assert kwargs == {"operation_id": "del-1"}
return {"success": True, "operation": {"operation_id": "del-1", "mode": "paragraph"}}
raise AssertionError(action)
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
list_response = client.get("/api/webui/memory/delete/operations", params={"limit": 5, "mode": "paragraph"})
get_response = client.get("/api/webui/memory/delete/operations/del-1")
assert list_response.status_code == 200
assert list_response.json()["count"] == 1
assert get_response.status_code == 200
assert get_response.json()["operation"]["operation_id"] == "del-1"
def test_feedback_correction_routes(client: TestClient, monkeypatch):
async def fake_feedback_admin(*, action: str, **kwargs):
if action == "list":
assert kwargs == {
"limit": 7,
"statuses": ["applied"],
"rollback_statuses": ["none"],
"query": "green",
}
return {"success": True, "items": [{"task_id": 11, "query_text": "what color"}], "count": 1}
if action == "get":
assert kwargs == {"task_id": 11}
return {"success": True, "task": {"task_id": 11, "query_text": "what color", "action_logs": []}}
if action == "rollback":
assert kwargs == {"task_id": 11, "requested_by": "tester", "reason": "manual revert"}
return {"success": True, "result": {"restored_relation_hashes": ["rel-1"]}}
raise AssertionError(action)
monkeypatch.setattr(memory_router_module.memory_service, "feedback_admin", fake_feedback_admin)
list_response = client.get(
"/api/webui/memory/feedback-corrections",
params={"limit": 7, "status": "applied", "rollback_status": "none", "query": "green"},
)
get_response = client.get("/api/webui/memory/feedback-corrections/11")
rollback_response = client.post(
"/api/webui/memory/feedback-corrections/11/rollback",
json={"requested_by": "tester", "reason": "manual revert"},
)
assert list_response.status_code == 200
assert list_response.json()["items"][0]["task_id"] == 11
assert get_response.status_code == 200
assert get_response.json()["task"]["task_id"] == 11
assert rollback_response.status_code == 200
assert rollback_response.json()["result"]["restored_relation_hashes"] == ["rel-1"]

View File

@@ -0,0 +1,533 @@
from __future__ import annotations
from pathlib import Path
from time import monotonic, sleep
from typing import Any, Dict, Generator
from uuid import uuid4
import asyncio
import json
from fastapi import FastAPI
from fastapi.testclient import TestClient
import pytest
import tomlkit
from src.A_memorix import host_service as host_service_module
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
from src.A_memorix.core.utils import retrieval_tuning_manager as tuning_manager_module
from src.webui.dependencies import require_auth
from src.webui.routers import memory as memory_router_module
REQUEST_TIMEOUT_SECONDS = 30
IMPORT_TIMEOUT_SECONDS = 120
TUNING_TIMEOUT_SECONDS = 420
IMPORT_TERMINAL_STATUSES = {"completed", "completed_with_errors", "failed", "cancelled"}
TUNING_TERMINAL_STATUSES = {"completed", "failed", "cancelled"}
class _FakeEmbeddingManager:
def __init__(self, dimension: int = 64) -> None:
self.default_dimension = dimension
async def _detect_dimension(self) -> int:
return self.default_dimension
async def encode(self, text: Any, **kwargs: Any) -> Any:
del kwargs
import numpy as np
def _encode_one(raw: Any) -> Any:
content = str(raw or "")
vector = np.zeros(self.default_dimension, dtype=np.float32)
for index, byte in enumerate(content.encode("utf-8")):
vector[index % self.default_dimension] += float((byte % 17) + 1)
norm = float(np.linalg.norm(vector))
if norm > 0:
vector /= norm
return vector
if isinstance(text, (list, tuple)):
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
return _encode_one(text).astype(np.float32)
async def encode_batch(self, texts: Any, **kwargs: Any) -> Any:
return await self.encode(texts, **kwargs)
def _build_test_config(data_dir: Path) -> Dict[str, Any]:
return {
"storage": {
"data_dir": str(data_dir),
},
"advanced": {
"enable_auto_save": False,
},
"embedding": {
"dimension": 64,
"batch_size": 4,
"max_concurrent": 1,
"retry": {
"max_attempts": 1,
"min_wait_seconds": 0.1,
"max_wait_seconds": 0.2,
"backoff_multiplier": 1.0,
},
"fallback": {
"enabled": True,
"allow_metadata_only_write": True,
"probe_interval_seconds": 30,
},
"paragraph_vector_backfill": {
"enabled": False,
"interval_seconds": 60,
"batch_size": 32,
"max_retry": 2,
},
},
"retrieval": {
"enable_parallel": False,
"enable_ppr": False,
"top_k_paragraphs": 20,
"top_k_relations": 10,
"top_k_final": 10,
"alpha": 0.5,
"search": {
"smart_fallback": {
"enabled": True,
},
},
"sparse": {
"enabled": True,
"mode": "auto",
"candidate_k": 80,
"relation_candidate_k": 60,
},
"fusion": {
"method": "weighted_rrf",
"rrf_k": 60,
"vector_weight": 0.7,
"bm25_weight": 0.3,
},
},
"threshold": {
"percentile": 70.0,
"min_results": 1,
},
"web": {
"tuning": {
"enabled": True,
"poll_interval_ms": 300,
"max_queue_size": 4,
"default_objective": "balanced",
"default_intensity": "quick",
"default_sample_size": 4,
"default_top_k_eval": 5,
"eval_query_timeout_seconds": 1.0,
"llm_retry": {
"max_attempts": 1,
"min_wait_seconds": 0.1,
"max_wait_seconds": 0.2,
"backoff_multiplier": 1.0,
},
},
},
}
def _assert_response_ok(response: Any) -> Dict[str, Any]:
assert response.status_code == 200, response.text
payload = response.json()
assert payload.get("success", True) is True, payload
return payload
def _wait_for_import_task_terminal(client: TestClient, task_id: str, *, timeout_seconds: float = IMPORT_TIMEOUT_SECONDS) -> Dict[str, Any]:
deadline = monotonic() + timeout_seconds
last_payload: Dict[str, Any] = {}
while monotonic() < deadline:
response = client.get(
f"/api/webui/memory/import/tasks/{task_id}",
params={"include_chunks": True},
)
payload = _assert_response_ok(response)
last_payload = payload
task = payload.get("task") or {}
status = str(task.get("status", "") or "")
if status in IMPORT_TERMINAL_STATUSES:
return task
sleep(0.2)
raise AssertionError(f"导入任务超时: task_id={task_id}, last_payload={last_payload}")
def _wait_for_tuning_task_terminal(client: TestClient, task_id: str, *, timeout_seconds: float = TUNING_TIMEOUT_SECONDS) -> Dict[str, Any]:
deadline = monotonic() + timeout_seconds
last_payload: Dict[str, Any] = {}
while monotonic() < deadline:
response = client.get(
f"/api/webui/memory/retrieval_tuning/tasks/{task_id}",
params={"include_rounds": False},
)
payload = _assert_response_ok(response)
last_payload = payload
task = payload.get("task") or {}
status = str(task.get("status", "") or "")
if status in TUNING_TERMINAL_STATUSES:
return task
sleep(0.3)
raise AssertionError(f"调优任务超时: task_id={task_id}, last_payload={last_payload}")
def _wait_for_query_hit(client: TestClient, query: str, *, timeout_seconds: float = 30.0) -> Dict[str, Any]:
deadline = monotonic() + timeout_seconds
last_payload: Dict[str, Any] = {}
while monotonic() < deadline:
payload = _assert_response_ok(
client.get(
"/api/webui/memory/query/aggregate",
params={"query": query, "limit": 20},
)
)
last_payload = payload
hits = payload.get("hits") or []
if isinstance(hits, list) and len(hits) > 0:
return payload
sleep(0.2)
raise AssertionError(f"检索命中超时: query={query}, last_payload={last_payload}")
def _get_source_item(client: TestClient, source_name: str) -> Dict[str, Any] | None:
payload = _assert_response_ok(client.get("/api/webui/memory/sources"))
items = payload.get("items") or []
for item in items:
if not isinstance(item, dict):
continue
if str(item.get("source", "") or "") == source_name:
return item
return None
def _source_paragraph_count(item: Dict[str, Any] | None) -> int:
payload = item or {}
if "paragraph_count" in payload:
return int(payload.get("paragraph_count", 0) or 0)
return int(payload.get("count", 0) or 0)
def _wait_for_source_paragraph_count(
client: TestClient,
source_name: str,
*,
min_count: int,
timeout_seconds: float = 30.0,
) -> Dict[str, Any]:
deadline = monotonic() + timeout_seconds
last_item: Dict[str, Any] = {}
while monotonic() < deadline:
item = _get_source_item(client, source_name)
count = _source_paragraph_count(item)
if count >= int(min_count):
return item or {}
if item:
last_item = dict(item)
sleep(0.2)
raise AssertionError(
f"等待来源段落计数超时: source={source_name}, min_count={min_count}, last_item={last_item}"
)
def _create_multitype_upload_task(client: TestClient) -> str:
structured_json = {
"paragraphs": [
{
"content": "Alice 携带地图前往火星港。",
"source": "integration-upload-json",
"entities": ["Alice", "地图", "火星港"],
"relations": [
{"subject": "Alice", "predicate": "携带", "object": "地图"},
{"subject": "Alice", "predicate": "前往", "object": "火星港"},
],
}
]
}
extra_json = {
"paragraphs": [
{
"content": "Carol 记录了一条补充说明。",
"source": "integration-upload-json-extra",
"entities": ["Carol"],
"relations": [],
}
]
}
payload_json = json.dumps(
{
"input_mode": "text",
"llm_enabled": False,
"file_concurrency": 2,
"chunk_concurrency": 2,
"dedupe_policy": "none",
},
ensure_ascii=False,
)
files = [
("files", ("integration-notes.txt", "Alice 在测试环境记录了一条长期记忆。".encode("utf-8"), "text/plain")),
("files", ("integration-diary.md", "# 日志\nBob 与 Alice 讨论了导图。".encode("utf-8"), "text/markdown")),
("files", ("integration-structured.json", json.dumps(structured_json, ensure_ascii=False).encode("utf-8"), "application/json")),
("files", ("integration-extra.json", json.dumps(extra_json, ensure_ascii=False).encode("utf-8"), "application/json")),
]
response = client.post(
"/api/webui/memory/import/upload",
data={"payload_json": payload_json},
files=files,
)
payload = _assert_response_ok(response)
task_id = str((payload.get("task") or {}).get("task_id") or "").strip()
assert task_id, payload
return task_id
def _create_seed_paste_task(client: TestClient, *, source: str, unique_token: str) -> str:
seed_payload = {
"paragraphs": [
{
"content": f"Alice 在火星港携带地图并记录了口令 {unique_token}",
"source": source,
"entities": ["Alice", "火星港", "地图"],
"relations": [
{"subject": "Alice", "predicate": "前往", "object": "火星港"},
{"subject": "Alice", "predicate": "携带", "object": "地图"},
],
},
{
"content": f"Bob 在火星港遇见 Alice并重复口令 {unique_token}",
"source": source,
"entities": ["Bob", "Alice", "火星港"],
"relations": [
{"subject": "Bob", "predicate": "遇见", "object": "Alice"},
{"subject": "Bob", "predicate": "位于", "object": "火星港"},
],
},
]
}
response = client.post(
"/api/webui/memory/import/paste",
json={
"name": "integration-seed.json",
"input_mode": "json",
"llm_enabled": False,
"content": json.dumps(seed_payload, ensure_ascii=False),
"dedupe_policy": "none",
},
)
payload = _assert_response_ok(response)
task_id = str((payload.get("task") or {}).get("task_id") or "").strip()
assert task_id, payload
return task_id
@pytest.fixture(scope="module")
def integration_state(tmp_path_factory: pytest.TempPathFactory) -> Generator[Dict[str, Any], None, None]:
tmp_root = tmp_path_factory.mktemp("memory_routes_integration")
data_dir = (tmp_root / "data").resolve()
staging_dir = (tmp_root / "upload_staging").resolve()
artifacts_dir = (tmp_root / "artifacts").resolve()
config_file = (tmp_root / "config" / "bot_config.toml").resolve()
runtime_config = _build_test_config(data_dir)
patches = pytest.MonkeyPatch()
patches.setattr(host_service_module.a_memorix_host_service, "_read_config", lambda: dict(runtime_config))
patches.setattr(host_service_module.a_memorix_host_service, "get_config_path", lambda: config_file)
patches.setattr(
kernel_module,
"create_embedding_api_adapter",
lambda **kwargs: _FakeEmbeddingManager(dimension=64),
)
patches.setattr(memory_router_module, "STAGING_ROOT", staging_dir)
patches.setattr(tuning_manager_module, "artifacts_root", lambda: artifacts_dir)
asyncio.run(host_service_module.a_memorix_host_service.stop())
host_service_module.a_memorix_host_service._config_cache = None # type: ignore[attr-defined]
app = FastAPI()
app.dependency_overrides[require_auth] = lambda: "ok"
app.include_router(memory_router_module.router, prefix="/api/webui")
app.include_router(memory_router_module.compat_router)
unique_token = f"INTEG_TOKEN_{uuid4().hex[:12]}"
source_name = f"integration-source-{uuid4().hex[:8]}"
with TestClient(app) as client:
upload_task_id = _create_multitype_upload_task(client)
upload_task = _wait_for_import_task_terminal(client, upload_task_id)
seed_task_id = _create_seed_paste_task(client, source=source_name, unique_token=unique_token)
seed_task = _wait_for_import_task_terminal(client, seed_task_id)
assert str(seed_task.get("status", "") or "") in {"completed", "completed_with_errors"}, seed_task
_wait_for_query_hit(client, unique_token, timeout_seconds=45.0)
yield {
"client": client,
"upload_task": upload_task,
"seed_task": seed_task,
"source_name": source_name,
"unique_token": unique_token,
}
asyncio.run(host_service_module.a_memorix_host_service.stop())
host_service_module.a_memorix_host_service._config_cache = None # type: ignore[attr-defined]
patches.undo()
def test_import_module_end_to_end_supports_multitype_upload(integration_state: Dict[str, Any]) -> None:
upload_task = integration_state["upload_task"]
assert str(upload_task.get("status", "") or "") in {"completed", "completed_with_errors"}, upload_task
files = upload_task.get("files") or []
assert isinstance(files, list)
assert len(files) >= 4
file_names = {str(item.get("name", "") or "") for item in files if isinstance(item, dict)}
assert "integration-notes.txt" in file_names
assert "integration-diary.md" in file_names
assert "integration-structured.json" in file_names
assert "integration-extra.json" in file_names
def test_retrieval_module_end_to_end_queries_seeded_data(integration_state: Dict[str, Any]) -> None:
client = integration_state["client"]
unique_token = integration_state["unique_token"]
aggregate_payload = _wait_for_query_hit(client, unique_token, timeout_seconds=45.0)
hits = aggregate_payload.get("hits") or []
joined_content = "\n".join(str(item.get("content", "") or "") for item in hits if isinstance(item, dict))
assert unique_token in joined_content
graph_payload = _assert_response_ok(
client.get(
"/api/webui/memory/graph/search",
params={"query": "Alice", "limit": 20},
)
)
graph_items = graph_payload.get("items") or []
assert isinstance(graph_items, list)
assert any(str(item.get("type", "") or "") == "entity" for item in graph_items if isinstance(item, dict)), graph_items
def test_tuning_module_end_to_end_create_and_apply_best(integration_state: Dict[str, Any]) -> None:
client = integration_state["client"]
create_payload = _assert_response_ok(
client.post(
"/api/webui/memory/retrieval_tuning/tasks",
json={
"objective": "balanced",
"intensity": "quick",
"rounds": 2,
"sample_size": 4,
"top_k_eval": 5,
"llm_enabled": False,
"eval_query_timeout_seconds": 1.0,
"seed": 20260403,
},
)
)
task_id = str((create_payload.get("task") or {}).get("task_id") or "").strip()
assert task_id, create_payload
task = _wait_for_tuning_task_terminal(client, task_id)
assert str(task.get("status", "") or "") == "completed", task
apply_payload = _assert_response_ok(
client.post(
f"/api/webui/memory/retrieval_tuning/tasks/{task_id}/apply-best",
)
)
assert "applied" in apply_payload
def test_delete_module_end_to_end_preview_execute_restore(integration_state: Dict[str, Any]) -> None:
client = integration_state["client"]
unique_token = integration_state["unique_token"]
source_name = integration_state["source_name"]
before_source_item = _wait_for_source_paragraph_count(client, source_name, min_count=1, timeout_seconds=45.0)
assert _source_paragraph_count(before_source_item) >= 1
preview_payload = _assert_response_ok(
client.post(
"/api/webui/memory/delete/preview",
json={
"mode": "source",
"selector": {"sources": [source_name]},
"reason": "integration_delete_preview",
"requested_by": "pytest_integration",
},
)
)
preview_counts = preview_payload.get("counts") or {}
assert int(preview_counts.get("paragraphs", 0) or 0) >= 1, preview_payload
execute_payload = _assert_response_ok(
client.post(
"/api/webui/memory/delete/execute",
json={
"mode": "source",
"selector": {"sources": [source_name]},
"reason": "integration_delete_execute",
"requested_by": "pytest_integration",
},
)
)
operation_id = str(execute_payload.get("operation_id", "") or "").strip()
assert operation_id, execute_payload
after_delete_payload = _assert_response_ok(
client.get(
"/api/webui/memory/query/aggregate",
params={"query": unique_token, "limit": 20},
)
)
after_delete_hits = after_delete_payload.get("hits") or []
after_delete_text = "\n".join(
str(item.get("content", "") or "")
for item in after_delete_hits
if isinstance(item, dict)
)
assert unique_token not in after_delete_text
assert int(execute_payload.get("deleted_paragraph_count", 0) or 0) >= 1, execute_payload
_assert_response_ok(
client.post(
"/api/webui/memory/delete/restore",
json={
"operation_id": operation_id,
"requested_by": "pytest_integration",
},
)
)
restored_source_item = _wait_for_source_paragraph_count(client, source_name, min_count=1, timeout_seconds=45.0)
assert _source_paragraph_count(restored_source_item) >= 1
operations_payload = _assert_response_ok(
client.get(
"/api/webui/memory/delete/operations",
params={"limit": 20, "mode": "source"},
)
)
operation_items = operations_payload.get("items") or []
operation_ids = {
str(item.get("operation_id", "") or "")
for item in operation_items
if isinstance(item, dict)
}
assert operation_id in operation_ids
operation_detail_payload = _assert_response_ok(client.get(f"/api/webui/memory/delete/operations/{operation_id}"))
detail_operation = operation_detail_payload.get("operation") or {}
assert str(detail_operation.get("status", "") or "") == "restored"

View File

@@ -0,0 +1,187 @@
"""模型路由测试
验证 Gemini 提供商连接测试会使用查询参数传递 API Key
并且不会回退到 OpenAI 兼容接口使用的 Bearer 认证方式。
"""
import importlib
import sys
from types import ModuleType
from typing import Any
import pytest
def load_model_routes(monkeypatch: pytest.MonkeyPatch):
"""在导入路由前 stub 配置与认证依赖模块,避免测试时触发真实初始化。"""
config_module = ModuleType("src.config.config")
config_module.__dict__["CONFIG_DIR"] = "."
monkeypatch.setitem(sys.modules, "src.config.config", config_module)
dependencies_module = ModuleType("src.webui.dependencies")
async def require_auth():
return "test-token"
dependencies_module.__dict__["require_auth"] = require_auth
monkeypatch.setitem(sys.modules, "src.webui.dependencies", dependencies_module)
sys.modules.pop("src.webui.routers.model", None)
return importlib.import_module("src.webui.routers.model")
class FakeResponse:
"""简化版 HTTP 响应对象。"""
def __init__(self, status_code: int):
self.status_code = status_code
def build_async_client_factory(
responses: list[FakeResponse],
calls: list[dict[str, Any]],
):
"""构造一个可记录请求参数的 AsyncClient 替身。"""
response_iter = iter(responses)
class FakeAsyncClient:
def __init__(self, *args: Any, **kwargs: Any):
self.args = args
self.kwargs = kwargs
async def __aenter__(self) -> "FakeAsyncClient":
return self
async def __aexit__(self, exc_type, exc, tb) -> bool:
return False
async def get(
self,
url: str,
headers: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
) -> FakeResponse:
calls.append(
{
"url": url,
"headers": headers or {},
"params": params or {},
}
)
return next(response_iter)
return FakeAsyncClient
@pytest.mark.asyncio
async def test_test_provider_connection_uses_query_api_key_for_gemini(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Gemini 连接测试应通过查询参数传递 API Key。"""
model_routes = load_model_routes(monkeypatch)
calls: list[dict[str, Any]] = []
fake_client_class = build_async_client_factory(
responses=[FakeResponse(200), FakeResponse(200)],
calls=calls,
)
monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class)
result = await model_routes.test_provider_connection(
base_url="https://generativelanguage.googleapis.com/v1beta",
api_key="valid-gemini-key",
client_type="gemini",
)
assert result["network_ok"] is True
assert result["api_key_valid"] is True
assert len(calls) == 2
network_call = calls[0]
validation_call = calls[1]
assert network_call["url"] == "https://generativelanguage.googleapis.com/v1beta"
assert network_call["headers"] == {}
assert network_call["params"] == {}
assert validation_call["url"] == "https://generativelanguage.googleapis.com/v1beta/models"
assert validation_call["params"] == {"key": "valid-gemini-key"}
assert validation_call["headers"] == {"Content-Type": "application/json"}
assert "Authorization" not in validation_call["headers"]
@pytest.mark.asyncio
async def test_test_provider_connection_uses_bearer_auth_for_openai_compatible(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""非 Gemini 提供商连接测试应继续使用 Bearer 认证。"""
model_routes = load_model_routes(monkeypatch)
calls: list[dict[str, Any]] = []
fake_client_class = build_async_client_factory(
responses=[FakeResponse(200), FakeResponse(200)],
calls=calls,
)
monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class)
result = await model_routes.test_provider_connection(
base_url="https://example.com/v1",
api_key="valid-openai-key",
client_type="openai",
)
assert result["network_ok"] is True
assert result["api_key_valid"] is True
assert len(calls) == 2
validation_call = calls[1]
assert validation_call["url"] == "https://example.com/v1/models"
assert validation_call["params"] == {}
assert validation_call["headers"]["Content-Type"] == "application/json"
assert validation_call["headers"]["Authorization"] == "Bearer valid-openai-key"
@pytest.mark.asyncio
async def test_test_provider_connection_by_name_forwards_provider_client_type(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
) -> None:
"""按提供商名称测试连接时,应透传配置中的 client_type。"""
model_routes = load_model_routes(monkeypatch)
config_path = tmp_path / "model_config.toml"
config_path.write_text(
"""
[[api_providers]]
name = "Gemini"
base_url = "https://generativelanguage.googleapis.com/v1beta"
api_key = "valid-gemini-key"
client_type = "gemini"
""".strip(),
encoding="utf-8",
)
monkeypatch.setattr(model_routes, "CONFIG_DIR", str(tmp_path))
captured_kwargs: dict[str, Any] = {}
async def fake_test_provider_connection(**kwargs: Any) -> dict[str, Any]:
captured_kwargs.update(kwargs)
return {
"network_ok": True,
"api_key_valid": True,
"latency_ms": 12.34,
"error": None,
"http_status": 200,
}
monkeypatch.setattr(model_routes, "test_provider_connection", fake_test_provider_connection)
result = await model_routes.test_provider_connection_by_name(provider_name="Gemini")
assert result["network_ok"] is True
assert result["api_key_valid"] is True
assert captured_kwargs == {
"base_url": "https://generativelanguage.googleapis.com/v1beta",
"api_key": "valid-gemini-key",
"client_type": "gemini",
}

View File

@@ -0,0 +1,136 @@
import json
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from src.webui.routers.plugin import management as management_module
from src.webui.routers.plugin import support as support_module
@pytest.fixture
def client(tmp_path, monkeypatch) -> TestClient:
plugins_dir = tmp_path / "plugins"
plugins_dir.mkdir(parents=True, exist_ok=True)
demo_dir = plugins_dir / "demo_plugin"
demo_dir.mkdir()
(demo_dir / "_manifest.json").write_text(
json.dumps(
{
"manifest_version": 2,
"id": "test.demo",
"name": "Demo Plugin",
"version": "1.0.0",
"description": "demo plugin",
}
),
encoding="utf-8",
)
monkeypatch.setattr(management_module, "require_plugin_token", lambda _: "ok")
monkeypatch.setattr(support_module, "get_plugins_dir", lambda: plugins_dir)
app = FastAPI()
app.include_router(management_module.router, prefix="/api/webui/plugins")
return TestClient(app)
def test_installed_plugins_only_scan_plugins_dir_and_exclude_a_memorix(client: TestClient):
response = client.get("/api/webui/plugins/installed")
assert response.status_code == 200
payload = response.json()
assert payload["success"] is True
ids = [plugin["id"] for plugin in payload["plugins"]]
assert ids == ["test.demo"]
assert "a-dawn.a-memorix" not in ids
assert all("/src/plugins/built_in/" not in plugin["path"] for plugin in payload["plugins"])
def test_resolve_installed_plugin_path_falls_back_to_manifest_id(client: TestClient):
plugin_path = support_module.resolve_installed_plugin_path("test.demo")
assert plugin_path is not None
assert plugin_path.name == "demo_plugin"
def test_resolve_installed_plugin_path_accepts_manifest_id_case_mismatch(client: TestClient):
plugin_path = support_module.resolve_installed_plugin_path("Test.Demo")
assert plugin_path is not None
assert plugin_path.name == "demo_plugin"
def test_install_plugin_preserves_manifest_declared_id(client: TestClient, monkeypatch):
class FakeGitMirrorService:
async def clone_repository(self, **kwargs):
target_path = kwargs["target_path"]
target_path.mkdir(parents=True, exist_ok=True)
(target_path / "_manifest.json").write_text(
json.dumps(
{
"manifest_version": 2,
"id": "author.declared",
"name": "Declared Plugin",
"version": "1.0.0",
"author": {"name": "author"},
}
),
encoding="utf-8",
)
return {"success": True}
monkeypatch.setattr(management_module, "get_git_mirror_service", lambda: FakeGitMirrorService())
response = client.post(
"/api/webui/plugins/install",
json={
"plugin_id": "market.plugin",
"repository_url": "https://github.com/author/declared",
"branch": "main",
},
)
assert response.status_code == 200
plugin_path = support_module.resolve_installed_plugin_path("author.declared")
assert plugin_path is not None
manifest = json.loads((plugin_path / "_manifest.json").read_text(encoding="utf-8"))
assert manifest["id"] == "author.declared"
def test_install_plugin_backfills_missing_manifest_id(client: TestClient, monkeypatch):
class FakeGitMirrorService:
async def clone_repository(self, **kwargs):
target_path = kwargs["target_path"]
target_path.mkdir(parents=True, exist_ok=True)
(target_path / "_manifest.json").write_text(
json.dumps(
{
"manifest_version": 2,
"name": "Legacy Plugin",
"version": "1.0.0",
"author": {"name": "author"},
}
),
encoding="utf-8",
)
return {"success": True}
monkeypatch.setattr(management_module, "get_git_mirror_service", lambda: FakeGitMirrorService())
response = client.post(
"/api/webui/plugins/install",
json={
"plugin_id": "market.legacy",
"repository_url": "https://github.com/author/legacy",
"branch": "main",
},
)
assert response.status_code == 200
plugin_path = support_module.resolve_installed_plugin_path("market.legacy")
assert plugin_path is not None
manifest = json.loads((plugin_path / "_manifest.json").read_text(encoding="utf-8"))
assert manifest["id"] == "market.legacy"

View File

@@ -0,0 +1,332 @@
from contextlib import contextmanager
from datetime import datetime, timedelta
from types import SimpleNamespace
from typing import Any, Iterator
import pytest
from src.services import statistics_service
from src.webui.schemas.statistics import DashboardData, StatisticsSummary, TimeSeriesData
class _Result:
def __init__(self, *, first_value: Any = None, all_values: list[Any] | None = None) -> None:
self._first_value = first_value
self._all_values = all_values or []
def first(self) -> Any:
return self._first_value
def all(self) -> list[Any]:
return self._all_values
class _Session:
def __init__(self, results: list[_Result]) -> None:
self._results = results
def exec(self, statement: Any) -> _Result:
del statement
return self._results.pop(0)
class _MemoryStore:
def __init__(self) -> None:
self.store: dict[str, Any] = {}
def __getitem__(self, item: str) -> Any:
return self.store.get(item)
def __setitem__(self, key: str, value: Any) -> None:
self.store[key] = value
def _patch_session_results(monkeypatch: pytest.MonkeyPatch, results: list[_Result]) -> list[bool]:
auto_commit_calls: list[bool] = []
@contextmanager
def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_Session]:
auto_commit_calls.append(auto_commit)
yield _Session([results.pop(0)])
monkeypatch.setattr(statistics_service, "get_db_session", _fake_get_db_session)
return auto_commit_calls
def _patch_session_result_group(monkeypatch: pytest.MonkeyPatch, results: list[_Result]) -> list[bool]:
auto_commit_calls: list[bool] = []
@contextmanager
def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_Session]:
auto_commit_calls.append(auto_commit)
yield _Session(results)
monkeypatch.setattr(statistics_service, "get_db_session", _fake_get_db_session)
return auto_commit_calls
def _build_dashboard_data(total_requests: int = 1) -> DashboardData:
return DashboardData(
summary=StatisticsSummary(total_requests=total_requests),
model_stats=[],
hourly_data=[],
daily_data=[],
recent_activity=[],
)
def _build_dashboard_data_with_time_series() -> DashboardData:
return DashboardData(
summary=StatisticsSummary(total_requests=1),
model_stats=[],
hourly_data=[
TimeSeriesData(timestamp="2026-05-06T10:00:00", requests=0, cost=0.0, tokens=0),
TimeSeriesData(timestamp="2026-05-06T11:00:00", requests=2, cost=0.5, tokens=50),
TimeSeriesData(timestamp="2026-05-06T12:00:00", requests=0, cost=0.0, tokens=0),
],
daily_data=[
TimeSeriesData(timestamp="2026-05-05T00:00:00", requests=0, cost=0.0, tokens=0),
TimeSeriesData(timestamp="2026-05-06T00:00:00", requests=3, cost=0.7, tokens=70),
],
recent_activity=[],
)
def test_shared_fetch_queries_disable_auto_commit(monkeypatch: pytest.MonkeyPatch) -> None:
now = datetime(2026, 5, 6, 12, 0, 0)
online_record = SimpleNamespace(start_timestamp=now - timedelta(minutes=5), end_timestamp=now)
usage_record = SimpleNamespace(
timestamp=now,
request_type="chat.reply",
model_api_provider_name="provider",
model_assign_name="chat-main",
model_name="gpt-a",
prompt_tokens=10,
completion_tokens=5,
cost=0.01,
time_cost=1.2,
)
message_record = SimpleNamespace(timestamp=now, message_id="msg-1")
tool_record = SimpleNamespace(timestamp=now, tool_name="reply")
auto_commit_calls = _patch_session_results(
monkeypatch,
[
_Result(all_values=[online_record]),
_Result(all_values=[usage_record]),
_Result(all_values=[message_record]),
_Result(all_values=[tool_record]),
],
)
online_ranges = statistics_service.fetch_online_time_since(now - timedelta(hours=1))
usage_records = statistics_service.fetch_model_usage_since(now - timedelta(hours=1))
messages = statistics_service.fetch_messages_since(now - timedelta(hours=1))
tool_records = statistics_service.fetch_tool_records_since(now - timedelta(hours=1))
assert online_ranges == [(online_record.start_timestamp, online_record.end_timestamp)]
assert usage_records == [
{
"timestamp": now,
"request_type": "chat.reply",
"model_api_provider_name": "provider",
"model_assign_name": "chat-main",
"model_name": "gpt-a",
"prompt_tokens": 10,
"completion_tokens": 5,
"cost": 0.01,
"time_cost": 1.2,
}
]
assert messages == [message_record]
assert tool_records == [tool_record]
assert auto_commit_calls == [False, False, False, False]
def test_get_earliest_statistics_time_uses_min_valid_timestamp(monkeypatch: pytest.MonkeyPatch) -> None:
fallback_time = datetime(2026, 5, 6, 12, 0, 0)
earliest_time = datetime(2026, 5, 1, 8, 30, 0)
auto_commit_calls = _patch_session_result_group(
monkeypatch,
[
_Result(first_value=datetime(2026, 5, 3, 9, 0, 0)),
_Result(first_value=earliest_time),
_Result(first_value=None),
_Result(first_value=datetime(2026, 5, 2, 9, 0, 0)),
],
)
result = statistics_service.get_earliest_statistics_time(fallback_time)
assert result == earliest_time
assert auto_commit_calls == [False]
def test_get_earliest_statistics_time_falls_back_when_query_fails(monkeypatch: pytest.MonkeyPatch) -> None:
fallback_time = datetime(2026, 5, 6, 12, 0, 0)
@contextmanager
def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_Session]:
del auto_commit
raise RuntimeError("database unavailable")
yield _Session([])
monkeypatch.setattr(statistics_service, "get_db_session", _fake_get_db_session)
assert statistics_service.get_earliest_statistics_time(fallback_time) == fallback_time
def test_dashboard_statistics_cache_roundtrip(monkeypatch: pytest.MonkeyPatch) -> None:
memory_store = _MemoryStore()
now = datetime.now()
dashboard_data = _build_dashboard_data(total_requests=7)
monkeypatch.setattr(statistics_service, "local_storage", memory_store)
statistics_service.store_dashboard_statistics_cache({24: dashboard_data}, generated_at=now)
cached_data = statistics_service.get_cached_dashboard_statistics(24)
assert cached_data is not None
assert cached_data.summary.total_requests == 7
def test_dashboard_statistics_cache_stores_sparse_time_series(monkeypatch: pytest.MonkeyPatch) -> None:
memory_store = _MemoryStore()
generated_at = datetime(2026, 5, 6, 12, 0, 0)
dashboard_data = _build_dashboard_data_with_time_series()
monkeypatch.setattr(statistics_service, "local_storage", memory_store)
statistics_service.store_dashboard_statistics_cache({2: dashboard_data}, generated_at=generated_at)
raw_cache = memory_store[statistics_service.DASHBOARD_STATISTICS_CACHE_KEY]
raw_entry = raw_cache["entries"]["2"]
assert raw_entry["sparse"] is True
assert raw_entry["hourly_data"] == [
{"timestamp": "2026-05-06T11:00:00", "requests": 2, "cost": 0.5, "tokens": 50}
]
assert raw_entry["daily_data"] == [
{"timestamp": "2026-05-06T00:00:00", "requests": 3, "cost": 0.7, "tokens": 70}
]
cached_data = statistics_service.get_cached_dashboard_statistics(2, max_age_seconds=10**9)
assert cached_data is not None
assert [item.timestamp for item in cached_data.hourly_data] == [
"2026-05-06T10:00:00",
"2026-05-06T11:00:00",
"2026-05-06T12:00:00",
]
assert cached_data.hourly_data[0].requests == 0
assert cached_data.hourly_data[1].requests == 2
assert cached_data.hourly_data[2].requests == 0
@pytest.mark.asyncio
async def test_get_dashboard_statistics_prefers_cache(monkeypatch: pytest.MonkeyPatch) -> None:
memory_store = _MemoryStore()
dashboard_data = _build_dashboard_data(total_requests=9)
monkeypatch.setattr(statistics_service, "local_storage", memory_store)
statistics_service.store_dashboard_statistics_cache({24: dashboard_data}, generated_at=datetime.now())
async def _fail_compute_dashboard_statistics(hours: int = 24) -> DashboardData:
del hours
raise AssertionError("cache should be used")
monkeypatch.setattr(statistics_service, "compute_dashboard_statistics", _fail_compute_dashboard_statistics)
result = await statistics_service.get_dashboard_statistics(24)
assert result.summary.total_requests == 9
@pytest.mark.asyncio
async def test_get_dashboard_statistics_returns_empty_when_cache_missing(monkeypatch: pytest.MonkeyPatch) -> None:
memory_store = _MemoryStore()
monkeypatch.setattr(statistics_service, "local_storage", memory_store)
async def _fail_compute_dashboard_statistics(hours: int = 24) -> DashboardData:
del hours
raise AssertionError("dashboard API should not compute fallback data")
monkeypatch.setattr(statistics_service, "compute_dashboard_statistics", _fail_compute_dashboard_statistics)
result = await statistics_service.get_dashboard_statistics(24)
assert result.summary.total_requests == 0
assert result.model_stats == []
@pytest.mark.asyncio
async def test_get_summary_statistics_aggregates_database_and_message_counts(monkeypatch: pytest.MonkeyPatch) -> None:
start_time = datetime(2026, 5, 6, 10, 0, 0)
end_time = datetime(2026, 5, 6, 12, 0, 0)
online_records = [
SimpleNamespace(
start_timestamp=start_time - timedelta(minutes=30),
end_timestamp=start_time + timedelta(minutes=30),
),
SimpleNamespace(
start_timestamp=start_time + timedelta(hours=1),
end_timestamp=end_time + timedelta(minutes=30),
),
]
auto_commit_calls = _patch_session_results(
monkeypatch,
[
_Result(first_value=(3, 1.5, 900, 2.5)),
_Result(all_values=online_records),
],
)
def _fake_count_messages(**kwargs: Any) -> int:
return 5 if kwargs.get("has_reply_to") is None else 2
monkeypatch.setattr(statistics_service, "count_messages", _fake_count_messages)
summary = await statistics_service.get_summary_statistics(start_time, end_time)
assert summary.total_requests == 3
assert summary.total_cost == 1.5
assert summary.total_tokens == 900
assert summary.avg_response_time == 2.5
assert summary.online_time == 5400
assert summary.total_messages == 5
assert summary.total_replies == 2
assert summary.cost_per_hour == pytest.approx(1.0)
assert summary.tokens_per_hour == pytest.approx(600.0)
assert auto_commit_calls == [False, False]
@pytest.mark.asyncio
async def test_get_model_statistics_groups_by_display_model_name(monkeypatch: pytest.MonkeyPatch) -> None:
now = datetime(2026, 5, 6, 12, 0, 0)
records = [
SimpleNamespace(
model_assign_name="chat-main",
model_name="gpt-a",
cost=0.4,
total_tokens=100,
time_cost=2.0,
),
SimpleNamespace(
model_assign_name="chat-main",
model_name="gpt-a",
cost=0.6,
total_tokens=200,
time_cost=4.0,
),
SimpleNamespace(
model_assign_name=None,
model_name="gpt-b",
cost=0.2,
total_tokens=50,
time_cost=0.0,
),
]
_patch_session_results(monkeypatch, [_Result(all_values=records)])
stats = await statistics_service.get_model_statistics(now - timedelta(hours=24))
assert [item.model_name for item in stats] == ["chat-main", "gpt-b"]
assert stats[0].request_count == 2
assert stats[0].total_cost == pytest.approx(1.0)
assert stats[0].total_tokens == 300
assert stats[0].avg_response_time == pytest.approx(3.0)
assert stats[1].avg_response_time == 0.0

View File

@@ -0,0 +1,13 @@
from src.webui.routers import system
def test_is_newer_version_detects_patch_update() -> None:
assert system._is_newer_version("1.0.7", "1.0.6") is True
def test_is_newer_version_ignores_same_version_with_shorter_parts() -> None:
assert system._is_newer_version("1.0.0", "1.0") is False
def test_is_newer_version_handles_unknown_current_version() -> None:
assert system._is_newer_version("1.0.7", "unknown") is False