feat: 添加启动绑定地址解析功能,支持从配置文件和环境变量迁移
This commit is contained in:
2
bot.py
2
bot.py
@@ -30,7 +30,7 @@ if env_path.exists():
|
||||
load_dotenv(str(env_path), override=True)
|
||||
else:
|
||||
print("[WIP] no .env file found, and templates is not ready yet.")
|
||||
raise
|
||||
print("[WIP] continue startup, use environment and existing config values.")
|
||||
# try:
|
||||
# if template_env_path.exists():
|
||||
# shutil.copyfile(template_env_path, env_path)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react'
|
||||
import { Check, ChevronsUpDown, Copy, Eye, EyeOff } from 'lucide-react'
|
||||
|
||||
import { Button } from '@/components/ui/button'
|
||||
@@ -42,8 +42,16 @@ export function ProviderForm({
|
||||
const [localProvider, setLocalProvider] = useState<APIProvider | null>(editingProvider)
|
||||
const { toast } = useToast()
|
||||
|
||||
// 同步外部状态到本地
|
||||
if (editingProvider !== localProvider && open) {
|
||||
// 当弹窗打开时,根据当前编辑对象同步一次本地编辑状态
|
||||
useEffect(() => {
|
||||
if (!open) {
|
||||
setLocalProvider(null)
|
||||
setFormErrors({})
|
||||
setShowApiKey(false)
|
||||
setSelectedTemplate('custom')
|
||||
return
|
||||
}
|
||||
|
||||
setLocalProvider(editingProvider)
|
||||
setFormErrors({})
|
||||
setShowApiKey(false)
|
||||
@@ -57,7 +65,7 @@ export function ProviderForm({
|
||||
} else {
|
||||
setSelectedTemplate('custom')
|
||||
}
|
||||
}
|
||||
}, [open, editingProvider, editingIndex])
|
||||
|
||||
const isUsingTemplate = useMemo(() => selectedTemplate !== 'custom', [selectedTemplate])
|
||||
|
||||
|
||||
104
pytests/config_test/test_startup_bindings.py
Normal file
104
pytests/config_test/test_startup_bindings.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
import sys
|
||||
|
||||
from src.config.legacy_migration import migrate_legacy_bind_env_to_bot_config_dict
|
||||
from src.config.startup_bindings import (
|
||||
BindAddress,
|
||||
get_startup_main_bind_address,
|
||||
get_startup_webui_bind_address,
|
||||
resolve_main_bind_address,
|
||||
resolve_webui_bind_address,
|
||||
)
|
||||
|
||||
|
||||
def test_startup_bindings_use_defaults_when_config_file_missing(tmp_path: Path):
|
||||
missing_path = tmp_path / "missing_bot_config.toml"
|
||||
|
||||
assert get_startup_main_bind_address(missing_path) == BindAddress(host="127.0.0.1", port=8080)
|
||||
assert get_startup_webui_bind_address(missing_path) == BindAddress(host="127.0.0.1", port=8001)
|
||||
|
||||
|
||||
def test_startup_bindings_can_read_addresses_from_bot_config(tmp_path: Path):
|
||||
config_path = tmp_path / "bot_config.toml"
|
||||
config_path.write_text(
|
||||
"""
|
||||
[inner]
|
||||
version = "8.3.1"
|
||||
|
||||
[maim_message]
|
||||
ws_server_host = "0.0.0.0"
|
||||
ws_server_port = 22345
|
||||
|
||||
[webui]
|
||||
host = "192.168.1.9"
|
||||
port = 18001
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
assert get_startup_main_bind_address(config_path) == BindAddress(host="0.0.0.0", port=22345)
|
||||
assert get_startup_webui_bind_address(config_path) == BindAddress(host="192.168.1.9", port=18001)
|
||||
|
||||
|
||||
def test_resolve_bindings_prefer_initialized_global_config(monkeypatch):
|
||||
fake_config_module = SimpleNamespace(
|
||||
global_config=SimpleNamespace(
|
||||
maim_message=SimpleNamespace(ws_server_host="10.0.0.2", ws_server_port=32000),
|
||||
webui=SimpleNamespace(host="10.0.0.3", port=32001),
|
||||
)
|
||||
)
|
||||
|
||||
monkeypatch.setitem(sys.modules, "src.config.config", fake_config_module)
|
||||
|
||||
assert resolve_main_bind_address() == BindAddress(host="10.0.0.2", port=32000)
|
||||
assert resolve_webui_bind_address() == BindAddress(host="10.0.0.3", port=32001)
|
||||
|
||||
|
||||
def test_legacy_env_bindings_are_migrated_when_fields_missing_or_default(monkeypatch):
|
||||
monkeypatch.setenv("HOST", "0.0.0.0")
|
||||
monkeypatch.setenv("PORT", "22345")
|
||||
monkeypatch.setenv("WEBUI_HOST", "192.168.1.8")
|
||||
monkeypatch.setenv("WEBUI_PORT", "19001")
|
||||
|
||||
payload = {
|
||||
"maim_message": {
|
||||
"ws_server_host": "127.0.0.1",
|
||||
"ws_server_port": 8080,
|
||||
},
|
||||
"webui": {},
|
||||
}
|
||||
|
||||
result = migrate_legacy_bind_env_to_bot_config_dict(payload)
|
||||
|
||||
assert result.migrated is True
|
||||
assert payload["maim_message"]["ws_server_host"] == "0.0.0.0"
|
||||
assert payload["maim_message"]["ws_server_port"] == 22345
|
||||
assert payload["webui"]["host"] == "192.168.1.8"
|
||||
assert payload["webui"]["port"] == 19001
|
||||
|
||||
|
||||
def test_legacy_env_bindings_do_not_override_explicit_config(monkeypatch):
|
||||
monkeypatch.setenv("HOST", "0.0.0.0")
|
||||
monkeypatch.setenv("PORT", "22345")
|
||||
monkeypatch.setenv("WEBUI_HOST", "192.168.1.8")
|
||||
monkeypatch.setenv("WEBUI_PORT", "19001")
|
||||
|
||||
payload = {
|
||||
"maim_message": {
|
||||
"ws_server_host": "10.1.1.1",
|
||||
"ws_server_port": 30000,
|
||||
},
|
||||
"webui": {
|
||||
"host": "10.1.1.2",
|
||||
"port": 30001,
|
||||
},
|
||||
}
|
||||
|
||||
result = migrate_legacy_bind_env_to_bot_config_dict(payload)
|
||||
|
||||
assert result.migrated is False
|
||||
assert payload["maim_message"]["ws_server_host"] == "10.1.1.1"
|
||||
assert payload["maim_message"]["ws_server_port"] == 30000
|
||||
assert payload["webui"]["host"] == "10.1.1.2"
|
||||
assert payload["webui"]["port"] == 30001
|
||||
@@ -480,8 +480,20 @@ class EmojiManager:
|
||||
logger.error(f"[注册表情包] 表情包文件不存在: {emoji.full_path}")
|
||||
return False
|
||||
|
||||
# 将表情包移动到已注册目录
|
||||
target_path = EMOJI_REGISTERED_DIR / emoji.file_name
|
||||
|
||||
# 先查库,避免重复记录导致文件被误移动后无法回收
|
||||
original_path = emoji.full_path
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
||||
existing_record = session.exec(statement).first()
|
||||
if existing_record and not existing_record.no_file_flag:
|
||||
logger.warning(f"[注册表情包] 数据库中已存在表情包记录: {emoji.file_hash}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[注册表情包] 查询数据库时出错: {e}")
|
||||
return False
|
||||
try:
|
||||
emoji.full_path.replace(target_path)
|
||||
emoji.full_path = target_path
|
||||
@@ -490,6 +502,7 @@ class EmojiManager:
|
||||
return False
|
||||
|
||||
# 注册到数据库
|
||||
restore_file = False
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
||||
@@ -509,6 +522,7 @@ class EmojiManager:
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[注册表情包] 数据库中已存在表情包记录: {emoji.file_hash}")
|
||||
restore_file = True
|
||||
return False
|
||||
else:
|
||||
image_record = emoji.to_db_instance()
|
||||
@@ -521,7 +535,15 @@ class EmojiManager:
|
||||
logger.info(f"[注册表情包] 成功注册表情包到数据库, ID: {record_id}, 路径: {emoji.full_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"[注册表情包] 注册到数据库时出错: {e}")
|
||||
restore_file = True
|
||||
return False
|
||||
finally:
|
||||
if restore_file:
|
||||
try:
|
||||
emoji.full_path.replace(original_path)
|
||||
emoji.full_path = original_path
|
||||
except Exception as e:
|
||||
logger.error(f"[注册表情包] 回滚文件移动失败: {e}")
|
||||
return True
|
||||
|
||||
def delete_emoji(self, emoji: MaiEmoji, no_desc: bool = False) -> bool:
|
||||
@@ -1045,7 +1067,13 @@ class EmojiManager:
|
||||
logger.error(f"[注册表情包] 创建表情包对象时出错: {e}")
|
||||
return False
|
||||
|
||||
# 0. 先验证数据库中是否已经存在相同哈希的表情包,避免重复构建
|
||||
calc_success = await target_emoji.calculate_hash_format()
|
||||
if not calc_success:
|
||||
logger.error(f"[注册表情包] 计算表情包哈希值和格式失败: {file_full_path}")
|
||||
return False
|
||||
file_full_path = target_emoji.full_path # 更新为可能修正后的路径
|
||||
|
||||
# 2. 先验证数据库中是否已经存在相同哈希的表情包,避免重复构建
|
||||
try:
|
||||
with get_db_session_manual() as session:
|
||||
statement = (
|
||||
@@ -1068,13 +1096,7 @@ class EmojiManager:
|
||||
logger.error(f"[注册表情包] 查询数据库时出错: {e}")
|
||||
return False
|
||||
|
||||
# 1. 计算哈希值和格式
|
||||
calc_success = await target_emoji.calculate_hash_format()
|
||||
if not calc_success:
|
||||
logger.error(f"[注册表情包] 计算表情包哈希值和格式失败: {file_full_path}")
|
||||
return False
|
||||
file_full_path = target_emoji.full_path # 更新为可能修正后的路径
|
||||
# 2. 检查是否已经存在过
|
||||
# 3. 检查内存缓存是否已经存在
|
||||
if existing_emoji := self.get_emoji_by_hash(target_emoji.file_hash):
|
||||
logger.warning(f"[注册表情包] 表情包已存在,跳过注册: {existing_emoji.file_name}")
|
||||
return False
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from maim_message import MessageServer
|
||||
|
||||
from importlib import metadata
|
||||
import traceback
|
||||
import importlib.metadata
|
||||
|
||||
from maim_message import MessageServer
|
||||
|
||||
from src.common.logger import adopt_library_logger, get_logger
|
||||
from src.common.utils.port_checker import assert_port_available
|
||||
from src.config.config import global_config
|
||||
from .server import get_global_server
|
||||
|
||||
global_api = None
|
||||
@@ -14,10 +13,12 @@ adopt_library_logger("maim_message", handler_names={"maim_message_default_handle
|
||||
|
||||
def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
||||
"""获取全局MessageServer实例"""
|
||||
from src.config.config import global_config
|
||||
|
||||
global global_api
|
||||
if global_api is None:
|
||||
# 检查maim_message版本
|
||||
maim_message_version = importlib.metadata.version("maim_message")
|
||||
maim_message_version = metadata.version("maim_message")
|
||||
version_int = [int(x) for x in maim_message_version.split(".")]
|
||||
if version_int < [0, 6, 2]:
|
||||
raise RuntimeError("maim_message 版本过低,请升级到 0.6.2 或更高版本。")
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from typing import Optional
|
||||
|
||||
import asyncio
|
||||
|
||||
from fastapi import FastAPI, APIRouter
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from rich.traceback import install
|
||||
from typing import Optional
|
||||
from uvicorn import Config, Server as UvicornServer
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.port_checker import assert_port_available, is_port_conflict_error, log_port_conflict
|
||||
from src.config.startup_bindings import resolve_main_bind_address
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -121,11 +123,8 @@ global_server = None
|
||||
|
||||
def get_global_server() -> Server:
|
||||
"""获取全局服务器实例"""
|
||||
from src.config.config import global_config
|
||||
|
||||
global global_server
|
||||
if global_server is None:
|
||||
global_server = Server(
|
||||
host=global_config.maim_message.ws_server_host, port=global_config.maim_message.ws_server_port
|
||||
)
|
||||
bind_address = resolve_main_bind_address()
|
||||
global_server = Server(host=bind_address.host, port=bind_address.port)
|
||||
return global_server
|
||||
|
||||
@@ -12,7 +12,7 @@ import tomlkit
|
||||
from .config_base import AttributeData, ConfigBase, Field
|
||||
from .config_utils import compare_versions, output_config_changes, recursive_parse_item_to_table
|
||||
from .file_watcher import FileChange, FileWatcher
|
||||
from .legacy_migration import try_migrate_legacy_bot_config_dict
|
||||
from .legacy_migration import migrate_legacy_bind_env_to_bot_config_dict, try_migrate_legacy_bot_config_dict
|
||||
from .model_configs import APIProvider, ModelInfo, ModelTaskConfig
|
||||
from .official_configs import (
|
||||
BotConfig,
|
||||
@@ -55,7 +55,7 @@ CONFIG_DIR: Path = PROJECT_ROOT / "config"
|
||||
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
||||
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
||||
MMC_VERSION: str = "1.0.0"
|
||||
CONFIG_VERSION: str = "8.3.0"
|
||||
CONFIG_VERSION: str = "8.3.1"
|
||||
MODEL_CONFIG_VERSION: str = "1.13.1"
|
||||
|
||||
logger = get_logger("config")
|
||||
@@ -472,6 +472,11 @@ def load_config_from_file(
|
||||
old_ver: str = inner_version
|
||||
config_data.remove("inner") # 移除 inner 部分,避免干扰后续处理
|
||||
config_data = config_data.unwrap() # 转换为普通字典,方便后续处理
|
||||
if config_path.name == "bot_config.toml" and config_class.__name__ == "Config":
|
||||
env_migration = migrate_legacy_bind_env_to_bot_config_dict(config_data)
|
||||
if env_migration.migrated:
|
||||
logger.warning(f"检测到旧版环境变量绑定配置,已迁移到主配置: {env_migration.reason}")
|
||||
config_data = env_migration.data
|
||||
# 保留一份“干净”的原始数据副本,避免第一次 from_dict 过程中对 dict 的就地修改
|
||||
original_data: dict[str, Any] = copy.deepcopy(config_data)
|
||||
try:
|
||||
|
||||
@@ -14,6 +14,8 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import os
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("legacy_migration")
|
||||
@@ -38,6 +40,41 @@ def _as_list(x: Any) -> Optional[list[Any]]:
|
||||
return x if isinstance(x, list) else None
|
||||
|
||||
|
||||
def _parse_host_env(value: Any) -> Optional[str]:
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
normalized_value = value.strip()
|
||||
return normalized_value or None
|
||||
|
||||
|
||||
def _parse_port_env(value: Any) -> Optional[int]:
|
||||
if isinstance(value, bool):
|
||||
return None
|
||||
|
||||
try:
|
||||
normalized_value = int(str(value).strip())
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
if normalized_value <= 0 or normalized_value > 65535:
|
||||
return None
|
||||
return normalized_value
|
||||
|
||||
|
||||
def _migrate_env_value(section: dict[str, Any], key: str, parsed_env_value: Any, default_value: Any) -> bool:
|
||||
if parsed_env_value is None:
|
||||
return False
|
||||
|
||||
current_value = section.get(key)
|
||||
if current_value == parsed_env_value:
|
||||
return False
|
||||
if key in section and current_value != default_value:
|
||||
return False
|
||||
|
||||
section[key] = parsed_env_value
|
||||
return True
|
||||
|
||||
|
||||
def _parse_triplet_target(s: str) -> Optional[dict[str, str]]:
|
||||
"""
|
||||
解析 "platform:id:type" -> {platform,item_id,rule_type}
|
||||
@@ -236,6 +273,43 @@ def _migrate_extra_prompt_list(exp: dict[str, Any], key: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def migrate_legacy_bind_env_to_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
"""将旧版环境变量中的绑定地址迁移到主配置结构。"""
|
||||
|
||||
migrated_any = False
|
||||
reasons: list[str] = []
|
||||
|
||||
main_host_env = _parse_host_env(os.getenv("HOST"))
|
||||
main_port_env = _parse_port_env(os.getenv("PORT"))
|
||||
maim_message = _as_dict(data.get("maim_message"))
|
||||
if maim_message is None and (main_host_env is not None or main_port_env is not None):
|
||||
maim_message = {}
|
||||
data["maim_message"] = maim_message
|
||||
|
||||
if maim_message is not None and _migrate_env_value(maim_message, "ws_server_host", main_host_env, "127.0.0.1"):
|
||||
migrated_any = True
|
||||
reasons.append("HOST->maim_message.ws_server_host")
|
||||
if maim_message is not None and _migrate_env_value(maim_message, "ws_server_port", main_port_env, 8080):
|
||||
migrated_any = True
|
||||
reasons.append("PORT->maim_message.ws_server_port")
|
||||
|
||||
webui_host_env = _parse_host_env(os.getenv("WEBUI_HOST"))
|
||||
webui_port_env = _parse_port_env(os.getenv("WEBUI_PORT"))
|
||||
webui = _as_dict(data.get("webui"))
|
||||
if webui is None and (webui_host_env is not None or webui_port_env is not None):
|
||||
webui = {}
|
||||
data["webui"] = webui
|
||||
|
||||
if webui is not None and _migrate_env_value(webui, "host", webui_host_env, "127.0.0.1"):
|
||||
migrated_any = True
|
||||
reasons.append("WEBUI_HOST->webui.host")
|
||||
if webui is not None and _migrate_env_value(webui, "port", webui_port_env, 8001):
|
||||
migrated_any = True
|
||||
reasons.append("WEBUI_PORT->webui.port")
|
||||
|
||||
return MigrationResult(data=data, migrated=migrated_any, reason=",".join(reasons))
|
||||
|
||||
|
||||
def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
"""
|
||||
尝试对“总配置 bot_config.toml”的 dict(已 unwrap)进行旧格式修复。
|
||||
|
||||
@@ -1414,6 +1414,24 @@ class WebUIConfig(ConfigBase):
|
||||
)
|
||||
"""是否启用WebUI"""
|
||||
|
||||
host: str = Field(
|
||||
default="127.0.0.1",
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "globe",
|
||||
},
|
||||
)
|
||||
"""WebUI 绑定主机地址"""
|
||||
|
||||
port: int = Field(
|
||||
default=8001,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "hash",
|
||||
},
|
||||
)
|
||||
"""WebUI 绑定端口"""
|
||||
|
||||
mode: Literal["development", "production"] = Field(
|
||||
default="production",
|
||||
json_schema_extra={
|
||||
|
||||
135
src/config/startup_bindings.py
Normal file
135
src/config/startup_bindings.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Mapping, Optional
|
||||
import sys
|
||||
|
||||
import tomlkit
|
||||
|
||||
|
||||
PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve()
|
||||
CONFIG_DIR: Path = PROJECT_ROOT / "config"
|
||||
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BindAddress:
|
||||
"""启动阶段使用的绑定地址。"""
|
||||
|
||||
host: str
|
||||
port: int
|
||||
|
||||
|
||||
_DEFAULT_MAIN_BIND_ADDRESS = BindAddress(host="127.0.0.1", port=8080)
|
||||
_DEFAULT_WEBUI_BIND_ADDRESS = BindAddress(host="127.0.0.1", port=8001)
|
||||
|
||||
|
||||
def _as_mapping(value: Any) -> Optional[Mapping[str, Any]]:
|
||||
return value if isinstance(value, Mapping) else None
|
||||
|
||||
|
||||
def _normalize_host(value: Any, default_host: str) -> str:
|
||||
if not isinstance(value, str):
|
||||
return default_host
|
||||
|
||||
normalized_host = value.strip()
|
||||
return normalized_host or default_host
|
||||
|
||||
|
||||
def _normalize_port(value: Any, default_port: int) -> int:
|
||||
if isinstance(value, bool):
|
||||
return default_port
|
||||
|
||||
try:
|
||||
normalized_port = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default_port
|
||||
|
||||
if normalized_port <= 0 or normalized_port > 65535:
|
||||
return default_port
|
||||
return normalized_port
|
||||
|
||||
|
||||
def _load_bootstrap_config_dict(config_path: Path = BOT_CONFIG_PATH) -> Dict[str, Any]:
|
||||
"""读取启动阶段需要的最小配置,不依赖完整 ConfigManager。"""
|
||||
|
||||
if not config_path.exists():
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as file_obj:
|
||||
config_data = tomlkit.load(file_obj).unwrap()
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
if not isinstance(config_data, dict):
|
||||
return {}
|
||||
return config_data
|
||||
|
||||
|
||||
def _resolve_bind_address_from_section(
|
||||
section: Mapping[str, Any],
|
||||
host_key: str,
|
||||
port_key: str,
|
||||
default_address: BindAddress,
|
||||
) -> BindAddress:
|
||||
return BindAddress(
|
||||
host=_normalize_host(section.get(host_key), default_address.host),
|
||||
port=_normalize_port(section.get(port_key), default_address.port),
|
||||
)
|
||||
|
||||
|
||||
def _get_loaded_global_config() -> Optional[Any]:
|
||||
config_module = sys.modules.get("src.config.config")
|
||||
if config_module is None:
|
||||
return None
|
||||
return getattr(config_module, "global_config", None)
|
||||
|
||||
|
||||
def get_startup_main_bind_address(config_path: Path = BOT_CONFIG_PATH) -> BindAddress:
|
||||
"""读取主程序消息服务绑定地址。"""
|
||||
|
||||
config_data = _load_bootstrap_config_dict(config_path)
|
||||
maim_message_config = _as_mapping(config_data.get("maim_message")) or {}
|
||||
return _resolve_bind_address_from_section(
|
||||
maim_message_config,
|
||||
host_key="ws_server_host",
|
||||
port_key="ws_server_port",
|
||||
default_address=_DEFAULT_MAIN_BIND_ADDRESS,
|
||||
)
|
||||
|
||||
|
||||
def get_startup_webui_bind_address(config_path: Path = BOT_CONFIG_PATH) -> BindAddress:
|
||||
"""读取 WebUI 绑定地址。"""
|
||||
|
||||
config_data = _load_bootstrap_config_dict(config_path)
|
||||
webui_config = _as_mapping(config_data.get("webui")) or {}
|
||||
return _resolve_bind_address_from_section(
|
||||
webui_config,
|
||||
host_key="host",
|
||||
port_key="port",
|
||||
default_address=_DEFAULT_WEBUI_BIND_ADDRESS,
|
||||
)
|
||||
|
||||
|
||||
def resolve_main_bind_address(config_path: Path = BOT_CONFIG_PATH) -> BindAddress:
|
||||
"""优先读取已初始化的主配置,否则回退到启动阶段配置读取。"""
|
||||
|
||||
global_config = _get_loaded_global_config()
|
||||
if global_config is not None:
|
||||
return BindAddress(
|
||||
host=global_config.maim_message.ws_server_host,
|
||||
port=global_config.maim_message.ws_server_port,
|
||||
)
|
||||
return get_startup_main_bind_address(config_path)
|
||||
|
||||
|
||||
def resolve_webui_bind_address(config_path: Path = BOT_CONFIG_PATH) -> BindAddress:
|
||||
"""优先读取已初始化的主配置,否则回退到启动阶段配置读取。"""
|
||||
|
||||
global_config = _get_loaded_global_config()
|
||||
if global_config is not None:
|
||||
return BindAddress(
|
||||
host=global_config.webui.host,
|
||||
port=global_config.webui.port,
|
||||
)
|
||||
return get_startup_webui_bind_address(config_path)
|
||||
@@ -1,18 +1,28 @@
|
||||
"""独立的 WebUI 服务器 - 运行在 0.0.0.0:8001"""
|
||||
"""独立的 WebUI 服务器。"""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
from uvicorn import Config
|
||||
from uvicorn import Server as UvicornServer
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.port_checker import assert_port_available, is_port_conflict_error, log_port_conflict
|
||||
from src.config.config import config_manager
|
||||
from src.config.startup_bindings import resolve_webui_bind_address
|
||||
from src.webui.app import create_app, show_access_token
|
||||
|
||||
logger = get_logger("webui_server")
|
||||
|
||||
|
||||
def _get_loaded_config_manager() -> Optional[Any]:
|
||||
config_module = sys.modules.get("src.config.config")
|
||||
if config_module is None:
|
||||
return None
|
||||
return getattr(config_module, "config_manager", None)
|
||||
|
||||
|
||||
class _ASGIProxy:
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
@@ -32,10 +42,33 @@ class WebUIServer:
|
||||
self.port = port
|
||||
self._app = create_app(host=host, port=port, enable_static=True)
|
||||
self.app = _ASGIProxy(self._app)
|
||||
self._server = None
|
||||
self._server: Optional[UvicornServer] = None
|
||||
self._reload_callback_registered = False
|
||||
|
||||
show_access_token()
|
||||
self._maybe_register_reload_callback()
|
||||
|
||||
def _maybe_register_reload_callback(self) -> None:
|
||||
if self._reload_callback_registered:
|
||||
return
|
||||
|
||||
config_manager = _get_loaded_config_manager()
|
||||
if config_manager is None:
|
||||
return
|
||||
|
||||
config_manager.register_reload_callback(self.reload_app)
|
||||
self._reload_callback_registered = True
|
||||
|
||||
def _maybe_unregister_reload_callback(self) -> None:
|
||||
if not self._reload_callback_registered:
|
||||
return
|
||||
|
||||
config_manager = _get_loaded_config_manager()
|
||||
if config_manager is None:
|
||||
return
|
||||
|
||||
config_manager.unregister_reload_callback(self.reload_app)
|
||||
self._reload_callback_registered = False
|
||||
|
||||
async def reload_app(self) -> None:
|
||||
self._app = create_app(host=self.host, port=self.port, enable_static=True)
|
||||
@@ -44,12 +77,13 @@ class WebUIServer:
|
||||
|
||||
async def start(self):
|
||||
"""启动服务器"""
|
||||
self._maybe_register_reload_callback()
|
||||
assert_port_available(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
service_name="WebUI 服务器",
|
||||
logger=logger,
|
||||
config_hint="WEBUI_PORT (.env)",
|
||||
config_hint="webui.port (config/bot_config.toml)",
|
||||
allow_reuse_addr=True,
|
||||
)
|
||||
|
||||
@@ -88,7 +122,7 @@ class WebUIServer:
|
||||
service_name="WebUI 服务器",
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
config_hint="WEBUI_PORT (.env)",
|
||||
config_hint="webui.port (config/bot_config.toml)",
|
||||
)
|
||||
else:
|
||||
logger.error(f"❌ WebUI 服务器启动失败 (网络错误): {e}")
|
||||
@@ -97,7 +131,7 @@ class WebUIServer:
|
||||
logger.error(f"❌ WebUI 服务器运行错误: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
config_manager.unregister_reload_callback(self.reload_app)
|
||||
self._maybe_unregister_reload_callback()
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭服务器"""
|
||||
@@ -123,10 +157,6 @@ def get_webui_server() -> WebUIServer:
|
||||
"""获取全局 WebUI 服务器实例"""
|
||||
global _webui_server
|
||||
if _webui_server is None:
|
||||
# 从环境变量读取
|
||||
import os
|
||||
|
||||
host = os.getenv("WEBUI_HOST", "127.0.0.1")
|
||||
port = int(os.getenv("WEBUI_PORT", "8001"))
|
||||
_webui_server = WebUIServer(host=host, port=port)
|
||||
bind_address = resolve_webui_bind_address()
|
||||
_webui_server = WebUIServer(host=bind_address.host, port=bind_address.port)
|
||||
return _webui_server
|
||||
|
||||
Reference in New Issue
Block a user