diff --git a/src/common/message_server/api.py b/src/common/message_server/api.py index b280b66e..b3cd5916 100644 --- a/src/common/message_server/api.py +++ b/src/common/message_server/api.py @@ -4,6 +4,7 @@ import traceback import importlib.metadata from src.common.logger import get_logger +from src.common.utils.port_checker import assert_port_available from src.config.config import global_config from .server import get_global_server @@ -47,17 +48,23 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method # 如果启用了API Server,则初始化额外服务器 if enable_api_server: + api_logger = get_logger("maim_message_api_server") + api_server_host = maim_message_config.api_server_host + api_server_port = maim_message_config.api_server_port + use_wss = maim_message_config.api_server_use_wss + + assert_port_available( + host=api_server_host, + port=api_server_port, + service_name="Additional API Server", + logger=api_logger, + config_hint="maim_message.api_server_port (config/bot_config.toml)", + ) + try: from maim_message.server import WebSocketServer, ServerConfig from maim_message.message import APIMessageBase - api_logger = get_logger("maim_message_api_server") - - # 1. Prepare Config - api_server_host = maim_message_config.api_server_host - api_server_port = maim_message_config.api_server_port - use_wss = maim_message_config.api_server_use_wss - server_config = ServerConfig( host=api_server_host, port=api_server_port, diff --git a/src/common/message_server/server.py b/src/common/message_server/server.py index d28a48e1..77a931e5 100644 --- a/src/common/message_server/server.py +++ b/src/common/message_server/server.py @@ -5,8 +5,13 @@ from rich.traceback import install from typing import Optional from uvicorn import Config, Server as UvicornServer +from src.common.logger import get_logger +from src.common.utils.port_checker import assert_port_available, is_port_conflict_error, log_port_conflict + install(extra_lines=3) +logger = get_logger("message_server") + class Server: def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"): @@ -49,6 +54,14 @@ class Server: async def run(self): """启动服务器""" + assert_port_available( + host=self._host, + port=self._port, + service_name="消息服务器", + logger=logger, + config_hint="maim_message.ws_server_port (config/bot_config.toml)", + ) + # 禁用 uvicorn 默认日志和访问日志 # 设置 ws_max_size 为 100MB,支持大消息(如包含多张图片的转发消息) config = Config( @@ -65,6 +78,17 @@ class Server: except KeyboardInterrupt: await self.shutdown() raise + except OSError as e: + if is_port_conflict_error(e): + log_port_conflict( + logger, + service_name="消息服务器", + host=self._host, + port=self._port, + config_hint="maim_message.ws_server_port (config/bot_config.toml)", + ) + await self.shutdown() + raise RuntimeError(f"服务器运行错误: {str(e)}") from e except Exception as e: await self.shutdown() raise RuntimeError(f"服务器运行错误: {str(e)}") from e diff --git a/src/common/utils/port_checker.py b/src/common/utils/port_checker.py new file mode 100644 index 00000000..8ff3c598 --- /dev/null +++ b/src/common/utils/port_checker.py @@ -0,0 +1,79 @@ +from typing import Optional + +import socket + + +PORT_CONFLICT_ERRNOS = {48, 98, 10048} + + +def _detect_socket_family(host: str) -> socket.AddressFamily: + return socket.AF_INET6 if ":" in host else socket.AF_INET + + +def _normalize_test_host(host: str) -> str: + if host == "0.0.0.0": + return "127.0.0.1" + return "::1" if host == "::" else host + + +def is_port_conflict_error(error: OSError) -> bool: + errno = getattr(error, "errno", None) + if errno in PORT_CONFLICT_ERRNOS: + return True + + message = str(error).lower() + return "address already in use" in message or "已被占用" in message + + +def check_port_available(host: str, port: int) -> bool: + family = _detect_socket_family(host) + test_host = _normalize_test_host(host) + + try: + with socket.socket(family, socket.SOCK_STREAM) as test_socket: + test_socket.settimeout(1) + test_socket.bind((test_host, port)) + return True + except OSError: + return False + + +def build_port_conflict_message(service_name: str, host: str, port: int) -> str: + return f"{service_name} 启动失败: 端口 {port} 已被占用 (host={host})" + + +def log_port_conflict( + logger, + *, + service_name: str, + host: str, + port: int, + config_hint: Optional[str] = None, +) -> None: + logger.error(f"❌ {build_port_conflict_message(service_name=service_name, host=host, port=port)}") + logger.error(f"💡 请检查是否有其他程序正在使用端口 {port}") + if config_hint: + logger.error(f"💡 请修改配置项 {config_hint} 来更改端口") + logger.error(f"💡 Windows 用户可以运行: netstat -ano | findstr :{port}") + logger.error(f"💡 Linux/Mac 用户可以运行: lsof -i :{port}") + + +def assert_port_available( + *, + host: str, + port: int, + service_name: str, + logger, + config_hint: Optional[str] = None, +) -> None: + if check_port_available(host=host, port=port): + return + + log_port_conflict( + logger, + service_name=service_name, + host=host, + port=port, + config_hint=config_hint, + ) + raise OSError(build_port_conflict_message(service_name=service_name, host=host, port=port)) \ No newline at end of file diff --git a/src/webui/webui_server.py b/src/webui/webui_server.py index 34b2d353..95eb6546 100644 --- a/src/webui/webui_server.py +++ b/src/webui/webui_server.py @@ -5,6 +5,7 @@ from uvicorn import Config, Server as UvicornServer import asyncio from src.common.logger import get_logger +from src.common.utils.port_checker import assert_port_available, is_port_conflict_error, log_port_conflict from src.config.config import config_manager from src.webui.app import create_app, show_access_token @@ -42,15 +43,13 @@ class WebUIServer: async def start(self): """启动服务器""" - # 预先检查端口是否可用 - if not self._check_port_available(): - error_msg = f"❌ WebUI 服务器启动失败: 端口 {self.port} 已被占用" - logger.error(error_msg) - logger.error(f"💡 请检查是否有其他程序正在使用端口 {self.port}") - logger.error("💡 可以在 .env 文件中修改 WEBUI_PORT 来更改 WebUI 端口") - logger.error(f"💡 Windows 用户可以运行: netstat -ano | findstr :{self.port}") - logger.error(f"💡 Linux/Mac 用户可以运行: lsof -i :{self.port}") - raise OSError(f"端口 {self.port} 已被占用,无法启动 WebUI 服务器") + assert_port_available( + host=self.host, + port=self.port, + service_name="WebUI 服务器", + logger=logger, + config_hint="WEBUI_PORT (.env)", + ) config = Config( app=self.app, @@ -81,11 +80,14 @@ class WebUIServer: try: await self._server.serve() except OSError as e: - # 处理端口绑定相关的错误 - if "address already in use" in str(e).lower() or e.errno in (98, 10048): # 98: Linux, 10048: Windows - logger.error(f"❌ WebUI 服务器启动失败: 端口 {self.port} 已被占用") - logger.error(f"💡 请检查是否有其他程序正在使用端口 {self.port}") - logger.error("💡 可以在 .env 文件中修改 WEBUI_PORT 来更改 WebUI 端口") + if is_port_conflict_error(e): + log_port_conflict( + logger, + service_name="WebUI 服务器", + host=self.host, + port=self.port, + config_hint="WEBUI_PORT (.env)", + ) else: logger.error(f"❌ WebUI 服务器启动失败 (网络错误): {e}") raise @@ -95,29 +97,6 @@ class WebUIServer: finally: config_manager.unregister_reload_callback(self.reload_app) - def _check_port_available(self) -> bool: - """检查端口是否可用(支持 IPv4 和 IPv6)""" - import socket - - # 判断使用 IPv4 还是 IPv6 - if ":" in self.host: - # IPv6 地址 - family = socket.AF_INET6 - test_host = self.host if self.host != "::" else "::1" - else: - # IPv4 地址 - family = socket.AF_INET - test_host = self.host if self.host != "0.0.0.0" else "127.0.0.1" - - try: - with socket.socket(family, socket.SOCK_STREAM) as s: - s.settimeout(1) - # 尝试绑定端口 - s.bind((test_host, self.port)) - return True - except OSError: - return False - async def shutdown(self): """关闭服务器""" if self._server: