@@ -19,7 +19,7 @@ RUN uv sync --frozen --no-dev --no-install-project
|
||||
# Copy project source
|
||||
COPY . .
|
||||
|
||||
RUN git clone --depth 1 --branch plugin https://github.com/Mai-with-u/MaiBot-Napcat-Adapter.git plugin-templates/MaiBot-Napcat-Adapter
|
||||
RUN git clone --depth 1 --branch main https://github.com/Mai-with-u/MaiBot-Napcat-Adapter.git plugin-templates/MaiBot-Napcat-Adapter
|
||||
RUN chmod +x docker-entrypoint.sh
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
@@ -18,7 +18,7 @@ const PLUGIN_DETAILS_FILE = 'plugin_details.json'
|
||||
* 插件列表 API 响应类型(只包含我们需要的字段)
|
||||
*/
|
||||
interface PluginApiResponse {
|
||||
id: string
|
||||
id?: string
|
||||
manifest: {
|
||||
manifest_version: number
|
||||
id?: string
|
||||
@@ -110,7 +110,7 @@ export async function fetchPluginList(): Promise<ApiResponse<PluginInfo[]>> {
|
||||
console.warn('跳过无效插件数据:', item)
|
||||
return false
|
||||
}
|
||||
const pluginId = item.manifest.id || item.id
|
||||
const pluginId = item.manifest.id
|
||||
if (!pluginId) {
|
||||
console.warn('跳过缺少 ID 的插件:', item)
|
||||
return false
|
||||
@@ -122,7 +122,7 @@ export async function fetchPluginList(): Promise<ApiResponse<PluginInfo[]>> {
|
||||
return true
|
||||
})
|
||||
.map((item) => ({
|
||||
id: item.manifest.id || item.id,
|
||||
id: item.manifest.id!,
|
||||
manifest: normalizePluginManifest(item.manifest),
|
||||
downloads: 0,
|
||||
rating: 0,
|
||||
|
||||
@@ -25,7 +25,7 @@ services:
|
||||
- ./data/MaiMBot/emoji:/data/emoji # 持久化表情包
|
||||
- ./data/MaiMBot/plugins:/MaiMBot/plugins # 插件目录
|
||||
- ./data/MaiMBot/logs:/MaiMBot/logs # 日志目录
|
||||
- ./depends-data:/MaiMBot/depends-data:ro # 运行时资源文件
|
||||
- ./depends-data:/MaiMBot/depends-data # 运行时资源文件
|
||||
# - site-packages:/usr/local/lib/python3.13/site-packages # 持久化Python包,需要时启用
|
||||
restart: always
|
||||
networks:
|
||||
|
||||
@@ -1 +1 @@
|
||||
请用中文描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题、直观感受,输出为一段平文本,最多30字,请注意不要分点,就输出一段文本
|
||||
请用中文详细描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题、直观感受,输出为一段平文本,最多100字,请注意不要分点,就输出一段文本
|
||||
|
||||
@@ -4,9 +4,9 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "MaiBot"
|
||||
version = "1.0.0"
|
||||
version = "1.0.0-pre.16"
|
||||
description = "MaiCore 是一个基于大语言模型的可交互智能体"
|
||||
requires-python = ">=3.10"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"Babel>=2.17.0",
|
||||
"aiohttp>=3.12.14",
|
||||
|
||||
@@ -6,6 +6,7 @@ from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo
|
||||
from src.common.data_models.message_component_data_model import AtComponent, MessageSequence, ReplyComponent, TextComponent
|
||||
from src.config.config import global_config
|
||||
from src.maisaka.builtin_tool.context import BuiltinToolRuntimeContext
|
||||
from src.maisaka.runtime import MaisakaHeartFlowChatting
|
||||
|
||||
|
||||
def _build_sent_message() -> SessionMessage:
|
||||
@@ -63,7 +64,9 @@ def test_post_process_reply_message_sequences_converts_at_marker_before_bracket_
|
||||
)
|
||||
)
|
||||
)
|
||||
runtime = SimpleNamespace(_source_messages_by_id={"12160142": target_message})
|
||||
runtime = SimpleNamespace(
|
||||
find_source_message_by_id=lambda message_id: target_message if message_id == "12160142" else None
|
||||
)
|
||||
engine = SimpleNamespace(_get_runtime_manager=lambda: None)
|
||||
tool_ctx = BuiltinToolRuntimeContext(engine=engine, runtime=runtime)
|
||||
|
||||
@@ -85,7 +88,7 @@ def test_post_process_reply_message_sequences_ignores_at_marker_when_disabled(mo
|
||||
"src.maisaka.builtin_tool.context.process_llm_response",
|
||||
lambda text: [text.strip()] if text.strip() else [],
|
||||
)
|
||||
runtime = SimpleNamespace(_source_messages_by_id={})
|
||||
runtime = SimpleNamespace(find_source_message_by_id=lambda message_id: None)
|
||||
engine = SimpleNamespace(_get_runtime_manager=lambda: None)
|
||||
tool_ctx = BuiltinToolRuntimeContext(engine=engine, runtime=runtime)
|
||||
|
||||
@@ -96,3 +99,15 @@ def test_post_process_reply_message_sequences_ignores_at_marker_when_disabled(mo
|
||||
assert len(components) == 1
|
||||
assert isinstance(components[0], TextComponent)
|
||||
assert components[0].text == "at[12160142] 就这个群"
|
||||
|
||||
|
||||
def test_runtime_finds_source_message_from_history() -> None:
|
||||
target_message = _build_sent_message()
|
||||
runtime = object.__new__(MaisakaHeartFlowChatting)
|
||||
runtime._chat_history = [
|
||||
SimpleNamespace(message_id="other-message-id", original_message=SimpleNamespace()),
|
||||
SimpleNamespace(message_id="real-message-id", original_message=target_message),
|
||||
]
|
||||
|
||||
assert runtime.find_source_message_by_id("real-message-id") is target_message
|
||||
assert runtime.find_source_message_by_id("missing-message-id") is None
|
||||
|
||||
105
pytests/test_maisaka_memory_retention.py
Normal file
105
pytests/test_maisaka_memory_retention.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import time
|
||||
|
||||
from src.chat.heart_flow import heartflow_manager as heartflow_manager_module
|
||||
from src.chat.heart_flow.heartflow_manager import HEARTFLOW_ACTIVE_RETENTION_SECONDS, HeartflowManager
|
||||
from src.learners.expression_learner import ExpressionLearner
|
||||
from src.maisaka.runtime import MAX_RETAINED_MESSAGE_CACHE_SIZE, MaisakaHeartFlowChatting
|
||||
|
||||
|
||||
def _build_runtime_with_messages(message_count: int) -> MaisakaHeartFlowChatting:
|
||||
runtime = object.__new__(MaisakaHeartFlowChatting)
|
||||
runtime.log_prefix = "[test]"
|
||||
runtime.message_cache = [SimpleNamespace(message_id=f"msg-{index}") for index in range(message_count)]
|
||||
runtime._last_processed_index = message_count
|
||||
runtime._expression_learner = ExpressionLearner("session-1")
|
||||
runtime._expression_learner.mark_all_processed(runtime.message_cache)
|
||||
return runtime
|
||||
|
||||
|
||||
def test_prune_processed_message_cache_keeps_bounded_recent_window() -> None:
|
||||
runtime = _build_runtime_with_messages(MAX_RETAINED_MESSAGE_CACHE_SIZE + 25)
|
||||
|
||||
runtime._prune_processed_message_cache()
|
||||
|
||||
assert len(runtime.message_cache) == MAX_RETAINED_MESSAGE_CACHE_SIZE
|
||||
assert runtime.message_cache[0].message_id == "msg-25"
|
||||
assert runtime._last_processed_index == MAX_RETAINED_MESSAGE_CACHE_SIZE
|
||||
assert runtime._expression_learner.last_processed_index == MAX_RETAINED_MESSAGE_CACHE_SIZE
|
||||
|
||||
|
||||
def test_prune_processed_message_cache_keeps_unlearned_messages() -> None:
|
||||
runtime = _build_runtime_with_messages(MAX_RETAINED_MESSAGE_CACHE_SIZE + 25)
|
||||
runtime._expression_learner.discard_processed_prefix(MAX_RETAINED_MESSAGE_CACHE_SIZE + 5)
|
||||
|
||||
runtime._prune_processed_message_cache()
|
||||
|
||||
assert len(runtime.message_cache) == MAX_RETAINED_MESSAGE_CACHE_SIZE + 5
|
||||
assert runtime.message_cache[0].message_id == "msg-20"
|
||||
assert runtime._expression_learner.last_processed_index == 0
|
||||
|
||||
|
||||
def test_collect_pending_messages_uses_single_pending_received_time() -> None:
|
||||
runtime = _build_runtime_with_messages(2)
|
||||
runtime._last_processed_index = 0
|
||||
runtime._oldest_pending_message_received_at = 123.0
|
||||
runtime._last_message_received_at = 456.0
|
||||
runtime._reply_latency_measurement_started_at = None
|
||||
|
||||
pending_messages = runtime._collect_pending_messages()
|
||||
|
||||
assert [message.message_id for message in pending_messages] == ["msg-0", "msg-1"]
|
||||
assert runtime._reply_latency_measurement_started_at == 123.0
|
||||
assert runtime._oldest_pending_message_received_at is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heartflow_manager_evicts_lru_chat_over_limit(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
manager = HeartflowManager()
|
||||
stopped_session_ids: list[str] = []
|
||||
old_active_at = time.time() - HEARTFLOW_ACTIVE_RETENTION_SECONDS - 1
|
||||
|
||||
class FakeChat:
|
||||
def __init__(self, session_id: str) -> None:
|
||||
self.session_id = session_id
|
||||
|
||||
async def stop(self) -> None:
|
||||
stopped_session_ids.append(self.session_id)
|
||||
|
||||
monkeypatch.setattr(heartflow_manager_module, "HEARTFLOW_MAX_ACTIVE_CHATS", 2)
|
||||
manager.heartflow_chat_list["session-1"] = FakeChat("session-1")
|
||||
manager.heartflow_chat_list["session-2"] = FakeChat("session-2")
|
||||
manager.heartflow_chat_list["session-3"] = FakeChat("session-3")
|
||||
manager._chat_last_active_at["session-1"] = old_active_at
|
||||
manager._chat_last_active_at["session-2"] = old_active_at
|
||||
manager._chat_last_active_at["session-3"] = time.time()
|
||||
|
||||
await manager._evict_over_limit_chats(protected_session_id="session-3")
|
||||
|
||||
assert stopped_session_ids == ["session-1"]
|
||||
assert list(manager.heartflow_chat_list) == ["session-2", "session-3"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heartflow_manager_keeps_recent_chats_even_over_limit(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
manager = HeartflowManager()
|
||||
stopped_session_ids: list[str] = []
|
||||
|
||||
class FakeChat:
|
||||
def __init__(self, session_id: str) -> None:
|
||||
self.session_id = session_id
|
||||
|
||||
async def stop(self) -> None:
|
||||
stopped_session_ids.append(self.session_id)
|
||||
|
||||
monkeypatch.setattr(heartflow_manager_module, "HEARTFLOW_MAX_ACTIVE_CHATS", 2)
|
||||
for session_id in ("session-1", "session-2", "session-3"):
|
||||
manager.heartflow_chat_list[session_id] = FakeChat(session_id)
|
||||
manager._chat_last_active_at[session_id] = time.time()
|
||||
|
||||
await manager._evict_over_limit_chats(protected_session_id="session-3")
|
||||
|
||||
assert stopped_session_ids == []
|
||||
assert list(manager.heartflow_chat_list) == ["session-1", "session-2", "session-3"]
|
||||
@@ -199,7 +199,7 @@ async def test_reply_tool_puts_monitor_detail_into_metadata(monkeypatch: pytest.
|
||||
),
|
||||
)
|
||||
runtime = SimpleNamespace(
|
||||
_source_messages_by_id={"msg-1": target_message},
|
||||
find_source_message_by_id=lambda message_id: target_message if message_id == "msg-1" else None,
|
||||
log_prefix="[test]",
|
||||
chat_stream=SimpleNamespace(platform=reply_tool_module.CLI_PLATFORM_NAME),
|
||||
session_id="session-1",
|
||||
|
||||
@@ -1169,6 +1169,8 @@ class TestVersionComparator:
|
||||
|
||||
assert VersionComparator.normalize_version("0.8.0-snapshot.1") == "0.8.0"
|
||||
assert VersionComparator.normalize_version("1.2") == "1.2.0"
|
||||
assert VersionComparator.normalize_version("1.0.0rc16") == "1.0.0"
|
||||
assert VersionComparator.normalize_version("1.0.0-pre.16") == "1.0.0"
|
||||
assert VersionComparator.normalize_version("") == "0.0.0"
|
||||
|
||||
def test_compare(self):
|
||||
@@ -2890,22 +2892,105 @@ class TestIntegration:
|
||||
monkeypatch.setattr(real_database_service, "db_get", fake_db_get)
|
||||
monkeypatch.setattr(real_db_models, "DemoTable", DummyModel, raising=False)
|
||||
|
||||
result = await integration_module.PluginRuntimeManager._cap_database_get(
|
||||
manager = object.__new__(integration_module.PluginRuntimeManager)
|
||||
result = await manager._cap_database_get(
|
||||
"plugin_a",
|
||||
"database.get",
|
||||
{
|
||||
"table": "DemoTable",
|
||||
"model_name": "DemoTable",
|
||||
"filters": {"status": "active"},
|
||||
"limit": 5,
|
||||
},
|
||||
)
|
||||
|
||||
assert result == {"success": True, "result": [{"id": 1}]}
|
||||
assert result == [{"id": 1}]
|
||||
assert captured["model_class"] is DummyModel
|
||||
assert captured["filters"] == {"status": "active"}
|
||||
assert captured["limit"] == 5
|
||||
assert captured["single_result"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cap_database_get_response_is_not_double_wrapped(self, monkeypatch):
|
||||
from src.plugin_runtime import integration as integration_module
|
||||
import src.common.database.database_model as real_db_models
|
||||
from src.plugin_runtime.host.capability_service import CapabilityService
|
||||
from src.plugin_runtime.protocol.envelope import CapabilityRequestPayload, Envelope, MessageType
|
||||
from src.services import database_service as real_database_service
|
||||
|
||||
class AllowAllAuthorization:
|
||||
def check_capability(self, plugin_id, capability):
|
||||
return True, ""
|
||||
|
||||
class DummyModel:
|
||||
pass
|
||||
|
||||
async def fake_db_get(model_class, filters=None, limit=None, order_by=None, single_result=False):
|
||||
return {"id": 1, "full_path": "E:\\test.png"}
|
||||
|
||||
monkeypatch.setattr(real_database_service, "db_get", fake_db_get)
|
||||
monkeypatch.setattr(real_db_models, "DemoTable", DummyModel, raising=False)
|
||||
|
||||
manager = object.__new__(integration_module.PluginRuntimeManager)
|
||||
service = CapabilityService(AllowAllAuthorization())
|
||||
service.register_capability("database.get", manager._cap_database_get)
|
||||
|
||||
request = Envelope(
|
||||
request_id=1,
|
||||
message_type=MessageType.REQUEST,
|
||||
method="cap.call",
|
||||
plugin_id="plugin_a",
|
||||
payload=CapabilityRequestPayload(
|
||||
capability="database.get",
|
||||
args={"model_name": "DemoTable", "single_result": True},
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
response = await service.handle_capability_request(request)
|
||||
|
||||
assert response.payload == {
|
||||
"success": True,
|
||||
"result": {"id": 1, "full_path": "E:\\test.png"},
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cap_database_success_handlers_return_raw_results(self, monkeypatch):
|
||||
from src.plugin_runtime import integration as integration_module
|
||||
import src.common.database.database_model as real_db_models
|
||||
from src.services import database_service as real_database_service
|
||||
|
||||
class DummyModel:
|
||||
pass
|
||||
|
||||
async def fake_db_get(**kwargs):
|
||||
return [{"id": 1}]
|
||||
|
||||
async def fake_db_save(**kwargs):
|
||||
return {"id": 2}
|
||||
|
||||
async def fake_db_delete(**kwargs):
|
||||
return 3
|
||||
|
||||
async def fake_db_count(**kwargs):
|
||||
return 4
|
||||
|
||||
monkeypatch.setattr(real_database_service, "db_get", fake_db_get)
|
||||
monkeypatch.setattr(real_database_service, "db_save", fake_db_save)
|
||||
monkeypatch.setattr(real_database_service, "db_delete", fake_db_delete)
|
||||
monkeypatch.setattr(real_database_service, "db_count", fake_db_count)
|
||||
monkeypatch.setattr(real_db_models, "DemoTable", DummyModel, raising=False)
|
||||
|
||||
manager = object.__new__(integration_module.PluginRuntimeManager)
|
||||
base_args = {"model_name": "DemoTable"}
|
||||
|
||||
assert await manager._cap_database_query("plugin_a", "database.query", base_args) == [{"id": 1}]
|
||||
assert await manager._cap_database_save(
|
||||
"plugin_a", "database.save", {**base_args, "data": {"name": "demo"}}
|
||||
) == {"id": 2}
|
||||
assert await manager._cap_database_delete(
|
||||
"plugin_a", "database.delete", {**base_args, "filters": {"id": 2}}
|
||||
) == 3
|
||||
assert await manager._cap_database_count("plugin_a", "database.count", base_args) == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_component_enable_rejects_ambiguous_short_name(self, monkeypatch):
|
||||
from src.plugin_runtime import integration as integration_module
|
||||
|
||||
187
pytests/webui/test_model_routes.py
Normal file
187
pytests/webui/test_model_routes.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""模型路由测试
|
||||
|
||||
验证 Gemini 提供商连接测试会使用查询参数传递 API Key,
|
||||
并且不会回退到 OpenAI 兼容接口使用的 Bearer 认证方式。
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def load_model_routes(monkeypatch: pytest.MonkeyPatch):
|
||||
"""在导入路由前 stub 配置与认证依赖模块,避免测试时触发真实初始化。"""
|
||||
config_module = ModuleType("src.config.config")
|
||||
config_module.__dict__["CONFIG_DIR"] = "."
|
||||
monkeypatch.setitem(sys.modules, "src.config.config", config_module)
|
||||
|
||||
dependencies_module = ModuleType("src.webui.dependencies")
|
||||
|
||||
async def require_auth():
|
||||
return "test-token"
|
||||
|
||||
dependencies_module.__dict__["require_auth"] = require_auth
|
||||
monkeypatch.setitem(sys.modules, "src.webui.dependencies", dependencies_module)
|
||||
|
||||
sys.modules.pop("src.webui.routers.model", None)
|
||||
return importlib.import_module("src.webui.routers.model")
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
"""简化版 HTTP 响应对象。"""
|
||||
|
||||
def __init__(self, status_code: int):
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
def build_async_client_factory(
|
||||
responses: list[FakeResponse],
|
||||
calls: list[dict[str, Any]],
|
||||
):
|
||||
"""构造一个可记录请求参数的 AsyncClient 替身。"""
|
||||
|
||||
response_iter = iter(responses)
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
async def __aenter__(self) -> "FakeAsyncClient":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
return False
|
||||
|
||||
async def get(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> FakeResponse:
|
||||
calls.append(
|
||||
{
|
||||
"url": url,
|
||||
"headers": headers or {},
|
||||
"params": params or {},
|
||||
}
|
||||
)
|
||||
return next(response_iter)
|
||||
|
||||
return FakeAsyncClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_test_provider_connection_uses_query_api_key_for_gemini(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Gemini 连接测试应通过查询参数传递 API Key。"""
|
||||
model_routes = load_model_routes(monkeypatch)
|
||||
calls: list[dict[str, Any]] = []
|
||||
fake_client_class = build_async_client_factory(
|
||||
responses=[FakeResponse(200), FakeResponse(200)],
|
||||
calls=calls,
|
||||
)
|
||||
monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class)
|
||||
|
||||
result = await model_routes.test_provider_connection(
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta",
|
||||
api_key="valid-gemini-key",
|
||||
client_type="gemini",
|
||||
)
|
||||
|
||||
assert result["network_ok"] is True
|
||||
assert result["api_key_valid"] is True
|
||||
assert len(calls) == 2
|
||||
|
||||
network_call = calls[0]
|
||||
validation_call = calls[1]
|
||||
|
||||
assert network_call["url"] == "https://generativelanguage.googleapis.com/v1beta"
|
||||
assert network_call["headers"] == {}
|
||||
assert network_call["params"] == {}
|
||||
|
||||
assert validation_call["url"] == "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
assert validation_call["params"] == {"key": "valid-gemini-key"}
|
||||
assert validation_call["headers"] == {"Content-Type": "application/json"}
|
||||
assert "Authorization" not in validation_call["headers"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_test_provider_connection_uses_bearer_auth_for_openai_compatible(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""非 Gemini 提供商连接测试应继续使用 Bearer 认证。"""
|
||||
model_routes = load_model_routes(monkeypatch)
|
||||
calls: list[dict[str, Any]] = []
|
||||
fake_client_class = build_async_client_factory(
|
||||
responses=[FakeResponse(200), FakeResponse(200)],
|
||||
calls=calls,
|
||||
)
|
||||
monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class)
|
||||
|
||||
result = await model_routes.test_provider_connection(
|
||||
base_url="https://example.com/v1",
|
||||
api_key="valid-openai-key",
|
||||
client_type="openai",
|
||||
)
|
||||
|
||||
assert result["network_ok"] is True
|
||||
assert result["api_key_valid"] is True
|
||||
assert len(calls) == 2
|
||||
|
||||
validation_call = calls[1]
|
||||
|
||||
assert validation_call["url"] == "https://example.com/v1/models"
|
||||
assert validation_call["params"] == {}
|
||||
assert validation_call["headers"]["Content-Type"] == "application/json"
|
||||
assert validation_call["headers"]["Authorization"] == "Bearer valid-openai-key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_test_provider_connection_by_name_forwards_provider_client_type(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
) -> None:
|
||||
"""按提供商名称测试连接时,应透传配置中的 client_type。"""
|
||||
model_routes = load_model_routes(monkeypatch)
|
||||
config_path = tmp_path / "model_config.toml"
|
||||
config_path.write_text(
|
||||
"""
|
||||
[[api_providers]]
|
||||
name = "Gemini"
|
||||
base_url = "https://generativelanguage.googleapis.com/v1beta"
|
||||
api_key = "valid-gemini-key"
|
||||
client_type = "gemini"
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(model_routes, "CONFIG_DIR", str(tmp_path))
|
||||
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
|
||||
async def fake_test_provider_connection(**kwargs: Any) -> dict[str, Any]:
|
||||
captured_kwargs.update(kwargs)
|
||||
return {
|
||||
"network_ok": True,
|
||||
"api_key_valid": True,
|
||||
"latency_ms": 12.34,
|
||||
"error": None,
|
||||
"http_status": 200,
|
||||
}
|
||||
|
||||
monkeypatch.setattr(model_routes, "test_provider_connection", fake_test_provider_connection)
|
||||
|
||||
result = await model_routes.test_provider_connection_by_name(provider_name="Gemini")
|
||||
|
||||
assert result["network_ok"] is True
|
||||
assert result["api_key_valid"] is True
|
||||
assert captured_kwargs == {
|
||||
"base_url": "https://generativelanguage.googleapis.com/v1beta",
|
||||
"api_key": "valid-gemini-key",
|
||||
"client_type": "gemini",
|
||||
}
|
||||
@@ -61,3 +61,76 @@ def test_resolve_installed_plugin_path_accepts_manifest_id_case_mismatch(client:
|
||||
|
||||
assert plugin_path is not None
|
||||
assert plugin_path.name == "demo_plugin"
|
||||
|
||||
|
||||
def test_install_plugin_preserves_manifest_declared_id(client: TestClient, monkeypatch):
|
||||
class FakeGitMirrorService:
|
||||
async def clone_repository(self, **kwargs):
|
||||
target_path = kwargs["target_path"]
|
||||
target_path.mkdir(parents=True, exist_ok=True)
|
||||
(target_path / "_manifest.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"manifest_version": 2,
|
||||
"id": "author.declared",
|
||||
"name": "Declared Plugin",
|
||||
"version": "1.0.0",
|
||||
"author": {"name": "author"},
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return {"success": True}
|
||||
|
||||
monkeypatch.setattr(management_module, "get_git_mirror_service", lambda: FakeGitMirrorService())
|
||||
|
||||
response = client.post(
|
||||
"/api/webui/plugins/install",
|
||||
json={
|
||||
"plugin_id": "market.plugin",
|
||||
"repository_url": "https://github.com/author/declared",
|
||||
"branch": "main",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
plugin_path = support_module.resolve_installed_plugin_path("author.declared")
|
||||
assert plugin_path is not None
|
||||
manifest = json.loads((plugin_path / "_manifest.json").read_text(encoding="utf-8"))
|
||||
assert manifest["id"] == "author.declared"
|
||||
|
||||
|
||||
def test_install_plugin_backfills_missing_manifest_id(client: TestClient, monkeypatch):
|
||||
class FakeGitMirrorService:
|
||||
async def clone_repository(self, **kwargs):
|
||||
target_path = kwargs["target_path"]
|
||||
target_path.mkdir(parents=True, exist_ok=True)
|
||||
(target_path / "_manifest.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"manifest_version": 2,
|
||||
"name": "Legacy Plugin",
|
||||
"version": "1.0.0",
|
||||
"author": {"name": "author"},
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return {"success": True}
|
||||
|
||||
monkeypatch.setattr(management_module, "get_git_mirror_service", lambda: FakeGitMirrorService())
|
||||
|
||||
response = client.post(
|
||||
"/api/webui/plugins/install",
|
||||
json={
|
||||
"plugin_id": "market.legacy",
|
||||
"repository_url": "https://github.com/author/legacy",
|
||||
"branch": "main",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
plugin_path = support_module.resolve_installed_plugin_path("market.legacy")
|
||||
assert plugin_path is not None
|
||||
manifest = json.loads((plugin_path / "_manifest.json").read_text(encoding="utf-8"))
|
||||
assert manifest["id"] == "market.legacy"
|
||||
|
||||
@@ -16,6 +16,7 @@ from src.common.logger import get_logger
|
||||
from ..storage import VectorStore, GraphStore, MetadataStore
|
||||
from ..embedding import EmbeddingAPIAdapter
|
||||
from ..utils.matcher import AhoCorasick
|
||||
from ..utils.metadata import coerce_metadata_dict
|
||||
from ..utils.time_parser import format_timestamp
|
||||
from .graph_relation_recall import GraphRelationRecallConfig, GraphRelationRecallService
|
||||
from .pagerank import PersonalizedPageRank, PageRankConfig
|
||||
@@ -482,7 +483,7 @@ class DualPathRetriever:
|
||||
score=float(item.score),
|
||||
result_type=item.result_type,
|
||||
source=item.source,
|
||||
metadata=dict(item.metadata or {}),
|
||||
metadata=coerce_metadata_dict(item.metadata),
|
||||
)
|
||||
|
||||
def _extract_graph_seed_entities(self, query: str, limit: int = 2) -> List[str]:
|
||||
@@ -762,7 +763,7 @@ class DualPathRetriever:
|
||||
existing = self._clone_retrieval_result(item)
|
||||
merged[item.hash_value] = existing
|
||||
else:
|
||||
for key, value in dict(item.metadata or {}).items():
|
||||
for key, value in coerce_metadata_dict(item.metadata).items():
|
||||
if key not in existing.metadata or existing.metadata.get(key) in (None, "", []):
|
||||
existing.metadata[key] = value
|
||||
source_sets.setdefault(item.hash_value, set()).add(str(item.source or "").strip() or "relation_search")
|
||||
|
||||
@@ -25,6 +25,7 @@ from ..utils.episode_retrieval_service import EpisodeRetrievalService
|
||||
from ..utils.episode_segmentation_service import EpisodeSegmentationService
|
||||
from ..utils.episode_service import EpisodeService
|
||||
from ..utils.hash import compute_hash, normalize_text
|
||||
from ..utils.metadata import coerce_metadata_dict
|
||||
from ..utils.person_profile_service import PersonProfileService
|
||||
from ..utils.relation_write_service import RelationWriteService
|
||||
from ..utils.retrieval_tuning_manager import RetrievalTuningManager
|
||||
@@ -871,7 +872,7 @@ class SDKMemoryKernel:
|
||||
"detail": "chat_filtered",
|
||||
}
|
||||
|
||||
summary_meta = dict(metadata or {})
|
||||
summary_meta = coerce_metadata_dict(metadata)
|
||||
summary_meta.setdefault("kind", "chat_summary")
|
||||
if not str(text or "").strip() or bool(summary_meta.get("generate_from_chat", False)):
|
||||
result = await self.summarize_chat_stream(
|
||||
@@ -961,7 +962,7 @@ class SDKMemoryKernel:
|
||||
participant_tokens = self._tokens(participants)
|
||||
entity_tokens = self._merge_tokens(entities, person_tokens, participant_tokens)
|
||||
source = self._build_source(source_type, chat_id, person_tokens)
|
||||
paragraph_meta = dict(metadata or {})
|
||||
paragraph_meta = coerce_metadata_dict(metadata)
|
||||
paragraph_meta.update(
|
||||
{
|
||||
"external_id": external_token,
|
||||
|
||||
11
src/A_memorix/core/utils/metadata.py
Normal file
11
src/A_memorix/core/utils/metadata.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def coerce_metadata_dict(value: Any) -> Dict[str, Any]:
|
||||
"""返回字典,如果输入值不是字典则返回空字典。"""
|
||||
if isinstance(value, Mapping):
|
||||
return dict(value)
|
||||
return {}
|
||||
@@ -27,6 +27,7 @@ from ..retrieval import (
|
||||
GraphRelationRecallConfig,
|
||||
)
|
||||
from ..storage import MetadataStore, GraphStore, VectorStore
|
||||
from .metadata import coerce_metadata_dict
|
||||
|
||||
logger = get_logger("A_Memorix.PersonProfileService")
|
||||
|
||||
@@ -334,7 +335,7 @@ class PersonProfileService:
|
||||
if not pid:
|
||||
return False
|
||||
|
||||
metadata = self._metadata_dict(relation.get("metadata"))
|
||||
metadata = coerce_metadata_dict(relation.get("metadata"))
|
||||
if str(metadata.get("person_id", "") or "").strip() == pid:
|
||||
return True
|
||||
if pid in self._list_tokens(metadata.get("person_ids")):
|
||||
@@ -350,7 +351,7 @@ class PersonProfileService:
|
||||
payload = {
|
||||
"hash": source_paragraph,
|
||||
"source": str(paragraph.get("source", "") or ""),
|
||||
"metadata": self._metadata_dict(paragraph.get("metadata")),
|
||||
"metadata": coerce_metadata_dict(paragraph.get("metadata")),
|
||||
}
|
||||
return self._is_evidence_bound_to_person(payload, person_id=pid)
|
||||
|
||||
@@ -385,15 +386,11 @@ class PersonProfileService:
|
||||
"score": 1.1,
|
||||
"content": content[:220],
|
||||
"source": str(row.get("source", "") or source),
|
||||
"metadata": dict(row.get("metadata", {}) or {}),
|
||||
"metadata": coerce_metadata_dict(row.get("metadata")),
|
||||
}
|
||||
)
|
||||
return self._filter_stale_paragraph_evidence(evidence)
|
||||
|
||||
@staticmethod
|
||||
def _metadata_dict(value: Any) -> Dict[str, Any]:
|
||||
return dict(value) if isinstance(value, dict) else {}
|
||||
|
||||
@staticmethod
|
||||
def _list_tokens(value: Any) -> List[str]:
|
||||
if value is None:
|
||||
@@ -414,7 +411,7 @@ class PersonProfileService:
|
||||
if not pid:
|
||||
return False
|
||||
|
||||
metadata = self._metadata_dict(item.get("metadata"))
|
||||
metadata = coerce_metadata_dict(item.get("metadata"))
|
||||
source = str(item.get("source", "") or metadata.get("source", "") or "").strip()
|
||||
if source == f"person_fact:{pid}":
|
||||
return True
|
||||
@@ -440,15 +437,15 @@ class PersonProfileService:
|
||||
paragraph_hash: str,
|
||||
metadata: Dict[str, Any],
|
||||
) -> Tuple[Dict[str, Any], str]:
|
||||
merged = self._metadata_dict(metadata)
|
||||
merged = coerce_metadata_dict(metadata)
|
||||
source = str(merged.get("source", "") or "").strip()
|
||||
try:
|
||||
paragraph = self.metadata_store.get_paragraph(paragraph_hash)
|
||||
except Exception:
|
||||
paragraph = None
|
||||
if isinstance(paragraph, dict):
|
||||
paragraph_metadata = paragraph.get("metadata", {}) or {}
|
||||
if isinstance(paragraph_metadata, dict):
|
||||
paragraph_metadata = coerce_metadata_dict(paragraph.get("metadata"))
|
||||
if paragraph_metadata:
|
||||
merged = {**paragraph_metadata, **merged}
|
||||
source = source or str(paragraph.get("source", "") or "").strip()
|
||||
source_type = str(merged.get("source_type", "") or "").strip() or self._source_type_from_source(source)
|
||||
@@ -538,7 +535,7 @@ class PersonProfileService:
|
||||
"score": 0.0,
|
||||
"content": str(para.get("content", ""))[:180],
|
||||
"source": str(para.get("source", "") or ""),
|
||||
"metadata": self._metadata_dict(para.get("metadata")),
|
||||
"metadata": coerce_metadata_dict(para.get("metadata")),
|
||||
}
|
||||
)
|
||||
if not self._is_evidence_bound_to_person(fallback[-1], person_id=person_id):
|
||||
@@ -562,18 +559,18 @@ class PersonProfileService:
|
||||
logger.warning(f"向量证据召回失败: alias={alias}, err={e}")
|
||||
continue
|
||||
for item in results:
|
||||
h = str(getattr(item, "hash_value", "") or "")
|
||||
h = str(item.hash_value or "")
|
||||
if not h or h in seen_hash:
|
||||
continue
|
||||
metadata, source = self._enrich_paragraph_evidence_metadata(
|
||||
h,
|
||||
self._metadata_dict(getattr(item, "metadata", {})),
|
||||
coerce_metadata_dict(item.metadata),
|
||||
)
|
||||
payload = {
|
||||
"hash": h,
|
||||
"type": str(getattr(item, "result_type", "")),
|
||||
"score": float(getattr(item, "score", 0.0) or 0.0),
|
||||
"content": str(getattr(item, "content", "") or "")[:220],
|
||||
"type": str(item.result_type),
|
||||
"score": float(item.score or 0.0),
|
||||
"content": str(item.content or "")[:220],
|
||||
"source": source,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
@@ -14,7 +14,8 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from ..retrieval import TemporalQueryOptions
|
||||
from ..retrieval import RetrievalResult, TemporalQueryOptions
|
||||
from .metadata import coerce_metadata_dict
|
||||
from .search_postprocess import (
|
||||
apply_safe_content_dedup,
|
||||
maybe_apply_smart_path_fallback,
|
||||
@@ -286,8 +287,11 @@ class SearchExecutionService:
|
||||
)
|
||||
|
||||
async def _executor() -> Dict[str, Any]:
|
||||
original_ppr = bool(getattr(retriever.config, "enable_ppr", True))
|
||||
setattr(retriever.config, "enable_ppr", bool(request.enable_ppr))
|
||||
retriever_config = getattr(retriever, "config", None)
|
||||
has_runtime_ppr_switch = retriever_config is not None and hasattr(retriever_config, "enable_ppr")
|
||||
original_ppr = bool(retriever_config.enable_ppr) if has_runtime_ppr_switch else None
|
||||
if has_runtime_ppr_switch:
|
||||
retriever_config.enable_ppr = bool(request.enable_ppr)
|
||||
started_at = time.time()
|
||||
try:
|
||||
retrieved = await retriever.retrieve(
|
||||
@@ -380,7 +384,8 @@ class SearchExecutionService:
|
||||
elapsed_ms = (time.time() - started_at) * 1000.0
|
||||
return {"results": retrieved, "elapsed_ms": elapsed_ms}
|
||||
finally:
|
||||
setattr(retriever.config, "enable_ppr", original_ppr)
|
||||
if has_runtime_ppr_switch:
|
||||
retriever_config.enable_ppr = bool(original_ppr)
|
||||
|
||||
dedup_hit = False
|
||||
try:
|
||||
@@ -421,18 +426,18 @@ class SearchExecutionService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def to_serializable_results(results: List[Any]) -> List[Dict[str, Any]]:
|
||||
def to_serializable_results(results: List[RetrievalResult]) -> List[Dict[str, Any]]:
|
||||
serialized: List[Dict[str, Any]] = []
|
||||
for item in results:
|
||||
metadata = dict(getattr(item, "metadata", {}) or {})
|
||||
metadata = coerce_metadata_dict(item.metadata)
|
||||
if "time_meta" not in metadata:
|
||||
metadata["time_meta"] = {}
|
||||
serialized.append(
|
||||
{
|
||||
"hash": getattr(item, "hash_value", ""),
|
||||
"type": getattr(item, "result_type", ""),
|
||||
"score": float(getattr(item, "score", 0.0)),
|
||||
"content": getattr(item, "content", ""),
|
||||
"hash": item.hash_value,
|
||||
"type": item.result_type,
|
||||
"score": float(item.score),
|
||||
"content": item.content,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,31 +1,39 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Dict
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.maisaka.runtime import MaisakaHeartFlowChatting
|
||||
|
||||
logger = get_logger("heartflow")
|
||||
|
||||
HEARTFLOW_ACTIVE_RETENTION_SECONDS = 24 * 60 * 60
|
||||
HEARTFLOW_MAX_ACTIVE_CHATS = 100
|
||||
|
||||
|
||||
class HeartflowManager:
|
||||
"""管理 session 级别的 Maisaka 心流实例。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.heartflow_chat_list: Dict[str, MaisakaHeartFlowChatting] = {}
|
||||
self.heartflow_chat_list: OrderedDict[str, MaisakaHeartFlowChatting] = OrderedDict()
|
||||
self._chat_create_locks: Dict[str, asyncio.Lock] = {}
|
||||
self._chat_last_active_at: Dict[str, float] = {}
|
||||
|
||||
async def get_or_create_heartflow_chat(self, session_id: str) -> MaisakaHeartFlowChatting:
|
||||
"""获取或创建指定会话对应的 Maisaka runtime。"""
|
||||
try:
|
||||
if chat := self.heartflow_chat_list.get(session_id):
|
||||
self._touch_chat(session_id)
|
||||
return chat
|
||||
|
||||
create_lock = self._chat_create_locks.setdefault(session_id, asyncio.Lock())
|
||||
async with create_lock:
|
||||
if chat := self.heartflow_chat_list.get(session_id):
|
||||
self._touch_chat(session_id)
|
||||
return chat
|
||||
|
||||
chat_session = chat_manager.get_session_by_session_id(session_id)
|
||||
@@ -35,16 +43,59 @@ class HeartflowManager:
|
||||
new_chat = MaisakaHeartFlowChatting(session_id=session_id)
|
||||
await new_chat.start()
|
||||
self.heartflow_chat_list[session_id] = new_chat
|
||||
self._touch_chat(session_id)
|
||||
await self._evict_over_limit_chats(protected_session_id=session_id)
|
||||
return new_chat
|
||||
except Exception as exc:
|
||||
logger.error(f"创建心流聊天 {session_id} 失败: {exc}", exc_info=True)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
def _touch_chat(self, session_id: str) -> None:
|
||||
"""记录会话最近活跃时间,并维护心流实例的 LRU 顺序。"""
|
||||
self._chat_last_active_at[session_id] = time.time()
|
||||
self.heartflow_chat_list.move_to_end(session_id)
|
||||
|
||||
async def _evict_over_limit_chats(self, *, protected_session_id: str) -> None:
|
||||
"""当实例数量超过上限时,仅淘汰 24 小时内无消息的旧会话。"""
|
||||
while len(self.heartflow_chat_list) > HEARTFLOW_MAX_ACTIVE_CHATS:
|
||||
session_id = self._find_evictable_session_id(protected_session_id=protected_session_id)
|
||||
if session_id is None:
|
||||
return
|
||||
await self._evict_chat(session_id, reason="cache_limit")
|
||||
|
||||
def _find_evictable_session_id(self, *, protected_session_id: str) -> str | None:
|
||||
"""按 LRU 查找超过活跃保护窗口的可淘汰会话。"""
|
||||
expire_before = time.time() - HEARTFLOW_ACTIVE_RETENTION_SECONDS
|
||||
for session_id in self.heartflow_chat_list:
|
||||
if session_id == protected_session_id:
|
||||
continue
|
||||
last_active_at = self._chat_last_active_at.get(session_id, 0.0)
|
||||
if last_active_at <= expire_before:
|
||||
return session_id
|
||||
return None
|
||||
|
||||
async def _evict_chat(self, session_id: str, *, reason: str) -> None:
|
||||
"""停止并移除指定会话的心流实例。"""
|
||||
chat = self.heartflow_chat_list.pop(session_id, None)
|
||||
self._chat_last_active_at.pop(session_id, None)
|
||||
lock = self._chat_create_locks.get(session_id)
|
||||
if lock is not None and not lock.locked():
|
||||
self._chat_create_locks.pop(session_id, None)
|
||||
if chat is None:
|
||||
return
|
||||
|
||||
try:
|
||||
await chat.stop()
|
||||
logger.info(f"已淘汰心流聊天 {session_id}: reason={reason}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"淘汰心流聊天 {session_id} 失败: {exc}", exc_info=True)
|
||||
|
||||
def adjust_talk_frequency(self, session_id: str, frequency: float) -> None:
|
||||
"""调整指定聊天流的说话频率。"""
|
||||
chat = self.heartflow_chat_list.get(session_id)
|
||||
if chat:
|
||||
self._touch_chat(session_id)
|
||||
chat.adjust_talk_frequency(frequency)
|
||||
logger.info(f"已调整聊天 {session_id} 的说话频率为 {frequency}")
|
||||
else:
|
||||
|
||||
@@ -86,7 +86,6 @@ class BaseMaisakaReplyGenerator:
|
||||
request_type=request_type,
|
||||
session_id=getattr(chat_stream, "session_id", "") if chat_stream is not None else "",
|
||||
)
|
||||
self._personality_prompt = self._build_personality_prompt()
|
||||
|
||||
def _build_personality_prompt(self) -> str:
|
||||
"""构建 replyer 使用的人设提示。"""
|
||||
@@ -272,7 +271,7 @@ class BaseMaisakaReplyGenerator:
|
||||
bot_name=global_config.bot.nickname,
|
||||
group_chat_attention_block=self._build_group_chat_attention_block(session_id),
|
||||
replyer_at_block=self._build_replyer_at_block(),
|
||||
identity=self._personality_prompt,
|
||||
identity=self._build_personality_prompt(),
|
||||
reply_style=self._select_reply_style(),
|
||||
)
|
||||
except Exception:
|
||||
|
||||
@@ -56,7 +56,7 @@ BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
||||
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
||||
LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute()
|
||||
A_MEMORIX_LEGACY_CONFIG_PATH: Path = (CONFIG_DIR / "a_memorix.toml").resolve().absolute()
|
||||
MMC_VERSION: str = "1.0.0-pre.16"
|
||||
MMC_VERSION: str = "1.0.0-pre.17"
|
||||
CONFIG_VERSION: str = "8.10.15"
|
||||
MODEL_CONFIG_VERSION: str = "1.16.1"
|
||||
|
||||
@@ -234,6 +234,7 @@ class ConfigManager:
|
||||
self._hot_reload_min_interval_s: float = 1.0
|
||||
self._hot_reload_timeout_s: float = 20.0
|
||||
self._last_hot_reload_monotonic: float = 0.0
|
||||
self.reload_revision: int = 0
|
||||
|
||||
def initialize(self):
|
||||
logger.info(t("config.current_version", version=MMC_VERSION))
|
||||
@@ -424,9 +425,7 @@ class ConfigManager:
|
||||
|
||||
self.global_config = global_config_new
|
||||
self.model_config = model_config_new
|
||||
global global_config, model_config
|
||||
global_config = global_config_new
|
||||
model_config = model_config_new
|
||||
self.reload_revision += 1
|
||||
logger.info(t("config.hot_reload_completed"))
|
||||
|
||||
for callback in list(self._reload_callbacks):
|
||||
@@ -657,8 +656,30 @@ def write_config_to_file(
|
||||
tomlkit.dump(full_config_data, f)
|
||||
|
||||
|
||||
class _ConfigProxy:
|
||||
"""稳定配置代理,确保热重载后旧导入也能读取最新配置。"""
|
||||
|
||||
def __init__(self, getter: Callable[[], ConfigBase]) -> None:
|
||||
self._getter = getter
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self._getter(), name)
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self._getter()[key]
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name == "_getter":
|
||||
object.__setattr__(self, name, value)
|
||||
return
|
||||
setattr(self._getter(), name, value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(self._getter())
|
||||
|
||||
|
||||
# generate_new_config_file(Config, BOT_CONFIG_PATH, CONFIG_VERSION)
|
||||
config_manager = ConfigManager()
|
||||
config_manager.initialize()
|
||||
global_config = config_manager.get_global_config()
|
||||
model_config = config_manager.get_model_config()
|
||||
global_config: Config = cast(Config, _ConfigProxy(config_manager.get_global_config))
|
||||
model_config: ModelConfig = cast(ModelConfig, _ConfigProxy(config_manager.get_model_config))
|
||||
|
||||
@@ -160,6 +160,23 @@ class ExpressionLearner:
|
||||
self._last_processed_index = 0
|
||||
self.min_messages_for_extraction = 10
|
||||
|
||||
@property
|
||||
def last_processed_index(self) -> int:
|
||||
"""返回表达学习已经消费到的消息缓存下标。"""
|
||||
return self._last_processed_index
|
||||
|
||||
def mark_all_processed(self, message_cache: List["SessionMessage"]) -> None:
|
||||
"""在跳过表达学习时,将现有消息标记为已处理,避免阻塞缓存裁剪。"""
|
||||
self._last_processed_index = len(message_cache)
|
||||
|
||||
def mark_processed_until(self, processed_end_index: int) -> None:
|
||||
"""将指定缓存下标之前的消息标记为已处理。"""
|
||||
self._last_processed_index = max(self._last_processed_index, processed_end_index)
|
||||
|
||||
def discard_processed_prefix(self, removed_count: int) -> None:
|
||||
"""同步 runtime 对消息缓存前缀的裁剪。"""
|
||||
self._last_processed_index = max(0, self._last_processed_index - removed_count)
|
||||
|
||||
@staticmethod
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
@@ -274,7 +291,8 @@ class ExpressionLearner:
|
||||
jargon_miner: Optional["JargonMiner"] = None,
|
||||
) -> bool:
|
||||
"""学习表达方式"""
|
||||
pending_messages = message_cache[self._last_processed_index :]
|
||||
processed_end_index = len(message_cache)
|
||||
pending_messages = message_cache[self._last_processed_index : processed_end_index]
|
||||
if not pending_messages:
|
||||
logger.debug("没有待处理消息")
|
||||
return False
|
||||
@@ -303,6 +321,7 @@ class ExpressionLearner:
|
||||
response = generation_result.response
|
||||
except Exception as e:
|
||||
logger.error(f"学习表达方式失败: {e}")
|
||||
self._last_processed_index = processed_end_index
|
||||
return False
|
||||
|
||||
expressions: List[Tuple[str, str, str]]
|
||||
@@ -336,7 +355,7 @@ class ExpressionLearner:
|
||||
)
|
||||
if after_extract_result.aborted:
|
||||
logger.info(f"{self.session_id} 表达方式选择 Hook 中止")
|
||||
self._last_processed_index = len(message_cache)
|
||||
self._last_processed_index = processed_end_index
|
||||
return False
|
||||
|
||||
after_extract_kwargs = after_extract_result.kwargs
|
||||
@@ -352,7 +371,7 @@ class ExpressionLearner:
|
||||
|
||||
if not expressions:
|
||||
logger.info("没有可学习的表达方式")
|
||||
self._last_processed_index = len(message_cache)
|
||||
self._last_processed_index = processed_end_index
|
||||
return False
|
||||
|
||||
logger.info(f"可学习的表达方式: {expressions}")
|
||||
@@ -361,7 +380,7 @@ class ExpressionLearner:
|
||||
learnt_expressions = self._filter_expressions(expressions, pending_messages)
|
||||
if not learnt_expressions:
|
||||
logger.info("没有可学习的表达方式通过过滤")
|
||||
self._last_processed_index = len(message_cache)
|
||||
self._last_processed_index = processed_end_index
|
||||
return False
|
||||
|
||||
learnt_expressions_str = "\n".join(f"{situation}->{style}" for situation, style in learnt_expressions)
|
||||
@@ -386,7 +405,7 @@ class ExpressionLearner:
|
||||
continue
|
||||
await self._upsert_expression_to_db(situation, style)
|
||||
|
||||
self._last_processed_index = len(message_cache)
|
||||
self._last_processed_index = processed_end_index
|
||||
return True
|
||||
|
||||
def _check_cached_jargons_in_messages(
|
||||
|
||||
@@ -149,7 +149,7 @@ class BuiltinToolRuntimeContext:
|
||||
def _build_at_component_for_message_id(self, message_id: str) -> Optional[AtComponent]:
|
||||
"""根据消息编号构造 at 组件。"""
|
||||
|
||||
target_message = self.runtime._source_messages_by_id.get(message_id)
|
||||
target_message = self.runtime.find_source_message_by_id(message_id)
|
||||
if target_message is None:
|
||||
return None
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ async def handle_tool(
|
||||
"reply 工具需要提供有效的 `msg_id` 参数。",
|
||||
)
|
||||
|
||||
target_message = tool_ctx.runtime._source_messages_by_id.get(target_message_id)
|
||||
target_message = tool_ctx.runtime.find_source_message_by_id(target_message_id)
|
||||
if target_message is None:
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
@@ -263,6 +263,7 @@ async def handle_tool(
|
||||
|
||||
target_user_info = target_message.message_info.user_info
|
||||
target_user_name = target_user_info.user_cardname or target_user_info.user_nickname or target_user_info.user_id
|
||||
bot_name = config_module.global_config.bot.nickname.strip() or "MaiSaka"
|
||||
|
||||
if tool_ctx.runtime.chat_stream.platform == CLI_PLATFORM_NAME:
|
||||
tool_ctx.append_guided_reply_to_chat_history(combined_reply_text)
|
||||
@@ -291,7 +292,7 @@ async def handle_tool(
|
||||
)
|
||||
return tool_ctx.build_success_result(
|
||||
invocation.tool_name,
|
||||
f'已生成并发送回复"{combined_reply_text}"\n发送对象:{target_user_name}',
|
||||
f'"{bot_name}"已生成并向"{target_user_name}"发送了回复"{combined_reply_text}"',
|
||||
structured_content={
|
||||
"msg_id": target_message_id,
|
||||
"set_quote": set_quote,
|
||||
|
||||
@@ -52,7 +52,7 @@ async def handle_tool(
|
||||
"查看复杂消息工具需要提供有效的 `msg_id` 参数。",
|
||||
)
|
||||
|
||||
target_message = tool_ctx.runtime._source_messages_by_id.get(target_message_id)
|
||||
target_message = tool_ctx.runtime.find_source_message_by_id(target_message_id)
|
||||
if target_message is None:
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
|
||||
@@ -10,7 +10,7 @@ from rich.console import RenderableType
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.common.i18n import get_locale
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import get_prompt_cache_revision, load_prompt
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.common.utils.utils_config import ChatConfigUtils
|
||||
from src.config.config import global_config
|
||||
from src.core.tooling import ToolAvailabilityContext, ToolRegistry
|
||||
@@ -218,21 +218,15 @@ class MaisakaChatLoopService:
|
||||
self._extra_tools: List[ToolOption] = []
|
||||
self._interrupt_flag: asyncio.Event | None = None
|
||||
self._tool_registry: ToolRegistry | None = None
|
||||
self._prompts_loaded = chat_system_prompt is not None
|
||||
self._prompt_cache_revision = get_prompt_cache_revision()
|
||||
self._custom_chat_system_prompt = chat_system_prompt
|
||||
self._prompt_load_lock = asyncio.Lock()
|
||||
self._personality_prompt = self._build_personality_prompt()
|
||||
if chat_system_prompt is None:
|
||||
self._chat_system_prompt = f"{self._personality_prompt}\n\nYou are a helpful AI assistant."
|
||||
else:
|
||||
self._chat_system_prompt = chat_system_prompt
|
||||
self._llm_chat_clients: dict[str, LLMServiceClient] = {}
|
||||
|
||||
@property
|
||||
def personality_prompt(self) -> str:
|
||||
"""返回当前人格提示词。"""
|
||||
|
||||
return self._personality_prompt
|
||||
return self._build_personality_prompt()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_llm_request_type(request_kind: str) -> str:
|
||||
@@ -349,13 +343,15 @@ class MaisakaChatLoopService:
|
||||
tools_section: 额外注入到提示词中的工具说明片段。
|
||||
"""
|
||||
async with self._prompt_load_lock:
|
||||
try:
|
||||
self._chat_system_prompt = load_prompt("maisaka_chat", **self.build_prompt_template_context(tools_section))
|
||||
except Exception:
|
||||
self._chat_system_prompt = f"{self._personality_prompt}\n\nYou are a helpful AI assistant."
|
||||
self._build_chat_system_prompt(tools_section)
|
||||
|
||||
self._prompts_loaded = True
|
||||
self._prompt_cache_revision = get_prompt_cache_revision()
|
||||
def _build_chat_system_prompt(self, tools_section: str = "") -> str:
|
||||
"""基于当前配置实时构造主聊天系统提示词。"""
|
||||
|
||||
try:
|
||||
return load_prompt("maisaka_chat", **self.build_prompt_template_context(tools_section))
|
||||
except Exception:
|
||||
return f"{self.personality_prompt}\n\nYou are a helpful AI assistant."
|
||||
|
||||
def build_prompt_template_context(self, tools_section: str = "") -> dict[str, str]:
|
||||
"""构造 Maisaka prompt 模板的公共渲染参数。"""
|
||||
@@ -364,7 +360,7 @@ class MaisakaChatLoopService:
|
||||
"bot_name": global_config.bot.nickname,
|
||||
"file_tools_section": tools_section,
|
||||
"group_chat_attention_block": self._build_group_chat_attention_block(),
|
||||
"identity": self._personality_prompt,
|
||||
"identity": self.personality_prompt,
|
||||
"timing_gate_wait_rule": self._build_timing_gate_wait_rule(),
|
||||
"time_block": self._build_time_block(),
|
||||
}
|
||||
@@ -471,7 +467,13 @@ class MaisakaChatLoopService:
|
||||
|
||||
messages: List[Message] = []
|
||||
system_msg = MessageBuilder().set_role(RoleType.System)
|
||||
system_msg.add_text_content(system_prompt if system_prompt is not None else self._chat_system_prompt)
|
||||
if system_prompt is not None:
|
||||
resolved_system_prompt = system_prompt
|
||||
elif self._custom_chat_system_prompt is not None:
|
||||
resolved_system_prompt = self._custom_chat_system_prompt
|
||||
else:
|
||||
resolved_system_prompt = self._build_chat_system_prompt()
|
||||
system_msg.add_text_content(resolved_system_prompt)
|
||||
messages.append(system_msg.build())
|
||||
|
||||
for msg in selected_history:
|
||||
@@ -521,8 +523,6 @@ class MaisakaChatLoopService:
|
||||
ChatResponse: 本轮规划器返回结果。
|
||||
"""
|
||||
|
||||
if not self._prompts_loaded or self._prompt_cache_revision != get_prompt_cache_revision():
|
||||
await self.ensure_chat_prompt_loaded()
|
||||
enable_visual_message = self._resolve_enable_visual_message(request_kind)
|
||||
selected_history, selection_reason = self.select_llm_context_messages(
|
||||
chat_history,
|
||||
|
||||
@@ -29,6 +29,7 @@ from src.learners.jargon_miner import JargonMiner
|
||||
from src.llm_models.payload_content.resp_format import RespFormat
|
||||
from src.llm_models.payload_content.tool_option import ToolDefinitionInput
|
||||
from src.mcp_module import MCPManager
|
||||
from src.mcp_module.config import build_mcp_server_runtime_configs
|
||||
from src.mcp_module.host_llm_bridge import MCPHostLLMBridge
|
||||
from src.mcp_module.provider import MCPToolProvider
|
||||
from src.plugin_runtime.tool_provider import PluginToolProvider
|
||||
@@ -56,6 +57,7 @@ from .tool_provider import MaisakaBuiltinToolProvider
|
||||
logger = get_logger("maisaka_runtime")
|
||||
|
||||
MAX_INTERNAL_ROUNDS = 10
|
||||
MAX_RETAINED_MESSAGE_CACHE_SIZE = 200
|
||||
|
||||
|
||||
class MaisakaHeartFlowChatting:
|
||||
@@ -90,7 +92,6 @@ class MaisakaHeartFlowChatting:
|
||||
self._mcp_manager: Optional[MCPManager] = None
|
||||
self._mcp_host_bridge: Optional[MCPHostLLMBridge] = None
|
||||
self._current_cycle_detail: Optional[CycleDetail] = None
|
||||
self._source_messages_by_id: dict[str, SessionMessage] = {}
|
||||
self._running = False
|
||||
self._cycle_counter = 0
|
||||
self._internal_loop_task: Optional[asyncio.Task] = None
|
||||
@@ -98,19 +99,13 @@ class MaisakaHeartFlowChatting:
|
||||
self._deferred_message_turn_task: Optional[asyncio.Task[None]] = None
|
||||
self._message_debounce_seconds = 1.0
|
||||
self._message_debounce_required = False
|
||||
self._message_received_at_by_id: dict[str, float] = {}
|
||||
self._oldest_pending_message_received_at: Optional[float] = None
|
||||
self._last_message_received_at = 0.0
|
||||
self._talk_frequency_adjust = 1.0
|
||||
self._reply_latency_measurement_started_at: Optional[float] = None
|
||||
self._recent_reply_latencies: deque[tuple[float, float]] = deque()
|
||||
self._wait_timeout_task: Optional[asyncio.Task[None]] = None
|
||||
self._max_internal_rounds = MAX_INTERNAL_ROUNDS
|
||||
configured_context_size = (
|
||||
global_config.chat.max_context_size
|
||||
if self.chat_stream.is_group_session
|
||||
else global_config.chat.max_private_context_size
|
||||
)
|
||||
self._max_context_size = max(1, int(configured_context_size))
|
||||
self._agent_state: Literal["running", "wait", "stop"] = self._STATE_STOP
|
||||
self._pending_wait_tool_call_id: Optional[str] = None
|
||||
self._force_next_timing_continue = False
|
||||
@@ -152,6 +147,17 @@ class MaisakaHeartFlowChatting:
|
||||
self._register_tool_providers()
|
||||
self._emit_monitor_session_start()
|
||||
|
||||
@property
|
||||
def _max_context_size(self) -> int:
|
||||
"""返回当前会话实时生效的上下文窗口大小。"""
|
||||
|
||||
configured_context_size = (
|
||||
global_config.chat.max_context_size
|
||||
if self.chat_stream.is_group_session
|
||||
else global_config.chat.max_private_context_size
|
||||
)
|
||||
return max(1, int(configured_context_size))
|
||||
|
||||
def _emit_monitor_session_start(self) -> None:
|
||||
"""向 WebUI 监控面板同步当前会话的展示标识。"""
|
||||
|
||||
@@ -312,10 +318,11 @@ class MaisakaHeartFlowChatting:
|
||||
self._ensure_background_tasks_running()
|
||||
received_at = time.time()
|
||||
self._last_message_received_at = received_at
|
||||
if self._oldest_pending_message_received_at is None:
|
||||
self._oldest_pending_message_received_at = received_at
|
||||
self._update_message_trigger_state(message)
|
||||
self.message_cache.append(message)
|
||||
self._message_received_at_by_id[message.message_id] = received_at
|
||||
self._source_messages_by_id[message.message_id] = message
|
||||
self._prune_processed_message_cache()
|
||||
if self._is_reply_effect_tracking_enabled():
|
||||
asyncio.create_task(self._reply_effect_tracker.observe_user_message(message))
|
||||
if self._agent_state == self._STATE_RUNNING:
|
||||
@@ -487,6 +494,45 @@ class MaisakaHeartFlowChatting:
|
||||
f"最近10分钟样本数={len(self._recent_reply_latencies)}"
|
||||
)
|
||||
|
||||
def find_source_message_by_id(self, message_id: str) -> Optional[SessionMessage]:
|
||||
"""从 Maisaka 历史中查找指定消息编号对应的原始消息。"""
|
||||
normalized_message_id = str(message_id or "").strip()
|
||||
if not normalized_message_id:
|
||||
return None
|
||||
|
||||
for history_message in reversed(self._chat_history):
|
||||
if str(getattr(history_message, "message_id", "") or "").strip() != normalized_message_id:
|
||||
continue
|
||||
|
||||
original_message = getattr(history_message, "original_message", None)
|
||||
if original_message is None:
|
||||
continue
|
||||
return original_message
|
||||
|
||||
return None
|
||||
|
||||
def _prune_processed_message_cache(self) -> None:
|
||||
"""裁剪 runtime 与表达学习器都已经消费过的旧消息。"""
|
||||
excess_count = len(self.message_cache) - MAX_RETAINED_MESSAGE_CACHE_SIZE
|
||||
if excess_count <= 0:
|
||||
return
|
||||
|
||||
removable_count = min(
|
||||
excess_count,
|
||||
self._last_processed_index,
|
||||
self._expression_learner.last_processed_index,
|
||||
)
|
||||
if removable_count <= 0:
|
||||
return
|
||||
|
||||
del self.message_cache[:removable_count]
|
||||
self._last_processed_index = max(0, self._last_processed_index - removable_count)
|
||||
self._expression_learner.discard_processed_prefix(removable_count)
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 已清理 Maisaka 旧消息缓存: "
|
||||
f"清理数量={removable_count} 保留数量={len(self.message_cache)}"
|
||||
)
|
||||
|
||||
def _should_trigger_message_turn_by_idle_compensation(
|
||||
self,
|
||||
*,
|
||||
@@ -637,6 +683,7 @@ class MaisakaHeartFlowChatting:
|
||||
return
|
||||
|
||||
if self._internal_loop_task is None or self._internal_loop_task.done():
|
||||
is_restart = self._internal_loop_task is not None
|
||||
if self._internal_loop_task is not None and not self._internal_loop_task.cancelled():
|
||||
try:
|
||||
exc = self._internal_loop_task.exception()
|
||||
@@ -645,7 +692,10 @@ class MaisakaHeartFlowChatting:
|
||||
if exc is not None:
|
||||
logger.error(f"{self.log_prefix} 内部循环任务异常退出: {exc}")
|
||||
self._internal_loop_task = asyncio.create_task(self._reasoning_engine.run_loop())
|
||||
logger.warning(f"{self.log_prefix} 已重新拉起 Maisaka 内部循环任务")
|
||||
if is_restart:
|
||||
logger.warning(f"{self.log_prefix} 已重新拉起 Maisaka 内部循环任务")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix} 已启动 Maisaka 内部循环任务")
|
||||
|
||||
def _register_tool_providers(self) -> None:
|
||||
"""注册 Maisaka 运行时默认启用的工具 Provider。"""
|
||||
@@ -1001,12 +1051,10 @@ class MaisakaHeartFlowChatting:
|
||||
# f"收集 {len(unique_messages)} 条新消息"
|
||||
# )
|
||||
if unique_messages and self._reply_latency_measurement_started_at is None:
|
||||
self._reply_latency_measurement_started_at = min(
|
||||
self._message_received_at_by_id.get(message.message_id, self._last_message_received_at)
|
||||
for message in unique_messages
|
||||
self._reply_latency_measurement_started_at = (
|
||||
self._oldest_pending_message_received_at or self._last_message_received_at
|
||||
)
|
||||
for message in unique_messages:
|
||||
self._message_received_at_by_id.pop(message.message_id, None)
|
||||
self._oldest_pending_message_received_at = None
|
||||
return unique_messages
|
||||
|
||||
async def _wait_for_message_quiet_period(self) -> None:
|
||||
@@ -1075,10 +1123,19 @@ class MaisakaHeartFlowChatting:
|
||||
|
||||
async def _trigger_batch_learning(self, messages: list[SessionMessage]) -> None:
|
||||
"""按同一批消息触发表达方式和黑话学习。"""
|
||||
processed_end_index = len(self.message_cache)
|
||||
if not self._enable_expression_learning:
|
||||
self._expression_learner.mark_all_processed(self.message_cache)
|
||||
self._prune_processed_message_cache()
|
||||
return
|
||||
|
||||
try:
|
||||
await self._trigger_expression_learning(messages)
|
||||
except Exception as exc:
|
||||
logger.error(f"{self.log_prefix} 表达学习任务异常退出: {exc}")
|
||||
self._expression_learner.mark_processed_until(processed_end_index)
|
||||
finally:
|
||||
self._prune_processed_message_cache()
|
||||
|
||||
def _should_trigger_learning(
|
||||
self,
|
||||
@@ -1145,6 +1202,10 @@ class MaisakaHeartFlowChatting:
|
||||
|
||||
async def _init_mcp(self) -> None:
|
||||
"""初始化 MCP 工具并注册到统一工具层。"""
|
||||
if not build_mcp_server_runtime_configs(global_config.mcp):
|
||||
logger.debug(f"{self.log_prefix} 未配置可用的 MCP 服务,跳过 Maisaka MCP 初始化")
|
||||
return
|
||||
|
||||
self._mcp_host_bridge = MCPHostLLMBridge(
|
||||
sampling_task_name=global_config.mcp.client.sampling.task_name,
|
||||
)
|
||||
@@ -1153,7 +1214,7 @@ class MaisakaHeartFlowChatting:
|
||||
host_callbacks=self._mcp_host_bridge.build_callbacks(),
|
||||
)
|
||||
if self._mcp_manager is None:
|
||||
logger.info(f"{self.log_prefix} Maisaka MCP 管理器不可用")
|
||||
logger.warning(f"{self.log_prefix} Maisaka MCP 管理器初始化失败,MCP 工具不会注册")
|
||||
return
|
||||
|
||||
mcp_tool_specs = self._mcp_manager.get_tool_specs()
|
||||
|
||||
@@ -110,7 +110,7 @@ class RuntimeDataCapabilityMixin:
|
||||
result = await database_service.db_count(model_class=model_class, filters=args.get("filters"))
|
||||
else:
|
||||
return {"success": False, "error": f"不支持的 query_type: {query_type}"}
|
||||
return {"success": True, "result": result}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.database.query] 执行失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
@@ -136,7 +136,7 @@ class RuntimeDataCapabilityMixin:
|
||||
key_field=args.get("key_field"),
|
||||
key_value=args.get("key_value"),
|
||||
)
|
||||
return {"success": True, "result": result}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.database.save] 执行失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
@@ -162,7 +162,7 @@ class RuntimeDataCapabilityMixin:
|
||||
order_by=args.get("order_by"),
|
||||
single_result=args.get("single_result", False),
|
||||
)
|
||||
return {"success": True, "result": result}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.database.get] 执行失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
@@ -185,7 +185,7 @@ class RuntimeDataCapabilityMixin:
|
||||
return {"success": False, "error": f"未找到数据模型: {model_name}"}
|
||||
|
||||
result = await database_service.db_delete(model_class=model_class, filters=filters)
|
||||
return {"success": True, "result": result}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.database.delete] 执行失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
@@ -205,7 +205,7 @@ class RuntimeDataCapabilityMixin:
|
||||
return {"success": False, "error": f"未找到数据模型: {model_name}"}
|
||||
|
||||
result = await database_service.db_count(model_class=model_class, filters=args.get("filters"))
|
||||
return {"success": True, "count": result}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.database.count] 执行失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@@ -47,6 +47,15 @@ class VersionComparator:
|
||||
return "0.0.0"
|
||||
|
||||
normalized = re.sub(r"-snapshot\.\d+", "", str(version).strip())
|
||||
try:
|
||||
parsed_version = Version(normalized)
|
||||
parts = [str(part) for part in parsed_version.release[:3]]
|
||||
while len(parts) < 3:
|
||||
parts.append("0")
|
||||
return ".".join(parts)
|
||||
except InvalidVersion:
|
||||
pass
|
||||
|
||||
if not re.match(r"^\d+(\.\d+){0,2}$", normalized):
|
||||
return "0.0.0"
|
||||
|
||||
@@ -132,6 +141,17 @@ class VersionComparator:
|
||||
"""
|
||||
return bool(_SEMVER_PATTERN.fullmatch(str(version or "").strip()))
|
||||
|
||||
@staticmethod
|
||||
def is_valid_project_version(version: str) -> bool:
|
||||
"""判断主程序或 SDK 的项目版本号是否可被解析。
|
||||
|
||||
``pyproject.toml`` 遵循 Python 包版本规范,允许 ``1.0.0rc16`` 或
|
||||
``1.0.0-pre.16`` 这类预发布版本;兼容性比较时只取其 release 部分。
|
||||
"""
|
||||
|
||||
normalized = VersionComparator.normalize_version(version)
|
||||
return normalized != "0.0.0" or str(version or "").strip() == "0.0.0"
|
||||
|
||||
|
||||
class _StrictManifestModel(BaseModel):
|
||||
"""Manifest 解析使用的严格基类模型。"""
|
||||
@@ -1030,7 +1050,7 @@ class ManifestValidator:
|
||||
return ""
|
||||
|
||||
raw_version = str(project_data.get("version", "") or "").strip()
|
||||
if VersionComparator.is_valid_semver(raw_version):
|
||||
if VersionComparator.is_valid_project_version(raw_version):
|
||||
return raw_version
|
||||
return ""
|
||||
|
||||
@@ -1047,7 +1067,7 @@ class ManifestValidator:
|
||||
"""
|
||||
try:
|
||||
raw_version = importlib_metadata.version("maibot-plugin-sdk")
|
||||
if VersionComparator.is_valid_semver(raw_version):
|
||||
if VersionComparator.is_valid_project_version(raw_version):
|
||||
return raw_version
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
pass
|
||||
@@ -1064,7 +1084,7 @@ class ManifestValidator:
|
||||
return ""
|
||||
|
||||
raw_version = str(project_data.get("version", "") or "").strip()
|
||||
if VersionComparator.is_valid_semver(raw_version):
|
||||
if VersionComparator.is_valid_project_version(raw_version):
|
||||
return raw_version
|
||||
return ""
|
||||
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from datetime import date, datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from sqlalchemy import delete, func
|
||||
@@ -19,14 +20,28 @@ if TYPE_CHECKING:
|
||||
logger = get_logger("database_service")
|
||||
|
||||
|
||||
def _to_msgpack_value(value: Any) -> Any:
|
||||
if isinstance(value, datetime):
|
||||
return value.isoformat()
|
||||
if isinstance(value, date):
|
||||
return value.isoformat()
|
||||
if isinstance(value, Enum):
|
||||
return value.value
|
||||
if isinstance(value, dict):
|
||||
return {_to_msgpack_value(key): _to_msgpack_value(item) for key, item in value.items()}
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
return [_to_msgpack_value(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def _to_dict(record: Any) -> dict[str, Any]:
|
||||
if record is None:
|
||||
return {}
|
||||
if isinstance(record, dict):
|
||||
return record
|
||||
return _to_msgpack_value(record)
|
||||
if hasattr(record, "model_dump"):
|
||||
return record.model_dump()
|
||||
return dict(record.__dict__) if hasattr(record, "__dict__") else {}
|
||||
return _to_msgpack_value(record.model_dump())
|
||||
return _to_msgpack_value(dict(record.__dict__)) if hasattr(record, "__dict__") else {}
|
||||
|
||||
|
||||
def _get_model_field(model_class: type[SQLModel], field_name: str) -> Any:
|
||||
|
||||
@@ -296,6 +296,7 @@ async def get_models_by_url(
|
||||
async def test_provider_connection(
|
||||
base_url: str = Query(..., description="提供商的基础 URL"),
|
||||
api_key: Optional[str] = Query(None, description="API Key(可选,用于验证 Key 有效性)"),
|
||||
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
|
||||
):
|
||||
"""
|
||||
测试提供商连接状态
|
||||
@@ -359,13 +360,19 @@ async def test_provider_connection(
|
||||
try:
|
||||
start_time = time.time()
|
||||
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
headers = {"Content-Type": "application/json"}
|
||||
params = {}
|
||||
|
||||
if client_type == "gemini":
|
||||
# Gemini 使用 URL 参数传递 API Key
|
||||
params["key"] = api_key
|
||||
else:
|
||||
# OpenAI 兼容格式使用 Authorization 头
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
# 尝试获取模型列表
|
||||
models_url = f"{base_url}/models"
|
||||
response = await client.get(models_url, headers=headers)
|
||||
response = await client.get(models_url, headers=headers, params=params)
|
||||
|
||||
if response.status_code == 200:
|
||||
result["api_key_valid"] = True
|
||||
@@ -408,9 +415,14 @@ async def test_provider_connection_by_name(
|
||||
|
||||
base_url = provider.get("base_url", "")
|
||||
api_key = provider.get("api_key", "")
|
||||
client_type = provider.get("client_type", "openai")
|
||||
|
||||
if not base_url:
|
||||
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
|
||||
|
||||
# 调用测试接口
|
||||
return await test_provider_connection(base_url=base_url, api_key=api_key or None)
|
||||
return await test_provider_connection(
|
||||
base_url=base_url,
|
||||
api_key=api_key if api_key else None,
|
||||
client_type=client_type,
|
||||
)
|
||||
|
||||
@@ -187,9 +187,10 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
|
||||
for field in ["manifest_version", "name", "version", "author"]:
|
||||
if field not in manifest:
|
||||
raise ValueError(f"缺少必需字段: {field}")
|
||||
manifest["id"] = plugin_id
|
||||
with open(manifest_path, "w", encoding="utf-8") as file_obj:
|
||||
json.dump(manifest, file_obj, ensure_ascii=False, indent=2)
|
||||
if not str(manifest.get("id", "")).strip():
|
||||
manifest["id"] = plugin_id
|
||||
with open(manifest_path, "w", encoding="utf-8") as file_obj:
|
||||
json.dump(manifest, file_obj, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
remove_tree(target_path)
|
||||
await update_progress(
|
||||
|
||||
Reference in New Issue
Block a user