maim_message强制版本要求;忽略APIServer下乱七八糟的类型注解和Patch

This commit is contained in:
UnCLAS-Prommer
2026-02-05 16:53:48 +08:00
parent 4dbe919cfe
commit 049027a48f
5 changed files with 50 additions and 73 deletions

View File

@@ -0,0 +1,8 @@
"""Maim Message - A message handling library"""
__version__ = "0.2.0"
from .api import get_global_api
__all__ = ["get_global_api"]

View File

@@ -0,0 +1,180 @@
from maim_message import MessageServer
import traceback
import importlib.metadata
from src.common.logger import get_logger
from src.config.config import global_config
from .server import get_global_server
global_api = None
def get_global_api() -> MessageServer: # sourcery skip: extract-method
"""获取全局MessageServer实例"""
global global_api
if global_api is None:
# 检查maim_message版本
maim_message_version = importlib.metadata.version("maim_message")
version_int = [int(x) for x in maim_message_version.split(".")]
if version_int < [0, 6, 2]:
raise RuntimeError("maim_message 版本过低,请升级到 0.6.2 或更高版本。")
# 读取配置项
maim_message_config = global_config.maim_message
# 设置基本参数 (Legacy Server Mode)
kwargs = {
"host": maim_message_config.ws_server_host,
"port": maim_message_config.ws_server_port,
"app": get_global_server().get_app(),
"custom_logger": get_logger("maim_message"),
"enable_custom_uvicorn_logger": False,
}
# 添加token认证
if maim_message_config.auth_token and len(maim_message_config.auth_token) > 0:
kwargs["enable_token"] = True
global_api = MessageServer(**kwargs)
if maim_message_config.auth_token:
for token in maim_message_config.auth_token:
global_api.add_valid_token(token)
# ---------------------------------------------------------------------
# Additional API Server Configuration
# ---------------------------------------------------------------------
enable_api_server = maim_message_config.enable_api_server
# 如果启用了API Server则初始化额外服务器
if enable_api_server:
try:
from maim_message.server import WebSocketServer, ServerConfig
from maim_message.message import APIMessageBase
api_logger = get_logger("maim_message_api_server")
# 1. Prepare Config
api_server_host = maim_message_config.api_server_host
api_server_port = maim_message_config.api_server_port
use_wss = maim_message_config.api_server_use_wss
server_config = ServerConfig(
host=api_server_host,
port=api_server_port,
ssl_enabled=use_wss,
ssl_certfile=maim_message_config.api_server_cert_file if use_wss else None,
ssl_keyfile=maim_message_config.api_server_key_file if use_wss else None,
custom_logger=api_logger, # 传入自定义logger
)
# 2. Setup Auth Handler
async def auth_handler(metadata: dict) -> bool:
allowed_keys = maim_message_config.api_server_allowed_api_keys
# If list is empty/None, allow all (default behavior of returning True)
if not allowed_keys:
return True
api_key = metadata.get("api_key")
if api_key in allowed_keys:
return True
api_logger.warning(f"Rejected connection with invalid API Key: {api_key}")
return False
server_config.on_auth = auth_handler # type: ignore # maim_message库写错类型了
# 3. Setup Message Bridge
# Initialize refined route map if not exists
if not hasattr(global_api, "platform_map"):
global_api.platform_map = {} # type: ignore # 不知道这是什么神奇写法
async def bridge_message_handler(message: APIMessageBase, metadata: dict):
# 使用 MessageConverter 转换 APIMessageBase 到 Legacy MessageBase
# 接收场景:收到从 Adapter 转发的外部消息
# sender_info 包含消息发送者信息,需要提取到 group_info/user_info
from maim_message import MessageConverter
legacy_message = MessageConverter.from_api_receive(message)
msg_dict = legacy_message.to_dict()
# Compatibility Layer: Ensure format_info exists with defaults
if "message_info" in msg_dict:
msg_info = msg_dict["message_info"]
# Route Caching Logic: Map platform to API Key (or connection uuid as fallback)
# This allows us to send messages back to the correct API client for this platform
try:
# Get api_key from metadata, use uuid as fallback if api_key is empty
api_key = metadata.get("api_key") or metadata.get("uuid") or "unknown"
platform = msg_info.get("platform")
api_logger.debug(f"Bridge received: api_key='{api_key}', platform='{platform}'")
if platform:
global_api.platform_map[platform] = api_key # type: ignore
api_logger.info(f"Updated platform_map: {platform} -> {api_key}")
except Exception as e:
api_logger.warning(f"Failed to update platform map: {e}")
# Compatibility Layer: Ensure raw_message exists (even if None) as it's part of MessageBase
if "raw_message" not in msg_dict:
msg_dict["raw_message"] = None
await global_api.process_message(msg_dict) # type: ignore
server_config.on_message = bridge_message_handler # type: ignore # maim_message库写错类型了
# 3.5. Register custom message handlers (bridge to Legacy handlers)
# message_id_echo: handles message ID echo from adapters
# 兼容新旧两个版本的 maim_message:
# - 旧版: handler(payload)
# - 新版: handler(payload, metadata)
async def custom_message_id_echo_handler(payload: dict, metadata: dict = None): # type: ignore
# Bridge to the Legacy custom handler registered in main.py
try:
# The Legacy handler expects the payload format directly
if hasattr(global_api, "_custom_message_handlers"):
handler = global_api._custom_message_handlers.get("message_id_echo") # type: ignore # 已经不知道这是什么了
if handler:
await handler(payload)
api_logger.debug(f"Processed message_id_echo: {payload}")
else:
api_logger.debug(f"No handler for message_id_echo, payload: {payload}")
except Exception as e:
api_logger.warning(f"Failed to process message_id_echo: {e}")
server_config.register_custom_handler("message_id_echo", custom_message_id_echo_handler) # type: ignore # maim_message库写错类型了
# 4. Initialize Server
extra_server = WebSocketServer(config=server_config)
# 5. Patch global_api lifecycle methods to manage both servers
original_run = global_api.run
original_stop = global_api.stop
async def patched_run():
api_logger.info(
f"Starting Additional API Server on {api_server_host}:{api_server_port} (WSS: {use_wss})"
)
# Start the extra server (non-blocking start)
await extra_server.start()
# Run the original legacy server (this usually keeps running)
await original_run()
async def patched_stop():
api_logger.info("Stopping Additional API Server...")
await extra_server.stop()
await original_stop()
global_api.run = patched_run
global_api.stop = patched_stop
# Attach for reference
global_api.extra_server = extra_server # type: ignore # 这是什么
except ImportError:
get_logger("maim_message").error(
"Cannot import maim_message.server components. Is maim_message >= 0.6.0 installed?"
)
except Exception as e:
get_logger("maim_message").error(f"Failed to initialize Additional API Server: {e}")
get_logger("maim_message").debug(traceback.format_exc())
return global_api

View File

@@ -0,0 +1,107 @@
import asyncio
from fastapi import FastAPI, APIRouter
from rich.traceback import install
from typing import Optional
from uvicorn import Config, Server as UvicornServer
install(extra_lines=3)
class Server:
def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"):
self.app = FastAPI(title=app_name)
self._host: str = "127.0.0.1"
self._port: int = 8080
self._server: Optional[UvicornServer] = None
self.set_address(host, port)
def register_router(self, router: APIRouter, prefix: str = ""):
"""注册路由
APIRouter 用于对相关的路由端点进行分组和模块化管理:
1. 可以将相关的端点组织在一起,便于管理
2. 支持添加统一的路由前缀
3. 可以为一组路由添加共同的依赖项、标签等
示例:
router = APIRouter()
@router.get("/users")
def get_users():
return {"users": [...]}
@router.post("/users")
def create_user():
return {"msg": "user created"}
# 注册路由,添加前缀 "/api/v1"
server.register_router(router, prefix="/api/v1")
"""
self.app.include_router(router, prefix=prefix)
def set_address(self, host: Optional[str] = None, port: Optional[int] = None):
"""设置服务器地址和端口"""
if host:
self._host = host
if port:
self._port = port
async def run(self):
"""启动服务器"""
# 禁用 uvicorn 默认日志和访问日志
# 设置 ws_max_size 为 100MB支持大消息如包含多张图片的转发消息
config = Config(
app=self.app,
host=self._host,
port=self._port,
log_config=None,
access_log=False,
ws_max_size=104_857_600, # 100MB
)
self._server = UvicornServer(config=config)
try:
await self._server.serve()
except KeyboardInterrupt:
await self.shutdown()
raise
except Exception as e:
await self.shutdown()
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
finally:
await self.shutdown()
async def shutdown(self):
"""安全关闭服务器"""
if self._server:
self._server.should_exit = True
try:
# 添加 3 秒超时,避免 shutdown 永久挂起
await asyncio.wait_for(self._server.shutdown(), timeout=3.0)
except asyncio.TimeoutError:
# 超时就强制标记为 None让垃圾回收处理
pass
except Exception:
# 忽略其他异常
pass
finally:
self._server = None
def get_app(self) -> FastAPI:
"""获取 FastAPI 实例"""
return self.app
global_server = None
def get_global_server() -> Server:
"""获取全局服务器实例"""
from src.config.config import global_config
global global_server
if global_server is None:
global_server = Server(
host=global_config.maim_message.ws_server_host, port=global_config.maim_message.ws_server_port
)
return global_server