feat: add unified WebSocket connection manager and routing

- Implemented UnifiedWebSocketManager for managing WebSocket connections, including subscription handling and message sending.
- Created unified WebSocket router to handle client messages, including authentication, subscription, and chat session management.
- Added support for logging and plugin progress subscriptions.
- Enhanced error handling and response structure for WebSocket operations.
This commit is contained in:
DrSmoothl
2026-04-02 22:08:52 +08:00
parent 7d0d429640
commit 1906890b67
28 changed files with 3845 additions and 1137 deletions

View File

@@ -0,0 +1,161 @@
import { unifiedWsClient, type ConnectionStatus } from './unified-ws'
interface ChatSessionOpenPayload {
group_id?: string
group_name?: string
person_id?: string
platform?: string
user_id?: string
user_name?: string
}
type ChatSessionListener = (message: Record<string, unknown>) => void
class ChatWsClient {
private initialized = false
private listeners: Map<string, Set<ChatSessionListener>> = new Map()
private sessionPayloads: Map<string, ChatSessionOpenPayload> = new Map()
private initialize(): void {
if (this.initialized) {
return
}
unifiedWsClient.addEventListener((message) => {
if (message.domain !== 'chat' || !message.session) {
return
}
const sessionListeners = this.listeners.get(message.session)
if (!sessionListeners) {
return
}
sessionListeners.forEach((listener) => {
try {
listener(message.data)
} catch (error) {
console.error('聊天会话监听器执行失败:', error)
}
})
})
unifiedWsClient.onReconnect(() => {
void this.reopenSessions()
})
this.initialized = true
}
private async reopenSessions(): Promise<void> {
const reopenTargets = Array.from(this.sessionPayloads.entries())
for (const [sessionId, payload] of reopenTargets) {
try {
await unifiedWsClient.call({
domain: 'chat',
method: 'session.open',
session: sessionId,
data: {
...payload,
restore: true,
} as Record<string, unknown>,
})
} catch (error) {
console.error(`恢复聊天会话失败 (${sessionId}):`, error)
}
}
}
async openSession(sessionId: string, payload: ChatSessionOpenPayload): Promise<void> {
this.initialize()
this.sessionPayloads.set(sessionId, payload)
await unifiedWsClient.call({
domain: 'chat',
method: 'session.open',
session: sessionId,
data: payload as Record<string, unknown>,
})
}
async closeSession(sessionId: string): Promise<void> {
this.sessionPayloads.delete(sessionId)
if (unifiedWsClient.getStatus() !== 'connected') {
return
}
try {
await unifiedWsClient.call({
domain: 'chat',
method: 'session.close',
session: sessionId,
data: {},
})
} catch (error) {
console.warn(`关闭聊天会话失败 (${sessionId}):`, error)
}
}
async sendMessage(sessionId: string, content: string, userName: string): Promise<void> {
await unifiedWsClient.call({
domain: 'chat',
method: 'message.send',
session: sessionId,
data: {
content,
user_name: userName,
},
})
}
async updateNickname(sessionId: string, userName: string): Promise<void> {
const currentPayload = this.sessionPayloads.get(sessionId)
if (currentPayload) {
this.sessionPayloads.set(sessionId, {
...currentPayload,
user_name: userName,
})
}
await unifiedWsClient.call({
domain: 'chat',
method: 'session.update_nickname',
session: sessionId,
data: {
user_name: userName,
},
})
}
onSessionMessage(sessionId: string, listener: ChatSessionListener): () => void {
this.initialize()
const sessionListeners = this.listeners.get(sessionId) ?? new Set<ChatSessionListener>()
sessionListeners.add(listener)
this.listeners.set(sessionId, sessionListeners)
return () => {
const currentListeners = this.listeners.get(sessionId)
if (!currentListeners) {
return
}
currentListeners.delete(listener)
if (currentListeners.size === 0) {
this.listeners.delete(sessionId)
}
}
}
onConnectionChange(listener: (connected: boolean) => void): () => void {
return unifiedWsClient.onConnectionChange(listener)
}
onStatusChange(listener: (status: ConnectionStatus) => void): () => void {
return unifiedWsClient.onStatusChange(listener)
}
async restart(): Promise<void> {
await unifiedWsClient.restart()
}
}
export const chatWsClient = new ChatWsClient()

View File

@@ -1,13 +1,11 @@
/**
* 全局日志 WebSocket 管理器
* 确保整个应用只有一个 WebSocket 连接
* 确保整个应用只通过统一连接层订阅日志流
*/
import { checkAuthStatus } from './fetch-with-auth'
import { getSetting } from './settings-manager'
import { createReconnectingWebSocket } from './ws-utils'
import { getWsBaseUrl } from '@/lib/api-base'
import { unifiedWsClient } from './unified-ws'
export interface LogEntry {
id: string
@@ -17,165 +15,79 @@ export interface LogEntry {
message: string
}
type LogCallback = (log: LogEntry) => void
type LogCallback = () => void
type ConnectionCallback = (connected: boolean) => void
class LogWebSocketManager {
private wsControl: ReturnType<typeof createReconnectingWebSocket> | null = null
// 订阅者
private logCallbacks: Set<LogCallback> = new Set()
private connectionCallbacks: Set<ConnectionCallback> = new Set()
private initialized = false
private isConnected = false
// 日志缓存 - 保存所有接收到的日志
private logCache: LogEntry[] = []
private logCallbacks: Set<LogCallback> = new Set()
private subscriptionActive = false
/**
* 获取最大缓存大小(从设置读取)
*/
private getMaxCacheSize(): number {
return getSetting('logCacheSize')
}
/**
* 获取最大重连次数(从设置读取)
*/
private getMaxReconnectAttempts(): number {
return getSetting('wsMaxReconnectAttempts')
}
/**
* 获取重连间隔(从设置读取)
*/
private getReconnectInterval(): number {
return getSetting('wsReconnectInterval')
}
/**
* 获取 WebSocket URL不含 token 参数)
*/
private async getWebSocketUrl(): Promise<string> {
const wsBase = await getWsBaseUrl()
return `${wsBase}/ws/logs`
}
/**
* 连接 WebSocket会先检查登录状态
*/
async connect() {
// 检查是否在登录页面
if (window.location.pathname === '/auth') {
console.log('📡 在登录页面,跳过 WebSocket 连接')
private initialize(): void {
if (this.initialized) {
return
}
// 检查登录状态,避免未登录时尝试连接
const isAuthenticated = await checkAuthStatus()
if (!isAuthenticated) {
console.log('📡 未登录,跳过 WebSocket 连接')
return
}
unifiedWsClient.addEventListener((message) => {
if (message.domain !== 'logs') {
return
}
const wsUrl = await this.getWebSocketUrl()
if (message.event === 'snapshot') {
const entries = Array.isArray(message.data.entries)
? (message.data.entries as LogEntry[])
: []
this.logCache = entries.slice(-this.getMaxCacheSize())
this.notifyLogChange()
return
}
// 使用 ws-utils 创建 WebSocket
this.wsControl = createReconnectingWebSocket(wsUrl, {
onMessage: (data: string) => {
try {
const log: LogEntry = JSON.parse(data)
this.notifyLog(log)
} catch (error) {
console.error('解析日志消息失败:', error)
}
},
onOpen: () => {
this.isConnected = true
this.notifyConnection(true)
},
onClose: () => {
this.isConnected = false
this.notifyConnection(false)
},
onError: (error) => {
console.error('❌ WebSocket 错误:', error)
this.isConnected = false
this.notifyConnection(false)
},
heartbeatInterval: 30000,
maxRetries: this.getMaxReconnectAttempts(),
backoffBase: this.getReconnectInterval(),
maxBackoff: 30000,
if (message.event === 'entry' && message.data.entry) {
this.appendLog(message.data.entry as LogEntry)
}
})
// 启动连接
await this.wsControl.connect()
unifiedWsClient.onConnectionChange((connected) => {
this.isConnected = connected
this.notifyConnection(connected)
})
this.initialized = true
}
/**
* 断开连接
*/
disconnect() {
if (this.wsControl) {
this.wsControl.disconnect()
this.wsControl = null
}
this.isConnected = false
}
/**
* 订阅日志消息
*/
onLog(callback: LogCallback) {
this.logCallbacks.add(callback)
return () => this.logCallbacks.delete(callback)
}
/**
* 订阅连接状态
*/
onConnectionChange(callback: ConnectionCallback) {
this.connectionCallbacks.add(callback)
// 立即通知当前状态
callback(this.isConnected)
return () => this.connectionCallbacks.delete(callback)
}
/**
* 通知所有订阅者新日志
*/
private notifyLog(log: LogEntry) {
// 检查是否已存在(通过 id 去重)
private appendLog(log: LogEntry): void {
const exists = this.logCache.some(existingLog => existingLog.id === log.id)
if (!exists) {
// 添加到缓存
this.logCache.push(log)
// 限制缓存大小(动态读取配置)
const maxCacheSize = this.getMaxCacheSize()
if (this.logCache.length > maxCacheSize) {
this.logCache = this.logCache.slice(-maxCacheSize)
}
// 只有新日志才通知订阅者
this.logCallbacks.forEach(callback => {
try {
callback(log)
} catch (error) {
console.error('日志回调执行失败:', error)
}
})
if (exists) {
return
}
this.logCache.push(log)
const maxCacheSize = this.getMaxCacheSize()
if (this.logCache.length > maxCacheSize) {
this.logCache = this.logCache.slice(-maxCacheSize)
}
this.notifyLogChange()
}
/**
* 通知所有订阅者连接状态变化
*/
private notifyConnection(connected: boolean) {
this.connectionCallbacks.forEach(callback => {
private notifyLogChange(): void {
this.logCallbacks.forEach((callback) => {
try {
callback()
} catch (error) {
console.error('日志回调执行失败:', error)
}
})
}
private notifyConnection(connected: boolean): void {
this.connectionCallbacks.forEach((callback) => {
try {
callback(connected)
} catch (error) {
@@ -184,35 +96,65 @@ class LogWebSocketManager {
})
}
/**
* 获取缓存的所有日志
*/
async connect(): Promise<void> {
if (window.location.pathname === '/auth') {
return
}
const isAuthenticated = await checkAuthStatus()
if (!isAuthenticated) {
return
}
this.initialize()
if (this.subscriptionActive) {
return
}
try {
await unifiedWsClient.subscribe('logs', 'main', { replay: 100 })
this.subscriptionActive = true
} catch (error) {
console.error('订阅日志流失败:', error)
}
}
disconnect(): void {
this.subscriptionActive = false
void unifiedWsClient.unsubscribe('logs', 'main')
this.isConnected = false
this.notifyConnection(false)
}
onLog(callback: LogCallback): () => void {
this.logCallbacks.add(callback)
return () => this.logCallbacks.delete(callback)
}
onConnectionChange(callback: ConnectionCallback): () => void {
this.connectionCallbacks.add(callback)
callback(this.isConnected)
return () => this.connectionCallbacks.delete(callback)
}
getAllLogs(): LogEntry[] {
return [...this.logCache]
}
/**
* 清空日志缓存
*/
clearLogs() {
clearLogs(): void {
this.logCache = []
this.notifyLogChange()
}
/**
* 获取当前连接状态
*/
getConnectionStatus(): boolean {
return this.isConnected
}
}
// 导出单例
export const logWebSocket = new LogWebSocketManager()
// 自动连接(应用启动时)
if (typeof window !== 'undefined') {
// 延迟一下确保页面加载完成
setTimeout(() => {
logWebSocket.connect()
void logWebSocket.connect()
}, 100)
}

View File

@@ -1,9 +1,9 @@
import type { ApiResponse } from '@/types/api'
import type { PluginInfo } from '@/types/plugin'
import { getWsBaseUrl } from '@/lib/api-base'
import { fetchWithAuth } from '@/lib/fetch-with-auth'
import { parseResponse } from '@/lib/api-helpers'
import { pluginProgressClient } from '@/lib/plugin-progress-client'
import type { GitStatus, MaimaiVersion } from './types'
/**
@@ -211,41 +211,13 @@ export function isPluginCompatible(
*/
export async function connectPluginProgressWebSocket(
onProgress: (progress: import('./types').PluginLoadProgress) => void,
onError?: (error: Event) => void
): Promise<WebSocket | null> {
const wsBase = await getWsBaseUrl()
const wsUrl = `${wsBase}/api/webui/ws/plugin-progress`
// 使用 ws-utils 创建 WebSocket
const { createReconnectingWebSocket } = await import('@/lib/ws-utils')
const wsControl = createReconnectingWebSocket(wsUrl, {
onMessage: (data: string) => {
try {
const progressData = JSON.parse(data) as import('./types').PluginLoadProgress
onProgress(progressData)
} catch (error) {
console.error('Failed to parse progress data:', error)
}
},
onOpen: () => {
console.log('Plugin progress WebSocket connected')
},
onClose: () => {
console.log('Plugin progress WebSocket disconnected')
},
onError: (error) => {
console.error('Plugin progress WebSocket error:', error)
onError?.(error)
},
heartbeatInterval: 30000,
maxRetries: 10,
backoffBase: 1000,
maxBackoff: 30000,
})
// 启动连接
await wsControl.connect()
// 返回 WebSocket 实例(用于外部检查连接状态)
return wsControl.getWebSocket()
onError?: (error: Error) => void
): Promise<() => Promise<void>> {
try {
return await pluginProgressClient.subscribe(onProgress)
} catch (error) {
const normalizedError = error instanceof Error ? error : new Error('插件进度订阅失败')
onError?.(normalizedError)
return async () => {}
}
}

View File

@@ -0,0 +1,58 @@
import type { PluginLoadProgress } from '@/lib/plugin-api/types'
import { unifiedWsClient } from './unified-ws'
type ProgressListener = (progress: PluginLoadProgress) => void
class PluginProgressClient {
private initialized = false
private listeners: Set<ProgressListener> = new Set()
private subscriptionActive = false
private initialize(): void {
if (this.initialized) {
return
}
unifiedWsClient.addEventListener((message) => {
if (message.domain !== 'plugin_progress') {
return
}
const progress = message.data.progress as PluginLoadProgress | undefined
if (!progress) {
return
}
this.listeners.forEach((listener) => {
try {
listener(progress)
} catch (error) {
console.error('插件进度监听器执行失败:', error)
}
})
})
this.initialized = true
}
async subscribe(listener: ProgressListener): Promise<() => Promise<void>> {
this.initialize()
this.listeners.add(listener)
if (!this.subscriptionActive) {
await unifiedWsClient.subscribe('plugin_progress', 'main')
this.subscriptionActive = true
}
return async () => {
this.listeners.delete(listener)
if (this.listeners.size === 0 && this.subscriptionActive) {
this.subscriptionActive = false
await unifiedWsClient.unsubscribe('plugin_progress', 'main')
}
}
}
}
export const pluginProgressClient = new PluginProgressClient()

View File

@@ -0,0 +1,495 @@
import { fetchWithAuth } from './fetch-with-auth'
import { getSetting } from './settings-manager'
import { getWsBaseUrl } from '@/lib/api-base'
export type ConnectionStatus = 'idle' | 'connecting' | 'connected'
export interface WsErrorPayload {
code?: string
message: string
}
export interface WsEventEnvelope {
op: 'event'
domain: string
event: string
session?: string
topic?: string
data: Record<string, unknown>
}
interface WsResponseEnvelope {
op: 'response'
id?: string
ok: boolean
data?: Record<string, unknown>
error?: WsErrorPayload
}
interface WsPongEnvelope {
op: 'pong'
ts: number
}
type WsServerEnvelope = WsEventEnvelope | WsPongEnvelope | WsResponseEnvelope
interface PendingRequest {
reject: (error: Error) => void
resolve: (data: Record<string, unknown>) => void
timeoutId: number
}
interface SubscriptionDefinition {
data?: Record<string, unknown>
domain: string
topic: string
}
type EventListener = (message: WsEventEnvelope) => void
type ConnectionListener = (connected: boolean) => void
type StatusListener = (status: ConnectionStatus) => void
type ReconnectListener = () => void
function isResponseEnvelope(message: WsServerEnvelope): message is WsResponseEnvelope {
return message.op === 'response'
}
function isEventEnvelope(message: WsServerEnvelope): message is WsEventEnvelope {
return message.op === 'event'
}
async function getWsToken(): Promise<string | null> {
try {
const response = await fetchWithAuth('/api/webui/ws-token', {
method: 'GET',
credentials: 'include',
})
if (!response.ok) {
return null
}
const data = await response.json()
if (data.success && data.token) {
return data.token as string
}
return null
} catch (error) {
console.error('获取统一 WebSocket token 失败:', error)
return null
}
}
class UnifiedWebSocketClient {
private connectPromise: Promise<void> | null = null
private connectionListeners: Set<ConnectionListener> = new Set()
private eventListeners: Set<EventListener> = new Set()
private hasConnectedOnce = false
private heartbeatIntervalId: number | null = null
private manualDisconnect = false
private pendingRequests: Map<string, PendingRequest> = new Map()
private reconnectAttempts = 0
private reconnectListeners: Set<ReconnectListener> = new Set()
private reconnectTimeout: number | null = null
private requestCounter = 0
private status: ConnectionStatus = 'idle'
private statusListeners: Set<StatusListener> = new Set()
private subscriptions: Map<string, SubscriptionDefinition> = new Map()
private ws: WebSocket | null = null
private getReconnectDelay(): number {
const baseDelay = getSetting('wsReconnectInterval')
return Math.min(baseDelay * Math.max(this.reconnectAttempts, 1), 30000)
}
private getMaxReconnectAttempts(): number {
return getSetting('wsMaxReconnectAttempts')
}
private getSubscriptionKey(domain: string, topic: string): string {
return `${domain}:${topic}`
}
private nextRequestId(): string {
this.requestCounter += 1
return `ws-${Date.now()}-${this.requestCounter}`
}
private setStatus(status: ConnectionStatus): void {
if (this.status === status) {
return
}
this.status = status
this.statusListeners.forEach((listener) => {
try {
listener(status)
} catch (error) {
console.error('WebSocket 状态监听器执行失败:', error)
}
})
const connected = status === 'connected'
this.connectionListeners.forEach((listener) => {
try {
listener(connected)
} catch (error) {
console.error('WebSocket 连接监听器执行失败:', error)
}
})
}
private stopHeartbeat(): void {
if (this.heartbeatIntervalId !== null) {
clearInterval(this.heartbeatIntervalId)
this.heartbeatIntervalId = null
}
}
private startHeartbeat(): void {
this.stopHeartbeat()
this.heartbeatIntervalId = window.setInterval(() => {
if (this.ws?.readyState === WebSocket.OPEN) {
this.ws.send(JSON.stringify({ op: 'ping' }))
}
}, 30000)
}
private clearReconnectTimer(): void {
if (this.reconnectTimeout !== null) {
clearTimeout(this.reconnectTimeout)
this.reconnectTimeout = null
}
}
private rejectPendingRequests(error: Error): void {
this.pendingRequests.forEach((pendingRequest, requestId) => {
clearTimeout(pendingRequest.timeoutId)
pendingRequest.reject(error)
this.pendingRequests.delete(requestId)
})
}
private scheduleReconnect(): void {
if (this.manualDisconnect) {
return
}
if (this.reconnectAttempts >= this.getMaxReconnectAttempts()) {
console.warn(`统一 WebSocket 达到最大重连次数 (${this.getMaxReconnectAttempts()}),停止重连`)
return
}
this.reconnectAttempts += 1
const delay = this.getReconnectDelay()
this.clearReconnectTimer()
this.reconnectTimeout = window.setTimeout(() => {
void this.connect().catch((error) => {
console.error('统一 WebSocket 重连失败:', error)
})
}, delay)
}
private async createWebSocketUrl(): Promise<string | null> {
const wsBaseUrl = await getWsBaseUrl()
const wsToken = await getWsToken()
if (!wsBaseUrl || !wsToken) {
return null
}
return `${wsBaseUrl}/api/webui/ws?token=${encodeURIComponent(wsToken)}`
}
private async sendRequest(
payload: Record<string, unknown>,
timeoutMs = 10000,
): Promise<Record<string, unknown>> {
if (this.ws?.readyState !== WebSocket.OPEN) {
throw new Error('统一 WebSocket 尚未连接')
}
const requestId = payload.id as string
return await new Promise<Record<string, unknown>>((resolve, reject) => {
const timeoutId = window.setTimeout(() => {
this.pendingRequests.delete(requestId)
reject(new Error(`统一 WebSocket 请求超时: ${requestId}`))
}, timeoutMs)
this.pendingRequests.set(requestId, {
resolve,
reject,
timeoutId,
})
this.ws?.send(JSON.stringify(payload))
})
}
private async restoreState(shouldNotifyReconnect: boolean): Promise<void> {
const subscriptions = Array.from(this.subscriptions.values())
for (const subscription of subscriptions) {
try {
await this.sendRequest({
op: 'subscribe',
id: this.nextRequestId(),
domain: subscription.domain,
topic: subscription.topic,
data: subscription.data ?? {},
})
} catch (error) {
console.error('恢复统一 WebSocket 订阅失败:', error)
}
}
if (shouldNotifyReconnect) {
this.reconnectListeners.forEach((listener) => {
try {
listener()
} catch (error) {
console.error('统一 WebSocket 重连监听器执行失败:', error)
}
})
}
}
private handleServerMessage(rawData: string): void {
let message: WsServerEnvelope
try {
message = JSON.parse(rawData) as WsServerEnvelope
} catch (error) {
console.error('解析统一 WebSocket 消息失败:', error)
return
}
if (message.op === 'pong') {
return
}
if (isResponseEnvelope(message)) {
const requestId = message.id
if (!requestId) {
return
}
const pendingRequest = this.pendingRequests.get(requestId)
if (!pendingRequest) {
return
}
clearTimeout(pendingRequest.timeoutId)
this.pendingRequests.delete(requestId)
if (message.ok) {
pendingRequest.resolve(message.data ?? {})
} else {
pendingRequest.reject(new Error(message.error?.message ?? '统一 WebSocket 请求失败'))
}
return
}
if (isEventEnvelope(message)) {
this.eventListeners.forEach((listener) => {
try {
listener(message)
} catch (error) {
console.error('统一 WebSocket 事件监听器执行失败:', error)
}
})
}
}
private handleClose(event: CloseEvent): void {
this.stopHeartbeat()
this.ws = null
this.connectPromise = null
this.setStatus('idle')
this.rejectPendingRequests(new Error(`统一 WebSocket 已关闭 (${event.code})`))
if (event.code === 4001) {
this.manualDisconnect = true
if (window.location.pathname !== '/auth') {
window.location.href = '/auth'
}
return
}
this.scheduleReconnect()
}
async connect(): Promise<void> {
if (this.ws?.readyState === WebSocket.OPEN) {
return
}
if (this.connectPromise) {
return await this.connectPromise
}
this.manualDisconnect = false
this.setStatus('connecting')
this.connectPromise = (async () => {
const wsUrl = await this.createWebSocketUrl()
if (!wsUrl) {
this.setStatus('idle')
throw new Error('无法建立统一 WebSocket 连接')
}
await new Promise<void>((resolve, reject) => {
let settled = false
const socket = new WebSocket(wsUrl)
this.ws = socket
socket.onopen = () => {
settled = true
const shouldNotifyReconnect = this.hasConnectedOnce
this.hasConnectedOnce = true
this.reconnectAttempts = 0
this.startHeartbeat()
this.setStatus('connected')
resolve()
void this.restoreState(shouldNotifyReconnect)
}
socket.onmessage = (event) => {
this.handleServerMessage(event.data)
}
socket.onerror = () => {
if (!settled) {
settled = true
reject(new Error('统一 WebSocket 连接失败'))
}
}
socket.onclose = (event) => {
if (!settled) {
settled = true
reject(new Error(`统一 WebSocket 已关闭 (${event.code})`))
}
this.handleClose(event)
}
})
})()
try {
await this.connectPromise
} finally {
if (this.status !== 'connected') {
this.connectPromise = null
}
}
}
disconnect(): void {
this.manualDisconnect = true
this.clearReconnectTimer()
this.stopHeartbeat()
this.rejectPendingRequests(new Error('统一 WebSocket 已手动断开'))
this.connectPromise = null
if (this.ws) {
this.ws.close()
this.ws = null
}
this.setStatus('idle')
}
async restart(): Promise<void> {
this.manualDisconnect = false
this.clearReconnectTimer()
if (this.ws) {
this.ws.close()
return
}
await this.connect()
}
async call(params: {
data?: Record<string, unknown>
domain: string
method: string
session?: string
}): Promise<Record<string, unknown>> {
await this.connect()
const requestId = this.nextRequestId()
return await this.sendRequest({
op: 'call',
id: requestId,
domain: params.domain,
method: params.method,
session: params.session,
data: params.data ?? {},
})
}
async subscribe(
domain: string,
topic: string,
data?: Record<string, unknown>,
): Promise<Record<string, unknown>> {
await this.connect()
this.subscriptions.set(this.getSubscriptionKey(domain, topic), {
domain,
topic,
data,
})
return await this.sendRequest({
op: 'subscribe',
id: this.nextRequestId(),
domain,
topic,
data: data ?? {},
})
}
async unsubscribe(domain: string, topic: string): Promise<Record<string, unknown> | null> {
this.subscriptions.delete(this.getSubscriptionKey(domain, topic))
if (this.ws?.readyState !== WebSocket.OPEN) {
return null
}
return await this.sendRequest({
op: 'unsubscribe',
id: this.nextRequestId(),
domain,
topic,
data: {},
})
}
addEventListener(listener: EventListener): () => void {
this.eventListeners.add(listener)
return () => {
this.eventListeners.delete(listener)
}
}
onConnectionChange(listener: ConnectionListener): () => void {
this.connectionListeners.add(listener)
listener(this.status === 'connected')
return () => {
this.connectionListeners.delete(listener)
}
}
onStatusChange(listener: StatusListener): () => void {
this.statusListeners.add(listener)
listener(this.status)
return () => {
this.statusListeners.delete(listener)
}
}
onReconnect(listener: ReconnectListener): () => void {
this.reconnectListeners.add(listener)
return () => {
this.reconnectListeners.delete(listener)
}
}
getStatus(): ConnectionStatus {
return this.status
}
}
export const unifiedWsClient = new UnifiedWebSocketClient()

View File

@@ -1,211 +0,0 @@
import { fetchWithAuth } from './fetch-with-auth'
/**
* WebSocket 配置选项
*/
export interface WebSocketOptions {
onMessage?: (data: string) => void
onOpen?: () => void
onClose?: () => void
onError?: (error: Event) => void
heartbeatInterval?: number // 心跳间隔(毫秒)
maxRetries?: number // 最大重连次数
backoffBase?: number // 重连基础间隔(毫秒)
maxBackoff?: number // 最大重连间隔(毫秒)
}
/**
* 获取 WebSocket 临时认证 token
*/
export async function getWsToken(): Promise<string | null> {
try {
// 使用相对路径,让前端代理处理请求,避免 CORS 问题
const response = await fetchWithAuth('/api/webui/ws-token', {
method: 'GET',
credentials: 'include', // 携带 Cookie
})
if (!response.ok) {
console.error('获取 WebSocket token 失败:', response.status)
return null
}
const data = await response.json()
if (data.success && data.token) {
return data.token
}
return null
} catch (error) {
console.error('获取 WebSocket token 失败:', error)
return null
}
}
/**
* 创建带重连、心跳的 WebSocket 封装
*
* @param url WebSocket URL不含 token 参数)
* @param options 配置选项
* @returns WebSocket 控制对象,包含 connect、disconnect、send 方法
*/
export function createReconnectingWebSocket(
url: string,
options: WebSocketOptions = {}
) {
const {
onMessage,
onOpen,
onClose,
onError,
heartbeatInterval = 30000,
maxRetries = 10,
backoffBase = 1000,
maxBackoff = 30000,
} = options
let ws: WebSocket | null = null
let reconnectTimeout: number | null = null
let reconnectAttempts = 0
let heartbeatIntervalId: number | null = null
let isManualDisconnect = false
/**
* 启动心跳
*/
function startHeartbeat() {
stopHeartbeat()
heartbeatIntervalId = window.setInterval(() => {
if (ws?.readyState === WebSocket.OPEN) {
ws.send('ping')
}
}, heartbeatInterval)
}
/**
* 停止心跳
*/
function stopHeartbeat() {
if (heartbeatIntervalId !== null) {
clearInterval(heartbeatIntervalId)
heartbeatIntervalId = null
}
}
/**
* 尝试重连
*/
function attemptReconnect() {
if (isManualDisconnect) {
return
}
if (reconnectAttempts >= maxRetries) {
console.warn(`WebSocket 达到最大重连次数 (${maxRetries}),停止重连`)
return
}
reconnectAttempts += 1
const delay = Math.min(backoffBase * reconnectAttempts, maxBackoff)
console.log(`WebSocket 将在 ${delay}ms 后重连(第 ${reconnectAttempts} 次)`)
reconnectTimeout = window.setTimeout(() => {
connect()
}, delay)
}
/**
* 连接 WebSocket
*/
async function connect() {
if (ws?.readyState === WebSocket.OPEN || ws?.readyState === WebSocket.CONNECTING) {
return
}
// 先获取临时认证 token
const wsToken = await getWsToken()
if (!wsToken) {
console.warn('无法获取 WebSocket token跳过连接')
return
}
const wsUrl = `${url}?token=${encodeURIComponent(wsToken)}`
try {
ws = new WebSocket(wsUrl)
ws.onopen = () => {
reconnectAttempts = 0
startHeartbeat()
onOpen?.()
}
ws.onmessage = (event) => {
// 忽略心跳响应
if (event.data === 'pong') {
return
}
onMessage?.(event.data)
}
ws.onerror = (error) => {
console.error('WebSocket 错误:', error)
onError?.(error)
}
ws.onclose = () => {
stopHeartbeat()
onClose?.()
attemptReconnect()
}
} catch (error) {
console.error('创建 WebSocket 连接失败:', error)
attemptReconnect()
}
}
/**
* 断开连接
*/
function disconnect() {
isManualDisconnect = true
if (reconnectTimeout !== null) {
clearTimeout(reconnectTimeout)
reconnectTimeout = null
}
stopHeartbeat()
if (ws) {
ws.close()
ws = null
}
reconnectAttempts = 0
}
/**
* 发送消息
*/
function send(data: string) {
if (ws?.readyState === WebSocket.OPEN) {
ws.send(data)
} else {
console.warn('WebSocket 未连接,无法发送消息')
}
}
/**
* 获取当前 WebSocket 实例
*/
function getWebSocket(): WebSocket | null {
return ws
}
return {
connect,
disconnect,
send,
getWebSocket,
}
}

View File

@@ -5,7 +5,7 @@ import { Button } from '@/components/ui/button'
import { Input } from '@/components/ui/input'
import { ScrollArea } from '@/components/ui/scroll-area'
import { useToast } from '@/hooks/use-toast'
import { getWsBaseUrl } from '@/lib/api-base'
import { chatWsClient } from '@/lib/chat-ws-client'
import { fetchWithAuth } from '@/lib/fetch-with-auth'
import { cn } from '@/lib/utils'
import { Bot, Edit2, Loader2, RefreshCw, User, Send, Wifi, WifiOff, UserCircle2 } from 'lucide-react'
@@ -85,14 +85,17 @@ export function ChatPage() {
// 持久化用户 ID
const userIdRef = useRef(getOrCreateUserId())
// 每个标签页的 WebSocket 连接
const wsMapRef = useRef<Map<string, WebSocket>>(new Map())
const messagesEndRef = useRef<HTMLDivElement>(null)
const reconnectTimeoutMapRef = useRef<Map<string, number>>(new Map())
const messageIdCounterRef = useRef(0)
const processedMessagesMapRef = useRef<Map<string, Set<string>>>(new Map())
const sessionUnsubscribeMapRef = useRef<Map<string, () => void>>(new Map())
const tabsRef = useRef<ChatTab[]>([])
const { toast } = useToast()
useEffect(() => {
tabsRef.current = tabs
}, [tabs])
// 生成唯一消息 ID
const generateMessageId = (prefix: string) => {
messageIdCounterRef.current += 1
@@ -197,357 +200,218 @@ export function ChatPage() {
}
}, [tempVirtualConfig.platform, personSearchQuery, fetchPersons])
// 加载聊天历史到指定标签页
const loadChatHistoryForTab = useCallback(async (tabId: string, groupId?: string) => {
const handleSessionMessage = useCallback((
tabId: string,
tabType: 'webui' | 'virtual',
config: VirtualIdentityConfig | undefined,
data: WsMessage,
) => {
switch (data.type) {
case 'session_info':
updateTab(tabId, {
sessionInfo: {
session_id: data.session_id,
user_id: data.user_id,
user_name: data.user_name,
bot_name: data.bot_name,
}
})
break
case 'system':
addMessageToTab(tabId, {
id: generateMessageId('sys'),
type: 'system',
content: data.content || '',
timestamp: data.timestamp || Date.now() / 1000,
})
break
case 'user_message': {
const senderUserId = data.sender?.user_id
const currentUserId = tabType === 'virtual' && config
? config.userId
: userIdRef.current
const normalizeSenderId = senderUserId ? senderUserId.replace(/^webui_user_/, '') : ''
const normalizeCurrentId = currentUserId ? currentUserId.replace(/^webui_user_/, '') : ''
if (normalizeSenderId && normalizeCurrentId && normalizeSenderId === normalizeCurrentId) {
break
}
const processedSet = processedMessagesMapRef.current.get(tabId) || new Set()
const contentHash = `user-${data.content}-${Math.floor((data.timestamp || 0) * 1000)}`
if (processedSet.has(contentHash)) {
break
}
processedSet.add(contentHash)
processedMessagesMapRef.current.set(tabId, processedSet)
if (processedSet.size > 100) {
const firstKey = processedSet.values().next().value
if (firstKey) processedSet.delete(firstKey)
}
addMessageToTab(tabId, {
id: data.message_id || generateMessageId('user'),
type: 'user',
content: data.content || '',
timestamp: data.timestamp || Date.now() / 1000,
sender: data.sender,
})
break
}
case 'bot_message': {
updateTab(tabId, { isTyping: false })
const processedSet = processedMessagesMapRef.current.get(tabId) || new Set()
const contentHash = `bot-${data.content}-${Math.floor((data.timestamp || 0) * 1000)}`
if (processedSet.has(contentHash)) {
break
}
processedSet.add(contentHash)
processedMessagesMapRef.current.set(tabId, processedSet)
if (processedSet.size > 100) {
const firstKey = processedSet.values().next().value
if (firstKey) processedSet.delete(firstKey)
}
setTabs(prev => prev.map(tab => {
if (tab.id !== tabId) return tab
const filteredMessages = tab.messages.filter(msg => msg.type !== 'thinking')
const newMessage: ChatMessage = {
id: generateMessageId('bot'),
type: 'bot',
content: data.content || '',
message_type: (data.message_type === 'rich' ? 'rich' : 'text') as 'text' | 'rich',
segments: data.segments,
timestamp: data.timestamp || Date.now() / 1000,
sender: data.sender,
}
return {
...tab,
messages: [...filteredMessages, newMessage]
}
}))
break
}
case 'typing':
updateTab(tabId, { isTyping: data.is_typing || false })
break
case 'error':
setTabs(prev => prev.map(tab => {
if (tab.id !== tabId) return tab
const filteredMessages = tab.messages.filter(msg => msg.type !== 'thinking')
return {
...tab,
messages: [...filteredMessages, {
id: generateMessageId('error'),
type: 'error' as const,
content: data.content || '发生错误',
timestamp: data.timestamp || Date.now() / 1000,
}]
}
}))
toast({
title: '错误',
description: data.content,
variant: 'destructive',
})
break
case 'history': {
const historyMessages = data.messages || []
const processedSet = new Set<string>()
const formattedMessages: ChatMessage[] = historyMessages.map((msg: {
id?: string
content: string
timestamp: number
sender_name?: string
sender_id?: string
is_bot?: boolean
}) => {
const isBot = msg.is_bot || false
const msgId = msg.id || generateMessageId(isBot ? 'bot' : 'user')
const contentHash = `${isBot ? 'bot' : 'user'}-${msg.content}-${Math.floor(msg.timestamp * 1000)}`
processedSet.add(contentHash)
return {
id: msgId,
type: isBot ? 'bot' : 'user' as const,
content: msg.content,
timestamp: msg.timestamp,
sender: {
name: msg.sender_name || (isBot ? '麦麦' : '用户'),
user_id: msg.sender_id,
is_bot: isBot,
},
}
})
processedMessagesMapRef.current.set(tabId, processedSet)
updateTab(tabId, { messages: formattedMessages })
setIsLoadingHistory(false)
break
}
default:
break
}
}, [addMessageToTab, toast, updateTab])
const ensureSessionListener = useCallback((
tabId: string,
tabType: 'webui' | 'virtual',
config?: VirtualIdentityConfig,
) => {
if (sessionUnsubscribeMapRef.current.has(tabId)) {
return
}
const unsubscribe = chatWsClient.onSessionMessage(tabId, (message) => {
handleSessionMessage(tabId, tabType, config, message as unknown as WsMessage)
})
sessionUnsubscribeMapRef.current.set(tabId, unsubscribe)
}, [handleSessionMessage])
const openSessionForTab = useCallback(async (
tabId: string,
tabType: 'webui' | 'virtual',
config?: VirtualIdentityConfig,
) => {
ensureSessionListener(tabId, tabType, config)
setIsLoadingHistory(true)
try {
const params = new URLSearchParams()
params.append('user_id', userIdRef.current)
params.append('limit', '50')
if (groupId) {
params.append('group_id', groupId)
if (tabType === 'virtual' && config) {
await chatWsClient.openSession(tabId, {
user_id: config.userId,
user_name: config.userName,
platform: config.platform,
person_id: config.personId,
group_name: config.groupName || 'WebUI虚拟群聊',
group_id: config.groupId,
})
} else {
await chatWsClient.openSession(tabId, {
user_id: userIdRef.current,
user_name: userName,
})
}
const url = `/api/chat/history?${params.toString()}`
console.log('[Chat] 正在加载历史消息:', url)
const response = await fetchWithAuth(url)
if (response.ok) {
const text = await response.text()
try {
const data = JSON.parse(text)
if (data.messages && data.messages.length > 0) {
const historyMessages: ChatMessage[] = data.messages.map((msg: {
id: string
type: string
content: string
timestamp: number
sender_name?: string
user_id?: string
is_bot?: boolean
}) => ({
id: msg.id,
type: msg.type as 'user' | 'bot' | 'system' | 'error',
content: msg.content,
timestamp: msg.timestamp,
sender: {
name: msg.sender_name || (msg.is_bot ? '麦麦' : 'WebUI用户'),
user_id: msg.user_id,
is_bot: msg.is_bot
}
}))
// 更新标签页的消息
updateTab(tabId, { messages: historyMessages })
// 将历史消息添加到去重缓存
const processedSet = processedMessagesMapRef.current.get(tabId) || new Set()
historyMessages.forEach(msg => {
if (msg.type === 'bot') {
const contentHash = `bot-${msg.content}-${Math.floor(msg.timestamp * 1000)}`
processedSet.add(contentHash)
}
})
processedMessagesMapRef.current.set(tabId, processedSet)
}
} catch (parseError) {
console.error('[Chat] JSON 解析失败:', parseError)
}
}
} catch (e) {
console.error('[Chat] 加载历史消息失败:', e)
} finally {
setIsLoadingHistory(false)
}
}, [updateTab])
// 为指定标签页连接 WebSocket异步需要先获取认证 token
const connectWebSocketForTab = useCallback(async (tabId: string, tabType: 'webui' | 'virtual', config?: VirtualIdentityConfig) => {
// 如果已经有连接,不要重复创建
const existingWs = wsMapRef.current.get(tabId)
if (existingWs?.readyState === WebSocket.OPEN ||
existingWs?.readyState === WebSocket.CONNECTING) {
console.log(`[Tab ${tabId}] WebSocket 已存在,跳过连接`)
return
}
setIsConnecting(true)
// 先获取临时 WebSocket token
let wsToken: string | null = null
try {
const tokenResponse = await fetchWithAuth('/api/webui/ws-token')
if (tokenResponse.ok) {
const tokenData = await tokenResponse.json()
if (tokenData.success && tokenData.token) {
wsToken = tokenData.token
} else {
console.warn(`[Tab ${tabId}] 获取 WebSocket token 失败: ${tokenData.message || '未登录'}`)
setIsConnecting(false)
return
}
}
updateTab(tabId, { isConnected: true })
} catch (error) {
console.error(`[Tab ${tabId}] 获取 WebSocket token 失败:`, error)
setIsConnecting(false)
return
console.error(`[Tab ${tabId}] 打开聊天会话失败:`, error)
setIsLoadingHistory(false)
toast({
title: '连接失败',
description: '无法建立聊天会话,请稍后重试',
variant: 'destructive',
})
}
// 此时 wsToken 一定有值(前面已经 return
if (!wsToken) {
setIsConnecting(false)
return
}
const wsBase = await getWsBaseUrl()
const params = new URLSearchParams()
// 添加 token 到参数
params.append('token', wsToken)
if (tabType === 'virtual' && config) {
params.append('user_id', config.userId)
params.append('user_name', config.userName)
params.append('platform', config.platform)
params.append('person_id', config.personId)
params.append('group_name', config.groupName || 'WebUI虚拟群聊')
// 传递稳定的 group_id确保历史记录能正确加载
if (config.groupId) {
params.append('group_id', config.groupId)
}
} else {
params.append('user_id', userIdRef.current)
params.append('user_name', userName)
}
const wsUrl = `${wsBase}/api/chat/ws?${params.toString()}`
console.log(`[Tab ${tabId}] 正在连接 WebSocket:`, wsUrl)
try {
const ws = new WebSocket(wsUrl)
wsMapRef.current.set(tabId, ws)
ws.onopen = () => {
updateTab(tabId, { isConnected: true })
setIsConnecting(false)
console.log(`[Tab ${tabId}] WebSocket 已连接`)
}
ws.onmessage = (event) => {
try {
const data: WsMessage = JSON.parse(event.data)
switch (data.type) {
case 'session_info':
updateTab(tabId, {
sessionInfo: {
session_id: data.session_id,
user_id: data.user_id,
user_name: data.user_name,
bot_name: data.bot_name,
}
})
break
case 'system':
addMessageToTab(tabId, {
id: generateMessageId('sys'),
type: 'system',
content: data.content || '',
timestamp: data.timestamp || Date.now() / 1000,
})
break
case 'user_message': {
// 检查是否是自己发的消息(已在发送时显示,跳过广播回来的)
const senderUserId = data.sender?.user_id
const currentUserId = tabType === 'virtual' && config
? config.userId
: userIdRef.current
console.log(`[Tab ${tabId}] 收到 user_message, sender: ${senderUserId}, current: ${currentUserId}`)
// 标准化 user_id去掉可能的前缀
const normalizeSenderId = senderUserId ? senderUserId.replace(/^webui_user_/, '') : ''
const normalizeCurrentId = currentUserId ? currentUserId.replace(/^webui_user_/, '') : ''
// 如果是自己发的消息,跳过(避免重复显示)
if (normalizeSenderId && normalizeCurrentId && normalizeSenderId === normalizeCurrentId) {
console.log(`[Tab ${tabId}] 跳过自己的消息user_id 匹配)`)
break
}
// 额外的消息去重:检查内容和时间戳
const processedSet = processedMessagesMapRef.current.get(tabId) || new Set()
const contentHash = `user-${data.content}-${Math.floor((data.timestamp || 0) * 1000)}`
if (processedSet.has(contentHash)) {
console.log(`[Tab ${tabId}] 跳过自己的消息(内容去重)`)
break
}
processedSet.add(contentHash)
processedMessagesMapRef.current.set(tabId, processedSet)
if (processedSet.size > 100) {
const firstKey = processedSet.values().next().value
if (firstKey) processedSet.delete(firstKey)
}
addMessageToTab(tabId, {
id: data.message_id || generateMessageId('user'),
type: 'user',
content: data.content || '',
timestamp: data.timestamp || Date.now() / 1000,
sender: data.sender,
})
break
}
case 'bot_message': {
updateTab(tabId, { isTyping: false })
const processedSet = processedMessagesMapRef.current.get(tabId) || new Set()
const contentHash = `bot-${data.content}-${Math.floor((data.timestamp || 0) * 1000)}`
if (processedSet.has(contentHash)) {
break
}
processedSet.add(contentHash)
processedMessagesMapRef.current.set(tabId, processedSet)
if (processedSet.size > 100) {
const firstKey = processedSet.values().next().value
if (firstKey) processedSet.delete(firstKey)
}
// 移除"思考中"占位消息,添加真实的机器人回复
setTabs(prev => prev.map(tab => {
if (tab.id !== tabId) return tab
// 过滤掉 thinking 类型的消息
const filteredMessages = tab.messages.filter(msg => msg.type !== 'thinking')
const newMessage: ChatMessage = {
id: generateMessageId('bot'),
type: 'bot',
content: data.content || '',
message_type: (data.message_type === 'rich' ? 'rich' : 'text') as 'text' | 'rich',
segments: data.segments,
timestamp: data.timestamp || Date.now() / 1000,
sender: data.sender,
}
return {
...tab,
messages: [...filteredMessages, newMessage]
}
}))
break
}
case 'typing':
updateTab(tabId, { isTyping: data.is_typing || false })
break
case 'error':
// 移除"思考中"占位消息,显示错误
setTabs(prev => prev.map(tab => {
if (tab.id !== tabId) return tab
const filteredMessages = tab.messages.filter(msg => msg.type !== 'thinking')
return {
...tab,
messages: [...filteredMessages, {
id: generateMessageId('error'),
type: 'error' as const,
content: data.content || '发生错误',
timestamp: data.timestamp || Date.now() / 1000,
}]
}
}))
toast({
title: '错误',
description: data.content,
variant: 'destructive',
})
break
case 'pong':
break
case 'history': {
// 处理服务端发送的历史消息
const historyMessages = data.messages || []
if (historyMessages.length > 0) {
const processedSet = processedMessagesMapRef.current.get(tabId) || new Set()
const formattedMessages: ChatMessage[] = historyMessages.map((msg: {
id?: string
content: string
timestamp: number
sender_name?: string
sender_id?: string
is_bot?: boolean
}) => {
const isBot = msg.is_bot || false
const msgId = msg.id || generateMessageId(isBot ? 'bot' : 'user')
// 添加到去重集合
const contentHash = `${isBot ? 'bot' : 'user'}-${msg.content}-${Math.floor(msg.timestamp * 1000)}`
processedSet.add(contentHash)
return {
id: msgId,
type: isBot ? 'bot' : 'user' as const,
content: msg.content,
timestamp: msg.timestamp,
sender: {
name: msg.sender_name || (isBot ? '麦麦' : '用户'),
user_id: msg.sender_id,
is_bot: isBot,
},
}
})
processedMessagesMapRef.current.set(tabId, processedSet)
// 替换当前标签页的所有消息
updateTab(tabId, { messages: formattedMessages })
console.log(`[Tab ${tabId}] 已加载 ${formattedMessages.length} 条历史消息`)
}
break
}
default:
console.log('未知消息类型:', data.type)
}
} catch (e) {
console.error('解析消息失败:', e)
}
}
ws.onclose = () => {
updateTab(tabId, { isConnected: false })
setIsConnecting(false)
wsMapRef.current.delete(tabId)
console.log(`[Tab ${tabId}] WebSocket 已断开`)
// 清除旧的重连定时器
const oldTimeout = reconnectTimeoutMapRef.current.get(tabId)
if (oldTimeout) {
clearTimeout(oldTimeout)
}
// 5秒后尝试重连
const timeout = window.setTimeout(() => {
if (!isUnmountedRef.current) {
const tab = tabs.find(t => t.id === tabId)
if (tab) {
connectWebSocketForTab(tabId, tab.type, tab.virtualConfig)
}
}
}, 5000)
reconnectTimeoutMapRef.current.set(tabId, timeout)
}
ws.onerror = (error) => {
console.error(`[Tab ${tabId}] WebSocket 错误:`, error)
setIsConnecting(false)
}
} catch (e) {
console.error(`[Tab ${tabId}] 创建 WebSocket 失败:`, e)
setIsConnecting(false)
}
}, [userName, updateTab, addMessageToTab, toast, tabs])
}, [ensureSessionListener, toast, updateTab, userName])
// 用于追踪组件是否已卸载
const isUnmountedRef = useRef(false)
@@ -555,69 +419,49 @@ export function ChatPage() {
// 初始化连接(默认 WebUI 标签页)
useEffect(() => {
isUnmountedRef.current = false
// 保存 ref 的当前值,用于清理
const wsMap = wsMapRef.current
const reconnectTimeoutMap = reconnectTimeoutMapRef.current
const processedMessagesMap = processedMessagesMapRef.current
// 加载默认标签页历史消息
loadChatHistoryForTab('webui-default')
// 延迟连接
const connectTimer = setTimeout(() => {
if (!isUnmountedRef.current) {
connectWebSocketForTab('webui-default', 'webui')
// 恢复的虚拟标签页也需要建立连接
tabs.forEach(tab => {
if (tab.type === 'virtual' && tab.virtualConfig) {
// 初始化去重缓存
processedMessagesMap.set(tab.id, new Set())
// 建立 WebSocket 连接
setTimeout(() => {
if (!isUnmountedRef.current) {
connectWebSocketForTab(tab.id, 'virtual', tab.virtualConfig)
}
}, 200)
}
})
}
}, 100)
// 心跳定时器 - 向所有活动连接发送
const heartbeat = setInterval(() => {
wsMap.forEach((ws) => {
if (ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({ type: 'ping' }))
}
})
}, 30000)
const unsubscribeConnection = chatWsClient.onConnectionChange((connected) => {
if (isUnmountedRef.current) {
return
}
setTabs(prev => prev.map(tab => ({
...tab,
isConnected: connected,
})))
})
const unsubscribeStatus = chatWsClient.onStatusChange((status) => {
if (!isUnmountedRef.current) {
setIsConnecting(status === 'connecting')
}
})
tabs.forEach(tab => {
processedMessagesMapRef.current.set(tab.id, new Set())
void openSessionForTab(tab.id, tab.type, tab.virtualConfig)
})
return () => {
isUnmountedRef.current = true
clearTimeout(connectTimer)
clearInterval(heartbeat)
// 清理所有重连定时器
reconnectTimeoutMap.forEach((timeout) => {
clearTimeout(timeout)
unsubscribeConnection()
unsubscribeStatus()
sessionUnsubscribeMapRef.current.forEach((unsubscribe) => {
unsubscribe()
})
reconnectTimeoutMap.clear()
// 关闭所有 WebSocket 连接
wsMap.forEach((ws) => {
ws.close()
sessionUnsubscribeMapRef.current.clear()
tabsRef.current.forEach(tab => {
void chatWsClient.closeSession(tab.id)
})
wsMap.clear()
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
// 发送消息到当前活动标签页
const sendMessage = useCallback(() => {
const ws = wsMapRef.current.get(activeTabId)
if (!inputValue.trim() || !ws || ws.readyState !== WebSocket.OPEN) {
const sendMessage = useCallback(async () => {
if (!inputValue.trim() || !activeTab?.isConnected) {
return
}
@@ -628,12 +472,6 @@ export function ChatPage() {
const messageContent = inputValue.trim()
const currentTimestamp = Date.now() / 1000
ws.send(JSON.stringify({
type: 'message',
content: messageContent,
user_name: displayName,
}))
// 添加到去重缓存,防止服务器广播回来的消息重复显示
const processedSet = processedMessagesMapRef.current.get(activeTabId) || new Set()
const contentHash = `user-${messageContent}-${Math.floor(currentTimestamp * 1000)}`
@@ -672,13 +510,32 @@ export function ChatPage() {
addMessageToTab(activeTabId, thinkingMessage)
setInputValue('')
}, [inputValue, userName, activeTabId, activeTab, addMessageToTab])
try {
await chatWsClient.sendMessage(activeTabId, messageContent, displayName)
} catch (error) {
console.error('发送聊天消息失败:', error)
setTabs(prev => prev.map(tab => {
if (tab.id !== activeTabId) return tab
return {
...tab,
isTyping: false,
messages: tab.messages.filter(msg => msg.type !== 'thinking')
}
}))
toast({
title: '发送失败',
description: '当前聊天会话不可用,请稍后重试',
variant: 'destructive',
})
}
}, [activeTab, activeTabId, addMessageToTab, inputValue, toast, userName])
// 处理键盘事件
const handleKeyDown = (e: React.KeyboardEvent) => {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault()
sendMessage()
void sendMessage()
}
}
@@ -693,13 +550,9 @@ export function ChatPage() {
setUserName(newName)
saveUserName(newName)
setIsEditingName(false)
// 通知当前标签页的后端昵称变更
const ws = wsMapRef.current.get(activeTabId)
if (ws?.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({
type: 'update_nickname',
user_name: newName
}))
if (activeTab?.isConnected) {
void chatWsClient.updateNickname(activeTabId, newName)
}
}
@@ -719,12 +572,7 @@ export function ChatPage() {
// 重新连接当前标签页
const handleReconnect = () => {
const ws = wsMapRef.current.get(activeTabId)
if (ws) {
ws.close()
wsMapRef.current.delete(activeTabId)
}
connectWebSocketForTab(activeTabId, activeTab?.type || 'webui', activeTab?.virtualConfig)
void chatWsClient.restart()
}
// 打开虚拟身份配置对话框(新建标签页用)
@@ -795,10 +643,10 @@ export function ChatPage() {
// 初始化去重缓存
processedMessagesMapRef.current.set(newTabId, new Set())
// 连接 WebSocket
setTimeout(() => {
connectWebSocketForTab(newTabId, 'virtual', tempVirtualConfig)
}, 100)
void openSessionForTab(newTabId, 'virtual', {
...tempVirtualConfig,
groupId: stableGroupId,
})
toast({
title: '虚拟身份标签页',
@@ -814,20 +662,14 @@ export function ChatPage() {
if (tabId === 'webui-default') {
return
}
// 关闭 WebSocket 连接
const ws = wsMapRef.current.get(tabId)
if (ws) {
ws.close()
wsMapRef.current.delete(tabId)
}
// 清理重连定时器
const timeout = reconnectTimeoutMapRef.current.get(tabId)
if (timeout) {
clearTimeout(timeout)
reconnectTimeoutMapRef.current.delete(tabId)
const unsubscribe = sessionUnsubscribeMapRef.current.get(tabId)
if (unsubscribe) {
unsubscribe()
sessionUnsubscribeMapRef.current.delete(tabId)
}
void chatWsClient.closeSession(tabId)
// 清理去重缓存
processedMessagesMapRef.current.delete(tabId)
@@ -1133,7 +975,7 @@ export function ChatPage() {
className="flex-1 h-10 sm:h-10"
/>
<Button
onClick={sendMessage}
onClick={() => { void sendMessage() }}
disabled={!activeTab?.isConnected || !inputValue.trim()}
size="icon"
className="h-10 w-10 shrink-0"

View File

@@ -93,12 +93,12 @@ function PluginsPageContent() {
// 统一管理 WebSocket 和数据加载
useEffect(() => {
let ws: WebSocket | null = null
let unsubscribeProgress: (() => Promise<void>) | null = null
let isUnmounted = false
const init = async () => {
// 1. 先连接 WebSocket异步获取 token
ws = await connectPluginProgressWebSocket(
unsubscribeProgress = await connectPluginProgressWebSocket(
(progress) => {
if (isUnmounted) return
@@ -128,29 +128,7 @@ function PluginsPageContent() {
}
)
// 2. 等待 WebSocket 连接建立
await new Promise<void>((resolve) => {
if (!ws) {
resolve()
return
}
const checkConnection = () => {
if (ws && ws.readyState === WebSocket.OPEN) {
console.log('WebSocket connected, starting to load plugins')
resolve()
} else if (ws && ws.readyState === WebSocket.CLOSED) {
console.warn('WebSocket closed before loading plugins')
resolve()
} else {
setTimeout(checkConnection, 100)
}
}
checkConnection()
})
// 3. 检查 Git 状态
// 2. 检查 Git 状态
if (!isUnmounted) {
const statusResult = await checkGitStatus()
if (!statusResult.success) {
@@ -173,7 +151,7 @@ function PluginsPageContent() {
}
}
// 4. 获取麦麦版本
// 3. 获取麦麦版本
if (!isUnmounted) {
const versionResult = await getMaimaiVersion()
if (!versionResult.success) {
@@ -186,7 +164,7 @@ function PluginsPageContent() {
setMaimaiVersion(versionResult.data)
}
}
// 5. 加载插件列表(包含已安装信息)
// 4. 加载插件列表(包含已安装信息)
if (!isUnmounted) {
try {
setLoading(true)
@@ -282,8 +260,8 @@ function PluginsPageContent() {
return () => {
isUnmounted = true
if (ws) {
ws.close()
if (unsubscribeProgress) {
void unsubscribeProgress()
}
}
}, [toast])

View File

@@ -0,0 +1,136 @@
"""业务命名 Hook 集成测试。"""
from types import SimpleNamespace
from typing import Any
import os
import sys
import pytest
# 确保项目根目录在 sys.path 中
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
# SDK 包路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "packages", "maibot-plugin-sdk"))
class _FakeHookManager:
"""用于业务 Hook 测试的最小运行时管理器。"""
def __init__(self, responses: dict[str, SimpleNamespace]) -> None:
"""初始化测试管理器。
Args:
responses: 按 Hook 名称预设的返回结果映射。
"""
self._responses = responses
self.calls: list[tuple[str, dict[str, Any]]] = []
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> SimpleNamespace:
"""模拟调用运行时命名 Hook。
Args:
hook_name: 目标 Hook 名称。
**kwargs: 传入 Hook 的参数。
Returns:
SimpleNamespace: 预设的 Hook 返回结果。
"""
self.calls.append((hook_name, dict(kwargs)))
return self._responses.get(hook_name, SimpleNamespace(kwargs=dict(kwargs), aborted=False))
def test_builtin_hook_catalog_includes_new_business_hooks(monkeypatch: pytest.MonkeyPatch) -> None:
"""内置 Hook 目录应包含三个业务系统新增的 Hook。"""
monkeypatch.setattr(sys, "exit", lambda code=0: None)
from src.plugin_runtime.hook_catalog import register_builtin_hook_specs
from src.plugin_runtime.host.hook_spec_registry import HookSpecRegistry
registry = HookSpecRegistry()
hook_names = {spec.name for spec in register_builtin_hook_specs(registry)}
assert "emoji.maisaka.before_select" in hook_names
assert "emoji.register.after_build_emotion" in hook_names
assert "jargon.extract.before_persist" in hook_names
assert "jargon.query.after_search" in hook_names
assert "expression.select.before_select" in hook_names
assert "expression.learn.before_upsert" in hook_names
@pytest.mark.asyncio
async def test_send_emoji_for_maisaka_can_be_aborted_by_hook(monkeypatch: pytest.MonkeyPatch) -> None:
"""表情包系统应允许在选择前被 Hook 中止。"""
from src.chat.emoji_system import maisaka_tool
fake_manager = _FakeHookManager(
{
"emoji.maisaka.before_select": SimpleNamespace(
kwargs={"abort_message": "插件阻止了表情发送。"},
aborted=True,
)
}
)
monkeypatch.setattr(maisaka_tool, "_get_runtime_manager", lambda: fake_manager)
result = await maisaka_tool.send_emoji_for_maisaka(stream_id="stream-1", requested_emotion="开心")
assert result.success is False
assert result.message == "插件阻止了表情发送。"
assert fake_manager.calls[0][0] == "emoji.maisaka.before_select"
@pytest.mark.asyncio
async def test_jargon_extract_can_be_aborted_before_persist(monkeypatch: pytest.MonkeyPatch) -> None:
"""黑话提取结果应允许在写库前被 Hook 中止。"""
from src.learners.jargon_miner import JargonMiner
fake_manager = _FakeHookManager(
{
"jargon.extract.before_persist": SimpleNamespace(
kwargs={"entries": []},
aborted=True,
)
}
)
monkeypatch.setattr(JargonMiner, "_get_runtime_manager", staticmethod(lambda: fake_manager))
miner = JargonMiner(session_id="session-1", session_name="测试会话")
await miner.process_extracted_entries(
[{"content": "yyds", "raw_content": {"[1] yyds 太强了"}}],
)
assert fake_manager.calls[0][0] == "jargon.extract.before_persist"
assert fake_manager.calls[0][1]["session_id"] == "session-1"
@pytest.mark.asyncio
async def test_expression_selection_can_be_aborted_by_hook(monkeypatch: pytest.MonkeyPatch) -> None:
"""表达方式选择流程应允许在开始前被 Hook 中止。"""
from src.learners.expression_selector import ExpressionSelector
fake_manager = _FakeHookManager(
{
"expression.select.before_select": SimpleNamespace(
kwargs={},
aborted=True,
)
}
)
monkeypatch.setattr(ExpressionSelector, "_get_runtime_manager", staticmethod(lambda: fake_manager))
monkeypatch.setattr(ExpressionSelector, "can_use_expression_for_chat", lambda self, chat_id: True)
selector = ExpressionSelector()
selected_expressions, selected_ids = await selector.select_suitable_expressions(
chat_id="session-1",
chat_info="用户刚刚发来一条消息。",
)
assert selected_expressions == []
assert selected_ids == []
assert fake_manager.calls[0][0] == "expression.select.before_select"

View File

@@ -1,25 +1,28 @@
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from rich.traceback import install
from sqlmodel import select
from typing import Any, Dict, List, Optional, Tuple
import asyncio
import hashlib
import heapq
import Levenshtein
import random
import re
from src.common.logger import get_logger
from rich.traceback import install
from sqlmodel import select
import Levenshtein
from src.common.data_models.image_data_model import MaiEmoji
from src.common.database.database_model import Images, ImageType
from src.common.database.database import get_db_session, get_db_session_manual
from src.common.utils.utils_image import ImageUtils
from src.prompt.prompt_manager import prompt_manager
from src.config.config import config_manager, global_config
from src.common.data_models.llm_service_data_models import LLMGenerationOptions, LLMImageOptions
from src.common.database.database import get_db_session, get_db_session_manual
from src.common.database.database_model import Images, ImageType
from src.common.logger import get_logger
from src.common.utils.utils_image import ImageUtils
from src.config.config import config_manager, global_config
from src.plugin_runtime.hook_schema_utils import build_object_schema
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
from src.prompt.prompt_manager import prompt_manager
from src.services.llm_service import LLMServiceClient
logger = get_logger("emoji")
@@ -33,6 +36,171 @@ EMOJI_REGISTERED_DIR = DATA_DIR / "emoji_registered" # 已注册的表情包注
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
def register_emoji_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""注册表情包系统内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 实际注册的 Hook 规格列表。
"""
emoji_schema = {
"type": "object",
"description": "当前表情包的序列化信息,主要包含 file_hash、description、emotions 等字段。",
}
string_array_schema = {
"type": "array",
"items": {"type": "string"},
}
return registry.register_hook_specs(
[
HookSpec(
name="emoji.maisaka.before_select",
description="Maisaka 表情发送工具选择表情前触发,可改写情绪、上下文和采样参数,或中止本次选择。",
parameters_schema=build_object_schema(
{
"stream_id": {"type": "string", "description": "目标会话 ID。"},
"requested_emotion": {"type": "string", "description": "请求的目标情绪标签。"},
"reasoning": {"type": "string", "description": "本次发送表情的推理理由。"},
"context_texts": {
**string_array_schema,
"description": "最近聊天上下文文本列表。",
},
"sample_size": {"type": "integer", "description": "候选表情采样数量。"},
"abort_message": {
"type": "string",
"description": "当 Hook 主动中止时可附带的失败提示。",
},
},
required=["stream_id", "requested_emotion", "reasoning", "context_texts", "sample_size"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="emoji.maisaka.after_select",
description="Maisaka 已选出表情后触发,可替换选中的表情哈希、补充匹配情绪,或中止发送。",
parameters_schema=build_object_schema(
{
"stream_id": {"type": "string", "description": "目标会话 ID。"},
"requested_emotion": {"type": "string", "description": "请求的目标情绪标签。"},
"reasoning": {"type": "string", "description": "本次发送表情的推理理由。"},
"context_texts": {
**string_array_schema,
"description": "最近聊天上下文文本列表。",
},
"sample_size": {"type": "integer", "description": "候选表情采样数量。"},
"selected_emoji": emoji_schema,
"selected_emoji_hash": {"type": "string", "description": "选中的表情哈希。"},
"matched_emotion": {"type": "string", "description": "最终命中的情绪标签。"},
"abort_message": {
"type": "string",
"description": "当 Hook 主动中止时可附带的失败提示。",
},
},
required=[
"stream_id",
"requested_emotion",
"reasoning",
"context_texts",
"sample_size",
"matched_emotion",
],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="emoji.register.after_build_description",
description="表情包描述生成并通过内容审查后触发,可改写描述文本或拒绝本次注册。",
parameters_schema=build_object_schema(
{
"emoji": emoji_schema,
"description": {"type": "string", "description": "当前生成出的表情包描述。"},
"image_format": {"type": "string", "description": "表情图片格式。"},
},
required=["emoji", "description", "image_format"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="emoji.register.after_build_emotion",
description="表情包情绪标签生成完成后触发,可改写标签列表或拒绝本次注册。",
parameters_schema=build_object_schema(
{
"emoji": emoji_schema,
"description": {"type": "string", "description": "当前表情包描述。"},
"emotions": {
**string_array_schema,
"description": "当前生成出的情绪标签列表。",
},
},
required=["emoji", "description", "emotions"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
]
)
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[Dict[str, Any]]:
"""将表情包对象序列化为 Hook 可传输载荷。
Args:
emoji: 待序列化的表情包对象。
Returns:
Optional[Dict[str, Any]]: 序列化后的字典;当表情为空时返回 ``None``。
"""
if emoji is None:
return None
return {
"file_hash": str(emoji.file_hash or "").strip(),
"file_name": emoji.file_name,
"full_path": str(emoji.full_path),
"description": emoji.description,
"emotions": [str(item).strip() for item in emoji.emotion if str(item).strip()],
"query_count": int(emoji.query_count),
}
def _normalize_string_list(raw_values: Any) -> List[str]:
"""将任意列表值规范化为字符串列表。
Args:
raw_values: 待规范化的原始值。
Returns:
List[str]: 去空白后的字符串列表。
"""
if not isinstance(raw_values, list):
return []
return [str(item).strip() for item in raw_values if str(item).strip()]
def _ensure_directories() -> None:
"""确保表情包相关目录存在"""
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
@@ -642,6 +810,22 @@ class EmojiManager:
if "" in llm_response:
logger.warning(f"[表情包审查] 表情包内容不符合要求,拒绝注册: {target_emoji.file_name}")
return False, target_emoji
hook_result = await _get_runtime_manager().invoke_hook(
"emoji.register.after_build_description",
emoji=_serialize_emoji_for_hook(target_emoji),
description=description,
image_format=image_format,
)
if hook_result.aborted:
logger.info(f"[构建描述] 表情包描述被 Hook 中止注册: {target_emoji.file_name}")
return False, target_emoji
normalized_description = str(hook_result.kwargs.get("description", description) or "").strip()
if not normalized_description:
logger.warning(f"[构建描述] Hook 返回空描述,拒绝注册: {target_emoji.file_name}")
return False, target_emoji
description = normalized_description
target_emoji.description = description
logger.info(f"[构建描述] 成功为表情包构建描述: {target_emoji.description}")
return True, target_emoji
@@ -687,6 +871,23 @@ class EmojiManager:
elif len(emotions) > 2:
emotions = random.sample(emotions, 2)
hook_result = await _get_runtime_manager().invoke_hook(
"emoji.register.after_build_emotion",
emoji=_serialize_emoji_for_hook(target_emoji),
description=target_emoji.description,
emotions=list(emotions),
)
if hook_result.aborted:
logger.info(f"[构建情感标签] 表情包情感标签被 Hook 中止注册: {target_emoji.file_name}")
return False, target_emoji
raw_emotions = hook_result.kwargs.get("emotions")
if raw_emotions is not None:
emotions = _normalize_string_list(raw_emotions)
if not emotions:
logger.warning(f"[构建情感标签] Hook 返回空情绪标签,拒绝注册: {target_emoji.file_name}")
return False, target_emoji
logger.info(f"[构建情感标签] 成功为表情包构建情感标签: {','.join(emotions)}")
target_emoji.emotion = emotions
return True, target_emoji

View File

@@ -1,7 +1,7 @@
"""Maisaka 表情工具内置能力。"""
from dataclasses import dataclass, field
from typing import Sequence
from typing import Any, Optional, Sequence
import random
@@ -11,7 +11,7 @@ from src.common.logger import get_logger
from src.common.utils.utils_image import ImageUtils
from src.services import send_service
from .emoji_manager import emoji_manager, emoji_manager_emotion_judge_llm
from .emoji_manager import _serialize_emoji_for_hook, emoji_manager, emoji_manager_emotion_judge_llm
logger = get_logger("emoji_maisaka_tool")
@@ -29,6 +29,76 @@ class MaisakaEmojiSendResult:
matched_emotion: str = ""
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
def _coerce_positive_int(value: Any, default: int) -> int:
"""将任意值安全转换为正整数。
Args:
value: 待转换的值。
default: 转换失败时使用的默认值。
Returns:
int: 规范化后的正整数。
"""
try:
normalized_value = int(value)
except (TypeError, ValueError):
return default
return normalized_value if normalized_value > 0 else default
def _normalize_context_texts(context_texts: Sequence[str] | None) -> list[str]:
"""清洗 Hook 和调用链传入的上下文文本列表。
Args:
context_texts: 原始上下文文本序列。
Returns:
list[str]: 过滤空白后的上下文文本列表。
"""
if not context_texts:
return []
return [str(item).strip() for item in context_texts if str(item).strip()]
def _resolve_selected_emoji(raw_value: Any) -> Optional[MaiEmoji]:
"""根据 Hook 返回值解析目标表情包对象。
Args:
raw_value: Hook 返回的 ``selected_emoji`` 或 ``selected_emoji_hash``。
Returns:
Optional[MaiEmoji]: 命中的表情包对象;未命中时返回 ``None``。
"""
raw_hash: str = ""
if isinstance(raw_value, dict):
raw_hash = str(raw_value.get("file_hash") or raw_value.get("hash") or "").strip()
elif isinstance(raw_value, str):
raw_hash = raw_value.strip()
if not raw_hash:
return None
for emoji in emoji_manager.emojis:
if emoji.file_hash == raw_hash:
return emoji
return None
def _normalize_emotions(emoji: MaiEmoji) -> list[str]:
"""提取并清洗单个表情的情绪标签。"""
@@ -129,16 +199,81 @@ async def send_emoji_for_maisaka(
) -> MaisakaEmojiSendResult:
"""为 Maisaka 选择并发送一个表情。"""
selected_emoji, matched_emotion = await select_emoji_for_maisaka(
requested_emotion=requested_emotion,
reasoning=reasoning,
context_texts=context_texts,
normalized_requested_emotion = requested_emotion.strip()
normalized_reasoning = reasoning.strip()
normalized_context_texts = _normalize_context_texts(context_texts)
sample_size = 30
before_select_result = await _get_runtime_manager().invoke_hook(
"emoji.maisaka.before_select",
stream_id=stream_id,
requested_emotion=normalized_requested_emotion,
reasoning=normalized_reasoning,
context_texts=list(normalized_context_texts),
sample_size=sample_size,
abort_message="表情选择已被 Hook 中止。",
)
if before_select_result.aborted:
abort_message = str(before_select_result.kwargs.get("abort_message") or "表情选择已被 Hook 中止。").strip()
return MaisakaEmojiSendResult(
success=False,
message=abort_message or "表情选择已被 Hook 中止。",
requested_emotion=normalized_requested_emotion,
)
before_select_kwargs = before_select_result.kwargs
normalized_requested_emotion = str(
before_select_kwargs.get("requested_emotion", normalized_requested_emotion) or ""
).strip()
normalized_reasoning = str(before_select_kwargs.get("reasoning", normalized_reasoning) or "").strip()
if isinstance(before_select_kwargs.get("context_texts"), list):
normalized_context_texts = _normalize_context_texts(before_select_kwargs.get("context_texts"))
sample_size = _coerce_positive_int(before_select_kwargs.get("sample_size"), sample_size)
selected_emoji, matched_emotion = await select_emoji_for_maisaka(
requested_emotion=normalized_requested_emotion,
reasoning=normalized_reasoning,
context_texts=normalized_context_texts,
sample_size=sample_size,
)
after_select_result = await _get_runtime_manager().invoke_hook(
"emoji.maisaka.after_select",
stream_id=stream_id,
requested_emotion=normalized_requested_emotion,
reasoning=normalized_reasoning,
context_texts=list(normalized_context_texts),
sample_size=sample_size,
selected_emoji=_serialize_emoji_for_hook(selected_emoji),
selected_emoji_hash=str(selected_emoji.file_hash or "").strip() if selected_emoji is not None else "",
matched_emotion=matched_emotion,
abort_message="表情发送已被 Hook 中止。",
)
if after_select_result.aborted:
abort_message = str(after_select_result.kwargs.get("abort_message") or "表情发送已被 Hook 中止。").strip()
return MaisakaEmojiSendResult(
success=False,
message=abort_message or "表情发送已被 Hook 中止。",
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
after_select_kwargs = after_select_result.kwargs
normalized_requested_emotion = str(
after_select_kwargs.get("requested_emotion", normalized_requested_emotion) or ""
).strip()
matched_emotion = str(after_select_kwargs.get("matched_emotion", matched_emotion) or "").strip()
override_emoji = _resolve_selected_emoji(after_select_kwargs.get("selected_emoji_hash"))
if override_emoji is None:
override_emoji = _resolve_selected_emoji(after_select_kwargs.get("selected_emoji"))
if override_emoji is not None:
selected_emoji = override_emoji
if selected_emoji is None:
return MaisakaEmojiSendResult(
success=False,
message="当前表情包库中没有可用表情。",
requested_emotion=requested_emotion.strip(),
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
try:
@@ -151,7 +286,7 @@ async def send_emoji_for_maisaka(
message=f"发送表情包失败:{exc}",
description=selected_emoji.description.strip(),
emotions=_normalize_emotions(selected_emoji),
requested_emotion=requested_emotion.strip(),
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
@@ -169,7 +304,7 @@ async def send_emoji_for_maisaka(
message=f"发送表情包时发生异常:{exc}",
description=selected_emoji.description.strip(),
emotions=_normalize_emotions(selected_emoji),
requested_emotion=requested_emotion.strip(),
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
@@ -181,7 +316,7 @@ async def send_emoji_for_maisaka(
message="发送表情包失败。",
description=description,
emotions=emotions,
requested_emotion=requested_emotion.strip(),
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
@@ -197,6 +332,6 @@ async def send_emoji_for_maisaka(
emoji_base64=emoji_base64,
description=description,
emotions=emotions,
requested_emotion=requested_emotion.strip(),
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)

View File

@@ -18,28 +18,29 @@ install(extra_lines=3)
logger = get_logger("sender")
# WebUI 聊天室的消息广播器(延迟导入避免循环依赖)
_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str]]] = None
_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str], Optional[str]]] = None
# 虚拟群 ID 前缀(与 chat_routes.py 保持一致)
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
# TODO: 重构完成后完成webui相关
def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str]]:
def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str], Optional[str]]:
"""获取 WebUI 聊天室广播器。
Returns:
Tuple[Any, Optional[str]]: ``(chat_manager, platform_name)`` 元组;
Tuple[Any, Optional[str], Optional[str]]: ``(chat_manager, platform_name, default_group_id)`` 元组;
若 WebUI 相关模块不可用,则元素会退化为 ``None``。
"""
global _webui_chat_broadcaster
if _webui_chat_broadcaster is None:
try:
from src.webui.routers.chat import WEBUI_CHAT_PLATFORM, chat_manager
from src.webui.routers.chat.service import WEBUI_CHAT_GROUP_ID
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM)
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM, WEBUI_CHAT_GROUP_ID)
except ImportError:
_webui_chat_broadcaster = (None, None)
_webui_chat_broadcaster = (None, None, None)
return _webui_chat_broadcaster
@@ -76,7 +77,7 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
try:
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
chat_manager, webui_platform = get_webui_chat_broadcaster()
chat_manager, webui_platform, default_group_id = get_webui_chat_broadcaster()
is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id)
if is_webui_message and chat_manager is not None:
@@ -97,8 +98,9 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
message_type = "rich"
segments = message_segments
await chat_manager.broadcast(
{
await chat_manager.broadcast_to_group(
group_id=group_id or default_group_id or "",
message={
"type": "bot_message",
"content": message.processed_plain_text,
"message_type": message_type,
@@ -110,7 +112,7 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
"avatar": None,
"is_bot": True,
},
}
},
)
# 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库

View File

@@ -1,22 +1,25 @@
from datetime import datetime
from sqlmodel import select
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
import asyncio
import difflib
import json
import re
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.services.llm_service import LLMServiceClient
from src.config.config import global_config
from src.prompt.prompt_manager import prompt_manager
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.common.database.database import get_db_session
from src.common.data_models.expression_data_model import MaiExpression
from sqlmodel import select
from src.chat.utils.utils import is_bot_self
from src.common.data_models.expression_data_model import MaiExpression
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.common.logger import get_logger
from src.common.utils.utils_message import MessageUtils
from src.config.config import global_config
from src.plugin_runtime.hook_schema_utils import build_object_schema
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
from src.prompt.prompt_manager import prompt_manager
from src.services.llm_service import LLMServiceClient
from .expression_utils import check_expression_suitability, parse_expression_response
@@ -34,8 +37,122 @@ summary_model = LLMServiceClient(task_name="utils", request_type="expression.sum
check_model = LLMServiceClient(task_name="utils", request_type="expression.check")
def register_expression_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""注册表达方式系统内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 实际注册的 Hook 规格列表。
"""
return registry.register_hook_specs(
[
HookSpec(
name="expression.select.before_select",
description="表达方式选择流程开始前触发,可改写会话上下文、选择参数或中止本次选择。",
parameters_schema=build_object_schema(
{
"chat_id": {"type": "string", "description": "当前聊天流 ID。"},
"chat_info": {"type": "string", "description": "用于选择表达方式的聊天上下文。"},
"max_num": {"type": "integer", "description": "最大可选表达方式数量。"},
"target_message": {"type": "string", "description": "当前目标回复消息文本。"},
"reply_reason": {"type": "string", "description": "规划器给出的回复理由。"},
"think_level": {"type": "integer", "description": "表达方式选择思考级别。"},
},
required=["chat_id", "chat_info", "max_num", "think_level"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="expression.select.after_selection",
description="表达方式选择完成后触发,可改写最终选中的表达方式列表与 ID。",
parameters_schema=build_object_schema(
{
"chat_id": {"type": "string", "description": "当前聊天流 ID。"},
"chat_info": {"type": "string", "description": "用于选择表达方式的聊天上下文。"},
"max_num": {"type": "integer", "description": "最大可选表达方式数量。"},
"target_message": {"type": "string", "description": "当前目标回复消息文本。"},
"reply_reason": {"type": "string", "description": "规划器给出的回复理由。"},
"think_level": {"type": "integer", "description": "表达方式选择思考级别。"},
"selected_expressions": {
"type": "array",
"items": {"type": "object"},
"description": "当前已选中的表达方式列表。",
},
"selected_expression_ids": {
"type": "array",
"items": {"type": "integer"},
"description": "当前已选中的表达方式 ID 列表。",
},
},
required=[
"chat_id",
"chat_info",
"max_num",
"think_level",
"selected_expressions",
"selected_expression_ids",
],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="expression.learn.after_extract",
description="表达方式学习解析出表达/黑话候选后触发,可改写候选集或直接终止本轮学习。",
parameters_schema=build_object_schema(
{
"session_id": {"type": "string", "description": "当前会话 ID。"},
"message_count": {"type": "integer", "description": "本轮参与学习的消息数量。"},
"expressions": {
"type": "array",
"items": {"type": "object"},
"description": "解析出的表达方式候选列表。",
},
"jargon_entries": {
"type": "array",
"items": {"type": "object"},
"description": "解析出的黑话候选列表。",
},
},
required=["session_id", "message_count", "expressions", "jargon_entries"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="expression.learn.before_upsert",
description="表达方式写入数据库前触发,可改写情景/风格文本或跳过本条写入。",
parameters_schema=build_object_schema(
{
"session_id": {"type": "string", "description": "当前会话 ID。"},
"situation": {"type": "string", "description": "即将写入的情景文本。"},
"style": {"type": "string", "description": "即将写入的风格文本。"},
},
required=["session_id", "situation", "style"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
]
)
class ExpressionLearner:
def __init__(self, session_id: str) -> None:
"""初始化表达方式学习器。
Args:
session_id: 当前会话 ID。
"""
self.session_id = session_id
# 学习锁,防止并发执行学习任务
@@ -44,6 +161,110 @@ class ExpressionLearner:
# 消息缓存
self._messages_cache: List["SessionMessage"] = []
@staticmethod
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
@staticmethod
def _serialize_expressions(expressions: List[Tuple[str, str, str]]) -> List[dict[str, str]]:
"""将表达方式候选序列化为 Hook 载荷。
Args:
expressions: 原始表达方式候选列表。
Returns:
List[dict[str, str]]: 序列化后的表达方式候选。
"""
return [
{
"situation": str(situation).strip(),
"style": str(style).strip(),
"source_id": str(source_id).strip(),
}
for situation, style, source_id in expressions
if str(situation).strip() and str(style).strip()
]
@staticmethod
def _deserialize_expressions(raw_expressions: Any) -> List[Tuple[str, str, str]]:
"""从 Hook 载荷恢复表达方式候选列表。
Args:
raw_expressions: Hook 返回的表达方式候选。
Returns:
List[Tuple[str, str, str]]: 恢复后的表达方式候选列表。
"""
if not isinstance(raw_expressions, list):
return []
normalized_expressions: List[Tuple[str, str, str]] = []
for raw_expression in raw_expressions:
if not isinstance(raw_expression, dict):
continue
situation = str(raw_expression.get("situation") or "").strip()
style = str(raw_expression.get("style") or "").strip()
source_id = str(raw_expression.get("source_id") or "").strip()
if not situation or not style:
continue
normalized_expressions.append((situation, style, source_id))
return normalized_expressions
@staticmethod
def _serialize_jargon_entries(jargon_entries: List[Tuple[str, str]]) -> List[dict[str, str]]:
"""将黑话候选序列化为 Hook 载荷。
Args:
jargon_entries: 原始黑话候选列表。
Returns:
List[dict[str, str]]: 序列化后的黑话候选列表。
"""
return [
{
"content": str(content).strip(),
"source_id": str(source_id).strip(),
}
for content, source_id in jargon_entries
if str(content).strip()
]
@staticmethod
def _deserialize_jargon_entries(raw_jargon_entries: Any) -> List[Tuple[str, str]]:
"""从 Hook 载荷恢复黑话候选列表。
Args:
raw_jargon_entries: Hook 返回的黑话候选列表。
Returns:
List[Tuple[str, str]]: 恢复后的黑话候选列表。
"""
if not isinstance(raw_jargon_entries, list):
return []
normalized_entries: List[Tuple[str, str]] = []
for raw_entry in raw_jargon_entries:
if not isinstance(raw_entry, dict):
continue
content = str(raw_entry.get("content") or "").strip()
source_id = str(raw_entry.get("source_id") or "").strip()
if not content:
continue
normalized_entries.append((content, source_id))
return normalized_entries
def add_messages(self, messages: List["SessionMessage"]) -> None:
"""添加消息到缓存"""
self._messages_cache.extend(messages)
@@ -52,8 +273,12 @@ class ExpressionLearner:
"""获取当前消息缓存的大小"""
return len(self._messages_cache)
async def learn(self, jargon_miner: Optional["JargonMiner"] = None):
"""学习主流程"""
async def learn(self, jargon_miner: Optional["JargonMiner"] = None) -> None:
"""执行表达方式学习主流程
Args:
jargon_miner: 可选的黑话学习器实例,用于同步处理黑话候选。
"""
if not self._messages_cache:
logger.debug("没有消息可供学习,跳过学习过程")
return
@@ -109,6 +334,25 @@ class ExpressionLearner:
logger.info(f"黑话提取数量超过 30 个(实际{len(jargon_entries)}个),放弃本次黑话学习")
jargon_entries = []
after_extract_result = await self._get_runtime_manager().invoke_hook(
"expression.learn.after_extract",
session_id=self.session_id,
message_count=len(self._messages_cache),
expressions=self._serialize_expressions(expressions),
jargon_entries=self._serialize_jargon_entries(jargon_entries),
)
if after_extract_result.aborted:
logger.info(f"{self.session_id} 的表达方式学习结果被 Hook 中止")
return
after_extract_kwargs = after_extract_result.kwargs
raw_expressions = after_extract_kwargs.get("expressions")
if raw_expressions is not None:
expressions = self._deserialize_expressions(raw_expressions)
raw_jargon_entries = after_extract_kwargs.get("jargon_entries")
if raw_jargon_entries is not None:
jargon_entries = self._deserialize_jargon_entries(raw_jargon_entries)
# 处理黑话条目,路由到 jargon_miner即使没有表达方式也要处理黑话
# TODO: 检测是否开启了
if jargon_entries:
@@ -135,6 +379,22 @@ class ExpressionLearner:
# 存储到数据库 Expression 表
for situation, style in learnt_expressions:
before_upsert_result = await self._get_runtime_manager().invoke_hook(
"expression.learn.before_upsert",
session_id=self.session_id,
situation=situation,
style=style,
)
if before_upsert_result.aborted:
logger.info(f"{self.session_id} 的表达方式写入被 Hook 跳过: situation={situation!r}")
continue
upsert_kwargs = before_upsert_result.kwargs
situation = str(upsert_kwargs.get("situation", situation) or "").strip()
style = str(upsert_kwargs.get("style", style) or "").strip()
if not situation or not style:
logger.info(f"{self.session_id} 的表达方式写入被 Hook 清空,已跳过")
continue
await self._upsert_expression_to_db(situation, style)
# ====== 黑话相关 ======

View File

@@ -1,27 +1,109 @@
from typing import Any, Dict, List, Optional, Tuple
import json
import time
from typing import List, Dict, Optional, Any, Tuple
from json_repair import repair_json
from src.services.llm_service import LLMServiceClient
from src.config.config import global_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.common.utils.utils_session import SessionUtils
from src.prompt.prompt_manager import prompt_manager
from src.learners.learner_utils_old import weighted_sample
from src.chat.utils.common_utils import TempMethodsExpression
from src.common.database.database_model import Expression
from src.common.logger import get_logger
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config
from src.learners.learner_utils_old import weighted_sample
from src.prompt.prompt_manager import prompt_manager
from src.services.llm_service import LLMServiceClient
logger = get_logger("expression_selector")
class ExpressionSelector:
def __init__(self):
def __init__(self) -> None:
"""初始化表达方式选择器。"""
self.llm_model = LLMServiceClient(
task_name="utils", request_type="expression.selector"
)
@staticmethod
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
@staticmethod
def _coerce_int(value: Any, default: int) -> int:
"""将任意值安全转换为整数。
Args:
value: 待转换的值。
default: 转换失败时的默认值。
Returns:
int: 转换后的整数结果。
"""
try:
return int(value)
except (TypeError, ValueError):
return default
@staticmethod
def _normalize_selected_expressions(raw_expressions: Any) -> List[Dict[str, Any]]:
"""从 Hook 载荷恢复表达方式选择结果。
Args:
raw_expressions: Hook 返回的表达方式列表。
Returns:
List[Dict[str, Any]]: 恢复后的表达方式列表。
"""
if not isinstance(raw_expressions, list):
return []
normalized_expressions: List[Dict[str, Any]] = []
for raw_expression in raw_expressions:
if not isinstance(raw_expression, dict):
continue
expression_id = raw_expression.get("id")
situation = str(raw_expression.get("situation") or "").strip()
style = str(raw_expression.get("style") or "").strip()
source_id = str(raw_expression.get("source_id") or "").strip()
if not isinstance(expression_id, int) or not situation or not style or not source_id:
continue
normalized_expression = dict(raw_expression)
normalized_expression["id"] = expression_id
normalized_expression["situation"] = situation
normalized_expression["style"] = style
normalized_expression["source_id"] = source_id
normalized_expressions.append(normalized_expression)
return normalized_expressions
@staticmethod
def _normalize_selected_expression_ids(raw_ids: Any, expressions: List[Dict[str, Any]]) -> List[int]:
"""规范化最终选中的表达方式 ID 列表。
Args:
raw_ids: Hook 返回的 ID 列表。
expressions: 当前最终表达方式列表。
Returns:
List[int]: 规范化后的 ID 列表。
"""
if isinstance(raw_ids, list):
normalized_ids = [item for item in raw_ids if isinstance(item, int)]
if normalized_ids:
return normalized_ids
return [expression["id"] for expression in expressions if isinstance(expression.get("id"), int)]
def can_use_expression_for_chat(self, chat_id: str) -> bool:
"""
检查指定聊天流是否允许使用表达
@@ -214,8 +296,7 @@ class ExpressionSelector:
reply_reason: Optional[str] = None,
think_level: int = 1,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
选择适合的表达方式使用classic模式随机选择+LLM选择
"""选择适合的表达方式。
Args:
chat_id: 聊天流ID
@@ -233,11 +314,60 @@ class ExpressionSelector:
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return [], []
before_select_result = await self._get_runtime_manager().invoke_hook(
"expression.select.before_select",
chat_id=chat_id,
chat_info=chat_info,
max_num=max_num,
target_message=target_message or "",
reply_reason=reply_reason or "",
think_level=think_level,
)
if before_select_result.aborted:
logger.info(f"聊天流 {chat_id} 的表达方式选择被 Hook 中止")
return [], []
before_select_kwargs = before_select_result.kwargs
chat_id = str(before_select_kwargs.get("chat_id", chat_id) or "").strip() or chat_id
chat_info = str(before_select_kwargs.get("chat_info", chat_info) or "")
max_num = max(self._coerce_int(before_select_kwargs.get("max_num"), max_num), 1)
raw_target_message = before_select_kwargs.get("target_message", target_message or "")
target_message = str(raw_target_message or "").strip() or None
raw_reply_reason = before_select_kwargs.get("reply_reason", reply_reason or "")
reply_reason = str(raw_reply_reason or "").strip() or None
think_level = self._coerce_int(before_select_kwargs.get("think_level"), think_level)
# 使用classic模式随机选择+LLM选择
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式think_level={think_level}")
return await self._select_expressions_classic(
selected_expressions, selected_ids = await self._select_expressions_classic(
chat_id, chat_info, max_num, target_message, reply_reason, think_level
)
after_selection_result = await self._get_runtime_manager().invoke_hook(
"expression.select.after_selection",
chat_id=chat_id,
chat_info=chat_info,
max_num=max_num,
target_message=target_message or "",
reply_reason=reply_reason or "",
think_level=think_level,
selected_expressions=[dict(item) for item in selected_expressions],
selected_expression_ids=list(selected_ids),
)
if after_selection_result.aborted:
logger.info(f"聊天流 {chat_id} 的表达方式选择结果被 Hook 中止")
return [], []
after_selection_kwargs = after_selection_result.kwargs
raw_selected_expressions = after_selection_kwargs.get("selected_expressions")
if raw_selected_expressions is not None:
selected_expressions = self._normalize_selected_expressions(raw_selected_expressions)
selected_ids = self._normalize_selected_expression_ids(
after_selection_kwargs.get("selected_expression_ids"),
selected_expressions,
)
if selected_expressions:
self.update_expressions_last_active_time(selected_expressions)
return selected_expressions, selected_ids
async def _select_expressions_classic(
self,

View File

@@ -1,5 +1,5 @@
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Set, TypedDict
from typing import Any, Callable, Dict, List, Optional, Set, TypedDict
import asyncio
import json
@@ -9,13 +9,15 @@ from json_repair import repair_json
from sqlmodel import select
from src.common.data_models.jargon_data_model import MaiJargon
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.common.database.database import get_db_session
from src.common.database.database_model import Jargon
from src.common.logger import get_logger
from src.config.config import global_config
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.services.llm_service import LLMServiceClient
from src.plugin_runtime.hook_schema_utils import build_object_schema
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
from src.prompt.prompt_manager import prompt_manager
from src.services.llm_service import LLMServiceClient
from .expression_utils import is_single_char_jargon
@@ -35,8 +37,140 @@ class JargonMeaningEntry(TypedDict):
meaning: str
def register_jargon_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""注册 jargon 系统内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 实际注册的 Hook 规格列表。
"""
return registry.register_hook_specs(
[
HookSpec(
name="jargon.query.before_search",
description="Maisaka 黑话查询工具执行检索前触发,可改写词条列表、检索参数或直接中止。",
parameters_schema=build_object_schema(
{
"words": {
"type": "array",
"items": {"type": "string"},
"description": "准备查询的黑话词条列表。",
},
"session_id": {"type": "string", "description": "当前会话 ID。"},
"limit": {"type": "integer", "description": "单个词条的最大返回条数。"},
"case_sensitive": {"type": "boolean", "description": "是否大小写敏感。"},
"enable_fuzzy_fallback": {"type": "boolean", "description": "是否允许精确命中失败后回退模糊检索。"},
"abort_message": {"type": "string", "description": "Hook 主动中止时的失败提示。"},
},
required=["words", "session_id", "limit", "case_sensitive", "enable_fuzzy_fallback"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="jargon.query.after_search",
description="Maisaka 黑话查询工具完成检索后触发,可改写结果列表或中止返回。",
parameters_schema=build_object_schema(
{
"words": {
"type": "array",
"items": {"type": "string"},
"description": "实际查询的黑话词条列表。",
},
"session_id": {"type": "string", "description": "当前会话 ID。"},
"limit": {"type": "integer", "description": "单个词条的最大返回条数。"},
"case_sensitive": {"type": "boolean", "description": "是否大小写敏感。"},
"enable_fuzzy_fallback": {"type": "boolean", "description": "是否启用了模糊检索回退。"},
"results": {
"type": "array",
"items": {"type": "object"},
"description": "查询结果列表。",
},
"abort_message": {"type": "string", "description": "Hook 主动中止时的失败提示。"},
},
required=["words", "session_id", "limit", "case_sensitive", "enable_fuzzy_fallback", "results"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="jargon.extract.before_persist",
description="黑话条目准备写入数据库前触发,可改写去重后的条目列表或跳过本次持久化。",
parameters_schema=build_object_schema(
{
"session_id": {"type": "string", "description": "当前会话 ID。"},
"session_name": {"type": "string", "description": "当前会话展示名称。"},
"entries": {
"type": "array",
"items": {"type": "object"},
"description": "即将持久化的黑话条目列表。",
},
},
required=["session_id", "session_name", "entries"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="jargon.inference.before_finalize",
description="黑话含义推断完成、写回数据库前触发,可改写最终判定与含义结果。",
parameters_schema=build_object_schema(
{
"session_id": {"type": "string", "description": "当前会话 ID。"},
"session_name": {"type": "string", "description": "当前会话展示名称。"},
"content": {"type": "string", "description": "当前黑话词条。"},
"count": {"type": "integer", "description": "当前词条累计命中次数。"},
"raw_content_list": {
"type": "array",
"items": {"type": "string"},
"description": "用于推断的原始上下文片段列表。",
},
"inference_with_context": {"type": "object", "description": "基于上下文的推断结果。"},
"inference_with_content_only": {"type": "object", "description": "仅基于词条内容的推断结果。"},
"comparison_result": {"type": "object", "description": "比较阶段输出结果。"},
"is_jargon": {"type": "boolean", "description": "当前推断是否判定为黑话。"},
"meaning": {"type": "string", "description": "当前推断出的黑话含义。"},
"is_complete": {"type": "boolean", "description": "当前是否已完成全部推断流程。"},
"last_inference_count": {"type": "integer", "description": "本次推断完成后应写回的 last_inference_count。"},
},
required=[
"session_id",
"session_name",
"content",
"count",
"raw_content_list",
"inference_with_context",
"inference_with_content_only",
"comparison_result",
"is_jargon",
"meaning",
"is_complete",
"last_inference_count",
],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
]
)
class JargonMiner:
def __init__(self, session_id: str, session_name: str) -> None:
"""初始化黑话学习器。
Args:
session_id: 当前会话 ID。
session_name: 当前会话展示名称。
"""
self.session_id = session_id
self.session_name = session_name
@@ -46,13 +180,92 @@ class JargonMiner:
# 黑话提取锁,防止并发执行
self._extraction_lock = asyncio.Lock()
@staticmethod
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
@staticmethod
def _coerce_int(value: Any, default: int) -> int:
"""将任意值安全转换为整数。
Args:
value: 待转换的值。
default: 转换失败时使用的默认值。
Returns:
int: 转换后的整数结果。
"""
try:
return int(value)
except (TypeError, ValueError):
return default
@staticmethod
def _serialize_jargon_entries(entries: List[JargonEntry]) -> List[Dict[str, object]]:
"""将黑话条目列表序列化为 Hook 可传输结构。
Args:
entries: 原始黑话条目列表。
Returns:
List[Dict[str, object]]: 序列化后的条目列表。
"""
return [
{
"content": str(entry["content"]).strip(),
"raw_content": sorted(str(item).strip() for item in entry["raw_content"] if str(item).strip()),
}
for entry in entries
if str(entry["content"]).strip()
]
@staticmethod
def _deserialize_jargon_entries(raw_entries: Any) -> List[JargonEntry]:
"""从 Hook 载荷恢复黑话条目列表。
Args:
raw_entries: Hook 返回的条目数据。
Returns:
List[JargonEntry]: 恢复后的黑话条目列表。
"""
if not isinstance(raw_entries, list):
return []
normalized_entries: List[JargonEntry] = []
for raw_entry in raw_entries:
if not isinstance(raw_entry, dict):
continue
content = str(raw_entry.get("content") or "").strip()
if not content:
continue
raw_content_values = raw_entry.get("raw_content")
raw_content: Set[str] = set()
if isinstance(raw_content_values, list):
raw_content = {str(item).strip() for item in raw_content_values if str(item).strip()}
normalized_entries.append({"content": content, "raw_content": raw_content})
return normalized_entries
def get_cached_jargons(self) -> List[str]:
"""获取缓存中的所有黑话列表"""
return list(self.cache.keys())
async def infer_meaning(self, jargon_obj: MaiJargon) -> None:
"""
对jargon进行含义推断
"""对黑话条目执行含义推断。
Args:
jargon_obj: 待推断的黑话数据对象。
"""
content = jargon_obj.content
# 解析raw_content列表
@@ -175,15 +388,45 @@ class JargonMiner:
is_similar = comparison_result.get("is_similar", False)
is_jargon = not is_similar # 如果相似,说明不是黑话;如果有差异,说明是黑话
finalized_meaning = inference1.get("meaning", "") if is_jargon else ""
is_complete = (jargon_obj.count or 0) >= 100
last_inference_count = jargon_obj.count or 0
finalize_result = await self._get_runtime_manager().invoke_hook(
"jargon.inference.before_finalize",
session_id=self.session_id,
session_name=self.session_name,
content=content,
count=current_count,
raw_content_list=list(raw_content_list),
inference_with_context=dict(inference1),
inference_with_content_only=dict(inference2),
comparison_result=dict(comparison_result),
is_jargon=is_jargon,
meaning=finalized_meaning,
is_complete=is_complete,
last_inference_count=last_inference_count,
)
if finalize_result.aborted:
logger.info(f"jargon {content} 的推断结果被 Hook 中止写回")
return
finalize_kwargs = finalize_result.kwargs
is_jargon = bool(finalize_kwargs.get("is_jargon", is_jargon))
finalized_meaning = str(finalize_kwargs.get("meaning", finalized_meaning) or "").strip() if is_jargon else ""
is_complete = bool(finalize_kwargs.get("is_complete", is_complete))
last_inference_count = self._coerce_int(
finalize_kwargs.get("last_inference_count"),
last_inference_count,
)
# 更新数据库记录
jargon_obj.is_jargon = is_jargon
jargon_obj.meaning = inference1.get("meaning", "") if is_jargon else ""
jargon_obj.meaning = finalized_meaning
# 更新最后一次判定的count值避免重启后重复判定
jargon_obj.last_inference_count = jargon_obj.count or 0
jargon_obj.last_inference_count = last_inference_count
# 如果count>=100标记为完成不再进行推断
if (jargon_obj.count or 0) >= 100:
jargon_obj.is_complete = True
jargon_obj.is_complete = is_complete
try:
self._modify_jargon_entry(jargon_obj)
@@ -232,6 +475,22 @@ class JargonMiner:
merged_entries[content] = {"content": content, "raw_content": set(raw_list)}
uniq_entries: List[JargonEntry] = list(merged_entries.values())
before_persist_result = await self._get_runtime_manager().invoke_hook(
"jargon.extract.before_persist",
session_id=self.session_id,
session_name=self.session_name,
entries=self._serialize_jargon_entries(uniq_entries),
)
if before_persist_result.aborted:
logger.info(f"[{self.session_name}] 黑话提取结果被 Hook 中止,不写入数据库")
return
raw_hook_entries = before_persist_result.kwargs.get("entries")
if raw_hook_entries is not None:
uniq_entries = self._deserialize_jargon_entries(raw_hook_entries)
if not uniq_entries:
logger.info(f"[{self.session_name}] Hook 过滤后没有可写入的黑话条目")
return
saved = 0
updated = 0

View File

@@ -54,6 +54,84 @@ class MaisakaReasoningEngine:
self._runtime = runtime
self._last_reasoning_content: str = ""
@staticmethod
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
@staticmethod
def _normalize_words(raw_words: Any) -> list[str]:
"""清洗黑话查询词条列表。
Args:
raw_words: 原始词条列表。
Returns:
list[str]: 去重去空白后的词条列表。
"""
if not isinstance(raw_words, list):
return []
normalized_words: list[str] = []
seen_words: set[str] = set()
for item in raw_words:
if not isinstance(item, str):
continue
word = item.strip()
if not word or word in seen_words:
continue
seen_words.add(word)
normalized_words.append(word)
return normalized_words
@staticmethod
def _normalize_jargon_query_results(raw_results: Any) -> list[dict[str, object]]:
"""规范化黑话查询结果列表。
Args:
raw_results: Hook 返回的结果列表。
Returns:
list[dict[str, object]]: 清洗后的结果列表。
"""
if not isinstance(raw_results, list):
return []
normalized_results: list[dict[str, object]] = []
for raw_item in raw_results:
if not isinstance(raw_item, dict):
continue
word = str(raw_item.get("word") or "").strip()
matches = raw_item.get("matches")
normalized_matches: list[dict[str, str]] = []
if isinstance(matches, list):
for match in matches:
if not isinstance(match, dict):
continue
content = str(match.get("content") or "").strip()
meaning = str(match.get("meaning") or "").strip()
if not content or not meaning:
continue
normalized_matches.append({"content": content, "meaning": meaning})
normalized_results.append(
{
"word": word,
"found": bool(raw_item.get("found", bool(normalized_matches))),
"matches": normalized_matches,
}
)
return normalized_results
def build_builtin_tool_handlers(self) -> dict[str, "BuiltinToolHandler"]:
"""构造 Maisaka 内置工具处理器映射。
@@ -1012,6 +1090,35 @@ class MaisakaReasoningEngine:
"查询黑话工具至少需要一个非空词条。",
)
limit = 5
case_sensitive = False
enable_fuzzy_fallback = True
before_search_result = await self._get_runtime_manager().invoke_hook(
"jargon.query.before_search",
words=list(words),
session_id=self._runtime.session_id,
limit=limit,
case_sensitive=case_sensitive,
enable_fuzzy_fallback=enable_fuzzy_fallback,
abort_message="黑话查询已被 Hook 中止。",
)
if before_search_result.aborted:
abort_message = str(before_search_result.kwargs.get("abort_message") or "黑话查询已被 Hook 中止。").strip()
return self._build_tool_failure_result(tool_call.func_name, abort_message or "黑话查询已被 Hook 中止。")
before_search_kwargs = before_search_result.kwargs
if before_search_kwargs.get("words") is not None:
words = self._normalize_words(before_search_kwargs.get("words"))
if not words:
return self._build_tool_failure_result(tool_call.func_name, "Hook 过滤后没有可查询的黑话词条。")
try:
limit = int(before_search_kwargs.get("limit", limit))
except (TypeError, ValueError):
limit = 5
limit = max(limit, 1)
case_sensitive = bool(before_search_kwargs.get("case_sensitive", case_sensitive))
enable_fuzzy_fallback = bool(before_search_kwargs.get("enable_fuzzy_fallback", enable_fuzzy_fallback))
logger.info(f"{self._runtime.log_prefix} 已触发黑话查询: 词条={words!r}")
results: list[dict[str, object]] = []
@@ -1019,17 +1126,19 @@ class MaisakaReasoningEngine:
exact_matches = search_jargon(
keyword=word,
chat_id=self._runtime.session_id,
limit=5,
case_sensitive=False,
limit=limit,
case_sensitive=case_sensitive,
fuzzy=False,
)
matched_entries = exact_matches or search_jargon(
keyword=word,
chat_id=self._runtime.session_id,
limit=5,
case_sensitive=False,
fuzzy=True,
)
matched_entries = exact_matches
if not matched_entries and enable_fuzzy_fallback:
matched_entries = search_jargon(
keyword=word,
chat_id=self._runtime.session_id,
limit=limit,
case_sensitive=case_sensitive,
fuzzy=True,
)
results.append(
{
@@ -1039,6 +1148,27 @@ class MaisakaReasoningEngine:
}
)
after_search_result = await self._get_runtime_manager().invoke_hook(
"jargon.query.after_search",
words=list(words),
session_id=self._runtime.session_id,
limit=limit,
case_sensitive=case_sensitive,
enable_fuzzy_fallback=enable_fuzzy_fallback,
results=list(results),
abort_message="黑话查询结果已被 Hook 中止。",
)
if after_search_result.aborted:
abort_message = str(after_search_result.kwargs.get("abort_message") or "黑话查询结果已被 Hook 中止。").strip()
return self._build_tool_failure_result(
tool_call.func_name,
abort_message or "黑话查询结果已被 Hook 中止。",
)
raw_results = after_search_result.kwargs.get("results")
if raw_results is not None:
results = self._normalize_jargon_query_results(raw_results)
logger.info(f"{self._runtime.log_prefix} 黑话查询完成: 结果={results!r}")
return self._build_tool_success_result(
tool_call.func_name,

View File

@@ -20,11 +20,17 @@ def _get_builtin_hook_spec_registrars() -> List[HookSpecRegistrar]:
"""
from src.chat.message_receive.bot import register_chat_hook_specs
from src.chat.emoji_system.emoji_manager import register_emoji_hook_specs
from src.learners.expression_learner import register_expression_hook_specs
from src.learners.jargon_miner import register_jargon_hook_specs
from src.maisaka.chat_loop_service import register_maisaka_hook_specs
from src.services.send_service import register_send_service_hook_specs
return [
register_chat_hook_specs,
register_emoji_hook_specs,
register_jargon_hook_specs,
register_expression_hook_specs,
register_send_service_hook_specs,
register_maisaka_hook_specs,
]

View File

@@ -9,6 +9,7 @@ from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
from src.common.logger import get_logger
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
from src.webui.routers.websocket.manager import websocket_manager
logger = get_logger("webui.logs_ws")
router = APIRouter()
@@ -148,24 +149,9 @@ async def broadcast_log(log_data: Dict):
Args:
log_data: 日志数据字典
"""
if not active_connections:
return
# 格式化为 JSON
message = json.dumps(log_data, ensure_ascii=False)
# 记录需要断开的连接
disconnected = set()
# 广播到所有客户端
for connection in active_connections:
try:
await connection.send_text(message)
except Exception:
# 发送失败,标记为断开
disconnected.add(connection)
# 清理断开的连接
if disconnected:
active_connections.difference_update(disconnected)
logger.debug(f"清理了 {len(disconnected)} 个断开的 WebSocket 连接")
await websocket_manager.broadcast_to_topic(
domain="logs",
topic="main",
event="entry",
data={"entry": log_data},
)

View File

@@ -18,12 +18,10 @@ def get_all_routers() -> List[APIRouter]:
from src.webui.api.replier import router as replier_router
from src.webui.routers.chat import router as chat_router
from src.webui.routers.knowledge import router as knowledge_router
from src.webui.routers.websocket.logs import router as logs_router
from src.webui.routes import router as main_router
return [
main_router,
logs_router,
knowledge_router,
chat_router,
planner_router,

View File

@@ -1,7 +1,7 @@
from typing import Tuple
from .routes import router
from .support import WEBUI_CHAT_PLATFORM, ChatConnectionManager, chat_manager
from .service import WEBUI_CHAT_PLATFORM, ChatConnectionManager, chat_manager
def get_webui_chat_broadcaster() -> Tuple[ChatConnectionManager, str]:

View File

@@ -1,9 +1,8 @@
"""本地聊天室路由 - WebUI 与麦麦直接对话。"""
import uuid
from typing import Dict, Optional
from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect
from fastapi import APIRouter, Depends, Query
from sqlalchemy import case, func
from sqlmodel import col, select
@@ -13,16 +12,11 @@ from src.common.logger import get_logger
from src.config.config import global_config
from src.webui.dependencies import require_auth
from .support import (
from .service import (
WEBUI_CHAT_GROUP_ID,
WEBUI_CHAT_PLATFORM,
authenticate_chat_websocket,
chat_history,
chat_manager,
dispatch_chat_event,
normalize_webui_user_id,
resolve_initial_virtual_identity,
send_initial_chat_state,
)
logger = get_logger("webui.chat")
@@ -113,55 +107,6 @@ async def clear_chat_history(
return {"success": True, "message": f"已清空 {deleted} 条聊天记录"}
@router.websocket("/ws")
async def websocket_chat(
websocket: WebSocket,
user_id: Optional[str] = Query(default=None),
user_name: Optional[str] = Query(default="WebUI用户"),
platform: Optional[str] = Query(default=None),
person_id: Optional[str] = Query(default=None),
group_name: Optional[str] = Query(default=None),
group_id: Optional[str] = Query(default=None),
token: Optional[str] = Query(default=None),
) -> None:
"""WebSocket 聊天端点。"""
if not await authenticate_chat_websocket(websocket, token):
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
session_id = str(uuid.uuid4())
normalized_user_id = normalize_webui_user_id(user_id)
current_user_name = user_name or "WebUI用户"
current_virtual_config = resolve_initial_virtual_identity(platform, person_id, group_name, group_id)
await chat_manager.connect(websocket, session_id, normalized_user_id)
try:
await send_initial_chat_state(
session_id=session_id,
user_id=normalized_user_id,
user_name=current_user_name,
virtual_config=current_virtual_config,
)
while True:
data = await websocket.receive_json()
current_user_name, current_virtual_config = await dispatch_chat_event(
session_id=session_id,
session_id_prefix=session_id[:8],
data=data,
current_user_name=current_user_name,
normalized_user_id=normalized_user_id,
current_virtual_config=current_virtual_config,
)
except WebSocketDisconnect:
logger.info(f"WebSocket 断开: session={session_id}, user={normalized_user_id}")
except Exception as e:
logger.error(f"WebSocket 错误: {e}")
finally:
chat_manager.disconnect(session_id, normalized_user_id)
@router.get("/info")
async def get_chat_info() -> Dict[str, object]:
"""获取聊天室信息。"""

View File

@@ -1,10 +1,10 @@
"""WebUI 聊天路由支持逻辑"""
"""WebUI 聊天运行时服务"""
from dataclasses import dataclass
import time
import uuid
from typing import Any, Dict, List, Optional, Tuple, cast
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, cast
from fastapi import WebSocket
from pydantic import BaseModel
from sqlmodel import col, delete, select
@@ -17,8 +17,6 @@ from src.common.logger import get_logger
from src.common.message_repository import find_messages
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
logger = get_logger("webui.chat")
@@ -27,6 +25,8 @@ WEBUI_CHAT_PLATFORM = "webui"
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
WEBUI_USER_ID_PREFIX = "webui_user_"
AsyncMessageSender = Callable[[Dict[str, Any]], Awaitable[None]]
class VirtualIdentityConfig(BaseModel):
"""虚拟身份配置。"""
@@ -52,13 +52,42 @@ class ChatHistoryMessage(BaseModel):
is_bot: bool = False
@dataclass
class ChatSessionConnection:
"""逻辑聊天会话连接信息。"""
session_id: str
connection_id: str
client_session_id: str
user_id: str
user_name: str
active_group_id: str
virtual_config: Optional[VirtualIdentityConfig]
sender: AsyncMessageSender
class ChatHistoryManager:
"""聊天历史管理器。"""
def __init__(self, max_messages: int = 200) -> None:
"""初始化聊天历史管理器。
Args:
max_messages: 内存中允许处理的最大消息数
"""
self.max_messages = max_messages
def _message_to_dict(self, msg: SessionMessage, group_id: Optional[str] = None) -> Dict[str, Any]:
"""将内部消息对象转换为前端可消费的字典。
Args:
msg: 内部统一消息对象
group_id: 当前会话所属的群组标识
Returns:
Dict[str, Any]: 面向 WebUI 的消息字典
"""
del group_id
user_info = msg.message_info.user_info
user_id = user_info.user_id or ""
is_bot = is_bot_self(msg.platform, user_id)
@@ -74,10 +103,27 @@ class ChatHistoryManager:
}
def _resolve_session_id(self, group_id: Optional[str]) -> str:
"""根据群组标识解析聊天会话 ID。
Args:
group_id: 群组标识
Returns:
str: 内部聊天会话 ID
"""
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
return SessionUtils.calculate_session_id(WEBUI_CHAT_PLATFORM, group_id=target_group_id)
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""获取指定会话的历史消息。
Args:
limit: 最大返回条数
group_id: 群组标识
Returns:
List[Dict[str, Any]]: 历史消息列表
"""
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
session_id = self._resolve_session_id(target_group_id)
try:
@@ -90,11 +136,19 @@ class ChatHistoryManager:
result = [self._message_to_dict(msg, target_group_id) for msg in messages]
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录 (group_id={target_group_id})")
return result
except Exception as e:
logger.error(f"从数据库加载聊天记录失败: {e}")
except Exception as exc:
logger.error(f"从数据库加载聊天记录失败: {exc}")
return []
def clear_history(self, group_id: Optional[str] = None) -> int:
"""清空指定会话的历史消息。
Args:
group_id: 群组标识
Returns:
int: 被删除的消息数量
"""
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
session_id = self._resolve_session_id(target_group_id)
try:
@@ -104,66 +158,245 @@ class ChatHistoryManager:
deleted = result.rowcount or 0
logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
return deleted
except Exception as e:
logger.error(f"清空聊天记录失败: {e}")
except Exception as exc:
logger.error(f"清空聊天记录失败: {exc}")
return 0
class ChatConnectionManager:
"""聊天连接管理器。"""
"""统一聊天逻辑会话管理器。"""
def __init__(self) -> None:
self.active_connections: Dict[str, WebSocket] = {}
self.user_sessions: Dict[str, str] = {}
"""初始化聊天逻辑会话管理器。"""
self.active_connections: Dict[str, ChatSessionConnection] = {}
self.client_sessions: Dict[Tuple[str, str], str] = {}
self.connection_sessions: Dict[str, Set[str]] = {}
self.group_sessions: Dict[str, Set[str]] = {}
self.user_sessions: Dict[str, Set[str]] = {}
async def connect(self, websocket: WebSocket, session_id: str, user_id: str) -> None:
await websocket.accept()
self.active_connections[session_id] = websocket
self.user_sessions[user_id] = session_id
logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}")
def _bind_group(self, session_id: str, group_id: str) -> None:
"""为会话绑定群组索引。
def disconnect(self, session_id: str, user_id: str) -> None:
if session_id in self.active_connections:
del self.active_connections[session_id]
if user_id in self.user_sessions and self.user_sessions[user_id] == session_id:
del self.user_sessions[user_id]
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
Args:
session_id: 内部会话 ID
group_id: 群组标识
"""
group_session_ids = self.group_sessions.setdefault(group_id, set())
group_session_ids.add(session_id)
def _unbind_group(self, session_id: str, group_id: str) -> None:
"""移除会话与群组的索引关系。
Args:
session_id: 内部会话 ID
group_id: 群组标识
"""
group_session_ids = self.group_sessions.get(group_id)
if group_session_ids is None:
return
group_session_ids.discard(session_id)
if not group_session_ids:
del self.group_sessions[group_id]
async def connect(
self,
session_id: str,
connection_id: str,
client_session_id: str,
user_id: str,
user_name: str,
virtual_config: Optional[VirtualIdentityConfig],
sender: AsyncMessageSender,
) -> None:
"""注册一个新的逻辑聊天会话。
Args:
session_id: 内部逻辑会话 ID
connection_id: 物理 WebSocket 连接 ID
client_session_id: 前端标签页使用的会话 ID
user_id: 规范化后的用户 ID
user_name: 当前展示昵称
virtual_config: 当前虚拟身份配置
sender: 发送消息到前端的异步回调
"""
existing_session_id = self.client_sessions.get((connection_id, client_session_id))
if existing_session_id is not None:
self.disconnect(existing_session_id)
active_group_id = get_current_group_id(virtual_config)
session_connection = ChatSessionConnection(
session_id=session_id,
connection_id=connection_id,
client_session_id=client_session_id,
user_id=user_id,
user_name=user_name,
active_group_id=active_group_id,
virtual_config=virtual_config,
sender=sender,
)
self.active_connections[session_id] = session_connection
self.client_sessions[(connection_id, client_session_id)] = session_id
self.connection_sessions.setdefault(connection_id, set()).add(session_id)
self.user_sessions.setdefault(user_id, set()).add(session_id)
self._bind_group(session_id, active_group_id)
logger.info(
"WebUI 聊天会话已连接: session=%s, connection=%s, client_session=%s, user=%s, group=%s",
session_id,
connection_id,
client_session_id,
user_id,
active_group_id,
)
def disconnect(self, session_id: str) -> None:
"""断开一个逻辑聊天会话。
Args:
session_id: 内部逻辑会话 ID
"""
session_connection = self.active_connections.pop(session_id, None)
if session_connection is None:
return
self.client_sessions.pop((session_connection.connection_id, session_connection.client_session_id), None)
self._unbind_group(session_id, session_connection.active_group_id)
connection_session_ids = self.connection_sessions.get(session_connection.connection_id)
if connection_session_ids is not None:
connection_session_ids.discard(session_id)
if not connection_session_ids:
del self.connection_sessions[session_connection.connection_id]
user_session_ids = self.user_sessions.get(session_connection.user_id)
if user_session_ids is not None:
user_session_ids.discard(session_id)
if not user_session_ids:
del self.user_sessions[session_connection.user_id]
logger.info("WebUI 聊天会话已断开: session=%s", session_id)
def disconnect_connection(self, connection_id: str) -> None:
"""断开物理连接下的全部逻辑聊天会话。
Args:
connection_id: 物理 WebSocket 连接 ID
"""
session_ids = list(self.connection_sessions.get(connection_id, set()))
for session_id in session_ids:
self.disconnect(session_id)
def get_session(self, session_id: str) -> Optional[ChatSessionConnection]:
"""获取逻辑聊天会话信息。
Args:
session_id: 内部逻辑会话 ID
Returns:
Optional[ChatSessionConnection]: 会话存在时返回对应信息
"""
return self.active_connections.get(session_id)
def get_session_id(self, connection_id: str, client_session_id: str) -> Optional[str]:
"""根据连接 ID 和前端会话 ID 查询内部会话 ID。
Args:
connection_id: 物理 WebSocket 连接 ID
client_session_id: 前端标签页使用的会话 ID
Returns:
Optional[str]: 找到时返回内部会话 ID
"""
return self.client_sessions.get((connection_id, client_session_id))
def update_session_context(
self,
session_id: str,
user_name: str,
virtual_config: Optional[VirtualIdentityConfig],
) -> None:
"""更新会话上下文信息。
Args:
session_id: 内部逻辑会话 ID
user_name: 最新昵称
virtual_config: 最新虚拟身份配置
"""
session_connection = self.active_connections.get(session_id)
if session_connection is None:
return
next_group_id = get_current_group_id(virtual_config)
if next_group_id != session_connection.active_group_id:
self._unbind_group(session_id, session_connection.active_group_id)
self._bind_group(session_id, next_group_id)
session_connection.active_group_id = next_group_id
session_connection.user_name = user_name
session_connection.virtual_config = virtual_config
async def send_message(self, session_id: str, message: Dict[str, Any]) -> None:
if session_id in self.active_connections:
try:
await self.active_connections[session_id].send_json(message)
except Exception as e:
logger.error(f"发送消息失败: {e}")
"""向指定逻辑会话发送消息。
Args:
session_id: 内部逻辑会话 ID
message: 发送消息内容
"""
session_connection = self.active_connections.get(session_id)
if session_connection is None:
return
try:
await session_connection.sender(message)
except Exception as exc:
logger.error("发送聊天消息失败: session=%s, error=%s", session_id, exc)
async def broadcast(self, message: Dict[str, Any]) -> None:
"""向全部逻辑聊天会话广播消息。
Args:
message: 待广播的消息内容
"""
for session_id in list(self.active_connections.keys()):
await self.send_message(session_id, message)
async def broadcast_to_group(self, group_id: str, message: Dict[str, Any]) -> None:
"""向指定群组下的全部逻辑会话广播消息。
Args:
group_id: 群组标识
message: 待广播的消息内容
"""
for session_id in list(self.group_sessions.get(group_id, set())):
await self.send_message(session_id, message)
chat_history = ChatHistoryManager()
chat_manager = ChatConnectionManager()
def is_virtual_mode_enabled(virtual_config: Optional[VirtualIdentityConfig]) -> bool:
"""判断当前是否启用了虚拟身份模式。
Args:
virtual_config: 虚拟身份配置
Returns:
bool: 已启用时返回 ``True``
"""
return bool(virtual_config and virtual_config.enabled)
async def authenticate_chat_websocket(websocket: WebSocket, token: Optional[str]) -> bool:
if token and verify_ws_token(token):
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
return True
if cookie_token := websocket.cookies.get("maibot_session"):
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
return True
return False
def normalize_webui_user_id(user_id: Optional[str]) -> str:
"""标准化 WebUI 用户 ID。
Args:
user_id: 原始用户 ID
Returns:
str: 带统一前缀的用户 ID
"""
if not user_id:
return f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}"
if user_id.startswith(WEBUI_USER_ID_PREFIX):
@@ -172,12 +405,30 @@ def normalize_webui_user_id(user_id: Optional[str]) -> str:
def get_person_by_person_id(person_id: str) -> Optional[PersonInfo]:
"""根据人物 ID 查询人物信息。
Args:
person_id: 人物 ID
Returns:
Optional[PersonInfo]: 查到时返回人物信息
"""
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
return session.exec(statement).first()
def build_virtual_identity_config(person: PersonInfo, group_id: str, group_name: str) -> VirtualIdentityConfig:
"""根据人物信息构建虚拟身份配置。
Args:
person: 人物信息对象
group_id: 逻辑群组 ID
group_name: 逻辑群组名称
Returns:
VirtualIdentityConfig: 虚拟身份配置对象
"""
return VirtualIdentityConfig(
enabled=True,
platform=person.platform,
@@ -195,6 +446,17 @@ def resolve_initial_virtual_identity(
group_name: Optional[str],
group_id: Optional[str],
) -> Optional[VirtualIdentityConfig]:
"""根据初始参数解析虚拟身份配置。
Args:
platform: 平台名称
person_id: 人物 ID
group_name: 群组名称
group_id: 群组 ID
Returns:
Optional[VirtualIdentityConfig]: 解析成功时返回虚拟身份配置
"""
if not (platform and person_id):
return None
@@ -210,11 +472,14 @@ def resolve_initial_virtual_identity(
group_name=group_name or "WebUI虚拟群聊",
)
logger.info(
f"虚拟身份模式已通过 URL 参数激活: {virtual_config.user_nickname} @ {virtual_config.platform}, group_id={virtual_group_id}"
"虚拟身份模式已通过参数激活: %s @ %s, group_id=%s",
virtual_config.user_nickname,
virtual_config.platform,
virtual_group_id,
)
return virtual_config
except Exception as e:
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
except Exception as exc:
logger.warning(f"通过参数配置虚拟身份失败: {exc}")
return None
@@ -224,6 +489,17 @@ def build_session_info_message(
user_name: str,
virtual_config: Optional[VirtualIdentityConfig],
) -> Dict[str, Any]:
"""构建会话信息消息。
Args:
session_id: 内部逻辑会话 ID
user_id: 规范化后的用户 ID
user_name: 当前昵称
virtual_config: 虚拟身份配置
Returns:
Dict[str, Any]: 会话信息消息
"""
session_info_data: Dict[str, Any] = {
"type": "session_info",
"session_id": session_id,
@@ -247,13 +523,41 @@ def build_session_info_message(
def get_active_history_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> Optional[str]:
"""获取当前虚拟身份对应的历史群组 ID。
Args:
virtual_config: 虚拟身份配置
Returns:
Optional[str]: 虚拟身份启用时返回对应群组 ID
"""
if is_virtual_mode_enabled(virtual_config):
assert virtual_config is not None
return virtual_config.group_id
return None
def get_current_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> str:
"""获取当前会话的有效群组 ID。
Args:
virtual_config: 虚拟身份配置
Returns:
str: 当前会话应使用的群组 ID
"""
return get_active_history_group_id(virtual_config) or WEBUI_CHAT_GROUP_ID
def build_welcome_message(virtual_config: Optional[VirtualIdentityConfig]) -> str:
"""构建欢迎消息。
Args:
virtual_config: 虚拟身份配置
Returns:
str: 欢迎消息文本
"""
if is_virtual_mode_enabled(virtual_config):
assert virtual_config is not None
return (
@@ -264,6 +568,12 @@ def build_welcome_message(virtual_config: Optional[VirtualIdentityConfig]) -> st
async def send_chat_error(session_id: str, content: str) -> None:
"""向指定会话发送错误消息。
Args:
session_id: 内部逻辑会话 ID
content: 错误消息内容
"""
await chat_manager.send_message(
session_id,
{
@@ -279,7 +589,17 @@ async def send_initial_chat_state(
user_id: str,
user_name: str,
virtual_config: Optional[VirtualIdentityConfig],
include_welcome: bool = True,
) -> None:
"""向新会话发送初始化状态。
Args:
session_id: 内部逻辑会话 ID
user_id: 规范化后的用户 ID
user_name: 当前昵称
virtual_config: 虚拟身份配置
include_welcome: 是否发送欢迎消息
"""
await chat_manager.send_message(
session_id,
build_session_info_message(
@@ -290,30 +610,43 @@ async def send_initial_chat_state(
),
)
if history := chat_history.get_history(50, get_active_history_group_id(virtual_config)):
await chat_manager.send_message(
session_id,
{
"type": "history",
"messages": history,
},
)
history_group_id = get_active_history_group_id(virtual_config)
history = chat_history.get_history(50, history_group_id)
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": build_welcome_message(virtual_config),
"timestamp": time.time(),
"type": "history",
"messages": history,
"group_id": get_current_group_id(virtual_config),
},
)
if include_welcome:
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": build_welcome_message(virtual_config),
"timestamp": time.time(),
},
)
def resolve_sender_identity(
current_user_name: str,
normalized_user_id: str,
virtual_config: Optional[VirtualIdentityConfig],
) -> Tuple[str, str]:
"""解析当前发送者身份。
Args:
current_user_name: 当前昵称
normalized_user_id: 规范化后的用户 ID
virtual_config: 虚拟身份配置
Returns:
Tuple[str, str]: ``(发送者昵称, 发送者用户 ID)``
"""
if is_virtual_mode_enabled(virtual_config):
assert virtual_config is not None
return virtual_config.user_nickname or current_user_name, virtual_config.user_id or normalized_user_id
@@ -328,6 +661,19 @@ def create_message_data(
is_at_bot: bool = True,
virtual_config: Optional[VirtualIdentityConfig] = None,
) -> Dict[str, Any]:
"""构建发送给聊天核心的消息数据。
Args:
content: 文本内容
user_id: 用户 ID
user_name: 用户昵称
message_id: 消息 ID
is_at_bot: 是否默认艾特机器人
virtual_config: 虚拟身份配置
Returns:
Dict[str, Any]: 聊天核心可处理的消息数据
"""
if message_id is None:
message_id = str(uuid.uuid4())
@@ -389,6 +735,18 @@ async def handle_chat_message(
normalized_user_id: str,
current_virtual_config: Optional[VirtualIdentityConfig],
) -> str:
"""处理用户发送的聊天消息。
Args:
session_id: 内部逻辑会话 ID
data: 前端提交的消息数据
current_user_name: 当前昵称
normalized_user_id: 规范化后的用户 ID
current_virtual_config: 当前虚拟身份配置
Returns:
str: 处理后的最新昵称
"""
content = str(data.get("content", "")).strip()
if not content:
return current_user_name
@@ -401,11 +759,14 @@ async def handle_chat_message(
normalized_user_id=normalized_user_id,
virtual_config=current_virtual_config,
)
target_group_id = get_current_group_id(current_virtual_config)
await chat_manager.broadcast(
await chat_manager.broadcast_to_group(
target_group_id,
{
"type": "user_message",
"content": content,
"group_id": target_group_id,
"message_id": message_id,
"timestamp": timestamp,
"sender": {
@@ -414,7 +775,7 @@ async def handle_chat_message(
"is_bot": False,
},
"virtual_mode": is_virtual_mode_enabled(current_virtual_config),
}
},
)
message_data = create_message_data(
@@ -427,22 +788,37 @@ async def handle_chat_message(
)
try:
await chat_manager.broadcast({"type": "typing", "is_typing": True})
await chat_manager.broadcast_to_group(target_group_id, {"type": "typing", "is_typing": True})
await chat_bot.message_process(message_data)
except Exception as e:
logger.error(f"处理消息时出错: {e}")
await send_chat_error(session_id, f"处理消息时出错: {str(e)}")
except Exception as exc:
logger.error(f"处理消息时出错: {exc}")
await send_chat_error(session_id, f"处理消息时出错: {str(exc)}")
finally:
await chat_manager.broadcast({"type": "typing", "is_typing": False})
await chat_manager.broadcast_to_group(target_group_id, {"type": "typing", "is_typing": False})
return next_user_name
async def handle_chat_ping(session_id: str) -> None:
"""处理聊天心跳。
Args:
session_id: 内部逻辑会话 ID
"""
await chat_manager.send_message(session_id, {"type": "pong", "timestamp": time.time()})
async def handle_nickname_update(session_id: str, data: Dict[str, Any], current_user_name: str) -> str:
"""处理昵称更新请求。
Args:
session_id: 内部逻辑会话 ID
data: 前端提交的数据
current_user_name: 当前昵称
Returns:
str: 更新后的昵称
"""
new_name = str(data.get("user_name", "")).strip()
if not new_name:
return current_user_name
@@ -463,6 +839,16 @@ async def enable_virtual_identity(
session_prefix: str,
virtual_data: Dict[str, Any],
) -> Optional[VirtualIdentityConfig]:
"""启用虚拟身份模式。
Args:
session_id: 内部逻辑会话 ID
session_prefix: 会话前缀用于生成默认群组 ID
virtual_data: 前端提交的虚拟身份配置
Returns:
Optional[VirtualIdentityConfig]: 启用成功时返回新的虚拟身份配置
"""
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
await send_chat_error(session_id, "虚拟身份配置缺少必要字段: platform 和 person_id")
return None
@@ -470,16 +856,18 @@ async def enable_virtual_identity(
person_id_value = str(virtual_data.get("person_id"))
try:
person = get_person_by_person_id(person_id_value)
if not person:
if person is None:
await send_chat_error(session_id, f"找不到用户: {person_id_value}")
return None
custom_group_id = virtual_data.get("group_id")
current_group_id = (
f"{VIRTUAL_GROUP_ID_PREFIX}{custom_group_id}"
if custom_group_id
else f"{VIRTUAL_GROUP_ID_PREFIX}{session_prefix}"
)
custom_group_id = str(virtual_data.get("group_id") or "").strip()
if custom_group_id:
current_group_id = custom_group_id
if not current_group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
current_group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{current_group_id}"
else:
current_group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{session_prefix}"
current_virtual_config = build_virtual_identity_config(
person=person,
group_id=current_group_id,
@@ -521,13 +909,18 @@ async def enable_virtual_identity(
},
)
return current_virtual_config
except Exception as e:
logger.error(f"设置虚拟身份失败: {e}")
await send_chat_error(session_id, f"设置虚拟身份失败: {str(e)}")
except Exception as exc:
logger.error(f"设置虚拟身份失败: {exc}")
await send_chat_error(session_id, f"设置虚拟身份失败: {str(exc)}")
return None
async def disable_virtual_identity(session_id: str) -> None:
"""关闭虚拟身份模式。
Args:
session_id: 内部逻辑会话 ID
"""
await chat_manager.send_message(
session_id,
{
@@ -560,7 +953,18 @@ async def handle_virtual_identity_update(
data: Dict[str, Any],
current_virtual_config: Optional[VirtualIdentityConfig],
) -> Optional[VirtualIdentityConfig]:
virtual_data = cast(dict[str, Any], data.get("config", {}))
"""处理虚拟身份切换请求。
Args:
session_id: 内部逻辑会话 ID
session_id_prefix: 会话前缀
data: 前端提交的数据
current_virtual_config: 当前虚拟身份配置
Returns:
Optional[VirtualIdentityConfig]: 更新后的虚拟身份配置
"""
virtual_data = cast(Dict[str, Any], data.get("config", {}))
if virtual_data.get("enabled"):
next_config = await enable_virtual_identity(session_id, session_id_prefix, virtual_data)
return next_config if next_config is not None else current_virtual_config
@@ -577,6 +981,19 @@ async def dispatch_chat_event(
normalized_user_id: str,
current_virtual_config: Optional[VirtualIdentityConfig],
) -> Tuple[str, Optional[VirtualIdentityConfig]]:
"""分发聊天事件到对应的处理器。
Args:
session_id: 内部逻辑会话 ID
session_id_prefix: 会话前缀
data: 前端提交的数据
current_user_name: 当前昵称
normalized_user_id: 规范化后的用户 ID
current_virtual_config: 当前虚拟身份配置
Returns:
Tuple[str, Optional[VirtualIdentityConfig]]: ``(最新昵称, 最新虚拟身份配置)``
"""
event_type = data.get("type")
if event_type == "message":
next_user_name = await handle_chat_message(

View File

@@ -1,12 +1,15 @@
"""插件进度实时推送支持。"""
from typing import Any, Dict, Optional, Set
import asyncio
import json
from typing import Any, Dict, Optional, Set
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
from src.common.logger import get_logger
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
from src.webui.routers.websocket.manager import websocket_manager
logger = get_logger("webui.plugin_progress")
@@ -25,25 +28,29 @@ current_progress: Dict[str, Any] = {
}
def get_current_progress() -> Dict[str, Any]:
"""获取当前插件进度快照。
Returns:
Dict[str, Any]: 当前插件进度数据副本。
"""
return current_progress.copy()
async def broadcast_progress(progress_data: Dict[str, Any]) -> None:
"""向统一连接层广播插件进度更新。
Args:
progress_data: 插件进度数据。
"""
global current_progress
current_progress = progress_data.copy()
if not active_connections:
return
message = json.dumps(progress_data, ensure_ascii=False)
disconnected: Set[WebSocket] = set()
for websocket in active_connections:
try:
await websocket.send_text(message)
except Exception as e:
logger.error(f"发送进度更新失败: {e}")
disconnected.add(websocket)
for websocket in disconnected:
active_connections.discard(websocket)
await websocket_manager.broadcast_to_topic(
domain="plugin_progress",
topic="main",
event="update",
data={"progress": progress_data},
)
async def update_progress(
@@ -56,6 +63,18 @@ async def update_progress(
total_plugins: int = 0,
loaded_plugins: int = 0,
) -> None:
"""更新当前插件进度并广播。
Args:
stage: 当前阶段。
progress: 当前进度百分比。
message: 进度说明消息。
operation: 当前操作类型。
error: 可选的错误信息。
plugin_id: 当前处理的插件 ID。
total_plugins: 总插件数量。
loaded_plugins: 已处理插件数量。
"""
progress_data = {
"operation": operation,
"stage": stage,
@@ -74,6 +93,12 @@ async def update_progress(
@router.websocket("/ws/plugin-progress")
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)) -> None:
"""旧版插件进度 WebSocket 入口。
Args:
websocket: FastAPI WebSocket 对象。
token: 可选的一次性握手 Token。
"""
is_authenticated = False
if token and verify_ws_token(token):
@@ -105,17 +130,22 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
data = await websocket.receive_text()
if data == "ping":
await websocket.send_text("pong")
except Exception as e:
logger.error(f"处理客户端消息时出错: {e}")
except Exception as exc:
logger.error(f"处理客户端消息时出错: {exc}")
break
except WebSocketDisconnect:
active_connections.discard(websocket)
logger.info(f"📡 插件进度 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
except Exception as e:
logger.error(f"❌ WebSocket 错误: {e}")
except Exception as exc:
logger.error(f"❌ WebSocket 错误: {exc}")
active_connections.discard(websocket)
def get_progress_router() -> APIRouter:
"""获取旧版插件进度路由对象。
Returns:
APIRouter: 插件进度路由对象。
"""
return router

View File

@@ -1,7 +1,9 @@
"""WebSocket 路由聚合导出。"""
from .auth import router as ws_auth_router
from .logs import router as logs_router
from .unified import router as unified_ws_router
__all__ = [
"logs_router",
"unified_ws_router",
"ws_auth_router",
]

View File

@@ -1,11 +0,0 @@
"""WebSocket 日志推送路由兼容导出。"""
from src.webui.logs_ws import active_connections, broadcast_log, load_recent_logs, router, websocket_logs
__all__ = [
"active_connections",
"broadcast_log",
"load_recent_logs",
"router",
"websocket_logs",
]

View File

@@ -0,0 +1,297 @@
"""统一 WebSocket 连接管理器。"""
import asyncio
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Set
from fastapi import WebSocket
from src.common.logger import get_logger
logger = get_logger("webui.websocket")
@dataclass
class WebSocketConnection:
"""统一 WebSocket 连接上下文。"""
connection_id: str
websocket: WebSocket
subscriptions: Set[str] = field(default_factory=set)
chat_sessions: Dict[str, str] = field(default_factory=dict)
send_queue: "asyncio.Queue[Optional[Dict[str, Any]]]" = field(default_factory=asyncio.Queue)
sender_task: Optional["asyncio.Task[None]"] = None
class UnifiedWebSocketManager:
"""统一 WebSocket 连接管理器。"""
def __init__(self) -> None:
"""初始化统一 WebSocket 连接管理器。"""
self.connections: Dict[str, WebSocketConnection] = {}
def _build_subscription_key(self, domain: str, topic: str) -> str:
"""构建订阅索引键。
Args:
domain: 业务域名称。
topic: 主题名称。
Returns:
str: 订阅索引键。
"""
return f"{domain}:{topic}"
async def _sender_loop(self, connection: WebSocketConnection) -> None:
"""串行发送指定连接的出站消息。
Args:
connection: 目标连接上下文。
"""
try:
while True:
message = await connection.send_queue.get()
if message is None:
return
await connection.websocket.send_json(message)
except asyncio.CancelledError:
raise
except Exception as exc:
logger.error("统一 WebSocket 发送失败: connection=%s, error=%s", connection.connection_id, exc)
async def connect(self, connection_id: str, websocket: WebSocket) -> WebSocketConnection:
"""注册一个新的物理 WebSocket 连接。
Args:
connection_id: 连接 ID。
websocket: FastAPI WebSocket 对象。
Returns:
WebSocketConnection: 新建的连接上下文。
"""
await websocket.accept()
connection = WebSocketConnection(connection_id=connection_id, websocket=websocket)
connection.sender_task = asyncio.create_task(self._sender_loop(connection))
self.connections[connection_id] = connection
return connection
async def disconnect(self, connection_id: str) -> None:
"""断开并清理指定连接。
Args:
connection_id: 连接 ID。
"""
connection = self.connections.pop(connection_id, None)
if connection is None:
return
await connection.send_queue.put(None)
if connection.sender_task is not None:
try:
await connection.sender_task
except asyncio.CancelledError:
pass
except Exception as exc:
logger.debug("等待发送协程退出时出现异常: connection=%s, error=%s", connection_id, exc)
def get_connection(self, connection_id: str) -> Optional[WebSocketConnection]:
"""获取指定连接上下文。
Args:
connection_id: 连接 ID。
Returns:
Optional[WebSocketConnection]: 找到时返回连接上下文。
"""
return self.connections.get(connection_id)
def register_chat_session(self, connection_id: str, client_session_id: str, session_id: str) -> None:
"""登记连接下的逻辑聊天会话。
Args:
connection_id: 连接 ID。
client_session_id: 前端会话 ID。
session_id: 内部会话 ID。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
connection.chat_sessions[client_session_id] = session_id
def unregister_chat_session(self, connection_id: str, client_session_id: str) -> None:
"""移除连接下的逻辑聊天会话登记。
Args:
connection_id: 连接 ID。
client_session_id: 前端会话 ID。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
connection.chat_sessions.pop(client_session_id, None)
def get_chat_session_id(self, connection_id: str, client_session_id: str) -> Optional[str]:
"""查询连接下的内部聊天会话 ID。
Args:
connection_id: 连接 ID。
client_session_id: 前端会话 ID。
Returns:
Optional[str]: 找到时返回内部会话 ID。
"""
connection = self.connections.get(connection_id)
if connection is None:
return None
return connection.chat_sessions.get(client_session_id)
def subscribe(self, connection_id: str, domain: str, topic: str) -> None:
"""登记连接的主题订阅。
Args:
connection_id: 连接 ID。
domain: 业务域名称。
topic: 主题名称。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
connection.subscriptions.add(self._build_subscription_key(domain, topic))
def unsubscribe(self, connection_id: str, domain: str, topic: str) -> None:
"""移除连接的主题订阅。
Args:
connection_id: 连接 ID。
domain: 业务域名称。
topic: 主题名称。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
connection.subscriptions.discard(self._build_subscription_key(domain, topic))
def is_subscribed(self, connection_id: str, domain: str, topic: str) -> bool:
"""判断连接是否订阅了指定主题。
Args:
connection_id: 连接 ID。
domain: 业务域名称。
topic: 主题名称。
Returns:
bool: 已订阅时返回 ``True``。
"""
connection = self.connections.get(connection_id)
if connection is None:
return False
return self._build_subscription_key(domain, topic) in connection.subscriptions
async def enqueue(self, connection_id: str, message: Dict[str, Any]) -> None:
"""向指定连接的发送队列压入消息。
Args:
connection_id: 连接 ID。
message: 待发送的消息。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
await connection.send_queue.put(message)
async def send_response(
self,
connection_id: str,
request_id: Optional[str],
ok: bool,
data: Optional[Dict[str, Any]] = None,
error: Optional[Dict[str, Any]] = None,
) -> None:
"""发送统一响应消息。
Args:
connection_id: 连接 ID。
request_id: 请求 ID。
ok: 请求是否成功。
data: 成功响应数据。
error: 失败响应数据。
"""
response_message: Dict[str, Any] = {
"op": "response",
"id": request_id,
"ok": ok,
}
if data is not None:
response_message["data"] = data
if error is not None:
response_message["error"] = error
await self.enqueue(connection_id, response_message)
async def send_event(
self,
connection_id: str,
domain: str,
event: str,
data: Dict[str, Any],
session: Optional[str] = None,
topic: Optional[str] = None,
) -> None:
"""发送统一事件消息。
Args:
connection_id: 连接 ID。
domain: 业务域名称。
event: 事件名称。
data: 事件数据。
session: 可选的逻辑会话 ID。
topic: 可选的主题名称。
"""
event_message: Dict[str, Any] = {
"op": "event",
"domain": domain,
"event": event,
"data": data,
}
if session is not None:
event_message["session"] = session
if topic is not None:
event_message["topic"] = topic
await self.enqueue(connection_id, event_message)
async def send_pong(self, connection_id: str, timestamp: float) -> None:
"""发送心跳响应。
Args:
connection_id: 连接 ID。
timestamp: 当前时间戳。
"""
await self.enqueue(
connection_id,
{
"op": "pong",
"ts": timestamp,
},
)
async def broadcast_to_topic(self, domain: str, topic: str, event: str, data: Dict[str, Any]) -> None:
"""向订阅指定主题的全部连接广播事件。
Args:
domain: 业务域名称。
topic: 主题名称。
event: 事件名称。
data: 事件数据。
"""
subscription_key = self._build_subscription_key(domain, topic)
for connection in list(self.connections.values()):
if subscription_key in connection.subscriptions:
await self.send_event(
connection.connection_id,
domain=domain,
event=event,
data=data,
topic=topic,
)
websocket_manager = UnifiedWebSocketManager()

View File

@@ -0,0 +1,548 @@
"""统一 WebSocket 路由。"""
from typing import Any, Dict, Optional, Set, cast
import asyncio
import time
import uuid
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
from src.common.logger import get_logger
from src.webui.core import get_token_manager
from src.webui.logs_ws import load_recent_logs
from src.webui.routers.chat.service import (
chat_manager,
dispatch_chat_event,
normalize_webui_user_id,
resolve_initial_virtual_identity,
send_initial_chat_state,
)
from src.webui.routers.plugin.progress import get_current_progress
from src.webui.routers.websocket.auth import verify_ws_token
from src.webui.routers.websocket.manager import websocket_manager
logger = get_logger("webui.unified_ws")
router = APIRouter()
_background_tasks: Set["asyncio.Task[None]"] = set()
def _build_error(code: str, message: str) -> Dict[str, Any]:
"""构建统一错误响应体。
Args:
code: 错误码。
message: 错误描述。
Returns:
Dict[str, Any]: 统一错误对象。
"""
return {
"code": code,
"message": message,
}
def _get_request_data(message: Dict[str, Any]) -> Dict[str, Any]:
"""从客户端消息中提取数据字段。
Args:
message: 客户端消息。
Returns:
Dict[str, Any]: 标准化后的数据字典。
"""
data = message.get("data", {})
if isinstance(data, dict):
return cast(Dict[str, Any], data)
return {}
def _track_background_task(task: "asyncio.Task[None]") -> None:
"""登记后台任务并在完成后自动清理。
Args:
task: 后台协程任务。
"""
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
async def authenticate_websocket_connection(websocket: WebSocket, token: Optional[str]) -> bool:
"""校验统一 WebSocket 连接的认证状态。
Args:
websocket: FastAPI WebSocket 对象。
token: 可选的一次性握手 Token。
Returns:
bool: 认证通过时返回 ``True``。
"""
if token and verify_ws_token(token):
logger.debug("统一 WebSocket 使用临时 token 认证成功")
return True
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
logger.debug("统一 WebSocket 使用 Cookie 认证成功")
return True
return False
async def _handle_logs_subscribe(connection_id: str, request_id: Optional[str], data: Dict[str, Any]) -> None:
"""处理日志域订阅请求。
Args:
connection_id: 连接 ID。
request_id: 请求 ID。
data: 订阅参数。
"""
replay_limit = int(data.get("replay", 100) or 100)
replay_limit = max(0, min(replay_limit, 500))
websocket_manager.subscribe(connection_id, domain="logs", topic="main")
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"domain": "logs", "topic": "main"},
)
await websocket_manager.send_event(
connection_id,
domain="logs",
event="snapshot",
topic="main",
data={"entries": load_recent_logs(limit=replay_limit)},
)
async def _handle_plugin_progress_subscribe(connection_id: str, request_id: Optional[str]) -> None:
"""处理插件进度域订阅请求。
Args:
connection_id: 连接 ID。
request_id: 请求 ID。
"""
websocket_manager.subscribe(connection_id, domain="plugin_progress", topic="main")
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"domain": "plugin_progress", "topic": "main"},
)
await websocket_manager.send_event(
connection_id,
domain="plugin_progress",
event="snapshot",
topic="main",
data={"progress": get_current_progress()},
)
async def _handle_subscribe(connection_id: str, message: Dict[str, Any]) -> None:
"""处理主题订阅请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
domain = str(message.get("domain") or "").strip()
topic = str(message.get("topic") or "").strip()
data = _get_request_data(message)
if domain == "logs" and topic == "main":
await _handle_logs_subscribe(connection_id, request_id, data)
return
if domain == "plugin_progress" and topic == "main":
await _handle_plugin_progress_subscribe(connection_id, request_id)
return
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("unsupported_subscription", f"不支持的订阅目标: {domain}:{topic}"),
)
async def _handle_unsubscribe(connection_id: str, message: Dict[str, Any]) -> None:
"""处理主题退订请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
domain = str(message.get("domain") or "").strip()
topic = str(message.get("topic") or "").strip()
if not domain or not topic:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("invalid_unsubscribe", "退订请求缺少 domain 或 topic"),
)
return
websocket_manager.unsubscribe(connection_id, domain=domain, topic=topic)
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"domain": domain, "topic": topic},
)
async def _open_chat_session(connection_id: str, message: Dict[str, Any]) -> None:
"""打开一个逻辑聊天会话。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
client_session_id = str(message.get("session") or "").strip()
if not client_session_id:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("missing_session", "聊天会话打开请求缺少 session"),
)
return
data = _get_request_data(message)
normalized_user_id = normalize_webui_user_id(cast(Optional[str], data.get("user_id")))
current_user_name = str(data.get("user_name") or "WebUI用户")
current_virtual_config = resolve_initial_virtual_identity(
platform=cast(Optional[str], data.get("platform")),
person_id=cast(Optional[str], data.get("person_id")),
group_name=cast(Optional[str], data.get("group_name")),
group_id=cast(Optional[str], data.get("group_id")),
)
restore = bool(data.get("restore"))
session_id = f"{connection_id}:{client_session_id}"
async def send_chat_event(chat_message: Dict[str, Any]) -> None:
"""将聊天消息封装为统一事件并发送。
Args:
chat_message: 聊天消息体。
"""
event_name = str(chat_message.get("type") or "message")
await websocket_manager.send_event(
connection_id,
domain="chat",
event=event_name,
session=client_session_id,
data=chat_message,
)
await chat_manager.connect(
session_id=session_id,
connection_id=connection_id,
client_session_id=client_session_id,
user_id=normalized_user_id,
user_name=current_user_name,
virtual_config=current_virtual_config,
sender=send_chat_event,
)
websocket_manager.register_chat_session(connection_id, client_session_id, session_id)
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"session": client_session_id, "session_id": session_id},
)
await send_initial_chat_state(
session_id=session_id,
user_id=normalized_user_id,
user_name=current_user_name,
virtual_config=current_virtual_config,
include_welcome=not restore,
)
async def _close_chat_session(connection_id: str, message: Dict[str, Any]) -> None:
"""关闭一个逻辑聊天会话。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
client_session_id = str(message.get("session") or "").strip()
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
if session_id is None:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
)
return
chat_manager.disconnect(session_id)
websocket_manager.unregister_chat_session(connection_id, client_session_id)
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"session": client_session_id},
)
async def _process_chat_message(connection_id: str, client_session_id: str, data: Dict[str, Any]) -> None:
"""在后台处理聊天消息事件。
Args:
connection_id: 连接 ID。
client_session_id: 前端会话 ID。
data: 客户端提交的消息数据。
"""
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
if session_id is None:
return
session_state = chat_manager.get_session(session_id)
if session_state is None:
return
next_user_name, next_virtual_config = await dispatch_chat_event(
session_id=session_id,
session_id_prefix=session_id[:8],
data=data,
current_user_name=session_state.user_name,
normalized_user_id=session_state.user_id,
current_virtual_config=session_state.virtual_config,
)
chat_manager.update_session_context(
session_id=session_id,
user_name=next_user_name,
virtual_config=next_virtual_config,
)
async def _handle_chat_message_send(connection_id: str, message: Dict[str, Any]) -> None:
"""处理聊天消息发送请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
client_session_id = str(message.get("session") or "").strip()
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
if session_id is None:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
)
return
data = _get_request_data(message)
payload = {
"type": "message",
"content": data.get("content", ""),
"user_name": data.get("user_name", ""),
}
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"accepted": True, "session": client_session_id},
)
_track_background_task(asyncio.create_task(_process_chat_message(connection_id, client_session_id, payload)))
async def _handle_chat_nickname_update(connection_id: str, message: Dict[str, Any]) -> None:
"""处理聊天昵称更新请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
client_session_id = str(message.get("session") or "").strip()
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
if session_id is None:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
)
return
data = _get_request_data(message)
session_state = chat_manager.get_session(session_id)
if session_state is None:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
)
return
next_user_name, next_virtual_config = await dispatch_chat_event(
session_id=session_id,
session_id_prefix=session_id[:8],
data={
"type": "update_nickname",
"user_name": data.get("user_name", ""),
},
current_user_name=session_state.user_name,
normalized_user_id=session_state.user_id,
current_virtual_config=session_state.virtual_config,
)
chat_manager.update_session_context(
session_id=session_id,
user_name=next_user_name,
virtual_config=next_virtual_config,
)
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"session": client_session_id, "user_name": next_user_name},
)
async def _handle_chat_call(connection_id: str, message: Dict[str, Any]) -> None:
"""处理聊天域调用请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
method = str(message.get("method") or "").strip()
if method == "session.open":
await _open_chat_session(connection_id, message)
return
if method == "session.close":
await _close_chat_session(connection_id, message)
return
if method == "message.send":
await _handle_chat_message_send(connection_id, message)
return
if method == "session.update_nickname":
await _handle_chat_nickname_update(connection_id, message)
return
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("unsupported_method", f"不支持的聊天方法: {method}"),
)
async def _handle_call(connection_id: str, message: Dict[str, Any]) -> None:
"""处理统一调用请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
domain = str(message.get("domain") or "").strip()
if domain == "chat":
await _handle_chat_call(connection_id, message)
return
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("unsupported_domain", f"不支持的调用域: {domain}"),
)
async def handle_client_message(connection_id: str, message: Dict[str, Any]) -> None:
"""处理统一 WebSocket 客户端消息。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
operation = str(message.get("op") or "").strip()
request_id = cast(Optional[str], message.get("id"))
if operation == "ping":
await websocket_manager.send_pong(connection_id, time.time())
return
if operation == "subscribe":
await _handle_subscribe(connection_id, message)
return
if operation == "unsubscribe":
await _handle_unsubscribe(connection_id, message)
return
if operation == "call":
await _handle_call(connection_id, message)
return
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("unsupported_operation", f"不支持的操作: {operation}"),
)
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, token: Optional[str] = Query(None)) -> None:
"""统一 WebSocket 入口。
Args:
websocket: FastAPI WebSocket 对象。
token: 可选的一次性握手 Token。
"""
if not await authenticate_websocket_connection(websocket, token):
logger.warning("统一 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
connection_id = uuid.uuid4().hex
await websocket_manager.connect(connection_id=connection_id, websocket=websocket)
logger.info("统一 WebSocket 客户端已连接: connection=%s", connection_id)
await websocket_manager.send_event(
connection_id,
domain="system",
event="ready",
data={"connection_id": connection_id, "timestamp": time.time()},
)
try:
while True:
raw_message = await websocket.receive_json()
if not isinstance(raw_message, dict):
await websocket_manager.send_response(
connection_id,
request_id=None,
ok=False,
error=_build_error("invalid_message", "消息必须是 JSON 对象"),
)
continue
await handle_client_message(connection_id, cast(Dict[str, Any], raw_message))
except WebSocketDisconnect:
logger.info("统一 WebSocket 客户端已断开: connection=%s", connection_id)
except Exception as exc:
logger.error(f"统一 WebSocket 处理失败: {exc}")
finally:
chat_manager.disconnect_connection(connection_id)
await websocket_manager.disconnect(connection_id)

View File

@@ -18,11 +18,11 @@ from src.webui.routers.expression import router as expression_router
from src.webui.routers.jargon import router as jargon_router
from src.webui.routers.model import router as model_router
from src.webui.routers.person import router as person_router
from src.webui.routers.plugin import get_progress_router
from src.webui.routers.plugin import router as plugin_router
from src.webui.routers.statistics import router as statistics_router
from src.webui.routers.system import router as system_router
from src.webui.routers.websocket.auth import router as ws_auth_router
from src.webui.routers.websocket.unified import router as unified_ws_router
logger = get_logger("webui.api")
@@ -43,14 +43,14 @@ router.include_router(jargon_router)
router.include_router(emoji_router)
# 注册插件管理路由
router.include_router(plugin_router)
# 注册插件进度 WebSocket 路由
router.include_router(get_progress_router())
# 注册系统控制路由
router.include_router(system_router)
# 注册模型列表获取路由
router.include_router(model_router)
# 注册 WebSocket 认证路由
router.include_router(ws_auth_router)
# 注册统一 WebSocket 路由
router.include_router(unified_ws_router)
class TokenVerifyRequest(BaseModel):