Merge branch 'r-dev' of https://github.com/Mai-with-u/MaiBot into r-dev
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,5 +1,6 @@
|
||||
data/
|
||||
data1/
|
||||
mai_knowledge/knowledge.json
|
||||
mongodb/
|
||||
NapCat.Framework.Windows.Once/
|
||||
NapCat.Framework.Windows.OneKey/
|
||||
|
||||
161
dashboard/src/lib/chat-ws-client.ts
Normal file
161
dashboard/src/lib/chat-ws-client.ts
Normal file
@@ -0,0 +1,161 @@
|
||||
import { unifiedWsClient, type ConnectionStatus } from './unified-ws'
|
||||
|
||||
interface ChatSessionOpenPayload {
|
||||
group_id?: string
|
||||
group_name?: string
|
||||
person_id?: string
|
||||
platform?: string
|
||||
user_id?: string
|
||||
user_name?: string
|
||||
}
|
||||
|
||||
type ChatSessionListener = (message: Record<string, unknown>) => void
|
||||
|
||||
class ChatWsClient {
|
||||
private initialized = false
|
||||
private listeners: Map<string, Set<ChatSessionListener>> = new Map()
|
||||
private sessionPayloads: Map<string, ChatSessionOpenPayload> = new Map()
|
||||
|
||||
private initialize(): void {
|
||||
if (this.initialized) {
|
||||
return
|
||||
}
|
||||
|
||||
unifiedWsClient.addEventListener((message) => {
|
||||
if (message.domain !== 'chat' || !message.session) {
|
||||
return
|
||||
}
|
||||
|
||||
const sessionListeners = this.listeners.get(message.session)
|
||||
if (!sessionListeners) {
|
||||
return
|
||||
}
|
||||
|
||||
sessionListeners.forEach((listener) => {
|
||||
try {
|
||||
listener(message.data)
|
||||
} catch (error) {
|
||||
console.error('聊天会话监听器执行失败:', error)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
unifiedWsClient.onReconnect(() => {
|
||||
void this.reopenSessions()
|
||||
})
|
||||
|
||||
this.initialized = true
|
||||
}
|
||||
|
||||
private async reopenSessions(): Promise<void> {
|
||||
const reopenTargets = Array.from(this.sessionPayloads.entries())
|
||||
for (const [sessionId, payload] of reopenTargets) {
|
||||
try {
|
||||
await unifiedWsClient.call({
|
||||
domain: 'chat',
|
||||
method: 'session.open',
|
||||
session: sessionId,
|
||||
data: {
|
||||
...payload,
|
||||
restore: true,
|
||||
} as Record<string, unknown>,
|
||||
})
|
||||
} catch (error) {
|
||||
console.error(`恢复聊天会话失败 (${sessionId}):`, error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async openSession(sessionId: string, payload: ChatSessionOpenPayload): Promise<void> {
|
||||
this.initialize()
|
||||
this.sessionPayloads.set(sessionId, payload)
|
||||
await unifiedWsClient.call({
|
||||
domain: 'chat',
|
||||
method: 'session.open',
|
||||
session: sessionId,
|
||||
data: payload as Record<string, unknown>,
|
||||
})
|
||||
}
|
||||
|
||||
async closeSession(sessionId: string): Promise<void> {
|
||||
this.sessionPayloads.delete(sessionId)
|
||||
if (unifiedWsClient.getStatus() !== 'connected') {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
await unifiedWsClient.call({
|
||||
domain: 'chat',
|
||||
method: 'session.close',
|
||||
session: sessionId,
|
||||
data: {},
|
||||
})
|
||||
} catch (error) {
|
||||
console.warn(`关闭聊天会话失败 (${sessionId}):`, error)
|
||||
}
|
||||
}
|
||||
|
||||
async sendMessage(sessionId: string, content: string, userName: string): Promise<void> {
|
||||
await unifiedWsClient.call({
|
||||
domain: 'chat',
|
||||
method: 'message.send',
|
||||
session: sessionId,
|
||||
data: {
|
||||
content,
|
||||
user_name: userName,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
async updateNickname(sessionId: string, userName: string): Promise<void> {
|
||||
const currentPayload = this.sessionPayloads.get(sessionId)
|
||||
if (currentPayload) {
|
||||
this.sessionPayloads.set(sessionId, {
|
||||
...currentPayload,
|
||||
user_name: userName,
|
||||
})
|
||||
}
|
||||
|
||||
await unifiedWsClient.call({
|
||||
domain: 'chat',
|
||||
method: 'session.update_nickname',
|
||||
session: sessionId,
|
||||
data: {
|
||||
user_name: userName,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
onSessionMessage(sessionId: string, listener: ChatSessionListener): () => void {
|
||||
this.initialize()
|
||||
const sessionListeners = this.listeners.get(sessionId) ?? new Set<ChatSessionListener>()
|
||||
sessionListeners.add(listener)
|
||||
this.listeners.set(sessionId, sessionListeners)
|
||||
|
||||
return () => {
|
||||
const currentListeners = this.listeners.get(sessionId)
|
||||
if (!currentListeners) {
|
||||
return
|
||||
}
|
||||
|
||||
currentListeners.delete(listener)
|
||||
if (currentListeners.size === 0) {
|
||||
this.listeners.delete(sessionId)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
onConnectionChange(listener: (connected: boolean) => void): () => void {
|
||||
return unifiedWsClient.onConnectionChange(listener)
|
||||
}
|
||||
|
||||
onStatusChange(listener: (status: ConnectionStatus) => void): () => void {
|
||||
return unifiedWsClient.onStatusChange(listener)
|
||||
}
|
||||
|
||||
async restart(): Promise<void> {
|
||||
await unifiedWsClient.restart()
|
||||
}
|
||||
}
|
||||
|
||||
export const chatWsClient = new ChatWsClient()
|
||||
@@ -1,13 +1,11 @@
|
||||
/**
|
||||
* 全局日志 WebSocket 管理器
|
||||
* 确保整个应用只有一个 WebSocket 连接
|
||||
* 确保整个应用只通过统一连接层订阅日志流
|
||||
*/
|
||||
|
||||
import { checkAuthStatus } from './fetch-with-auth'
|
||||
import { getSetting } from './settings-manager'
|
||||
import { createReconnectingWebSocket } from './ws-utils'
|
||||
|
||||
import { getWsBaseUrl } from '@/lib/api-base'
|
||||
import { unifiedWsClient } from './unified-ws'
|
||||
|
||||
export interface LogEntry {
|
||||
id: string
|
||||
@@ -17,165 +15,79 @@ export interface LogEntry {
|
||||
message: string
|
||||
}
|
||||
|
||||
type LogCallback = (log: LogEntry) => void
|
||||
type LogCallback = () => void
|
||||
type ConnectionCallback = (connected: boolean) => void
|
||||
|
||||
class LogWebSocketManager {
|
||||
private wsControl: ReturnType<typeof createReconnectingWebSocket> | null = null
|
||||
|
||||
// 订阅者
|
||||
private logCallbacks: Set<LogCallback> = new Set()
|
||||
private connectionCallbacks: Set<ConnectionCallback> = new Set()
|
||||
|
||||
private initialized = false
|
||||
private isConnected = false
|
||||
|
||||
// 日志缓存 - 保存所有接收到的日志
|
||||
private logCache: LogEntry[] = []
|
||||
private logCallbacks: Set<LogCallback> = new Set()
|
||||
private subscriptionActive = false
|
||||
|
||||
/**
|
||||
* 获取最大缓存大小(从设置读取)
|
||||
*/
|
||||
private getMaxCacheSize(): number {
|
||||
return getSetting('logCacheSize')
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取最大重连次数(从设置读取)
|
||||
*/
|
||||
private getMaxReconnectAttempts(): number {
|
||||
return getSetting('wsMaxReconnectAttempts')
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取重连间隔(从设置读取)
|
||||
*/
|
||||
private getReconnectInterval(): number {
|
||||
return getSetting('wsReconnectInterval')
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 WebSocket URL(不含 token 参数)
|
||||
*/
|
||||
private async getWebSocketUrl(): Promise<string> {
|
||||
const wsBase = await getWsBaseUrl()
|
||||
return `${wsBase}/ws/logs`
|
||||
}
|
||||
|
||||
/**
|
||||
* 连接 WebSocket(会先检查登录状态)
|
||||
*/
|
||||
async connect() {
|
||||
// 检查是否在登录页面
|
||||
if (window.location.pathname === '/auth') {
|
||||
console.log('📡 在登录页面,跳过 WebSocket 连接')
|
||||
private initialize(): void {
|
||||
if (this.initialized) {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查登录状态,避免未登录时尝试连接
|
||||
const isAuthenticated = await checkAuthStatus()
|
||||
if (!isAuthenticated) {
|
||||
console.log('📡 未登录,跳过 WebSocket 连接')
|
||||
return
|
||||
}
|
||||
unifiedWsClient.addEventListener((message) => {
|
||||
if (message.domain !== 'logs') {
|
||||
return
|
||||
}
|
||||
|
||||
const wsUrl = await this.getWebSocketUrl()
|
||||
if (message.event === 'snapshot') {
|
||||
const entries = Array.isArray(message.data.entries)
|
||||
? (message.data.entries as LogEntry[])
|
||||
: []
|
||||
this.logCache = entries.slice(-this.getMaxCacheSize())
|
||||
this.notifyLogChange()
|
||||
return
|
||||
}
|
||||
|
||||
// 使用 ws-utils 创建 WebSocket
|
||||
this.wsControl = createReconnectingWebSocket(wsUrl, {
|
||||
onMessage: (data: string) => {
|
||||
try {
|
||||
const log: LogEntry = JSON.parse(data)
|
||||
this.notifyLog(log)
|
||||
} catch (error) {
|
||||
console.error('解析日志消息失败:', error)
|
||||
}
|
||||
},
|
||||
onOpen: () => {
|
||||
this.isConnected = true
|
||||
this.notifyConnection(true)
|
||||
},
|
||||
onClose: () => {
|
||||
this.isConnected = false
|
||||
this.notifyConnection(false)
|
||||
},
|
||||
onError: (error) => {
|
||||
console.error('❌ WebSocket 错误:', error)
|
||||
this.isConnected = false
|
||||
this.notifyConnection(false)
|
||||
},
|
||||
heartbeatInterval: 30000,
|
||||
maxRetries: this.getMaxReconnectAttempts(),
|
||||
backoffBase: this.getReconnectInterval(),
|
||||
maxBackoff: 30000,
|
||||
if (message.event === 'entry' && message.data.entry) {
|
||||
this.appendLog(message.data.entry as LogEntry)
|
||||
}
|
||||
})
|
||||
|
||||
// 启动连接
|
||||
await this.wsControl.connect()
|
||||
unifiedWsClient.onConnectionChange((connected) => {
|
||||
this.isConnected = connected
|
||||
this.notifyConnection(connected)
|
||||
})
|
||||
|
||||
this.initialized = true
|
||||
}
|
||||
|
||||
/**
|
||||
* 断开连接
|
||||
*/
|
||||
disconnect() {
|
||||
if (this.wsControl) {
|
||||
this.wsControl.disconnect()
|
||||
this.wsControl = null
|
||||
}
|
||||
|
||||
this.isConnected = false
|
||||
}
|
||||
|
||||
/**
|
||||
* 订阅日志消息
|
||||
*/
|
||||
onLog(callback: LogCallback) {
|
||||
this.logCallbacks.add(callback)
|
||||
return () => this.logCallbacks.delete(callback)
|
||||
}
|
||||
|
||||
/**
|
||||
* 订阅连接状态
|
||||
*/
|
||||
onConnectionChange(callback: ConnectionCallback) {
|
||||
this.connectionCallbacks.add(callback)
|
||||
// 立即通知当前状态
|
||||
callback(this.isConnected)
|
||||
return () => this.connectionCallbacks.delete(callback)
|
||||
}
|
||||
|
||||
/**
|
||||
* 通知所有订阅者新日志
|
||||
*/
|
||||
private notifyLog(log: LogEntry) {
|
||||
// 检查是否已存在(通过 id 去重)
|
||||
private appendLog(log: LogEntry): void {
|
||||
const exists = this.logCache.some(existingLog => existingLog.id === log.id)
|
||||
|
||||
if (!exists) {
|
||||
// 添加到缓存
|
||||
this.logCache.push(log)
|
||||
|
||||
// 限制缓存大小(动态读取配置)
|
||||
const maxCacheSize = this.getMaxCacheSize()
|
||||
if (this.logCache.length > maxCacheSize) {
|
||||
this.logCache = this.logCache.slice(-maxCacheSize)
|
||||
}
|
||||
|
||||
// 只有新日志才通知订阅者
|
||||
this.logCallbacks.forEach(callback => {
|
||||
try {
|
||||
callback(log)
|
||||
} catch (error) {
|
||||
console.error('日志回调执行失败:', error)
|
||||
}
|
||||
})
|
||||
if (exists) {
|
||||
return
|
||||
}
|
||||
|
||||
this.logCache.push(log)
|
||||
const maxCacheSize = this.getMaxCacheSize()
|
||||
if (this.logCache.length > maxCacheSize) {
|
||||
this.logCache = this.logCache.slice(-maxCacheSize)
|
||||
}
|
||||
this.notifyLogChange()
|
||||
}
|
||||
|
||||
/**
|
||||
* 通知所有订阅者连接状态变化
|
||||
*/
|
||||
private notifyConnection(connected: boolean) {
|
||||
this.connectionCallbacks.forEach(callback => {
|
||||
private notifyLogChange(): void {
|
||||
this.logCallbacks.forEach((callback) => {
|
||||
try {
|
||||
callback()
|
||||
} catch (error) {
|
||||
console.error('日志回调执行失败:', error)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private notifyConnection(connected: boolean): void {
|
||||
this.connectionCallbacks.forEach((callback) => {
|
||||
try {
|
||||
callback(connected)
|
||||
} catch (error) {
|
||||
@@ -184,35 +96,65 @@ class LogWebSocketManager {
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取缓存的所有日志
|
||||
*/
|
||||
async connect(): Promise<void> {
|
||||
if (window.location.pathname === '/auth') {
|
||||
return
|
||||
}
|
||||
|
||||
const isAuthenticated = await checkAuthStatus()
|
||||
if (!isAuthenticated) {
|
||||
return
|
||||
}
|
||||
|
||||
this.initialize()
|
||||
if (this.subscriptionActive) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
await unifiedWsClient.subscribe('logs', 'main', { replay: 100 })
|
||||
this.subscriptionActive = true
|
||||
} catch (error) {
|
||||
console.error('订阅日志流失败:', error)
|
||||
}
|
||||
}
|
||||
|
||||
disconnect(): void {
|
||||
this.subscriptionActive = false
|
||||
void unifiedWsClient.unsubscribe('logs', 'main')
|
||||
this.isConnected = false
|
||||
this.notifyConnection(false)
|
||||
}
|
||||
|
||||
onLog(callback: LogCallback): () => void {
|
||||
this.logCallbacks.add(callback)
|
||||
return () => this.logCallbacks.delete(callback)
|
||||
}
|
||||
|
||||
onConnectionChange(callback: ConnectionCallback): () => void {
|
||||
this.connectionCallbacks.add(callback)
|
||||
callback(this.isConnected)
|
||||
return () => this.connectionCallbacks.delete(callback)
|
||||
}
|
||||
|
||||
getAllLogs(): LogEntry[] {
|
||||
return [...this.logCache]
|
||||
}
|
||||
|
||||
/**
|
||||
* 清空日志缓存
|
||||
*/
|
||||
clearLogs() {
|
||||
clearLogs(): void {
|
||||
this.logCache = []
|
||||
this.notifyLogChange()
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前连接状态
|
||||
*/
|
||||
getConnectionStatus(): boolean {
|
||||
return this.isConnected
|
||||
}
|
||||
}
|
||||
|
||||
// 导出单例
|
||||
export const logWebSocket = new LogWebSocketManager()
|
||||
|
||||
// 自动连接(应用启动时)
|
||||
if (typeof window !== 'undefined') {
|
||||
// 延迟一下确保页面加载完成
|
||||
setTimeout(() => {
|
||||
logWebSocket.connect()
|
||||
void logWebSocket.connect()
|
||||
}, 100)
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import type { ApiResponse } from '@/types/api'
|
||||
import type { PluginInfo } from '@/types/plugin'
|
||||
|
||||
import { getWsBaseUrl } from '@/lib/api-base'
|
||||
import { fetchWithAuth } from '@/lib/fetch-with-auth'
|
||||
import { parseResponse } from '@/lib/api-helpers'
|
||||
import { pluginProgressClient } from '@/lib/plugin-progress-client'
|
||||
import type { GitStatus, MaimaiVersion } from './types'
|
||||
|
||||
/**
|
||||
@@ -211,41 +211,13 @@ export function isPluginCompatible(
|
||||
*/
|
||||
export async function connectPluginProgressWebSocket(
|
||||
onProgress: (progress: import('./types').PluginLoadProgress) => void,
|
||||
onError?: (error: Event) => void
|
||||
): Promise<WebSocket | null> {
|
||||
const wsBase = await getWsBaseUrl()
|
||||
const wsUrl = `${wsBase}/api/webui/ws/plugin-progress`
|
||||
|
||||
// 使用 ws-utils 创建 WebSocket
|
||||
const { createReconnectingWebSocket } = await import('@/lib/ws-utils')
|
||||
const wsControl = createReconnectingWebSocket(wsUrl, {
|
||||
onMessage: (data: string) => {
|
||||
try {
|
||||
const progressData = JSON.parse(data) as import('./types').PluginLoadProgress
|
||||
onProgress(progressData)
|
||||
} catch (error) {
|
||||
console.error('Failed to parse progress data:', error)
|
||||
}
|
||||
},
|
||||
onOpen: () => {
|
||||
console.log('Plugin progress WebSocket connected')
|
||||
},
|
||||
onClose: () => {
|
||||
console.log('Plugin progress WebSocket disconnected')
|
||||
},
|
||||
onError: (error) => {
|
||||
console.error('Plugin progress WebSocket error:', error)
|
||||
onError?.(error)
|
||||
},
|
||||
heartbeatInterval: 30000,
|
||||
maxRetries: 10,
|
||||
backoffBase: 1000,
|
||||
maxBackoff: 30000,
|
||||
})
|
||||
|
||||
// 启动连接
|
||||
await wsControl.connect()
|
||||
|
||||
// 返回 WebSocket 实例(用于外部检查连接状态)
|
||||
return wsControl.getWebSocket()
|
||||
onError?: (error: Error) => void
|
||||
): Promise<() => Promise<void>> {
|
||||
try {
|
||||
return await pluginProgressClient.subscribe(onProgress)
|
||||
} catch (error) {
|
||||
const normalizedError = error instanceof Error ? error : new Error('插件进度订阅失败')
|
||||
onError?.(normalizedError)
|
||||
return async () => {}
|
||||
}
|
||||
}
|
||||
|
||||
58
dashboard/src/lib/plugin-progress-client.ts
Normal file
58
dashboard/src/lib/plugin-progress-client.ts
Normal file
@@ -0,0 +1,58 @@
|
||||
import type { PluginLoadProgress } from '@/lib/plugin-api/types'
|
||||
|
||||
import { unifiedWsClient } from './unified-ws'
|
||||
|
||||
type ProgressListener = (progress: PluginLoadProgress) => void
|
||||
|
||||
class PluginProgressClient {
|
||||
private initialized = false
|
||||
private listeners: Set<ProgressListener> = new Set()
|
||||
private subscriptionActive = false
|
||||
|
||||
private initialize(): void {
|
||||
if (this.initialized) {
|
||||
return
|
||||
}
|
||||
|
||||
unifiedWsClient.addEventListener((message) => {
|
||||
if (message.domain !== 'plugin_progress') {
|
||||
return
|
||||
}
|
||||
|
||||
const progress = message.data.progress as PluginLoadProgress | undefined
|
||||
if (!progress) {
|
||||
return
|
||||
}
|
||||
|
||||
this.listeners.forEach((listener) => {
|
||||
try {
|
||||
listener(progress)
|
||||
} catch (error) {
|
||||
console.error('插件进度监听器执行失败:', error)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
this.initialized = true
|
||||
}
|
||||
|
||||
async subscribe(listener: ProgressListener): Promise<() => Promise<void>> {
|
||||
this.initialize()
|
||||
this.listeners.add(listener)
|
||||
|
||||
if (!this.subscriptionActive) {
|
||||
await unifiedWsClient.subscribe('plugin_progress', 'main')
|
||||
this.subscriptionActive = true
|
||||
}
|
||||
|
||||
return async () => {
|
||||
this.listeners.delete(listener)
|
||||
if (this.listeners.size === 0 && this.subscriptionActive) {
|
||||
this.subscriptionActive = false
|
||||
await unifiedWsClient.unsubscribe('plugin_progress', 'main')
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const pluginProgressClient = new PluginProgressClient()
|
||||
495
dashboard/src/lib/unified-ws.ts
Normal file
495
dashboard/src/lib/unified-ws.ts
Normal file
@@ -0,0 +1,495 @@
|
||||
import { fetchWithAuth } from './fetch-with-auth'
|
||||
import { getSetting } from './settings-manager'
|
||||
|
||||
import { getWsBaseUrl } from '@/lib/api-base'
|
||||
|
||||
export type ConnectionStatus = 'idle' | 'connecting' | 'connected'
|
||||
|
||||
export interface WsErrorPayload {
|
||||
code?: string
|
||||
message: string
|
||||
}
|
||||
|
||||
export interface WsEventEnvelope {
|
||||
op: 'event'
|
||||
domain: string
|
||||
event: string
|
||||
session?: string
|
||||
topic?: string
|
||||
data: Record<string, unknown>
|
||||
}
|
||||
|
||||
interface WsResponseEnvelope {
|
||||
op: 'response'
|
||||
id?: string
|
||||
ok: boolean
|
||||
data?: Record<string, unknown>
|
||||
error?: WsErrorPayload
|
||||
}
|
||||
|
||||
interface WsPongEnvelope {
|
||||
op: 'pong'
|
||||
ts: number
|
||||
}
|
||||
|
||||
type WsServerEnvelope = WsEventEnvelope | WsPongEnvelope | WsResponseEnvelope
|
||||
|
||||
interface PendingRequest {
|
||||
reject: (error: Error) => void
|
||||
resolve: (data: Record<string, unknown>) => void
|
||||
timeoutId: number
|
||||
}
|
||||
|
||||
interface SubscriptionDefinition {
|
||||
data?: Record<string, unknown>
|
||||
domain: string
|
||||
topic: string
|
||||
}
|
||||
|
||||
type EventListener = (message: WsEventEnvelope) => void
|
||||
type ConnectionListener = (connected: boolean) => void
|
||||
type StatusListener = (status: ConnectionStatus) => void
|
||||
type ReconnectListener = () => void
|
||||
|
||||
function isResponseEnvelope(message: WsServerEnvelope): message is WsResponseEnvelope {
|
||||
return message.op === 'response'
|
||||
}
|
||||
|
||||
function isEventEnvelope(message: WsServerEnvelope): message is WsEventEnvelope {
|
||||
return message.op === 'event'
|
||||
}
|
||||
|
||||
async function getWsToken(): Promise<string | null> {
|
||||
try {
|
||||
const response = await fetchWithAuth('/api/webui/ws-token', {
|
||||
method: 'GET',
|
||||
credentials: 'include',
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
return null
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
if (data.success && data.token) {
|
||||
return data.token as string
|
||||
}
|
||||
|
||||
return null
|
||||
} catch (error) {
|
||||
console.error('获取统一 WebSocket token 失败:', error)
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
class UnifiedWebSocketClient {
|
||||
private connectPromise: Promise<void> | null = null
|
||||
private connectionListeners: Set<ConnectionListener> = new Set()
|
||||
private eventListeners: Set<EventListener> = new Set()
|
||||
private hasConnectedOnce = false
|
||||
private heartbeatIntervalId: number | null = null
|
||||
private manualDisconnect = false
|
||||
private pendingRequests: Map<string, PendingRequest> = new Map()
|
||||
private reconnectAttempts = 0
|
||||
private reconnectListeners: Set<ReconnectListener> = new Set()
|
||||
private reconnectTimeout: number | null = null
|
||||
private requestCounter = 0
|
||||
private status: ConnectionStatus = 'idle'
|
||||
private statusListeners: Set<StatusListener> = new Set()
|
||||
private subscriptions: Map<string, SubscriptionDefinition> = new Map()
|
||||
private ws: WebSocket | null = null
|
||||
|
||||
private getReconnectDelay(): number {
|
||||
const baseDelay = getSetting('wsReconnectInterval')
|
||||
return Math.min(baseDelay * Math.max(this.reconnectAttempts, 1), 30000)
|
||||
}
|
||||
|
||||
private getMaxReconnectAttempts(): number {
|
||||
return getSetting('wsMaxReconnectAttempts')
|
||||
}
|
||||
|
||||
private getSubscriptionKey(domain: string, topic: string): string {
|
||||
return `${domain}:${topic}`
|
||||
}
|
||||
|
||||
private nextRequestId(): string {
|
||||
this.requestCounter += 1
|
||||
return `ws-${Date.now()}-${this.requestCounter}`
|
||||
}
|
||||
|
||||
private setStatus(status: ConnectionStatus): void {
|
||||
if (this.status === status) {
|
||||
return
|
||||
}
|
||||
|
||||
this.status = status
|
||||
this.statusListeners.forEach((listener) => {
|
||||
try {
|
||||
listener(status)
|
||||
} catch (error) {
|
||||
console.error('WebSocket 状态监听器执行失败:', error)
|
||||
}
|
||||
})
|
||||
|
||||
const connected = status === 'connected'
|
||||
this.connectionListeners.forEach((listener) => {
|
||||
try {
|
||||
listener(connected)
|
||||
} catch (error) {
|
||||
console.error('WebSocket 连接监听器执行失败:', error)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private stopHeartbeat(): void {
|
||||
if (this.heartbeatIntervalId !== null) {
|
||||
clearInterval(this.heartbeatIntervalId)
|
||||
this.heartbeatIntervalId = null
|
||||
}
|
||||
}
|
||||
|
||||
private startHeartbeat(): void {
|
||||
this.stopHeartbeat()
|
||||
this.heartbeatIntervalId = window.setInterval(() => {
|
||||
if (this.ws?.readyState === WebSocket.OPEN) {
|
||||
this.ws.send(JSON.stringify({ op: 'ping' }))
|
||||
}
|
||||
}, 30000)
|
||||
}
|
||||
|
||||
private clearReconnectTimer(): void {
|
||||
if (this.reconnectTimeout !== null) {
|
||||
clearTimeout(this.reconnectTimeout)
|
||||
this.reconnectTimeout = null
|
||||
}
|
||||
}
|
||||
|
||||
private rejectPendingRequests(error: Error): void {
|
||||
this.pendingRequests.forEach((pendingRequest, requestId) => {
|
||||
clearTimeout(pendingRequest.timeoutId)
|
||||
pendingRequest.reject(error)
|
||||
this.pendingRequests.delete(requestId)
|
||||
})
|
||||
}
|
||||
|
||||
private scheduleReconnect(): void {
|
||||
if (this.manualDisconnect) {
|
||||
return
|
||||
}
|
||||
|
||||
if (this.reconnectAttempts >= this.getMaxReconnectAttempts()) {
|
||||
console.warn(`统一 WebSocket 达到最大重连次数 (${this.getMaxReconnectAttempts()}),停止重连`)
|
||||
return
|
||||
}
|
||||
|
||||
this.reconnectAttempts += 1
|
||||
const delay = this.getReconnectDelay()
|
||||
this.clearReconnectTimer()
|
||||
this.reconnectTimeout = window.setTimeout(() => {
|
||||
void this.connect().catch((error) => {
|
||||
console.error('统一 WebSocket 重连失败:', error)
|
||||
})
|
||||
}, delay)
|
||||
}
|
||||
|
||||
private async createWebSocketUrl(): Promise<string | null> {
|
||||
const wsBaseUrl = await getWsBaseUrl()
|
||||
const wsToken = await getWsToken()
|
||||
if (!wsBaseUrl || !wsToken) {
|
||||
return null
|
||||
}
|
||||
return `${wsBaseUrl}/api/webui/ws?token=${encodeURIComponent(wsToken)}`
|
||||
}
|
||||
|
||||
private async sendRequest(
|
||||
payload: Record<string, unknown>,
|
||||
timeoutMs = 10000,
|
||||
): Promise<Record<string, unknown>> {
|
||||
if (this.ws?.readyState !== WebSocket.OPEN) {
|
||||
throw new Error('统一 WebSocket 尚未连接')
|
||||
}
|
||||
|
||||
const requestId = payload.id as string
|
||||
return await new Promise<Record<string, unknown>>((resolve, reject) => {
|
||||
const timeoutId = window.setTimeout(() => {
|
||||
this.pendingRequests.delete(requestId)
|
||||
reject(new Error(`统一 WebSocket 请求超时: ${requestId}`))
|
||||
}, timeoutMs)
|
||||
|
||||
this.pendingRequests.set(requestId, {
|
||||
resolve,
|
||||
reject,
|
||||
timeoutId,
|
||||
})
|
||||
this.ws?.send(JSON.stringify(payload))
|
||||
})
|
||||
}
|
||||
|
||||
private async restoreState(shouldNotifyReconnect: boolean): Promise<void> {
|
||||
const subscriptions = Array.from(this.subscriptions.values())
|
||||
for (const subscription of subscriptions) {
|
||||
try {
|
||||
await this.sendRequest({
|
||||
op: 'subscribe',
|
||||
id: this.nextRequestId(),
|
||||
domain: subscription.domain,
|
||||
topic: subscription.topic,
|
||||
data: subscription.data ?? {},
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('恢复统一 WebSocket 订阅失败:', error)
|
||||
}
|
||||
}
|
||||
|
||||
if (shouldNotifyReconnect) {
|
||||
this.reconnectListeners.forEach((listener) => {
|
||||
try {
|
||||
listener()
|
||||
} catch (error) {
|
||||
console.error('统一 WebSocket 重连监听器执行失败:', error)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
private handleServerMessage(rawData: string): void {
|
||||
let message: WsServerEnvelope
|
||||
try {
|
||||
message = JSON.parse(rawData) as WsServerEnvelope
|
||||
} catch (error) {
|
||||
console.error('解析统一 WebSocket 消息失败:', error)
|
||||
return
|
||||
}
|
||||
|
||||
if (message.op === 'pong') {
|
||||
return
|
||||
}
|
||||
|
||||
if (isResponseEnvelope(message)) {
|
||||
const requestId = message.id
|
||||
if (!requestId) {
|
||||
return
|
||||
}
|
||||
|
||||
const pendingRequest = this.pendingRequests.get(requestId)
|
||||
if (!pendingRequest) {
|
||||
return
|
||||
}
|
||||
|
||||
clearTimeout(pendingRequest.timeoutId)
|
||||
this.pendingRequests.delete(requestId)
|
||||
if (message.ok) {
|
||||
pendingRequest.resolve(message.data ?? {})
|
||||
} else {
|
||||
pendingRequest.reject(new Error(message.error?.message ?? '统一 WebSocket 请求失败'))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if (isEventEnvelope(message)) {
|
||||
this.eventListeners.forEach((listener) => {
|
||||
try {
|
||||
listener(message)
|
||||
} catch (error) {
|
||||
console.error('统一 WebSocket 事件监听器执行失败:', error)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
private handleClose(event: CloseEvent): void {
|
||||
this.stopHeartbeat()
|
||||
this.ws = null
|
||||
this.connectPromise = null
|
||||
this.setStatus('idle')
|
||||
this.rejectPendingRequests(new Error(`统一 WebSocket 已关闭 (${event.code})`))
|
||||
|
||||
if (event.code === 4001) {
|
||||
this.manualDisconnect = true
|
||||
if (window.location.pathname !== '/auth') {
|
||||
window.location.href = '/auth'
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
this.scheduleReconnect()
|
||||
}
|
||||
|
||||
async connect(): Promise<void> {
|
||||
if (this.ws?.readyState === WebSocket.OPEN) {
|
||||
return
|
||||
}
|
||||
|
||||
if (this.connectPromise) {
|
||||
return await this.connectPromise
|
||||
}
|
||||
|
||||
this.manualDisconnect = false
|
||||
this.setStatus('connecting')
|
||||
|
||||
this.connectPromise = (async () => {
|
||||
const wsUrl = await this.createWebSocketUrl()
|
||||
if (!wsUrl) {
|
||||
this.setStatus('idle')
|
||||
throw new Error('无法建立统一 WebSocket 连接')
|
||||
}
|
||||
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
let settled = false
|
||||
const socket = new WebSocket(wsUrl)
|
||||
this.ws = socket
|
||||
|
||||
socket.onopen = () => {
|
||||
settled = true
|
||||
const shouldNotifyReconnect = this.hasConnectedOnce
|
||||
this.hasConnectedOnce = true
|
||||
this.reconnectAttempts = 0
|
||||
this.startHeartbeat()
|
||||
this.setStatus('connected')
|
||||
resolve()
|
||||
void this.restoreState(shouldNotifyReconnect)
|
||||
}
|
||||
|
||||
socket.onmessage = (event) => {
|
||||
this.handleServerMessage(event.data)
|
||||
}
|
||||
|
||||
socket.onerror = () => {
|
||||
if (!settled) {
|
||||
settled = true
|
||||
reject(new Error('统一 WebSocket 连接失败'))
|
||||
}
|
||||
}
|
||||
|
||||
socket.onclose = (event) => {
|
||||
if (!settled) {
|
||||
settled = true
|
||||
reject(new Error(`统一 WebSocket 已关闭 (${event.code})`))
|
||||
}
|
||||
this.handleClose(event)
|
||||
}
|
||||
})
|
||||
})()
|
||||
|
||||
try {
|
||||
await this.connectPromise
|
||||
} finally {
|
||||
if (this.status !== 'connected') {
|
||||
this.connectPromise = null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
disconnect(): void {
|
||||
this.manualDisconnect = true
|
||||
this.clearReconnectTimer()
|
||||
this.stopHeartbeat()
|
||||
this.rejectPendingRequests(new Error('统一 WebSocket 已手动断开'))
|
||||
this.connectPromise = null
|
||||
if (this.ws) {
|
||||
this.ws.close()
|
||||
this.ws = null
|
||||
}
|
||||
this.setStatus('idle')
|
||||
}
|
||||
|
||||
async restart(): Promise<void> {
|
||||
this.manualDisconnect = false
|
||||
this.clearReconnectTimer()
|
||||
if (this.ws) {
|
||||
this.ws.close()
|
||||
return
|
||||
}
|
||||
await this.connect()
|
||||
}
|
||||
|
||||
async call(params: {
|
||||
data?: Record<string, unknown>
|
||||
domain: string
|
||||
method: string
|
||||
session?: string
|
||||
}): Promise<Record<string, unknown>> {
|
||||
await this.connect()
|
||||
const requestId = this.nextRequestId()
|
||||
return await this.sendRequest({
|
||||
op: 'call',
|
||||
id: requestId,
|
||||
domain: params.domain,
|
||||
method: params.method,
|
||||
session: params.session,
|
||||
data: params.data ?? {},
|
||||
})
|
||||
}
|
||||
|
||||
async subscribe(
|
||||
domain: string,
|
||||
topic: string,
|
||||
data?: Record<string, unknown>,
|
||||
): Promise<Record<string, unknown>> {
|
||||
await this.connect()
|
||||
this.subscriptions.set(this.getSubscriptionKey(domain, topic), {
|
||||
domain,
|
||||
topic,
|
||||
data,
|
||||
})
|
||||
|
||||
return await this.sendRequest({
|
||||
op: 'subscribe',
|
||||
id: this.nextRequestId(),
|
||||
domain,
|
||||
topic,
|
||||
data: data ?? {},
|
||||
})
|
||||
}
|
||||
|
||||
async unsubscribe(domain: string, topic: string): Promise<Record<string, unknown> | null> {
|
||||
this.subscriptions.delete(this.getSubscriptionKey(domain, topic))
|
||||
if (this.ws?.readyState !== WebSocket.OPEN) {
|
||||
return null
|
||||
}
|
||||
|
||||
return await this.sendRequest({
|
||||
op: 'unsubscribe',
|
||||
id: this.nextRequestId(),
|
||||
domain,
|
||||
topic,
|
||||
data: {},
|
||||
})
|
||||
}
|
||||
|
||||
addEventListener(listener: EventListener): () => void {
|
||||
this.eventListeners.add(listener)
|
||||
return () => {
|
||||
this.eventListeners.delete(listener)
|
||||
}
|
||||
}
|
||||
|
||||
onConnectionChange(listener: ConnectionListener): () => void {
|
||||
this.connectionListeners.add(listener)
|
||||
listener(this.status === 'connected')
|
||||
return () => {
|
||||
this.connectionListeners.delete(listener)
|
||||
}
|
||||
}
|
||||
|
||||
onStatusChange(listener: StatusListener): () => void {
|
||||
this.statusListeners.add(listener)
|
||||
listener(this.status)
|
||||
return () => {
|
||||
this.statusListeners.delete(listener)
|
||||
}
|
||||
}
|
||||
|
||||
onReconnect(listener: ReconnectListener): () => void {
|
||||
this.reconnectListeners.add(listener)
|
||||
return () => {
|
||||
this.reconnectListeners.delete(listener)
|
||||
}
|
||||
}
|
||||
|
||||
getStatus(): ConnectionStatus {
|
||||
return this.status
|
||||
}
|
||||
}
|
||||
|
||||
export const unifiedWsClient = new UnifiedWebSocketClient()
|
||||
@@ -1,211 +0,0 @@
|
||||
import { fetchWithAuth } from './fetch-with-auth'
|
||||
|
||||
/**
|
||||
* WebSocket 配置选项
|
||||
*/
|
||||
export interface WebSocketOptions {
|
||||
onMessage?: (data: string) => void
|
||||
onOpen?: () => void
|
||||
onClose?: () => void
|
||||
onError?: (error: Event) => void
|
||||
heartbeatInterval?: number // 心跳间隔(毫秒)
|
||||
maxRetries?: number // 最大重连次数
|
||||
backoffBase?: number // 重连基础间隔(毫秒)
|
||||
maxBackoff?: number // 最大重连间隔(毫秒)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 WebSocket 临时认证 token
|
||||
*/
|
||||
export async function getWsToken(): Promise<string | null> {
|
||||
try {
|
||||
// 使用相对路径,让前端代理处理请求,避免 CORS 问题
|
||||
const response = await fetchWithAuth('/api/webui/ws-token', {
|
||||
method: 'GET',
|
||||
credentials: 'include', // 携带 Cookie
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
console.error('获取 WebSocket token 失败:', response.status)
|
||||
return null
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
if (data.success && data.token) {
|
||||
return data.token
|
||||
}
|
||||
return null
|
||||
} catch (error) {
|
||||
console.error('获取 WebSocket token 失败:', error)
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建带重连、心跳的 WebSocket 封装
|
||||
*
|
||||
* @param url WebSocket URL(不含 token 参数)
|
||||
* @param options 配置选项
|
||||
* @returns WebSocket 控制对象,包含 connect、disconnect、send 方法
|
||||
*/
|
||||
export function createReconnectingWebSocket(
|
||||
url: string,
|
||||
options: WebSocketOptions = {}
|
||||
) {
|
||||
const {
|
||||
onMessage,
|
||||
onOpen,
|
||||
onClose,
|
||||
onError,
|
||||
heartbeatInterval = 30000,
|
||||
maxRetries = 10,
|
||||
backoffBase = 1000,
|
||||
maxBackoff = 30000,
|
||||
} = options
|
||||
|
||||
let ws: WebSocket | null = null
|
||||
let reconnectTimeout: number | null = null
|
||||
let reconnectAttempts = 0
|
||||
let heartbeatIntervalId: number | null = null
|
||||
let isManualDisconnect = false
|
||||
|
||||
/**
|
||||
* 启动心跳
|
||||
*/
|
||||
function startHeartbeat() {
|
||||
stopHeartbeat()
|
||||
heartbeatIntervalId = window.setInterval(() => {
|
||||
if (ws?.readyState === WebSocket.OPEN) {
|
||||
ws.send('ping')
|
||||
}
|
||||
}, heartbeatInterval)
|
||||
}
|
||||
|
||||
/**
|
||||
* 停止心跳
|
||||
*/
|
||||
function stopHeartbeat() {
|
||||
if (heartbeatIntervalId !== null) {
|
||||
clearInterval(heartbeatIntervalId)
|
||||
heartbeatIntervalId = null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 尝试重连
|
||||
*/
|
||||
function attemptReconnect() {
|
||||
if (isManualDisconnect) {
|
||||
return
|
||||
}
|
||||
|
||||
if (reconnectAttempts >= maxRetries) {
|
||||
console.warn(`WebSocket 达到最大重连次数 (${maxRetries}),停止重连`)
|
||||
return
|
||||
}
|
||||
|
||||
reconnectAttempts += 1
|
||||
const delay = Math.min(backoffBase * reconnectAttempts, maxBackoff)
|
||||
|
||||
console.log(`WebSocket 将在 ${delay}ms 后重连(第 ${reconnectAttempts} 次)`)
|
||||
reconnectTimeout = window.setTimeout(() => {
|
||||
connect()
|
||||
}, delay)
|
||||
}
|
||||
|
||||
/**
|
||||
* 连接 WebSocket
|
||||
*/
|
||||
async function connect() {
|
||||
if (ws?.readyState === WebSocket.OPEN || ws?.readyState === WebSocket.CONNECTING) {
|
||||
return
|
||||
}
|
||||
|
||||
// 先获取临时认证 token
|
||||
const wsToken = await getWsToken()
|
||||
if (!wsToken) {
|
||||
console.warn('无法获取 WebSocket token,跳过连接')
|
||||
return
|
||||
}
|
||||
|
||||
const wsUrl = `${url}?token=${encodeURIComponent(wsToken)}`
|
||||
|
||||
try {
|
||||
ws = new WebSocket(wsUrl)
|
||||
|
||||
ws.onopen = () => {
|
||||
reconnectAttempts = 0
|
||||
startHeartbeat()
|
||||
onOpen?.()
|
||||
}
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
// 忽略心跳响应
|
||||
if (event.data === 'pong') {
|
||||
return
|
||||
}
|
||||
onMessage?.(event.data)
|
||||
}
|
||||
|
||||
ws.onerror = (error) => {
|
||||
console.error('WebSocket 错误:', error)
|
||||
onError?.(error)
|
||||
}
|
||||
|
||||
ws.onclose = () => {
|
||||
stopHeartbeat()
|
||||
onClose?.()
|
||||
attemptReconnect()
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('创建 WebSocket 连接失败:', error)
|
||||
attemptReconnect()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 断开连接
|
||||
*/
|
||||
function disconnect() {
|
||||
isManualDisconnect = true
|
||||
|
||||
if (reconnectTimeout !== null) {
|
||||
clearTimeout(reconnectTimeout)
|
||||
reconnectTimeout = null
|
||||
}
|
||||
|
||||
stopHeartbeat()
|
||||
|
||||
if (ws) {
|
||||
ws.close()
|
||||
ws = null
|
||||
}
|
||||
|
||||
reconnectAttempts = 0
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送消息
|
||||
*/
|
||||
function send(data: string) {
|
||||
if (ws?.readyState === WebSocket.OPEN) {
|
||||
ws.send(data)
|
||||
} else {
|
||||
console.warn('WebSocket 未连接,无法发送消息')
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前 WebSocket 实例
|
||||
*/
|
||||
function getWebSocket(): WebSocket | null {
|
||||
return ws
|
||||
}
|
||||
|
||||
return {
|
||||
connect,
|
||||
disconnect,
|
||||
send,
|
||||
getWebSocket,
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import { Button } from '@/components/ui/button'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { ScrollArea } from '@/components/ui/scroll-area'
|
||||
import { useToast } from '@/hooks/use-toast'
|
||||
import { getWsBaseUrl } from '@/lib/api-base'
|
||||
import { chatWsClient } from '@/lib/chat-ws-client'
|
||||
import { fetchWithAuth } from '@/lib/fetch-with-auth'
|
||||
import { cn } from '@/lib/utils'
|
||||
import { Bot, Edit2, Loader2, RefreshCw, User, Send, Wifi, WifiOff, UserCircle2 } from 'lucide-react'
|
||||
@@ -85,14 +85,17 @@ export function ChatPage() {
|
||||
// 持久化用户 ID
|
||||
const userIdRef = useRef(getOrCreateUserId())
|
||||
|
||||
// 每个标签页的 WebSocket 连接
|
||||
const wsMapRef = useRef<Map<string, WebSocket>>(new Map())
|
||||
const messagesEndRef = useRef<HTMLDivElement>(null)
|
||||
const reconnectTimeoutMapRef = useRef<Map<string, number>>(new Map())
|
||||
const messageIdCounterRef = useRef(0)
|
||||
const processedMessagesMapRef = useRef<Map<string, Set<string>>>(new Map())
|
||||
const sessionUnsubscribeMapRef = useRef<Map<string, () => void>>(new Map())
|
||||
const tabsRef = useRef<ChatTab[]>([])
|
||||
const { toast } = useToast()
|
||||
|
||||
useEffect(() => {
|
||||
tabsRef.current = tabs
|
||||
}, [tabs])
|
||||
|
||||
// 生成唯一消息 ID
|
||||
const generateMessageId = (prefix: string) => {
|
||||
messageIdCounterRef.current += 1
|
||||
@@ -197,357 +200,218 @@ export function ChatPage() {
|
||||
}
|
||||
}, [tempVirtualConfig.platform, personSearchQuery, fetchPersons])
|
||||
|
||||
// 加载聊天历史到指定标签页
|
||||
const loadChatHistoryForTab = useCallback(async (tabId: string, groupId?: string) => {
|
||||
const handleSessionMessage = useCallback((
|
||||
tabId: string,
|
||||
tabType: 'webui' | 'virtual',
|
||||
config: VirtualIdentityConfig | undefined,
|
||||
data: WsMessage,
|
||||
) => {
|
||||
switch (data.type) {
|
||||
case 'session_info':
|
||||
updateTab(tabId, {
|
||||
sessionInfo: {
|
||||
session_id: data.session_id,
|
||||
user_id: data.user_id,
|
||||
user_name: data.user_name,
|
||||
bot_name: data.bot_name,
|
||||
}
|
||||
})
|
||||
break
|
||||
|
||||
case 'system':
|
||||
addMessageToTab(tabId, {
|
||||
id: generateMessageId('sys'),
|
||||
type: 'system',
|
||||
content: data.content || '',
|
||||
timestamp: data.timestamp || Date.now() / 1000,
|
||||
})
|
||||
break
|
||||
|
||||
case 'user_message': {
|
||||
const senderUserId = data.sender?.user_id
|
||||
const currentUserId = tabType === 'virtual' && config
|
||||
? config.userId
|
||||
: userIdRef.current
|
||||
|
||||
const normalizeSenderId = senderUserId ? senderUserId.replace(/^webui_user_/, '') : ''
|
||||
const normalizeCurrentId = currentUserId ? currentUserId.replace(/^webui_user_/, '') : ''
|
||||
if (normalizeSenderId && normalizeCurrentId && normalizeSenderId === normalizeCurrentId) {
|
||||
break
|
||||
}
|
||||
|
||||
const processedSet = processedMessagesMapRef.current.get(tabId) || new Set()
|
||||
const contentHash = `user-${data.content}-${Math.floor((data.timestamp || 0) * 1000)}`
|
||||
if (processedSet.has(contentHash)) {
|
||||
break
|
||||
}
|
||||
|
||||
processedSet.add(contentHash)
|
||||
processedMessagesMapRef.current.set(tabId, processedSet)
|
||||
if (processedSet.size > 100) {
|
||||
const firstKey = processedSet.values().next().value
|
||||
if (firstKey) processedSet.delete(firstKey)
|
||||
}
|
||||
|
||||
addMessageToTab(tabId, {
|
||||
id: data.message_id || generateMessageId('user'),
|
||||
type: 'user',
|
||||
content: data.content || '',
|
||||
timestamp: data.timestamp || Date.now() / 1000,
|
||||
sender: data.sender,
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
case 'bot_message': {
|
||||
updateTab(tabId, { isTyping: false })
|
||||
const processedSet = processedMessagesMapRef.current.get(tabId) || new Set()
|
||||
const contentHash = `bot-${data.content}-${Math.floor((data.timestamp || 0) * 1000)}`
|
||||
if (processedSet.has(contentHash)) {
|
||||
break
|
||||
}
|
||||
|
||||
processedSet.add(contentHash)
|
||||
processedMessagesMapRef.current.set(tabId, processedSet)
|
||||
if (processedSet.size > 100) {
|
||||
const firstKey = processedSet.values().next().value
|
||||
if (firstKey) processedSet.delete(firstKey)
|
||||
}
|
||||
|
||||
setTabs(prev => prev.map(tab => {
|
||||
if (tab.id !== tabId) return tab
|
||||
const filteredMessages = tab.messages.filter(msg => msg.type !== 'thinking')
|
||||
const newMessage: ChatMessage = {
|
||||
id: generateMessageId('bot'),
|
||||
type: 'bot',
|
||||
content: data.content || '',
|
||||
message_type: (data.message_type === 'rich' ? 'rich' : 'text') as 'text' | 'rich',
|
||||
segments: data.segments,
|
||||
timestamp: data.timestamp || Date.now() / 1000,
|
||||
sender: data.sender,
|
||||
}
|
||||
return {
|
||||
...tab,
|
||||
messages: [...filteredMessages, newMessage]
|
||||
}
|
||||
}))
|
||||
break
|
||||
}
|
||||
|
||||
case 'typing':
|
||||
updateTab(tabId, { isTyping: data.is_typing || false })
|
||||
break
|
||||
|
||||
case 'error':
|
||||
setTabs(prev => prev.map(tab => {
|
||||
if (tab.id !== tabId) return tab
|
||||
const filteredMessages = tab.messages.filter(msg => msg.type !== 'thinking')
|
||||
return {
|
||||
...tab,
|
||||
messages: [...filteredMessages, {
|
||||
id: generateMessageId('error'),
|
||||
type: 'error' as const,
|
||||
content: data.content || '发生错误',
|
||||
timestamp: data.timestamp || Date.now() / 1000,
|
||||
}]
|
||||
}
|
||||
}))
|
||||
toast({
|
||||
title: '错误',
|
||||
description: data.content,
|
||||
variant: 'destructive',
|
||||
})
|
||||
break
|
||||
|
||||
case 'history': {
|
||||
const historyMessages = data.messages || []
|
||||
const processedSet = new Set<string>()
|
||||
const formattedMessages: ChatMessage[] = historyMessages.map((msg: {
|
||||
id?: string
|
||||
content: string
|
||||
timestamp: number
|
||||
sender_name?: string
|
||||
sender_id?: string
|
||||
is_bot?: boolean
|
||||
}) => {
|
||||
const isBot = msg.is_bot || false
|
||||
const msgId = msg.id || generateMessageId(isBot ? 'bot' : 'user')
|
||||
const contentHash = `${isBot ? 'bot' : 'user'}-${msg.content}-${Math.floor(msg.timestamp * 1000)}`
|
||||
processedSet.add(contentHash)
|
||||
return {
|
||||
id: msgId,
|
||||
type: isBot ? 'bot' : 'user' as const,
|
||||
content: msg.content,
|
||||
timestamp: msg.timestamp,
|
||||
sender: {
|
||||
name: msg.sender_name || (isBot ? '麦麦' : '用户'),
|
||||
user_id: msg.sender_id,
|
||||
is_bot: isBot,
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
processedMessagesMapRef.current.set(tabId, processedSet)
|
||||
updateTab(tabId, { messages: formattedMessages })
|
||||
setIsLoadingHistory(false)
|
||||
break
|
||||
}
|
||||
|
||||
default:
|
||||
break
|
||||
}
|
||||
}, [addMessageToTab, toast, updateTab])
|
||||
|
||||
const ensureSessionListener = useCallback((
|
||||
tabId: string,
|
||||
tabType: 'webui' | 'virtual',
|
||||
config?: VirtualIdentityConfig,
|
||||
) => {
|
||||
if (sessionUnsubscribeMapRef.current.has(tabId)) {
|
||||
return
|
||||
}
|
||||
|
||||
const unsubscribe = chatWsClient.onSessionMessage(tabId, (message) => {
|
||||
handleSessionMessage(tabId, tabType, config, message as unknown as WsMessage)
|
||||
})
|
||||
sessionUnsubscribeMapRef.current.set(tabId, unsubscribe)
|
||||
}, [handleSessionMessage])
|
||||
|
||||
const openSessionForTab = useCallback(async (
|
||||
tabId: string,
|
||||
tabType: 'webui' | 'virtual',
|
||||
config?: VirtualIdentityConfig,
|
||||
) => {
|
||||
ensureSessionListener(tabId, tabType, config)
|
||||
setIsLoadingHistory(true)
|
||||
|
||||
try {
|
||||
const params = new URLSearchParams()
|
||||
params.append('user_id', userIdRef.current)
|
||||
params.append('limit', '50')
|
||||
if (groupId) {
|
||||
params.append('group_id', groupId)
|
||||
if (tabType === 'virtual' && config) {
|
||||
await chatWsClient.openSession(tabId, {
|
||||
user_id: config.userId,
|
||||
user_name: config.userName,
|
||||
platform: config.platform,
|
||||
person_id: config.personId,
|
||||
group_name: config.groupName || 'WebUI虚拟群聊',
|
||||
group_id: config.groupId,
|
||||
})
|
||||
} else {
|
||||
await chatWsClient.openSession(tabId, {
|
||||
user_id: userIdRef.current,
|
||||
user_name: userName,
|
||||
})
|
||||
}
|
||||
const url = `/api/chat/history?${params.toString()}`
|
||||
console.log('[Chat] 正在加载历史消息:', url)
|
||||
|
||||
const response = await fetchWithAuth(url)
|
||||
|
||||
if (response.ok) {
|
||||
const text = await response.text()
|
||||
try {
|
||||
const data = JSON.parse(text)
|
||||
|
||||
if (data.messages && data.messages.length > 0) {
|
||||
const historyMessages: ChatMessage[] = data.messages.map((msg: {
|
||||
id: string
|
||||
type: string
|
||||
content: string
|
||||
timestamp: number
|
||||
sender_name?: string
|
||||
user_id?: string
|
||||
is_bot?: boolean
|
||||
}) => ({
|
||||
id: msg.id,
|
||||
type: msg.type as 'user' | 'bot' | 'system' | 'error',
|
||||
content: msg.content,
|
||||
timestamp: msg.timestamp,
|
||||
sender: {
|
||||
name: msg.sender_name || (msg.is_bot ? '麦麦' : 'WebUI用户'),
|
||||
user_id: msg.user_id,
|
||||
is_bot: msg.is_bot
|
||||
}
|
||||
}))
|
||||
|
||||
// 更新标签页的消息
|
||||
updateTab(tabId, { messages: historyMessages })
|
||||
|
||||
// 将历史消息添加到去重缓存
|
||||
const processedSet = processedMessagesMapRef.current.get(tabId) || new Set()
|
||||
historyMessages.forEach(msg => {
|
||||
if (msg.type === 'bot') {
|
||||
const contentHash = `bot-${msg.content}-${Math.floor(msg.timestamp * 1000)}`
|
||||
processedSet.add(contentHash)
|
||||
}
|
||||
})
|
||||
processedMessagesMapRef.current.set(tabId, processedSet)
|
||||
}
|
||||
} catch (parseError) {
|
||||
console.error('[Chat] JSON 解析失败:', parseError)
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('[Chat] 加载历史消息失败:', e)
|
||||
} finally {
|
||||
setIsLoadingHistory(false)
|
||||
}
|
||||
}, [updateTab])
|
||||
|
||||
// 为指定标签页连接 WebSocket(异步,需要先获取认证 token)
|
||||
const connectWebSocketForTab = useCallback(async (tabId: string, tabType: 'webui' | 'virtual', config?: VirtualIdentityConfig) => {
|
||||
// 如果已经有连接,不要重复创建
|
||||
const existingWs = wsMapRef.current.get(tabId)
|
||||
if (existingWs?.readyState === WebSocket.OPEN ||
|
||||
existingWs?.readyState === WebSocket.CONNECTING) {
|
||||
console.log(`[Tab ${tabId}] WebSocket 已存在,跳过连接`)
|
||||
return
|
||||
}
|
||||
|
||||
setIsConnecting(true)
|
||||
|
||||
// 先获取临时 WebSocket token
|
||||
let wsToken: string | null = null
|
||||
try {
|
||||
const tokenResponse = await fetchWithAuth('/api/webui/ws-token')
|
||||
if (tokenResponse.ok) {
|
||||
const tokenData = await tokenResponse.json()
|
||||
if (tokenData.success && tokenData.token) {
|
||||
wsToken = tokenData.token
|
||||
} else {
|
||||
console.warn(`[Tab ${tabId}] 获取 WebSocket token 失败: ${tokenData.message || '未登录'}`)
|
||||
setIsConnecting(false)
|
||||
return
|
||||
}
|
||||
}
|
||||
updateTab(tabId, { isConnected: true })
|
||||
} catch (error) {
|
||||
console.error(`[Tab ${tabId}] 获取 WebSocket token 失败:`, error)
|
||||
setIsConnecting(false)
|
||||
return
|
||||
console.error(`[Tab ${tabId}] 打开聊天会话失败:`, error)
|
||||
setIsLoadingHistory(false)
|
||||
toast({
|
||||
title: '连接失败',
|
||||
description: '无法建立聊天会话,请稍后重试',
|
||||
variant: 'destructive',
|
||||
})
|
||||
}
|
||||
|
||||
// 此时 wsToken 一定有值(前面已经 return)
|
||||
if (!wsToken) {
|
||||
setIsConnecting(false)
|
||||
return
|
||||
}
|
||||
|
||||
const wsBase = await getWsBaseUrl()
|
||||
const params = new URLSearchParams()
|
||||
|
||||
// 添加 token 到参数
|
||||
params.append('token', wsToken)
|
||||
|
||||
if (tabType === 'virtual' && config) {
|
||||
params.append('user_id', config.userId)
|
||||
params.append('user_name', config.userName)
|
||||
params.append('platform', config.platform)
|
||||
params.append('person_id', config.personId)
|
||||
params.append('group_name', config.groupName || 'WebUI虚拟群聊')
|
||||
// 传递稳定的 group_id,确保历史记录能正确加载
|
||||
if (config.groupId) {
|
||||
params.append('group_id', config.groupId)
|
||||
}
|
||||
} else {
|
||||
params.append('user_id', userIdRef.current)
|
||||
params.append('user_name', userName)
|
||||
}
|
||||
|
||||
const wsUrl = `${wsBase}/api/chat/ws?${params.toString()}`
|
||||
console.log(`[Tab ${tabId}] 正在连接 WebSocket:`, wsUrl)
|
||||
|
||||
try {
|
||||
const ws = new WebSocket(wsUrl)
|
||||
wsMapRef.current.set(tabId, ws)
|
||||
|
||||
ws.onopen = () => {
|
||||
updateTab(tabId, { isConnected: true })
|
||||
setIsConnecting(false)
|
||||
console.log(`[Tab ${tabId}] WebSocket 已连接`)
|
||||
}
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
try {
|
||||
const data: WsMessage = JSON.parse(event.data)
|
||||
|
||||
switch (data.type) {
|
||||
case 'session_info':
|
||||
updateTab(tabId, {
|
||||
sessionInfo: {
|
||||
session_id: data.session_id,
|
||||
user_id: data.user_id,
|
||||
user_name: data.user_name,
|
||||
bot_name: data.bot_name,
|
||||
}
|
||||
})
|
||||
break
|
||||
|
||||
case 'system':
|
||||
addMessageToTab(tabId, {
|
||||
id: generateMessageId('sys'),
|
||||
type: 'system',
|
||||
content: data.content || '',
|
||||
timestamp: data.timestamp || Date.now() / 1000,
|
||||
})
|
||||
break
|
||||
|
||||
case 'user_message': {
|
||||
// 检查是否是自己发的消息(已在发送时显示,跳过广播回来的)
|
||||
const senderUserId = data.sender?.user_id
|
||||
const currentUserId = tabType === 'virtual' && config
|
||||
? config.userId
|
||||
: userIdRef.current
|
||||
|
||||
console.log(`[Tab ${tabId}] 收到 user_message, sender: ${senderUserId}, current: ${currentUserId}`)
|
||||
|
||||
// 标准化 user_id(去掉可能的前缀)
|
||||
const normalizeSenderId = senderUserId ? senderUserId.replace(/^webui_user_/, '') : ''
|
||||
const normalizeCurrentId = currentUserId ? currentUserId.replace(/^webui_user_/, '') : ''
|
||||
|
||||
// 如果是自己发的消息,跳过(避免重复显示)
|
||||
if (normalizeSenderId && normalizeCurrentId && normalizeSenderId === normalizeCurrentId) {
|
||||
console.log(`[Tab ${tabId}] 跳过自己的消息(user_id 匹配)`)
|
||||
break
|
||||
}
|
||||
|
||||
// 额外的消息去重:检查内容和时间戳
|
||||
const processedSet = processedMessagesMapRef.current.get(tabId) || new Set()
|
||||
const contentHash = `user-${data.content}-${Math.floor((data.timestamp || 0) * 1000)}`
|
||||
if (processedSet.has(contentHash)) {
|
||||
console.log(`[Tab ${tabId}] 跳过自己的消息(内容去重)`)
|
||||
break
|
||||
}
|
||||
processedSet.add(contentHash)
|
||||
processedMessagesMapRef.current.set(tabId, processedSet)
|
||||
|
||||
if (processedSet.size > 100) {
|
||||
const firstKey = processedSet.values().next().value
|
||||
if (firstKey) processedSet.delete(firstKey)
|
||||
}
|
||||
|
||||
addMessageToTab(tabId, {
|
||||
id: data.message_id || generateMessageId('user'),
|
||||
type: 'user',
|
||||
content: data.content || '',
|
||||
timestamp: data.timestamp || Date.now() / 1000,
|
||||
sender: data.sender,
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
case 'bot_message': {
|
||||
updateTab(tabId, { isTyping: false })
|
||||
const processedSet = processedMessagesMapRef.current.get(tabId) || new Set()
|
||||
const contentHash = `bot-${data.content}-${Math.floor((data.timestamp || 0) * 1000)}`
|
||||
if (processedSet.has(contentHash)) {
|
||||
break
|
||||
}
|
||||
processedSet.add(contentHash)
|
||||
processedMessagesMapRef.current.set(tabId, processedSet)
|
||||
|
||||
if (processedSet.size > 100) {
|
||||
const firstKey = processedSet.values().next().value
|
||||
if (firstKey) processedSet.delete(firstKey)
|
||||
}
|
||||
|
||||
// 移除"思考中"占位消息,添加真实的机器人回复
|
||||
setTabs(prev => prev.map(tab => {
|
||||
if (tab.id !== tabId) return tab
|
||||
// 过滤掉 thinking 类型的消息
|
||||
const filteredMessages = tab.messages.filter(msg => msg.type !== 'thinking')
|
||||
const newMessage: ChatMessage = {
|
||||
id: generateMessageId('bot'),
|
||||
type: 'bot',
|
||||
content: data.content || '',
|
||||
message_type: (data.message_type === 'rich' ? 'rich' : 'text') as 'text' | 'rich',
|
||||
segments: data.segments,
|
||||
timestamp: data.timestamp || Date.now() / 1000,
|
||||
sender: data.sender,
|
||||
}
|
||||
return {
|
||||
...tab,
|
||||
messages: [...filteredMessages, newMessage]
|
||||
}
|
||||
}))
|
||||
break
|
||||
}
|
||||
|
||||
case 'typing':
|
||||
updateTab(tabId, { isTyping: data.is_typing || false })
|
||||
break
|
||||
|
||||
case 'error':
|
||||
// 移除"思考中"占位消息,显示错误
|
||||
setTabs(prev => prev.map(tab => {
|
||||
if (tab.id !== tabId) return tab
|
||||
const filteredMessages = tab.messages.filter(msg => msg.type !== 'thinking')
|
||||
return {
|
||||
...tab,
|
||||
messages: [...filteredMessages, {
|
||||
id: generateMessageId('error'),
|
||||
type: 'error' as const,
|
||||
content: data.content || '发生错误',
|
||||
timestamp: data.timestamp || Date.now() / 1000,
|
||||
}]
|
||||
}
|
||||
}))
|
||||
toast({
|
||||
title: '错误',
|
||||
description: data.content,
|
||||
variant: 'destructive',
|
||||
})
|
||||
break
|
||||
|
||||
case 'pong':
|
||||
break
|
||||
|
||||
case 'history': {
|
||||
// 处理服务端发送的历史消息
|
||||
const historyMessages = data.messages || []
|
||||
if (historyMessages.length > 0) {
|
||||
const processedSet = processedMessagesMapRef.current.get(tabId) || new Set()
|
||||
const formattedMessages: ChatMessage[] = historyMessages.map((msg: {
|
||||
id?: string
|
||||
content: string
|
||||
timestamp: number
|
||||
sender_name?: string
|
||||
sender_id?: string
|
||||
is_bot?: boolean
|
||||
}) => {
|
||||
const isBot = msg.is_bot || false
|
||||
const msgId = msg.id || generateMessageId(isBot ? 'bot' : 'user')
|
||||
// 添加到去重集合
|
||||
const contentHash = `${isBot ? 'bot' : 'user'}-${msg.content}-${Math.floor(msg.timestamp * 1000)}`
|
||||
processedSet.add(contentHash)
|
||||
return {
|
||||
id: msgId,
|
||||
type: isBot ? 'bot' : 'user' as const,
|
||||
content: msg.content,
|
||||
timestamp: msg.timestamp,
|
||||
sender: {
|
||||
name: msg.sender_name || (isBot ? '麦麦' : '用户'),
|
||||
user_id: msg.sender_id,
|
||||
is_bot: isBot,
|
||||
},
|
||||
}
|
||||
})
|
||||
processedMessagesMapRef.current.set(tabId, processedSet)
|
||||
// 替换当前标签页的所有消息
|
||||
updateTab(tabId, { messages: formattedMessages })
|
||||
console.log(`[Tab ${tabId}] 已加载 ${formattedMessages.length} 条历史消息`)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
default:
|
||||
console.log('未知消息类型:', data.type)
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('解析消息失败:', e)
|
||||
}
|
||||
}
|
||||
|
||||
ws.onclose = () => {
|
||||
updateTab(tabId, { isConnected: false })
|
||||
setIsConnecting(false)
|
||||
wsMapRef.current.delete(tabId)
|
||||
console.log(`[Tab ${tabId}] WebSocket 已断开`)
|
||||
|
||||
// 清除旧的重连定时器
|
||||
const oldTimeout = reconnectTimeoutMapRef.current.get(tabId)
|
||||
if (oldTimeout) {
|
||||
clearTimeout(oldTimeout)
|
||||
}
|
||||
|
||||
// 5秒后尝试重连
|
||||
const timeout = window.setTimeout(() => {
|
||||
if (!isUnmountedRef.current) {
|
||||
const tab = tabs.find(t => t.id === tabId)
|
||||
if (tab) {
|
||||
connectWebSocketForTab(tabId, tab.type, tab.virtualConfig)
|
||||
}
|
||||
}
|
||||
}, 5000)
|
||||
reconnectTimeoutMapRef.current.set(tabId, timeout)
|
||||
}
|
||||
|
||||
ws.onerror = (error) => {
|
||||
console.error(`[Tab ${tabId}] WebSocket 错误:`, error)
|
||||
setIsConnecting(false)
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(`[Tab ${tabId}] 创建 WebSocket 失败:`, e)
|
||||
setIsConnecting(false)
|
||||
}
|
||||
}, [userName, updateTab, addMessageToTab, toast, tabs])
|
||||
}, [ensureSessionListener, toast, updateTab, userName])
|
||||
|
||||
// 用于追踪组件是否已卸载
|
||||
const isUnmountedRef = useRef(false)
|
||||
@@ -555,69 +419,49 @@ export function ChatPage() {
|
||||
// 初始化连接(默认 WebUI 标签页)
|
||||
useEffect(() => {
|
||||
isUnmountedRef.current = false
|
||||
|
||||
// 保存 ref 的当前值,用于清理
|
||||
const wsMap = wsMapRef.current
|
||||
const reconnectTimeoutMap = reconnectTimeoutMapRef.current
|
||||
const processedMessagesMap = processedMessagesMapRef.current
|
||||
|
||||
// 加载默认标签页历史消息
|
||||
loadChatHistoryForTab('webui-default')
|
||||
|
||||
// 延迟连接
|
||||
const connectTimer = setTimeout(() => {
|
||||
if (!isUnmountedRef.current) {
|
||||
connectWebSocketForTab('webui-default', 'webui')
|
||||
|
||||
// 恢复的虚拟标签页也需要建立连接
|
||||
tabs.forEach(tab => {
|
||||
if (tab.type === 'virtual' && tab.virtualConfig) {
|
||||
// 初始化去重缓存
|
||||
processedMessagesMap.set(tab.id, new Set())
|
||||
// 建立 WebSocket 连接
|
||||
setTimeout(() => {
|
||||
if (!isUnmountedRef.current) {
|
||||
connectWebSocketForTab(tab.id, 'virtual', tab.virtualConfig)
|
||||
}
|
||||
}, 200)
|
||||
}
|
||||
})
|
||||
}
|
||||
}, 100)
|
||||
|
||||
// 心跳定时器 - 向所有活动连接发送
|
||||
const heartbeat = setInterval(() => {
|
||||
wsMap.forEach((ws) => {
|
||||
if (ws.readyState === WebSocket.OPEN) {
|
||||
ws.send(JSON.stringify({ type: 'ping' }))
|
||||
}
|
||||
})
|
||||
}, 30000)
|
||||
const unsubscribeConnection = chatWsClient.onConnectionChange((connected) => {
|
||||
if (isUnmountedRef.current) {
|
||||
return
|
||||
}
|
||||
|
||||
setTabs(prev => prev.map(tab => ({
|
||||
...tab,
|
||||
isConnected: connected,
|
||||
})))
|
||||
})
|
||||
|
||||
const unsubscribeStatus = chatWsClient.onStatusChange((status) => {
|
||||
if (!isUnmountedRef.current) {
|
||||
setIsConnecting(status === 'connecting')
|
||||
}
|
||||
})
|
||||
|
||||
tabs.forEach(tab => {
|
||||
processedMessagesMapRef.current.set(tab.id, new Set())
|
||||
void openSessionForTab(tab.id, tab.type, tab.virtualConfig)
|
||||
})
|
||||
|
||||
return () => {
|
||||
isUnmountedRef.current = true
|
||||
clearTimeout(connectTimer)
|
||||
clearInterval(heartbeat)
|
||||
|
||||
// 清理所有重连定时器
|
||||
reconnectTimeoutMap.forEach((timeout) => {
|
||||
clearTimeout(timeout)
|
||||
unsubscribeConnection()
|
||||
unsubscribeStatus()
|
||||
|
||||
sessionUnsubscribeMapRef.current.forEach((unsubscribe) => {
|
||||
unsubscribe()
|
||||
})
|
||||
reconnectTimeoutMap.clear()
|
||||
|
||||
// 关闭所有 WebSocket 连接
|
||||
wsMap.forEach((ws) => {
|
||||
ws.close()
|
||||
sessionUnsubscribeMapRef.current.clear()
|
||||
|
||||
tabsRef.current.forEach(tab => {
|
||||
void chatWsClient.closeSession(tab.id)
|
||||
})
|
||||
wsMap.clear()
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [])
|
||||
|
||||
// 发送消息到当前活动标签页
|
||||
const sendMessage = useCallback(() => {
|
||||
const ws = wsMapRef.current.get(activeTabId)
|
||||
if (!inputValue.trim() || !ws || ws.readyState !== WebSocket.OPEN) {
|
||||
const sendMessage = useCallback(async () => {
|
||||
if (!inputValue.trim() || !activeTab?.isConnected) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -628,12 +472,6 @@ export function ChatPage() {
|
||||
const messageContent = inputValue.trim()
|
||||
const currentTimestamp = Date.now() / 1000
|
||||
|
||||
ws.send(JSON.stringify({
|
||||
type: 'message',
|
||||
content: messageContent,
|
||||
user_name: displayName,
|
||||
}))
|
||||
|
||||
// 添加到去重缓存,防止服务器广播回来的消息重复显示
|
||||
const processedSet = processedMessagesMapRef.current.get(activeTabId) || new Set()
|
||||
const contentHash = `user-${messageContent}-${Math.floor(currentTimestamp * 1000)}`
|
||||
@@ -672,13 +510,32 @@ export function ChatPage() {
|
||||
addMessageToTab(activeTabId, thinkingMessage)
|
||||
|
||||
setInputValue('')
|
||||
}, [inputValue, userName, activeTabId, activeTab, addMessageToTab])
|
||||
|
||||
try {
|
||||
await chatWsClient.sendMessage(activeTabId, messageContent, displayName)
|
||||
} catch (error) {
|
||||
console.error('发送聊天消息失败:', error)
|
||||
setTabs(prev => prev.map(tab => {
|
||||
if (tab.id !== activeTabId) return tab
|
||||
return {
|
||||
...tab,
|
||||
isTyping: false,
|
||||
messages: tab.messages.filter(msg => msg.type !== 'thinking')
|
||||
}
|
||||
}))
|
||||
toast({
|
||||
title: '发送失败',
|
||||
description: '当前聊天会话不可用,请稍后重试',
|
||||
variant: 'destructive',
|
||||
})
|
||||
}
|
||||
}, [activeTab, activeTabId, addMessageToTab, inputValue, toast, userName])
|
||||
|
||||
// 处理键盘事件
|
||||
const handleKeyDown = (e: React.KeyboardEvent) => {
|
||||
if (e.key === 'Enter' && !e.shiftKey) {
|
||||
e.preventDefault()
|
||||
sendMessage()
|
||||
void sendMessage()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -693,13 +550,9 @@ export function ChatPage() {
|
||||
setUserName(newName)
|
||||
saveUserName(newName)
|
||||
setIsEditingName(false)
|
||||
// 通知当前标签页的后端昵称变更
|
||||
const ws = wsMapRef.current.get(activeTabId)
|
||||
if (ws?.readyState === WebSocket.OPEN) {
|
||||
ws.send(JSON.stringify({
|
||||
type: 'update_nickname',
|
||||
user_name: newName
|
||||
}))
|
||||
|
||||
if (activeTab?.isConnected) {
|
||||
void chatWsClient.updateNickname(activeTabId, newName)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -719,12 +572,7 @@ export function ChatPage() {
|
||||
|
||||
// 重新连接当前标签页
|
||||
const handleReconnect = () => {
|
||||
const ws = wsMapRef.current.get(activeTabId)
|
||||
if (ws) {
|
||||
ws.close()
|
||||
wsMapRef.current.delete(activeTabId)
|
||||
}
|
||||
connectWebSocketForTab(activeTabId, activeTab?.type || 'webui', activeTab?.virtualConfig)
|
||||
void chatWsClient.restart()
|
||||
}
|
||||
|
||||
// 打开虚拟身份配置对话框(新建标签页用)
|
||||
@@ -795,10 +643,10 @@ export function ChatPage() {
|
||||
// 初始化去重缓存
|
||||
processedMessagesMapRef.current.set(newTabId, new Set())
|
||||
|
||||
// 连接 WebSocket
|
||||
setTimeout(() => {
|
||||
connectWebSocketForTab(newTabId, 'virtual', tempVirtualConfig)
|
||||
}, 100)
|
||||
void openSessionForTab(newTabId, 'virtual', {
|
||||
...tempVirtualConfig,
|
||||
groupId: stableGroupId,
|
||||
})
|
||||
|
||||
toast({
|
||||
title: '虚拟身份标签页',
|
||||
@@ -814,20 +662,14 @@ export function ChatPage() {
|
||||
if (tabId === 'webui-default') {
|
||||
return
|
||||
}
|
||||
|
||||
// 关闭 WebSocket 连接
|
||||
const ws = wsMapRef.current.get(tabId)
|
||||
if (ws) {
|
||||
ws.close()
|
||||
wsMapRef.current.delete(tabId)
|
||||
}
|
||||
|
||||
// 清理重连定时器
|
||||
const timeout = reconnectTimeoutMapRef.current.get(tabId)
|
||||
if (timeout) {
|
||||
clearTimeout(timeout)
|
||||
reconnectTimeoutMapRef.current.delete(tabId)
|
||||
|
||||
const unsubscribe = sessionUnsubscribeMapRef.current.get(tabId)
|
||||
if (unsubscribe) {
|
||||
unsubscribe()
|
||||
sessionUnsubscribeMapRef.current.delete(tabId)
|
||||
}
|
||||
|
||||
void chatWsClient.closeSession(tabId)
|
||||
|
||||
// 清理去重缓存
|
||||
processedMessagesMapRef.current.delete(tabId)
|
||||
@@ -1133,7 +975,7 @@ export function ChatPage() {
|
||||
className="flex-1 h-10 sm:h-10"
|
||||
/>
|
||||
<Button
|
||||
onClick={sendMessage}
|
||||
onClick={() => { void sendMessage() }}
|
||||
disabled={!activeTab?.isConnected || !inputValue.trim()}
|
||||
size="icon"
|
||||
className="h-10 w-10 shrink-0"
|
||||
|
||||
@@ -93,12 +93,12 @@ function PluginsPageContent() {
|
||||
|
||||
// 统一管理 WebSocket 和数据加载
|
||||
useEffect(() => {
|
||||
let ws: WebSocket | null = null
|
||||
let unsubscribeProgress: (() => Promise<void>) | null = null
|
||||
let isUnmounted = false
|
||||
|
||||
const init = async () => {
|
||||
// 1. 先连接 WebSocket(异步获取 token)
|
||||
ws = await connectPluginProgressWebSocket(
|
||||
unsubscribeProgress = await connectPluginProgressWebSocket(
|
||||
(progress) => {
|
||||
if (isUnmounted) return
|
||||
|
||||
@@ -128,29 +128,7 @@ function PluginsPageContent() {
|
||||
}
|
||||
)
|
||||
|
||||
// 2. 等待 WebSocket 连接建立
|
||||
await new Promise<void>((resolve) => {
|
||||
if (!ws) {
|
||||
resolve()
|
||||
return
|
||||
}
|
||||
|
||||
const checkConnection = () => {
|
||||
if (ws && ws.readyState === WebSocket.OPEN) {
|
||||
console.log('WebSocket connected, starting to load plugins')
|
||||
resolve()
|
||||
} else if (ws && ws.readyState === WebSocket.CLOSED) {
|
||||
console.warn('WebSocket closed before loading plugins')
|
||||
resolve()
|
||||
} else {
|
||||
setTimeout(checkConnection, 100)
|
||||
}
|
||||
}
|
||||
|
||||
checkConnection()
|
||||
})
|
||||
|
||||
// 3. 检查 Git 状态
|
||||
// 2. 检查 Git 状态
|
||||
if (!isUnmounted) {
|
||||
const statusResult = await checkGitStatus()
|
||||
if (!statusResult.success) {
|
||||
@@ -173,7 +151,7 @@ function PluginsPageContent() {
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 获取麦麦版本
|
||||
// 3. 获取麦麦版本
|
||||
if (!isUnmounted) {
|
||||
const versionResult = await getMaimaiVersion()
|
||||
if (!versionResult.success) {
|
||||
@@ -186,7 +164,7 @@ function PluginsPageContent() {
|
||||
setMaimaiVersion(versionResult.data)
|
||||
}
|
||||
}
|
||||
// 5. 加载插件列表(包含已安装信息)
|
||||
// 4. 加载插件列表(包含已安装信息)
|
||||
if (!isUnmounted) {
|
||||
try {
|
||||
setLoading(true)
|
||||
@@ -282,8 +260,8 @@ function PluginsPageContent() {
|
||||
|
||||
return () => {
|
||||
isUnmounted = true
|
||||
if (ws) {
|
||||
ws.close()
|
||||
if (unsubscribeProgress) {
|
||||
void unsubscribeProgress()
|
||||
}
|
||||
}
|
||||
}, [toast])
|
||||
|
||||
@@ -1,887 +0,0 @@
|
||||
{
|
||||
"1": [
|
||||
{
|
||||
"id": "know_1_1774770946.623486",
|
||||
"content": "备战中考",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:55:46.623486"
|
||||
},
|
||||
{
|
||||
"id": "know_1_1774771765.051286",
|
||||
"content": "性别为女性",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:09:25.051286"
|
||||
},
|
||||
{
|
||||
"id": "know_1_1774771851.333504",
|
||||
"content": "用户是I人(内向型人格)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:10:51.333504"
|
||||
},
|
||||
{
|
||||
"id": "know_1_1774771894.517183",
|
||||
"content": "用户名为小千,被他人称为“宝宝”,结合语境推测为女性",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:11:34.517183"
|
||||
},
|
||||
{
|
||||
"id": "know_1_1774771923.859455",
|
||||
"content": "小千是I人(内向型人格)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:12:03.859455"
|
||||
},
|
||||
{
|
||||
"id": "know_1_1774771993.479732",
|
||||
"content": "小千是女性",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:13:13.479732"
|
||||
},
|
||||
{
|
||||
"id": "know_1_1774772079.496335",
|
||||
"content": "用户名为小千,被他人称为“宝宝”,推测为女性或处于亲密社交语境中(注:性别非明确陈述,但基于昵称高频使用及语境,高置信度归纳为女性或女性化称呼偏好,若严格遵循“明确表达”则此项存疑。鉴于指令要求“高置信度可归纳”,且群内互动模式符合典型女性向昵称习惯,此处提取为倾向性事实)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:14:39.496335"
|
||||
},
|
||||
{
|
||||
"id": "know_1_1774773435.68612",
|
||||
"content": "用户名为小千",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:37:15.686120"
|
||||
},
|
||||
{
|
||||
"id": "know_1_1774773676.69252",
|
||||
"content": "用户自称猫娘(二次元人设)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:41:16.692520"
|
||||
}
|
||||
],
|
||||
"2": [
|
||||
{
|
||||
"id": "know_2_1774768612.298128",
|
||||
"content": "性格自信,常以“真理在我这边”自居",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:16:52.298128"
|
||||
},
|
||||
{
|
||||
"id": "know_2_1774768645.029561",
|
||||
"content": "性格自信且带有自嘲精神,喜欢用轻松调侃的方式应对他人评价",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:17:25.029561"
|
||||
},
|
||||
{
|
||||
"id": "know_2_1774771068.355999",
|
||||
"content": "喜欢用夸张、幽默或古风修辞表达观点",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:57:48.355999"
|
||||
},
|
||||
{
|
||||
"id": "know_2_1774771397.764996",
|
||||
"content": "性格幽默,喜欢使用夸张比喻和古风表达",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:03:17.764996"
|
||||
},
|
||||
{
|
||||
"id": "know_2_1774771471.03367",
|
||||
"content": "幽默风趣,喜欢使用夸张比喻和玩梗",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:04:31.033670"
|
||||
},
|
||||
{
|
||||
"id": "know_2_1774771765.052285",
|
||||
"content": "性格不孤僻,社交圈较广",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:09:25.052285"
|
||||
},
|
||||
{
|
||||
"id": "know_2_1774771851.33601",
|
||||
"content": "用户表现出社恐倾向,喜欢回避社交互动",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:10:51.336010"
|
||||
},
|
||||
{
|
||||
"id": "know_2_1774771894.520185",
|
||||
"content": "性格偏向内向(I人),有社恐倾向,喜欢回避社交压力",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:11:34.520185"
|
||||
},
|
||||
{
|
||||
"id": "know_2_1774771958.585244",
|
||||
"content": "小千是内向型人格(I人)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:12:38.585244"
|
||||
},
|
||||
{
|
||||
"id": "know_2_1774771993.481732",
|
||||
"content": "小千性格内向(I人)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:13:13.481732"
|
||||
}
|
||||
],
|
||||
"3": [
|
||||
{
|
||||
"id": "know_3_1774773676.695521",
|
||||
"content": "喜欢冰淇淋",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:41:16.695521"
|
||||
}
|
||||
],
|
||||
"4": [],
|
||||
"5": [],
|
||||
"6": [
|
||||
{
|
||||
"id": "know_6_1774768486.451792",
|
||||
"content": "正在搭建 RAG 测试集",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:14:46.451792"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774768517.122405",
|
||||
"content": "熟悉 NapCat、RAG 等技术工具及互联网梗文化",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:15:17.122405"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774769406.247087",
|
||||
"content": "喜欢动漫风格插画",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:30:06.247087"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774770487.207364",
|
||||
"content": "关注显卡硬件参数(如显存、型号)及深度学习/炼丹应用",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:48:07.207364"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774770487.209372",
|
||||
"content": "对游戏光影效果感兴趣",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:48:07.209372"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774770603.063873",
|
||||
"content": "喜欢玩《我的世界》和VRChat",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:50:03.063873"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774770654.654349",
|
||||
"content": "关注显卡硬件参数(如4090、48G显存、5090)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:50:54.654349"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774770654.655356",
|
||||
"content": "使用VRChat进行社交娱乐",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:50:54.655356"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774770734.287947",
|
||||
"content": "关注显卡硬件(如4090、3050)及AI炼丹技术",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:52:14.287947"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774770734.289944",
|
||||
"content": "玩《我的世界》并配置光影效果",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:52:14.289944"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774770734.291944",
|
||||
"content": "计划游玩VRChat",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:52:14.291944"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774771033.111011",
|
||||
"content": "喜欢玩VRChat",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:57:13.111011"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774771068.358999",
|
||||
"content": "关注VRChat等虚拟现实游戏及硬件性能话题",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:57:48.358999"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774771233.980219",
|
||||
"content": "使用VRChat(VRC)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:00:33.980219"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774771397.766996",
|
||||
"content": "对VRChat(VRC)及虚拟形象社交感兴趣",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:03:17.766996"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774771471.03567",
|
||||
"content": "对VRChat等虚拟社交游戏感兴趣",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:04:31.035670"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774771894.521183",
|
||||
"content": "熟悉二次元文化、动漫角色及互联网流行梗(Meme)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:11:34.521183"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774771923.861534",
|
||||
"content": "小千玩CS:GO游戏",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:12:03.861534"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774771958.587243",
|
||||
"content": "回声者_Echoderd喜欢玩CS:GO游戏",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:12:38.587243"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774771993.483732",
|
||||
"content": "小千喜欢二次元文化及动漫游戏圈梗",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:13:13.483732"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774772079.499335",
|
||||
"content": "熟悉并喜爱二次元文化、动漫角色及互联网梗图(如阴间美学、病娇系、黑长直萌妹等风格)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:14:39.499335"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774772112.716455",
|
||||
"content": "小千关注CS:GO游戏及中考备考话题",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:15:12.716455"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774772154.873237",
|
||||
"content": "用户玩CS:GO游戏",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:15:54.873237"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774772186.438797",
|
||||
"content": "玩CS:GO游戏",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:16:26.438797"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774772730.867535",
|
||||
"content": "熟悉《我的青春恋爱物语果然有问题》及二次元表情包文化",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:25:30.867535"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774773338.849271",
|
||||
"content": "熟悉《原神》等二次元游戏及网络梗文化",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:35:38.849271"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774773371.406209",
|
||||
"content": "关注高分屏字体显示效果",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:36:11.406209"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774773401.48921",
|
||||
"content": "熟悉电脑显示技术(如高分屏字体选择)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:36:41.489210"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774773435.688119",
|
||||
"content": "关注高分屏显示效果与字体选择(无衬线/衬线体)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:37:15.688119"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774773608.256103",
|
||||
"content": "关注屏幕字体与分辨率(无衬线/有衬线)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:40:08.256103"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774773645.671546",
|
||||
"content": "关注屏幕分辨率与字体显示效果(高分屏/无衬线体)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:40:45.671546"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774773676.698035",
|
||||
"content": "关注字体设计(无衬线体)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:41:16.698035"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774773740.83822",
|
||||
"content": "喜欢二次元文化及 VTuber 风格内容",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:42:20.838220"
|
||||
}
|
||||
],
|
||||
"7": [
|
||||
{
|
||||
"id": "know_7_1774768517.120403",
|
||||
"content": "从事 RAG 测试集搭建或相关技术工作",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:15:17.120403"
|
||||
},
|
||||
{
|
||||
"id": "know_7_1774768573.741823",
|
||||
"content": "从事 RAG(检索增强生成)测试集搭建相关工作",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:16:13.741823"
|
||||
},
|
||||
{
|
||||
"id": "know_7_1774770603.062873",
|
||||
"content": "备战中考",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:50:03.062873"
|
||||
},
|
||||
{
|
||||
"id": "know_7_1774771471.036668",
|
||||
"content": "正在备战中考的学生",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:04:31.036668"
|
||||
},
|
||||
{
|
||||
"id": "know_7_1774771923.862535",
|
||||
"content": "小千正在备战中考",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:12:03.862535"
|
||||
},
|
||||
{
|
||||
"id": "know_7_1774771958.588749",
|
||||
"content": "回声者_Echoderd正在备战中考",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:12:38.588749"
|
||||
},
|
||||
{
|
||||
"id": "know_7_1774772112.714455",
|
||||
"content": "小千使用AI模型进行对话",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:15:12.714455"
|
||||
},
|
||||
{
|
||||
"id": "know_7_1774772154.870238",
|
||||
"content": "用户正在备战中考",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:15:54.870238"
|
||||
},
|
||||
{
|
||||
"id": "know_7_1774773185.194069",
|
||||
"content": "使用 NapCat 框架",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:33:05.194069"
|
||||
},
|
||||
{
|
||||
"id": "know_7_1774773338.851275",
|
||||
"content": "使用 NapCat 框架,具备技术平台认知能力",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:35:38.851275"
|
||||
},
|
||||
{
|
||||
"id": "know_7_1774773371.403696",
|
||||
"content": "熟悉 NapCat 框架",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:36:11.403696"
|
||||
}
|
||||
],
|
||||
"8": [
|
||||
{
|
||||
"id": "know_8_1774770946.624486",
|
||||
"content": "日常逛游戏地图",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:55:46.624486"
|
||||
},
|
||||
{
|
||||
"id": "know_8_1774771397.769034",
|
||||
"content": "备考中考期间仍保持日常游戏娱乐习惯",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:03:17.769034"
|
||||
},
|
||||
{
|
||||
"id": "know_8_1774771851.338018",
|
||||
"content": "用户有备考中考的学习任务",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:10:51.338018"
|
||||
},
|
||||
{
|
||||
"id": "know_8_1774771894.523189",
|
||||
"content": "备考中(备战中考)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:11:34.523189"
|
||||
},
|
||||
{
|
||||
"id": "know_8_1774771993.484733",
|
||||
"content": "小千有打CS:GO的游戏习惯",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:13:13.484733"
|
||||
},
|
||||
{
|
||||
"id": "know_8_1774772079.501334",
|
||||
"content": "有在高压环境下(如中考前)进行游戏娱乐(CS:GO)的习惯,自称或认同“摆烂”的生活态度",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:14:39.501334"
|
||||
},
|
||||
{
|
||||
"id": "know_8_1774772154.875743",
|
||||
"content": "用户在备考期间有打游戏摸鱼的习惯",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:15:54.875743"
|
||||
},
|
||||
{
|
||||
"id": "know_8_1774773435.690121",
|
||||
"content": "习惯使用表情包表达情绪或进行网络互动",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:37:15.690121"
|
||||
},
|
||||
{
|
||||
"id": "know_8_1774773676.701034",
|
||||
"content": "备战中考",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:41:16.701034"
|
||||
}
|
||||
],
|
||||
"9": [],
|
||||
"10": [
|
||||
{
|
||||
"id": "know_10_1774768486.452792",
|
||||
"content": "沟通风格带有调侃和自信,习惯用反问句表达观点",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:14:46.452792"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774768517.121403",
|
||||
"content": "沟通风格带有较强的好胜心和防御性,习惯用反问和调侃回应质疑",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:15:17.121403"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774768573.742824",
|
||||
"content": "沟通风格幽默,擅长使用逻辑闭环和反问句式进行辩论或调侃",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:16:13.742824"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774768612.299126",
|
||||
"content": "沟通风格幽默风趣,擅长使用网络梗和表情包互动",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:16:52.299126"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774768612.299845",
|
||||
"content": "偶尔会文绉绉地表达(自称“文青病犯了”),但能迅速切换回口语化",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:16:52.299845"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774768645.028561",
|
||||
"content": "沟通风格幽默风趣,偶尔会文青病发作使用古风表达",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:17:25.028561"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774769406.249584",
|
||||
"content": "沟通中常使用文言文或半文言表达",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:30:06.249584"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774769406.251097",
|
||||
"content": "习惯用反问句和夸张语气进行互动",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:30:06.251097"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774770487.211056",
|
||||
"content": "沟通风格幽默,常使用网络梗和夸张表达",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:48:07.211056"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774771471.038677",
|
||||
"content": "沟通风格轻松随意,善于接话和调侃",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:04:31.038677"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774771765.053285",
|
||||
"content": "沟通风格活泼,喜欢使用语气词和表情符号撒娇",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:09:25.053285"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774772079.503333",
|
||||
"content": "沟通风格幽默调侃,擅长用反话(如“烦到了”)和夸张修辞(如“耳朵起茧子”、“要报警了”)表达情绪",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:14:39.503333"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774773338.853274",
|
||||
"content": "沟通风格幽默风趣,擅长玩梗与自嘲",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:35:38.853274"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774773371.408719",
|
||||
"content": "喜欢用幽默调侃的方式回应他人",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:36:11.408719"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774773401.491209",
|
||||
"content": "沟通风格幽默风趣,擅长玩梗和角色扮演",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:36:41.491209"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774773435.693121",
|
||||
"content": "沟通风格幽默、喜欢玩梗和自嘲,擅长接话茬",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:37:15.693121"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774773532.488374",
|
||||
"content": "沟通风格幽默,喜欢使用网络梗和表情包活跃气氛",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:38:52.488374"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774773532.490959",
|
||||
"content": "在争论中倾向于据理力争,并自嘲或调侃对方阅读理解能力",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:38:52.490959"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774773569.709356",
|
||||
"content": "喜欢用幽默、夸张和自嘲的方式活跃气氛",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:39:29.709356"
|
||||
}
|
||||
],
|
||||
"11": [
|
||||
{
|
||||
"id": "know_11_1774771068.360999",
|
||||
"content": "乐于接受并学习新的技术技巧(如加速器用法)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:57:48.360999"
|
||||
}
|
||||
],
|
||||
"12": [
|
||||
{
|
||||
"id": "know_12_1774770654.657355",
|
||||
"content": "面对网络延迟问题倾向于寻找加速器解决方案",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:50:54.657355"
|
||||
},
|
||||
{
|
||||
"id": "know_12_1774773185.196068",
|
||||
"content": "备战中考",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:33:05.196068"
|
||||
},
|
||||
{
|
||||
"id": "know_12_1774773740.836223",
|
||||
"content": "面对压力或冲突时,倾向于通过撒娇、耍赖和寻求盟友支持来应对",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:42:20.836223"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -13,12 +13,13 @@ import pytest
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
from src.plugin_runtime.protocol.envelope import (
|
||||
Envelope,
|
||||
InspectPluginConfigPayload,
|
||||
MessageType,
|
||||
RegisterPluginPayload,
|
||||
ValidatePluginConfigPayload,
|
||||
)
|
||||
from src.plugin_runtime.runner.runner_main import PluginRunner
|
||||
from src.webui.routers.plugin.config_routes import update_plugin_config
|
||||
from src.webui.routers.plugin.config_routes import get_plugin_config, get_plugin_config_schema, update_plugin_config
|
||||
from src.webui.routers.plugin.schemas import UpdatePluginConfigRequest
|
||||
|
||||
|
||||
@@ -56,6 +57,61 @@ class _DemoConfigPlugin:
|
||||
|
||||
self.received_config = config
|
||||
|
||||
def get_default_config(self) -> Dict[str, Any]:
|
||||
"""返回测试插件的默认配置。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 默认配置字典。
|
||||
"""
|
||||
|
||||
return {"plugin": {"enabled": True, "retry_count": 3}}
|
||||
|
||||
def get_webui_config_schema(
|
||||
self,
|
||||
*,
|
||||
plugin_id: str = "",
|
||||
plugin_name: str = "",
|
||||
plugin_version: str = "",
|
||||
plugin_description: str = "",
|
||||
plugin_author: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
"""返回测试插件的 WebUI 配置 Schema。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID。
|
||||
plugin_name: 插件名称。
|
||||
plugin_version: 插件版本。
|
||||
plugin_description: 插件描述。
|
||||
plugin_author: 插件作者。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 测试配置 Schema。
|
||||
"""
|
||||
|
||||
del plugin_name, plugin_description, plugin_author
|
||||
return {
|
||||
"plugin_id": plugin_id,
|
||||
"plugin_info": {
|
||||
"name": "Demo",
|
||||
"version": plugin_version,
|
||||
"description": "",
|
||||
"author": "",
|
||||
},
|
||||
"sections": {
|
||||
"plugin": {
|
||||
"fields": {
|
||||
"enabled": {
|
||||
"type": "boolean",
|
||||
"label": "启用",
|
||||
"default": True,
|
||||
"ui_type": "switch",
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"layout": {"type": "auto", "tabs": []},
|
||||
}
|
||||
|
||||
|
||||
class _StrictConfigPlugin:
|
||||
"""用于测试配置校验错误的伪插件。"""
|
||||
@@ -173,6 +229,63 @@ async def test_runner_validate_plugin_config_handler_returns_normalized_config(m
|
||||
assert response.payload["normalized_config"] == {"plugin": {"enabled": False, "retry_count": 3}}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_inspect_plugin_config_handler_supports_unloaded_plugin(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Runner 应支持对未加载插件执行冷检查。"""
|
||||
|
||||
plugin = _DemoConfigPlugin()
|
||||
runner = PluginRunner(
|
||||
host_address="ipc://unused",
|
||||
session_token="session-token",
|
||||
plugin_dirs=[],
|
||||
)
|
||||
meta = SimpleNamespace(
|
||||
plugin_id="demo.plugin",
|
||||
plugin_dir="/tmp/demo-plugin",
|
||||
instance=plugin,
|
||||
manifest=SimpleNamespace(
|
||||
name="Demo",
|
||||
description="",
|
||||
author=SimpleNamespace(name="tester"),
|
||||
),
|
||||
version="1.0.0",
|
||||
)
|
||||
purged_plugins: list[tuple[str, str]] = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
runner,
|
||||
"_resolve_plugin_meta_for_config_request",
|
||||
lambda plugin_id: (meta, True, None) if plugin_id == "demo.plugin" else (None, False, "not-found"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
runner._loader,
|
||||
"purge_plugin_modules",
|
||||
lambda plugin_id, plugin_dir: purged_plugins.append((plugin_id, plugin_dir)),
|
||||
)
|
||||
|
||||
envelope = Envelope(
|
||||
request_id=1,
|
||||
message_type=MessageType.REQUEST,
|
||||
method="plugin.inspect_config",
|
||||
plugin_id="demo.plugin",
|
||||
payload=InspectPluginConfigPayload(
|
||||
config_data={"plugin": {"enabled": False}},
|
||||
use_provided_config=True,
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
response = await runner._handle_inspect_plugin_config(envelope)
|
||||
|
||||
assert response.error is None
|
||||
assert response.payload["success"] is True
|
||||
assert response.payload["enabled"] is False
|
||||
assert response.payload["normalized_config"] == {"plugin": {"enabled": False, "retry_count": 3}}
|
||||
assert response.payload["default_config"] == {"plugin": {"enabled": True, "retry_count": 3}}
|
||||
assert purged_plugins == [("demo.plugin", "/tmp/demo-plugin")]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_validate_plugin_config_handler_returns_error_on_invalid_config(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
@@ -251,3 +364,73 @@ async def test_update_plugin_config_prefers_runtime_validation(
|
||||
with config_path.open("rb") as handle:
|
||||
saved_config = tomllib.load(handle)
|
||||
assert saved_config == {"plugin": {"enabled": False, "retry_count": 3}}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webui_config_endpoints_use_runtime_inspection_for_unloaded_plugin(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""WebUI 在插件未加载时也应从代码定义返回配置与 Schema。"""
|
||||
|
||||
async def _mock_inspect_plugin_config(
|
||||
plugin_id: str,
|
||||
config_data: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
use_provided_config: bool = False,
|
||||
) -> SimpleNamespace | None:
|
||||
"""返回运行时冷检查结果。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID。
|
||||
config_data: 可选配置。
|
||||
use_provided_config: 是否使用传入配置。
|
||||
|
||||
Returns:
|
||||
SimpleNamespace | None: 冷检查结果。
|
||||
"""
|
||||
|
||||
del config_data, use_provided_config
|
||||
if plugin_id != "demo.plugin":
|
||||
return None
|
||||
return SimpleNamespace(
|
||||
config_schema={
|
||||
"plugin_id": "demo.plugin",
|
||||
"plugin_info": {
|
||||
"name": "Demo",
|
||||
"version": "1.0.0",
|
||||
"description": "",
|
||||
"author": "",
|
||||
},
|
||||
"sections": {"plugin": {"fields": {}}},
|
||||
"layout": {"type": "auto", "tabs": []},
|
||||
},
|
||||
normalized_config={"plugin": {"enabled": True, "retry_count": 3}},
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
fake_runtime_manager = SimpleNamespace(inspect_plugin_config=_mock_inspect_plugin_config)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"src.webui.routers.plugin.config_routes.require_plugin_token",
|
||||
lambda session: session or "session-token",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"src.webui.routers.plugin.config_routes.find_plugin_path_by_id",
|
||||
lambda plugin_id: tmp_path if plugin_id == "demo.plugin" else None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"src.plugin_runtime.integration.get_plugin_runtime_manager",
|
||||
lambda: fake_runtime_manager,
|
||||
)
|
||||
|
||||
schema_response = await get_plugin_config_schema("demo.plugin", maibot_session="session-token")
|
||||
config_response = await get_plugin_config("demo.plugin", maibot_session="session-token")
|
||||
|
||||
assert schema_response["success"] is True
|
||||
assert schema_response["schema"]["plugin_id"] == "demo.plugin"
|
||||
assert config_response == {
|
||||
"success": True,
|
||||
"config": {"plugin": {"enabled": True, "retry_count": 3}},
|
||||
"message": "配置文件不存在,已返回默认配置",
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Sequence
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
@@ -1405,6 +1405,57 @@ class TestComponentRegistry:
|
||||
assert warnings
|
||||
assert "plugin_a.broken" in warnings[0]
|
||||
|
||||
def test_register_hook_handler_rejects_unknown_hook(self):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistrationError, ComponentRegistry
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpecRegistry
|
||||
|
||||
reg = ComponentRegistry(hook_spec_registry=HookSpecRegistry())
|
||||
|
||||
with pytest.raises(ComponentRegistrationError, match="未注册的 Hook"):
|
||||
reg.register_component(
|
||||
"broken_hook",
|
||||
"hook_handler",
|
||||
"plugin_a",
|
||||
{
|
||||
"hook": "chat.receive.unknown",
|
||||
"mode": "blocking",
|
||||
},
|
||||
)
|
||||
|
||||
def test_register_plugin_components_is_atomic_when_hook_invalid(self):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistrationError, ComponentRegistry
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
|
||||
hook_spec_registry = HookSpecRegistry()
|
||||
hook_spec_registry.register_hook_spec(HookSpec(name="chat.receive.before_process"))
|
||||
reg = ComponentRegistry(hook_spec_registry=hook_spec_registry)
|
||||
reg.register_plugin_components(
|
||||
"plugin_a",
|
||||
[
|
||||
{"name": "cmd_old", "component_type": "command", "metadata": {"command_pattern": r"^/old"}},
|
||||
],
|
||||
)
|
||||
|
||||
with pytest.raises(ComponentRegistrationError, match="未注册的 Hook"):
|
||||
reg.register_plugin_components(
|
||||
"plugin_a",
|
||||
[
|
||||
{
|
||||
"name": "hook_ok",
|
||||
"component_type": "hook_handler",
|
||||
"metadata": {"hook": "chat.receive.before_process", "mode": "blocking"},
|
||||
},
|
||||
{
|
||||
"name": "hook_bad",
|
||||
"component_type": "hook_handler",
|
||||
"metadata": {"hook": "chat.receive.missing", "mode": "blocking"},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
assert reg.get_component("plugin_a.cmd_old") is not None
|
||||
assert reg.get_component("plugin_a.hook_ok") is None
|
||||
|
||||
def test_query_by_type(self):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
|
||||
@@ -2142,6 +2193,18 @@ class TestPluginRuntimeHookEntry:
|
||||
assert result.kwargs["session_id"] == "s-1"
|
||||
assert ("b1", "builtin_guard") in call_log
|
||||
|
||||
def test_manager_lists_builtin_hook_specs(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""PluginRuntimeManager 应暴露内置 Hook 规格清单。"""
|
||||
|
||||
_ComponentRegistry, PluginRuntimeManager = self._import_manager_modules(monkeypatch)
|
||||
|
||||
manager = PluginRuntimeManager()
|
||||
hook_names = {spec.name for spec in manager.list_hook_specs()}
|
||||
|
||||
assert "chat.receive.before_process" in hook_names
|
||||
assert "send_service.before_send" in hook_names
|
||||
assert "maisaka.planner.after_response" in hook_names
|
||||
|
||||
|
||||
class TestRPCServer:
|
||||
"""RPC Server 代际保护测试"""
|
||||
@@ -2974,6 +3037,16 @@ class TestIntegration:
|
||||
self._registered_plugins = {plugin_id: object() for plugin_id in plugins}
|
||||
self.config_updates = []
|
||||
|
||||
async def inspect_plugin_config(
|
||||
self,
|
||||
plugin_id: str,
|
||||
config_data: Optional[Dict[str, Any]] = None,
|
||||
use_provided_config: bool = False,
|
||||
) -> SimpleNamespace:
|
||||
"""返回测试用的配置解析结果。"""
|
||||
del config_data, use_provided_config
|
||||
return SimpleNamespace(enabled=True, normalized_config={"enabled": True}, plugin_id=plugin_id)
|
||||
|
||||
async def notify_plugin_config_updated(
|
||||
self,
|
||||
plugin_id,
|
||||
@@ -2997,6 +3070,110 @@ class TestIntegration:
|
||||
assert manager._builtin_supervisor.config_updates == [("test.alpha", {"enabled": True}, "", "self")]
|
||||
assert manager._third_party_supervisor.config_updates == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_plugin_config_changes_loads_unloaded_enabled_plugin(self, monkeypatch, tmp_path):
|
||||
from src.plugin_runtime import integration as integration_module
|
||||
from src.config.file_watcher import FileChange
|
||||
import json
|
||||
|
||||
thirdparty_root = tmp_path / "plugins"
|
||||
alpha_dir = thirdparty_root / "alpha"
|
||||
alpha_dir.mkdir(parents=True)
|
||||
(alpha_dir / "config.toml").write_text("[plugin]\nenabled = true\n", encoding="utf-8")
|
||||
(alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
|
||||
(alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8")
|
||||
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
class FakeSupervisor:
|
||||
def __init__(self, plugin_dirs):
|
||||
self._plugin_dirs = plugin_dirs
|
||||
self._registered_plugins = {}
|
||||
|
||||
async def inspect_plugin_config(
|
||||
self,
|
||||
plugin_id: str,
|
||||
config_data: Optional[Dict[str, Any]] = None,
|
||||
use_provided_config: bool = False,
|
||||
) -> SimpleNamespace:
|
||||
"""返回测试用的启用配置快照。"""
|
||||
del config_data, use_provided_config
|
||||
return SimpleNamespace(enabled=True, normalized_config={"plugin": {"enabled": True}}, plugin_id=plugin_id)
|
||||
|
||||
manager = integration_module.PluginRuntimeManager()
|
||||
manager._started = True
|
||||
manager._third_party_supervisor = FakeSupervisor([thirdparty_root])
|
||||
|
||||
load_calls = []
|
||||
|
||||
async def fake_load_plugin_globally(plugin_id: str, reason: str = "manual") -> bool:
|
||||
"""记录自动加载调用。"""
|
||||
load_calls.append((plugin_id, reason))
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(manager, "load_plugin_globally", fake_load_plugin_globally)
|
||||
|
||||
await manager._handle_plugin_config_changes(
|
||||
"test.alpha",
|
||||
[FileChange(change_type=1, path=alpha_dir / "config.toml")],
|
||||
)
|
||||
|
||||
assert load_calls == [("test.alpha", "config_enabled")]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_plugin_config_changes_unloads_loaded_disabled_plugin(self, monkeypatch, tmp_path):
|
||||
from src.plugin_runtime import integration as integration_module
|
||||
from src.config.file_watcher import FileChange
|
||||
import json
|
||||
|
||||
builtin_root = tmp_path / "src" / "plugins" / "built_in"
|
||||
alpha_dir = builtin_root / "alpha"
|
||||
alpha_dir.mkdir(parents=True)
|
||||
(alpha_dir / "config.toml").write_text("[plugin]\nenabled = false\n", encoding="utf-8")
|
||||
(alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
|
||||
(alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8")
|
||||
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
class FakeSupervisor:
|
||||
def __init__(self, plugin_dirs, plugins):
|
||||
self._plugin_dirs = plugin_dirs
|
||||
self._registered_plugins = {plugin_id: object() for plugin_id in plugins}
|
||||
|
||||
async def inspect_plugin_config(
|
||||
self,
|
||||
plugin_id: str,
|
||||
config_data: Optional[Dict[str, Any]] = None,
|
||||
use_provided_config: bool = False,
|
||||
) -> SimpleNamespace:
|
||||
"""返回测试用的禁用配置快照。"""
|
||||
del config_data, use_provided_config
|
||||
return SimpleNamespace(
|
||||
enabled=False,
|
||||
normalized_config={"plugin": {"enabled": False}},
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
manager = integration_module.PluginRuntimeManager()
|
||||
manager._started = True
|
||||
manager._builtin_supervisor = FakeSupervisor([builtin_root], ["test.alpha"])
|
||||
|
||||
reload_calls = []
|
||||
|
||||
async def fake_reload_plugins_globally(plugin_ids: Sequence[str], reason: str = "manual") -> bool:
|
||||
"""记录自动卸载调用。"""
|
||||
reload_calls.append((list(plugin_ids), reason))
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(manager, "reload_plugins_globally", fake_reload_plugins_globally)
|
||||
|
||||
await manager._handle_plugin_config_changes(
|
||||
"test.alpha",
|
||||
[FileChange(change_type=1, path=alpha_dir / "config.toml")],
|
||||
)
|
||||
|
||||
assert reload_calls == [(["test.alpha"], "config_disabled")]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_main_config_reload_only_notifies_subscribers(self, monkeypatch):
|
||||
from src.plugin_runtime import integration as integration_module
|
||||
@@ -3108,6 +3285,55 @@ class TestIntegration:
|
||||
subscription["paths"][0] for subscription in manager._plugin_file_watcher.subscriptions
|
||||
} == {alpha_dir / "config.toml", beta_dir / "config.toml"}
|
||||
|
||||
def test_refresh_plugin_config_watch_subscriptions_includes_unloaded_plugins(self, tmp_path):
|
||||
from src.plugin_runtime import integration as integration_module
|
||||
import json
|
||||
|
||||
thirdparty_root = tmp_path / "plugins"
|
||||
alpha_dir = thirdparty_root / "alpha"
|
||||
beta_dir = thirdparty_root / "beta"
|
||||
alpha_dir.mkdir(parents=True)
|
||||
beta_dir.mkdir(parents=True)
|
||||
(alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
|
||||
(beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
|
||||
(alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8")
|
||||
(beta_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.beta")), encoding="utf-8")
|
||||
|
||||
class FakeWatcher:
|
||||
def __init__(self):
|
||||
self.subscriptions = []
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
callback: Any,
|
||||
*,
|
||||
paths: Optional[Sequence[Path]] = None,
|
||||
change_types: Any = None,
|
||||
) -> str:
|
||||
"""记录新的监听订阅。"""
|
||||
del callback, change_types
|
||||
subscription_id = f"sub-{len(self.subscriptions) + 1}"
|
||||
self.subscriptions.append({"id": subscription_id, "paths": tuple(paths or ())})
|
||||
return subscription_id
|
||||
|
||||
def unsubscribe(self, subscription_id: str) -> bool:
|
||||
"""兼容 watcher 取消订阅接口。"""
|
||||
del subscription_id
|
||||
return True
|
||||
|
||||
class FakeSupervisor:
|
||||
def __init__(self, plugin_dirs, plugins):
|
||||
self._plugin_dirs = plugin_dirs
|
||||
self._registered_plugins = {plugin_id: object() for plugin_id in plugins}
|
||||
|
||||
manager = integration_module.PluginRuntimeManager()
|
||||
manager._plugin_file_watcher = FakeWatcher()
|
||||
manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["test.alpha"])
|
||||
|
||||
manager._refresh_plugin_config_watch_subscriptions()
|
||||
|
||||
assert set(manager._plugin_config_watcher_subscriptions.keys()) == {"test.alpha", "test.beta"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_component_reload_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch):
|
||||
from src.plugin_runtime import integration as integration_module
|
||||
|
||||
136
pytests/test_runtime_business_hooks.py
Normal file
136
pytests/test_runtime_business_hooks.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""业务命名 Hook 集成测试。"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
# 确保项目根目录在 sys.path 中
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
# SDK 包路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "packages", "maibot-plugin-sdk"))
|
||||
|
||||
|
||||
class _FakeHookManager:
|
||||
"""用于业务 Hook 测试的最小运行时管理器。"""
|
||||
|
||||
def __init__(self, responses: dict[str, SimpleNamespace]) -> None:
|
||||
"""初始化测试管理器。
|
||||
|
||||
Args:
|
||||
responses: 按 Hook 名称预设的返回结果映射。
|
||||
"""
|
||||
|
||||
self._responses = responses
|
||||
self.calls: list[tuple[str, dict[str, Any]]] = []
|
||||
|
||||
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> SimpleNamespace:
|
||||
"""模拟调用运行时命名 Hook。
|
||||
|
||||
Args:
|
||||
hook_name: 目标 Hook 名称。
|
||||
**kwargs: 传入 Hook 的参数。
|
||||
|
||||
Returns:
|
||||
SimpleNamespace: 预设的 Hook 返回结果。
|
||||
"""
|
||||
|
||||
self.calls.append((hook_name, dict(kwargs)))
|
||||
return self._responses.get(hook_name, SimpleNamespace(kwargs=dict(kwargs), aborted=False))
|
||||
|
||||
|
||||
def test_builtin_hook_catalog_includes_new_business_hooks(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""内置 Hook 目录应包含三个业务系统新增的 Hook。"""
|
||||
|
||||
monkeypatch.setattr(sys, "exit", lambda code=0: None)
|
||||
from src.plugin_runtime.hook_catalog import register_builtin_hook_specs
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpecRegistry
|
||||
|
||||
registry = HookSpecRegistry()
|
||||
hook_names = {spec.name for spec in register_builtin_hook_specs(registry)}
|
||||
|
||||
assert "emoji.maisaka.before_select" in hook_names
|
||||
assert "emoji.register.after_build_emotion" in hook_names
|
||||
assert "jargon.extract.before_persist" in hook_names
|
||||
assert "jargon.query.after_search" in hook_names
|
||||
assert "expression.select.before_select" in hook_names
|
||||
assert "expression.learn.before_upsert" in hook_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_emoji_for_maisaka_can_be_aborted_by_hook(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""表情包系统应允许在选择前被 Hook 中止。"""
|
||||
|
||||
from src.chat.emoji_system import maisaka_tool
|
||||
|
||||
fake_manager = _FakeHookManager(
|
||||
{
|
||||
"emoji.maisaka.before_select": SimpleNamespace(
|
||||
kwargs={"abort_message": "插件阻止了表情发送。"},
|
||||
aborted=True,
|
||||
)
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(maisaka_tool, "_get_runtime_manager", lambda: fake_manager)
|
||||
|
||||
result = await maisaka_tool.send_emoji_for_maisaka(stream_id="stream-1", requested_emotion="开心")
|
||||
|
||||
assert result.success is False
|
||||
assert result.message == "插件阻止了表情发送。"
|
||||
assert fake_manager.calls[0][0] == "emoji.maisaka.before_select"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jargon_extract_can_be_aborted_before_persist(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""黑话提取结果应允许在写库前被 Hook 中止。"""
|
||||
|
||||
from src.learners.jargon_miner import JargonMiner
|
||||
|
||||
fake_manager = _FakeHookManager(
|
||||
{
|
||||
"jargon.extract.before_persist": SimpleNamespace(
|
||||
kwargs={"entries": []},
|
||||
aborted=True,
|
||||
)
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(JargonMiner, "_get_runtime_manager", staticmethod(lambda: fake_manager))
|
||||
|
||||
miner = JargonMiner(session_id="session-1", session_name="测试会话")
|
||||
await miner.process_extracted_entries(
|
||||
[{"content": "yyds", "raw_content": {"[1] yyds 太强了"}}],
|
||||
)
|
||||
|
||||
assert fake_manager.calls[0][0] == "jargon.extract.before_persist"
|
||||
assert fake_manager.calls[0][1]["session_id"] == "session-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expression_selection_can_be_aborted_by_hook(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""表达方式选择流程应允许在开始前被 Hook 中止。"""
|
||||
|
||||
from src.learners.expression_selector import ExpressionSelector
|
||||
|
||||
fake_manager = _FakeHookManager(
|
||||
{
|
||||
"expression.select.before_select": SimpleNamespace(
|
||||
kwargs={},
|
||||
aborted=True,
|
||||
)
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(ExpressionSelector, "_get_runtime_manager", staticmethod(lambda: fake_manager))
|
||||
monkeypatch.setattr(ExpressionSelector, "can_use_expression_for_chat", lambda self, chat_id: True)
|
||||
|
||||
selector = ExpressionSelector()
|
||||
selected_expressions, selected_ids = await selector.select_suitable_expressions(
|
||||
chat_id="session-1",
|
||||
chat_info="用户刚刚发来一条消息。",
|
||||
)
|
||||
|
||||
assert selected_expressions == []
|
||||
assert selected_ids == []
|
||||
assert fake_manager.calls[0][0] == "expression.select.before_select"
|
||||
@@ -1,25 +1,28 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from rich.traceback import install
|
||||
from sqlmodel import select
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import heapq
|
||||
import Levenshtein
|
||||
import random
|
||||
import re
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from rich.traceback import install
|
||||
from sqlmodel import select
|
||||
|
||||
import Levenshtein
|
||||
|
||||
from src.common.data_models.image_data_model import MaiEmoji
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
from src.common.database.database import get_db_session, get_db_session_manual
|
||||
from src.common.utils.utils_image import ImageUtils
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.config.config import config_manager, global_config
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions, LLMImageOptions
|
||||
from src.common.database.database import get_db_session, get_db_session_manual
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_image import ImageUtils
|
||||
from src.config.config import config_manager, global_config
|
||||
from src.plugin_runtime.hook_schema_utils import build_object_schema
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
logger = get_logger("emoji")
|
||||
@@ -33,6 +36,171 @@ EMOJI_REGISTERED_DIR = DATA_DIR / "emoji_registered" # 已注册的表情包注
|
||||
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
|
||||
|
||||
|
||||
def register_emoji_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
||||
"""注册表情包系统内置 Hook 规格。
|
||||
|
||||
Args:
|
||||
registry: 目标 Hook 规格注册中心。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 实际注册的 Hook 规格列表。
|
||||
"""
|
||||
|
||||
emoji_schema = {
|
||||
"type": "object",
|
||||
"description": "当前表情包的序列化信息,主要包含 file_hash、description、emotions 等字段。",
|
||||
}
|
||||
string_array_schema = {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
return registry.register_hook_specs(
|
||||
[
|
||||
HookSpec(
|
||||
name="emoji.maisaka.before_select",
|
||||
description="Maisaka 表情发送工具选择表情前触发,可改写情绪、上下文和采样参数,或中止本次选择。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"stream_id": {"type": "string", "description": "目标会话 ID。"},
|
||||
"requested_emotion": {"type": "string", "description": "请求的目标情绪标签。"},
|
||||
"reasoning": {"type": "string", "description": "本次发送表情的推理理由。"},
|
||||
"context_texts": {
|
||||
**string_array_schema,
|
||||
"description": "最近聊天上下文文本列表。",
|
||||
},
|
||||
"sample_size": {"type": "integer", "description": "候选表情采样数量。"},
|
||||
"abort_message": {
|
||||
"type": "string",
|
||||
"description": "当 Hook 主动中止时可附带的失败提示。",
|
||||
},
|
||||
},
|
||||
required=["stream_id", "requested_emotion", "reasoning", "context_texts", "sample_size"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="emoji.maisaka.after_select",
|
||||
description="Maisaka 已选出表情后触发,可替换选中的表情哈希、补充匹配情绪,或中止发送。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"stream_id": {"type": "string", "description": "目标会话 ID。"},
|
||||
"requested_emotion": {"type": "string", "description": "请求的目标情绪标签。"},
|
||||
"reasoning": {"type": "string", "description": "本次发送表情的推理理由。"},
|
||||
"context_texts": {
|
||||
**string_array_schema,
|
||||
"description": "最近聊天上下文文本列表。",
|
||||
},
|
||||
"sample_size": {"type": "integer", "description": "候选表情采样数量。"},
|
||||
"selected_emoji": emoji_schema,
|
||||
"selected_emoji_hash": {"type": "string", "description": "选中的表情哈希。"},
|
||||
"matched_emotion": {"type": "string", "description": "最终命中的情绪标签。"},
|
||||
"abort_message": {
|
||||
"type": "string",
|
||||
"description": "当 Hook 主动中止时可附带的失败提示。",
|
||||
},
|
||||
},
|
||||
required=[
|
||||
"stream_id",
|
||||
"requested_emotion",
|
||||
"reasoning",
|
||||
"context_texts",
|
||||
"sample_size",
|
||||
"matched_emotion",
|
||||
],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="emoji.register.after_build_description",
|
||||
description="表情包描述生成并通过内容审查后触发,可改写描述文本或拒绝本次注册。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"emoji": emoji_schema,
|
||||
"description": {"type": "string", "description": "当前生成出的表情包描述。"},
|
||||
"image_format": {"type": "string", "description": "表情图片格式。"},
|
||||
},
|
||||
required=["emoji", "description", "image_format"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="emoji.register.after_build_emotion",
|
||||
description="表情包情绪标签生成完成后触发,可改写标签列表或拒绝本次注册。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"emoji": emoji_schema,
|
||||
"description": {"type": "string", "description": "当前表情包描述。"},
|
||||
"emotions": {
|
||||
**string_array_schema,
|
||||
"description": "当前生成出的情绪标签列表。",
|
||||
},
|
||||
},
|
||||
required=["emoji", "description", "emotions"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
|
||||
def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[Dict[str, Any]]:
|
||||
"""将表情包对象序列化为 Hook 可传输载荷。
|
||||
|
||||
Args:
|
||||
emoji: 待序列化的表情包对象。
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 序列化后的字典;当表情为空时返回 ``None``。
|
||||
"""
|
||||
|
||||
if emoji is None:
|
||||
return None
|
||||
|
||||
return {
|
||||
"file_hash": str(emoji.file_hash or "").strip(),
|
||||
"file_name": emoji.file_name,
|
||||
"full_path": str(emoji.full_path),
|
||||
"description": emoji.description,
|
||||
"emotions": [str(item).strip() for item in emoji.emotion if str(item).strip()],
|
||||
"query_count": int(emoji.query_count),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_string_list(raw_values: Any) -> List[str]:
|
||||
"""将任意列表值规范化为字符串列表。
|
||||
|
||||
Args:
|
||||
raw_values: 待规范化的原始值。
|
||||
|
||||
Returns:
|
||||
List[str]: 去空白后的字符串列表。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_values, list):
|
||||
return []
|
||||
return [str(item).strip() for item in raw_values if str(item).strip()]
|
||||
|
||||
|
||||
def _ensure_directories() -> None:
|
||||
"""确保表情包相关目录存在"""
|
||||
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
|
||||
@@ -642,6 +810,22 @@ class EmojiManager:
|
||||
if "否" in llm_response:
|
||||
logger.warning(f"[表情包审查] 表情包内容不符合要求,拒绝注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
hook_result = await _get_runtime_manager().invoke_hook(
|
||||
"emoji.register.after_build_description",
|
||||
emoji=_serialize_emoji_for_hook(target_emoji),
|
||||
description=description,
|
||||
image_format=image_format,
|
||||
)
|
||||
if hook_result.aborted:
|
||||
logger.info(f"[构建描述] 表情包描述被 Hook 中止注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
normalized_description = str(hook_result.kwargs.get("description", description) or "").strip()
|
||||
if not normalized_description:
|
||||
logger.warning(f"[构建描述] Hook 返回空描述,拒绝注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
description = normalized_description
|
||||
target_emoji.description = description
|
||||
logger.info(f"[构建描述] 成功为表情包构建描述: {target_emoji.description}")
|
||||
return True, target_emoji
|
||||
@@ -687,6 +871,23 @@ class EmojiManager:
|
||||
elif len(emotions) > 2:
|
||||
emotions = random.sample(emotions, 2)
|
||||
|
||||
hook_result = await _get_runtime_manager().invoke_hook(
|
||||
"emoji.register.after_build_emotion",
|
||||
emoji=_serialize_emoji_for_hook(target_emoji),
|
||||
description=target_emoji.description,
|
||||
emotions=list(emotions),
|
||||
)
|
||||
if hook_result.aborted:
|
||||
logger.info(f"[构建情感标签] 表情包情感标签被 Hook 中止注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
raw_emotions = hook_result.kwargs.get("emotions")
|
||||
if raw_emotions is not None:
|
||||
emotions = _normalize_string_list(raw_emotions)
|
||||
if not emotions:
|
||||
logger.warning(f"[构建情感标签] Hook 返回空情绪标签,拒绝注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
logger.info(f"[构建情感标签] 成功为表情包构建情感标签: {','.join(emotions)}")
|
||||
target_emoji.emotion = emotions
|
||||
return True, target_emoji
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Maisaka 表情工具内置能力。"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Sequence
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
import random
|
||||
|
||||
@@ -11,7 +11,7 @@ from src.common.logger import get_logger
|
||||
from src.common.utils.utils_image import ImageUtils
|
||||
from src.services import send_service
|
||||
|
||||
from .emoji_manager import emoji_manager, emoji_manager_emotion_judge_llm
|
||||
from .emoji_manager import _serialize_emoji_for_hook, emoji_manager, emoji_manager_emotion_judge_llm
|
||||
|
||||
logger = get_logger("emoji_maisaka_tool")
|
||||
|
||||
@@ -29,6 +29,76 @@ class MaisakaEmojiSendResult:
|
||||
matched_emotion: str = ""
|
||||
|
||||
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
|
||||
def _coerce_positive_int(value: Any, default: int) -> int:
|
||||
"""将任意值安全转换为正整数。
|
||||
|
||||
Args:
|
||||
value: 待转换的值。
|
||||
default: 转换失败时使用的默认值。
|
||||
|
||||
Returns:
|
||||
int: 规范化后的正整数。
|
||||
"""
|
||||
|
||||
try:
|
||||
normalized_value = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
return normalized_value if normalized_value > 0 else default
|
||||
|
||||
|
||||
def _normalize_context_texts(context_texts: Sequence[str] | None) -> list[str]:
|
||||
"""清洗 Hook 和调用链传入的上下文文本列表。
|
||||
|
||||
Args:
|
||||
context_texts: 原始上下文文本序列。
|
||||
|
||||
Returns:
|
||||
list[str]: 过滤空白后的上下文文本列表。
|
||||
"""
|
||||
|
||||
if not context_texts:
|
||||
return []
|
||||
return [str(item).strip() for item in context_texts if str(item).strip()]
|
||||
|
||||
|
||||
def _resolve_selected_emoji(raw_value: Any) -> Optional[MaiEmoji]:
|
||||
"""根据 Hook 返回值解析目标表情包对象。
|
||||
|
||||
Args:
|
||||
raw_value: Hook 返回的 ``selected_emoji`` 或 ``selected_emoji_hash``。
|
||||
|
||||
Returns:
|
||||
Optional[MaiEmoji]: 命中的表情包对象;未命中时返回 ``None``。
|
||||
"""
|
||||
|
||||
raw_hash: str = ""
|
||||
if isinstance(raw_value, dict):
|
||||
raw_hash = str(raw_value.get("file_hash") or raw_value.get("hash") or "").strip()
|
||||
elif isinstance(raw_value, str):
|
||||
raw_hash = raw_value.strip()
|
||||
|
||||
if not raw_hash:
|
||||
return None
|
||||
|
||||
for emoji in emoji_manager.emojis:
|
||||
if emoji.file_hash == raw_hash:
|
||||
return emoji
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_emotions(emoji: MaiEmoji) -> list[str]:
|
||||
"""提取并清洗单个表情的情绪标签。"""
|
||||
|
||||
@@ -129,16 +199,81 @@ async def send_emoji_for_maisaka(
|
||||
) -> MaisakaEmojiSendResult:
|
||||
"""为 Maisaka 选择并发送一个表情。"""
|
||||
|
||||
selected_emoji, matched_emotion = await select_emoji_for_maisaka(
|
||||
requested_emotion=requested_emotion,
|
||||
reasoning=reasoning,
|
||||
context_texts=context_texts,
|
||||
normalized_requested_emotion = requested_emotion.strip()
|
||||
normalized_reasoning = reasoning.strip()
|
||||
normalized_context_texts = _normalize_context_texts(context_texts)
|
||||
sample_size = 30
|
||||
|
||||
before_select_result = await _get_runtime_manager().invoke_hook(
|
||||
"emoji.maisaka.before_select",
|
||||
stream_id=stream_id,
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
reasoning=normalized_reasoning,
|
||||
context_texts=list(normalized_context_texts),
|
||||
sample_size=sample_size,
|
||||
abort_message="表情选择已被 Hook 中止。",
|
||||
)
|
||||
if before_select_result.aborted:
|
||||
abort_message = str(before_select_result.kwargs.get("abort_message") or "表情选择已被 Hook 中止。").strip()
|
||||
return MaisakaEmojiSendResult(
|
||||
success=False,
|
||||
message=abort_message or "表情选择已被 Hook 中止。",
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
)
|
||||
|
||||
before_select_kwargs = before_select_result.kwargs
|
||||
normalized_requested_emotion = str(
|
||||
before_select_kwargs.get("requested_emotion", normalized_requested_emotion) or ""
|
||||
).strip()
|
||||
normalized_reasoning = str(before_select_kwargs.get("reasoning", normalized_reasoning) or "").strip()
|
||||
if isinstance(before_select_kwargs.get("context_texts"), list):
|
||||
normalized_context_texts = _normalize_context_texts(before_select_kwargs.get("context_texts"))
|
||||
sample_size = _coerce_positive_int(before_select_kwargs.get("sample_size"), sample_size)
|
||||
|
||||
selected_emoji, matched_emotion = await select_emoji_for_maisaka(
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
reasoning=normalized_reasoning,
|
||||
context_texts=normalized_context_texts,
|
||||
sample_size=sample_size,
|
||||
)
|
||||
after_select_result = await _get_runtime_manager().invoke_hook(
|
||||
"emoji.maisaka.after_select",
|
||||
stream_id=stream_id,
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
reasoning=normalized_reasoning,
|
||||
context_texts=list(normalized_context_texts),
|
||||
sample_size=sample_size,
|
||||
selected_emoji=_serialize_emoji_for_hook(selected_emoji),
|
||||
selected_emoji_hash=str(selected_emoji.file_hash or "").strip() if selected_emoji is not None else "",
|
||||
matched_emotion=matched_emotion,
|
||||
abort_message="表情发送已被 Hook 中止。",
|
||||
)
|
||||
if after_select_result.aborted:
|
||||
abort_message = str(after_select_result.kwargs.get("abort_message") or "表情发送已被 Hook 中止。").strip()
|
||||
return MaisakaEmojiSendResult(
|
||||
success=False,
|
||||
message=abort_message or "表情发送已被 Hook 中止。",
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
matched_emotion=matched_emotion,
|
||||
)
|
||||
|
||||
after_select_kwargs = after_select_result.kwargs
|
||||
normalized_requested_emotion = str(
|
||||
after_select_kwargs.get("requested_emotion", normalized_requested_emotion) or ""
|
||||
).strip()
|
||||
matched_emotion = str(after_select_kwargs.get("matched_emotion", matched_emotion) or "").strip()
|
||||
override_emoji = _resolve_selected_emoji(after_select_kwargs.get("selected_emoji_hash"))
|
||||
if override_emoji is None:
|
||||
override_emoji = _resolve_selected_emoji(after_select_kwargs.get("selected_emoji"))
|
||||
if override_emoji is not None:
|
||||
selected_emoji = override_emoji
|
||||
|
||||
if selected_emoji is None:
|
||||
return MaisakaEmojiSendResult(
|
||||
success=False,
|
||||
message="当前表情包库中没有可用表情。",
|
||||
requested_emotion=requested_emotion.strip(),
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
matched_emotion=matched_emotion,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -151,7 +286,7 @@ async def send_emoji_for_maisaka(
|
||||
message=f"发送表情包失败:{exc}",
|
||||
description=selected_emoji.description.strip(),
|
||||
emotions=_normalize_emotions(selected_emoji),
|
||||
requested_emotion=requested_emotion.strip(),
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
matched_emotion=matched_emotion,
|
||||
)
|
||||
|
||||
@@ -169,7 +304,7 @@ async def send_emoji_for_maisaka(
|
||||
message=f"发送表情包时发生异常:{exc}",
|
||||
description=selected_emoji.description.strip(),
|
||||
emotions=_normalize_emotions(selected_emoji),
|
||||
requested_emotion=requested_emotion.strip(),
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
matched_emotion=matched_emotion,
|
||||
)
|
||||
|
||||
@@ -181,7 +316,7 @@ async def send_emoji_for_maisaka(
|
||||
message="发送表情包失败。",
|
||||
description=description,
|
||||
emotions=emotions,
|
||||
requested_emotion=requested_emotion.strip(),
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
matched_emotion=matched_emotion,
|
||||
)
|
||||
|
||||
@@ -197,6 +332,6 @@ async def send_emoji_for_maisaka(
|
||||
emoji_base64=emoji_base64,
|
||||
description=description,
|
||||
emotions=emotions,
|
||||
requested_emotion=requested_emotion.strip(),
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
matched_emotion=matched_emotion,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""聊天消息入口与主链路调度。"""
|
||||
|
||||
from contextlib import suppress
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import os
|
||||
import traceback
|
||||
@@ -13,12 +15,15 @@ from src.common.utils.utils_message import MessageUtils
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||
|
||||
from src.core.announcement_manager import global_announcement_manager
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
from src.plugin_runtime.hook_payloads import deserialize_session_message, serialize_session_message
|
||||
from src.plugin_runtime.hook_schema_utils import build_object_schema
|
||||
from src.plugin_runtime.host.hook_dispatcher import HookDispatchResult
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
|
||||
from .message import SessionMessage
|
||||
from .chat_manager import chat_manager
|
||||
from .message import SessionMessage
|
||||
|
||||
# 定义日志配置
|
||||
|
||||
@@ -29,7 +34,137 @@ PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..
|
||||
logger = get_logger("chat")
|
||||
|
||||
|
||||
def register_chat_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
||||
"""注册聊天消息主链内置 Hook 规格。
|
||||
|
||||
Args:
|
||||
registry: 目标 Hook 规格注册中心。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 实际注册的 Hook 规格列表。
|
||||
"""
|
||||
|
||||
return registry.register_hook_specs(
|
||||
[
|
||||
HookSpec(
|
||||
name="chat.receive.before_process",
|
||||
description="在入站消息执行 `SessionMessage.process()` 之前触发,可拦截或改写消息。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"message": {
|
||||
"type": "object",
|
||||
"description": "当前入站消息的序列化 SessionMessage。",
|
||||
},
|
||||
},
|
||||
required=["message"],
|
||||
),
|
||||
default_timeout_ms=8000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="chat.receive.after_process",
|
||||
description="在入站消息完成轻量预处理后触发,可改写文本、消息体或中止后续链路。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"message": {
|
||||
"type": "object",
|
||||
"description": "已完成 `process()` 的序列化 SessionMessage。",
|
||||
},
|
||||
},
|
||||
required=["message"],
|
||||
),
|
||||
default_timeout_ms=8000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="chat.command.before_execute",
|
||||
description="在命令匹配成功、实际执行前触发,可拦截命令或改写命令上下文。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"message": {
|
||||
"type": "object",
|
||||
"description": "当前命令消息的序列化 SessionMessage。",
|
||||
},
|
||||
"command_name": {
|
||||
"type": "string",
|
||||
"description": "命中的命令名称。",
|
||||
},
|
||||
"plugin_id": {
|
||||
"type": "string",
|
||||
"description": "命令所属插件 ID。",
|
||||
},
|
||||
"matched_groups": {
|
||||
"type": "object",
|
||||
"description": "命令正则命名捕获结果。",
|
||||
},
|
||||
},
|
||||
required=["message", "command_name", "plugin_id", "matched_groups"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="chat.command.after_execute",
|
||||
description="在命令执行结束后触发,可调整返回文本和是否继续主链处理。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"message": {
|
||||
"type": "object",
|
||||
"description": "当前命令消息的序列化 SessionMessage。",
|
||||
},
|
||||
"command_name": {
|
||||
"type": "string",
|
||||
"description": "命令名称。",
|
||||
},
|
||||
"plugin_id": {
|
||||
"type": "string",
|
||||
"description": "命令所属插件 ID。",
|
||||
},
|
||||
"matched_groups": {
|
||||
"type": "object",
|
||||
"description": "命令正则命名捕获结果。",
|
||||
},
|
||||
"success": {
|
||||
"type": "boolean",
|
||||
"description": "命令执行是否成功。",
|
||||
},
|
||||
"response": {
|
||||
"type": "string",
|
||||
"description": "命令返回文本。",
|
||||
},
|
||||
"intercept_message_level": {
|
||||
"type": "integer",
|
||||
"description": "命令拦截等级。",
|
||||
},
|
||||
"continue_process": {
|
||||
"type": "boolean",
|
||||
"description": "命令执行后是否继续后续消息处理。",
|
||||
},
|
||||
},
|
||||
required=[
|
||||
"message",
|
||||
"command_name",
|
||||
"plugin_id",
|
||||
"matched_groups",
|
||||
"success",
|
||||
"intercept_message_level",
|
||||
"continue_process",
|
||||
],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=False,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ChatBot:
|
||||
"""聊天机器人入口协调器。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化聊天机器人入口。"""
|
||||
|
||||
@@ -44,6 +179,66 @@ class ChatBot:
|
||||
|
||||
self._started = True
|
||||
|
||||
@staticmethod
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
"""将任意值安全转换为整数。
|
||||
|
||||
Args:
|
||||
value: 待转换的值。
|
||||
default: 转换失败时的默认值。
|
||||
|
||||
Returns:
|
||||
int: 转换后的整数结果。
|
||||
"""
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
async def _invoke_message_hook(
|
||||
self,
|
||||
hook_name: str,
|
||||
message: SessionMessage,
|
||||
**kwargs: Any,
|
||||
) -> tuple[HookDispatchResult, SessionMessage]:
|
||||
"""触发携带会话消息的命名 Hook。
|
||||
|
||||
Args:
|
||||
hook_name: 目标 Hook 名称。
|
||||
message: 当前会话消息。
|
||||
**kwargs: 需要附带传递的额外参数。
|
||||
|
||||
Returns:
|
||||
tuple[HookDispatchResult, SessionMessage]: Hook 聚合结果以及可能被改写后的消息对象。
|
||||
"""
|
||||
|
||||
hook_result = await self._get_runtime_manager().invoke_hook(
|
||||
hook_name,
|
||||
message=serialize_session_message(message),
|
||||
**kwargs,
|
||||
)
|
||||
mutated_message = message
|
||||
raw_message = hook_result.kwargs.get("message")
|
||||
if raw_message is not None:
|
||||
try:
|
||||
mutated_message = deserialize_session_message(raw_message)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Hook {hook_name} 返回的 message 无法反序列化,已忽略: {exc}")
|
||||
return hook_result, mutated_message
|
||||
|
||||
async def _process_commands(self, message: SessionMessage) -> tuple[bool, Optional[str], bool]:
|
||||
"""使用统一组件注册表处理命令。
|
||||
|
||||
@@ -71,6 +266,25 @@ class ChatBot:
|
||||
return False, None, True
|
||||
|
||||
message.is_command = True
|
||||
before_result, message = await self._invoke_message_hook(
|
||||
"chat.command.before_execute",
|
||||
message,
|
||||
command_name=command_name,
|
||||
plugin_id=plugin_name,
|
||||
matched_groups=dict(matched_groups),
|
||||
)
|
||||
if before_result.aborted:
|
||||
logger.info(f"命令 {command_name} 被 Hook 中止,跳过命令执行")
|
||||
return True, None, False
|
||||
|
||||
hook_kwargs = before_result.kwargs
|
||||
command_name = str(hook_kwargs.get("command_name", command_name) or command_name)
|
||||
plugin_name = str(hook_kwargs.get("plugin_id", plugin_name) or plugin_name)
|
||||
matched_groups = (
|
||||
dict(hook_kwargs["matched_groups"])
|
||||
if isinstance(hook_kwargs.get("matched_groups"), dict)
|
||||
else dict(matched_groups)
|
||||
)
|
||||
|
||||
# 获取插件配置
|
||||
plugin_config = component_query_service.get_plugin_config(plugin_name)
|
||||
@@ -82,27 +296,43 @@ class ChatBot:
|
||||
plugin_config=plugin_config,
|
||||
matched_groups=matched_groups,
|
||||
)
|
||||
self._mark_command_message(message, intercept_message_level)
|
||||
|
||||
# 记录命令执行结果
|
||||
if success:
|
||||
logger.info(f"命令执行成功: {command_name} (拦截等级: {intercept_message_level})")
|
||||
else:
|
||||
logger.warning(f"命令执行失败: {command_name} - {response}")
|
||||
|
||||
# 根据命令的拦截设置决定是否继续处理消息
|
||||
return (
|
||||
True,
|
||||
response,
|
||||
not bool(intercept_message_level),
|
||||
) # 找到命令,根据intercept_message决定是否继续
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令时出错: {command_name} - {e}")
|
||||
continue_process = not bool(intercept_message_level)
|
||||
except Exception as exc:
|
||||
logger.error(f"执行命令时出错: {command_name} - {exc}")
|
||||
logger.error(traceback.format_exc())
|
||||
success = False
|
||||
response = str(exc)
|
||||
intercept_message_level = 1
|
||||
continue_process = False
|
||||
|
||||
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
|
||||
return True, str(e), False # 出错时继续处理消息
|
||||
after_result, message = await self._invoke_message_hook(
|
||||
"chat.command.after_execute",
|
||||
message,
|
||||
command_name=command_name,
|
||||
plugin_id=plugin_name,
|
||||
matched_groups=dict(matched_groups),
|
||||
success=success,
|
||||
response=response,
|
||||
intercept_message_level=intercept_message_level,
|
||||
continue_process=continue_process,
|
||||
)
|
||||
after_kwargs = after_result.kwargs
|
||||
success = bool(after_kwargs.get("success", success))
|
||||
raw_response = after_kwargs.get("response", response)
|
||||
response = None if raw_response is None else str(raw_response)
|
||||
intercept_message_level = self._coerce_int(
|
||||
after_kwargs.get("intercept_message_level", intercept_message_level),
|
||||
intercept_message_level,
|
||||
)
|
||||
continue_process = bool(after_kwargs.get("continue_process", continue_process))
|
||||
self._mark_command_message(message, intercept_message_level)
|
||||
|
||||
if success:
|
||||
logger.info(f"命令执行成功: {command_name} (拦截等级: {intercept_message_level})")
|
||||
else:
|
||||
logger.warning(f"命令执行失败: {command_name} - {response}")
|
||||
|
||||
return True, response, continue_process
|
||||
|
||||
return False, None, True
|
||||
|
||||
@@ -138,6 +368,17 @@ class ChatBot:
|
||||
cmd_result: Optional[str],
|
||||
continue_process: bool,
|
||||
) -> bool:
|
||||
"""处理命令链结果并决定是否终止主消息链。
|
||||
|
||||
Args:
|
||||
message: 当前命令消息。
|
||||
cmd_result: 命令响应文本。
|
||||
continue_process: 是否继续后续主链处理。
|
||||
|
||||
Returns:
|
||||
bool: ``True`` 表示已经终止后续主链。
|
||||
"""
|
||||
|
||||
if continue_process:
|
||||
return False
|
||||
|
||||
@@ -145,9 +386,18 @@ class ChatBot:
|
||||
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
|
||||
return True
|
||||
|
||||
async def handle_notice_message(self, message: SessionMessage):
|
||||
async def handle_notice_message(self, message: SessionMessage) -> bool:
|
||||
"""处理通知类消息。
|
||||
|
||||
Args:
|
||||
message: 当前通知消息。
|
||||
|
||||
Returns:
|
||||
bool: 当前消息是否为通知消息。
|
||||
"""
|
||||
|
||||
if message.message_id != "notice":
|
||||
return
|
||||
return False
|
||||
|
||||
message.is_notify = True
|
||||
logger.debug("notice消息")
|
||||
@@ -203,9 +453,12 @@ class ChatBot:
|
||||
return True
|
||||
|
||||
async def echo_message_process(self, raw_data: Dict[str, Any]) -> None:
|
||||
"""处理消息回送 ID 对应关系。
|
||||
|
||||
Args:
|
||||
raw_data: 平台适配器上报的原始回送载荷。
|
||||
"""
|
||||
用于专门处理回送消息ID的函数
|
||||
"""
|
||||
|
||||
message_data: Dict[str, Any] = raw_data.get("content", {})
|
||||
if not message_data:
|
||||
return
|
||||
@@ -218,18 +471,10 @@ class ChatBot:
|
||||
logger.debug(f"收到回送消息ID: {mmc_message_id} -> {actual_message_id}")
|
||||
|
||||
async def message_process(self, message_data: Dict[str, Any]) -> None:
|
||||
"""处理转化后的统一格式消息
|
||||
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
|
||||
heart_flow模式:使用思维流系统进行回复
|
||||
- 包含思维流状态管理
|
||||
- 在回复前进行观察和状态更新
|
||||
- 回复后更新思维流状态
|
||||
- 消息过滤
|
||||
- 记忆激活
|
||||
- 意愿计算
|
||||
- 消息生成和发送
|
||||
- 表情包处理
|
||||
- 性能计时
|
||||
"""处理统一格式的入站消息字典。
|
||||
|
||||
Args:
|
||||
message_data: 适配器整理后的统一消息字典。
|
||||
"""
|
||||
try:
|
||||
# 确保所有任务已启动
|
||||
@@ -253,7 +498,13 @@ class ChatBot:
|
||||
logger.error(f"预处理消息失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def receive_message(self, message: SessionMessage):
|
||||
async def receive_message(self, message: SessionMessage) -> None:
|
||||
"""处理单条入站会话消息。
|
||||
|
||||
Args:
|
||||
message: 待处理的会话消息。
|
||||
"""
|
||||
|
||||
try:
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
@@ -272,6 +523,19 @@ class ChatBot:
|
||||
)
|
||||
|
||||
message.session_id = session_id # 正确初始化session_id
|
||||
before_process_result, message = await self._invoke_message_hook(
|
||||
"chat.receive.before_process",
|
||||
message,
|
||||
)
|
||||
if before_process_result.aborted:
|
||||
logger.info(f"消息 {message.message_id} 在预处理前被 Hook 中止")
|
||||
return
|
||||
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
additional_config = message.message_info.additional_config
|
||||
if isinstance(additional_config, dict):
|
||||
account_id, scope = RouteKeyFactory.extract_components(additional_config)
|
||||
|
||||
# TODO: 修复事件预处理部分
|
||||
# continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
@@ -294,6 +558,16 @@ class ChatBot:
|
||||
enable_heavy_media_analysis=False,
|
||||
enable_voice_transcription=False,
|
||||
)
|
||||
after_process_result, message = await self._invoke_message_hook(
|
||||
"chat.receive.after_process",
|
||||
message,
|
||||
)
|
||||
if after_process_result.aborted:
|
||||
logger.info(f"消息 {message.message_id} 在预处理后被 Hook 中止")
|
||||
return
|
||||
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
|
||||
# 平台层的 @ 检测由底层 is_mentioned_bot_in_message 统一处理;此处不做用户名硬编码匹配
|
||||
|
||||
|
||||
@@ -18,28 +18,29 @@ install(extra_lines=3)
|
||||
logger = get_logger("sender")
|
||||
|
||||
# WebUI 聊天室的消息广播器(延迟导入避免循环依赖)
|
||||
_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str]]] = None
|
||||
_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str], Optional[str]]] = None
|
||||
|
||||
# 虚拟群 ID 前缀(与 chat_routes.py 保持一致)
|
||||
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
|
||||
|
||||
|
||||
# TODO: 重构完成后完成webui相关
|
||||
def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str]]:
|
||||
def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str], Optional[str]]:
|
||||
"""获取 WebUI 聊天室广播器。
|
||||
|
||||
Returns:
|
||||
Tuple[Any, Optional[str]]: ``(chat_manager, platform_name)`` 二元组;
|
||||
Tuple[Any, Optional[str], Optional[str]]: ``(chat_manager, platform_name, default_group_id)`` 三元组;
|
||||
若 WebUI 相关模块不可用,则元素会退化为 ``None``。
|
||||
"""
|
||||
global _webui_chat_broadcaster
|
||||
if _webui_chat_broadcaster is None:
|
||||
try:
|
||||
from src.webui.routers.chat import WEBUI_CHAT_PLATFORM, chat_manager
|
||||
from src.webui.routers.chat.service import WEBUI_CHAT_GROUP_ID
|
||||
|
||||
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM)
|
||||
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM, WEBUI_CHAT_GROUP_ID)
|
||||
except ImportError:
|
||||
_webui_chat_broadcaster = (None, None)
|
||||
_webui_chat_broadcaster = (None, None, None)
|
||||
return _webui_chat_broadcaster
|
||||
|
||||
|
||||
@@ -76,7 +77,7 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
|
||||
|
||||
try:
|
||||
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
|
||||
chat_manager, webui_platform = get_webui_chat_broadcaster()
|
||||
chat_manager, webui_platform, default_group_id = get_webui_chat_broadcaster()
|
||||
is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id)
|
||||
|
||||
if is_webui_message and chat_manager is not None:
|
||||
@@ -97,8 +98,9 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
|
||||
message_type = "rich"
|
||||
segments = message_segments
|
||||
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
await chat_manager.broadcast_to_group(
|
||||
group_id=group_id or default_group_id or "",
|
||||
message={
|
||||
"type": "bot_message",
|
||||
"content": message.processed_plain_text,
|
||||
"message_type": message_type,
|
||||
@@ -110,7 +112,7 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
|
||||
"avatar": None,
|
||||
"is_bot": True,
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库
|
||||
|
||||
@@ -1,22 +1,25 @@
|
||||
from datetime import datetime
|
||||
from sqlmodel import select
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
import difflib
|
||||
import json
|
||||
import re
|
||||
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.config.config import global_config
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.data_models.expression_data_model import MaiExpression
|
||||
from sqlmodel import select
|
||||
|
||||
from src.chat.utils.utils import is_bot_self
|
||||
from src.common.data_models.expression_data_model import MaiExpression
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_message import MessageUtils
|
||||
from src.config.config import global_config
|
||||
from src.plugin_runtime.hook_schema_utils import build_object_schema
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from .expression_utils import check_expression_suitability, parse_expression_response
|
||||
|
||||
@@ -34,8 +37,122 @@ summary_model = LLMServiceClient(task_name="utils", request_type="expression.sum
|
||||
check_model = LLMServiceClient(task_name="utils", request_type="expression.check")
|
||||
|
||||
|
||||
def register_expression_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
||||
"""注册表达方式系统内置 Hook 规格。
|
||||
|
||||
Args:
|
||||
registry: 目标 Hook 规格注册中心。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 实际注册的 Hook 规格列表。
|
||||
"""
|
||||
|
||||
return registry.register_hook_specs(
|
||||
[
|
||||
HookSpec(
|
||||
name="expression.select.before_select",
|
||||
description="表达方式选择流程开始前触发,可改写会话上下文、选择参数或中止本次选择。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"chat_id": {"type": "string", "description": "当前聊天流 ID。"},
|
||||
"chat_info": {"type": "string", "description": "用于选择表达方式的聊天上下文。"},
|
||||
"max_num": {"type": "integer", "description": "最大可选表达方式数量。"},
|
||||
"target_message": {"type": "string", "description": "当前目标回复消息文本。"},
|
||||
"reply_reason": {"type": "string", "description": "规划器给出的回复理由。"},
|
||||
"think_level": {"type": "integer", "description": "表达方式选择思考级别。"},
|
||||
},
|
||||
required=["chat_id", "chat_info", "max_num", "think_level"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="expression.select.after_selection",
|
||||
description="表达方式选择完成后触发,可改写最终选中的表达方式列表与 ID。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"chat_id": {"type": "string", "description": "当前聊天流 ID。"},
|
||||
"chat_info": {"type": "string", "description": "用于选择表达方式的聊天上下文。"},
|
||||
"max_num": {"type": "integer", "description": "最大可选表达方式数量。"},
|
||||
"target_message": {"type": "string", "description": "当前目标回复消息文本。"},
|
||||
"reply_reason": {"type": "string", "description": "规划器给出的回复理由。"},
|
||||
"think_level": {"type": "integer", "description": "表达方式选择思考级别。"},
|
||||
"selected_expressions": {
|
||||
"type": "array",
|
||||
"items": {"type": "object"},
|
||||
"description": "当前已选中的表达方式列表。",
|
||||
},
|
||||
"selected_expression_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "integer"},
|
||||
"description": "当前已选中的表达方式 ID 列表。",
|
||||
},
|
||||
},
|
||||
required=[
|
||||
"chat_id",
|
||||
"chat_info",
|
||||
"max_num",
|
||||
"think_level",
|
||||
"selected_expressions",
|
||||
"selected_expression_ids",
|
||||
],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="expression.learn.after_extract",
|
||||
description="表达方式学习解析出表达/黑话候选后触发,可改写候选集或直接终止本轮学习。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"session_id": {"type": "string", "description": "当前会话 ID。"},
|
||||
"message_count": {"type": "integer", "description": "本轮参与学习的消息数量。"},
|
||||
"expressions": {
|
||||
"type": "array",
|
||||
"items": {"type": "object"},
|
||||
"description": "解析出的表达方式候选列表。",
|
||||
},
|
||||
"jargon_entries": {
|
||||
"type": "array",
|
||||
"items": {"type": "object"},
|
||||
"description": "解析出的黑话候选列表。",
|
||||
},
|
||||
},
|
||||
required=["session_id", "message_count", "expressions", "jargon_entries"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="expression.learn.before_upsert",
|
||||
description="表达方式写入数据库前触发,可改写情景/风格文本或跳过本条写入。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"session_id": {"type": "string", "description": "当前会话 ID。"},
|
||||
"situation": {"type": "string", "description": "即将写入的情景文本。"},
|
||||
"style": {"type": "string", "description": "即将写入的风格文本。"},
|
||||
},
|
||||
required=["session_id", "situation", "style"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ExpressionLearner:
|
||||
def __init__(self, session_id: str) -> None:
|
||||
"""初始化表达方式学习器。
|
||||
|
||||
Args:
|
||||
session_id: 当前会话 ID。
|
||||
"""
|
||||
|
||||
self.session_id = session_id
|
||||
|
||||
# 学习锁,防止并发执行学习任务
|
||||
@@ -44,6 +161,110 @@ class ExpressionLearner:
|
||||
# 消息缓存
|
||||
self._messages_cache: List["SessionMessage"] = []
|
||||
|
||||
@staticmethod
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_expressions(expressions: List[Tuple[str, str, str]]) -> List[dict[str, str]]:
|
||||
"""将表达方式候选序列化为 Hook 载荷。
|
||||
|
||||
Args:
|
||||
expressions: 原始表达方式候选列表。
|
||||
|
||||
Returns:
|
||||
List[dict[str, str]]: 序列化后的表达方式候选。
|
||||
"""
|
||||
|
||||
return [
|
||||
{
|
||||
"situation": str(situation).strip(),
|
||||
"style": str(style).strip(),
|
||||
"source_id": str(source_id).strip(),
|
||||
}
|
||||
for situation, style, source_id in expressions
|
||||
if str(situation).strip() and str(style).strip()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_expressions(raw_expressions: Any) -> List[Tuple[str, str, str]]:
|
||||
"""从 Hook 载荷恢复表达方式候选列表。
|
||||
|
||||
Args:
|
||||
raw_expressions: Hook 返回的表达方式候选。
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, str]]: 恢复后的表达方式候选列表。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_expressions, list):
|
||||
return []
|
||||
|
||||
normalized_expressions: List[Tuple[str, str, str]] = []
|
||||
for raw_expression in raw_expressions:
|
||||
if not isinstance(raw_expression, dict):
|
||||
continue
|
||||
situation = str(raw_expression.get("situation") or "").strip()
|
||||
style = str(raw_expression.get("style") or "").strip()
|
||||
source_id = str(raw_expression.get("source_id") or "").strip()
|
||||
if not situation or not style:
|
||||
continue
|
||||
normalized_expressions.append((situation, style, source_id))
|
||||
return normalized_expressions
|
||||
|
||||
@staticmethod
|
||||
def _serialize_jargon_entries(jargon_entries: List[Tuple[str, str]]) -> List[dict[str, str]]:
|
||||
"""将黑话候选序列化为 Hook 载荷。
|
||||
|
||||
Args:
|
||||
jargon_entries: 原始黑话候选列表。
|
||||
|
||||
Returns:
|
||||
List[dict[str, str]]: 序列化后的黑话候选列表。
|
||||
"""
|
||||
|
||||
return [
|
||||
{
|
||||
"content": str(content).strip(),
|
||||
"source_id": str(source_id).strip(),
|
||||
}
|
||||
for content, source_id in jargon_entries
|
||||
if str(content).strip()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_jargon_entries(raw_jargon_entries: Any) -> List[Tuple[str, str]]:
|
||||
"""从 Hook 载荷恢复黑话候选列表。
|
||||
|
||||
Args:
|
||||
raw_jargon_entries: Hook 返回的黑话候选列表。
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str]]: 恢复后的黑话候选列表。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_jargon_entries, list):
|
||||
return []
|
||||
|
||||
normalized_entries: List[Tuple[str, str]] = []
|
||||
for raw_entry in raw_jargon_entries:
|
||||
if not isinstance(raw_entry, dict):
|
||||
continue
|
||||
content = str(raw_entry.get("content") or "").strip()
|
||||
source_id = str(raw_entry.get("source_id") or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
normalized_entries.append((content, source_id))
|
||||
return normalized_entries
|
||||
|
||||
def add_messages(self, messages: List["SessionMessage"]) -> None:
|
||||
"""添加消息到缓存"""
|
||||
self._messages_cache.extend(messages)
|
||||
@@ -52,8 +273,12 @@ class ExpressionLearner:
|
||||
"""获取当前消息缓存的大小"""
|
||||
return len(self._messages_cache)
|
||||
|
||||
async def learn(self, jargon_miner: Optional["JargonMiner"] = None):
|
||||
"""学习主流程"""
|
||||
async def learn(self, jargon_miner: Optional["JargonMiner"] = None) -> None:
|
||||
"""执行表达方式学习主流程。
|
||||
|
||||
Args:
|
||||
jargon_miner: 可选的黑话学习器实例,用于同步处理黑话候选。
|
||||
"""
|
||||
if not self._messages_cache:
|
||||
logger.debug("没有消息可供学习,跳过学习过程")
|
||||
return
|
||||
@@ -109,6 +334,25 @@ class ExpressionLearner:
|
||||
logger.info(f"黑话提取数量超过 30 个(实际{len(jargon_entries)}个),放弃本次黑话学习")
|
||||
jargon_entries = []
|
||||
|
||||
after_extract_result = await self._get_runtime_manager().invoke_hook(
|
||||
"expression.learn.after_extract",
|
||||
session_id=self.session_id,
|
||||
message_count=len(self._messages_cache),
|
||||
expressions=self._serialize_expressions(expressions),
|
||||
jargon_entries=self._serialize_jargon_entries(jargon_entries),
|
||||
)
|
||||
if after_extract_result.aborted:
|
||||
logger.info(f"{self.session_id} 的表达方式学习结果被 Hook 中止")
|
||||
return
|
||||
|
||||
after_extract_kwargs = after_extract_result.kwargs
|
||||
raw_expressions = after_extract_kwargs.get("expressions")
|
||||
if raw_expressions is not None:
|
||||
expressions = self._deserialize_expressions(raw_expressions)
|
||||
raw_jargon_entries = after_extract_kwargs.get("jargon_entries")
|
||||
if raw_jargon_entries is not None:
|
||||
jargon_entries = self._deserialize_jargon_entries(raw_jargon_entries)
|
||||
|
||||
# 处理黑话条目,路由到 jargon_miner(即使没有表达方式也要处理黑话)
|
||||
# TODO: 检测是否开启了
|
||||
if jargon_entries:
|
||||
@@ -135,6 +379,22 @@ class ExpressionLearner:
|
||||
|
||||
# 存储到数据库 Expression 表
|
||||
for situation, style in learnt_expressions:
|
||||
before_upsert_result = await self._get_runtime_manager().invoke_hook(
|
||||
"expression.learn.before_upsert",
|
||||
session_id=self.session_id,
|
||||
situation=situation,
|
||||
style=style,
|
||||
)
|
||||
if before_upsert_result.aborted:
|
||||
logger.info(f"{self.session_id} 的表达方式写入被 Hook 跳过: situation={situation!r}")
|
||||
continue
|
||||
|
||||
upsert_kwargs = before_upsert_result.kwargs
|
||||
situation = str(upsert_kwargs.get("situation", situation) or "").strip()
|
||||
style = str(upsert_kwargs.get("style", style) or "").strip()
|
||||
if not situation or not style:
|
||||
logger.info(f"{self.session_id} 的表达方式写入被 Hook 清空,已跳过")
|
||||
continue
|
||||
await self._upsert_expression_to_db(situation, style)
|
||||
|
||||
# ====== 黑话相关 ======
|
||||
|
||||
@@ -1,27 +1,109 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import json
|
||||
import time
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.learners.learner_utils_old import weighted_sample
|
||||
from src.chat.utils.common_utils import TempMethodsExpression
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.learners.learner_utils_old import weighted_sample
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
|
||||
class ExpressionSelector:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""初始化表达方式选择器。"""
|
||||
|
||||
self.llm_model = LLMServiceClient(
|
||||
task_name="utils", request_type="expression.selector"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
"""将任意值安全转换为整数。
|
||||
|
||||
Args:
|
||||
value: 待转换的值。
|
||||
default: 转换失败时的默认值。
|
||||
|
||||
Returns:
|
||||
int: 转换后的整数结果。
|
||||
"""
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _normalize_selected_expressions(raw_expressions: Any) -> List[Dict[str, Any]]:
|
||||
"""从 Hook 载荷恢复表达方式选择结果。
|
||||
|
||||
Args:
|
||||
raw_expressions: Hook 返回的表达方式列表。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 恢复后的表达方式列表。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_expressions, list):
|
||||
return []
|
||||
|
||||
normalized_expressions: List[Dict[str, Any]] = []
|
||||
for raw_expression in raw_expressions:
|
||||
if not isinstance(raw_expression, dict):
|
||||
continue
|
||||
expression_id = raw_expression.get("id")
|
||||
situation = str(raw_expression.get("situation") or "").strip()
|
||||
style = str(raw_expression.get("style") or "").strip()
|
||||
source_id = str(raw_expression.get("source_id") or "").strip()
|
||||
if not isinstance(expression_id, int) or not situation or not style or not source_id:
|
||||
continue
|
||||
normalized_expression = dict(raw_expression)
|
||||
normalized_expression["id"] = expression_id
|
||||
normalized_expression["situation"] = situation
|
||||
normalized_expression["style"] = style
|
||||
normalized_expression["source_id"] = source_id
|
||||
normalized_expressions.append(normalized_expression)
|
||||
return normalized_expressions
|
||||
|
||||
@staticmethod
|
||||
def _normalize_selected_expression_ids(raw_ids: Any, expressions: List[Dict[str, Any]]) -> List[int]:
|
||||
"""规范化最终选中的表达方式 ID 列表。
|
||||
|
||||
Args:
|
||||
raw_ids: Hook 返回的 ID 列表。
|
||||
expressions: 当前最终表达方式列表。
|
||||
|
||||
Returns:
|
||||
List[int]: 规范化后的 ID 列表。
|
||||
"""
|
||||
|
||||
if isinstance(raw_ids, list):
|
||||
normalized_ids = [item for item in raw_ids if isinstance(item, int)]
|
||||
if normalized_ids:
|
||||
return normalized_ids
|
||||
return [expression["id"] for expression in expressions if isinstance(expression.get("id"), int)]
|
||||
|
||||
def can_use_expression_for_chat(self, chat_id: str) -> bool:
|
||||
"""
|
||||
检查指定聊天流是否允许使用表达
|
||||
@@ -214,8 +296,7 @@ class ExpressionSelector:
|
||||
reply_reason: Optional[str] = None,
|
||||
think_level: int = 1,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
选择适合的表达方式(使用classic模式:随机选择+LLM选择)
|
||||
"""选择适合的表达方式。
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
@@ -233,11 +314,60 @@ class ExpressionSelector:
|
||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||
return [], []
|
||||
|
||||
before_select_result = await self._get_runtime_manager().invoke_hook(
|
||||
"expression.select.before_select",
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
max_num=max_num,
|
||||
target_message=target_message or "",
|
||||
reply_reason=reply_reason or "",
|
||||
think_level=think_level,
|
||||
)
|
||||
if before_select_result.aborted:
|
||||
logger.info(f"聊天流 {chat_id} 的表达方式选择被 Hook 中止")
|
||||
return [], []
|
||||
|
||||
before_select_kwargs = before_select_result.kwargs
|
||||
chat_id = str(before_select_kwargs.get("chat_id", chat_id) or "").strip() or chat_id
|
||||
chat_info = str(before_select_kwargs.get("chat_info", chat_info) or "")
|
||||
max_num = max(self._coerce_int(before_select_kwargs.get("max_num"), max_num), 1)
|
||||
raw_target_message = before_select_kwargs.get("target_message", target_message or "")
|
||||
target_message = str(raw_target_message or "").strip() or None
|
||||
raw_reply_reason = before_select_kwargs.get("reply_reason", reply_reason or "")
|
||||
reply_reason = str(raw_reply_reason or "").strip() or None
|
||||
think_level = self._coerce_int(before_select_kwargs.get("think_level"), think_level)
|
||||
|
||||
# 使用classic模式(随机选择+LLM选择)
|
||||
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式,think_level={think_level}")
|
||||
return await self._select_expressions_classic(
|
||||
selected_expressions, selected_ids = await self._select_expressions_classic(
|
||||
chat_id, chat_info, max_num, target_message, reply_reason, think_level
|
||||
)
|
||||
after_selection_result = await self._get_runtime_manager().invoke_hook(
|
||||
"expression.select.after_selection",
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
max_num=max_num,
|
||||
target_message=target_message or "",
|
||||
reply_reason=reply_reason or "",
|
||||
think_level=think_level,
|
||||
selected_expressions=[dict(item) for item in selected_expressions],
|
||||
selected_expression_ids=list(selected_ids),
|
||||
)
|
||||
if after_selection_result.aborted:
|
||||
logger.info(f"聊天流 {chat_id} 的表达方式选择结果被 Hook 中止")
|
||||
return [], []
|
||||
|
||||
after_selection_kwargs = after_selection_result.kwargs
|
||||
raw_selected_expressions = after_selection_kwargs.get("selected_expressions")
|
||||
if raw_selected_expressions is not None:
|
||||
selected_expressions = self._normalize_selected_expressions(raw_selected_expressions)
|
||||
selected_ids = self._normalize_selected_expression_ids(
|
||||
after_selection_kwargs.get("selected_expression_ids"),
|
||||
selected_expressions,
|
||||
)
|
||||
if selected_expressions:
|
||||
self.update_expressions_last_active_time(selected_expressions)
|
||||
return selected_expressions, selected_ids
|
||||
|
||||
async def _select_expressions_classic(
|
||||
self,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, Dict, List, Optional, Set, TypedDict
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, TypedDict
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
@@ -9,13 +9,15 @@ from json_repair import repair_json
|
||||
from sqlmodel import select
|
||||
|
||||
from src.common.data_models.jargon_data_model import MaiJargon
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Jargon
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.plugin_runtime.hook_schema_utils import build_object_schema
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from .expression_utils import is_single_char_jargon
|
||||
|
||||
@@ -35,8 +37,140 @@ class JargonMeaningEntry(TypedDict):
|
||||
meaning: str
|
||||
|
||||
|
||||
def register_jargon_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
||||
"""注册 jargon 系统内置 Hook 规格。
|
||||
|
||||
Args:
|
||||
registry: 目标 Hook 规格注册中心。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 实际注册的 Hook 规格列表。
|
||||
"""
|
||||
|
||||
return registry.register_hook_specs(
|
||||
[
|
||||
HookSpec(
|
||||
name="jargon.query.before_search",
|
||||
description="Maisaka 黑话查询工具执行检索前触发,可改写词条列表、检索参数或直接中止。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"words": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "准备查询的黑话词条列表。",
|
||||
},
|
||||
"session_id": {"type": "string", "description": "当前会话 ID。"},
|
||||
"limit": {"type": "integer", "description": "单个词条的最大返回条数。"},
|
||||
"case_sensitive": {"type": "boolean", "description": "是否大小写敏感。"},
|
||||
"enable_fuzzy_fallback": {"type": "boolean", "description": "是否允许精确命中失败后回退模糊检索。"},
|
||||
"abort_message": {"type": "string", "description": "Hook 主动中止时的失败提示。"},
|
||||
},
|
||||
required=["words", "session_id", "limit", "case_sensitive", "enable_fuzzy_fallback"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="jargon.query.after_search",
|
||||
description="Maisaka 黑话查询工具完成检索后触发,可改写结果列表或中止返回。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"words": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "实际查询的黑话词条列表。",
|
||||
},
|
||||
"session_id": {"type": "string", "description": "当前会话 ID。"},
|
||||
"limit": {"type": "integer", "description": "单个词条的最大返回条数。"},
|
||||
"case_sensitive": {"type": "boolean", "description": "是否大小写敏感。"},
|
||||
"enable_fuzzy_fallback": {"type": "boolean", "description": "是否启用了模糊检索回退。"},
|
||||
"results": {
|
||||
"type": "array",
|
||||
"items": {"type": "object"},
|
||||
"description": "查询结果列表。",
|
||||
},
|
||||
"abort_message": {"type": "string", "description": "Hook 主动中止时的失败提示。"},
|
||||
},
|
||||
required=["words", "session_id", "limit", "case_sensitive", "enable_fuzzy_fallback", "results"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="jargon.extract.before_persist",
|
||||
description="黑话条目准备写入数据库前触发,可改写去重后的条目列表或跳过本次持久化。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"session_id": {"type": "string", "description": "当前会话 ID。"},
|
||||
"session_name": {"type": "string", "description": "当前会话展示名称。"},
|
||||
"entries": {
|
||||
"type": "array",
|
||||
"items": {"type": "object"},
|
||||
"description": "即将持久化的黑话条目列表。",
|
||||
},
|
||||
},
|
||||
required=["session_id", "session_name", "entries"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="jargon.inference.before_finalize",
|
||||
description="黑话含义推断完成、写回数据库前触发,可改写最终判定与含义结果。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"session_id": {"type": "string", "description": "当前会话 ID。"},
|
||||
"session_name": {"type": "string", "description": "当前会话展示名称。"},
|
||||
"content": {"type": "string", "description": "当前黑话词条。"},
|
||||
"count": {"type": "integer", "description": "当前词条累计命中次数。"},
|
||||
"raw_content_list": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "用于推断的原始上下文片段列表。",
|
||||
},
|
||||
"inference_with_context": {"type": "object", "description": "基于上下文的推断结果。"},
|
||||
"inference_with_content_only": {"type": "object", "description": "仅基于词条内容的推断结果。"},
|
||||
"comparison_result": {"type": "object", "description": "比较阶段输出结果。"},
|
||||
"is_jargon": {"type": "boolean", "description": "当前推断是否判定为黑话。"},
|
||||
"meaning": {"type": "string", "description": "当前推断出的黑话含义。"},
|
||||
"is_complete": {"type": "boolean", "description": "当前是否已完成全部推断流程。"},
|
||||
"last_inference_count": {"type": "integer", "description": "本次推断完成后应写回的 last_inference_count。"},
|
||||
},
|
||||
required=[
|
||||
"session_id",
|
||||
"session_name",
|
||||
"content",
|
||||
"count",
|
||||
"raw_content_list",
|
||||
"inference_with_context",
|
||||
"inference_with_content_only",
|
||||
"comparison_result",
|
||||
"is_jargon",
|
||||
"meaning",
|
||||
"is_complete",
|
||||
"last_inference_count",
|
||||
],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class JargonMiner:
|
||||
def __init__(self, session_id: str, session_name: str) -> None:
|
||||
"""初始化黑话学习器。
|
||||
|
||||
Args:
|
||||
session_id: 当前会话 ID。
|
||||
session_name: 当前会话展示名称。
|
||||
"""
|
||||
|
||||
self.session_id = session_id
|
||||
self.session_name = session_name
|
||||
|
||||
@@ -46,13 +180,92 @@ class JargonMiner:
|
||||
# 黑话提取锁,防止并发执行
|
||||
self._extraction_lock = asyncio.Lock()
|
||||
|
||||
@staticmethod
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
"""将任意值安全转换为整数。
|
||||
|
||||
Args:
|
||||
value: 待转换的值。
|
||||
default: 转换失败时使用的默认值。
|
||||
|
||||
Returns:
|
||||
int: 转换后的整数结果。
|
||||
"""
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _serialize_jargon_entries(entries: List[JargonEntry]) -> List[Dict[str, object]]:
|
||||
"""将黑话条目列表序列化为 Hook 可传输结构。
|
||||
|
||||
Args:
|
||||
entries: 原始黑话条目列表。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, object]]: 序列化后的条目列表。
|
||||
"""
|
||||
|
||||
return [
|
||||
{
|
||||
"content": str(entry["content"]).strip(),
|
||||
"raw_content": sorted(str(item).strip() for item in entry["raw_content"] if str(item).strip()),
|
||||
}
|
||||
for entry in entries
|
||||
if str(entry["content"]).strip()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_jargon_entries(raw_entries: Any) -> List[JargonEntry]:
|
||||
"""从 Hook 载荷恢复黑话条目列表。
|
||||
|
||||
Args:
|
||||
raw_entries: Hook 返回的条目数据。
|
||||
|
||||
Returns:
|
||||
List[JargonEntry]: 恢复后的黑话条目列表。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_entries, list):
|
||||
return []
|
||||
|
||||
normalized_entries: List[JargonEntry] = []
|
||||
for raw_entry in raw_entries:
|
||||
if not isinstance(raw_entry, dict):
|
||||
continue
|
||||
content = str(raw_entry.get("content") or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
raw_content_values = raw_entry.get("raw_content")
|
||||
raw_content: Set[str] = set()
|
||||
if isinstance(raw_content_values, list):
|
||||
raw_content = {str(item).strip() for item in raw_content_values if str(item).strip()}
|
||||
normalized_entries.append({"content": content, "raw_content": raw_content})
|
||||
return normalized_entries
|
||||
|
||||
def get_cached_jargons(self) -> List[str]:
|
||||
"""获取缓存中的所有黑话列表"""
|
||||
return list(self.cache.keys())
|
||||
|
||||
async def infer_meaning(self, jargon_obj: MaiJargon) -> None:
|
||||
"""
|
||||
对jargon进行含义推断
|
||||
"""对黑话条目执行含义推断。
|
||||
|
||||
Args:
|
||||
jargon_obj: 待推断的黑话数据对象。
|
||||
"""
|
||||
content = jargon_obj.content
|
||||
# 解析raw_content列表
|
||||
@@ -175,15 +388,45 @@ class JargonMiner:
|
||||
is_similar = comparison_result.get("is_similar", False)
|
||||
is_jargon = not is_similar # 如果相似,说明不是黑话;如果有差异,说明是黑话
|
||||
|
||||
finalized_meaning = inference1.get("meaning", "") if is_jargon else ""
|
||||
is_complete = (jargon_obj.count or 0) >= 100
|
||||
last_inference_count = jargon_obj.count or 0
|
||||
finalize_result = await self._get_runtime_manager().invoke_hook(
|
||||
"jargon.inference.before_finalize",
|
||||
session_id=self.session_id,
|
||||
session_name=self.session_name,
|
||||
content=content,
|
||||
count=current_count,
|
||||
raw_content_list=list(raw_content_list),
|
||||
inference_with_context=dict(inference1),
|
||||
inference_with_content_only=dict(inference2),
|
||||
comparison_result=dict(comparison_result),
|
||||
is_jargon=is_jargon,
|
||||
meaning=finalized_meaning,
|
||||
is_complete=is_complete,
|
||||
last_inference_count=last_inference_count,
|
||||
)
|
||||
if finalize_result.aborted:
|
||||
logger.info(f"jargon {content} 的推断结果被 Hook 中止写回")
|
||||
return
|
||||
|
||||
finalize_kwargs = finalize_result.kwargs
|
||||
is_jargon = bool(finalize_kwargs.get("is_jargon", is_jargon))
|
||||
finalized_meaning = str(finalize_kwargs.get("meaning", finalized_meaning) or "").strip() if is_jargon else ""
|
||||
is_complete = bool(finalize_kwargs.get("is_complete", is_complete))
|
||||
last_inference_count = self._coerce_int(
|
||||
finalize_kwargs.get("last_inference_count"),
|
||||
last_inference_count,
|
||||
)
|
||||
|
||||
# 更新数据库记录
|
||||
jargon_obj.is_jargon = is_jargon
|
||||
jargon_obj.meaning = inference1.get("meaning", "") if is_jargon else ""
|
||||
jargon_obj.meaning = finalized_meaning
|
||||
# 更新最后一次判定的count值,避免重启后重复判定
|
||||
jargon_obj.last_inference_count = jargon_obj.count or 0
|
||||
jargon_obj.last_inference_count = last_inference_count
|
||||
|
||||
# 如果count>=100,标记为完成,不再进行推断
|
||||
if (jargon_obj.count or 0) >= 100:
|
||||
jargon_obj.is_complete = True
|
||||
jargon_obj.is_complete = is_complete
|
||||
|
||||
try:
|
||||
self._modify_jargon_entry(jargon_obj)
|
||||
@@ -232,6 +475,22 @@ class JargonMiner:
|
||||
merged_entries[content] = {"content": content, "raw_content": set(raw_list)}
|
||||
|
||||
uniq_entries: List[JargonEntry] = list(merged_entries.values())
|
||||
before_persist_result = await self._get_runtime_manager().invoke_hook(
|
||||
"jargon.extract.before_persist",
|
||||
session_id=self.session_id,
|
||||
session_name=self.session_name,
|
||||
entries=self._serialize_jargon_entries(uniq_entries),
|
||||
)
|
||||
if before_persist_result.aborted:
|
||||
logger.info(f"[{self.session_name}] 黑话提取结果被 Hook 中止,不写入数据库")
|
||||
return
|
||||
|
||||
raw_hook_entries = before_persist_result.kwargs.get("entries")
|
||||
if raw_hook_entries is not None:
|
||||
uniq_entries = self._deserialize_jargon_entries(raw_hook_entries)
|
||||
if not uniq_entries:
|
||||
logger.info(f"[{self.session_name}] Hook 过滤后没有可写入的黑话条目")
|
||||
return
|
||||
|
||||
saved = 0
|
||||
updated = 0
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from time import perf_counter
|
||||
from typing import List, Optional, Sequence
|
||||
from typing import Any, List, Optional, Sequence
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
@@ -26,6 +26,15 @@ from src.llm_models.model_client.base_client import BaseClient
|
||||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||||
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
|
||||
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput, ToolOption, normalize_tool_options
|
||||
from src.plugin_runtime.hook_payloads import (
|
||||
deserialize_prompt_messages,
|
||||
deserialize_tool_calls,
|
||||
serialize_prompt_messages,
|
||||
serialize_tool_calls,
|
||||
serialize_tool_definitions,
|
||||
)
|
||||
from src.plugin_runtime.hook_schema_utils import build_object_schema
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from .builtin_tools import get_builtin_tools
|
||||
@@ -58,6 +67,123 @@ class ToolFilterSelection(BaseModel):
|
||||
logger = get_logger("maisaka_chat_loop")
|
||||
|
||||
|
||||
def register_maisaka_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
||||
"""注册 Maisaka 规划器内置 Hook 规格。
|
||||
|
||||
Args:
|
||||
registry: 目标 Hook 规格注册中心。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 实际注册的 Hook 规格列表。
|
||||
"""
|
||||
|
||||
return registry.register_hook_specs(
|
||||
[
|
||||
HookSpec(
|
||||
name="maisaka.planner.before_request",
|
||||
description="在 Maisaka 向模型发起规划请求前触发,可改写消息窗口与工具定义。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"messages": {
|
||||
"type": "array",
|
||||
"description": "即将发给模型的 PromptMessage 列表。",
|
||||
},
|
||||
"tool_definitions": {
|
||||
"type": "array",
|
||||
"description": "当前候选工具定义列表。",
|
||||
},
|
||||
"selected_history_count": {
|
||||
"type": "integer",
|
||||
"description": "当前选中的上下文消息数量。",
|
||||
},
|
||||
"built_message_count": {
|
||||
"type": "integer",
|
||||
"description": "实际发送给模型的消息数量。",
|
||||
},
|
||||
"selection_reason": {
|
||||
"type": "string",
|
||||
"description": "上下文选择说明。",
|
||||
},
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "当前会话 ID。",
|
||||
},
|
||||
},
|
||||
required=[
|
||||
"messages",
|
||||
"tool_definitions",
|
||||
"selected_history_count",
|
||||
"built_message_count",
|
||||
"selection_reason",
|
||||
"session_id",
|
||||
],
|
||||
),
|
||||
default_timeout_ms=6000,
|
||||
allow_abort=False,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="maisaka.planner.after_response",
|
||||
description="在 Maisaka 收到模型响应后触发,可调整文本结果与工具调用列表。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"response": {
|
||||
"type": "string",
|
||||
"description": "模型返回的文本内容。",
|
||||
},
|
||||
"tool_calls": {
|
||||
"type": "array",
|
||||
"description": "模型返回的工具调用列表。",
|
||||
},
|
||||
"selected_history_count": {
|
||||
"type": "integer",
|
||||
"description": "当前选中的上下文消息数量。",
|
||||
},
|
||||
"built_message_count": {
|
||||
"type": "integer",
|
||||
"description": "实际发送给模型的消息数量。",
|
||||
},
|
||||
"selection_reason": {
|
||||
"type": "string",
|
||||
"description": "上下文选择说明。",
|
||||
},
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "当前会话 ID。",
|
||||
},
|
||||
"prompt_tokens": {
|
||||
"type": "integer",
|
||||
"description": "输入 Token 数。",
|
||||
},
|
||||
"completion_tokens": {
|
||||
"type": "integer",
|
||||
"description": "输出 Token 数。",
|
||||
},
|
||||
"total_tokens": {
|
||||
"type": "integer",
|
||||
"description": "总 Token 数。",
|
||||
},
|
||||
},
|
||||
required=[
|
||||
"response",
|
||||
"tool_calls",
|
||||
"selected_history_count",
|
||||
"built_message_count",
|
||||
"selection_reason",
|
||||
"session_id",
|
||||
"prompt_tokens",
|
||||
"completion_tokens",
|
||||
"total_tokens",
|
||||
],
|
||||
),
|
||||
default_timeout_ms=6000,
|
||||
allow_abort=False,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class MaisakaChatLoopService:
|
||||
"""负责 Maisaka 主对话循环、系统提示词和终端渲染。"""
|
||||
|
||||
@@ -105,6 +231,35 @@ class MaisakaChatLoopService:
|
||||
|
||||
return self._personality_prompt
|
||||
|
||||
@staticmethod
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
"""将任意值安全转换为整数。
|
||||
|
||||
Args:
|
||||
value: 待转换的输入值。
|
||||
default: 转换失败时的默认值。
|
||||
|
||||
Returns:
|
||||
int: 转换后的整数结果。
|
||||
"""
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
def _build_personality_prompt(self) -> str:
|
||||
"""构造人格提示词。"""
|
||||
|
||||
@@ -580,6 +735,26 @@ class MaisakaChatLoopService:
|
||||
else:
|
||||
all_tools = [*get_builtin_tools(), *self._extra_tools]
|
||||
|
||||
before_request_result = await self._get_runtime_manager().invoke_hook(
|
||||
"maisaka.planner.before_request",
|
||||
messages=serialize_prompt_messages(built_messages),
|
||||
tool_definitions=serialize_tool_definitions(all_tools),
|
||||
selected_history_count=len(selected_history),
|
||||
built_message_count=len(built_messages),
|
||||
selection_reason=selection_reason,
|
||||
session_id=self._session_id,
|
||||
)
|
||||
before_request_kwargs = before_request_result.kwargs
|
||||
raw_messages = before_request_kwargs.get("messages")
|
||||
if isinstance(raw_messages, list):
|
||||
try:
|
||||
built_messages = deserialize_prompt_messages(raw_messages)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Hook maisaka.planner.before_request 返回的 messages 无法反序列化,已忽略: {exc}")
|
||||
raw_tool_definitions = before_request_kwargs.get("tool_definitions")
|
||||
if isinstance(raw_tool_definitions, list):
|
||||
all_tools = [item for item in raw_tool_definitions if isinstance(item, dict)]
|
||||
|
||||
ordered_panels = PromptCLIVisualizer.build_prompt_panels(
|
||||
built_messages,
|
||||
image_display_mode=global_config.maisaka.terminal_image_display_mode,
|
||||
@@ -625,33 +800,63 @@ class MaisakaChatLoopService:
|
||||
)
|
||||
logger.info(f"本轮Prompt统计: {prompt_stats_text}")
|
||||
|
||||
final_response = generation_result.response or ""
|
||||
final_tool_calls = list(generation_result.tool_calls or [])
|
||||
after_response_result = await self._get_runtime_manager().invoke_hook(
|
||||
"maisaka.planner.after_response",
|
||||
response=final_response,
|
||||
tool_calls=serialize_tool_calls(final_tool_calls),
|
||||
selected_history_count=len(selected_history),
|
||||
built_message_count=len(built_messages),
|
||||
selection_reason=selection_reason,
|
||||
session_id=self._session_id,
|
||||
prompt_tokens=generation_result.prompt_tokens,
|
||||
completion_tokens=generation_result.completion_tokens,
|
||||
total_tokens=generation_result.total_tokens,
|
||||
)
|
||||
after_response_kwargs = after_response_result.kwargs
|
||||
if "response" in after_response_kwargs:
|
||||
final_response = str(after_response_kwargs.get("response") or "")
|
||||
raw_tool_calls = after_response_kwargs.get("tool_calls")
|
||||
if isinstance(raw_tool_calls, list):
|
||||
try:
|
||||
final_tool_calls = deserialize_tool_calls(raw_tool_calls)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Hook maisaka.planner.after_response 返回的 tool_calls 无法反序列化,已忽略: {exc}")
|
||||
prompt_tokens = self._coerce_int(after_response_kwargs.get("prompt_tokens"), generation_result.prompt_tokens)
|
||||
completion_tokens = self._coerce_int(
|
||||
after_response_kwargs.get("completion_tokens"),
|
||||
generation_result.completion_tokens,
|
||||
)
|
||||
total_tokens = self._coerce_int(after_response_kwargs.get("total_tokens"), generation_result.total_tokens)
|
||||
|
||||
tool_call_summaries = [
|
||||
{
|
||||
"调用编号": getattr(tool_call, "call_id", getattr(tool_call, "id", None)),
|
||||
"工具名": getattr(tool_call, "func_name", getattr(tool_call, "name", None)),
|
||||
"参数": getattr(tool_call, "args", getattr(tool_call, "arguments", None)),
|
||||
}
|
||||
for tool_call in (generation_result.tool_calls or [])
|
||||
for tool_call in final_tool_calls
|
||||
]
|
||||
logger.info(
|
||||
f"Maisaka 规划器返回结果: 内容={generation_result.response or ''!r} "
|
||||
f"Maisaka 规划器返回结果: 内容={final_response!r} "
|
||||
f"工具调用={tool_call_summaries}"
|
||||
)
|
||||
|
||||
raw_message = AssistantMessage(
|
||||
content=generation_result.response or "",
|
||||
content=final_response,
|
||||
timestamp=datetime.now(),
|
||||
tool_calls=generation_result.tool_calls or [],
|
||||
tool_calls=final_tool_calls,
|
||||
)
|
||||
return ChatResponse(
|
||||
content=generation_result.response,
|
||||
tool_calls=generation_result.tool_calls or [],
|
||||
content=final_response or None,
|
||||
tool_calls=final_tool_calls,
|
||||
raw_message=raw_message,
|
||||
selected_history_count=len(selected_history),
|
||||
prompt_tokens=generation_result.prompt_tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
built_message_count=len(built_messages),
|
||||
completion_tokens=generation_result.completion_tokens,
|
||||
total_tokens=generation_result.total_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -54,6 +54,84 @@ class MaisakaReasoningEngine:
|
||||
self._runtime = runtime
|
||||
self._last_reasoning_content: str = ""
|
||||
|
||||
@staticmethod
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_words(raw_words: Any) -> list[str]:
|
||||
"""清洗黑话查询词条列表。
|
||||
|
||||
Args:
|
||||
raw_words: 原始词条列表。
|
||||
|
||||
Returns:
|
||||
list[str]: 去重去空白后的词条列表。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_words, list):
|
||||
return []
|
||||
|
||||
normalized_words: list[str] = []
|
||||
seen_words: set[str] = set()
|
||||
for item in raw_words:
|
||||
if not isinstance(item, str):
|
||||
continue
|
||||
word = item.strip()
|
||||
if not word or word in seen_words:
|
||||
continue
|
||||
seen_words.add(word)
|
||||
normalized_words.append(word)
|
||||
return normalized_words
|
||||
|
||||
@staticmethod
|
||||
def _normalize_jargon_query_results(raw_results: Any) -> list[dict[str, object]]:
|
||||
"""规范化黑话查询结果列表。
|
||||
|
||||
Args:
|
||||
raw_results: Hook 返回的结果列表。
|
||||
|
||||
Returns:
|
||||
list[dict[str, object]]: 清洗后的结果列表。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_results, list):
|
||||
return []
|
||||
|
||||
normalized_results: list[dict[str, object]] = []
|
||||
for raw_item in raw_results:
|
||||
if not isinstance(raw_item, dict):
|
||||
continue
|
||||
word = str(raw_item.get("word") or "").strip()
|
||||
matches = raw_item.get("matches")
|
||||
normalized_matches: list[dict[str, str]] = []
|
||||
if isinstance(matches, list):
|
||||
for match in matches:
|
||||
if not isinstance(match, dict):
|
||||
continue
|
||||
content = str(match.get("content") or "").strip()
|
||||
meaning = str(match.get("meaning") or "").strip()
|
||||
if not content or not meaning:
|
||||
continue
|
||||
normalized_matches.append({"content": content, "meaning": meaning})
|
||||
|
||||
normalized_results.append(
|
||||
{
|
||||
"word": word,
|
||||
"found": bool(raw_item.get("found", bool(normalized_matches))),
|
||||
"matches": normalized_matches,
|
||||
}
|
||||
)
|
||||
return normalized_results
|
||||
|
||||
def build_builtin_tool_handlers(self) -> dict[str, "BuiltinToolHandler"]:
|
||||
"""构造 Maisaka 内置工具处理器映射。
|
||||
|
||||
@@ -1012,6 +1090,35 @@ class MaisakaReasoningEngine:
|
||||
"查询黑话工具至少需要一个非空词条。",
|
||||
)
|
||||
|
||||
limit = 5
|
||||
case_sensitive = False
|
||||
enable_fuzzy_fallback = True
|
||||
before_search_result = await self._get_runtime_manager().invoke_hook(
|
||||
"jargon.query.before_search",
|
||||
words=list(words),
|
||||
session_id=self._runtime.session_id,
|
||||
limit=limit,
|
||||
case_sensitive=case_sensitive,
|
||||
enable_fuzzy_fallback=enable_fuzzy_fallback,
|
||||
abort_message="黑话查询已被 Hook 中止。",
|
||||
)
|
||||
if before_search_result.aborted:
|
||||
abort_message = str(before_search_result.kwargs.get("abort_message") or "黑话查询已被 Hook 中止。").strip()
|
||||
return self._build_tool_failure_result(tool_call.func_name, abort_message or "黑话查询已被 Hook 中止。")
|
||||
|
||||
before_search_kwargs = before_search_result.kwargs
|
||||
if before_search_kwargs.get("words") is not None:
|
||||
words = self._normalize_words(before_search_kwargs.get("words"))
|
||||
if not words:
|
||||
return self._build_tool_failure_result(tool_call.func_name, "Hook 过滤后没有可查询的黑话词条。")
|
||||
try:
|
||||
limit = int(before_search_kwargs.get("limit", limit))
|
||||
except (TypeError, ValueError):
|
||||
limit = 5
|
||||
limit = max(limit, 1)
|
||||
case_sensitive = bool(before_search_kwargs.get("case_sensitive", case_sensitive))
|
||||
enable_fuzzy_fallback = bool(before_search_kwargs.get("enable_fuzzy_fallback", enable_fuzzy_fallback))
|
||||
|
||||
logger.info(f"{self._runtime.log_prefix} 已触发黑话查询: 词条={words!r}")
|
||||
|
||||
results: list[dict[str, object]] = []
|
||||
@@ -1019,17 +1126,19 @@ class MaisakaReasoningEngine:
|
||||
exact_matches = search_jargon(
|
||||
keyword=word,
|
||||
chat_id=self._runtime.session_id,
|
||||
limit=5,
|
||||
case_sensitive=False,
|
||||
limit=limit,
|
||||
case_sensitive=case_sensitive,
|
||||
fuzzy=False,
|
||||
)
|
||||
matched_entries = exact_matches or search_jargon(
|
||||
keyword=word,
|
||||
chat_id=self._runtime.session_id,
|
||||
limit=5,
|
||||
case_sensitive=False,
|
||||
fuzzy=True,
|
||||
)
|
||||
matched_entries = exact_matches
|
||||
if not matched_entries and enable_fuzzy_fallback:
|
||||
matched_entries = search_jargon(
|
||||
keyword=word,
|
||||
chat_id=self._runtime.session_id,
|
||||
limit=limit,
|
||||
case_sensitive=case_sensitive,
|
||||
fuzzy=True,
|
||||
)
|
||||
|
||||
results.append(
|
||||
{
|
||||
@@ -1039,6 +1148,27 @@ class MaisakaReasoningEngine:
|
||||
}
|
||||
)
|
||||
|
||||
after_search_result = await self._get_runtime_manager().invoke_hook(
|
||||
"jargon.query.after_search",
|
||||
words=list(words),
|
||||
session_id=self._runtime.session_id,
|
||||
limit=limit,
|
||||
case_sensitive=case_sensitive,
|
||||
enable_fuzzy_fallback=enable_fuzzy_fallback,
|
||||
results=list(results),
|
||||
abort_message="黑话查询结果已被 Hook 中止。",
|
||||
)
|
||||
if after_search_result.aborted:
|
||||
abort_message = str(after_search_result.kwargs.get("abort_message") or "黑话查询结果已被 Hook 中止。").strip()
|
||||
return self._build_tool_failure_result(
|
||||
tool_call.func_name,
|
||||
abort_message or "黑话查询结果已被 Hook 中止。",
|
||||
)
|
||||
|
||||
raw_results = after_search_result.kwargs.get("results")
|
||||
if raw_results is not None:
|
||||
results = self._normalize_jargon_query_results(raw_results)
|
||||
|
||||
logger.info(f"{self._runtime.log_prefix} 黑话查询完成: 结果={results!r}")
|
||||
return self._build_tool_success_result(
|
||||
tool_call.func_name,
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tuple, cast
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -908,5 +909,27 @@ class ComponentQueryService:
|
||||
return None
|
||||
return dict(registration.config_schema)
|
||||
|
||||
def list_hook_specs(self) -> list[dict[str, Any]]:
|
||||
"""返回当前运行时公开的 Hook 规格清单。
|
||||
|
||||
Returns:
|
||||
list[dict[str, Any]]: 可直接序列化给 WebUI 的 Hook 规格列表。
|
||||
"""
|
||||
|
||||
runtime_manager = self._get_runtime_manager()
|
||||
return [
|
||||
{
|
||||
"name": spec.name,
|
||||
"description": spec.description,
|
||||
"parameters_schema": deepcopy(spec.parameters_schema),
|
||||
"default_timeout_ms": spec.default_timeout_ms,
|
||||
"allow_blocking": spec.allow_blocking,
|
||||
"allow_observe": spec.allow_observe,
|
||||
"allow_abort": spec.allow_abort,
|
||||
"allow_kwargs_mutation": spec.allow_kwargs_mutation,
|
||||
}
|
||||
for spec in runtime_manager.list_hook_specs()
|
||||
]
|
||||
|
||||
|
||||
component_query_service = ComponentQueryService()
|
||||
|
||||
52
src/plugin_runtime/hook_catalog.py
Normal file
52
src/plugin_runtime/hook_catalog.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""内置命名 Hook 目录注册器。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import List
|
||||
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
|
||||
|
||||
HookSpecRegistrar = Callable[[HookSpecRegistry], List[HookSpec]]
|
||||
"""单个业务模块向注册中心写入 Hook 规格的注册器签名。"""
|
||||
|
||||
|
||||
def _get_builtin_hook_spec_registrars() -> List[HookSpecRegistrar]:
|
||||
"""返回当前内置 Hook 规格注册器列表。
|
||||
|
||||
Returns:
|
||||
List[HookSpecRegistrar]: 已启用的内置 Hook 注册器列表。
|
||||
"""
|
||||
|
||||
from src.chat.message_receive.bot import register_chat_hook_specs
|
||||
from src.chat.emoji_system.emoji_manager import register_emoji_hook_specs
|
||||
from src.learners.expression_learner import register_expression_hook_specs
|
||||
from src.learners.jargon_miner import register_jargon_hook_specs
|
||||
from src.maisaka.chat_loop_service import register_maisaka_hook_specs
|
||||
from src.services.send_service import register_send_service_hook_specs
|
||||
|
||||
return [
|
||||
register_chat_hook_specs,
|
||||
register_emoji_hook_specs,
|
||||
register_jargon_hook_specs,
|
||||
register_expression_hook_specs,
|
||||
register_send_service_hook_specs,
|
||||
register_maisaka_hook_specs,
|
||||
]
|
||||
|
||||
|
||||
def register_builtin_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
||||
"""向注册中心写入全部内置 Hook 规格。
|
||||
|
||||
Args:
|
||||
registry: 目标 Hook 规格注册中心。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 本次完成注册后的全部内置 Hook 规格。
|
||||
"""
|
||||
|
||||
registered_specs: List[HookSpec] = []
|
||||
for registrar in _get_builtin_hook_spec_registrars():
|
||||
registered_specs.extend(registrar(registry))
|
||||
return registered_specs
|
||||
178
src/plugin_runtime/hook_payloads.py
Normal file
178
src/plugin_runtime/hook_payloads.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""运行时 Hook 载荷序列化辅助。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Sequence
|
||||
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.data_models.llm_service_data_models import PromptMessage
|
||||
from src.llm_models.payload_content.message import Message
|
||||
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput, normalize_tool_options
|
||||
from src.plugin_runtime.host.message_utils import PluginMessageUtils
|
||||
|
||||
|
||||
def serialize_session_message(message: SessionMessage) -> Dict[str, Any]:
|
||||
"""将会话消息序列化为 Hook 可传输载荷。
|
||||
|
||||
Args:
|
||||
message: 待序列化的会话消息。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 可通过插件运行时传输的消息字典。
|
||||
"""
|
||||
|
||||
return dict(PluginMessageUtils._session_message_to_dict(message))
|
||||
|
||||
|
||||
def deserialize_session_message(raw_message: Any) -> SessionMessage:
|
||||
"""从 Hook 载荷恢复会话消息。
|
||||
|
||||
Args:
|
||||
raw_message: Hook 返回的消息字典。
|
||||
|
||||
Returns:
|
||||
SessionMessage: 恢复后的会话消息对象。
|
||||
|
||||
Raises:
|
||||
ValueError: 消息结构不合法时抛出。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_message, dict):
|
||||
raise ValueError("Hook 返回的 `message` 必须是字典")
|
||||
return PluginMessageUtils._build_session_message_from_dict(raw_message)
|
||||
|
||||
|
||||
def serialize_tool_calls(tool_calls: Sequence[ToolCall] | None) -> List[Dict[str, Any]]:
|
||||
"""将工具调用列表序列化为 Hook 可传输载荷。
|
||||
|
||||
Args:
|
||||
tool_calls: 原始工具调用列表。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 序列化后的工具调用列表。
|
||||
"""
|
||||
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
return [
|
||||
{
|
||||
"id": tool_call.call_id,
|
||||
"function": {
|
||||
"name": tool_call.func_name,
|
||||
"arguments": dict(tool_call.args or {}),
|
||||
},
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
|
||||
|
||||
def deserialize_tool_calls(raw_tool_calls: Any) -> List[ToolCall]:
|
||||
"""从 Hook 载荷恢复工具调用列表。
|
||||
|
||||
Args:
|
||||
raw_tool_calls: Hook 返回的工具调用列表。
|
||||
|
||||
Returns:
|
||||
List[ToolCall]: 恢复后的工具调用列表。
|
||||
|
||||
Raises:
|
||||
ValueError: 结构不合法时抛出。
|
||||
"""
|
||||
|
||||
if raw_tool_calls in (None, []):
|
||||
return []
|
||||
if not isinstance(raw_tool_calls, list):
|
||||
raise ValueError("Hook 返回的 `tool_calls` 必须是列表")
|
||||
|
||||
normalized_tool_calls: List[ToolCall] = []
|
||||
for raw_tool_call in raw_tool_calls:
|
||||
if not isinstance(raw_tool_call, dict):
|
||||
raise ValueError("Hook 返回的工具调用项必须是字典")
|
||||
|
||||
function_info = raw_tool_call.get("function", {})
|
||||
if isinstance(function_info, dict):
|
||||
function_name = function_info.get("name")
|
||||
function_arguments = function_info.get("arguments")
|
||||
else:
|
||||
function_name = raw_tool_call.get("name")
|
||||
function_arguments = raw_tool_call.get("arguments")
|
||||
|
||||
call_id = raw_tool_call.get("id") or raw_tool_call.get("call_id")
|
||||
if not isinstance(call_id, str) or not isinstance(function_name, str):
|
||||
raise ValueError("Hook 返回的工具调用缺少 `id` 或函数名称")
|
||||
|
||||
normalized_tool_calls.append(
|
||||
ToolCall(
|
||||
call_id=call_id,
|
||||
func_name=function_name,
|
||||
args=function_arguments if isinstance(function_arguments, dict) else {},
|
||||
)
|
||||
)
|
||||
return normalized_tool_calls
|
||||
|
||||
|
||||
def serialize_prompt_messages(messages: Sequence[Message]) -> List[PromptMessage]:
|
||||
"""将 LLM 消息列表序列化为 Hook 可传输载荷。
|
||||
|
||||
Args:
|
||||
messages: 原始 LLM 消息列表。
|
||||
|
||||
Returns:
|
||||
List[PromptMessage]: 序列化后的消息字典列表。
|
||||
"""
|
||||
|
||||
serialized_messages: List[PromptMessage] = []
|
||||
for message in messages:
|
||||
serialized_message: PromptMessage = {
|
||||
"role": message.role.value,
|
||||
"content": message.content,
|
||||
}
|
||||
if message.tool_call_id:
|
||||
serialized_message["tool_call_id"] = message.tool_call_id
|
||||
if message.tool_calls:
|
||||
serialized_message["tool_calls"] = serialize_tool_calls(message.tool_calls)
|
||||
serialized_messages.append(serialized_message)
|
||||
return serialized_messages
|
||||
|
||||
|
||||
def deserialize_prompt_messages(raw_messages: Any) -> List[Message]:
|
||||
"""从 Hook 载荷恢复 LLM 消息列表。
|
||||
|
||||
Args:
|
||||
raw_messages: Hook 返回的消息列表。
|
||||
|
||||
Returns:
|
||||
List[Message]: 恢复后的 LLM 消息列表。
|
||||
|
||||
Raises:
|
||||
ValueError: 结构不合法时抛出。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_messages, list):
|
||||
raise ValueError("Hook 返回的 `messages` 必须是列表")
|
||||
|
||||
from src.services.llm_service import _build_message_from_dict
|
||||
|
||||
normalized_messages: List[Message] = []
|
||||
for raw_message in raw_messages:
|
||||
if not isinstance(raw_message, dict):
|
||||
raise ValueError("Hook 返回的消息项必须是字典")
|
||||
normalized_messages.append(_build_message_from_dict(raw_message))
|
||||
return normalized_messages
|
||||
|
||||
|
||||
def serialize_tool_definitions(tool_definitions: Sequence[ToolDefinitionInput]) -> List[Dict[str, Any]]:
|
||||
"""将工具定义列表序列化为 Hook 可传输载荷。
|
||||
|
||||
Args:
|
||||
tool_definitions: 原始工具定义列表。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 序列化后的工具定义列表。
|
||||
"""
|
||||
|
||||
normalized_tool_options = normalize_tool_options(list(tool_definitions))
|
||||
if not normalized_tool_options:
|
||||
return []
|
||||
return [tool_option.to_openai_function_schema() for tool_option in normalized_tool_options]
|
||||
31
src/plugin_runtime/hook_schema_utils.py
Normal file
31
src/plugin_runtime/hook_schema_utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Hook 参数模型构造辅助。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Sequence
|
||||
|
||||
|
||||
def build_object_schema(
|
||||
properties: Dict[str, Dict[str, Any]],
|
||||
*,
|
||||
required: Sequence[str] | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""构造对象级 JSON Schema。
|
||||
|
||||
Args:
|
||||
properties: 字段定义映射。
|
||||
required: 必填字段名列表。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 标准化后的对象级 Schema。
|
||||
"""
|
||||
|
||||
schema: Dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": deepcopy(properties),
|
||||
}
|
||||
normalized_required = [str(item).strip() for item in (required or []) if str(item).strip()]
|
||||
if normalized_required:
|
||||
schema["required"] = normalized_required
|
||||
return schema
|
||||
@@ -18,9 +18,37 @@ import re
|
||||
from src.common.logger import get_logger
|
||||
from src.core.tooling import build_tool_detailed_description
|
||||
|
||||
from .hook_spec_registry import HookSpecRegistry
|
||||
|
||||
logger = get_logger("plugin_runtime.host.component_registry")
|
||||
|
||||
|
||||
class ComponentRegistrationError(ValueError):
|
||||
"""组件注册失败异常。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
component_name: str = "",
|
||||
component_type: str = "",
|
||||
plugin_id: str = "",
|
||||
) -> None:
|
||||
"""初始化组件注册失败异常。
|
||||
|
||||
Args:
|
||||
message: 原始错误信息。
|
||||
component_name: 组件名称。
|
||||
component_type: 组件类型。
|
||||
plugin_id: 插件 ID。
|
||||
"""
|
||||
|
||||
self.component_name = str(component_name or "").strip()
|
||||
self.component_type = str(component_type or "").strip()
|
||||
self.plugin_id = str(plugin_id or "").strip()
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ComponentTypes(str, Enum):
|
||||
ACTION = "ACTION"
|
||||
COMMAND = "COMMAND"
|
||||
@@ -359,7 +387,14 @@ class ComponentRegistry:
|
||||
供业务层查询可用组件、匹配命令、调度 action/event 等。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, hook_spec_registry: Optional[HookSpecRegistry] = None) -> None:
|
||||
"""初始化组件注册表。
|
||||
|
||||
Args:
|
||||
hook_spec_registry: 可选的 Hook 规格注册中心;提供后会在注册
|
||||
HookHandler 时执行规格校验。
|
||||
"""
|
||||
|
||||
# 全量索引
|
||||
self._components: Dict[str, ComponentEntry] = {} # full_name -> comp
|
||||
|
||||
@@ -370,6 +405,7 @@ class ComponentRegistry:
|
||||
|
||||
# 按插件索引
|
||||
self._by_plugin: Dict[str, List[ComponentEntry]] = {}
|
||||
self._hook_spec_registry = hook_spec_registry
|
||||
|
||||
@staticmethod
|
||||
def _convert_action_metadata_to_tool_metadata(
|
||||
@@ -475,77 +511,211 @@ class ComponentRegistry:
|
||||
type_dict.clear()
|
||||
self._by_plugin.clear()
|
||||
|
||||
# ====== 注册 / 注销 ======
|
||||
def register_component(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
|
||||
"""注册单个组件
|
||||
@staticmethod
|
||||
def _is_legacy_action_component(component: ComponentEntry) -> bool:
|
||||
"""判断组件是否为兼容旧 Action 的 Tool 条目。
|
||||
|
||||
Args:
|
||||
name: 组件名称(不含插件id前缀)
|
||||
component_type: 组件类型(如 `ACTION`、`COMMAND` 等)
|
||||
plugin_id: 插件id
|
||||
metadata: 组件元数据
|
||||
component: 待判断的组件条目。
|
||||
|
||||
Returns:
|
||||
success (bool): 是否成功注册(失败原因通常是组件类型无效)
|
||||
bool: 是否为兼容旧 Action 组件。
|
||||
"""
|
||||
|
||||
if not isinstance(component, ToolEntry):
|
||||
return False
|
||||
return str(component.metadata.get("legacy_component_type", "") or "").strip().upper() == "ACTION"
|
||||
|
||||
def _validate_hook_handler_entry(self, component: HookHandlerEntry) -> None:
|
||||
"""校验 HookHandler 是否满足已注册的 Hook 规格。
|
||||
|
||||
Args:
|
||||
component: 待校验的 HookHandler 条目。
|
||||
|
||||
Raises:
|
||||
ComponentRegistrationError: HookHandler 声明不合法时抛出。
|
||||
"""
|
||||
|
||||
if self._hook_spec_registry is None:
|
||||
return
|
||||
|
||||
hook_spec = self._hook_spec_registry.get_hook_spec(component.hook)
|
||||
if hook_spec is None:
|
||||
raise ComponentRegistrationError(
|
||||
f"HookHandler {component.full_name} 声明了未注册的 Hook: {component.hook}",
|
||||
component_name=component.name,
|
||||
component_type=component.component_type.value,
|
||||
plugin_id=component.plugin_id,
|
||||
)
|
||||
|
||||
if component.is_blocking and not hook_spec.allow_blocking:
|
||||
raise ComponentRegistrationError(
|
||||
f"HookHandler {component.full_name} 不能注册为 blocking:Hook {component.hook} 不允许 blocking 处理器",
|
||||
component_name=component.name,
|
||||
component_type=component.component_type.value,
|
||||
plugin_id=component.plugin_id,
|
||||
)
|
||||
|
||||
if component.is_observe and not hook_spec.allow_observe:
|
||||
raise ComponentRegistrationError(
|
||||
f"HookHandler {component.full_name} 不能注册为 observe:Hook {component.hook} 不允许 observe 处理器",
|
||||
component_name=component.name,
|
||||
component_type=component.component_type.value,
|
||||
plugin_id=component.plugin_id,
|
||||
)
|
||||
|
||||
if component.error_policy == "abort" and not hook_spec.allow_abort:
|
||||
raise ComponentRegistrationError(
|
||||
f"HookHandler {component.full_name} 不能使用 error_policy=abort:Hook {component.hook} 不允许 abort",
|
||||
component_name=component.name,
|
||||
component_type=component.component_type.value,
|
||||
plugin_id=component.plugin_id,
|
||||
)
|
||||
|
||||
def _build_component_entry(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
) -> ComponentEntry:
|
||||
"""根据声明构造组件条目。
|
||||
|
||||
Args:
|
||||
name: 组件名称。
|
||||
component_type: 组件类型。
|
||||
plugin_id: 插件 ID。
|
||||
metadata: 组件元数据。
|
||||
|
||||
Returns:
|
||||
ComponentEntry: 已构造并完成校验的组件条目。
|
||||
|
||||
Raises:
|
||||
ComponentRegistrationError: 组件声明不合法时抛出。
|
||||
"""
|
||||
|
||||
try:
|
||||
normalized_type = self._normalize_component_type(component_type)
|
||||
normalized_metadata = dict(metadata)
|
||||
if normalized_type == ComponentTypes.ACTION:
|
||||
normalized_metadata = self._convert_action_metadata_to_tool_metadata(name, normalized_metadata)
|
||||
comp = ToolEntry(name, ComponentTypes.TOOL.value, plugin_id, normalized_metadata)
|
||||
component = ToolEntry(name, ComponentTypes.TOOL.value, plugin_id, normalized_metadata)
|
||||
elif normalized_type == ComponentTypes.COMMAND:
|
||||
comp = CommandEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
component = CommandEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
elif normalized_type == ComponentTypes.TOOL:
|
||||
comp = ToolEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
component = ToolEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
elif normalized_type == ComponentTypes.EVENT_HANDLER:
|
||||
comp = EventHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
component = EventHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
elif normalized_type == ComponentTypes.HOOK_HANDLER:
|
||||
comp = HookHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
component = HookHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
self._validate_hook_handler_entry(component)
|
||||
elif normalized_type == ComponentTypes.MESSAGE_GATEWAY:
|
||||
comp = MessageGatewayEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
component = MessageGatewayEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
else:
|
||||
raise ValueError(f"组件类型 {component_type} 不存在")
|
||||
except ValueError:
|
||||
logger.error(f"组件类型 {component_type} 不存在")
|
||||
return False
|
||||
raise ComponentRegistrationError(
|
||||
f"组件类型 {component_type} 不存在",
|
||||
component_name=name,
|
||||
component_type=component_type,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
except ComponentRegistrationError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise ComponentRegistrationError(
|
||||
str(exc),
|
||||
component_name=name,
|
||||
component_type=component_type,
|
||||
plugin_id=plugin_id,
|
||||
) from exc
|
||||
|
||||
if comp.full_name in self._components:
|
||||
logger.warning(f"组件 {comp.full_name} 已存在,覆盖")
|
||||
old_comp = self._components[comp.full_name]
|
||||
# 从 _by_plugin 列表中移除旧条目,防止幽灵组件堆积
|
||||
old_list = self._by_plugin.get(old_comp.plugin_id)
|
||||
if old_list is not None:
|
||||
with contextlib.suppress(ValueError):
|
||||
old_list.remove(old_comp)
|
||||
# 从旧类型索引中移除,防止类型变更时幽灵残留
|
||||
if old_type_dict := self._by_type.get(old_comp.component_type):
|
||||
old_type_dict.pop(comp.full_name, None)
|
||||
return component
|
||||
|
||||
self._components[comp.full_name] = comp
|
||||
self._by_type[comp.component_type][comp.full_name] = comp
|
||||
self._by_plugin.setdefault(plugin_id, []).append(comp)
|
||||
def _remove_existing_component_entry(self, component: ComponentEntry) -> None:
|
||||
"""移除同名旧组件条目。
|
||||
|
||||
Args:
|
||||
component: 即将写入的新组件条目。
|
||||
"""
|
||||
|
||||
if component.full_name not in self._components:
|
||||
return
|
||||
|
||||
logger.warning(f"组件 {component.full_name} 已存在,覆盖")
|
||||
old_component = self._components[component.full_name]
|
||||
old_list = self._by_plugin.get(old_component.plugin_id)
|
||||
if old_list is not None:
|
||||
with contextlib.suppress(ValueError):
|
||||
old_list.remove(old_component)
|
||||
if old_type_dict := self._by_type.get(old_component.component_type):
|
||||
old_type_dict.pop(component.full_name, None)
|
||||
|
||||
def _add_component_entry(self, component: ComponentEntry) -> None:
|
||||
"""写入单个组件条目到全部索引。
|
||||
|
||||
Args:
|
||||
component: 待写入的组件条目。
|
||||
"""
|
||||
|
||||
self._remove_existing_component_entry(component)
|
||||
self._components[component.full_name] = component
|
||||
self._by_type[component.component_type][component.full_name] = component
|
||||
self._by_plugin.setdefault(component.plugin_id, []).append(component)
|
||||
|
||||
# ====== 注册 / 注销 ======
|
||||
def register_component(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
|
||||
"""注册单个组件。
|
||||
|
||||
Args:
|
||||
name: 组件名称(不含插件 ID 前缀)。
|
||||
component_type: 组件类型(如 ``ACTION``、``COMMAND`` 等)。
|
||||
plugin_id: 插件 ID。
|
||||
metadata: 组件元数据。
|
||||
|
||||
Returns:
|
||||
bool: 注册成功时恒为 ``True``。
|
||||
|
||||
Raises:
|
||||
ComponentRegistrationError: 组件声明不合法时抛出。
|
||||
"""
|
||||
|
||||
component = self._build_component_entry(name, component_type, plugin_id, metadata)
|
||||
self._add_component_entry(component)
|
||||
return True
|
||||
|
||||
def register_plugin_components(self, plugin_id: str, components: List[Dict[str, Any]]) -> int:
|
||||
"""批量注册一个插件的所有组件,返回成功注册数。
|
||||
"""批量替换一个插件的组件集合。
|
||||
|
||||
该方法会先完整校验所有组件声明,只有全部通过后才会替换旧组件,
|
||||
从而避免插件进入半注册状态。
|
||||
|
||||
Args:
|
||||
plugin_id (str): 插件id
|
||||
components (List[Dict[str, Any]]): 组件字典列表,每个组件包含 name, component_type, metadata 等字段
|
||||
plugin_id: 插件 ID。
|
||||
components: 组件声明字典列表。
|
||||
|
||||
Returns:
|
||||
count (int): 成功注册的组件数量
|
||||
int: 实际注册的组件数量。
|
||||
|
||||
Raises:
|
||||
ComponentRegistrationError: 任一组件声明不合法时抛出。
|
||||
"""
|
||||
count = 0
|
||||
for comp_data in components:
|
||||
ok = self.register_component(
|
||||
name=comp_data.get("name", ""),
|
||||
component_type=comp_data.get("component_type", ""),
|
||||
plugin_id=plugin_id,
|
||||
metadata=comp_data.get("metadata", {}),
|
||||
|
||||
prepared_components: List[ComponentEntry] = []
|
||||
for component_data in components:
|
||||
prepared_components.append(
|
||||
self._build_component_entry(
|
||||
name=str(component_data.get("name", "") or ""),
|
||||
component_type=str(component_data.get("component_type", "") or ""),
|
||||
plugin_id=plugin_id,
|
||||
metadata=component_data.get("metadata", {})
|
||||
if isinstance(component_data.get("metadata"), dict)
|
||||
else {},
|
||||
)
|
||||
)
|
||||
if ok:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
self.remove_components_by_plugin(plugin_id)
|
||||
for component in prepared_components:
|
||||
self._add_component_entry(component)
|
||||
return len(prepared_components)
|
||||
|
||||
def remove_components_by_plugin(self, plugin_id: str) -> int:
|
||||
"""移除某个插件的所有组件,返回移除数量。
|
||||
@@ -652,6 +822,17 @@ class ComponentRegistry:
|
||||
except ValueError:
|
||||
logger.error(f"组件类型 {component_type} 不存在")
|
||||
raise
|
||||
|
||||
if comp_type == ComponentTypes.ACTION:
|
||||
action_components = [
|
||||
component
|
||||
for component in self._by_type.get(ComponentTypes.TOOL, {}).values()
|
||||
if self._is_legacy_action_component(component)
|
||||
]
|
||||
if enabled_only:
|
||||
return [component for component in action_components if self.check_component_enabled(component, session_id)]
|
||||
return action_components
|
||||
|
||||
type_dict = self._by_type.get(comp_type, {})
|
||||
if enabled_only:
|
||||
return [c for c in type_dict.values() if self.check_component_enabled(c, session_id)]
|
||||
@@ -854,6 +1035,34 @@ class ComponentRegistry:
|
||||
tools.append(comp)
|
||||
return tools
|
||||
|
||||
def get_tools_for_llm(self, *, enabled_only: bool = True, session_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""兼容旧接口,返回可供 LLM 使用的工具条目列表。
|
||||
|
||||
Args:
|
||||
enabled_only: 是否仅返回启用的组件。
|
||||
session_id: 可选的会话 ID,若提供则考虑会话禁用状态。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 兼容旧结构的工具组件字典列表。
|
||||
"""
|
||||
|
||||
return [
|
||||
{
|
||||
"name": tool.full_name,
|
||||
"description": tool.description,
|
||||
"parameters": (
|
||||
dict(tool.parameters_raw)
|
||||
if isinstance(tool.parameters_raw, dict) and tool.parameters_raw
|
||||
else tool._get_parameters_schema() or {}
|
||||
),
|
||||
"parameters_raw": tool.parameters_raw,
|
||||
"enabled": tool.enabled,
|
||||
"plugin_id": tool.plugin_id,
|
||||
}
|
||||
for tool in self.get_tools(enabled_only=enabled_only, session_id=session_id)
|
||||
if not self._is_legacy_action_component(tool)
|
||||
]
|
||||
|
||||
# ====== 统计信息 ======
|
||||
def get_stats(self) -> StatusDict:
|
||||
"""获取注册统计。
|
||||
@@ -863,9 +1072,21 @@ class ComponentRegistry:
|
||||
"""
|
||||
return StatusDict(
|
||||
total=len(self._components),
|
||||
action=len(self._by_type[ComponentTypes.ACTION]),
|
||||
action=len(
|
||||
[
|
||||
component
|
||||
for component in self._by_type.get(ComponentTypes.TOOL, {}).values()
|
||||
if self._is_legacy_action_component(component)
|
||||
]
|
||||
),
|
||||
command=len(self._by_type[ComponentTypes.COMMAND]),
|
||||
tool=len(self._by_type[ComponentTypes.TOOL]),
|
||||
tool=len(
|
||||
[
|
||||
component
|
||||
for component in self._by_type.get(ComponentTypes.TOOL, {}).values()
|
||||
if not self._is_legacy_action_component(component)
|
||||
]
|
||||
),
|
||||
event_handler=len(self._by_type[ComponentTypes.EVENT_HANDLER]),
|
||||
hook_handler=len(self._by_type[ComponentTypes.HOOK_HANDLER]),
|
||||
message_gateway=len(self._by_type[ComponentTypes.MESSAGE_GATEWAY]),
|
||||
|
||||
@@ -26,6 +26,8 @@ import contextlib
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
from .hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .component_registry import HookHandlerEntry
|
||||
from .supervisor import PluginRunnerSupervisor
|
||||
@@ -33,29 +35,6 @@ if TYPE_CHECKING:
|
||||
logger = get_logger("plugin_runtime.host.hook_dispatcher")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class HookSpec:
|
||||
"""命名 Hook 的静态规格定义。
|
||||
|
||||
Attributes:
|
||||
name: Hook 的唯一名称。
|
||||
description: Hook 描述。
|
||||
default_timeout_ms: 默认超时毫秒数;为 `0` 时退回系统默认值。
|
||||
allow_blocking: 是否允许注册阻塞处理器。
|
||||
allow_observe: 是否允许注册观察处理器。
|
||||
allow_abort: 是否允许处理器中止当前 Hook 调用。
|
||||
allow_kwargs_mutation: 是否允许阻塞处理器修改 `kwargs`。
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str = ""
|
||||
default_timeout_ms: int = 0
|
||||
allow_blocking: bool = True
|
||||
allow_observe: bool = True
|
||||
allow_abort: bool = True
|
||||
allow_kwargs_mutation: bool = True
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class HookHandlerExecutionResult:
|
||||
"""单个 HookHandler 的执行结果。
|
||||
@@ -121,17 +100,19 @@ class HookDispatcher:
|
||||
def __init__(
|
||||
self,
|
||||
supervisors_provider: Optional[Callable[[], Sequence["PluginRunnerSupervisor"]]] = None,
|
||||
hook_spec_registry: Optional[HookSpecRegistry] = None,
|
||||
) -> None:
|
||||
"""初始化 Hook 分发器。
|
||||
|
||||
Args:
|
||||
supervisors_provider: 可选的 Supervisor 提供器。若调用 `invoke_hook()`
|
||||
时未显式传入 `supervisors`,则使用该回调获取目标 Supervisor 列表。
|
||||
hook_spec_registry: 可选的 Hook 规格注册中心;留空时使用独立注册中心。
|
||||
"""
|
||||
|
||||
self._background_tasks: Set[asyncio.Task[Any]] = set()
|
||||
self._hook_specs: Dict[str, HookSpec] = {}
|
||||
self._supervisors_provider = supervisors_provider
|
||||
self._hook_spec_registry = hook_spec_registry or HookSpecRegistry()
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止分发器并取消所有未完成的观察任务。"""
|
||||
@@ -148,16 +129,7 @@ class HookDispatcher:
|
||||
spec: 需要注册的 Hook 规格。
|
||||
"""
|
||||
|
||||
normalized_name = self._normalize_hook_name(spec.name)
|
||||
self._hook_specs[normalized_name] = HookSpec(
|
||||
name=normalized_name,
|
||||
description=spec.description,
|
||||
default_timeout_ms=max(int(spec.default_timeout_ms), 0),
|
||||
allow_blocking=bool(spec.allow_blocking),
|
||||
allow_observe=bool(spec.allow_observe),
|
||||
allow_abort=bool(spec.allow_abort),
|
||||
allow_kwargs_mutation=bool(spec.allow_kwargs_mutation),
|
||||
)
|
||||
self._hook_spec_registry.register_hook_spec(spec)
|
||||
|
||||
def register_hook_specs(self, specs: Sequence[HookSpec]) -> None:
|
||||
"""批量注册命名 Hook 规格。
|
||||
@@ -180,14 +152,37 @@ class HookDispatcher:
|
||||
"""
|
||||
|
||||
normalized_name = self._normalize_hook_name(hook_name)
|
||||
if normalized_name in self._hook_specs:
|
||||
return self._hook_specs[normalized_name]
|
||||
registered_spec = self._hook_spec_registry.get_hook_spec(normalized_name)
|
||||
if registered_spec is not None:
|
||||
return registered_spec
|
||||
|
||||
return HookSpec(
|
||||
name=normalized_name,
|
||||
parameters_schema={},
|
||||
default_timeout_ms=self._get_default_timeout_ms(),
|
||||
)
|
||||
|
||||
def unregister_hook_spec(self, hook_name: str) -> bool:
|
||||
"""注销指定命名 Hook 规格。
|
||||
|
||||
Args:
|
||||
hook_name: 目标 Hook 名称。
|
||||
|
||||
Returns:
|
||||
bool: 是否成功注销。
|
||||
"""
|
||||
|
||||
return self._hook_spec_registry.unregister_hook_spec(hook_name)
|
||||
|
||||
def list_hook_specs(self) -> List[HookSpec]:
|
||||
"""返回当前全部显式注册的 Hook 规格。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 已注册 Hook 规格列表。
|
||||
"""
|
||||
|
||||
return self._hook_spec_registry.list_hook_specs()
|
||||
|
||||
async def invoke_hook(
|
||||
self,
|
||||
hook_name: str,
|
||||
|
||||
190
src/plugin_runtime/host/hook_spec_registry.py
Normal file
190
src/plugin_runtime/host/hook_spec_registry.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""命名 Hook 规格注册中心。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class HookSpec:
|
||||
"""命名 Hook 的静态规格定义。
|
||||
|
||||
Attributes:
|
||||
name: Hook 的唯一名称。
|
||||
description: Hook 描述。
|
||||
parameters_schema: Hook 参数模型,使用对象级 JSON Schema 表示。
|
||||
default_timeout_ms: 默认超时毫秒数;为 ``0`` 时退回系统默认值。
|
||||
allow_blocking: 是否允许注册阻塞处理器。
|
||||
allow_observe: 是否允许注册观察处理器。
|
||||
allow_abort: 是否允许处理器中止当前 Hook 调用。
|
||||
allow_kwargs_mutation: 是否允许阻塞处理器修改 ``kwargs``。
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str = ""
|
||||
parameters_schema: Dict[str, Any] = field(default_factory=dict)
|
||||
default_timeout_ms: int = 0
|
||||
allow_blocking: bool = True
|
||||
allow_observe: bool = True
|
||||
allow_abort: bool = True
|
||||
allow_kwargs_mutation: bool = True
|
||||
|
||||
|
||||
class HookSpecRegistry:
|
||||
"""命名 Hook 规格注册中心。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化 Hook 规格注册中心。"""
|
||||
|
||||
self._hook_specs: Dict[str, HookSpec] = {}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_hook_name(hook_name: str) -> str:
|
||||
"""规范化 Hook 名称。
|
||||
|
||||
Args:
|
||||
hook_name: 原始 Hook 名称。
|
||||
|
||||
Returns:
|
||||
str: 规范化后的 Hook 名称。
|
||||
|
||||
Raises:
|
||||
ValueError: Hook 名称为空时抛出。
|
||||
"""
|
||||
|
||||
normalized_name = str(hook_name or "").strip()
|
||||
if not normalized_name:
|
||||
raise ValueError("Hook 名称不能为空")
|
||||
return normalized_name
|
||||
|
||||
@staticmethod
|
||||
def _normalize_parameters_schema(raw_schema: Any) -> Dict[str, Any]:
|
||||
"""规范化 Hook 参数模型。
|
||||
|
||||
Args:
|
||||
raw_schema: 原始参数模型。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 规范化后的对象级 JSON Schema。
|
||||
|
||||
Raises:
|
||||
ValueError: 参数模型不是合法对象级 Schema 时抛出。
|
||||
"""
|
||||
|
||||
if raw_schema is None:
|
||||
return {}
|
||||
if not isinstance(raw_schema, dict):
|
||||
raise ValueError("Hook 参数模型必须是字典")
|
||||
if not raw_schema:
|
||||
return {}
|
||||
|
||||
normalized_schema = deepcopy(raw_schema)
|
||||
schema_type = normalized_schema.get("type")
|
||||
properties = normalized_schema.get("properties")
|
||||
if schema_type not in {"", None, "object"} and properties is None:
|
||||
raise ValueError("Hook 参数模型必须是 object 类型或属性映射")
|
||||
if schema_type in {"", None} and properties is None:
|
||||
normalized_schema = {
|
||||
"type": "object",
|
||||
"properties": normalized_schema,
|
||||
}
|
||||
elif schema_type in {"", None}:
|
||||
normalized_schema["type"] = "object"
|
||||
|
||||
if normalized_schema.get("type") != "object":
|
||||
raise ValueError("Hook 参数模型必须是 object 类型")
|
||||
return normalized_schema
|
||||
|
||||
@classmethod
|
||||
def _normalize_spec(cls, spec: HookSpec) -> HookSpec:
|
||||
"""规范化 Hook 规格对象。
|
||||
|
||||
Args:
|
||||
spec: 原始 Hook 规格。
|
||||
|
||||
Returns:
|
||||
HookSpec: 规范化后的 Hook 规格副本。
|
||||
"""
|
||||
|
||||
return HookSpec(
|
||||
name=cls._normalize_hook_name(spec.name),
|
||||
description=str(spec.description or "").strip(),
|
||||
parameters_schema=cls._normalize_parameters_schema(spec.parameters_schema),
|
||||
default_timeout_ms=max(int(spec.default_timeout_ms), 0),
|
||||
allow_blocking=bool(spec.allow_blocking),
|
||||
allow_observe=bool(spec.allow_observe),
|
||||
allow_abort=bool(spec.allow_abort),
|
||||
allow_kwargs_mutation=bool(spec.allow_kwargs_mutation),
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空全部 Hook 规格。"""
|
||||
|
||||
self._hook_specs.clear()
|
||||
|
||||
def register_hook_spec(self, spec: HookSpec) -> HookSpec:
|
||||
"""注册单个 Hook 规格。
|
||||
|
||||
Args:
|
||||
spec: 需要注册的 Hook 规格。
|
||||
|
||||
Returns:
|
||||
HookSpec: 规范化后实际注册的 Hook 规格。
|
||||
"""
|
||||
|
||||
normalized_spec = self._normalize_spec(spec)
|
||||
self._hook_specs[normalized_spec.name] = normalized_spec
|
||||
return normalized_spec
|
||||
|
||||
def register_hook_specs(self, specs: Sequence[HookSpec]) -> List[HookSpec]:
|
||||
"""批量注册 Hook 规格。
|
||||
|
||||
Args:
|
||||
specs: 需要注册的 Hook 规格列表。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 规范化后实际注册的 Hook 规格列表。
|
||||
"""
|
||||
|
||||
return [self.register_hook_spec(spec) for spec in specs]
|
||||
|
||||
def unregister_hook_spec(self, hook_name: str) -> bool:
|
||||
"""注销指定 Hook 规格。
|
||||
|
||||
Args:
|
||||
hook_name: 目标 Hook 名称。
|
||||
|
||||
Returns:
|
||||
bool: 是否成功删除。
|
||||
"""
|
||||
|
||||
normalized_name = self._normalize_hook_name(hook_name)
|
||||
return self._hook_specs.pop(normalized_name, None) is not None
|
||||
|
||||
def get_hook_spec(self, hook_name: str) -> Optional[HookSpec]:
|
||||
"""获取指定 Hook 的显式规格。
|
||||
|
||||
Args:
|
||||
hook_name: 目标 Hook 名称。
|
||||
|
||||
Returns:
|
||||
Optional[HookSpec]: 已注册时返回规格副本,否则返回 ``None``。
|
||||
"""
|
||||
|
||||
normalized_name = self._normalize_hook_name(hook_name)
|
||||
spec = self._hook_specs.get(normalized_name)
|
||||
return None if spec is None else self._normalize_spec(spec)
|
||||
|
||||
def list_hook_specs(self) -> List[HookSpec]:
|
||||
"""返回当前全部 Hook 规格。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 按 Hook 名称升序排列的规格副本列表。
|
||||
"""
|
||||
|
||||
return [
|
||||
self._normalize_spec(spec)
|
||||
for _, spec in sorted(self._hook_specs.items(), key=lambda item: item[0])
|
||||
]
|
||||
@@ -27,6 +27,8 @@ from src.plugin_runtime.protocol.envelope import (
|
||||
ConfigUpdatedPayload,
|
||||
Envelope,
|
||||
HealthPayload,
|
||||
InspectPluginConfigPayload,
|
||||
InspectPluginConfigResultPayload,
|
||||
MessageGatewayStateUpdatePayload,
|
||||
MessageGatewayStateUpdateResultPayload,
|
||||
PROTOCOL_VERSION,
|
||||
@@ -52,6 +54,7 @@ from .capability_service import CapabilityService
|
||||
from .component_registry import ComponentRegistry
|
||||
from .event_dispatcher import EventDispatcher
|
||||
from .hook_dispatcher import HookDispatchResult, HookDispatcher
|
||||
from .hook_spec_registry import HookSpecRegistry
|
||||
from .logger_bridge import RunnerLogBridge
|
||||
from .message_gateway import MessageGateway
|
||||
from .rpc_server import RPCServer
|
||||
@@ -84,6 +87,7 @@ class PluginRunnerSupervisor:
|
||||
self,
|
||||
plugin_dirs: Optional[List[Path]] = None,
|
||||
group_name: str = "third_party",
|
||||
hook_spec_registry: Optional[HookSpecRegistry] = None,
|
||||
socket_path: Optional[str] = None,
|
||||
health_check_interval_sec: Optional[float] = None,
|
||||
max_restart_attempts: Optional[int] = None,
|
||||
@@ -94,6 +98,7 @@ class PluginRunnerSupervisor:
|
||||
Args:
|
||||
plugin_dirs: 由当前 Runner 负责加载的插件目录列表。
|
||||
group_name: 当前 Supervisor 所属运行时分组名称。
|
||||
hook_spec_registry: 可选的共享 Hook 规格注册中心。
|
||||
socket_path: 自定义 IPC 地址;留空时由传输层自动生成。
|
||||
health_check_interval_sec: 健康检查间隔,单位秒。
|
||||
max_restart_attempts: 自动重启 Runner 的最大次数。
|
||||
@@ -110,9 +115,12 @@ class PluginRunnerSupervisor:
|
||||
self._authorization = AuthorizationManager()
|
||||
self._capability_service = CapabilityService(self._authorization)
|
||||
self._api_registry = APIRegistry()
|
||||
self._component_registry = ComponentRegistry()
|
||||
self._component_registry = ComponentRegistry(hook_spec_registry=hook_spec_registry)
|
||||
self._event_dispatcher = EventDispatcher(self._component_registry)
|
||||
self._hook_dispatcher = HookDispatcher(lambda: [self])
|
||||
self._hook_dispatcher = HookDispatcher(
|
||||
lambda: [self],
|
||||
hook_spec_registry=hook_spec_registry,
|
||||
)
|
||||
self._message_gateway = MessageGateway(self._component_registry)
|
||||
self._log_bridge = RunnerLogBridge()
|
||||
|
||||
@@ -581,6 +589,49 @@ class PluginRunnerSupervisor:
|
||||
raise ValueError("插件配置校验失败")
|
||||
return dict(result.normalized_config)
|
||||
|
||||
async def inspect_plugin_config(
|
||||
self,
|
||||
plugin_id: str,
|
||||
config_data: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
use_provided_config: bool = False,
|
||||
) -> InspectPluginConfigResultPayload:
|
||||
"""请求 Runner 解析插件配置元数据。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
config_data: 可选的配置内容。
|
||||
use_provided_config: 是否优先使用传入配置而不是磁盘配置。
|
||||
|
||||
Returns:
|
||||
InspectPluginConfigResultPayload: 插件配置解析结果。
|
||||
|
||||
Raises:
|
||||
ValueError: Runner 无法解析插件或返回了错误响应时抛出。
|
||||
"""
|
||||
|
||||
payload = InspectPluginConfigPayload(
|
||||
config_data=config_data or {},
|
||||
use_provided_config=use_provided_config,
|
||||
)
|
||||
try:
|
||||
response = await self._rpc_server.send_request(
|
||||
"plugin.inspect_config",
|
||||
plugin_id=plugin_id,
|
||||
payload=payload.model_dump(),
|
||||
timeout_ms=10000,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise ValueError(f"插件配置解析请求失败: {exc}") from exc
|
||||
|
||||
if response.error:
|
||||
raise ValueError(str(response.error.get("message", "插件配置解析失败")))
|
||||
|
||||
result = InspectPluginConfigResultPayload.model_validate(response.payload)
|
||||
if not result.success:
|
||||
raise ValueError("插件配置解析失败")
|
||||
return result
|
||||
|
||||
def get_config_reload_subscribers(self, scope: str) -> List[str]:
|
||||
"""返回订阅指定全局配置广播的插件列表。
|
||||
|
||||
@@ -713,15 +764,25 @@ class PluginRunnerSupervisor:
|
||||
|
||||
component_declarations = [component.model_dump() for component in payload.components]
|
||||
runtime_components, api_components = self._split_component_declarations(component_declarations)
|
||||
self._component_registry.remove_components_by_plugin(payload.plugin_id)
|
||||
self._api_registry.remove_apis_by_plugin(payload.plugin_id)
|
||||
await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
|
||||
try:
|
||||
registered_count = self._component_registry.register_plugin_components(
|
||||
payload.plugin_id,
|
||||
runtime_components,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"插件 {payload.plugin_id} 组件注册失败: {exc}")
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_BAD_PAYLOAD.value,
|
||||
str(exc),
|
||||
details={
|
||||
"plugin_id": payload.plugin_id,
|
||||
"component_count": len(runtime_components),
|
||||
},
|
||||
)
|
||||
|
||||
registered_count = self._component_registry.register_plugin_components(
|
||||
payload.plugin_id,
|
||||
runtime_components,
|
||||
)
|
||||
self._api_registry.remove_apis_by_plugin(payload.plugin_id)
|
||||
registered_api_count = self._api_registry.register_plugin_apis(payload.plugin_id, api_components)
|
||||
await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
|
||||
self._registered_plugins[payload.plugin_id] = payload
|
||||
self._message_gateway_states[payload.plugin_id] = {}
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
|
||||
import tomlkit
|
||||
|
||||
@@ -32,14 +33,17 @@ from src.common.logger import get_logger
|
||||
from src.config.config import config_manager
|
||||
from src.config.file_watcher import FileChange, FileWatcher
|
||||
from src.platform_io import DeliveryBatch, InboundMessageEnvelope, get_platform_io_manager
|
||||
from src.plugin_runtime.hook_catalog import register_builtin_hook_specs
|
||||
from src.plugin_runtime.capabilities import (
|
||||
RuntimeComponentCapabilityMixin,
|
||||
RuntimeCoreCapabilityMixin,
|
||||
RuntimeDataCapabilityMixin,
|
||||
)
|
||||
from src.plugin_runtime.capabilities.registry import register_capability_impls
|
||||
from src.plugin_runtime.host.hook_dispatcher import HookDispatchResult, HookDispatcher, HookSpec
|
||||
from src.plugin_runtime.host.hook_dispatcher import HookDispatchResult, HookDispatcher
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
from src.plugin_runtime.host.message_utils import MessageDict, PluginMessageUtils
|
||||
from src.plugin_runtime.protocol.envelope import InspectPluginConfigResultPayload
|
||||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -87,7 +91,12 @@ class PluginRuntimeManager(
|
||||
self._manifest_validator: ManifestValidator = ManifestValidator()
|
||||
self._config_reload_callback: Callable[[Sequence[str]], Awaitable[None]] = self._handle_main_config_reload
|
||||
self._config_reload_callback_registered: bool = False
|
||||
self._hook_dispatcher: HookDispatcher = HookDispatcher(lambda: self.supervisors)
|
||||
self._hook_spec_registry: HookSpecRegistry = HookSpecRegistry()
|
||||
self._builtin_hook_specs_registered: bool = False
|
||||
self._hook_dispatcher: HookDispatcher = HookDispatcher(
|
||||
lambda: self.supervisors,
|
||||
hook_spec_registry=self._hook_spec_registry,
|
||||
)
|
||||
|
||||
async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None:
|
||||
"""接收 Platform IO 审核后的入站消息并送入主消息链。
|
||||
@@ -155,6 +164,33 @@ class PluginRuntimeManager(
|
||||
return ["third_party", "builtin"]
|
||||
return ["builtin", "third_party"]
|
||||
|
||||
@staticmethod
|
||||
def _instantiate_supervisor(supervisor_cls: Any, **kwargs: Any) -> Any:
|
||||
"""兼容不同构造签名地实例化 Supervisor。
|
||||
|
||||
Args:
|
||||
supervisor_cls: 目标 Supervisor 类。
|
||||
**kwargs: 期望传入的构造参数。
|
||||
|
||||
Returns:
|
||||
Any: 实例化后的 Supervisor。
|
||||
"""
|
||||
|
||||
signature = inspect.signature(supervisor_cls)
|
||||
accepts_var_keyword = any(
|
||||
parameter.kind == inspect.Parameter.VAR_KEYWORD
|
||||
for parameter in signature.parameters.values()
|
||||
)
|
||||
if accepts_var_keyword:
|
||||
return supervisor_cls(**kwargs)
|
||||
|
||||
supported_kwargs = {
|
||||
key: value
|
||||
for key, value in kwargs.items()
|
||||
if key in signature.parameters
|
||||
}
|
||||
return supervisor_cls(**supported_kwargs)
|
||||
|
||||
# ─── 生命周期 ─────────────────────────────────────────────
|
||||
|
||||
async def start(self) -> None:
|
||||
@@ -185,6 +221,7 @@ class PluginRuntimeManager(
|
||||
logger.info("未找到任何插件目录,跳过插件运行时启动")
|
||||
return
|
||||
|
||||
self.ensure_builtin_hook_specs_registered()
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
|
||||
# 从配置读取自定义 IPC socket 路径(留空则自动生成)
|
||||
@@ -196,17 +233,21 @@ class PluginRuntimeManager(
|
||||
|
||||
# 创建两个 Supervisor,各自拥有独立的 socket / Runner 子进程
|
||||
if builtin_dirs:
|
||||
self._builtin_supervisor = PluginSupervisor(
|
||||
self._builtin_supervisor = self._instantiate_supervisor(
|
||||
PluginSupervisor,
|
||||
plugin_dirs=builtin_dirs,
|
||||
group_name="builtin",
|
||||
hook_spec_registry=self._hook_spec_registry,
|
||||
socket_path=builtin_socket,
|
||||
)
|
||||
self._register_capability_impls(self._builtin_supervisor)
|
||||
|
||||
if third_party_dirs:
|
||||
self._third_party_supervisor = PluginSupervisor(
|
||||
self._third_party_supervisor = self._instantiate_supervisor(
|
||||
PluginSupervisor,
|
||||
plugin_dirs=third_party_dirs,
|
||||
group_name="third_party",
|
||||
hook_spec_registry=self._hook_spec_registry,
|
||||
socket_path=third_party_socket,
|
||||
)
|
||||
self._register_capability_impls(self._third_party_supervisor)
|
||||
@@ -328,6 +369,7 @@ class PluginRuntimeManager(
|
||||
spec: 需要注册的 Hook 规格。
|
||||
"""
|
||||
|
||||
self.ensure_builtin_hook_specs_registered()
|
||||
self._hook_dispatcher.register_hook_spec(spec)
|
||||
|
||||
def register_hook_specs(self, specs: Sequence[HookSpec]) -> None:
|
||||
@@ -337,8 +379,41 @@ class PluginRuntimeManager(
|
||||
specs: 需要注册的 Hook 规格序列。
|
||||
"""
|
||||
|
||||
self.ensure_builtin_hook_specs_registered()
|
||||
self._hook_dispatcher.register_hook_specs(specs)
|
||||
|
||||
def unregister_hook_spec(self, hook_name: str) -> bool:
|
||||
"""注销指定命名 Hook 规格。
|
||||
|
||||
Args:
|
||||
hook_name: 目标 Hook 名称。
|
||||
|
||||
Returns:
|
||||
bool: 是否成功注销。
|
||||
"""
|
||||
|
||||
self.ensure_builtin_hook_specs_registered()
|
||||
return self._hook_dispatcher.unregister_hook_spec(hook_name)
|
||||
|
||||
def list_hook_specs(self) -> List[HookSpec]:
|
||||
"""返回当前全部命名 Hook 规格。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 当前已注册的 Hook 规格列表。
|
||||
"""
|
||||
|
||||
self.ensure_builtin_hook_specs_registered()
|
||||
return self._hook_dispatcher.list_hook_specs()
|
||||
|
||||
def ensure_builtin_hook_specs_registered(self) -> None:
|
||||
"""确保内置 Hook 规格已经注册到共享中心表。"""
|
||||
|
||||
if self._builtin_hook_specs_registered:
|
||||
return
|
||||
|
||||
register_builtin_hook_specs(self._hook_spec_registry)
|
||||
self._builtin_hook_specs_registered = True
|
||||
|
||||
def _build_registered_dependency_map(self) -> Dict[str, Set[str]]:
|
||||
"""根据当前已注册插件构建全局依赖图。"""
|
||||
|
||||
@@ -542,8 +617,8 @@ class PluginRuntimeManager(
|
||||
config_data: 待校验的配置内容。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any] | None: 校验成功时返回规范化后的配置;若插件当前未加载
|
||||
或运行时不可用,则返回 ``None`` 以便调用方回退到静态 Schema 方案。
|
||||
Dict[str, Any] | None: 校验成功时返回规范化后的配置;若插件不存在、
|
||||
当前不可路由或运行时不可用,则返回 ``None`` 以便调用方回退到弱推断方案。
|
||||
|
||||
Raises:
|
||||
ValueError: 插件已加载,但配置校验失败时抛出。
|
||||
@@ -558,6 +633,8 @@ class PluginRuntimeManager(
|
||||
logger.warning(f"插件 {plugin_id} 配置校验路由失败,将回退到静态 Schema: {exc}")
|
||||
return None
|
||||
|
||||
if supervisor is None:
|
||||
supervisor = self._find_supervisor_by_plugin_directory(plugin_id)
|
||||
if supervisor is None:
|
||||
return None
|
||||
|
||||
@@ -569,6 +646,54 @@ class PluginRuntimeManager(
|
||||
logger.warning(f"插件 {plugin_id} 运行时配置校验不可用,将回退到静态 Schema: {exc}")
|
||||
return None
|
||||
|
||||
async def inspect_plugin_config(
|
||||
self,
|
||||
plugin_id: str,
|
||||
config_data: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
use_provided_config: bool = False,
|
||||
) -> InspectPluginConfigResultPayload | None:
|
||||
"""请求运行时解析插件配置元数据。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
config_data: 可选的配置内容。
|
||||
use_provided_config: 是否优先使用传入的配置内容而不是磁盘配置。
|
||||
|
||||
Returns:
|
||||
InspectPluginConfigResultPayload | None: 解析成功时返回结构化结果;若插件
|
||||
当前不可路由或运行时不可用,则返回 ``None``。
|
||||
|
||||
Raises:
|
||||
ValueError: 插件存在,但运行时明确拒绝解析请求时抛出。
|
||||
"""
|
||||
|
||||
if not self._started:
|
||||
return None
|
||||
|
||||
try:
|
||||
supervisor = self._get_supervisor_for_plugin(plugin_id)
|
||||
except RuntimeError as exc:
|
||||
logger.warning(f"插件 {plugin_id} 配置解析路由失败: {exc}")
|
||||
return None
|
||||
|
||||
if supervisor is None:
|
||||
supervisor = self._find_supervisor_by_plugin_directory(plugin_id)
|
||||
if supervisor is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return await supervisor.inspect_plugin_config(
|
||||
plugin_id=plugin_id,
|
||||
config_data=config_data,
|
||||
use_provided_config=use_provided_config,
|
||||
)
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.warning(f"插件 {plugin_id} 配置解析不可用: {exc}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_config_reload_scopes(changed_scopes: Sequence[str]) -> tuple[str, ...]:
|
||||
"""规范化配置热重载范围列表。
|
||||
@@ -771,7 +896,15 @@ class PluginRuntimeManager(
|
||||
return matches[0] if matches else None
|
||||
|
||||
async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool:
|
||||
"""加载或重载单个插件,并为其补齐跨 Supervisor 外部依赖。"""
|
||||
"""加载或重载单个插件,并为其补齐跨 Supervisor 外部依赖。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
reason: 加载或重载原因。
|
||||
|
||||
Returns:
|
||||
bool: 插件最终是否处于已加载状态。
|
||||
"""
|
||||
|
||||
normalized_plugin_id = str(plugin_id or "").strip()
|
||||
if not normalized_plugin_id:
|
||||
@@ -789,11 +922,12 @@ class PluginRuntimeManager(
|
||||
if supervisor is None:
|
||||
return False
|
||||
|
||||
return await supervisor.reload_plugins(
|
||||
reloaded = await supervisor.reload_plugins(
|
||||
plugin_ids=[normalized_plugin_id],
|
||||
reason=reason,
|
||||
external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor),
|
||||
)
|
||||
return reloaded and normalized_plugin_id in supervisor.get_loaded_plugin_ids()
|
||||
|
||||
@classmethod
|
||||
def _find_duplicate_plugin_ids(cls, plugin_dirs: List[Path]) -> Dict[str, List[Path]]:
|
||||
@@ -920,15 +1054,16 @@ class PluginRuntimeManager(
|
||||
return None
|
||||
|
||||
def _refresh_plugin_config_watch_subscriptions(self) -> None:
|
||||
"""按当前已注册插件集合刷新 config.toml 的单插件订阅。
|
||||
"""按当前可识别插件集合刷新 config.toml 的单插件订阅。
|
||||
|
||||
当插件热重载后,插件集合或目录位置可能发生变化,因此需要重新对齐
|
||||
watcher 的订阅,确保每个插件配置变更只触发对应 plugin_id。
|
||||
这里不仅覆盖当前已注册插件,也覆盖已存在但暂未激活的合法插件。
|
||||
"""
|
||||
if self._plugin_file_watcher is None:
|
||||
return
|
||||
|
||||
desired_plugin_paths = dict(self._iter_registered_plugin_paths())
|
||||
desired_plugin_paths = dict(self._iter_watchable_plugin_paths())
|
||||
self._plugin_path_cache = desired_plugin_paths.copy()
|
||||
desired_config_paths = {
|
||||
plugin_id: plugin_path / "config.toml" for plugin_id, plugin_path in desired_plugin_paths.items()
|
||||
@@ -970,6 +1105,18 @@ class PluginRuntimeManager(
|
||||
if plugin_path := self._get_plugin_path_for_supervisor(supervisor, plugin_id):
|
||||
yield plugin_id, plugin_path
|
||||
|
||||
def _iter_watchable_plugin_paths(self) -> Iterable[Tuple[str, Path]]:
|
||||
"""迭代应被配置监听器追踪的插件目录。
|
||||
|
||||
Returns:
|
||||
Iterable[Tuple[str, Path]]: ``(plugin_id, plugin_path)`` 迭代器。
|
||||
"""
|
||||
|
||||
watchable_plugin_paths = dict(self._iter_discovered_plugin_paths(self._iter_plugin_dirs()))
|
||||
for plugin_id, plugin_path in self._iter_registered_plugin_paths():
|
||||
watchable_plugin_paths.setdefault(plugin_id, plugin_path)
|
||||
yield from watchable_plugin_paths.items()
|
||||
|
||||
def _get_plugin_config_path_for_supervisor(self, supervisor: Any, plugin_id: str) -> Optional[Path]:
|
||||
"""从指定 Supervisor 的插件目录中定位某个插件的 config.toml。"""
|
||||
plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
|
||||
@@ -993,18 +1140,43 @@ class PluginRuntimeManager(
|
||||
return
|
||||
|
||||
if supervisor is None:
|
||||
supervisor = self._find_supervisor_by_plugin_directory(plugin_id)
|
||||
if supervisor is None:
|
||||
return
|
||||
|
||||
plugin_is_loaded = plugin_id in getattr(supervisor, "_registered_plugins", {})
|
||||
|
||||
try:
|
||||
snapshot = await supervisor.inspect_plugin_config(plugin_id)
|
||||
except Exception as exc:
|
||||
logger.warning(f"插件 {plugin_id} 配置文件变更解析失败: {exc}")
|
||||
return
|
||||
|
||||
try:
|
||||
config_payload = self._load_plugin_config_for_supervisor(supervisor, plugin_id)
|
||||
delivered = await supervisor.notify_plugin_config_updated(
|
||||
plugin_id=plugin_id,
|
||||
config_data=config_payload,
|
||||
config_version="",
|
||||
config_scope="self",
|
||||
)
|
||||
if not delivered:
|
||||
logger.warning(f"插件 {plugin_id} 配置文件变更后通知失败")
|
||||
if plugin_is_loaded and snapshot.enabled:
|
||||
delivered = await supervisor.notify_plugin_config_updated(
|
||||
plugin_id=plugin_id,
|
||||
config_data=dict(snapshot.normalized_config),
|
||||
config_version="",
|
||||
config_scope="self",
|
||||
)
|
||||
if not delivered:
|
||||
logger.warning(f"插件 {plugin_id} 配置文件变更后通知失败")
|
||||
return
|
||||
|
||||
if plugin_is_loaded and not snapshot.enabled:
|
||||
reloaded = await self.reload_plugins_globally([plugin_id], reason="config_disabled")
|
||||
if not reloaded:
|
||||
logger.warning(f"插件 {plugin_id} 禁用配置已写入,但运行时卸载失败")
|
||||
return
|
||||
|
||||
if not snapshot.enabled:
|
||||
logger.info(f"插件 {plugin_id} 当前处于禁用状态,跳过自动加载")
|
||||
return
|
||||
|
||||
loaded = await self.load_plugin_globally(plugin_id, reason="config_enabled")
|
||||
if not loaded:
|
||||
logger.warning(f"插件 {plugin_id} 配置文件变更后自动加载失败")
|
||||
except Exception as exc:
|
||||
logger.warning(f"插件 {plugin_id} 配置文件变更处理失败: {exc}")
|
||||
|
||||
|
||||
@@ -288,6 +288,8 @@ class RunnerReadyPayload(BaseModel):
|
||||
"""已完成初始化的插件列表"""
|
||||
failed_plugins: List[str] = Field(default_factory=list, description="初始化失败的插件列表")
|
||||
"""初始化失败的插件列表"""
|
||||
inactive_plugins: List[str] = Field(default_factory=list, description="当前因禁用或依赖不可用而未激活的插件列表")
|
||||
"""当前因禁用或依赖不可用而未激活的插件列表"""
|
||||
|
||||
|
||||
# ====== 配置更新 ======
|
||||
@@ -311,6 +313,32 @@ class ValidatePluginConfigPayload(BaseModel):
|
||||
"""待校验的配置内容"""
|
||||
|
||||
|
||||
class InspectPluginConfigPayload(BaseModel):
|
||||
"""plugin.inspect_config 请求 payload。"""
|
||||
|
||||
config_data: Dict[str, Any] = Field(default_factory=dict, description="可选的配置内容")
|
||||
"""可选的配置内容"""
|
||||
use_provided_config: bool = Field(default=False, description="是否优先使用请求中携带的配置内容")
|
||||
"""是否优先使用请求中携带的配置内容"""
|
||||
|
||||
|
||||
class InspectPluginConfigResultPayload(BaseModel):
|
||||
"""plugin.inspect_config 响应 payload。"""
|
||||
|
||||
success: bool = Field(description="是否解析成功")
|
||||
"""是否解析成功"""
|
||||
default_config: Dict[str, Any] = Field(default_factory=dict, description="插件默认配置")
|
||||
"""插件默认配置"""
|
||||
config_schema: Dict[str, Any] = Field(default_factory=dict, description="插件配置 Schema")
|
||||
"""插件配置 Schema"""
|
||||
normalized_config: Dict[str, Any] = Field(default_factory=dict, description="归一化后的配置内容")
|
||||
"""归一化后的配置内容"""
|
||||
changed: bool = Field(default=False, description="是否在归一化过程中自动补齐或修正了配置")
|
||||
"""是否在归一化过程中自动补齐或修正了配置"""
|
||||
enabled: bool = Field(default=True, description="插件在当前配置下是否应被视为启用")
|
||||
"""插件在当前配置下是否应被视为启用"""
|
||||
|
||||
|
||||
class ValidatePluginConfigResultPayload(BaseModel):
|
||||
"""plugin.validate_config 响应 payload。"""
|
||||
|
||||
@@ -380,6 +408,8 @@ class ReloadPluginResultPayload(BaseModel):
|
||||
"""成功完成重载的插件列表"""
|
||||
unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表")
|
||||
"""本次已卸载的插件列表"""
|
||||
inactive_plugins: List[str] = Field(default_factory=list, description="本次处于未激活状态的插件列表")
|
||||
"""本次处于未激活状态的插件列表"""
|
||||
failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因")
|
||||
"""重载失败的插件及原因"""
|
||||
|
||||
@@ -395,6 +425,8 @@ class ReloadPluginsResultPayload(BaseModel):
|
||||
"""成功完成重载的插件列表"""
|
||||
unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表")
|
||||
"""本次已卸载的插件列表"""
|
||||
inactive_plugins: List[str] = Field(default_factory=list, description="本次处于未激活状态的插件列表")
|
||||
"""本次处于未激活状态的插件列表"""
|
||||
failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因")
|
||||
"""重载失败的插件及原因"""
|
||||
|
||||
|
||||
@@ -9,8 +9,10 @@
|
||||
6. 转发插件的能力调用到 Host
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Protocol, Set, Tuple, cast
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, cast
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
@@ -39,6 +41,8 @@ from src.plugin_runtime.protocol.envelope import (
|
||||
ConfigUpdatedPayload,
|
||||
Envelope,
|
||||
HealthPayload,
|
||||
InspectPluginConfigPayload,
|
||||
InspectPluginConfigResultPayload,
|
||||
InvokePayload,
|
||||
InvokeResultPayload,
|
||||
RegisterPluginPayload,
|
||||
@@ -141,6 +145,14 @@ class _ConfigAwarePlugin(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class PluginActivationStatus(str, Enum):
|
||||
"""描述插件激活结果。"""
|
||||
|
||||
LOADED = "loaded"
|
||||
INACTIVE = "inactive"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
def _install_shutdown_signal_handlers(
|
||||
mark_runner_shutting_down: Callable[[], None],
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
@@ -236,13 +248,43 @@ class PluginRunner:
|
||||
|
||||
# 4. 注入 PluginContext + 调用 on_load 生命周期钩子
|
||||
failed_plugins: Set[str] = set(self._loader.failed_plugins.keys())
|
||||
inactive_plugins: Set[str] = set()
|
||||
available_plugin_versions: Dict[str, str] = dict(self._external_available_plugins)
|
||||
for meta in plugins:
|
||||
ok = await self._activate_plugin(meta)
|
||||
if not ok:
|
||||
unsatisfied_dependencies = [
|
||||
dependency.id
|
||||
for dependency in meta.manifest.plugin_dependencies
|
||||
if dependency.id not in available_plugin_versions
|
||||
or not self._loader.manifest_validator.is_plugin_dependency_satisfied(
|
||||
dependency,
|
||||
available_plugin_versions[dependency.id],
|
||||
)
|
||||
]
|
||||
if unsatisfied_dependencies:
|
||||
if any(dependency_id in inactive_plugins for dependency_id in unsatisfied_dependencies):
|
||||
logger.info(
|
||||
f"插件 {meta.plugin_id} 依赖的插件当前未激活,跳过本次启动: {', '.join(unsatisfied_dependencies)}"
|
||||
)
|
||||
inactive_plugins.add(meta.plugin_id)
|
||||
continue
|
||||
failed_plugins.add(meta.plugin_id)
|
||||
continue
|
||||
|
||||
successful_plugins = [meta.plugin_id for meta in plugins if meta.plugin_id not in failed_plugins]
|
||||
await self._notify_ready(successful_plugins, sorted(failed_plugins))
|
||||
activation_status = await self._activate_plugin(meta)
|
||||
if activation_status == PluginActivationStatus.LOADED:
|
||||
available_plugin_versions[meta.plugin_id] = meta.version
|
||||
continue
|
||||
if activation_status == PluginActivationStatus.INACTIVE:
|
||||
inactive_plugins.add(meta.plugin_id)
|
||||
continue
|
||||
failed_plugins.add(meta.plugin_id)
|
||||
|
||||
successful_plugins = [
|
||||
meta.plugin_id
|
||||
for meta in plugins
|
||||
if meta.plugin_id not in failed_plugins and meta.plugin_id not in inactive_plugins
|
||||
]
|
||||
await self._notify_ready(successful_plugins, sorted(failed_plugins), sorted(inactive_plugins))
|
||||
|
||||
# 5. 等待直到收到关停信号
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
@@ -352,17 +394,17 @@ class PluginRunner:
|
||||
cast(_ContextAwarePlugin, instance)._set_context(ctx)
|
||||
logger.debug(f"已为插件 {plugin_id} 注入 PluginContext")
|
||||
|
||||
def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[Dict[str, Any]] = None) -> None:
|
||||
def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""在 Runner 侧为插件实例注入当前插件配置。
|
||||
|
||||
Args:
|
||||
meta: 插件元数据。
|
||||
config_data: 可选的配置数据;留空时自动从插件目录读取。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 归一化后的当前插件配置。
|
||||
"""
|
||||
instance = meta.instance
|
||||
if not hasattr(instance, "set_plugin_config"):
|
||||
return
|
||||
|
||||
raw_config = config_data if config_data is not None else self._load_plugin_config(meta.plugin_dir)
|
||||
plugin_config, should_persist = self._normalize_plugin_config(instance, raw_config)
|
||||
config_path = Path(meta.plugin_dir) / "config.toml"
|
||||
@@ -370,10 +412,12 @@ class PluginRunner:
|
||||
should_initialize_file = not config_path.exists() and bool(default_config)
|
||||
if should_persist or should_initialize_file:
|
||||
self._save_plugin_config(meta.plugin_dir, plugin_config)
|
||||
try:
|
||||
cast(_ConfigAwarePlugin, instance).set_plugin_config(plugin_config)
|
||||
except Exception as exc:
|
||||
logger.warning(f"插件 {meta.plugin_id} 配置注入失败: {exc}")
|
||||
if hasattr(instance, "set_plugin_config"):
|
||||
try:
|
||||
cast(_ConfigAwarePlugin, instance).set_plugin_config(plugin_config)
|
||||
except Exception as exc:
|
||||
logger.warning(f"插件 {meta.plugin_id} 配置注入失败: {exc}")
|
||||
return plugin_config
|
||||
|
||||
def _normalize_plugin_config(
|
||||
self,
|
||||
@@ -405,6 +449,33 @@ class PluginRunner:
|
||||
logger.warning(f"插件配置归一化失败,将回退为原始配置: {exc}")
|
||||
return normalized_config, False
|
||||
|
||||
@staticmethod
|
||||
def _is_plugin_enabled(config_data: Optional[Mapping[str, Any]]) -> bool:
|
||||
"""根据配置内容判断插件是否应被视为启用。
|
||||
|
||||
Args:
|
||||
config_data: 当前插件配置。
|
||||
|
||||
Returns:
|
||||
bool: 插件是否启用。
|
||||
"""
|
||||
|
||||
if not isinstance(config_data, Mapping):
|
||||
return True
|
||||
|
||||
plugin_section = config_data.get("plugin")
|
||||
if not isinstance(plugin_section, Mapping):
|
||||
return True
|
||||
|
||||
enabled_value = plugin_section.get("enabled", True)
|
||||
if isinstance(enabled_value, str):
|
||||
normalized_value = enabled_value.strip().lower()
|
||||
if normalized_value in {"0", "false", "no", "off"}:
|
||||
return False
|
||||
if normalized_value in {"1", "true", "yes", "on"}:
|
||||
return True
|
||||
return bool(enabled_value)
|
||||
|
||||
@staticmethod
|
||||
def _save_plugin_config(plugin_dir: str, config_data: Dict[str, Any]) -> None:
|
||||
"""将插件配置写回到 ``config.toml``。
|
||||
@@ -435,6 +506,99 @@ class PluginRunner:
|
||||
|
||||
return loaded if isinstance(loaded, dict) else {}
|
||||
|
||||
def _resolve_plugin_candidate(self, plugin_id: str) -> Tuple[Optional[PluginCandidate], Optional[str]]:
|
||||
"""解析指定插件的候选目录。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[PluginCandidate], Optional[str]]: 候选插件与错误信息。
|
||||
"""
|
||||
|
||||
candidates, duplicate_candidates = self._loader.discover_candidates(self._plugin_dirs)
|
||||
if plugin_id in duplicate_candidates:
|
||||
conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id])
|
||||
return None, f"检测到重复插件 ID: {conflict_paths}"
|
||||
|
||||
candidate = candidates.get(plugin_id)
|
||||
if candidate is None:
|
||||
return None, f"未找到插件: {plugin_id}"
|
||||
return candidate, None
|
||||
|
||||
def _resolve_plugin_meta_for_config_request(
|
||||
self,
|
||||
plugin_id: str,
|
||||
) -> Tuple[Optional[PluginMeta], bool, Optional[str]]:
|
||||
"""为配置相关请求解析插件元数据。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[PluginMeta], bool, Optional[str]]: 依次为插件元数据、
|
||||
是否为临时冷加载实例、以及错误信息。
|
||||
"""
|
||||
|
||||
loaded_meta = self._loader.get_plugin(plugin_id)
|
||||
if loaded_meta is not None:
|
||||
return loaded_meta, False, None
|
||||
|
||||
candidate, error_message = self._resolve_plugin_candidate(plugin_id)
|
||||
if candidate is None:
|
||||
return None, False, error_message
|
||||
|
||||
try:
|
||||
meta = self._loader.load_candidate(plugin_id, candidate)
|
||||
except Exception as exc:
|
||||
return None, False, str(exc)
|
||||
if meta is None:
|
||||
return None, False, "插件模块加载失败"
|
||||
return meta, True, None
|
||||
|
||||
def _inspect_plugin_config(
|
||||
self,
|
||||
meta: PluginMeta,
|
||||
*,
|
||||
config_data: Optional[Dict[str, Any]] = None,
|
||||
use_provided_config: bool = False,
|
||||
suppress_errors: bool = True,
|
||||
) -> InspectPluginConfigResultPayload:
|
||||
"""解析插件代码定义的配置元数据。
|
||||
|
||||
Args:
|
||||
meta: 插件元数据。
|
||||
config_data: 可选的配置内容。
|
||||
use_provided_config: 是否优先使用传入的配置内容。
|
||||
suppress_errors: 是否在归一化失败时回退原始配置。
|
||||
|
||||
Returns:
|
||||
InspectPluginConfigResultPayload: 结构化解析结果。
|
||||
"""
|
||||
|
||||
raw_config = config_data if use_provided_config else self._load_plugin_config(meta.plugin_dir)
|
||||
if use_provided_config and config_data is None:
|
||||
raw_config = {}
|
||||
|
||||
normalized_config, changed = self._normalize_plugin_config(
|
||||
meta.instance,
|
||||
raw_config,
|
||||
suppress_errors=suppress_errors,
|
||||
)
|
||||
default_config = self._get_plugin_default_config(meta.instance)
|
||||
if not normalized_config and not raw_config and default_config:
|
||||
normalized_config = dict(default_config)
|
||||
changed = True
|
||||
|
||||
return InspectPluginConfigResultPayload(
|
||||
success=True,
|
||||
default_config=default_config,
|
||||
config_schema=self._get_plugin_config_schema(meta),
|
||||
normalized_config=normalized_config,
|
||||
changed=changed,
|
||||
enabled=self._is_plugin_enabled(normalized_config),
|
||||
)
|
||||
|
||||
def _register_handlers(self) -> None:
|
||||
"""注册 Host -> Runner 的方法处理器。"""
|
||||
self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke)
|
||||
@@ -448,6 +612,7 @@ class PluginRunner:
|
||||
self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown)
|
||||
self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown)
|
||||
self._rpc_client.register_method("plugin.config_updated", self._handle_config_updated)
|
||||
self._rpc_client.register_method("plugin.inspect_config", self._handle_inspect_plugin_config)
|
||||
self._rpc_client.register_method("plugin.validate_config", self._handle_validate_plugin_config)
|
||||
self._rpc_client.register_method("plugin.reload", self._handle_reload_plugin)
|
||||
self._rpc_client.register_method("plugin.reload_batch", self._handle_reload_plugins)
|
||||
@@ -579,6 +744,9 @@ class PluginRunner:
|
||||
)
|
||||
if response.error:
|
||||
raise RuntimeError(response.error.get("message", "插件注册失败"))
|
||||
response_payload = response.payload if isinstance(response.payload, dict) else {}
|
||||
if not bool(response_payload.get("accepted", True)):
|
||||
raise RuntimeError(str(response_payload.get("reason", "插件注册失败")))
|
||||
logger.info(f"插件 {meta.plugin_id} 注册完成")
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -689,36 +857,40 @@ class PluginRunner:
|
||||
except Exception as exc:
|
||||
logger.error(f"插件 {meta.plugin_id} on_unload 失败: {exc}", exc_info=True)
|
||||
|
||||
async def _activate_plugin(self, meta: PluginMeta) -> bool:
|
||||
async def _activate_plugin(self, meta: PluginMeta) -> PluginActivationStatus:
|
||||
"""完成插件注入、授权、生命周期和组件注册。
|
||||
|
||||
Args:
|
||||
meta: 待激活的插件元数据。
|
||||
|
||||
Returns:
|
||||
bool: 是否激活成功。
|
||||
PluginActivationStatus: 插件激活结果。
|
||||
"""
|
||||
self._inject_context(meta.plugin_id, meta.instance)
|
||||
self._apply_plugin_config(meta)
|
||||
plugin_config = self._apply_plugin_config(meta)
|
||||
if not self._is_plugin_enabled(plugin_config):
|
||||
logger.info(f"插件 {meta.plugin_id} 已在配置中禁用,跳过激活")
|
||||
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
||||
return PluginActivationStatus.INACTIVE
|
||||
|
||||
if not await self._bootstrap_plugin(meta):
|
||||
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
||||
return False
|
||||
return PluginActivationStatus.FAILED
|
||||
|
||||
if not await self._register_plugin(meta):
|
||||
await self._invoke_plugin_on_unload(meta)
|
||||
await self._deactivate_plugin(meta)
|
||||
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
||||
return False
|
||||
return PluginActivationStatus.FAILED
|
||||
|
||||
if not await self._invoke_plugin_on_load(meta):
|
||||
await self._unregister_plugin(meta.plugin_id, reason="on_load_failed")
|
||||
await self._deactivate_plugin(meta)
|
||||
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
||||
return False
|
||||
return PluginActivationStatus.FAILED
|
||||
|
||||
self._loader.set_loaded_plugin(meta)
|
||||
return True
|
||||
return PluginActivationStatus.LOADED
|
||||
|
||||
async def _unload_plugin(self, meta: PluginMeta, reason: str, *, purge_modules: bool = True) -> None:
|
||||
"""卸载单个插件并清理 Host/Runner 两侧状态。
|
||||
@@ -879,6 +1051,7 @@ class PluginRunner:
|
||||
requested_plugin_id=plugin_id,
|
||||
reloaded_plugins=batch_result.reloaded_plugins,
|
||||
unloaded_plugins=batch_result.unloaded_plugins,
|
||||
inactive_plugins=batch_result.inactive_plugins,
|
||||
failed_plugins=batch_result.failed_plugins,
|
||||
)
|
||||
|
||||
@@ -973,6 +1146,8 @@ class PluginRunner:
|
||||
},
|
||||
}
|
||||
reloaded_plugins: List[str] = []
|
||||
inactive_plugins: List[str] = []
|
||||
inactive_plugin_ids: Set[str] = set()
|
||||
|
||||
for load_plugin_id in load_order:
|
||||
if load_plugin_id in failed_plugins:
|
||||
@@ -983,10 +1158,28 @@ class PluginRunner:
|
||||
continue
|
||||
|
||||
_, manifest, _ = candidate
|
||||
unsatisfied_dependency_ids = [
|
||||
dependency.id
|
||||
for dependency in manifest.plugin_dependencies
|
||||
if dependency.id not in available_plugins
|
||||
or not self._loader.manifest_validator.is_plugin_dependency_satisfied(
|
||||
dependency,
|
||||
available_plugins[dependency.id],
|
||||
)
|
||||
]
|
||||
if unsatisfied_dependencies := self._loader.manifest_validator.get_unsatisfied_plugin_dependencies(
|
||||
manifest,
|
||||
available_plugin_versions=available_plugins,
|
||||
):
|
||||
if load_plugin_id not in reload_root_ids and any(
|
||||
dependency_id in inactive_plugin_ids for dependency_id in unsatisfied_dependency_ids
|
||||
):
|
||||
logger.info(
|
||||
f"插件 {load_plugin_id} 的依赖当前未激活,保留为未激活状态: {', '.join(unsatisfied_dependencies)}"
|
||||
)
|
||||
inactive_plugin_ids.add(load_plugin_id)
|
||||
inactive_plugins.append(load_plugin_id)
|
||||
continue
|
||||
failed_plugins[load_plugin_id] = f"依赖未满足: {', '.join(unsatisfied_dependencies)}"
|
||||
continue
|
||||
|
||||
@@ -996,9 +1189,13 @@ class PluginRunner:
|
||||
continue
|
||||
|
||||
activated = await self._activate_plugin(meta)
|
||||
if not activated:
|
||||
if activated == PluginActivationStatus.FAILED:
|
||||
failed_plugins[load_plugin_id] = "插件初始化失败"
|
||||
continue
|
||||
if activated == PluginActivationStatus.INACTIVE:
|
||||
inactive_plugin_ids.add(load_plugin_id)
|
||||
inactive_plugins.append(load_plugin_id)
|
||||
continue
|
||||
|
||||
available_plugins[load_plugin_id] = meta.version
|
||||
reloaded_plugins.append(load_plugin_id)
|
||||
@@ -1033,7 +1230,7 @@ class PluginRunner:
|
||||
rollback_failures[rollback_plugin_id] = str(exc)
|
||||
continue
|
||||
|
||||
if not restored:
|
||||
if restored != PluginActivationStatus.LOADED:
|
||||
rollback_failures[rollback_plugin_id] = "无法重新激活旧版本"
|
||||
|
||||
return ReloadPluginsResultPayload(
|
||||
@@ -1041,29 +1238,40 @@ class PluginRunner:
|
||||
requested_plugin_ids=normalized_plugin_ids,
|
||||
reloaded_plugins=[],
|
||||
unloaded_plugins=unloaded_plugins,
|
||||
inactive_plugins=[],
|
||||
failed_plugins=self._finalize_failed_reload_messages(failed_plugins, rollback_failures),
|
||||
)
|
||||
|
||||
requested_plugin_success = all(plugin_id in reloaded_plugins for plugin_id in reload_root_ids)
|
||||
requested_plugin_success = all(
|
||||
plugin_id in reloaded_plugins or plugin_id in inactive_plugins for plugin_id in reload_root_ids
|
||||
)
|
||||
|
||||
return ReloadPluginsResultPayload(
|
||||
success=requested_plugin_success and not failed_plugins,
|
||||
requested_plugin_ids=normalized_plugin_ids,
|
||||
reloaded_plugins=reloaded_plugins,
|
||||
unloaded_plugins=unloaded_plugins,
|
||||
inactive_plugins=inactive_plugins,
|
||||
failed_plugins=failed_plugins,
|
||||
)
|
||||
|
||||
async def _notify_ready(self, loaded_plugins: List[str], failed_plugins: List[str]) -> None:
|
||||
async def _notify_ready(
|
||||
self,
|
||||
loaded_plugins: List[str],
|
||||
failed_plugins: List[str],
|
||||
inactive_plugins: List[str],
|
||||
) -> None:
|
||||
"""通知 Host 当前 Runner 已完成插件初始化。
|
||||
|
||||
Args:
|
||||
loaded_plugins: 成功初始化的插件列表。
|
||||
failed_plugins: 初始化失败的插件列表。
|
||||
inactive_plugins: 因禁用或依赖不可用而未激活的插件列表。
|
||||
"""
|
||||
payload = RunnerReadyPayload(
|
||||
loaded_plugins=loaded_plugins,
|
||||
failed_plugins=failed_plugins,
|
||||
inactive_plugins=inactive_plugins,
|
||||
)
|
||||
await self._rpc_client.send_request(
|
||||
"runner.ready",
|
||||
@@ -1289,6 +1497,44 @@ class PluginRunner:
|
||||
return envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
|
||||
return envelope.make_response(payload={"acknowledged": True})
|
||||
|
||||
async def _handle_inspect_plugin_config(self, envelope: Envelope) -> Envelope:
|
||||
"""处理插件配置元数据解析请求。
|
||||
|
||||
Args:
|
||||
envelope: RPC 请求信封。
|
||||
|
||||
Returns:
|
||||
Envelope: RPC 响应信封。
|
||||
"""
|
||||
|
||||
try:
|
||||
payload = InspectPluginConfigPayload.model_validate(envelope.payload)
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
|
||||
plugin_id = envelope.plugin_id
|
||||
meta, is_temporary_meta, error_message = self._resolve_plugin_meta_for_config_request(plugin_id)
|
||||
if meta is None:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_PLUGIN_NOT_FOUND.value,
|
||||
error_message or f"未找到插件: {plugin_id}",
|
||||
)
|
||||
|
||||
try:
|
||||
result = self._inspect_plugin_config(
|
||||
meta,
|
||||
config_data=payload.config_data,
|
||||
use_provided_config=payload.use_provided_config,
|
||||
suppress_errors=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
finally:
|
||||
if is_temporary_meta:
|
||||
self._loader.purge_plugin_modules(plugin_id, meta.plugin_dir)
|
||||
|
||||
return envelope.make_response(payload=result.model_dump())
|
||||
|
||||
async def _handle_validate_plugin_config(self, envelope: Envelope) -> Envelope:
|
||||
"""处理插件配置校验请求。
|
||||
|
||||
@@ -1305,23 +1551,30 @@ class PluginRunner:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
|
||||
plugin_id = envelope.plugin_id
|
||||
meta = self._loader.get_plugin(plugin_id)
|
||||
meta, is_temporary_meta, error_message = self._resolve_plugin_meta_for_config_request(plugin_id)
|
||||
if meta is None:
|
||||
return envelope.make_error_response(ErrorCode.E_PLUGIN_NOT_FOUND.value, f"未找到插件: {plugin_id}")
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_PLUGIN_NOT_FOUND.value,
|
||||
error_message or f"未找到插件: {plugin_id}",
|
||||
)
|
||||
|
||||
try:
|
||||
normalized_config, changed = self._normalize_plugin_config(
|
||||
meta.instance,
|
||||
payload.config_data,
|
||||
inspection_result = self._inspect_plugin_config(
|
||||
meta,
|
||||
config_data=payload.config_data,
|
||||
use_provided_config=True,
|
||||
suppress_errors=False,
|
||||
)
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
finally:
|
||||
if is_temporary_meta:
|
||||
self._loader.purge_plugin_modules(plugin_id, meta.plugin_dir)
|
||||
|
||||
result = ValidatePluginConfigResultPayload(
|
||||
success=True,
|
||||
normalized_config=normalized_config,
|
||||
changed=changed,
|
||||
normalized_config=inspection_result.normalized_config,
|
||||
changed=inspection_result.changed,
|
||||
)
|
||||
return envelope.make_response(payload=result.model_dump())
|
||||
|
||||
|
||||
@@ -40,10 +40,213 @@ from src.common.utils.utils_message import MessageUtils
|
||||
from src.config.config import global_config
|
||||
from src.platform_io import DeliveryBatch, get_platform_io_manager
|
||||
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||
from src.plugin_runtime.hook_payloads import deserialize_session_message, serialize_session_message
|
||||
from src.plugin_runtime.hook_schema_utils import build_object_schema
|
||||
from src.plugin_runtime.host.hook_dispatcher import HookDispatchResult
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
|
||||
logger = get_logger("send_service")
|
||||
|
||||
|
||||
def register_send_service_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
||||
"""注册发送服务内置 Hook 规格。
|
||||
|
||||
Args:
|
||||
registry: 目标 Hook 规格注册中心。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 实际注册的 Hook 规格列表。
|
||||
"""
|
||||
|
||||
return registry.register_hook_specs(
|
||||
[
|
||||
HookSpec(
|
||||
name="send_service.after_build_message",
|
||||
description="在出站 SessionMessage 构建完成后触发,可改写消息体或取消发送。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"message": {
|
||||
"type": "object",
|
||||
"description": "待发送消息的序列化 SessionMessage。",
|
||||
},
|
||||
"stream_id": {
|
||||
"type": "string",
|
||||
"description": "目标会话 ID。",
|
||||
},
|
||||
"display_message": {
|
||||
"type": "string",
|
||||
"description": "展示层文本。",
|
||||
},
|
||||
"typing": {
|
||||
"type": "boolean",
|
||||
"description": "是否模拟打字。",
|
||||
},
|
||||
"set_reply": {
|
||||
"type": "boolean",
|
||||
"description": "是否附带引用回复。",
|
||||
},
|
||||
"storage_message": {
|
||||
"type": "boolean",
|
||||
"description": "发送成功后是否写库。",
|
||||
},
|
||||
"show_log": {
|
||||
"type": "boolean",
|
||||
"description": "是否输出发送日志。",
|
||||
},
|
||||
},
|
||||
required=[
|
||||
"message",
|
||||
"stream_id",
|
||||
"display_message",
|
||||
"typing",
|
||||
"set_reply",
|
||||
"storage_message",
|
||||
"show_log",
|
||||
],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="send_service.before_send",
|
||||
description="在真正调用 Platform IO 发送前触发,可改写消息或取消本次发送。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"message": {
|
||||
"type": "object",
|
||||
"description": "待发送消息的序列化 SessionMessage。",
|
||||
},
|
||||
"typing": {
|
||||
"type": "boolean",
|
||||
"description": "是否模拟打字。",
|
||||
},
|
||||
"set_reply": {
|
||||
"type": "boolean",
|
||||
"description": "是否附带引用回复。",
|
||||
},
|
||||
"reply_message_id": {
|
||||
"type": "string",
|
||||
"description": "被引用消息 ID。",
|
||||
},
|
||||
"storage_message": {
|
||||
"type": "boolean",
|
||||
"description": "发送成功后是否写库。",
|
||||
},
|
||||
"show_log": {
|
||||
"type": "boolean",
|
||||
"description": "是否输出发送日志。",
|
||||
},
|
||||
},
|
||||
required=["message", "typing", "set_reply", "storage_message", "show_log"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="send_service.after_send",
|
||||
description="在发送流程结束后触发,用于观察最终发送结果。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"message": {
|
||||
"type": "object",
|
||||
"description": "本次发送消息的序列化 SessionMessage。",
|
||||
},
|
||||
"sent": {
|
||||
"type": "boolean",
|
||||
"description": "本次发送是否成功。",
|
||||
},
|
||||
"typing": {
|
||||
"type": "boolean",
|
||||
"description": "是否模拟打字。",
|
||||
},
|
||||
"set_reply": {
|
||||
"type": "boolean",
|
||||
"description": "是否附带引用回复。",
|
||||
},
|
||||
"reply_message_id": {
|
||||
"type": "string",
|
||||
"description": "被引用消息 ID。",
|
||||
},
|
||||
"storage_message": {
|
||||
"type": "boolean",
|
||||
"description": "发送成功后是否写库。",
|
||||
},
|
||||
"show_log": {
|
||||
"type": "boolean",
|
||||
"description": "是否输出发送日志。",
|
||||
},
|
||||
},
|
||||
required=["message", "sent", "typing", "set_reply", "storage_message", "show_log"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=False,
|
||||
allow_kwargs_mutation=False,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
|
||||
def _coerce_bool(value: Any, default: bool) -> bool:
|
||||
"""将任意值安全转换为布尔值。
|
||||
|
||||
Args:
|
||||
value: 待转换的值。
|
||||
default: 当值为空时使用的默认值。
|
||||
|
||||
Returns:
|
||||
bool: 转换后的布尔值。
|
||||
"""
|
||||
|
||||
if value is None:
|
||||
return default
|
||||
return bool(value)
|
||||
|
||||
|
||||
async def _invoke_send_hook(
|
||||
hook_name: str,
|
||||
message: SessionMessage,
|
||||
**kwargs: Any,
|
||||
) -> tuple[HookDispatchResult, SessionMessage]:
|
||||
"""触发携带出站消息的命名 Hook。
|
||||
|
||||
Args:
|
||||
hook_name: 目标 Hook 名称。
|
||||
message: 当前待发送消息。
|
||||
**kwargs: 需要附带的额外参数。
|
||||
|
||||
Returns:
|
||||
tuple[HookDispatchResult, SessionMessage]: Hook 聚合结果以及可能被改写后的消息对象。
|
||||
"""
|
||||
|
||||
hook_result = await _get_runtime_manager().invoke_hook(
|
||||
hook_name,
|
||||
message=serialize_session_message(message),
|
||||
**kwargs,
|
||||
)
|
||||
mutated_message = message
|
||||
raw_message = hook_result.kwargs.get("message")
|
||||
if raw_message is not None:
|
||||
try:
|
||||
mutated_message = deserialize_session_message(raw_message)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Hook {hook_name} 返回的 message 无法反序列化,已忽略: {exc}")
|
||||
return hook_result, mutated_message
|
||||
|
||||
|
||||
def _inherit_platform_io_route_metadata(target_stream: BotChatSession) -> Dict[str, object]:
|
||||
"""从目标会话继承 Platform IO 路由元数据。
|
||||
|
||||
@@ -469,6 +672,27 @@ async def _send_via_platform_io(
|
||||
Returns:
|
||||
bool: 发送成功时返回 ``True``。
|
||||
"""
|
||||
before_send_result, message = await _invoke_send_hook(
|
||||
"send_service.before_send",
|
||||
message,
|
||||
typing=typing,
|
||||
set_reply=set_reply,
|
||||
reply_message_id=reply_message_id,
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
if before_send_result.aborted:
|
||||
logger.info(f"[SendService] 消息 {message.message_id} 在发送前被 Hook 中止")
|
||||
return False
|
||||
|
||||
before_kwargs = before_send_result.kwargs
|
||||
typing = _coerce_bool(before_kwargs.get("typing"), typing)
|
||||
set_reply = _coerce_bool(before_kwargs.get("set_reply"), set_reply)
|
||||
storage_message = _coerce_bool(before_kwargs.get("storage_message"), storage_message)
|
||||
show_log = _coerce_bool(before_kwargs.get("show_log"), show_log)
|
||||
raw_reply_message_id = before_kwargs.get("reply_message_id", reply_message_id)
|
||||
reply_message_id = None if raw_reply_message_id in {None, ""} else str(raw_reply_message_id)
|
||||
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
try:
|
||||
await platform_io_manager.ensure_send_pipeline_ready()
|
||||
@@ -500,6 +724,18 @@ async def _send_via_platform_io(
|
||||
logger.debug(traceback.format_exc())
|
||||
return False
|
||||
|
||||
sent = bool(delivery_batch.has_success)
|
||||
await _invoke_send_hook(
|
||||
"send_service.after_send",
|
||||
message,
|
||||
sent=sent,
|
||||
typing=typing,
|
||||
set_reply=set_reply,
|
||||
reply_message_id=reply_message_id,
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
|
||||
if delivery_batch.has_success:
|
||||
if storage_message:
|
||||
_store_sent_message(message)
|
||||
@@ -606,6 +842,26 @@ async def _send_to_target(
|
||||
if outbound_message is None:
|
||||
return False
|
||||
|
||||
after_build_result, outbound_message = await _invoke_send_hook(
|
||||
"send_service.after_build_message",
|
||||
outbound_message,
|
||||
stream_id=stream_id,
|
||||
display_message=display_message,
|
||||
typing=typing,
|
||||
set_reply=set_reply,
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
if after_build_result.aborted:
|
||||
logger.info(f"[SendService] 消息 {outbound_message.message_id} 在构建后被 Hook 中止")
|
||||
return False
|
||||
|
||||
after_build_kwargs = after_build_result.kwargs
|
||||
typing = _coerce_bool(after_build_kwargs.get("typing"), typing)
|
||||
set_reply = _coerce_bool(after_build_kwargs.get("set_reply"), set_reply)
|
||||
storage_message = _coerce_bool(after_build_kwargs.get("storage_message"), storage_message)
|
||||
show_log = _coerce_bool(after_build_kwargs.get("show_log"), show_log)
|
||||
|
||||
sent = await send_session_message(
|
||||
outbound_message,
|
||||
typing=typing,
|
||||
|
||||
@@ -9,6 +9,7 @@ from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.core import get_token_manager
|
||||
from src.webui.routers.websocket.auth import verify_ws_token
|
||||
from src.webui.routers.websocket.manager import websocket_manager
|
||||
|
||||
logger = get_logger("webui.logs_ws")
|
||||
router = APIRouter()
|
||||
@@ -148,24 +149,9 @@ async def broadcast_log(log_data: Dict):
|
||||
Args:
|
||||
log_data: 日志数据字典
|
||||
"""
|
||||
if not active_connections:
|
||||
return
|
||||
|
||||
# 格式化为 JSON
|
||||
message = json.dumps(log_data, ensure_ascii=False)
|
||||
|
||||
# 记录需要断开的连接
|
||||
disconnected = set()
|
||||
|
||||
# 广播到所有客户端
|
||||
for connection in active_connections:
|
||||
try:
|
||||
await connection.send_text(message)
|
||||
except Exception:
|
||||
# 发送失败,标记为断开
|
||||
disconnected.add(connection)
|
||||
|
||||
# 清理断开的连接
|
||||
if disconnected:
|
||||
active_connections.difference_update(disconnected)
|
||||
logger.debug(f"清理了 {len(disconnected)} 个断开的 WebSocket 连接")
|
||||
await websocket_manager.broadcast_to_topic(
|
||||
domain="logs",
|
||||
topic="main",
|
||||
event="entry",
|
||||
data={"entry": log_data},
|
||||
)
|
||||
|
||||
@@ -18,12 +18,10 @@ def get_all_routers() -> List[APIRouter]:
|
||||
from src.webui.api.replier import router as replier_router
|
||||
from src.webui.routers.chat import router as chat_router
|
||||
from src.webui.routers.knowledge import router as knowledge_router
|
||||
from src.webui.routers.websocket.logs import router as logs_router
|
||||
from src.webui.routes import router as main_router
|
||||
|
||||
return [
|
||||
main_router,
|
||||
logs_router,
|
||||
knowledge_router,
|
||||
chat_router,
|
||||
planner_router,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Tuple
|
||||
|
||||
from .routes import router
|
||||
from .support import WEBUI_CHAT_PLATFORM, ChatConnectionManager, chat_manager
|
||||
from .service import WEBUI_CHAT_PLATFORM, ChatConnectionManager, chat_manager
|
||||
|
||||
|
||||
def get_webui_chat_broadcaster() -> Tuple[ChatConnectionManager, str]:
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
"""本地聊天室路由 - WebUI 与麦麦直接对话。"""
|
||||
|
||||
import uuid
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import case, func
|
||||
from sqlmodel import col, select
|
||||
|
||||
@@ -13,16 +12,11 @@ from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
from .support import (
|
||||
from .service import (
|
||||
WEBUI_CHAT_GROUP_ID,
|
||||
WEBUI_CHAT_PLATFORM,
|
||||
authenticate_chat_websocket,
|
||||
chat_history,
|
||||
chat_manager,
|
||||
dispatch_chat_event,
|
||||
normalize_webui_user_id,
|
||||
resolve_initial_virtual_identity,
|
||||
send_initial_chat_state,
|
||||
)
|
||||
|
||||
logger = get_logger("webui.chat")
|
||||
@@ -113,55 +107,6 @@ async def clear_chat_history(
|
||||
return {"success": True, "message": f"已清空 {deleted} 条聊天记录"}
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_chat(
|
||||
websocket: WebSocket,
|
||||
user_id: Optional[str] = Query(default=None),
|
||||
user_name: Optional[str] = Query(default="WebUI用户"),
|
||||
platform: Optional[str] = Query(default=None),
|
||||
person_id: Optional[str] = Query(default=None),
|
||||
group_name: Optional[str] = Query(default=None),
|
||||
group_id: Optional[str] = Query(default=None),
|
||||
token: Optional[str] = Query(default=None),
|
||||
) -> None:
|
||||
"""WebSocket 聊天端点。"""
|
||||
if not await authenticate_chat_websocket(websocket, token):
|
||||
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
normalized_user_id = normalize_webui_user_id(user_id)
|
||||
current_user_name = user_name or "WebUI用户"
|
||||
current_virtual_config = resolve_initial_virtual_identity(platform, person_id, group_name, group_id)
|
||||
|
||||
await chat_manager.connect(websocket, session_id, normalized_user_id)
|
||||
try:
|
||||
await send_initial_chat_state(
|
||||
session_id=session_id,
|
||||
user_id=normalized_user_id,
|
||||
user_name=current_user_name,
|
||||
virtual_config=current_virtual_config,
|
||||
)
|
||||
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
current_user_name, current_virtual_config = await dispatch_chat_event(
|
||||
session_id=session_id,
|
||||
session_id_prefix=session_id[:8],
|
||||
data=data,
|
||||
current_user_name=current_user_name,
|
||||
normalized_user_id=normalized_user_id,
|
||||
current_virtual_config=current_virtual_config,
|
||||
)
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket 断开: session={session_id}, user={normalized_user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket 错误: {e}")
|
||||
finally:
|
||||
chat_manager.disconnect(session_id, normalized_user_id)
|
||||
|
||||
|
||||
@router.get("/info")
|
||||
async def get_chat_info() -> Dict[str, object]:
|
||||
"""获取聊天室信息。"""
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""WebUI 聊天路由支持逻辑。"""
|
||||
"""WebUI 聊天运行时服务。"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, cast
|
||||
|
||||
from fastapi import WebSocket
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import col, delete, select
|
||||
|
||||
@@ -17,8 +17,6 @@ from src.common.logger import get_logger
|
||||
from src.common.message_repository import find_messages
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.webui.core import get_token_manager
|
||||
from src.webui.routers.websocket.auth import verify_ws_token
|
||||
|
||||
logger = get_logger("webui.chat")
|
||||
|
||||
@@ -27,6 +25,8 @@ WEBUI_CHAT_PLATFORM = "webui"
|
||||
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
|
||||
WEBUI_USER_ID_PREFIX = "webui_user_"
|
||||
|
||||
AsyncMessageSender = Callable[[Dict[str, Any]], Awaitable[None]]
|
||||
|
||||
|
||||
class VirtualIdentityConfig(BaseModel):
|
||||
"""虚拟身份配置。"""
|
||||
@@ -52,13 +52,42 @@ class ChatHistoryMessage(BaseModel):
|
||||
is_bot: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatSessionConnection:
|
||||
"""逻辑聊天会话连接信息。"""
|
||||
|
||||
session_id: str
|
||||
connection_id: str
|
||||
client_session_id: str
|
||||
user_id: str
|
||||
user_name: str
|
||||
active_group_id: str
|
||||
virtual_config: Optional[VirtualIdentityConfig]
|
||||
sender: AsyncMessageSender
|
||||
|
||||
|
||||
class ChatHistoryManager:
|
||||
"""聊天历史管理器。"""
|
||||
|
||||
def __init__(self, max_messages: int = 200) -> None:
|
||||
"""初始化聊天历史管理器。
|
||||
|
||||
Args:
|
||||
max_messages: 内存中允许处理的最大消息数。
|
||||
"""
|
||||
self.max_messages = max_messages
|
||||
|
||||
def _message_to_dict(self, msg: SessionMessage, group_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""将内部消息对象转换为前端可消费的字典。
|
||||
|
||||
Args:
|
||||
msg: 内部统一消息对象。
|
||||
group_id: 当前会话所属的群组标识。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 面向 WebUI 的消息字典。
|
||||
"""
|
||||
del group_id
|
||||
user_info = msg.message_info.user_info
|
||||
user_id = user_info.user_id or ""
|
||||
is_bot = is_bot_self(msg.platform, user_id)
|
||||
@@ -74,10 +103,27 @@ class ChatHistoryManager:
|
||||
}
|
||||
|
||||
def _resolve_session_id(self, group_id: Optional[str]) -> str:
|
||||
"""根据群组标识解析聊天会话 ID。
|
||||
|
||||
Args:
|
||||
group_id: 群组标识。
|
||||
|
||||
Returns:
|
||||
str: 内部聊天会话 ID。
|
||||
"""
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
return SessionUtils.calculate_session_id(WEBUI_CHAT_PLATFORM, group_id=target_group_id)
|
||||
|
||||
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""获取指定会话的历史消息。
|
||||
|
||||
Args:
|
||||
limit: 最大返回条数。
|
||||
group_id: 群组标识。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 历史消息列表。
|
||||
"""
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
session_id = self._resolve_session_id(target_group_id)
|
||||
try:
|
||||
@@ -90,11 +136,19 @@ class ChatHistoryManager:
|
||||
result = [self._message_to_dict(msg, target_group_id) for msg in messages]
|
||||
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录 (group_id={target_group_id})")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载聊天记录失败: {e}")
|
||||
except Exception as exc:
|
||||
logger.error(f"从数据库加载聊天记录失败: {exc}")
|
||||
return []
|
||||
|
||||
def clear_history(self, group_id: Optional[str] = None) -> int:
|
||||
"""清空指定会话的历史消息。
|
||||
|
||||
Args:
|
||||
group_id: 群组标识。
|
||||
|
||||
Returns:
|
||||
int: 被删除的消息数量。
|
||||
"""
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
session_id = self._resolve_session_id(target_group_id)
|
||||
try:
|
||||
@@ -104,66 +158,245 @@ class ChatHistoryManager:
|
||||
deleted = result.rowcount or 0
|
||||
logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
|
||||
return deleted
|
||||
except Exception as e:
|
||||
logger.error(f"清空聊天记录失败: {e}")
|
||||
except Exception as exc:
|
||||
logger.error(f"清空聊天记录失败: {exc}")
|
||||
return 0
|
||||
|
||||
|
||||
class ChatConnectionManager:
|
||||
"""聊天连接管理器。"""
|
||||
"""统一聊天逻辑会话管理器。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.user_sessions: Dict[str, str] = {}
|
||||
"""初始化聊天逻辑会话管理器。"""
|
||||
self.active_connections: Dict[str, ChatSessionConnection] = {}
|
||||
self.client_sessions: Dict[Tuple[str, str], str] = {}
|
||||
self.connection_sessions: Dict[str, Set[str]] = {}
|
||||
self.group_sessions: Dict[str, Set[str]] = {}
|
||||
self.user_sessions: Dict[str, Set[str]] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket, session_id: str, user_id: str) -> None:
|
||||
await websocket.accept()
|
||||
self.active_connections[session_id] = websocket
|
||||
self.user_sessions[user_id] = session_id
|
||||
logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}")
|
||||
def _bind_group(self, session_id: str, group_id: str) -> None:
|
||||
"""为会话绑定群组索引。
|
||||
|
||||
def disconnect(self, session_id: str, user_id: str) -> None:
|
||||
if session_id in self.active_connections:
|
||||
del self.active_connections[session_id]
|
||||
if user_id in self.user_sessions and self.user_sessions[user_id] == session_id:
|
||||
del self.user_sessions[user_id]
|
||||
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
|
||||
Args:
|
||||
session_id: 内部会话 ID。
|
||||
group_id: 群组标识。
|
||||
"""
|
||||
group_session_ids = self.group_sessions.setdefault(group_id, set())
|
||||
group_session_ids.add(session_id)
|
||||
|
||||
def _unbind_group(self, session_id: str, group_id: str) -> None:
|
||||
"""移除会话与群组的索引关系。
|
||||
|
||||
Args:
|
||||
session_id: 内部会话 ID。
|
||||
group_id: 群组标识。
|
||||
"""
|
||||
group_session_ids = self.group_sessions.get(group_id)
|
||||
if group_session_ids is None:
|
||||
return
|
||||
|
||||
group_session_ids.discard(session_id)
|
||||
if not group_session_ids:
|
||||
del self.group_sessions[group_id]
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
session_id: str,
|
||||
connection_id: str,
|
||||
client_session_id: str,
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
sender: AsyncMessageSender,
|
||||
) -> None:
|
||||
"""注册一个新的逻辑聊天会话。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
connection_id: 物理 WebSocket 连接 ID。
|
||||
client_session_id: 前端标签页使用的会话 ID。
|
||||
user_id: 规范化后的用户 ID。
|
||||
user_name: 当前展示昵称。
|
||||
virtual_config: 当前虚拟身份配置。
|
||||
sender: 发送消息到前端的异步回调。
|
||||
"""
|
||||
existing_session_id = self.client_sessions.get((connection_id, client_session_id))
|
||||
if existing_session_id is not None:
|
||||
self.disconnect(existing_session_id)
|
||||
|
||||
active_group_id = get_current_group_id(virtual_config)
|
||||
session_connection = ChatSessionConnection(
|
||||
session_id=session_id,
|
||||
connection_id=connection_id,
|
||||
client_session_id=client_session_id,
|
||||
user_id=user_id,
|
||||
user_name=user_name,
|
||||
active_group_id=active_group_id,
|
||||
virtual_config=virtual_config,
|
||||
sender=sender,
|
||||
)
|
||||
|
||||
self.active_connections[session_id] = session_connection
|
||||
self.client_sessions[(connection_id, client_session_id)] = session_id
|
||||
self.connection_sessions.setdefault(connection_id, set()).add(session_id)
|
||||
self.user_sessions.setdefault(user_id, set()).add(session_id)
|
||||
self._bind_group(session_id, active_group_id)
|
||||
logger.info(
|
||||
"WebUI 聊天会话已连接: session=%s, connection=%s, client_session=%s, user=%s, group=%s",
|
||||
session_id,
|
||||
connection_id,
|
||||
client_session_id,
|
||||
user_id,
|
||||
active_group_id,
|
||||
)
|
||||
|
||||
def disconnect(self, session_id: str) -> None:
|
||||
"""断开一个逻辑聊天会话。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
"""
|
||||
session_connection = self.active_connections.pop(session_id, None)
|
||||
if session_connection is None:
|
||||
return
|
||||
|
||||
self.client_sessions.pop((session_connection.connection_id, session_connection.client_session_id), None)
|
||||
self._unbind_group(session_id, session_connection.active_group_id)
|
||||
|
||||
connection_session_ids = self.connection_sessions.get(session_connection.connection_id)
|
||||
if connection_session_ids is not None:
|
||||
connection_session_ids.discard(session_id)
|
||||
if not connection_session_ids:
|
||||
del self.connection_sessions[session_connection.connection_id]
|
||||
|
||||
user_session_ids = self.user_sessions.get(session_connection.user_id)
|
||||
if user_session_ids is not None:
|
||||
user_session_ids.discard(session_id)
|
||||
if not user_session_ids:
|
||||
del self.user_sessions[session_connection.user_id]
|
||||
|
||||
logger.info("WebUI 聊天会话已断开: session=%s", session_id)
|
||||
|
||||
def disconnect_connection(self, connection_id: str) -> None:
|
||||
"""断开物理连接下的全部逻辑聊天会话。
|
||||
|
||||
Args:
|
||||
connection_id: 物理 WebSocket 连接 ID。
|
||||
"""
|
||||
session_ids = list(self.connection_sessions.get(connection_id, set()))
|
||||
for session_id in session_ids:
|
||||
self.disconnect(session_id)
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[ChatSessionConnection]:
|
||||
"""获取逻辑聊天会话信息。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
|
||||
Returns:
|
||||
Optional[ChatSessionConnection]: 会话存在时返回对应信息。
|
||||
"""
|
||||
return self.active_connections.get(session_id)
|
||||
|
||||
def get_session_id(self, connection_id: str, client_session_id: str) -> Optional[str]:
|
||||
"""根据连接 ID 和前端会话 ID 查询内部会话 ID。
|
||||
|
||||
Args:
|
||||
connection_id: 物理 WebSocket 连接 ID。
|
||||
client_session_id: 前端标签页使用的会话 ID。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 找到时返回内部会话 ID。
|
||||
"""
|
||||
return self.client_sessions.get((connection_id, client_session_id))
|
||||
|
||||
def update_session_context(
|
||||
self,
|
||||
session_id: str,
|
||||
user_name: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> None:
|
||||
"""更新会话上下文信息。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
user_name: 最新昵称。
|
||||
virtual_config: 最新虚拟身份配置。
|
||||
"""
|
||||
session_connection = self.active_connections.get(session_id)
|
||||
if session_connection is None:
|
||||
return
|
||||
|
||||
next_group_id = get_current_group_id(virtual_config)
|
||||
if next_group_id != session_connection.active_group_id:
|
||||
self._unbind_group(session_id, session_connection.active_group_id)
|
||||
self._bind_group(session_id, next_group_id)
|
||||
session_connection.active_group_id = next_group_id
|
||||
|
||||
session_connection.user_name = user_name
|
||||
session_connection.virtual_config = virtual_config
|
||||
|
||||
async def send_message(self, session_id: str, message: Dict[str, Any]) -> None:
|
||||
if session_id in self.active_connections:
|
||||
try:
|
||||
await self.active_connections[session_id].send_json(message)
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
"""向指定逻辑会话发送消息。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
message: 待发送的消息内容。
|
||||
"""
|
||||
session_connection = self.active_connections.get(session_id)
|
||||
if session_connection is None:
|
||||
return
|
||||
|
||||
try:
|
||||
await session_connection.sender(message)
|
||||
except Exception as exc:
|
||||
logger.error("发送聊天消息失败: session=%s, error=%s", session_id, exc)
|
||||
|
||||
async def broadcast(self, message: Dict[str, Any]) -> None:
|
||||
"""向全部逻辑聊天会话广播消息。
|
||||
|
||||
Args:
|
||||
message: 待广播的消息内容。
|
||||
"""
|
||||
for session_id in list(self.active_connections.keys()):
|
||||
await self.send_message(session_id, message)
|
||||
|
||||
async def broadcast_to_group(self, group_id: str, message: Dict[str, Any]) -> None:
|
||||
"""向指定群组下的全部逻辑会话广播消息。
|
||||
|
||||
Args:
|
||||
group_id: 群组标识。
|
||||
message: 待广播的消息内容。
|
||||
"""
|
||||
for session_id in list(self.group_sessions.get(group_id, set())):
|
||||
await self.send_message(session_id, message)
|
||||
|
||||
|
||||
chat_history = ChatHistoryManager()
|
||||
chat_manager = ChatConnectionManager()
|
||||
|
||||
|
||||
def is_virtual_mode_enabled(virtual_config: Optional[VirtualIdentityConfig]) -> bool:
|
||||
"""判断当前是否启用了虚拟身份模式。
|
||||
|
||||
Args:
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
bool: 已启用时返回 ``True``。
|
||||
"""
|
||||
return bool(virtual_config and virtual_config.enabled)
|
||||
|
||||
|
||||
async def authenticate_chat_websocket(websocket: WebSocket, token: Optional[str]) -> bool:
|
||||
if token and verify_ws_token(token):
|
||||
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
|
||||
return True
|
||||
|
||||
if cookie_token := websocket.cookies.get("maibot_session"):
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def normalize_webui_user_id(user_id: Optional[str]) -> str:
|
||||
"""标准化 WebUI 用户 ID。
|
||||
|
||||
Args:
|
||||
user_id: 原始用户 ID。
|
||||
|
||||
Returns:
|
||||
str: 带统一前缀的用户 ID。
|
||||
"""
|
||||
if not user_id:
|
||||
return f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}"
|
||||
if user_id.startswith(WEBUI_USER_ID_PREFIX):
|
||||
@@ -172,12 +405,30 @@ def normalize_webui_user_id(user_id: Optional[str]) -> str:
|
||||
|
||||
|
||||
def get_person_by_person_id(person_id: str) -> Optional[PersonInfo]:
|
||||
"""根据人物 ID 查询人物信息。
|
||||
|
||||
Args:
|
||||
person_id: 人物 ID。
|
||||
|
||||
Returns:
|
||||
Optional[PersonInfo]: 查到时返回人物信息。
|
||||
"""
|
||||
with get_db_session() as session:
|
||||
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||
return session.exec(statement).first()
|
||||
|
||||
|
||||
def build_virtual_identity_config(person: PersonInfo, group_id: str, group_name: str) -> VirtualIdentityConfig:
|
||||
"""根据人物信息构建虚拟身份配置。
|
||||
|
||||
Args:
|
||||
person: 人物信息对象。
|
||||
group_id: 逻辑群组 ID。
|
||||
group_name: 逻辑群组名称。
|
||||
|
||||
Returns:
|
||||
VirtualIdentityConfig: 虚拟身份配置对象。
|
||||
"""
|
||||
return VirtualIdentityConfig(
|
||||
enabled=True,
|
||||
platform=person.platform,
|
||||
@@ -195,6 +446,17 @@ def resolve_initial_virtual_identity(
|
||||
group_name: Optional[str],
|
||||
group_id: Optional[str],
|
||||
) -> Optional[VirtualIdentityConfig]:
|
||||
"""根据初始参数解析虚拟身份配置。
|
||||
|
||||
Args:
|
||||
platform: 平台名称。
|
||||
person_id: 人物 ID。
|
||||
group_name: 群组名称。
|
||||
group_id: 群组 ID。
|
||||
|
||||
Returns:
|
||||
Optional[VirtualIdentityConfig]: 解析成功时返回虚拟身份配置。
|
||||
"""
|
||||
if not (platform and person_id):
|
||||
return None
|
||||
|
||||
@@ -210,11 +472,14 @@ def resolve_initial_virtual_identity(
|
||||
group_name=group_name or "WebUI虚拟群聊",
|
||||
)
|
||||
logger.info(
|
||||
f"虚拟身份模式已通过 URL 参数激活: {virtual_config.user_nickname} @ {virtual_config.platform}, group_id={virtual_group_id}"
|
||||
"虚拟身份模式已通过参数激活: %s @ %s, group_id=%s",
|
||||
virtual_config.user_nickname,
|
||||
virtual_config.platform,
|
||||
virtual_group_id,
|
||||
)
|
||||
return virtual_config
|
||||
except Exception as e:
|
||||
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"通过参数配置虚拟身份失败: {exc}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -224,6 +489,17 @@ def build_session_info_message(
|
||||
user_name: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> Dict[str, Any]:
|
||||
"""构建会话信息消息。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
user_id: 规范化后的用户 ID。
|
||||
user_name: 当前昵称。
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 会话信息消息。
|
||||
"""
|
||||
session_info_data: Dict[str, Any] = {
|
||||
"type": "session_info",
|
||||
"session_id": session_id,
|
||||
@@ -247,13 +523,41 @@ def build_session_info_message(
|
||||
|
||||
|
||||
def get_active_history_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> Optional[str]:
|
||||
"""获取当前虚拟身份对应的历史群组 ID。
|
||||
|
||||
Args:
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 虚拟身份启用时返回对应群组 ID。
|
||||
"""
|
||||
if is_virtual_mode_enabled(virtual_config):
|
||||
assert virtual_config is not None
|
||||
return virtual_config.group_id
|
||||
return None
|
||||
|
||||
|
||||
def get_current_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> str:
|
||||
"""获取当前会话的有效群组 ID。
|
||||
|
||||
Args:
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
str: 当前会话应使用的群组 ID。
|
||||
"""
|
||||
return get_active_history_group_id(virtual_config) or WEBUI_CHAT_GROUP_ID
|
||||
|
||||
|
||||
def build_welcome_message(virtual_config: Optional[VirtualIdentityConfig]) -> str:
|
||||
"""构建欢迎消息。
|
||||
|
||||
Args:
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
str: 欢迎消息文本。
|
||||
"""
|
||||
if is_virtual_mode_enabled(virtual_config):
|
||||
assert virtual_config is not None
|
||||
return (
|
||||
@@ -264,6 +568,12 @@ def build_welcome_message(virtual_config: Optional[VirtualIdentityConfig]) -> st
|
||||
|
||||
|
||||
async def send_chat_error(session_id: str, content: str) -> None:
|
||||
"""向指定会话发送错误消息。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
content: 错误消息内容。
|
||||
"""
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
@@ -279,7 +589,17 @@ async def send_initial_chat_state(
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
include_welcome: bool = True,
|
||||
) -> None:
|
||||
"""向新会话发送初始化状态。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
user_id: 规范化后的用户 ID。
|
||||
user_name: 当前昵称。
|
||||
virtual_config: 虚拟身份配置。
|
||||
include_welcome: 是否发送欢迎消息。
|
||||
"""
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
build_session_info_message(
|
||||
@@ -290,30 +610,43 @@ async def send_initial_chat_state(
|
||||
),
|
||||
)
|
||||
|
||||
if history := chat_history.get_history(50, get_active_history_group_id(virtual_config)):
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "history",
|
||||
"messages": history,
|
||||
},
|
||||
)
|
||||
|
||||
history_group_id = get_active_history_group_id(virtual_config)
|
||||
history = chat_history.get_history(50, history_group_id)
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": build_welcome_message(virtual_config),
|
||||
"timestamp": time.time(),
|
||||
"type": "history",
|
||||
"messages": history,
|
||||
"group_id": get_current_group_id(virtual_config),
|
||||
},
|
||||
)
|
||||
|
||||
if include_welcome:
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": build_welcome_message(virtual_config),
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def resolve_sender_identity(
|
||||
current_user_name: str,
|
||||
normalized_user_id: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> Tuple[str, str]:
|
||||
"""解析当前发送者身份。
|
||||
|
||||
Args:
|
||||
current_user_name: 当前昵称。
|
||||
normalized_user_id: 规范化后的用户 ID。
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: ``(发送者昵称, 发送者用户 ID)``。
|
||||
"""
|
||||
if is_virtual_mode_enabled(virtual_config):
|
||||
assert virtual_config is not None
|
||||
return virtual_config.user_nickname or current_user_name, virtual_config.user_id or normalized_user_id
|
||||
@@ -328,6 +661,19 @@ def create_message_data(
|
||||
is_at_bot: bool = True,
|
||||
virtual_config: Optional[VirtualIdentityConfig] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""构建发送给聊天核心的消息数据。
|
||||
|
||||
Args:
|
||||
content: 文本内容。
|
||||
user_id: 用户 ID。
|
||||
user_name: 用户昵称。
|
||||
message_id: 消息 ID。
|
||||
is_at_bot: 是否默认艾特机器人。
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 聊天核心可处理的消息数据。
|
||||
"""
|
||||
if message_id is None:
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
@@ -389,6 +735,18 @@ async def handle_chat_message(
|
||||
normalized_user_id: str,
|
||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> str:
|
||||
"""处理用户发送的聊天消息。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
data: 前端提交的消息数据。
|
||||
current_user_name: 当前昵称。
|
||||
normalized_user_id: 规范化后的用户 ID。
|
||||
current_virtual_config: 当前虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
str: 处理后的最新昵称。
|
||||
"""
|
||||
content = str(data.get("content", "")).strip()
|
||||
if not content:
|
||||
return current_user_name
|
||||
@@ -401,11 +759,14 @@ async def handle_chat_message(
|
||||
normalized_user_id=normalized_user_id,
|
||||
virtual_config=current_virtual_config,
|
||||
)
|
||||
target_group_id = get_current_group_id(current_virtual_config)
|
||||
|
||||
await chat_manager.broadcast(
|
||||
await chat_manager.broadcast_to_group(
|
||||
target_group_id,
|
||||
{
|
||||
"type": "user_message",
|
||||
"content": content,
|
||||
"group_id": target_group_id,
|
||||
"message_id": message_id,
|
||||
"timestamp": timestamp,
|
||||
"sender": {
|
||||
@@ -414,7 +775,7 @@ async def handle_chat_message(
|
||||
"is_bot": False,
|
||||
},
|
||||
"virtual_mode": is_virtual_mode_enabled(current_virtual_config),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
message_data = create_message_data(
|
||||
@@ -427,22 +788,37 @@ async def handle_chat_message(
|
||||
)
|
||||
|
||||
try:
|
||||
await chat_manager.broadcast({"type": "typing", "is_typing": True})
|
||||
await chat_manager.broadcast_to_group(target_group_id, {"type": "typing", "is_typing": True})
|
||||
await chat_bot.message_process(message_data)
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息时出错: {e}")
|
||||
await send_chat_error(session_id, f"处理消息时出错: {str(e)}")
|
||||
except Exception as exc:
|
||||
logger.error(f"处理消息时出错: {exc}")
|
||||
await send_chat_error(session_id, f"处理消息时出错: {str(exc)}")
|
||||
finally:
|
||||
await chat_manager.broadcast({"type": "typing", "is_typing": False})
|
||||
await chat_manager.broadcast_to_group(target_group_id, {"type": "typing", "is_typing": False})
|
||||
|
||||
return next_user_name
|
||||
|
||||
|
||||
async def handle_chat_ping(session_id: str) -> None:
|
||||
"""处理聊天心跳。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
"""
|
||||
await chat_manager.send_message(session_id, {"type": "pong", "timestamp": time.time()})
|
||||
|
||||
|
||||
async def handle_nickname_update(session_id: str, data: Dict[str, Any], current_user_name: str) -> str:
|
||||
"""处理昵称更新请求。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
data: 前端提交的数据。
|
||||
current_user_name: 当前昵称。
|
||||
|
||||
Returns:
|
||||
str: 更新后的昵称。
|
||||
"""
|
||||
new_name = str(data.get("user_name", "")).strip()
|
||||
if not new_name:
|
||||
return current_user_name
|
||||
@@ -463,6 +839,16 @@ async def enable_virtual_identity(
|
||||
session_prefix: str,
|
||||
virtual_data: Dict[str, Any],
|
||||
) -> Optional[VirtualIdentityConfig]:
|
||||
"""启用虚拟身份模式。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
session_prefix: 会话前缀,用于生成默认群组 ID。
|
||||
virtual_data: 前端提交的虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Optional[VirtualIdentityConfig]: 启用成功时返回新的虚拟身份配置。
|
||||
"""
|
||||
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
|
||||
await send_chat_error(session_id, "虚拟身份配置缺少必要字段: platform 和 person_id")
|
||||
return None
|
||||
@@ -470,16 +856,18 @@ async def enable_virtual_identity(
|
||||
person_id_value = str(virtual_data.get("person_id"))
|
||||
try:
|
||||
person = get_person_by_person_id(person_id_value)
|
||||
if not person:
|
||||
if person is None:
|
||||
await send_chat_error(session_id, f"找不到用户: {person_id_value}")
|
||||
return None
|
||||
|
||||
custom_group_id = virtual_data.get("group_id")
|
||||
current_group_id = (
|
||||
f"{VIRTUAL_GROUP_ID_PREFIX}{custom_group_id}"
|
||||
if custom_group_id
|
||||
else f"{VIRTUAL_GROUP_ID_PREFIX}{session_prefix}"
|
||||
)
|
||||
custom_group_id = str(virtual_data.get("group_id") or "").strip()
|
||||
if custom_group_id:
|
||||
current_group_id = custom_group_id
|
||||
if not current_group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
|
||||
current_group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{current_group_id}"
|
||||
else:
|
||||
current_group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{session_prefix}"
|
||||
|
||||
current_virtual_config = build_virtual_identity_config(
|
||||
person=person,
|
||||
group_id=current_group_id,
|
||||
@@ -521,13 +909,18 @@ async def enable_virtual_identity(
|
||||
},
|
||||
)
|
||||
return current_virtual_config
|
||||
except Exception as e:
|
||||
logger.error(f"设置虚拟身份失败: {e}")
|
||||
await send_chat_error(session_id, f"设置虚拟身份失败: {str(e)}")
|
||||
except Exception as exc:
|
||||
logger.error(f"设置虚拟身份失败: {exc}")
|
||||
await send_chat_error(session_id, f"设置虚拟身份失败: {str(exc)}")
|
||||
return None
|
||||
|
||||
|
||||
async def disable_virtual_identity(session_id: str) -> None:
|
||||
"""关闭虚拟身份模式。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
"""
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
@@ -560,7 +953,18 @@ async def handle_virtual_identity_update(
|
||||
data: Dict[str, Any],
|
||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> Optional[VirtualIdentityConfig]:
|
||||
virtual_data = cast(dict[str, Any], data.get("config", {}))
|
||||
"""处理虚拟身份切换请求。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
session_id_prefix: 会话前缀。
|
||||
data: 前端提交的数据。
|
||||
current_virtual_config: 当前虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Optional[VirtualIdentityConfig]: 更新后的虚拟身份配置。
|
||||
"""
|
||||
virtual_data = cast(Dict[str, Any], data.get("config", {}))
|
||||
if virtual_data.get("enabled"):
|
||||
next_config = await enable_virtual_identity(session_id, session_id_prefix, virtual_data)
|
||||
return next_config if next_config is not None else current_virtual_config
|
||||
@@ -577,6 +981,19 @@ async def dispatch_chat_event(
|
||||
normalized_user_id: str,
|
||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> Tuple[str, Optional[VirtualIdentityConfig]]:
|
||||
"""分发聊天事件到对应的处理器。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
session_id_prefix: 会话前缀。
|
||||
data: 前端提交的数据。
|
||||
current_user_name: 当前昵称。
|
||||
normalized_user_id: 规范化后的用户 ID。
|
||||
current_virtual_config: 当前虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Tuple[str, Optional[VirtualIdentityConfig]]: ``(最新昵称, 最新虚拟身份配置)``。
|
||||
"""
|
||||
event_type = data.get("type")
|
||||
if event_type == "message":
|
||||
next_user_name = await handle_chat_message(
|
||||
@@ -6,11 +6,13 @@ from .catalog import router as catalog_router
|
||||
from .config_routes import router as config_router
|
||||
from .management import router as management_router
|
||||
from .progress import get_progress_router, update_progress
|
||||
from .runtime_routes import router as runtime_router
|
||||
|
||||
router = APIRouter(prefix="/plugins", tags=["插件管理"])
|
||||
router.include_router(catalog_router)
|
||||
router.include_router(management_router)
|
||||
router.include_router(config_router)
|
||||
router.include_router(runtime_router)
|
||||
|
||||
set_update_progress_callback(update_progress)
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
"""插件配置相关 WebUI 路由。"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
import tomlkit
|
||||
from fastapi import APIRouter, Cookie, HTTPException
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
from src.plugin_runtime.protocol.envelope import InspectPluginConfigResultPayload
|
||||
from src.webui.utils.toml_utils import save_toml_with_format
|
||||
|
||||
from .schemas import UpdatePluginConfigRequest, UpdatePluginRawConfigRequest
|
||||
@@ -207,6 +207,55 @@ def _build_toml_document(config_data: Dict[str, Any]) -> tomlkit.TOMLDocument:
|
||||
return tomlkit.parse(tomlkit.dumps(config_data))
|
||||
|
||||
|
||||
def _load_plugin_config_from_disk(plugin_path: Path) -> Dict[str, Any]:
|
||||
"""从磁盘读取插件配置。
|
||||
|
||||
Args:
|
||||
plugin_path: 插件目录。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 当前配置字典;文件不存在时返回空字典。
|
||||
"""
|
||||
|
||||
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
|
||||
if not config_path.exists():
|
||||
return {}
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as file_obj:
|
||||
loaded_config = tomlkit.load(file_obj).unwrap()
|
||||
return loaded_config if isinstance(loaded_config, dict) else {}
|
||||
|
||||
|
||||
async def _inspect_plugin_config_via_runtime(
|
||||
plugin_id: str,
|
||||
config_data: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
use_provided_config: bool = False,
|
||||
) -> InspectPluginConfigResultPayload | None:
|
||||
"""通过插件运行时解析配置元数据。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID。
|
||||
config_data: 可选的配置内容。
|
||||
use_provided_config: 是否优先使用传入配置而不是磁盘配置。
|
||||
|
||||
Returns:
|
||||
InspectPluginConfigResultPayload | None: 运行时可用时返回解析结果,否则返回 ``None``。
|
||||
|
||||
Raises:
|
||||
ValueError: 插件运行时明确拒绝解析请求时抛出。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
runtime_manager = get_plugin_runtime_manager()
|
||||
return await runtime_manager.inspect_plugin_config(
|
||||
plugin_id,
|
||||
config_data,
|
||||
use_provided_config=use_provided_config,
|
||||
)
|
||||
|
||||
|
||||
async def _validate_plugin_config_via_runtime(plugin_id: str, config_data: Dict[str, Any]) -> Dict[str, Any] | None:
|
||||
"""通过插件运行时对配置进行校验。
|
||||
|
||||
@@ -244,27 +293,24 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
|
||||
logger.info(f"获取插件配置 Schema: {plugin_id}")
|
||||
|
||||
try:
|
||||
registration_schema = component_query_service.get_plugin_config_schema(plugin_id)
|
||||
if isinstance(registration_schema, dict) and registration_schema:
|
||||
return {"success": True, "schema": registration_schema}
|
||||
|
||||
plugin_path = find_plugin_path_by_id(plugin_id)
|
||||
if plugin_path is None:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
schema_json_path = resolve_plugin_file_path(plugin_path, "config_schema.json")
|
||||
if schema_json_path.exists():
|
||||
try:
|
||||
with open(schema_json_path, "r", encoding="utf-8") as file_obj:
|
||||
return {"success": True, "schema": json.load(file_obj)}
|
||||
except Exception as e:
|
||||
logger.warning(f"读取 config_schema.json 失败,回退到自动推断: {e}")
|
||||
try:
|
||||
runtime_snapshot = await _inspect_plugin_config_via_runtime(plugin_id)
|
||||
except ValueError as exc:
|
||||
logger.warning(f"插件 {plugin_id} 配置 Schema 解析失败,将回退到弱推断: {exc}")
|
||||
runtime_snapshot = None
|
||||
|
||||
current_config: Any = component_query_service.get_plugin_default_config(plugin_id) or {}
|
||||
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
|
||||
if config_path.exists():
|
||||
with open(config_path, "r", encoding="utf-8") as file_obj:
|
||||
current_config = tomlkit.load(file_obj)
|
||||
if runtime_snapshot is not None and runtime_snapshot.config_schema:
|
||||
return {"success": True, "schema": dict(runtime_snapshot.config_schema)}
|
||||
|
||||
current_config: Any = (
|
||||
dict(runtime_snapshot.normalized_config)
|
||||
if runtime_snapshot is not None
|
||||
else _load_plugin_config_from_disk(plugin_path)
|
||||
)
|
||||
|
||||
return {"success": True, "schema": _build_schema_from_current_config(plugin_id, current_config)}
|
||||
except HTTPException:
|
||||
@@ -375,15 +421,24 @@ async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cook
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
|
||||
try:
|
||||
runtime_snapshot = await _inspect_plugin_config_via_runtime(plugin_id)
|
||||
except ValueError as exc:
|
||||
logger.warning(f"插件 {plugin_id} 配置读取失败,将回退到磁盘内容: {exc}")
|
||||
runtime_snapshot = None
|
||||
|
||||
if runtime_snapshot is not None:
|
||||
message = "配置文件不存在,已返回默认配置" if not config_path.exists() else ""
|
||||
return {
|
||||
"success": True,
|
||||
"config": dict(runtime_snapshot.normalized_config),
|
||||
"message": message,
|
||||
}
|
||||
|
||||
if not config_path.exists():
|
||||
default_config = component_query_service.get_plugin_default_config(plugin_id)
|
||||
if isinstance(default_config, dict):
|
||||
return {"success": True, "config": default_config, "message": "配置文件不存在,已返回默认配置"}
|
||||
return {"success": True, "config": {}, "message": "配置文件不存在"}
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as file_obj:
|
||||
config = tomlkit.load(file_obj)
|
||||
return {"success": True, "config": dict(config)}
|
||||
return {"success": True, "config": _load_plugin_config_from_disk(plugin_path)}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -412,6 +467,10 @@ async def update_plugin_config(
|
||||
logger.info(f"更新插件配置: {plugin_id}")
|
||||
|
||||
try:
|
||||
plugin_path = find_plugin_path_by_id(plugin_id)
|
||||
if plugin_path is None:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
config_data = request.config or {}
|
||||
if isinstance(config_data, dict):
|
||||
config_data = normalize_dotted_keys(config_data)
|
||||
@@ -419,12 +478,13 @@ async def update_plugin_config(
|
||||
if isinstance(runtime_validated_config, dict):
|
||||
config_data = runtime_validated_config
|
||||
else:
|
||||
plugin_schema = component_query_service.get_plugin_config_schema(plugin_id)
|
||||
if isinstance(plugin_schema, dict) and plugin_schema:
|
||||
_coerce_config_by_plugin_schema(plugin_schema, config_data)
|
||||
plugin_path = find_plugin_path_by_id(plugin_id)
|
||||
if plugin_path is None:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
runtime_snapshot = await _inspect_plugin_config_via_runtime(
|
||||
plugin_id,
|
||||
config_data,
|
||||
use_provided_config=True,
|
||||
)
|
||||
if runtime_snapshot is not None and runtime_snapshot.config_schema:
|
||||
_coerce_config_by_plugin_schema(dict(runtime_snapshot.config_schema), config_data)
|
||||
|
||||
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
|
||||
backup_path = backup_file(config_path, "backup")
|
||||
@@ -498,17 +558,29 @@ async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(N
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
|
||||
default_config = component_query_service.get_plugin_default_config(plugin_id)
|
||||
config = _build_toml_document(default_config if isinstance(default_config, dict) else {})
|
||||
if config_path.exists():
|
||||
with open(config_path, "r", encoding="utf-8") as file_obj:
|
||||
config = tomlkit.load(file_obj)
|
||||
try:
|
||||
runtime_snapshot = await _inspect_plugin_config_via_runtime(plugin_id)
|
||||
except ValueError as exc:
|
||||
logger.warning(f"插件 {plugin_id} 状态切换前配置解析失败,将回退到磁盘内容: {exc}")
|
||||
runtime_snapshot = None
|
||||
|
||||
if "plugin" not in config:
|
||||
current_config = (
|
||||
dict(runtime_snapshot.normalized_config)
|
||||
if runtime_snapshot is not None
|
||||
else _load_plugin_config_from_disk(plugin_path)
|
||||
)
|
||||
config = _build_toml_document(current_config)
|
||||
|
||||
plugin_section = config.get("plugin")
|
||||
if plugin_section is None or not hasattr(plugin_section, "get"):
|
||||
config["plugin"] = tomlkit.table()
|
||||
|
||||
plugin_config = cast(Any, config["plugin"])
|
||||
current_enabled = bool(plugin_config.get("enabled", True))
|
||||
current_enabled = (
|
||||
bool(runtime_snapshot.enabled)
|
||||
if runtime_snapshot is not None
|
||||
else bool(plugin_config.get("enabled", True))
|
||||
)
|
||||
new_enabled = not current_enabled
|
||||
plugin_config["enabled"] = new_enabled
|
||||
save_toml_with_format(config, str(config_path))
|
||||
@@ -519,7 +591,7 @@ async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(N
|
||||
"success": True,
|
||||
"enabled": new_enabled,
|
||||
"message": f"插件已{status}",
|
||||
"note": "状态更改将在下次加载插件时生效",
|
||||
"note": "状态更改将自动热更新到对应插件",
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
"""插件进度实时推送支持。"""
|
||||
|
||||
from typing import Any, Dict, Optional, Set
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Dict, Optional, Set
|
||||
|
||||
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.core import get_token_manager
|
||||
from src.webui.routers.websocket.auth import verify_ws_token
|
||||
from src.webui.routers.websocket.manager import websocket_manager
|
||||
|
||||
logger = get_logger("webui.plugin_progress")
|
||||
|
||||
@@ -25,25 +28,29 @@ current_progress: Dict[str, Any] = {
|
||||
}
|
||||
|
||||
|
||||
def get_current_progress() -> Dict[str, Any]:
|
||||
"""获取当前插件进度快照。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 当前插件进度数据副本。
|
||||
"""
|
||||
return current_progress.copy()
|
||||
|
||||
|
||||
async def broadcast_progress(progress_data: Dict[str, Any]) -> None:
|
||||
"""向统一连接层广播插件进度更新。
|
||||
|
||||
Args:
|
||||
progress_data: 插件进度数据。
|
||||
"""
|
||||
global current_progress
|
||||
current_progress = progress_data.copy()
|
||||
|
||||
if not active_connections:
|
||||
return
|
||||
|
||||
message = json.dumps(progress_data, ensure_ascii=False)
|
||||
disconnected: Set[WebSocket] = set()
|
||||
|
||||
for websocket in active_connections:
|
||||
try:
|
||||
await websocket.send_text(message)
|
||||
except Exception as e:
|
||||
logger.error(f"发送进度更新失败: {e}")
|
||||
disconnected.add(websocket)
|
||||
|
||||
for websocket in disconnected:
|
||||
active_connections.discard(websocket)
|
||||
await websocket_manager.broadcast_to_topic(
|
||||
domain="plugin_progress",
|
||||
topic="main",
|
||||
event="update",
|
||||
data={"progress": progress_data},
|
||||
)
|
||||
|
||||
|
||||
async def update_progress(
|
||||
@@ -56,6 +63,18 @@ async def update_progress(
|
||||
total_plugins: int = 0,
|
||||
loaded_plugins: int = 0,
|
||||
) -> None:
|
||||
"""更新当前插件进度并广播。
|
||||
|
||||
Args:
|
||||
stage: 当前阶段。
|
||||
progress: 当前进度百分比。
|
||||
message: 进度说明消息。
|
||||
operation: 当前操作类型。
|
||||
error: 可选的错误信息。
|
||||
plugin_id: 当前处理的插件 ID。
|
||||
total_plugins: 总插件数量。
|
||||
loaded_plugins: 已处理插件数量。
|
||||
"""
|
||||
progress_data = {
|
||||
"operation": operation,
|
||||
"stage": stage,
|
||||
@@ -74,6 +93,12 @@ async def update_progress(
|
||||
|
||||
@router.websocket("/ws/plugin-progress")
|
||||
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)) -> None:
|
||||
"""旧版插件进度 WebSocket 入口。
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket 对象。
|
||||
token: 可选的一次性握手 Token。
|
||||
"""
|
||||
is_authenticated = False
|
||||
|
||||
if token and verify_ws_token(token):
|
||||
@@ -105,17 +130,22 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
|
||||
data = await websocket.receive_text()
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
except Exception as e:
|
||||
logger.error(f"处理客户端消息时出错: {e}")
|
||||
except Exception as exc:
|
||||
logger.error(f"处理客户端消息时出错: {exc}")
|
||||
break
|
||||
|
||||
except WebSocketDisconnect:
|
||||
active_connections.discard(websocket)
|
||||
logger.info(f"📡 插件进度 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ WebSocket 错误: {e}")
|
||||
except Exception as exc:
|
||||
logger.error(f"❌ WebSocket 错误: {exc}")
|
||||
active_connections.discard(websocket)
|
||||
|
||||
|
||||
def get_progress_router() -> APIRouter:
|
||||
"""获取旧版插件进度路由对象。
|
||||
|
||||
Returns:
|
||||
APIRouter: 插件进度路由对象。
|
||||
"""
|
||||
return router
|
||||
|
||||
28
src/webui/routers/plugin/runtime_routes.py
Normal file
28
src/webui/routers/plugin/runtime_routes.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""插件运行时相关 WebUI 路由。"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Cookie
|
||||
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
|
||||
from .schemas import HookSpecListResponse, HookSpecResponse
|
||||
from .support import require_plugin_token
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/runtime/hooks", response_model=HookSpecListResponse)
|
||||
async def list_runtime_hook_specs(maibot_session: Optional[str] = Cookie(None)) -> HookSpecListResponse:
|
||||
"""返回当前插件运行时公开的 Hook 规格清单。
|
||||
|
||||
Args:
|
||||
maibot_session: 当前 WebUI 会话令牌。
|
||||
|
||||
Returns:
|
||||
HookSpecListResponse: Hook 规格列表响应。
|
||||
"""
|
||||
|
||||
require_plugin_token(maibot_session)
|
||||
hooks = [HookSpecResponse(**hook_data) for hook_data in component_query_service.list_hook_specs()]
|
||||
return HookSpecListResponse(success=True, hooks=hooks)
|
||||
@@ -111,3 +111,19 @@ class UpdatePluginConfigRequest(BaseModel):
|
||||
|
||||
class UpdatePluginRawConfigRequest(BaseModel):
|
||||
config: str = Field(..., description="原始 TOML 配置内容")
|
||||
|
||||
|
||||
class HookSpecResponse(BaseModel):
|
||||
name: str = Field(..., description="Hook 名称")
|
||||
description: str = Field("", description="Hook 描述")
|
||||
parameters_schema: Dict[str, Any] = Field(default_factory=dict, description="Hook 参数模型")
|
||||
default_timeout_ms: int = Field(..., description="默认超时毫秒数")
|
||||
allow_blocking: bool = Field(..., description="是否允许 blocking 处理器")
|
||||
allow_observe: bool = Field(..., description="是否允许 observe 处理器")
|
||||
allow_abort: bool = Field(..., description="是否允许 abort")
|
||||
allow_kwargs_mutation: bool = Field(..., description="是否允许修改 kwargs")
|
||||
|
||||
|
||||
class HookSpecListResponse(BaseModel):
|
||||
success: bool = Field(..., description="是否成功")
|
||||
hooks: List[HookSpecResponse] = Field(default_factory=list, description="Hook 规格列表")
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""WebSocket 路由聚合导出。"""
|
||||
|
||||
from .auth import router as ws_auth_router
|
||||
from .logs import router as logs_router
|
||||
from .unified import router as unified_ws_router
|
||||
|
||||
__all__ = [
|
||||
"logs_router",
|
||||
"unified_ws_router",
|
||||
"ws_auth_router",
|
||||
]
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
"""WebSocket 日志推送路由兼容导出。"""
|
||||
|
||||
from src.webui.logs_ws import active_connections, broadcast_log, load_recent_logs, router, websocket_logs
|
||||
|
||||
__all__ = [
|
||||
"active_connections",
|
||||
"broadcast_log",
|
||||
"load_recent_logs",
|
||||
"router",
|
||||
"websocket_logs",
|
||||
]
|
||||
297
src/webui/routers/websocket/manager.py
Normal file
297
src/webui/routers/websocket/manager.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""统一 WebSocket 连接管理器。"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Set
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui.websocket")
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebSocketConnection:
|
||||
"""统一 WebSocket 连接上下文。"""
|
||||
|
||||
connection_id: str
|
||||
websocket: WebSocket
|
||||
subscriptions: Set[str] = field(default_factory=set)
|
||||
chat_sessions: Dict[str, str] = field(default_factory=dict)
|
||||
send_queue: "asyncio.Queue[Optional[Dict[str, Any]]]" = field(default_factory=asyncio.Queue)
|
||||
sender_task: Optional["asyncio.Task[None]"] = None
|
||||
|
||||
|
||||
class UnifiedWebSocketManager:
|
||||
"""统一 WebSocket 连接管理器。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化统一 WebSocket 连接管理器。"""
|
||||
self.connections: Dict[str, WebSocketConnection] = {}
|
||||
|
||||
def _build_subscription_key(self, domain: str, topic: str) -> str:
|
||||
"""构建订阅索引键。
|
||||
|
||||
Args:
|
||||
domain: 业务域名称。
|
||||
topic: 主题名称。
|
||||
|
||||
Returns:
|
||||
str: 订阅索引键。
|
||||
"""
|
||||
return f"{domain}:{topic}"
|
||||
|
||||
async def _sender_loop(self, connection: WebSocketConnection) -> None:
|
||||
"""串行发送指定连接的出站消息。
|
||||
|
||||
Args:
|
||||
connection: 目标连接上下文。
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
message = await connection.send_queue.get()
|
||||
if message is None:
|
||||
return
|
||||
await connection.websocket.send_json(message)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error("统一 WebSocket 发送失败: connection=%s, error=%s", connection.connection_id, exc)
|
||||
|
||||
async def connect(self, connection_id: str, websocket: WebSocket) -> WebSocketConnection:
|
||||
"""注册一个新的物理 WebSocket 连接。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
websocket: FastAPI WebSocket 对象。
|
||||
|
||||
Returns:
|
||||
WebSocketConnection: 新建的连接上下文。
|
||||
"""
|
||||
await websocket.accept()
|
||||
connection = WebSocketConnection(connection_id=connection_id, websocket=websocket)
|
||||
connection.sender_task = asyncio.create_task(self._sender_loop(connection))
|
||||
self.connections[connection_id] = connection
|
||||
return connection
|
||||
|
||||
async def disconnect(self, connection_id: str) -> None:
|
||||
"""断开并清理指定连接。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
"""
|
||||
connection = self.connections.pop(connection_id, None)
|
||||
if connection is None:
|
||||
return
|
||||
|
||||
await connection.send_queue.put(None)
|
||||
if connection.sender_task is not None:
|
||||
try:
|
||||
await connection.sender_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.debug("等待发送协程退出时出现异常: connection=%s, error=%s", connection_id, exc)
|
||||
|
||||
def get_connection(self, connection_id: str) -> Optional[WebSocketConnection]:
|
||||
"""获取指定连接上下文。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
|
||||
Returns:
|
||||
Optional[WebSocketConnection]: 找到时返回连接上下文。
|
||||
"""
|
||||
return self.connections.get(connection_id)
|
||||
|
||||
def register_chat_session(self, connection_id: str, client_session_id: str, session_id: str) -> None:
|
||||
"""登记连接下的逻辑聊天会话。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
client_session_id: 前端会话 ID。
|
||||
session_id: 内部会话 ID。
|
||||
"""
|
||||
connection = self.connections.get(connection_id)
|
||||
if connection is None:
|
||||
return
|
||||
connection.chat_sessions[client_session_id] = session_id
|
||||
|
||||
def unregister_chat_session(self, connection_id: str, client_session_id: str) -> None:
|
||||
"""移除连接下的逻辑聊天会话登记。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
client_session_id: 前端会话 ID。
|
||||
"""
|
||||
connection = self.connections.get(connection_id)
|
||||
if connection is None:
|
||||
return
|
||||
connection.chat_sessions.pop(client_session_id, None)
|
||||
|
||||
def get_chat_session_id(self, connection_id: str, client_session_id: str) -> Optional[str]:
|
||||
"""查询连接下的内部聊天会话 ID。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
client_session_id: 前端会话 ID。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 找到时返回内部会话 ID。
|
||||
"""
|
||||
connection = self.connections.get(connection_id)
|
||||
if connection is None:
|
||||
return None
|
||||
return connection.chat_sessions.get(client_session_id)
|
||||
|
||||
def subscribe(self, connection_id: str, domain: str, topic: str) -> None:
|
||||
"""登记连接的主题订阅。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
domain: 业务域名称。
|
||||
topic: 主题名称。
|
||||
"""
|
||||
connection = self.connections.get(connection_id)
|
||||
if connection is None:
|
||||
return
|
||||
connection.subscriptions.add(self._build_subscription_key(domain, topic))
|
||||
|
||||
def unsubscribe(self, connection_id: str, domain: str, topic: str) -> None:
|
||||
"""移除连接的主题订阅。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
domain: 业务域名称。
|
||||
topic: 主题名称。
|
||||
"""
|
||||
connection = self.connections.get(connection_id)
|
||||
if connection is None:
|
||||
return
|
||||
connection.subscriptions.discard(self._build_subscription_key(domain, topic))
|
||||
|
||||
def is_subscribed(self, connection_id: str, domain: str, topic: str) -> bool:
|
||||
"""判断连接是否订阅了指定主题。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
domain: 业务域名称。
|
||||
topic: 主题名称。
|
||||
|
||||
Returns:
|
||||
bool: 已订阅时返回 ``True``。
|
||||
"""
|
||||
connection = self.connections.get(connection_id)
|
||||
if connection is None:
|
||||
return False
|
||||
return self._build_subscription_key(domain, topic) in connection.subscriptions
|
||||
|
||||
async def enqueue(self, connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""向指定连接的发送队列压入消息。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 待发送的消息。
|
||||
"""
|
||||
connection = self.connections.get(connection_id)
|
||||
if connection is None:
|
||||
return
|
||||
await connection.send_queue.put(message)
|
||||
|
||||
async def send_response(
|
||||
self,
|
||||
connection_id: str,
|
||||
request_id: Optional[str],
|
||||
ok: bool,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
error: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""发送统一响应消息。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
request_id: 请求 ID。
|
||||
ok: 请求是否成功。
|
||||
data: 成功响应数据。
|
||||
error: 失败响应数据。
|
||||
"""
|
||||
response_message: Dict[str, Any] = {
|
||||
"op": "response",
|
||||
"id": request_id,
|
||||
"ok": ok,
|
||||
}
|
||||
if data is not None:
|
||||
response_message["data"] = data
|
||||
if error is not None:
|
||||
response_message["error"] = error
|
||||
await self.enqueue(connection_id, response_message)
|
||||
|
||||
async def send_event(
|
||||
self,
|
||||
connection_id: str,
|
||||
domain: str,
|
||||
event: str,
|
||||
data: Dict[str, Any],
|
||||
session: Optional[str] = None,
|
||||
topic: Optional[str] = None,
|
||||
) -> None:
|
||||
"""发送统一事件消息。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
domain: 业务域名称。
|
||||
event: 事件名称。
|
||||
data: 事件数据。
|
||||
session: 可选的逻辑会话 ID。
|
||||
topic: 可选的主题名称。
|
||||
"""
|
||||
event_message: Dict[str, Any] = {
|
||||
"op": "event",
|
||||
"domain": domain,
|
||||
"event": event,
|
||||
"data": data,
|
||||
}
|
||||
if session is not None:
|
||||
event_message["session"] = session
|
||||
if topic is not None:
|
||||
event_message["topic"] = topic
|
||||
await self.enqueue(connection_id, event_message)
|
||||
|
||||
async def send_pong(self, connection_id: str, timestamp: float) -> None:
|
||||
"""发送心跳响应。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
timestamp: 当前时间戳。
|
||||
"""
|
||||
await self.enqueue(
|
||||
connection_id,
|
||||
{
|
||||
"op": "pong",
|
||||
"ts": timestamp,
|
||||
},
|
||||
)
|
||||
|
||||
async def broadcast_to_topic(self, domain: str, topic: str, event: str, data: Dict[str, Any]) -> None:
|
||||
"""向订阅指定主题的全部连接广播事件。
|
||||
|
||||
Args:
|
||||
domain: 业务域名称。
|
||||
topic: 主题名称。
|
||||
event: 事件名称。
|
||||
data: 事件数据。
|
||||
"""
|
||||
subscription_key = self._build_subscription_key(domain, topic)
|
||||
for connection in list(self.connections.values()):
|
||||
if subscription_key in connection.subscriptions:
|
||||
await self.send_event(
|
||||
connection.connection_id,
|
||||
domain=domain,
|
||||
event=event,
|
||||
data=data,
|
||||
topic=topic,
|
||||
)
|
||||
|
||||
|
||||
websocket_manager = UnifiedWebSocketManager()
|
||||
548
src/webui/routers/websocket/unified.py
Normal file
548
src/webui/routers/websocket/unified.py
Normal file
@@ -0,0 +1,548 @@
|
||||
"""统一 WebSocket 路由。"""
|
||||
|
||||
from typing import Any, Dict, Optional, Set, cast
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.core import get_token_manager
|
||||
from src.webui.logs_ws import load_recent_logs
|
||||
from src.webui.routers.chat.service import (
|
||||
chat_manager,
|
||||
dispatch_chat_event,
|
||||
normalize_webui_user_id,
|
||||
resolve_initial_virtual_identity,
|
||||
send_initial_chat_state,
|
||||
)
|
||||
from src.webui.routers.plugin.progress import get_current_progress
|
||||
from src.webui.routers.websocket.auth import verify_ws_token
|
||||
from src.webui.routers.websocket.manager import websocket_manager
|
||||
|
||||
logger = get_logger("webui.unified_ws")
|
||||
router = APIRouter()
|
||||
_background_tasks: Set["asyncio.Task[None]"] = set()
|
||||
|
||||
|
||||
def _build_error(code: str, message: str) -> Dict[str, Any]:
|
||||
"""构建统一错误响应体。
|
||||
|
||||
Args:
|
||||
code: 错误码。
|
||||
message: 错误描述。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 统一错误对象。
|
||||
"""
|
||||
return {
|
||||
"code": code,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
|
||||
def _get_request_data(message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""从客户端消息中提取数据字段。
|
||||
|
||||
Args:
|
||||
message: 客户端消息。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 标准化后的数据字典。
|
||||
"""
|
||||
data = message.get("data", {})
|
||||
if isinstance(data, dict):
|
||||
return cast(Dict[str, Any], data)
|
||||
return {}
|
||||
|
||||
|
||||
def _track_background_task(task: "asyncio.Task[None]") -> None:
|
||||
"""登记后台任务并在完成后自动清理。
|
||||
|
||||
Args:
|
||||
task: 后台协程任务。
|
||||
"""
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
|
||||
async def authenticate_websocket_connection(websocket: WebSocket, token: Optional[str]) -> bool:
|
||||
"""校验统一 WebSocket 连接的认证状态。
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket 对象。
|
||||
token: 可选的一次性握手 Token。
|
||||
|
||||
Returns:
|
||||
bool: 认证通过时返回 ``True``。
|
||||
"""
|
||||
if token and verify_ws_token(token):
|
||||
logger.debug("统一 WebSocket 使用临时 token 认证成功")
|
||||
return True
|
||||
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
if cookie_token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
logger.debug("统一 WebSocket 使用 Cookie 认证成功")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def _handle_logs_subscribe(connection_id: str, request_id: Optional[str], data: Dict[str, Any]) -> None:
|
||||
"""处理日志域订阅请求。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
request_id: 请求 ID。
|
||||
data: 订阅参数。
|
||||
"""
|
||||
replay_limit = int(data.get("replay", 100) or 100)
|
||||
replay_limit = max(0, min(replay_limit, 500))
|
||||
websocket_manager.subscribe(connection_id, domain="logs", topic="main")
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=True,
|
||||
data={"domain": "logs", "topic": "main"},
|
||||
)
|
||||
await websocket_manager.send_event(
|
||||
connection_id,
|
||||
domain="logs",
|
||||
event="snapshot",
|
||||
topic="main",
|
||||
data={"entries": load_recent_logs(limit=replay_limit)},
|
||||
)
|
||||
|
||||
|
||||
async def _handle_plugin_progress_subscribe(connection_id: str, request_id: Optional[str]) -> None:
|
||||
"""处理插件进度域订阅请求。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
request_id: 请求 ID。
|
||||
"""
|
||||
websocket_manager.subscribe(connection_id, domain="plugin_progress", topic="main")
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=True,
|
||||
data={"domain": "plugin_progress", "topic": "main"},
|
||||
)
|
||||
await websocket_manager.send_event(
|
||||
connection_id,
|
||||
domain="plugin_progress",
|
||||
event="snapshot",
|
||||
topic="main",
|
||||
data={"progress": get_current_progress()},
|
||||
)
|
||||
|
||||
|
||||
async def _handle_subscribe(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""处理主题订阅请求。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
domain = str(message.get("domain") or "").strip()
|
||||
topic = str(message.get("topic") or "").strip()
|
||||
data = _get_request_data(message)
|
||||
|
||||
if domain == "logs" and topic == "main":
|
||||
await _handle_logs_subscribe(connection_id, request_id, data)
|
||||
return
|
||||
|
||||
if domain == "plugin_progress" and topic == "main":
|
||||
await _handle_plugin_progress_subscribe(connection_id, request_id)
|
||||
return
|
||||
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("unsupported_subscription", f"不支持的订阅目标: {domain}:{topic}"),
|
||||
)
|
||||
|
||||
|
||||
async def _handle_unsubscribe(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""处理主题退订请求。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
domain = str(message.get("domain") or "").strip()
|
||||
topic = str(message.get("topic") or "").strip()
|
||||
|
||||
if not domain or not topic:
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("invalid_unsubscribe", "退订请求缺少 domain 或 topic"),
|
||||
)
|
||||
return
|
||||
|
||||
websocket_manager.unsubscribe(connection_id, domain=domain, topic=topic)
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=True,
|
||||
data={"domain": domain, "topic": topic},
|
||||
)
|
||||
|
||||
|
||||
async def _open_chat_session(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""打开一个逻辑聊天会话。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
client_session_id = str(message.get("session") or "").strip()
|
||||
if not client_session_id:
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("missing_session", "聊天会话打开请求缺少 session"),
|
||||
)
|
||||
return
|
||||
|
||||
data = _get_request_data(message)
|
||||
normalized_user_id = normalize_webui_user_id(cast(Optional[str], data.get("user_id")))
|
||||
current_user_name = str(data.get("user_name") or "WebUI用户")
|
||||
current_virtual_config = resolve_initial_virtual_identity(
|
||||
platform=cast(Optional[str], data.get("platform")),
|
||||
person_id=cast(Optional[str], data.get("person_id")),
|
||||
group_name=cast(Optional[str], data.get("group_name")),
|
||||
group_id=cast(Optional[str], data.get("group_id")),
|
||||
)
|
||||
restore = bool(data.get("restore"))
|
||||
session_id = f"{connection_id}:{client_session_id}"
|
||||
|
||||
async def send_chat_event(chat_message: Dict[str, Any]) -> None:
|
||||
"""将聊天消息封装为统一事件并发送。
|
||||
|
||||
Args:
|
||||
chat_message: 聊天消息体。
|
||||
"""
|
||||
event_name = str(chat_message.get("type") or "message")
|
||||
await websocket_manager.send_event(
|
||||
connection_id,
|
||||
domain="chat",
|
||||
event=event_name,
|
||||
session=client_session_id,
|
||||
data=chat_message,
|
||||
)
|
||||
|
||||
await chat_manager.connect(
|
||||
session_id=session_id,
|
||||
connection_id=connection_id,
|
||||
client_session_id=client_session_id,
|
||||
user_id=normalized_user_id,
|
||||
user_name=current_user_name,
|
||||
virtual_config=current_virtual_config,
|
||||
sender=send_chat_event,
|
||||
)
|
||||
websocket_manager.register_chat_session(connection_id, client_session_id, session_id)
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=True,
|
||||
data={"session": client_session_id, "session_id": session_id},
|
||||
)
|
||||
await send_initial_chat_state(
|
||||
session_id=session_id,
|
||||
user_id=normalized_user_id,
|
||||
user_name=current_user_name,
|
||||
virtual_config=current_virtual_config,
|
||||
include_welcome=not restore,
|
||||
)
|
||||
|
||||
|
||||
async def _close_chat_session(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""关闭一个逻辑聊天会话。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
client_session_id = str(message.get("session") or "").strip()
|
||||
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
|
||||
if session_id is None:
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
|
||||
)
|
||||
return
|
||||
|
||||
chat_manager.disconnect(session_id)
|
||||
websocket_manager.unregister_chat_session(connection_id, client_session_id)
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=True,
|
||||
data={"session": client_session_id},
|
||||
)
|
||||
|
||||
|
||||
async def _process_chat_message(connection_id: str, client_session_id: str, data: Dict[str, Any]) -> None:
|
||||
"""在后台处理聊天消息事件。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
client_session_id: 前端会话 ID。
|
||||
data: 客户端提交的消息数据。
|
||||
"""
|
||||
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
|
||||
if session_id is None:
|
||||
return
|
||||
|
||||
session_state = chat_manager.get_session(session_id)
|
||||
if session_state is None:
|
||||
return
|
||||
|
||||
next_user_name, next_virtual_config = await dispatch_chat_event(
|
||||
session_id=session_id,
|
||||
session_id_prefix=session_id[:8],
|
||||
data=data,
|
||||
current_user_name=session_state.user_name,
|
||||
normalized_user_id=session_state.user_id,
|
||||
current_virtual_config=session_state.virtual_config,
|
||||
)
|
||||
chat_manager.update_session_context(
|
||||
session_id=session_id,
|
||||
user_name=next_user_name,
|
||||
virtual_config=next_virtual_config,
|
||||
)
|
||||
|
||||
|
||||
async def _handle_chat_message_send(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""处理聊天消息发送请求。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
client_session_id = str(message.get("session") or "").strip()
|
||||
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
|
||||
if session_id is None:
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
|
||||
)
|
||||
return
|
||||
|
||||
data = _get_request_data(message)
|
||||
payload = {
|
||||
"type": "message",
|
||||
"content": data.get("content", ""),
|
||||
"user_name": data.get("user_name", ""),
|
||||
}
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=True,
|
||||
data={"accepted": True, "session": client_session_id},
|
||||
)
|
||||
_track_background_task(asyncio.create_task(_process_chat_message(connection_id, client_session_id, payload)))
|
||||
|
||||
|
||||
async def _handle_chat_nickname_update(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""处理聊天昵称更新请求。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
client_session_id = str(message.get("session") or "").strip()
|
||||
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
|
||||
if session_id is None:
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
|
||||
)
|
||||
return
|
||||
|
||||
data = _get_request_data(message)
|
||||
session_state = chat_manager.get_session(session_id)
|
||||
if session_state is None:
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
|
||||
)
|
||||
return
|
||||
|
||||
next_user_name, next_virtual_config = await dispatch_chat_event(
|
||||
session_id=session_id,
|
||||
session_id_prefix=session_id[:8],
|
||||
data={
|
||||
"type": "update_nickname",
|
||||
"user_name": data.get("user_name", ""),
|
||||
},
|
||||
current_user_name=session_state.user_name,
|
||||
normalized_user_id=session_state.user_id,
|
||||
current_virtual_config=session_state.virtual_config,
|
||||
)
|
||||
chat_manager.update_session_context(
|
||||
session_id=session_id,
|
||||
user_name=next_user_name,
|
||||
virtual_config=next_virtual_config,
|
||||
)
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=True,
|
||||
data={"session": client_session_id, "user_name": next_user_name},
|
||||
)
|
||||
|
||||
|
||||
async def _handle_chat_call(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""处理聊天域调用请求。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
method = str(message.get("method") or "").strip()
|
||||
|
||||
if method == "session.open":
|
||||
await _open_chat_session(connection_id, message)
|
||||
return
|
||||
|
||||
if method == "session.close":
|
||||
await _close_chat_session(connection_id, message)
|
||||
return
|
||||
|
||||
if method == "message.send":
|
||||
await _handle_chat_message_send(connection_id, message)
|
||||
return
|
||||
|
||||
if method == "session.update_nickname":
|
||||
await _handle_chat_nickname_update(connection_id, message)
|
||||
return
|
||||
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("unsupported_method", f"不支持的聊天方法: {method}"),
|
||||
)
|
||||
|
||||
|
||||
async def _handle_call(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""处理统一调用请求。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
domain = str(message.get("domain") or "").strip()
|
||||
if domain == "chat":
|
||||
await _handle_chat_call(connection_id, message)
|
||||
return
|
||||
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("unsupported_domain", f"不支持的调用域: {domain}"),
|
||||
)
|
||||
|
||||
|
||||
async def handle_client_message(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""处理统一 WebSocket 客户端消息。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
operation = str(message.get("op") or "").strip()
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
|
||||
if operation == "ping":
|
||||
await websocket_manager.send_pong(connection_id, time.time())
|
||||
return
|
||||
|
||||
if operation == "subscribe":
|
||||
await _handle_subscribe(connection_id, message)
|
||||
return
|
||||
|
||||
if operation == "unsubscribe":
|
||||
await _handle_unsubscribe(connection_id, message)
|
||||
return
|
||||
|
||||
if operation == "call":
|
||||
await _handle_call(connection_id, message)
|
||||
return
|
||||
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("unsupported_operation", f"不支持的操作: {operation}"),
|
||||
)
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket, token: Optional[str] = Query(None)) -> None:
|
||||
"""统一 WebSocket 入口。
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket 对象。
|
||||
token: 可选的一次性握手 Token。
|
||||
"""
|
||||
if not await authenticate_websocket_connection(websocket, token):
|
||||
logger.warning("统一 WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
connection_id = uuid.uuid4().hex
|
||||
await websocket_manager.connect(connection_id=connection_id, websocket=websocket)
|
||||
logger.info("统一 WebSocket 客户端已连接: connection=%s", connection_id)
|
||||
await websocket_manager.send_event(
|
||||
connection_id,
|
||||
domain="system",
|
||||
event="ready",
|
||||
data={"connection_id": connection_id, "timestamp": time.time()},
|
||||
)
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw_message = await websocket.receive_json()
|
||||
if not isinstance(raw_message, dict):
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=None,
|
||||
ok=False,
|
||||
error=_build_error("invalid_message", "消息必须是 JSON 对象"),
|
||||
)
|
||||
continue
|
||||
await handle_client_message(connection_id, cast(Dict[str, Any], raw_message))
|
||||
except WebSocketDisconnect:
|
||||
logger.info("统一 WebSocket 客户端已断开: connection=%s", connection_id)
|
||||
except Exception as exc:
|
||||
logger.error(f"统一 WebSocket 处理失败: {exc}")
|
||||
finally:
|
||||
chat_manager.disconnect_connection(connection_id)
|
||||
await websocket_manager.disconnect(connection_id)
|
||||
@@ -18,11 +18,11 @@ from src.webui.routers.expression import router as expression_router
|
||||
from src.webui.routers.jargon import router as jargon_router
|
||||
from src.webui.routers.model import router as model_router
|
||||
from src.webui.routers.person import router as person_router
|
||||
from src.webui.routers.plugin import get_progress_router
|
||||
from src.webui.routers.plugin import router as plugin_router
|
||||
from src.webui.routers.statistics import router as statistics_router
|
||||
from src.webui.routers.system import router as system_router
|
||||
from src.webui.routers.websocket.auth import router as ws_auth_router
|
||||
from src.webui.routers.websocket.unified import router as unified_ws_router
|
||||
|
||||
logger = get_logger("webui.api")
|
||||
|
||||
@@ -43,14 +43,14 @@ router.include_router(jargon_router)
|
||||
router.include_router(emoji_router)
|
||||
# 注册插件管理路由
|
||||
router.include_router(plugin_router)
|
||||
# 注册插件进度 WebSocket 路由
|
||||
router.include_router(get_progress_router())
|
||||
# 注册系统控制路由
|
||||
router.include_router(system_router)
|
||||
# 注册模型列表获取路由
|
||||
router.include_router(model_router)
|
||||
# 注册 WebSocket 认证路由
|
||||
router.include_router(ws_auth_router)
|
||||
# 注册统一 WebSocket 路由
|
||||
router.include_router(unified_ws_router)
|
||||
|
||||
|
||||
class TokenVerifyRequest(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user