better:优化部分导入以更快的启动

This commit is contained in:
SengokuCola
2026-04-25 14:26:02 +08:00
parent dcc6748a76
commit 8168fe0d8a
5 changed files with 64 additions and 27 deletions

5
bot.py
View File

@@ -105,6 +105,8 @@ if os.environ.get("MAIBOT_WORKER_PROCESS") != "1":
require_legacy_upgrade_confirmation(Path(script_dir))
logger.info(t("startup.worker_dir_set", script_dir=script_dir))
from src.main import MainSystem # noqa
from src.manager.async_task_manager import async_task_manager # noqa
@@ -117,9 +119,6 @@ from src.manager.async_task_manager import async_task_manager # noqa
# 设置工作目录为脚本所在目录
# script_dir = os.path.dirname(os.path.abspath(__file__))
# os.chdir(script_dir)
logger.info(t("startup.worker_dir_set", script_dir=script_dir))
confirm_logger = get_logger("confirm")
# 获取没有加载env时的环境变量
env_mask = {key: os.getenv(key) for key in os.environ}

View File

@@ -4,7 +4,6 @@ import asyncio
import json
import pickle
import time
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta
from pathlib import Path
@@ -19,7 +18,7 @@ from src.services.llm_service import LLMServiceClient
from ...paths import default_data_dir, resolve_repo_path
from ..embedding import create_embedding_api_adapter
from ..retrieval import RetrievalResult, SparseBM25Config, SparseBM25Index, TemporalQueryOptions
from ..retrieval import RetrievalResult, SparseBM25Config, SparseBM25Index
from ..storage import GraphStore, MetadataStore, QuantizationType, SparseMatrixFormat, VectorStore
from ..utils.aggregate_query_service import AggregateQueryService
from ..utils.episode_retrieval_service import EpisodeRetrievalService
@@ -634,23 +633,18 @@ class SDKMemoryKernel:
model_name=str(self._cfg("embedding.model_name", "auto") or "auto"),
retry_config=self._cfg("embedding.retry", {}) or {},
)
detected_dimension = int(await self.embedding_manager._detect_dimension())
self.embedding_dimension = detected_dimension
dimension_detection_task = asyncio.create_task(
asyncio.to_thread(lambda: asyncio.run(self.embedding_manager._detect_dimension()))
)
await asyncio.sleep(0)
stored_dimension = self._stored_vector_dimension()
if stored_dimension is not None and stored_dimension != detected_dimension:
raise RuntimeError(
self._vector_mismatch_error(
stored_dimension=stored_dimension,
detected_dimension=detected_dimension,
)
)
provisional_dimension = stored_dimension or self.embedding_dimension
matrix_format = str(self._cfg("graph.sparse_matrix_format", "csr") or "csr").strip().lower()
graph_format = SparseMatrixFormat.CSC if matrix_format == "csc" else SparseMatrixFormat.CSR
self.vector_store = VectorStore(
dimension=detected_dimension,
dimension=provisional_dimension,
quantization_type=QuantizationType.INT8,
data_dir=self.data_dir / "vectors",
)
@@ -658,9 +652,11 @@ class SDKMemoryKernel:
self.metadata_store = MetadataStore(data_dir=self.data_dir / "metadata")
self.metadata_store.connect()
if self.vector_store.has_data():
vector_store_loaded = False
if stored_dimension is not None and self.vector_store.has_data():
self.vector_store.load()
self.vector_store.warmup_index(force_train=True)
vector_store_loaded = True
if self.graph_store.has_data():
self.graph_store.load()
@@ -674,6 +670,33 @@ class SDKMemoryKernel:
if getattr(self.sparse_index.config, "enabled", False):
self.sparse_index.ensure_loaded()
try:
detected_dimension = int(await dimension_detection_task)
except Exception:
if not dimension_detection_task.done():
dimension_detection_task.cancel()
raise
self.embedding_dimension = detected_dimension
if stored_dimension is not None and stored_dimension != detected_dimension:
raise RuntimeError(
self._vector_mismatch_error(
stored_dimension=stored_dimension,
detected_dimension=detected_dimension,
)
)
if self.vector_store.dimension != detected_dimension:
self.vector_store = VectorStore(
dimension=detected_dimension,
quantization_type=QuantizationType.INT8,
data_dir=self.data_dir / "vectors",
)
if not vector_store_loaded and self.vector_store.has_data():
self.vector_store.load()
self.vector_store.warmup_index(force_train=True)
self.relation_write_service = RelationWriteService(
metadata_store=self.metadata_store,
graph_store=self.graph_store,

View File

@@ -279,9 +279,14 @@ class AMemorixHostService:
async with self._lock:
if self._kernel is None:
config = self._read_config()
self._kernel = SDKMemoryKernel(plugin_root=repo_root(), config=config)
await self._kernel.initialize()
set_runtime_kernel(self._kernel)
kernel = SDKMemoryKernel(plugin_root=repo_root(), config=config)
try:
await kernel.initialize()
except Exception:
kernel.close()
raise
self._kernel = kernel
set_runtime_kernel(kernel)
return self._kernel
def _read_config(self) -> Dict[str, Any]:

View File

@@ -1,8 +1,8 @@
from typing import Any
import aiohttp
import asyncio
import certifi
import platform
import ssl
from src.common.logger import get_logger
from src.config.config import global_config, MMC_VERSION
@@ -13,11 +13,19 @@ logger = get_logger("remote")
TELEMETRY_SERVER_URL = "http://hyybuth.xyz:10058"
"""遥测服务地址"""
ssl_context = ssl.create_default_context(cafile=certifi.where())
_ssl_context: Any = None
async def get_tcp_connector():
return aiohttp.TCPConnector(ssl=ssl_context)
global _ssl_context
if _ssl_context is None:
import certifi
import ssl
_ssl_context = ssl.create_default_context(cafile=certifi.where())
return aiohttp.TCPConnector(ssl=_ssl_context)
class TelemetryHeartBeatTask(AsyncTask):

View File

@@ -1,4 +1,3 @@
from maim_message import MessageServer
from rich.traceback import install
from typing import TYPE_CHECKING
@@ -13,9 +12,7 @@ from src.chat.message_receive.chat_manager import chat_manager
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
from src.common.i18n import t
from src.common.logger import get_logger
from src.common.message_server import get_global_api
from src.common.message_server.server import Server, get_global_server
from src.common.remote import TelemetryHeartBeatTask
from src.config.config import config_manager, global_config
from src.manager.async_task_manager import async_task_manager
from src.maisaka.display.stage_status_board import disable_stage_status_board, enable_stage_status_board
@@ -35,12 +32,15 @@ logger = get_logger("main")
if TYPE_CHECKING:
from maim_message import MessageServer
from src.webui.webui_server import WebUIServer
class MainSystem:
def __init__(self) -> None:
# 使用消息API替代直接的FastAPI实例
from src.common.message_server import get_global_api
self.app: MessageServer = get_global_api()
self.server: Server = get_global_server()
self.webui_server: WebUIServer | None = None # 独立的 WebUI 服务器
@@ -88,6 +88,8 @@ class MainSystem:
await async_task_manager.add_task(StatisticOutputTask())
# 添加遥测心跳任务
from src.common.remote import TelemetryHeartBeatTask
await async_task_manager.add_task(TelemetryHeartBeatTask())
# 添加表达方式自动检查任务