maim_message强制版本要求;忽略APIServer下乱七八糟的类型注解和Patch
This commit is contained in:
8
src/common/message_server/__init__.py
Normal file
8
src/common/message_server/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Maim Message - A message handling library"""
|
||||
|
||||
__version__ = "0.2.0"
|
||||
|
||||
from .api import get_global_api
|
||||
|
||||
|
||||
__all__ = ["get_global_api"]
|
||||
180
src/common/message_server/api.py
Normal file
180
src/common/message_server/api.py
Normal file
@@ -0,0 +1,180 @@
|
||||
from maim_message import MessageServer
|
||||
|
||||
import traceback
|
||||
import importlib.metadata
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from .server import get_global_server
|
||||
|
||||
global_api = None
|
||||
|
||||
def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
||||
"""获取全局MessageServer实例"""
|
||||
global global_api
|
||||
if global_api is None:
|
||||
# 检查maim_message版本
|
||||
maim_message_version = importlib.metadata.version("maim_message")
|
||||
version_int = [int(x) for x in maim_message_version.split(".")]
|
||||
if version_int < [0, 6, 2]:
|
||||
raise RuntimeError("maim_message 版本过低,请升级到 0.6.2 或更高版本。")
|
||||
# 读取配置项
|
||||
maim_message_config = global_config.maim_message
|
||||
|
||||
# 设置基本参数 (Legacy Server Mode)
|
||||
kwargs = {
|
||||
"host": maim_message_config.ws_server_host,
|
||||
"port": maim_message_config.ws_server_port,
|
||||
"app": get_global_server().get_app(),
|
||||
"custom_logger": get_logger("maim_message"),
|
||||
"enable_custom_uvicorn_logger": False,
|
||||
}
|
||||
|
||||
# 添加token认证
|
||||
if maim_message_config.auth_token and len(maim_message_config.auth_token) > 0:
|
||||
kwargs["enable_token"] = True
|
||||
|
||||
global_api = MessageServer(**kwargs)
|
||||
if maim_message_config.auth_token:
|
||||
for token in maim_message_config.auth_token:
|
||||
global_api.add_valid_token(token)
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Additional API Server Configuration
|
||||
# ---------------------------------------------------------------------
|
||||
enable_api_server = maim_message_config.enable_api_server
|
||||
|
||||
# 如果启用了API Server,则初始化额外服务器
|
||||
if enable_api_server:
|
||||
try:
|
||||
from maim_message.server import WebSocketServer, ServerConfig
|
||||
from maim_message.message import APIMessageBase
|
||||
|
||||
api_logger = get_logger("maim_message_api_server")
|
||||
|
||||
# 1. Prepare Config
|
||||
api_server_host = maim_message_config.api_server_host
|
||||
api_server_port = maim_message_config.api_server_port
|
||||
use_wss = maim_message_config.api_server_use_wss
|
||||
|
||||
server_config = ServerConfig(
|
||||
host=api_server_host,
|
||||
port=api_server_port,
|
||||
ssl_enabled=use_wss,
|
||||
ssl_certfile=maim_message_config.api_server_cert_file if use_wss else None,
|
||||
ssl_keyfile=maim_message_config.api_server_key_file if use_wss else None,
|
||||
custom_logger=api_logger, # 传入自定义logger
|
||||
)
|
||||
|
||||
# 2. Setup Auth Handler
|
||||
async def auth_handler(metadata: dict) -> bool:
|
||||
allowed_keys = maim_message_config.api_server_allowed_api_keys
|
||||
# If list is empty/None, allow all (default behavior of returning True)
|
||||
if not allowed_keys:
|
||||
return True
|
||||
|
||||
api_key = metadata.get("api_key")
|
||||
if api_key in allowed_keys:
|
||||
return True
|
||||
|
||||
api_logger.warning(f"Rejected connection with invalid API Key: {api_key}")
|
||||
return False
|
||||
|
||||
server_config.on_auth = auth_handler # type: ignore # maim_message库写错类型了
|
||||
|
||||
# 3. Setup Message Bridge
|
||||
# Initialize refined route map if not exists
|
||||
if not hasattr(global_api, "platform_map"):
|
||||
global_api.platform_map = {} # type: ignore # 不知道这是什么神奇写法
|
||||
|
||||
async def bridge_message_handler(message: APIMessageBase, metadata: dict):
|
||||
# 使用 MessageConverter 转换 APIMessageBase 到 Legacy MessageBase
|
||||
# 接收场景:收到从 Adapter 转发的外部消息
|
||||
# sender_info 包含消息发送者信息,需要提取到 group_info/user_info
|
||||
from maim_message import MessageConverter
|
||||
|
||||
legacy_message = MessageConverter.from_api_receive(message)
|
||||
msg_dict = legacy_message.to_dict()
|
||||
|
||||
# Compatibility Layer: Ensure format_info exists with defaults
|
||||
if "message_info" in msg_dict:
|
||||
msg_info = msg_dict["message_info"]
|
||||
# Route Caching Logic: Map platform to API Key (or connection uuid as fallback)
|
||||
# This allows us to send messages back to the correct API client for this platform
|
||||
try:
|
||||
# Get api_key from metadata, use uuid as fallback if api_key is empty
|
||||
api_key = metadata.get("api_key") or metadata.get("uuid") or "unknown"
|
||||
platform = msg_info.get("platform")
|
||||
api_logger.debug(f"Bridge received: api_key='{api_key}', platform='{platform}'")
|
||||
|
||||
if platform:
|
||||
global_api.platform_map[platform] = api_key # type: ignore
|
||||
api_logger.info(f"Updated platform_map: {platform} -> {api_key}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Failed to update platform map: {e}")
|
||||
|
||||
# Compatibility Layer: Ensure raw_message exists (even if None) as it's part of MessageBase
|
||||
if "raw_message" not in msg_dict:
|
||||
msg_dict["raw_message"] = None
|
||||
|
||||
await global_api.process_message(msg_dict) # type: ignore
|
||||
|
||||
server_config.on_message = bridge_message_handler # type: ignore # maim_message库写错类型了
|
||||
|
||||
# 3.5. Register custom message handlers (bridge to Legacy handlers)
|
||||
# message_id_echo: handles message ID echo from adapters
|
||||
# 兼容新旧两个版本的 maim_message:
|
||||
# - 旧版: handler(payload)
|
||||
# - 新版: handler(payload, metadata)
|
||||
async def custom_message_id_echo_handler(payload: dict, metadata: dict = None): # type: ignore
|
||||
# Bridge to the Legacy custom handler registered in main.py
|
||||
try:
|
||||
# The Legacy handler expects the payload format directly
|
||||
if hasattr(global_api, "_custom_message_handlers"):
|
||||
handler = global_api._custom_message_handlers.get("message_id_echo") # type: ignore # 已经不知道这是什么了
|
||||
if handler:
|
||||
await handler(payload)
|
||||
api_logger.debug(f"Processed message_id_echo: {payload}")
|
||||
else:
|
||||
api_logger.debug(f"No handler for message_id_echo, payload: {payload}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Failed to process message_id_echo: {e}")
|
||||
|
||||
server_config.register_custom_handler("message_id_echo", custom_message_id_echo_handler) # type: ignore # maim_message库写错类型了
|
||||
|
||||
# 4. Initialize Server
|
||||
extra_server = WebSocketServer(config=server_config)
|
||||
|
||||
# 5. Patch global_api lifecycle methods to manage both servers
|
||||
original_run = global_api.run
|
||||
original_stop = global_api.stop
|
||||
|
||||
async def patched_run():
|
||||
api_logger.info(
|
||||
f"Starting Additional API Server on {api_server_host}:{api_server_port} (WSS: {use_wss})"
|
||||
)
|
||||
# Start the extra server (non-blocking start)
|
||||
await extra_server.start()
|
||||
# Run the original legacy server (this usually keeps running)
|
||||
await original_run()
|
||||
|
||||
async def patched_stop():
|
||||
api_logger.info("Stopping Additional API Server...")
|
||||
await extra_server.stop()
|
||||
await original_stop()
|
||||
|
||||
global_api.run = patched_run
|
||||
global_api.stop = patched_stop
|
||||
|
||||
# Attach for reference
|
||||
global_api.extra_server = extra_server # type: ignore # 这是什么
|
||||
|
||||
except ImportError:
|
||||
get_logger("maim_message").error(
|
||||
"Cannot import maim_message.server components. Is maim_message >= 0.6.0 installed?"
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger("maim_message").error(f"Failed to initialize Additional API Server: {e}")
|
||||
get_logger("maim_message").debug(traceback.format_exc())
|
||||
|
||||
return global_api
|
||||
107
src/common/message_server/server.py
Normal file
107
src/common/message_server/server.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import asyncio
|
||||
|
||||
from fastapi import FastAPI, APIRouter
|
||||
from rich.traceback import install
|
||||
from typing import Optional
|
||||
from uvicorn import Config, Server as UvicornServer
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
class Server:
|
||||
def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"):
|
||||
self.app = FastAPI(title=app_name)
|
||||
self._host: str = "127.0.0.1"
|
||||
self._port: int = 8080
|
||||
self._server: Optional[UvicornServer] = None
|
||||
self.set_address(host, port)
|
||||
|
||||
def register_router(self, router: APIRouter, prefix: str = ""):
|
||||
"""注册路由
|
||||
|
||||
APIRouter 用于对相关的路由端点进行分组和模块化管理:
|
||||
1. 可以将相关的端点组织在一起,便于管理
|
||||
2. 支持添加统一的路由前缀
|
||||
3. 可以为一组路由添加共同的依赖项、标签等
|
||||
|
||||
示例:
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/users")
|
||||
def get_users():
|
||||
return {"users": [...]}
|
||||
|
||||
@router.post("/users")
|
||||
def create_user():
|
||||
return {"msg": "user created"}
|
||||
|
||||
# 注册路由,添加前缀 "/api/v1"
|
||||
server.register_router(router, prefix="/api/v1")
|
||||
"""
|
||||
self.app.include_router(router, prefix=prefix)
|
||||
|
||||
def set_address(self, host: Optional[str] = None, port: Optional[int] = None):
|
||||
"""设置服务器地址和端口"""
|
||||
if host:
|
||||
self._host = host
|
||||
if port:
|
||||
self._port = port
|
||||
|
||||
async def run(self):
|
||||
"""启动服务器"""
|
||||
# 禁用 uvicorn 默认日志和访问日志
|
||||
# 设置 ws_max_size 为 100MB,支持大消息(如包含多张图片的转发消息)
|
||||
config = Config(
|
||||
app=self.app,
|
||||
host=self._host,
|
||||
port=self._port,
|
||||
log_config=None,
|
||||
access_log=False,
|
||||
ws_max_size=104_857_600, # 100MB
|
||||
)
|
||||
self._server = UvicornServer(config=config)
|
||||
try:
|
||||
await self._server.serve()
|
||||
except KeyboardInterrupt:
|
||||
await self.shutdown()
|
||||
raise
|
||||
except Exception as e:
|
||||
await self.shutdown()
|
||||
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
|
||||
finally:
|
||||
await self.shutdown()
|
||||
|
||||
async def shutdown(self):
|
||||
"""安全关闭服务器"""
|
||||
if self._server:
|
||||
self._server.should_exit = True
|
||||
try:
|
||||
# 添加 3 秒超时,避免 shutdown 永久挂起
|
||||
await asyncio.wait_for(self._server.shutdown(), timeout=3.0)
|
||||
except asyncio.TimeoutError:
|
||||
# 超时就强制标记为 None,让垃圾回收处理
|
||||
pass
|
||||
except Exception:
|
||||
# 忽略其他异常
|
||||
pass
|
||||
finally:
|
||||
self._server = None
|
||||
|
||||
def get_app(self) -> FastAPI:
|
||||
"""获取 FastAPI 实例"""
|
||||
return self.app
|
||||
|
||||
|
||||
global_server = None
|
||||
|
||||
|
||||
def get_global_server() -> Server:
|
||||
"""获取全局服务器实例"""
|
||||
from src.config.config import global_config
|
||||
|
||||
global global_server
|
||||
if global_server is None:
|
||||
global_server = Server(
|
||||
host=global_config.maim_message.ws_server_host, port=global_config.maim_message.ws_server_port
|
||||
)
|
||||
return global_server
|
||||
Reference in New Issue
Block a user