diff --git a/bot.py b/bot.py index 1a13e46e..6c5ea8e1 100644 --- a/bot.py +++ b/bot.py @@ -105,6 +105,8 @@ if os.environ.get("MAIBOT_WORKER_PROCESS") != "1": require_legacy_upgrade_confirmation(Path(script_dir)) +logger.info(t("startup.worker_dir_set", script_dir=script_dir)) + from src.main import MainSystem # noqa from src.manager.async_task_manager import async_task_manager # noqa @@ -117,9 +119,6 @@ from src.manager.async_task_manager import async_task_manager # noqa # 设置工作目录为脚本所在目录 # script_dir = os.path.dirname(os.path.abspath(__file__)) # os.chdir(script_dir) -logger.info(t("startup.worker_dir_set", script_dir=script_dir)) - - confirm_logger = get_logger("confirm") # 获取没有加载env时的环境变量 env_mask = {key: os.getenv(key) for key in os.environ} diff --git a/src/A_memorix/core/runtime/sdk_memory_kernel.py b/src/A_memorix/core/runtime/sdk_memory_kernel.py index 12681e05..a9ca167c 100644 --- a/src/A_memorix/core/runtime/sdk_memory_kernel.py +++ b/src/A_memorix/core/runtime/sdk_memory_kernel.py @@ -4,7 +4,6 @@ import asyncio import json import pickle import time -import uuid from dataclasses import dataclass from datetime import datetime, timedelta from pathlib import Path @@ -19,7 +18,7 @@ from src.services.llm_service import LLMServiceClient from ...paths import default_data_dir, resolve_repo_path from ..embedding import create_embedding_api_adapter -from ..retrieval import RetrievalResult, SparseBM25Config, SparseBM25Index, TemporalQueryOptions +from ..retrieval import RetrievalResult, SparseBM25Config, SparseBM25Index from ..storage import GraphStore, MetadataStore, QuantizationType, SparseMatrixFormat, VectorStore from ..utils.aggregate_query_service import AggregateQueryService from ..utils.episode_retrieval_service import EpisodeRetrievalService @@ -634,23 +633,18 @@ class SDKMemoryKernel: model_name=str(self._cfg("embedding.model_name", "auto") or "auto"), retry_config=self._cfg("embedding.retry", {}) or {}, ) - detected_dimension = int(await self.embedding_manager._detect_dimension()) - self.embedding_dimension = detected_dimension - + dimension_detection_task = asyncio.create_task( + asyncio.to_thread(lambda: asyncio.run(self.embedding_manager._detect_dimension())) + ) + await asyncio.sleep(0) stored_dimension = self._stored_vector_dimension() - if stored_dimension is not None and stored_dimension != detected_dimension: - raise RuntimeError( - self._vector_mismatch_error( - stored_dimension=stored_dimension, - detected_dimension=detected_dimension, - ) - ) + provisional_dimension = stored_dimension or self.embedding_dimension matrix_format = str(self._cfg("graph.sparse_matrix_format", "csr") or "csr").strip().lower() graph_format = SparseMatrixFormat.CSC if matrix_format == "csc" else SparseMatrixFormat.CSR self.vector_store = VectorStore( - dimension=detected_dimension, + dimension=provisional_dimension, quantization_type=QuantizationType.INT8, data_dir=self.data_dir / "vectors", ) @@ -658,9 +652,11 @@ class SDKMemoryKernel: self.metadata_store = MetadataStore(data_dir=self.data_dir / "metadata") self.metadata_store.connect() - if self.vector_store.has_data(): + vector_store_loaded = False + if stored_dimension is not None and self.vector_store.has_data(): self.vector_store.load() self.vector_store.warmup_index(force_train=True) + vector_store_loaded = True if self.graph_store.has_data(): self.graph_store.load() @@ -674,6 +670,33 @@ class SDKMemoryKernel: if getattr(self.sparse_index.config, "enabled", False): self.sparse_index.ensure_loaded() + try: + detected_dimension = int(await dimension_detection_task) + except Exception: + if not dimension_detection_task.done(): + dimension_detection_task.cancel() + raise + self.embedding_dimension = detected_dimension + + if stored_dimension is not None and stored_dimension != detected_dimension: + raise RuntimeError( + self._vector_mismatch_error( + stored_dimension=stored_dimension, + detected_dimension=detected_dimension, + ) + ) + + if self.vector_store.dimension != detected_dimension: + self.vector_store = VectorStore( + dimension=detected_dimension, + quantization_type=QuantizationType.INT8, + data_dir=self.data_dir / "vectors", + ) + + if not vector_store_loaded and self.vector_store.has_data(): + self.vector_store.load() + self.vector_store.warmup_index(force_train=True) + self.relation_write_service = RelationWriteService( metadata_store=self.metadata_store, graph_store=self.graph_store, diff --git a/src/A_memorix/host_service.py b/src/A_memorix/host_service.py index b3766dd6..8b05127d 100644 --- a/src/A_memorix/host_service.py +++ b/src/A_memorix/host_service.py @@ -279,9 +279,14 @@ class AMemorixHostService: async with self._lock: if self._kernel is None: config = self._read_config() - self._kernel = SDKMemoryKernel(plugin_root=repo_root(), config=config) - await self._kernel.initialize() - set_runtime_kernel(self._kernel) + kernel = SDKMemoryKernel(plugin_root=repo_root(), config=config) + try: + await kernel.initialize() + except Exception: + kernel.close() + raise + self._kernel = kernel + set_runtime_kernel(kernel) return self._kernel def _read_config(self) -> Dict[str, Any]: diff --git a/src/common/remote.py b/src/common/remote.py index 0a91b3ab..4a9d5ddc 100644 --- a/src/common/remote.py +++ b/src/common/remote.py @@ -1,8 +1,8 @@ +from typing import Any + import aiohttp import asyncio -import certifi import platform -import ssl from src.common.logger import get_logger from src.config.config import global_config, MMC_VERSION @@ -13,11 +13,19 @@ logger = get_logger("remote") TELEMETRY_SERVER_URL = "http://hyybuth.xyz:10058" """遥测服务地址""" -ssl_context = ssl.create_default_context(cafile=certifi.where()) +_ssl_context: Any = None async def get_tcp_connector(): - return aiohttp.TCPConnector(ssl=ssl_context) + global _ssl_context + + if _ssl_context is None: + import certifi + import ssl + + _ssl_context = ssl.create_default_context(cafile=certifi.where()) + + return aiohttp.TCPConnector(ssl=_ssl_context) class TelemetryHeartBeatTask(AsyncTask): diff --git a/src/main.py b/src/main.py index 1e515fe9..1e184b28 100644 --- a/src/main.py +++ b/src/main.py @@ -1,4 +1,3 @@ -from maim_message import MessageServer from rich.traceback import install from typing import TYPE_CHECKING @@ -13,9 +12,7 @@ from src.chat.message_receive.chat_manager import chat_manager from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask from src.common.i18n import t from src.common.logger import get_logger -from src.common.message_server import get_global_api from src.common.message_server.server import Server, get_global_server -from src.common.remote import TelemetryHeartBeatTask from src.config.config import config_manager, global_config from src.manager.async_task_manager import async_task_manager from src.maisaka.display.stage_status_board import disable_stage_status_board, enable_stage_status_board @@ -35,12 +32,15 @@ logger = get_logger("main") if TYPE_CHECKING: + from maim_message import MessageServer from src.webui.webui_server import WebUIServer class MainSystem: def __init__(self) -> None: # 使用消息API替代直接的FastAPI实例 + from src.common.message_server import get_global_api + self.app: MessageServer = get_global_api() self.server: Server = get_global_server() self.webui_server: WebUIServer | None = None # 独立的 WebUI 服务器 @@ -88,6 +88,8 @@ class MainSystem: await async_task_manager.add_task(StatisticOutputTask()) # 添加遥测心跳任务 + from src.common.remote import TelemetryHeartBeatTask + await async_task_manager.add_task(TelemetryHeartBeatTask()) # 添加表达方式自动检查任务