feat: 添加启动绑定地址解析功能,支持从配置文件和环境变量迁移

This commit is contained in:
DrSmoothl
2026-04-04 20:13:04 +08:00
parent d87e6ec0bb
commit 2fb911a8d5
11 changed files with 437 additions and 41 deletions

2
bot.py
View File

@@ -30,7 +30,7 @@ if env_path.exists():
load_dotenv(str(env_path), override=True) load_dotenv(str(env_path), override=True)
else: else:
print("[WIP] no .env file found, and templates is not ready yet.") print("[WIP] no .env file found, and templates is not ready yet.")
raise print("[WIP] continue startup, use environment and existing config values.")
# try: # try:
# if template_env_path.exists(): # if template_env_path.exists():
# shutil.copyfile(template_env_path, env_path) # shutil.copyfile(template_env_path, env_path)

View File

@@ -1,4 +1,4 @@
import { useCallback, useMemo, useState } from 'react' import { useCallback, useEffect, useMemo, useState } from 'react'
import { Check, ChevronsUpDown, Copy, Eye, EyeOff } from 'lucide-react' import { Check, ChevronsUpDown, Copy, Eye, EyeOff } from 'lucide-react'
import { Button } from '@/components/ui/button' import { Button } from '@/components/ui/button'
@@ -42,8 +42,16 @@ export function ProviderForm({
const [localProvider, setLocalProvider] = useState<APIProvider | null>(editingProvider) const [localProvider, setLocalProvider] = useState<APIProvider | null>(editingProvider)
const { toast } = useToast() const { toast } = useToast()
// 同步外部状态到本地 // 当弹窗打开时,根据当前编辑对象同步一次本地编辑状态
if (editingProvider !== localProvider && open) { useEffect(() => {
if (!open) {
setLocalProvider(null)
setFormErrors({})
setShowApiKey(false)
setSelectedTemplate('custom')
return
}
setLocalProvider(editingProvider) setLocalProvider(editingProvider)
setFormErrors({}) setFormErrors({})
setShowApiKey(false) setShowApiKey(false)
@@ -57,7 +65,7 @@ export function ProviderForm({
} else { } else {
setSelectedTemplate('custom') setSelectedTemplate('custom')
} }
} }, [open, editingProvider, editingIndex])
const isUsingTemplate = useMemo(() => selectedTemplate !== 'custom', [selectedTemplate]) const isUsingTemplate = useMemo(() => selectedTemplate !== 'custom', [selectedTemplate])

View File

@@ -0,0 +1,104 @@
from pathlib import Path
from types import SimpleNamespace
import sys
from src.config.legacy_migration import migrate_legacy_bind_env_to_bot_config_dict
from src.config.startup_bindings import (
BindAddress,
get_startup_main_bind_address,
get_startup_webui_bind_address,
resolve_main_bind_address,
resolve_webui_bind_address,
)
def test_startup_bindings_use_defaults_when_config_file_missing(tmp_path: Path):
missing_path = tmp_path / "missing_bot_config.toml"
assert get_startup_main_bind_address(missing_path) == BindAddress(host="127.0.0.1", port=8080)
assert get_startup_webui_bind_address(missing_path) == BindAddress(host="127.0.0.1", port=8001)
def test_startup_bindings_can_read_addresses_from_bot_config(tmp_path: Path):
config_path = tmp_path / "bot_config.toml"
config_path.write_text(
"""
[inner]
version = "8.3.1"
[maim_message]
ws_server_host = "0.0.0.0"
ws_server_port = 22345
[webui]
host = "192.168.1.9"
port = 18001
""".strip(),
encoding="utf-8",
)
assert get_startup_main_bind_address(config_path) == BindAddress(host="0.0.0.0", port=22345)
assert get_startup_webui_bind_address(config_path) == BindAddress(host="192.168.1.9", port=18001)
def test_resolve_bindings_prefer_initialized_global_config(monkeypatch):
fake_config_module = SimpleNamespace(
global_config=SimpleNamespace(
maim_message=SimpleNamespace(ws_server_host="10.0.0.2", ws_server_port=32000),
webui=SimpleNamespace(host="10.0.0.3", port=32001),
)
)
monkeypatch.setitem(sys.modules, "src.config.config", fake_config_module)
assert resolve_main_bind_address() == BindAddress(host="10.0.0.2", port=32000)
assert resolve_webui_bind_address() == BindAddress(host="10.0.0.3", port=32001)
def test_legacy_env_bindings_are_migrated_when_fields_missing_or_default(monkeypatch):
monkeypatch.setenv("HOST", "0.0.0.0")
monkeypatch.setenv("PORT", "22345")
monkeypatch.setenv("WEBUI_HOST", "192.168.1.8")
monkeypatch.setenv("WEBUI_PORT", "19001")
payload = {
"maim_message": {
"ws_server_host": "127.0.0.1",
"ws_server_port": 8080,
},
"webui": {},
}
result = migrate_legacy_bind_env_to_bot_config_dict(payload)
assert result.migrated is True
assert payload["maim_message"]["ws_server_host"] == "0.0.0.0"
assert payload["maim_message"]["ws_server_port"] == 22345
assert payload["webui"]["host"] == "192.168.1.8"
assert payload["webui"]["port"] == 19001
def test_legacy_env_bindings_do_not_override_explicit_config(monkeypatch):
monkeypatch.setenv("HOST", "0.0.0.0")
monkeypatch.setenv("PORT", "22345")
monkeypatch.setenv("WEBUI_HOST", "192.168.1.8")
monkeypatch.setenv("WEBUI_PORT", "19001")
payload = {
"maim_message": {
"ws_server_host": "10.1.1.1",
"ws_server_port": 30000,
},
"webui": {
"host": "10.1.1.2",
"port": 30001,
},
}
result = migrate_legacy_bind_env_to_bot_config_dict(payload)
assert result.migrated is False
assert payload["maim_message"]["ws_server_host"] == "10.1.1.1"
assert payload["maim_message"]["ws_server_port"] == 30000
assert payload["webui"]["host"] == "10.1.1.2"
assert payload["webui"]["port"] == 30001

View File

@@ -480,8 +480,20 @@ class EmojiManager:
logger.error(f"[注册表情包] 表情包文件不存在: {emoji.full_path}") logger.error(f"[注册表情包] 表情包文件不存在: {emoji.full_path}")
return False return False
# 将表情包移动到已注册目录
target_path = EMOJI_REGISTERED_DIR / emoji.file_name target_path = EMOJI_REGISTERED_DIR / emoji.file_name
# 先查库,避免重复记录导致文件被误移动后无法回收
original_path = emoji.full_path
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
existing_record = session.exec(statement).first()
if existing_record and not existing_record.no_file_flag:
logger.warning(f"[注册表情包] 数据库中已存在表情包记录: {emoji.file_hash}")
return False
except Exception as e:
logger.error(f"[注册表情包] 查询数据库时出错: {e}")
return False
try: try:
emoji.full_path.replace(target_path) emoji.full_path.replace(target_path)
emoji.full_path = target_path emoji.full_path = target_path
@@ -490,6 +502,7 @@ class EmojiManager:
return False return False
# 注册到数据库 # 注册到数据库
restore_file = False
try: try:
with get_db_session() as session: with get_db_session() as session:
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1) statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
@@ -509,6 +522,7 @@ class EmojiManager:
) )
else: else:
logger.warning(f"[注册表情包] 数据库中已存在表情包记录: {emoji.file_hash}") logger.warning(f"[注册表情包] 数据库中已存在表情包记录: {emoji.file_hash}")
restore_file = True
return False return False
else: else:
image_record = emoji.to_db_instance() image_record = emoji.to_db_instance()
@@ -521,7 +535,15 @@ class EmojiManager:
logger.info(f"[注册表情包] 成功注册表情包到数据库, ID: {record_id}, 路径: {emoji.full_path}") logger.info(f"[注册表情包] 成功注册表情包到数据库, ID: {record_id}, 路径: {emoji.full_path}")
except Exception as e: except Exception as e:
logger.error(f"[注册表情包] 注册到数据库时出错: {e}") logger.error(f"[注册表情包] 注册到数据库时出错: {e}")
restore_file = True
return False return False
finally:
if restore_file:
try:
emoji.full_path.replace(original_path)
emoji.full_path = original_path
except Exception as e:
logger.error(f"[注册表情包] 回滚文件移动失败: {e}")
return True return True
def delete_emoji(self, emoji: MaiEmoji, no_desc: bool = False) -> bool: def delete_emoji(self, emoji: MaiEmoji, no_desc: bool = False) -> bool:
@@ -1045,7 +1067,13 @@ class EmojiManager:
logger.error(f"[注册表情包] 创建表情包对象时出错: {e}") logger.error(f"[注册表情包] 创建表情包对象时出错: {e}")
return False return False
# 0. 先验证数据库中是否已经存在相同哈希的表情包,避免重复构建 calc_success = await target_emoji.calculate_hash_format()
if not calc_success:
logger.error(f"[注册表情包] 计算表情包哈希值和格式失败: {file_full_path}")
return False
file_full_path = target_emoji.full_path # 更新为可能修正后的路径
# 2. 先验证数据库中是否已经存在相同哈希的表情包,避免重复构建
try: try:
with get_db_session_manual() as session: with get_db_session_manual() as session:
statement = ( statement = (
@@ -1068,13 +1096,7 @@ class EmojiManager:
logger.error(f"[注册表情包] 查询数据库时出错: {e}") logger.error(f"[注册表情包] 查询数据库时出错: {e}")
return False return False
# 1. 计算哈希值和格式 # 3. 检查内存缓存是否已经存在
calc_success = await target_emoji.calculate_hash_format()
if not calc_success:
logger.error(f"[注册表情包] 计算表情包哈希值和格式失败: {file_full_path}")
return False
file_full_path = target_emoji.full_path # 更新为可能修正后的路径
# 2. 检查是否已经存在过
if existing_emoji := self.get_emoji_by_hash(target_emoji.file_hash): if existing_emoji := self.get_emoji_by_hash(target_emoji.file_hash):
logger.warning(f"[注册表情包] 表情包已存在,跳过注册: {existing_emoji.file_name}") logger.warning(f"[注册表情包] 表情包已存在,跳过注册: {existing_emoji.file_name}")
return False return False

View File

@@ -1,11 +1,10 @@
from maim_message import MessageServer from importlib import metadata
import traceback import traceback
import importlib.metadata
from maim_message import MessageServer
from src.common.logger import adopt_library_logger, get_logger from src.common.logger import adopt_library_logger, get_logger
from src.common.utils.port_checker import assert_port_available from src.common.utils.port_checker import assert_port_available
from src.config.config import global_config
from .server import get_global_server from .server import get_global_server
global_api = None global_api = None
@@ -14,10 +13,12 @@ adopt_library_logger("maim_message", handler_names={"maim_message_default_handle
def get_global_api() -> MessageServer: # sourcery skip: extract-method def get_global_api() -> MessageServer: # sourcery skip: extract-method
"""获取全局MessageServer实例""" """获取全局MessageServer实例"""
from src.config.config import global_config
global global_api global global_api
if global_api is None: if global_api is None:
# 检查maim_message版本 # 检查maim_message版本
maim_message_version = importlib.metadata.version("maim_message") maim_message_version = metadata.version("maim_message")
version_int = [int(x) for x in maim_message_version.split(".")] version_int = [int(x) for x in maim_message_version.split(".")]
if version_int < [0, 6, 2]: if version_int < [0, 6, 2]:
raise RuntimeError("maim_message 版本过低,请升级到 0.6.2 或更高版本。") raise RuntimeError("maim_message 版本过低,请升级到 0.6.2 或更高版本。")

View File

@@ -1,12 +1,14 @@
from typing import Optional
import asyncio import asyncio
from fastapi import FastAPI, APIRouter from fastapi import APIRouter, FastAPI
from rich.traceback import install from rich.traceback import install
from typing import Optional
from uvicorn import Config, Server as UvicornServer from uvicorn import Config, Server as UvicornServer
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.utils.port_checker import assert_port_available, is_port_conflict_error, log_port_conflict from src.common.utils.port_checker import assert_port_available, is_port_conflict_error, log_port_conflict
from src.config.startup_bindings import resolve_main_bind_address
install(extra_lines=3) install(extra_lines=3)
@@ -21,7 +23,7 @@ class Server:
self._server: Optional[UvicornServer] = None self._server: Optional[UvicornServer] = None
self.set_address(host, port) self.set_address(host, port)
def register_router(self, router: APIRouter, prefix: str = ""): def register_router(self, router: APIRouter, prefix: str = ""):
"""注册路由 """注册路由
APIRouter 用于对相关的路由端点进行分组和模块化管理: APIRouter 用于对相关的路由端点进行分组和模块化管理:
@@ -121,11 +123,8 @@ global_server = None
def get_global_server() -> Server: def get_global_server() -> Server:
"""获取全局服务器实例""" """获取全局服务器实例"""
from src.config.config import global_config
global global_server global global_server
if global_server is None: if global_server is None:
global_server = Server( bind_address = resolve_main_bind_address()
host=global_config.maim_message.ws_server_host, port=global_config.maim_message.ws_server_port global_server = Server(host=bind_address.host, port=bind_address.port)
)
return global_server return global_server

View File

@@ -12,7 +12,7 @@ import tomlkit
from .config_base import AttributeData, ConfigBase, Field from .config_base import AttributeData, ConfigBase, Field
from .config_utils import compare_versions, output_config_changes, recursive_parse_item_to_table from .config_utils import compare_versions, output_config_changes, recursive_parse_item_to_table
from .file_watcher import FileChange, FileWatcher from .file_watcher import FileChange, FileWatcher
from .legacy_migration import try_migrate_legacy_bot_config_dict from .legacy_migration import migrate_legacy_bind_env_to_bot_config_dict, try_migrate_legacy_bot_config_dict
from .model_configs import APIProvider, ModelInfo, ModelTaskConfig from .model_configs import APIProvider, ModelInfo, ModelTaskConfig
from .official_configs import ( from .official_configs import (
BotConfig, BotConfig,
@@ -55,7 +55,7 @@ CONFIG_DIR: Path = PROJECT_ROOT / "config"
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute() BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute() MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
MMC_VERSION: str = "1.0.0" MMC_VERSION: str = "1.0.0"
CONFIG_VERSION: str = "8.3.0" CONFIG_VERSION: str = "8.3.1"
MODEL_CONFIG_VERSION: str = "1.13.1" MODEL_CONFIG_VERSION: str = "1.13.1"
logger = get_logger("config") logger = get_logger("config")
@@ -472,6 +472,11 @@ def load_config_from_file(
old_ver: str = inner_version old_ver: str = inner_version
config_data.remove("inner") # 移除 inner 部分,避免干扰后续处理 config_data.remove("inner") # 移除 inner 部分,避免干扰后续处理
config_data = config_data.unwrap() # 转换为普通字典,方便后续处理 config_data = config_data.unwrap() # 转换为普通字典,方便后续处理
if config_path.name == "bot_config.toml" and config_class.__name__ == "Config":
env_migration = migrate_legacy_bind_env_to_bot_config_dict(config_data)
if env_migration.migrated:
logger.warning(f"检测到旧版环境变量绑定配置,已迁移到主配置: {env_migration.reason}")
config_data = env_migration.data
# 保留一份“干净”的原始数据副本,避免第一次 from_dict 过程中对 dict 的就地修改 # 保留一份“干净”的原始数据副本,避免第一次 from_dict 过程中对 dict 的就地修改
original_data: dict[str, Any] = copy.deepcopy(config_data) original_data: dict[str, Any] = copy.deepcopy(config_data)
try: try:

View File

@@ -14,6 +14,8 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional from typing import Any, Optional
import os
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger("legacy_migration") logger = get_logger("legacy_migration")
@@ -38,6 +40,41 @@ def _as_list(x: Any) -> Optional[list[Any]]:
return x if isinstance(x, list) else None return x if isinstance(x, list) else None
def _parse_host_env(value: Any) -> Optional[str]:
if not isinstance(value, str):
return None
normalized_value = value.strip()
return normalized_value or None
def _parse_port_env(value: Any) -> Optional[int]:
if isinstance(value, bool):
return None
try:
normalized_value = int(str(value).strip())
except (TypeError, ValueError):
return None
if normalized_value <= 0 or normalized_value > 65535:
return None
return normalized_value
def _migrate_env_value(section: dict[str, Any], key: str, parsed_env_value: Any, default_value: Any) -> bool:
if parsed_env_value is None:
return False
current_value = section.get(key)
if current_value == parsed_env_value:
return False
if key in section and current_value != default_value:
return False
section[key] = parsed_env_value
return True
def _parse_triplet_target(s: str) -> Optional[dict[str, str]]: def _parse_triplet_target(s: str) -> Optional[dict[str, str]]:
""" """
解析 "platform:id:type" -> {platform,item_id,rule_type} 解析 "platform:id:type" -> {platform,item_id,rule_type}
@@ -236,6 +273,43 @@ def _migrate_extra_prompt_list(exp: dict[str, Any], key: str) -> bool:
return True return True
def migrate_legacy_bind_env_to_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
"""将旧版环境变量中的绑定地址迁移到主配置结构。"""
migrated_any = False
reasons: list[str] = []
main_host_env = _parse_host_env(os.getenv("HOST"))
main_port_env = _parse_port_env(os.getenv("PORT"))
maim_message = _as_dict(data.get("maim_message"))
if maim_message is None and (main_host_env is not None or main_port_env is not None):
maim_message = {}
data["maim_message"] = maim_message
if maim_message is not None and _migrate_env_value(maim_message, "ws_server_host", main_host_env, "127.0.0.1"):
migrated_any = True
reasons.append("HOST->maim_message.ws_server_host")
if maim_message is not None and _migrate_env_value(maim_message, "ws_server_port", main_port_env, 8080):
migrated_any = True
reasons.append("PORT->maim_message.ws_server_port")
webui_host_env = _parse_host_env(os.getenv("WEBUI_HOST"))
webui_port_env = _parse_port_env(os.getenv("WEBUI_PORT"))
webui = _as_dict(data.get("webui"))
if webui is None and (webui_host_env is not None or webui_port_env is not None):
webui = {}
data["webui"] = webui
if webui is not None and _migrate_env_value(webui, "host", webui_host_env, "127.0.0.1"):
migrated_any = True
reasons.append("WEBUI_HOST->webui.host")
if webui is not None and _migrate_env_value(webui, "port", webui_port_env, 8001):
migrated_any = True
reasons.append("WEBUI_PORT->webui.port")
return MigrationResult(data=data, migrated=migrated_any, reason=",".join(reasons))
def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult: def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
""" """
尝试对“总配置 bot_config.toml”的 dict已 unwrap进行旧格式修复。 尝试对“总配置 bot_config.toml”的 dict已 unwrap进行旧格式修复。

View File

@@ -1414,6 +1414,24 @@ class WebUIConfig(ConfigBase):
) )
"""是否启用WebUI""" """是否启用WebUI"""
host: str = Field(
default="127.0.0.1",
json_schema_extra={
"x-widget": "input",
"x-icon": "globe",
},
)
"""WebUI 绑定主机地址"""
port: int = Field(
default=8001,
json_schema_extra={
"x-widget": "input",
"x-icon": "hash",
},
)
"""WebUI 绑定端口"""
mode: Literal["development", "production"] = Field( mode: Literal["development", "production"] = Field(
default="production", default="production",
json_schema_extra={ json_schema_extra={

View File

@@ -0,0 +1,135 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Mapping, Optional
import sys
import tomlkit
PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve()
CONFIG_DIR: Path = PROJECT_ROOT / "config"
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
@dataclass(frozen=True)
class BindAddress:
"""启动阶段使用的绑定地址。"""
host: str
port: int
_DEFAULT_MAIN_BIND_ADDRESS = BindAddress(host="127.0.0.1", port=8080)
_DEFAULT_WEBUI_BIND_ADDRESS = BindAddress(host="127.0.0.1", port=8001)
def _as_mapping(value: Any) -> Optional[Mapping[str, Any]]:
return value if isinstance(value, Mapping) else None
def _normalize_host(value: Any, default_host: str) -> str:
if not isinstance(value, str):
return default_host
normalized_host = value.strip()
return normalized_host or default_host
def _normalize_port(value: Any, default_port: int) -> int:
if isinstance(value, bool):
return default_port
try:
normalized_port = int(value)
except (TypeError, ValueError):
return default_port
if normalized_port <= 0 or normalized_port > 65535:
return default_port
return normalized_port
def _load_bootstrap_config_dict(config_path: Path = BOT_CONFIG_PATH) -> Dict[str, Any]:
"""读取启动阶段需要的最小配置,不依赖完整 ConfigManager。"""
if not config_path.exists():
return {}
try:
with open(config_path, "r", encoding="utf-8") as file_obj:
config_data = tomlkit.load(file_obj).unwrap()
except Exception:
return {}
if not isinstance(config_data, dict):
return {}
return config_data
def _resolve_bind_address_from_section(
section: Mapping[str, Any],
host_key: str,
port_key: str,
default_address: BindAddress,
) -> BindAddress:
return BindAddress(
host=_normalize_host(section.get(host_key), default_address.host),
port=_normalize_port(section.get(port_key), default_address.port),
)
def _get_loaded_global_config() -> Optional[Any]:
config_module = sys.modules.get("src.config.config")
if config_module is None:
return None
return getattr(config_module, "global_config", None)
def get_startup_main_bind_address(config_path: Path = BOT_CONFIG_PATH) -> BindAddress:
"""读取主程序消息服务绑定地址。"""
config_data = _load_bootstrap_config_dict(config_path)
maim_message_config = _as_mapping(config_data.get("maim_message")) or {}
return _resolve_bind_address_from_section(
maim_message_config,
host_key="ws_server_host",
port_key="ws_server_port",
default_address=_DEFAULT_MAIN_BIND_ADDRESS,
)
def get_startup_webui_bind_address(config_path: Path = BOT_CONFIG_PATH) -> BindAddress:
"""读取 WebUI 绑定地址。"""
config_data = _load_bootstrap_config_dict(config_path)
webui_config = _as_mapping(config_data.get("webui")) or {}
return _resolve_bind_address_from_section(
webui_config,
host_key="host",
port_key="port",
default_address=_DEFAULT_WEBUI_BIND_ADDRESS,
)
def resolve_main_bind_address(config_path: Path = BOT_CONFIG_PATH) -> BindAddress:
"""优先读取已初始化的主配置,否则回退到启动阶段配置读取。"""
global_config = _get_loaded_global_config()
if global_config is not None:
return BindAddress(
host=global_config.maim_message.ws_server_host,
port=global_config.maim_message.ws_server_port,
)
return get_startup_main_bind_address(config_path)
def resolve_webui_bind_address(config_path: Path = BOT_CONFIG_PATH) -> BindAddress:
"""优先读取已初始化的主配置,否则回退到启动阶段配置读取。"""
global_config = _get_loaded_global_config()
if global_config is not None:
return BindAddress(
host=global_config.webui.host,
port=global_config.webui.port,
)
return get_startup_webui_bind_address(config_path)

View File

@@ -1,18 +1,28 @@
"""独立的 WebUI 服务器 - 运行在 0.0.0.0:8001""" """独立的 WebUI 服务器"""
from typing import Any, Optional
import asyncio import asyncio
import sys
from uvicorn import Config from uvicorn import Config
from uvicorn import Server as UvicornServer from uvicorn import Server as UvicornServer
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.utils.port_checker import assert_port_available, is_port_conflict_error, log_port_conflict from src.common.utils.port_checker import assert_port_available, is_port_conflict_error, log_port_conflict
from src.config.config import config_manager from src.config.startup_bindings import resolve_webui_bind_address
from src.webui.app import create_app, show_access_token from src.webui.app import create_app, show_access_token
logger = get_logger("webui_server") logger = get_logger("webui_server")
def _get_loaded_config_manager() -> Optional[Any]:
config_module = sys.modules.get("src.config.config")
if config_module is None:
return None
return getattr(config_module, "config_manager", None)
class _ASGIProxy: class _ASGIProxy:
def __init__(self, app): def __init__(self, app):
self._app = app self._app = app
@@ -32,10 +42,33 @@ class WebUIServer:
self.port = port self.port = port
self._app = create_app(host=host, port=port, enable_static=True) self._app = create_app(host=host, port=port, enable_static=True)
self.app = _ASGIProxy(self._app) self.app = _ASGIProxy(self._app)
self._server = None self._server: Optional[UvicornServer] = None
self._reload_callback_registered = False
show_access_token() show_access_token()
self._maybe_register_reload_callback()
def _maybe_register_reload_callback(self) -> None:
if self._reload_callback_registered:
return
config_manager = _get_loaded_config_manager()
if config_manager is None:
return
config_manager.register_reload_callback(self.reload_app) config_manager.register_reload_callback(self.reload_app)
self._reload_callback_registered = True
def _maybe_unregister_reload_callback(self) -> None:
if not self._reload_callback_registered:
return
config_manager = _get_loaded_config_manager()
if config_manager is None:
return
config_manager.unregister_reload_callback(self.reload_app)
self._reload_callback_registered = False
async def reload_app(self) -> None: async def reload_app(self) -> None:
self._app = create_app(host=self.host, port=self.port, enable_static=True) self._app = create_app(host=self.host, port=self.port, enable_static=True)
@@ -44,12 +77,13 @@ class WebUIServer:
async def start(self): async def start(self):
"""启动服务器""" """启动服务器"""
self._maybe_register_reload_callback()
assert_port_available( assert_port_available(
host=self.host, host=self.host,
port=self.port, port=self.port,
service_name="WebUI 服务器", service_name="WebUI 服务器",
logger=logger, logger=logger,
config_hint="WEBUI_PORT (.env)", config_hint="webui.port (config/bot_config.toml)",
allow_reuse_addr=True, allow_reuse_addr=True,
) )
@@ -88,7 +122,7 @@ class WebUIServer:
service_name="WebUI 服务器", service_name="WebUI 服务器",
host=self.host, host=self.host,
port=self.port, port=self.port,
config_hint="WEBUI_PORT (.env)", config_hint="webui.port (config/bot_config.toml)",
) )
else: else:
logger.error(f"❌ WebUI 服务器启动失败 (网络错误): {e}") logger.error(f"❌ WebUI 服务器启动失败 (网络错误): {e}")
@@ -97,7 +131,7 @@ class WebUIServer:
logger.error(f"❌ WebUI 服务器运行错误: {e}", exc_info=True) logger.error(f"❌ WebUI 服务器运行错误: {e}", exc_info=True)
raise raise
finally: finally:
config_manager.unregister_reload_callback(self.reload_app) self._maybe_unregister_reload_callback()
async def shutdown(self): async def shutdown(self):
"""关闭服务器""" """关闭服务器"""
@@ -123,10 +157,6 @@ def get_webui_server() -> WebUIServer:
"""获取全局 WebUI 服务器实例""" """获取全局 WebUI 服务器实例"""
global _webui_server global _webui_server
if _webui_server is None: if _webui_server is None:
# 从环境变量读取 bind_address = resolve_webui_bind_address()
import os _webui_server = WebUIServer(host=bind_address.host, port=bind_address.port)
host = os.getenv("WEBUI_HOST", "127.0.0.1")
port = int(os.getenv("WEBUI_PORT", "8001"))
_webui_server = WebUIServer(host=host, port=port)
return _webui_server return _webui_server