better:优化部分导入以更快的启动

This commit is contained in:
SengokuCola
2026-04-25 14:26:02 +08:00
parent dcc6748a76
commit 8168fe0d8a
5 changed files with 64 additions and 27 deletions

5
bot.py
View File

@@ -105,6 +105,8 @@ if os.environ.get("MAIBOT_WORKER_PROCESS") != "1":
require_legacy_upgrade_confirmation(Path(script_dir)) require_legacy_upgrade_confirmation(Path(script_dir))
logger.info(t("startup.worker_dir_set", script_dir=script_dir))
from src.main import MainSystem # noqa from src.main import MainSystem # noqa
from src.manager.async_task_manager import async_task_manager # noqa from src.manager.async_task_manager import async_task_manager # noqa
@@ -117,9 +119,6 @@ from src.manager.async_task_manager import async_task_manager # noqa
# 设置工作目录为脚本所在目录 # 设置工作目录为脚本所在目录
# script_dir = os.path.dirname(os.path.abspath(__file__)) # script_dir = os.path.dirname(os.path.abspath(__file__))
# os.chdir(script_dir) # os.chdir(script_dir)
logger.info(t("startup.worker_dir_set", script_dir=script_dir))
confirm_logger = get_logger("confirm") confirm_logger = get_logger("confirm")
# 获取没有加载env时的环境变量 # 获取没有加载env时的环境变量
env_mask = {key: os.getenv(key) for key in os.environ} env_mask = {key: os.getenv(key) for key in os.environ}

View File

@@ -4,7 +4,6 @@ import asyncio
import json import json
import pickle import pickle
import time import time
import uuid
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from pathlib import Path from pathlib import Path
@@ -19,7 +18,7 @@ from src.services.llm_service import LLMServiceClient
from ...paths import default_data_dir, resolve_repo_path from ...paths import default_data_dir, resolve_repo_path
from ..embedding import create_embedding_api_adapter from ..embedding import create_embedding_api_adapter
from ..retrieval import RetrievalResult, SparseBM25Config, SparseBM25Index, TemporalQueryOptions from ..retrieval import RetrievalResult, SparseBM25Config, SparseBM25Index
from ..storage import GraphStore, MetadataStore, QuantizationType, SparseMatrixFormat, VectorStore from ..storage import GraphStore, MetadataStore, QuantizationType, SparseMatrixFormat, VectorStore
from ..utils.aggregate_query_service import AggregateQueryService from ..utils.aggregate_query_service import AggregateQueryService
from ..utils.episode_retrieval_service import EpisodeRetrievalService from ..utils.episode_retrieval_service import EpisodeRetrievalService
@@ -634,23 +633,18 @@ class SDKMemoryKernel:
model_name=str(self._cfg("embedding.model_name", "auto") or "auto"), model_name=str(self._cfg("embedding.model_name", "auto") or "auto"),
retry_config=self._cfg("embedding.retry", {}) or {}, retry_config=self._cfg("embedding.retry", {}) or {},
) )
detected_dimension = int(await self.embedding_manager._detect_dimension()) dimension_detection_task = asyncio.create_task(
self.embedding_dimension = detected_dimension asyncio.to_thread(lambda: asyncio.run(self.embedding_manager._detect_dimension()))
)
await asyncio.sleep(0)
stored_dimension = self._stored_vector_dimension() stored_dimension = self._stored_vector_dimension()
if stored_dimension is not None and stored_dimension != detected_dimension: provisional_dimension = stored_dimension or self.embedding_dimension
raise RuntimeError(
self._vector_mismatch_error(
stored_dimension=stored_dimension,
detected_dimension=detected_dimension,
)
)
matrix_format = str(self._cfg("graph.sparse_matrix_format", "csr") or "csr").strip().lower() matrix_format = str(self._cfg("graph.sparse_matrix_format", "csr") or "csr").strip().lower()
graph_format = SparseMatrixFormat.CSC if matrix_format == "csc" else SparseMatrixFormat.CSR graph_format = SparseMatrixFormat.CSC if matrix_format == "csc" else SparseMatrixFormat.CSR
self.vector_store = VectorStore( self.vector_store = VectorStore(
dimension=detected_dimension, dimension=provisional_dimension,
quantization_type=QuantizationType.INT8, quantization_type=QuantizationType.INT8,
data_dir=self.data_dir / "vectors", data_dir=self.data_dir / "vectors",
) )
@@ -658,9 +652,11 @@ class SDKMemoryKernel:
self.metadata_store = MetadataStore(data_dir=self.data_dir / "metadata") self.metadata_store = MetadataStore(data_dir=self.data_dir / "metadata")
self.metadata_store.connect() self.metadata_store.connect()
if self.vector_store.has_data(): vector_store_loaded = False
if stored_dimension is not None and self.vector_store.has_data():
self.vector_store.load() self.vector_store.load()
self.vector_store.warmup_index(force_train=True) self.vector_store.warmup_index(force_train=True)
vector_store_loaded = True
if self.graph_store.has_data(): if self.graph_store.has_data():
self.graph_store.load() self.graph_store.load()
@@ -674,6 +670,33 @@ class SDKMemoryKernel:
if getattr(self.sparse_index.config, "enabled", False): if getattr(self.sparse_index.config, "enabled", False):
self.sparse_index.ensure_loaded() self.sparse_index.ensure_loaded()
try:
detected_dimension = int(await dimension_detection_task)
except Exception:
if not dimension_detection_task.done():
dimension_detection_task.cancel()
raise
self.embedding_dimension = detected_dimension
if stored_dimension is not None and stored_dimension != detected_dimension:
raise RuntimeError(
self._vector_mismatch_error(
stored_dimension=stored_dimension,
detected_dimension=detected_dimension,
)
)
if self.vector_store.dimension != detected_dimension:
self.vector_store = VectorStore(
dimension=detected_dimension,
quantization_type=QuantizationType.INT8,
data_dir=self.data_dir / "vectors",
)
if not vector_store_loaded and self.vector_store.has_data():
self.vector_store.load()
self.vector_store.warmup_index(force_train=True)
self.relation_write_service = RelationWriteService( self.relation_write_service = RelationWriteService(
metadata_store=self.metadata_store, metadata_store=self.metadata_store,
graph_store=self.graph_store, graph_store=self.graph_store,

View File

@@ -279,9 +279,14 @@ class AMemorixHostService:
async with self._lock: async with self._lock:
if self._kernel is None: if self._kernel is None:
config = self._read_config() config = self._read_config()
self._kernel = SDKMemoryKernel(plugin_root=repo_root(), config=config) kernel = SDKMemoryKernel(plugin_root=repo_root(), config=config)
await self._kernel.initialize() try:
set_runtime_kernel(self._kernel) await kernel.initialize()
except Exception:
kernel.close()
raise
self._kernel = kernel
set_runtime_kernel(kernel)
return self._kernel return self._kernel
def _read_config(self) -> Dict[str, Any]: def _read_config(self) -> Dict[str, Any]:

View File

@@ -1,8 +1,8 @@
from typing import Any
import aiohttp import aiohttp
import asyncio import asyncio
import certifi
import platform import platform
import ssl
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, MMC_VERSION from src.config.config import global_config, MMC_VERSION
@@ -13,11 +13,19 @@ logger = get_logger("remote")
TELEMETRY_SERVER_URL = "http://hyybuth.xyz:10058" TELEMETRY_SERVER_URL = "http://hyybuth.xyz:10058"
"""遥测服务地址""" """遥测服务地址"""
ssl_context = ssl.create_default_context(cafile=certifi.where()) _ssl_context: Any = None
async def get_tcp_connector(): async def get_tcp_connector():
return aiohttp.TCPConnector(ssl=ssl_context) global _ssl_context
if _ssl_context is None:
import certifi
import ssl
_ssl_context = ssl.create_default_context(cafile=certifi.where())
return aiohttp.TCPConnector(ssl=_ssl_context)
class TelemetryHeartBeatTask(AsyncTask): class TelemetryHeartBeatTask(AsyncTask):

View File

@@ -1,4 +1,3 @@
from maim_message import MessageServer
from rich.traceback import install from rich.traceback import install
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@@ -13,9 +12,7 @@ from src.chat.message_receive.chat_manager import chat_manager
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
from src.common.i18n import t from src.common.i18n import t
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.message_server import get_global_api
from src.common.message_server.server import Server, get_global_server from src.common.message_server.server import Server, get_global_server
from src.common.remote import TelemetryHeartBeatTask
from src.config.config import config_manager, global_config from src.config.config import config_manager, global_config
from src.manager.async_task_manager import async_task_manager from src.manager.async_task_manager import async_task_manager
from src.maisaka.display.stage_status_board import disable_stage_status_board, enable_stage_status_board from src.maisaka.display.stage_status_board import disable_stage_status_board, enable_stage_status_board
@@ -35,12 +32,15 @@ logger = get_logger("main")
if TYPE_CHECKING: if TYPE_CHECKING:
from maim_message import MessageServer
from src.webui.webui_server import WebUIServer from src.webui.webui_server import WebUIServer
class MainSystem: class MainSystem:
def __init__(self) -> None: def __init__(self) -> None:
# 使用消息API替代直接的FastAPI实例 # 使用消息API替代直接的FastAPI实例
from src.common.message_server import get_global_api
self.app: MessageServer = get_global_api() self.app: MessageServer = get_global_api()
self.server: Server = get_global_server() self.server: Server = get_global_server()
self.webui_server: WebUIServer | None = None # 独立的 WebUI 服务器 self.webui_server: WebUIServer | None = None # 独立的 WebUI 服务器
@@ -88,6 +88,8 @@ class MainSystem:
await async_task_manager.add_task(StatisticOutputTask()) await async_task_manager.add_task(StatisticOutputTask())
# 添加遥测心跳任务 # 添加遥测心跳任务
from src.common.remote import TelemetryHeartBeatTask
await async_task_manager.add_task(TelemetryHeartBeatTask()) await async_task_manager.add_task(TelemetryHeartBeatTask())
# 添加表达方式自动检查任务 # 添加表达方式自动检查任务