Merge branch 'r-dev' of https://github.com/Mai-with-u/MaiBot into r-dev

This commit is contained in:
SengokuCola
2026-04-03 00:00:55 +08:00
49 changed files with 6543 additions and 2257 deletions

1
.gitignore vendored
View File

@@ -1,5 +1,6 @@
data/
data1/
mai_knowledge/knowledge.json
mongodb/
NapCat.Framework.Windows.Once/
NapCat.Framework.Windows.OneKey/

View File

@@ -0,0 +1,161 @@
import { unifiedWsClient, type ConnectionStatus } from './unified-ws'
interface ChatSessionOpenPayload {
group_id?: string
group_name?: string
person_id?: string
platform?: string
user_id?: string
user_name?: string
}
type ChatSessionListener = (message: Record<string, unknown>) => void
class ChatWsClient {
private initialized = false
private listeners: Map<string, Set<ChatSessionListener>> = new Map()
private sessionPayloads: Map<string, ChatSessionOpenPayload> = new Map()
private initialize(): void {
if (this.initialized) {
return
}
unifiedWsClient.addEventListener((message) => {
if (message.domain !== 'chat' || !message.session) {
return
}
const sessionListeners = this.listeners.get(message.session)
if (!sessionListeners) {
return
}
sessionListeners.forEach((listener) => {
try {
listener(message.data)
} catch (error) {
console.error('聊天会话监听器执行失败:', error)
}
})
})
unifiedWsClient.onReconnect(() => {
void this.reopenSessions()
})
this.initialized = true
}
private async reopenSessions(): Promise<void> {
const reopenTargets = Array.from(this.sessionPayloads.entries())
for (const [sessionId, payload] of reopenTargets) {
try {
await unifiedWsClient.call({
domain: 'chat',
method: 'session.open',
session: sessionId,
data: {
...payload,
restore: true,
} as Record<string, unknown>,
})
} catch (error) {
console.error(`恢复聊天会话失败 (${sessionId}):`, error)
}
}
}
async openSession(sessionId: string, payload: ChatSessionOpenPayload): Promise<void> {
this.initialize()
this.sessionPayloads.set(sessionId, payload)
await unifiedWsClient.call({
domain: 'chat',
method: 'session.open',
session: sessionId,
data: payload as Record<string, unknown>,
})
}
async closeSession(sessionId: string): Promise<void> {
this.sessionPayloads.delete(sessionId)
if (unifiedWsClient.getStatus() !== 'connected') {
return
}
try {
await unifiedWsClient.call({
domain: 'chat',
method: 'session.close',
session: sessionId,
data: {},
})
} catch (error) {
console.warn(`关闭聊天会话失败 (${sessionId}):`, error)
}
}
async sendMessage(sessionId: string, content: string, userName: string): Promise<void> {
await unifiedWsClient.call({
domain: 'chat',
method: 'message.send',
session: sessionId,
data: {
content,
user_name: userName,
},
})
}
async updateNickname(sessionId: string, userName: string): Promise<void> {
const currentPayload = this.sessionPayloads.get(sessionId)
if (currentPayload) {
this.sessionPayloads.set(sessionId, {
...currentPayload,
user_name: userName,
})
}
await unifiedWsClient.call({
domain: 'chat',
method: 'session.update_nickname',
session: sessionId,
data: {
user_name: userName,
},
})
}
onSessionMessage(sessionId: string, listener: ChatSessionListener): () => void {
this.initialize()
const sessionListeners = this.listeners.get(sessionId) ?? new Set<ChatSessionListener>()
sessionListeners.add(listener)
this.listeners.set(sessionId, sessionListeners)
return () => {
const currentListeners = this.listeners.get(sessionId)
if (!currentListeners) {
return
}
currentListeners.delete(listener)
if (currentListeners.size === 0) {
this.listeners.delete(sessionId)
}
}
}
onConnectionChange(listener: (connected: boolean) => void): () => void {
return unifiedWsClient.onConnectionChange(listener)
}
onStatusChange(listener: (status: ConnectionStatus) => void): () => void {
return unifiedWsClient.onStatusChange(listener)
}
async restart(): Promise<void> {
await unifiedWsClient.restart()
}
}
export const chatWsClient = new ChatWsClient()

View File

@@ -1,13 +1,11 @@
/**
* 全局日志 WebSocket 管理器
* 确保整个应用只有一个 WebSocket 连接
* 确保整个应用只通过统一连接层订阅日志流
*/
import { checkAuthStatus } from './fetch-with-auth'
import { getSetting } from './settings-manager'
import { createReconnectingWebSocket } from './ws-utils'
import { getWsBaseUrl } from '@/lib/api-base'
import { unifiedWsClient } from './unified-ws'
export interface LogEntry {
id: string
@@ -17,165 +15,79 @@ export interface LogEntry {
message: string
}
type LogCallback = (log: LogEntry) => void
type LogCallback = () => void
type ConnectionCallback = (connected: boolean) => void
class LogWebSocketManager {
private wsControl: ReturnType<typeof createReconnectingWebSocket> | null = null
// 订阅者
private logCallbacks: Set<LogCallback> = new Set()
private connectionCallbacks: Set<ConnectionCallback> = new Set()
private initialized = false
private isConnected = false
// 日志缓存 - 保存所有接收到的日志
private logCache: LogEntry[] = []
private logCallbacks: Set<LogCallback> = new Set()
private subscriptionActive = false
/**
* 获取最大缓存大小(从设置读取)
*/
private getMaxCacheSize(): number {
return getSetting('logCacheSize')
}
/**
* 获取最大重连次数(从设置读取)
*/
private getMaxReconnectAttempts(): number {
return getSetting('wsMaxReconnectAttempts')
}
/**
* 获取重连间隔(从设置读取)
*/
private getReconnectInterval(): number {
return getSetting('wsReconnectInterval')
}
/**
* 获取 WebSocket URL不含 token 参数)
*/
private async getWebSocketUrl(): Promise<string> {
const wsBase = await getWsBaseUrl()
return `${wsBase}/ws/logs`
}
/**
* 连接 WebSocket会先检查登录状态
*/
async connect() {
// 检查是否在登录页面
if (window.location.pathname === '/auth') {
console.log('📡 在登录页面,跳过 WebSocket 连接')
private initialize(): void {
if (this.initialized) {
return
}
// 检查登录状态,避免未登录时尝试连接
const isAuthenticated = await checkAuthStatus()
if (!isAuthenticated) {
console.log('📡 未登录,跳过 WebSocket 连接')
return
}
unifiedWsClient.addEventListener((message) => {
if (message.domain !== 'logs') {
return
}
const wsUrl = await this.getWebSocketUrl()
if (message.event === 'snapshot') {
const entries = Array.isArray(message.data.entries)
? (message.data.entries as LogEntry[])
: []
this.logCache = entries.slice(-this.getMaxCacheSize())
this.notifyLogChange()
return
}
// 使用 ws-utils 创建 WebSocket
this.wsControl = createReconnectingWebSocket(wsUrl, {
onMessage: (data: string) => {
try {
const log: LogEntry = JSON.parse(data)
this.notifyLog(log)
} catch (error) {
console.error('解析日志消息失败:', error)
}
},
onOpen: () => {
this.isConnected = true
this.notifyConnection(true)
},
onClose: () => {
this.isConnected = false
this.notifyConnection(false)
},
onError: (error) => {
console.error('❌ WebSocket 错误:', error)
this.isConnected = false
this.notifyConnection(false)
},
heartbeatInterval: 30000,
maxRetries: this.getMaxReconnectAttempts(),
backoffBase: this.getReconnectInterval(),
maxBackoff: 30000,
if (message.event === 'entry' && message.data.entry) {
this.appendLog(message.data.entry as LogEntry)
}
})
// 启动连接
await this.wsControl.connect()
unifiedWsClient.onConnectionChange((connected) => {
this.isConnected = connected
this.notifyConnection(connected)
})
this.initialized = true
}
/**
* 断开连接
*/
disconnect() {
if (this.wsControl) {
this.wsControl.disconnect()
this.wsControl = null
}
this.isConnected = false
}
/**
* 订阅日志消息
*/
onLog(callback: LogCallback) {
this.logCallbacks.add(callback)
return () => this.logCallbacks.delete(callback)
}
/**
* 订阅连接状态
*/
onConnectionChange(callback: ConnectionCallback) {
this.connectionCallbacks.add(callback)
// 立即通知当前状态
callback(this.isConnected)
return () => this.connectionCallbacks.delete(callback)
}
/**
* 通知所有订阅者新日志
*/
private notifyLog(log: LogEntry) {
// 检查是否已存在(通过 id 去重)
private appendLog(log: LogEntry): void {
const exists = this.logCache.some(existingLog => existingLog.id === log.id)
if (!exists) {
// 添加到缓存
this.logCache.push(log)
// 限制缓存大小(动态读取配置)
const maxCacheSize = this.getMaxCacheSize()
if (this.logCache.length > maxCacheSize) {
this.logCache = this.logCache.slice(-maxCacheSize)
}
// 只有新日志才通知订阅者
this.logCallbacks.forEach(callback => {
try {
callback(log)
} catch (error) {
console.error('日志回调执行失败:', error)
}
})
if (exists) {
return
}
this.logCache.push(log)
const maxCacheSize = this.getMaxCacheSize()
if (this.logCache.length > maxCacheSize) {
this.logCache = this.logCache.slice(-maxCacheSize)
}
this.notifyLogChange()
}
/**
* 通知所有订阅者连接状态变化
*/
private notifyConnection(connected: boolean) {
this.connectionCallbacks.forEach(callback => {
private notifyLogChange(): void {
this.logCallbacks.forEach((callback) => {
try {
callback()
} catch (error) {
console.error('日志回调执行失败:', error)
}
})
}
private notifyConnection(connected: boolean): void {
this.connectionCallbacks.forEach((callback) => {
try {
callback(connected)
} catch (error) {
@@ -184,35 +96,65 @@ class LogWebSocketManager {
})
}
/**
* 获取缓存的所有日志
*/
async connect(): Promise<void> {
if (window.location.pathname === '/auth') {
return
}
const isAuthenticated = await checkAuthStatus()
if (!isAuthenticated) {
return
}
this.initialize()
if (this.subscriptionActive) {
return
}
try {
await unifiedWsClient.subscribe('logs', 'main', { replay: 100 })
this.subscriptionActive = true
} catch (error) {
console.error('订阅日志流失败:', error)
}
}
disconnect(): void {
this.subscriptionActive = false
void unifiedWsClient.unsubscribe('logs', 'main')
this.isConnected = false
this.notifyConnection(false)
}
onLog(callback: LogCallback): () => void {
this.logCallbacks.add(callback)
return () => this.logCallbacks.delete(callback)
}
onConnectionChange(callback: ConnectionCallback): () => void {
this.connectionCallbacks.add(callback)
callback(this.isConnected)
return () => this.connectionCallbacks.delete(callback)
}
getAllLogs(): LogEntry[] {
return [...this.logCache]
}
/**
* 清空日志缓存
*/
clearLogs() {
clearLogs(): void {
this.logCache = []
this.notifyLogChange()
}
/**
* 获取当前连接状态
*/
getConnectionStatus(): boolean {
return this.isConnected
}
}
// 导出单例
export const logWebSocket = new LogWebSocketManager()
// 自动连接(应用启动时)
if (typeof window !== 'undefined') {
// 延迟一下确保页面加载完成
setTimeout(() => {
logWebSocket.connect()
void logWebSocket.connect()
}, 100)
}

View File

@@ -1,9 +1,9 @@
import type { ApiResponse } from '@/types/api'
import type { PluginInfo } from '@/types/plugin'
import { getWsBaseUrl } from '@/lib/api-base'
import { fetchWithAuth } from '@/lib/fetch-with-auth'
import { parseResponse } from '@/lib/api-helpers'
import { pluginProgressClient } from '@/lib/plugin-progress-client'
import type { GitStatus, MaimaiVersion } from './types'
/**
@@ -211,41 +211,13 @@ export function isPluginCompatible(
*/
export async function connectPluginProgressWebSocket(
onProgress: (progress: import('./types').PluginLoadProgress) => void,
onError?: (error: Event) => void
): Promise<WebSocket | null> {
const wsBase = await getWsBaseUrl()
const wsUrl = `${wsBase}/api/webui/ws/plugin-progress`
// 使用 ws-utils 创建 WebSocket
const { createReconnectingWebSocket } = await import('@/lib/ws-utils')
const wsControl = createReconnectingWebSocket(wsUrl, {
onMessage: (data: string) => {
try {
const progressData = JSON.parse(data) as import('./types').PluginLoadProgress
onProgress(progressData)
} catch (error) {
console.error('Failed to parse progress data:', error)
}
},
onOpen: () => {
console.log('Plugin progress WebSocket connected')
},
onClose: () => {
console.log('Plugin progress WebSocket disconnected')
},
onError: (error) => {
console.error('Plugin progress WebSocket error:', error)
onError?.(error)
},
heartbeatInterval: 30000,
maxRetries: 10,
backoffBase: 1000,
maxBackoff: 30000,
})
// 启动连接
await wsControl.connect()
// 返回 WebSocket 实例(用于外部检查连接状态)
return wsControl.getWebSocket()
onError?: (error: Error) => void
): Promise<() => Promise<void>> {
try {
return await pluginProgressClient.subscribe(onProgress)
} catch (error) {
const normalizedError = error instanceof Error ? error : new Error('插件进度订阅失败')
onError?.(normalizedError)
return async () => {}
}
}

View File

@@ -0,0 +1,58 @@
import type { PluginLoadProgress } from '@/lib/plugin-api/types'
import { unifiedWsClient } from './unified-ws'
type ProgressListener = (progress: PluginLoadProgress) => void
class PluginProgressClient {
private initialized = false
private listeners: Set<ProgressListener> = new Set()
private subscriptionActive = false
private initialize(): void {
if (this.initialized) {
return
}
unifiedWsClient.addEventListener((message) => {
if (message.domain !== 'plugin_progress') {
return
}
const progress = message.data.progress as PluginLoadProgress | undefined
if (!progress) {
return
}
this.listeners.forEach((listener) => {
try {
listener(progress)
} catch (error) {
console.error('插件进度监听器执行失败:', error)
}
})
})
this.initialized = true
}
async subscribe(listener: ProgressListener): Promise<() => Promise<void>> {
this.initialize()
this.listeners.add(listener)
if (!this.subscriptionActive) {
await unifiedWsClient.subscribe('plugin_progress', 'main')
this.subscriptionActive = true
}
return async () => {
this.listeners.delete(listener)
if (this.listeners.size === 0 && this.subscriptionActive) {
this.subscriptionActive = false
await unifiedWsClient.unsubscribe('plugin_progress', 'main')
}
}
}
}
export const pluginProgressClient = new PluginProgressClient()

View File

@@ -0,0 +1,495 @@
import { fetchWithAuth } from './fetch-with-auth'
import { getSetting } from './settings-manager'
import { getWsBaseUrl } from '@/lib/api-base'
export type ConnectionStatus = 'idle' | 'connecting' | 'connected'
export interface WsErrorPayload {
code?: string
message: string
}
export interface WsEventEnvelope {
op: 'event'
domain: string
event: string
session?: string
topic?: string
data: Record<string, unknown>
}
interface WsResponseEnvelope {
op: 'response'
id?: string
ok: boolean
data?: Record<string, unknown>
error?: WsErrorPayload
}
interface WsPongEnvelope {
op: 'pong'
ts: number
}
type WsServerEnvelope = WsEventEnvelope | WsPongEnvelope | WsResponseEnvelope
interface PendingRequest {
reject: (error: Error) => void
resolve: (data: Record<string, unknown>) => void
timeoutId: number
}
interface SubscriptionDefinition {
data?: Record<string, unknown>
domain: string
topic: string
}
type EventListener = (message: WsEventEnvelope) => void
type ConnectionListener = (connected: boolean) => void
type StatusListener = (status: ConnectionStatus) => void
type ReconnectListener = () => void
function isResponseEnvelope(message: WsServerEnvelope): message is WsResponseEnvelope {
return message.op === 'response'
}
function isEventEnvelope(message: WsServerEnvelope): message is WsEventEnvelope {
return message.op === 'event'
}
async function getWsToken(): Promise<string | null> {
try {
const response = await fetchWithAuth('/api/webui/ws-token', {
method: 'GET',
credentials: 'include',
})
if (!response.ok) {
return null
}
const data = await response.json()
if (data.success && data.token) {
return data.token as string
}
return null
} catch (error) {
console.error('获取统一 WebSocket token 失败:', error)
return null
}
}
class UnifiedWebSocketClient {
private connectPromise: Promise<void> | null = null
private connectionListeners: Set<ConnectionListener> = new Set()
private eventListeners: Set<EventListener> = new Set()
private hasConnectedOnce = false
private heartbeatIntervalId: number | null = null
private manualDisconnect = false
private pendingRequests: Map<string, PendingRequest> = new Map()
private reconnectAttempts = 0
private reconnectListeners: Set<ReconnectListener> = new Set()
private reconnectTimeout: number | null = null
private requestCounter = 0
private status: ConnectionStatus = 'idle'
private statusListeners: Set<StatusListener> = new Set()
private subscriptions: Map<string, SubscriptionDefinition> = new Map()
private ws: WebSocket | null = null
private getReconnectDelay(): number {
const baseDelay = getSetting('wsReconnectInterval')
return Math.min(baseDelay * Math.max(this.reconnectAttempts, 1), 30000)
}
private getMaxReconnectAttempts(): number {
return getSetting('wsMaxReconnectAttempts')
}
private getSubscriptionKey(domain: string, topic: string): string {
return `${domain}:${topic}`
}
private nextRequestId(): string {
this.requestCounter += 1
return `ws-${Date.now()}-${this.requestCounter}`
}
private setStatus(status: ConnectionStatus): void {
if (this.status === status) {
return
}
this.status = status
this.statusListeners.forEach((listener) => {
try {
listener(status)
} catch (error) {
console.error('WebSocket 状态监听器执行失败:', error)
}
})
const connected = status === 'connected'
this.connectionListeners.forEach((listener) => {
try {
listener(connected)
} catch (error) {
console.error('WebSocket 连接监听器执行失败:', error)
}
})
}
private stopHeartbeat(): void {
if (this.heartbeatIntervalId !== null) {
clearInterval(this.heartbeatIntervalId)
this.heartbeatIntervalId = null
}
}
private startHeartbeat(): void {
this.stopHeartbeat()
this.heartbeatIntervalId = window.setInterval(() => {
if (this.ws?.readyState === WebSocket.OPEN) {
this.ws.send(JSON.stringify({ op: 'ping' }))
}
}, 30000)
}
private clearReconnectTimer(): void {
if (this.reconnectTimeout !== null) {
clearTimeout(this.reconnectTimeout)
this.reconnectTimeout = null
}
}
private rejectPendingRequests(error: Error): void {
this.pendingRequests.forEach((pendingRequest, requestId) => {
clearTimeout(pendingRequest.timeoutId)
pendingRequest.reject(error)
this.pendingRequests.delete(requestId)
})
}
private scheduleReconnect(): void {
if (this.manualDisconnect) {
return
}
if (this.reconnectAttempts >= this.getMaxReconnectAttempts()) {
console.warn(`统一 WebSocket 达到最大重连次数 (${this.getMaxReconnectAttempts()}),停止重连`)
return
}
this.reconnectAttempts += 1
const delay = this.getReconnectDelay()
this.clearReconnectTimer()
this.reconnectTimeout = window.setTimeout(() => {
void this.connect().catch((error) => {
console.error('统一 WebSocket 重连失败:', error)
})
}, delay)
}
private async createWebSocketUrl(): Promise<string | null> {
const wsBaseUrl = await getWsBaseUrl()
const wsToken = await getWsToken()
if (!wsBaseUrl || !wsToken) {
return null
}
return `${wsBaseUrl}/api/webui/ws?token=${encodeURIComponent(wsToken)}`
}
private async sendRequest(
payload: Record<string, unknown>,
timeoutMs = 10000,
): Promise<Record<string, unknown>> {
if (this.ws?.readyState !== WebSocket.OPEN) {
throw new Error('统一 WebSocket 尚未连接')
}
const requestId = payload.id as string
return await new Promise<Record<string, unknown>>((resolve, reject) => {
const timeoutId = window.setTimeout(() => {
this.pendingRequests.delete(requestId)
reject(new Error(`统一 WebSocket 请求超时: ${requestId}`))
}, timeoutMs)
this.pendingRequests.set(requestId, {
resolve,
reject,
timeoutId,
})
this.ws?.send(JSON.stringify(payload))
})
}
private async restoreState(shouldNotifyReconnect: boolean): Promise<void> {
const subscriptions = Array.from(this.subscriptions.values())
for (const subscription of subscriptions) {
try {
await this.sendRequest({
op: 'subscribe',
id: this.nextRequestId(),
domain: subscription.domain,
topic: subscription.topic,
data: subscription.data ?? {},
})
} catch (error) {
console.error('恢复统一 WebSocket 订阅失败:', error)
}
}
if (shouldNotifyReconnect) {
this.reconnectListeners.forEach((listener) => {
try {
listener()
} catch (error) {
console.error('统一 WebSocket 重连监听器执行失败:', error)
}
})
}
}
private handleServerMessage(rawData: string): void {
let message: WsServerEnvelope
try {
message = JSON.parse(rawData) as WsServerEnvelope
} catch (error) {
console.error('解析统一 WebSocket 消息失败:', error)
return
}
if (message.op === 'pong') {
return
}
if (isResponseEnvelope(message)) {
const requestId = message.id
if (!requestId) {
return
}
const pendingRequest = this.pendingRequests.get(requestId)
if (!pendingRequest) {
return
}
clearTimeout(pendingRequest.timeoutId)
this.pendingRequests.delete(requestId)
if (message.ok) {
pendingRequest.resolve(message.data ?? {})
} else {
pendingRequest.reject(new Error(message.error?.message ?? '统一 WebSocket 请求失败'))
}
return
}
if (isEventEnvelope(message)) {
this.eventListeners.forEach((listener) => {
try {
listener(message)
} catch (error) {
console.error('统一 WebSocket 事件监听器执行失败:', error)
}
})
}
}
private handleClose(event: CloseEvent): void {
this.stopHeartbeat()
this.ws = null
this.connectPromise = null
this.setStatus('idle')
this.rejectPendingRequests(new Error(`统一 WebSocket 已关闭 (${event.code})`))
if (event.code === 4001) {
this.manualDisconnect = true
if (window.location.pathname !== '/auth') {
window.location.href = '/auth'
}
return
}
this.scheduleReconnect()
}
async connect(): Promise<void> {
if (this.ws?.readyState === WebSocket.OPEN) {
return
}
if (this.connectPromise) {
return await this.connectPromise
}
this.manualDisconnect = false
this.setStatus('connecting')
this.connectPromise = (async () => {
const wsUrl = await this.createWebSocketUrl()
if (!wsUrl) {
this.setStatus('idle')
throw new Error('无法建立统一 WebSocket 连接')
}
await new Promise<void>((resolve, reject) => {
let settled = false
const socket = new WebSocket(wsUrl)
this.ws = socket
socket.onopen = () => {
settled = true
const shouldNotifyReconnect = this.hasConnectedOnce
this.hasConnectedOnce = true
this.reconnectAttempts = 0
this.startHeartbeat()
this.setStatus('connected')
resolve()
void this.restoreState(shouldNotifyReconnect)
}
socket.onmessage = (event) => {
this.handleServerMessage(event.data)
}
socket.onerror = () => {
if (!settled) {
settled = true
reject(new Error('统一 WebSocket 连接失败'))
}
}
socket.onclose = (event) => {
if (!settled) {
settled = true
reject(new Error(`统一 WebSocket 已关闭 (${event.code})`))
}
this.handleClose(event)
}
})
})()
try {
await this.connectPromise
} finally {
if (this.status !== 'connected') {
this.connectPromise = null
}
}
}
disconnect(): void {
this.manualDisconnect = true
this.clearReconnectTimer()
this.stopHeartbeat()
this.rejectPendingRequests(new Error('统一 WebSocket 已手动断开'))
this.connectPromise = null
if (this.ws) {
this.ws.close()
this.ws = null
}
this.setStatus('idle')
}
async restart(): Promise<void> {
this.manualDisconnect = false
this.clearReconnectTimer()
if (this.ws) {
this.ws.close()
return
}
await this.connect()
}
async call(params: {
data?: Record<string, unknown>
domain: string
method: string
session?: string
}): Promise<Record<string, unknown>> {
await this.connect()
const requestId = this.nextRequestId()
return await this.sendRequest({
op: 'call',
id: requestId,
domain: params.domain,
method: params.method,
session: params.session,
data: params.data ?? {},
})
}
async subscribe(
domain: string,
topic: string,
data?: Record<string, unknown>,
): Promise<Record<string, unknown>> {
await this.connect()
this.subscriptions.set(this.getSubscriptionKey(domain, topic), {
domain,
topic,
data,
})
return await this.sendRequest({
op: 'subscribe',
id: this.nextRequestId(),
domain,
topic,
data: data ?? {},
})
}
async unsubscribe(domain: string, topic: string): Promise<Record<string, unknown> | null> {
this.subscriptions.delete(this.getSubscriptionKey(domain, topic))
if (this.ws?.readyState !== WebSocket.OPEN) {
return null
}
return await this.sendRequest({
op: 'unsubscribe',
id: this.nextRequestId(),
domain,
topic,
data: {},
})
}
addEventListener(listener: EventListener): () => void {
this.eventListeners.add(listener)
return () => {
this.eventListeners.delete(listener)
}
}
onConnectionChange(listener: ConnectionListener): () => void {
this.connectionListeners.add(listener)
listener(this.status === 'connected')
return () => {
this.connectionListeners.delete(listener)
}
}
onStatusChange(listener: StatusListener): () => void {
this.statusListeners.add(listener)
listener(this.status)
return () => {
this.statusListeners.delete(listener)
}
}
onReconnect(listener: ReconnectListener): () => void {
this.reconnectListeners.add(listener)
return () => {
this.reconnectListeners.delete(listener)
}
}
getStatus(): ConnectionStatus {
return this.status
}
}
export const unifiedWsClient = new UnifiedWebSocketClient()

View File

@@ -1,211 +0,0 @@
import { fetchWithAuth } from './fetch-with-auth'
/**
* WebSocket 配置选项
*/
export interface WebSocketOptions {
onMessage?: (data: string) => void
onOpen?: () => void
onClose?: () => void
onError?: (error: Event) => void
heartbeatInterval?: number // 心跳间隔(毫秒)
maxRetries?: number // 最大重连次数
backoffBase?: number // 重连基础间隔(毫秒)
maxBackoff?: number // 最大重连间隔(毫秒)
}
/**
* 获取 WebSocket 临时认证 token
*/
export async function getWsToken(): Promise<string | null> {
try {
// 使用相对路径,让前端代理处理请求,避免 CORS 问题
const response = await fetchWithAuth('/api/webui/ws-token', {
method: 'GET',
credentials: 'include', // 携带 Cookie
})
if (!response.ok) {
console.error('获取 WebSocket token 失败:', response.status)
return null
}
const data = await response.json()
if (data.success && data.token) {
return data.token
}
return null
} catch (error) {
console.error('获取 WebSocket token 失败:', error)
return null
}
}
/**
* 创建带重连、心跳的 WebSocket 封装
*
* @param url WebSocket URL不含 token 参数)
* @param options 配置选项
* @returns WebSocket 控制对象,包含 connect、disconnect、send 方法
*/
export function createReconnectingWebSocket(
url: string,
options: WebSocketOptions = {}
) {
const {
onMessage,
onOpen,
onClose,
onError,
heartbeatInterval = 30000,
maxRetries = 10,
backoffBase = 1000,
maxBackoff = 30000,
} = options
let ws: WebSocket | null = null
let reconnectTimeout: number | null = null
let reconnectAttempts = 0
let heartbeatIntervalId: number | null = null
let isManualDisconnect = false
/**
* 启动心跳
*/
function startHeartbeat() {
stopHeartbeat()
heartbeatIntervalId = window.setInterval(() => {
if (ws?.readyState === WebSocket.OPEN) {
ws.send('ping')
}
}, heartbeatInterval)
}
/**
* 停止心跳
*/
function stopHeartbeat() {
if (heartbeatIntervalId !== null) {
clearInterval(heartbeatIntervalId)
heartbeatIntervalId = null
}
}
/**
* 尝试重连
*/
function attemptReconnect() {
if (isManualDisconnect) {
return
}
if (reconnectAttempts >= maxRetries) {
console.warn(`WebSocket 达到最大重连次数 (${maxRetries}),停止重连`)
return
}
reconnectAttempts += 1
const delay = Math.min(backoffBase * reconnectAttempts, maxBackoff)
console.log(`WebSocket 将在 ${delay}ms 后重连(第 ${reconnectAttempts} 次)`)
reconnectTimeout = window.setTimeout(() => {
connect()
}, delay)
}
/**
* 连接 WebSocket
*/
async function connect() {
if (ws?.readyState === WebSocket.OPEN || ws?.readyState === WebSocket.CONNECTING) {
return
}
// 先获取临时认证 token
const wsToken = await getWsToken()
if (!wsToken) {
console.warn('无法获取 WebSocket token跳过连接')
return
}
const wsUrl = `${url}?token=${encodeURIComponent(wsToken)}`
try {
ws = new WebSocket(wsUrl)
ws.onopen = () => {
reconnectAttempts = 0
startHeartbeat()
onOpen?.()
}
ws.onmessage = (event) => {
// 忽略心跳响应
if (event.data === 'pong') {
return
}
onMessage?.(event.data)
}
ws.onerror = (error) => {
console.error('WebSocket 错误:', error)
onError?.(error)
}
ws.onclose = () => {
stopHeartbeat()
onClose?.()
attemptReconnect()
}
} catch (error) {
console.error('创建 WebSocket 连接失败:', error)
attemptReconnect()
}
}
/**
* 断开连接
*/
function disconnect() {
isManualDisconnect = true
if (reconnectTimeout !== null) {
clearTimeout(reconnectTimeout)
reconnectTimeout = null
}
stopHeartbeat()
if (ws) {
ws.close()
ws = null
}
reconnectAttempts = 0
}
/**
* 发送消息
*/
function send(data: string) {
if (ws?.readyState === WebSocket.OPEN) {
ws.send(data)
} else {
console.warn('WebSocket 未连接,无法发送消息')
}
}
/**
* 获取当前 WebSocket 实例
*/
function getWebSocket(): WebSocket | null {
return ws
}
return {
connect,
disconnect,
send,
getWebSocket,
}
}

View File

@@ -5,7 +5,7 @@ import { Button } from '@/components/ui/button'
import { Input } from '@/components/ui/input'
import { ScrollArea } from '@/components/ui/scroll-area'
import { useToast } from '@/hooks/use-toast'
import { getWsBaseUrl } from '@/lib/api-base'
import { chatWsClient } from '@/lib/chat-ws-client'
import { fetchWithAuth } from '@/lib/fetch-with-auth'
import { cn } from '@/lib/utils'
import { Bot, Edit2, Loader2, RefreshCw, User, Send, Wifi, WifiOff, UserCircle2 } from 'lucide-react'
@@ -85,14 +85,17 @@ export function ChatPage() {
// 持久化用户 ID
const userIdRef = useRef(getOrCreateUserId())
// 每个标签页的 WebSocket 连接
const wsMapRef = useRef<Map<string, WebSocket>>(new Map())
const messagesEndRef = useRef<HTMLDivElement>(null)
const reconnectTimeoutMapRef = useRef<Map<string, number>>(new Map())
const messageIdCounterRef = useRef(0)
const processedMessagesMapRef = useRef<Map<string, Set<string>>>(new Map())
const sessionUnsubscribeMapRef = useRef<Map<string, () => void>>(new Map())
const tabsRef = useRef<ChatTab[]>([])
const { toast } = useToast()
useEffect(() => {
tabsRef.current = tabs
}, [tabs])
// 生成唯一消息 ID
const generateMessageId = (prefix: string) => {
messageIdCounterRef.current += 1
@@ -197,357 +200,218 @@ export function ChatPage() {
}
}, [tempVirtualConfig.platform, personSearchQuery, fetchPersons])
// 加载聊天历史到指定标签页
const loadChatHistoryForTab = useCallback(async (tabId: string, groupId?: string) => {
const handleSessionMessage = useCallback((
tabId: string,
tabType: 'webui' | 'virtual',
config: VirtualIdentityConfig | undefined,
data: WsMessage,
) => {
switch (data.type) {
case 'session_info':
updateTab(tabId, {
sessionInfo: {
session_id: data.session_id,
user_id: data.user_id,
user_name: data.user_name,
bot_name: data.bot_name,
}
})
break
case 'system':
addMessageToTab(tabId, {
id: generateMessageId('sys'),
type: 'system',
content: data.content || '',
timestamp: data.timestamp || Date.now() / 1000,
})
break
case 'user_message': {
const senderUserId = data.sender?.user_id
const currentUserId = tabType === 'virtual' && config
? config.userId
: userIdRef.current
const normalizeSenderId = senderUserId ? senderUserId.replace(/^webui_user_/, '') : ''
const normalizeCurrentId = currentUserId ? currentUserId.replace(/^webui_user_/, '') : ''
if (normalizeSenderId && normalizeCurrentId && normalizeSenderId === normalizeCurrentId) {
break
}
const processedSet = processedMessagesMapRef.current.get(tabId) || new Set()
const contentHash = `user-${data.content}-${Math.floor((data.timestamp || 0) * 1000)}`
if (processedSet.has(contentHash)) {
break
}
processedSet.add(contentHash)
processedMessagesMapRef.current.set(tabId, processedSet)
if (processedSet.size > 100) {
const firstKey = processedSet.values().next().value
if (firstKey) processedSet.delete(firstKey)
}
addMessageToTab(tabId, {
id: data.message_id || generateMessageId('user'),
type: 'user',
content: data.content || '',
timestamp: data.timestamp || Date.now() / 1000,
sender: data.sender,
})
break
}
case 'bot_message': {
updateTab(tabId, { isTyping: false })
const processedSet = processedMessagesMapRef.current.get(tabId) || new Set()
const contentHash = `bot-${data.content}-${Math.floor((data.timestamp || 0) * 1000)}`
if (processedSet.has(contentHash)) {
break
}
processedSet.add(contentHash)
processedMessagesMapRef.current.set(tabId, processedSet)
if (processedSet.size > 100) {
const firstKey = processedSet.values().next().value
if (firstKey) processedSet.delete(firstKey)
}
setTabs(prev => prev.map(tab => {
if (tab.id !== tabId) return tab
const filteredMessages = tab.messages.filter(msg => msg.type !== 'thinking')
const newMessage: ChatMessage = {
id: generateMessageId('bot'),
type: 'bot',
content: data.content || '',
message_type: (data.message_type === 'rich' ? 'rich' : 'text') as 'text' | 'rich',
segments: data.segments,
timestamp: data.timestamp || Date.now() / 1000,
sender: data.sender,
}
return {
...tab,
messages: [...filteredMessages, newMessage]
}
}))
break
}
case 'typing':
updateTab(tabId, { isTyping: data.is_typing || false })
break
case 'error':
setTabs(prev => prev.map(tab => {
if (tab.id !== tabId) return tab
const filteredMessages = tab.messages.filter(msg => msg.type !== 'thinking')
return {
...tab,
messages: [...filteredMessages, {
id: generateMessageId('error'),
type: 'error' as const,
content: data.content || '发生错误',
timestamp: data.timestamp || Date.now() / 1000,
}]
}
}))
toast({
title: '错误',
description: data.content,
variant: 'destructive',
})
break
case 'history': {
const historyMessages = data.messages || []
const processedSet = new Set<string>()
const formattedMessages: ChatMessage[] = historyMessages.map((msg: {
id?: string
content: string
timestamp: number
sender_name?: string
sender_id?: string
is_bot?: boolean
}) => {
const isBot = msg.is_bot || false
const msgId = msg.id || generateMessageId(isBot ? 'bot' : 'user')
const contentHash = `${isBot ? 'bot' : 'user'}-${msg.content}-${Math.floor(msg.timestamp * 1000)}`
processedSet.add(contentHash)
return {
id: msgId,
type: isBot ? 'bot' : 'user' as const,
content: msg.content,
timestamp: msg.timestamp,
sender: {
name: msg.sender_name || (isBot ? '麦麦' : '用户'),
user_id: msg.sender_id,
is_bot: isBot,
},
}
})
processedMessagesMapRef.current.set(tabId, processedSet)
updateTab(tabId, { messages: formattedMessages })
setIsLoadingHistory(false)
break
}
default:
break
}
}, [addMessageToTab, toast, updateTab])
const ensureSessionListener = useCallback((
tabId: string,
tabType: 'webui' | 'virtual',
config?: VirtualIdentityConfig,
) => {
if (sessionUnsubscribeMapRef.current.has(tabId)) {
return
}
const unsubscribe = chatWsClient.onSessionMessage(tabId, (message) => {
handleSessionMessage(tabId, tabType, config, message as unknown as WsMessage)
})
sessionUnsubscribeMapRef.current.set(tabId, unsubscribe)
}, [handleSessionMessage])
const openSessionForTab = useCallback(async (
tabId: string,
tabType: 'webui' | 'virtual',
config?: VirtualIdentityConfig,
) => {
ensureSessionListener(tabId, tabType, config)
setIsLoadingHistory(true)
try {
const params = new URLSearchParams()
params.append('user_id', userIdRef.current)
params.append('limit', '50')
if (groupId) {
params.append('group_id', groupId)
if (tabType === 'virtual' && config) {
await chatWsClient.openSession(tabId, {
user_id: config.userId,
user_name: config.userName,
platform: config.platform,
person_id: config.personId,
group_name: config.groupName || 'WebUI虚拟群聊',
group_id: config.groupId,
})
} else {
await chatWsClient.openSession(tabId, {
user_id: userIdRef.current,
user_name: userName,
})
}
const url = `/api/chat/history?${params.toString()}`
console.log('[Chat] 正在加载历史消息:', url)
const response = await fetchWithAuth(url)
if (response.ok) {
const text = await response.text()
try {
const data = JSON.parse(text)
if (data.messages && data.messages.length > 0) {
const historyMessages: ChatMessage[] = data.messages.map((msg: {
id: string
type: string
content: string
timestamp: number
sender_name?: string
user_id?: string
is_bot?: boolean
}) => ({
id: msg.id,
type: msg.type as 'user' | 'bot' | 'system' | 'error',
content: msg.content,
timestamp: msg.timestamp,
sender: {
name: msg.sender_name || (msg.is_bot ? '麦麦' : 'WebUI用户'),
user_id: msg.user_id,
is_bot: msg.is_bot
}
}))
// 更新标签页的消息
updateTab(tabId, { messages: historyMessages })
// 将历史消息添加到去重缓存
const processedSet = processedMessagesMapRef.current.get(tabId) || new Set()
historyMessages.forEach(msg => {
if (msg.type === 'bot') {
const contentHash = `bot-${msg.content}-${Math.floor(msg.timestamp * 1000)}`
processedSet.add(contentHash)
}
})
processedMessagesMapRef.current.set(tabId, processedSet)
}
} catch (parseError) {
console.error('[Chat] JSON 解析失败:', parseError)
}
}
} catch (e) {
console.error('[Chat] 加载历史消息失败:', e)
} finally {
setIsLoadingHistory(false)
}
}, [updateTab])
// 为指定标签页连接 WebSocket异步需要先获取认证 token
const connectWebSocketForTab = useCallback(async (tabId: string, tabType: 'webui' | 'virtual', config?: VirtualIdentityConfig) => {
// 如果已经有连接,不要重复创建
const existingWs = wsMapRef.current.get(tabId)
if (existingWs?.readyState === WebSocket.OPEN ||
existingWs?.readyState === WebSocket.CONNECTING) {
console.log(`[Tab ${tabId}] WebSocket 已存在,跳过连接`)
return
}
setIsConnecting(true)
// 先获取临时 WebSocket token
let wsToken: string | null = null
try {
const tokenResponse = await fetchWithAuth('/api/webui/ws-token')
if (tokenResponse.ok) {
const tokenData = await tokenResponse.json()
if (tokenData.success && tokenData.token) {
wsToken = tokenData.token
} else {
console.warn(`[Tab ${tabId}] 获取 WebSocket token 失败: ${tokenData.message || '未登录'}`)
setIsConnecting(false)
return
}
}
updateTab(tabId, { isConnected: true })
} catch (error) {
console.error(`[Tab ${tabId}] 获取 WebSocket token 失败:`, error)
setIsConnecting(false)
return
console.error(`[Tab ${tabId}] 打开聊天会话失败:`, error)
setIsLoadingHistory(false)
toast({
title: '连接失败',
description: '无法建立聊天会话,请稍后重试',
variant: 'destructive',
})
}
// 此时 wsToken 一定有值(前面已经 return
if (!wsToken) {
setIsConnecting(false)
return
}
const wsBase = await getWsBaseUrl()
const params = new URLSearchParams()
// 添加 token 到参数
params.append('token', wsToken)
if (tabType === 'virtual' && config) {
params.append('user_id', config.userId)
params.append('user_name', config.userName)
params.append('platform', config.platform)
params.append('person_id', config.personId)
params.append('group_name', config.groupName || 'WebUI虚拟群聊')
// 传递稳定的 group_id确保历史记录能正确加载
if (config.groupId) {
params.append('group_id', config.groupId)
}
} else {
params.append('user_id', userIdRef.current)
params.append('user_name', userName)
}
const wsUrl = `${wsBase}/api/chat/ws?${params.toString()}`
console.log(`[Tab ${tabId}] 正在连接 WebSocket:`, wsUrl)
try {
const ws = new WebSocket(wsUrl)
wsMapRef.current.set(tabId, ws)
ws.onopen = () => {
updateTab(tabId, { isConnected: true })
setIsConnecting(false)
console.log(`[Tab ${tabId}] WebSocket 已连接`)
}
ws.onmessage = (event) => {
try {
const data: WsMessage = JSON.parse(event.data)
switch (data.type) {
case 'session_info':
updateTab(tabId, {
sessionInfo: {
session_id: data.session_id,
user_id: data.user_id,
user_name: data.user_name,
bot_name: data.bot_name,
}
})
break
case 'system':
addMessageToTab(tabId, {
id: generateMessageId('sys'),
type: 'system',
content: data.content || '',
timestamp: data.timestamp || Date.now() / 1000,
})
break
case 'user_message': {
// 检查是否是自己发的消息(已在发送时显示,跳过广播回来的)
const senderUserId = data.sender?.user_id
const currentUserId = tabType === 'virtual' && config
? config.userId
: userIdRef.current
console.log(`[Tab ${tabId}] 收到 user_message, sender: ${senderUserId}, current: ${currentUserId}`)
// 标准化 user_id去掉可能的前缀
const normalizeSenderId = senderUserId ? senderUserId.replace(/^webui_user_/, '') : ''
const normalizeCurrentId = currentUserId ? currentUserId.replace(/^webui_user_/, '') : ''
// 如果是自己发的消息,跳过(避免重复显示)
if (normalizeSenderId && normalizeCurrentId && normalizeSenderId === normalizeCurrentId) {
console.log(`[Tab ${tabId}] 跳过自己的消息user_id 匹配)`)
break
}
// 额外的消息去重:检查内容和时间戳
const processedSet = processedMessagesMapRef.current.get(tabId) || new Set()
const contentHash = `user-${data.content}-${Math.floor((data.timestamp || 0) * 1000)}`
if (processedSet.has(contentHash)) {
console.log(`[Tab ${tabId}] 跳过自己的消息(内容去重)`)
break
}
processedSet.add(contentHash)
processedMessagesMapRef.current.set(tabId, processedSet)
if (processedSet.size > 100) {
const firstKey = processedSet.values().next().value
if (firstKey) processedSet.delete(firstKey)
}
addMessageToTab(tabId, {
id: data.message_id || generateMessageId('user'),
type: 'user',
content: data.content || '',
timestamp: data.timestamp || Date.now() / 1000,
sender: data.sender,
})
break
}
case 'bot_message': {
updateTab(tabId, { isTyping: false })
const processedSet = processedMessagesMapRef.current.get(tabId) || new Set()
const contentHash = `bot-${data.content}-${Math.floor((data.timestamp || 0) * 1000)}`
if (processedSet.has(contentHash)) {
break
}
processedSet.add(contentHash)
processedMessagesMapRef.current.set(tabId, processedSet)
if (processedSet.size > 100) {
const firstKey = processedSet.values().next().value
if (firstKey) processedSet.delete(firstKey)
}
// 移除"思考中"占位消息,添加真实的机器人回复
setTabs(prev => prev.map(tab => {
if (tab.id !== tabId) return tab
// 过滤掉 thinking 类型的消息
const filteredMessages = tab.messages.filter(msg => msg.type !== 'thinking')
const newMessage: ChatMessage = {
id: generateMessageId('bot'),
type: 'bot',
content: data.content || '',
message_type: (data.message_type === 'rich' ? 'rich' : 'text') as 'text' | 'rich',
segments: data.segments,
timestamp: data.timestamp || Date.now() / 1000,
sender: data.sender,
}
return {
...tab,
messages: [...filteredMessages, newMessage]
}
}))
break
}
case 'typing':
updateTab(tabId, { isTyping: data.is_typing || false })
break
case 'error':
// 移除"思考中"占位消息,显示错误
setTabs(prev => prev.map(tab => {
if (tab.id !== tabId) return tab
const filteredMessages = tab.messages.filter(msg => msg.type !== 'thinking')
return {
...tab,
messages: [...filteredMessages, {
id: generateMessageId('error'),
type: 'error' as const,
content: data.content || '发生错误',
timestamp: data.timestamp || Date.now() / 1000,
}]
}
}))
toast({
title: '错误',
description: data.content,
variant: 'destructive',
})
break
case 'pong':
break
case 'history': {
// 处理服务端发送的历史消息
const historyMessages = data.messages || []
if (historyMessages.length > 0) {
const processedSet = processedMessagesMapRef.current.get(tabId) || new Set()
const formattedMessages: ChatMessage[] = historyMessages.map((msg: {
id?: string
content: string
timestamp: number
sender_name?: string
sender_id?: string
is_bot?: boolean
}) => {
const isBot = msg.is_bot || false
const msgId = msg.id || generateMessageId(isBot ? 'bot' : 'user')
// 添加到去重集合
const contentHash = `${isBot ? 'bot' : 'user'}-${msg.content}-${Math.floor(msg.timestamp * 1000)}`
processedSet.add(contentHash)
return {
id: msgId,
type: isBot ? 'bot' : 'user' as const,
content: msg.content,
timestamp: msg.timestamp,
sender: {
name: msg.sender_name || (isBot ? '麦麦' : '用户'),
user_id: msg.sender_id,
is_bot: isBot,
},
}
})
processedMessagesMapRef.current.set(tabId, processedSet)
// 替换当前标签页的所有消息
updateTab(tabId, { messages: formattedMessages })
console.log(`[Tab ${tabId}] 已加载 ${formattedMessages.length} 条历史消息`)
}
break
}
default:
console.log('未知消息类型:', data.type)
}
} catch (e) {
console.error('解析消息失败:', e)
}
}
ws.onclose = () => {
updateTab(tabId, { isConnected: false })
setIsConnecting(false)
wsMapRef.current.delete(tabId)
console.log(`[Tab ${tabId}] WebSocket 已断开`)
// 清除旧的重连定时器
const oldTimeout = reconnectTimeoutMapRef.current.get(tabId)
if (oldTimeout) {
clearTimeout(oldTimeout)
}
// 5秒后尝试重连
const timeout = window.setTimeout(() => {
if (!isUnmountedRef.current) {
const tab = tabs.find(t => t.id === tabId)
if (tab) {
connectWebSocketForTab(tabId, tab.type, tab.virtualConfig)
}
}
}, 5000)
reconnectTimeoutMapRef.current.set(tabId, timeout)
}
ws.onerror = (error) => {
console.error(`[Tab ${tabId}] WebSocket 错误:`, error)
setIsConnecting(false)
}
} catch (e) {
console.error(`[Tab ${tabId}] 创建 WebSocket 失败:`, e)
setIsConnecting(false)
}
}, [userName, updateTab, addMessageToTab, toast, tabs])
}, [ensureSessionListener, toast, updateTab, userName])
// 用于追踪组件是否已卸载
const isUnmountedRef = useRef(false)
@@ -555,69 +419,49 @@ export function ChatPage() {
// 初始化连接(默认 WebUI 标签页)
useEffect(() => {
isUnmountedRef.current = false
// 保存 ref 的当前值,用于清理
const wsMap = wsMapRef.current
const reconnectTimeoutMap = reconnectTimeoutMapRef.current
const processedMessagesMap = processedMessagesMapRef.current
// 加载默认标签页历史消息
loadChatHistoryForTab('webui-default')
// 延迟连接
const connectTimer = setTimeout(() => {
if (!isUnmountedRef.current) {
connectWebSocketForTab('webui-default', 'webui')
// 恢复的虚拟标签页也需要建立连接
tabs.forEach(tab => {
if (tab.type === 'virtual' && tab.virtualConfig) {
// 初始化去重缓存
processedMessagesMap.set(tab.id, new Set())
// 建立 WebSocket 连接
setTimeout(() => {
if (!isUnmountedRef.current) {
connectWebSocketForTab(tab.id, 'virtual', tab.virtualConfig)
}
}, 200)
}
})
}
}, 100)
// 心跳定时器 - 向所有活动连接发送
const heartbeat = setInterval(() => {
wsMap.forEach((ws) => {
if (ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({ type: 'ping' }))
}
})
}, 30000)
const unsubscribeConnection = chatWsClient.onConnectionChange((connected) => {
if (isUnmountedRef.current) {
return
}
setTabs(prev => prev.map(tab => ({
...tab,
isConnected: connected,
})))
})
const unsubscribeStatus = chatWsClient.onStatusChange((status) => {
if (!isUnmountedRef.current) {
setIsConnecting(status === 'connecting')
}
})
tabs.forEach(tab => {
processedMessagesMapRef.current.set(tab.id, new Set())
void openSessionForTab(tab.id, tab.type, tab.virtualConfig)
})
return () => {
isUnmountedRef.current = true
clearTimeout(connectTimer)
clearInterval(heartbeat)
// 清理所有重连定时器
reconnectTimeoutMap.forEach((timeout) => {
clearTimeout(timeout)
unsubscribeConnection()
unsubscribeStatus()
sessionUnsubscribeMapRef.current.forEach((unsubscribe) => {
unsubscribe()
})
reconnectTimeoutMap.clear()
// 关闭所有 WebSocket 连接
wsMap.forEach((ws) => {
ws.close()
sessionUnsubscribeMapRef.current.clear()
tabsRef.current.forEach(tab => {
void chatWsClient.closeSession(tab.id)
})
wsMap.clear()
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
// 发送消息到当前活动标签页
const sendMessage = useCallback(() => {
const ws = wsMapRef.current.get(activeTabId)
if (!inputValue.trim() || !ws || ws.readyState !== WebSocket.OPEN) {
const sendMessage = useCallback(async () => {
if (!inputValue.trim() || !activeTab?.isConnected) {
return
}
@@ -628,12 +472,6 @@ export function ChatPage() {
const messageContent = inputValue.trim()
const currentTimestamp = Date.now() / 1000
ws.send(JSON.stringify({
type: 'message',
content: messageContent,
user_name: displayName,
}))
// 添加到去重缓存,防止服务器广播回来的消息重复显示
const processedSet = processedMessagesMapRef.current.get(activeTabId) || new Set()
const contentHash = `user-${messageContent}-${Math.floor(currentTimestamp * 1000)}`
@@ -672,13 +510,32 @@ export function ChatPage() {
addMessageToTab(activeTabId, thinkingMessage)
setInputValue('')
}, [inputValue, userName, activeTabId, activeTab, addMessageToTab])
try {
await chatWsClient.sendMessage(activeTabId, messageContent, displayName)
} catch (error) {
console.error('发送聊天消息失败:', error)
setTabs(prev => prev.map(tab => {
if (tab.id !== activeTabId) return tab
return {
...tab,
isTyping: false,
messages: tab.messages.filter(msg => msg.type !== 'thinking')
}
}))
toast({
title: '发送失败',
description: '当前聊天会话不可用,请稍后重试',
variant: 'destructive',
})
}
}, [activeTab, activeTabId, addMessageToTab, inputValue, toast, userName])
// 处理键盘事件
const handleKeyDown = (e: React.KeyboardEvent) => {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault()
sendMessage()
void sendMessage()
}
}
@@ -693,13 +550,9 @@ export function ChatPage() {
setUserName(newName)
saveUserName(newName)
setIsEditingName(false)
// 通知当前标签页的后端昵称变更
const ws = wsMapRef.current.get(activeTabId)
if (ws?.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({
type: 'update_nickname',
user_name: newName
}))
if (activeTab?.isConnected) {
void chatWsClient.updateNickname(activeTabId, newName)
}
}
@@ -719,12 +572,7 @@ export function ChatPage() {
// 重新连接当前标签页
const handleReconnect = () => {
const ws = wsMapRef.current.get(activeTabId)
if (ws) {
ws.close()
wsMapRef.current.delete(activeTabId)
}
connectWebSocketForTab(activeTabId, activeTab?.type || 'webui', activeTab?.virtualConfig)
void chatWsClient.restart()
}
// 打开虚拟身份配置对话框(新建标签页用)
@@ -795,10 +643,10 @@ export function ChatPage() {
// 初始化去重缓存
processedMessagesMapRef.current.set(newTabId, new Set())
// 连接 WebSocket
setTimeout(() => {
connectWebSocketForTab(newTabId, 'virtual', tempVirtualConfig)
}, 100)
void openSessionForTab(newTabId, 'virtual', {
...tempVirtualConfig,
groupId: stableGroupId,
})
toast({
title: '虚拟身份标签页',
@@ -814,20 +662,14 @@ export function ChatPage() {
if (tabId === 'webui-default') {
return
}
// 关闭 WebSocket 连接
const ws = wsMapRef.current.get(tabId)
if (ws) {
ws.close()
wsMapRef.current.delete(tabId)
}
// 清理重连定时器
const timeout = reconnectTimeoutMapRef.current.get(tabId)
if (timeout) {
clearTimeout(timeout)
reconnectTimeoutMapRef.current.delete(tabId)
const unsubscribe = sessionUnsubscribeMapRef.current.get(tabId)
if (unsubscribe) {
unsubscribe()
sessionUnsubscribeMapRef.current.delete(tabId)
}
void chatWsClient.closeSession(tabId)
// 清理去重缓存
processedMessagesMapRef.current.delete(tabId)
@@ -1133,7 +975,7 @@ export function ChatPage() {
className="flex-1 h-10 sm:h-10"
/>
<Button
onClick={sendMessage}
onClick={() => { void sendMessage() }}
disabled={!activeTab?.isConnected || !inputValue.trim()}
size="icon"
className="h-10 w-10 shrink-0"

View File

@@ -93,12 +93,12 @@ function PluginsPageContent() {
// 统一管理 WebSocket 和数据加载
useEffect(() => {
let ws: WebSocket | null = null
let unsubscribeProgress: (() => Promise<void>) | null = null
let isUnmounted = false
const init = async () => {
// 1. 先连接 WebSocket异步获取 token
ws = await connectPluginProgressWebSocket(
unsubscribeProgress = await connectPluginProgressWebSocket(
(progress) => {
if (isUnmounted) return
@@ -128,29 +128,7 @@ function PluginsPageContent() {
}
)
// 2. 等待 WebSocket 连接建立
await new Promise<void>((resolve) => {
if (!ws) {
resolve()
return
}
const checkConnection = () => {
if (ws && ws.readyState === WebSocket.OPEN) {
console.log('WebSocket connected, starting to load plugins')
resolve()
} else if (ws && ws.readyState === WebSocket.CLOSED) {
console.warn('WebSocket closed before loading plugins')
resolve()
} else {
setTimeout(checkConnection, 100)
}
}
checkConnection()
})
// 3. 检查 Git 状态
// 2. 检查 Git 状态
if (!isUnmounted) {
const statusResult = await checkGitStatus()
if (!statusResult.success) {
@@ -173,7 +151,7 @@ function PluginsPageContent() {
}
}
// 4. 获取麦麦版本
// 3. 获取麦麦版本
if (!isUnmounted) {
const versionResult = await getMaimaiVersion()
if (!versionResult.success) {
@@ -186,7 +164,7 @@ function PluginsPageContent() {
setMaimaiVersion(versionResult.data)
}
}
// 5. 加载插件列表(包含已安装信息)
// 4. 加载插件列表(包含已安装信息)
if (!isUnmounted) {
try {
setLoading(true)
@@ -282,8 +260,8 @@ function PluginsPageContent() {
return () => {
isUnmounted = true
if (ws) {
ws.close()
if (unsubscribeProgress) {
void unsubscribeProgress()
}
}
}, [toast])

View File

@@ -1,887 +0,0 @@
{
"1": [
{
"id": "know_1_1774770946.623486",
"content": "备战中考",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:55:46.623486"
},
{
"id": "know_1_1774771765.051286",
"content": "性别为女性",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:09:25.051286"
},
{
"id": "know_1_1774771851.333504",
"content": "用户是I人内向型人格",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:10:51.333504"
},
{
"id": "know_1_1774771894.517183",
"content": "用户名为小千,被他人称为“宝宝”,结合语境推测为女性",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:11:34.517183"
},
{
"id": "know_1_1774771923.859455",
"content": "小千是I人内向型人格",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:12:03.859455"
},
{
"id": "know_1_1774771993.479732",
"content": "小千是女性",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:13:13.479732"
},
{
"id": "know_1_1774772079.496335",
"content": "用户名为小千,被他人称为“宝宝”,推测为女性或处于亲密社交语境中(注:性别非明确陈述,但基于昵称高频使用及语境,高置信度归纳为女性或女性化称呼偏好,若严格遵循“明确表达”则此项存疑。鉴于指令要求“高置信度可归纳”,且群内互动模式符合典型女性向昵称习惯,此处提取为倾向性事实)",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:14:39.496335"
},
{
"id": "know_1_1774773435.68612",
"content": "用户名为小千",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:37:15.686120"
},
{
"id": "know_1_1774773676.69252",
"content": "用户自称猫娘(二次元人设)",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:41:16.692520"
}
],
"2": [
{
"id": "know_2_1774768612.298128",
"content": "性格自信,常以“真理在我这边”自居",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:16:52.298128"
},
{
"id": "know_2_1774768645.029561",
"content": "性格自信且带有自嘲精神,喜欢用轻松调侃的方式应对他人评价",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:17:25.029561"
},
{
"id": "know_2_1774771068.355999",
"content": "喜欢用夸张、幽默或古风修辞表达观点",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:57:48.355999"
},
{
"id": "know_2_1774771397.764996",
"content": "性格幽默,喜欢使用夸张比喻和古风表达",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:03:17.764996"
},
{
"id": "know_2_1774771471.03367",
"content": "幽默风趣,喜欢使用夸张比喻和玩梗",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:04:31.033670"
},
{
"id": "know_2_1774771765.052285",
"content": "性格不孤僻,社交圈较广",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:09:25.052285"
},
{
"id": "know_2_1774771851.33601",
"content": "用户表现出社恐倾向,喜欢回避社交互动",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:10:51.336010"
},
{
"id": "know_2_1774771894.520185",
"content": "性格偏向内向I人有社恐倾向喜欢回避社交压力",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:11:34.520185"
},
{
"id": "know_2_1774771958.585244",
"content": "小千是内向型人格I人",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:12:38.585244"
},
{
"id": "know_2_1774771993.481732",
"content": "小千性格内向I人",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:13:13.481732"
}
],
"3": [
{
"id": "know_3_1774773676.695521",
"content": "喜欢冰淇淋",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:41:16.695521"
}
],
"4": [],
"5": [],
"6": [
{
"id": "know_6_1774768486.451792",
"content": "正在搭建 RAG 测试集",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:14:46.451792"
},
{
"id": "know_6_1774768517.122405",
"content": "熟悉 NapCat、RAG 等技术工具及互联网梗文化",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:15:17.122405"
},
{
"id": "know_6_1774769406.247087",
"content": "喜欢动漫风格插画",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:30:06.247087"
},
{
"id": "know_6_1774770487.207364",
"content": "关注显卡硬件参数(如显存、型号)及深度学习/炼丹应用",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:48:07.207364"
},
{
"id": "know_6_1774770487.209372",
"content": "对游戏光影效果感兴趣",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:48:07.209372"
},
{
"id": "know_6_1774770603.063873",
"content": "喜欢玩《我的世界》和VRChat",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:50:03.063873"
},
{
"id": "know_6_1774770654.654349",
"content": "关注显卡硬件参数如4090、48G显存、5090",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:50:54.654349"
},
{
"id": "know_6_1774770654.655356",
"content": "使用VRChat进行社交娱乐",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:50:54.655356"
},
{
"id": "know_6_1774770734.287947",
"content": "关注显卡硬件如4090、3050及AI炼丹技术",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:52:14.287947"
},
{
"id": "know_6_1774770734.289944",
"content": "玩《我的世界》并配置光影效果",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:52:14.289944"
},
{
"id": "know_6_1774770734.291944",
"content": "计划游玩VRChat",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:52:14.291944"
},
{
"id": "know_6_1774771033.111011",
"content": "喜欢玩VRChat",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:57:13.111011"
},
{
"id": "know_6_1774771068.358999",
"content": "关注VRChat等虚拟现实游戏及硬件性能话题",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:57:48.358999"
},
{
"id": "know_6_1774771233.980219",
"content": "使用VRChatVRC",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:00:33.980219"
},
{
"id": "know_6_1774771397.766996",
"content": "对VRChatVRC及虚拟形象社交感兴趣",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:03:17.766996"
},
{
"id": "know_6_1774771471.03567",
"content": "对VRChat等虚拟社交游戏感兴趣",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:04:31.035670"
},
{
"id": "know_6_1774771894.521183",
"content": "熟悉二次元文化、动漫角色及互联网流行梗Meme",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:11:34.521183"
},
{
"id": "know_6_1774771923.861534",
"content": "小千玩CS:GO游戏",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:12:03.861534"
},
{
"id": "know_6_1774771958.587243",
"content": "回声者_Echoderd喜欢玩CS:GO游戏",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:12:38.587243"
},
{
"id": "know_6_1774771993.483732",
"content": "小千喜欢二次元文化及动漫游戏圈梗",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:13:13.483732"
},
{
"id": "know_6_1774772079.499335",
"content": "熟悉并喜爱二次元文化、动漫角色及互联网梗图(如阴间美学、病娇系、黑长直萌妹等风格)",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:14:39.499335"
},
{
"id": "know_6_1774772112.716455",
"content": "小千关注CS:GO游戏及中考备考话题",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:15:12.716455"
},
{
"id": "know_6_1774772154.873237",
"content": "用户玩CS:GO游戏",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:15:54.873237"
},
{
"id": "know_6_1774772186.438797",
"content": "玩CS:GO游戏",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:16:26.438797"
},
{
"id": "know_6_1774772730.867535",
"content": "熟悉《我的青春恋爱物语果然有问题》及二次元表情包文化",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:25:30.867535"
},
{
"id": "know_6_1774773338.849271",
"content": "熟悉《原神》等二次元游戏及网络梗文化",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:35:38.849271"
},
{
"id": "know_6_1774773371.406209",
"content": "关注高分屏字体显示效果",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:36:11.406209"
},
{
"id": "know_6_1774773401.48921",
"content": "熟悉电脑显示技术(如高分屏字体选择)",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:36:41.489210"
},
{
"id": "know_6_1774773435.688119",
"content": "关注高分屏显示效果与字体选择(无衬线/衬线体)",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:37:15.688119"
},
{
"id": "know_6_1774773608.256103",
"content": "关注屏幕字体与分辨率(无衬线/有衬线)",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:40:08.256103"
},
{
"id": "know_6_1774773645.671546",
"content": "关注屏幕分辨率与字体显示效果(高分屏/无衬线体)",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:40:45.671546"
},
{
"id": "know_6_1774773676.698035",
"content": "关注字体设计(无衬线体)",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:41:16.698035"
},
{
"id": "know_6_1774773740.83822",
"content": "喜欢二次元文化及 VTuber 风格内容",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:42:20.838220"
}
],
"7": [
{
"id": "know_7_1774768517.120403",
"content": "从事 RAG 测试集搭建或相关技术工作",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:15:17.120403"
},
{
"id": "know_7_1774768573.741823",
"content": "从事 RAG检索增强生成测试集搭建相关工作",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:16:13.741823"
},
{
"id": "know_7_1774770603.062873",
"content": "备战中考",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:50:03.062873"
},
{
"id": "know_7_1774771471.036668",
"content": "正在备战中考的学生",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:04:31.036668"
},
{
"id": "know_7_1774771923.862535",
"content": "小千正在备战中考",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:12:03.862535"
},
{
"id": "know_7_1774771958.588749",
"content": "回声者_Echoderd正在备战中考",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:12:38.588749"
},
{
"id": "know_7_1774772112.714455",
"content": "小千使用AI模型进行对话",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:15:12.714455"
},
{
"id": "know_7_1774772154.870238",
"content": "用户正在备战中考",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:15:54.870238"
},
{
"id": "know_7_1774773185.194069",
"content": "使用 NapCat 框架",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:33:05.194069"
},
{
"id": "know_7_1774773338.851275",
"content": "使用 NapCat 框架,具备技术平台认知能力",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:35:38.851275"
},
{
"id": "know_7_1774773371.403696",
"content": "熟悉 NapCat 框架",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:36:11.403696"
}
],
"8": [
{
"id": "know_8_1774770946.624486",
"content": "日常逛游戏地图",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:55:46.624486"
},
{
"id": "know_8_1774771397.769034",
"content": "备考中考期间仍保持日常游戏娱乐习惯",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:03:17.769034"
},
{
"id": "know_8_1774771851.338018",
"content": "用户有备考中考的学习任务",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:10:51.338018"
},
{
"id": "know_8_1774771894.523189",
"content": "备考中(备战中考)",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:11:34.523189"
},
{
"id": "know_8_1774771993.484733",
"content": "小千有打CS:GO的游戏习惯",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:13:13.484733"
},
{
"id": "know_8_1774772079.501334",
"content": "有在高压环境下如中考前进行游戏娱乐CS:GO的习惯自称或认同“摆烂”的生活态度",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:14:39.501334"
},
{
"id": "know_8_1774772154.875743",
"content": "用户在备考期间有打游戏摸鱼的习惯",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:15:54.875743"
},
{
"id": "know_8_1774773435.690121",
"content": "习惯使用表情包表达情绪或进行网络互动",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:37:15.690121"
},
{
"id": "know_8_1774773676.701034",
"content": "备战中考",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:41:16.701034"
}
],
"9": [],
"10": [
{
"id": "know_10_1774768486.452792",
"content": "沟通风格带有调侃和自信,习惯用反问句表达观点",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:14:46.452792"
},
{
"id": "know_10_1774768517.121403",
"content": "沟通风格带有较强的好胜心和防御性,习惯用反问和调侃回应质疑",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:15:17.121403"
},
{
"id": "know_10_1774768573.742824",
"content": "沟通风格幽默,擅长使用逻辑闭环和反问句式进行辩论或调侃",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:16:13.742824"
},
{
"id": "know_10_1774768612.299126",
"content": "沟通风格幽默风趣,擅长使用网络梗和表情包互动",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:16:52.299126"
},
{
"id": "know_10_1774768612.299845",
"content": "偶尔会文绉绉地表达(自称“文青病犯了”),但能迅速切换回口语化",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:16:52.299845"
},
{
"id": "know_10_1774768645.028561",
"content": "沟通风格幽默风趣,偶尔会文青病发作使用古风表达",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:17:25.028561"
},
{
"id": "know_10_1774769406.249584",
"content": "沟通中常使用文言文或半文言表达",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:30:06.249584"
},
{
"id": "know_10_1774769406.251097",
"content": "习惯用反问句和夸张语气进行互动",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:30:06.251097"
},
{
"id": "know_10_1774770487.211056",
"content": "沟通风格幽默,常使用网络梗和夸张表达",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:48:07.211056"
},
{
"id": "know_10_1774771471.038677",
"content": "沟通风格轻松随意,善于接话和调侃",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:04:31.038677"
},
{
"id": "know_10_1774771765.053285",
"content": "沟通风格活泼,喜欢使用语气词和表情符号撒娇",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:09:25.053285"
},
{
"id": "know_10_1774772079.503333",
"content": "沟通风格幽默调侃,擅长用反话(如“烦到了”)和夸张修辞(如“耳朵起茧子”、“要报警了”)表达情绪",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:14:39.503333"
},
{
"id": "know_10_1774773338.853274",
"content": "沟通风格幽默风趣,擅长玩梗与自嘲",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:35:38.853274"
},
{
"id": "know_10_1774773371.408719",
"content": "喜欢用幽默调侃的方式回应他人",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:36:11.408719"
},
{
"id": "know_10_1774773401.491209",
"content": "沟通风格幽默风趣,擅长玩梗和角色扮演",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:36:41.491209"
},
{
"id": "know_10_1774773435.693121",
"content": "沟通风格幽默、喜欢玩梗和自嘲,擅长接话茬",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:37:15.693121"
},
{
"id": "know_10_1774773532.488374",
"content": "沟通风格幽默,喜欢使用网络梗和表情包活跃气氛",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:38:52.488374"
},
{
"id": "know_10_1774773532.490959",
"content": "在争论中倾向于据理力争,并自嘲或调侃对方阅读理解能力",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:38:52.490959"
},
{
"id": "know_10_1774773569.709356",
"content": "喜欢用幽默、夸张和自嘲的方式活跃气氛",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:39:29.709356"
}
],
"11": [
{
"id": "know_11_1774771068.360999",
"content": "乐于接受并学习新的技术技巧(如加速器用法)",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:57:48.360999"
}
],
"12": [
{
"id": "know_12_1774770654.657355",
"content": "面对网络延迟问题倾向于寻找加速器解决方案",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:50:54.657355"
},
{
"id": "know_12_1774773185.196068",
"content": "备战中考",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:33:05.196068"
},
{
"id": "know_12_1774773740.836223",
"content": "面对压力或冲突时,倾向于通过撒娇、耍赖和寻求盟友支持来应对",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:42:20.836223"
}
]
}

View File

@@ -13,12 +13,13 @@ import pytest
from src.plugin_runtime.component_query import component_query_service
from src.plugin_runtime.protocol.envelope import (
Envelope,
InspectPluginConfigPayload,
MessageType,
RegisterPluginPayload,
ValidatePluginConfigPayload,
)
from src.plugin_runtime.runner.runner_main import PluginRunner
from src.webui.routers.plugin.config_routes import update_plugin_config
from src.webui.routers.plugin.config_routes import get_plugin_config, get_plugin_config_schema, update_plugin_config
from src.webui.routers.plugin.schemas import UpdatePluginConfigRequest
@@ -56,6 +57,61 @@ class _DemoConfigPlugin:
self.received_config = config
def get_default_config(self) -> Dict[str, Any]:
"""返回测试插件的默认配置。
Returns:
Dict[str, Any]: 默认配置字典。
"""
return {"plugin": {"enabled": True, "retry_count": 3}}
def get_webui_config_schema(
self,
*,
plugin_id: str = "",
plugin_name: str = "",
plugin_version: str = "",
plugin_description: str = "",
plugin_author: str = "",
) -> Dict[str, Any]:
"""返回测试插件的 WebUI 配置 Schema。
Args:
plugin_id: 插件 ID。
plugin_name: 插件名称。
plugin_version: 插件版本。
plugin_description: 插件描述。
plugin_author: 插件作者。
Returns:
Dict[str, Any]: 测试配置 Schema。
"""
del plugin_name, plugin_description, plugin_author
return {
"plugin_id": plugin_id,
"plugin_info": {
"name": "Demo",
"version": plugin_version,
"description": "",
"author": "",
},
"sections": {
"plugin": {
"fields": {
"enabled": {
"type": "boolean",
"label": "启用",
"default": True,
"ui_type": "switch",
}
}
}
},
"layout": {"type": "auto", "tabs": []},
}
class _StrictConfigPlugin:
"""用于测试配置校验错误的伪插件。"""
@@ -173,6 +229,63 @@ async def test_runner_validate_plugin_config_handler_returns_normalized_config(m
assert response.payload["normalized_config"] == {"plugin": {"enabled": False, "retry_count": 3}}
@pytest.mark.asyncio
async def test_runner_inspect_plugin_config_handler_supports_unloaded_plugin(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Runner 应支持对未加载插件执行冷检查。"""
plugin = _DemoConfigPlugin()
runner = PluginRunner(
host_address="ipc://unused",
session_token="session-token",
plugin_dirs=[],
)
meta = SimpleNamespace(
plugin_id="demo.plugin",
plugin_dir="/tmp/demo-plugin",
instance=plugin,
manifest=SimpleNamespace(
name="Demo",
description="",
author=SimpleNamespace(name="tester"),
),
version="1.0.0",
)
purged_plugins: list[tuple[str, str]] = []
monkeypatch.setattr(
runner,
"_resolve_plugin_meta_for_config_request",
lambda plugin_id: (meta, True, None) if plugin_id == "demo.plugin" else (None, False, "not-found"),
)
monkeypatch.setattr(
runner._loader,
"purge_plugin_modules",
lambda plugin_id, plugin_dir: purged_plugins.append((plugin_id, plugin_dir)),
)
envelope = Envelope(
request_id=1,
message_type=MessageType.REQUEST,
method="plugin.inspect_config",
plugin_id="demo.plugin",
payload=InspectPluginConfigPayload(
config_data={"plugin": {"enabled": False}},
use_provided_config=True,
).model_dump(),
)
response = await runner._handle_inspect_plugin_config(envelope)
assert response.error is None
assert response.payload["success"] is True
assert response.payload["enabled"] is False
assert response.payload["normalized_config"] == {"plugin": {"enabled": False, "retry_count": 3}}
assert response.payload["default_config"] == {"plugin": {"enabled": True, "retry_count": 3}}
assert purged_plugins == [("demo.plugin", "/tmp/demo-plugin")]
@pytest.mark.asyncio
async def test_runner_validate_plugin_config_handler_returns_error_on_invalid_config(
monkeypatch: pytest.MonkeyPatch,
@@ -251,3 +364,73 @@ async def test_update_plugin_config_prefers_runtime_validation(
with config_path.open("rb") as handle:
saved_config = tomllib.load(handle)
assert saved_config == {"plugin": {"enabled": False, "retry_count": 3}}
@pytest.mark.asyncio
async def test_webui_config_endpoints_use_runtime_inspection_for_unloaded_plugin(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
"""WebUI 在插件未加载时也应从代码定义返回配置与 Schema。"""
async def _mock_inspect_plugin_config(
plugin_id: str,
config_data: Optional[Dict[str, Any]] = None,
*,
use_provided_config: bool = False,
) -> SimpleNamespace | None:
"""返回运行时冷检查结果。
Args:
plugin_id: 插件 ID。
config_data: 可选配置。
use_provided_config: 是否使用传入配置。
Returns:
SimpleNamespace | None: 冷检查结果。
"""
del config_data, use_provided_config
if plugin_id != "demo.plugin":
return None
return SimpleNamespace(
config_schema={
"plugin_id": "demo.plugin",
"plugin_info": {
"name": "Demo",
"version": "1.0.0",
"description": "",
"author": "",
},
"sections": {"plugin": {"fields": {}}},
"layout": {"type": "auto", "tabs": []},
},
normalized_config={"plugin": {"enabled": True, "retry_count": 3}},
enabled=True,
)
fake_runtime_manager = SimpleNamespace(inspect_plugin_config=_mock_inspect_plugin_config)
monkeypatch.setattr(
"src.webui.routers.plugin.config_routes.require_plugin_token",
lambda session: session or "session-token",
)
monkeypatch.setattr(
"src.webui.routers.plugin.config_routes.find_plugin_path_by_id",
lambda plugin_id: tmp_path if plugin_id == "demo.plugin" else None,
)
monkeypatch.setattr(
"src.plugin_runtime.integration.get_plugin_runtime_manager",
lambda: fake_runtime_manager,
)
schema_response = await get_plugin_config_schema("demo.plugin", maibot_session="session-token")
config_response = await get_plugin_config("demo.plugin", maibot_session="session-token")
assert schema_response["success"] is True
assert schema_response["schema"]["plugin_id"] == "demo.plugin"
assert config_response == {
"success": True,
"config": {"plugin": {"enabled": True, "retry_count": 3}},
"message": "配置文件不存在,已返回默认配置",
}

View File

@@ -5,7 +5,7 @@
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Awaitable, Callable, Dict, List, Optional
from typing import Any, Awaitable, Callable, Dict, List, Optional, Sequence
import asyncio
import json
@@ -1405,6 +1405,57 @@ class TestComponentRegistry:
assert warnings
assert "plugin_a.broken" in warnings[0]
def test_register_hook_handler_rejects_unknown_hook(self):
from src.plugin_runtime.host.component_registry import ComponentRegistrationError, ComponentRegistry
from src.plugin_runtime.host.hook_spec_registry import HookSpecRegistry
reg = ComponentRegistry(hook_spec_registry=HookSpecRegistry())
with pytest.raises(ComponentRegistrationError, match="未注册的 Hook"):
reg.register_component(
"broken_hook",
"hook_handler",
"plugin_a",
{
"hook": "chat.receive.unknown",
"mode": "blocking",
},
)
def test_register_plugin_components_is_atomic_when_hook_invalid(self):
from src.plugin_runtime.host.component_registry import ComponentRegistrationError, ComponentRegistry
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
hook_spec_registry = HookSpecRegistry()
hook_spec_registry.register_hook_spec(HookSpec(name="chat.receive.before_process"))
reg = ComponentRegistry(hook_spec_registry=hook_spec_registry)
reg.register_plugin_components(
"plugin_a",
[
{"name": "cmd_old", "component_type": "command", "metadata": {"command_pattern": r"^/old"}},
],
)
with pytest.raises(ComponentRegistrationError, match="未注册的 Hook"):
reg.register_plugin_components(
"plugin_a",
[
{
"name": "hook_ok",
"component_type": "hook_handler",
"metadata": {"hook": "chat.receive.before_process", "mode": "blocking"},
},
{
"name": "hook_bad",
"component_type": "hook_handler",
"metadata": {"hook": "chat.receive.missing", "mode": "blocking"},
},
],
)
assert reg.get_component("plugin_a.cmd_old") is not None
assert reg.get_component("plugin_a.hook_ok") is None
def test_query_by_type(self):
from src.plugin_runtime.host.component_registry import ComponentRegistry
@@ -2142,6 +2193,18 @@ class TestPluginRuntimeHookEntry:
assert result.kwargs["session_id"] == "s-1"
assert ("b1", "builtin_guard") in call_log
def test_manager_lists_builtin_hook_specs(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""PluginRuntimeManager 应暴露内置 Hook 规格清单。"""
_ComponentRegistry, PluginRuntimeManager = self._import_manager_modules(monkeypatch)
manager = PluginRuntimeManager()
hook_names = {spec.name for spec in manager.list_hook_specs()}
assert "chat.receive.before_process" in hook_names
assert "send_service.before_send" in hook_names
assert "maisaka.planner.after_response" in hook_names
class TestRPCServer:
"""RPC Server 代际保护测试"""
@@ -2974,6 +3037,16 @@ class TestIntegration:
self._registered_plugins = {plugin_id: object() for plugin_id in plugins}
self.config_updates = []
async def inspect_plugin_config(
self,
plugin_id: str,
config_data: Optional[Dict[str, Any]] = None,
use_provided_config: bool = False,
) -> SimpleNamespace:
"""返回测试用的配置解析结果。"""
del config_data, use_provided_config
return SimpleNamespace(enabled=True, normalized_config={"enabled": True}, plugin_id=plugin_id)
async def notify_plugin_config_updated(
self,
plugin_id,
@@ -2997,6 +3070,110 @@ class TestIntegration:
assert manager._builtin_supervisor.config_updates == [("test.alpha", {"enabled": True}, "", "self")]
assert manager._third_party_supervisor.config_updates == []
@pytest.mark.asyncio
async def test_handle_plugin_config_changes_loads_unloaded_enabled_plugin(self, monkeypatch, tmp_path):
from src.plugin_runtime import integration as integration_module
from src.config.file_watcher import FileChange
import json
thirdparty_root = tmp_path / "plugins"
alpha_dir = thirdparty_root / "alpha"
alpha_dir.mkdir(parents=True)
(alpha_dir / "config.toml").write_text("[plugin]\nenabled = true\n", encoding="utf-8")
(alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
(alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8")
monkeypatch.chdir(tmp_path)
class FakeSupervisor:
def __init__(self, plugin_dirs):
self._plugin_dirs = plugin_dirs
self._registered_plugins = {}
async def inspect_plugin_config(
self,
plugin_id: str,
config_data: Optional[Dict[str, Any]] = None,
use_provided_config: bool = False,
) -> SimpleNamespace:
"""返回测试用的启用配置快照。"""
del config_data, use_provided_config
return SimpleNamespace(enabled=True, normalized_config={"plugin": {"enabled": True}}, plugin_id=plugin_id)
manager = integration_module.PluginRuntimeManager()
manager._started = True
manager._third_party_supervisor = FakeSupervisor([thirdparty_root])
load_calls = []
async def fake_load_plugin_globally(plugin_id: str, reason: str = "manual") -> bool:
"""记录自动加载调用。"""
load_calls.append((plugin_id, reason))
return True
monkeypatch.setattr(manager, "load_plugin_globally", fake_load_plugin_globally)
await manager._handle_plugin_config_changes(
"test.alpha",
[FileChange(change_type=1, path=alpha_dir / "config.toml")],
)
assert load_calls == [("test.alpha", "config_enabled")]
@pytest.mark.asyncio
async def test_handle_plugin_config_changes_unloads_loaded_disabled_plugin(self, monkeypatch, tmp_path):
from src.plugin_runtime import integration as integration_module
from src.config.file_watcher import FileChange
import json
builtin_root = tmp_path / "src" / "plugins" / "built_in"
alpha_dir = builtin_root / "alpha"
alpha_dir.mkdir(parents=True)
(alpha_dir / "config.toml").write_text("[plugin]\nenabled = false\n", encoding="utf-8")
(alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
(alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8")
monkeypatch.chdir(tmp_path)
class FakeSupervisor:
def __init__(self, plugin_dirs, plugins):
self._plugin_dirs = plugin_dirs
self._registered_plugins = {plugin_id: object() for plugin_id in plugins}
async def inspect_plugin_config(
self,
plugin_id: str,
config_data: Optional[Dict[str, Any]] = None,
use_provided_config: bool = False,
) -> SimpleNamespace:
"""返回测试用的禁用配置快照。"""
del config_data, use_provided_config
return SimpleNamespace(
enabled=False,
normalized_config={"plugin": {"enabled": False}},
plugin_id=plugin_id,
)
manager = integration_module.PluginRuntimeManager()
manager._started = True
manager._builtin_supervisor = FakeSupervisor([builtin_root], ["test.alpha"])
reload_calls = []
async def fake_reload_plugins_globally(plugin_ids: Sequence[str], reason: str = "manual") -> bool:
"""记录自动卸载调用。"""
reload_calls.append((list(plugin_ids), reason))
return True
monkeypatch.setattr(manager, "reload_plugins_globally", fake_reload_plugins_globally)
await manager._handle_plugin_config_changes(
"test.alpha",
[FileChange(change_type=1, path=alpha_dir / "config.toml")],
)
assert reload_calls == [(["test.alpha"], "config_disabled")]
@pytest.mark.asyncio
async def test_handle_main_config_reload_only_notifies_subscribers(self, monkeypatch):
from src.plugin_runtime import integration as integration_module
@@ -3108,6 +3285,55 @@ class TestIntegration:
subscription["paths"][0] for subscription in manager._plugin_file_watcher.subscriptions
} == {alpha_dir / "config.toml", beta_dir / "config.toml"}
def test_refresh_plugin_config_watch_subscriptions_includes_unloaded_plugins(self, tmp_path):
from src.plugin_runtime import integration as integration_module
import json
thirdparty_root = tmp_path / "plugins"
alpha_dir = thirdparty_root / "alpha"
beta_dir = thirdparty_root / "beta"
alpha_dir.mkdir(parents=True)
beta_dir.mkdir(parents=True)
(alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
(beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
(alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8")
(beta_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.beta")), encoding="utf-8")
class FakeWatcher:
def __init__(self):
self.subscriptions = []
def subscribe(
self,
callback: Any,
*,
paths: Optional[Sequence[Path]] = None,
change_types: Any = None,
) -> str:
"""记录新的监听订阅。"""
del callback, change_types
subscription_id = f"sub-{len(self.subscriptions) + 1}"
self.subscriptions.append({"id": subscription_id, "paths": tuple(paths or ())})
return subscription_id
def unsubscribe(self, subscription_id: str) -> bool:
"""兼容 watcher 取消订阅接口。"""
del subscription_id
return True
class FakeSupervisor:
def __init__(self, plugin_dirs, plugins):
self._plugin_dirs = plugin_dirs
self._registered_plugins = {plugin_id: object() for plugin_id in plugins}
manager = integration_module.PluginRuntimeManager()
manager._plugin_file_watcher = FakeWatcher()
manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["test.alpha"])
manager._refresh_plugin_config_watch_subscriptions()
assert set(manager._plugin_config_watcher_subscriptions.keys()) == {"test.alpha", "test.beta"}
@pytest.mark.asyncio
async def test_component_reload_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch):
from src.plugin_runtime import integration as integration_module

View File

@@ -0,0 +1,136 @@
"""业务命名 Hook 集成测试。"""
from types import SimpleNamespace
from typing import Any
import os
import sys
import pytest
# 确保项目根目录在 sys.path 中
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
# SDK 包路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "packages", "maibot-plugin-sdk"))
class _FakeHookManager:
"""用于业务 Hook 测试的最小运行时管理器。"""
def __init__(self, responses: dict[str, SimpleNamespace]) -> None:
"""初始化测试管理器。
Args:
responses: 按 Hook 名称预设的返回结果映射。
"""
self._responses = responses
self.calls: list[tuple[str, dict[str, Any]]] = []
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> SimpleNamespace:
"""模拟调用运行时命名 Hook。
Args:
hook_name: 目标 Hook 名称。
**kwargs: 传入 Hook 的参数。
Returns:
SimpleNamespace: 预设的 Hook 返回结果。
"""
self.calls.append((hook_name, dict(kwargs)))
return self._responses.get(hook_name, SimpleNamespace(kwargs=dict(kwargs), aborted=False))
def test_builtin_hook_catalog_includes_new_business_hooks(monkeypatch: pytest.MonkeyPatch) -> None:
"""内置 Hook 目录应包含三个业务系统新增的 Hook。"""
monkeypatch.setattr(sys, "exit", lambda code=0: None)
from src.plugin_runtime.hook_catalog import register_builtin_hook_specs
from src.plugin_runtime.host.hook_spec_registry import HookSpecRegistry
registry = HookSpecRegistry()
hook_names = {spec.name for spec in register_builtin_hook_specs(registry)}
assert "emoji.maisaka.before_select" in hook_names
assert "emoji.register.after_build_emotion" in hook_names
assert "jargon.extract.before_persist" in hook_names
assert "jargon.query.after_search" in hook_names
assert "expression.select.before_select" in hook_names
assert "expression.learn.before_upsert" in hook_names
@pytest.mark.asyncio
async def test_send_emoji_for_maisaka_can_be_aborted_by_hook(monkeypatch: pytest.MonkeyPatch) -> None:
"""表情包系统应允许在选择前被 Hook 中止。"""
from src.chat.emoji_system import maisaka_tool
fake_manager = _FakeHookManager(
{
"emoji.maisaka.before_select": SimpleNamespace(
kwargs={"abort_message": "插件阻止了表情发送。"},
aborted=True,
)
}
)
monkeypatch.setattr(maisaka_tool, "_get_runtime_manager", lambda: fake_manager)
result = await maisaka_tool.send_emoji_for_maisaka(stream_id="stream-1", requested_emotion="开心")
assert result.success is False
assert result.message == "插件阻止了表情发送。"
assert fake_manager.calls[0][0] == "emoji.maisaka.before_select"
@pytest.mark.asyncio
async def test_jargon_extract_can_be_aborted_before_persist(monkeypatch: pytest.MonkeyPatch) -> None:
"""黑话提取结果应允许在写库前被 Hook 中止。"""
from src.learners.jargon_miner import JargonMiner
fake_manager = _FakeHookManager(
{
"jargon.extract.before_persist": SimpleNamespace(
kwargs={"entries": []},
aborted=True,
)
}
)
monkeypatch.setattr(JargonMiner, "_get_runtime_manager", staticmethod(lambda: fake_manager))
miner = JargonMiner(session_id="session-1", session_name="测试会话")
await miner.process_extracted_entries(
[{"content": "yyds", "raw_content": {"[1] yyds 太强了"}}],
)
assert fake_manager.calls[0][0] == "jargon.extract.before_persist"
assert fake_manager.calls[0][1]["session_id"] == "session-1"
@pytest.mark.asyncio
async def test_expression_selection_can_be_aborted_by_hook(monkeypatch: pytest.MonkeyPatch) -> None:
"""表达方式选择流程应允许在开始前被 Hook 中止。"""
from src.learners.expression_selector import ExpressionSelector
fake_manager = _FakeHookManager(
{
"expression.select.before_select": SimpleNamespace(
kwargs={},
aborted=True,
)
}
)
monkeypatch.setattr(ExpressionSelector, "_get_runtime_manager", staticmethod(lambda: fake_manager))
monkeypatch.setattr(ExpressionSelector, "can_use_expression_for_chat", lambda self, chat_id: True)
selector = ExpressionSelector()
selected_expressions, selected_ids = await selector.select_suitable_expressions(
chat_id="session-1",
chat_info="用户刚刚发来一条消息。",
)
assert selected_expressions == []
assert selected_ids == []
assert fake_manager.calls[0][0] == "expression.select.before_select"

View File

@@ -1,25 +1,28 @@
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from rich.traceback import install
from sqlmodel import select
from typing import Any, Dict, List, Optional, Tuple
import asyncio
import hashlib
import heapq
import Levenshtein
import random
import re
from src.common.logger import get_logger
from rich.traceback import install
from sqlmodel import select
import Levenshtein
from src.common.data_models.image_data_model import MaiEmoji
from src.common.database.database_model import Images, ImageType
from src.common.database.database import get_db_session, get_db_session_manual
from src.common.utils.utils_image import ImageUtils
from src.prompt.prompt_manager import prompt_manager
from src.config.config import config_manager, global_config
from src.common.data_models.llm_service_data_models import LLMGenerationOptions, LLMImageOptions
from src.common.database.database import get_db_session, get_db_session_manual
from src.common.database.database_model import Images, ImageType
from src.common.logger import get_logger
from src.common.utils.utils_image import ImageUtils
from src.config.config import config_manager, global_config
from src.plugin_runtime.hook_schema_utils import build_object_schema
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
from src.prompt.prompt_manager import prompt_manager
from src.services.llm_service import LLMServiceClient
logger = get_logger("emoji")
@@ -33,6 +36,171 @@ EMOJI_REGISTERED_DIR = DATA_DIR / "emoji_registered" # 已注册的表情包注
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
def register_emoji_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""注册表情包系统内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 实际注册的 Hook 规格列表。
"""
emoji_schema = {
"type": "object",
"description": "当前表情包的序列化信息,主要包含 file_hash、description、emotions 等字段。",
}
string_array_schema = {
"type": "array",
"items": {"type": "string"},
}
return registry.register_hook_specs(
[
HookSpec(
name="emoji.maisaka.before_select",
description="Maisaka 表情发送工具选择表情前触发,可改写情绪、上下文和采样参数,或中止本次选择。",
parameters_schema=build_object_schema(
{
"stream_id": {"type": "string", "description": "目标会话 ID。"},
"requested_emotion": {"type": "string", "description": "请求的目标情绪标签。"},
"reasoning": {"type": "string", "description": "本次发送表情的推理理由。"},
"context_texts": {
**string_array_schema,
"description": "最近聊天上下文文本列表。",
},
"sample_size": {"type": "integer", "description": "候选表情采样数量。"},
"abort_message": {
"type": "string",
"description": "当 Hook 主动中止时可附带的失败提示。",
},
},
required=["stream_id", "requested_emotion", "reasoning", "context_texts", "sample_size"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="emoji.maisaka.after_select",
description="Maisaka 已选出表情后触发,可替换选中的表情哈希、补充匹配情绪,或中止发送。",
parameters_schema=build_object_schema(
{
"stream_id": {"type": "string", "description": "目标会话 ID。"},
"requested_emotion": {"type": "string", "description": "请求的目标情绪标签。"},
"reasoning": {"type": "string", "description": "本次发送表情的推理理由。"},
"context_texts": {
**string_array_schema,
"description": "最近聊天上下文文本列表。",
},
"sample_size": {"type": "integer", "description": "候选表情采样数量。"},
"selected_emoji": emoji_schema,
"selected_emoji_hash": {"type": "string", "description": "选中的表情哈希。"},
"matched_emotion": {"type": "string", "description": "最终命中的情绪标签。"},
"abort_message": {
"type": "string",
"description": "当 Hook 主动中止时可附带的失败提示。",
},
},
required=[
"stream_id",
"requested_emotion",
"reasoning",
"context_texts",
"sample_size",
"matched_emotion",
],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="emoji.register.after_build_description",
description="表情包描述生成并通过内容审查后触发,可改写描述文本或拒绝本次注册。",
parameters_schema=build_object_schema(
{
"emoji": emoji_schema,
"description": {"type": "string", "description": "当前生成出的表情包描述。"},
"image_format": {"type": "string", "description": "表情图片格式。"},
},
required=["emoji", "description", "image_format"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="emoji.register.after_build_emotion",
description="表情包情绪标签生成完成后触发,可改写标签列表或拒绝本次注册。",
parameters_schema=build_object_schema(
{
"emoji": emoji_schema,
"description": {"type": "string", "description": "当前表情包描述。"},
"emotions": {
**string_array_schema,
"description": "当前生成出的情绪标签列表。",
},
},
required=["emoji", "description", "emotions"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
]
)
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[Dict[str, Any]]:
"""将表情包对象序列化为 Hook 可传输载荷。
Args:
emoji: 待序列化的表情包对象。
Returns:
Optional[Dict[str, Any]]: 序列化后的字典;当表情为空时返回 ``None``。
"""
if emoji is None:
return None
return {
"file_hash": str(emoji.file_hash or "").strip(),
"file_name": emoji.file_name,
"full_path": str(emoji.full_path),
"description": emoji.description,
"emotions": [str(item).strip() for item in emoji.emotion if str(item).strip()],
"query_count": int(emoji.query_count),
}
def _normalize_string_list(raw_values: Any) -> List[str]:
"""将任意列表值规范化为字符串列表。
Args:
raw_values: 待规范化的原始值。
Returns:
List[str]: 去空白后的字符串列表。
"""
if not isinstance(raw_values, list):
return []
return [str(item).strip() for item in raw_values if str(item).strip()]
def _ensure_directories() -> None:
"""确保表情包相关目录存在"""
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
@@ -642,6 +810,22 @@ class EmojiManager:
if "" in llm_response:
logger.warning(f"[表情包审查] 表情包内容不符合要求,拒绝注册: {target_emoji.file_name}")
return False, target_emoji
hook_result = await _get_runtime_manager().invoke_hook(
"emoji.register.after_build_description",
emoji=_serialize_emoji_for_hook(target_emoji),
description=description,
image_format=image_format,
)
if hook_result.aborted:
logger.info(f"[构建描述] 表情包描述被 Hook 中止注册: {target_emoji.file_name}")
return False, target_emoji
normalized_description = str(hook_result.kwargs.get("description", description) or "").strip()
if not normalized_description:
logger.warning(f"[构建描述] Hook 返回空描述,拒绝注册: {target_emoji.file_name}")
return False, target_emoji
description = normalized_description
target_emoji.description = description
logger.info(f"[构建描述] 成功为表情包构建描述: {target_emoji.description}")
return True, target_emoji
@@ -687,6 +871,23 @@ class EmojiManager:
elif len(emotions) > 2:
emotions = random.sample(emotions, 2)
hook_result = await _get_runtime_manager().invoke_hook(
"emoji.register.after_build_emotion",
emoji=_serialize_emoji_for_hook(target_emoji),
description=target_emoji.description,
emotions=list(emotions),
)
if hook_result.aborted:
logger.info(f"[构建情感标签] 表情包情感标签被 Hook 中止注册: {target_emoji.file_name}")
return False, target_emoji
raw_emotions = hook_result.kwargs.get("emotions")
if raw_emotions is not None:
emotions = _normalize_string_list(raw_emotions)
if not emotions:
logger.warning(f"[构建情感标签] Hook 返回空情绪标签,拒绝注册: {target_emoji.file_name}")
return False, target_emoji
logger.info(f"[构建情感标签] 成功为表情包构建情感标签: {','.join(emotions)}")
target_emoji.emotion = emotions
return True, target_emoji

View File

@@ -1,7 +1,7 @@
"""Maisaka 表情工具内置能力。"""
from dataclasses import dataclass, field
from typing import Sequence
from typing import Any, Optional, Sequence
import random
@@ -11,7 +11,7 @@ from src.common.logger import get_logger
from src.common.utils.utils_image import ImageUtils
from src.services import send_service
from .emoji_manager import emoji_manager, emoji_manager_emotion_judge_llm
from .emoji_manager import _serialize_emoji_for_hook, emoji_manager, emoji_manager_emotion_judge_llm
logger = get_logger("emoji_maisaka_tool")
@@ -29,6 +29,76 @@ class MaisakaEmojiSendResult:
matched_emotion: str = ""
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
def _coerce_positive_int(value: Any, default: int) -> int:
"""将任意值安全转换为正整数。
Args:
value: 待转换的值。
default: 转换失败时使用的默认值。
Returns:
int: 规范化后的正整数。
"""
try:
normalized_value = int(value)
except (TypeError, ValueError):
return default
return normalized_value if normalized_value > 0 else default
def _normalize_context_texts(context_texts: Sequence[str] | None) -> list[str]:
"""清洗 Hook 和调用链传入的上下文文本列表。
Args:
context_texts: 原始上下文文本序列。
Returns:
list[str]: 过滤空白后的上下文文本列表。
"""
if not context_texts:
return []
return [str(item).strip() for item in context_texts if str(item).strip()]
def _resolve_selected_emoji(raw_value: Any) -> Optional[MaiEmoji]:
"""根据 Hook 返回值解析目标表情包对象。
Args:
raw_value: Hook 返回的 ``selected_emoji`` 或 ``selected_emoji_hash``。
Returns:
Optional[MaiEmoji]: 命中的表情包对象;未命中时返回 ``None``。
"""
raw_hash: str = ""
if isinstance(raw_value, dict):
raw_hash = str(raw_value.get("file_hash") or raw_value.get("hash") or "").strip()
elif isinstance(raw_value, str):
raw_hash = raw_value.strip()
if not raw_hash:
return None
for emoji in emoji_manager.emojis:
if emoji.file_hash == raw_hash:
return emoji
return None
def _normalize_emotions(emoji: MaiEmoji) -> list[str]:
"""提取并清洗单个表情的情绪标签。"""
@@ -129,16 +199,81 @@ async def send_emoji_for_maisaka(
) -> MaisakaEmojiSendResult:
"""为 Maisaka 选择并发送一个表情。"""
selected_emoji, matched_emotion = await select_emoji_for_maisaka(
requested_emotion=requested_emotion,
reasoning=reasoning,
context_texts=context_texts,
normalized_requested_emotion = requested_emotion.strip()
normalized_reasoning = reasoning.strip()
normalized_context_texts = _normalize_context_texts(context_texts)
sample_size = 30
before_select_result = await _get_runtime_manager().invoke_hook(
"emoji.maisaka.before_select",
stream_id=stream_id,
requested_emotion=normalized_requested_emotion,
reasoning=normalized_reasoning,
context_texts=list(normalized_context_texts),
sample_size=sample_size,
abort_message="表情选择已被 Hook 中止。",
)
if before_select_result.aborted:
abort_message = str(before_select_result.kwargs.get("abort_message") or "表情选择已被 Hook 中止。").strip()
return MaisakaEmojiSendResult(
success=False,
message=abort_message or "表情选择已被 Hook 中止。",
requested_emotion=normalized_requested_emotion,
)
before_select_kwargs = before_select_result.kwargs
normalized_requested_emotion = str(
before_select_kwargs.get("requested_emotion", normalized_requested_emotion) or ""
).strip()
normalized_reasoning = str(before_select_kwargs.get("reasoning", normalized_reasoning) or "").strip()
if isinstance(before_select_kwargs.get("context_texts"), list):
normalized_context_texts = _normalize_context_texts(before_select_kwargs.get("context_texts"))
sample_size = _coerce_positive_int(before_select_kwargs.get("sample_size"), sample_size)
selected_emoji, matched_emotion = await select_emoji_for_maisaka(
requested_emotion=normalized_requested_emotion,
reasoning=normalized_reasoning,
context_texts=normalized_context_texts,
sample_size=sample_size,
)
after_select_result = await _get_runtime_manager().invoke_hook(
"emoji.maisaka.after_select",
stream_id=stream_id,
requested_emotion=normalized_requested_emotion,
reasoning=normalized_reasoning,
context_texts=list(normalized_context_texts),
sample_size=sample_size,
selected_emoji=_serialize_emoji_for_hook(selected_emoji),
selected_emoji_hash=str(selected_emoji.file_hash or "").strip() if selected_emoji is not None else "",
matched_emotion=matched_emotion,
abort_message="表情发送已被 Hook 中止。",
)
if after_select_result.aborted:
abort_message = str(after_select_result.kwargs.get("abort_message") or "表情发送已被 Hook 中止。").strip()
return MaisakaEmojiSendResult(
success=False,
message=abort_message or "表情发送已被 Hook 中止。",
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
after_select_kwargs = after_select_result.kwargs
normalized_requested_emotion = str(
after_select_kwargs.get("requested_emotion", normalized_requested_emotion) or ""
).strip()
matched_emotion = str(after_select_kwargs.get("matched_emotion", matched_emotion) or "").strip()
override_emoji = _resolve_selected_emoji(after_select_kwargs.get("selected_emoji_hash"))
if override_emoji is None:
override_emoji = _resolve_selected_emoji(after_select_kwargs.get("selected_emoji"))
if override_emoji is not None:
selected_emoji = override_emoji
if selected_emoji is None:
return MaisakaEmojiSendResult(
success=False,
message="当前表情包库中没有可用表情。",
requested_emotion=requested_emotion.strip(),
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
try:
@@ -151,7 +286,7 @@ async def send_emoji_for_maisaka(
message=f"发送表情包失败:{exc}",
description=selected_emoji.description.strip(),
emotions=_normalize_emotions(selected_emoji),
requested_emotion=requested_emotion.strip(),
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
@@ -169,7 +304,7 @@ async def send_emoji_for_maisaka(
message=f"发送表情包时发生异常:{exc}",
description=selected_emoji.description.strip(),
emotions=_normalize_emotions(selected_emoji),
requested_emotion=requested_emotion.strip(),
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
@@ -181,7 +316,7 @@ async def send_emoji_for_maisaka(
message="发送表情包失败。",
description=description,
emotions=emotions,
requested_emotion=requested_emotion.strip(),
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
@@ -197,6 +332,6 @@ async def send_emoji_for_maisaka(
emoji_base64=emoji_base64,
description=description,
emotions=emotions,
requested_emotion=requested_emotion.strip(),
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)

View File

@@ -1,6 +1,8 @@
"""聊天消息入口与主链路调度。"""
from contextlib import suppress
from copy import deepcopy
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
import os
import traceback
@@ -13,12 +15,15 @@ from src.common.utils.utils_message import MessageUtils
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config
from src.platform_io.route_key_factory import RouteKeyFactory
from src.core.announcement_manager import global_announcement_manager
from src.plugin_runtime.component_query import component_query_service
from src.plugin_runtime.hook_payloads import deserialize_session_message, serialize_session_message
from src.plugin_runtime.hook_schema_utils import build_object_schema
from src.plugin_runtime.host.hook_dispatcher import HookDispatchResult
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
from .message import SessionMessage
from .chat_manager import chat_manager
from .message import SessionMessage
# 定义日志配置
@@ -29,7 +34,137 @@ PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..
logger = get_logger("chat")
def register_chat_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""注册聊天消息主链内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 实际注册的 Hook 规格列表。
"""
return registry.register_hook_specs(
[
HookSpec(
name="chat.receive.before_process",
description="在入站消息执行 `SessionMessage.process()` 之前触发,可拦截或改写消息。",
parameters_schema=build_object_schema(
{
"message": {
"type": "object",
"description": "当前入站消息的序列化 SessionMessage。",
},
},
required=["message"],
),
default_timeout_ms=8000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="chat.receive.after_process",
description="在入站消息完成轻量预处理后触发,可改写文本、消息体或中止后续链路。",
parameters_schema=build_object_schema(
{
"message": {
"type": "object",
"description": "已完成 `process()` 的序列化 SessionMessage。",
},
},
required=["message"],
),
default_timeout_ms=8000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="chat.command.before_execute",
description="在命令匹配成功、实际执行前触发,可拦截命令或改写命令上下文。",
parameters_schema=build_object_schema(
{
"message": {
"type": "object",
"description": "当前命令消息的序列化 SessionMessage。",
},
"command_name": {
"type": "string",
"description": "命中的命令名称。",
},
"plugin_id": {
"type": "string",
"description": "命令所属插件 ID。",
},
"matched_groups": {
"type": "object",
"description": "命令正则命名捕获结果。",
},
},
required=["message", "command_name", "plugin_id", "matched_groups"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="chat.command.after_execute",
description="在命令执行结束后触发,可调整返回文本和是否继续主链处理。",
parameters_schema=build_object_schema(
{
"message": {
"type": "object",
"description": "当前命令消息的序列化 SessionMessage。",
},
"command_name": {
"type": "string",
"description": "命令名称。",
},
"plugin_id": {
"type": "string",
"description": "命令所属插件 ID。",
},
"matched_groups": {
"type": "object",
"description": "命令正则命名捕获结果。",
},
"success": {
"type": "boolean",
"description": "命令执行是否成功。",
},
"response": {
"type": "string",
"description": "命令返回文本。",
},
"intercept_message_level": {
"type": "integer",
"description": "命令拦截等级。",
},
"continue_process": {
"type": "boolean",
"description": "命令执行后是否继续后续消息处理。",
},
},
required=[
"message",
"command_name",
"plugin_id",
"matched_groups",
"success",
"intercept_message_level",
"continue_process",
],
),
default_timeout_ms=5000,
allow_abort=False,
allow_kwargs_mutation=True,
),
]
)
class ChatBot:
"""聊天机器人入口协调器。"""
def __init__(self) -> None:
"""初始化聊天机器人入口。"""
@@ -44,6 +179,66 @@ class ChatBot:
self._started = True
@staticmethod
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
@staticmethod
def _coerce_int(value: Any, default: int) -> int:
"""将任意值安全转换为整数。
Args:
value: 待转换的值。
default: 转换失败时的默认值。
Returns:
int: 转换后的整数结果。
"""
try:
return int(value)
except (TypeError, ValueError):
return default
async def _invoke_message_hook(
self,
hook_name: str,
message: SessionMessage,
**kwargs: Any,
) -> tuple[HookDispatchResult, SessionMessage]:
"""触发携带会话消息的命名 Hook。
Args:
hook_name: 目标 Hook 名称。
message: 当前会话消息。
**kwargs: 需要附带传递的额外参数。
Returns:
tuple[HookDispatchResult, SessionMessage]: Hook 聚合结果以及可能被改写后的消息对象。
"""
hook_result = await self._get_runtime_manager().invoke_hook(
hook_name,
message=serialize_session_message(message),
**kwargs,
)
mutated_message = message
raw_message = hook_result.kwargs.get("message")
if raw_message is not None:
try:
mutated_message = deserialize_session_message(raw_message)
except Exception as exc:
logger.warning(f"Hook {hook_name} 返回的 message 无法反序列化,已忽略: {exc}")
return hook_result, mutated_message
async def _process_commands(self, message: SessionMessage) -> tuple[bool, Optional[str], bool]:
"""使用统一组件注册表处理命令。
@@ -71,6 +266,25 @@ class ChatBot:
return False, None, True
message.is_command = True
before_result, message = await self._invoke_message_hook(
"chat.command.before_execute",
message,
command_name=command_name,
plugin_id=plugin_name,
matched_groups=dict(matched_groups),
)
if before_result.aborted:
logger.info(f"命令 {command_name} 被 Hook 中止,跳过命令执行")
return True, None, False
hook_kwargs = before_result.kwargs
command_name = str(hook_kwargs.get("command_name", command_name) or command_name)
plugin_name = str(hook_kwargs.get("plugin_id", plugin_name) or plugin_name)
matched_groups = (
dict(hook_kwargs["matched_groups"])
if isinstance(hook_kwargs.get("matched_groups"), dict)
else dict(matched_groups)
)
# 获取插件配置
plugin_config = component_query_service.get_plugin_config(plugin_name)
@@ -82,27 +296,43 @@ class ChatBot:
plugin_config=plugin_config,
matched_groups=matched_groups,
)
self._mark_command_message(message, intercept_message_level)
# 记录命令执行结果
if success:
logger.info(f"命令执行成功: {command_name} (拦截等级: {intercept_message_level})")
else:
logger.warning(f"命令执行失败: {command_name} - {response}")
# 根据命令的拦截设置决定是否继续处理消息
return (
True,
response,
not bool(intercept_message_level),
) # 找到命令根据intercept_message决定是否继续
except Exception as e:
logger.error(f"执行命令时出错: {command_name} - {e}")
continue_process = not bool(intercept_message_level)
except Exception as exc:
logger.error(f"执行命令时出错: {command_name} - {exc}")
logger.error(traceback.format_exc())
success = False
response = str(exc)
intercept_message_level = 1
continue_process = False
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
return True, str(e), False # 出错时继续处理消息
after_result, message = await self._invoke_message_hook(
"chat.command.after_execute",
message,
command_name=command_name,
plugin_id=plugin_name,
matched_groups=dict(matched_groups),
success=success,
response=response,
intercept_message_level=intercept_message_level,
continue_process=continue_process,
)
after_kwargs = after_result.kwargs
success = bool(after_kwargs.get("success", success))
raw_response = after_kwargs.get("response", response)
response = None if raw_response is None else str(raw_response)
intercept_message_level = self._coerce_int(
after_kwargs.get("intercept_message_level", intercept_message_level),
intercept_message_level,
)
continue_process = bool(after_kwargs.get("continue_process", continue_process))
self._mark_command_message(message, intercept_message_level)
if success:
logger.info(f"命令执行成功: {command_name} (拦截等级: {intercept_message_level})")
else:
logger.warning(f"命令执行失败: {command_name} - {response}")
return True, response, continue_process
return False, None, True
@@ -138,6 +368,17 @@ class ChatBot:
cmd_result: Optional[str],
continue_process: bool,
) -> bool:
"""处理命令链结果并决定是否终止主消息链。
Args:
message: 当前命令消息。
cmd_result: 命令响应文本。
continue_process: 是否继续后续主链处理。
Returns:
bool: ``True`` 表示已经终止后续主链。
"""
if continue_process:
return False
@@ -145,9 +386,18 @@ class ChatBot:
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
return True
async def handle_notice_message(self, message: SessionMessage):
async def handle_notice_message(self, message: SessionMessage) -> bool:
"""处理通知类消息。
Args:
message: 当前通知消息。
Returns:
bool: 当前消息是否为通知消息。
"""
if message.message_id != "notice":
return
return False
message.is_notify = True
logger.debug("notice消息")
@@ -203,9 +453,12 @@ class ChatBot:
return True
async def echo_message_process(self, raw_data: Dict[str, Any]) -> None:
"""处理消息回送 ID 对应关系。
Args:
raw_data: 平台适配器上报的原始回送载荷。
"""
用于专门处理回送消息ID的函数
"""
message_data: Dict[str, Any] = raw_data.get("content", {})
if not message_data:
return
@@ -218,18 +471,10 @@ class ChatBot:
logger.debug(f"收到回送消息ID: {mmc_message_id} -> {actual_message_id}")
async def message_process(self, message_data: Dict[str, Any]) -> None:
"""处理转化后的统一格式消息
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
heart_flow模式使用思维流系统进行回复
- 包含思维流状态管理
- 在回复前进行观察和状态更新
- 回复后更新思维流状态
- 消息过滤
- 记忆激活
- 意愿计算
- 消息生成和发送
- 表情包处理
- 性能计时
"""处理统一格式的入站消息字典。
Args:
message_data: 适配器整理后的统一消息字典。
"""
try:
# 确保所有任务已启动
@@ -253,7 +498,13 @@ class ChatBot:
logger.error(f"预处理消息失败: {e}")
traceback.print_exc()
async def receive_message(self, message: SessionMessage):
async def receive_message(self, message: SessionMessage) -> None:
"""处理单条入站会话消息。
Args:
message: 待处理的会话消息。
"""
try:
group_info = message.message_info.group_info
user_info = message.message_info.user_info
@@ -272,6 +523,19 @@ class ChatBot:
)
message.session_id = session_id # 正确初始化session_id
before_process_result, message = await self._invoke_message_hook(
"chat.receive.before_process",
message,
)
if before_process_result.aborted:
logger.info(f"消息 {message.message_id} 在预处理前被 Hook 中止")
return
group_info = message.message_info.group_info
user_info = message.message_info.user_info
additional_config = message.message_info.additional_config
if isinstance(additional_config, dict):
account_id, scope = RouteKeyFactory.extract_components(additional_config)
# TODO: 修复事件预处理部分
# continue_flag, modified_message = await events_manager.handle_mai_events(
@@ -294,6 +558,16 @@ class ChatBot:
enable_heavy_media_analysis=False,
enable_voice_transcription=False,
)
after_process_result, message = await self._invoke_message_hook(
"chat.receive.after_process",
message,
)
if after_process_result.aborted:
logger.info(f"消息 {message.message_id} 在预处理后被 Hook 中止")
return
group_info = message.message_info.group_info
user_info = message.message_info.user_info
# 平台层的 @ 检测由底层 is_mentioned_bot_in_message 统一处理;此处不做用户名硬编码匹配

View File

@@ -18,28 +18,29 @@ install(extra_lines=3)
logger = get_logger("sender")
# WebUI 聊天室的消息广播器(延迟导入避免循环依赖)
_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str]]] = None
_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str], Optional[str]]] = None
# 虚拟群 ID 前缀(与 chat_routes.py 保持一致)
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
# TODO: 重构完成后完成webui相关
def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str]]:
def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str], Optional[str]]:
"""获取 WebUI 聊天室广播器。
Returns:
Tuple[Any, Optional[str]]: ``(chat_manager, platform_name)`` 元组;
Tuple[Any, Optional[str], Optional[str]]: ``(chat_manager, platform_name, default_group_id)`` 元组;
若 WebUI 相关模块不可用,则元素会退化为 ``None``。
"""
global _webui_chat_broadcaster
if _webui_chat_broadcaster is None:
try:
from src.webui.routers.chat import WEBUI_CHAT_PLATFORM, chat_manager
from src.webui.routers.chat.service import WEBUI_CHAT_GROUP_ID
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM)
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM, WEBUI_CHAT_GROUP_ID)
except ImportError:
_webui_chat_broadcaster = (None, None)
_webui_chat_broadcaster = (None, None, None)
return _webui_chat_broadcaster
@@ -76,7 +77,7 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
try:
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
chat_manager, webui_platform = get_webui_chat_broadcaster()
chat_manager, webui_platform, default_group_id = get_webui_chat_broadcaster()
is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id)
if is_webui_message and chat_manager is not None:
@@ -97,8 +98,9 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
message_type = "rich"
segments = message_segments
await chat_manager.broadcast(
{
await chat_manager.broadcast_to_group(
group_id=group_id or default_group_id or "",
message={
"type": "bot_message",
"content": message.processed_plain_text,
"message_type": message_type,
@@ -110,7 +112,7 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
"avatar": None,
"is_bot": True,
},
}
},
)
# 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库

View File

@@ -1,22 +1,25 @@
from datetime import datetime
from sqlmodel import select
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
import asyncio
import difflib
import json
import re
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.services.llm_service import LLMServiceClient
from src.config.config import global_config
from src.prompt.prompt_manager import prompt_manager
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.common.database.database import get_db_session
from src.common.data_models.expression_data_model import MaiExpression
from sqlmodel import select
from src.chat.utils.utils import is_bot_self
from src.common.data_models.expression_data_model import MaiExpression
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.common.logger import get_logger
from src.common.utils.utils_message import MessageUtils
from src.config.config import global_config
from src.plugin_runtime.hook_schema_utils import build_object_schema
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
from src.prompt.prompt_manager import prompt_manager
from src.services.llm_service import LLMServiceClient
from .expression_utils import check_expression_suitability, parse_expression_response
@@ -34,8 +37,122 @@ summary_model = LLMServiceClient(task_name="utils", request_type="expression.sum
check_model = LLMServiceClient(task_name="utils", request_type="expression.check")
def register_expression_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""注册表达方式系统内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 实际注册的 Hook 规格列表。
"""
return registry.register_hook_specs(
[
HookSpec(
name="expression.select.before_select",
description="表达方式选择流程开始前触发,可改写会话上下文、选择参数或中止本次选择。",
parameters_schema=build_object_schema(
{
"chat_id": {"type": "string", "description": "当前聊天流 ID。"},
"chat_info": {"type": "string", "description": "用于选择表达方式的聊天上下文。"},
"max_num": {"type": "integer", "description": "最大可选表达方式数量。"},
"target_message": {"type": "string", "description": "当前目标回复消息文本。"},
"reply_reason": {"type": "string", "description": "规划器给出的回复理由。"},
"think_level": {"type": "integer", "description": "表达方式选择思考级别。"},
},
required=["chat_id", "chat_info", "max_num", "think_level"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="expression.select.after_selection",
description="表达方式选择完成后触发,可改写最终选中的表达方式列表与 ID。",
parameters_schema=build_object_schema(
{
"chat_id": {"type": "string", "description": "当前聊天流 ID。"},
"chat_info": {"type": "string", "description": "用于选择表达方式的聊天上下文。"},
"max_num": {"type": "integer", "description": "最大可选表达方式数量。"},
"target_message": {"type": "string", "description": "当前目标回复消息文本。"},
"reply_reason": {"type": "string", "description": "规划器给出的回复理由。"},
"think_level": {"type": "integer", "description": "表达方式选择思考级别。"},
"selected_expressions": {
"type": "array",
"items": {"type": "object"},
"description": "当前已选中的表达方式列表。",
},
"selected_expression_ids": {
"type": "array",
"items": {"type": "integer"},
"description": "当前已选中的表达方式 ID 列表。",
},
},
required=[
"chat_id",
"chat_info",
"max_num",
"think_level",
"selected_expressions",
"selected_expression_ids",
],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="expression.learn.after_extract",
description="表达方式学习解析出表达/黑话候选后触发,可改写候选集或直接终止本轮学习。",
parameters_schema=build_object_schema(
{
"session_id": {"type": "string", "description": "当前会话 ID。"},
"message_count": {"type": "integer", "description": "本轮参与学习的消息数量。"},
"expressions": {
"type": "array",
"items": {"type": "object"},
"description": "解析出的表达方式候选列表。",
},
"jargon_entries": {
"type": "array",
"items": {"type": "object"},
"description": "解析出的黑话候选列表。",
},
},
required=["session_id", "message_count", "expressions", "jargon_entries"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="expression.learn.before_upsert",
description="表达方式写入数据库前触发,可改写情景/风格文本或跳过本条写入。",
parameters_schema=build_object_schema(
{
"session_id": {"type": "string", "description": "当前会话 ID。"},
"situation": {"type": "string", "description": "即将写入的情景文本。"},
"style": {"type": "string", "description": "即将写入的风格文本。"},
},
required=["session_id", "situation", "style"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
]
)
class ExpressionLearner:
def __init__(self, session_id: str) -> None:
"""初始化表达方式学习器。
Args:
session_id: 当前会话 ID。
"""
self.session_id = session_id
# 学习锁,防止并发执行学习任务
@@ -44,6 +161,110 @@ class ExpressionLearner:
# 消息缓存
self._messages_cache: List["SessionMessage"] = []
@staticmethod
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
@staticmethod
def _serialize_expressions(expressions: List[Tuple[str, str, str]]) -> List[dict[str, str]]:
"""将表达方式候选序列化为 Hook 载荷。
Args:
expressions: 原始表达方式候选列表。
Returns:
List[dict[str, str]]: 序列化后的表达方式候选。
"""
return [
{
"situation": str(situation).strip(),
"style": str(style).strip(),
"source_id": str(source_id).strip(),
}
for situation, style, source_id in expressions
if str(situation).strip() and str(style).strip()
]
@staticmethod
def _deserialize_expressions(raw_expressions: Any) -> List[Tuple[str, str, str]]:
"""从 Hook 载荷恢复表达方式候选列表。
Args:
raw_expressions: Hook 返回的表达方式候选。
Returns:
List[Tuple[str, str, str]]: 恢复后的表达方式候选列表。
"""
if not isinstance(raw_expressions, list):
return []
normalized_expressions: List[Tuple[str, str, str]] = []
for raw_expression in raw_expressions:
if not isinstance(raw_expression, dict):
continue
situation = str(raw_expression.get("situation") or "").strip()
style = str(raw_expression.get("style") or "").strip()
source_id = str(raw_expression.get("source_id") or "").strip()
if not situation or not style:
continue
normalized_expressions.append((situation, style, source_id))
return normalized_expressions
@staticmethod
def _serialize_jargon_entries(jargon_entries: List[Tuple[str, str]]) -> List[dict[str, str]]:
"""将黑话候选序列化为 Hook 载荷。
Args:
jargon_entries: 原始黑话候选列表。
Returns:
List[dict[str, str]]: 序列化后的黑话候选列表。
"""
return [
{
"content": str(content).strip(),
"source_id": str(source_id).strip(),
}
for content, source_id in jargon_entries
if str(content).strip()
]
@staticmethod
def _deserialize_jargon_entries(raw_jargon_entries: Any) -> List[Tuple[str, str]]:
"""从 Hook 载荷恢复黑话候选列表。
Args:
raw_jargon_entries: Hook 返回的黑话候选列表。
Returns:
List[Tuple[str, str]]: 恢复后的黑话候选列表。
"""
if not isinstance(raw_jargon_entries, list):
return []
normalized_entries: List[Tuple[str, str]] = []
for raw_entry in raw_jargon_entries:
if not isinstance(raw_entry, dict):
continue
content = str(raw_entry.get("content") or "").strip()
source_id = str(raw_entry.get("source_id") or "").strip()
if not content:
continue
normalized_entries.append((content, source_id))
return normalized_entries
def add_messages(self, messages: List["SessionMessage"]) -> None:
"""添加消息到缓存"""
self._messages_cache.extend(messages)
@@ -52,8 +273,12 @@ class ExpressionLearner:
"""获取当前消息缓存的大小"""
return len(self._messages_cache)
async def learn(self, jargon_miner: Optional["JargonMiner"] = None):
"""学习主流程"""
async def learn(self, jargon_miner: Optional["JargonMiner"] = None) -> None:
"""执行表达方式学习主流程
Args:
jargon_miner: 可选的黑话学习器实例,用于同步处理黑话候选。
"""
if not self._messages_cache:
logger.debug("没有消息可供学习,跳过学习过程")
return
@@ -109,6 +334,25 @@ class ExpressionLearner:
logger.info(f"黑话提取数量超过 30 个(实际{len(jargon_entries)}个),放弃本次黑话学习")
jargon_entries = []
after_extract_result = await self._get_runtime_manager().invoke_hook(
"expression.learn.after_extract",
session_id=self.session_id,
message_count=len(self._messages_cache),
expressions=self._serialize_expressions(expressions),
jargon_entries=self._serialize_jargon_entries(jargon_entries),
)
if after_extract_result.aborted:
logger.info(f"{self.session_id} 的表达方式学习结果被 Hook 中止")
return
after_extract_kwargs = after_extract_result.kwargs
raw_expressions = after_extract_kwargs.get("expressions")
if raw_expressions is not None:
expressions = self._deserialize_expressions(raw_expressions)
raw_jargon_entries = after_extract_kwargs.get("jargon_entries")
if raw_jargon_entries is not None:
jargon_entries = self._deserialize_jargon_entries(raw_jargon_entries)
# 处理黑话条目,路由到 jargon_miner即使没有表达方式也要处理黑话
# TODO: 检测是否开启了
if jargon_entries:
@@ -135,6 +379,22 @@ class ExpressionLearner:
# 存储到数据库 Expression 表
for situation, style in learnt_expressions:
before_upsert_result = await self._get_runtime_manager().invoke_hook(
"expression.learn.before_upsert",
session_id=self.session_id,
situation=situation,
style=style,
)
if before_upsert_result.aborted:
logger.info(f"{self.session_id} 的表达方式写入被 Hook 跳过: situation={situation!r}")
continue
upsert_kwargs = before_upsert_result.kwargs
situation = str(upsert_kwargs.get("situation", situation) or "").strip()
style = str(upsert_kwargs.get("style", style) or "").strip()
if not situation or not style:
logger.info(f"{self.session_id} 的表达方式写入被 Hook 清空,已跳过")
continue
await self._upsert_expression_to_db(situation, style)
# ====== 黑话相关 ======

View File

@@ -1,27 +1,109 @@
from typing import Any, Dict, List, Optional, Tuple
import json
import time
from typing import List, Dict, Optional, Any, Tuple
from json_repair import repair_json
from src.services.llm_service import LLMServiceClient
from src.config.config import global_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.common.utils.utils_session import SessionUtils
from src.prompt.prompt_manager import prompt_manager
from src.learners.learner_utils_old import weighted_sample
from src.chat.utils.common_utils import TempMethodsExpression
from src.common.database.database_model import Expression
from src.common.logger import get_logger
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config
from src.learners.learner_utils_old import weighted_sample
from src.prompt.prompt_manager import prompt_manager
from src.services.llm_service import LLMServiceClient
logger = get_logger("expression_selector")
class ExpressionSelector:
def __init__(self):
def __init__(self) -> None:
"""初始化表达方式选择器。"""
self.llm_model = LLMServiceClient(
task_name="utils", request_type="expression.selector"
)
@staticmethod
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
@staticmethod
def _coerce_int(value: Any, default: int) -> int:
"""将任意值安全转换为整数。
Args:
value: 待转换的值。
default: 转换失败时的默认值。
Returns:
int: 转换后的整数结果。
"""
try:
return int(value)
except (TypeError, ValueError):
return default
@staticmethod
def _normalize_selected_expressions(raw_expressions: Any) -> List[Dict[str, Any]]:
"""从 Hook 载荷恢复表达方式选择结果。
Args:
raw_expressions: Hook 返回的表达方式列表。
Returns:
List[Dict[str, Any]]: 恢复后的表达方式列表。
"""
if not isinstance(raw_expressions, list):
return []
normalized_expressions: List[Dict[str, Any]] = []
for raw_expression in raw_expressions:
if not isinstance(raw_expression, dict):
continue
expression_id = raw_expression.get("id")
situation = str(raw_expression.get("situation") or "").strip()
style = str(raw_expression.get("style") or "").strip()
source_id = str(raw_expression.get("source_id") or "").strip()
if not isinstance(expression_id, int) or not situation or not style or not source_id:
continue
normalized_expression = dict(raw_expression)
normalized_expression["id"] = expression_id
normalized_expression["situation"] = situation
normalized_expression["style"] = style
normalized_expression["source_id"] = source_id
normalized_expressions.append(normalized_expression)
return normalized_expressions
@staticmethod
def _normalize_selected_expression_ids(raw_ids: Any, expressions: List[Dict[str, Any]]) -> List[int]:
"""规范化最终选中的表达方式 ID 列表。
Args:
raw_ids: Hook 返回的 ID 列表。
expressions: 当前最终表达方式列表。
Returns:
List[int]: 规范化后的 ID 列表。
"""
if isinstance(raw_ids, list):
normalized_ids = [item for item in raw_ids if isinstance(item, int)]
if normalized_ids:
return normalized_ids
return [expression["id"] for expression in expressions if isinstance(expression.get("id"), int)]
def can_use_expression_for_chat(self, chat_id: str) -> bool:
"""
检查指定聊天流是否允许使用表达
@@ -214,8 +296,7 @@ class ExpressionSelector:
reply_reason: Optional[str] = None,
think_level: int = 1,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
选择适合的表达方式使用classic模式随机选择+LLM选择
"""选择适合的表达方式。
Args:
chat_id: 聊天流ID
@@ -233,11 +314,60 @@ class ExpressionSelector:
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return [], []
before_select_result = await self._get_runtime_manager().invoke_hook(
"expression.select.before_select",
chat_id=chat_id,
chat_info=chat_info,
max_num=max_num,
target_message=target_message or "",
reply_reason=reply_reason or "",
think_level=think_level,
)
if before_select_result.aborted:
logger.info(f"聊天流 {chat_id} 的表达方式选择被 Hook 中止")
return [], []
before_select_kwargs = before_select_result.kwargs
chat_id = str(before_select_kwargs.get("chat_id", chat_id) or "").strip() or chat_id
chat_info = str(before_select_kwargs.get("chat_info", chat_info) or "")
max_num = max(self._coerce_int(before_select_kwargs.get("max_num"), max_num), 1)
raw_target_message = before_select_kwargs.get("target_message", target_message or "")
target_message = str(raw_target_message or "").strip() or None
raw_reply_reason = before_select_kwargs.get("reply_reason", reply_reason or "")
reply_reason = str(raw_reply_reason or "").strip() or None
think_level = self._coerce_int(before_select_kwargs.get("think_level"), think_level)
# 使用classic模式随机选择+LLM选择
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式think_level={think_level}")
return await self._select_expressions_classic(
selected_expressions, selected_ids = await self._select_expressions_classic(
chat_id, chat_info, max_num, target_message, reply_reason, think_level
)
after_selection_result = await self._get_runtime_manager().invoke_hook(
"expression.select.after_selection",
chat_id=chat_id,
chat_info=chat_info,
max_num=max_num,
target_message=target_message or "",
reply_reason=reply_reason or "",
think_level=think_level,
selected_expressions=[dict(item) for item in selected_expressions],
selected_expression_ids=list(selected_ids),
)
if after_selection_result.aborted:
logger.info(f"聊天流 {chat_id} 的表达方式选择结果被 Hook 中止")
return [], []
after_selection_kwargs = after_selection_result.kwargs
raw_selected_expressions = after_selection_kwargs.get("selected_expressions")
if raw_selected_expressions is not None:
selected_expressions = self._normalize_selected_expressions(raw_selected_expressions)
selected_ids = self._normalize_selected_expression_ids(
after_selection_kwargs.get("selected_expression_ids"),
selected_expressions,
)
if selected_expressions:
self.update_expressions_last_active_time(selected_expressions)
return selected_expressions, selected_ids
async def _select_expressions_classic(
self,

View File

@@ -1,5 +1,5 @@
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Set, TypedDict
from typing import Any, Callable, Dict, List, Optional, Set, TypedDict
import asyncio
import json
@@ -9,13 +9,15 @@ from json_repair import repair_json
from sqlmodel import select
from src.common.data_models.jargon_data_model import MaiJargon
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.common.database.database import get_db_session
from src.common.database.database_model import Jargon
from src.common.logger import get_logger
from src.config.config import global_config
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.services.llm_service import LLMServiceClient
from src.plugin_runtime.hook_schema_utils import build_object_schema
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
from src.prompt.prompt_manager import prompt_manager
from src.services.llm_service import LLMServiceClient
from .expression_utils import is_single_char_jargon
@@ -35,8 +37,140 @@ class JargonMeaningEntry(TypedDict):
meaning: str
def register_jargon_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""注册 jargon 系统内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 实际注册的 Hook 规格列表。
"""
return registry.register_hook_specs(
[
HookSpec(
name="jargon.query.before_search",
description="Maisaka 黑话查询工具执行检索前触发,可改写词条列表、检索参数或直接中止。",
parameters_schema=build_object_schema(
{
"words": {
"type": "array",
"items": {"type": "string"},
"description": "准备查询的黑话词条列表。",
},
"session_id": {"type": "string", "description": "当前会话 ID。"},
"limit": {"type": "integer", "description": "单个词条的最大返回条数。"},
"case_sensitive": {"type": "boolean", "description": "是否大小写敏感。"},
"enable_fuzzy_fallback": {"type": "boolean", "description": "是否允许精确命中失败后回退模糊检索。"},
"abort_message": {"type": "string", "description": "Hook 主动中止时的失败提示。"},
},
required=["words", "session_id", "limit", "case_sensitive", "enable_fuzzy_fallback"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="jargon.query.after_search",
description="Maisaka 黑话查询工具完成检索后触发,可改写结果列表或中止返回。",
parameters_schema=build_object_schema(
{
"words": {
"type": "array",
"items": {"type": "string"},
"description": "实际查询的黑话词条列表。",
},
"session_id": {"type": "string", "description": "当前会话 ID。"},
"limit": {"type": "integer", "description": "单个词条的最大返回条数。"},
"case_sensitive": {"type": "boolean", "description": "是否大小写敏感。"},
"enable_fuzzy_fallback": {"type": "boolean", "description": "是否启用了模糊检索回退。"},
"results": {
"type": "array",
"items": {"type": "object"},
"description": "查询结果列表。",
},
"abort_message": {"type": "string", "description": "Hook 主动中止时的失败提示。"},
},
required=["words", "session_id", "limit", "case_sensitive", "enable_fuzzy_fallback", "results"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="jargon.extract.before_persist",
description="黑话条目准备写入数据库前触发,可改写去重后的条目列表或跳过本次持久化。",
parameters_schema=build_object_schema(
{
"session_id": {"type": "string", "description": "当前会话 ID。"},
"session_name": {"type": "string", "description": "当前会话展示名称。"},
"entries": {
"type": "array",
"items": {"type": "object"},
"description": "即将持久化的黑话条目列表。",
},
},
required=["session_id", "session_name", "entries"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="jargon.inference.before_finalize",
description="黑话含义推断完成、写回数据库前触发,可改写最终判定与含义结果。",
parameters_schema=build_object_schema(
{
"session_id": {"type": "string", "description": "当前会话 ID。"},
"session_name": {"type": "string", "description": "当前会话展示名称。"},
"content": {"type": "string", "description": "当前黑话词条。"},
"count": {"type": "integer", "description": "当前词条累计命中次数。"},
"raw_content_list": {
"type": "array",
"items": {"type": "string"},
"description": "用于推断的原始上下文片段列表。",
},
"inference_with_context": {"type": "object", "description": "基于上下文的推断结果。"},
"inference_with_content_only": {"type": "object", "description": "仅基于词条内容的推断结果。"},
"comparison_result": {"type": "object", "description": "比较阶段输出结果。"},
"is_jargon": {"type": "boolean", "description": "当前推断是否判定为黑话。"},
"meaning": {"type": "string", "description": "当前推断出的黑话含义。"},
"is_complete": {"type": "boolean", "description": "当前是否已完成全部推断流程。"},
"last_inference_count": {"type": "integer", "description": "本次推断完成后应写回的 last_inference_count。"},
},
required=[
"session_id",
"session_name",
"content",
"count",
"raw_content_list",
"inference_with_context",
"inference_with_content_only",
"comparison_result",
"is_jargon",
"meaning",
"is_complete",
"last_inference_count",
],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
]
)
class JargonMiner:
def __init__(self, session_id: str, session_name: str) -> None:
"""初始化黑话学习器。
Args:
session_id: 当前会话 ID。
session_name: 当前会话展示名称。
"""
self.session_id = session_id
self.session_name = session_name
@@ -46,13 +180,92 @@ class JargonMiner:
# 黑话提取锁,防止并发执行
self._extraction_lock = asyncio.Lock()
@staticmethod
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
@staticmethod
def _coerce_int(value: Any, default: int) -> int:
"""将任意值安全转换为整数。
Args:
value: 待转换的值。
default: 转换失败时使用的默认值。
Returns:
int: 转换后的整数结果。
"""
try:
return int(value)
except (TypeError, ValueError):
return default
@staticmethod
def _serialize_jargon_entries(entries: List[JargonEntry]) -> List[Dict[str, object]]:
"""将黑话条目列表序列化为 Hook 可传输结构。
Args:
entries: 原始黑话条目列表。
Returns:
List[Dict[str, object]]: 序列化后的条目列表。
"""
return [
{
"content": str(entry["content"]).strip(),
"raw_content": sorted(str(item).strip() for item in entry["raw_content"] if str(item).strip()),
}
for entry in entries
if str(entry["content"]).strip()
]
@staticmethod
def _deserialize_jargon_entries(raw_entries: Any) -> List[JargonEntry]:
"""从 Hook 载荷恢复黑话条目列表。
Args:
raw_entries: Hook 返回的条目数据。
Returns:
List[JargonEntry]: 恢复后的黑话条目列表。
"""
if not isinstance(raw_entries, list):
return []
normalized_entries: List[JargonEntry] = []
for raw_entry in raw_entries:
if not isinstance(raw_entry, dict):
continue
content = str(raw_entry.get("content") or "").strip()
if not content:
continue
raw_content_values = raw_entry.get("raw_content")
raw_content: Set[str] = set()
if isinstance(raw_content_values, list):
raw_content = {str(item).strip() for item in raw_content_values if str(item).strip()}
normalized_entries.append({"content": content, "raw_content": raw_content})
return normalized_entries
def get_cached_jargons(self) -> List[str]:
"""获取缓存中的所有黑话列表"""
return list(self.cache.keys())
async def infer_meaning(self, jargon_obj: MaiJargon) -> None:
"""
对jargon进行含义推断
"""对黑话条目执行含义推断。
Args:
jargon_obj: 待推断的黑话数据对象。
"""
content = jargon_obj.content
# 解析raw_content列表
@@ -175,15 +388,45 @@ class JargonMiner:
is_similar = comparison_result.get("is_similar", False)
is_jargon = not is_similar # 如果相似,说明不是黑话;如果有差异,说明是黑话
finalized_meaning = inference1.get("meaning", "") if is_jargon else ""
is_complete = (jargon_obj.count or 0) >= 100
last_inference_count = jargon_obj.count or 0
finalize_result = await self._get_runtime_manager().invoke_hook(
"jargon.inference.before_finalize",
session_id=self.session_id,
session_name=self.session_name,
content=content,
count=current_count,
raw_content_list=list(raw_content_list),
inference_with_context=dict(inference1),
inference_with_content_only=dict(inference2),
comparison_result=dict(comparison_result),
is_jargon=is_jargon,
meaning=finalized_meaning,
is_complete=is_complete,
last_inference_count=last_inference_count,
)
if finalize_result.aborted:
logger.info(f"jargon {content} 的推断结果被 Hook 中止写回")
return
finalize_kwargs = finalize_result.kwargs
is_jargon = bool(finalize_kwargs.get("is_jargon", is_jargon))
finalized_meaning = str(finalize_kwargs.get("meaning", finalized_meaning) or "").strip() if is_jargon else ""
is_complete = bool(finalize_kwargs.get("is_complete", is_complete))
last_inference_count = self._coerce_int(
finalize_kwargs.get("last_inference_count"),
last_inference_count,
)
# 更新数据库记录
jargon_obj.is_jargon = is_jargon
jargon_obj.meaning = inference1.get("meaning", "") if is_jargon else ""
jargon_obj.meaning = finalized_meaning
# 更新最后一次判定的count值避免重启后重复判定
jargon_obj.last_inference_count = jargon_obj.count or 0
jargon_obj.last_inference_count = last_inference_count
# 如果count>=100标记为完成不再进行推断
if (jargon_obj.count or 0) >= 100:
jargon_obj.is_complete = True
jargon_obj.is_complete = is_complete
try:
self._modify_jargon_entry(jargon_obj)
@@ -232,6 +475,22 @@ class JargonMiner:
merged_entries[content] = {"content": content, "raw_content": set(raw_list)}
uniq_entries: List[JargonEntry] = list(merged_entries.values())
before_persist_result = await self._get_runtime_manager().invoke_hook(
"jargon.extract.before_persist",
session_id=self.session_id,
session_name=self.session_name,
entries=self._serialize_jargon_entries(uniq_entries),
)
if before_persist_result.aborted:
logger.info(f"[{self.session_name}] 黑话提取结果被 Hook 中止,不写入数据库")
return
raw_hook_entries = before_persist_result.kwargs.get("entries")
if raw_hook_entries is not None:
uniq_entries = self._deserialize_jargon_entries(raw_hook_entries)
if not uniq_entries:
logger.info(f"[{self.session_name}] Hook 过滤后没有可写入的黑话条目")
return
saved = 0
updated = 0

View File

@@ -3,7 +3,7 @@
from dataclasses import dataclass
from datetime import datetime
from time import perf_counter
from typing import List, Optional, Sequence
from typing import Any, List, Optional, Sequence
import asyncio
import json
@@ -26,6 +26,15 @@ from src.llm_models.model_client.base_client import BaseClient
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput, ToolOption, normalize_tool_options
from src.plugin_runtime.hook_payloads import (
deserialize_prompt_messages,
deserialize_tool_calls,
serialize_prompt_messages,
serialize_tool_calls,
serialize_tool_definitions,
)
from src.plugin_runtime.hook_schema_utils import build_object_schema
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
from src.services.llm_service import LLMServiceClient
from .builtin_tools import get_builtin_tools
@@ -58,6 +67,123 @@ class ToolFilterSelection(BaseModel):
logger = get_logger("maisaka_chat_loop")
def register_maisaka_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""注册 Maisaka 规划器内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 实际注册的 Hook 规格列表。
"""
return registry.register_hook_specs(
[
HookSpec(
name="maisaka.planner.before_request",
description="在 Maisaka 向模型发起规划请求前触发,可改写消息窗口与工具定义。",
parameters_schema=build_object_schema(
{
"messages": {
"type": "array",
"description": "即将发给模型的 PromptMessage 列表。",
},
"tool_definitions": {
"type": "array",
"description": "当前候选工具定义列表。",
},
"selected_history_count": {
"type": "integer",
"description": "当前选中的上下文消息数量。",
},
"built_message_count": {
"type": "integer",
"description": "实际发送给模型的消息数量。",
},
"selection_reason": {
"type": "string",
"description": "上下文选择说明。",
},
"session_id": {
"type": "string",
"description": "当前会话 ID。",
},
},
required=[
"messages",
"tool_definitions",
"selected_history_count",
"built_message_count",
"selection_reason",
"session_id",
],
),
default_timeout_ms=6000,
allow_abort=False,
allow_kwargs_mutation=True,
),
HookSpec(
name="maisaka.planner.after_response",
description="在 Maisaka 收到模型响应后触发,可调整文本结果与工具调用列表。",
parameters_schema=build_object_schema(
{
"response": {
"type": "string",
"description": "模型返回的文本内容。",
},
"tool_calls": {
"type": "array",
"description": "模型返回的工具调用列表。",
},
"selected_history_count": {
"type": "integer",
"description": "当前选中的上下文消息数量。",
},
"built_message_count": {
"type": "integer",
"description": "实际发送给模型的消息数量。",
},
"selection_reason": {
"type": "string",
"description": "上下文选择说明。",
},
"session_id": {
"type": "string",
"description": "当前会话 ID。",
},
"prompt_tokens": {
"type": "integer",
"description": "输入 Token 数。",
},
"completion_tokens": {
"type": "integer",
"description": "输出 Token 数。",
},
"total_tokens": {
"type": "integer",
"description": "总 Token 数。",
},
},
required=[
"response",
"tool_calls",
"selected_history_count",
"built_message_count",
"selection_reason",
"session_id",
"prompt_tokens",
"completion_tokens",
"total_tokens",
],
),
default_timeout_ms=6000,
allow_abort=False,
allow_kwargs_mutation=True,
),
]
)
class MaisakaChatLoopService:
"""负责 Maisaka 主对话循环、系统提示词和终端渲染。"""
@@ -105,6 +231,35 @@ class MaisakaChatLoopService:
return self._personality_prompt
@staticmethod
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
@staticmethod
def _coerce_int(value: Any, default: int) -> int:
"""将任意值安全转换为整数。
Args:
value: 待转换的输入值。
default: 转换失败时的默认值。
Returns:
int: 转换后的整数结果。
"""
try:
return int(value)
except (TypeError, ValueError):
return default
def _build_personality_prompt(self) -> str:
"""构造人格提示词。"""
@@ -580,6 +735,26 @@ class MaisakaChatLoopService:
else:
all_tools = [*get_builtin_tools(), *self._extra_tools]
before_request_result = await self._get_runtime_manager().invoke_hook(
"maisaka.planner.before_request",
messages=serialize_prompt_messages(built_messages),
tool_definitions=serialize_tool_definitions(all_tools),
selected_history_count=len(selected_history),
built_message_count=len(built_messages),
selection_reason=selection_reason,
session_id=self._session_id,
)
before_request_kwargs = before_request_result.kwargs
raw_messages = before_request_kwargs.get("messages")
if isinstance(raw_messages, list):
try:
built_messages = deserialize_prompt_messages(raw_messages)
except Exception as exc:
logger.warning(f"Hook maisaka.planner.before_request 返回的 messages 无法反序列化,已忽略: {exc}")
raw_tool_definitions = before_request_kwargs.get("tool_definitions")
if isinstance(raw_tool_definitions, list):
all_tools = [item for item in raw_tool_definitions if isinstance(item, dict)]
ordered_panels = PromptCLIVisualizer.build_prompt_panels(
built_messages,
image_display_mode=global_config.maisaka.terminal_image_display_mode,
@@ -625,33 +800,63 @@ class MaisakaChatLoopService:
)
logger.info(f"本轮Prompt统计: {prompt_stats_text}")
final_response = generation_result.response or ""
final_tool_calls = list(generation_result.tool_calls or [])
after_response_result = await self._get_runtime_manager().invoke_hook(
"maisaka.planner.after_response",
response=final_response,
tool_calls=serialize_tool_calls(final_tool_calls),
selected_history_count=len(selected_history),
built_message_count=len(built_messages),
selection_reason=selection_reason,
session_id=self._session_id,
prompt_tokens=generation_result.prompt_tokens,
completion_tokens=generation_result.completion_tokens,
total_tokens=generation_result.total_tokens,
)
after_response_kwargs = after_response_result.kwargs
if "response" in after_response_kwargs:
final_response = str(after_response_kwargs.get("response") or "")
raw_tool_calls = after_response_kwargs.get("tool_calls")
if isinstance(raw_tool_calls, list):
try:
final_tool_calls = deserialize_tool_calls(raw_tool_calls)
except Exception as exc:
logger.warning(f"Hook maisaka.planner.after_response 返回的 tool_calls 无法反序列化,已忽略: {exc}")
prompt_tokens = self._coerce_int(after_response_kwargs.get("prompt_tokens"), generation_result.prompt_tokens)
completion_tokens = self._coerce_int(
after_response_kwargs.get("completion_tokens"),
generation_result.completion_tokens,
)
total_tokens = self._coerce_int(after_response_kwargs.get("total_tokens"), generation_result.total_tokens)
tool_call_summaries = [
{
"调用编号": getattr(tool_call, "call_id", getattr(tool_call, "id", None)),
"工具名": getattr(tool_call, "func_name", getattr(tool_call, "name", None)),
"参数": getattr(tool_call, "args", getattr(tool_call, "arguments", None)),
}
for tool_call in (generation_result.tool_calls or [])
for tool_call in final_tool_calls
]
logger.info(
f"Maisaka 规划器返回结果: 内容={generation_result.response or ''!r} "
f"Maisaka 规划器返回结果: 内容={final_response!r} "
f"工具调用={tool_call_summaries}"
)
raw_message = AssistantMessage(
content=generation_result.response or "",
content=final_response,
timestamp=datetime.now(),
tool_calls=generation_result.tool_calls or [],
tool_calls=final_tool_calls,
)
return ChatResponse(
content=generation_result.response,
tool_calls=generation_result.tool_calls or [],
content=final_response or None,
tool_calls=final_tool_calls,
raw_message=raw_message,
selected_history_count=len(selected_history),
prompt_tokens=generation_result.prompt_tokens,
prompt_tokens=prompt_tokens,
built_message_count=len(built_messages),
completion_tokens=generation_result.completion_tokens,
total_tokens=generation_result.total_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
@staticmethod

View File

@@ -54,6 +54,84 @@ class MaisakaReasoningEngine:
self._runtime = runtime
self._last_reasoning_content: str = ""
@staticmethod
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
@staticmethod
def _normalize_words(raw_words: Any) -> list[str]:
"""清洗黑话查询词条列表。
Args:
raw_words: 原始词条列表。
Returns:
list[str]: 去重去空白后的词条列表。
"""
if not isinstance(raw_words, list):
return []
normalized_words: list[str] = []
seen_words: set[str] = set()
for item in raw_words:
if not isinstance(item, str):
continue
word = item.strip()
if not word or word in seen_words:
continue
seen_words.add(word)
normalized_words.append(word)
return normalized_words
@staticmethod
def _normalize_jargon_query_results(raw_results: Any) -> list[dict[str, object]]:
"""规范化黑话查询结果列表。
Args:
raw_results: Hook 返回的结果列表。
Returns:
list[dict[str, object]]: 清洗后的结果列表。
"""
if not isinstance(raw_results, list):
return []
normalized_results: list[dict[str, object]] = []
for raw_item in raw_results:
if not isinstance(raw_item, dict):
continue
word = str(raw_item.get("word") or "").strip()
matches = raw_item.get("matches")
normalized_matches: list[dict[str, str]] = []
if isinstance(matches, list):
for match in matches:
if not isinstance(match, dict):
continue
content = str(match.get("content") or "").strip()
meaning = str(match.get("meaning") or "").strip()
if not content or not meaning:
continue
normalized_matches.append({"content": content, "meaning": meaning})
normalized_results.append(
{
"word": word,
"found": bool(raw_item.get("found", bool(normalized_matches))),
"matches": normalized_matches,
}
)
return normalized_results
def build_builtin_tool_handlers(self) -> dict[str, "BuiltinToolHandler"]:
"""构造 Maisaka 内置工具处理器映射。
@@ -1012,6 +1090,35 @@ class MaisakaReasoningEngine:
"查询黑话工具至少需要一个非空词条。",
)
limit = 5
case_sensitive = False
enable_fuzzy_fallback = True
before_search_result = await self._get_runtime_manager().invoke_hook(
"jargon.query.before_search",
words=list(words),
session_id=self._runtime.session_id,
limit=limit,
case_sensitive=case_sensitive,
enable_fuzzy_fallback=enable_fuzzy_fallback,
abort_message="黑话查询已被 Hook 中止。",
)
if before_search_result.aborted:
abort_message = str(before_search_result.kwargs.get("abort_message") or "黑话查询已被 Hook 中止。").strip()
return self._build_tool_failure_result(tool_call.func_name, abort_message or "黑话查询已被 Hook 中止。")
before_search_kwargs = before_search_result.kwargs
if before_search_kwargs.get("words") is not None:
words = self._normalize_words(before_search_kwargs.get("words"))
if not words:
return self._build_tool_failure_result(tool_call.func_name, "Hook 过滤后没有可查询的黑话词条。")
try:
limit = int(before_search_kwargs.get("limit", limit))
except (TypeError, ValueError):
limit = 5
limit = max(limit, 1)
case_sensitive = bool(before_search_kwargs.get("case_sensitive", case_sensitive))
enable_fuzzy_fallback = bool(before_search_kwargs.get("enable_fuzzy_fallback", enable_fuzzy_fallback))
logger.info(f"{self._runtime.log_prefix} 已触发黑话查询: 词条={words!r}")
results: list[dict[str, object]] = []
@@ -1019,17 +1126,19 @@ class MaisakaReasoningEngine:
exact_matches = search_jargon(
keyword=word,
chat_id=self._runtime.session_id,
limit=5,
case_sensitive=False,
limit=limit,
case_sensitive=case_sensitive,
fuzzy=False,
)
matched_entries = exact_matches or search_jargon(
keyword=word,
chat_id=self._runtime.session_id,
limit=5,
case_sensitive=False,
fuzzy=True,
)
matched_entries = exact_matches
if not matched_entries and enable_fuzzy_fallback:
matched_entries = search_jargon(
keyword=word,
chat_id=self._runtime.session_id,
limit=limit,
case_sensitive=case_sensitive,
fuzzy=True,
)
results.append(
{
@@ -1039,6 +1148,27 @@ class MaisakaReasoningEngine:
}
)
after_search_result = await self._get_runtime_manager().invoke_hook(
"jargon.query.after_search",
words=list(words),
session_id=self._runtime.session_id,
limit=limit,
case_sensitive=case_sensitive,
enable_fuzzy_fallback=enable_fuzzy_fallback,
results=list(results),
abort_message="黑话查询结果已被 Hook 中止。",
)
if after_search_result.aborted:
abort_message = str(after_search_result.kwargs.get("abort_message") or "黑话查询结果已被 Hook 中止。").strip()
return self._build_tool_failure_result(
tool_call.func_name,
abort_message or "黑话查询结果已被 Hook 中止。",
)
raw_results = after_search_result.kwargs.get("results")
if raw_results is not None:
results = self._normalize_jargon_query_results(raw_results)
logger.info(f"{self._runtime.log_prefix} 黑话查询完成: 结果={results!r}")
return self._build_tool_success_result(
tool_call.func_name,

View File

@@ -6,6 +6,7 @@
from __future__ import annotations
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tuple, cast
from src.common.logger import get_logger
@@ -908,5 +909,27 @@ class ComponentQueryService:
return None
return dict(registration.config_schema)
def list_hook_specs(self) -> list[dict[str, Any]]:
"""返回当前运行时公开的 Hook 规格清单。
Returns:
list[dict[str, Any]]: 可直接序列化给 WebUI 的 Hook 规格列表。
"""
runtime_manager = self._get_runtime_manager()
return [
{
"name": spec.name,
"description": spec.description,
"parameters_schema": deepcopy(spec.parameters_schema),
"default_timeout_ms": spec.default_timeout_ms,
"allow_blocking": spec.allow_blocking,
"allow_observe": spec.allow_observe,
"allow_abort": spec.allow_abort,
"allow_kwargs_mutation": spec.allow_kwargs_mutation,
}
for spec in runtime_manager.list_hook_specs()
]
component_query_service = ComponentQueryService()

View File

@@ -0,0 +1,52 @@
"""内置命名 Hook 目录注册器。"""
from __future__ import annotations
from collections.abc import Callable
from typing import List
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
HookSpecRegistrar = Callable[[HookSpecRegistry], List[HookSpec]]
"""单个业务模块向注册中心写入 Hook 规格的注册器签名。"""
def _get_builtin_hook_spec_registrars() -> List[HookSpecRegistrar]:
"""返回当前内置 Hook 规格注册器列表。
Returns:
List[HookSpecRegistrar]: 已启用的内置 Hook 注册器列表。
"""
from src.chat.message_receive.bot import register_chat_hook_specs
from src.chat.emoji_system.emoji_manager import register_emoji_hook_specs
from src.learners.expression_learner import register_expression_hook_specs
from src.learners.jargon_miner import register_jargon_hook_specs
from src.maisaka.chat_loop_service import register_maisaka_hook_specs
from src.services.send_service import register_send_service_hook_specs
return [
register_chat_hook_specs,
register_emoji_hook_specs,
register_jargon_hook_specs,
register_expression_hook_specs,
register_send_service_hook_specs,
register_maisaka_hook_specs,
]
def register_builtin_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""向注册中心写入全部内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 本次完成注册后的全部内置 Hook 规格。
"""
registered_specs: List[HookSpec] = []
for registrar in _get_builtin_hook_spec_registrars():
registered_specs.extend(registrar(registry))
return registered_specs

View File

@@ -0,0 +1,178 @@
"""运行时 Hook 载荷序列化辅助。"""
from __future__ import annotations
from typing import Any, Dict, List, Sequence
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.llm_service_data_models import PromptMessage
from src.llm_models.payload_content.message import Message
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput, normalize_tool_options
from src.plugin_runtime.host.message_utils import PluginMessageUtils
def serialize_session_message(message: SessionMessage) -> Dict[str, Any]:
"""将会话消息序列化为 Hook 可传输载荷。
Args:
message: 待序列化的会话消息。
Returns:
Dict[str, Any]: 可通过插件运行时传输的消息字典。
"""
return dict(PluginMessageUtils._session_message_to_dict(message))
def deserialize_session_message(raw_message: Any) -> SessionMessage:
"""从 Hook 载荷恢复会话消息。
Args:
raw_message: Hook 返回的消息字典。
Returns:
SessionMessage: 恢复后的会话消息对象。
Raises:
ValueError: 消息结构不合法时抛出。
"""
if not isinstance(raw_message, dict):
raise ValueError("Hook 返回的 `message` 必须是字典")
return PluginMessageUtils._build_session_message_from_dict(raw_message)
def serialize_tool_calls(tool_calls: Sequence[ToolCall] | None) -> List[Dict[str, Any]]:
"""将工具调用列表序列化为 Hook 可传输载荷。
Args:
tool_calls: 原始工具调用列表。
Returns:
List[Dict[str, Any]]: 序列化后的工具调用列表。
"""
if not tool_calls:
return []
return [
{
"id": tool_call.call_id,
"function": {
"name": tool_call.func_name,
"arguments": dict(tool_call.args or {}),
},
}
for tool_call in tool_calls
]
def deserialize_tool_calls(raw_tool_calls: Any) -> List[ToolCall]:
"""从 Hook 载荷恢复工具调用列表。
Args:
raw_tool_calls: Hook 返回的工具调用列表。
Returns:
List[ToolCall]: 恢复后的工具调用列表。
Raises:
ValueError: 结构不合法时抛出。
"""
if raw_tool_calls in (None, []):
return []
if not isinstance(raw_tool_calls, list):
raise ValueError("Hook 返回的 `tool_calls` 必须是列表")
normalized_tool_calls: List[ToolCall] = []
for raw_tool_call in raw_tool_calls:
if not isinstance(raw_tool_call, dict):
raise ValueError("Hook 返回的工具调用项必须是字典")
function_info = raw_tool_call.get("function", {})
if isinstance(function_info, dict):
function_name = function_info.get("name")
function_arguments = function_info.get("arguments")
else:
function_name = raw_tool_call.get("name")
function_arguments = raw_tool_call.get("arguments")
call_id = raw_tool_call.get("id") or raw_tool_call.get("call_id")
if not isinstance(call_id, str) or not isinstance(function_name, str):
raise ValueError("Hook 返回的工具调用缺少 `id` 或函数名称")
normalized_tool_calls.append(
ToolCall(
call_id=call_id,
func_name=function_name,
args=function_arguments if isinstance(function_arguments, dict) else {},
)
)
return normalized_tool_calls
def serialize_prompt_messages(messages: Sequence[Message]) -> List[PromptMessage]:
"""将 LLM 消息列表序列化为 Hook 可传输载荷。
Args:
messages: 原始 LLM 消息列表。
Returns:
List[PromptMessage]: 序列化后的消息字典列表。
"""
serialized_messages: List[PromptMessage] = []
for message in messages:
serialized_message: PromptMessage = {
"role": message.role.value,
"content": message.content,
}
if message.tool_call_id:
serialized_message["tool_call_id"] = message.tool_call_id
if message.tool_calls:
serialized_message["tool_calls"] = serialize_tool_calls(message.tool_calls)
serialized_messages.append(serialized_message)
return serialized_messages
def deserialize_prompt_messages(raw_messages: Any) -> List[Message]:
"""从 Hook 载荷恢复 LLM 消息列表。
Args:
raw_messages: Hook 返回的消息列表。
Returns:
List[Message]: 恢复后的 LLM 消息列表。
Raises:
ValueError: 结构不合法时抛出。
"""
if not isinstance(raw_messages, list):
raise ValueError("Hook 返回的 `messages` 必须是列表")
from src.services.llm_service import _build_message_from_dict
normalized_messages: List[Message] = []
for raw_message in raw_messages:
if not isinstance(raw_message, dict):
raise ValueError("Hook 返回的消息项必须是字典")
normalized_messages.append(_build_message_from_dict(raw_message))
return normalized_messages
def serialize_tool_definitions(tool_definitions: Sequence[ToolDefinitionInput]) -> List[Dict[str, Any]]:
"""将工具定义列表序列化为 Hook 可传输载荷。
Args:
tool_definitions: 原始工具定义列表。
Returns:
List[Dict[str, Any]]: 序列化后的工具定义列表。
"""
normalized_tool_options = normalize_tool_options(list(tool_definitions))
if not normalized_tool_options:
return []
return [tool_option.to_openai_function_schema() for tool_option in normalized_tool_options]

View File

@@ -0,0 +1,31 @@
"""Hook 参数模型构造辅助。"""
from __future__ import annotations
from copy import deepcopy
from typing import Any, Dict, Sequence
def build_object_schema(
properties: Dict[str, Dict[str, Any]],
*,
required: Sequence[str] | None = None,
) -> Dict[str, Any]:
"""构造对象级 JSON Schema。
Args:
properties: 字段定义映射。
required: 必填字段名列表。
Returns:
Dict[str, Any]: 标准化后的对象级 Schema。
"""
schema: Dict[str, Any] = {
"type": "object",
"properties": deepcopy(properties),
}
normalized_required = [str(item).strip() for item in (required or []) if str(item).strip()]
if normalized_required:
schema["required"] = normalized_required
return schema

View File

@@ -18,9 +18,37 @@ import re
from src.common.logger import get_logger
from src.core.tooling import build_tool_detailed_description
from .hook_spec_registry import HookSpecRegistry
logger = get_logger("plugin_runtime.host.component_registry")
class ComponentRegistrationError(ValueError):
"""组件注册失败异常。"""
def __init__(
self,
message: str,
*,
component_name: str = "",
component_type: str = "",
plugin_id: str = "",
) -> None:
"""初始化组件注册失败异常。
Args:
message: 原始错误信息。
component_name: 组件名称。
component_type: 组件类型。
plugin_id: 插件 ID。
"""
self.component_name = str(component_name or "").strip()
self.component_type = str(component_type or "").strip()
self.plugin_id = str(plugin_id or "").strip()
super().__init__(message)
class ComponentTypes(str, Enum):
ACTION = "ACTION"
COMMAND = "COMMAND"
@@ -359,7 +387,14 @@ class ComponentRegistry:
供业务层查询可用组件、匹配命令、调度 action/event 等。
"""
def __init__(self) -> None:
def __init__(self, hook_spec_registry: Optional[HookSpecRegistry] = None) -> None:
"""初始化组件注册表。
Args:
hook_spec_registry: 可选的 Hook 规格注册中心;提供后会在注册
HookHandler 时执行规格校验。
"""
# 全量索引
self._components: Dict[str, ComponentEntry] = {} # full_name -> comp
@@ -370,6 +405,7 @@ class ComponentRegistry:
# 按插件索引
self._by_plugin: Dict[str, List[ComponentEntry]] = {}
self._hook_spec_registry = hook_spec_registry
@staticmethod
def _convert_action_metadata_to_tool_metadata(
@@ -475,77 +511,211 @@ class ComponentRegistry:
type_dict.clear()
self._by_plugin.clear()
# ====== 注册 / 注销 ======
def register_component(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
"""注册单个组件
@staticmethod
def _is_legacy_action_component(component: ComponentEntry) -> bool:
"""判断组件是否为兼容旧 Action 的 Tool 条目。
Args:
name: 组件名称不含插件id前缀
component_type: 组件类型(如 `ACTION`、`COMMAND` 等)
plugin_id: 插件id
metadata: 组件元数据
component: 待判断的组件条目。
Returns:
success (bool): 是否成功注册(失败原因通常是组件类型无效)
bool: 是否为兼容旧 Action 组件。
"""
if not isinstance(component, ToolEntry):
return False
return str(component.metadata.get("legacy_component_type", "") or "").strip().upper() == "ACTION"
def _validate_hook_handler_entry(self, component: HookHandlerEntry) -> None:
"""校验 HookHandler 是否满足已注册的 Hook 规格。
Args:
component: 待校验的 HookHandler 条目。
Raises:
ComponentRegistrationError: HookHandler 声明不合法时抛出。
"""
if self._hook_spec_registry is None:
return
hook_spec = self._hook_spec_registry.get_hook_spec(component.hook)
if hook_spec is None:
raise ComponentRegistrationError(
f"HookHandler {component.full_name} 声明了未注册的 Hook: {component.hook}",
component_name=component.name,
component_type=component.component_type.value,
plugin_id=component.plugin_id,
)
if component.is_blocking and not hook_spec.allow_blocking:
raise ComponentRegistrationError(
f"HookHandler {component.full_name} 不能注册为 blockingHook {component.hook} 不允许 blocking 处理器",
component_name=component.name,
component_type=component.component_type.value,
plugin_id=component.plugin_id,
)
if component.is_observe and not hook_spec.allow_observe:
raise ComponentRegistrationError(
f"HookHandler {component.full_name} 不能注册为 observeHook {component.hook} 不允许 observe 处理器",
component_name=component.name,
component_type=component.component_type.value,
plugin_id=component.plugin_id,
)
if component.error_policy == "abort" and not hook_spec.allow_abort:
raise ComponentRegistrationError(
f"HookHandler {component.full_name} 不能使用 error_policy=abortHook {component.hook} 不允许 abort",
component_name=component.name,
component_type=component.component_type.value,
plugin_id=component.plugin_id,
)
def _build_component_entry(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
) -> ComponentEntry:
"""根据声明构造组件条目。
Args:
name: 组件名称。
component_type: 组件类型。
plugin_id: 插件 ID。
metadata: 组件元数据。
Returns:
ComponentEntry: 已构造并完成校验的组件条目。
Raises:
ComponentRegistrationError: 组件声明不合法时抛出。
"""
try:
normalized_type = self._normalize_component_type(component_type)
normalized_metadata = dict(metadata)
if normalized_type == ComponentTypes.ACTION:
normalized_metadata = self._convert_action_metadata_to_tool_metadata(name, normalized_metadata)
comp = ToolEntry(name, ComponentTypes.TOOL.value, plugin_id, normalized_metadata)
component = ToolEntry(name, ComponentTypes.TOOL.value, plugin_id, normalized_metadata)
elif normalized_type == ComponentTypes.COMMAND:
comp = CommandEntry(name, normalized_type.value, plugin_id, normalized_metadata)
component = CommandEntry(name, normalized_type.value, plugin_id, normalized_metadata)
elif normalized_type == ComponentTypes.TOOL:
comp = ToolEntry(name, normalized_type.value, plugin_id, normalized_metadata)
component = ToolEntry(name, normalized_type.value, plugin_id, normalized_metadata)
elif normalized_type == ComponentTypes.EVENT_HANDLER:
comp = EventHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
component = EventHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
elif normalized_type == ComponentTypes.HOOK_HANDLER:
comp = HookHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
component = HookHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
self._validate_hook_handler_entry(component)
elif normalized_type == ComponentTypes.MESSAGE_GATEWAY:
comp = MessageGatewayEntry(name, normalized_type.value, plugin_id, normalized_metadata)
component = MessageGatewayEntry(name, normalized_type.value, plugin_id, normalized_metadata)
else:
raise ValueError(f"组件类型 {component_type} 不存在")
except ValueError:
logger.error(f"组件类型 {component_type} 不存在")
return False
raise ComponentRegistrationError(
f"组件类型 {component_type} 不存在",
component_name=name,
component_type=component_type,
plugin_id=plugin_id,
)
except ComponentRegistrationError:
raise
except Exception as exc:
raise ComponentRegistrationError(
str(exc),
component_name=name,
component_type=component_type,
plugin_id=plugin_id,
) from exc
if comp.full_name in self._components:
logger.warning(f"组件 {comp.full_name} 已存在,覆盖")
old_comp = self._components[comp.full_name]
# 从 _by_plugin 列表中移除旧条目,防止幽灵组件堆积
old_list = self._by_plugin.get(old_comp.plugin_id)
if old_list is not None:
with contextlib.suppress(ValueError):
old_list.remove(old_comp)
# 从旧类型索引中移除,防止类型变更时幽灵残留
if old_type_dict := self._by_type.get(old_comp.component_type):
old_type_dict.pop(comp.full_name, None)
return component
self._components[comp.full_name] = comp
self._by_type[comp.component_type][comp.full_name] = comp
self._by_plugin.setdefault(plugin_id, []).append(comp)
def _remove_existing_component_entry(self, component: ComponentEntry) -> None:
"""移除同名旧组件条目。
Args:
component: 即将写入的新组件条目。
"""
if component.full_name not in self._components:
return
logger.warning(f"组件 {component.full_name} 已存在,覆盖")
old_component = self._components[component.full_name]
old_list = self._by_plugin.get(old_component.plugin_id)
if old_list is not None:
with contextlib.suppress(ValueError):
old_list.remove(old_component)
if old_type_dict := self._by_type.get(old_component.component_type):
old_type_dict.pop(component.full_name, None)
def _add_component_entry(self, component: ComponentEntry) -> None:
"""写入单个组件条目到全部索引。
Args:
component: 待写入的组件条目。
"""
self._remove_existing_component_entry(component)
self._components[component.full_name] = component
self._by_type[component.component_type][component.full_name] = component
self._by_plugin.setdefault(component.plugin_id, []).append(component)
# ====== 注册 / 注销 ======
def register_component(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
"""注册单个组件。
Args:
name: 组件名称(不含插件 ID 前缀)。
component_type: 组件类型(如 ``ACTION``、``COMMAND`` 等)。
plugin_id: 插件 ID。
metadata: 组件元数据。
Returns:
bool: 注册成功时恒为 ``True``。
Raises:
ComponentRegistrationError: 组件声明不合法时抛出。
"""
component = self._build_component_entry(name, component_type, plugin_id, metadata)
self._add_component_entry(component)
return True
def register_plugin_components(self, plugin_id: str, components: List[Dict[str, Any]]) -> int:
"""批量注册一个插件的所有组件,返回成功注册数
"""批量替换一个插件的组件集合
该方法会先完整校验所有组件声明,只有全部通过后才会替换旧组件,
从而避免插件进入半注册状态。
Args:
plugin_id (str): 插件id
components (List[Dict[str, Any]]): 组件字典列表,每个组件包含 name, component_type, metadata 等字段
plugin_id: 插件 ID。
components: 组件声明字典列表
Returns:
count (int): 成功注册的组件数量
int: 实际注册的组件数量
Raises:
ComponentRegistrationError: 任一组件声明不合法时抛出。
"""
count = 0
for comp_data in components:
ok = self.register_component(
name=comp_data.get("name", ""),
component_type=comp_data.get("component_type", ""),
plugin_id=plugin_id,
metadata=comp_data.get("metadata", {}),
prepared_components: List[ComponentEntry] = []
for component_data in components:
prepared_components.append(
self._build_component_entry(
name=str(component_data.get("name", "") or ""),
component_type=str(component_data.get("component_type", "") or ""),
plugin_id=plugin_id,
metadata=component_data.get("metadata", {})
if isinstance(component_data.get("metadata"), dict)
else {},
)
)
if ok:
count += 1
return count
self.remove_components_by_plugin(plugin_id)
for component in prepared_components:
self._add_component_entry(component)
return len(prepared_components)
def remove_components_by_plugin(self, plugin_id: str) -> int:
"""移除某个插件的所有组件,返回移除数量。
@@ -652,6 +822,17 @@ class ComponentRegistry:
except ValueError:
logger.error(f"组件类型 {component_type} 不存在")
raise
if comp_type == ComponentTypes.ACTION:
action_components = [
component
for component in self._by_type.get(ComponentTypes.TOOL, {}).values()
if self._is_legacy_action_component(component)
]
if enabled_only:
return [component for component in action_components if self.check_component_enabled(component, session_id)]
return action_components
type_dict = self._by_type.get(comp_type, {})
if enabled_only:
return [c for c in type_dict.values() if self.check_component_enabled(c, session_id)]
@@ -854,6 +1035,34 @@ class ComponentRegistry:
tools.append(comp)
return tools
def get_tools_for_llm(self, *, enabled_only: bool = True, session_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""兼容旧接口,返回可供 LLM 使用的工具条目列表。
Args:
enabled_only: 是否仅返回启用的组件。
session_id: 可选的会话 ID若提供则考虑会话禁用状态。
Returns:
List[Dict[str, Any]]: 兼容旧结构的工具组件字典列表。
"""
return [
{
"name": tool.full_name,
"description": tool.description,
"parameters": (
dict(tool.parameters_raw)
if isinstance(tool.parameters_raw, dict) and tool.parameters_raw
else tool._get_parameters_schema() or {}
),
"parameters_raw": tool.parameters_raw,
"enabled": tool.enabled,
"plugin_id": tool.plugin_id,
}
for tool in self.get_tools(enabled_only=enabled_only, session_id=session_id)
if not self._is_legacy_action_component(tool)
]
# ====== 统计信息 ======
def get_stats(self) -> StatusDict:
"""获取注册统计。
@@ -863,9 +1072,21 @@ class ComponentRegistry:
"""
return StatusDict(
total=len(self._components),
action=len(self._by_type[ComponentTypes.ACTION]),
action=len(
[
component
for component in self._by_type.get(ComponentTypes.TOOL, {}).values()
if self._is_legacy_action_component(component)
]
),
command=len(self._by_type[ComponentTypes.COMMAND]),
tool=len(self._by_type[ComponentTypes.TOOL]),
tool=len(
[
component
for component in self._by_type.get(ComponentTypes.TOOL, {}).values()
if not self._is_legacy_action_component(component)
]
),
event_handler=len(self._by_type[ComponentTypes.EVENT_HANDLER]),
hook_handler=len(self._by_type[ComponentTypes.HOOK_HANDLER]),
message_gateway=len(self._by_type[ComponentTypes.MESSAGE_GATEWAY]),

View File

@@ -26,6 +26,8 @@ import contextlib
from src.common.logger import get_logger
from src.config.config import global_config
from .hook_spec_registry import HookSpec, HookSpecRegistry
if TYPE_CHECKING:
from .component_registry import HookHandlerEntry
from .supervisor import PluginRunnerSupervisor
@@ -33,29 +35,6 @@ if TYPE_CHECKING:
logger = get_logger("plugin_runtime.host.hook_dispatcher")
@dataclass(slots=True)
class HookSpec:
"""命名 Hook 的静态规格定义。
Attributes:
name: Hook 的唯一名称。
description: Hook 描述。
default_timeout_ms: 默认超时毫秒数;为 `0` 时退回系统默认值。
allow_blocking: 是否允许注册阻塞处理器。
allow_observe: 是否允许注册观察处理器。
allow_abort: 是否允许处理器中止当前 Hook 调用。
allow_kwargs_mutation: 是否允许阻塞处理器修改 `kwargs`。
"""
name: str
description: str = ""
default_timeout_ms: int = 0
allow_blocking: bool = True
allow_observe: bool = True
allow_abort: bool = True
allow_kwargs_mutation: bool = True
@dataclass(slots=True)
class HookHandlerExecutionResult:
"""单个 HookHandler 的执行结果。
@@ -121,17 +100,19 @@ class HookDispatcher:
def __init__(
self,
supervisors_provider: Optional[Callable[[], Sequence["PluginRunnerSupervisor"]]] = None,
hook_spec_registry: Optional[HookSpecRegistry] = None,
) -> None:
"""初始化 Hook 分发器。
Args:
supervisors_provider: 可选的 Supervisor 提供器。若调用 `invoke_hook()`
时未显式传入 `supervisors`,则使用该回调获取目标 Supervisor 列表。
hook_spec_registry: 可选的 Hook 规格注册中心;留空时使用独立注册中心。
"""
self._background_tasks: Set[asyncio.Task[Any]] = set()
self._hook_specs: Dict[str, HookSpec] = {}
self._supervisors_provider = supervisors_provider
self._hook_spec_registry = hook_spec_registry or HookSpecRegistry()
async def stop(self) -> None:
"""停止分发器并取消所有未完成的观察任务。"""
@@ -148,16 +129,7 @@ class HookDispatcher:
spec: 需要注册的 Hook 规格。
"""
normalized_name = self._normalize_hook_name(spec.name)
self._hook_specs[normalized_name] = HookSpec(
name=normalized_name,
description=spec.description,
default_timeout_ms=max(int(spec.default_timeout_ms), 0),
allow_blocking=bool(spec.allow_blocking),
allow_observe=bool(spec.allow_observe),
allow_abort=bool(spec.allow_abort),
allow_kwargs_mutation=bool(spec.allow_kwargs_mutation),
)
self._hook_spec_registry.register_hook_spec(spec)
def register_hook_specs(self, specs: Sequence[HookSpec]) -> None:
"""批量注册命名 Hook 规格。
@@ -180,14 +152,37 @@ class HookDispatcher:
"""
normalized_name = self._normalize_hook_name(hook_name)
if normalized_name in self._hook_specs:
return self._hook_specs[normalized_name]
registered_spec = self._hook_spec_registry.get_hook_spec(normalized_name)
if registered_spec is not None:
return registered_spec
return HookSpec(
name=normalized_name,
parameters_schema={},
default_timeout_ms=self._get_default_timeout_ms(),
)
def unregister_hook_spec(self, hook_name: str) -> bool:
"""注销指定命名 Hook 规格。
Args:
hook_name: 目标 Hook 名称。
Returns:
bool: 是否成功注销。
"""
return self._hook_spec_registry.unregister_hook_spec(hook_name)
def list_hook_specs(self) -> List[HookSpec]:
"""返回当前全部显式注册的 Hook 规格。
Returns:
List[HookSpec]: 已注册 Hook 规格列表。
"""
return self._hook_spec_registry.list_hook_specs()
async def invoke_hook(
self,
hook_name: str,

View File

@@ -0,0 +1,190 @@
"""命名 Hook 规格注册中心。"""
from __future__ import annotations
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence
@dataclass(slots=True)
class HookSpec:
"""命名 Hook 的静态规格定义。
Attributes:
name: Hook 的唯一名称。
description: Hook 描述。
parameters_schema: Hook 参数模型,使用对象级 JSON Schema 表示。
default_timeout_ms: 默认超时毫秒数;为 ``0`` 时退回系统默认值。
allow_blocking: 是否允许注册阻塞处理器。
allow_observe: 是否允许注册观察处理器。
allow_abort: 是否允许处理器中止当前 Hook 调用。
allow_kwargs_mutation: 是否允许阻塞处理器修改 ``kwargs``。
"""
name: str
description: str = ""
parameters_schema: Dict[str, Any] = field(default_factory=dict)
default_timeout_ms: int = 0
allow_blocking: bool = True
allow_observe: bool = True
allow_abort: bool = True
allow_kwargs_mutation: bool = True
class HookSpecRegistry:
"""命名 Hook 规格注册中心。"""
def __init__(self) -> None:
"""初始化 Hook 规格注册中心。"""
self._hook_specs: Dict[str, HookSpec] = {}
@staticmethod
def _normalize_hook_name(hook_name: str) -> str:
"""规范化 Hook 名称。
Args:
hook_name: 原始 Hook 名称。
Returns:
str: 规范化后的 Hook 名称。
Raises:
ValueError: Hook 名称为空时抛出。
"""
normalized_name = str(hook_name or "").strip()
if not normalized_name:
raise ValueError("Hook 名称不能为空")
return normalized_name
@staticmethod
def _normalize_parameters_schema(raw_schema: Any) -> Dict[str, Any]:
"""规范化 Hook 参数模型。
Args:
raw_schema: 原始参数模型。
Returns:
Dict[str, Any]: 规范化后的对象级 JSON Schema。
Raises:
ValueError: 参数模型不是合法对象级 Schema 时抛出。
"""
if raw_schema is None:
return {}
if not isinstance(raw_schema, dict):
raise ValueError("Hook 参数模型必须是字典")
if not raw_schema:
return {}
normalized_schema = deepcopy(raw_schema)
schema_type = normalized_schema.get("type")
properties = normalized_schema.get("properties")
if schema_type not in {"", None, "object"} and properties is None:
raise ValueError("Hook 参数模型必须是 object 类型或属性映射")
if schema_type in {"", None} and properties is None:
normalized_schema = {
"type": "object",
"properties": normalized_schema,
}
elif schema_type in {"", None}:
normalized_schema["type"] = "object"
if normalized_schema.get("type") != "object":
raise ValueError("Hook 参数模型必须是 object 类型")
return normalized_schema
@classmethod
def _normalize_spec(cls, spec: HookSpec) -> HookSpec:
"""规范化 Hook 规格对象。
Args:
spec: 原始 Hook 规格。
Returns:
HookSpec: 规范化后的 Hook 规格副本。
"""
return HookSpec(
name=cls._normalize_hook_name(spec.name),
description=str(spec.description or "").strip(),
parameters_schema=cls._normalize_parameters_schema(spec.parameters_schema),
default_timeout_ms=max(int(spec.default_timeout_ms), 0),
allow_blocking=bool(spec.allow_blocking),
allow_observe=bool(spec.allow_observe),
allow_abort=bool(spec.allow_abort),
allow_kwargs_mutation=bool(spec.allow_kwargs_mutation),
)
def clear(self) -> None:
"""清空全部 Hook 规格。"""
self._hook_specs.clear()
def register_hook_spec(self, spec: HookSpec) -> HookSpec:
"""注册单个 Hook 规格。
Args:
spec: 需要注册的 Hook 规格。
Returns:
HookSpec: 规范化后实际注册的 Hook 规格。
"""
normalized_spec = self._normalize_spec(spec)
self._hook_specs[normalized_spec.name] = normalized_spec
return normalized_spec
def register_hook_specs(self, specs: Sequence[HookSpec]) -> List[HookSpec]:
"""批量注册 Hook 规格。
Args:
specs: 需要注册的 Hook 规格列表。
Returns:
List[HookSpec]: 规范化后实际注册的 Hook 规格列表。
"""
return [self.register_hook_spec(spec) for spec in specs]
def unregister_hook_spec(self, hook_name: str) -> bool:
"""注销指定 Hook 规格。
Args:
hook_name: 目标 Hook 名称。
Returns:
bool: 是否成功删除。
"""
normalized_name = self._normalize_hook_name(hook_name)
return self._hook_specs.pop(normalized_name, None) is not None
def get_hook_spec(self, hook_name: str) -> Optional[HookSpec]:
"""获取指定 Hook 的显式规格。
Args:
hook_name: 目标 Hook 名称。
Returns:
Optional[HookSpec]: 已注册时返回规格副本,否则返回 ``None``。
"""
normalized_name = self._normalize_hook_name(hook_name)
spec = self._hook_specs.get(normalized_name)
return None if spec is None else self._normalize_spec(spec)
def list_hook_specs(self) -> List[HookSpec]:
"""返回当前全部 Hook 规格。
Returns:
List[HookSpec]: 按 Hook 名称升序排列的规格副本列表。
"""
return [
self._normalize_spec(spec)
for _, spec in sorted(self._hook_specs.items(), key=lambda item: item[0])
]

View File

@@ -27,6 +27,8 @@ from src.plugin_runtime.protocol.envelope import (
ConfigUpdatedPayload,
Envelope,
HealthPayload,
InspectPluginConfigPayload,
InspectPluginConfigResultPayload,
MessageGatewayStateUpdatePayload,
MessageGatewayStateUpdateResultPayload,
PROTOCOL_VERSION,
@@ -52,6 +54,7 @@ from .capability_service import CapabilityService
from .component_registry import ComponentRegistry
from .event_dispatcher import EventDispatcher
from .hook_dispatcher import HookDispatchResult, HookDispatcher
from .hook_spec_registry import HookSpecRegistry
from .logger_bridge import RunnerLogBridge
from .message_gateway import MessageGateway
from .rpc_server import RPCServer
@@ -84,6 +87,7 @@ class PluginRunnerSupervisor:
self,
plugin_dirs: Optional[List[Path]] = None,
group_name: str = "third_party",
hook_spec_registry: Optional[HookSpecRegistry] = None,
socket_path: Optional[str] = None,
health_check_interval_sec: Optional[float] = None,
max_restart_attempts: Optional[int] = None,
@@ -94,6 +98,7 @@ class PluginRunnerSupervisor:
Args:
plugin_dirs: 由当前 Runner 负责加载的插件目录列表。
group_name: 当前 Supervisor 所属运行时分组名称。
hook_spec_registry: 可选的共享 Hook 规格注册中心。
socket_path: 自定义 IPC 地址;留空时由传输层自动生成。
health_check_interval_sec: 健康检查间隔,单位秒。
max_restart_attempts: 自动重启 Runner 的最大次数。
@@ -110,9 +115,12 @@ class PluginRunnerSupervisor:
self._authorization = AuthorizationManager()
self._capability_service = CapabilityService(self._authorization)
self._api_registry = APIRegistry()
self._component_registry = ComponentRegistry()
self._component_registry = ComponentRegistry(hook_spec_registry=hook_spec_registry)
self._event_dispatcher = EventDispatcher(self._component_registry)
self._hook_dispatcher = HookDispatcher(lambda: [self])
self._hook_dispatcher = HookDispatcher(
lambda: [self],
hook_spec_registry=hook_spec_registry,
)
self._message_gateway = MessageGateway(self._component_registry)
self._log_bridge = RunnerLogBridge()
@@ -581,6 +589,49 @@ class PluginRunnerSupervisor:
raise ValueError("插件配置校验失败")
return dict(result.normalized_config)
async def inspect_plugin_config(
self,
plugin_id: str,
config_data: Optional[Dict[str, Any]] = None,
*,
use_provided_config: bool = False,
) -> InspectPluginConfigResultPayload:
"""请求 Runner 解析插件配置元数据。
Args:
plugin_id: 目标插件 ID。
config_data: 可选的配置内容。
use_provided_config: 是否优先使用传入配置而不是磁盘配置。
Returns:
InspectPluginConfigResultPayload: 插件配置解析结果。
Raises:
ValueError: Runner 无法解析插件或返回了错误响应时抛出。
"""
payload = InspectPluginConfigPayload(
config_data=config_data or {},
use_provided_config=use_provided_config,
)
try:
response = await self._rpc_server.send_request(
"plugin.inspect_config",
plugin_id=plugin_id,
payload=payload.model_dump(),
timeout_ms=10000,
)
except Exception as exc:
raise ValueError(f"插件配置解析请求失败: {exc}") from exc
if response.error:
raise ValueError(str(response.error.get("message", "插件配置解析失败")))
result = InspectPluginConfigResultPayload.model_validate(response.payload)
if not result.success:
raise ValueError("插件配置解析失败")
return result
def get_config_reload_subscribers(self, scope: str) -> List[str]:
"""返回订阅指定全局配置广播的插件列表。
@@ -713,15 +764,25 @@ class PluginRunnerSupervisor:
component_declarations = [component.model_dump() for component in payload.components]
runtime_components, api_components = self._split_component_declarations(component_declarations)
self._component_registry.remove_components_by_plugin(payload.plugin_id)
self._api_registry.remove_apis_by_plugin(payload.plugin_id)
await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
try:
registered_count = self._component_registry.register_plugin_components(
payload.plugin_id,
runtime_components,
)
except Exception as exc:
logger.error(f"插件 {payload.plugin_id} 组件注册失败: {exc}")
return envelope.make_error_response(
ErrorCode.E_BAD_PAYLOAD.value,
str(exc),
details={
"plugin_id": payload.plugin_id,
"component_count": len(runtime_components),
},
)
registered_count = self._component_registry.register_plugin_components(
payload.plugin_id,
runtime_components,
)
self._api_registry.remove_apis_by_plugin(payload.plugin_id)
registered_api_count = self._api_registry.register_plugin_apis(payload.plugin_id, api_components)
await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
self._registered_plugins[payload.plugin_id] = payload
self._message_gateway_states[payload.plugin_id] = {}

View File

@@ -25,6 +25,7 @@ from typing import (
)
import asyncio
import inspect
import tomlkit
@@ -32,14 +33,17 @@ from src.common.logger import get_logger
from src.config.config import config_manager
from src.config.file_watcher import FileChange, FileWatcher
from src.platform_io import DeliveryBatch, InboundMessageEnvelope, get_platform_io_manager
from src.plugin_runtime.hook_catalog import register_builtin_hook_specs
from src.plugin_runtime.capabilities import (
RuntimeComponentCapabilityMixin,
RuntimeCoreCapabilityMixin,
RuntimeDataCapabilityMixin,
)
from src.plugin_runtime.capabilities.registry import register_capability_impls
from src.plugin_runtime.host.hook_dispatcher import HookDispatchResult, HookDispatcher, HookSpec
from src.plugin_runtime.host.hook_dispatcher import HookDispatchResult, HookDispatcher
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
from src.plugin_runtime.host.message_utils import MessageDict, PluginMessageUtils
from src.plugin_runtime.protocol.envelope import InspectPluginConfigResultPayload
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
if TYPE_CHECKING:
@@ -87,7 +91,12 @@ class PluginRuntimeManager(
self._manifest_validator: ManifestValidator = ManifestValidator()
self._config_reload_callback: Callable[[Sequence[str]], Awaitable[None]] = self._handle_main_config_reload
self._config_reload_callback_registered: bool = False
self._hook_dispatcher: HookDispatcher = HookDispatcher(lambda: self.supervisors)
self._hook_spec_registry: HookSpecRegistry = HookSpecRegistry()
self._builtin_hook_specs_registered: bool = False
self._hook_dispatcher: HookDispatcher = HookDispatcher(
lambda: self.supervisors,
hook_spec_registry=self._hook_spec_registry,
)
async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None:
"""接收 Platform IO 审核后的入站消息并送入主消息链。
@@ -155,6 +164,33 @@ class PluginRuntimeManager(
return ["third_party", "builtin"]
return ["builtin", "third_party"]
@staticmethod
def _instantiate_supervisor(supervisor_cls: Any, **kwargs: Any) -> Any:
"""兼容不同构造签名地实例化 Supervisor。
Args:
supervisor_cls: 目标 Supervisor 类。
**kwargs: 期望传入的构造参数。
Returns:
Any: 实例化后的 Supervisor。
"""
signature = inspect.signature(supervisor_cls)
accepts_var_keyword = any(
parameter.kind == inspect.Parameter.VAR_KEYWORD
for parameter in signature.parameters.values()
)
if accepts_var_keyword:
return supervisor_cls(**kwargs)
supported_kwargs = {
key: value
for key, value in kwargs.items()
if key in signature.parameters
}
return supervisor_cls(**supported_kwargs)
# ─── 生命周期 ─────────────────────────────────────────────
async def start(self) -> None:
@@ -185,6 +221,7 @@ class PluginRuntimeManager(
logger.info("未找到任何插件目录,跳过插件运行时启动")
return
self.ensure_builtin_hook_specs_registered()
platform_io_manager = get_platform_io_manager()
# 从配置读取自定义 IPC socket 路径(留空则自动生成)
@@ -196,17 +233,21 @@ class PluginRuntimeManager(
# 创建两个 Supervisor各自拥有独立的 socket / Runner 子进程
if builtin_dirs:
self._builtin_supervisor = PluginSupervisor(
self._builtin_supervisor = self._instantiate_supervisor(
PluginSupervisor,
plugin_dirs=builtin_dirs,
group_name="builtin",
hook_spec_registry=self._hook_spec_registry,
socket_path=builtin_socket,
)
self._register_capability_impls(self._builtin_supervisor)
if third_party_dirs:
self._third_party_supervisor = PluginSupervisor(
self._third_party_supervisor = self._instantiate_supervisor(
PluginSupervisor,
plugin_dirs=third_party_dirs,
group_name="third_party",
hook_spec_registry=self._hook_spec_registry,
socket_path=third_party_socket,
)
self._register_capability_impls(self._third_party_supervisor)
@@ -328,6 +369,7 @@ class PluginRuntimeManager(
spec: 需要注册的 Hook 规格。
"""
self.ensure_builtin_hook_specs_registered()
self._hook_dispatcher.register_hook_spec(spec)
def register_hook_specs(self, specs: Sequence[HookSpec]) -> None:
@@ -337,8 +379,41 @@ class PluginRuntimeManager(
specs: 需要注册的 Hook 规格序列。
"""
self.ensure_builtin_hook_specs_registered()
self._hook_dispatcher.register_hook_specs(specs)
def unregister_hook_spec(self, hook_name: str) -> bool:
"""注销指定命名 Hook 规格。
Args:
hook_name: 目标 Hook 名称。
Returns:
bool: 是否成功注销。
"""
self.ensure_builtin_hook_specs_registered()
return self._hook_dispatcher.unregister_hook_spec(hook_name)
def list_hook_specs(self) -> List[HookSpec]:
"""返回当前全部命名 Hook 规格。
Returns:
List[HookSpec]: 当前已注册的 Hook 规格列表。
"""
self.ensure_builtin_hook_specs_registered()
return self._hook_dispatcher.list_hook_specs()
def ensure_builtin_hook_specs_registered(self) -> None:
"""确保内置 Hook 规格已经注册到共享中心表。"""
if self._builtin_hook_specs_registered:
return
register_builtin_hook_specs(self._hook_spec_registry)
self._builtin_hook_specs_registered = True
def _build_registered_dependency_map(self) -> Dict[str, Set[str]]:
"""根据当前已注册插件构建全局依赖图。"""
@@ -542,8 +617,8 @@ class PluginRuntimeManager(
config_data: 待校验的配置内容。
Returns:
Dict[str, Any] | None: 校验成功时返回规范化后的配置;若插件当前未加载
或运行时不可用,则返回 ``None`` 以便调用方回退到静态 Schema 方案。
Dict[str, Any] | None: 校验成功时返回规范化后的配置;若插件不存在、
当前不可路由或运行时不可用,则返回 ``None`` 以便调用方回退到弱推断方案。
Raises:
ValueError: 插件已加载,但配置校验失败时抛出。
@@ -558,6 +633,8 @@ class PluginRuntimeManager(
logger.warning(f"插件 {plugin_id} 配置校验路由失败,将回退到静态 Schema: {exc}")
return None
if supervisor is None:
supervisor = self._find_supervisor_by_plugin_directory(plugin_id)
if supervisor is None:
return None
@@ -569,6 +646,54 @@ class PluginRuntimeManager(
logger.warning(f"插件 {plugin_id} 运行时配置校验不可用,将回退到静态 Schema: {exc}")
return None
async def inspect_plugin_config(
self,
plugin_id: str,
config_data: Optional[Dict[str, Any]] = None,
*,
use_provided_config: bool = False,
) -> InspectPluginConfigResultPayload | None:
"""请求运行时解析插件配置元数据。
Args:
plugin_id: 目标插件 ID。
config_data: 可选的配置内容。
use_provided_config: 是否优先使用传入的配置内容而不是磁盘配置。
Returns:
InspectPluginConfigResultPayload | None: 解析成功时返回结构化结果;若插件
当前不可路由或运行时不可用,则返回 ``None``。
Raises:
ValueError: 插件存在,但运行时明确拒绝解析请求时抛出。
"""
if not self._started:
return None
try:
supervisor = self._get_supervisor_for_plugin(plugin_id)
except RuntimeError as exc:
logger.warning(f"插件 {plugin_id} 配置解析路由失败: {exc}")
return None
if supervisor is None:
supervisor = self._find_supervisor_by_plugin_directory(plugin_id)
if supervisor is None:
return None
try:
return await supervisor.inspect_plugin_config(
plugin_id=plugin_id,
config_data=config_data,
use_provided_config=use_provided_config,
)
except ValueError:
raise
except Exception as exc:
logger.warning(f"插件 {plugin_id} 配置解析不可用: {exc}")
return None
@staticmethod
def _normalize_config_reload_scopes(changed_scopes: Sequence[str]) -> tuple[str, ...]:
"""规范化配置热重载范围列表。
@@ -771,7 +896,15 @@ class PluginRuntimeManager(
return matches[0] if matches else None
async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool:
"""加载或重载单个插件,并为其补齐跨 Supervisor 外部依赖。"""
"""加载或重载单个插件,并为其补齐跨 Supervisor 外部依赖。
Args:
plugin_id: 目标插件 ID。
reason: 加载或重载原因。
Returns:
bool: 插件最终是否处于已加载状态。
"""
normalized_plugin_id = str(plugin_id or "").strip()
if not normalized_plugin_id:
@@ -789,11 +922,12 @@ class PluginRuntimeManager(
if supervisor is None:
return False
return await supervisor.reload_plugins(
reloaded = await supervisor.reload_plugins(
plugin_ids=[normalized_plugin_id],
reason=reason,
external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor),
)
return reloaded and normalized_plugin_id in supervisor.get_loaded_plugin_ids()
@classmethod
def _find_duplicate_plugin_ids(cls, plugin_dirs: List[Path]) -> Dict[str, List[Path]]:
@@ -920,15 +1054,16 @@ class PluginRuntimeManager(
return None
def _refresh_plugin_config_watch_subscriptions(self) -> None:
"""按当前已注册插件集合刷新 config.toml 的单插件订阅。
"""按当前可识别插件集合刷新 config.toml 的单插件订阅。
当插件热重载后,插件集合或目录位置可能发生变化,因此需要重新对齐
watcher 的订阅,确保每个插件配置变更只触发对应 plugin_id。
这里不仅覆盖当前已注册插件,也覆盖已存在但暂未激活的合法插件。
"""
if self._plugin_file_watcher is None:
return
desired_plugin_paths = dict(self._iter_registered_plugin_paths())
desired_plugin_paths = dict(self._iter_watchable_plugin_paths())
self._plugin_path_cache = desired_plugin_paths.copy()
desired_config_paths = {
plugin_id: plugin_path / "config.toml" for plugin_id, plugin_path in desired_plugin_paths.items()
@@ -970,6 +1105,18 @@ class PluginRuntimeManager(
if plugin_path := self._get_plugin_path_for_supervisor(supervisor, plugin_id):
yield plugin_id, plugin_path
def _iter_watchable_plugin_paths(self) -> Iterable[Tuple[str, Path]]:
"""迭代应被配置监听器追踪的插件目录。
Returns:
Iterable[Tuple[str, Path]]: ``(plugin_id, plugin_path)`` 迭代器。
"""
watchable_plugin_paths = dict(self._iter_discovered_plugin_paths(self._iter_plugin_dirs()))
for plugin_id, plugin_path in self._iter_registered_plugin_paths():
watchable_plugin_paths.setdefault(plugin_id, plugin_path)
yield from watchable_plugin_paths.items()
def _get_plugin_config_path_for_supervisor(self, supervisor: Any, plugin_id: str) -> Optional[Path]:
"""从指定 Supervisor 的插件目录中定位某个插件的 config.toml。"""
plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
@@ -993,18 +1140,43 @@ class PluginRuntimeManager(
return
if supervisor is None:
supervisor = self._find_supervisor_by_plugin_directory(plugin_id)
if supervisor is None:
return
plugin_is_loaded = plugin_id in getattr(supervisor, "_registered_plugins", {})
try:
snapshot = await supervisor.inspect_plugin_config(plugin_id)
except Exception as exc:
logger.warning(f"插件 {plugin_id} 配置文件变更解析失败: {exc}")
return
try:
config_payload = self._load_plugin_config_for_supervisor(supervisor, plugin_id)
delivered = await supervisor.notify_plugin_config_updated(
plugin_id=plugin_id,
config_data=config_payload,
config_version="",
config_scope="self",
)
if not delivered:
logger.warning(f"插件 {plugin_id} 配置文件变更后通知失败")
if plugin_is_loaded and snapshot.enabled:
delivered = await supervisor.notify_plugin_config_updated(
plugin_id=plugin_id,
config_data=dict(snapshot.normalized_config),
config_version="",
config_scope="self",
)
if not delivered:
logger.warning(f"插件 {plugin_id} 配置文件变更后通知失败")
return
if plugin_is_loaded and not snapshot.enabled:
reloaded = await self.reload_plugins_globally([plugin_id], reason="config_disabled")
if not reloaded:
logger.warning(f"插件 {plugin_id} 禁用配置已写入,但运行时卸载失败")
return
if not snapshot.enabled:
logger.info(f"插件 {plugin_id} 当前处于禁用状态,跳过自动加载")
return
loaded = await self.load_plugin_globally(plugin_id, reason="config_enabled")
if not loaded:
logger.warning(f"插件 {plugin_id} 配置文件变更后自动加载失败")
except Exception as exc:
logger.warning(f"插件 {plugin_id} 配置文件变更处理失败: {exc}")

View File

@@ -288,6 +288,8 @@ class RunnerReadyPayload(BaseModel):
"""已完成初始化的插件列表"""
failed_plugins: List[str] = Field(default_factory=list, description="初始化失败的插件列表")
"""初始化失败的插件列表"""
inactive_plugins: List[str] = Field(default_factory=list, description="当前因禁用或依赖不可用而未激活的插件列表")
"""当前因禁用或依赖不可用而未激活的插件列表"""
# ====== 配置更新 ======
@@ -311,6 +313,32 @@ class ValidatePluginConfigPayload(BaseModel):
"""待校验的配置内容"""
class InspectPluginConfigPayload(BaseModel):
"""plugin.inspect_config 请求 payload。"""
config_data: Dict[str, Any] = Field(default_factory=dict, description="可选的配置内容")
"""可选的配置内容"""
use_provided_config: bool = Field(default=False, description="是否优先使用请求中携带的配置内容")
"""是否优先使用请求中携带的配置内容"""
class InspectPluginConfigResultPayload(BaseModel):
"""plugin.inspect_config 响应 payload。"""
success: bool = Field(description="是否解析成功")
"""是否解析成功"""
default_config: Dict[str, Any] = Field(default_factory=dict, description="插件默认配置")
"""插件默认配置"""
config_schema: Dict[str, Any] = Field(default_factory=dict, description="插件配置 Schema")
"""插件配置 Schema"""
normalized_config: Dict[str, Any] = Field(default_factory=dict, description="归一化后的配置内容")
"""归一化后的配置内容"""
changed: bool = Field(default=False, description="是否在归一化过程中自动补齐或修正了配置")
"""是否在归一化过程中自动补齐或修正了配置"""
enabled: bool = Field(default=True, description="插件在当前配置下是否应被视为启用")
"""插件在当前配置下是否应被视为启用"""
class ValidatePluginConfigResultPayload(BaseModel):
"""plugin.validate_config 响应 payload。"""
@@ -380,6 +408,8 @@ class ReloadPluginResultPayload(BaseModel):
"""成功完成重载的插件列表"""
unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表")
"""本次已卸载的插件列表"""
inactive_plugins: List[str] = Field(default_factory=list, description="本次处于未激活状态的插件列表")
"""本次处于未激活状态的插件列表"""
failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因")
"""重载失败的插件及原因"""
@@ -395,6 +425,8 @@ class ReloadPluginsResultPayload(BaseModel):
"""成功完成重载的插件列表"""
unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表")
"""本次已卸载的插件列表"""
inactive_plugins: List[str] = Field(default_factory=list, description="本次处于未激活状态的插件列表")
"""本次处于未激活状态的插件列表"""
failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因")
"""重载失败的插件及原因"""

View File

@@ -9,8 +9,10 @@
6. 转发插件的能力调用到 Host
"""
from collections.abc import Mapping
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Protocol, Set, Tuple, cast
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, cast
import asyncio
import contextlib
@@ -39,6 +41,8 @@ from src.plugin_runtime.protocol.envelope import (
ConfigUpdatedPayload,
Envelope,
HealthPayload,
InspectPluginConfigPayload,
InspectPluginConfigResultPayload,
InvokePayload,
InvokeResultPayload,
RegisterPluginPayload,
@@ -141,6 +145,14 @@ class _ConfigAwarePlugin(Protocol):
...
class PluginActivationStatus(str, Enum):
"""描述插件激活结果。"""
LOADED = "loaded"
INACTIVE = "inactive"
FAILED = "failed"
def _install_shutdown_signal_handlers(
mark_runner_shutting_down: Callable[[], None],
loop: Optional[asyncio.AbstractEventLoop] = None,
@@ -236,13 +248,43 @@ class PluginRunner:
# 4. 注入 PluginContext + 调用 on_load 生命周期钩子
failed_plugins: Set[str] = set(self._loader.failed_plugins.keys())
inactive_plugins: Set[str] = set()
available_plugin_versions: Dict[str, str] = dict(self._external_available_plugins)
for meta in plugins:
ok = await self._activate_plugin(meta)
if not ok:
unsatisfied_dependencies = [
dependency.id
for dependency in meta.manifest.plugin_dependencies
if dependency.id not in available_plugin_versions
or not self._loader.manifest_validator.is_plugin_dependency_satisfied(
dependency,
available_plugin_versions[dependency.id],
)
]
if unsatisfied_dependencies:
if any(dependency_id in inactive_plugins for dependency_id in unsatisfied_dependencies):
logger.info(
f"插件 {meta.plugin_id} 依赖的插件当前未激活,跳过本次启动: {', '.join(unsatisfied_dependencies)}"
)
inactive_plugins.add(meta.plugin_id)
continue
failed_plugins.add(meta.plugin_id)
continue
successful_plugins = [meta.plugin_id for meta in plugins if meta.plugin_id not in failed_plugins]
await self._notify_ready(successful_plugins, sorted(failed_plugins))
activation_status = await self._activate_plugin(meta)
if activation_status == PluginActivationStatus.LOADED:
available_plugin_versions[meta.plugin_id] = meta.version
continue
if activation_status == PluginActivationStatus.INACTIVE:
inactive_plugins.add(meta.plugin_id)
continue
failed_plugins.add(meta.plugin_id)
successful_plugins = [
meta.plugin_id
for meta in plugins
if meta.plugin_id not in failed_plugins and meta.plugin_id not in inactive_plugins
]
await self._notify_ready(successful_plugins, sorted(failed_plugins), sorted(inactive_plugins))
# 5. 等待直到收到关停信号
with contextlib.suppress(asyncio.CancelledError):
@@ -352,17 +394,17 @@ class PluginRunner:
cast(_ContextAwarePlugin, instance)._set_context(ctx)
logger.debug(f"已为插件 {plugin_id} 注入 PluginContext")
def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[Dict[str, Any]] = None) -> None:
def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""在 Runner 侧为插件实例注入当前插件配置。
Args:
meta: 插件元数据。
config_data: 可选的配置数据;留空时自动从插件目录读取。
Returns:
Dict[str, Any]: 归一化后的当前插件配置。
"""
instance = meta.instance
if not hasattr(instance, "set_plugin_config"):
return
raw_config = config_data if config_data is not None else self._load_plugin_config(meta.plugin_dir)
plugin_config, should_persist = self._normalize_plugin_config(instance, raw_config)
config_path = Path(meta.plugin_dir) / "config.toml"
@@ -370,10 +412,12 @@ class PluginRunner:
should_initialize_file = not config_path.exists() and bool(default_config)
if should_persist or should_initialize_file:
self._save_plugin_config(meta.plugin_dir, plugin_config)
try:
cast(_ConfigAwarePlugin, instance).set_plugin_config(plugin_config)
except Exception as exc:
logger.warning(f"插件 {meta.plugin_id} 配置注入失败: {exc}")
if hasattr(instance, "set_plugin_config"):
try:
cast(_ConfigAwarePlugin, instance).set_plugin_config(plugin_config)
except Exception as exc:
logger.warning(f"插件 {meta.plugin_id} 配置注入失败: {exc}")
return plugin_config
def _normalize_plugin_config(
self,
@@ -405,6 +449,33 @@ class PluginRunner:
logger.warning(f"插件配置归一化失败,将回退为原始配置: {exc}")
return normalized_config, False
@staticmethod
def _is_plugin_enabled(config_data: Optional[Mapping[str, Any]]) -> bool:
"""根据配置内容判断插件是否应被视为启用。
Args:
config_data: 当前插件配置。
Returns:
bool: 插件是否启用。
"""
if not isinstance(config_data, Mapping):
return True
plugin_section = config_data.get("plugin")
if not isinstance(plugin_section, Mapping):
return True
enabled_value = plugin_section.get("enabled", True)
if isinstance(enabled_value, str):
normalized_value = enabled_value.strip().lower()
if normalized_value in {"0", "false", "no", "off"}:
return False
if normalized_value in {"1", "true", "yes", "on"}:
return True
return bool(enabled_value)
@staticmethod
def _save_plugin_config(plugin_dir: str, config_data: Dict[str, Any]) -> None:
"""将插件配置写回到 ``config.toml``。
@@ -435,6 +506,99 @@ class PluginRunner:
return loaded if isinstance(loaded, dict) else {}
def _resolve_plugin_candidate(self, plugin_id: str) -> Tuple[Optional[PluginCandidate], Optional[str]]:
"""解析指定插件的候选目录。
Args:
plugin_id: 目标插件 ID。
Returns:
Tuple[Optional[PluginCandidate], Optional[str]]: 候选插件与错误信息。
"""
candidates, duplicate_candidates = self._loader.discover_candidates(self._plugin_dirs)
if plugin_id in duplicate_candidates:
conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id])
return None, f"检测到重复插件 ID: {conflict_paths}"
candidate = candidates.get(plugin_id)
if candidate is None:
return None, f"未找到插件: {plugin_id}"
return candidate, None
def _resolve_plugin_meta_for_config_request(
self,
plugin_id: str,
) -> Tuple[Optional[PluginMeta], bool, Optional[str]]:
"""为配置相关请求解析插件元数据。
Args:
plugin_id: 目标插件 ID。
Returns:
Tuple[Optional[PluginMeta], bool, Optional[str]]: 依次为插件元数据、
是否为临时冷加载实例、以及错误信息。
"""
loaded_meta = self._loader.get_plugin(plugin_id)
if loaded_meta is not None:
return loaded_meta, False, None
candidate, error_message = self._resolve_plugin_candidate(plugin_id)
if candidate is None:
return None, False, error_message
try:
meta = self._loader.load_candidate(plugin_id, candidate)
except Exception as exc:
return None, False, str(exc)
if meta is None:
return None, False, "插件模块加载失败"
return meta, True, None
def _inspect_plugin_config(
self,
meta: PluginMeta,
*,
config_data: Optional[Dict[str, Any]] = None,
use_provided_config: bool = False,
suppress_errors: bool = True,
) -> InspectPluginConfigResultPayload:
"""解析插件代码定义的配置元数据。
Args:
meta: 插件元数据。
config_data: 可选的配置内容。
use_provided_config: 是否优先使用传入的配置内容。
suppress_errors: 是否在归一化失败时回退原始配置。
Returns:
InspectPluginConfigResultPayload: 结构化解析结果。
"""
raw_config = config_data if use_provided_config else self._load_plugin_config(meta.plugin_dir)
if use_provided_config and config_data is None:
raw_config = {}
normalized_config, changed = self._normalize_plugin_config(
meta.instance,
raw_config,
suppress_errors=suppress_errors,
)
default_config = self._get_plugin_default_config(meta.instance)
if not normalized_config and not raw_config and default_config:
normalized_config = dict(default_config)
changed = True
return InspectPluginConfigResultPayload(
success=True,
default_config=default_config,
config_schema=self._get_plugin_config_schema(meta),
normalized_config=normalized_config,
changed=changed,
enabled=self._is_plugin_enabled(normalized_config),
)
def _register_handlers(self) -> None:
"""注册 Host -> Runner 的方法处理器。"""
self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke)
@@ -448,6 +612,7 @@ class PluginRunner:
self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown)
self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown)
self._rpc_client.register_method("plugin.config_updated", self._handle_config_updated)
self._rpc_client.register_method("plugin.inspect_config", self._handle_inspect_plugin_config)
self._rpc_client.register_method("plugin.validate_config", self._handle_validate_plugin_config)
self._rpc_client.register_method("plugin.reload", self._handle_reload_plugin)
self._rpc_client.register_method("plugin.reload_batch", self._handle_reload_plugins)
@@ -579,6 +744,9 @@ class PluginRunner:
)
if response.error:
raise RuntimeError(response.error.get("message", "插件注册失败"))
response_payload = response.payload if isinstance(response.payload, dict) else {}
if not bool(response_payload.get("accepted", True)):
raise RuntimeError(str(response_payload.get("reason", "插件注册失败")))
logger.info(f"插件 {meta.plugin_id} 注册完成")
return True
except Exception as e:
@@ -689,36 +857,40 @@ class PluginRunner:
except Exception as exc:
logger.error(f"插件 {meta.plugin_id} on_unload 失败: {exc}", exc_info=True)
async def _activate_plugin(self, meta: PluginMeta) -> bool:
async def _activate_plugin(self, meta: PluginMeta) -> PluginActivationStatus:
"""完成插件注入、授权、生命周期和组件注册。
Args:
meta: 待激活的插件元数据。
Returns:
bool: 是否激活成功
PluginActivationStatus: 插件激活结果
"""
self._inject_context(meta.plugin_id, meta.instance)
self._apply_plugin_config(meta)
plugin_config = self._apply_plugin_config(meta)
if not self._is_plugin_enabled(plugin_config):
logger.info(f"插件 {meta.plugin_id} 已在配置中禁用,跳过激活")
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
return PluginActivationStatus.INACTIVE
if not await self._bootstrap_plugin(meta):
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
return False
return PluginActivationStatus.FAILED
if not await self._register_plugin(meta):
await self._invoke_plugin_on_unload(meta)
await self._deactivate_plugin(meta)
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
return False
return PluginActivationStatus.FAILED
if not await self._invoke_plugin_on_load(meta):
await self._unregister_plugin(meta.plugin_id, reason="on_load_failed")
await self._deactivate_plugin(meta)
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
return False
return PluginActivationStatus.FAILED
self._loader.set_loaded_plugin(meta)
return True
return PluginActivationStatus.LOADED
async def _unload_plugin(self, meta: PluginMeta, reason: str, *, purge_modules: bool = True) -> None:
"""卸载单个插件并清理 Host/Runner 两侧状态。
@@ -879,6 +1051,7 @@ class PluginRunner:
requested_plugin_id=plugin_id,
reloaded_plugins=batch_result.reloaded_plugins,
unloaded_plugins=batch_result.unloaded_plugins,
inactive_plugins=batch_result.inactive_plugins,
failed_plugins=batch_result.failed_plugins,
)
@@ -973,6 +1146,8 @@ class PluginRunner:
},
}
reloaded_plugins: List[str] = []
inactive_plugins: List[str] = []
inactive_plugin_ids: Set[str] = set()
for load_plugin_id in load_order:
if load_plugin_id in failed_plugins:
@@ -983,10 +1158,28 @@ class PluginRunner:
continue
_, manifest, _ = candidate
unsatisfied_dependency_ids = [
dependency.id
for dependency in manifest.plugin_dependencies
if dependency.id not in available_plugins
or not self._loader.manifest_validator.is_plugin_dependency_satisfied(
dependency,
available_plugins[dependency.id],
)
]
if unsatisfied_dependencies := self._loader.manifest_validator.get_unsatisfied_plugin_dependencies(
manifest,
available_plugin_versions=available_plugins,
):
if load_plugin_id not in reload_root_ids and any(
dependency_id in inactive_plugin_ids for dependency_id in unsatisfied_dependency_ids
):
logger.info(
f"插件 {load_plugin_id} 的依赖当前未激活,保留为未激活状态: {', '.join(unsatisfied_dependencies)}"
)
inactive_plugin_ids.add(load_plugin_id)
inactive_plugins.append(load_plugin_id)
continue
failed_plugins[load_plugin_id] = f"依赖未满足: {', '.join(unsatisfied_dependencies)}"
continue
@@ -996,9 +1189,13 @@ class PluginRunner:
continue
activated = await self._activate_plugin(meta)
if not activated:
if activated == PluginActivationStatus.FAILED:
failed_plugins[load_plugin_id] = "插件初始化失败"
continue
if activated == PluginActivationStatus.INACTIVE:
inactive_plugin_ids.add(load_plugin_id)
inactive_plugins.append(load_plugin_id)
continue
available_plugins[load_plugin_id] = meta.version
reloaded_plugins.append(load_plugin_id)
@@ -1033,7 +1230,7 @@ class PluginRunner:
rollback_failures[rollback_plugin_id] = str(exc)
continue
if not restored:
if restored != PluginActivationStatus.LOADED:
rollback_failures[rollback_plugin_id] = "无法重新激活旧版本"
return ReloadPluginsResultPayload(
@@ -1041,29 +1238,40 @@ class PluginRunner:
requested_plugin_ids=normalized_plugin_ids,
reloaded_plugins=[],
unloaded_plugins=unloaded_plugins,
inactive_plugins=[],
failed_plugins=self._finalize_failed_reload_messages(failed_plugins, rollback_failures),
)
requested_plugin_success = all(plugin_id in reloaded_plugins for plugin_id in reload_root_ids)
requested_plugin_success = all(
plugin_id in reloaded_plugins or plugin_id in inactive_plugins for plugin_id in reload_root_ids
)
return ReloadPluginsResultPayload(
success=requested_plugin_success and not failed_plugins,
requested_plugin_ids=normalized_plugin_ids,
reloaded_plugins=reloaded_plugins,
unloaded_plugins=unloaded_plugins,
inactive_plugins=inactive_plugins,
failed_plugins=failed_plugins,
)
async def _notify_ready(self, loaded_plugins: List[str], failed_plugins: List[str]) -> None:
async def _notify_ready(
self,
loaded_plugins: List[str],
failed_plugins: List[str],
inactive_plugins: List[str],
) -> None:
"""通知 Host 当前 Runner 已完成插件初始化。
Args:
loaded_plugins: 成功初始化的插件列表。
failed_plugins: 初始化失败的插件列表。
inactive_plugins: 因禁用或依赖不可用而未激活的插件列表。
"""
payload = RunnerReadyPayload(
loaded_plugins=loaded_plugins,
failed_plugins=failed_plugins,
inactive_plugins=inactive_plugins,
)
await self._rpc_client.send_request(
"runner.ready",
@@ -1289,6 +1497,44 @@ class PluginRunner:
return envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
return envelope.make_response(payload={"acknowledged": True})
async def _handle_inspect_plugin_config(self, envelope: Envelope) -> Envelope:
"""处理插件配置元数据解析请求。
Args:
envelope: RPC 请求信封。
Returns:
Envelope: RPC 响应信封。
"""
try:
payload = InspectPluginConfigPayload.model_validate(envelope.payload)
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
plugin_id = envelope.plugin_id
meta, is_temporary_meta, error_message = self._resolve_plugin_meta_for_config_request(plugin_id)
if meta is None:
return envelope.make_error_response(
ErrorCode.E_PLUGIN_NOT_FOUND.value,
error_message or f"未找到插件: {plugin_id}",
)
try:
result = self._inspect_plugin_config(
meta,
config_data=payload.config_data,
use_provided_config=payload.use_provided_config,
suppress_errors=True,
)
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
finally:
if is_temporary_meta:
self._loader.purge_plugin_modules(plugin_id, meta.plugin_dir)
return envelope.make_response(payload=result.model_dump())
async def _handle_validate_plugin_config(self, envelope: Envelope) -> Envelope:
"""处理插件配置校验请求。
@@ -1305,23 +1551,30 @@ class PluginRunner:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
plugin_id = envelope.plugin_id
meta = self._loader.get_plugin(plugin_id)
meta, is_temporary_meta, error_message = self._resolve_plugin_meta_for_config_request(plugin_id)
if meta is None:
return envelope.make_error_response(ErrorCode.E_PLUGIN_NOT_FOUND.value, f"未找到插件: {plugin_id}")
return envelope.make_error_response(
ErrorCode.E_PLUGIN_NOT_FOUND.value,
error_message or f"未找到插件: {plugin_id}",
)
try:
normalized_config, changed = self._normalize_plugin_config(
meta.instance,
payload.config_data,
inspection_result = self._inspect_plugin_config(
meta,
config_data=payload.config_data,
use_provided_config=True,
suppress_errors=False,
)
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
finally:
if is_temporary_meta:
self._loader.purge_plugin_modules(plugin_id, meta.plugin_dir)
result = ValidatePluginConfigResultPayload(
success=True,
normalized_config=normalized_config,
changed=changed,
normalized_config=inspection_result.normalized_config,
changed=inspection_result.changed,
)
return envelope.make_response(payload=result.model_dump())

View File

@@ -40,10 +40,213 @@ from src.common.utils.utils_message import MessageUtils
from src.config.config import global_config
from src.platform_io import DeliveryBatch, get_platform_io_manager
from src.platform_io.route_key_factory import RouteKeyFactory
from src.plugin_runtime.hook_payloads import deserialize_session_message, serialize_session_message
from src.plugin_runtime.hook_schema_utils import build_object_schema
from src.plugin_runtime.host.hook_dispatcher import HookDispatchResult
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
logger = get_logger("send_service")
def register_send_service_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""注册发送服务内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 实际注册的 Hook 规格列表。
"""
return registry.register_hook_specs(
[
HookSpec(
name="send_service.after_build_message",
description="在出站 SessionMessage 构建完成后触发,可改写消息体或取消发送。",
parameters_schema=build_object_schema(
{
"message": {
"type": "object",
"description": "待发送消息的序列化 SessionMessage。",
},
"stream_id": {
"type": "string",
"description": "目标会话 ID。",
},
"display_message": {
"type": "string",
"description": "展示层文本。",
},
"typing": {
"type": "boolean",
"description": "是否模拟打字。",
},
"set_reply": {
"type": "boolean",
"description": "是否附带引用回复。",
},
"storage_message": {
"type": "boolean",
"description": "发送成功后是否写库。",
},
"show_log": {
"type": "boolean",
"description": "是否输出发送日志。",
},
},
required=[
"message",
"stream_id",
"display_message",
"typing",
"set_reply",
"storage_message",
"show_log",
],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="send_service.before_send",
description="在真正调用 Platform IO 发送前触发,可改写消息或取消本次发送。",
parameters_schema=build_object_schema(
{
"message": {
"type": "object",
"description": "待发送消息的序列化 SessionMessage。",
},
"typing": {
"type": "boolean",
"description": "是否模拟打字。",
},
"set_reply": {
"type": "boolean",
"description": "是否附带引用回复。",
},
"reply_message_id": {
"type": "string",
"description": "被引用消息 ID。",
},
"storage_message": {
"type": "boolean",
"description": "发送成功后是否写库。",
},
"show_log": {
"type": "boolean",
"description": "是否输出发送日志。",
},
},
required=["message", "typing", "set_reply", "storage_message", "show_log"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="send_service.after_send",
description="在发送流程结束后触发,用于观察最终发送结果。",
parameters_schema=build_object_schema(
{
"message": {
"type": "object",
"description": "本次发送消息的序列化 SessionMessage。",
},
"sent": {
"type": "boolean",
"description": "本次发送是否成功。",
},
"typing": {
"type": "boolean",
"description": "是否模拟打字。",
},
"set_reply": {
"type": "boolean",
"description": "是否附带引用回复。",
},
"reply_message_id": {
"type": "string",
"description": "被引用消息 ID。",
},
"storage_message": {
"type": "boolean",
"description": "发送成功后是否写库。",
},
"show_log": {
"type": "boolean",
"description": "是否输出发送日志。",
},
},
required=["message", "sent", "typing", "set_reply", "storage_message", "show_log"],
),
default_timeout_ms=5000,
allow_abort=False,
allow_kwargs_mutation=False,
),
]
)
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
def _coerce_bool(value: Any, default: bool) -> bool:
"""将任意值安全转换为布尔值。
Args:
value: 待转换的值。
default: 当值为空时使用的默认值。
Returns:
bool: 转换后的布尔值。
"""
if value is None:
return default
return bool(value)
async def _invoke_send_hook(
hook_name: str,
message: SessionMessage,
**kwargs: Any,
) -> tuple[HookDispatchResult, SessionMessage]:
"""触发携带出站消息的命名 Hook。
Args:
hook_name: 目标 Hook 名称。
message: 当前待发送消息。
**kwargs: 需要附带的额外参数。
Returns:
tuple[HookDispatchResult, SessionMessage]: Hook 聚合结果以及可能被改写后的消息对象。
"""
hook_result = await _get_runtime_manager().invoke_hook(
hook_name,
message=serialize_session_message(message),
**kwargs,
)
mutated_message = message
raw_message = hook_result.kwargs.get("message")
if raw_message is not None:
try:
mutated_message = deserialize_session_message(raw_message)
except Exception as exc:
logger.warning(f"Hook {hook_name} 返回的 message 无法反序列化,已忽略: {exc}")
return hook_result, mutated_message
def _inherit_platform_io_route_metadata(target_stream: BotChatSession) -> Dict[str, object]:
"""从目标会话继承 Platform IO 路由元数据。
@@ -469,6 +672,27 @@ async def _send_via_platform_io(
Returns:
bool: 发送成功时返回 ``True``。
"""
before_send_result, message = await _invoke_send_hook(
"send_service.before_send",
message,
typing=typing,
set_reply=set_reply,
reply_message_id=reply_message_id,
storage_message=storage_message,
show_log=show_log,
)
if before_send_result.aborted:
logger.info(f"[SendService] 消息 {message.message_id} 在发送前被 Hook 中止")
return False
before_kwargs = before_send_result.kwargs
typing = _coerce_bool(before_kwargs.get("typing"), typing)
set_reply = _coerce_bool(before_kwargs.get("set_reply"), set_reply)
storage_message = _coerce_bool(before_kwargs.get("storage_message"), storage_message)
show_log = _coerce_bool(before_kwargs.get("show_log"), show_log)
raw_reply_message_id = before_kwargs.get("reply_message_id", reply_message_id)
reply_message_id = None if raw_reply_message_id in {None, ""} else str(raw_reply_message_id)
platform_io_manager = get_platform_io_manager()
try:
await platform_io_manager.ensure_send_pipeline_ready()
@@ -500,6 +724,18 @@ async def _send_via_platform_io(
logger.debug(traceback.format_exc())
return False
sent = bool(delivery_batch.has_success)
await _invoke_send_hook(
"send_service.after_send",
message,
sent=sent,
typing=typing,
set_reply=set_reply,
reply_message_id=reply_message_id,
storage_message=storage_message,
show_log=show_log,
)
if delivery_batch.has_success:
if storage_message:
_store_sent_message(message)
@@ -606,6 +842,26 @@ async def _send_to_target(
if outbound_message is None:
return False
after_build_result, outbound_message = await _invoke_send_hook(
"send_service.after_build_message",
outbound_message,
stream_id=stream_id,
display_message=display_message,
typing=typing,
set_reply=set_reply,
storage_message=storage_message,
show_log=show_log,
)
if after_build_result.aborted:
logger.info(f"[SendService] 消息 {outbound_message.message_id} 在构建后被 Hook 中止")
return False
after_build_kwargs = after_build_result.kwargs
typing = _coerce_bool(after_build_kwargs.get("typing"), typing)
set_reply = _coerce_bool(after_build_kwargs.get("set_reply"), set_reply)
storage_message = _coerce_bool(after_build_kwargs.get("storage_message"), storage_message)
show_log = _coerce_bool(after_build_kwargs.get("show_log"), show_log)
sent = await send_session_message(
outbound_message,
typing=typing,

View File

@@ -9,6 +9,7 @@ from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
from src.common.logger import get_logger
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
from src.webui.routers.websocket.manager import websocket_manager
logger = get_logger("webui.logs_ws")
router = APIRouter()
@@ -148,24 +149,9 @@ async def broadcast_log(log_data: Dict):
Args:
log_data: 日志数据字典
"""
if not active_connections:
return
# 格式化为 JSON
message = json.dumps(log_data, ensure_ascii=False)
# 记录需要断开的连接
disconnected = set()
# 广播到所有客户端
for connection in active_connections:
try:
await connection.send_text(message)
except Exception:
# 发送失败,标记为断开
disconnected.add(connection)
# 清理断开的连接
if disconnected:
active_connections.difference_update(disconnected)
logger.debug(f"清理了 {len(disconnected)} 个断开的 WebSocket 连接")
await websocket_manager.broadcast_to_topic(
domain="logs",
topic="main",
event="entry",
data={"entry": log_data},
)

View File

@@ -18,12 +18,10 @@ def get_all_routers() -> List[APIRouter]:
from src.webui.api.replier import router as replier_router
from src.webui.routers.chat import router as chat_router
from src.webui.routers.knowledge import router as knowledge_router
from src.webui.routers.websocket.logs import router as logs_router
from src.webui.routes import router as main_router
return [
main_router,
logs_router,
knowledge_router,
chat_router,
planner_router,

View File

@@ -1,7 +1,7 @@
from typing import Tuple
from .routes import router
from .support import WEBUI_CHAT_PLATFORM, ChatConnectionManager, chat_manager
from .service import WEBUI_CHAT_PLATFORM, ChatConnectionManager, chat_manager
def get_webui_chat_broadcaster() -> Tuple[ChatConnectionManager, str]:

View File

@@ -1,9 +1,8 @@
"""本地聊天室路由 - WebUI 与麦麦直接对话。"""
import uuid
from typing import Dict, Optional
from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect
from fastapi import APIRouter, Depends, Query
from sqlalchemy import case, func
from sqlmodel import col, select
@@ -13,16 +12,11 @@ from src.common.logger import get_logger
from src.config.config import global_config
from src.webui.dependencies import require_auth
from .support import (
from .service import (
WEBUI_CHAT_GROUP_ID,
WEBUI_CHAT_PLATFORM,
authenticate_chat_websocket,
chat_history,
chat_manager,
dispatch_chat_event,
normalize_webui_user_id,
resolve_initial_virtual_identity,
send_initial_chat_state,
)
logger = get_logger("webui.chat")
@@ -113,55 +107,6 @@ async def clear_chat_history(
return {"success": True, "message": f"已清空 {deleted} 条聊天记录"}
@router.websocket("/ws")
async def websocket_chat(
websocket: WebSocket,
user_id: Optional[str] = Query(default=None),
user_name: Optional[str] = Query(default="WebUI用户"),
platform: Optional[str] = Query(default=None),
person_id: Optional[str] = Query(default=None),
group_name: Optional[str] = Query(default=None),
group_id: Optional[str] = Query(default=None),
token: Optional[str] = Query(default=None),
) -> None:
"""WebSocket 聊天端点。"""
if not await authenticate_chat_websocket(websocket, token):
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
session_id = str(uuid.uuid4())
normalized_user_id = normalize_webui_user_id(user_id)
current_user_name = user_name or "WebUI用户"
current_virtual_config = resolve_initial_virtual_identity(platform, person_id, group_name, group_id)
await chat_manager.connect(websocket, session_id, normalized_user_id)
try:
await send_initial_chat_state(
session_id=session_id,
user_id=normalized_user_id,
user_name=current_user_name,
virtual_config=current_virtual_config,
)
while True:
data = await websocket.receive_json()
current_user_name, current_virtual_config = await dispatch_chat_event(
session_id=session_id,
session_id_prefix=session_id[:8],
data=data,
current_user_name=current_user_name,
normalized_user_id=normalized_user_id,
current_virtual_config=current_virtual_config,
)
except WebSocketDisconnect:
logger.info(f"WebSocket 断开: session={session_id}, user={normalized_user_id}")
except Exception as e:
logger.error(f"WebSocket 错误: {e}")
finally:
chat_manager.disconnect(session_id, normalized_user_id)
@router.get("/info")
async def get_chat_info() -> Dict[str, object]:
"""获取聊天室信息。"""

View File

@@ -1,10 +1,10 @@
"""WebUI 聊天路由支持逻辑"""
"""WebUI 聊天运行时服务"""
from dataclasses import dataclass
import time
import uuid
from typing import Any, Dict, List, Optional, Tuple, cast
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, cast
from fastapi import WebSocket
from pydantic import BaseModel
from sqlmodel import col, delete, select
@@ -17,8 +17,6 @@ from src.common.logger import get_logger
from src.common.message_repository import find_messages
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
logger = get_logger("webui.chat")
@@ -27,6 +25,8 @@ WEBUI_CHAT_PLATFORM = "webui"
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
WEBUI_USER_ID_PREFIX = "webui_user_"
AsyncMessageSender = Callable[[Dict[str, Any]], Awaitable[None]]
class VirtualIdentityConfig(BaseModel):
"""虚拟身份配置。"""
@@ -52,13 +52,42 @@ class ChatHistoryMessage(BaseModel):
is_bot: bool = False
@dataclass
class ChatSessionConnection:
"""逻辑聊天会话连接信息。"""
session_id: str
connection_id: str
client_session_id: str
user_id: str
user_name: str
active_group_id: str
virtual_config: Optional[VirtualIdentityConfig]
sender: AsyncMessageSender
class ChatHistoryManager:
"""聊天历史管理器。"""
def __init__(self, max_messages: int = 200) -> None:
"""初始化聊天历史管理器。
Args:
max_messages: 内存中允许处理的最大消息数
"""
self.max_messages = max_messages
def _message_to_dict(self, msg: SessionMessage, group_id: Optional[str] = None) -> Dict[str, Any]:
"""将内部消息对象转换为前端可消费的字典。
Args:
msg: 内部统一消息对象
group_id: 当前会话所属的群组标识
Returns:
Dict[str, Any]: 面向 WebUI 的消息字典
"""
del group_id
user_info = msg.message_info.user_info
user_id = user_info.user_id or ""
is_bot = is_bot_self(msg.platform, user_id)
@@ -74,10 +103,27 @@ class ChatHistoryManager:
}
def _resolve_session_id(self, group_id: Optional[str]) -> str:
"""根据群组标识解析聊天会话 ID。
Args:
group_id: 群组标识
Returns:
str: 内部聊天会话 ID
"""
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
return SessionUtils.calculate_session_id(WEBUI_CHAT_PLATFORM, group_id=target_group_id)
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""获取指定会话的历史消息。
Args:
limit: 最大返回条数
group_id: 群组标识
Returns:
List[Dict[str, Any]]: 历史消息列表
"""
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
session_id = self._resolve_session_id(target_group_id)
try:
@@ -90,11 +136,19 @@ class ChatHistoryManager:
result = [self._message_to_dict(msg, target_group_id) for msg in messages]
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录 (group_id={target_group_id})")
return result
except Exception as e:
logger.error(f"从数据库加载聊天记录失败: {e}")
except Exception as exc:
logger.error(f"从数据库加载聊天记录失败: {exc}")
return []
def clear_history(self, group_id: Optional[str] = None) -> int:
"""清空指定会话的历史消息。
Args:
group_id: 群组标识
Returns:
int: 被删除的消息数量
"""
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
session_id = self._resolve_session_id(target_group_id)
try:
@@ -104,66 +158,245 @@ class ChatHistoryManager:
deleted = result.rowcount or 0
logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
return deleted
except Exception as e:
logger.error(f"清空聊天记录失败: {e}")
except Exception as exc:
logger.error(f"清空聊天记录失败: {exc}")
return 0
class ChatConnectionManager:
"""聊天连接管理器。"""
"""统一聊天逻辑会话管理器。"""
def __init__(self) -> None:
self.active_connections: Dict[str, WebSocket] = {}
self.user_sessions: Dict[str, str] = {}
"""初始化聊天逻辑会话管理器。"""
self.active_connections: Dict[str, ChatSessionConnection] = {}
self.client_sessions: Dict[Tuple[str, str], str] = {}
self.connection_sessions: Dict[str, Set[str]] = {}
self.group_sessions: Dict[str, Set[str]] = {}
self.user_sessions: Dict[str, Set[str]] = {}
async def connect(self, websocket: WebSocket, session_id: str, user_id: str) -> None:
await websocket.accept()
self.active_connections[session_id] = websocket
self.user_sessions[user_id] = session_id
logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}")
def _bind_group(self, session_id: str, group_id: str) -> None:
"""为会话绑定群组索引。
def disconnect(self, session_id: str, user_id: str) -> None:
if session_id in self.active_connections:
del self.active_connections[session_id]
if user_id in self.user_sessions and self.user_sessions[user_id] == session_id:
del self.user_sessions[user_id]
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
Args:
session_id: 内部会话 ID
group_id: 群组标识
"""
group_session_ids = self.group_sessions.setdefault(group_id, set())
group_session_ids.add(session_id)
def _unbind_group(self, session_id: str, group_id: str) -> None:
"""移除会话与群组的索引关系。
Args:
session_id: 内部会话 ID
group_id: 群组标识
"""
group_session_ids = self.group_sessions.get(group_id)
if group_session_ids is None:
return
group_session_ids.discard(session_id)
if not group_session_ids:
del self.group_sessions[group_id]
async def connect(
self,
session_id: str,
connection_id: str,
client_session_id: str,
user_id: str,
user_name: str,
virtual_config: Optional[VirtualIdentityConfig],
sender: AsyncMessageSender,
) -> None:
"""注册一个新的逻辑聊天会话。
Args:
session_id: 内部逻辑会话 ID
connection_id: 物理 WebSocket 连接 ID
client_session_id: 前端标签页使用的会话 ID
user_id: 规范化后的用户 ID
user_name: 当前展示昵称
virtual_config: 当前虚拟身份配置
sender: 发送消息到前端的异步回调
"""
existing_session_id = self.client_sessions.get((connection_id, client_session_id))
if existing_session_id is not None:
self.disconnect(existing_session_id)
active_group_id = get_current_group_id(virtual_config)
session_connection = ChatSessionConnection(
session_id=session_id,
connection_id=connection_id,
client_session_id=client_session_id,
user_id=user_id,
user_name=user_name,
active_group_id=active_group_id,
virtual_config=virtual_config,
sender=sender,
)
self.active_connections[session_id] = session_connection
self.client_sessions[(connection_id, client_session_id)] = session_id
self.connection_sessions.setdefault(connection_id, set()).add(session_id)
self.user_sessions.setdefault(user_id, set()).add(session_id)
self._bind_group(session_id, active_group_id)
logger.info(
"WebUI 聊天会话已连接: session=%s, connection=%s, client_session=%s, user=%s, group=%s",
session_id,
connection_id,
client_session_id,
user_id,
active_group_id,
)
def disconnect(self, session_id: str) -> None:
"""断开一个逻辑聊天会话。
Args:
session_id: 内部逻辑会话 ID
"""
session_connection = self.active_connections.pop(session_id, None)
if session_connection is None:
return
self.client_sessions.pop((session_connection.connection_id, session_connection.client_session_id), None)
self._unbind_group(session_id, session_connection.active_group_id)
connection_session_ids = self.connection_sessions.get(session_connection.connection_id)
if connection_session_ids is not None:
connection_session_ids.discard(session_id)
if not connection_session_ids:
del self.connection_sessions[session_connection.connection_id]
user_session_ids = self.user_sessions.get(session_connection.user_id)
if user_session_ids is not None:
user_session_ids.discard(session_id)
if not user_session_ids:
del self.user_sessions[session_connection.user_id]
logger.info("WebUI 聊天会话已断开: session=%s", session_id)
def disconnect_connection(self, connection_id: str) -> None:
"""断开物理连接下的全部逻辑聊天会话。
Args:
connection_id: 物理 WebSocket 连接 ID
"""
session_ids = list(self.connection_sessions.get(connection_id, set()))
for session_id in session_ids:
self.disconnect(session_id)
def get_session(self, session_id: str) -> Optional[ChatSessionConnection]:
"""获取逻辑聊天会话信息。
Args:
session_id: 内部逻辑会话 ID
Returns:
Optional[ChatSessionConnection]: 会话存在时返回对应信息
"""
return self.active_connections.get(session_id)
def get_session_id(self, connection_id: str, client_session_id: str) -> Optional[str]:
"""根据连接 ID 和前端会话 ID 查询内部会话 ID。
Args:
connection_id: 物理 WebSocket 连接 ID
client_session_id: 前端标签页使用的会话 ID
Returns:
Optional[str]: 找到时返回内部会话 ID
"""
return self.client_sessions.get((connection_id, client_session_id))
def update_session_context(
self,
session_id: str,
user_name: str,
virtual_config: Optional[VirtualIdentityConfig],
) -> None:
"""更新会话上下文信息。
Args:
session_id: 内部逻辑会话 ID
user_name: 最新昵称
virtual_config: 最新虚拟身份配置
"""
session_connection = self.active_connections.get(session_id)
if session_connection is None:
return
next_group_id = get_current_group_id(virtual_config)
if next_group_id != session_connection.active_group_id:
self._unbind_group(session_id, session_connection.active_group_id)
self._bind_group(session_id, next_group_id)
session_connection.active_group_id = next_group_id
session_connection.user_name = user_name
session_connection.virtual_config = virtual_config
async def send_message(self, session_id: str, message: Dict[str, Any]) -> None:
if session_id in self.active_connections:
try:
await self.active_connections[session_id].send_json(message)
except Exception as e:
logger.error(f"发送消息失败: {e}")
"""向指定逻辑会话发送消息。
Args:
session_id: 内部逻辑会话 ID
message: 发送消息内容
"""
session_connection = self.active_connections.get(session_id)
if session_connection is None:
return
try:
await session_connection.sender(message)
except Exception as exc:
logger.error("发送聊天消息失败: session=%s, error=%s", session_id, exc)
async def broadcast(self, message: Dict[str, Any]) -> None:
"""向全部逻辑聊天会话广播消息。
Args:
message: 待广播的消息内容
"""
for session_id in list(self.active_connections.keys()):
await self.send_message(session_id, message)
async def broadcast_to_group(self, group_id: str, message: Dict[str, Any]) -> None:
"""向指定群组下的全部逻辑会话广播消息。
Args:
group_id: 群组标识
message: 待广播的消息内容
"""
for session_id in list(self.group_sessions.get(group_id, set())):
await self.send_message(session_id, message)
chat_history = ChatHistoryManager()
chat_manager = ChatConnectionManager()
def is_virtual_mode_enabled(virtual_config: Optional[VirtualIdentityConfig]) -> bool:
"""判断当前是否启用了虚拟身份模式。
Args:
virtual_config: 虚拟身份配置
Returns:
bool: 已启用时返回 ``True``
"""
return bool(virtual_config and virtual_config.enabled)
async def authenticate_chat_websocket(websocket: WebSocket, token: Optional[str]) -> bool:
if token and verify_ws_token(token):
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
return True
if cookie_token := websocket.cookies.get("maibot_session"):
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
return True
return False
def normalize_webui_user_id(user_id: Optional[str]) -> str:
"""标准化 WebUI 用户 ID。
Args:
user_id: 原始用户 ID
Returns:
str: 带统一前缀的用户 ID
"""
if not user_id:
return f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}"
if user_id.startswith(WEBUI_USER_ID_PREFIX):
@@ -172,12 +405,30 @@ def normalize_webui_user_id(user_id: Optional[str]) -> str:
def get_person_by_person_id(person_id: str) -> Optional[PersonInfo]:
"""根据人物 ID 查询人物信息。
Args:
person_id: 人物 ID
Returns:
Optional[PersonInfo]: 查到时返回人物信息
"""
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
return session.exec(statement).first()
def build_virtual_identity_config(person: PersonInfo, group_id: str, group_name: str) -> VirtualIdentityConfig:
"""根据人物信息构建虚拟身份配置。
Args:
person: 人物信息对象
group_id: 逻辑群组 ID
group_name: 逻辑群组名称
Returns:
VirtualIdentityConfig: 虚拟身份配置对象
"""
return VirtualIdentityConfig(
enabled=True,
platform=person.platform,
@@ -195,6 +446,17 @@ def resolve_initial_virtual_identity(
group_name: Optional[str],
group_id: Optional[str],
) -> Optional[VirtualIdentityConfig]:
"""根据初始参数解析虚拟身份配置。
Args:
platform: 平台名称
person_id: 人物 ID
group_name: 群组名称
group_id: 群组 ID
Returns:
Optional[VirtualIdentityConfig]: 解析成功时返回虚拟身份配置
"""
if not (platform and person_id):
return None
@@ -210,11 +472,14 @@ def resolve_initial_virtual_identity(
group_name=group_name or "WebUI虚拟群聊",
)
logger.info(
f"虚拟身份模式已通过 URL 参数激活: {virtual_config.user_nickname} @ {virtual_config.platform}, group_id={virtual_group_id}"
"虚拟身份模式已通过参数激活: %s @ %s, group_id=%s",
virtual_config.user_nickname,
virtual_config.platform,
virtual_group_id,
)
return virtual_config
except Exception as e:
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
except Exception as exc:
logger.warning(f"通过参数配置虚拟身份失败: {exc}")
return None
@@ -224,6 +489,17 @@ def build_session_info_message(
user_name: str,
virtual_config: Optional[VirtualIdentityConfig],
) -> Dict[str, Any]:
"""构建会话信息消息。
Args:
session_id: 内部逻辑会话 ID
user_id: 规范化后的用户 ID
user_name: 当前昵称
virtual_config: 虚拟身份配置
Returns:
Dict[str, Any]: 会话信息消息
"""
session_info_data: Dict[str, Any] = {
"type": "session_info",
"session_id": session_id,
@@ -247,13 +523,41 @@ def build_session_info_message(
def get_active_history_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> Optional[str]:
"""获取当前虚拟身份对应的历史群组 ID。
Args:
virtual_config: 虚拟身份配置
Returns:
Optional[str]: 虚拟身份启用时返回对应群组 ID
"""
if is_virtual_mode_enabled(virtual_config):
assert virtual_config is not None
return virtual_config.group_id
return None
def get_current_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> str:
"""获取当前会话的有效群组 ID。
Args:
virtual_config: 虚拟身份配置
Returns:
str: 当前会话应使用的群组 ID
"""
return get_active_history_group_id(virtual_config) or WEBUI_CHAT_GROUP_ID
def build_welcome_message(virtual_config: Optional[VirtualIdentityConfig]) -> str:
"""构建欢迎消息。
Args:
virtual_config: 虚拟身份配置
Returns:
str: 欢迎消息文本
"""
if is_virtual_mode_enabled(virtual_config):
assert virtual_config is not None
return (
@@ -264,6 +568,12 @@ def build_welcome_message(virtual_config: Optional[VirtualIdentityConfig]) -> st
async def send_chat_error(session_id: str, content: str) -> None:
"""向指定会话发送错误消息。
Args:
session_id: 内部逻辑会话 ID
content: 错误消息内容
"""
await chat_manager.send_message(
session_id,
{
@@ -279,7 +589,17 @@ async def send_initial_chat_state(
user_id: str,
user_name: str,
virtual_config: Optional[VirtualIdentityConfig],
include_welcome: bool = True,
) -> None:
"""向新会话发送初始化状态。
Args:
session_id: 内部逻辑会话 ID
user_id: 规范化后的用户 ID
user_name: 当前昵称
virtual_config: 虚拟身份配置
include_welcome: 是否发送欢迎消息
"""
await chat_manager.send_message(
session_id,
build_session_info_message(
@@ -290,30 +610,43 @@ async def send_initial_chat_state(
),
)
if history := chat_history.get_history(50, get_active_history_group_id(virtual_config)):
await chat_manager.send_message(
session_id,
{
"type": "history",
"messages": history,
},
)
history_group_id = get_active_history_group_id(virtual_config)
history = chat_history.get_history(50, history_group_id)
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": build_welcome_message(virtual_config),
"timestamp": time.time(),
"type": "history",
"messages": history,
"group_id": get_current_group_id(virtual_config),
},
)
if include_welcome:
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": build_welcome_message(virtual_config),
"timestamp": time.time(),
},
)
def resolve_sender_identity(
current_user_name: str,
normalized_user_id: str,
virtual_config: Optional[VirtualIdentityConfig],
) -> Tuple[str, str]:
"""解析当前发送者身份。
Args:
current_user_name: 当前昵称
normalized_user_id: 规范化后的用户 ID
virtual_config: 虚拟身份配置
Returns:
Tuple[str, str]: ``(发送者昵称, 发送者用户 ID)``
"""
if is_virtual_mode_enabled(virtual_config):
assert virtual_config is not None
return virtual_config.user_nickname or current_user_name, virtual_config.user_id or normalized_user_id
@@ -328,6 +661,19 @@ def create_message_data(
is_at_bot: bool = True,
virtual_config: Optional[VirtualIdentityConfig] = None,
) -> Dict[str, Any]:
"""构建发送给聊天核心的消息数据。
Args:
content: 文本内容
user_id: 用户 ID
user_name: 用户昵称
message_id: 消息 ID
is_at_bot: 是否默认艾特机器人
virtual_config: 虚拟身份配置
Returns:
Dict[str, Any]: 聊天核心可处理的消息数据
"""
if message_id is None:
message_id = str(uuid.uuid4())
@@ -389,6 +735,18 @@ async def handle_chat_message(
normalized_user_id: str,
current_virtual_config: Optional[VirtualIdentityConfig],
) -> str:
"""处理用户发送的聊天消息。
Args:
session_id: 内部逻辑会话 ID
data: 前端提交的消息数据
current_user_name: 当前昵称
normalized_user_id: 规范化后的用户 ID
current_virtual_config: 当前虚拟身份配置
Returns:
str: 处理后的最新昵称
"""
content = str(data.get("content", "")).strip()
if not content:
return current_user_name
@@ -401,11 +759,14 @@ async def handle_chat_message(
normalized_user_id=normalized_user_id,
virtual_config=current_virtual_config,
)
target_group_id = get_current_group_id(current_virtual_config)
await chat_manager.broadcast(
await chat_manager.broadcast_to_group(
target_group_id,
{
"type": "user_message",
"content": content,
"group_id": target_group_id,
"message_id": message_id,
"timestamp": timestamp,
"sender": {
@@ -414,7 +775,7 @@ async def handle_chat_message(
"is_bot": False,
},
"virtual_mode": is_virtual_mode_enabled(current_virtual_config),
}
},
)
message_data = create_message_data(
@@ -427,22 +788,37 @@ async def handle_chat_message(
)
try:
await chat_manager.broadcast({"type": "typing", "is_typing": True})
await chat_manager.broadcast_to_group(target_group_id, {"type": "typing", "is_typing": True})
await chat_bot.message_process(message_data)
except Exception as e:
logger.error(f"处理消息时出错: {e}")
await send_chat_error(session_id, f"处理消息时出错: {str(e)}")
except Exception as exc:
logger.error(f"处理消息时出错: {exc}")
await send_chat_error(session_id, f"处理消息时出错: {str(exc)}")
finally:
await chat_manager.broadcast({"type": "typing", "is_typing": False})
await chat_manager.broadcast_to_group(target_group_id, {"type": "typing", "is_typing": False})
return next_user_name
async def handle_chat_ping(session_id: str) -> None:
"""处理聊天心跳。
Args:
session_id: 内部逻辑会话 ID
"""
await chat_manager.send_message(session_id, {"type": "pong", "timestamp": time.time()})
async def handle_nickname_update(session_id: str, data: Dict[str, Any], current_user_name: str) -> str:
"""处理昵称更新请求。
Args:
session_id: 内部逻辑会话 ID
data: 前端提交的数据
current_user_name: 当前昵称
Returns:
str: 更新后的昵称
"""
new_name = str(data.get("user_name", "")).strip()
if not new_name:
return current_user_name
@@ -463,6 +839,16 @@ async def enable_virtual_identity(
session_prefix: str,
virtual_data: Dict[str, Any],
) -> Optional[VirtualIdentityConfig]:
"""启用虚拟身份模式。
Args:
session_id: 内部逻辑会话 ID
session_prefix: 会话前缀用于生成默认群组 ID
virtual_data: 前端提交的虚拟身份配置
Returns:
Optional[VirtualIdentityConfig]: 启用成功时返回新的虚拟身份配置
"""
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
await send_chat_error(session_id, "虚拟身份配置缺少必要字段: platform 和 person_id")
return None
@@ -470,16 +856,18 @@ async def enable_virtual_identity(
person_id_value = str(virtual_data.get("person_id"))
try:
person = get_person_by_person_id(person_id_value)
if not person:
if person is None:
await send_chat_error(session_id, f"找不到用户: {person_id_value}")
return None
custom_group_id = virtual_data.get("group_id")
current_group_id = (
f"{VIRTUAL_GROUP_ID_PREFIX}{custom_group_id}"
if custom_group_id
else f"{VIRTUAL_GROUP_ID_PREFIX}{session_prefix}"
)
custom_group_id = str(virtual_data.get("group_id") or "").strip()
if custom_group_id:
current_group_id = custom_group_id
if not current_group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
current_group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{current_group_id}"
else:
current_group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{session_prefix}"
current_virtual_config = build_virtual_identity_config(
person=person,
group_id=current_group_id,
@@ -521,13 +909,18 @@ async def enable_virtual_identity(
},
)
return current_virtual_config
except Exception as e:
logger.error(f"设置虚拟身份失败: {e}")
await send_chat_error(session_id, f"设置虚拟身份失败: {str(e)}")
except Exception as exc:
logger.error(f"设置虚拟身份失败: {exc}")
await send_chat_error(session_id, f"设置虚拟身份失败: {str(exc)}")
return None
async def disable_virtual_identity(session_id: str) -> None:
"""关闭虚拟身份模式。
Args:
session_id: 内部逻辑会话 ID
"""
await chat_manager.send_message(
session_id,
{
@@ -560,7 +953,18 @@ async def handle_virtual_identity_update(
data: Dict[str, Any],
current_virtual_config: Optional[VirtualIdentityConfig],
) -> Optional[VirtualIdentityConfig]:
virtual_data = cast(dict[str, Any], data.get("config", {}))
"""处理虚拟身份切换请求。
Args:
session_id: 内部逻辑会话 ID
session_id_prefix: 会话前缀
data: 前端提交的数据
current_virtual_config: 当前虚拟身份配置
Returns:
Optional[VirtualIdentityConfig]: 更新后的虚拟身份配置
"""
virtual_data = cast(Dict[str, Any], data.get("config", {}))
if virtual_data.get("enabled"):
next_config = await enable_virtual_identity(session_id, session_id_prefix, virtual_data)
return next_config if next_config is not None else current_virtual_config
@@ -577,6 +981,19 @@ async def dispatch_chat_event(
normalized_user_id: str,
current_virtual_config: Optional[VirtualIdentityConfig],
) -> Tuple[str, Optional[VirtualIdentityConfig]]:
"""分发聊天事件到对应的处理器。
Args:
session_id: 内部逻辑会话 ID
session_id_prefix: 会话前缀
data: 前端提交的数据
current_user_name: 当前昵称
normalized_user_id: 规范化后的用户 ID
current_virtual_config: 当前虚拟身份配置
Returns:
Tuple[str, Optional[VirtualIdentityConfig]]: ``(最新昵称, 最新虚拟身份配置)``
"""
event_type = data.get("type")
if event_type == "message":
next_user_name = await handle_chat_message(

View File

@@ -6,11 +6,13 @@ from .catalog import router as catalog_router
from .config_routes import router as config_router
from .management import router as management_router
from .progress import get_progress_router, update_progress
from .runtime_routes import router as runtime_router
router = APIRouter(prefix="/plugins", tags=["插件管理"])
router.include_router(catalog_router)
router.include_router(management_router)
router.include_router(config_router)
router.include_router(runtime_router)
set_update_progress_callback(update_progress)

View File

@@ -1,13 +1,13 @@
"""插件配置相关 WebUI 路由。"""
import json
from pathlib import Path
from typing import Any, Dict, Optional, cast
import tomlkit
from fastapi import APIRouter, Cookie, HTTPException
from src.common.logger import get_logger
from src.plugin_runtime.component_query import component_query_service
from src.plugin_runtime.protocol.envelope import InspectPluginConfigResultPayload
from src.webui.utils.toml_utils import save_toml_with_format
from .schemas import UpdatePluginConfigRequest, UpdatePluginRawConfigRequest
@@ -207,6 +207,55 @@ def _build_toml_document(config_data: Dict[str, Any]) -> tomlkit.TOMLDocument:
return tomlkit.parse(tomlkit.dumps(config_data))
def _load_plugin_config_from_disk(plugin_path: Path) -> Dict[str, Any]:
"""从磁盘读取插件配置。
Args:
plugin_path: 插件目录。
Returns:
Dict[str, Any]: 当前配置字典;文件不存在时返回空字典。
"""
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
if not config_path.exists():
return {}
with open(config_path, "r", encoding="utf-8") as file_obj:
loaded_config = tomlkit.load(file_obj).unwrap()
return loaded_config if isinstance(loaded_config, dict) else {}
async def _inspect_plugin_config_via_runtime(
plugin_id: str,
config_data: Optional[Dict[str, Any]] = None,
*,
use_provided_config: bool = False,
) -> InspectPluginConfigResultPayload | None:
"""通过插件运行时解析配置元数据。
Args:
plugin_id: 插件 ID。
config_data: 可选的配置内容。
use_provided_config: 是否优先使用传入配置而不是磁盘配置。
Returns:
InspectPluginConfigResultPayload | None: 运行时可用时返回解析结果,否则返回 ``None``。
Raises:
ValueError: 插件运行时明确拒绝解析请求时抛出。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
runtime_manager = get_plugin_runtime_manager()
return await runtime_manager.inspect_plugin_config(
plugin_id,
config_data,
use_provided_config=use_provided_config,
)
async def _validate_plugin_config_via_runtime(plugin_id: str, config_data: Dict[str, Any]) -> Dict[str, Any] | None:
"""通过插件运行时对配置进行校验。
@@ -244,27 +293,24 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
logger.info(f"获取插件配置 Schema: {plugin_id}")
try:
registration_schema = component_query_service.get_plugin_config_schema(plugin_id)
if isinstance(registration_schema, dict) and registration_schema:
return {"success": True, "schema": registration_schema}
plugin_path = find_plugin_path_by_id(plugin_id)
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
schema_json_path = resolve_plugin_file_path(plugin_path, "config_schema.json")
if schema_json_path.exists():
try:
with open(schema_json_path, "r", encoding="utf-8") as file_obj:
return {"success": True, "schema": json.load(file_obj)}
except Exception as e:
logger.warning(f"读取 config_schema.json 失败,回退到自动推断: {e}")
try:
runtime_snapshot = await _inspect_plugin_config_via_runtime(plugin_id)
except ValueError as exc:
logger.warning(f"插件 {plugin_id} 配置 Schema 解析失败,将回退到弱推断: {exc}")
runtime_snapshot = None
current_config: Any = component_query_service.get_plugin_default_config(plugin_id) or {}
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as file_obj:
current_config = tomlkit.load(file_obj)
if runtime_snapshot is not None and runtime_snapshot.config_schema:
return {"success": True, "schema": dict(runtime_snapshot.config_schema)}
current_config: Any = (
dict(runtime_snapshot.normalized_config)
if runtime_snapshot is not None
else _load_plugin_config_from_disk(plugin_path)
)
return {"success": True, "schema": _build_schema_from_current_config(plugin_id, current_config)}
except HTTPException:
@@ -375,15 +421,24 @@ async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cook
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
try:
runtime_snapshot = await _inspect_plugin_config_via_runtime(plugin_id)
except ValueError as exc:
logger.warning(f"插件 {plugin_id} 配置读取失败,将回退到磁盘内容: {exc}")
runtime_snapshot = None
if runtime_snapshot is not None:
message = "配置文件不存在,已返回默认配置" if not config_path.exists() else ""
return {
"success": True,
"config": dict(runtime_snapshot.normalized_config),
"message": message,
}
if not config_path.exists():
default_config = component_query_service.get_plugin_default_config(plugin_id)
if isinstance(default_config, dict):
return {"success": True, "config": default_config, "message": "配置文件不存在,已返回默认配置"}
return {"success": True, "config": {}, "message": "配置文件不存在"}
with open(config_path, "r", encoding="utf-8") as file_obj:
config = tomlkit.load(file_obj)
return {"success": True, "config": dict(config)}
return {"success": True, "config": _load_plugin_config_from_disk(plugin_path)}
except HTTPException:
raise
except Exception as e:
@@ -412,6 +467,10 @@ async def update_plugin_config(
logger.info(f"更新插件配置: {plugin_id}")
try:
plugin_path = find_plugin_path_by_id(plugin_id)
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_data = request.config or {}
if isinstance(config_data, dict):
config_data = normalize_dotted_keys(config_data)
@@ -419,12 +478,13 @@ async def update_plugin_config(
if isinstance(runtime_validated_config, dict):
config_data = runtime_validated_config
else:
plugin_schema = component_query_service.get_plugin_config_schema(plugin_id)
if isinstance(plugin_schema, dict) and plugin_schema:
_coerce_config_by_plugin_schema(plugin_schema, config_data)
plugin_path = find_plugin_path_by_id(plugin_id)
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
runtime_snapshot = await _inspect_plugin_config_via_runtime(
plugin_id,
config_data,
use_provided_config=True,
)
if runtime_snapshot is not None and runtime_snapshot.config_schema:
_coerce_config_by_plugin_schema(dict(runtime_snapshot.config_schema), config_data)
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
backup_path = backup_file(config_path, "backup")
@@ -498,17 +558,29 @@ async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(N
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
default_config = component_query_service.get_plugin_default_config(plugin_id)
config = _build_toml_document(default_config if isinstance(default_config, dict) else {})
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as file_obj:
config = tomlkit.load(file_obj)
try:
runtime_snapshot = await _inspect_plugin_config_via_runtime(plugin_id)
except ValueError as exc:
logger.warning(f"插件 {plugin_id} 状态切换前配置解析失败,将回退到磁盘内容: {exc}")
runtime_snapshot = None
if "plugin" not in config:
current_config = (
dict(runtime_snapshot.normalized_config)
if runtime_snapshot is not None
else _load_plugin_config_from_disk(plugin_path)
)
config = _build_toml_document(current_config)
plugin_section = config.get("plugin")
if plugin_section is None or not hasattr(plugin_section, "get"):
config["plugin"] = tomlkit.table()
plugin_config = cast(Any, config["plugin"])
current_enabled = bool(plugin_config.get("enabled", True))
current_enabled = (
bool(runtime_snapshot.enabled)
if runtime_snapshot is not None
else bool(plugin_config.get("enabled", True))
)
new_enabled = not current_enabled
plugin_config["enabled"] = new_enabled
save_toml_with_format(config, str(config_path))
@@ -519,7 +591,7 @@ async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(N
"success": True,
"enabled": new_enabled,
"message": f"插件已{status}",
"note": "状态更改将在下次加载插件时生效",
"note": "状态更改将自动热更新到对应插件",
}
except HTTPException:
raise

View File

@@ -1,12 +1,15 @@
"""插件进度实时推送支持。"""
from typing import Any, Dict, Optional, Set
import asyncio
import json
from typing import Any, Dict, Optional, Set
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
from src.common.logger import get_logger
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
from src.webui.routers.websocket.manager import websocket_manager
logger = get_logger("webui.plugin_progress")
@@ -25,25 +28,29 @@ current_progress: Dict[str, Any] = {
}
def get_current_progress() -> Dict[str, Any]:
"""获取当前插件进度快照。
Returns:
Dict[str, Any]: 当前插件进度数据副本。
"""
return current_progress.copy()
async def broadcast_progress(progress_data: Dict[str, Any]) -> None:
"""向统一连接层广播插件进度更新。
Args:
progress_data: 插件进度数据。
"""
global current_progress
current_progress = progress_data.copy()
if not active_connections:
return
message = json.dumps(progress_data, ensure_ascii=False)
disconnected: Set[WebSocket] = set()
for websocket in active_connections:
try:
await websocket.send_text(message)
except Exception as e:
logger.error(f"发送进度更新失败: {e}")
disconnected.add(websocket)
for websocket in disconnected:
active_connections.discard(websocket)
await websocket_manager.broadcast_to_topic(
domain="plugin_progress",
topic="main",
event="update",
data={"progress": progress_data},
)
async def update_progress(
@@ -56,6 +63,18 @@ async def update_progress(
total_plugins: int = 0,
loaded_plugins: int = 0,
) -> None:
"""更新当前插件进度并广播。
Args:
stage: 当前阶段。
progress: 当前进度百分比。
message: 进度说明消息。
operation: 当前操作类型。
error: 可选的错误信息。
plugin_id: 当前处理的插件 ID。
total_plugins: 总插件数量。
loaded_plugins: 已处理插件数量。
"""
progress_data = {
"operation": operation,
"stage": stage,
@@ -74,6 +93,12 @@ async def update_progress(
@router.websocket("/ws/plugin-progress")
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)) -> None:
"""旧版插件进度 WebSocket 入口。
Args:
websocket: FastAPI WebSocket 对象。
token: 可选的一次性握手 Token。
"""
is_authenticated = False
if token and verify_ws_token(token):
@@ -105,17 +130,22 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
data = await websocket.receive_text()
if data == "ping":
await websocket.send_text("pong")
except Exception as e:
logger.error(f"处理客户端消息时出错: {e}")
except Exception as exc:
logger.error(f"处理客户端消息时出错: {exc}")
break
except WebSocketDisconnect:
active_connections.discard(websocket)
logger.info(f"📡 插件进度 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
except Exception as e:
logger.error(f"❌ WebSocket 错误: {e}")
except Exception as exc:
logger.error(f"❌ WebSocket 错误: {exc}")
active_connections.discard(websocket)
def get_progress_router() -> APIRouter:
"""获取旧版插件进度路由对象。
Returns:
APIRouter: 插件进度路由对象。
"""
return router

View File

@@ -0,0 +1,28 @@
"""插件运行时相关 WebUI 路由。"""
from typing import Optional
from fastapi import APIRouter, Cookie
from src.plugin_runtime.component_query import component_query_service
from .schemas import HookSpecListResponse, HookSpecResponse
from .support import require_plugin_token
router = APIRouter()
@router.get("/runtime/hooks", response_model=HookSpecListResponse)
async def list_runtime_hook_specs(maibot_session: Optional[str] = Cookie(None)) -> HookSpecListResponse:
"""返回当前插件运行时公开的 Hook 规格清单。
Args:
maibot_session: 当前 WebUI 会话令牌。
Returns:
HookSpecListResponse: Hook 规格列表响应。
"""
require_plugin_token(maibot_session)
hooks = [HookSpecResponse(**hook_data) for hook_data in component_query_service.list_hook_specs()]
return HookSpecListResponse(success=True, hooks=hooks)

View File

@@ -111,3 +111,19 @@ class UpdatePluginConfigRequest(BaseModel):
class UpdatePluginRawConfigRequest(BaseModel):
config: str = Field(..., description="原始 TOML 配置内容")
class HookSpecResponse(BaseModel):
name: str = Field(..., description="Hook 名称")
description: str = Field("", description="Hook 描述")
parameters_schema: Dict[str, Any] = Field(default_factory=dict, description="Hook 参数模型")
default_timeout_ms: int = Field(..., description="默认超时毫秒数")
allow_blocking: bool = Field(..., description="是否允许 blocking 处理器")
allow_observe: bool = Field(..., description="是否允许 observe 处理器")
allow_abort: bool = Field(..., description="是否允许 abort")
allow_kwargs_mutation: bool = Field(..., description="是否允许修改 kwargs")
class HookSpecListResponse(BaseModel):
success: bool = Field(..., description="是否成功")
hooks: List[HookSpecResponse] = Field(default_factory=list, description="Hook 规格列表")

View File

@@ -1,7 +1,9 @@
"""WebSocket 路由聚合导出。"""
from .auth import router as ws_auth_router
from .logs import router as logs_router
from .unified import router as unified_ws_router
__all__ = [
"logs_router",
"unified_ws_router",
"ws_auth_router",
]

View File

@@ -1,11 +0,0 @@
"""WebSocket 日志推送路由兼容导出。"""
from src.webui.logs_ws import active_connections, broadcast_log, load_recent_logs, router, websocket_logs
__all__ = [
"active_connections",
"broadcast_log",
"load_recent_logs",
"router",
"websocket_logs",
]

View File

@@ -0,0 +1,297 @@
"""统一 WebSocket 连接管理器。"""
import asyncio
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Set
from fastapi import WebSocket
from src.common.logger import get_logger
logger = get_logger("webui.websocket")
@dataclass
class WebSocketConnection:
"""统一 WebSocket 连接上下文。"""
connection_id: str
websocket: WebSocket
subscriptions: Set[str] = field(default_factory=set)
chat_sessions: Dict[str, str] = field(default_factory=dict)
send_queue: "asyncio.Queue[Optional[Dict[str, Any]]]" = field(default_factory=asyncio.Queue)
sender_task: Optional["asyncio.Task[None]"] = None
class UnifiedWebSocketManager:
"""统一 WebSocket 连接管理器。"""
def __init__(self) -> None:
"""初始化统一 WebSocket 连接管理器。"""
self.connections: Dict[str, WebSocketConnection] = {}
def _build_subscription_key(self, domain: str, topic: str) -> str:
"""构建订阅索引键。
Args:
domain: 业务域名称。
topic: 主题名称。
Returns:
str: 订阅索引键。
"""
return f"{domain}:{topic}"
async def _sender_loop(self, connection: WebSocketConnection) -> None:
"""串行发送指定连接的出站消息。
Args:
connection: 目标连接上下文。
"""
try:
while True:
message = await connection.send_queue.get()
if message is None:
return
await connection.websocket.send_json(message)
except asyncio.CancelledError:
raise
except Exception as exc:
logger.error("统一 WebSocket 发送失败: connection=%s, error=%s", connection.connection_id, exc)
async def connect(self, connection_id: str, websocket: WebSocket) -> WebSocketConnection:
"""注册一个新的物理 WebSocket 连接。
Args:
connection_id: 连接 ID。
websocket: FastAPI WebSocket 对象。
Returns:
WebSocketConnection: 新建的连接上下文。
"""
await websocket.accept()
connection = WebSocketConnection(connection_id=connection_id, websocket=websocket)
connection.sender_task = asyncio.create_task(self._sender_loop(connection))
self.connections[connection_id] = connection
return connection
async def disconnect(self, connection_id: str) -> None:
"""断开并清理指定连接。
Args:
connection_id: 连接 ID。
"""
connection = self.connections.pop(connection_id, None)
if connection is None:
return
await connection.send_queue.put(None)
if connection.sender_task is not None:
try:
await connection.sender_task
except asyncio.CancelledError:
pass
except Exception as exc:
logger.debug("等待发送协程退出时出现异常: connection=%s, error=%s", connection_id, exc)
def get_connection(self, connection_id: str) -> Optional[WebSocketConnection]:
"""获取指定连接上下文。
Args:
connection_id: 连接 ID。
Returns:
Optional[WebSocketConnection]: 找到时返回连接上下文。
"""
return self.connections.get(connection_id)
def register_chat_session(self, connection_id: str, client_session_id: str, session_id: str) -> None:
"""登记连接下的逻辑聊天会话。
Args:
connection_id: 连接 ID。
client_session_id: 前端会话 ID。
session_id: 内部会话 ID。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
connection.chat_sessions[client_session_id] = session_id
def unregister_chat_session(self, connection_id: str, client_session_id: str) -> None:
"""移除连接下的逻辑聊天会话登记。
Args:
connection_id: 连接 ID。
client_session_id: 前端会话 ID。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
connection.chat_sessions.pop(client_session_id, None)
def get_chat_session_id(self, connection_id: str, client_session_id: str) -> Optional[str]:
"""查询连接下的内部聊天会话 ID。
Args:
connection_id: 连接 ID。
client_session_id: 前端会话 ID。
Returns:
Optional[str]: 找到时返回内部会话 ID。
"""
connection = self.connections.get(connection_id)
if connection is None:
return None
return connection.chat_sessions.get(client_session_id)
def subscribe(self, connection_id: str, domain: str, topic: str) -> None:
"""登记连接的主题订阅。
Args:
connection_id: 连接 ID。
domain: 业务域名称。
topic: 主题名称。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
connection.subscriptions.add(self._build_subscription_key(domain, topic))
def unsubscribe(self, connection_id: str, domain: str, topic: str) -> None:
"""移除连接的主题订阅。
Args:
connection_id: 连接 ID。
domain: 业务域名称。
topic: 主题名称。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
connection.subscriptions.discard(self._build_subscription_key(domain, topic))
def is_subscribed(self, connection_id: str, domain: str, topic: str) -> bool:
"""判断连接是否订阅了指定主题。
Args:
connection_id: 连接 ID。
domain: 业务域名称。
topic: 主题名称。
Returns:
bool: 已订阅时返回 ``True``。
"""
connection = self.connections.get(connection_id)
if connection is None:
return False
return self._build_subscription_key(domain, topic) in connection.subscriptions
async def enqueue(self, connection_id: str, message: Dict[str, Any]) -> None:
"""向指定连接的发送队列压入消息。
Args:
connection_id: 连接 ID。
message: 待发送的消息。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
await connection.send_queue.put(message)
async def send_response(
self,
connection_id: str,
request_id: Optional[str],
ok: bool,
data: Optional[Dict[str, Any]] = None,
error: Optional[Dict[str, Any]] = None,
) -> None:
"""发送统一响应消息。
Args:
connection_id: 连接 ID。
request_id: 请求 ID。
ok: 请求是否成功。
data: 成功响应数据。
error: 失败响应数据。
"""
response_message: Dict[str, Any] = {
"op": "response",
"id": request_id,
"ok": ok,
}
if data is not None:
response_message["data"] = data
if error is not None:
response_message["error"] = error
await self.enqueue(connection_id, response_message)
async def send_event(
self,
connection_id: str,
domain: str,
event: str,
data: Dict[str, Any],
session: Optional[str] = None,
topic: Optional[str] = None,
) -> None:
"""发送统一事件消息。
Args:
connection_id: 连接 ID。
domain: 业务域名称。
event: 事件名称。
data: 事件数据。
session: 可选的逻辑会话 ID。
topic: 可选的主题名称。
"""
event_message: Dict[str, Any] = {
"op": "event",
"domain": domain,
"event": event,
"data": data,
}
if session is not None:
event_message["session"] = session
if topic is not None:
event_message["topic"] = topic
await self.enqueue(connection_id, event_message)
async def send_pong(self, connection_id: str, timestamp: float) -> None:
"""发送心跳响应。
Args:
connection_id: 连接 ID。
timestamp: 当前时间戳。
"""
await self.enqueue(
connection_id,
{
"op": "pong",
"ts": timestamp,
},
)
async def broadcast_to_topic(self, domain: str, topic: str, event: str, data: Dict[str, Any]) -> None:
"""向订阅指定主题的全部连接广播事件。
Args:
domain: 业务域名称。
topic: 主题名称。
event: 事件名称。
data: 事件数据。
"""
subscription_key = self._build_subscription_key(domain, topic)
for connection in list(self.connections.values()):
if subscription_key in connection.subscriptions:
await self.send_event(
connection.connection_id,
domain=domain,
event=event,
data=data,
topic=topic,
)
websocket_manager = UnifiedWebSocketManager()

View File

@@ -0,0 +1,548 @@
"""统一 WebSocket 路由。"""
from typing import Any, Dict, Optional, Set, cast
import asyncio
import time
import uuid
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
from src.common.logger import get_logger
from src.webui.core import get_token_manager
from src.webui.logs_ws import load_recent_logs
from src.webui.routers.chat.service import (
chat_manager,
dispatch_chat_event,
normalize_webui_user_id,
resolve_initial_virtual_identity,
send_initial_chat_state,
)
from src.webui.routers.plugin.progress import get_current_progress
from src.webui.routers.websocket.auth import verify_ws_token
from src.webui.routers.websocket.manager import websocket_manager
logger = get_logger("webui.unified_ws")
router = APIRouter()
_background_tasks: Set["asyncio.Task[None]"] = set()
def _build_error(code: str, message: str) -> Dict[str, Any]:
"""构建统一错误响应体。
Args:
code: 错误码。
message: 错误描述。
Returns:
Dict[str, Any]: 统一错误对象。
"""
return {
"code": code,
"message": message,
}
def _get_request_data(message: Dict[str, Any]) -> Dict[str, Any]:
"""从客户端消息中提取数据字段。
Args:
message: 客户端消息。
Returns:
Dict[str, Any]: 标准化后的数据字典。
"""
data = message.get("data", {})
if isinstance(data, dict):
return cast(Dict[str, Any], data)
return {}
def _track_background_task(task: "asyncio.Task[None]") -> None:
"""登记后台任务并在完成后自动清理。
Args:
task: 后台协程任务。
"""
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
async def authenticate_websocket_connection(websocket: WebSocket, token: Optional[str]) -> bool:
"""校验统一 WebSocket 连接的认证状态。
Args:
websocket: FastAPI WebSocket 对象。
token: 可选的一次性握手 Token。
Returns:
bool: 认证通过时返回 ``True``。
"""
if token and verify_ws_token(token):
logger.debug("统一 WebSocket 使用临时 token 认证成功")
return True
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
logger.debug("统一 WebSocket 使用 Cookie 认证成功")
return True
return False
async def _handle_logs_subscribe(connection_id: str, request_id: Optional[str], data: Dict[str, Any]) -> None:
"""处理日志域订阅请求。
Args:
connection_id: 连接 ID。
request_id: 请求 ID。
data: 订阅参数。
"""
replay_limit = int(data.get("replay", 100) or 100)
replay_limit = max(0, min(replay_limit, 500))
websocket_manager.subscribe(connection_id, domain="logs", topic="main")
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"domain": "logs", "topic": "main"},
)
await websocket_manager.send_event(
connection_id,
domain="logs",
event="snapshot",
topic="main",
data={"entries": load_recent_logs(limit=replay_limit)},
)
async def _handle_plugin_progress_subscribe(connection_id: str, request_id: Optional[str]) -> None:
"""处理插件进度域订阅请求。
Args:
connection_id: 连接 ID。
request_id: 请求 ID。
"""
websocket_manager.subscribe(connection_id, domain="plugin_progress", topic="main")
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"domain": "plugin_progress", "topic": "main"},
)
await websocket_manager.send_event(
connection_id,
domain="plugin_progress",
event="snapshot",
topic="main",
data={"progress": get_current_progress()},
)
async def _handle_subscribe(connection_id: str, message: Dict[str, Any]) -> None:
"""处理主题订阅请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
domain = str(message.get("domain") or "").strip()
topic = str(message.get("topic") or "").strip()
data = _get_request_data(message)
if domain == "logs" and topic == "main":
await _handle_logs_subscribe(connection_id, request_id, data)
return
if domain == "plugin_progress" and topic == "main":
await _handle_plugin_progress_subscribe(connection_id, request_id)
return
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("unsupported_subscription", f"不支持的订阅目标: {domain}:{topic}"),
)
async def _handle_unsubscribe(connection_id: str, message: Dict[str, Any]) -> None:
"""处理主题退订请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
domain = str(message.get("domain") or "").strip()
topic = str(message.get("topic") or "").strip()
if not domain or not topic:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("invalid_unsubscribe", "退订请求缺少 domain 或 topic"),
)
return
websocket_manager.unsubscribe(connection_id, domain=domain, topic=topic)
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"domain": domain, "topic": topic},
)
async def _open_chat_session(connection_id: str, message: Dict[str, Any]) -> None:
"""打开一个逻辑聊天会话。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
client_session_id = str(message.get("session") or "").strip()
if not client_session_id:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("missing_session", "聊天会话打开请求缺少 session"),
)
return
data = _get_request_data(message)
normalized_user_id = normalize_webui_user_id(cast(Optional[str], data.get("user_id")))
current_user_name = str(data.get("user_name") or "WebUI用户")
current_virtual_config = resolve_initial_virtual_identity(
platform=cast(Optional[str], data.get("platform")),
person_id=cast(Optional[str], data.get("person_id")),
group_name=cast(Optional[str], data.get("group_name")),
group_id=cast(Optional[str], data.get("group_id")),
)
restore = bool(data.get("restore"))
session_id = f"{connection_id}:{client_session_id}"
async def send_chat_event(chat_message: Dict[str, Any]) -> None:
"""将聊天消息封装为统一事件并发送。
Args:
chat_message: 聊天消息体。
"""
event_name = str(chat_message.get("type") or "message")
await websocket_manager.send_event(
connection_id,
domain="chat",
event=event_name,
session=client_session_id,
data=chat_message,
)
await chat_manager.connect(
session_id=session_id,
connection_id=connection_id,
client_session_id=client_session_id,
user_id=normalized_user_id,
user_name=current_user_name,
virtual_config=current_virtual_config,
sender=send_chat_event,
)
websocket_manager.register_chat_session(connection_id, client_session_id, session_id)
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"session": client_session_id, "session_id": session_id},
)
await send_initial_chat_state(
session_id=session_id,
user_id=normalized_user_id,
user_name=current_user_name,
virtual_config=current_virtual_config,
include_welcome=not restore,
)
async def _close_chat_session(connection_id: str, message: Dict[str, Any]) -> None:
"""关闭一个逻辑聊天会话。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
client_session_id = str(message.get("session") or "").strip()
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
if session_id is None:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
)
return
chat_manager.disconnect(session_id)
websocket_manager.unregister_chat_session(connection_id, client_session_id)
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"session": client_session_id},
)
async def _process_chat_message(connection_id: str, client_session_id: str, data: Dict[str, Any]) -> None:
"""在后台处理聊天消息事件。
Args:
connection_id: 连接 ID。
client_session_id: 前端会话 ID。
data: 客户端提交的消息数据。
"""
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
if session_id is None:
return
session_state = chat_manager.get_session(session_id)
if session_state is None:
return
next_user_name, next_virtual_config = await dispatch_chat_event(
session_id=session_id,
session_id_prefix=session_id[:8],
data=data,
current_user_name=session_state.user_name,
normalized_user_id=session_state.user_id,
current_virtual_config=session_state.virtual_config,
)
chat_manager.update_session_context(
session_id=session_id,
user_name=next_user_name,
virtual_config=next_virtual_config,
)
async def _handle_chat_message_send(connection_id: str, message: Dict[str, Any]) -> None:
"""处理聊天消息发送请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
client_session_id = str(message.get("session") or "").strip()
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
if session_id is None:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
)
return
data = _get_request_data(message)
payload = {
"type": "message",
"content": data.get("content", ""),
"user_name": data.get("user_name", ""),
}
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"accepted": True, "session": client_session_id},
)
_track_background_task(asyncio.create_task(_process_chat_message(connection_id, client_session_id, payload)))
async def _handle_chat_nickname_update(connection_id: str, message: Dict[str, Any]) -> None:
"""处理聊天昵称更新请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
client_session_id = str(message.get("session") or "").strip()
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
if session_id is None:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
)
return
data = _get_request_data(message)
session_state = chat_manager.get_session(session_id)
if session_state is None:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
)
return
next_user_name, next_virtual_config = await dispatch_chat_event(
session_id=session_id,
session_id_prefix=session_id[:8],
data={
"type": "update_nickname",
"user_name": data.get("user_name", ""),
},
current_user_name=session_state.user_name,
normalized_user_id=session_state.user_id,
current_virtual_config=session_state.virtual_config,
)
chat_manager.update_session_context(
session_id=session_id,
user_name=next_user_name,
virtual_config=next_virtual_config,
)
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"session": client_session_id, "user_name": next_user_name},
)
async def _handle_chat_call(connection_id: str, message: Dict[str, Any]) -> None:
"""处理聊天域调用请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
method = str(message.get("method") or "").strip()
if method == "session.open":
await _open_chat_session(connection_id, message)
return
if method == "session.close":
await _close_chat_session(connection_id, message)
return
if method == "message.send":
await _handle_chat_message_send(connection_id, message)
return
if method == "session.update_nickname":
await _handle_chat_nickname_update(connection_id, message)
return
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("unsupported_method", f"不支持的聊天方法: {method}"),
)
async def _handle_call(connection_id: str, message: Dict[str, Any]) -> None:
"""处理统一调用请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
domain = str(message.get("domain") or "").strip()
if domain == "chat":
await _handle_chat_call(connection_id, message)
return
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("unsupported_domain", f"不支持的调用域: {domain}"),
)
async def handle_client_message(connection_id: str, message: Dict[str, Any]) -> None:
"""处理统一 WebSocket 客户端消息。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
operation = str(message.get("op") or "").strip()
request_id = cast(Optional[str], message.get("id"))
if operation == "ping":
await websocket_manager.send_pong(connection_id, time.time())
return
if operation == "subscribe":
await _handle_subscribe(connection_id, message)
return
if operation == "unsubscribe":
await _handle_unsubscribe(connection_id, message)
return
if operation == "call":
await _handle_call(connection_id, message)
return
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("unsupported_operation", f"不支持的操作: {operation}"),
)
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, token: Optional[str] = Query(None)) -> None:
"""统一 WebSocket 入口。
Args:
websocket: FastAPI WebSocket 对象。
token: 可选的一次性握手 Token。
"""
if not await authenticate_websocket_connection(websocket, token):
logger.warning("统一 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
connection_id = uuid.uuid4().hex
await websocket_manager.connect(connection_id=connection_id, websocket=websocket)
logger.info("统一 WebSocket 客户端已连接: connection=%s", connection_id)
await websocket_manager.send_event(
connection_id,
domain="system",
event="ready",
data={"connection_id": connection_id, "timestamp": time.time()},
)
try:
while True:
raw_message = await websocket.receive_json()
if not isinstance(raw_message, dict):
await websocket_manager.send_response(
connection_id,
request_id=None,
ok=False,
error=_build_error("invalid_message", "消息必须是 JSON 对象"),
)
continue
await handle_client_message(connection_id, cast(Dict[str, Any], raw_message))
except WebSocketDisconnect:
logger.info("统一 WebSocket 客户端已断开: connection=%s", connection_id)
except Exception as exc:
logger.error(f"统一 WebSocket 处理失败: {exc}")
finally:
chat_manager.disconnect_connection(connection_id)
await websocket_manager.disconnect(connection_id)

View File

@@ -18,11 +18,11 @@ from src.webui.routers.expression import router as expression_router
from src.webui.routers.jargon import router as jargon_router
from src.webui.routers.model import router as model_router
from src.webui.routers.person import router as person_router
from src.webui.routers.plugin import get_progress_router
from src.webui.routers.plugin import router as plugin_router
from src.webui.routers.statistics import router as statistics_router
from src.webui.routers.system import router as system_router
from src.webui.routers.websocket.auth import router as ws_auth_router
from src.webui.routers.websocket.unified import router as unified_ws_router
logger = get_logger("webui.api")
@@ -43,14 +43,14 @@ router.include_router(jargon_router)
router.include_router(emoji_router)
# 注册插件管理路由
router.include_router(plugin_router)
# 注册插件进度 WebSocket 路由
router.include_router(get_progress_router())
# 注册系统控制路由
router.include_router(system_router)
# 注册模型列表获取路由
router.include_router(model_router)
# 注册 WebSocket 认证路由
router.include_router(ws_auth_router)
# 注册统一 WebSocket 路由
router.include_router(unified_ws_router)
class TokenVerifyRequest(BaseModel):