Merge pull request #1668 from Mai-with-u/dev

Dev
This commit is contained in:
SengokuCola
2026-05-09 23:18:08 +08:00
committed by GitHub
31 changed files with 823 additions and 1263 deletions

View File

@@ -19,7 +19,7 @@ RUN uv sync --frozen --no-dev --no-install-project
# Copy project source # Copy project source
COPY . . COPY . .
RUN git clone --depth 1 --branch plugin https://github.com/Mai-with-u/MaiBot-Napcat-Adapter.git plugin-templates/MaiBot-Napcat-Adapter RUN git clone --depth 1 --branch main https://github.com/Mai-with-u/MaiBot-Napcat-Adapter.git plugin-templates/MaiBot-Napcat-Adapter
RUN chmod +x docker-entrypoint.sh RUN chmod +x docker-entrypoint.sh
EXPOSE 8000 EXPOSE 8000

View File

@@ -18,7 +18,7 @@ const PLUGIN_DETAILS_FILE = 'plugin_details.json'
* 插件列表 API 响应类型(只包含我们需要的字段) * 插件列表 API 响应类型(只包含我们需要的字段)
*/ */
interface PluginApiResponse { interface PluginApiResponse {
id: string id?: string
manifest: { manifest: {
manifest_version: number manifest_version: number
id?: string id?: string
@@ -110,7 +110,7 @@ export async function fetchPluginList(): Promise<ApiResponse<PluginInfo[]>> {
console.warn('跳过无效插件数据:', item) console.warn('跳过无效插件数据:', item)
return false return false
} }
const pluginId = item.manifest.id || item.id const pluginId = item.manifest.id
if (!pluginId) { if (!pluginId) {
console.warn('跳过缺少 ID 的插件:', item) console.warn('跳过缺少 ID 的插件:', item)
return false return false
@@ -122,7 +122,7 @@ export async function fetchPluginList(): Promise<ApiResponse<PluginInfo[]>> {
return true return true
}) })
.map((item) => ({ .map((item) => ({
id: item.manifest.id || item.id, id: item.manifest.id!,
manifest: normalizePluginManifest(item.manifest), manifest: normalizePluginManifest(item.manifest),
downloads: 0, downloads: 0,
rating: 0, rating: 0,

View File

@@ -25,7 +25,7 @@ services:
- ./data/MaiMBot/emoji:/data/emoji # 持久化表情包 - ./data/MaiMBot/emoji:/data/emoji # 持久化表情包
- ./data/MaiMBot/plugins:/MaiMBot/plugins # 插件目录 - ./data/MaiMBot/plugins:/MaiMBot/plugins # 插件目录
- ./data/MaiMBot/logs:/MaiMBot/logs # 日志目录 - ./data/MaiMBot/logs:/MaiMBot/logs # 日志目录
- ./depends-data:/MaiMBot/depends-data:ro # 运行时资源文件 - ./depends-data:/MaiMBot/depends-data # 运行时资源文件
# - site-packages:/usr/local/lib/python3.13/site-packages # 持久化Python包需要时启用 # - site-packages:/usr/local/lib/python3.13/site-packages # 持久化Python包需要时启用
restart: always restart: always
networks: networks:

View File

@@ -1 +1 @@
请用中文描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题、直观感受,输出为一段平文本,最多30字请注意不要分点就输出一段文本 请用中文详细描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题、直观感受,输出为一段平文本,最多100字请注意不要分点就输出一段文本

View File

@@ -4,9 +4,9 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "MaiBot" name = "MaiBot"
version = "1.0.0" version = "1.0.0-pre.16"
description = "MaiCore 是一个基于大语言模型的可交互智能体" description = "MaiCore 是一个基于大语言模型的可交互智能体"
requires-python = ">=3.10" requires-python = ">=3.12"
dependencies = [ dependencies = [
"Babel>=2.17.0", "Babel>=2.17.0",
"aiohttp>=3.12.14", "aiohttp>=3.12.14",

View File

@@ -6,6 +6,7 @@ from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo
from src.common.data_models.message_component_data_model import AtComponent, MessageSequence, ReplyComponent, TextComponent from src.common.data_models.message_component_data_model import AtComponent, MessageSequence, ReplyComponent, TextComponent
from src.config.config import global_config from src.config.config import global_config
from src.maisaka.builtin_tool.context import BuiltinToolRuntimeContext from src.maisaka.builtin_tool.context import BuiltinToolRuntimeContext
from src.maisaka.runtime import MaisakaHeartFlowChatting
def _build_sent_message() -> SessionMessage: def _build_sent_message() -> SessionMessage:
@@ -63,7 +64,9 @@ def test_post_process_reply_message_sequences_converts_at_marker_before_bracket_
) )
) )
) )
runtime = SimpleNamespace(_source_messages_by_id={"12160142": target_message}) runtime = SimpleNamespace(
find_source_message_by_id=lambda message_id: target_message if message_id == "12160142" else None
)
engine = SimpleNamespace(_get_runtime_manager=lambda: None) engine = SimpleNamespace(_get_runtime_manager=lambda: None)
tool_ctx = BuiltinToolRuntimeContext(engine=engine, runtime=runtime) tool_ctx = BuiltinToolRuntimeContext(engine=engine, runtime=runtime)
@@ -85,7 +88,7 @@ def test_post_process_reply_message_sequences_ignores_at_marker_when_disabled(mo
"src.maisaka.builtin_tool.context.process_llm_response", "src.maisaka.builtin_tool.context.process_llm_response",
lambda text: [text.strip()] if text.strip() else [], lambda text: [text.strip()] if text.strip() else [],
) )
runtime = SimpleNamespace(_source_messages_by_id={}) runtime = SimpleNamespace(find_source_message_by_id=lambda message_id: None)
engine = SimpleNamespace(_get_runtime_manager=lambda: None) engine = SimpleNamespace(_get_runtime_manager=lambda: None)
tool_ctx = BuiltinToolRuntimeContext(engine=engine, runtime=runtime) tool_ctx = BuiltinToolRuntimeContext(engine=engine, runtime=runtime)
@@ -96,3 +99,15 @@ def test_post_process_reply_message_sequences_ignores_at_marker_when_disabled(mo
assert len(components) == 1 assert len(components) == 1
assert isinstance(components[0], TextComponent) assert isinstance(components[0], TextComponent)
assert components[0].text == "at[12160142] 就这个群" assert components[0].text == "at[12160142] 就这个群"
def test_runtime_finds_source_message_from_history() -> None:
target_message = _build_sent_message()
runtime = object.__new__(MaisakaHeartFlowChatting)
runtime._chat_history = [
SimpleNamespace(message_id="other-message-id", original_message=SimpleNamespace()),
SimpleNamespace(message_id="real-message-id", original_message=target_message),
]
assert runtime.find_source_message_by_id("real-message-id") is target_message
assert runtime.find_source_message_by_id("missing-message-id") is None

View File

@@ -0,0 +1,105 @@
from types import SimpleNamespace
import pytest
import time
from src.chat.heart_flow import heartflow_manager as heartflow_manager_module
from src.chat.heart_flow.heartflow_manager import HEARTFLOW_ACTIVE_RETENTION_SECONDS, HeartflowManager
from src.learners.expression_learner import ExpressionLearner
from src.maisaka.runtime import MAX_RETAINED_MESSAGE_CACHE_SIZE, MaisakaHeartFlowChatting
def _build_runtime_with_messages(message_count: int) -> MaisakaHeartFlowChatting:
runtime = object.__new__(MaisakaHeartFlowChatting)
runtime.log_prefix = "[test]"
runtime.message_cache = [SimpleNamespace(message_id=f"msg-{index}") for index in range(message_count)]
runtime._last_processed_index = message_count
runtime._expression_learner = ExpressionLearner("session-1")
runtime._expression_learner.mark_all_processed(runtime.message_cache)
return runtime
def test_prune_processed_message_cache_keeps_bounded_recent_window() -> None:
runtime = _build_runtime_with_messages(MAX_RETAINED_MESSAGE_CACHE_SIZE + 25)
runtime._prune_processed_message_cache()
assert len(runtime.message_cache) == MAX_RETAINED_MESSAGE_CACHE_SIZE
assert runtime.message_cache[0].message_id == "msg-25"
assert runtime._last_processed_index == MAX_RETAINED_MESSAGE_CACHE_SIZE
assert runtime._expression_learner.last_processed_index == MAX_RETAINED_MESSAGE_CACHE_SIZE
def test_prune_processed_message_cache_keeps_unlearned_messages() -> None:
runtime = _build_runtime_with_messages(MAX_RETAINED_MESSAGE_CACHE_SIZE + 25)
runtime._expression_learner.discard_processed_prefix(MAX_RETAINED_MESSAGE_CACHE_SIZE + 5)
runtime._prune_processed_message_cache()
assert len(runtime.message_cache) == MAX_RETAINED_MESSAGE_CACHE_SIZE + 5
assert runtime.message_cache[0].message_id == "msg-20"
assert runtime._expression_learner.last_processed_index == 0
def test_collect_pending_messages_uses_single_pending_received_time() -> None:
runtime = _build_runtime_with_messages(2)
runtime._last_processed_index = 0
runtime._oldest_pending_message_received_at = 123.0
runtime._last_message_received_at = 456.0
runtime._reply_latency_measurement_started_at = None
pending_messages = runtime._collect_pending_messages()
assert [message.message_id for message in pending_messages] == ["msg-0", "msg-1"]
assert runtime._reply_latency_measurement_started_at == 123.0
assert runtime._oldest_pending_message_received_at is None
@pytest.mark.asyncio
async def test_heartflow_manager_evicts_lru_chat_over_limit(monkeypatch: pytest.MonkeyPatch) -> None:
manager = HeartflowManager()
stopped_session_ids: list[str] = []
old_active_at = time.time() - HEARTFLOW_ACTIVE_RETENTION_SECONDS - 1
class FakeChat:
def __init__(self, session_id: str) -> None:
self.session_id = session_id
async def stop(self) -> None:
stopped_session_ids.append(self.session_id)
monkeypatch.setattr(heartflow_manager_module, "HEARTFLOW_MAX_ACTIVE_CHATS", 2)
manager.heartflow_chat_list["session-1"] = FakeChat("session-1")
manager.heartflow_chat_list["session-2"] = FakeChat("session-2")
manager.heartflow_chat_list["session-3"] = FakeChat("session-3")
manager._chat_last_active_at["session-1"] = old_active_at
manager._chat_last_active_at["session-2"] = old_active_at
manager._chat_last_active_at["session-3"] = time.time()
await manager._evict_over_limit_chats(protected_session_id="session-3")
assert stopped_session_ids == ["session-1"]
assert list(manager.heartflow_chat_list) == ["session-2", "session-3"]
@pytest.mark.asyncio
async def test_heartflow_manager_keeps_recent_chats_even_over_limit(monkeypatch: pytest.MonkeyPatch) -> None:
manager = HeartflowManager()
stopped_session_ids: list[str] = []
class FakeChat:
def __init__(self, session_id: str) -> None:
self.session_id = session_id
async def stop(self) -> None:
stopped_session_ids.append(self.session_id)
monkeypatch.setattr(heartflow_manager_module, "HEARTFLOW_MAX_ACTIVE_CHATS", 2)
for session_id in ("session-1", "session-2", "session-3"):
manager.heartflow_chat_list[session_id] = FakeChat(session_id)
manager._chat_last_active_at[session_id] = time.time()
await manager._evict_over_limit_chats(protected_session_id="session-3")
assert stopped_session_ids == []
assert list(manager.heartflow_chat_list) == ["session-1", "session-2", "session-3"]

View File

@@ -199,7 +199,7 @@ async def test_reply_tool_puts_monitor_detail_into_metadata(monkeypatch: pytest.
), ),
) )
runtime = SimpleNamespace( runtime = SimpleNamespace(
_source_messages_by_id={"msg-1": target_message}, find_source_message_by_id=lambda message_id: target_message if message_id == "msg-1" else None,
log_prefix="[test]", log_prefix="[test]",
chat_stream=SimpleNamespace(platform=reply_tool_module.CLI_PLATFORM_NAME), chat_stream=SimpleNamespace(platform=reply_tool_module.CLI_PLATFORM_NAME),
session_id="session-1", session_id="session-1",

View File

@@ -1169,6 +1169,8 @@ class TestVersionComparator:
assert VersionComparator.normalize_version("0.8.0-snapshot.1") == "0.8.0" assert VersionComparator.normalize_version("0.8.0-snapshot.1") == "0.8.0"
assert VersionComparator.normalize_version("1.2") == "1.2.0" assert VersionComparator.normalize_version("1.2") == "1.2.0"
assert VersionComparator.normalize_version("1.0.0rc16") == "1.0.0"
assert VersionComparator.normalize_version("1.0.0-pre.16") == "1.0.0"
assert VersionComparator.normalize_version("") == "0.0.0" assert VersionComparator.normalize_version("") == "0.0.0"
def test_compare(self): def test_compare(self):
@@ -2890,22 +2892,105 @@ class TestIntegration:
monkeypatch.setattr(real_database_service, "db_get", fake_db_get) monkeypatch.setattr(real_database_service, "db_get", fake_db_get)
monkeypatch.setattr(real_db_models, "DemoTable", DummyModel, raising=False) monkeypatch.setattr(real_db_models, "DemoTable", DummyModel, raising=False)
result = await integration_module.PluginRuntimeManager._cap_database_get( manager = object.__new__(integration_module.PluginRuntimeManager)
result = await manager._cap_database_get(
"plugin_a", "plugin_a",
"database.get", "database.get",
{ {
"table": "DemoTable", "model_name": "DemoTable",
"filters": {"status": "active"}, "filters": {"status": "active"},
"limit": 5, "limit": 5,
}, },
) )
assert result == {"success": True, "result": [{"id": 1}]} assert result == [{"id": 1}]
assert captured["model_class"] is DummyModel assert captured["model_class"] is DummyModel
assert captured["filters"] == {"status": "active"} assert captured["filters"] == {"status": "active"}
assert captured["limit"] == 5 assert captured["limit"] == 5
assert captured["single_result"] is False assert captured["single_result"] is False
@pytest.mark.asyncio
async def test_cap_database_get_response_is_not_double_wrapped(self, monkeypatch):
from src.plugin_runtime import integration as integration_module
import src.common.database.database_model as real_db_models
from src.plugin_runtime.host.capability_service import CapabilityService
from src.plugin_runtime.protocol.envelope import CapabilityRequestPayload, Envelope, MessageType
from src.services import database_service as real_database_service
class AllowAllAuthorization:
def check_capability(self, plugin_id, capability):
return True, ""
class DummyModel:
pass
async def fake_db_get(model_class, filters=None, limit=None, order_by=None, single_result=False):
return {"id": 1, "full_path": "E:\\test.png"}
monkeypatch.setattr(real_database_service, "db_get", fake_db_get)
monkeypatch.setattr(real_db_models, "DemoTable", DummyModel, raising=False)
manager = object.__new__(integration_module.PluginRuntimeManager)
service = CapabilityService(AllowAllAuthorization())
service.register_capability("database.get", manager._cap_database_get)
request = Envelope(
request_id=1,
message_type=MessageType.REQUEST,
method="cap.call",
plugin_id="plugin_a",
payload=CapabilityRequestPayload(
capability="database.get",
args={"model_name": "DemoTable", "single_result": True},
).model_dump(),
)
response = await service.handle_capability_request(request)
assert response.payload == {
"success": True,
"result": {"id": 1, "full_path": "E:\\test.png"},
}
@pytest.mark.asyncio
async def test_cap_database_success_handlers_return_raw_results(self, monkeypatch):
from src.plugin_runtime import integration as integration_module
import src.common.database.database_model as real_db_models
from src.services import database_service as real_database_service
class DummyModel:
pass
async def fake_db_get(**kwargs):
return [{"id": 1}]
async def fake_db_save(**kwargs):
return {"id": 2}
async def fake_db_delete(**kwargs):
return 3
async def fake_db_count(**kwargs):
return 4
monkeypatch.setattr(real_database_service, "db_get", fake_db_get)
monkeypatch.setattr(real_database_service, "db_save", fake_db_save)
monkeypatch.setattr(real_database_service, "db_delete", fake_db_delete)
monkeypatch.setattr(real_database_service, "db_count", fake_db_count)
monkeypatch.setattr(real_db_models, "DemoTable", DummyModel, raising=False)
manager = object.__new__(integration_module.PluginRuntimeManager)
base_args = {"model_name": "DemoTable"}
assert await manager._cap_database_query("plugin_a", "database.query", base_args) == [{"id": 1}]
assert await manager._cap_database_save(
"plugin_a", "database.save", {**base_args, "data": {"name": "demo"}}
) == {"id": 2}
assert await manager._cap_database_delete(
"plugin_a", "database.delete", {**base_args, "filters": {"id": 2}}
) == 3
assert await manager._cap_database_count("plugin_a", "database.count", base_args) == 4
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_component_enable_rejects_ambiguous_short_name(self, monkeypatch): async def test_component_enable_rejects_ambiguous_short_name(self, monkeypatch):
from src.plugin_runtime import integration as integration_module from src.plugin_runtime import integration as integration_module

View File

@@ -0,0 +1,187 @@
"""模型路由测试
验证 Gemini 提供商连接测试会使用查询参数传递 API Key
并且不会回退到 OpenAI 兼容接口使用的 Bearer 认证方式。
"""
import importlib
import sys
from types import ModuleType
from typing import Any
import pytest
def load_model_routes(monkeypatch: pytest.MonkeyPatch):
"""在导入路由前 stub 配置与认证依赖模块,避免测试时触发真实初始化。"""
config_module = ModuleType("src.config.config")
config_module.__dict__["CONFIG_DIR"] = "."
monkeypatch.setitem(sys.modules, "src.config.config", config_module)
dependencies_module = ModuleType("src.webui.dependencies")
async def require_auth():
return "test-token"
dependencies_module.__dict__["require_auth"] = require_auth
monkeypatch.setitem(sys.modules, "src.webui.dependencies", dependencies_module)
sys.modules.pop("src.webui.routers.model", None)
return importlib.import_module("src.webui.routers.model")
class FakeResponse:
"""简化版 HTTP 响应对象。"""
def __init__(self, status_code: int):
self.status_code = status_code
def build_async_client_factory(
responses: list[FakeResponse],
calls: list[dict[str, Any]],
):
"""构造一个可记录请求参数的 AsyncClient 替身。"""
response_iter = iter(responses)
class FakeAsyncClient:
def __init__(self, *args: Any, **kwargs: Any):
self.args = args
self.kwargs = kwargs
async def __aenter__(self) -> "FakeAsyncClient":
return self
async def __aexit__(self, exc_type, exc, tb) -> bool:
return False
async def get(
self,
url: str,
headers: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
) -> FakeResponse:
calls.append(
{
"url": url,
"headers": headers or {},
"params": params or {},
}
)
return next(response_iter)
return FakeAsyncClient
@pytest.mark.asyncio
async def test_test_provider_connection_uses_query_api_key_for_gemini(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Gemini 连接测试应通过查询参数传递 API Key。"""
model_routes = load_model_routes(monkeypatch)
calls: list[dict[str, Any]] = []
fake_client_class = build_async_client_factory(
responses=[FakeResponse(200), FakeResponse(200)],
calls=calls,
)
monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class)
result = await model_routes.test_provider_connection(
base_url="https://generativelanguage.googleapis.com/v1beta",
api_key="valid-gemini-key",
client_type="gemini",
)
assert result["network_ok"] is True
assert result["api_key_valid"] is True
assert len(calls) == 2
network_call = calls[0]
validation_call = calls[1]
assert network_call["url"] == "https://generativelanguage.googleapis.com/v1beta"
assert network_call["headers"] == {}
assert network_call["params"] == {}
assert validation_call["url"] == "https://generativelanguage.googleapis.com/v1beta/models"
assert validation_call["params"] == {"key": "valid-gemini-key"}
assert validation_call["headers"] == {"Content-Type": "application/json"}
assert "Authorization" not in validation_call["headers"]
@pytest.mark.asyncio
async def test_test_provider_connection_uses_bearer_auth_for_openai_compatible(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""非 Gemini 提供商连接测试应继续使用 Bearer 认证。"""
model_routes = load_model_routes(monkeypatch)
calls: list[dict[str, Any]] = []
fake_client_class = build_async_client_factory(
responses=[FakeResponse(200), FakeResponse(200)],
calls=calls,
)
monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class)
result = await model_routes.test_provider_connection(
base_url="https://example.com/v1",
api_key="valid-openai-key",
client_type="openai",
)
assert result["network_ok"] is True
assert result["api_key_valid"] is True
assert len(calls) == 2
validation_call = calls[1]
assert validation_call["url"] == "https://example.com/v1/models"
assert validation_call["params"] == {}
assert validation_call["headers"]["Content-Type"] == "application/json"
assert validation_call["headers"]["Authorization"] == "Bearer valid-openai-key"
@pytest.mark.asyncio
async def test_test_provider_connection_by_name_forwards_provider_client_type(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
) -> None:
"""按提供商名称测试连接时,应透传配置中的 client_type。"""
model_routes = load_model_routes(monkeypatch)
config_path = tmp_path / "model_config.toml"
config_path.write_text(
"""
[[api_providers]]
name = "Gemini"
base_url = "https://generativelanguage.googleapis.com/v1beta"
api_key = "valid-gemini-key"
client_type = "gemini"
""".strip(),
encoding="utf-8",
)
monkeypatch.setattr(model_routes, "CONFIG_DIR", str(tmp_path))
captured_kwargs: dict[str, Any] = {}
async def fake_test_provider_connection(**kwargs: Any) -> dict[str, Any]:
captured_kwargs.update(kwargs)
return {
"network_ok": True,
"api_key_valid": True,
"latency_ms": 12.34,
"error": None,
"http_status": 200,
}
monkeypatch.setattr(model_routes, "test_provider_connection", fake_test_provider_connection)
result = await model_routes.test_provider_connection_by_name(provider_name="Gemini")
assert result["network_ok"] is True
assert result["api_key_valid"] is True
assert captured_kwargs == {
"base_url": "https://generativelanguage.googleapis.com/v1beta",
"api_key": "valid-gemini-key",
"client_type": "gemini",
}

View File

@@ -61,3 +61,76 @@ def test_resolve_installed_plugin_path_accepts_manifest_id_case_mismatch(client:
assert plugin_path is not None assert plugin_path is not None
assert plugin_path.name == "demo_plugin" assert plugin_path.name == "demo_plugin"
def test_install_plugin_preserves_manifest_declared_id(client: TestClient, monkeypatch):
class FakeGitMirrorService:
async def clone_repository(self, **kwargs):
target_path = kwargs["target_path"]
target_path.mkdir(parents=True, exist_ok=True)
(target_path / "_manifest.json").write_text(
json.dumps(
{
"manifest_version": 2,
"id": "author.declared",
"name": "Declared Plugin",
"version": "1.0.0",
"author": {"name": "author"},
}
),
encoding="utf-8",
)
return {"success": True}
monkeypatch.setattr(management_module, "get_git_mirror_service", lambda: FakeGitMirrorService())
response = client.post(
"/api/webui/plugins/install",
json={
"plugin_id": "market.plugin",
"repository_url": "https://github.com/author/declared",
"branch": "main",
},
)
assert response.status_code == 200
plugin_path = support_module.resolve_installed_plugin_path("author.declared")
assert plugin_path is not None
manifest = json.loads((plugin_path / "_manifest.json").read_text(encoding="utf-8"))
assert manifest["id"] == "author.declared"
def test_install_plugin_backfills_missing_manifest_id(client: TestClient, monkeypatch):
class FakeGitMirrorService:
async def clone_repository(self, **kwargs):
target_path = kwargs["target_path"]
target_path.mkdir(parents=True, exist_ok=True)
(target_path / "_manifest.json").write_text(
json.dumps(
{
"manifest_version": 2,
"name": "Legacy Plugin",
"version": "1.0.0",
"author": {"name": "author"},
}
),
encoding="utf-8",
)
return {"success": True}
monkeypatch.setattr(management_module, "get_git_mirror_service", lambda: FakeGitMirrorService())
response = client.post(
"/api/webui/plugins/install",
json={
"plugin_id": "market.legacy",
"repository_url": "https://github.com/author/legacy",
"branch": "main",
},
)
assert response.status_code == 200
plugin_path = support_module.resolve_installed_plugin_path("market.legacy")
assert plugin_path is not None
manifest = json.loads((plugin_path / "_manifest.json").read_text(encoding="utf-8"))
assert manifest["id"] == "market.legacy"

View File

@@ -16,6 +16,7 @@ from src.common.logger import get_logger
from ..storage import VectorStore, GraphStore, MetadataStore from ..storage import VectorStore, GraphStore, MetadataStore
from ..embedding import EmbeddingAPIAdapter from ..embedding import EmbeddingAPIAdapter
from ..utils.matcher import AhoCorasick from ..utils.matcher import AhoCorasick
from ..utils.metadata import coerce_metadata_dict
from ..utils.time_parser import format_timestamp from ..utils.time_parser import format_timestamp
from .graph_relation_recall import GraphRelationRecallConfig, GraphRelationRecallService from .graph_relation_recall import GraphRelationRecallConfig, GraphRelationRecallService
from .pagerank import PersonalizedPageRank, PageRankConfig from .pagerank import PersonalizedPageRank, PageRankConfig
@@ -482,7 +483,7 @@ class DualPathRetriever:
score=float(item.score), score=float(item.score),
result_type=item.result_type, result_type=item.result_type,
source=item.source, source=item.source,
metadata=dict(item.metadata or {}), metadata=coerce_metadata_dict(item.metadata),
) )
def _extract_graph_seed_entities(self, query: str, limit: int = 2) -> List[str]: def _extract_graph_seed_entities(self, query: str, limit: int = 2) -> List[str]:
@@ -762,7 +763,7 @@ class DualPathRetriever:
existing = self._clone_retrieval_result(item) existing = self._clone_retrieval_result(item)
merged[item.hash_value] = existing merged[item.hash_value] = existing
else: else:
for key, value in dict(item.metadata or {}).items(): for key, value in coerce_metadata_dict(item.metadata).items():
if key not in existing.metadata or existing.metadata.get(key) in (None, "", []): if key not in existing.metadata or existing.metadata.get(key) in (None, "", []):
existing.metadata[key] = value existing.metadata[key] = value
source_sets.setdefault(item.hash_value, set()).add(str(item.source or "").strip() or "relation_search") source_sets.setdefault(item.hash_value, set()).add(str(item.source or "").strip() or "relation_search")

View File

@@ -25,6 +25,7 @@ from ..utils.episode_retrieval_service import EpisodeRetrievalService
from ..utils.episode_segmentation_service import EpisodeSegmentationService from ..utils.episode_segmentation_service import EpisodeSegmentationService
from ..utils.episode_service import EpisodeService from ..utils.episode_service import EpisodeService
from ..utils.hash import compute_hash, normalize_text from ..utils.hash import compute_hash, normalize_text
from ..utils.metadata import coerce_metadata_dict
from ..utils.person_profile_service import PersonProfileService from ..utils.person_profile_service import PersonProfileService
from ..utils.relation_write_service import RelationWriteService from ..utils.relation_write_service import RelationWriteService
from ..utils.retrieval_tuning_manager import RetrievalTuningManager from ..utils.retrieval_tuning_manager import RetrievalTuningManager
@@ -871,7 +872,7 @@ class SDKMemoryKernel:
"detail": "chat_filtered", "detail": "chat_filtered",
} }
summary_meta = dict(metadata or {}) summary_meta = coerce_metadata_dict(metadata)
summary_meta.setdefault("kind", "chat_summary") summary_meta.setdefault("kind", "chat_summary")
if not str(text or "").strip() or bool(summary_meta.get("generate_from_chat", False)): if not str(text or "").strip() or bool(summary_meta.get("generate_from_chat", False)):
result = await self.summarize_chat_stream( result = await self.summarize_chat_stream(
@@ -961,7 +962,7 @@ class SDKMemoryKernel:
participant_tokens = self._tokens(participants) participant_tokens = self._tokens(participants)
entity_tokens = self._merge_tokens(entities, person_tokens, participant_tokens) entity_tokens = self._merge_tokens(entities, person_tokens, participant_tokens)
source = self._build_source(source_type, chat_id, person_tokens) source = self._build_source(source_type, chat_id, person_tokens)
paragraph_meta = dict(metadata or {}) paragraph_meta = coerce_metadata_dict(metadata)
paragraph_meta.update( paragraph_meta.update(
{ {
"external_id": external_token, "external_id": external_token,

View File

@@ -0,0 +1,11 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Dict
def coerce_metadata_dict(value: Any) -> Dict[str, Any]:
"""返回字典,如果输入值不是字典则返回空字典。"""
if isinstance(value, Mapping):
return dict(value)
return {}

View File

@@ -27,6 +27,7 @@ from ..retrieval import (
GraphRelationRecallConfig, GraphRelationRecallConfig,
) )
from ..storage import MetadataStore, GraphStore, VectorStore from ..storage import MetadataStore, GraphStore, VectorStore
from .metadata import coerce_metadata_dict
logger = get_logger("A_Memorix.PersonProfileService") logger = get_logger("A_Memorix.PersonProfileService")
@@ -334,7 +335,7 @@ class PersonProfileService:
if not pid: if not pid:
return False return False
metadata = self._metadata_dict(relation.get("metadata")) metadata = coerce_metadata_dict(relation.get("metadata"))
if str(metadata.get("person_id", "") or "").strip() == pid: if str(metadata.get("person_id", "") or "").strip() == pid:
return True return True
if pid in self._list_tokens(metadata.get("person_ids")): if pid in self._list_tokens(metadata.get("person_ids")):
@@ -350,7 +351,7 @@ class PersonProfileService:
payload = { payload = {
"hash": source_paragraph, "hash": source_paragraph,
"source": str(paragraph.get("source", "") or ""), "source": str(paragraph.get("source", "") or ""),
"metadata": self._metadata_dict(paragraph.get("metadata")), "metadata": coerce_metadata_dict(paragraph.get("metadata")),
} }
return self._is_evidence_bound_to_person(payload, person_id=pid) return self._is_evidence_bound_to_person(payload, person_id=pid)
@@ -385,15 +386,11 @@ class PersonProfileService:
"score": 1.1, "score": 1.1,
"content": content[:220], "content": content[:220],
"source": str(row.get("source", "") or source), "source": str(row.get("source", "") or source),
"metadata": dict(row.get("metadata", {}) or {}), "metadata": coerce_metadata_dict(row.get("metadata")),
} }
) )
return self._filter_stale_paragraph_evidence(evidence) return self._filter_stale_paragraph_evidence(evidence)
@staticmethod
def _metadata_dict(value: Any) -> Dict[str, Any]:
return dict(value) if isinstance(value, dict) else {}
@staticmethod @staticmethod
def _list_tokens(value: Any) -> List[str]: def _list_tokens(value: Any) -> List[str]:
if value is None: if value is None:
@@ -414,7 +411,7 @@ class PersonProfileService:
if not pid: if not pid:
return False return False
metadata = self._metadata_dict(item.get("metadata")) metadata = coerce_metadata_dict(item.get("metadata"))
source = str(item.get("source", "") or metadata.get("source", "") or "").strip() source = str(item.get("source", "") or metadata.get("source", "") or "").strip()
if source == f"person_fact:{pid}": if source == f"person_fact:{pid}":
return True return True
@@ -440,15 +437,15 @@ class PersonProfileService:
paragraph_hash: str, paragraph_hash: str,
metadata: Dict[str, Any], metadata: Dict[str, Any],
) -> Tuple[Dict[str, Any], str]: ) -> Tuple[Dict[str, Any], str]:
merged = self._metadata_dict(metadata) merged = coerce_metadata_dict(metadata)
source = str(merged.get("source", "") or "").strip() source = str(merged.get("source", "") or "").strip()
try: try:
paragraph = self.metadata_store.get_paragraph(paragraph_hash) paragraph = self.metadata_store.get_paragraph(paragraph_hash)
except Exception: except Exception:
paragraph = None paragraph = None
if isinstance(paragraph, dict): if isinstance(paragraph, dict):
paragraph_metadata = paragraph.get("metadata", {}) or {} paragraph_metadata = coerce_metadata_dict(paragraph.get("metadata"))
if isinstance(paragraph_metadata, dict): if paragraph_metadata:
merged = {**paragraph_metadata, **merged} merged = {**paragraph_metadata, **merged}
source = source or str(paragraph.get("source", "") or "").strip() source = source or str(paragraph.get("source", "") or "").strip()
source_type = str(merged.get("source_type", "") or "").strip() or self._source_type_from_source(source) source_type = str(merged.get("source_type", "") or "").strip() or self._source_type_from_source(source)
@@ -538,7 +535,7 @@ class PersonProfileService:
"score": 0.0, "score": 0.0,
"content": str(para.get("content", ""))[:180], "content": str(para.get("content", ""))[:180],
"source": str(para.get("source", "") or ""), "source": str(para.get("source", "") or ""),
"metadata": self._metadata_dict(para.get("metadata")), "metadata": coerce_metadata_dict(para.get("metadata")),
} }
) )
if not self._is_evidence_bound_to_person(fallback[-1], person_id=person_id): if not self._is_evidence_bound_to_person(fallback[-1], person_id=person_id):
@@ -562,18 +559,18 @@ class PersonProfileService:
logger.warning(f"向量证据召回失败: alias={alias}, err={e}") logger.warning(f"向量证据召回失败: alias={alias}, err={e}")
continue continue
for item in results: for item in results:
h = str(getattr(item, "hash_value", "") or "") h = str(item.hash_value or "")
if not h or h in seen_hash: if not h or h in seen_hash:
continue continue
metadata, source = self._enrich_paragraph_evidence_metadata( metadata, source = self._enrich_paragraph_evidence_metadata(
h, h,
self._metadata_dict(getattr(item, "metadata", {})), coerce_metadata_dict(item.metadata),
) )
payload = { payload = {
"hash": h, "hash": h,
"type": str(getattr(item, "result_type", "")), "type": str(item.result_type),
"score": float(getattr(item, "score", 0.0) or 0.0), "score": float(item.score or 0.0),
"content": str(getattr(item, "content", "") or "")[:220], "content": str(item.content or "")[:220],
"source": source, "source": source,
"metadata": metadata, "metadata": metadata,
} }

View File

@@ -14,7 +14,8 @@ from typing import Any, Dict, List, Optional, Tuple
from src.common.logger import get_logger from src.common.logger import get_logger
from ..retrieval import TemporalQueryOptions from ..retrieval import RetrievalResult, TemporalQueryOptions
from .metadata import coerce_metadata_dict
from .search_postprocess import ( from .search_postprocess import (
apply_safe_content_dedup, apply_safe_content_dedup,
maybe_apply_smart_path_fallback, maybe_apply_smart_path_fallback,
@@ -286,8 +287,11 @@ class SearchExecutionService:
) )
async def _executor() -> Dict[str, Any]: async def _executor() -> Dict[str, Any]:
original_ppr = bool(getattr(retriever.config, "enable_ppr", True)) retriever_config = getattr(retriever, "config", None)
setattr(retriever.config, "enable_ppr", bool(request.enable_ppr)) has_runtime_ppr_switch = retriever_config is not None and hasattr(retriever_config, "enable_ppr")
original_ppr = bool(retriever_config.enable_ppr) if has_runtime_ppr_switch else None
if has_runtime_ppr_switch:
retriever_config.enable_ppr = bool(request.enable_ppr)
started_at = time.time() started_at = time.time()
try: try:
retrieved = await retriever.retrieve( retrieved = await retriever.retrieve(
@@ -380,7 +384,8 @@ class SearchExecutionService:
elapsed_ms = (time.time() - started_at) * 1000.0 elapsed_ms = (time.time() - started_at) * 1000.0
return {"results": retrieved, "elapsed_ms": elapsed_ms} return {"results": retrieved, "elapsed_ms": elapsed_ms}
finally: finally:
setattr(retriever.config, "enable_ppr", original_ppr) if has_runtime_ppr_switch:
retriever_config.enable_ppr = bool(original_ppr)
dedup_hit = False dedup_hit = False
try: try:
@@ -421,18 +426,18 @@ class SearchExecutionService:
) )
@staticmethod @staticmethod
def to_serializable_results(results: List[Any]) -> List[Dict[str, Any]]: def to_serializable_results(results: List[RetrievalResult]) -> List[Dict[str, Any]]:
serialized: List[Dict[str, Any]] = [] serialized: List[Dict[str, Any]] = []
for item in results: for item in results:
metadata = dict(getattr(item, "metadata", {}) or {}) metadata = coerce_metadata_dict(item.metadata)
if "time_meta" not in metadata: if "time_meta" not in metadata:
metadata["time_meta"] = {} metadata["time_meta"] = {}
serialized.append( serialized.append(
{ {
"hash": getattr(item, "hash_value", ""), "hash": item.hash_value,
"type": getattr(item, "result_type", ""), "type": item.result_type,
"score": float(getattr(item, "score", 0.0)), "score": float(item.score),
"content": getattr(item, "content", ""), "content": item.content,
"metadata": metadata, "metadata": metadata,
} }
) )

View File

@@ -1,31 +1,39 @@
import asyncio from collections import OrderedDict
import traceback
from typing import Dict from typing import Dict
import asyncio
import time
import traceback
from src.chat.message_receive.chat_manager import chat_manager from src.chat.message_receive.chat_manager import chat_manager
from src.common.logger import get_logger from src.common.logger import get_logger
from src.maisaka.runtime import MaisakaHeartFlowChatting from src.maisaka.runtime import MaisakaHeartFlowChatting
logger = get_logger("heartflow") logger = get_logger("heartflow")
HEARTFLOW_ACTIVE_RETENTION_SECONDS = 24 * 60 * 60
HEARTFLOW_MAX_ACTIVE_CHATS = 100
class HeartflowManager: class HeartflowManager:
"""管理 session 级别的 Maisaka 心流实例。""" """管理 session 级别的 Maisaka 心流实例。"""
def __init__(self) -> None: def __init__(self) -> None:
self.heartflow_chat_list: Dict[str, MaisakaHeartFlowChatting] = {} self.heartflow_chat_list: OrderedDict[str, MaisakaHeartFlowChatting] = OrderedDict()
self._chat_create_locks: Dict[str, asyncio.Lock] = {} self._chat_create_locks: Dict[str, asyncio.Lock] = {}
self._chat_last_active_at: Dict[str, float] = {}
async def get_or_create_heartflow_chat(self, session_id: str) -> MaisakaHeartFlowChatting: async def get_or_create_heartflow_chat(self, session_id: str) -> MaisakaHeartFlowChatting:
"""获取或创建指定会话对应的 Maisaka runtime。""" """获取或创建指定会话对应的 Maisaka runtime。"""
try: try:
if chat := self.heartflow_chat_list.get(session_id): if chat := self.heartflow_chat_list.get(session_id):
self._touch_chat(session_id)
return chat return chat
create_lock = self._chat_create_locks.setdefault(session_id, asyncio.Lock()) create_lock = self._chat_create_locks.setdefault(session_id, asyncio.Lock())
async with create_lock: async with create_lock:
if chat := self.heartflow_chat_list.get(session_id): if chat := self.heartflow_chat_list.get(session_id):
self._touch_chat(session_id)
return chat return chat
chat_session = chat_manager.get_session_by_session_id(session_id) chat_session = chat_manager.get_session_by_session_id(session_id)
@@ -35,16 +43,59 @@ class HeartflowManager:
new_chat = MaisakaHeartFlowChatting(session_id=session_id) new_chat = MaisakaHeartFlowChatting(session_id=session_id)
await new_chat.start() await new_chat.start()
self.heartflow_chat_list[session_id] = new_chat self.heartflow_chat_list[session_id] = new_chat
self._touch_chat(session_id)
await self._evict_over_limit_chats(protected_session_id=session_id)
return new_chat return new_chat
except Exception as exc: except Exception as exc:
logger.error(f"创建心流聊天 {session_id} 失败: {exc}", exc_info=True) logger.error(f"创建心流聊天 {session_id} 失败: {exc}", exc_info=True)
traceback.print_exc() traceback.print_exc()
raise raise
def _touch_chat(self, session_id: str) -> None:
"""记录会话最近活跃时间,并维护心流实例的 LRU 顺序。"""
self._chat_last_active_at[session_id] = time.time()
self.heartflow_chat_list.move_to_end(session_id)
async def _evict_over_limit_chats(self, *, protected_session_id: str) -> None:
"""当实例数量超过上限时,仅淘汰 24 小时内无消息的旧会话。"""
while len(self.heartflow_chat_list) > HEARTFLOW_MAX_ACTIVE_CHATS:
session_id = self._find_evictable_session_id(protected_session_id=protected_session_id)
if session_id is None:
return
await self._evict_chat(session_id, reason="cache_limit")
def _find_evictable_session_id(self, *, protected_session_id: str) -> str | None:
"""按 LRU 查找超过活跃保护窗口的可淘汰会话。"""
expire_before = time.time() - HEARTFLOW_ACTIVE_RETENTION_SECONDS
for session_id in self.heartflow_chat_list:
if session_id == protected_session_id:
continue
last_active_at = self._chat_last_active_at.get(session_id, 0.0)
if last_active_at <= expire_before:
return session_id
return None
async def _evict_chat(self, session_id: str, *, reason: str) -> None:
"""停止并移除指定会话的心流实例。"""
chat = self.heartflow_chat_list.pop(session_id, None)
self._chat_last_active_at.pop(session_id, None)
lock = self._chat_create_locks.get(session_id)
if lock is not None and not lock.locked():
self._chat_create_locks.pop(session_id, None)
if chat is None:
return
try:
await chat.stop()
logger.info(f"已淘汰心流聊天 {session_id}: reason={reason}")
except Exception as exc:
logger.warning(f"淘汰心流聊天 {session_id} 失败: {exc}", exc_info=True)
def adjust_talk_frequency(self, session_id: str, frequency: float) -> None: def adjust_talk_frequency(self, session_id: str, frequency: float) -> None:
"""调整指定聊天流的说话频率。""" """调整指定聊天流的说话频率。"""
chat = self.heartflow_chat_list.get(session_id) chat = self.heartflow_chat_list.get(session_id)
if chat: if chat:
self._touch_chat(session_id)
chat.adjust_talk_frequency(frequency) chat.adjust_talk_frequency(frequency)
logger.info(f"已调整聊天 {session_id} 的说话频率为 {frequency}") logger.info(f"已调整聊天 {session_id} 的说话频率为 {frequency}")
else: else:

View File

@@ -86,7 +86,6 @@ class BaseMaisakaReplyGenerator:
request_type=request_type, request_type=request_type,
session_id=getattr(chat_stream, "session_id", "") if chat_stream is not None else "", session_id=getattr(chat_stream, "session_id", "") if chat_stream is not None else "",
) )
self._personality_prompt = self._build_personality_prompt()
def _build_personality_prompt(self) -> str: def _build_personality_prompt(self) -> str:
"""构建 replyer 使用的人设提示。""" """构建 replyer 使用的人设提示。"""
@@ -272,7 +271,7 @@ class BaseMaisakaReplyGenerator:
bot_name=global_config.bot.nickname, bot_name=global_config.bot.nickname,
group_chat_attention_block=self._build_group_chat_attention_block(session_id), group_chat_attention_block=self._build_group_chat_attention_block(session_id),
replyer_at_block=self._build_replyer_at_block(), replyer_at_block=self._build_replyer_at_block(),
identity=self._personality_prompt, identity=self._build_personality_prompt(),
reply_style=self._select_reply_style(), reply_style=self._select_reply_style(),
) )
except Exception: except Exception:

View File

@@ -56,7 +56,7 @@ BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute() MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute() LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute()
A_MEMORIX_LEGACY_CONFIG_PATH: Path = (CONFIG_DIR / "a_memorix.toml").resolve().absolute() A_MEMORIX_LEGACY_CONFIG_PATH: Path = (CONFIG_DIR / "a_memorix.toml").resolve().absolute()
MMC_VERSION: str = "1.0.0-pre.16" MMC_VERSION: str = "1.0.0-pre.17"
CONFIG_VERSION: str = "8.10.15" CONFIG_VERSION: str = "8.10.15"
MODEL_CONFIG_VERSION: str = "1.16.1" MODEL_CONFIG_VERSION: str = "1.16.1"
@@ -234,6 +234,7 @@ class ConfigManager:
self._hot_reload_min_interval_s: float = 1.0 self._hot_reload_min_interval_s: float = 1.0
self._hot_reload_timeout_s: float = 20.0 self._hot_reload_timeout_s: float = 20.0
self._last_hot_reload_monotonic: float = 0.0 self._last_hot_reload_monotonic: float = 0.0
self.reload_revision: int = 0
def initialize(self): def initialize(self):
logger.info(t("config.current_version", version=MMC_VERSION)) logger.info(t("config.current_version", version=MMC_VERSION))
@@ -424,9 +425,7 @@ class ConfigManager:
self.global_config = global_config_new self.global_config = global_config_new
self.model_config = model_config_new self.model_config = model_config_new
global global_config, model_config self.reload_revision += 1
global_config = global_config_new
model_config = model_config_new
logger.info(t("config.hot_reload_completed")) logger.info(t("config.hot_reload_completed"))
for callback in list(self._reload_callbacks): for callback in list(self._reload_callbacks):
@@ -657,8 +656,30 @@ def write_config_to_file(
tomlkit.dump(full_config_data, f) tomlkit.dump(full_config_data, f)
class _ConfigProxy:
"""稳定配置代理,确保热重载后旧导入也能读取最新配置。"""
def __init__(self, getter: Callable[[], ConfigBase]) -> None:
self._getter = getter
def __getattr__(self, name: str) -> Any:
return getattr(self._getter(), name)
def __getitem__(self, key: str) -> Any:
return self._getter()[key]
def __setattr__(self, name: str, value: Any) -> None:
if name == "_getter":
object.__setattr__(self, name, value)
return
setattr(self._getter(), name, value)
def __repr__(self) -> str:
return repr(self._getter())
# generate_new_config_file(Config, BOT_CONFIG_PATH, CONFIG_VERSION) # generate_new_config_file(Config, BOT_CONFIG_PATH, CONFIG_VERSION)
config_manager = ConfigManager() config_manager = ConfigManager()
config_manager.initialize() config_manager.initialize()
global_config = config_manager.get_global_config() global_config: Config = cast(Config, _ConfigProxy(config_manager.get_global_config))
model_config = config_manager.get_model_config() model_config: ModelConfig = cast(ModelConfig, _ConfigProxy(config_manager.get_model_config))

View File

@@ -160,6 +160,23 @@ class ExpressionLearner:
self._last_processed_index = 0 self._last_processed_index = 0
self.min_messages_for_extraction = 10 self.min_messages_for_extraction = 10
@property
def last_processed_index(self) -> int:
"""返回表达学习已经消费到的消息缓存下标。"""
return self._last_processed_index
def mark_all_processed(self, message_cache: List["SessionMessage"]) -> None:
"""在跳过表达学习时,将现有消息标记为已处理,避免阻塞缓存裁剪。"""
self._last_processed_index = len(message_cache)
def mark_processed_until(self, processed_end_index: int) -> None:
"""将指定缓存下标之前的消息标记为已处理。"""
self._last_processed_index = max(self._last_processed_index, processed_end_index)
def discard_processed_prefix(self, removed_count: int) -> None:
"""同步 runtime 对消息缓存前缀的裁剪。"""
self._last_processed_index = max(0, self._last_processed_index - removed_count)
@staticmethod @staticmethod
def _get_runtime_manager() -> Any: def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。 """获取插件运行时管理器。
@@ -274,7 +291,8 @@ class ExpressionLearner:
jargon_miner: Optional["JargonMiner"] = None, jargon_miner: Optional["JargonMiner"] = None,
) -> bool: ) -> bool:
"""学习表达方式""" """学习表达方式"""
pending_messages = message_cache[self._last_processed_index :] processed_end_index = len(message_cache)
pending_messages = message_cache[self._last_processed_index : processed_end_index]
if not pending_messages: if not pending_messages:
logger.debug("没有待处理消息") logger.debug("没有待处理消息")
return False return False
@@ -303,6 +321,7 @@ class ExpressionLearner:
response = generation_result.response response = generation_result.response
except Exception as e: except Exception as e:
logger.error(f"学习表达方式失败: {e}") logger.error(f"学习表达方式失败: {e}")
self._last_processed_index = processed_end_index
return False return False
expressions: List[Tuple[str, str, str]] expressions: List[Tuple[str, str, str]]
@@ -336,7 +355,7 @@ class ExpressionLearner:
) )
if after_extract_result.aborted: if after_extract_result.aborted:
logger.info(f"{self.session_id} 表达方式选择 Hook 中止") logger.info(f"{self.session_id} 表达方式选择 Hook 中止")
self._last_processed_index = len(message_cache) self._last_processed_index = processed_end_index
return False return False
after_extract_kwargs = after_extract_result.kwargs after_extract_kwargs = after_extract_result.kwargs
@@ -352,7 +371,7 @@ class ExpressionLearner:
if not expressions: if not expressions:
logger.info("没有可学习的表达方式") logger.info("没有可学习的表达方式")
self._last_processed_index = len(message_cache) self._last_processed_index = processed_end_index
return False return False
logger.info(f"可学习的表达方式: {expressions}") logger.info(f"可学习的表达方式: {expressions}")
@@ -361,7 +380,7 @@ class ExpressionLearner:
learnt_expressions = self._filter_expressions(expressions, pending_messages) learnt_expressions = self._filter_expressions(expressions, pending_messages)
if not learnt_expressions: if not learnt_expressions:
logger.info("没有可学习的表达方式通过过滤") logger.info("没有可学习的表达方式通过过滤")
self._last_processed_index = len(message_cache) self._last_processed_index = processed_end_index
return False return False
learnt_expressions_str = "\n".join(f"{situation}->{style}" for situation, style in learnt_expressions) learnt_expressions_str = "\n".join(f"{situation}->{style}" for situation, style in learnt_expressions)
@@ -386,7 +405,7 @@ class ExpressionLearner:
continue continue
await self._upsert_expression_to_db(situation, style) await self._upsert_expression_to_db(situation, style)
self._last_processed_index = len(message_cache) self._last_processed_index = processed_end_index
return True return True
def _check_cached_jargons_in_messages( def _check_cached_jargons_in_messages(

View File

@@ -149,7 +149,7 @@ class BuiltinToolRuntimeContext:
def _build_at_component_for_message_id(self, message_id: str) -> Optional[AtComponent]: def _build_at_component_for_message_id(self, message_id: str) -> Optional[AtComponent]:
"""根据消息编号构造 at 组件。""" """根据消息编号构造 at 组件。"""
target_message = self.runtime._source_messages_by_id.get(message_id) target_message = self.runtime.find_source_message_by_id(message_id)
if target_message is None: if target_message is None:
return None return None

View File

@@ -113,7 +113,7 @@ async def handle_tool(
"reply 工具需要提供有效的 `msg_id` 参数。", "reply 工具需要提供有效的 `msg_id` 参数。",
) )
target_message = tool_ctx.runtime._source_messages_by_id.get(target_message_id) target_message = tool_ctx.runtime.find_source_message_by_id(target_message_id)
if target_message is None: if target_message is None:
return tool_ctx.build_failure_result( return tool_ctx.build_failure_result(
invocation.tool_name, invocation.tool_name,
@@ -263,6 +263,7 @@ async def handle_tool(
target_user_info = target_message.message_info.user_info target_user_info = target_message.message_info.user_info
target_user_name = target_user_info.user_cardname or target_user_info.user_nickname or target_user_info.user_id target_user_name = target_user_info.user_cardname or target_user_info.user_nickname or target_user_info.user_id
bot_name = config_module.global_config.bot.nickname.strip() or "MaiSaka"
if tool_ctx.runtime.chat_stream.platform == CLI_PLATFORM_NAME: if tool_ctx.runtime.chat_stream.platform == CLI_PLATFORM_NAME:
tool_ctx.append_guided_reply_to_chat_history(combined_reply_text) tool_ctx.append_guided_reply_to_chat_history(combined_reply_text)
@@ -291,7 +292,7 @@ async def handle_tool(
) )
return tool_ctx.build_success_result( return tool_ctx.build_success_result(
invocation.tool_name, invocation.tool_name,
f'已生成并发送回复"{combined_reply_text}"\n发送对象:{target_user_name}', f'"{bot_name}"已生成并向"{target_user_name}"发送了回复"{combined_reply_text}"',
structured_content={ structured_content={
"msg_id": target_message_id, "msg_id": target_message_id,
"set_quote": set_quote, "set_quote": set_quote,

View File

@@ -52,7 +52,7 @@ async def handle_tool(
"查看复杂消息工具需要提供有效的 `msg_id` 参数。", "查看复杂消息工具需要提供有效的 `msg_id` 参数。",
) )
target_message = tool_ctx.runtime._source_messages_by_id.get(target_message_id) target_message = tool_ctx.runtime.find_source_message_by_id(target_message_id)
if target_message is None: if target_message is None:
return tool_ctx.build_failure_result( return tool_ctx.build_failure_result(
invocation.tool_name, invocation.tool_name,

View File

@@ -10,7 +10,7 @@ from rich.console import RenderableType
from src.common.data_models.llm_service_data_models import LLMGenerationOptions from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.common.i18n import get_locale from src.common.i18n import get_locale
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.prompt_i18n import get_prompt_cache_revision, load_prompt from src.common.prompt_i18n import load_prompt
from src.common.utils.utils_config import ChatConfigUtils from src.common.utils.utils_config import ChatConfigUtils
from src.config.config import global_config from src.config.config import global_config
from src.core.tooling import ToolAvailabilityContext, ToolRegistry from src.core.tooling import ToolAvailabilityContext, ToolRegistry
@@ -218,21 +218,15 @@ class MaisakaChatLoopService:
self._extra_tools: List[ToolOption] = [] self._extra_tools: List[ToolOption] = []
self._interrupt_flag: asyncio.Event | None = None self._interrupt_flag: asyncio.Event | None = None
self._tool_registry: ToolRegistry | None = None self._tool_registry: ToolRegistry | None = None
self._prompts_loaded = chat_system_prompt is not None self._custom_chat_system_prompt = chat_system_prompt
self._prompt_cache_revision = get_prompt_cache_revision()
self._prompt_load_lock = asyncio.Lock() self._prompt_load_lock = asyncio.Lock()
self._personality_prompt = self._build_personality_prompt()
if chat_system_prompt is None:
self._chat_system_prompt = f"{self._personality_prompt}\n\nYou are a helpful AI assistant."
else:
self._chat_system_prompt = chat_system_prompt
self._llm_chat_clients: dict[str, LLMServiceClient] = {} self._llm_chat_clients: dict[str, LLMServiceClient] = {}
@property @property
def personality_prompt(self) -> str: def personality_prompt(self) -> str:
"""返回当前人格提示词。""" """返回当前人格提示词。"""
return self._personality_prompt return self._build_personality_prompt()
@staticmethod @staticmethod
def _resolve_llm_request_type(request_kind: str) -> str: def _resolve_llm_request_type(request_kind: str) -> str:
@@ -349,13 +343,15 @@ class MaisakaChatLoopService:
tools_section: 额外注入到提示词中的工具说明片段。 tools_section: 额外注入到提示词中的工具说明片段。
""" """
async with self._prompt_load_lock: async with self._prompt_load_lock:
try: self._build_chat_system_prompt(tools_section)
self._chat_system_prompt = load_prompt("maisaka_chat", **self.build_prompt_template_context(tools_section))
except Exception:
self._chat_system_prompt = f"{self._personality_prompt}\n\nYou are a helpful AI assistant."
self._prompts_loaded = True def _build_chat_system_prompt(self, tools_section: str = "") -> str:
self._prompt_cache_revision = get_prompt_cache_revision() """基于当前配置实时构造主聊天系统提示词。"""
try:
return load_prompt("maisaka_chat", **self.build_prompt_template_context(tools_section))
except Exception:
return f"{self.personality_prompt}\n\nYou are a helpful AI assistant."
def build_prompt_template_context(self, tools_section: str = "") -> dict[str, str]: def build_prompt_template_context(self, tools_section: str = "") -> dict[str, str]:
"""构造 Maisaka prompt 模板的公共渲染参数。""" """构造 Maisaka prompt 模板的公共渲染参数。"""
@@ -364,7 +360,7 @@ class MaisakaChatLoopService:
"bot_name": global_config.bot.nickname, "bot_name": global_config.bot.nickname,
"file_tools_section": tools_section, "file_tools_section": tools_section,
"group_chat_attention_block": self._build_group_chat_attention_block(), "group_chat_attention_block": self._build_group_chat_attention_block(),
"identity": self._personality_prompt, "identity": self.personality_prompt,
"timing_gate_wait_rule": self._build_timing_gate_wait_rule(), "timing_gate_wait_rule": self._build_timing_gate_wait_rule(),
"time_block": self._build_time_block(), "time_block": self._build_time_block(),
} }
@@ -471,7 +467,13 @@ class MaisakaChatLoopService:
messages: List[Message] = [] messages: List[Message] = []
system_msg = MessageBuilder().set_role(RoleType.System) system_msg = MessageBuilder().set_role(RoleType.System)
system_msg.add_text_content(system_prompt if system_prompt is not None else self._chat_system_prompt) if system_prompt is not None:
resolved_system_prompt = system_prompt
elif self._custom_chat_system_prompt is not None:
resolved_system_prompt = self._custom_chat_system_prompt
else:
resolved_system_prompt = self._build_chat_system_prompt()
system_msg.add_text_content(resolved_system_prompt)
messages.append(system_msg.build()) messages.append(system_msg.build())
for msg in selected_history: for msg in selected_history:
@@ -521,8 +523,6 @@ class MaisakaChatLoopService:
ChatResponse: 本轮规划器返回结果。 ChatResponse: 本轮规划器返回结果。
""" """
if not self._prompts_loaded or self._prompt_cache_revision != get_prompt_cache_revision():
await self.ensure_chat_prompt_loaded()
enable_visual_message = self._resolve_enable_visual_message(request_kind) enable_visual_message = self._resolve_enable_visual_message(request_kind)
selected_history, selection_reason = self.select_llm_context_messages( selected_history, selection_reason = self.select_llm_context_messages(
chat_history, chat_history,

View File

@@ -29,6 +29,7 @@ from src.learners.jargon_miner import JargonMiner
from src.llm_models.payload_content.resp_format import RespFormat from src.llm_models.payload_content.resp_format import RespFormat
from src.llm_models.payload_content.tool_option import ToolDefinitionInput from src.llm_models.payload_content.tool_option import ToolDefinitionInput
from src.mcp_module import MCPManager from src.mcp_module import MCPManager
from src.mcp_module.config import build_mcp_server_runtime_configs
from src.mcp_module.host_llm_bridge import MCPHostLLMBridge from src.mcp_module.host_llm_bridge import MCPHostLLMBridge
from src.mcp_module.provider import MCPToolProvider from src.mcp_module.provider import MCPToolProvider
from src.plugin_runtime.tool_provider import PluginToolProvider from src.plugin_runtime.tool_provider import PluginToolProvider
@@ -56,6 +57,7 @@ from .tool_provider import MaisakaBuiltinToolProvider
logger = get_logger("maisaka_runtime") logger = get_logger("maisaka_runtime")
MAX_INTERNAL_ROUNDS = 10 MAX_INTERNAL_ROUNDS = 10
MAX_RETAINED_MESSAGE_CACHE_SIZE = 200
class MaisakaHeartFlowChatting: class MaisakaHeartFlowChatting:
@@ -90,7 +92,6 @@ class MaisakaHeartFlowChatting:
self._mcp_manager: Optional[MCPManager] = None self._mcp_manager: Optional[MCPManager] = None
self._mcp_host_bridge: Optional[MCPHostLLMBridge] = None self._mcp_host_bridge: Optional[MCPHostLLMBridge] = None
self._current_cycle_detail: Optional[CycleDetail] = None self._current_cycle_detail: Optional[CycleDetail] = None
self._source_messages_by_id: dict[str, SessionMessage] = {}
self._running = False self._running = False
self._cycle_counter = 0 self._cycle_counter = 0
self._internal_loop_task: Optional[asyncio.Task] = None self._internal_loop_task: Optional[asyncio.Task] = None
@@ -98,19 +99,13 @@ class MaisakaHeartFlowChatting:
self._deferred_message_turn_task: Optional[asyncio.Task[None]] = None self._deferred_message_turn_task: Optional[asyncio.Task[None]] = None
self._message_debounce_seconds = 1.0 self._message_debounce_seconds = 1.0
self._message_debounce_required = False self._message_debounce_required = False
self._message_received_at_by_id: dict[str, float] = {} self._oldest_pending_message_received_at: Optional[float] = None
self._last_message_received_at = 0.0 self._last_message_received_at = 0.0
self._talk_frequency_adjust = 1.0 self._talk_frequency_adjust = 1.0
self._reply_latency_measurement_started_at: Optional[float] = None self._reply_latency_measurement_started_at: Optional[float] = None
self._recent_reply_latencies: deque[tuple[float, float]] = deque() self._recent_reply_latencies: deque[tuple[float, float]] = deque()
self._wait_timeout_task: Optional[asyncio.Task[None]] = None self._wait_timeout_task: Optional[asyncio.Task[None]] = None
self._max_internal_rounds = MAX_INTERNAL_ROUNDS self._max_internal_rounds = MAX_INTERNAL_ROUNDS
configured_context_size = (
global_config.chat.max_context_size
if self.chat_stream.is_group_session
else global_config.chat.max_private_context_size
)
self._max_context_size = max(1, int(configured_context_size))
self._agent_state: Literal["running", "wait", "stop"] = self._STATE_STOP self._agent_state: Literal["running", "wait", "stop"] = self._STATE_STOP
self._pending_wait_tool_call_id: Optional[str] = None self._pending_wait_tool_call_id: Optional[str] = None
self._force_next_timing_continue = False self._force_next_timing_continue = False
@@ -152,6 +147,17 @@ class MaisakaHeartFlowChatting:
self._register_tool_providers() self._register_tool_providers()
self._emit_monitor_session_start() self._emit_monitor_session_start()
@property
def _max_context_size(self) -> int:
"""返回当前会话实时生效的上下文窗口大小。"""
configured_context_size = (
global_config.chat.max_context_size
if self.chat_stream.is_group_session
else global_config.chat.max_private_context_size
)
return max(1, int(configured_context_size))
def _emit_monitor_session_start(self) -> None: def _emit_monitor_session_start(self) -> None:
"""向 WebUI 监控面板同步当前会话的展示标识。""" """向 WebUI 监控面板同步当前会话的展示标识。"""
@@ -312,10 +318,11 @@ class MaisakaHeartFlowChatting:
self._ensure_background_tasks_running() self._ensure_background_tasks_running()
received_at = time.time() received_at = time.time()
self._last_message_received_at = received_at self._last_message_received_at = received_at
if self._oldest_pending_message_received_at is None:
self._oldest_pending_message_received_at = received_at
self._update_message_trigger_state(message) self._update_message_trigger_state(message)
self.message_cache.append(message) self.message_cache.append(message)
self._message_received_at_by_id[message.message_id] = received_at self._prune_processed_message_cache()
self._source_messages_by_id[message.message_id] = message
if self._is_reply_effect_tracking_enabled(): if self._is_reply_effect_tracking_enabled():
asyncio.create_task(self._reply_effect_tracker.observe_user_message(message)) asyncio.create_task(self._reply_effect_tracker.observe_user_message(message))
if self._agent_state == self._STATE_RUNNING: if self._agent_state == self._STATE_RUNNING:
@@ -487,6 +494,45 @@ class MaisakaHeartFlowChatting:
f"最近10分钟样本数={len(self._recent_reply_latencies)}" f"最近10分钟样本数={len(self._recent_reply_latencies)}"
) )
def find_source_message_by_id(self, message_id: str) -> Optional[SessionMessage]:
"""从 Maisaka 历史中查找指定消息编号对应的原始消息。"""
normalized_message_id = str(message_id or "").strip()
if not normalized_message_id:
return None
for history_message in reversed(self._chat_history):
if str(getattr(history_message, "message_id", "") or "").strip() != normalized_message_id:
continue
original_message = getattr(history_message, "original_message", None)
if original_message is None:
continue
return original_message
return None
def _prune_processed_message_cache(self) -> None:
"""裁剪 runtime 与表达学习器都已经消费过的旧消息。"""
excess_count = len(self.message_cache) - MAX_RETAINED_MESSAGE_CACHE_SIZE
if excess_count <= 0:
return
removable_count = min(
excess_count,
self._last_processed_index,
self._expression_learner.last_processed_index,
)
if removable_count <= 0:
return
del self.message_cache[:removable_count]
self._last_processed_index = max(0, self._last_processed_index - removable_count)
self._expression_learner.discard_processed_prefix(removable_count)
logger.debug(
f"{self.log_prefix} 已清理 Maisaka 旧消息缓存: "
f"清理数量={removable_count} 保留数量={len(self.message_cache)}"
)
def _should_trigger_message_turn_by_idle_compensation( def _should_trigger_message_turn_by_idle_compensation(
self, self,
*, *,
@@ -637,6 +683,7 @@ class MaisakaHeartFlowChatting:
return return
if self._internal_loop_task is None or self._internal_loop_task.done(): if self._internal_loop_task is None or self._internal_loop_task.done():
is_restart = self._internal_loop_task is not None
if self._internal_loop_task is not None and not self._internal_loop_task.cancelled(): if self._internal_loop_task is not None and not self._internal_loop_task.cancelled():
try: try:
exc = self._internal_loop_task.exception() exc = self._internal_loop_task.exception()
@@ -645,7 +692,10 @@ class MaisakaHeartFlowChatting:
if exc is not None: if exc is not None:
logger.error(f"{self.log_prefix} 内部循环任务异常退出: {exc}") logger.error(f"{self.log_prefix} 内部循环任务异常退出: {exc}")
self._internal_loop_task = asyncio.create_task(self._reasoning_engine.run_loop()) self._internal_loop_task = asyncio.create_task(self._reasoning_engine.run_loop())
logger.warning(f"{self.log_prefix} 已重新拉起 Maisaka 内部循环任务") if is_restart:
logger.warning(f"{self.log_prefix} 已重新拉起 Maisaka 内部循环任务")
else:
logger.debug(f"{self.log_prefix} 已启动 Maisaka 内部循环任务")
def _register_tool_providers(self) -> None: def _register_tool_providers(self) -> None:
"""注册 Maisaka 运行时默认启用的工具 Provider。""" """注册 Maisaka 运行时默认启用的工具 Provider。"""
@@ -1001,12 +1051,10 @@ class MaisakaHeartFlowChatting:
# f"收集 {len(unique_messages)} 条新消息" # f"收集 {len(unique_messages)} 条新消息"
# ) # )
if unique_messages and self._reply_latency_measurement_started_at is None: if unique_messages and self._reply_latency_measurement_started_at is None:
self._reply_latency_measurement_started_at = min( self._reply_latency_measurement_started_at = (
self._message_received_at_by_id.get(message.message_id, self._last_message_received_at) self._oldest_pending_message_received_at or self._last_message_received_at
for message in unique_messages
) )
for message in unique_messages: self._oldest_pending_message_received_at = None
self._message_received_at_by_id.pop(message.message_id, None)
return unique_messages return unique_messages
async def _wait_for_message_quiet_period(self) -> None: async def _wait_for_message_quiet_period(self) -> None:
@@ -1075,10 +1123,19 @@ class MaisakaHeartFlowChatting:
async def _trigger_batch_learning(self, messages: list[SessionMessage]) -> None: async def _trigger_batch_learning(self, messages: list[SessionMessage]) -> None:
"""按同一批消息触发表达方式和黑话学习。""" """按同一批消息触发表达方式和黑话学习。"""
processed_end_index = len(self.message_cache)
if not self._enable_expression_learning:
self._expression_learner.mark_all_processed(self.message_cache)
self._prune_processed_message_cache()
return
try: try:
await self._trigger_expression_learning(messages) await self._trigger_expression_learning(messages)
except Exception as exc: except Exception as exc:
logger.error(f"{self.log_prefix} 表达学习任务异常退出: {exc}") logger.error(f"{self.log_prefix} 表达学习任务异常退出: {exc}")
self._expression_learner.mark_processed_until(processed_end_index)
finally:
self._prune_processed_message_cache()
def _should_trigger_learning( def _should_trigger_learning(
self, self,
@@ -1145,6 +1202,10 @@ class MaisakaHeartFlowChatting:
async def _init_mcp(self) -> None: async def _init_mcp(self) -> None:
"""初始化 MCP 工具并注册到统一工具层。""" """初始化 MCP 工具并注册到统一工具层。"""
if not build_mcp_server_runtime_configs(global_config.mcp):
logger.debug(f"{self.log_prefix} 未配置可用的 MCP 服务,跳过 Maisaka MCP 初始化")
return
self._mcp_host_bridge = MCPHostLLMBridge( self._mcp_host_bridge = MCPHostLLMBridge(
sampling_task_name=global_config.mcp.client.sampling.task_name, sampling_task_name=global_config.mcp.client.sampling.task_name,
) )
@@ -1153,7 +1214,7 @@ class MaisakaHeartFlowChatting:
host_callbacks=self._mcp_host_bridge.build_callbacks(), host_callbacks=self._mcp_host_bridge.build_callbacks(),
) )
if self._mcp_manager is None: if self._mcp_manager is None:
logger.info(f"{self.log_prefix} Maisaka MCP 管理器不可用") logger.warning(f"{self.log_prefix} Maisaka MCP 管理器初始化失败MCP 工具不会注册")
return return
mcp_tool_specs = self._mcp_manager.get_tool_specs() mcp_tool_specs = self._mcp_manager.get_tool_specs()

View File

@@ -110,7 +110,7 @@ class RuntimeDataCapabilityMixin:
result = await database_service.db_count(model_class=model_class, filters=args.get("filters")) result = await database_service.db_count(model_class=model_class, filters=args.get("filters"))
else: else:
return {"success": False, "error": f"不支持的 query_type: {query_type}"} return {"success": False, "error": f"不支持的 query_type: {query_type}"}
return {"success": True, "result": result} return result
except Exception as e: except Exception as e:
logger.error(f"[cap.database.query] 执行失败: {e}", exc_info=True) logger.error(f"[cap.database.query] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@@ -136,7 +136,7 @@ class RuntimeDataCapabilityMixin:
key_field=args.get("key_field"), key_field=args.get("key_field"),
key_value=args.get("key_value"), key_value=args.get("key_value"),
) )
return {"success": True, "result": result} return result
except Exception as e: except Exception as e:
logger.error(f"[cap.database.save] 执行失败: {e}", exc_info=True) logger.error(f"[cap.database.save] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@@ -162,7 +162,7 @@ class RuntimeDataCapabilityMixin:
order_by=args.get("order_by"), order_by=args.get("order_by"),
single_result=args.get("single_result", False), single_result=args.get("single_result", False),
) )
return {"success": True, "result": result} return result
except Exception as e: except Exception as e:
logger.error(f"[cap.database.get] 执行失败: {e}", exc_info=True) logger.error(f"[cap.database.get] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@@ -185,7 +185,7 @@ class RuntimeDataCapabilityMixin:
return {"success": False, "error": f"未找到数据模型: {model_name}"} return {"success": False, "error": f"未找到数据模型: {model_name}"}
result = await database_service.db_delete(model_class=model_class, filters=filters) result = await database_service.db_delete(model_class=model_class, filters=filters)
return {"success": True, "result": result} return result
except Exception as e: except Exception as e:
logger.error(f"[cap.database.delete] 执行失败: {e}", exc_info=True) logger.error(f"[cap.database.delete] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@@ -205,7 +205,7 @@ class RuntimeDataCapabilityMixin:
return {"success": False, "error": f"未找到数据模型: {model_name}"} return {"success": False, "error": f"未找到数据模型: {model_name}"}
result = await database_service.db_count(model_class=model_class, filters=args.get("filters")) result = await database_service.db_count(model_class=model_class, filters=args.get("filters"))
return {"success": True, "count": result} return result
except Exception as e: except Exception as e:
logger.error(f"[cap.database.count] 执行失败: {e}", exc_info=True) logger.error(f"[cap.database.count] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}

View File

@@ -47,6 +47,15 @@ class VersionComparator:
return "0.0.0" return "0.0.0"
normalized = re.sub(r"-snapshot\.\d+", "", str(version).strip()) normalized = re.sub(r"-snapshot\.\d+", "", str(version).strip())
try:
parsed_version = Version(normalized)
parts = [str(part) for part in parsed_version.release[:3]]
while len(parts) < 3:
parts.append("0")
return ".".join(parts)
except InvalidVersion:
pass
if not re.match(r"^\d+(\.\d+){0,2}$", normalized): if not re.match(r"^\d+(\.\d+){0,2}$", normalized):
return "0.0.0" return "0.0.0"
@@ -132,6 +141,17 @@ class VersionComparator:
""" """
return bool(_SEMVER_PATTERN.fullmatch(str(version or "").strip())) return bool(_SEMVER_PATTERN.fullmatch(str(version or "").strip()))
@staticmethod
def is_valid_project_version(version: str) -> bool:
"""判断主程序或 SDK 的项目版本号是否可被解析。
``pyproject.toml`` 遵循 Python 包版本规范,允许 ``1.0.0rc16`` 或
``1.0.0-pre.16`` 这类预发布版本;兼容性比较时只取其 release 部分。
"""
normalized = VersionComparator.normalize_version(version)
return normalized != "0.0.0" or str(version or "").strip() == "0.0.0"
class _StrictManifestModel(BaseModel): class _StrictManifestModel(BaseModel):
"""Manifest 解析使用的严格基类模型。""" """Manifest 解析使用的严格基类模型。"""
@@ -1030,7 +1050,7 @@ class ManifestValidator:
return "" return ""
raw_version = str(project_data.get("version", "") or "").strip() raw_version = str(project_data.get("version", "") or "").strip()
if VersionComparator.is_valid_semver(raw_version): if VersionComparator.is_valid_project_version(raw_version):
return raw_version return raw_version
return "" return ""
@@ -1047,7 +1067,7 @@ class ManifestValidator:
""" """
try: try:
raw_version = importlib_metadata.version("maibot-plugin-sdk") raw_version = importlib_metadata.version("maibot-plugin-sdk")
if VersionComparator.is_valid_semver(raw_version): if VersionComparator.is_valid_project_version(raw_version):
return raw_version return raw_version
except importlib_metadata.PackageNotFoundError: except importlib_metadata.PackageNotFoundError:
pass pass
@@ -1064,7 +1084,7 @@ class ManifestValidator:
return "" return ""
raw_version = str(project_data.get("version", "") or "").strip() raw_version = str(project_data.get("version", "") or "").strip()
if VersionComparator.is_valid_semver(raw_version): if VersionComparator.is_valid_project_version(raw_version):
return raw_version return raw_version
return "" return ""

View File

@@ -3,7 +3,8 @@
import json import json
import time import time
import traceback import traceback
from datetime import datetime from datetime import date, datetime
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional, cast from typing import TYPE_CHECKING, Any, Optional, cast
from sqlalchemy import delete, func from sqlalchemy import delete, func
@@ -19,14 +20,28 @@ if TYPE_CHECKING:
logger = get_logger("database_service") logger = get_logger("database_service")
def _to_msgpack_value(value: Any) -> Any:
if isinstance(value, datetime):
return value.isoformat()
if isinstance(value, date):
return value.isoformat()
if isinstance(value, Enum):
return value.value
if isinstance(value, dict):
return {_to_msgpack_value(key): _to_msgpack_value(item) for key, item in value.items()}
if isinstance(value, (list, tuple, set)):
return [_to_msgpack_value(item) for item in value]
return value
def _to_dict(record: Any) -> dict[str, Any]: def _to_dict(record: Any) -> dict[str, Any]:
if record is None: if record is None:
return {} return {}
if isinstance(record, dict): if isinstance(record, dict):
return record return _to_msgpack_value(record)
if hasattr(record, "model_dump"): if hasattr(record, "model_dump"):
return record.model_dump() return _to_msgpack_value(record.model_dump())
return dict(record.__dict__) if hasattr(record, "__dict__") else {} return _to_msgpack_value(dict(record.__dict__)) if hasattr(record, "__dict__") else {}
def _get_model_field(model_class: type[SQLModel], field_name: str) -> Any: def _get_model_field(model_class: type[SQLModel], field_name: str) -> Any:

View File

@@ -296,6 +296,7 @@ async def get_models_by_url(
async def test_provider_connection( async def test_provider_connection(
base_url: str = Query(..., description="提供商的基础 URL"), base_url: str = Query(..., description="提供商的基础 URL"),
api_key: Optional[str] = Query(None, description="API Key可选用于验证 Key 有效性)"), api_key: Optional[str] = Query(None, description="API Key可选用于验证 Key 有效性)"),
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
): ):
""" """
测试提供商连接状态 测试提供商连接状态
@@ -359,13 +360,19 @@ async def test_provider_connection(
try: try:
start_time = time.time() start_time = time.time()
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client: async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
headers = { headers = {"Content-Type": "application/json"}
"Authorization": f"Bearer {api_key}", params = {}
"Content-Type": "application/json",
} if client_type == "gemini":
# Gemini 使用 URL 参数传递 API Key
params["key"] = api_key
else:
# OpenAI 兼容格式使用 Authorization 头
headers["Authorization"] = f"Bearer {api_key}"
# 尝试获取模型列表 # 尝试获取模型列表
models_url = f"{base_url}/models" models_url = f"{base_url}/models"
response = await client.get(models_url, headers=headers) response = await client.get(models_url, headers=headers, params=params)
if response.status_code == 200: if response.status_code == 200:
result["api_key_valid"] = True result["api_key_valid"] = True
@@ -408,9 +415,14 @@ async def test_provider_connection_by_name(
base_url = provider.get("base_url", "") base_url = provider.get("base_url", "")
api_key = provider.get("api_key", "") api_key = provider.get("api_key", "")
client_type = provider.get("client_type", "openai")
if not base_url: if not base_url:
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url") raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
# 调用测试接口 # 调用测试接口
return await test_provider_connection(base_url=base_url, api_key=api_key or None) return await test_provider_connection(
base_url=base_url,
api_key=api_key if api_key else None,
client_type=client_type,
)

View File

@@ -187,9 +187,10 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
for field in ["manifest_version", "name", "version", "author"]: for field in ["manifest_version", "name", "version", "author"]:
if field not in manifest: if field not in manifest:
raise ValueError(f"缺少必需字段: {field}") raise ValueError(f"缺少必需字段: {field}")
manifest["id"] = plugin_id if not str(manifest.get("id", "")).strip():
with open(manifest_path, "w", encoding="utf-8") as file_obj: manifest["id"] = plugin_id
json.dump(manifest, file_obj, ensure_ascii=False, indent=2) with open(manifest_path, "w", encoding="utf-8") as file_obj:
json.dump(manifest, file_obj, ensure_ascii=False, indent=2)
except Exception as e: except Exception as e:
remove_tree(target_path) remove_tree(target_path)
await update_progress( await update_progress(

1160
uv.lock generated

File diff suppressed because it is too large Load Diff