efactor(network): centralize port validation and enforce strict configured ports
add a shared port checker utility for availability and conflict detection migrate WebUI, message server, and additional API server to use the new module fail fast with clear error hints when a configured port is occupied (no auto-increment)
This commit is contained in:
@@ -4,6 +4,7 @@ import traceback
|
|||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.common.utils.port_checker import assert_port_available
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from .server import get_global_server
|
from .server import get_global_server
|
||||||
|
|
||||||
@@ -47,17 +48,23 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
|||||||
|
|
||||||
# 如果启用了API Server,则初始化额外服务器
|
# 如果启用了API Server,则初始化额外服务器
|
||||||
if enable_api_server:
|
if enable_api_server:
|
||||||
|
api_logger = get_logger("maim_message_api_server")
|
||||||
|
api_server_host = maim_message_config.api_server_host
|
||||||
|
api_server_port = maim_message_config.api_server_port
|
||||||
|
use_wss = maim_message_config.api_server_use_wss
|
||||||
|
|
||||||
|
assert_port_available(
|
||||||
|
host=api_server_host,
|
||||||
|
port=api_server_port,
|
||||||
|
service_name="Additional API Server",
|
||||||
|
logger=api_logger,
|
||||||
|
config_hint="maim_message.api_server_port (config/bot_config.toml)",
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from maim_message.server import WebSocketServer, ServerConfig
|
from maim_message.server import WebSocketServer, ServerConfig
|
||||||
from maim_message.message import APIMessageBase
|
from maim_message.message import APIMessageBase
|
||||||
|
|
||||||
api_logger = get_logger("maim_message_api_server")
|
|
||||||
|
|
||||||
# 1. Prepare Config
|
|
||||||
api_server_host = maim_message_config.api_server_host
|
|
||||||
api_server_port = maim_message_config.api_server_port
|
|
||||||
use_wss = maim_message_config.api_server_use_wss
|
|
||||||
|
|
||||||
server_config = ServerConfig(
|
server_config = ServerConfig(
|
||||||
host=api_server_host,
|
host=api_server_host,
|
||||||
port=api_server_port,
|
port=api_server_port,
|
||||||
|
|||||||
@@ -5,8 +5,13 @@ from rich.traceback import install
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from uvicorn import Config, Server as UvicornServer
|
from uvicorn import Config, Server as UvicornServer
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.common.utils.port_checker import assert_port_available, is_port_conflict_error, log_port_conflict
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
logger = get_logger("message_server")
|
||||||
|
|
||||||
|
|
||||||
class Server:
|
class Server:
|
||||||
def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"):
|
def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"):
|
||||||
@@ -49,6 +54,14 @@ class Server:
|
|||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""启动服务器"""
|
"""启动服务器"""
|
||||||
|
assert_port_available(
|
||||||
|
host=self._host,
|
||||||
|
port=self._port,
|
||||||
|
service_name="消息服务器",
|
||||||
|
logger=logger,
|
||||||
|
config_hint="maim_message.ws_server_port (config/bot_config.toml)",
|
||||||
|
)
|
||||||
|
|
||||||
# 禁用 uvicorn 默认日志和访问日志
|
# 禁用 uvicorn 默认日志和访问日志
|
||||||
# 设置 ws_max_size 为 100MB,支持大消息(如包含多张图片的转发消息)
|
# 设置 ws_max_size 为 100MB,支持大消息(如包含多张图片的转发消息)
|
||||||
config = Config(
|
config = Config(
|
||||||
@@ -65,6 +78,17 @@ class Server:
|
|||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
await self.shutdown()
|
await self.shutdown()
|
||||||
raise
|
raise
|
||||||
|
except OSError as e:
|
||||||
|
if is_port_conflict_error(e):
|
||||||
|
log_port_conflict(
|
||||||
|
logger,
|
||||||
|
service_name="消息服务器",
|
||||||
|
host=self._host,
|
||||||
|
port=self._port,
|
||||||
|
config_hint="maim_message.ws_server_port (config/bot_config.toml)",
|
||||||
|
)
|
||||||
|
await self.shutdown()
|
||||||
|
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await self.shutdown()
|
await self.shutdown()
|
||||||
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
|
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
|
||||||
|
|||||||
79
src/common/utils/port_checker.py
Normal file
79
src/common/utils/port_checker.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import socket
|
||||||
|
|
||||||
|
|
||||||
|
PORT_CONFLICT_ERRNOS = {48, 98, 10048}
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_socket_family(host: str) -> socket.AddressFamily:
|
||||||
|
return socket.AF_INET6 if ":" in host else socket.AF_INET
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_test_host(host: str) -> str:
|
||||||
|
if host == "0.0.0.0":
|
||||||
|
return "127.0.0.1"
|
||||||
|
return "::1" if host == "::" else host
|
||||||
|
|
||||||
|
|
||||||
|
def is_port_conflict_error(error: OSError) -> bool:
|
||||||
|
errno = getattr(error, "errno", None)
|
||||||
|
if errno in PORT_CONFLICT_ERRNOS:
|
||||||
|
return True
|
||||||
|
|
||||||
|
message = str(error).lower()
|
||||||
|
return "address already in use" in message or "已被占用" in message
|
||||||
|
|
||||||
|
|
||||||
|
def check_port_available(host: str, port: int) -> bool:
|
||||||
|
family = _detect_socket_family(host)
|
||||||
|
test_host = _normalize_test_host(host)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with socket.socket(family, socket.SOCK_STREAM) as test_socket:
|
||||||
|
test_socket.settimeout(1)
|
||||||
|
test_socket.bind((test_host, port))
|
||||||
|
return True
|
||||||
|
except OSError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def build_port_conflict_message(service_name: str, host: str, port: int) -> str:
|
||||||
|
return f"{service_name} 启动失败: 端口 {port} 已被占用 (host={host})"
|
||||||
|
|
||||||
|
|
||||||
|
def log_port_conflict(
|
||||||
|
logger,
|
||||||
|
*,
|
||||||
|
service_name: str,
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
config_hint: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
logger.error(f"❌ {build_port_conflict_message(service_name=service_name, host=host, port=port)}")
|
||||||
|
logger.error(f"💡 请检查是否有其他程序正在使用端口 {port}")
|
||||||
|
if config_hint:
|
||||||
|
logger.error(f"💡 请修改配置项 {config_hint} 来更改端口")
|
||||||
|
logger.error(f"💡 Windows 用户可以运行: netstat -ano | findstr :{port}")
|
||||||
|
logger.error(f"💡 Linux/Mac 用户可以运行: lsof -i :{port}")
|
||||||
|
|
||||||
|
|
||||||
|
def assert_port_available(
|
||||||
|
*,
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
service_name: str,
|
||||||
|
logger,
|
||||||
|
config_hint: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
if check_port_available(host=host, port=port):
|
||||||
|
return
|
||||||
|
|
||||||
|
log_port_conflict(
|
||||||
|
logger,
|
||||||
|
service_name=service_name,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
config_hint=config_hint,
|
||||||
|
)
|
||||||
|
raise OSError(build_port_conflict_message(service_name=service_name, host=host, port=port))
|
||||||
@@ -5,6 +5,7 @@ from uvicorn import Config, Server as UvicornServer
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.common.utils.port_checker import assert_port_available, is_port_conflict_error, log_port_conflict
|
||||||
from src.config.config import config_manager
|
from src.config.config import config_manager
|
||||||
from src.webui.app import create_app, show_access_token
|
from src.webui.app import create_app, show_access_token
|
||||||
|
|
||||||
@@ -42,15 +43,13 @@ class WebUIServer:
|
|||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
"""启动服务器"""
|
"""启动服务器"""
|
||||||
# 预先检查端口是否可用
|
assert_port_available(
|
||||||
if not self._check_port_available():
|
host=self.host,
|
||||||
error_msg = f"❌ WebUI 服务器启动失败: 端口 {self.port} 已被占用"
|
port=self.port,
|
||||||
logger.error(error_msg)
|
service_name="WebUI 服务器",
|
||||||
logger.error(f"💡 请检查是否有其他程序正在使用端口 {self.port}")
|
logger=logger,
|
||||||
logger.error("💡 可以在 .env 文件中修改 WEBUI_PORT 来更改 WebUI 端口")
|
config_hint="WEBUI_PORT (.env)",
|
||||||
logger.error(f"💡 Windows 用户可以运行: netstat -ano | findstr :{self.port}")
|
)
|
||||||
logger.error(f"💡 Linux/Mac 用户可以运行: lsof -i :{self.port}")
|
|
||||||
raise OSError(f"端口 {self.port} 已被占用,无法启动 WebUI 服务器")
|
|
||||||
|
|
||||||
config = Config(
|
config = Config(
|
||||||
app=self.app,
|
app=self.app,
|
||||||
@@ -81,11 +80,14 @@ class WebUIServer:
|
|||||||
try:
|
try:
|
||||||
await self._server.serve()
|
await self._server.serve()
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
# 处理端口绑定相关的错误
|
if is_port_conflict_error(e):
|
||||||
if "address already in use" in str(e).lower() or e.errno in (98, 10048): # 98: Linux, 10048: Windows
|
log_port_conflict(
|
||||||
logger.error(f"❌ WebUI 服务器启动失败: 端口 {self.port} 已被占用")
|
logger,
|
||||||
logger.error(f"💡 请检查是否有其他程序正在使用端口 {self.port}")
|
service_name="WebUI 服务器",
|
||||||
logger.error("💡 可以在 .env 文件中修改 WEBUI_PORT 来更改 WebUI 端口")
|
host=self.host,
|
||||||
|
port=self.port,
|
||||||
|
config_hint="WEBUI_PORT (.env)",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(f"❌ WebUI 服务器启动失败 (网络错误): {e}")
|
logger.error(f"❌ WebUI 服务器启动失败 (网络错误): {e}")
|
||||||
raise
|
raise
|
||||||
@@ -95,29 +97,6 @@ class WebUIServer:
|
|||||||
finally:
|
finally:
|
||||||
config_manager.unregister_reload_callback(self.reload_app)
|
config_manager.unregister_reload_callback(self.reload_app)
|
||||||
|
|
||||||
def _check_port_available(self) -> bool:
|
|
||||||
"""检查端口是否可用(支持 IPv4 和 IPv6)"""
|
|
||||||
import socket
|
|
||||||
|
|
||||||
# 判断使用 IPv4 还是 IPv6
|
|
||||||
if ":" in self.host:
|
|
||||||
# IPv6 地址
|
|
||||||
family = socket.AF_INET6
|
|
||||||
test_host = self.host if self.host != "::" else "::1"
|
|
||||||
else:
|
|
||||||
# IPv4 地址
|
|
||||||
family = socket.AF_INET
|
|
||||||
test_host = self.host if self.host != "0.0.0.0" else "127.0.0.1"
|
|
||||||
|
|
||||||
try:
|
|
||||||
with socket.socket(family, socket.SOCK_STREAM) as s:
|
|
||||||
s.settimeout(1)
|
|
||||||
# 尝试绑定端口
|
|
||||||
s.bind((test_host, self.port))
|
|
||||||
return True
|
|
||||||
except OSError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
"""关闭服务器"""
|
"""关闭服务器"""
|
||||||
if self._server:
|
if self._server:
|
||||||
|
|||||||
Reference in New Issue
Block a user