better:优化部分导入以更快的启动
This commit is contained in:
5
bot.py
5
bot.py
@@ -105,6 +105,8 @@ if os.environ.get("MAIBOT_WORKER_PROCESS") != "1":
|
||||
|
||||
require_legacy_upgrade_confirmation(Path(script_dir))
|
||||
|
||||
logger.info(t("startup.worker_dir_set", script_dir=script_dir))
|
||||
|
||||
from src.main import MainSystem # noqa
|
||||
from src.manager.async_task_manager import async_task_manager # noqa
|
||||
|
||||
@@ -117,9 +119,6 @@ from src.manager.async_task_manager import async_task_manager # noqa
|
||||
# 设置工作目录为脚本所在目录
|
||||
# script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# os.chdir(script_dir)
|
||||
logger.info(t("startup.worker_dir_set", script_dir=script_dir))
|
||||
|
||||
|
||||
confirm_logger = get_logger("confirm")
|
||||
# 获取没有加载env时的环境变量
|
||||
env_mask = {key: os.getenv(key) for key in os.environ}
|
||||
|
||||
@@ -4,7 +4,6 @@ import asyncio
|
||||
import json
|
||||
import pickle
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
@@ -19,7 +18,7 @@ from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from ...paths import default_data_dir, resolve_repo_path
|
||||
from ..embedding import create_embedding_api_adapter
|
||||
from ..retrieval import RetrievalResult, SparseBM25Config, SparseBM25Index, TemporalQueryOptions
|
||||
from ..retrieval import RetrievalResult, SparseBM25Config, SparseBM25Index
|
||||
from ..storage import GraphStore, MetadataStore, QuantizationType, SparseMatrixFormat, VectorStore
|
||||
from ..utils.aggregate_query_service import AggregateQueryService
|
||||
from ..utils.episode_retrieval_service import EpisodeRetrievalService
|
||||
@@ -634,23 +633,18 @@ class SDKMemoryKernel:
|
||||
model_name=str(self._cfg("embedding.model_name", "auto") or "auto"),
|
||||
retry_config=self._cfg("embedding.retry", {}) or {},
|
||||
)
|
||||
detected_dimension = int(await self.embedding_manager._detect_dimension())
|
||||
self.embedding_dimension = detected_dimension
|
||||
|
||||
dimension_detection_task = asyncio.create_task(
|
||||
asyncio.to_thread(lambda: asyncio.run(self.embedding_manager._detect_dimension()))
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
stored_dimension = self._stored_vector_dimension()
|
||||
if stored_dimension is not None and stored_dimension != detected_dimension:
|
||||
raise RuntimeError(
|
||||
self._vector_mismatch_error(
|
||||
stored_dimension=stored_dimension,
|
||||
detected_dimension=detected_dimension,
|
||||
)
|
||||
)
|
||||
provisional_dimension = stored_dimension or self.embedding_dimension
|
||||
|
||||
matrix_format = str(self._cfg("graph.sparse_matrix_format", "csr") or "csr").strip().lower()
|
||||
graph_format = SparseMatrixFormat.CSC if matrix_format == "csc" else SparseMatrixFormat.CSR
|
||||
|
||||
self.vector_store = VectorStore(
|
||||
dimension=detected_dimension,
|
||||
dimension=provisional_dimension,
|
||||
quantization_type=QuantizationType.INT8,
|
||||
data_dir=self.data_dir / "vectors",
|
||||
)
|
||||
@@ -658,9 +652,11 @@ class SDKMemoryKernel:
|
||||
self.metadata_store = MetadataStore(data_dir=self.data_dir / "metadata")
|
||||
self.metadata_store.connect()
|
||||
|
||||
if self.vector_store.has_data():
|
||||
vector_store_loaded = False
|
||||
if stored_dimension is not None and self.vector_store.has_data():
|
||||
self.vector_store.load()
|
||||
self.vector_store.warmup_index(force_train=True)
|
||||
vector_store_loaded = True
|
||||
if self.graph_store.has_data():
|
||||
self.graph_store.load()
|
||||
|
||||
@@ -674,6 +670,33 @@ class SDKMemoryKernel:
|
||||
if getattr(self.sparse_index.config, "enabled", False):
|
||||
self.sparse_index.ensure_loaded()
|
||||
|
||||
try:
|
||||
detected_dimension = int(await dimension_detection_task)
|
||||
except Exception:
|
||||
if not dimension_detection_task.done():
|
||||
dimension_detection_task.cancel()
|
||||
raise
|
||||
self.embedding_dimension = detected_dimension
|
||||
|
||||
if stored_dimension is not None and stored_dimension != detected_dimension:
|
||||
raise RuntimeError(
|
||||
self._vector_mismatch_error(
|
||||
stored_dimension=stored_dimension,
|
||||
detected_dimension=detected_dimension,
|
||||
)
|
||||
)
|
||||
|
||||
if self.vector_store.dimension != detected_dimension:
|
||||
self.vector_store = VectorStore(
|
||||
dimension=detected_dimension,
|
||||
quantization_type=QuantizationType.INT8,
|
||||
data_dir=self.data_dir / "vectors",
|
||||
)
|
||||
|
||||
if not vector_store_loaded and self.vector_store.has_data():
|
||||
self.vector_store.load()
|
||||
self.vector_store.warmup_index(force_train=True)
|
||||
|
||||
self.relation_write_service = RelationWriteService(
|
||||
metadata_store=self.metadata_store,
|
||||
graph_store=self.graph_store,
|
||||
|
||||
@@ -279,9 +279,14 @@ class AMemorixHostService:
|
||||
async with self._lock:
|
||||
if self._kernel is None:
|
||||
config = self._read_config()
|
||||
self._kernel = SDKMemoryKernel(plugin_root=repo_root(), config=config)
|
||||
await self._kernel.initialize()
|
||||
set_runtime_kernel(self._kernel)
|
||||
kernel = SDKMemoryKernel(plugin_root=repo_root(), config=config)
|
||||
try:
|
||||
await kernel.initialize()
|
||||
except Exception:
|
||||
kernel.close()
|
||||
raise
|
||||
self._kernel = kernel
|
||||
set_runtime_kernel(kernel)
|
||||
return self._kernel
|
||||
|
||||
def _read_config(self) -> Dict[str, Any]:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import certifi
|
||||
import platform
|
||||
import ssl
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, MMC_VERSION
|
||||
@@ -13,11 +13,19 @@ logger = get_logger("remote")
|
||||
|
||||
TELEMETRY_SERVER_URL = "http://hyybuth.xyz:10058"
|
||||
"""遥测服务地址"""
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
_ssl_context: Any = None
|
||||
|
||||
|
||||
async def get_tcp_connector():
|
||||
return aiohttp.TCPConnector(ssl=ssl_context)
|
||||
global _ssl_context
|
||||
|
||||
if _ssl_context is None:
|
||||
import certifi
|
||||
import ssl
|
||||
|
||||
_ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
|
||||
return aiohttp.TCPConnector(ssl=_ssl_context)
|
||||
|
||||
|
||||
class TelemetryHeartBeatTask(AsyncTask):
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from maim_message import MessageServer
|
||||
from rich.traceback import install
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -13,9 +12,7 @@ from src.chat.message_receive.chat_manager import chat_manager
|
||||
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
||||
from src.common.i18n import t
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_server import get_global_api
|
||||
from src.common.message_server.server import Server, get_global_server
|
||||
from src.common.remote import TelemetryHeartBeatTask
|
||||
from src.config.config import config_manager, global_config
|
||||
from src.manager.async_task_manager import async_task_manager
|
||||
from src.maisaka.display.stage_status_board import disable_stage_status_board, enable_stage_status_board
|
||||
@@ -35,12 +32,15 @@ logger = get_logger("main")
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from maim_message import MessageServer
|
||||
from src.webui.webui_server import WebUIServer
|
||||
|
||||
|
||||
class MainSystem:
|
||||
def __init__(self) -> None:
|
||||
# 使用消息API替代直接的FastAPI实例
|
||||
from src.common.message_server import get_global_api
|
||||
|
||||
self.app: MessageServer = get_global_api()
|
||||
self.server: Server = get_global_server()
|
||||
self.webui_server: WebUIServer | None = None # 独立的 WebUI 服务器
|
||||
@@ -88,6 +88,8 @@ class MainSystem:
|
||||
await async_task_manager.add_task(StatisticOutputTask())
|
||||
|
||||
# 添加遥测心跳任务
|
||||
from src.common.remote import TelemetryHeartBeatTask
|
||||
|
||||
await async_task_manager.add_task(TelemetryHeartBeatTask())
|
||||
|
||||
# 添加表达方式自动检查任务
|
||||
|
||||
Reference in New Issue
Block a user