This commit is contained in:
SengokuCola
2025-12-18 10:53:02 +08:00
20 changed files with 385 additions and 173 deletions

View File

@@ -182,11 +182,60 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
logger.info(f"已将消息 '{message_preview}' 发往 WebUI 聊天室")
return True
# 直接调用API发送消息
await get_global_api().send_message(message)
if show_log:
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'")
return True
# 尝试通过Legacy API发送消息如果失败尝试API Server Fallback
try:
send_result = await get_global_api().send_message(message)
if not send_result:
raise Exception("Legacy API send_message returned False (Target platform not found or connection lost)")
if show_log:
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'")
return True
except Exception as legacy_error:
# Fallback: 尝试使用额外的 API Server 发送
global_api = get_global_api()
extra_server = getattr(global_api, "extra_server", None)
if extra_server and extra_server.is_running():
platform = message.message_info.platform
# Fallback: 使用极其简单的 Platform -> API Key 映射
# 只有收到过该平台的消息,我们才知道该平台的 API Key才能回传消息
platform_map = getattr(global_api, "platform_map", {})
target_api_key = platform_map.get(platform)
if target_api_key:
logger.warning(f"Legacy API发送失败: {legacy_error}。使用缓存API Key通过API Server发送...")
try:
# 构造 APIMessageBase
from maim_message.message import APIMessageBase, MessageDim
msg_dim = MessageDim(api_key=target_api_key, platform=platform)
api_message = APIMessageBase(
message_info=message.message_info,
message_segment=message.message_segment,
message_dim=msg_dim
)
# 直接调用 Server 的 send_message 接口,它会自动处理路由
results = await extra_server.send_message(api_message)
# 检查是否有任何连接发送成功
if any(results.values()):
if show_log:
logger.info(f"已通过API Server Fallback将消息 '{message_preview}' 发往平台'{platform}' (key: {target_api_key})")
return True
else:
logger.error(f"API Server Fallback发送失败: 目标用户(Key={target_api_key})无活跃连接")
except Exception as fallback_error:
logger.error(f"API Server Fallback发送出错: {fallback_error}")
else:
logger.warning(f"Legacy API发送失败且无可用API Server缓存 (未收到过来自 '{platform}' 的消息无法获取API Key)")
# 如果没有fallback或fallback失败抛出原始异常
raise legacy_error
except Exception as e:
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {str(e)}")

View File

@@ -15,14 +15,18 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
# 检查maim_message版本
try:
maim_message_version = importlib.metadata.version("maim_message")
version_compatible = [int(x) for x in maim_message_version.split(".")] >= [0, 3, 3]
version_int = [int(x) for x in maim_message_version.split(".")]
version_compatible = version_int >= [0, 3, 3]
# Check for API Server feature (>= 0.6.0)
has_api_server_feature = version_int >= [0, 6, 0]
except (importlib.metadata.PackageNotFoundError, ValueError):
version_compatible = False
has_api_server_feature = False
# 读取配置项
maim_message_config = global_config.maim_message
# 设置基本参数
# 设置基本参数 (Legacy Server Mode)
kwargs = {
"host": os.environ["HOST"],
"port": int(os.environ["PORT"]),
@@ -39,21 +43,129 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
if maim_message_config.auth_token and len(maim_message_config.auth_token) > 0:
kwargs["enable_token"] = True
if maim_message_config.use_custom:
# 添加WSS模式支持
del kwargs["app"]
kwargs["host"] = maim_message_config.host
kwargs["port"] = maim_message_config.port
kwargs["mode"] = maim_message_config.mode
if maim_message_config.use_wss:
if maim_message_config.cert_file:
kwargs["ssl_certfile"] = maim_message_config.cert_file
if maim_message_config.key_file:
kwargs["ssl_keyfile"] = maim_message_config.key_file
kwargs["enable_custom_uvicorn_logger"] = False
# Removed legacy custom config block (use_custom) as requested.
kwargs["enable_custom_uvicorn_logger"] = False
global_api = MessageServer(**kwargs)
if version_compatible and maim_message_config.auth_token:
for token in maim_message_config.auth_token:
global_api.add_valid_token(token)
# ---------------------------------------------------------------------
# Additional API Server Configuration (maim_message >= 6.0)
# ---------------------------------------------------------------------
enable_api_server = maim_message_config.enable_api_server
# 如果版本支持且启用了API Server则初始化额外服务器
if has_api_server_feature and enable_api_server:
try:
from maim_message.server import WebSocketServer, ServerConfig
from maim_message.message import APIMessageBase
api_logger = get_logger("maim_message_api_server")
# 1. Prepare Config
api_server_host = maim_message_config.api_server_host
api_server_port = maim_message_config.api_server_port
use_wss = maim_message_config.api_server_use_wss
server_config = ServerConfig(
host=api_server_host,
port=api_server_port,
ssl_enabled=use_wss,
ssl_certfile=maim_message_config.api_server_cert_file if use_wss else None,
ssl_keyfile=maim_message_config.api_server_key_file if use_wss else None,
)
# 2. Setup Auth Handler
async def auth_handler(metadata: dict) -> bool:
allowed_keys = maim_message_config.api_server_allowed_api_keys
# If list is empty/None, allow all (default behavior of returning True)
if not allowed_keys:
return True
api_key = metadata.get("api_key")
if api_key in allowed_keys:
return True
api_logger.warning(f"Rejected connection with invalid API Key: {api_key}")
return False
server_config.on_auth = auth_handler
# 3. Setup Message Bridge
# Initialize refined route map if not exists
if not hasattr(global_api, "platform_map"):
global_api.platform_map = {}
async def bridge_message_handler(message: APIMessageBase, metadata: dict):
# Bridge message to the main bot logic
# We convert APIMessageBase to dict to be compatible with legacy handlers
# that MainBot (ChatManager) expects.
msg_dict = message.to_dict()
# Compatibility Layer: Flatten sender_info to top-level user_info/group_info
# Legacy MessageBase expects message_info to have user_info and group_info directly.
if "message_info" in msg_dict:
msg_info = msg_dict["message_info"]
sender_info = msg_info.get("sender_info")
if sender_info:
# If direct user_info/group_info are missing, populate them from sender_info
if "user_info" not in msg_info and (ui := sender_info.get("user_info")):
msg_info["user_info"] = ui
if "group_info" not in msg_info and (gi := sender_info.get("group_info")):
msg_info["group_info"] = gi
# Route Caching Logic: Simply map platform to API Key
# This allows us to send messages back to the correct API client for this platform
try:
api_key = metadata.get("api_key")
if api_key:
platform = msg_info.get("platform")
if platform:
global_api.platform_map[platform] = api_key
except Exception as e:
api_logger.warning(f"Failed to update platform map: {e}")
# Compatibility Layer: Ensure raw_message exists (even if None) as it's part of MessageBase
if "raw_message" not in msg_dict:
msg_dict["raw_message"] = None
await global_api.process_message(msg_dict)
server_config.on_message = bridge_message_handler
# 4. Initialize Server
extra_server = WebSocketServer(config=server_config)
# 5. Patch global_api lifecycle methods to manage both servers
original_run = global_api.run
original_stop = global_api.stop
async def patched_run():
api_logger.info(f"Starting Additional API Server on {api_server_host}:{api_server_port} (WSS: {use_wss})")
# Start the extra server (non-blocking start)
await extra_server.start()
# Run the original legacy server (this usually keeps running)
await original_run()
async def patched_stop():
api_logger.info("Stopping Additional API Server...")
await extra_server.stop()
await original_stop()
global_api.run = patched_run
global_api.stop = patched_stop
# Attach for reference
global_api.extra_server = extra_server
except ImportError:
get_logger("maim_message").error("Cannot import maim_message.server components. Is maim_message >= 0.6.0 installed?")
except Exception as e:
get_logger("maim_message").error(f"Failed to initialize Additional API Server: {e}")
import traceback
get_logger("maim_message").debug(traceback.format_exc())
return global_api

View File

@@ -648,29 +648,29 @@ class ExperimentalConfig(ConfigBase):
class MaimMessageConfig(ConfigBase):
"""maim_message配置类"""
use_custom: bool = False
"""是否使用自定义的maim_message配置"""
host: str = "127.0.0.1"
"""主机地址"""
port: int = 8090
""""端口号"""
mode: Literal["ws", "tcp"] = "ws"
"""连接模式支持ws和tcp"""
use_wss: bool = False
"""是否使用WSS安全连接"""
cert_file: str = ""
"""SSL证书文件路径仅在use_wss=True时有效"""
key_file: str = ""
"""SSL密钥文件路径仅在use_wss=True时有效"""
auth_token: list[str] = field(default_factory=lambda: [])
"""认证令牌用于API验证为空则不启用验证"""
"""认证令牌,用于旧版API验证为空则不启用验证"""
enable_api_server: bool = False
"""是否启用额外的新版API Server"""
api_server_host: str = "0.0.0.0"
"""新版API Server主机地址"""
api_server_port: int = 8090
"""新版API Server端口号"""
api_server_use_wss: bool = False
"""新版API Server是否启用WSS"""
api_server_cert_file: str = ""
"""新版API Server SSL证书文件路径"""
api_server_key_file: str = ""
"""新版API Server SSL密钥文件路径"""
api_server_allowed_api_keys: list[str] = field(default_factory=lambda: [])
"""新版API Server允许的API Key列表为空则允许所有连接"""
@dataclass

View File

@@ -154,6 +154,10 @@ def _parse_allowed_ips(ip_string: str) -> list:
ip_entry = ip_entry.strip() # 去除空格
if not ip_entry:
continue
# 跳过注释行(以#开头)
if ip_entry.startswith("#"):
continue
# 检查通配符格式(包含*
if "*" in ip_entry:

View File

@@ -24,17 +24,22 @@ def _is_secure_environment() -> bool:
bool: 如果应该使用 secure cookie 则返回 True
"""
# 检查环境变量
if os.environ.get("WEBUI_SECURE_COOKIE", "").lower() in ("true", "1", "yes"):
secure_cookie_env = os.environ.get("WEBUI_SECURE_COOKIE", "")
if secure_cookie_env.lower() in ("true", "1", "yes"):
logger.info(f"WEBUI_SECURE_COOKIE 设置为 {secure_cookie_env},启用 secure cookie")
return True
if os.environ.get("WEBUI_SECURE_COOKIE", "").lower() in ("false", "0", "no"):
if secure_cookie_env.lower() in ("false", "0", "no"):
logger.info(f"WEBUI_SECURE_COOKIE 设置为 {secure_cookie_env},禁用 secure cookie")
return False
# 检查是否是生产环境
env = os.environ.get("WEBUI_MODE", "").lower()
if env in ("production", "prod"):
logger.info(f"WEBUI_MODE 设置为 {env},启用 secure cookie")
return True
# 默认:开发环境不启用(因为通常是 HTTP
logger.debug(f"未设置特殊环境变量 (WEBUI_SECURE_COOKIE={secure_cookie_env}, WEBUI_MODE={env}),禁用 secure cookie")
return False
@@ -77,27 +82,53 @@ def get_current_token(
return token
def set_auth_cookie(response: Response, token: str) -> None:
def set_auth_cookie(response: Response, token: str, request: Optional[Request] = None) -> None:
"""
设置认证 Cookie
Args:
response: FastAPI Response 对象
token: 要设置的 token
request: FastAPI Request 对象(可选,用于检测协议)
"""
# 根据环境决定安全设置
# 根据环境和实际请求协议决定安全设置
is_secure = _is_secure_environment()
# 如果提供了 request检测实际使用的协议
if request:
# 检查 X-Forwarded-Proto header代理/负载均衡器)
forwarded_proto = request.headers.get("x-forwarded-proto", "").lower()
if forwarded_proto:
is_https = forwarded_proto == "https"
logger.debug(f"检测到 X-Forwarded-Proto: {forwarded_proto}, is_https={is_https}")
else:
# 检查 request.url.scheme
is_https = request.url.scheme == "https"
logger.debug(f"检测到 scheme: {request.url.scheme}, is_https={is_https}")
# 如果是 HTTP 连接,强制禁用 secure 标志
if not is_https and is_secure:
logger.warning("=" * 80)
logger.warning("检测到 HTTP 连接但环境配置要求 HTTPS (secure cookie)")
logger.warning("已自动禁用 secure 标志以允许登录,但建议修改配置:")
logger.warning("1. 在 .env 文件中设置: WEBUI_SECURE_COOKIE=false")
logger.warning("2. 如果使用反向代理,请确保正确配置 X-Forwarded-Proto 头")
logger.warning("=" * 80)
is_secure = False
# 设置 Cookie
response.set_cookie(
key=COOKIE_NAME,
value=token,
max_age=COOKIE_MAX_AGE,
httponly=True, # 防止 JS 读取,阻止 XSS 窃取
samesite="strict" if is_secure else "lax", # 生产环境使用 strict 防止 CSRF
secure=is_secure, # 生产环境强制 HTTPS
samesite="lax", # 使用 lax 以兼容更多场景(开发和生产)
secure=is_secure, # 根据实际协议决定
path="/", # 确保 Cookie 在所有路径下可用
)
logger.debug(f"已设置认证 Cookie: {token[:8]}... (secure={is_secure})")
logger.info(f"已设置认证 Cookie: {token[:8]}... (secure={is_secure}, samesite=lax, httponly=True, path=/, max_age={COOKIE_MAX_AGE})")
logger.debug(f"完整 token 前缀: {token[:20]}...")
def clear_auth_cookie(response: Response) -> None:

View File

@@ -137,8 +137,8 @@ async def verify_token(
if is_valid:
# 认证成功,重置失败计数
rate_limiter.reset_failures(request)
# 设置 HttpOnly Cookie
set_auth_cookie(response, request_body.token)
# 设置 HttpOnly Cookie(传入 request 以检测协议)
set_auth_cookie(response, request_body.token, request)
# 同时返回首次配置状态,避免额外请求
is_first_setup = token_manager.is_first_setup()
return TokenVerifyResponse(valid=True, message="Token 验证成功", is_first_setup=is_first_setup)
@@ -195,23 +195,33 @@ async def check_auth_status(
"""
try:
token = None
# 记录请求信息用于调试
logger.debug(f"检查认证状态 - Cookie: {maibot_session[:20] if maibot_session else 'None'}..., Authorization: {'Present' if authorization else 'None'}")
# 优先从 Cookie 获取
if maibot_session:
token = maibot_session
logger.debug("使用 Cookie 中的 token")
# 其次从 Header 获取
elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
logger.debug("使用 Header 中的 token")
if not token:
logger.debug("未找到 token返回未认证")
return {"authenticated": False}
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_valid = token_manager.verify_token(token)
logger.debug(f"Token 验证结果: {is_valid}")
if is_valid:
return {"authenticated": True}
else:
return {"authenticated": False}
except Exception:
except Exception as e:
logger.error(f"认证检查失败: {e}", exc_info=True)
return {"authenticated": False}