WebUI 前端 & 后端超级大重构
This commit is contained in:
@@ -14,13 +14,13 @@ import {
|
||||
import { Button } from '@/components/ui/button'
|
||||
import {
|
||||
Dialog,
|
||||
DialogBody,
|
||||
DialogContent,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from '@/components/ui/dialog'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { Label } from '@/components/ui/label'
|
||||
import { ScrollArea } from '@/components/ui/scroll-area'
|
||||
import { useBackendConnections } from '@/hooks/useBackendConnections'
|
||||
import { isElectron } from '@/lib/runtime'
|
||||
import type { BackendConnection } from '@/types/electron'
|
||||
@@ -78,7 +78,7 @@ export function BackendManager({ open, onOpenChange }: BackendManagerProps) {
|
||||
return (
|
||||
<>
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="max-w-md sm:max-w-[425px]">
|
||||
<DialogContent className="max-w-md sm:max-w-106.25">
|
||||
<DialogHeader>
|
||||
<DialogTitle>后端连接管理</DialogTitle>
|
||||
</DialogHeader>
|
||||
@@ -88,7 +88,7 @@ export function BackendManager({ open, onOpenChange }: BackendManagerProps) {
|
||||
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
|
||||
</div>
|
||||
) : (
|
||||
<ScrollArea className="max-h-[60vh] pr-4">
|
||||
<DialogBody className="pr-4">
|
||||
<div className="flex flex-col gap-3 py-4">
|
||||
{backends.map((backend) => {
|
||||
const isActive = backend.id === activeId
|
||||
@@ -100,7 +100,7 @@ export function BackendManager({ open, onOpenChange }: BackendManagerProps) {
|
||||
}`}
|
||||
>
|
||||
<div className="flex flex-1 items-center gap-3 overflow-hidden">
|
||||
<div className="flex-shrink-0">
|
||||
<div className="shrink-0">
|
||||
{isActive ? (
|
||||
<Check className="h-5 w-5 text-blue-500" />
|
||||
) : (
|
||||
@@ -156,7 +156,7 @@ export function BackendManager({ open, onOpenChange }: BackendManagerProps) {
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
</DialogBody>
|
||||
)}
|
||||
|
||||
<div className="flex justify-end pt-4 border-t">
|
||||
@@ -173,7 +173,7 @@ export function BackendManager({ open, onOpenChange }: BackendManagerProps) {
|
||||
|
||||
{/* Edit/Add Dialog */}
|
||||
<Dialog open={!!editConn} onOpenChange={(open) => !open && setEditConn(null)}>
|
||||
<DialogContent className="sm:max-w-[425px]">
|
||||
<DialogContent className="sm:max-w-106.25" confirmOnEnter>
|
||||
<DialogHeader>
|
||||
<DialogTitle>{editConn?.id ? '编辑连接' : '添加连接'}</DialogTitle>
|
||||
</DialogHeader>
|
||||
@@ -212,6 +212,7 @@ export function BackendManager({ open, onOpenChange }: BackendManagerProps) {
|
||||
!editConn?.url ||
|
||||
!/^https?:\/\//.test(editConn.url)
|
||||
}
|
||||
data-dialog-action="confirm"
|
||||
>
|
||||
保存
|
||||
</Button>
|
||||
|
||||
@@ -22,6 +22,7 @@ import {
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { Badge } from '@/components/ui/badge'
|
||||
import { ShortcutKbd } from '@/components/ui/kbd'
|
||||
import { ScrollArea } from '@/components/ui/scroll-area'
|
||||
import { Checkbox } from '@/components/ui/checkbox'
|
||||
import { Tabs, TabsList, TabsTrigger } from '@/components/ui/tabs'
|
||||
@@ -1689,19 +1690,19 @@ if (isCurrent) {
|
||||
{/* 底部快捷键提示(桌面端) */}
|
||||
<div className="hidden sm:flex items-center justify-center gap-6 px-6 py-3 border-t text-xs text-muted-foreground">
|
||||
<div className="flex items-center gap-1">
|
||||
<kbd className="px-2 py-1 bg-muted rounded text-xs">←</kbd>
|
||||
<ShortcutKbd size="sm" keys={['left']} />
|
||||
<span>拒绝</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
<kbd className="px-2 py-1 bg-muted rounded text-xs">→</kbd>
|
||||
<ShortcutKbd size="sm" keys={['right']} />
|
||||
<span>通过</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
<kbd className="px-2 py-1 bg-muted rounded text-xs">↑</kbd>
|
||||
<ShortcutKbd size="sm" keys={['up']} />
|
||||
<span>上一条</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
<kbd className="px-2 py-1 bg-muted rounded text-xs">↓</kbd>
|
||||
<ShortcutKbd size="sm" keys={['down']} />
|
||||
<span>下一条</span>
|
||||
</div>
|
||||
<span className="text-muted-foreground/50">|</span>
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { Link } from '@tanstack/react-router'
|
||||
import { BookOpen, ChevronLeft, Globe, LogOut, Menu, Moon, PieChart, Search, Server, Sun } from 'lucide-react'
|
||||
import { BookOpen, ChevronLeft, Globe, LogOut, Menu, Moon, Search, Server, Sun } from 'lucide-react'
|
||||
import { useEffect, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
@@ -13,7 +12,7 @@ import {
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from '@/components/ui/dropdown-menu'
|
||||
import { Kbd } from '@/components/ui/kbd'
|
||||
import { ShortcutKbd } from '@/components/ui/kbd'
|
||||
import { toggleThemeWithTransition } from '@/components/use-theme'
|
||||
import { useBackground } from '@/hooks/use-background'
|
||||
import { logout } from '@/lib/fetch-with-auth'
|
||||
@@ -99,7 +98,7 @@ export function Header({
|
||||
title={t('header.toggleConnection')}
|
||||
>
|
||||
<Server className="h-4 w-4" />
|
||||
<span className="hidden sm:inline text-xs text-muted-foreground truncate max-w-[100px]">
|
||||
<span className="hidden sm:inline text-xs text-muted-foreground truncate max-w-25">
|
||||
{activeBackendName}
|
||||
</span>
|
||||
</Button>
|
||||
@@ -107,19 +106,6 @@ export function Header({
|
||||
<div className="h-6 w-px bg-border" />
|
||||
</>
|
||||
)}
|
||||
{/* 年度总结入口 */}
|
||||
<Link to="/annual-report">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="gap-2 bg-gradient-to-r from-pink-500/10 to-purple-500/10 hover:from-pink-500/20 hover:to-purple-500/20 border border-pink-500/20"
|
||||
title={t('header.viewAnnualSummary')}
|
||||
>
|
||||
<PieChart className="h-4 w-4 text-pink-500" />
|
||||
<span className="hidden sm:inline bg-gradient-to-r from-pink-500 to-purple-500 bg-clip-text text-transparent font-medium">{t('header.annualSummary')}</span>
|
||||
</Button>
|
||||
</Link>
|
||||
|
||||
{/* 搜索框 */}
|
||||
<button
|
||||
onClick={() => onSearchOpenChange(true)}
|
||||
@@ -128,9 +114,7 @@ export function Header({
|
||||
>
|
||||
<Search className="absolute left-3 top-1/2 h-4 w-4 -translate-y-1/2 text-muted-foreground" aria-hidden="true" />
|
||||
<span className="text-sm text-muted-foreground">{t('header.searchPlaceholder')}</span>
|
||||
<Kbd size="sm" className="absolute right-2 top-1/2 -translate-y-1/2">
|
||||
<span className="text-xs">⌘</span>K
|
||||
</Kbd>
|
||||
<ShortcutKbd size="sm" className="absolute right-2 top-1/2 -translate-y-1/2" keys={['mod', 'k']} />
|
||||
</button>
|
||||
|
||||
{/* 搜索对话框 */}
|
||||
|
||||
@@ -13,6 +13,7 @@ import { useAuthGuard } from '@/hooks/use-auth'
|
||||
import { useBackground } from '@/hooks/use-background'
|
||||
|
||||
import { TitleBar } from '@/components/electron/TitleBar'
|
||||
import { matchesShortcut } from '@/lib/keyboard'
|
||||
import { isElectron } from '@/lib/runtime'
|
||||
import { cn } from '@/lib/utils'
|
||||
import { menuSections } from './constants'
|
||||
@@ -49,7 +50,7 @@ export function Layout({ children }: LayoutProps) {
|
||||
// 搜索快捷键监听(Cmd/Ctrl + K)
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (e: KeyboardEvent) => {
|
||||
if ((e.metaKey || e.ctrlKey) && e.key === 'k') {
|
||||
if (matchesShortcut(e, ['mod', 'k'])) {
|
||||
e.preventDefault()
|
||||
setSearchOpen(true)
|
||||
}
|
||||
@@ -68,9 +69,8 @@ export function Layout({ children }: LayoutProps) {
|
||||
}
|
||||
}
|
||||
|
||||
const unsubscribe = router.subscribe('onResolved', () => {
|
||||
const pathname = router.state.location.pathname
|
||||
const pageTitle = pathToLabel[pathname] ?? 'MaiBot Dashboard'
|
||||
return router.subscribe('onResolved', () => {
|
||||
const pageTitle = pathToLabel[router.state.location.pathname] ?? 'MaiBot Dashboard'
|
||||
const fullTitle = pageTitle === 'MaiBot Dashboard'
|
||||
? 'MaiBot Dashboard'
|
||||
: `${pageTitle} — MaiBot Dashboard`
|
||||
@@ -90,8 +90,6 @@ export function Layout({ children }: LayoutProps) {
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
return unsubscribe
|
||||
}, [router, announce, t])
|
||||
|
||||
// 获取实际应用的主题(处理 system 情况)
|
||||
|
||||
@@ -6,25 +6,25 @@ export const menuSections: MenuSection[] = [
|
||||
{
|
||||
title: 'sidebar.groups.overview',
|
||||
items: [
|
||||
{ icon: Home, label: 'sidebar.menu.home', path: '/' },
|
||||
{ icon: Home, label: 'sidebar.menu.home', path: '/', searchDescription: 'search.items.homeDesc' },
|
||||
],
|
||||
},
|
||||
{
|
||||
title: 'sidebar.groups.botConfig',
|
||||
items: [
|
||||
{ icon: FileText, label: 'sidebar.menu.botMainConfig', path: '/config/bot' },
|
||||
{ icon: Server, label: 'sidebar.menu.aiModelProvider', path: '/config/modelProvider', tourId: 'sidebar-model-provider' },
|
||||
{ icon: Boxes, label: 'sidebar.menu.modelManagement', path: '/config/model', tourId: 'sidebar-model-management' },
|
||||
{ icon: FileText, label: 'sidebar.menu.botMainConfig', path: '/config/bot', searchDescription: 'search.items.botConfigDesc' },
|
||||
{ icon: Server, label: 'sidebar.menu.aiModelProvider', path: '/config/modelProvider', searchDescription: 'search.items.modelProviderDesc', tourId: 'sidebar-model-provider' },
|
||||
{ icon: Boxes, label: 'sidebar.menu.modelManagement', path: '/config/model', searchDescription: 'search.items.modelDesc', tourId: 'sidebar-model-management' },
|
||||
{ icon: Sliders, label: 'sidebar.menu.adapterConfig', path: '/config/adapter' },
|
||||
],
|
||||
},
|
||||
{
|
||||
title: 'sidebar.groups.botResources',
|
||||
items: [
|
||||
{ icon: Smile, label: 'sidebar.menu.emojiManagement', path: '/resource/emoji' },
|
||||
{ icon: MessageSquare, label: 'sidebar.menu.expressionManagement', path: '/resource/expression' },
|
||||
{ icon: Hash, label: 'sidebar.menu.slangManagement', path: '/resource/jargon' },
|
||||
{ icon: UserCircle, label: 'sidebar.menu.personInfo', path: '/resource/person' },
|
||||
{ icon: Smile, label: 'sidebar.menu.emojiManagement', path: '/resource/emoji', searchDescription: 'search.items.emojiDesc' },
|
||||
{ icon: MessageSquare, label: 'sidebar.menu.expressionManagement', path: '/resource/expression', searchDescription: 'search.items.expressionDesc' },
|
||||
{ icon: Hash, label: 'sidebar.menu.slangManagement', path: '/resource/jargon', searchDescription: 'search.items.jargonDesc' },
|
||||
{ icon: UserCircle, label: 'sidebar.menu.personInfo', path: '/resource/person', searchDescription: 'search.items.personDesc' },
|
||||
{ icon: Network, label: 'sidebar.menu.knowledgeGraph', path: '/resource/knowledge-graph' },
|
||||
{ icon: Database, label: 'sidebar.menu.knowledgeBase', path: '/resource/knowledge-base' },
|
||||
],
|
||||
@@ -32,10 +32,10 @@ export const menuSections: MenuSection[] = [
|
||||
{
|
||||
title: 'sidebar.groups.extensionsMonitor',
|
||||
items: [
|
||||
{ icon: Package, label: 'sidebar.menu.pluginMarket', path: '/plugins' },
|
||||
{ icon: Package, label: 'sidebar.menu.pluginMarket', path: '/plugins', searchDescription: 'search.items.pluginsDesc' },
|
||||
{ icon: LayoutGrid, label: 'sidebar.menu.configTemplate', path: '/config/pack-market' },
|
||||
{ icon: Sliders, label: 'sidebar.menu.pluginConfig', path: '/plugin-config' },
|
||||
{ icon: FileSearch, label: 'sidebar.menu.logViewer', path: '/logs' },
|
||||
{ icon: FileSearch, label: 'sidebar.menu.logViewer', path: '/logs', searchDescription: 'search.items.logsDesc' },
|
||||
{ icon: Activity, label: 'sidebar.menu.plannerMonitor', path: '/planner-monitor' },
|
||||
{ icon: MessageSquare, label: 'sidebar.menu.localChat', path: '/chat' },
|
||||
],
|
||||
@@ -43,7 +43,7 @@ export const menuSections: MenuSection[] = [
|
||||
{
|
||||
title: 'sidebar.groups.system',
|
||||
items: [
|
||||
{ icon: Settings, label: 'sidebar.menu.settings', path: '/settings' },
|
||||
{ icon: Settings, label: 'sidebar.menu.settings', path: '/settings', searchDescription: 'search.items.settingsDesc' },
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
@@ -9,6 +9,7 @@ export interface MenuItem {
|
||||
icon: ComponentType<LucideProps>
|
||||
label: string
|
||||
path: string
|
||||
searchDescription?: string
|
||||
tourId?: string
|
||||
}
|
||||
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
import { useState, useCallback, useMemo } from 'react'
|
||||
import { Search, FileText, Server, Boxes, Smile, MessageSquare, UserCircle, FileSearch, BarChart3, Package, Settings, Home, Hash } from 'lucide-react'
|
||||
import { useState, useCallback, useEffect, useMemo, useRef } from 'react'
|
||||
import { Search } from 'lucide-react'
|
||||
import { useNavigate } from '@tanstack/react-router'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import type { LucideProps } from 'lucide-react'
|
||||
|
||||
import {
|
||||
Dialog,
|
||||
DialogBody,
|
||||
DialogContent,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from '@/components/ui/dialog'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { ScrollArea } from '@/components/ui/scroll-area'
|
||||
import { ShortcutKbd } from '@/components/ui/kbd'
|
||||
import { menuSections } from '@/components/layout/constants'
|
||||
import { registeredRoutePaths } from '@/router'
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
interface SearchDialogProps {
|
||||
@@ -19,7 +23,7 @@ interface SearchDialogProps {
|
||||
}
|
||||
|
||||
interface SearchItem {
|
||||
icon: React.ComponentType<{ className?: string }>
|
||||
icon: React.ComponentType<LucideProps>
|
||||
title: string
|
||||
description: string
|
||||
path: string
|
||||
@@ -29,95 +33,37 @@ interface SearchItem {
|
||||
export function SearchDialog({ open, onOpenChange }: SearchDialogProps) {
|
||||
const [searchQuery, setSearchQuery] = useState('')
|
||||
const [selectedIndex, setSelectedIndex] = useState(0)
|
||||
const inputRef = useRef<HTMLInputElement>(null)
|
||||
const navigate = useNavigate()
|
||||
const { t } = useTranslation()
|
||||
|
||||
const searchItems: SearchItem[] = useMemo(() => [
|
||||
{
|
||||
icon: Home,
|
||||
title: t('search.items.home'),
|
||||
description: t('search.items.homeDesc'),
|
||||
path: '/',
|
||||
category: t('search.categories.overview'),
|
||||
},
|
||||
{
|
||||
icon: FileText,
|
||||
title: t('search.items.botConfig'),
|
||||
description: t('search.items.botConfigDesc'),
|
||||
path: '/config/bot',
|
||||
category: t('search.categories.config'),
|
||||
},
|
||||
{
|
||||
icon: Server,
|
||||
title: t('search.items.modelProvider'),
|
||||
description: t('search.items.modelProviderDesc'),
|
||||
path: '/config/modelProvider',
|
||||
category: t('search.categories.config'),
|
||||
},
|
||||
{
|
||||
icon: Boxes,
|
||||
title: t('search.items.model'),
|
||||
description: t('search.items.modelDesc'),
|
||||
path: '/config/model',
|
||||
category: t('search.categories.config'),
|
||||
},
|
||||
{
|
||||
icon: Smile,
|
||||
title: t('search.items.emoji'),
|
||||
description: t('search.items.emojiDesc'),
|
||||
path: '/resource/emoji',
|
||||
category: t('search.categories.resources'),
|
||||
},
|
||||
{
|
||||
icon: MessageSquare,
|
||||
title: t('search.items.expression'),
|
||||
description: t('search.items.expressionDesc'),
|
||||
path: '/resource/expression',
|
||||
category: t('search.categories.resources'),
|
||||
},
|
||||
{
|
||||
icon: UserCircle,
|
||||
title: t('search.items.person'),
|
||||
description: t('search.items.personDesc'),
|
||||
path: '/resource/person',
|
||||
category: t('search.categories.resources'),
|
||||
},
|
||||
{
|
||||
icon: Hash,
|
||||
title: t('search.items.jargon'),
|
||||
description: t('search.items.jargonDesc'),
|
||||
path: '/resource/jargon',
|
||||
category: t('search.categories.resources'),
|
||||
},
|
||||
{
|
||||
icon: BarChart3,
|
||||
title: t('search.items.statistics'),
|
||||
description: t('search.items.statisticsDesc'),
|
||||
path: '/statistics',
|
||||
category: t('search.categories.monitor'),
|
||||
},
|
||||
{
|
||||
icon: Package,
|
||||
title: t('search.items.plugins'),
|
||||
description: t('search.items.pluginsDesc'),
|
||||
path: '/plugins',
|
||||
category: t('search.categories.extensions'),
|
||||
},
|
||||
{
|
||||
icon: FileSearch,
|
||||
title: t('search.items.logs'),
|
||||
description: t('search.items.logsDesc'),
|
||||
path: '/logs',
|
||||
category: t('search.categories.monitor'),
|
||||
},
|
||||
{
|
||||
icon: Settings,
|
||||
title: t('search.items.settings'),
|
||||
description: t('search.items.settingsDesc'),
|
||||
path: '/settings',
|
||||
category: t('search.categories.system'),
|
||||
},
|
||||
], [t])
|
||||
useEffect(() => {
|
||||
if (!open) {
|
||||
return
|
||||
}
|
||||
|
||||
const frameId = window.requestAnimationFrame(() => {
|
||||
inputRef.current?.focus()
|
||||
})
|
||||
|
||||
return () => window.cancelAnimationFrame(frameId)
|
||||
}, [open])
|
||||
|
||||
const searchItems: SearchItem[] = useMemo(
|
||||
() =>
|
||||
menuSections.flatMap((section) =>
|
||||
section.items
|
||||
.filter((item) => registeredRoutePaths.has(item.path))
|
||||
.map((item) => ({
|
||||
icon: item.icon,
|
||||
title: t(item.label),
|
||||
description: item.searchDescription ? t(item.searchDescription) : item.path,
|
||||
path: item.path,
|
||||
category: t(section.title),
|
||||
}))
|
||||
),
|
||||
[t]
|
||||
)
|
||||
|
||||
// 过滤搜索结果
|
||||
const filteredItems = searchItems.filter(
|
||||
@@ -155,12 +101,13 @@ export function SearchDialog({ open, onOpenChange }: SearchDialogProps) {
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="max-w-2xl p-0 gap-0">
|
||||
<DialogContent className="max-w-2xl p-0 gap-0" confirmOnEnter>
|
||||
<DialogHeader className="px-4 pt-4 pb-0">
|
||||
<DialogTitle className="sr-only">{t('search.title')}</DialogTitle>
|
||||
<div className="relative">
|
||||
<Search className="absolute left-3 top-1/2 h-5 w-5 -translate-y-1/2 text-muted-foreground" />
|
||||
<Input
|
||||
ref={inputRef}
|
||||
value={searchQuery}
|
||||
onChange={(e) => {
|
||||
setSearchQuery(e.target.value)
|
||||
@@ -169,13 +116,12 @@ export function SearchDialog({ open, onOpenChange }: SearchDialogProps) {
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder={t('search.placeholder')}
|
||||
className="h-12 pl-11 text-base border-0 focus-visible:ring-0 shadow-none"
|
||||
autoFocus
|
||||
/>
|
||||
</div>
|
||||
</DialogHeader>
|
||||
|
||||
<div className="border-t">
|
||||
<ScrollArea className="h-[400px]">
|
||||
<DialogBody className="h-100" viewportClassName="px-0">
|
||||
{filteredItems.length > 0 ? (
|
||||
<div className="p-2">
|
||||
{filteredItems.map((item, index) => {
|
||||
@@ -192,7 +138,7 @@ export function SearchDialog({ open, onOpenChange }: SearchDialogProps) {
|
||||
: 'hover:bg-accent/50'
|
||||
)}
|
||||
>
|
||||
<Icon className="h-5 w-5 flex-shrink-0" />
|
||||
<Icon className="h-5 w-5 shrink-0" />
|
||||
<div className="flex-1 min-w-0">
|
||||
<div className="font-medium text-sm">{item.title}</div>
|
||||
<div className="text-xs text-muted-foreground truncate">
|
||||
@@ -214,22 +160,22 @@ export function SearchDialog({ open, onOpenChange }: SearchDialogProps) {
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</ScrollArea>
|
||||
</DialogBody>
|
||||
</div>
|
||||
|
||||
<div className="border-t px-4 py-3 flex items-center justify-between text-xs text-muted-foreground">
|
||||
<div className="flex items-center gap-4">
|
||||
<span className="flex items-center gap-1">
|
||||
<kbd className="px-1.5 py-0.5 bg-muted rounded border">↑</kbd>
|
||||
<kbd className="px-1.5 py-0.5 bg-muted rounded border">↓</kbd>
|
||||
<ShortcutKbd size="sm" keys={['up']} />
|
||||
<ShortcutKbd size="sm" keys={['down']} />
|
||||
{t('search.navigate')}
|
||||
</span>
|
||||
<span className="flex items-center gap-1">
|
||||
<kbd className="px-1.5 py-0.5 bg-muted rounded border">Enter</kbd>
|
||||
<ShortcutKbd size="sm" keys={['enter']} />
|
||||
{t('search.select')}
|
||||
</span>
|
||||
<span className="flex items-center gap-1">
|
||||
<kbd className="px-1.5 py-0.5 bg-muted rounded border">Esc</kbd>
|
||||
<ShortcutKbd size="sm" keys={['esc']} />
|
||||
{t('search.close')}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
@@ -24,6 +24,7 @@ import { Checkbox } from '@/components/ui/checkbox'
|
||||
import { Badge } from '@/components/ui/badge'
|
||||
import {
|
||||
Dialog,
|
||||
DialogBody,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
@@ -34,7 +35,6 @@ import {
|
||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'
|
||||
import { Alert, AlertDescription, AlertTitle } from '@/components/ui/alert'
|
||||
import { Separator } from '@/components/ui/separator'
|
||||
import { ScrollArea } from '@/components/ui/scroll-area'
|
||||
import { toast } from '@/hooks/use-toast'
|
||||
import {
|
||||
createPack,
|
||||
@@ -340,7 +340,7 @@ export function SharePackDialog({ trigger }: SharePackDialogProps) {
|
||||
)}
|
||||
</DialogTrigger>
|
||||
|
||||
<DialogContent className="max-w-2xl max-h-[85vh] flex flex-col">
|
||||
<DialogContent className="max-w-2xl flex flex-col" confirmOnEnter>
|
||||
<DialogHeader>
|
||||
<DialogTitle className="flex items-center gap-2">
|
||||
<Package className="w-5 h-5" />
|
||||
@@ -353,7 +353,7 @@ export function SharePackDialog({ trigger }: SharePackDialogProps) {
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<ScrollArea className="h-[calc(85vh-220px)] pr-4">
|
||||
<DialogBody>
|
||||
{loading ? (
|
||||
<div className="py-8 text-center">
|
||||
<Loader2 className="w-8 h-8 mx-auto animate-spin text-primary" />
|
||||
@@ -639,7 +639,7 @@ export function SharePackDialog({ trigger }: SharePackDialogProps) {
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</ScrollArea>
|
||||
</DialogBody>
|
||||
|
||||
<DialogFooter className="flex justify-between pt-4 border-t">
|
||||
<div>
|
||||
@@ -662,6 +662,7 @@ export function SharePackDialog({ trigger }: SharePackDialogProps) {
|
||||
</Button>
|
||||
{step < totalSteps ? (
|
||||
<Button
|
||||
data-dialog-action="confirm"
|
||||
onClick={() => setStep(step + 1)}
|
||||
disabled={
|
||||
loading ||
|
||||
@@ -671,7 +672,7 @@ export function SharePackDialog({ trigger }: SharePackDialogProps) {
|
||||
下一步
|
||||
</Button>
|
||||
) : (
|
||||
<Button onClick={handleSubmit} disabled={submitting}>
|
||||
<Button data-dialog-action="confirm" onClick={handleSubmit} disabled={submitting}>
|
||||
{submitting && <Loader2 className="w-4 h-4 mr-2 animate-spin" />}
|
||||
提交审核
|
||||
</Button>
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
import * as React from "react"
|
||||
import * as DialogPrimitive from "@radix-ui/react-dialog"
|
||||
|
||||
import { cn } from "@/lib/utils"
|
||||
import { X } from "lucide-react"
|
||||
|
||||
import { isEditableTarget, matchesShortcut } from "@/lib/keyboard"
|
||||
|
||||
import { ScrollArea } from "@/components/ui/scroll-area"
|
||||
|
||||
const Dialog = DialogPrimitive.Root
|
||||
|
||||
const DialogTrigger = DialogPrimitive.Trigger
|
||||
@@ -32,22 +37,48 @@ interface DialogContentProps
|
||||
preventOutsideClose?: boolean
|
||||
/** 隐藏默认关闭按钮(当使用自定义关闭按钮时) */
|
||||
hideCloseButton?: boolean
|
||||
/** 回车触发主操作按钮 */
|
||||
confirmOnEnter?: boolean
|
||||
}
|
||||
|
||||
interface DialogBodyProps extends React.ComponentPropsWithoutRef<typeof ScrollArea> {
|
||||
allowHorizontalScroll?: boolean
|
||||
}
|
||||
|
||||
const DialogContent = React.forwardRef<
|
||||
React.ElementRef<typeof DialogPrimitive.Content>,
|
||||
DialogContentProps
|
||||
>(({ className, children, preventOutsideClose = false, hideCloseButton = false, ...props }, ref) => (
|
||||
>(({ className, children, preventOutsideClose = false, hideCloseButton = false, confirmOnEnter = false, onKeyDownCapture, ...props }, ref) => (
|
||||
<DialogPortal>
|
||||
<DialogOverlay />
|
||||
<DialogPrimitive.Content
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"fixed left-[50%] top-[50%] z-50 grid w-full max-w-lg translate-x-[-50%] translate-y-[-50%] gap-4 border bg-background p-6 shadow-lg duration-200 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[state=closed]:slide-out-to-left-1/2 data-[state=closed]:slide-out-to-top-[48%] data-[state=open]:slide-in-from-left-1/2 data-[state=open]:slide-in-from-top-[48%] sm:rounded-lg",
|
||||
"fixed left-[50%] top-[50%] z-50 grid w-[min(calc(100vw-2rem),var(--dialog-width,32rem))] max-h-[calc(100vh-2rem)] translate-x-[-50%] translate-y-[-50%] gap-4 overflow-hidden border bg-background p-6 shadow-lg duration-200 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[state=closed]:slide-out-to-left-1/2 data-[state=closed]:slide-out-to-top-[48%] data-[state=open]:slide-in-from-left-1/2 data-[state=open]:slide-in-from-top-[48%] sm:rounded-lg",
|
||||
className
|
||||
)}
|
||||
onPointerDownOutside={preventOutsideClose ? (e) => e.preventDefault() : undefined}
|
||||
onInteractOutside={preventOutsideClose ? (e) => e.preventDefault() : undefined}
|
||||
onKeyDownCapture={(event) => {
|
||||
onKeyDownCapture?.(event)
|
||||
if (
|
||||
!confirmOnEnter ||
|
||||
event.defaultPrevented ||
|
||||
!matchesShortcut(event, ['enter']) ||
|
||||
event.nativeEvent.isComposing ||
|
||||
isEditableTarget(event.target)
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
const confirmButton = event.currentTarget.querySelector<HTMLElement>('[data-dialog-action="confirm"]:not([disabled])')
|
||||
if (!confirmButton) {
|
||||
return
|
||||
}
|
||||
|
||||
event.preventDefault()
|
||||
confirmButton.click()
|
||||
}}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
@@ -62,6 +93,22 @@ const DialogContent = React.forwardRef<
|
||||
))
|
||||
DialogContent.displayName = DialogPrimitive.Content.displayName
|
||||
|
||||
const DialogBody = React.forwardRef<HTMLDivElement, DialogBodyProps>(
|
||||
({ className, children, allowHorizontalScroll = false, contentClassName, scrollbars, viewportClassName, ...props }, ref) => (
|
||||
<ScrollArea
|
||||
ref={ref as never}
|
||||
className={cn("min-h-0 flex-1", className)}
|
||||
contentClassName={cn(allowHorizontalScroll && "min-w-full w-max", contentClassName)}
|
||||
scrollbars={scrollbars ?? (allowHorizontalScroll ? "both" : "vertical")}
|
||||
viewportClassName={cn("pr-4", viewportClassName)}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</ScrollArea>
|
||||
)
|
||||
)
|
||||
DialogBody.displayName = "DialogBody"
|
||||
|
||||
const DialogHeader = ({
|
||||
className,
|
||||
...props
|
||||
@@ -125,6 +172,7 @@ export {
|
||||
DialogClose,
|
||||
DialogContent,
|
||||
DialogHeader,
|
||||
DialogBody,
|
||||
DialogFooter,
|
||||
DialogTitle,
|
||||
DialogDescription,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import * as React from "react"
|
||||
import { cva, type VariantProps } from "class-variance-authority"
|
||||
|
||||
import { getPlatformModifierAriaLabel, getShortcutKeyLabel, type ShortcutKey } from "@/lib/keyboard"
|
||||
import { cn } from "@/lib/utils"
|
||||
|
||||
const kbdVariants = cva(
|
||||
@@ -25,6 +26,10 @@ export interface KbdProps
|
||||
abbrTitle?: string
|
||||
}
|
||||
|
||||
interface ShortcutKbdProps extends Omit<KbdProps, "children"> {
|
||||
keys: ShortcutKey[]
|
||||
}
|
||||
|
||||
const Kbd = React.forwardRef<HTMLElement, KbdProps>(
|
||||
({ className, size, abbrTitle, children, ...props }, ref) => {
|
||||
return (
|
||||
@@ -40,4 +45,20 @@ const Kbd = React.forwardRef<HTMLElement, KbdProps>(
|
||||
)
|
||||
Kbd.displayName = "Kbd"
|
||||
|
||||
export { Kbd }
|
||||
function ShortcutKbd({ keys, className, size, ...props }: ShortcutKbdProps) {
|
||||
return (
|
||||
<span className={cn("inline-flex items-center gap-1", className)}>
|
||||
{keys.map((key) => {
|
||||
const label = getShortcutKeyLabel(key)
|
||||
const abbrTitle = key === 'mod' ? getPlatformModifierAriaLabel() : undefined
|
||||
return (
|
||||
<Kbd key={`${key}-${label}`} size={size} abbrTitle={abbrTitle} {...props}>
|
||||
{label}
|
||||
</Kbd>
|
||||
)
|
||||
})}
|
||||
</span>
|
||||
)
|
||||
}
|
||||
|
||||
export { Kbd, ShortcutKbd }
|
||||
|
||||
@@ -5,22 +5,25 @@ import { cn } from "@/lib/utils"
|
||||
|
||||
interface ScrollAreaProps extends React.ComponentPropsWithoutRef<typeof ScrollAreaPrimitive.Root> {
|
||||
viewportRef?: React.RefObject<HTMLDivElement | null>
|
||||
viewportClassName?: string
|
||||
contentClassName?: string
|
||||
scrollbars?: "vertical" | "horizontal" | "both"
|
||||
}
|
||||
|
||||
const ScrollArea = React.forwardRef<
|
||||
React.ElementRef<typeof ScrollAreaPrimitive.Root>,
|
||||
ScrollAreaProps
|
||||
>(({ className, children, viewportRef, ...props }, ref) => (
|
||||
>(({ className, children, viewportRef, viewportClassName, contentClassName, scrollbars = "both", ...props }, ref) => (
|
||||
<ScrollAreaPrimitive.Root
|
||||
ref={ref}
|
||||
className={cn("relative overflow-hidden", className)}
|
||||
{...props}
|
||||
>
|
||||
<ScrollAreaPrimitive.Viewport ref={viewportRef} className="h-full w-full rounded-[inherit]">
|
||||
{children}
|
||||
<ScrollAreaPrimitive.Viewport ref={viewportRef} className={cn("h-full w-full rounded-[inherit]", viewportClassName)}>
|
||||
<div className={contentClassName}>{children}</div>
|
||||
</ScrollAreaPrimitive.Viewport>
|
||||
<ScrollBar />
|
||||
<ScrollBar orientation="horizontal" />
|
||||
{scrollbars !== "horizontal" && <ScrollBar />}
|
||||
{scrollbars !== "vertical" && <ScrollBar orientation="horizontal" />}
|
||||
<ScrollAreaPrimitive.Corner />
|
||||
</ScrollAreaPrimitive.Root>
|
||||
))
|
||||
@@ -36,9 +39,9 @@ const ScrollBar = React.forwardRef<
|
||||
className={cn(
|
||||
"flex touch-none select-none transition-colors",
|
||||
orientation === "vertical" &&
|
||||
"h-full w-2.5 border-l border-l-transparent p-[1px]",
|
||||
"h-full w-2.5 border-l border-l-transparent p-px",
|
||||
orientation === "horizontal" &&
|
||||
"h-2.5 flex-col border-t border-t-transparent p-[1px]",
|
||||
"h-2.5 flex-col border-t border-t-transparent p-px",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
import { fetchWithAuth } from './fetch-with-auth'
|
||||
|
||||
export interface TimeFootprintData {
|
||||
total_online_hours: number
|
||||
first_message_time: string | null
|
||||
first_message_user: string | null
|
||||
first_message_content: string | null
|
||||
busiest_day: string | null
|
||||
busiest_day_count: number
|
||||
hourly_distribution: number[]
|
||||
midnight_chat_count: number
|
||||
is_night_owl: boolean
|
||||
}
|
||||
|
||||
export interface SocialNetworkData {
|
||||
total_groups: number
|
||||
top_groups: Array<{
|
||||
group_id: string
|
||||
group_name: string
|
||||
message_count: number
|
||||
is_webui?: boolean
|
||||
}>
|
||||
top_users: Array<{
|
||||
user_id: string
|
||||
user_nickname: string
|
||||
message_count: number
|
||||
is_webui?: boolean
|
||||
}>
|
||||
at_count: number
|
||||
mentioned_count: number
|
||||
longest_companion_user: string | null
|
||||
longest_companion_days: number
|
||||
}
|
||||
|
||||
export interface BrainPowerData {
|
||||
total_tokens: number
|
||||
total_cost: number
|
||||
favorite_model: string | null
|
||||
favorite_model_count: number
|
||||
model_distribution: Array<{
|
||||
model: string
|
||||
count: number
|
||||
tokens: number
|
||||
cost: number
|
||||
}>
|
||||
top_reply_models: Array<{
|
||||
model: string
|
||||
count: number
|
||||
}>
|
||||
most_expensive_cost: number
|
||||
most_expensive_time: string | null
|
||||
top_token_consumers: Array<{
|
||||
user_id: string
|
||||
cost: number
|
||||
tokens: number
|
||||
}>
|
||||
silence_rate: number
|
||||
total_actions: number
|
||||
no_reply_count: number
|
||||
avg_interest_value: number
|
||||
max_interest_value: number
|
||||
max_interest_time: string | null
|
||||
avg_reasoning_length: number
|
||||
max_reasoning_length: number
|
||||
max_reasoning_time: string | null
|
||||
}
|
||||
|
||||
export interface ExpressionVibeData {
|
||||
top_emoji: {
|
||||
id: number
|
||||
path: string
|
||||
description: string
|
||||
usage_count: number
|
||||
hash: string
|
||||
} | null
|
||||
top_emojis: Array<{
|
||||
id: number
|
||||
path: string
|
||||
description: string
|
||||
usage_count: number
|
||||
hash: string
|
||||
}>
|
||||
top_expressions: Array<{
|
||||
style: string
|
||||
count: number
|
||||
}>
|
||||
rejected_expression_count: number
|
||||
checked_expression_count: number
|
||||
total_expressions: number
|
||||
action_types: Array<{
|
||||
action: string
|
||||
count: number
|
||||
}>
|
||||
image_processed_count: number
|
||||
late_night_reply: {
|
||||
time: string
|
||||
content: string
|
||||
} | null
|
||||
favorite_reply: {
|
||||
content: string
|
||||
count: number
|
||||
} | null
|
||||
}
|
||||
|
||||
export interface AchievementData {
|
||||
new_jargon_count: number
|
||||
sample_jargons: Array<{
|
||||
content: string
|
||||
meaning: string
|
||||
count: number
|
||||
}>
|
||||
total_messages: number
|
||||
total_replies: number
|
||||
}
|
||||
|
||||
export interface AnnualReportData {
|
||||
year: number
|
||||
bot_name: string
|
||||
generated_at: string
|
||||
time_footprint: TimeFootprintData
|
||||
social_network: SocialNetworkData
|
||||
brain_power: BrainPowerData
|
||||
expression_vibe: ExpressionVibeData
|
||||
achievements: AchievementData
|
||||
}
|
||||
|
||||
export async function getAnnualReport(year: number = 2025): Promise<AnnualReportData> {
|
||||
const response = await fetchWithAuth(`/api/webui/annual-report/full?year=${year}`)
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.json()
|
||||
throw new Error(error.detail || '获取年度报告失败')
|
||||
}
|
||||
|
||||
return response.json()
|
||||
}
|
||||
93
dashboard/src/lib/keyboard.ts
Normal file
93
dashboard/src/lib/keyboard.ts
Normal file
@@ -0,0 +1,93 @@
|
||||
export type ShortcutKey =
|
||||
| 'mod'
|
||||
| 'shift'
|
||||
| 'alt'
|
||||
| 'enter'
|
||||
| 'esc'
|
||||
| 'up'
|
||||
| 'down'
|
||||
| 'left'
|
||||
| 'right'
|
||||
| string
|
||||
|
||||
const MAC_PLATFORMS = /(Mac|iPhone|iPod|iPad)/i
|
||||
|
||||
export function isMacLikePlatform(): boolean {
|
||||
if (typeof navigator === 'undefined') {
|
||||
return false
|
||||
}
|
||||
|
||||
return MAC_PLATFORMS.test(navigator.platform || navigator.userAgent)
|
||||
}
|
||||
|
||||
export function getShortcutKeyLabel(key: ShortcutKey): string {
|
||||
const isMacLike = isMacLikePlatform()
|
||||
const normalizedKey = key.toLowerCase()
|
||||
|
||||
switch (normalizedKey) {
|
||||
case 'mod':
|
||||
return isMacLike ? '⌘' : 'Ctrl'
|
||||
case 'shift':
|
||||
return isMacLike ? '⇧' : 'Shift'
|
||||
case 'alt':
|
||||
return isMacLike ? '⌥' : 'Alt'
|
||||
case 'enter':
|
||||
return isMacLike ? '↵' : 'Enter'
|
||||
case 'esc':
|
||||
case 'escape':
|
||||
return 'Esc'
|
||||
case 'up':
|
||||
return '↑'
|
||||
case 'down':
|
||||
return '↓'
|
||||
case 'left':
|
||||
return '←'
|
||||
case 'right':
|
||||
return '→'
|
||||
default:
|
||||
return key.length === 1 ? key.toUpperCase() : key
|
||||
}
|
||||
}
|
||||
|
||||
export function getPlatformModifierAriaLabel(): string {
|
||||
return isMacLikePlatform() ? 'Command' : 'Control'
|
||||
}
|
||||
|
||||
export function matchesShortcut(event: KeyboardEvent | React.KeyboardEvent, keys: ShortcutKey[]): boolean {
|
||||
const normalizedKeys = keys.map((key) => key.toLowerCase())
|
||||
const eventKey = event.key.toLowerCase()
|
||||
|
||||
const modifierChecks = {
|
||||
mod: isMacLikePlatform() ? event.metaKey : event.ctrlKey,
|
||||
shift: event.shiftKey,
|
||||
alt: event.altKey,
|
||||
}
|
||||
|
||||
for (const key of normalizedKeys) {
|
||||
if (key in modifierChecks) {
|
||||
if (!modifierChecks[key as keyof typeof modifierChecks]) {
|
||||
return false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if (eventKey !== key) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
export function isEditableTarget(target: EventTarget | null): boolean {
|
||||
if (!(target instanceof HTMLElement)) {
|
||||
return false
|
||||
}
|
||||
|
||||
return (
|
||||
target.tagName === 'INPUT' ||
|
||||
target.tagName === 'TEXTAREA' ||
|
||||
target.isContentEditable ||
|
||||
target.getAttribute('role') === 'textbox'
|
||||
)
|
||||
}
|
||||
@@ -23,9 +23,6 @@ export const STORAGE_KEYS = {
|
||||
WS_MAX_RECONNECT_ATTEMPTS: 'maibot-ws-max-reconnect-attempts',
|
||||
|
||||
// 用户数据
|
||||
// 注意:ACCESS_TOKEN 已弃用,现在使用 HttpOnly Cookie 存储认证信息
|
||||
// 保留此常量仅用于向后兼容和清理旧数据
|
||||
ACCESS_TOKEN: 'access-token',
|
||||
COMPLETED_TOURS: 'maibot-completed-tours',
|
||||
CHAT_USER_ID: 'maibot_webui_user_id',
|
||||
CHAT_USER_NAME: 'maibot_webui_user_name',
|
||||
@@ -211,10 +208,8 @@ export function clearLocalCache(): { clearedKeys: string[]; preservedKeys: strin
|
||||
const keysToRemove: string[] = []
|
||||
for (let i = 0; i < localStorage.length; i++) {
|
||||
const key = localStorage.key(i)
|
||||
if (key) {
|
||||
if (key.startsWith('maibot') || key.startsWith('accent-color') || key === 'access-token') {
|
||||
keysToRemove.push(key)
|
||||
}
|
||||
if (key && (key.startsWith('maibot') || key.startsWith('accent-color'))) {
|
||||
keysToRemove.push(key)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ import { PluginMirrorsPage } from './routes/plugin-mirrors'
|
||||
import { PluginDetailPage } from './routes/plugin-detail'
|
||||
import { ChatPage } from './routes/chat/index'
|
||||
import { WebUIFeedbackSurveyPage, MaiBotFeedbackSurveyPage } from './routes/survey'
|
||||
import { AnnualReportPage } from './routes/annual-report'
|
||||
import PackMarketPage from './routes/config/pack-market'
|
||||
import PackDetailPage from './routes/config/pack-detail'
|
||||
import { Layout } from './components/layout'
|
||||
@@ -241,13 +240,6 @@ const maibotFeedbackSurveyRoute = createRoute({
|
||||
component: MaiBotFeedbackSurveyPage,
|
||||
})
|
||||
|
||||
// 年度报告路由
|
||||
const annualReportRoute = createRoute({
|
||||
getParentRoute: () => protectedRoute,
|
||||
path: '/annual-report',
|
||||
component: AnnualReportPage,
|
||||
})
|
||||
|
||||
// 404 路由
|
||||
const notFoundRoute = createRoute({
|
||||
getParentRoute: () => rootRoute,
|
||||
@@ -284,11 +276,23 @@ const routeTree = rootRoute.addChildren([
|
||||
packDetailRoute,
|
||||
webuiFeedbackSurveyRoute,
|
||||
maibotFeedbackSurveyRoute,
|
||||
annualReportRoute,
|
||||
]),
|
||||
notFoundRoute,
|
||||
])
|
||||
|
||||
type RouteNode = {
|
||||
fullPath?: string
|
||||
children?: RouteNode[]
|
||||
}
|
||||
|
||||
function collectRoutePaths(node: RouteNode): string[] {
|
||||
const currentPath = node.fullPath ? [node.fullPath] : []
|
||||
const childPaths = node.children?.flatMap(collectRoutePaths) ?? []
|
||||
return [...currentPath, ...childPaths]
|
||||
}
|
||||
|
||||
export const registeredRoutePaths = new Set(collectRoutePaths(routeTree as RouteNode))
|
||||
|
||||
// 创建路由器
|
||||
export const router = createRouter({
|
||||
routeTree,
|
||||
|
||||
@@ -1,883 +0,0 @@
|
||||
import { useState, useRef, useEffect, useCallback } from 'react'
|
||||
import { getAnnualReport, type AnnualReportData } from '@/lib/annual-report-api'
|
||||
import { Card, CardContent, CardHeader, CardTitle, CardDescription } from '@/components/ui/card'
|
||||
import { Skeleton } from '@/components/ui/skeleton'
|
||||
import { Badge } from '@/components/ui/badge'
|
||||
import { ScrollArea } from '@/components/ui/scroll-area'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { useToast } from '@/hooks/use-toast'
|
||||
import { toPng } from 'html-to-image'
|
||||
import {
|
||||
BarChart,
|
||||
Bar,
|
||||
XAxis,
|
||||
YAxis,
|
||||
CartesianGrid,
|
||||
Tooltip,
|
||||
ResponsiveContainer,
|
||||
} from 'recharts'
|
||||
import {
|
||||
Clock,
|
||||
Users,
|
||||
Brain,
|
||||
Smile,
|
||||
Trophy,
|
||||
Calendar,
|
||||
MessageSquare,
|
||||
Zap,
|
||||
Moon,
|
||||
Sun,
|
||||
AtSign,
|
||||
Heart,
|
||||
Image as ImageIcon,
|
||||
Bot,
|
||||
Download,
|
||||
Loader2,
|
||||
} from 'lucide-react'
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
// 颜色常量
|
||||
const COLORS = ['#0088FE', '#00C49F', '#FFBB28', '#FF8042', '#8884d8', '#82ca9d']
|
||||
|
||||
// 动态比喻生成函数
|
||||
function getOnlineHoursMetaphor(hours: number): string {
|
||||
if (hours >= 8760) return "相当于全年无休,7x24小时在线!"
|
||||
if (hours >= 5000) return "相当于一位全职员工的年工作时长"
|
||||
if (hours >= 2000) return "相当于看完了 1000 部电影"
|
||||
if (hours >= 1000) return "相当于环球飞行 80 次"
|
||||
if (hours >= 500) return "相当于读完了 100 本书"
|
||||
if (hours >= 100) return "相当于马拉松跑了 25 次"
|
||||
return "虽然不多,但每一刻都很珍贵"
|
||||
}
|
||||
|
||||
function getMidnightMetaphor(count: number): string {
|
||||
if (count >= 1000) return "夜深人静时的知心好友"
|
||||
if (count >= 500) return "午夜场的常客"
|
||||
if (count >= 100) return "偶尔熬夜的小伙伴"
|
||||
if (count >= 50) return "深夜有时也会陪你聊聊"
|
||||
return "早睡早起,健康作息"
|
||||
}
|
||||
|
||||
function getTokenMetaphor(tokens: number): string {
|
||||
const millions = tokens / 1000000
|
||||
if (millions >= 100) return "思考量堪比一座图书馆"
|
||||
if (millions >= 50) return "相当于写了一部百科全书"
|
||||
if (millions >= 10) return "脑细胞估计消耗了不少"
|
||||
if (millions >= 1) return "也算是费了一番脑筋"
|
||||
return "轻轻松松,游刃有余"
|
||||
}
|
||||
|
||||
function getCostMetaphor(cost: number): string {
|
||||
if (cost >= 1000) return "这钱够吃一年的泡面了"
|
||||
if (cost >= 500) return "相当于买了一台游戏机"
|
||||
if (cost >= 100) return "够请大家喝几杯奶茶"
|
||||
if (cost >= 50) return "一顿火锅的钱"
|
||||
if (cost >= 10) return "几杯咖啡的价格"
|
||||
return "省钱小能手"
|
||||
}
|
||||
|
||||
function getSilenceMetaphor(rate: number): string {
|
||||
if (rate >= 80) return "沉默是金,惜字如金"
|
||||
if (rate >= 60) return "话不多但句句到位"
|
||||
if (rate >= 40) return "该说的时候才开口"
|
||||
if (rate >= 20) return "能聊的都聊了"
|
||||
return "话痨本痨,有问必答"
|
||||
}
|
||||
|
||||
function getImageMetaphor(count: number): string {
|
||||
if (count >= 10000) return "眼睛都快看花了"
|
||||
if (count >= 5000) return "堪比专业摄影师的阅片量"
|
||||
if (count >= 1000) return "看图小达人"
|
||||
if (count >= 500) return "图片鉴赏家"
|
||||
if (count >= 100) return "偶尔欣赏一下美图"
|
||||
return "图片?有空再看"
|
||||
}
|
||||
|
||||
function getRejectedMetaphor(count: number): string {
|
||||
if (count >= 500) return "在不断的纠正中成长"
|
||||
if (count >= 200) return "学习永无止境"
|
||||
if (count >= 100) return "虚心接受,积极改正"
|
||||
if (count >= 50) return "偶尔也会犯错"
|
||||
if (count >= 10) return "表现还算不错"
|
||||
return "完美表达,无需纠正"
|
||||
}
|
||||
|
||||
function getExpensiveThinkingMetaphor(cost: number): string {
|
||||
if (cost >= 1) return "这次思考的价值堪比一顿大餐!"
|
||||
if (cost >= 0.5) return "为了这个问题,我可是认真思考了!"
|
||||
if (cost >= 0.1) return "下了点功夫,值得的!"
|
||||
if (cost >= 0.01) return "花了点小钱,但很值得"
|
||||
return "小小思考,不足挂齿"
|
||||
}
|
||||
|
||||
function getFavoriteReplyMetaphor(count: number, botName: string): string {
|
||||
if (count >= 100) return "这句话简直是万能钥匙!"
|
||||
if (count >= 50) return "百试不爽的经典回复"
|
||||
if (count >= 20) return `${botName}的口头禅`
|
||||
if (count >= 10) return "常用语录之一"
|
||||
return "偶尔用用的小确幸"
|
||||
}
|
||||
|
||||
function getNightOwlMetaphor(isNightOwl: boolean, midnightCount: number): string {
|
||||
if (isNightOwl) {
|
||||
if (midnightCount >= 1000) return "深夜的守护者,黑暗中的光芒"
|
||||
if (midnightCount >= 500) return "月亮是我的好朋友"
|
||||
if (midnightCount >= 100) return "越夜越精神,夜晚才是主场"
|
||||
return "偶尔熬夜,享受宁静时光"
|
||||
} else {
|
||||
if (midnightCount <= 10) return "作息规律,健康生活的典范"
|
||||
if (midnightCount <= 50) return "早睡早起,偶尔也会熬个夜"
|
||||
return "虽然是早起鸟,但也会守候深夜"
|
||||
}
|
||||
}
|
||||
|
||||
function getBusiestDayMetaphor(count: number): string {
|
||||
if (count >= 1000) return "忙到飞起,键盘都要冒烟了"
|
||||
if (count >= 500) return "这天简直是话痨附体"
|
||||
if (count >= 200) return "社交达人上线"
|
||||
if (count >= 100) return "比平时活跃不少"
|
||||
if (count >= 50) return "小忙一下"
|
||||
return "还算轻松的一天"
|
||||
}
|
||||
|
||||
export function AnnualReportPage() {
|
||||
const [year] = useState(2025)
|
||||
const [data, setData] = useState<AnnualReportData | null>(null)
|
||||
const [isLoading, setIsLoading] = useState(true)
|
||||
const [isExporting, setIsExporting] = useState(false)
|
||||
const [error, setError] = useState<Error | null>(null)
|
||||
const reportRef = useRef<HTMLDivElement>(null)
|
||||
const { toast } = useToast()
|
||||
|
||||
const loadReport = useCallback(async () => {
|
||||
try {
|
||||
setIsLoading(true)
|
||||
setError(null)
|
||||
const result = await getAnnualReport(year)
|
||||
setData(result)
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err : new Error('获取年度报告失败'))
|
||||
} finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}, [year])
|
||||
|
||||
// 导出为图片
|
||||
const handleExport = useCallback(async () => {
|
||||
if (!reportRef.current || !data) return
|
||||
|
||||
setIsExporting(true)
|
||||
toast({
|
||||
title: '正在生成图片',
|
||||
description: '请稍候...',
|
||||
})
|
||||
|
||||
try {
|
||||
const element = reportRef.current
|
||||
|
||||
// 获取当前主题的背景色
|
||||
const computedStyle = getComputedStyle(document.documentElement)
|
||||
const backgroundColor = computedStyle.getPropertyValue('--background').trim()
|
||||
? `hsl(${computedStyle.getPropertyValue('--background').trim()})`
|
||||
: (document.documentElement.classList.contains('dark') ? '#0a0a0a' : '#ffffff')
|
||||
|
||||
// 保存原始样式
|
||||
const originalWidth = element.style.width
|
||||
const originalMaxWidth = element.style.maxWidth
|
||||
|
||||
// 临时设置固定宽度以去除左右空白
|
||||
element.style.width = '1024px'
|
||||
element.style.maxWidth = '1024px'
|
||||
|
||||
const dataUrl = await toPng(element, {
|
||||
quality: 1,
|
||||
pixelRatio: 2,
|
||||
backgroundColor,
|
||||
cacheBust: true,
|
||||
filter: (node) => {
|
||||
// 过滤掉导出按钮
|
||||
if (node instanceof HTMLElement && node.hasAttribute('data-export-btn')) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
},
|
||||
})
|
||||
|
||||
// 恢复原始样式
|
||||
element.style.width = originalWidth
|
||||
element.style.maxWidth = originalMaxWidth
|
||||
|
||||
// 创建下载链接
|
||||
const link = document.createElement('a')
|
||||
link.download = `${data.bot_name}_${data.year}_年度总结.png`
|
||||
link.href = dataUrl
|
||||
link.click()
|
||||
|
||||
toast({
|
||||
title: '导出成功',
|
||||
description: '年度报告已保存为图片',
|
||||
})
|
||||
} catch (err) {
|
||||
console.error('导出图片失败:', err)
|
||||
toast({
|
||||
title: '导出失败',
|
||||
description: '请重试',
|
||||
variant: 'destructive',
|
||||
})
|
||||
} finally {
|
||||
setIsExporting(false)
|
||||
}
|
||||
}, [data, toast])
|
||||
|
||||
useEffect(() => {
|
||||
loadReport()
|
||||
}, [loadReport])
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingSkeleton />
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="flex h-screen items-center justify-center text-red-500">
|
||||
获取年度报告失败: {error.message}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (!data) return null
|
||||
|
||||
return (
|
||||
<ScrollArea className="h-[calc(100vh-4rem)]">
|
||||
<div className="min-h-screen bg-gradient-to-b from-background to-muted/50 p-4 md:p-8 print:p-0" ref={reportRef}>
|
||||
<div className="mx-auto max-w-5xl space-y-8 print:space-y-4">
|
||||
{/* 头部 Hero */}
|
||||
<header className="relative overflow-hidden rounded-3xl bg-primary p-8 text-primary-foreground shadow-2xl print:rounded-none print:shadow-none">
|
||||
{/* 导出按钮 */}
|
||||
<div className="absolute right-4 top-4 z-20 print:hidden" data-export-btn>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
onClick={handleExport}
|
||||
disabled={isExporting}
|
||||
className="gap-2 bg-white/20 hover:bg-white/30 text-white border-white/30"
|
||||
>
|
||||
{isExporting ? (
|
||||
<>
|
||||
<Loader2 className="h-4 w-4 animate-spin" />
|
||||
导出中...
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Download className="h-4 w-4" />
|
||||
保存图片
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
<div className="relative z-10 flex flex-col items-center text-center">
|
||||
<Bot className="mb-4 h-16 w-16 animate-bounce" />
|
||||
<h1 className="text-4xl font-bold tracking-tighter sm:text-6xl">
|
||||
{data.bot_name} {data.year} 年度总结
|
||||
</h1>
|
||||
<p className="mt-4 max-w-2xl text-lg opacity-90">
|
||||
连接与成长 · Connection & Growth
|
||||
</p>
|
||||
<div className="mt-6 flex items-center gap-2 text-sm opacity-75">
|
||||
<Calendar className="h-4 w-4" />
|
||||
<span>生成时间: {data.generated_at}</span>
|
||||
</div>
|
||||
</div>
|
||||
{/* 背景装饰 */}
|
||||
<div className="absolute -right-20 -top-20 h-64 w-64 rounded-full bg-white/10 blur-3xl" />
|
||||
<div className="absolute -bottom-20 -left-20 h-64 w-64 rounded-full bg-white/10 blur-3xl" />
|
||||
</header>
|
||||
|
||||
{/* 维度一:时光足迹 */}
|
||||
<section className="space-y-4 break-inside-avoid">
|
||||
<div className="flex items-center gap-2 text-2xl font-bold text-primary">
|
||||
<Clock className="h-8 w-8" />
|
||||
<h2>时光足迹</h2>
|
||||
</div>
|
||||
<div className="grid gap-4 md:grid-cols-2 lg:grid-cols-4">
|
||||
<StatCard
|
||||
title="年度在线时长"
|
||||
value={`${data.time_footprint.total_online_hours} 小时`}
|
||||
description={getOnlineHoursMetaphor(data.time_footprint.total_online_hours)}
|
||||
icon={<Clock className="h-4 w-4" />}
|
||||
/>
|
||||
<StatCard
|
||||
title="最忙碌的一天"
|
||||
value={data.time_footprint.busiest_day || 'N/A'}
|
||||
description={getBusiestDayMetaphor(data.time_footprint.busiest_day_count)}
|
||||
icon={<Calendar className="h-4 w-4" />}
|
||||
/>
|
||||
<StatCard
|
||||
title="深夜互动 (0-4点)"
|
||||
value={`${data.time_footprint.midnight_chat_count} 次`}
|
||||
description={getMidnightMetaphor(data.time_footprint.midnight_chat_count)}
|
||||
icon={<Moon className="h-4 w-4" />}
|
||||
/>
|
||||
<StatCard
|
||||
title="作息属性"
|
||||
value={data.time_footprint.is_night_owl ? '夜猫子' : '早起鸟'}
|
||||
description={getNightOwlMetaphor(data.time_footprint.is_night_owl, data.time_footprint.midnight_chat_count)}
|
||||
icon={data.time_footprint.is_night_owl ? <Moon className="h-4 w-4" /> : <Sun className="h-4 w-4" />}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<Card className="overflow-hidden">
|
||||
<CardHeader>
|
||||
<CardTitle>24小时活跃时钟</CardTitle>
|
||||
<CardDescription>{data.bot_name}在一天中各个时段的活跃程度</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="h-[300px]">
|
||||
<ResponsiveContainer width="100%" height="100%">
|
||||
<BarChart data={data.time_footprint.hourly_distribution.map((count: number, hour: number) => ({ hour: `${hour}点`, count }))}>
|
||||
<CartesianGrid strokeDasharray="3 3" vertical={false} />
|
||||
<XAxis dataKey="hour" />
|
||||
<YAxis />
|
||||
<Tooltip
|
||||
contentStyle={{ borderRadius: '8px', border: 'none', boxShadow: '0 4px 12px rgba(0,0,0,0.1)' }}
|
||||
cursor={{ fill: 'transparent' }}
|
||||
/>
|
||||
<Bar dataKey="count" fill="hsl(var(--color-primary))" radius={[4, 4, 0, 0]} />
|
||||
</BarChart>
|
||||
</ResponsiveContainer>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
{data.time_footprint.first_message_time && (
|
||||
<Card className="bg-muted/30 border-dashed">
|
||||
<CardContent className="flex flex-col items-center justify-center p-6 text-center">
|
||||
<p className="text-muted-foreground mb-2">2025年的故事开始于</p>
|
||||
<div className="text-xl font-bold text-primary mb-1">{data.time_footprint.first_message_time}</div>
|
||||
<p className="text-lg">
|
||||
<span className="font-semibold text-foreground">{data.time_footprint.first_message_user}</span> 说:
|
||||
<span className="italic text-muted-foreground">"{data.time_footprint.first_message_content}"</span>
|
||||
</p>
|
||||
</CardContent>
|
||||
</Card>
|
||||
)}
|
||||
</section>
|
||||
|
||||
{/* 维度二:社交网络 */}
|
||||
<section className="space-y-4 break-inside-avoid">
|
||||
<div className="flex items-center gap-2 text-2xl font-bold text-primary">
|
||||
<Users className="h-8 w-8" />
|
||||
<h2>社交网络</h2>
|
||||
</div>
|
||||
<div className="grid gap-4 md:grid-cols-3">
|
||||
<StatCard
|
||||
title="社交圈子"
|
||||
value={`${data.social_network.total_groups} 个群组`}
|
||||
description={`${data.bot_name}加入的群组总数`}
|
||||
icon={<Users className="h-4 w-4" />}
|
||||
/>
|
||||
<StatCard
|
||||
title="被呼叫次数"
|
||||
value={`${data.social_network.at_count + data.social_network.mentioned_count} 次`}
|
||||
description="我的名字被大家频繁提起"
|
||||
icon={<AtSign className="h-4 w-4" />}
|
||||
/>
|
||||
<StatCard
|
||||
title="最长情陪伴"
|
||||
value={data.social_network.longest_companion_user || 'N/A'}
|
||||
description={`始终都在,已陪伴 ${data.social_network.longest_companion_days} 天`}
|
||||
icon={<Heart className="h-4 w-4 text-red-500" />}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>话痨群组 TOP5</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="space-y-3">
|
||||
{data.social_network.top_groups.length > 0 ? (
|
||||
data.social_network.top_groups.map((group: { group_id: string; group_name: string; message_count: number; is_webui?: boolean }, index: number) => (
|
||||
<div key={group.group_id} className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-3">
|
||||
<Badge variant={index === 0 ? "default" : "secondary"} className="h-6 w-6 rounded-full p-0 flex items-center justify-center shrink-0">
|
||||
{index + 1}
|
||||
</Badge>
|
||||
<span className="font-medium truncate max-w-[120px]">{group.group_name}</span>
|
||||
{group.is_webui && (
|
||||
<Badge variant="outline" className="text-xs px-1.5 py-0 h-5 bg-blue-50 text-blue-600 border-blue-200">
|
||||
WebUI
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
<span className="text-muted-foreground text-sm shrink-0">{group.message_count} 条消息</span>
|
||||
</div>
|
||||
))
|
||||
) : (
|
||||
<div className="text-center text-muted-foreground py-4">暂无数据</div>
|
||||
)}
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>年度最佳损友 TOP5</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="space-y-3">
|
||||
{data.social_network.top_users.length > 0 ? (
|
||||
data.social_network.top_users.map((user: { user_id: string; user_nickname: string; message_count: number; is_webui?: boolean }, index: number) => (
|
||||
<div key={user.user_id} className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-3">
|
||||
<Badge variant={index === 0 ? "default" : "secondary"} className="h-6 w-6 rounded-full p-0 flex items-center justify-center shrink-0">
|
||||
{index + 1}
|
||||
</Badge>
|
||||
<span className="font-medium truncate max-w-[120px]">{user.user_nickname}</span>
|
||||
{user.is_webui && (
|
||||
<Badge variant="outline" className="text-xs px-1.5 py-0 h-5 bg-blue-50 text-blue-600 border-blue-200">
|
||||
WebUI
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
<span className="text-muted-foreground text-sm shrink-0">{user.message_count} 次互动</span>
|
||||
</div>
|
||||
))
|
||||
) : (
|
||||
<div className="text-center text-muted-foreground py-4">暂无数据</div>
|
||||
)}
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
{/* 维度三:最强大脑 */}
|
||||
<section className="space-y-4 break-inside-avoid">
|
||||
<div className="flex items-center gap-2 text-2xl font-bold text-primary">
|
||||
<Brain className="h-8 w-8" />
|
||||
<h2>最强大脑</h2>
|
||||
</div>
|
||||
<div className="grid gap-4 md:grid-cols-2 lg:grid-cols-4">
|
||||
<StatCard
|
||||
title="年度 Token 消耗"
|
||||
value={(data.brain_power.total_tokens / 1000000).toFixed(2) + ' M'}
|
||||
description={getTokenMetaphor(data.brain_power.total_tokens)}
|
||||
icon={<Zap className="h-4 w-4" />}
|
||||
/>
|
||||
<StatCard
|
||||
title="年度总花费"
|
||||
value={`$${data.brain_power.total_cost.toFixed(2)}`}
|
||||
description={getCostMetaphor(data.brain_power.total_cost)}
|
||||
icon={<span className="font-bold">$</span>}
|
||||
/>
|
||||
<StatCard
|
||||
title="高冷指数"
|
||||
value={`${data.brain_power.silence_rate}%`}
|
||||
description={getSilenceMetaphor(data.brain_power.silence_rate)}
|
||||
icon={<Moon className="h-4 w-4" />}
|
||||
/>
|
||||
<StatCard
|
||||
title="最高兴趣值"
|
||||
value={data.brain_power.max_interest_value ?? 'N/A'}
|
||||
description={data.brain_power.max_interest_time ? `出现在 ${data.brain_power.max_interest_time}` : '暂无数据'}
|
||||
icon={<Heart className="h-4 w-4" />}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>模型偏好分布</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="space-y-3">
|
||||
{data.brain_power.model_distribution.slice(0, 5).map((item: { model: string; count: number }, index: number) => {
|
||||
const maxCount = data.brain_power.model_distribution[0]?.count || 1
|
||||
const percentage = Math.round((item.count / maxCount) * 100)
|
||||
return (
|
||||
<div key={item.model} className="space-y-1">
|
||||
<div className="flex justify-between text-sm">
|
||||
<span className="font-medium truncate max-w-[200px]">{item.model}</span>
|
||||
<span className="text-muted-foreground">{item.count.toLocaleString()} 次</span>
|
||||
</div>
|
||||
<div className="h-2 w-full overflow-hidden rounded-full bg-secondary">
|
||||
<div
|
||||
className="h-full transition-all duration-500"
|
||||
style={{
|
||||
width: `${percentage}%`,
|
||||
backgroundColor: COLORS[index % COLORS.length]
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
{/* 最喜欢的回复模型 TOP5 */}
|
||||
{data.brain_power.top_reply_models && data.brain_power.top_reply_models.length > 0 && (
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>最喜欢的回复模型 TOP5</CardTitle>
|
||||
<CardDescription>{data.bot_name}用来回复消息的模型偏好</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="space-y-3">
|
||||
{data.brain_power.top_reply_models.map((item: { model: string; count: number }, index: number) => {
|
||||
const maxCount = data.brain_power.top_reply_models[0]?.count || 1
|
||||
const percentage = Math.round((item.count / maxCount) * 100)
|
||||
return (
|
||||
<div key={item.model} className="space-y-1">
|
||||
<div className="flex justify-between text-sm">
|
||||
<span className="font-medium truncate max-w-[200px]">{item.model}</span>
|
||||
<span className="text-muted-foreground">{item.count.toLocaleString()} 次</span>
|
||||
</div>
|
||||
<div className="h-2 w-full overflow-hidden rounded-full bg-secondary">
|
||||
<div
|
||||
className="h-full transition-all duration-500"
|
||||
style={{
|
||||
width: `${percentage}%`,
|
||||
backgroundColor: COLORS[index % COLORS.length]
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
)}
|
||||
|
||||
{/* 烧钱大户 - 只有有有效用户数据时才显示 */}
|
||||
{data.brain_power.top_token_consumers && data.brain_power.top_token_consumers.length > 0 && (
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>烧钱大户 TOP3</CardTitle>
|
||||
<CardDescription>谁消耗了最多的 API 额度</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="space-y-6">
|
||||
{data.brain_power.top_token_consumers.map((consumer: { user_id: string; cost: number; tokens: number }) => (
|
||||
<div key={consumer.user_id} className="space-y-2">
|
||||
<div className="flex justify-between text-sm font-medium">
|
||||
<span>用户 {consumer.user_id}</span>
|
||||
<span>${consumer.cost.toFixed(2)}</span>
|
||||
</div>
|
||||
<div className="h-2 w-full overflow-hidden rounded-full bg-secondary">
|
||||
<div
|
||||
className="h-full bg-primary transition-all duration-500"
|
||||
style={{ width: `${(consumer.cost / (data.brain_power.top_token_consumers[0]?.cost || 1)) * 100}%` }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* 最昂贵的思考 & 思考深度 */}
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<Card className="bg-gradient-to-br from-amber-50 to-orange-50 dark:from-amber-950/20 dark:to-orange-950/20">
|
||||
<CardHeader>
|
||||
<CardTitle className="flex items-center gap-2">
|
||||
<span className="text-2xl">💰</span>
|
||||
最昂贵的一次思考
|
||||
</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent className="text-center">
|
||||
<div className="text-4xl font-bold text-amber-600 dark:text-amber-400">
|
||||
${data.brain_power.most_expensive_cost.toFixed(4)}
|
||||
</div>
|
||||
{data.brain_power.most_expensive_time && (
|
||||
<p className="mt-2 text-sm text-muted-foreground">
|
||||
发生在 {data.brain_power.most_expensive_time}
|
||||
</p>
|
||||
)}
|
||||
<p className="mt-4 text-sm text-muted-foreground">
|
||||
{getExpensiveThinkingMetaphor(data.brain_power.most_expensive_cost)}
|
||||
</p>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
<Card className="bg-gradient-to-br from-indigo-50 to-blue-50 dark:from-indigo-950/20 dark:to-blue-950/20">
|
||||
<CardHeader>
|
||||
<CardTitle className="flex items-center gap-2">
|
||||
<span className="text-2xl">🧠</span>
|
||||
思考深度
|
||||
</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="grid grid-cols-2 gap-4 text-center">
|
||||
<div>
|
||||
<div className="text-2xl font-bold text-indigo-600 dark:text-indigo-400">
|
||||
{data.brain_power.avg_reasoning_length?.toFixed(0) || 0}
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground">平均思考字数</div>
|
||||
</div>
|
||||
<div>
|
||||
<div className="text-2xl font-bold text-blue-600 dark:text-blue-400">
|
||||
{data.brain_power.max_reasoning_length?.toLocaleString() || 0}
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground">最长思考字数</div>
|
||||
</div>
|
||||
</div>
|
||||
{data.brain_power.max_reasoning_time && (
|
||||
<p className="mt-4 text-center text-xs text-muted-foreground">
|
||||
最深沉的思考发生在 {data.brain_power.max_reasoning_time}
|
||||
</p>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
{/* 维度四:个性与表达 */}
|
||||
<section className="space-y-4 break-inside-avoid">
|
||||
<div className="flex items-center gap-2 text-2xl font-bold text-primary">
|
||||
<Smile className="h-8 w-8" />
|
||||
<h2>个性与表达</h2>
|
||||
</div>
|
||||
|
||||
{/* 深夜回复 & 最喜欢的回复 */}
|
||||
{(data.expression_vibe.late_night_reply || data.expression_vibe.favorite_reply) && (
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
{data.expression_vibe.late_night_reply && (
|
||||
<Card className="bg-gradient-to-br from-indigo-50 to-violet-50 dark:from-indigo-950/20 dark:to-violet-950/20">
|
||||
<CardHeader>
|
||||
<CardTitle className="flex items-center gap-2">
|
||||
<span className="text-2xl">🌙</span>
|
||||
深夜还在回复
|
||||
</CardTitle>
|
||||
<CardDescription>凌晨 {data.expression_vibe.late_night_reply.time},{data.bot_name}还在回复...</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="text-center">
|
||||
<p className="text-lg italic text-muted-foreground">
|
||||
"{data.expression_vibe.late_night_reply.content}"
|
||||
</p>
|
||||
<p className="mt-4 text-sm text-muted-foreground">
|
||||
是有什么心事吗?
|
||||
</p>
|
||||
</CardContent>
|
||||
</Card>
|
||||
)}
|
||||
|
||||
{data.expression_vibe.favorite_reply && (
|
||||
<Card className="bg-gradient-to-br from-rose-50 to-pink-50 dark:from-rose-950/20 dark:to-pink-950/20">
|
||||
<CardHeader>
|
||||
<CardTitle className="flex items-center gap-2">
|
||||
<span className="text-2xl">💬</span>
|
||||
最喜欢的回复
|
||||
</CardTitle>
|
||||
<CardDescription>使用了 {data.expression_vibe.favorite_reply.count} 次</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="text-center">
|
||||
<p className="text-lg font-medium text-primary">
|
||||
"{data.expression_vibe.favorite_reply.content}"
|
||||
</p>
|
||||
<p className="mt-4 text-sm text-muted-foreground">
|
||||
{getFavoriteReplyMetaphor(data.expression_vibe.favorite_reply.count, data.bot_name)}
|
||||
</p>
|
||||
</CardContent>
|
||||
</Card>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
{/* 使用最多的表情包 TOP3 */}
|
||||
<Card className="bg-gradient-to-br from-pink-50 to-purple-50 dark:from-pink-950/20 dark:to-purple-950/20">
|
||||
<CardHeader>
|
||||
<CardTitle>使用最多的表情包 TOP3</CardTitle>
|
||||
<CardDescription>年度最爱的表情包们</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
{data.expression_vibe.top_emojis && data.expression_vibe.top_emojis.length > 0 ? (
|
||||
<div className="flex justify-center gap-4">
|
||||
{data.expression_vibe.top_emojis.slice(0, 3).map((emoji: { id: number; usage_count: number }, index: number) => (
|
||||
<div key={emoji.id} className="flex flex-col items-center">
|
||||
<div className="relative">
|
||||
<img
|
||||
src={`/api/webui/emoji/${emoji.id}/thumbnail?original=true`}
|
||||
alt={`TOP ${index + 1}`}
|
||||
className="h-24 w-24 rounded-lg object-cover shadow-md transition-transform hover:scale-105"
|
||||
/>
|
||||
<Badge
|
||||
className={cn(
|
||||
"absolute -top-2 -right-2",
|
||||
index === 0 ? "bg-yellow-500" : index === 1 ? "bg-gray-400" : "bg-amber-700"
|
||||
)}
|
||||
>
|
||||
{index + 1}
|
||||
</Badge>
|
||||
</div>
|
||||
<p className="mt-2 text-sm text-muted-foreground">{emoji.usage_count} 次</p>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex h-32 items-center justify-center text-muted-foreground">暂无数据</div>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
<div className="space-y-4">
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>印象最深刻的表达风格</CardTitle>
|
||||
<CardDescription>{data.bot_name}最常使用的表达方式</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{data.expression_vibe.top_expressions.map((exp: { style: string; count: number }, index: number) => (
|
||||
<Badge
|
||||
key={exp.style}
|
||||
variant="outline"
|
||||
className={cn(
|
||||
"px-3 py-1 text-sm",
|
||||
index === 0 && "border-primary bg-primary/10 text-primary text-base px-4 py-2"
|
||||
)}
|
||||
>
|
||||
{exp.style} ({exp.count})
|
||||
</Badge>
|
||||
))}
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<StatCard
|
||||
title="图片鉴赏"
|
||||
value={`${data.expression_vibe.image_processed_count} 张`}
|
||||
description={getImageMetaphor(data.expression_vibe.image_processed_count)}
|
||||
icon={<ImageIcon className="h-4 w-4" />}
|
||||
/>
|
||||
<StatCard
|
||||
title="成长的足迹"
|
||||
value={`${data.expression_vibe.rejected_expression_count} 次`}
|
||||
description={getRejectedMetaphor(data.expression_vibe.rejected_expression_count)}
|
||||
icon={<Zap className="h-4 w-4" />}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* 行动派 */}
|
||||
{data.expression_vibe.action_types.length > 0 && (
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle className="flex items-center gap-2">
|
||||
<span className="text-2xl">⚡</span>
|
||||
行动派
|
||||
</CardTitle>
|
||||
<CardDescription>除了聊天,我还帮大家做了这些事</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="flex flex-wrap gap-3">
|
||||
{data.expression_vibe.action_types.map((action: { action: string; count: number }) => (
|
||||
<div
|
||||
key={action.action}
|
||||
className="flex items-center gap-2 rounded-full bg-primary/10 px-4 py-2"
|
||||
>
|
||||
<span className="font-medium text-primary">{action.action}</span>
|
||||
<Badge variant="secondary">{action.count} 次</Badge>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
)}
|
||||
</section>
|
||||
|
||||
{/* 维度五:趣味成就 */}
|
||||
<section className="space-y-4 break-inside-avoid">
|
||||
<div className="flex items-center gap-2 text-2xl font-bold text-primary">
|
||||
<Trophy className="h-8 w-8" />
|
||||
<h2>趣味成就</h2>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-3">
|
||||
<Card className="col-span-1 md:col-span-2">
|
||||
<CardHeader>
|
||||
<CardTitle>新学到的"黑话"</CardTitle>
|
||||
<CardDescription>今年我学会了 {data.achievements.new_jargon_count} 个新词</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="flex flex-wrap gap-3">
|
||||
{data.achievements.sample_jargons.map((jargon: { content: string; meaning: string; count: number }) => (
|
||||
<div key={jargon.content} className="group relative rounded-lg border bg-card p-3 shadow-sm transition-all hover:shadow-md">
|
||||
<div className="font-bold text-primary">{jargon.content}</div>
|
||||
<div className="text-xs text-muted-foreground mt-1 line-clamp-2 max-w-[200px]">
|
||||
{jargon.meaning || '暂无解释'}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
<Card className="flex flex-col justify-center items-center bg-primary text-primary-foreground">
|
||||
<CardContent className="flex flex-col items-center justify-center p-6 text-center">
|
||||
<MessageSquare className="h-12 w-12 mb-4 opacity-80" />
|
||||
<div className="text-4xl font-bold mb-2">{data.achievements.total_messages.toLocaleString()}</div>
|
||||
<div className="text-sm opacity-80">年度总消息数</div>
|
||||
<div className="mt-4 text-xs opacity-60">
|
||||
其中回复了 {data.achievements.total_replies.toLocaleString()} 次
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
{/* 底部 */}
|
||||
<footer className="mt-12 text-center text-muted-foreground">
|
||||
<p>MaiBot 2025 Annual Report</p>
|
||||
<p className="text-sm">Generated with ❤️ by MaiBot Team</p>
|
||||
</footer>
|
||||
</div>
|
||||
</div>
|
||||
</ScrollArea>
|
||||
)
|
||||
}
|
||||
|
||||
function StatCard({
|
||||
title,
|
||||
value,
|
||||
description,
|
||||
icon,
|
||||
}: {
|
||||
title: string
|
||||
value: string | number
|
||||
description: string
|
||||
icon: React.ReactNode
|
||||
}) {
|
||||
return (
|
||||
<Card>
|
||||
<CardHeader className="flex flex-row items-center justify-between space-y-0 pb-2">
|
||||
<CardTitle className="text-sm font-medium">{title}</CardTitle>
|
||||
<div className="text-muted-foreground">{icon}</div>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="text-2xl font-bold">{value}</div>
|
||||
<p className="text-xs text-muted-foreground">{description}</p>
|
||||
</CardContent>
|
||||
</Card>
|
||||
)
|
||||
}
|
||||
|
||||
function LoadingSkeleton() {
|
||||
return (
|
||||
<div className="container mx-auto space-y-8 p-8">
|
||||
<Skeleton className="h-64 w-full rounded-3xl" />
|
||||
<div className="grid gap-4 md:grid-cols-4">
|
||||
{[...Array(4)].map((_, i) => (
|
||||
<Skeleton key={i} className="h-32 w-full" />
|
||||
))}
|
||||
</div>
|
||||
<Skeleton className="h-96 w-full" />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -2,6 +2,7 @@ import { Avatar, AvatarFallback } from '@/components/ui/avatar'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import {
|
||||
Dialog,
|
||||
DialogBody,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
@@ -54,7 +55,7 @@ export function VirtualIdentityDialog({
|
||||
}: VirtualIdentityDialogProps) {
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="sm:max-w-[500px] max-h-[85vh] overflow-hidden flex flex-col">
|
||||
<DialogContent className="sm:max-w-125 max-h-[85vh] overflow-hidden flex flex-col" confirmOnEnter>
|
||||
<DialogHeader>
|
||||
<DialogTitle className="flex items-center gap-2">
|
||||
<UserCircle2 className="h-5 w-5" />
|
||||
@@ -65,7 +66,7 @@ export function VirtualIdentityDialog({
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<div className="space-y-4 flex-1 overflow-hidden flex flex-col">
|
||||
<DialogBody className="space-y-4 flex-1" viewportClassName="pr-0">
|
||||
{/* 平台选择 */}
|
||||
<div className="space-y-2">
|
||||
<Label className="flex items-center gap-2">
|
||||
@@ -113,7 +114,7 @@ export function VirtualIdentityDialog({
|
||||
className="pl-9"
|
||||
/>
|
||||
</div>
|
||||
<ScrollArea className="h-[250px] border rounded-md">
|
||||
<ScrollArea className="h-62.5 border rounded-md">
|
||||
<div className="p-2">
|
||||
{isLoadingPersons ? (
|
||||
<div className="flex items-center justify-center py-8">
|
||||
@@ -187,13 +188,14 @@ export function VirtualIdentityDialog({
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</DialogBody>
|
||||
|
||||
<DialogFooter className="gap-2 sm:gap-0">
|
||||
<Button variant="outline" onClick={() => onOpenChange(false)}>
|
||||
取消
|
||||
</Button>
|
||||
<Button
|
||||
<Button
|
||||
data-dialog-action="confirm"
|
||||
onClick={onCreateVirtualTab}
|
||||
disabled={!tempVirtualConfig.platform || !tempVirtualConfig.personId}
|
||||
>
|
||||
|
||||
@@ -18,6 +18,7 @@ import {
|
||||
} from '@/components/ui/alert-dialog'
|
||||
import {
|
||||
Dialog,
|
||||
DialogBody,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
@@ -249,7 +250,7 @@ function RegexEditor({
|
||||
正则编辑器
|
||||
</Button>
|
||||
</DialogTrigger>
|
||||
<DialogContent className="max-w-[95vw] sm:max-w-[900px] max-h-[90vh]">
|
||||
<DialogContent className="max-w-[95vw] sm:max-w-225">
|
||||
<DialogHeader>
|
||||
<DialogTitle>正则表达式编辑器</DialogTitle>
|
||||
<DialogDescription className="text-sm">
|
||||
@@ -257,7 +258,7 @@ function RegexEditor({
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<ScrollArea className="max-h-[calc(90vh-120px)]">
|
||||
<DialogBody>
|
||||
<Tabs value={activeTab} onValueChange={(v) => setActiveTab(v as 'build' | 'test')} className="w-full">
|
||||
<TabsList className="grid w-full grid-cols-2">
|
||||
<TabsTrigger value="build">🔧 构建器</TabsTrigger>
|
||||
@@ -406,7 +407,7 @@ function RegexEditor({
|
||||
value={testText}
|
||||
onChange={(e) => setTestText(e.target.value)}
|
||||
placeholder="在此输入要测试的文本... 例如:打游戏是这样的"
|
||||
className="min-h-[100px] text-sm"
|
||||
className="min-h-25 text-sm"
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -444,7 +445,7 @@ function RegexEditor({
|
||||
<div className="space-y-2">
|
||||
<Label className="text-sm font-medium">匹配高亮</Label>
|
||||
<ScrollArea className="h-40 rounded-md bg-muted p-3">
|
||||
<div className="text-sm break-words">
|
||||
<div className="text-sm wrap-break-word">
|
||||
{renderHighlightedText()}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
@@ -458,7 +459,7 @@ function RegexEditor({
|
||||
<div className="space-y-2">
|
||||
{Object.entries(captureGroups).map(([name, value]) => (
|
||||
<div key={name} className="flex items-start gap-2 text-sm">
|
||||
<span className="font-mono font-semibold text-primary min-w-[80px]">[{name}]</span>
|
||||
<span className="font-mono font-semibold text-primary min-w-20">[{name}]</span>
|
||||
<span className="text-muted-foreground">=</span>
|
||||
<span className="font-mono bg-muted px-2 py-0.5 rounded">{value}</span>
|
||||
</div>
|
||||
@@ -473,7 +474,7 @@ function RegexEditor({
|
||||
<div className="space-y-2">
|
||||
<Label className="text-sm font-medium">Reaction 替换预览</Label>
|
||||
<ScrollArea className="h-48 rounded-md bg-blue-50 dark:bg-blue-950/30 border border-blue-200 dark:border-blue-800 p-3">
|
||||
<div className="text-sm break-words">
|
||||
<div className="text-sm wrap-break-word">
|
||||
{replacedReaction}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
@@ -497,7 +498,7 @@ function RegexEditor({
|
||||
</div>
|
||||
</TabsContent>
|
||||
</Tabs>
|
||||
</ScrollArea>
|
||||
</DialogBody>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
)
|
||||
@@ -628,7 +629,7 @@ export const ProcessingSection = React.memo(function ProcessingSection({
|
||||
预览
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent className="w-[95vw] sm:w-[500px]">
|
||||
<PopoverContent className="w-[95vw] sm:w-125">
|
||||
<div className="space-y-2">
|
||||
<h4 className="font-medium text-sm">配置预览</h4>
|
||||
<ScrollArea className="h-60 rounded-md bg-muted p-3">
|
||||
@@ -656,7 +657,7 @@ export const ProcessingSection = React.memo(function ProcessingSection({
|
||||
预览
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent className="w-[95vw] sm:w-[500px]">
|
||||
<PopoverContent className="w-[95vw] sm:w-125">
|
||||
<div className="space-y-2">
|
||||
<h4 className="font-medium text-sm">配置预览</h4>
|
||||
<ScrollArea className="h-60 rounded-md bg-muted p-3">
|
||||
|
||||
@@ -6,6 +6,7 @@ import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'
|
||||
import { ScrollArea } from '@/components/ui/scroll-area'
|
||||
import {
|
||||
Dialog,
|
||||
DialogBody,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
@@ -971,9 +972,10 @@ function ModelConfigPageContent() {
|
||||
{/* 编辑模型对话框 */}
|
||||
<Dialog open={editDialogOpen} onOpenChange={handleEditDialogClose}>
|
||||
<DialogContent
|
||||
className="max-w-[95vw] sm:max-w-2xl max-h-[90vh] overflow-y-auto"
|
||||
className="max-w-[95vw] sm:max-w-2xl"
|
||||
data-tour="model-dialog"
|
||||
preventOutsideClose={tourIsRunning}
|
||||
confirmOnEnter
|
||||
>
|
||||
<DialogHeader>
|
||||
<DialogTitle>
|
||||
@@ -982,6 +984,7 @@ function ModelConfigPageContent() {
|
||||
<DialogDescription>配置模型的基本信息和参数</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<DialogBody>
|
||||
<div className="grid gap-4 py-4">
|
||||
<div className="grid gap-2" data-tour="model-name-input">
|
||||
<Label htmlFor="model_name" className={formErrors.name ? 'text-destructive' : ''}>模型名称 *</Label>
|
||||
@@ -1492,12 +1495,13 @@ function ModelConfigPageContent() {
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</DialogBody>
|
||||
|
||||
<DialogFooter>
|
||||
<Button variant="outline" onClick={() => setEditDialogOpen(false)} data-tour="model-cancel-button">
|
||||
取消
|
||||
</Button>
|
||||
<Button onClick={handleSaveEdit} data-tour="model-save-button">保存</Button>
|
||||
<Button data-dialog-action="confirm" onClick={handleSaveEdit} data-tour="model-save-button">保存</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
|
||||
@@ -3,7 +3,7 @@ import { Check, ChevronsUpDown, Copy, Eye, EyeOff } from 'lucide-react'
|
||||
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Command, CommandEmpty, CommandGroup, CommandInput, CommandItem, CommandList } from '@/components/ui/command'
|
||||
import { Dialog, DialogContent, DialogDescription, DialogFooter, DialogHeader, DialogTitle } from '@/components/ui/dialog'
|
||||
import { Dialog, DialogBody, DialogContent, DialogDescription, DialogFooter, DialogHeader, DialogTitle } from '@/components/ui/dialog'
|
||||
import { HelpTooltip } from '@/components/ui/help-tooltip'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { Label } from '@/components/ui/label'
|
||||
@@ -116,9 +116,10 @@ export function ProviderForm({
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent
|
||||
className="max-w-[95vw] sm:max-w-2xl max-h-[90vh] overflow-y-auto"
|
||||
className="max-w-[95vw] sm:max-w-2xl"
|
||||
data-tour="provider-dialog"
|
||||
preventOutsideClose={tourState.isRunning}
|
||||
confirmOnEnter
|
||||
>
|
||||
<DialogHeader>
|
||||
<DialogTitle>
|
||||
@@ -130,6 +131,7 @@ export function ProviderForm({
|
||||
</DialogHeader>
|
||||
|
||||
<form onSubmit={(e) => { e.preventDefault(); handleSaveEdit(); }} autoComplete="off">
|
||||
<DialogBody>
|
||||
<div className="grid gap-4 py-4">
|
||||
<div className="grid gap-2" data-tour="provider-template-select">
|
||||
<Label htmlFor="template">提供商模板</Label>
|
||||
@@ -450,12 +452,13 @@ export function ProviderForm({
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</DialogBody>
|
||||
|
||||
<DialogFooter>
|
||||
<Button type="button" variant="outline" onClick={() => onOpenChange(false)} data-tour="provider-cancel-button">
|
||||
取消
|
||||
</Button>
|
||||
<Button type="submit" data-tour="provider-save-button">保存</Button>
|
||||
<Button type="submit" data-dialog-action="confirm" data-tour="provider-save-button">保存</Button>
|
||||
</DialogFooter>
|
||||
</form>
|
||||
</DialogContent>
|
||||
|
||||
@@ -40,6 +40,7 @@ import { ScrollArea } from '@/components/ui/scroll-area'
|
||||
import { Alert, AlertDescription, AlertTitle } from '@/components/ui/alert'
|
||||
import {
|
||||
Dialog,
|
||||
DialogBody,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
@@ -575,7 +576,7 @@ function ApplyDialog({
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="max-w-2xl max-h-[80vh] overflow-y-auto">
|
||||
<DialogContent className="max-w-2xl" confirmOnEnter>
|
||||
<DialogHeader>
|
||||
<DialogTitle className="flex items-center gap-2">
|
||||
<Package className="w-5 h-5" />
|
||||
@@ -589,6 +590,7 @@ function ApplyDialog({
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<DialogBody>
|
||||
{detectingConflicts ? (
|
||||
<div className="py-8 text-center">
|
||||
<Loader2 className="w-8 h-8 mx-auto animate-spin text-primary" />
|
||||
@@ -831,6 +833,7 @@ function ApplyDialog({
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</DialogBody>
|
||||
|
||||
<DialogFooter className="flex justify-between">
|
||||
<div>
|
||||
@@ -845,11 +848,11 @@ function ApplyDialog({
|
||||
取消
|
||||
</Button>
|
||||
{step < totalSteps ? (
|
||||
<Button onClick={() => setStep(step + 1)} disabled={detectingConflicts}>
|
||||
<Button data-dialog-action="confirm" onClick={() => setStep(step + 1)} disabled={detectingConflicts}>
|
||||
下一步
|
||||
</Button>
|
||||
) : (
|
||||
<Button onClick={onApply} disabled={applying}>
|
||||
<Button data-dialog-action="confirm" onClick={onApply} disabled={applying}>
|
||||
{applying && <Loader2 className="w-4 h-4 mr-2 animate-spin" />}
|
||||
应用模板
|
||||
</Button>
|
||||
|
||||
@@ -29,6 +29,7 @@ import { Button } from '@/components/ui/button'
|
||||
import { Checkbox } from '@/components/ui/checkbox'
|
||||
import {
|
||||
Dialog,
|
||||
DialogBody,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
@@ -573,7 +574,7 @@ export function PersonManagementPage() {
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => handleViewDetail(person)}
|
||||
className="text-xs px-2 py-1 h-auto flex-shrink-0"
|
||||
className="text-xs px-2 py-1 h-auto shrink-0"
|
||||
>
|
||||
<Eye className="h-3 w-3 mr-1" />
|
||||
查看
|
||||
@@ -582,7 +583,7 @@ export function PersonManagementPage() {
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => handleEdit(person)}
|
||||
className="text-xs px-2 py-1 h-auto flex-shrink-0"
|
||||
className="text-xs px-2 py-1 h-auto shrink-0"
|
||||
>
|
||||
<Edit className="h-3 w-3 mr-1" />
|
||||
编辑
|
||||
@@ -591,7 +592,7 @@ export function PersonManagementPage() {
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => setDeleteConfirmPerson(person)}
|
||||
className="text-xs px-2 py-1 h-auto flex-shrink-0 text-destructive hover:text-destructive"
|
||||
className="text-xs px-2 py-1 h-auto shrink-0 text-destructive hover:text-destructive"
|
||||
>
|
||||
<Trash2 className="h-3 w-3 mr-1" />
|
||||
删除
|
||||
@@ -771,7 +772,7 @@ function PersonDetailDialog({
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="max-w-2xl max-h-[80vh] overflow-y-auto">
|
||||
<DialogContent className="max-w-2xl">
|
||||
<DialogHeader>
|
||||
<DialogTitle>人物详情</DialogTitle>
|
||||
<DialogDescription>
|
||||
@@ -779,6 +780,7 @@ function PersonDetailDialog({
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<DialogBody>
|
||||
<div className="space-y-4">
|
||||
{/* 基本信息 */}
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
@@ -829,6 +831,7 @@ function PersonDetailDialog({
|
||||
<InfoItem icon={Clock} label="最后更新" value={formatTime(person.last_know)} />
|
||||
</div>
|
||||
</div>
|
||||
</DialogBody>
|
||||
|
||||
<DialogFooter>
|
||||
<Button onClick={() => onOpenChange(false)}>关闭</Button>
|
||||
@@ -919,7 +922,7 @@ function PersonEditDialog({
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="max-w-2xl max-h-[80vh] overflow-y-auto">
|
||||
<DialogContent className="max-w-2xl" confirmOnEnter>
|
||||
<DialogHeader>
|
||||
<DialogTitle>编辑人物信息</DialogTitle>
|
||||
<DialogDescription>
|
||||
@@ -927,6 +930,7 @@ function PersonEditDialog({
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<DialogBody>
|
||||
<div className="space-y-4">
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div className="space-y-2">
|
||||
@@ -974,6 +978,7 @@ function PersonEditDialog({
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</DialogBody>
|
||||
|
||||
<DialogFooter>
|
||||
<Button variant="outline" onClick={() => onOpenChange(false)}>
|
||||
|
||||
@@ -10,6 +10,7 @@ import { Button } from '@/components/ui/button'
|
||||
import { Checkbox } from '@/components/ui/checkbox'
|
||||
import {
|
||||
Dialog,
|
||||
DialogBody,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
@@ -59,7 +60,7 @@ export function EmojiDetailDialog({
|
||||
<DialogHeader>
|
||||
<DialogTitle>表情包详情</DialogTitle>
|
||||
</DialogHeader>
|
||||
<ScrollArea className="max-h-[calc(90vh-8rem)] pr-4">
|
||||
<DialogBody>
|
||||
<div className="space-y-4">
|
||||
{/* 表情包预览图 - 使用原图 */}
|
||||
<div className="flex justify-center">
|
||||
@@ -177,7 +178,7 @@ export function EmojiDetailDialog({
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</ScrollArea>
|
||||
</DialogBody>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
)
|
||||
@@ -252,11 +253,12 @@ export function EmojiEditDialog({
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="max-w-2xl">
|
||||
<DialogContent className="max-w-2xl" confirmOnEnter>
|
||||
<DialogHeader>
|
||||
<DialogTitle>编辑表情包</DialogTitle>
|
||||
<DialogDescription>修改表情包的情绪和状态信息</DialogDescription>
|
||||
</DialogHeader>
|
||||
<DialogBody>
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<Label>情绪</Label>
|
||||
@@ -310,11 +312,12 @@ export function EmojiEditDialog({
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</DialogBody>
|
||||
<DialogFooter>
|
||||
<Button variant="outline" onClick={() => onOpenChange(false)}>
|
||||
取消
|
||||
</Button>
|
||||
<Button onClick={handleSave} disabled={saving}>
|
||||
<Button data-dialog-action="confirm" onClick={handleSave} disabled={saving}>
|
||||
{saving ? '保存中...' : '保存'}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
@@ -658,7 +661,7 @@ export function EmojiUploadDialog({
|
||||
|
||||
<div className="flex gap-6">
|
||||
{/* 预览图 */}
|
||||
<div className="flex-shrink-0">
|
||||
<div className="shrink-0">
|
||||
<div className="w-32 h-32 rounded-lg border overflow-hidden bg-muted flex items-center justify-center">
|
||||
<img
|
||||
src={file.previewUrl}
|
||||
@@ -764,7 +767,7 @@ export function EmojiUploadDialog({
|
||||
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
{/* 左侧:文件卡片列表 */}
|
||||
<ScrollArea className="h-[350px] pr-2">
|
||||
<ScrollArea className="h-87.5 pr-2">
|
||||
<div className="space-y-2">
|
||||
{uploadedFiles.map((file) => {
|
||||
const complete = isFileComplete(file)
|
||||
@@ -782,7 +785,7 @@ export function EmojiUploadDialog({
|
||||
${complete ? 'border-green-500 bg-green-50 dark:bg-green-950/20' : 'border-border hover:border-muted-foreground/50'}
|
||||
`}
|
||||
>
|
||||
<div className="w-12 h-12 rounded border overflow-hidden bg-muted flex-shrink-0 flex items-center justify-center">
|
||||
<div className="flex h-12 w-12 shrink-0 items-center justify-center overflow-hidden rounded border bg-muted">
|
||||
<img
|
||||
src={file.previewUrl}
|
||||
alt={file.name}
|
||||
@@ -796,9 +799,9 @@ export function EmojiUploadDialog({
|
||||
</p>
|
||||
</div>
|
||||
{complete ? (
|
||||
<CheckCircle2 className="h-5 w-5 text-green-500 flex-shrink-0" />
|
||||
<CheckCircle2 className="h-5 w-5 shrink-0 text-green-500" />
|
||||
) : (
|
||||
<div className="h-5 w-5 rounded-full border-2 border-muted-foreground/30 flex-shrink-0" />
|
||||
<div className="h-5 w-5 shrink-0 rounded-full border-2 border-muted-foreground/30" />
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
@@ -908,7 +911,7 @@ export function EmojiUploadDialog({
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="max-w-3xl max-h-[90vh] overflow-hidden">
|
||||
<DialogContent className="max-w-3xl max-h-[90vh] overflow-hidden" confirmOnEnter>
|
||||
<DialogHeader>
|
||||
<DialogTitle className="flex items-center gap-2">
|
||||
<Upload className="h-5 w-5" />
|
||||
@@ -925,11 +928,11 @@ export function EmojiUploadDialog({
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<div className="overflow-y-auto pr-1">
|
||||
<DialogBody viewportClassName="pr-1">
|
||||
{step === 'select' && renderSelectStep()}
|
||||
{step === 'edit-single' && renderEditSingleStep()}
|
||||
{step === 'edit-multiple' && renderEditMultipleStep()}
|
||||
</div>
|
||||
</DialogBody>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
import { Button } from '@/components/ui/button'
|
||||
import {
|
||||
Dialog,
|
||||
DialogBody,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
@@ -65,7 +66,7 @@ export function ExpressionDetailDialog({
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="max-w-2xl max-h-[80vh] overflow-y-auto">
|
||||
<DialogContent className="max-w-2xl" confirmOnEnter>
|
||||
<DialogHeader>
|
||||
<DialogTitle>表达方式详情</DialogTitle>
|
||||
<DialogDescription>
|
||||
@@ -73,6 +74,7 @@ export function ExpressionDetailDialog({
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<DialogBody>
|
||||
<div className="space-y-4">
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<InfoItem label="情境" value={expression.situation} />
|
||||
@@ -131,6 +133,7 @@ export function ExpressionDetailDialog({
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</DialogBody>
|
||||
|
||||
<DialogFooter>
|
||||
<Button onClick={() => onOpenChange(false)}>关闭</Button>
|
||||
@@ -233,7 +236,7 @@ export function ExpressionCreateDialog({
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="max-w-2xl max-h-[80vh] overflow-y-auto">
|
||||
<DialogContent className="max-w-2xl" confirmOnEnter>
|
||||
<DialogHeader>
|
||||
<DialogTitle>新增表达方式</DialogTitle>
|
||||
<DialogDescription>
|
||||
@@ -241,6 +244,7 @@ export function ExpressionCreateDialog({
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<DialogBody>
|
||||
<div className="space-y-4">
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div className="space-y-2">
|
||||
@@ -291,12 +295,13 @@ export function ExpressionCreateDialog({
|
||||
</Select>
|
||||
</div>
|
||||
</div>
|
||||
</DialogBody>
|
||||
|
||||
<DialogFooter>
|
||||
<Button variant="outline" onClick={() => onOpenChange(false)}>
|
||||
取消
|
||||
</Button>
|
||||
<Button onClick={handleCreate} disabled={saving}>
|
||||
<Button data-dialog-action="confirm" onClick={handleCreate} disabled={saving}>
|
||||
{saving ? '创建中...' : '创建'}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
@@ -371,7 +376,7 @@ export function ExpressionEditDialog({
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="max-w-2xl max-h-[80vh] overflow-y-auto">
|
||||
<DialogContent className="max-w-2xl" confirmOnEnter>
|
||||
<DialogHeader>
|
||||
<DialogTitle>编辑表达方式</DialogTitle>
|
||||
<DialogDescription>
|
||||
@@ -379,6 +384,7 @@ export function ExpressionEditDialog({
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<DialogBody>
|
||||
<div className="space-y-4">
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div className="space-y-2">
|
||||
@@ -474,12 +480,13 @@ export function ExpressionEditDialog({
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</DialogBody>
|
||||
|
||||
<DialogFooter>
|
||||
<Button variant="outline" onClick={() => onOpenChange(false)}>
|
||||
取消
|
||||
</Button>
|
||||
<Button onClick={handleSave} disabled={saving}>
|
||||
<Button data-dialog-action="confirm" onClick={handleSave} disabled={saving}>
|
||||
{saving ? '保存中...' : '保存'}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
|
||||
@@ -14,6 +14,7 @@ import {
|
||||
import { Button } from '@/components/ui/button'
|
||||
import {
|
||||
Dialog,
|
||||
DialogBody,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
@@ -24,7 +25,6 @@ import { Badge } from '@/components/ui/badge'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { Label } from '@/components/ui/label'
|
||||
import { MarkdownRenderer } from '@/components/markdown-renderer'
|
||||
import { ScrollArea } from '@/components/ui/scroll-area'
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
@@ -92,7 +92,7 @@ export function JargonDetailDialog({
|
||||
<DialogDescription>查看黑话的完整信息</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<ScrollArea className="h-full pr-4">
|
||||
<DialogBody className="h-full">
|
||||
<div className="space-y-4 pb-2">
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<InfoItem icon={Hash} label="记录ID" value={jargon.id.toString()} mono />
|
||||
@@ -167,7 +167,7 @@ export function JargonDetailDialog({
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
</DialogBody>
|
||||
|
||||
<DialogFooter className="flex-shrink-0">
|
||||
<Button onClick={() => onOpenChange(false)}>关闭</Button>
|
||||
@@ -234,12 +234,13 @@ export function JargonCreateDialog({
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="max-w-2xl max-h-[80vh] overflow-y-auto">
|
||||
<DialogContent className="max-w-2xl" confirmOnEnter>
|
||||
<DialogHeader>
|
||||
<DialogTitle>新增黑话</DialogTitle>
|
||||
<DialogDescription>创建新的黑话记录</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<DialogBody>
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="content">
|
||||
@@ -294,10 +295,11 @@ export function JargonCreateDialog({
|
||||
<Label htmlFor="is_global">设为全局黑话</Label>
|
||||
</div>
|
||||
</div>
|
||||
</DialogBody>
|
||||
|
||||
<DialogFooter>
|
||||
<Button variant="outline" onClick={() => onOpenChange(false)}>取消</Button>
|
||||
<Button onClick={handleCreate} disabled={saving}>
|
||||
<Button data-dialog-action="confirm" onClick={handleCreate} disabled={saving}>
|
||||
{saving ? '创建中...' : '创建'}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
@@ -366,12 +368,13 @@ export function JargonEditDialog({
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="max-w-2xl max-h-[80vh] overflow-y-auto">
|
||||
<DialogContent className="max-w-2xl" confirmOnEnter>
|
||||
<DialogHeader>
|
||||
<DialogTitle>编辑黑话</DialogTitle>
|
||||
<DialogDescription>修改黑话的信息</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<DialogBody>
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="edit_content">内容</Label>
|
||||
@@ -439,10 +442,11 @@ export function JargonEditDialog({
|
||||
<Label htmlFor="edit_is_global">全局黑话</Label>
|
||||
</div>
|
||||
</div>
|
||||
</DialogBody>
|
||||
|
||||
<DialogFooter>
|
||||
<Button variant="outline" onClick={() => onOpenChange(false)}>取消</Button>
|
||||
<Button onClick={handleSave} disabled={saving}>
|
||||
<Button data-dialog-action="confirm" onClick={handleSave} disabled={saving}>
|
||||
{saving ? '保存中...' : '保存'}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
import { Badge } from '@/components/ui/badge'
|
||||
import {
|
||||
Dialog,
|
||||
DialogBody,
|
||||
DialogContent,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from '@/components/ui/dialog'
|
||||
import { ScrollArea } from '@/components/ui/scroll-area'
|
||||
|
||||
import type { GraphNode, SelectedEdgeData } from './types'
|
||||
|
||||
@@ -24,7 +24,7 @@ export function NodeDetailDialog({ open, onOpenChange, selectedNodeData }: NodeD
|
||||
<DialogTitle>节点详情</DialogTitle>
|
||||
</DialogHeader>
|
||||
{selectedNodeData && (
|
||||
<ScrollArea className="h-full pr-4">
|
||||
<DialogBody className="h-full">
|
||||
<div className="space-y-4 pb-2">
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div>
|
||||
@@ -62,7 +62,7 @@ export function NodeDetailDialog({ open, onOpenChange, selectedNodeData }: NodeD
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</ScrollArea>
|
||||
</DialogBody>
|
||||
)}
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
@@ -83,7 +83,7 @@ export function EdgeDetailDialog({ open, onOpenChange, selectedEdgeData }: EdgeD
|
||||
<DialogTitle>边详情</DialogTitle>
|
||||
</DialogHeader>
|
||||
{selectedEdgeData && (
|
||||
<ScrollArea className="flex-1 pr-4">
|
||||
<DialogBody>
|
||||
<div className="space-y-4">
|
||||
<div className="flex items-center gap-4">
|
||||
<div className="flex-1 min-w-0 p-3 bg-blue-50 dark:bg-blue-950 rounded border-2 border-blue-200 dark:border-blue-800">
|
||||
@@ -114,7 +114,7 @@ export function EdgeDetailDialog({ open, onOpenChange, selectedEdgeData }: EdgeD
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</ScrollArea>
|
||||
</DialogBody>
|
||||
)}
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
|
||||
@@ -190,7 +190,6 @@ class ChatBot:
|
||||
@staticmethod
|
||||
def _mark_command_message(message: SessionMessage, intercept_message_level: int) -> None:
|
||||
message.is_command = True
|
||||
message.intercept_message_level = intercept_message_level
|
||||
message.message_info.additional_config["intercept_message_level"] = intercept_message_level
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -29,7 +29,7 @@ def get_webui_chat_broadcaster():
|
||||
global _webui_chat_broadcaster
|
||||
if _webui_chat_broadcaster is None:
|
||||
try:
|
||||
from src.webui.chat_routes import chat_manager, WEBUI_CHAT_PLATFORM
|
||||
from src.webui.routers.chat import WEBUI_CHAT_PLATFORM, chat_manager
|
||||
|
||||
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM)
|
||||
except ImportError:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import traceback
|
||||
import time
|
||||
import asyncio
|
||||
import importlib
|
||||
import random
|
||||
import re
|
||||
|
||||
@@ -11,9 +12,9 @@ from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from maim_message import BaseMessageInfo, MessageBase, Seg
|
||||
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUserInfo
|
||||
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
@@ -786,7 +787,6 @@ class DefaultReplyer:
|
||||
available_actions = {}
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.session_id
|
||||
_is_group_chat = bool(chat_stream.group_info)
|
||||
platform = chat_stream.platform
|
||||
|
||||
user_id = "用户ID"
|
||||
@@ -795,11 +795,12 @@ class DefaultReplyer:
|
||||
target = "消息"
|
||||
|
||||
if reply_message:
|
||||
user_id = reply_message.user_info.user_id
|
||||
reply_user_info = reply_message.message_info.user_info
|
||||
user_id = reply_user_info.user_id
|
||||
person = Person(platform=platform, user_id=user_id)
|
||||
person_name = person.person_name or user_id
|
||||
sender = person_name
|
||||
target = reply_message.processed_plain_text
|
||||
target = reply_message.processed_plain_text or ""
|
||||
|
||||
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
|
||||
|
||||
@@ -825,16 +826,17 @@ class DefaultReplyer:
|
||||
|
||||
person_list_short: List[Person] = []
|
||||
for msg in message_list_before_short:
|
||||
msg_user_info = msg.message_info.user_info
|
||||
# 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI)
|
||||
if is_bot_self(msg.user_info.platform, msg.user_info.user_id):
|
||||
if is_bot_self(msg.platform, msg_user_info.user_id):
|
||||
continue
|
||||
if (
|
||||
reply_message
|
||||
and reply_message.user_info.user_id == msg.user_info.user_id
|
||||
and reply_message.user_info.platform == msg.user_info.platform
|
||||
and reply_message.message_info.user_info.user_id == msg_user_info.user_id
|
||||
and reply_message.platform == msg.platform
|
||||
):
|
||||
continue
|
||||
person = Person(platform=msg.user_info.platform, user_id=msg.user_info.user_id)
|
||||
person = Person(platform=msg.platform, user_id=msg_user_info.user_id)
|
||||
if person.is_known:
|
||||
person_list_short.append(person)
|
||||
|
||||
@@ -847,7 +849,6 @@ class DefaultReplyer:
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=True,
|
||||
long_time_notice=True,
|
||||
)
|
||||
|
||||
# 统一黑话解释构建:根据配置选择上下文或 Planner 模式
|
||||
@@ -956,7 +957,6 @@ class DefaultReplyer:
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=True,
|
||||
long_time_notice=True,
|
||||
)
|
||||
|
||||
# 获取匹配的额外prompt
|
||||
@@ -1121,7 +1121,7 @@ class DefaultReplyer:
|
||||
platform=self.chat_stream.platform,
|
||||
message_id=message_id,
|
||||
time=thinking_start_time,
|
||||
user_info=UserInfo(
|
||||
user_info=MaimUserInfo(
|
||||
user_id=str(global_config.bot.qq_account),
|
||||
user_nickname=global_config.bot.nickname,
|
||||
),
|
||||
@@ -1160,7 +1160,16 @@ class DefaultReplyer:
|
||||
async def get_prompt_info(self, message: str, sender: str, target: str):
|
||||
related_info = ""
|
||||
start_time = time.time()
|
||||
from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool
|
||||
try:
|
||||
knowledge_module = importlib.import_module("src.plugins.built_in.knowledge.lpmm_get_knowledge")
|
||||
except ImportError:
|
||||
logger.debug("LPMM知识库工具模块不存在,跳过获取知识库内容")
|
||||
return ""
|
||||
|
||||
search_knowledge_tool = getattr(knowledge_module, "SearchKnowledgeFromLPMMTool", None)
|
||||
if search_knowledge_tool is None:
|
||||
logger.debug("LPMM知识库工具未提供 SearchKnowledgeFromLPMMTool,跳过获取知识库内容")
|
||||
return ""
|
||||
|
||||
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||
# 从LPMM知识库获取知识
|
||||
@@ -1183,14 +1192,14 @@ class DefaultReplyer:
|
||||
_, _, _, _, tool_calls = await llm_api.generate_with_model_with_tools(
|
||||
prompt,
|
||||
model_config=model_config.model_task_config.tool_use,
|
||||
tool_options=[SearchKnowledgeFromLPMMTool.get_tool_definition()],
|
||||
tool_options=[search_knowledge_tool.get_tool_definition()],
|
||||
)
|
||||
|
||||
# logger.info(f"工具调用提示词: {prompt}")
|
||||
# logger.info(f"工具调用: {tool_calls}")
|
||||
|
||||
if tool_calls:
|
||||
result = await self.tool_executor.execute_tool_call(tool_calls[0], SearchKnowledgeFromLPMMTool())
|
||||
result = await self.tool_executor.execute_tool_call(tool_calls[0])
|
||||
end_time = time.time()
|
||||
if not result or not result.get("content"):
|
||||
logger.debug("从LPMM知识库获取知识失败,返回空知识...")
|
||||
|
||||
@@ -11,9 +11,9 @@ from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from maim_message import BaseMessageInfo, MessageBase, Seg
|
||||
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUserInfo
|
||||
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
@@ -636,11 +636,12 @@ class PrivateReplyer:
|
||||
target = "消息"
|
||||
|
||||
if reply_message:
|
||||
user_id = reply_message.user_info.user_id
|
||||
reply_user_info = reply_message.message_info.user_info
|
||||
user_id = reply_user_info.user_id
|
||||
person = Person(platform=platform, user_id=user_id)
|
||||
person_name = person.person_name or user_id
|
||||
sender = person_name
|
||||
target = reply_message.processed_plain_text
|
||||
target = reply_message.processed_plain_text or ""
|
||||
|
||||
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
|
||||
|
||||
@@ -663,7 +664,6 @@ class PrivateReplyer:
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=True,
|
||||
long_time_notice=True,
|
||||
)
|
||||
|
||||
message_list_before_short = get_messages_before_time_in_chat(
|
||||
@@ -675,16 +675,17 @@ class PrivateReplyer:
|
||||
|
||||
person_list_short: List[Person] = []
|
||||
for msg in message_list_before_short:
|
||||
msg_user_info = msg.message_info.user_info
|
||||
# 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI)
|
||||
if is_bot_self(msg.user_info.platform, msg.user_info.user_id):
|
||||
if is_bot_self(msg.platform, msg_user_info.user_id):
|
||||
continue
|
||||
if (
|
||||
reply_message
|
||||
and reply_message.user_info.user_id == msg.user_info.user_id
|
||||
and reply_message.user_info.platform == msg.user_info.platform
|
||||
and reply_message.message_info.user_info.user_id == msg_user_info.user_id
|
||||
and reply_message.platform == msg.platform
|
||||
):
|
||||
continue
|
||||
person = Person(platform=msg.user_info.platform, user_id=msg.user_info.user_id)
|
||||
person = Person(platform=msg.platform, user_id=msg_user_info.user_id)
|
||||
if person.is_known:
|
||||
person_list_short.append(person)
|
||||
|
||||
@@ -960,7 +961,7 @@ class PrivateReplyer:
|
||||
platform=self.chat_stream.platform,
|
||||
message_id=message_id,
|
||||
time=thinking_start_time,
|
||||
user_info=UserInfo(
|
||||
user_info=MaimUserInfo(
|
||||
user_id=str(global_config.bot.qq_account),
|
||||
user_nickname=global_config.bot.nickname,
|
||||
),
|
||||
|
||||
@@ -76,7 +76,6 @@ MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = {
|
||||
"webui.jargon": ("#d7d75f", None, False),
|
||||
"webui.person": ("#87d787", None, False),
|
||||
"webui.statistics": ("#af87ff", None, False),
|
||||
"webui.annual_report": ("#ffaf87", None, False),
|
||||
"webui.plugin_routes": ("#ffaf00", None, False),
|
||||
"webui.plugin_progress": ("#ff8700", None, False),
|
||||
"webui.git_mirror": ("#878787", None, False),
|
||||
@@ -180,7 +179,6 @@ MODULE_ALIASES = {
|
||||
"webui.jargon": "WebUI黑话",
|
||||
"webui.person": "WebUI人物",
|
||||
"webui.statistics": "WebUI统计",
|
||||
"webui.annual_report": "WebUI年报",
|
||||
"webui.plugin_routes": "WebUI插件",
|
||||
"webui.plugin_progress": "WebUI插件进度",
|
||||
"webui.git_mirror": "WebUI镜像",
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import json
|
||||
@@ -48,48 +47,8 @@ def _parse_additional_config(message: Messages) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
|
||||
def _normalize_optional_str(value: object) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
except (TypeError, ValueError):
|
||||
return str(value)
|
||||
|
||||
|
||||
def _message_to_instance(message: Messages) -> SessionMessage:
|
||||
config = _parse_additional_config(message)
|
||||
instance = SessionMessage.from_db_instance(message)
|
||||
instance.interest_value = config.get("interest_value")
|
||||
instance.key_words = _normalize_optional_str(config.get("key_words"))
|
||||
instance.key_words_lite = _normalize_optional_str(config.get("key_words_lite"))
|
||||
instance.reply_probability_boost = config.get("reply_probability_boost")
|
||||
instance.priority_mode = _normalize_optional_str(config.get("priority_mode"))
|
||||
instance.priority_info = _normalize_optional_str(config.get("priority_info"))
|
||||
instance.intercept_message_level = config.get("intercept_message_level", 0)
|
||||
instance.selected_expressions = _normalize_optional_str(config.get("selected_expressions"))
|
||||
group_info = instance.message_info.group_info
|
||||
legacy_group_info = None
|
||||
if group_info:
|
||||
legacy_group_info = SimpleNamespace(
|
||||
group_id=group_info.group_id,
|
||||
group_name=group_info.group_name,
|
||||
)
|
||||
instance.user_info = SimpleNamespace(
|
||||
user_id=instance.message_info.user_info.user_id,
|
||||
user_nickname=instance.message_info.user_info.user_nickname,
|
||||
user_cardname=instance.message_info.user_info.user_cardname,
|
||||
platform=instance.platform,
|
||||
)
|
||||
instance.chat_info = SimpleNamespace(
|
||||
platform=instance.platform,
|
||||
stream_id=instance.session_id,
|
||||
group_info=legacy_group_info,
|
||||
)
|
||||
instance.time = instance.timestamp.timestamp()
|
||||
return instance
|
||||
return SessionMessage.from_db_instance(message)
|
||||
|
||||
|
||||
def _coerce_datetime(value: Any) -> Any:
|
||||
@@ -118,6 +77,7 @@ def _build_message_conditions(
|
||||
end_time: float | None = None,
|
||||
before_time: float | None = None,
|
||||
after_time: float | None = None,
|
||||
has_reply_to: bool | None = None,
|
||||
) -> list[Any]:
|
||||
conditions: list[Any] = [Messages.message_id != "notice"]
|
||||
|
||||
@@ -141,6 +101,10 @@ def _build_message_conditions(
|
||||
conditions.append(Messages.timestamp < _coerce_datetime(before_time))
|
||||
if after_time is not None:
|
||||
conditions.append(Messages.timestamp > _coerce_datetime(after_time))
|
||||
if has_reply_to is True:
|
||||
conditions.append(col(Messages.reply_to).is_not(None))
|
||||
elif has_reply_to is False:
|
||||
conditions.append(col(Messages.reply_to).is_(None))
|
||||
|
||||
return conditions
|
||||
|
||||
@@ -261,6 +225,7 @@ def count_messages(
|
||||
end_time: float | None = None,
|
||||
before_time: float | None = None,
|
||||
after_time: float | None = None,
|
||||
has_reply_to: bool | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
根据提供的过滤器计算消息数量。
|
||||
@@ -276,6 +241,7 @@ def count_messages(
|
||||
end_time: 结束时间,闭区间上界。
|
||||
before_time: 严格早于该时间。
|
||||
after_time: 严格晚于该时间。
|
||||
has_reply_to: 是否要求存在 reply_to 字段。
|
||||
|
||||
Returns:
|
||||
符合条件的消息数量,如果出错则返回 0。
|
||||
@@ -292,6 +258,7 @@ def count_messages(
|
||||
end_time=end_time,
|
||||
before_time=before_time,
|
||||
after_time=after_time,
|
||||
has_reply_to=has_reply_to,
|
||||
)
|
||||
statement = select(func.count()).select_from(Messages).where(*conditions)
|
||||
with get_db_session() as session:
|
||||
@@ -302,7 +269,7 @@ def count_messages(
|
||||
"使用 SQLModel 计数消息失败 "
|
||||
f"(session_id={session_id}, user_id={user_id}, group_id={group_id}, platform={platform}, "
|
||||
f"message_id={message_id}, reply_to={reply_to}, start_time={start_time}, end_time={end_time}, "
|
||||
f"before_time={before_time}, after_time={after_time}): {e}\n{traceback.format_exc()}"
|
||||
f"before_time={before_time}, after_time={after_time}, has_reply_to={has_reply_to}): {e}\n{traceback.format_exc()}"
|
||||
)
|
||||
logger.error(log_message)
|
||||
return 0
|
||||
|
||||
@@ -25,13 +25,15 @@ def is_port_conflict_error(error: OSError) -> bool:
|
||||
return "address already in use" in message or "已被占用" in message
|
||||
|
||||
|
||||
def check_port_available(host: str, port: int) -> bool:
|
||||
def check_port_available(host: str, port: int, *, allow_reuse_addr: bool = False) -> bool:
|
||||
family = _detect_socket_family(host)
|
||||
test_host = _normalize_test_host(host)
|
||||
|
||||
try:
|
||||
with socket.socket(family, socket.SOCK_STREAM) as test_socket:
|
||||
test_socket.settimeout(1)
|
||||
if allow_reuse_addr:
|
||||
test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
test_socket.bind((test_host, port))
|
||||
return True
|
||||
except OSError:
|
||||
@@ -65,8 +67,9 @@ def assert_port_available(
|
||||
service_name: str,
|
||||
logger,
|
||||
config_hint: Optional[str] = None,
|
||||
allow_reuse_addr: bool = False,
|
||||
) -> None:
|
||||
if check_port_available(host=host, port=port):
|
||||
if check_port_available(host=host, port=port, allow_reuse_addr=allow_reuse_addr):
|
||||
return
|
||||
|
||||
log_port_conflict(
|
||||
|
||||
@@ -8,13 +8,17 @@
|
||||
3. 详情按需加载
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(prefix="/api/planner", tags=["planner"])
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
router = APIRouter(prefix="/api/planner", tags=["planner"], dependencies=[Depends(require_auth)])
|
||||
|
||||
# 规划器日志目录
|
||||
PLAN_LOG_DIR = Path("logs/plan")
|
||||
|
||||
@@ -8,13 +8,17 @@
|
||||
3. 详情按需加载
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(prefix="/api/replier", tags=["replier"])
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
router = APIRouter(prefix="/api/replier", tags=["replier"], dependencies=[Depends(require_auth)])
|
||||
|
||||
# 回复器日志目录
|
||||
REPLY_LOG_DIR = Path("logs/reply")
|
||||
|
||||
@@ -52,7 +52,6 @@ def _setup_cors(app: FastAPI, port: int):
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
|
||||
allow_headers=[
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
"Accept",
|
||||
"Origin",
|
||||
"X-Requested-With",
|
||||
|
||||
@@ -9,6 +9,7 @@ from .auth import (
|
||||
COOKIE_NAME,
|
||||
COOKIE_MAX_AGE,
|
||||
get_current_token,
|
||||
is_token_valid,
|
||||
set_auth_cookie,
|
||||
clear_auth_cookie,
|
||||
verify_auth_token_from_cookie_or_header,
|
||||
@@ -24,6 +25,7 @@ __all__ = [
|
||||
"COOKIE_NAME",
|
||||
"COOKIE_MAX_AGE",
|
||||
"get_current_token",
|
||||
"is_token_valid",
|
||||
"set_auth_cookie",
|
||||
"clear_auth_cookie",
|
||||
"verify_auth_token_from_cookie_or_header",
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
"""
|
||||
WebUI 认证模块
|
||||
提供统一的认证依赖,支持 Cookie 和 Header 两种方式
|
||||
"""
|
||||
"""WebUI 认证模块。"""
|
||||
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException, Cookie, Header, Response, Request
|
||||
|
||||
from fastapi import Cookie, HTTPException, Request, Response
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from .security import get_token_manager
|
||||
@@ -39,17 +37,13 @@ def _is_secure_environment() -> bool:
|
||||
|
||||
|
||||
def get_current_token(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> str:
|
||||
"""
|
||||
获取当前请求的 token,优先从 Cookie 获取,其次从 Header 获取
|
||||
获取当前请求的 token,仅从 HttpOnly Cookie 获取。
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization Header (Bearer token)
|
||||
|
||||
Returns:
|
||||
验证通过的 token
|
||||
@@ -57,24 +51,19 @@ def get_current_token(
|
||||
Raises:
|
||||
HTTPException: 认证失败时抛出 401 错误
|
||||
"""
|
||||
token = None
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
# 其次从 Header 获取(兼容旧版本)
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
# 验证 token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(token):
|
||||
if not is_token_valid(maibot_session):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
return token
|
||||
return maibot_session
|
||||
|
||||
|
||||
def is_token_valid(maibot_session: Optional[str]) -> bool:
|
||||
"""判断认证 token 是否存在且有效。"""
|
||||
if not maibot_session:
|
||||
return False
|
||||
|
||||
token_manager = get_token_manager()
|
||||
return token_manager.verify_token(maibot_session)
|
||||
|
||||
|
||||
def set_auth_cookie(response: Response, token: str, request: Optional[Request] = None) -> None:
|
||||
@@ -150,14 +139,12 @@ def clear_auth_cookie(response: Response) -> None:
|
||||
|
||||
def verify_auth_token_from_cookie_or_header(
|
||||
maibot_session: Optional[str] = None,
|
||||
authorization: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
验证认证 Token,支持从 Cookie 或 Header 获取
|
||||
验证认证 Cookie。
|
||||
|
||||
Args:
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
验证成功返回 True
|
||||
@@ -165,21 +152,7 @@ def verify_auth_token_from_cookie_or_header(
|
||||
Raises:
|
||||
HTTPException: 认证失败时抛出 401 错误
|
||||
"""
|
||||
token = None
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
# 其次从 Header 获取(兼容旧版本)
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
# 验证 token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(token):
|
||||
if not is_token_valid(maibot_session):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
return True
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
from typing import Optional
|
||||
from fastapi import Depends, Cookie, Header, Request
|
||||
from .core import get_current_token, get_token_manager, check_auth_rate_limit
|
||||
|
||||
from fastapi import Cookie, Depends, Request
|
||||
from .core import check_auth_rate_limit, get_current_token, is_token_valid
|
||||
|
||||
|
||||
async def require_auth(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> str:
|
||||
"""
|
||||
FastAPI 依赖:要求有效认证
|
||||
|
||||
用于保护需要认证的路由,自动从 Cookie 或 Header 获取并验证 token
|
||||
用于保护需要认证的路由,自动从 Cookie 获取并验证 token
|
||||
|
||||
Returns:
|
||||
验证通过的 token
|
||||
@@ -19,13 +18,12 @@ async def require_auth(
|
||||
Raises:
|
||||
HTTPException 401: 认证失败
|
||||
"""
|
||||
return get_current_token(request, maibot_session, authorization)
|
||||
return get_current_token(maibot_session)
|
||||
|
||||
|
||||
async def require_auth_with_rate_limit(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
_rate_limit: None = Depends(check_auth_rate_limit),
|
||||
) -> str:
|
||||
"""
|
||||
@@ -40,12 +38,11 @@ async def require_auth_with_rate_limit(
|
||||
HTTPException 401: 认证失败
|
||||
HTTPException 429: 请求过于频繁
|
||||
"""
|
||||
return get_current_token(request, maibot_session, authorization)
|
||||
return get_current_token(maibot_session)
|
||||
|
||||
|
||||
def get_optional_token(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
FastAPI 依赖:可选获取 token(不验证)
|
||||
@@ -55,16 +52,11 @@ def get_optional_token(
|
||||
Returns:
|
||||
token 字符串或 None
|
||||
"""
|
||||
if maibot_session:
|
||||
return maibot_session
|
||||
if authorization and authorization.startswith("Bearer "):
|
||||
return authorization.replace("Bearer ", "")
|
||||
return None
|
||||
return maibot_session or None
|
||||
|
||||
|
||||
async def verify_token_optional(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""
|
||||
FastAPI 依赖:可选验证 token
|
||||
@@ -74,14 +66,4 @@ async def verify_token_optional(
|
||||
Returns:
|
||||
True 如果 token 有效,否则 False
|
||||
"""
|
||||
token = None
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not token:
|
||||
return False
|
||||
|
||||
token_manager = get_token_manager()
|
||||
return token_manager.verify_token(token)
|
||||
return is_token_valid(maibot_session)
|
||||
|
||||
@@ -1,916 +0,0 @@
|
||||
"""麦麦 2025 年度总结 API 路由"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Cookie, Depends, Header, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import desc, func
|
||||
from sqlmodel import col, select
|
||||
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import (
|
||||
ActionRecord,
|
||||
Expression,
|
||||
Images,
|
||||
Jargon,
|
||||
Messages,
|
||||
ModelUsage,
|
||||
OnlineTime,
|
||||
PersonInfo,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.core import verify_auth_token_from_cookie_or_header
|
||||
|
||||
logger = get_logger("webui.annual_report")
|
||||
|
||||
router = APIRouter(prefix="/annual-report", tags=["annual-report"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
# ==================== Pydantic 模型定义 ====================
|
||||
|
||||
|
||||
class TimeFootprintData(BaseModel):
|
||||
"""时光足迹数据"""
|
||||
|
||||
total_online_hours: float = Field(0.0, description="年度在线总时长(小时)")
|
||||
first_message_time: Optional[str] = Field(None, description="初次消息时间")
|
||||
first_message_user: Optional[str] = Field(None, description="初次消息用户昵称")
|
||||
first_message_content: Optional[str] = Field(None, description="初次消息内容(截断)")
|
||||
busiest_day: Optional[str] = Field(None, description="最忙碌的一天")
|
||||
busiest_day_count: int = Field(0, description="最忙碌那天的消息数")
|
||||
hourly_distribution: list[int] = Field(default_factory=lambda: [0] * 24, description="24小时活跃分布")
|
||||
midnight_chat_count: int = Field(0, description="深夜(0-4点)互动次数")
|
||||
is_night_owl: bool = Field(False, description="是否是夜猫子")
|
||||
|
||||
|
||||
class SocialNetworkData(BaseModel):
|
||||
"""社交网络数据"""
|
||||
|
||||
total_groups: int = Field(0, description="加入的群组总数")
|
||||
top_groups: list[dict[str, Any]] = Field(default_factory=list, description="话痨群组TOP5")
|
||||
top_users: list[dict[str, Any]] = Field(default_factory=list, description="互动最多的用户TOP5")
|
||||
at_count: int = Field(0, description="被@次数")
|
||||
mentioned_count: int = Field(0, description="被提及次数")
|
||||
longest_companion_user: Optional[str] = Field(None, description="最长情陪伴的用户")
|
||||
longest_companion_days: int = Field(0, description="陪伴天数")
|
||||
|
||||
|
||||
class BrainPowerData(BaseModel):
|
||||
"""最强大脑数据"""
|
||||
|
||||
total_tokens: int = Field(0, description="年度消耗Token总量")
|
||||
total_cost: float = Field(0.0, description="年度总花费")
|
||||
favorite_model: Optional[str] = Field(None, description="最爱用的模型")
|
||||
favorite_model_count: int = Field(0, description="最爱模型的调用次数")
|
||||
model_distribution: list[dict[str, Any]] = Field(default_factory=list, description="模型使用分布")
|
||||
top_reply_models: list[dict[str, Any]] = Field(default_factory=list, description="最喜欢的回复模型TOP5")
|
||||
most_expensive_cost: float = Field(0.0, description="最昂贵的一次思考花费")
|
||||
most_expensive_time: Optional[str] = Field(None, description="最昂贵思考的时间")
|
||||
top_token_consumers: list[dict[str, Any]] = Field(default_factory=list, description="烧钱大户TOP3")
|
||||
silence_rate: float = Field(0.0, description="高冷指数(沉默率)")
|
||||
total_actions: int = Field(0, description="总动作数")
|
||||
no_reply_count: int = Field(0, description="选择沉默的次数")
|
||||
avg_interest_value: float = Field(0.0, description="平均兴趣值")
|
||||
max_interest_value: float = Field(0.0, description="最高兴趣值")
|
||||
max_interest_time: Optional[str] = Field(None, description="最高兴趣值时间")
|
||||
avg_reasoning_length: float = Field(0.0, description="平均思考长度")
|
||||
max_reasoning_length: int = Field(0, description="最长思考长度")
|
||||
max_reasoning_time: Optional[str] = Field(None, description="最长思考的时间")
|
||||
|
||||
|
||||
class ExpressionVibeData(BaseModel):
|
||||
"""个性与表达数据"""
|
||||
|
||||
top_emoji: Optional[dict[str, Any]] = Field(None, description="表情包之王")
|
||||
top_emojis: list[dict[str, Any]] = Field(default_factory=list, description="TOP3表情包")
|
||||
top_expressions: list[dict[str, Any]] = Field(default_factory=list, description="印象最深刻的表达风格")
|
||||
rejected_expression_count: int = Field(0, description="被拒绝的表达次数")
|
||||
checked_expression_count: int = Field(0, description="已检查的表达次数")
|
||||
total_expressions: int = Field(0, description="表达总数")
|
||||
action_types: list[dict[str, Any]] = Field(default_factory=list, description="动作类型分布")
|
||||
image_processed_count: int = Field(0, description="处理的图片数量")
|
||||
late_night_reply: Optional[dict[str, Any]] = Field(None, description="深夜还在回复")
|
||||
favorite_reply: Optional[dict[str, Any]] = Field(None, description="最喜欢的回复")
|
||||
|
||||
|
||||
class AchievementData(BaseModel):
|
||||
"""趣味成就数据"""
|
||||
|
||||
new_jargon_count: int = Field(0, description="新学到的黑话数量")
|
||||
sample_jargons: list[dict[str, Any]] = Field(default_factory=list, description="代表性黑话示例")
|
||||
total_messages: int = Field(0, description="总消息数")
|
||||
total_replies: int = Field(0, description="总回复数")
|
||||
|
||||
|
||||
class AnnualReportData(BaseModel):
|
||||
"""年度报告完整数据"""
|
||||
|
||||
year: int = Field(2025, description="报告年份")
|
||||
bot_name: str = Field("麦麦", description="Bot名称")
|
||||
generated_at: str = Field(..., description="报告生成时间")
|
||||
time_footprint: TimeFootprintData = Field(default_factory=lambda: TimeFootprintData.model_construct())
|
||||
social_network: SocialNetworkData = Field(default_factory=lambda: SocialNetworkData.model_construct())
|
||||
brain_power: BrainPowerData = Field(default_factory=lambda: BrainPowerData.model_construct())
|
||||
expression_vibe: ExpressionVibeData = Field(default_factory=lambda: ExpressionVibeData.model_construct())
|
||||
achievements: AchievementData = Field(default_factory=lambda: AchievementData.model_construct())
|
||||
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
|
||||
|
||||
def get_year_time_range(year: int = 2025) -> tuple[float, float]:
|
||||
"""获取指定年份的时间戳范围"""
|
||||
start = datetime(year, 1, 1, 0, 0, 0).timestamp()
|
||||
end = datetime(year, 12, 31, 23, 59, 59).timestamp()
|
||||
return start, end
|
||||
|
||||
|
||||
def get_year_datetime_range(year: int = 2025) -> tuple[datetime, datetime]:
|
||||
"""获取指定年份的 datetime 范围"""
|
||||
start = datetime(year, 1, 1, 0, 0, 0)
|
||||
end = datetime(year, 12, 31, 23, 59, 59)
|
||||
return start, end
|
||||
|
||||
|
||||
# ==================== 维度一:时光足迹 ====================
|
||||
|
||||
|
||||
async def get_time_footprint(year: int = 2025) -> TimeFootprintData:
|
||||
"""获取时光足迹数据"""
|
||||
data = TimeFootprintData.model_construct()
|
||||
start_ts, end_ts = get_year_time_range(year)
|
||||
start_dt, end_dt = get_year_datetime_range(year)
|
||||
|
||||
try:
|
||||
# 1. 年度在线时长
|
||||
with get_db_session() as session:
|
||||
statement = select(OnlineTime).where(
|
||||
col(OnlineTime.start_timestamp) >= start_dt,
|
||||
col(OnlineTime.end_timestamp) <= end_dt,
|
||||
)
|
||||
online_records = session.exec(statement).all()
|
||||
total_seconds = 0
|
||||
for record in online_records:
|
||||
try:
|
||||
start = max(record.start_timestamp, start_dt)
|
||||
end = min(record.end_timestamp, end_dt)
|
||||
if end > start:
|
||||
total_seconds += (end - start).total_seconds()
|
||||
except Exception:
|
||||
continue
|
||||
data.total_online_hours = round(total_seconds / 3600, 2)
|
||||
|
||||
# 2. 初次相遇 - 年度第一条消息
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(Messages)
|
||||
.where(
|
||||
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||
)
|
||||
.order_by(col(Messages.timestamp).asc())
|
||||
.limit(1)
|
||||
)
|
||||
first_msg = session.exec(statement).first()
|
||||
if first_msg:
|
||||
data.first_message_time = first_msg.timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||
data.first_message_user = first_msg.user_nickname or first_msg.user_id or "未知用户"
|
||||
content = first_msg.processed_plain_text or first_msg.display_message or ""
|
||||
data.first_message_content = content[:50] + "..." if len(content) > 50 else content
|
||||
|
||||
# 3. 最忙碌的一天
|
||||
# 使用 SQLite 的 date 函数按日期分组
|
||||
day_expr = func.date(col(Messages.timestamp))
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(
|
||||
day_expr.label("day"),
|
||||
func.count().label("count"),
|
||||
)
|
||||
.where(
|
||||
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||
)
|
||||
.group_by(day_expr)
|
||||
.order_by(func.count().desc())
|
||||
.limit(1)
|
||||
)
|
||||
busiest_result = session.exec(statement).all()
|
||||
if busiest_result:
|
||||
data.busiest_day = busiest_result[0][0]
|
||||
data.busiest_day_count = busiest_result[0][1] or 0
|
||||
|
||||
# 4. 昼夜节律 - 24小时活跃分布
|
||||
hour_expr = func.strftime("%H", col(Messages.timestamp))
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(
|
||||
hour_expr.label("hour"),
|
||||
func.count().label("count"),
|
||||
)
|
||||
.where(
|
||||
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||
)
|
||||
.group_by(hour_expr)
|
||||
)
|
||||
hourly_rows = session.exec(statement).all()
|
||||
hourly_distribution = [0] * 24
|
||||
for row in hourly_rows:
|
||||
try:
|
||||
hour = int(row[0] or 0)
|
||||
if 0 <= hour < 24:
|
||||
hourly_distribution[hour] = row[1] or 0
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
data.hourly_distribution = hourly_distribution
|
||||
|
||||
# 5. 深夜食堂 (0-4点)
|
||||
data.midnight_chat_count = sum(hourly_distribution[0:5])
|
||||
|
||||
# 6. 判断是否夜猫子 (22点-4点活跃度 vs 6点-12点)
|
||||
night_activity = sum(hourly_distribution[22:24]) + sum(hourly_distribution[0:5])
|
||||
morning_activity = sum(hourly_distribution[6:13])
|
||||
data.is_night_owl = night_activity > morning_activity
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取时光足迹数据失败: {e}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# ==================== 维度二:社交网络 ====================
|
||||
|
||||
|
||||
async def get_social_network(year: int = 2025) -> SocialNetworkData:
|
||||
"""获取社交网络数据"""
|
||||
from src.config.config import global_config
|
||||
|
||||
data = SocialNetworkData.model_construct()
|
||||
start_ts, end_ts = get_year_time_range(year)
|
||||
|
||||
# 获取 bot 自身的 QQ 账号,用于过滤
|
||||
bot_qq = str(global_config.bot.qq_account or "")
|
||||
|
||||
try:
|
||||
# 1. 加入的群组总数
|
||||
with get_db_session() as session:
|
||||
statement = select(func.count(func.distinct(col(Messages.group_id)))).where(
|
||||
col(Messages.group_id).is_not(None),
|
||||
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||
)
|
||||
data.total_groups = int(session.exec(statement).first() or 0)
|
||||
|
||||
# 2. 话痨群组 TOP3
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(
|
||||
col(Messages.group_id),
|
||||
func.max(col(Messages.group_name)).label("group_name"),
|
||||
func.count().label("count"),
|
||||
)
|
||||
.where(
|
||||
col(Messages.group_id).is_not(None),
|
||||
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||
)
|
||||
.group_by(col(Messages.group_id))
|
||||
.order_by(func.count().desc())
|
||||
.limit(5)
|
||||
)
|
||||
top_groups_rows = session.exec(statement).all()
|
||||
data.top_groups = [
|
||||
{
|
||||
"group_id": row[0],
|
||||
"group_name": row[1] or "未知群组",
|
||||
"message_count": row[2] or 0,
|
||||
"is_webui": str(row[0]).startswith("webui_"),
|
||||
}
|
||||
for row in top_groups_rows
|
||||
]
|
||||
|
||||
# 3. 互动最多的用户 TOP5(过滤 bot 自身)
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(
|
||||
col(Messages.user_id),
|
||||
func.max(col(Messages.user_nickname)).label("user_nickname"),
|
||||
func.count().label("count"),
|
||||
)
|
||||
.where(
|
||||
col(Messages.user_id).is_not(None),
|
||||
col(Messages.user_id) != bot_qq,
|
||||
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||
)
|
||||
.group_by(col(Messages.user_id))
|
||||
.order_by(func.count().desc())
|
||||
.limit(5)
|
||||
)
|
||||
top_users_rows = session.exec(statement).all()
|
||||
data.top_users = [
|
||||
{
|
||||
"user_id": row[0],
|
||||
"user_nickname": row[1] or "未知用户",
|
||||
"message_count": row[2] or 0,
|
||||
"is_webui": str(row[0]).startswith("webui_"),
|
||||
}
|
||||
for row in top_users_rows
|
||||
]
|
||||
|
||||
# 4. 被@次数
|
||||
with get_db_session() as session:
|
||||
statement = select(func.count()).where(
|
||||
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||
col(Messages.is_at),
|
||||
)
|
||||
data.at_count = int(session.exec(statement).first() or 0)
|
||||
|
||||
# 5. 被提及次数
|
||||
with get_db_session() as session:
|
||||
statement = select(func.count()).where(
|
||||
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||
col(Messages.is_mentioned),
|
||||
)
|
||||
data.mentioned_count = int(session.exec(statement).first() or 0)
|
||||
|
||||
# 6. 最长情陪伴的用户(过滤 bot 自身)
|
||||
with get_db_session() as session:
|
||||
statement = select(PersonInfo).where(
|
||||
col(PersonInfo.user_id) != bot_qq,
|
||||
col(PersonInfo.first_known_time).is_not(None),
|
||||
col(PersonInfo.last_known_time).is_not(None),
|
||||
)
|
||||
persons = session.exec(statement).all()
|
||||
if persons:
|
||||
|
||||
def _companion_days(person: PersonInfo) -> float:
|
||||
if not person.first_known_time or not person.last_known_time:
|
||||
return 0.0
|
||||
return (person.last_known_time - person.first_known_time).total_seconds()
|
||||
|
||||
longest = max(persons, key=_companion_days)
|
||||
data.longest_companion_user = longest.person_name or longest.user_nickname or longest.user_id
|
||||
data.longest_companion_days = int(_companion_days(longest) / 86400)
|
||||
else:
|
||||
data.longest_companion_user = None
|
||||
data.longest_companion_days = 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取社交网络数据失败: {e}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# ==================== 维度三:最强大脑 ====================
|
||||
|
||||
|
||||
async def get_brain_power(year: int = 2025) -> BrainPowerData:
|
||||
"""获取最强大脑数据"""
|
||||
data = BrainPowerData.model_construct()
|
||||
start_dt, end_dt = get_year_datetime_range(year)
|
||||
start_ts, end_ts = get_year_time_range(year)
|
||||
|
||||
try:
|
||||
# 1. 年度消耗 Token 总量和总花费
|
||||
with get_db_session() as session:
|
||||
statement = select(
|
||||
func.sum(col(ModelUsage.total_tokens)).label("total_tokens"),
|
||||
func.sum(col(ModelUsage.cost)).label("total_cost"),
|
||||
).where(col(ModelUsage.timestamp) >= start_dt, col(ModelUsage.timestamp) <= end_dt)
|
||||
result = session.exec(statement).first()
|
||||
if result:
|
||||
data.total_tokens = int(result[0] or 0)
|
||||
data.total_cost = round(float(result[1] or 0), 4)
|
||||
|
||||
# 2. 最爱用的模型
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(ModelUsage)
|
||||
.where(col(ModelUsage.timestamp) >= start_dt, col(ModelUsage.timestamp) <= end_dt)
|
||||
.order_by(desc(col(ModelUsage.timestamp)))
|
||||
)
|
||||
records = session.exec(statement).all()
|
||||
|
||||
model_agg: dict[str, dict[str, float | int]] = {}
|
||||
for record in records:
|
||||
model_name = record.model_assign_name or record.model_name or "unknown"
|
||||
if model_name not in model_agg:
|
||||
model_agg[model_name] = {"count": 0, "tokens": 0, "cost": 0.0}
|
||||
bucket = model_agg[model_name]
|
||||
bucket["count"] = int(bucket["count"]) + 1
|
||||
bucket["tokens"] = int(bucket["tokens"]) + int(record.total_tokens or 0)
|
||||
bucket["cost"] = float(bucket["cost"]) + float(record.cost or 0.0)
|
||||
|
||||
model_results = sorted(
|
||||
model_agg.items(),
|
||||
key=lambda item: float(item[1]["count"]),
|
||||
reverse=True,
|
||||
)[:10]
|
||||
if model_results:
|
||||
data.favorite_model = model_results[0][0]
|
||||
data.favorite_model_count = int(model_results[0][1]["count"])
|
||||
data.model_distribution = [
|
||||
{
|
||||
"model": model_name,
|
||||
"count": int(bucket["count"]),
|
||||
"tokens": int(bucket["tokens"]),
|
||||
"cost": round(float(bucket["cost"]), 4),
|
||||
}
|
||||
for model_name, bucket in model_results
|
||||
]
|
||||
|
||||
# 3. 最昂贵的一次思考
|
||||
if records:
|
||||
expensive_record = max(records, key=lambda record: record.cost or 0.0)
|
||||
data.most_expensive_cost = round(expensive_record.cost or 0.0, 4)
|
||||
data.most_expensive_time = expensive_record.timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# 4. 烧钱大户 TOP3 (按用户,过滤 system)
|
||||
consumer_agg: dict[str, dict[str, float | int]] = {}
|
||||
for record in records:
|
||||
user_id = record.model_api_provider_name
|
||||
if not user_id or user_id == "system":
|
||||
continue
|
||||
if user_id not in consumer_agg:
|
||||
consumer_agg[user_id] = {"cost": 0.0, "tokens": 0}
|
||||
bucket = consumer_agg[user_id]
|
||||
bucket["cost"] = float(bucket["cost"]) + float(record.cost or 0.0)
|
||||
bucket["tokens"] = int(bucket["tokens"]) + int(record.total_tokens or 0)
|
||||
|
||||
data.top_token_consumers = [
|
||||
{
|
||||
"user_id": user_id,
|
||||
"cost": round(float(bucket["cost"]), 4),
|
||||
"tokens": int(bucket["tokens"]),
|
||||
}
|
||||
for user_id, bucket in sorted(
|
||||
consumer_agg.items(),
|
||||
key=lambda item: float(item[1]["cost"]),
|
||||
reverse=True,
|
||||
)[:3]
|
||||
]
|
||||
|
||||
# 5. 最喜欢的回复模型 TOP5(按模型的回复次数统计,只统计 replyer 调用)
|
||||
# 假设 replyer 调用有特定的 model_assign_name 格式或可以通过某种方式识别
|
||||
reply_model_agg: dict[str, int] = {}
|
||||
for record in records:
|
||||
model_assign_name = record.model_assign_name or ""
|
||||
if "replyer" not in model_assign_name and "回复" not in model_assign_name:
|
||||
continue
|
||||
model_name = model_assign_name or record.model_name or "unknown"
|
||||
reply_model_agg[model_name] = reply_model_agg.get(model_name, 0) + 1
|
||||
data.top_reply_models = [
|
||||
{"model": model_name, "count": count}
|
||||
for model_name, count in sorted(reply_model_agg.items(), key=lambda item: item[1], reverse=True)[:5]
|
||||
]
|
||||
|
||||
# 6. 高冷指数 (沉默率) - 基于 ActionRecords
|
||||
with get_db_session() as session:
|
||||
statement = select(func.count()).where(
|
||||
col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||
col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||
)
|
||||
total_actions = int(session.exec(statement).first() or 0)
|
||||
with get_db_session() as session:
|
||||
statement = select(func.count()).where(
|
||||
col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||
col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||
col(ActionRecord.action_name) == "no_reply",
|
||||
)
|
||||
no_reply_count = int(session.exec(statement).first() or 0)
|
||||
data.total_actions = total_actions
|
||||
data.no_reply_count = no_reply_count
|
||||
data.silence_rate = round(no_reply_count / total_actions * 100, 2) if total_actions > 0 else 0
|
||||
|
||||
# 6. 情绪波动 (兴趣值)
|
||||
data.avg_interest_value = 0.0
|
||||
data.max_interest_value = 0.0
|
||||
|
||||
# 找到最高兴趣值的时间
|
||||
if data.max_interest_value > 0:
|
||||
data.max_interest_time = None
|
||||
|
||||
# 7. 思考深度 (基于 action_reasoning 长度)
|
||||
with get_db_session() as session:
|
||||
statement = select(ActionRecord).where(
|
||||
col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||
col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||
col(ActionRecord.action_reasoning).is_not(None),
|
||||
col(ActionRecord.action_reasoning) != "",
|
||||
)
|
||||
reasoning_records = session.exec(statement).all()
|
||||
reasoning_lengths = []
|
||||
max_len = 0
|
||||
max_len_time = None
|
||||
for record in reasoning_records:
|
||||
if record.action_reasoning:
|
||||
length = len(record.action_reasoning)
|
||||
reasoning_lengths.append(length)
|
||||
if length > max_len:
|
||||
max_len = length
|
||||
max_len_time = record.timestamp
|
||||
|
||||
if reasoning_lengths:
|
||||
data.avg_reasoning_length = round(sum(reasoning_lengths) / len(reasoning_lengths), 1)
|
||||
data.max_reasoning_length = max_len
|
||||
if max_len_time:
|
||||
data.max_reasoning_time = max_len_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取最强大脑数据失败: {e}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# ==================== 维度四:个性与表达 ====================
|
||||
|
||||
|
||||
async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
|
||||
"""获取个性与表达数据"""
|
||||
from src.config.config import global_config
|
||||
|
||||
data = ExpressionVibeData.model_construct()
|
||||
start_ts, end_ts = get_year_time_range(year)
|
||||
|
||||
# 获取 bot 自身的 QQ 账号,用于筛选 bot 发送的消息
|
||||
bot_qq = str(global_config.bot.qq_account or "")
|
||||
|
||||
try:
|
||||
# 1. 表情包之王 - 使用次数最多的表情包
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(col(Images.is_registered)).order_by(desc(col(Images.query_count))).limit(5)
|
||||
top_emojis = session.exec(statement).all()
|
||||
if top_emojis:
|
||||
data.top_emoji = {
|
||||
"id": top_emojis[0].id,
|
||||
"path": top_emojis[0].full_path,
|
||||
"description": top_emojis[0].description,
|
||||
"usage_count": top_emojis[0].query_count,
|
||||
"hash": top_emojis[0].image_hash,
|
||||
}
|
||||
data.top_emojis = [
|
||||
{
|
||||
"id": e.id,
|
||||
"path": e.full_path,
|
||||
"description": e.description,
|
||||
"usage_count": e.query_count,
|
||||
"hash": e.image_hash,
|
||||
}
|
||||
for e in top_emojis
|
||||
]
|
||||
|
||||
# 2. 百变麦麦 - 最常用的表达风格
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(Expression.style, func.sum(col(Expression.count)).label("total_count"))
|
||||
.where(
|
||||
col(Expression.last_active_time) >= datetime.fromtimestamp(start_ts),
|
||||
col(Expression.last_active_time) <= datetime.fromtimestamp(end_ts),
|
||||
)
|
||||
.group_by(Expression.style)
|
||||
.order_by(func.sum(col(Expression.count)).desc())
|
||||
.limit(5)
|
||||
)
|
||||
expression_rows = session.exec(statement).all()
|
||||
data.top_expressions = [{"style": row[0], "count": row[1] or 0} for row in expression_rows]
|
||||
|
||||
# 3. 被拒绝的表达
|
||||
data.rejected_expression_count = 0
|
||||
|
||||
# 4. 已检查的表达
|
||||
data.checked_expression_count = 0
|
||||
|
||||
# 5. 表达总数
|
||||
with get_db_session() as session:
|
||||
statement = select(func.count()).where(
|
||||
col(Expression.last_active_time) >= datetime.fromtimestamp(start_ts),
|
||||
col(Expression.last_active_time) <= datetime.fromtimestamp(end_ts),
|
||||
)
|
||||
data.total_expressions = int(session.exec(statement).first() or 0)
|
||||
|
||||
# 6. 动作类型分布 (过滤无意义的动作)
|
||||
# 过滤掉: no_reply_until_call, make_question, no_action, wait, complete_talk, listening, block_and_ignore
|
||||
excluded_actions = [
|
||||
"reply",
|
||||
"no_reply",
|
||||
"no_reply_until_call",
|
||||
"make_question",
|
||||
"no_action",
|
||||
"wait",
|
||||
"complete_talk",
|
||||
"listening",
|
||||
"block_and_ignore",
|
||||
]
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(ActionRecord.action_name, func.count().label("count"))
|
||||
.where(
|
||||
col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||
col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||
col(ActionRecord.action_name).not_in(excluded_actions),
|
||||
)
|
||||
.group_by(ActionRecord.action_name)
|
||||
.order_by(func.count().desc())
|
||||
.limit(10)
|
||||
)
|
||||
action_rows = session.exec(statement).all()
|
||||
data.action_types = [{"action": row[0], "count": row[1]} for row in action_rows]
|
||||
|
||||
# 7. 处理的图片数量
|
||||
with get_db_session() as session:
|
||||
statement = select(func.count()).where(
|
||||
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||
col(Messages.is_picture),
|
||||
)
|
||||
data.image_processed_count = int(session.exec(statement).first() or 0)
|
||||
|
||||
# 8. 深夜还在回复 (0-6点最晚的10条消息中随机抽取一条)
|
||||
import random
|
||||
import re
|
||||
|
||||
def clean_message_content(content: str) -> str:
|
||||
"""清理消息内容,移除回复引用等标记"""
|
||||
if not content:
|
||||
return ""
|
||||
# 移除 [回复<xxx:xxx> 的消息:...] 格式的引用
|
||||
content = re.sub(r"\[回复<[^>]+>\s*的消息[::][^\]]*\]", "", content)
|
||||
# 移除 [图片] [表情] 等标记
|
||||
content = re.sub(r"\[(图片|表情|语音|视频|文件)\]", "", content)
|
||||
# 移除多余的空白
|
||||
content = re.sub(r"\s+", " ", content).strip()
|
||||
return content
|
||||
|
||||
# 使用 user_id 判断是否是 bot 发送的消息
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(Messages)
|
||||
.where(
|
||||
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||
col(Messages.user_id) == bot_qq,
|
||||
)
|
||||
.order_by(desc(col(Messages.timestamp)))
|
||||
.limit(200)
|
||||
)
|
||||
late_night_messages = session.exec(statement).all()
|
||||
# 筛选出0-6点的消息
|
||||
late_night_filtered = []
|
||||
for msg in late_night_messages:
|
||||
msg_dt = msg.timestamp
|
||||
hour = msg_dt.hour
|
||||
if 0 <= hour < 6: # 0点到6点
|
||||
raw_content = msg.processed_plain_text or msg.display_message or ""
|
||||
cleaned_content = clean_message_content(raw_content)
|
||||
# 只保留有意义的内容
|
||||
if cleaned_content and len(cleaned_content) > 2:
|
||||
late_night_filtered.append(
|
||||
{
|
||||
"time": msg_dt.timestamp(),
|
||||
"hour": hour,
|
||||
"minute": msg_dt.minute,
|
||||
"content": cleaned_content,
|
||||
"datetime_str": msg_dt.strftime("%H:%M"),
|
||||
}
|
||||
)
|
||||
if len(late_night_filtered) >= 10:
|
||||
break
|
||||
|
||||
if late_night_filtered:
|
||||
selected = random.choice(late_night_filtered)
|
||||
content = selected["content"][:50] + "..." if len(selected["content"]) > 50 else selected["content"]
|
||||
data.late_night_reply = {
|
||||
"time": selected["datetime_str"],
|
||||
"content": content,
|
||||
}
|
||||
|
||||
# 9. 最喜欢的回复(按 action_data 统计回复内容出现次数)
|
||||
from collections import Counter
|
||||
import json as json_lib
|
||||
|
||||
with get_db_session() as session:
|
||||
statement = select(ActionRecord).where(
|
||||
col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||
col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||
col(ActionRecord.action_name) == "reply",
|
||||
col(ActionRecord.action_data).is_not(None),
|
||||
col(ActionRecord.action_data) != "",
|
||||
)
|
||||
reply_records = session.exec(statement).all()
|
||||
|
||||
reply_contents = []
|
||||
for record in reply_records:
|
||||
try:
|
||||
action_data = record.action_data
|
||||
if action_data:
|
||||
content = None
|
||||
# 尝试解析 JSON 格式
|
||||
try:
|
||||
parsed = json_lib.loads(action_data)
|
||||
if isinstance(parsed, dict):
|
||||
# 优先使用 reply_text,其次使用 content
|
||||
content = parsed.get("reply_text") or parsed.get("content")
|
||||
elif isinstance(parsed, str):
|
||||
content = parsed
|
||||
except (json_lib.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# 如果 JSON 解析失败,尝试解析 Python 字典字符串格式
|
||||
# 例如: "{'reply_text': '墨白灵不知道哦'}"
|
||||
if content is None:
|
||||
import ast
|
||||
|
||||
try:
|
||||
parsed = ast.literal_eval(action_data)
|
||||
if isinstance(parsed, dict):
|
||||
content = parsed.get("reply_text") or parsed.get("content")
|
||||
elif isinstance(parsed, str):
|
||||
content = parsed
|
||||
except (ValueError, SyntaxError):
|
||||
# 无法解析,使用原始字符串
|
||||
content = action_data
|
||||
|
||||
# 只统计有意义的回复(长度大于2)
|
||||
if content and len(content) > 2:
|
||||
reply_contents.append(content)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if reply_contents:
|
||||
content_counter = Counter(reply_contents)
|
||||
most_common = content_counter.most_common(1)
|
||||
if most_common:
|
||||
fav_content, fav_count = most_common[0]
|
||||
# 截断过长的内容
|
||||
display_content = fav_content[:50] + "..." if len(fav_content) > 50 else fav_content
|
||||
data.favorite_reply = {
|
||||
"content": display_content,
|
||||
"count": fav_count,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取个性与表达数据失败: {e}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# ==================== 维度五:趣味成就 ====================
|
||||
|
||||
|
||||
async def get_achievements(year: int = 2025) -> AchievementData:
|
||||
"""获取趣味成就数据"""
|
||||
data = AchievementData.model_construct()
|
||||
start_ts, end_ts = get_year_time_range(year)
|
||||
|
||||
try:
|
||||
# 1. 新学到的黑话数量
|
||||
# Jargon 表没有时间字段,统计全部已确认的黑话
|
||||
with get_db_session() as session:
|
||||
statement = select(func.count()).where(col(Jargon.is_jargon))
|
||||
data.new_jargon_count = int(session.exec(statement).first() or 0)
|
||||
|
||||
# 2. 代表性黑话示例
|
||||
with get_db_session() as session:
|
||||
statement = select(Jargon).where(col(Jargon.is_jargon)).order_by(desc(col(Jargon.count))).limit(5)
|
||||
jargon_samples = session.exec(statement).all()
|
||||
data.sample_jargons = [
|
||||
{
|
||||
"content": j.content,
|
||||
"meaning": j.meaning,
|
||||
"count": j.count,
|
||||
}
|
||||
for j in jargon_samples
|
||||
]
|
||||
|
||||
# 3. 总消息数
|
||||
with get_db_session() as session:
|
||||
statement = select(func.count()).where(
|
||||
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||
)
|
||||
data.total_messages = int(session.exec(statement).first() or 0)
|
||||
|
||||
# 4. 总回复数 (有 reply_to 的消息)
|
||||
with get_db_session() as session:
|
||||
statement = select(func.count()).where(
|
||||
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||
col(Messages.reply_to).is_not(None),
|
||||
)
|
||||
data.total_replies = int(session.exec(statement).first() or 0)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取趣味成就数据失败: {e}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# ==================== API 路由 ====================
|
||||
|
||||
|
||||
@router.get("/full", response_model=AnnualReportData)
|
||||
async def get_full_annual_report(year: int = 2025, _auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
获取完整年度报告数据
|
||||
|
||||
Args:
|
||||
year: 报告年份,默认2025
|
||||
|
||||
Returns:
|
||||
完整的年度报告数据
|
||||
"""
|
||||
try:
|
||||
from src.config.config import global_config
|
||||
|
||||
logger.info(f"开始生成 {year} 年度报告...")
|
||||
|
||||
# 获取 bot 名称
|
||||
bot_name = global_config.bot.nickname or "麦麦"
|
||||
|
||||
# 并行获取各维度数据
|
||||
time_footprint = await get_time_footprint(year)
|
||||
social_network = await get_social_network(year)
|
||||
brain_power = await get_brain_power(year)
|
||||
expression_vibe = await get_expression_vibe(year)
|
||||
achievements = await get_achievements(year)
|
||||
|
||||
report = AnnualReportData(
|
||||
year=year,
|
||||
bot_name=bot_name,
|
||||
generated_at=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
time_footprint=time_footprint,
|
||||
social_network=social_network,
|
||||
brain_power=brain_power,
|
||||
expression_vibe=expression_vibe,
|
||||
achievements=achievements,
|
||||
)
|
||||
|
||||
logger.info(f"{year} 年度报告生成完成")
|
||||
return report
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成年度报告失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"生成年度报告失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/time-footprint", response_model=TimeFootprintData)
|
||||
async def get_time_footprint_api(year: int = 2025, _auth: bool = Depends(require_auth)):
|
||||
"""获取时光足迹数据"""
|
||||
try:
|
||||
return await get_time_footprint(year)
|
||||
except Exception as e:
|
||||
logger.error(f"获取时光足迹数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.get("/social-network", response_model=SocialNetworkData)
|
||||
async def get_social_network_api(year: int = 2025, _auth: bool = Depends(require_auth)):
|
||||
"""获取社交网络数据"""
|
||||
try:
|
||||
return await get_social_network(year)
|
||||
except Exception as e:
|
||||
logger.error(f"获取社交网络数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.get("/brain-power", response_model=BrainPowerData)
|
||||
async def get_brain_power_api(year: int = 2025, _auth: bool = Depends(require_auth)):
|
||||
"""获取最强大脑数据"""
|
||||
try:
|
||||
return await get_brain_power(year)
|
||||
except Exception as e:
|
||||
logger.error(f"获取最强大脑数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.get("/expression-vibe", response_model=ExpressionVibeData)
|
||||
async def get_expression_vibe_api(year: int = 2025, _auth: bool = Depends(require_auth)):
|
||||
"""获取个性与表达数据"""
|
||||
try:
|
||||
return await get_expression_vibe(year)
|
||||
except Exception as e:
|
||||
logger.error(f"获取个性与表达数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.get("/achievements", response_model=AchievementData)
|
||||
async def get_achievements_api(year: int = 2025, _auth: bool = Depends(require_auth)):
|
||||
"""获取趣味成就数据"""
|
||||
try:
|
||||
return await get_achievements(year)
|
||||
except Exception as e:
|
||||
logger.error(f"获取趣味成就数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
@@ -1,801 +0,0 @@
|
||||
"""本地聊天室路由 - WebUI 与麦麦直接对话
|
||||
|
||||
支持两种模式:
|
||||
1. WebUI 模式:使用 WebUI 平台独立身份聊天
|
||||
2. 虚拟身份模式:使用真实平台用户的身份,在虚拟群聊中与麦麦对话
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Cookie, Depends, Header, Query, WebSocket, WebSocketDisconnect
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import case, desc, func
|
||||
from sqlmodel import col, select, delete
|
||||
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Messages, PersonInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.webui.core import get_token_manager, verify_auth_token_from_cookie_or_header
|
||||
from src.webui.routers.websocket.auth import verify_ws_token
|
||||
|
||||
logger = get_logger("webui.chat")
|
||||
|
||||
router = APIRouter(prefix="/api/chat", tags=["LocalChat"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
# WebUI 聊天的虚拟群组 ID
|
||||
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
|
||||
WEBUI_CHAT_PLATFORM = "webui"
|
||||
|
||||
# 虚拟身份模式的群 ID 前缀
|
||||
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
|
||||
|
||||
# 固定的 WebUI 用户 ID 前缀
|
||||
WEBUI_USER_ID_PREFIX = "webui_user_"
|
||||
|
||||
|
||||
class VirtualIdentityConfig(BaseModel):
|
||||
"""虚拟身份配置"""
|
||||
|
||||
enabled: bool = False # 是否启用虚拟身份模式
|
||||
platform: Optional[str] = None # 目标平台(如 qq, discord 等)
|
||||
person_id: Optional[str] = None # PersonInfo 的 person_id
|
||||
user_id: Optional[str] = None # 原始平台用户 ID
|
||||
user_nickname: Optional[str] = None # 用户昵称
|
||||
group_id: Optional[str] = None # 虚拟群 ID(自动生成或用户指定)
|
||||
group_name: Optional[str] = None # 虚拟群名(用户自定义)
|
||||
|
||||
|
||||
class ChatHistoryMessage(BaseModel):
|
||||
"""聊天历史消息"""
|
||||
|
||||
id: str
|
||||
type: str # 'user' | 'bot' | 'system'
|
||||
content: str
|
||||
timestamp: float
|
||||
sender_name: str
|
||||
sender_id: Optional[str] = None
|
||||
is_bot: bool = False
|
||||
|
||||
|
||||
class ChatHistoryManager:
|
||||
"""聊天历史管理器 - 使用 SQLite 数据库存储"""
|
||||
|
||||
def __init__(self, max_messages: int = 200):
|
||||
self.max_messages = max_messages
|
||||
|
||||
def _message_to_dict(self, msg: Messages, group_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""将数据库消息转换为前端格式
|
||||
|
||||
Args:
|
||||
msg: 数据库消息对象
|
||||
group_id: 群 ID,用于判断是否是虚拟群
|
||||
"""
|
||||
# 判断是否是机器人消息
|
||||
user_id = msg.user_id or ""
|
||||
|
||||
# 对于虚拟群,通过比较机器人 QQ 账号来判断
|
||||
# 对于普通 WebUI 群,检查 user_id 是否以 webui_ 开头
|
||||
if group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
|
||||
# 虚拟群:user_id 等于机器人 QQ 账号的是机器人消息
|
||||
bot_qq = str(global_config.bot.qq_account)
|
||||
is_bot = user_id == bot_qq
|
||||
else:
|
||||
# 普通 WebUI 群:不以 webui_ 开头的是机器人消息
|
||||
is_bot = not user_id.startswith("webui_") and not user_id.startswith(WEBUI_USER_ID_PREFIX)
|
||||
|
||||
return {
|
||||
"id": msg.message_id,
|
||||
"type": "bot" if is_bot else "user",
|
||||
"content": msg.processed_plain_text or msg.display_message or "",
|
||||
"timestamp": msg.timestamp.timestamp(),
|
||||
"sender_name": msg.user_nickname or (global_config.bot.nickname if is_bot else "未知用户"),
|
||||
"sender_id": "bot" if is_bot else user_id,
|
||||
"is_bot": is_bot,
|
||||
}
|
||||
|
||||
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""从数据库获取最近的历史记录
|
||||
|
||||
Args:
|
||||
limit: 获取的消息数量
|
||||
group_id: 群 ID,默认为 WEBUI_CHAT_GROUP_ID
|
||||
"""
|
||||
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
|
||||
try:
|
||||
# 查询指定群的消息,按时间排序
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(Messages)
|
||||
.where(col(Messages.group_id) == target_group_id)
|
||||
.order_by(desc(col(Messages.timestamp)))
|
||||
.limit(limit)
|
||||
)
|
||||
messages = session.exec(statement).all()
|
||||
|
||||
# 转换为列表并反转(使最旧的消息在前)
|
||||
# 传递 group_id 以便正确判断虚拟群中的机器人消息
|
||||
result = [self._message_to_dict(msg, target_group_id) for msg in messages]
|
||||
result.reverse()
|
||||
|
||||
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录 (group_id={target_group_id})")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载聊天记录失败: {e}")
|
||||
return []
|
||||
|
||||
def clear_history(self, group_id: Optional[str] = None) -> int:
|
||||
"""清空聊天历史记录
|
||||
|
||||
Args:
|
||||
group_id: 群 ID,默认清空 WebUI 默认聊天室
|
||||
"""
|
||||
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = delete(Messages).where(col(Messages.group_id) == target_group_id)
|
||||
result = session.exec(statement)
|
||||
deleted = result.rowcount or 0
|
||||
logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
|
||||
return deleted
|
||||
except Exception as e:
|
||||
logger.error(f"清空聊天记录失败: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
# 全局聊天历史管理器
|
||||
chat_history = ChatHistoryManager()
|
||||
|
||||
|
||||
# 存储 WebSocket 连接
|
||||
class ChatConnectionManager:
|
||||
"""聊天连接管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.user_sessions: Dict[str, str] = {} # user_id -> session_id 映射
|
||||
|
||||
async def connect(self, websocket: WebSocket, session_id: str, user_id: str):
|
||||
await websocket.accept()
|
||||
self.active_connections[session_id] = websocket
|
||||
self.user_sessions[user_id] = session_id
|
||||
logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}")
|
||||
|
||||
def disconnect(self, session_id: str, user_id: str):
|
||||
if session_id in self.active_connections:
|
||||
del self.active_connections[session_id]
|
||||
if user_id in self.user_sessions and self.user_sessions[user_id] == session_id:
|
||||
del self.user_sessions[user_id]
|
||||
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
|
||||
|
||||
async def send_message(self, session_id: str, message: dict[str, Any]):
|
||||
if session_id in self.active_connections:
|
||||
try:
|
||||
await self.active_connections[session_id].send_json(message)
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
|
||||
async def broadcast(self, message: dict[str, Any]):
|
||||
"""广播消息给所有连接"""
|
||||
for session_id in list(self.active_connections.keys()):
|
||||
await self.send_message(session_id, message)
|
||||
|
||||
|
||||
chat_manager = ChatConnectionManager()
|
||||
|
||||
|
||||
def create_message_data(
|
||||
content: str,
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
message_id: Optional[str] = None,
|
||||
is_at_bot: bool = True,
|
||||
virtual_config: Optional[VirtualIdentityConfig] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""创建符合麦麦消息格式的消息数据
|
||||
|
||||
Args:
|
||||
content: 消息内容
|
||||
user_id: 用户 ID
|
||||
user_name: 用户昵称
|
||||
message_id: 消息 ID(可选,自动生成)
|
||||
is_at_bot: 是否 @ 机器人
|
||||
virtual_config: 虚拟身份配置(可选,启用后使用真实平台身份)
|
||||
"""
|
||||
if message_id is None:
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
# 确定使用的平台、群信息和用户信息
|
||||
if virtual_config and virtual_config.enabled:
|
||||
# 虚拟身份模式:使用真实平台身份
|
||||
platform = virtual_config.platform or WEBUI_CHAT_PLATFORM
|
||||
group_id = virtual_config.group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{uuid.uuid4().hex[:8]}"
|
||||
group_name = virtual_config.group_name or "WebUI虚拟群聊"
|
||||
actual_user_id = virtual_config.user_id or user_id
|
||||
actual_user_name = virtual_config.user_nickname or user_name
|
||||
else:
|
||||
# 标准 WebUI 模式
|
||||
platform = WEBUI_CHAT_PLATFORM
|
||||
group_id = WEBUI_CHAT_GROUP_ID
|
||||
group_name = "WebUI本地聊天室"
|
||||
actual_user_id = user_id
|
||||
actual_user_name = user_name
|
||||
|
||||
return {
|
||||
"message_info": {
|
||||
"platform": platform,
|
||||
"message_id": message_id,
|
||||
"time": time.time(),
|
||||
"group_info": {
|
||||
"group_id": group_id,
|
||||
"group_name": group_name,
|
||||
"platform": platform,
|
||||
},
|
||||
"user_info": {
|
||||
"user_id": actual_user_id,
|
||||
"user_nickname": actual_user_name,
|
||||
"user_cardname": actual_user_name,
|
||||
"platform": platform,
|
||||
},
|
||||
"additional_config": {
|
||||
"at_bot": is_at_bot,
|
||||
},
|
||||
},
|
||||
"message_segment": {
|
||||
"type": "seglist",
|
||||
"data": [
|
||||
{
|
||||
"type": "text",
|
||||
"data": content,
|
||||
},
|
||||
{
|
||||
"type": "mention_bot",
|
||||
"data": "1.0",
|
||||
},
|
||||
],
|
||||
},
|
||||
"raw_message": content,
|
||||
"processed_plain_text": content,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
async def get_chat_history(
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
|
||||
group_id: Optional[str] = Query(default=None), # 可选:指定群 ID 获取历史
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""获取聊天历史记录
|
||||
|
||||
所有 WebUI 用户共享同一个聊天室,因此返回所有历史记录
|
||||
如果指定了 group_id,则获取该虚拟群的历史记录
|
||||
"""
|
||||
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
|
||||
history = chat_history.get_history(limit, target_group_id)
|
||||
return {
|
||||
"success": True,
|
||||
"messages": history,
|
||||
"total": len(history),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/platforms")
|
||||
async def get_available_platforms(_auth: bool = Depends(require_auth)):
|
||||
"""获取可用平台列表
|
||||
|
||||
从 PersonInfo 表中获取所有已知的平台
|
||||
"""
|
||||
try:
|
||||
# 查询所有不同的平台
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(PersonInfo.platform, func.count().label("count"))
|
||||
.group_by(PersonInfo.platform)
|
||||
.order_by(func.count().desc())
|
||||
)
|
||||
platforms = session.exec(statement).all()
|
||||
|
||||
result = []
|
||||
for platform, count in platforms:
|
||||
if platform:
|
||||
result.append({"platform": platform, "count": count})
|
||||
|
||||
return {"success": True, "platforms": result}
|
||||
except Exception as e:
|
||||
logger.error(f"获取平台列表失败: {e}")
|
||||
return {"success": False, "error": str(e), "platforms": []}
|
||||
|
||||
|
||||
@router.get("/persons")
|
||||
async def get_persons_by_platform(
|
||||
platform: str = Query(..., description="平台名称"),
|
||||
search: Optional[str] = Query(default=None, description="搜索关键词"),
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""获取指定平台的用户列表
|
||||
|
||||
Args:
|
||||
platform: 平台名称(如 qq, discord 等)
|
||||
search: 搜索关键词(匹配昵称、用户名、user_id)
|
||||
limit: 返回数量限制
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
statement = select(PersonInfo).where(col(PersonInfo.platform) == platform)
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
statement = statement.where(
|
||||
(col(PersonInfo.person_name).contains(search))
|
||||
| (col(PersonInfo.user_nickname).contains(search))
|
||||
| (col(PersonInfo.user_id).contains(search))
|
||||
)
|
||||
|
||||
# 按最后交互时间排序,优先显示活跃用户
|
||||
statement = statement.order_by(
|
||||
case((col(PersonInfo.last_known_time).is_(None), 1), else_=0),
|
||||
col(PersonInfo.last_known_time).desc(),
|
||||
).limit(limit)
|
||||
|
||||
with get_db_session() as session:
|
||||
persons = session.exec(statement).all()
|
||||
|
||||
result = []
|
||||
for person in persons:
|
||||
result.append(
|
||||
{
|
||||
"person_id": person.person_id,
|
||||
"user_id": person.user_id,
|
||||
"person_name": person.person_name,
|
||||
"nickname": person.user_nickname,
|
||||
"is_known": person.is_known,
|
||||
"platform": person.platform,
|
||||
"display_name": person.person_name or person.user_nickname or person.user_id,
|
||||
}
|
||||
)
|
||||
|
||||
return {"success": True, "persons": result, "total": len(result)}
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户列表失败: {e}")
|
||||
return {"success": False, "error": str(e), "persons": []}
|
||||
|
||||
|
||||
@router.delete("/history")
|
||||
async def clear_chat_history(group_id: Optional[str] = Query(default=None), _auth: bool = Depends(require_auth)):
|
||||
"""清空聊天历史记录
|
||||
|
||||
Args:
|
||||
group_id: 可选,指定要清空的群 ID,默认清空 WebUI 默认聊天室
|
||||
"""
|
||||
deleted = chat_history.clear_history(group_id)
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"已清空 {deleted} 条聊天记录",
|
||||
}
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_chat(
|
||||
websocket: WebSocket,
|
||||
user_id: Optional[str] = Query(default=None),
|
||||
user_name: Optional[str] = Query(default="WebUI用户"),
|
||||
platform: Optional[str] = Query(default=None),
|
||||
person_id: Optional[str] = Query(default=None),
|
||||
group_name: Optional[str] = Query(default=None),
|
||||
group_id: Optional[str] = Query(default=None), # 前端传递的稳定 group_id
|
||||
token: Optional[str] = Query(default=None), # 认证 token
|
||||
):
|
||||
"""WebSocket 聊天端点
|
||||
|
||||
Args:
|
||||
user_id: 用户唯一标识(由前端生成并持久化)
|
||||
user_name: 用户显示昵称(可修改)
|
||||
platform: 虚拟身份模式的平台(可选)
|
||||
person_id: 虚拟身份模式的用户 person_id(可选)
|
||||
group_name: 虚拟身份模式的群名(可选)
|
||||
group_id: 虚拟身份模式的群 ID(可选,由前端生成并持久化)
|
||||
token: 认证 token(可选,也可从 Cookie 获取)
|
||||
|
||||
虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置
|
||||
|
||||
支持三种认证方式(按优先级):
|
||||
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||
2. Cookie 中的 maibot_session
|
||||
3. 直接使用 session token(兼容)
|
||||
|
||||
示例:ws://host/api/chat/ws?token=xxx
|
||||
"""
|
||||
is_authenticated = False
|
||||
|
||||
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||
if token and verify_ws_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
|
||||
|
||||
# 方式 2: 尝试从 Cookie 获取 session token
|
||||
if not is_authenticated:
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
if cookie_token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
is_authenticated = True
|
||||
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
|
||||
|
||||
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||
if not is_authenticated and token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("聊天 WebSocket 使用 session token 认证成功")
|
||||
|
||||
if not is_authenticated:
|
||||
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
# 生成会话 ID(每次连接都是新的)
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# 如果没有提供 user_id,生成一个新的
|
||||
if not user_id:
|
||||
user_id = f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}"
|
||||
elif not user_id.startswith(WEBUI_USER_ID_PREFIX):
|
||||
# 确保 user_id 有正确的前缀
|
||||
user_id = f"{WEBUI_USER_ID_PREFIX}{user_id}"
|
||||
|
||||
# 当前会话的虚拟身份配置(可通过消息动态更新)
|
||||
current_virtual_config: Optional[VirtualIdentityConfig] = None
|
||||
|
||||
# 如果 URL 参数中提供了虚拟身份信息,自动配置
|
||||
if platform and person_id:
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||
person = session.exec(statement).first()
|
||||
if person:
|
||||
# 使用前端传递的 group_id,如果没有则生成一个稳定的
|
||||
virtual_group_id = group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{platform}_{person.user_id}"
|
||||
current_virtual_config = VirtualIdentityConfig(
|
||||
enabled=True,
|
||||
platform=person.platform,
|
||||
person_id=person.person_id,
|
||||
user_id=person.user_id,
|
||||
user_nickname=person.person_name or person.user_nickname or person.user_id,
|
||||
group_id=virtual_group_id,
|
||||
group_name=group_name or "WebUI虚拟群聊",
|
||||
)
|
||||
logger.info(
|
||||
f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
|
||||
|
||||
await chat_manager.connect(websocket, session_id, user_id)
|
||||
|
||||
try:
|
||||
# 构建会话信息
|
||||
session_info_data: dict[str, Any] = {
|
||||
"type": "session_info",
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
"user_name": user_name,
|
||||
"bot_name": global_config.bot.nickname,
|
||||
}
|
||||
|
||||
# 如果有虚拟身份配置,添加到会话信息中
|
||||
if current_virtual_config and current_virtual_config.enabled:
|
||||
session_info_data["virtual_mode"] = True
|
||||
session_info_data["group_id"] = current_virtual_config.group_id
|
||||
session_info_data["virtual_identity"] = {
|
||||
"platform": current_virtual_config.platform,
|
||||
"user_id": current_virtual_config.user_id,
|
||||
"user_nickname": current_virtual_config.user_nickname,
|
||||
"group_name": current_virtual_config.group_name,
|
||||
}
|
||||
|
||||
# 发送会话信息(包含用户 ID,前端需要保存)
|
||||
await chat_manager.send_message(session_id, session_info_data)
|
||||
|
||||
# 发送历史记录(根据模式选择不同的群)
|
||||
if current_virtual_config and current_virtual_config.enabled:
|
||||
history = chat_history.get_history(50, current_virtual_config.group_id)
|
||||
else:
|
||||
history = chat_history.get_history(50)
|
||||
if history:
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "history",
|
||||
"messages": history,
|
||||
},
|
||||
)
|
||||
|
||||
# 发送欢迎消息(不保存到历史)
|
||||
if current_virtual_config and current_virtual_config.enabled:
|
||||
welcome_msg = f"已以 {current_virtual_config.user_nickname} 的身份连接到「{current_virtual_config.group_name}」,开始与 {global_config.bot.nickname} 对话吧!"
|
||||
else:
|
||||
welcome_msg = f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!"
|
||||
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": welcome_msg,
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
|
||||
if data.get("type") == "message":
|
||||
content = data.get("content", "").strip()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
# 用户可以更新昵称
|
||||
current_user_name = data.get("user_name", user_name)
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
timestamp = time.time()
|
||||
|
||||
# 确定发送者信息(根据是否使用虚拟身份)
|
||||
if current_virtual_config and current_virtual_config.enabled:
|
||||
sender_name = current_virtual_config.user_nickname or current_user_name
|
||||
sender_user_id = current_virtual_config.user_id or user_id
|
||||
else:
|
||||
sender_name = current_user_name
|
||||
sender_user_id = user_id
|
||||
|
||||
# 广播用户消息给所有连接(包括发送者)
|
||||
# 注意:用户消息会在 chat_bot.message_process 中自动保存到数据库
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
"type": "user_message",
|
||||
"content": content,
|
||||
"message_id": message_id,
|
||||
"timestamp": timestamp,
|
||||
"sender": {
|
||||
"name": sender_name,
|
||||
"user_id": sender_user_id,
|
||||
"is_bot": False,
|
||||
},
|
||||
"virtual_mode": current_virtual_config.enabled if current_virtual_config else False,
|
||||
}
|
||||
)
|
||||
|
||||
# 创建麦麦消息格式
|
||||
message_data = create_message_data(
|
||||
content=content,
|
||||
user_id=user_id,
|
||||
user_name=current_user_name,
|
||||
message_id=message_id,
|
||||
is_at_bot=True,
|
||||
virtual_config=current_virtual_config,
|
||||
)
|
||||
|
||||
try:
|
||||
# 显示正在输入状态
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
"type": "typing",
|
||||
"is_typing": True,
|
||||
}
|
||||
)
|
||||
|
||||
# 调用麦麦的消息处理
|
||||
await chat_bot.message_process(message_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息时出错: {e}")
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "error",
|
||||
"content": f"处理消息时出错: {str(e)}",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
finally:
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
"type": "typing",
|
||||
"is_typing": False,
|
||||
}
|
||||
)
|
||||
|
||||
elif data.get("type") == "ping":
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "pong",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
elif data.get("type") == "update_nickname":
|
||||
# 允许用户更新昵称
|
||||
if new_name := data.get("user_name", "").strip():
|
||||
current_user_name = new_name
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "nickname_updated",
|
||||
"user_name": current_user_name,
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
elif data.get("type") == "set_virtual_identity":
|
||||
# 设置或更新虚拟身份配置
|
||||
virtual_data = data.get("config", {})
|
||||
if virtual_data.get("enabled"):
|
||||
# 验证必要字段
|
||||
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "error",
|
||||
"content": "虚拟身份配置缺少必要字段: platform 和 person_id",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
# 获取用户信息
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(PersonInfo)
|
||||
.where(col(PersonInfo.person_id) == virtual_data.get("person_id"))
|
||||
.limit(1)
|
||||
)
|
||||
person = session.exec(statement).first()
|
||||
if not person:
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "error",
|
||||
"content": f"找不到用户: {virtual_data.get('person_id')}",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
# 生成虚拟群 ID
|
||||
custom_group_id = virtual_data.get("group_id")
|
||||
if custom_group_id:
|
||||
group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{custom_group_id}"
|
||||
else:
|
||||
group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{session_id[:8]}"
|
||||
|
||||
current_virtual_config = VirtualIdentityConfig(
|
||||
enabled=True,
|
||||
platform=person.platform,
|
||||
person_id=person.person_id,
|
||||
user_id=person.user_id,
|
||||
user_nickname=person.person_name or person.user_nickname or person.user_id,
|
||||
group_id=group_id,
|
||||
group_name=virtual_data.get("group_name", "WebUI虚拟群聊"),
|
||||
)
|
||||
|
||||
# 发送虚拟身份已激活的消息
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "virtual_identity_set",
|
||||
"config": {
|
||||
"enabled": True,
|
||||
"platform": current_virtual_config.platform,
|
||||
"user_id": current_virtual_config.user_id,
|
||||
"user_nickname": current_virtual_config.user_nickname,
|
||||
"group_id": current_virtual_config.group_id,
|
||||
"group_name": current_virtual_config.group_name,
|
||||
},
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
# 加载虚拟群的历史记录
|
||||
virtual_history = chat_history.get_history(50, current_virtual_config.group_id)
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "history",
|
||||
"messages": virtual_history,
|
||||
"group_id": current_virtual_config.group_id,
|
||||
},
|
||||
)
|
||||
|
||||
# 发送系统消息
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": f"已切换到虚拟身份模式:以 {current_virtual_config.user_nickname} 的身份在「{current_virtual_config.group_name}」与 {global_config.bot.nickname} 对话",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"设置虚拟身份失败: {e}")
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "error",
|
||||
"content": f"设置虚拟身份失败: {str(e)}",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
else:
|
||||
# 禁用虚拟身份模式
|
||||
current_virtual_config = None
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "virtual_identity_set",
|
||||
"config": {"enabled": False},
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
# 重新加载默认聊天室历史
|
||||
default_history = chat_history.get_history(50, WEBUI_CHAT_GROUP_ID)
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "history",
|
||||
"messages": default_history,
|
||||
"group_id": WEBUI_CHAT_GROUP_ID,
|
||||
},
|
||||
)
|
||||
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": "已切换回 WebUI 独立用户模式",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket 断开: session={session_id}, user={user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket 错误: {e}")
|
||||
finally:
|
||||
chat_manager.disconnect(session_id, user_id)
|
||||
|
||||
|
||||
@router.get("/info")
|
||||
async def get_chat_info(_auth: bool = Depends(require_auth)):
|
||||
"""获取聊天室信息"""
|
||||
return {
|
||||
"bot_name": global_config.bot.nickname,
|
||||
"platform": WEBUI_CHAT_PLATFORM,
|
||||
"group_id": WEBUI_CHAT_GROUP_ID,
|
||||
"active_sessions": len(chat_manager.active_connections),
|
||||
}
|
||||
|
||||
|
||||
def get_webui_chat_broadcaster() -> tuple[ChatConnectionManager, str]:
|
||||
"""获取 WebUI 聊天广播器,供外部模块使用
|
||||
|
||||
Returns:
|
||||
(chat_manager, WEBUI_CHAT_PLATFORM) 元组
|
||||
"""
|
||||
return (chat_manager, WEBUI_CHAT_PLATFORM)
|
||||
18
src/webui/routers/chat/__init__.py
Normal file
18
src/webui/routers/chat/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .routes import router
|
||||
from .support import ChatConnectionManager, WEBUI_CHAT_PLATFORM, chat_manager
|
||||
|
||||
|
||||
def get_webui_chat_broadcaster() -> tuple[ChatConnectionManager, str]:
|
||||
"""获取 WebUI 聊天广播器,供外部模块使用。"""
|
||||
return chat_manager, WEBUI_CHAT_PLATFORM
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ChatConnectionManager",
|
||||
"WEBUI_CHAT_PLATFORM",
|
||||
"chat_manager",
|
||||
"get_webui_chat_broadcaster",
|
||||
"router",
|
||||
]
|
||||
174
src/webui/routers/chat/routes.py
Normal file
174
src/webui/routers/chat/routes.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""本地聊天室路由 - WebUI 与麦麦直接对话。"""
|
||||
|
||||
import uuid
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect
|
||||
from sqlalchemy import case, func
|
||||
from sqlmodel import col, select
|
||||
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import PersonInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
from .support import (
|
||||
WEBUI_CHAT_GROUP_ID,
|
||||
WEBUI_CHAT_PLATFORM,
|
||||
authenticate_chat_websocket,
|
||||
chat_history,
|
||||
chat_manager,
|
||||
dispatch_chat_event,
|
||||
normalize_webui_user_id,
|
||||
resolve_initial_virtual_identity,
|
||||
send_initial_chat_state,
|
||||
)
|
||||
|
||||
logger = get_logger("webui.chat")
|
||||
|
||||
router = APIRouter(prefix="/api/chat", tags=["LocalChat"], dependencies=[Depends(require_auth)])
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
async def get_chat_history(
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
user_id: Optional[str] = Query(default=None),
|
||||
group_id: Optional[str] = Query(default=None),
|
||||
) -> dict[str, object]:
|
||||
"""获取聊天历史记录。"""
|
||||
del user_id
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
history = chat_history.get_history(limit, target_group_id)
|
||||
return {"success": True, "messages": history, "total": len(history)}
|
||||
|
||||
|
||||
@router.get("/platforms")
|
||||
async def get_available_platforms() -> dict[str, object]:
|
||||
"""获取可用平台列表。"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(PersonInfo.platform, func.count().label("count"))
|
||||
.group_by(PersonInfo.platform)
|
||||
.order_by(func.count().desc())
|
||||
)
|
||||
platforms = session.exec(statement).all()
|
||||
|
||||
result = [{"platform": platform, "count": count} for platform, count in platforms if platform]
|
||||
return {"success": True, "platforms": result}
|
||||
except Exception as e:
|
||||
logger.error(f"获取平台列表失败: {e}")
|
||||
return {"success": False, "error": str(e), "platforms": []}
|
||||
|
||||
|
||||
@router.get("/persons")
|
||||
async def get_persons_by_platform(
|
||||
platform: str = Query(..., description="平台名称"),
|
||||
search: Optional[str] = Query(default=None, description="搜索关键词"),
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
) -> dict[str, object]:
|
||||
"""获取指定平台的用户列表。"""
|
||||
try:
|
||||
statement = select(PersonInfo).where(col(PersonInfo.platform) == platform)
|
||||
if search:
|
||||
statement = statement.where(
|
||||
(col(PersonInfo.person_name).contains(search))
|
||||
| (col(PersonInfo.user_nickname).contains(search))
|
||||
| (col(PersonInfo.user_id).contains(search))
|
||||
)
|
||||
|
||||
statement = statement.order_by(
|
||||
case((col(PersonInfo.last_known_time).is_(None), 1), else_=0),
|
||||
col(PersonInfo.last_known_time).desc(),
|
||||
).limit(limit)
|
||||
|
||||
with get_db_session() as session:
|
||||
persons = session.exec(statement).all()
|
||||
|
||||
result = [
|
||||
{
|
||||
"person_id": person.person_id,
|
||||
"user_id": person.user_id,
|
||||
"person_name": person.person_name,
|
||||
"nickname": person.user_nickname,
|
||||
"is_known": person.is_known,
|
||||
"platform": person.platform,
|
||||
"display_name": person.person_name or person.user_nickname or person.user_id,
|
||||
}
|
||||
for person in persons
|
||||
]
|
||||
return {"success": True, "persons": result, "total": len(result)}
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户列表失败: {e}")
|
||||
return {"success": False, "error": str(e), "persons": []}
|
||||
|
||||
|
||||
@router.delete("/history")
|
||||
async def clear_chat_history(
|
||||
group_id: Optional[str] = Query(default=None),
|
||||
) -> dict[str, object]:
|
||||
"""清空聊天历史记录。"""
|
||||
deleted = chat_history.clear_history(group_id)
|
||||
return {"success": True, "message": f"已清空 {deleted} 条聊天记录"}
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_chat(
|
||||
websocket: WebSocket,
|
||||
user_id: Optional[str] = Query(default=None),
|
||||
user_name: Optional[str] = Query(default="WebUI用户"),
|
||||
platform: Optional[str] = Query(default=None),
|
||||
person_id: Optional[str] = Query(default=None),
|
||||
group_name: Optional[str] = Query(default=None),
|
||||
group_id: Optional[str] = Query(default=None),
|
||||
token: Optional[str] = Query(default=None),
|
||||
) -> None:
|
||||
"""WebSocket 聊天端点。"""
|
||||
if not await authenticate_chat_websocket(websocket, token):
|
||||
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
normalized_user_id = normalize_webui_user_id(user_id)
|
||||
current_user_name = user_name or "WebUI用户"
|
||||
current_virtual_config = resolve_initial_virtual_identity(platform, person_id, group_name, group_id)
|
||||
|
||||
await chat_manager.connect(websocket, session_id, normalized_user_id)
|
||||
try:
|
||||
await send_initial_chat_state(
|
||||
session_id=session_id,
|
||||
user_id=normalized_user_id,
|
||||
user_name=current_user_name,
|
||||
virtual_config=current_virtual_config,
|
||||
)
|
||||
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
current_user_name, current_virtual_config = await dispatch_chat_event(
|
||||
session_id=session_id,
|
||||
session_id_prefix=session_id[:8],
|
||||
data=data,
|
||||
current_user_name=current_user_name,
|
||||
normalized_user_id=normalized_user_id,
|
||||
current_virtual_config=current_virtual_config,
|
||||
)
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket 断开: session={session_id}, user={normalized_user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket 错误: {e}")
|
||||
finally:
|
||||
chat_manager.disconnect(session_id, normalized_user_id)
|
||||
|
||||
|
||||
@router.get("/info")
|
||||
async def get_chat_info() -> dict[str, object]:
|
||||
"""获取聊天室信息。"""
|
||||
return {
|
||||
"bot_name": global_config.bot.nickname,
|
||||
"platform": WEBUI_CHAT_PLATFORM,
|
||||
"group_id": WEBUI_CHAT_GROUP_ID,
|
||||
"active_sessions": len(chat_manager.active_connections),
|
||||
}
|
||||
614
src/webui/routers/chat/support.py
Normal file
614
src/webui/routers/chat/support.py
Normal file
@@ -0,0 +1,614 @@
|
||||
"""WebUI 聊天路由支持逻辑。"""
|
||||
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from fastapi import WebSocket
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import col, delete, select
|
||||
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Messages, PersonInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import find_messages
|
||||
from src.common.utils.system_utils import is_bot_self
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.webui.core import get_token_manager
|
||||
from src.webui.routers.websocket.auth import verify_ws_token
|
||||
|
||||
logger = get_logger("webui.chat")
|
||||
|
||||
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
|
||||
WEBUI_CHAT_PLATFORM = "webui"
|
||||
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
|
||||
WEBUI_USER_ID_PREFIX = "webui_user_"
|
||||
|
||||
|
||||
class VirtualIdentityConfig(BaseModel):
|
||||
"""虚拟身份配置。"""
|
||||
|
||||
enabled: bool = False
|
||||
platform: Optional[str] = None
|
||||
person_id: Optional[str] = None
|
||||
user_id: Optional[str] = None
|
||||
user_nickname: Optional[str] = None
|
||||
group_id: Optional[str] = None
|
||||
group_name: Optional[str] = None
|
||||
|
||||
|
||||
class ChatHistoryMessage(BaseModel):
|
||||
"""聊天历史消息。"""
|
||||
|
||||
id: str
|
||||
type: str
|
||||
content: str
|
||||
timestamp: float
|
||||
sender_name: str
|
||||
sender_id: Optional[str] = None
|
||||
is_bot: bool = False
|
||||
|
||||
|
||||
class ChatHistoryManager:
|
||||
"""聊天历史管理器。"""
|
||||
|
||||
def __init__(self, max_messages: int = 200) -> None:
|
||||
self.max_messages = max_messages
|
||||
|
||||
def _message_to_dict(self, msg: SessionMessage, group_id: Optional[str] = None) -> dict[str, Any]:
|
||||
user_info = msg.message_info.user_info
|
||||
user_id = user_info.user_id or ""
|
||||
is_bot = is_bot_self(user_id, msg.platform)
|
||||
|
||||
if not is_bot and group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
|
||||
is_bot = user_id == str(global_config.bot.qq_account)
|
||||
elif not is_bot:
|
||||
is_bot = not user_id.startswith(WEBUI_USER_ID_PREFIX)
|
||||
|
||||
return {
|
||||
"id": msg.message_id,
|
||||
"type": "bot" if is_bot else "user",
|
||||
"content": msg.processed_plain_text or msg.display_message or "",
|
||||
"timestamp": msg.timestamp.timestamp(),
|
||||
"sender_name": user_info.user_nickname or (global_config.bot.nickname if is_bot else "未知用户"),
|
||||
"sender_id": "bot" if is_bot else user_id,
|
||||
"is_bot": is_bot,
|
||||
}
|
||||
|
||||
def _resolve_session_id(self, group_id: Optional[str]) -> str:
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
return SessionUtils.calculate_session_id(WEBUI_CHAT_PLATFORM, group_id=target_group_id)
|
||||
|
||||
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> list[dict[str, Any]]:
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
session_id = self._resolve_session_id(target_group_id)
|
||||
try:
|
||||
messages = find_messages(
|
||||
session_id=session_id,
|
||||
limit=limit,
|
||||
limit_mode="latest",
|
||||
filter_command=False,
|
||||
)
|
||||
result = [self._message_to_dict(msg, target_group_id) for msg in messages]
|
||||
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录 (group_id={target_group_id})")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载聊天记录失败: {e}")
|
||||
return []
|
||||
|
||||
def clear_history(self, group_id: Optional[str] = None) -> int:
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
session_id = self._resolve_session_id(target_group_id)
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = delete(Messages).where(col(Messages.session_id) == session_id)
|
||||
result = session.exec(statement)
|
||||
deleted = result.rowcount or 0
|
||||
logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
|
||||
return deleted
|
||||
except Exception as e:
|
||||
logger.error(f"清空聊天记录失败: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
class ChatConnectionManager:
|
||||
"""聊天连接管理器。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.active_connections: dict[str, WebSocket] = {}
|
||||
self.user_sessions: dict[str, str] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket, session_id: str, user_id: str) -> None:
|
||||
await websocket.accept()
|
||||
self.active_connections[session_id] = websocket
|
||||
self.user_sessions[user_id] = session_id
|
||||
logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}")
|
||||
|
||||
def disconnect(self, session_id: str, user_id: str) -> None:
|
||||
if session_id in self.active_connections:
|
||||
del self.active_connections[session_id]
|
||||
if user_id in self.user_sessions and self.user_sessions[user_id] == session_id:
|
||||
del self.user_sessions[user_id]
|
||||
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
|
||||
|
||||
async def send_message(self, session_id: str, message: dict[str, Any]) -> None:
|
||||
if session_id in self.active_connections:
|
||||
try:
|
||||
await self.active_connections[session_id].send_json(message)
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
|
||||
async def broadcast(self, message: dict[str, Any]) -> None:
|
||||
for session_id in list(self.active_connections.keys()):
|
||||
await self.send_message(session_id, message)
|
||||
|
||||
|
||||
chat_history = ChatHistoryManager()
|
||||
chat_manager = ChatConnectionManager()
|
||||
|
||||
|
||||
def is_virtual_mode_enabled(virtual_config: Optional[VirtualIdentityConfig]) -> bool:
|
||||
return bool(virtual_config and virtual_config.enabled)
|
||||
|
||||
|
||||
async def authenticate_chat_websocket(websocket: WebSocket, token: Optional[str]) -> bool:
|
||||
if token and verify_ws_token(token):
|
||||
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
|
||||
return True
|
||||
|
||||
if cookie_token := websocket.cookies.get("maibot_session"):
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def normalize_webui_user_id(user_id: Optional[str]) -> str:
|
||||
if not user_id:
|
||||
return f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}"
|
||||
if user_id.startswith(WEBUI_USER_ID_PREFIX):
|
||||
return user_id
|
||||
return f"{WEBUI_USER_ID_PREFIX}{user_id}"
|
||||
|
||||
|
||||
def get_person_by_person_id(person_id: str) -> Optional[PersonInfo]:
|
||||
with get_db_session() as session:
|
||||
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||
return session.exec(statement).first()
|
||||
|
||||
|
||||
def build_virtual_identity_config(person: PersonInfo, group_id: str, group_name: str) -> VirtualIdentityConfig:
|
||||
return VirtualIdentityConfig(
|
||||
enabled=True,
|
||||
platform=person.platform,
|
||||
person_id=person.person_id,
|
||||
user_id=person.user_id,
|
||||
user_nickname=person.person_name or person.user_nickname or person.user_id,
|
||||
group_id=group_id,
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
|
||||
def resolve_initial_virtual_identity(
|
||||
platform: Optional[str],
|
||||
person_id: Optional[str],
|
||||
group_name: Optional[str],
|
||||
group_id: Optional[str],
|
||||
) -> Optional[VirtualIdentityConfig]:
|
||||
if not (platform and person_id):
|
||||
return None
|
||||
|
||||
try:
|
||||
person = get_person_by_person_id(person_id)
|
||||
if person is None:
|
||||
return None
|
||||
|
||||
virtual_group_id = group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{platform}_{person.user_id}"
|
||||
virtual_config = build_virtual_identity_config(
|
||||
person=person,
|
||||
group_id=virtual_group_id,
|
||||
group_name=group_name or "WebUI虚拟群聊",
|
||||
)
|
||||
logger.info(
|
||||
f"虚拟身份模式已通过 URL 参数激活: {virtual_config.user_nickname} @ {virtual_config.platform}, group_id={virtual_group_id}"
|
||||
)
|
||||
return virtual_config
|
||||
except Exception as e:
|
||||
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def build_session_info_message(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> dict[str, Any]:
|
||||
session_info_data: dict[str, Any] = {
|
||||
"type": "session_info",
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
"user_name": user_name,
|
||||
"bot_name": global_config.bot.nickname,
|
||||
}
|
||||
|
||||
if is_virtual_mode_enabled(virtual_config):
|
||||
assert virtual_config is not None
|
||||
session_info_data["virtual_mode"] = True
|
||||
session_info_data["group_id"] = virtual_config.group_id
|
||||
session_info_data["virtual_identity"] = {
|
||||
"platform": virtual_config.platform,
|
||||
"user_id": virtual_config.user_id,
|
||||
"user_nickname": virtual_config.user_nickname,
|
||||
"group_name": virtual_config.group_name,
|
||||
}
|
||||
|
||||
return session_info_data
|
||||
|
||||
|
||||
def get_active_history_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> Optional[str]:
|
||||
if is_virtual_mode_enabled(virtual_config):
|
||||
assert virtual_config is not None
|
||||
return virtual_config.group_id
|
||||
return None
|
||||
|
||||
|
||||
def build_welcome_message(virtual_config: Optional[VirtualIdentityConfig]) -> str:
|
||||
if is_virtual_mode_enabled(virtual_config):
|
||||
assert virtual_config is not None
|
||||
return (
|
||||
f"已以 {virtual_config.user_nickname} 的身份连接到「{virtual_config.group_name}」,"
|
||||
f"开始与 {global_config.bot.nickname} 对话吧!"
|
||||
)
|
||||
return f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!"
|
||||
|
||||
|
||||
async def send_chat_error(session_id: str, content: str) -> None:
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "error",
|
||||
"content": content,
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def send_initial_chat_state(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> None:
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
build_session_info_message(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
user_name=user_name,
|
||||
virtual_config=virtual_config,
|
||||
),
|
||||
)
|
||||
|
||||
if history := chat_history.get_history(50, get_active_history_group_id(virtual_config)):
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "history",
|
||||
"messages": history,
|
||||
},
|
||||
)
|
||||
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": build_welcome_message(virtual_config),
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def resolve_sender_identity(
|
||||
current_user_name: str,
|
||||
normalized_user_id: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> tuple[str, str]:
|
||||
if is_virtual_mode_enabled(virtual_config):
|
||||
assert virtual_config is not None
|
||||
return virtual_config.user_nickname or current_user_name, virtual_config.user_id or normalized_user_id
|
||||
return current_user_name, normalized_user_id
|
||||
|
||||
|
||||
def create_message_data(
|
||||
content: str,
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
message_id: Optional[str] = None,
|
||||
is_at_bot: bool = True,
|
||||
virtual_config: Optional[VirtualIdentityConfig] = None,
|
||||
) -> dict[str, Any]:
|
||||
if message_id is None:
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
if virtual_config and virtual_config.enabled:
|
||||
platform = virtual_config.platform or WEBUI_CHAT_PLATFORM
|
||||
group_id = virtual_config.group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{uuid.uuid4().hex[:8]}"
|
||||
group_name = virtual_config.group_name or "WebUI虚拟群聊"
|
||||
actual_user_id = virtual_config.user_id or user_id
|
||||
actual_user_name = virtual_config.user_nickname or user_name
|
||||
else:
|
||||
platform = WEBUI_CHAT_PLATFORM
|
||||
group_id = WEBUI_CHAT_GROUP_ID
|
||||
group_name = "WebUI本地聊天室"
|
||||
actual_user_id = user_id
|
||||
actual_user_name = user_name
|
||||
|
||||
return {
|
||||
"message_info": {
|
||||
"platform": platform,
|
||||
"message_id": message_id,
|
||||
"time": time.time(),
|
||||
"group_info": {
|
||||
"group_id": group_id,
|
||||
"group_name": group_name,
|
||||
"platform": platform,
|
||||
},
|
||||
"user_info": {
|
||||
"user_id": actual_user_id,
|
||||
"user_nickname": actual_user_name,
|
||||
"user_cardname": actual_user_name,
|
||||
"platform": platform,
|
||||
},
|
||||
"additional_config": {
|
||||
"at_bot": is_at_bot,
|
||||
},
|
||||
},
|
||||
"message_segment": {
|
||||
"type": "seglist",
|
||||
"data": [
|
||||
{
|
||||
"type": "text",
|
||||
"data": content,
|
||||
},
|
||||
{
|
||||
"type": "mention_bot",
|
||||
"data": "1.0",
|
||||
},
|
||||
],
|
||||
},
|
||||
"raw_message": content,
|
||||
"processed_plain_text": content,
|
||||
}
|
||||
|
||||
|
||||
async def handle_chat_message(
|
||||
session_id: str,
|
||||
data: dict[str, Any],
|
||||
current_user_name: str,
|
||||
normalized_user_id: str,
|
||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> str:
|
||||
content = str(data.get("content", "")).strip()
|
||||
if not content:
|
||||
return current_user_name
|
||||
|
||||
next_user_name = str(data.get("user_name", current_user_name))
|
||||
message_id = str(uuid.uuid4())
|
||||
timestamp = time.time()
|
||||
sender_name, sender_user_id = resolve_sender_identity(
|
||||
current_user_name=next_user_name,
|
||||
normalized_user_id=normalized_user_id,
|
||||
virtual_config=current_virtual_config,
|
||||
)
|
||||
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
"type": "user_message",
|
||||
"content": content,
|
||||
"message_id": message_id,
|
||||
"timestamp": timestamp,
|
||||
"sender": {
|
||||
"name": sender_name,
|
||||
"user_id": sender_user_id,
|
||||
"is_bot": False,
|
||||
},
|
||||
"virtual_mode": is_virtual_mode_enabled(current_virtual_config),
|
||||
}
|
||||
)
|
||||
|
||||
message_data = create_message_data(
|
||||
content=content,
|
||||
user_id=normalized_user_id,
|
||||
user_name=next_user_name,
|
||||
message_id=message_id,
|
||||
is_at_bot=True,
|
||||
virtual_config=current_virtual_config,
|
||||
)
|
||||
|
||||
try:
|
||||
await chat_manager.broadcast({"type": "typing", "is_typing": True})
|
||||
await chat_bot.message_process(message_data)
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息时出错: {e}")
|
||||
await send_chat_error(session_id, f"处理消息时出错: {str(e)}")
|
||||
finally:
|
||||
await chat_manager.broadcast({"type": "typing", "is_typing": False})
|
||||
|
||||
return next_user_name
|
||||
|
||||
|
||||
async def handle_chat_ping(session_id: str) -> None:
|
||||
await chat_manager.send_message(session_id, {"type": "pong", "timestamp": time.time()})
|
||||
|
||||
|
||||
async def handle_nickname_update(session_id: str, data: dict[str, Any], current_user_name: str) -> str:
|
||||
new_name = str(data.get("user_name", "")).strip()
|
||||
if not new_name:
|
||||
return current_user_name
|
||||
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "nickname_updated",
|
||||
"user_name": new_name,
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
return new_name
|
||||
|
||||
|
||||
async def enable_virtual_identity(
|
||||
session_id: str,
|
||||
session_prefix: str,
|
||||
virtual_data: dict[str, Any],
|
||||
) -> Optional[VirtualIdentityConfig]:
|
||||
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
|
||||
await send_chat_error(session_id, "虚拟身份配置缺少必要字段: platform 和 person_id")
|
||||
return None
|
||||
|
||||
person_id_value = str(virtual_data.get("person_id"))
|
||||
try:
|
||||
person = get_person_by_person_id(person_id_value)
|
||||
if not person:
|
||||
await send_chat_error(session_id, f"找不到用户: {person_id_value}")
|
||||
return None
|
||||
|
||||
custom_group_id = virtual_data.get("group_id")
|
||||
current_group_id = (
|
||||
f"{VIRTUAL_GROUP_ID_PREFIX}{custom_group_id}"
|
||||
if custom_group_id
|
||||
else f"{VIRTUAL_GROUP_ID_PREFIX}{session_prefix}"
|
||||
)
|
||||
current_virtual_config = build_virtual_identity_config(
|
||||
person=person,
|
||||
group_id=current_group_id,
|
||||
group_name=str(virtual_data.get("group_name", "WebUI虚拟群聊")),
|
||||
)
|
||||
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "virtual_identity_set",
|
||||
"config": {
|
||||
"enabled": True,
|
||||
"platform": current_virtual_config.platform,
|
||||
"user_id": current_virtual_config.user_id,
|
||||
"user_nickname": current_virtual_config.user_nickname,
|
||||
"group_id": current_virtual_config.group_id,
|
||||
"group_name": current_virtual_config.group_name,
|
||||
},
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "history",
|
||||
"messages": chat_history.get_history(50, current_virtual_config.group_id),
|
||||
"group_id": current_virtual_config.group_id,
|
||||
},
|
||||
)
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": (
|
||||
f"已切换到虚拟身份模式:以 {current_virtual_config.user_nickname} 的身份在"
|
||||
f"「{current_virtual_config.group_name}」与 {global_config.bot.nickname} 对话"
|
||||
),
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
return current_virtual_config
|
||||
except Exception as e:
|
||||
logger.error(f"设置虚拟身份失败: {e}")
|
||||
await send_chat_error(session_id, f"设置虚拟身份失败: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def disable_virtual_identity(session_id: str) -> None:
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "virtual_identity_set",
|
||||
"config": {"enabled": False},
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "history",
|
||||
"messages": chat_history.get_history(50, WEBUI_CHAT_GROUP_ID),
|
||||
"group_id": WEBUI_CHAT_GROUP_ID,
|
||||
},
|
||||
)
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": "已切换回 WebUI 独立用户模式",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def handle_virtual_identity_update(
|
||||
session_id: str,
|
||||
session_id_prefix: str,
|
||||
data: dict[str, Any],
|
||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> Optional[VirtualIdentityConfig]:
|
||||
virtual_data = cast(dict[str, Any], data.get("config", {}))
|
||||
if virtual_data.get("enabled"):
|
||||
next_config = await enable_virtual_identity(session_id, session_id_prefix, virtual_data)
|
||||
return next_config if next_config is not None else current_virtual_config
|
||||
|
||||
await disable_virtual_identity(session_id)
|
||||
return None
|
||||
|
||||
|
||||
async def dispatch_chat_event(
|
||||
session_id: str,
|
||||
session_id_prefix: str,
|
||||
data: dict[str, Any],
|
||||
current_user_name: str,
|
||||
normalized_user_id: str,
|
||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> tuple[str, Optional[VirtualIdentityConfig]]:
|
||||
event_type = data.get("type")
|
||||
if event_type == "message":
|
||||
next_user_name = await handle_chat_message(
|
||||
session_id=session_id,
|
||||
data=data,
|
||||
current_user_name=current_user_name,
|
||||
normalized_user_id=normalized_user_id,
|
||||
current_virtual_config=current_virtual_config,
|
||||
)
|
||||
return next_user_name, current_virtual_config
|
||||
|
||||
if event_type == "ping":
|
||||
await handle_chat_ping(session_id)
|
||||
return current_user_name, current_virtual_config
|
||||
|
||||
if event_type == "update_nickname":
|
||||
next_user_name = await handle_nickname_update(session_id, data, current_user_name)
|
||||
return next_user_name, current_virtual_config
|
||||
|
||||
if event_type == "set_virtual_identity":
|
||||
next_virtual_config = await handle_virtual_identity_update(
|
||||
session_id=session_id,
|
||||
session_id_prefix=session_id_prefix,
|
||||
data=data,
|
||||
current_virtual_config=current_virtual_config,
|
||||
)
|
||||
return current_user_name, next_virtual_config
|
||||
|
||||
return current_user_name, current_virtual_config
|
||||
@@ -4,12 +4,13 @@
|
||||
|
||||
import copy
|
||||
import os
|
||||
import tomlkit
|
||||
from fastapi import APIRouter, HTTPException, Body, Depends, Cookie, Header
|
||||
from typing import Any, Annotated, Optional
|
||||
|
||||
import tomlkit
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.core import verify_auth_token_from_cookie_or_header
|
||||
from src.webui.dependencies import require_auth
|
||||
from src.webui.utils.toml_utils import save_toml_with_format, _update_toml_doc
|
||||
from src.config.config import Config, ModelConfig, CONFIG_DIR, PROJECT_ROOT
|
||||
from src.config.config_base import AttributeData
|
||||
@@ -49,7 +50,7 @@ SectionBody = Annotated[Any, Body()]
|
||||
RawContentBody = Annotated[str, Body(embed=True)]
|
||||
PathBody = Annotated[dict[str, str], Body()]
|
||||
|
||||
router = APIRouter(prefix="/config", tags=["config"])
|
||||
router = APIRouter(prefix="/config", tags=["config"], dependencies=[Depends(require_auth)])
|
||||
|
||||
|
||||
def _toml_to_plain_dict(obj: Any) -> Any:
|
||||
@@ -59,21 +60,11 @@ def _toml_to_plain_dict(obj: Any) -> Any:
|
||||
if isinstance(obj, list):
|
||||
return [_toml_to_plain_dict(v) for v in obj]
|
||||
return obj
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
# ===== 架构获取接口 =====
|
||||
|
||||
|
||||
@router.get("/schema/bot")
|
||||
async def get_bot_config_schema(_auth: bool = Depends(require_auth)):
|
||||
async def get_bot_config_schema():
|
||||
"""获取麦麦主程序配置架构"""
|
||||
try:
|
||||
# Config 类包含所有子配置
|
||||
@@ -85,7 +76,7 @@ async def get_bot_config_schema(_auth: bool = Depends(require_auth)):
|
||||
|
||||
|
||||
@router.get("/schema/model")
|
||||
async def get_model_config_schema(_auth: bool = Depends(require_auth)):
|
||||
async def get_model_config_schema():
|
||||
"""获取模型配置架构(包含提供商和模型任务配置)"""
|
||||
try:
|
||||
schema = ConfigSchemaGenerator.generate_config_schema(ModelConfig)
|
||||
@@ -99,7 +90,7 @@ async def get_model_config_schema(_auth: bool = Depends(require_auth)):
|
||||
|
||||
|
||||
@router.get("/schema/section/{section_name}")
|
||||
async def get_config_section_schema(section_name: str, _auth: bool = Depends(require_auth)):
|
||||
async def get_config_section_schema(section_name: str):
|
||||
"""
|
||||
获取指定配置节的架构
|
||||
|
||||
@@ -169,7 +160,7 @@ async def get_config_section_schema(section_name: str, _auth: bool = Depends(req
|
||||
|
||||
|
||||
@router.get("/bot")
|
||||
async def get_bot_config(_auth: bool = Depends(require_auth)):
|
||||
async def get_bot_config():
|
||||
"""获取麦麦主程序配置"""
|
||||
try:
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
@@ -188,7 +179,7 @@ async def get_bot_config(_auth: bool = Depends(require_auth)):
|
||||
|
||||
|
||||
@router.get("/model")
|
||||
async def get_model_config(_auth: bool = Depends(require_auth)):
|
||||
async def get_model_config():
|
||||
"""获取模型配置(包含提供商和模型任务配置)"""
|
||||
try:
|
||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
@@ -210,7 +201,7 @@ async def get_model_config(_auth: bool = Depends(require_auth)):
|
||||
|
||||
|
||||
@router.post("/bot")
|
||||
async def update_bot_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
|
||||
async def update_bot_config(config_data: ConfigBody):
|
||||
"""更新麦麦主程序配置"""
|
||||
try:
|
||||
# 验证配置数据
|
||||
@@ -233,7 +224,7 @@ async def update_bot_config(config_data: ConfigBody, _auth: bool = Depends(requi
|
||||
|
||||
|
||||
@router.post("/model")
|
||||
async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
|
||||
async def update_model_config(config_data: ConfigBody):
|
||||
"""更新模型配置"""
|
||||
try:
|
||||
# 验证配置数据
|
||||
@@ -259,7 +250,7 @@ async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(req
|
||||
|
||||
|
||||
@router.post("/bot/section/{section_name}")
|
||||
async def update_bot_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)):
|
||||
async def update_bot_config_section(section_name: str, section_data: SectionBody):
|
||||
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
|
||||
try:
|
||||
# 读取现有配置
|
||||
@@ -308,7 +299,7 @@ async def update_bot_config_section(section_name: str, section_data: SectionBody
|
||||
|
||||
|
||||
@router.get("/bot/raw")
|
||||
async def get_bot_config_raw(_auth: bool = Depends(require_auth)):
|
||||
async def get_bot_config_raw():
|
||||
"""获取麦麦主程序配置的原始 TOML 内容"""
|
||||
try:
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
@@ -327,7 +318,7 @@ async def get_bot_config_raw(_auth: bool = Depends(require_auth)):
|
||||
|
||||
|
||||
@router.post("/bot/raw")
|
||||
async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depends(require_auth)):
|
||||
async def update_bot_config_raw(raw_content: RawContentBody):
|
||||
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
|
||||
try:
|
||||
# 验证 TOML 格式
|
||||
@@ -357,9 +348,7 @@ async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depen
|
||||
|
||||
|
||||
@router.post("/model/section/{section_name}")
|
||||
async def update_model_config_section(
|
||||
section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)
|
||||
):
|
||||
async def update_model_config_section(section_name: str, section_data: SectionBody):
|
||||
"""更新模型配置的指定节(保留注释和格式)"""
|
||||
try:
|
||||
# 读取现有配置
|
||||
@@ -451,7 +440,7 @@ def _to_relative_path(path: str) -> str:
|
||||
|
||||
|
||||
@router.get("/adapter-config/path")
|
||||
async def get_adapter_config_path(_auth: bool = Depends(require_auth)):
|
||||
async def get_adapter_config_path():
|
||||
"""获取保存的适配器配置文件路径"""
|
||||
try:
|
||||
# 从 data/webui.json 读取路径偏好
|
||||
@@ -490,7 +479,7 @@ async def get_adapter_config_path(_auth: bool = Depends(require_auth)):
|
||||
|
||||
|
||||
@router.post("/adapter-config/path")
|
||||
async def save_adapter_config_path(data: PathBody, _auth: bool = Depends(require_auth)):
|
||||
async def save_adapter_config_path(data: PathBody):
|
||||
"""保存适配器配置文件路径偏好"""
|
||||
try:
|
||||
path = data.get("path")
|
||||
@@ -533,7 +522,7 @@ async def save_adapter_config_path(data: PathBody, _auth: bool = Depends(require
|
||||
|
||||
|
||||
@router.get("/adapter-config")
|
||||
async def get_adapter_config(path: str, _auth: bool = Depends(require_auth)):
|
||||
async def get_adapter_config(path: str):
|
||||
"""从指定路径读取适配器配置文件"""
|
||||
try:
|
||||
if not path:
|
||||
@@ -565,7 +554,7 @@ async def get_adapter_config(path: str, _auth: bool = Depends(require_auth)):
|
||||
|
||||
|
||||
@router.post("/adapter-config")
|
||||
async def save_adapter_config(data: PathBody, _auth: bool = Depends(require_auth)):
|
||||
async def save_adapter_config(data: PathBody):
|
||||
"""保存适配器配置到指定路径"""
|
||||
try:
|
||||
path = data.get("path")
|
||||
|
||||
3
src/webui/routers/emoji/__init__.py
Normal file
3
src/webui/routers/emoji/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .routes import router
|
||||
|
||||
__all__ = ["router"]
|
||||
File diff suppressed because it is too large
Load Diff
140
src/webui/routers/emoji/schemas.py
Normal file
140
src/webui/routers/emoji/schemas.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from typing import Annotated, List, Optional
|
||||
|
||||
from fastapi import File, Form, UploadFile
|
||||
from pydantic import BaseModel
|
||||
|
||||
from src.common.database.database_model import Images
|
||||
|
||||
EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")]
|
||||
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")]
|
||||
DescriptionForm = Annotated[str, Form(description="表情包描述")]
|
||||
EmotionForm = Annotated[str, Form(description="情感标签,多个用逗号分隔")]
|
||||
IsRegisteredForm = Annotated[bool, Form(description="是否直接注册")]
|
||||
|
||||
|
||||
class EmojiResponse(BaseModel):
|
||||
"""表情包响应"""
|
||||
|
||||
id: int
|
||||
full_path: str
|
||||
emoji_hash: str
|
||||
description: str
|
||||
query_count: int
|
||||
is_registered: bool
|
||||
is_banned: bool
|
||||
emotion: Optional[str]
|
||||
record_time: float
|
||||
register_time: Optional[float]
|
||||
last_used_time: Optional[float]
|
||||
|
||||
|
||||
class EmojiListResponse(BaseModel):
|
||||
"""表情包列表响应"""
|
||||
|
||||
success: bool
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
data: List[EmojiResponse]
|
||||
|
||||
|
||||
class EmojiDetailResponse(BaseModel):
|
||||
"""表情包详情响应"""
|
||||
|
||||
success: bool
|
||||
data: EmojiResponse
|
||||
|
||||
|
||||
class EmojiUpdateRequest(BaseModel):
|
||||
"""表情包更新请求"""
|
||||
|
||||
description: Optional[str] = None
|
||||
is_registered: Optional[bool] = None
|
||||
is_banned: Optional[bool] = None
|
||||
emotion: Optional[str] = None
|
||||
|
||||
|
||||
class EmojiUpdateResponse(BaseModel):
|
||||
"""表情包更新响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[EmojiResponse] = None
|
||||
|
||||
|
||||
class EmojiDeleteResponse(BaseModel):
|
||||
"""表情包删除响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class BatchDeleteRequest(BaseModel):
|
||||
"""批量删除请求"""
|
||||
|
||||
emoji_ids: List[int]
|
||||
|
||||
|
||||
class BatchDeleteResponse(BaseModel):
|
||||
"""批量删除响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
deleted_count: int
|
||||
failed_count: int
|
||||
failed_ids: List[int] = []
|
||||
|
||||
|
||||
class EmojiUploadResponse(BaseModel):
|
||||
"""表情包上传响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[EmojiResponse] = None
|
||||
|
||||
|
||||
class ThumbnailCacheStatsResponse(BaseModel):
|
||||
"""缩略图缓存统计响应"""
|
||||
|
||||
success: bool
|
||||
cache_dir: str
|
||||
total_count: int
|
||||
total_size_mb: float
|
||||
emoji_count: int
|
||||
coverage_percent: float
|
||||
|
||||
|
||||
class ThumbnailCleanupResponse(BaseModel):
|
||||
"""缩略图清理响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
cleaned_count: int
|
||||
kept_count: int
|
||||
|
||||
|
||||
class ThumbnailPreheatResponse(BaseModel):
|
||||
"""缩略图预热响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
generated_count: int
|
||||
skipped_count: int
|
||||
failed_count: int
|
||||
|
||||
|
||||
def emoji_to_response(image: Images) -> EmojiResponse:
|
||||
"""将数据库表情包模型转换为响应对象。"""
|
||||
return EmojiResponse(
|
||||
id=image.id if image.id is not None else 0,
|
||||
full_path=image.full_path,
|
||||
emoji_hash=image.image_hash,
|
||||
description=image.description,
|
||||
query_count=image.query_count,
|
||||
is_registered=image.is_registered,
|
||||
is_banned=image.is_banned,
|
||||
emotion=image.emotion,
|
||||
record_time=image.record_time.timestamp() if image.record_time else 0.0,
|
||||
register_time=image.register_time.timestamp() if image.register_time else None,
|
||||
last_used_time=image.last_used_time.timestamp() if image.last_used_time else None,
|
||||
)
|
||||
142
src/webui/routers/emoji/support.py
Normal file
142
src/webui/routers/emoji/support.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
import threading
|
||||
|
||||
from PIL import Image
|
||||
from sqlmodel import col, select
|
||||
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui.emoji")
|
||||
|
||||
THUMBNAIL_CACHE_DIR = Path("data/emoji_thumbnails")
|
||||
THUMBNAIL_SIZE = (200, 200)
|
||||
THUMBNAIL_QUALITY = 80
|
||||
EMOJI_REGISTERED_DIR = os.path.join("data", "emoji_registed")
|
||||
|
||||
_thumbnail_locks: dict[str, threading.Lock] = {}
|
||||
_locks_lock = threading.Lock()
|
||||
_thumbnail_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="thumbnail")
|
||||
_generating_thumbnails: set[str] = set()
|
||||
_generating_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_thumbnail_executor() -> ThreadPoolExecutor:
|
||||
"""获取缩略图生成线程池。"""
|
||||
return _thumbnail_executor
|
||||
|
||||
|
||||
def get_generating_lock() -> threading.Lock:
|
||||
"""获取缩略图生成状态锁。"""
|
||||
return _generating_lock
|
||||
|
||||
|
||||
def get_generating_thumbnails() -> set[str]:
|
||||
"""获取正在生成的缩略图哈希集合。"""
|
||||
return _generating_thumbnails
|
||||
|
||||
|
||||
def _get_thumbnail_lock(file_hash: str) -> threading.Lock:
|
||||
"""获取指定文件哈希的锁,用于防止并发生成同一缩略图。"""
|
||||
with _locks_lock:
|
||||
if file_hash not in _thumbnail_locks:
|
||||
_thumbnail_locks[file_hash] = threading.Lock()
|
||||
return _thumbnail_locks[file_hash]
|
||||
|
||||
|
||||
def _background_generate_thumbnail(source_path: str, file_hash: str) -> None:
|
||||
"""在线程池中后台生成缩略图。"""
|
||||
try:
|
||||
_generate_thumbnail(source_path, file_hash)
|
||||
except Exception as e:
|
||||
logger.warning(f"后台生成缩略图失败 {file_hash}: {e}")
|
||||
finally:
|
||||
with _generating_lock:
|
||||
_generating_thumbnails.discard(file_hash)
|
||||
|
||||
|
||||
def ensure_thumbnail_cache_dir() -> Path:
|
||||
"""确保缩略图缓存目录存在。"""
|
||||
_ = THUMBNAIL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
return THUMBNAIL_CACHE_DIR
|
||||
|
||||
|
||||
def get_thumbnail_cache_path(file_hash: str) -> Path:
|
||||
"""获取缩略图缓存路径。"""
|
||||
return THUMBNAIL_CACHE_DIR / f"{file_hash}.webp"
|
||||
|
||||
|
||||
def _generate_thumbnail(source_path: str, file_hash: str) -> Path:
|
||||
"""生成缩略图并保存到缓存目录。"""
|
||||
ensure_thumbnail_cache_dir()
|
||||
cache_path = get_thumbnail_cache_path(file_hash)
|
||||
|
||||
lock = _get_thumbnail_lock(file_hash)
|
||||
with lock:
|
||||
if cache_path.exists():
|
||||
return cache_path
|
||||
|
||||
try:
|
||||
with Image.open(source_path) as img:
|
||||
if getattr(img, "n_frames", 1) > 1:
|
||||
img.seek(0)
|
||||
|
||||
if img.mode in ("P", "PA"):
|
||||
img = img.convert("RGBA")
|
||||
elif img.mode == "LA":
|
||||
img = img.convert("RGBA")
|
||||
elif img.mode not in ("RGB", "RGBA"):
|
||||
img = img.convert("RGB")
|
||||
|
||||
img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS)
|
||||
img.save(cache_path, "WEBP", quality=THUMBNAIL_QUALITY, method=6)
|
||||
logger.debug(f"生成缩略图: {file_hash} -> {cache_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"生成缩略图失败 {file_hash}: {e},将返回原图")
|
||||
raise
|
||||
|
||||
return cache_path
|
||||
|
||||
|
||||
def generate_thumbnail(source_path: str, file_hash: str) -> Path:
|
||||
"""暴露给路由层的缩略图生成函数。"""
|
||||
return _generate_thumbnail(source_path, file_hash)
|
||||
|
||||
|
||||
def background_generate_thumbnail(source_path: str, file_hash: str) -> None:
|
||||
"""暴露给路由层的后台缩略图生成函数。"""
|
||||
_background_generate_thumbnail(source_path, file_hash)
|
||||
|
||||
|
||||
def cleanup_orphaned_thumbnails() -> tuple[int, int]:
|
||||
"""清理孤立的缩略图缓存。"""
|
||||
if not THUMBNAIL_CACHE_DIR.exists():
|
||||
return 0, 0
|
||||
|
||||
with get_db_session() as session:
|
||||
statement = select(Images.image_hash).where(col(Images.image_type) == ImageType.EMOJI)
|
||||
valid_hashes = set(session.exec(statement).all())
|
||||
|
||||
cleaned = 0
|
||||
kept = 0
|
||||
|
||||
for cache_file in THUMBNAIL_CACHE_DIR.glob("*.webp"):
|
||||
file_hash = cache_file.stem
|
||||
if file_hash not in valid_hashes:
|
||||
try:
|
||||
cache_file.unlink()
|
||||
cleaned += 1
|
||||
logger.debug(f"清理孤立缩略图: {cache_file.name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"清理缩略图失败 {cache_file.name}: {e}")
|
||||
else:
|
||||
kept += 1
|
||||
|
||||
if cleaned > 0:
|
||||
logger.info(f"清理孤立缩略图: 删除 {cleaned} 个,保留 {kept} 个")
|
||||
|
||||
return cleaned, kept
|
||||
@@ -1,9 +1,10 @@
|
||||
"""表达方式管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sqlalchemy import case, func
|
||||
from sqlmodel import col, select, delete
|
||||
@@ -12,12 +13,12 @@ from src.common.logger import get_logger
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Expression
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.webui.core import verify_auth_token_from_cookie_or_header
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
logger = get_logger("webui.expression")
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter(prefix="/expression", tags=["Expression"])
|
||||
router = APIRouter(prefix="/expression", tags=["Expression"], dependencies=[Depends(require_auth)])
|
||||
|
||||
|
||||
class ExpressionResponse(BaseModel):
|
||||
@@ -90,14 +91,6 @@ class ExpressionCreateResponse(BaseModel):
|
||||
data: ExpressionResponse
|
||||
|
||||
|
||||
def verify_auth_token(
|
||||
maibot_session: Optional[str] = None,
|
||||
authorization: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""验证认证 Token,支持 Cookie 和 Header"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
def expression_to_response(expression: Expression) -> ExpressionResponse:
|
||||
"""将 Expression 模型转换为响应对象"""
|
||||
last_active_time = expression.last_active_time.timestamp() if expression.last_active_time else 0.0
|
||||
@@ -156,19 +149,14 @@ class ChatListResponse(BaseModel):
|
||||
|
||||
|
||||
@router.get("/chats", response_model=ChatListResponse)
|
||||
async def get_chat_list(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def get_chat_list():
|
||||
"""
|
||||
获取所有聊天列表(用于下拉选择)
|
||||
|
||||
Args:
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
聊天列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
chat_list = []
|
||||
for session_id, session in _chat_manager.sessions.items():
|
||||
chat_name = _chat_manager.get_session_name(session_id) or session_id
|
||||
@@ -199,8 +187,6 @@ async def get_expression_list(
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取表达方式列表
|
||||
@@ -210,14 +196,11 @@ async def get_expression_list(
|
||||
page_size: 每页数量 (1-100)
|
||||
search: 搜索关键词 (匹配 situation, style)
|
||||
chat_id: 聊天ID筛选
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
表达方式列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
# 构建查询
|
||||
statement = select(Expression)
|
||||
|
||||
@@ -264,22 +247,17 @@ async def get_expression_list(
|
||||
|
||||
|
||||
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
|
||||
async def get_expression_detail(
|
||||
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def get_expression_detail(expression_id: int):
|
||||
"""
|
||||
获取表达方式详细信息
|
||||
|
||||
Args:
|
||||
expression_id: 表达方式ID
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
表达方式详细信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
with get_db_session() as session:
|
||||
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
|
||||
expression = session.exec(statement).first()
|
||||
@@ -299,22 +277,17 @@ async def get_expression_detail(
|
||||
@router.post("/", response_model=ExpressionCreateResponse)
|
||||
async def create_expression(
|
||||
request: ExpressionCreateRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
创建新的表达方式
|
||||
|
||||
Args:
|
||||
request: 创建请求
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
创建结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
current_time = datetime.now()
|
||||
|
||||
# 创建表达方式
|
||||
@@ -349,8 +322,6 @@ async def create_expression(
|
||||
async def update_expression(
|
||||
expression_id: int,
|
||||
request: ExpressionUpdateRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
增量更新表达方式(只更新提供的字段)
|
||||
@@ -358,14 +329,11 @@ async def update_expression(
|
||||
Args:
|
||||
expression_id: 表达方式ID
|
||||
request: 更新请求(只包含需要更新的字段)
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
with get_db_session() as session:
|
||||
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
|
||||
expression = session.exec(statement).first()
|
||||
@@ -411,22 +379,17 @@ async def update_expression(
|
||||
|
||||
|
||||
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
|
||||
async def delete_expression(
|
||||
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def delete_expression(expression_id: int):
|
||||
"""
|
||||
删除表达方式
|
||||
|
||||
Args:
|
||||
expression_id: 表达方式ID
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
with get_db_session() as session:
|
||||
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
|
||||
expression = session.exec(statement).first()
|
||||
@@ -461,22 +424,17 @@ class BatchDeleteRequest(BaseModel):
|
||||
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
|
||||
async def batch_delete_expressions(
|
||||
request: BatchDeleteRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
批量删除表达方式
|
||||
|
||||
Args:
|
||||
request: 包含要删除的ID列表的请求
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
if not request.ids:
|
||||
raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID")
|
||||
|
||||
@@ -506,21 +464,14 @@ async def batch_delete_expressions(
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_expression_stats(
|
||||
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def get_expression_stats():
|
||||
"""
|
||||
获取表达方式统计数据
|
||||
|
||||
Args:
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
with get_db_session() as session:
|
||||
total = len(session.exec(select(Expression.id)).all())
|
||||
|
||||
@@ -569,7 +520,7 @@ class ReviewStatsResponse(BaseModel):
|
||||
|
||||
|
||||
@router.get("/review/stats", response_model=ReviewStatsResponse)
|
||||
async def get_review_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def get_review_stats():
|
||||
"""
|
||||
获取审核统计数据
|
||||
|
||||
@@ -577,8 +528,6 @@ async def get_review_stats(maibot_session: Optional[str] = Cookie(None), authori
|
||||
审核统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
with get_db_session() as session:
|
||||
total = len(session.exec(select(Expression.id)).all())
|
||||
unchecked = 0
|
||||
@@ -620,8 +569,6 @@ async def get_review_list(
|
||||
filter_type: str = Query("unchecked", description="筛选类型: unchecked/passed/rejected/all"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取待审核/已审核的表达方式列表
|
||||
@@ -637,8 +584,6 @@ async def get_review_list(
|
||||
表达方式列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
statement = select(Expression)
|
||||
|
||||
if filter_type in {"unchecked", "passed", "rejected"}:
|
||||
@@ -728,8 +673,6 @@ class BatchReviewResponse(BaseModel):
|
||||
@router.post("/review/batch", response_model=BatchReviewResponse)
|
||||
async def batch_review_expressions(
|
||||
request: BatchReviewRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
批量审核表达方式
|
||||
@@ -741,8 +684,6 @@ async def batch_review_expressions(
|
||||
批量审核结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
if not request.items:
|
||||
raise HTTPException(status_code=400, detail="未提供要审核的表达方式")
|
||||
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
"""黑话(俚语)管理路由"""
|
||||
|
||||
from typing import Annotated, Any, List, Optional
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func as fn
|
||||
from sqlmodel import Session, col, delete, select
|
||||
|
||||
import json
|
||||
|
||||
from typing import Annotated, Any, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlmodel import Session, col, delete, select
|
||||
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import ChatSession, Jargon
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
logger = get_logger("webui.jargon")
|
||||
|
||||
router = APIRouter(prefix="/jargon", tags=["Jargon"])
|
||||
router = APIRouter(prefix="/jargon", tags=["Jargon"], dependencies=[Depends(require_auth)])
|
||||
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
@@ -33,14 +34,10 @@ def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
|
||||
parsed = json.loads(chat_id_str)
|
||||
if isinstance(parsed, list):
|
||||
# 格式: [["stream_id", user_id], ...]
|
||||
stream_ids = []
|
||||
for item in parsed:
|
||||
if isinstance(item, list) and len(item) >= 1:
|
||||
stream_ids.append(str(item[0]))
|
||||
return stream_ids
|
||||
else:
|
||||
# 其他格式,返回原始字符串
|
||||
return [chat_id_str]
|
||||
return [str(item[0]) for item in parsed if isinstance(item, list) and len(item) >= 1]
|
||||
|
||||
# 其他格式,返回原始字符串
|
||||
return [chat_id_str]
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# 不是有效的 JSON,可能是直接的 stream_id
|
||||
return [chat_id_str]
|
||||
@@ -57,9 +54,7 @@ def get_display_name_for_chat_id(chat_id_str: str, session: Session) -> str:
|
||||
return chat_id_str[:20]
|
||||
|
||||
stream_id = stream_ids[0]
|
||||
chat_session = session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first()
|
||||
|
||||
if not chat_session:
|
||||
if not (chat_session := session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first()):
|
||||
return stream_id[:20]
|
||||
|
||||
if chat_session.group_id:
|
||||
@@ -180,9 +175,59 @@ class ChatListResponse(BaseModel):
|
||||
# ==================== 工具函数 ====================
|
||||
|
||||
|
||||
def parse_session_id_dict(session_id_dict_str: Optional[str]) -> dict[str, int]:
|
||||
"""解析会话计数字典。"""
|
||||
if not session_id_dict_str:
|
||||
return {}
|
||||
|
||||
try:
|
||||
parsed = json.loads(session_id_dict_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return {}
|
||||
|
||||
if not isinstance(parsed, dict):
|
||||
return {}
|
||||
|
||||
session_counts: dict[str, int] = {}
|
||||
for session_id, count in parsed.items():
|
||||
if not isinstance(session_id, str):
|
||||
continue
|
||||
if isinstance(count, int):
|
||||
session_counts[session_id] = count
|
||||
else:
|
||||
try:
|
||||
session_counts[session_id] = int(count)
|
||||
except (TypeError, ValueError):
|
||||
session_counts[session_id] = 0
|
||||
return session_counts
|
||||
|
||||
|
||||
def dump_session_id_dict(session_counts: dict[str, int]) -> str:
|
||||
"""序列化会话计数字典。"""
|
||||
return json.dumps(session_counts, ensure_ascii=False)
|
||||
|
||||
|
||||
def get_primary_chat_id(session_id_dict_str: Optional[str]) -> str:
|
||||
"""从会话计数字典中选出主聊天 ID。"""
|
||||
if not (session_counts := parse_session_id_dict(session_id_dict_str)):
|
||||
return ""
|
||||
|
||||
return max(session_counts.items(), key=lambda item: item[1])[0]
|
||||
|
||||
|
||||
def has_chat_id(session_id_dict_str: Optional[str], chat_id: str) -> bool:
|
||||
"""判断记录是否包含指定聊天 ID。"""
|
||||
return chat_id in parse_session_id_dict(session_id_dict_str)
|
||||
|
||||
|
||||
def build_session_id_dict_for_chat(chat_id: str, count: int = 1) -> str:
|
||||
"""为单个聊天 ID 构建会话计数字典。"""
|
||||
return dump_session_id_dict({chat_id: count})
|
||||
|
||||
|
||||
def jargon_to_dict(jargon: Jargon, session: Session) -> dict[str, Any]:
|
||||
"""将 Jargon ORM 对象转换为字典"""
|
||||
chat_id = jargon.session_id or ""
|
||||
chat_id = get_primary_chat_id(jargon.session_id_dict)
|
||||
chat_name = get_display_name_for_chat_id(chat_id, session) if chat_id else None
|
||||
|
||||
return {
|
||||
@@ -191,7 +236,7 @@ def jargon_to_dict(jargon: Jargon, session: Session) -> dict[str, Any]:
|
||||
"raw_content": jargon.raw_content,
|
||||
"meaning": jargon.meaning,
|
||||
"chat_id": chat_id,
|
||||
"stream_id": jargon.session_id,
|
||||
"stream_id": chat_id or None,
|
||||
"chat_name": chat_name,
|
||||
"count": jargon.count,
|
||||
"is_jargon": jargon.is_jargon,
|
||||
@@ -215,7 +260,6 @@ async def get_jargon_list(
|
||||
"""获取黑话列表"""
|
||||
try:
|
||||
statement = select(Jargon)
|
||||
count_statement = select(fn.count()).select_from(Jargon)
|
||||
|
||||
if search:
|
||||
search_filter = (
|
||||
@@ -224,28 +268,28 @@ async def get_jargon_list(
|
||||
| (col(Jargon.raw_content).contains(search))
|
||||
)
|
||||
statement = statement.where(search_filter)
|
||||
count_statement = count_statement.where(search_filter)
|
||||
|
||||
if chat_id:
|
||||
stream_ids = parse_chat_id_to_stream_ids(chat_id)
|
||||
if stream_ids:
|
||||
chat_filter = col(Jargon.session_id).contains(stream_ids[0])
|
||||
else:
|
||||
chat_filter = col(Jargon.session_id) == chat_id
|
||||
statement = statement.where(chat_filter)
|
||||
count_statement = count_statement.where(chat_filter)
|
||||
|
||||
if is_jargon is not None:
|
||||
statement = statement.where(col(Jargon.is_jargon) == is_jargon)
|
||||
count_statement = count_statement.where(col(Jargon.is_jargon) == is_jargon)
|
||||
|
||||
statement = statement.order_by(col(Jargon.count).desc(), col(Jargon.id).desc())
|
||||
statement = statement.offset((page - 1) * page_size).limit(page_size)
|
||||
|
||||
with get_db_session() as session:
|
||||
total = session.exec(count_statement).one()
|
||||
jargons = session.exec(statement).all()
|
||||
data = [jargon_to_dict(jargon, session) for jargon in jargons]
|
||||
|
||||
if chat_id:
|
||||
stream_ids = parse_chat_id_to_stream_ids(chat_id)
|
||||
chat_ids = stream_ids or [chat_id]
|
||||
jargons = [
|
||||
jargon
|
||||
for jargon in jargons
|
||||
if any(has_chat_id(jargon.session_id_dict, current_chat_id) for current_chat_id in chat_ids)
|
||||
]
|
||||
|
||||
total = len(jargons)
|
||||
offset = (page - 1) * page_size
|
||||
page_jargons = jargons[offset : offset + page_size]
|
||||
data = [jargon_to_dict(jargon, session) for jargon in page_jargons]
|
||||
|
||||
return JargonListResponse(
|
||||
success=True,
|
||||
@@ -265,22 +309,16 @@ async def get_chat_list():
|
||||
"""获取所有有黑话记录的聊天列表"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Jargon.session_id).distinct().where(col(Jargon.session_id).is_not(None))
|
||||
chat_id_list = [chat_id for chat_id in session.exec(statement).all() if chat_id]
|
||||
jargons = session.exec(select(Jargon)).all()
|
||||
|
||||
# 用于按 stream_id 去重
|
||||
seen_stream_ids: set[str] = set()
|
||||
|
||||
for chat_id in chat_id_list:
|
||||
stream_ids = parse_chat_id_to_stream_ids(chat_id)
|
||||
if stream_ids:
|
||||
seen_stream_ids.add(stream_ids[0])
|
||||
for jargon in jargons:
|
||||
seen_stream_ids.update(parse_session_id_dict(jargon.session_id_dict).keys())
|
||||
|
||||
result = []
|
||||
with get_db_session() as session:
|
||||
for stream_id in seen_stream_ids:
|
||||
chat_session = session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first()
|
||||
if chat_session:
|
||||
if chat_session := session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first():
|
||||
chat_name = str(chat_session.group_id) if chat_session.group_id else stream_id[:20]
|
||||
result.append(
|
||||
ChatInfoResponse(
|
||||
@@ -312,30 +350,21 @@ async def get_jargon_stats():
|
||||
"""获取黑话统计数据"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
total = session.exec(select(fn.count()).select_from(Jargon)).one()
|
||||
jargons = session.exec(select(Jargon)).all()
|
||||
|
||||
confirmed_jargon = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon))).one()
|
||||
confirmed_not_jargon = session.exec(
|
||||
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(False))
|
||||
).one()
|
||||
pending = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(None))).one()
|
||||
total = len(jargons)
|
||||
confirmed_jargon = sum(jargon.is_jargon is True for jargon in jargons)
|
||||
confirmed_not_jargon = sum(jargon.is_jargon is False for jargon in jargons)
|
||||
pending = sum(jargon.is_jargon is None for jargon in jargons)
|
||||
complete_count = sum(jargon.is_complete for jargon in jargons)
|
||||
|
||||
complete_count = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_complete))).one()
|
||||
top_chats_counter: dict[str, int] = {}
|
||||
for jargon in jargons:
|
||||
for session_id in parse_session_id_dict(jargon.session_id_dict):
|
||||
top_chats_counter[session_id] = top_chats_counter.get(session_id, 0) + 1
|
||||
|
||||
chat_count = session.exec(
|
||||
select(fn.count()).select_from(
|
||||
select(col(Jargon.session_id)).distinct().where(col(Jargon.session_id).is_not(None)).subquery()
|
||||
)
|
||||
).one()
|
||||
|
||||
top_chats = session.exec(
|
||||
select(col(Jargon.session_id), fn.count().label("count"))
|
||||
.where(col(Jargon.session_id).is_not(None))
|
||||
.group_by(col(Jargon.session_id))
|
||||
.order_by(fn.count().desc())
|
||||
.limit(5)
|
||||
).all()
|
||||
top_chats_dict = {session_id: count for session_id, count in top_chats if session_id}
|
||||
top_chats_dict = dict(sorted(top_chats_counter.items(), key=lambda item: item[1], reverse=True)[:5])
|
||||
chat_count = len(top_chats_counter)
|
||||
|
||||
return JargonStatsResponse(
|
||||
success=True,
|
||||
@@ -360,8 +389,7 @@ async def get_jargon_detail(jargon_id: int):
|
||||
"""获取黑话详情"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()
|
||||
if not jargon:
|
||||
if not (jargon := session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()):
|
||||
raise HTTPException(status_code=404, detail="黑话不存在")
|
||||
data = JargonResponse(**jargon_to_dict(jargon, session))
|
||||
|
||||
@@ -379,19 +407,19 @@ async def create_jargon(request: JargonCreateRequest):
|
||||
"""创建黑话"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
existing = session.exec(
|
||||
select(Jargon).where(
|
||||
(col(Jargon.content) == request.content) & (col(Jargon.session_id) == request.chat_id)
|
||||
)
|
||||
).first()
|
||||
if existing:
|
||||
same_content_jargons = session.exec(select(Jargon).where(col(Jargon.content) == request.content)).all()
|
||||
existing = next(
|
||||
(jargon for jargon in same_content_jargons if has_chat_id(jargon.session_id_dict, request.chat_id)),
|
||||
None,
|
||||
)
|
||||
if existing is not None:
|
||||
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
|
||||
|
||||
jargon = Jargon(
|
||||
content=request.content,
|
||||
raw_content=request.raw_content,
|
||||
meaning=request.meaning or "",
|
||||
session_id=request.chat_id,
|
||||
session_id_dict=build_session_id_dict_for_chat(request.chat_id),
|
||||
count=0,
|
||||
is_jargon=None,
|
||||
is_complete=False,
|
||||
@@ -420,13 +448,12 @@ async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
|
||||
if not jargon:
|
||||
raise HTTPException(status_code=404, detail="黑话不存在")
|
||||
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
if update_data:
|
||||
if update_data := request.model_dump(exclude_unset=True):
|
||||
for field, value in update_data.items():
|
||||
if field == "is_global":
|
||||
continue
|
||||
if field == "chat_id":
|
||||
jargon.session_id = value
|
||||
jargon.session_id_dict = build_session_id_dict_for_chat(value, max(jargon.count, 1))
|
||||
continue
|
||||
if value is not None or field in ["meaning", "raw_content", "is_jargon"]:
|
||||
setattr(jargon, field, value)
|
||||
|
||||
@@ -1,18 +1,28 @@
|
||||
"""知识库图谱可视化 API 路由"""
|
||||
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Query, Depends, Cookie, Header
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
from src.webui.core import verify_auth_token_from_cookie_or_header
|
||||
from src.config.config import global_config
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
|
||||
router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"], dependencies=[Depends(require_auth)])
|
||||
|
||||
# 延迟初始化的轻量级 embedding store(只读,仅用于获取段落完整文本)
|
||||
_paragraph_store_cache = None
|
||||
_paragraph_store_cache: Any = None
|
||||
|
||||
|
||||
def _get_embedding_dir() -> str:
|
||||
"""获取 embedding 数据目录。"""
|
||||
import os
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_path = os.path.abspath(os.path.join(current_dir, "..", ".."))
|
||||
return os.path.join(root_path, "data/embedding")
|
||||
|
||||
|
||||
def _get_paragraph_store():
|
||||
@@ -31,17 +41,11 @@ def _get_paragraph_store():
|
||||
|
||||
try:
|
||||
from src.chat.knowledge.embedding_store import EmbeddingStore
|
||||
import os
|
||||
|
||||
# 获取数据路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_path = os.path.abspath(os.path.join(current_dir, "..", ".."))
|
||||
embedding_dir = os.path.join(root_path, "data/embedding")
|
||||
|
||||
# 只加载段落 embedding store(轻量级)
|
||||
paragraph_store = EmbeddingStore(
|
||||
namespace="paragraph",
|
||||
dir_path=embedding_dir,
|
||||
dir_path=_get_embedding_dir(),
|
||||
max_workers=1, # 只读不需要多线程
|
||||
chunk_size=100,
|
||||
)
|
||||
@@ -74,8 +78,7 @@ def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]:
|
||||
paragraph_item = paragraph_store.store.get(node_id)
|
||||
if paragraph_item is not None:
|
||||
# paragraph_item 是 EmbeddingStoreItem,其 str 属性包含完整文本
|
||||
content: str = getattr(paragraph_item, "str", "")
|
||||
if content:
|
||||
if content := getattr(paragraph_item, "str", ""):
|
||||
return content, True
|
||||
return None, True
|
||||
except Exception as e:
|
||||
@@ -83,14 +86,6 @@ def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]:
|
||||
return None, True
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
class KnowledgeNode(BaseModel):
|
||||
"""知识节点"""
|
||||
|
||||
@@ -205,7 +200,6 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
|
||||
async def get_knowledge_graph(
|
||||
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
|
||||
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""获取知识图谱(限制节点数量)
|
||||
|
||||
@@ -303,7 +297,7 @@ async def get_knowledge_graph(
|
||||
|
||||
|
||||
@router.get("/stats", response_model=KnowledgeStats)
|
||||
async def get_knowledge_stats(_auth: bool = Depends(require_auth)):
|
||||
async def get_knowledge_stats():
|
||||
"""获取知识库统计信息
|
||||
|
||||
Returns:
|
||||
@@ -352,7 +346,7 @@ async def get_knowledge_stats(_auth: bool = Depends(require_auth)):
|
||||
|
||||
|
||||
@router.get("/search", response_model=List[KnowledgeNode])
|
||||
async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bool = Depends(require_auth)):
|
||||
async def search_knowledge_node(query: str = Query(..., min_length=1)):
|
||||
"""搜索知识节点
|
||||
|
||||
Args:
|
||||
|
||||
@@ -6,27 +6,18 @@
|
||||
|
||||
import os
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException, Query, Depends, Cookie, Header
|
||||
from typing import Optional
|
||||
|
||||
import tomlkit
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import CONFIG_DIR
|
||||
from src.webui.core import verify_auth_token_from_cookie_or_header
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
logger = get_logger("webui")
|
||||
|
||||
router = APIRouter(prefix="/models", tags=["models"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/models", tags=["models"], dependencies=[Depends(require_auth)])
|
||||
# 模型获取器配置
|
||||
MODEL_FETCHER_CONFIG = {
|
||||
# OpenAI 兼容格式的提供商
|
||||
@@ -44,9 +35,7 @@ MODEL_FETCHER_CONFIG = {
|
||||
|
||||
def _normalize_url(url: str) -> str:
|
||||
"""规范化 URL(去掉尾部斜杠)"""
|
||||
if not url:
|
||||
return ""
|
||||
return url.rstrip("/")
|
||||
return url.rstrip("/") if url else ""
|
||||
|
||||
|
||||
def _parse_openai_response(data: dict) -> list[dict]:
|
||||
@@ -55,18 +44,18 @@ def _parse_openai_response(data: dict) -> list[dict]:
|
||||
|
||||
格式: { "data": [{ "id": "gpt-4", "object": "model", ... }] }
|
||||
"""
|
||||
models = []
|
||||
if "data" in data and isinstance(data["data"], list):
|
||||
for model in data["data"]:
|
||||
if isinstance(model, dict) and "id" in model:
|
||||
models.append(
|
||||
{
|
||||
"id": model["id"],
|
||||
"name": model.get("name") or model["id"],
|
||||
"owned_by": model.get("owned_by", ""),
|
||||
}
|
||||
)
|
||||
return models
|
||||
if "data" not in data or not isinstance(data["data"], list):
|
||||
return []
|
||||
|
||||
return [
|
||||
{
|
||||
"id": model["id"],
|
||||
"name": model.get("name") or model["id"],
|
||||
"owned_by": model.get("owned_by", ""),
|
||||
}
|
||||
for model in data["data"]
|
||||
if isinstance(model, dict) and "id" in model
|
||||
]
|
||||
|
||||
|
||||
def _parse_gemini_response(data: dict) -> list[dict]:
|
||||
@@ -178,11 +167,8 @@ def _get_provider_config(provider_name: str) -> Optional[dict]:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
providers = config_data.get("api_providers", [])
|
||||
for provider in providers:
|
||||
if provider.get("name") == provider_name:
|
||||
return dict(provider)
|
||||
|
||||
return None
|
||||
provider = next((provider for provider in providers if provider.get("name") == provider_name), None)
|
||||
return dict(provider) if provider is not None else None
|
||||
except Exception as e:
|
||||
logger.error(f"读取提供商配置失败: {e}")
|
||||
return None
|
||||
@@ -193,7 +179,6 @@ async def get_provider_models(
|
||||
provider_name: str = Query(..., description="提供商名称"),
|
||||
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
|
||||
endpoint: str = Query("/models", description="获取模型列表的端点"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
获取指定提供商的可用模型列表
|
||||
@@ -238,7 +223,6 @@ async def get_models_by_url(
|
||||
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
|
||||
endpoint: str = Query("/models", description="获取模型列表的端点"),
|
||||
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
通过 URL 直接获取模型列表(用于自定义提供商)
|
||||
@@ -262,7 +246,6 @@ async def get_models_by_url(
|
||||
async def test_provider_connection(
|
||||
base_url: str = Query(..., description="提供商的基础 URL"),
|
||||
api_key: Optional[str] = Query(None, description="API Key(可选,用于验证 Key 有效性)"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
测试提供商连接状态
|
||||
@@ -349,7 +332,6 @@ async def test_provider_connection(
|
||||
@router.post("/test-connection-by-name")
|
||||
async def test_provider_connection_by_name(
|
||||
provider_name: str = Query(..., description="提供商名称"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
通过提供商名称测试连接(从配置文件读取信息)
|
||||
@@ -364,11 +346,7 @@ async def test_provider_connection_by_name(
|
||||
|
||||
# 查找提供商
|
||||
providers = config.get("api_providers", [])
|
||||
provider = None
|
||||
for p in providers:
|
||||
if p.get("name") == provider_name:
|
||||
provider = p
|
||||
break
|
||||
provider = next((item for item in providers if item.get("name") == provider_name), None)
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
|
||||
@@ -380,4 +358,4 @@ async def test_provider_connection_by_name(
|
||||
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
|
||||
|
||||
# 调用测试接口
|
||||
return await test_provider_connection(base_url=base_url, api_key=api_key if api_key else None)
|
||||
return await test_provider_connection(base_url=base_url, api_key=api_key or None)
|
||||
|
||||
@@ -1,23 +1,24 @@
|
||||
"""人物信息管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict
|
||||
from datetime import datetime
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sqlalchemy import case
|
||||
from sqlmodel import col, select, delete
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import PersonInfo
|
||||
from src.webui.core import verify_auth_token_from_cookie_or_header
|
||||
import json
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
logger = get_logger("webui.person")
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter(prefix="/person", tags=["Person"])
|
||||
router = APIRouter(prefix="/person", tags=["Person"], dependencies=[Depends(require_auth)])
|
||||
|
||||
|
||||
class PersonInfoResponse(BaseModel):
|
||||
@@ -96,14 +97,6 @@ class BatchDeleteResponse(BaseModel):
|
||||
failed_ids: List[str] = []
|
||||
|
||||
|
||||
def verify_auth_token(
|
||||
maibot_session: Optional[str] = None,
|
||||
authorization: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""验证认证 Token,支持 Cookie 和 Header"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
def parse_group_nick_name(group_nick_name_str: Optional[str]) -> Optional[List[Dict[str, str]]]:
|
||||
"""解析群昵称 JSON 字符串"""
|
||||
if not group_nick_name_str:
|
||||
@@ -127,7 +120,7 @@ def person_to_response(person: PersonInfo) -> PersonInfoResponse:
|
||||
platform=person.platform,
|
||||
user_id=person.user_id,
|
||||
nickname=person.user_nickname,
|
||||
group_nick_name=parse_group_nick_name(person.group_nickname),
|
||||
group_nick_name=parse_group_nick_name(person.group_cardname),
|
||||
memory_points=person.memory_points,
|
||||
know_times=person.know_counts,
|
||||
know_since=know_since,
|
||||
@@ -142,8 +135,6 @@ async def get_person_list(
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
is_known: Optional[bool] = Query(None, description="是否已认识筛选"),
|
||||
platform: Optional[str] = Query(None, description="平台筛选"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取人物信息列表
|
||||
@@ -154,14 +145,11 @@ async def get_person_list(
|
||||
search: 搜索关键词 (匹配 person_name, nickname, user_id)
|
||||
is_known: 是否已认识筛选
|
||||
platform: 平台筛选
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
人物信息列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
# 构建查询
|
||||
statement = select(PersonInfo)
|
||||
|
||||
@@ -219,22 +207,17 @@ async def get_person_list(
|
||||
|
||||
|
||||
@router.get("/{person_id}", response_model=PersonDetailResponse)
|
||||
async def get_person_detail(
|
||||
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def get_person_detail(person_id: str):
|
||||
"""
|
||||
获取人物详细信息
|
||||
|
||||
Args:
|
||||
person_id: 人物唯一 ID
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
人物详细信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
with get_db_session() as session:
|
||||
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||
person = session.exec(statement).first()
|
||||
@@ -255,8 +238,6 @@ async def get_person_detail(
|
||||
async def update_person(
|
||||
person_id: str,
|
||||
request: PersonUpdateRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
增量更新人物信息(只更新提供的字段)
|
||||
@@ -264,14 +245,11 @@ async def update_person(
|
||||
Args:
|
||||
person_id: 人物唯一 ID
|
||||
request: 更新请求(只包含需要更新的字段)
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
with get_db_session() as session:
|
||||
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||
person = session.exec(statement).first()
|
||||
@@ -313,22 +291,17 @@ async def update_person(
|
||||
|
||||
|
||||
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
|
||||
async def delete_person(
|
||||
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def delete_person(person_id: str):
|
||||
"""
|
||||
删除人物信息
|
||||
|
||||
Args:
|
||||
person_id: 人物唯一 ID
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
with get_db_session() as session:
|
||||
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||
person = session.exec(statement).first()
|
||||
@@ -355,19 +328,14 @@ async def delete_person(
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def get_person_stats():
|
||||
"""
|
||||
获取人物信息统计数据
|
||||
|
||||
Args:
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
with get_db_session() as session:
|
||||
total = len(session.exec(select(PersonInfo.id)).all())
|
||||
known = len(session.exec(select(PersonInfo.id).where(col(PersonInfo.is_known))).all())
|
||||
@@ -392,22 +360,17 @@ async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authori
|
||||
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
||||
async def batch_delete_persons(
|
||||
request: BatchDeleteRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
批量删除人物信息
|
||||
|
||||
Args:
|
||||
request: 包含person_ids列表的请求
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
批量删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
if not request.person_ids:
|
||||
raise HTTPException(status_code=400, detail="未提供要删除的人物ID")
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
17
src/webui/routers/plugin/__init__.py
Normal file
17
src/webui/routers/plugin/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from src.webui.services.git_mirror_service import set_update_progress_callback
|
||||
|
||||
from .catalog import router as catalog_router
|
||||
from .config_routes import router as config_router
|
||||
from .management import router as management_router
|
||||
from .progress import get_progress_router, update_progress
|
||||
|
||||
router = APIRouter(prefix="/plugins", tags=["插件管理"])
|
||||
router.include_router(catalog_router)
|
||||
router.include_router(management_router)
|
||||
router.include_router(config_router)
|
||||
|
||||
set_update_progress_callback(update_progress)
|
||||
|
||||
__all__ = ["get_progress_router", "router"]
|
||||
205
src/webui/routers/plugin/catalog.py
Normal file
205
src/webui/routers/plugin/catalog.py
Normal file
@@ -0,0 +1,205 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, Cookie, HTTPException
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import MMC_VERSION
|
||||
from src.webui.services.git_mirror_service import get_git_mirror_service
|
||||
|
||||
from .progress import update_progress
|
||||
from .schemas import (
|
||||
AddMirrorRequest,
|
||||
AvailableMirrorsResponse,
|
||||
CloneRepositoryRequest,
|
||||
CloneRepositoryResponse,
|
||||
FetchRawFileRequest,
|
||||
FetchRawFileResponse,
|
||||
GitStatusResponse,
|
||||
MirrorConfigResponse,
|
||||
UpdateMirrorRequest,
|
||||
VersionResponse,
|
||||
)
|
||||
from .support import get_plugins_dir, parse_version, require_plugin_token, validate_safe_path
|
||||
|
||||
logger = get_logger("webui.plugin_routes")
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _mirror_to_response(mirror: dict[str, Any]) -> MirrorConfigResponse:
|
||||
return MirrorConfigResponse(
|
||||
id=mirror["id"],
|
||||
name=mirror["name"],
|
||||
raw_prefix=mirror["raw_prefix"],
|
||||
clone_prefix=mirror["clone_prefix"],
|
||||
enabled=mirror["enabled"],
|
||||
priority=mirror["priority"],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/version", response_model=VersionResponse)
|
||||
async def get_maimai_version() -> VersionResponse:
|
||||
major, minor, patch = parse_version(MMC_VERSION)
|
||||
return VersionResponse(version=MMC_VERSION, version_major=major, version_minor=minor, version_patch=patch)
|
||||
|
||||
|
||||
@router.get("/git-status", response_model=GitStatusResponse)
|
||||
async def check_git_status() -> GitStatusResponse:
|
||||
service = get_git_mirror_service()
|
||||
return GitStatusResponse(**service.check_git_installed())
|
||||
|
||||
|
||||
@router.get("/mirrors", response_model=AvailableMirrorsResponse)
|
||||
async def get_available_mirrors(maibot_session: Optional[str] = Cookie(None)) -> AvailableMirrorsResponse:
|
||||
require_plugin_token(maibot_session)
|
||||
|
||||
service = get_git_mirror_service()
|
||||
config = service.get_mirror_config()
|
||||
mirrors = [_mirror_to_response(mirror) for mirror in config.get_all_mirrors()]
|
||||
return AvailableMirrorsResponse(mirrors=mirrors, default_priority=config.get_default_priority_list())
|
||||
|
||||
|
||||
@router.post("/mirrors", response_model=MirrorConfigResponse)
|
||||
async def add_mirror(request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None)) -> MirrorConfigResponse:
|
||||
require_plugin_token(maibot_session)
|
||||
|
||||
try:
|
||||
service = get_git_mirror_service()
|
||||
config = service.get_mirror_config()
|
||||
mirror = config.add_mirror(
|
||||
mirror_id=request.id,
|
||||
name=request.name,
|
||||
raw_prefix=request.raw_prefix,
|
||||
clone_prefix=request.clone_prefix,
|
||||
enabled=request.enabled,
|
||||
priority=request.priority,
|
||||
)
|
||||
return _mirror_to_response(mirror)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
except Exception as e:
|
||||
logger.error(f"添加镜像源失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse)
|
||||
async def update_mirror(
|
||||
mirror_id: str,
|
||||
request: UpdateMirrorRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
) -> MirrorConfigResponse:
|
||||
require_plugin_token(maibot_session)
|
||||
|
||||
try:
|
||||
service = get_git_mirror_service()
|
||||
config = service.get_mirror_config()
|
||||
mirror = config.update_mirror(
|
||||
mirror_id=mirror_id,
|
||||
name=request.name,
|
||||
raw_prefix=request.raw_prefix,
|
||||
clone_prefix=request.clone_prefix,
|
||||
enabled=request.enabled,
|
||||
priority=request.priority,
|
||||
)
|
||||
if mirror is None:
|
||||
raise HTTPException(status_code=404, detail=f"未找到镜像源: {mirror_id}")
|
||||
return _mirror_to_response(mirror)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新镜像源失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.delete("/mirrors/{mirror_id}")
|
||||
async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
|
||||
service = get_git_mirror_service()
|
||||
config = service.get_mirror_config()
|
||||
if not config.delete_mirror(mirror_id):
|
||||
raise HTTPException(status_code=404, detail=f"未找到镜像源: {mirror_id}")
|
||||
return {"success": True, "message": f"已删除镜像源: {mirror_id}"}
|
||||
|
||||
|
||||
@router.post("/fetch-raw", response_model=FetchRawFileResponse)
|
||||
async def fetch_raw_file(
|
||||
request: FetchRawFileRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
) -> FetchRawFileResponse:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"收到获取 Raw 文件请求: {request.owner}/{request.repo}/{request.branch}/{request.file_path}")
|
||||
|
||||
await update_progress(
|
||||
stage="loading",
|
||||
progress=10,
|
||||
message=f"正在获取插件列表: {request.file_path}",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0,
|
||||
)
|
||||
|
||||
try:
|
||||
service = get_git_mirror_service()
|
||||
result = await service.fetch_raw_file(
|
||||
owner=request.owner,
|
||||
repo=request.repo,
|
||||
branch=request.branch,
|
||||
file_path=request.file_path,
|
||||
mirror_id=request.mirror_id,
|
||||
custom_url=request.custom_url,
|
||||
)
|
||||
|
||||
if result.get("success"):
|
||||
await update_progress(
|
||||
stage="loading",
|
||||
progress=70,
|
||||
message="正在解析插件数据...",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0,
|
||||
)
|
||||
try:
|
||||
data = json.loads(result.get("data", "[]"))
|
||||
total = len(data) if isinstance(data, list) else 0
|
||||
await update_progress(
|
||||
stage="success",
|
||||
progress=100,
|
||||
message=f"成功加载 {total} 个插件",
|
||||
total_plugins=total,
|
||||
loaded_plugins=total,
|
||||
)
|
||||
except Exception:
|
||||
await update_progress(stage="success", progress=100, message="加载完成", total_plugins=0, loaded_plugins=0)
|
||||
|
||||
return FetchRawFileResponse(**result)
|
||||
except Exception as e:
|
||||
logger.error(f"获取 Raw 文件失败: {e}")
|
||||
await update_progress(stage="error", progress=0, message="加载失败", error=str(e), total_plugins=0, loaded_plugins=0)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/clone", response_model=CloneRepositoryResponse)
|
||||
async def clone_repository(
|
||||
request: CloneRepositoryRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
) -> CloneRepositoryResponse:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}")
|
||||
|
||||
try:
|
||||
target_path = validate_safe_path(request.target_path, get_plugins_dir())
|
||||
service = get_git_mirror_service()
|
||||
result = await service.clone_repository(
|
||||
owner=request.owner,
|
||||
repo=request.repo,
|
||||
target_path=target_path,
|
||||
branch=request.branch,
|
||||
mirror_id=request.mirror_id,
|
||||
custom_url=request.custom_url,
|
||||
depth=request.depth,
|
||||
)
|
||||
return CloneRepositoryResponse(**result)
|
||||
except Exception as e:
|
||||
logger.error(f"克隆仓库失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
333
src/webui/routers/plugin/config_routes.py
Normal file
333
src/webui/routers/plugin/config_routes.py
Normal file
@@ -0,0 +1,333 @@
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import json
|
||||
import tomlkit
|
||||
|
||||
from fastapi import APIRouter, Cookie, HTTPException
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.utils.toml_utils import save_toml_with_format
|
||||
|
||||
from .schemas import UpdatePluginConfigRequest, UpdatePluginRawConfigRequest
|
||||
from .support import (
|
||||
backup_file,
|
||||
coerce_types,
|
||||
find_plugin_instance,
|
||||
find_plugin_path_by_id,
|
||||
normalize_dotted_keys,
|
||||
require_plugin_token,
|
||||
)
|
||||
|
||||
logger = get_logger("webui.plugin_routes")
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _build_schema_from_current_config(plugin_id: str, current_config: Any) -> dict[str, Any]:
|
||||
schema: dict[str, Any] = {
|
||||
"plugin_id": plugin_id,
|
||||
"plugin_info": {
|
||||
"name": plugin_id,
|
||||
"version": "",
|
||||
"description": "",
|
||||
"author": "",
|
||||
},
|
||||
"sections": {},
|
||||
"layout": {"type": "auto", "tabs": []},
|
||||
"_note": "插件未加载,仅返回当前配置结构",
|
||||
}
|
||||
|
||||
for section_name, section_data in current_config.items():
|
||||
if not isinstance(section_data, dict):
|
||||
continue
|
||||
section_fields: dict[str, Any] = {}
|
||||
for field_name, field_value in section_data.items():
|
||||
field_type = type(field_value).__name__
|
||||
ui_type = "text"
|
||||
item_type = None
|
||||
item_fields = None
|
||||
|
||||
if isinstance(field_value, bool):
|
||||
ui_type = "switch"
|
||||
elif isinstance(field_value, (int, float)):
|
||||
ui_type = "number"
|
||||
elif isinstance(field_value, list):
|
||||
ui_type = "list"
|
||||
if field_value:
|
||||
first_item = field_value[0]
|
||||
if isinstance(first_item, dict):
|
||||
item_type = "object"
|
||||
item_fields = {
|
||||
key: {
|
||||
"type": "number" if isinstance(value, (int, float)) else "string",
|
||||
"label": key,
|
||||
"default": "" if isinstance(value, str) else 0,
|
||||
}
|
||||
for key, value in first_item.items()
|
||||
}
|
||||
elif isinstance(first_item, (int, float)):
|
||||
item_type = "number"
|
||||
else:
|
||||
item_type = "string"
|
||||
else:
|
||||
item_type = "string"
|
||||
elif isinstance(field_value, dict):
|
||||
ui_type = "json"
|
||||
|
||||
section_fields[field_name] = {
|
||||
"name": field_name,
|
||||
"type": field_type,
|
||||
"default": field_value,
|
||||
"description": field_name,
|
||||
"label": field_name,
|
||||
"ui_type": ui_type,
|
||||
"required": False,
|
||||
"hidden": False,
|
||||
"disabled": False,
|
||||
"order": 0,
|
||||
"item_type": item_type,
|
||||
"item_fields": item_fields,
|
||||
"min_items": None,
|
||||
"max_items": None,
|
||||
"placeholder": None,
|
||||
"hint": None,
|
||||
"icon": None,
|
||||
"example": None,
|
||||
"choices": None,
|
||||
"min": None,
|
||||
"max": None,
|
||||
"step": None,
|
||||
"pattern": None,
|
||||
"max_length": None,
|
||||
"input_type": None,
|
||||
"rows": 3,
|
||||
"group": None,
|
||||
"depends_on": None,
|
||||
"depends_value": None,
|
||||
}
|
||||
|
||||
schema["sections"][section_name] = {
|
||||
"name": section_name,
|
||||
"title": section_name,
|
||||
"description": None,
|
||||
"icon": None,
|
||||
"collapsed": False,
|
||||
"order": 0,
|
||||
"fields": section_fields,
|
||||
}
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
@router.get("/config/{plugin_id}/schema")
|
||||
async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"获取插件配置 Schema: {plugin_id}")
|
||||
|
||||
try:
|
||||
plugin_instance = find_plugin_instance(plugin_id)
|
||||
if plugin_instance and hasattr(plugin_instance, "get_webui_config_schema"):
|
||||
return {"success": True, "schema": plugin_instance.get_webui_config_schema()}
|
||||
|
||||
plugin_path = find_plugin_path_by_id(plugin_id)
|
||||
if plugin_path is None:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
schema_json_path = plugin_path / "config_schema.json"
|
||||
if schema_json_path.exists():
|
||||
try:
|
||||
with open(schema_json_path, "r", encoding="utf-8") as file_obj:
|
||||
return {"success": True, "schema": json.load(file_obj)}
|
||||
except Exception as e:
|
||||
logger.warning(f"读取 config_schema.json 失败,回退到自动推断: {e}")
|
||||
|
||||
current_config: Any = {}
|
||||
config_path = plugin_path / "config.toml"
|
||||
if config_path.exists():
|
||||
with open(config_path, "r", encoding="utf-8") as file_obj:
|
||||
current_config = tomlkit.load(file_obj)
|
||||
|
||||
return {"success": True, "schema": _build_schema_from_current_config(plugin_id, current_config)}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取插件配置 Schema 失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/config/{plugin_id}/raw")
|
||||
async def get_plugin_config_raw(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"获取插件原始配置: {plugin_id}")
|
||||
|
||||
try:
|
||||
plugin_path = find_plugin_path_by_id(plugin_id)
|
||||
if plugin_path is None:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
config_path = plugin_path / "config.toml"
|
||||
if not config_path.exists():
|
||||
return {"success": True, "config": "", "message": "配置文件不存在"}
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as file_obj:
|
||||
return {"success": True, "config": file_obj.read()}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取插件原始配置失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.put("/config/{plugin_id}/raw")
|
||||
async def update_plugin_config_raw(
|
||||
plugin_id: str,
|
||||
request: UpdatePluginRawConfigRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
) -> dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"更新插件原始配置: {plugin_id}")
|
||||
|
||||
try:
|
||||
plugin_path = find_plugin_path_by_id(plugin_id)
|
||||
if plugin_path is None:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
config_path = plugin_path / "config.toml"
|
||||
try:
|
||||
tomlkit.loads(request.config)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
|
||||
|
||||
backup_path = backup_file(config_path, "backup")
|
||||
if backup_path is not None:
|
||||
logger.info(f"已备份配置文件: {backup_path}")
|
||||
|
||||
with open(config_path, "w", encoding="utf-8") as file_obj:
|
||||
file_obj.write(request.config)
|
||||
|
||||
logger.info(f"已更新插件原始配置: {plugin_id}")
|
||||
return {"success": True, "message": "配置已保存", "note": "配置更改将在插件重新加载后生效"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新插件原始配置失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/config/{plugin_id}")
|
||||
async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"获取插件配置: {plugin_id}")
|
||||
|
||||
try:
|
||||
plugin_path = find_plugin_path_by_id(plugin_id)
|
||||
if plugin_path is None:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
config_path = plugin_path / "config.toml"
|
||||
if not config_path.exists():
|
||||
return {"success": True, "config": {}, "message": "配置文件不存在"}
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as file_obj:
|
||||
config = tomlkit.load(file_obj)
|
||||
return {"success": True, "config": dict(config)}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取插件配置失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.put("/config/{plugin_id}")
|
||||
async def update_plugin_config(
|
||||
plugin_id: str,
|
||||
request: UpdatePluginConfigRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
) -> dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"更新插件配置: {plugin_id}")
|
||||
|
||||
try:
|
||||
plugin_instance = find_plugin_instance(plugin_id)
|
||||
config_data = request.config or {}
|
||||
if plugin_instance and isinstance(config_data, dict):
|
||||
config_data = normalize_dotted_keys(config_data)
|
||||
if isinstance(plugin_instance.config_schema, dict):
|
||||
coerce_types(plugin_instance.config_schema, config_data)
|
||||
|
||||
plugin_path = find_plugin_path_by_id(plugin_id)
|
||||
if plugin_path is None:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
config_path = plugin_path / "config.toml"
|
||||
backup_path = backup_file(config_path, "backup")
|
||||
if backup_path is not None:
|
||||
logger.info(f"已备份配置文件: {backup_path}")
|
||||
|
||||
save_toml_with_format(config_data, str(config_path))
|
||||
logger.info(f"已更新插件配置: {plugin_id}")
|
||||
return {"success": True, "message": "配置已保存", "note": "配置更改将在插件重新加载后生效"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新插件配置失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/config/{plugin_id}/reset")
|
||||
async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"重置插件配置: {plugin_id}")
|
||||
|
||||
try:
|
||||
plugin_path = find_plugin_path_by_id(plugin_id)
|
||||
if plugin_path is None:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
config_path = plugin_path / "config.toml"
|
||||
if not config_path.exists():
|
||||
return {"success": True, "message": "配置文件不存在,无需重置"}
|
||||
|
||||
backup_path = backup_file(config_path, "reset", move_file=True)
|
||||
logger.info(f"已重置插件配置: {plugin_id},备份: {backup_path}")
|
||||
return {"success": True, "message": "配置已重置,下次加载插件时将使用默认配置", "backup": str(backup_path)}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"重置插件配置失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/config/{plugin_id}/toggle")
|
||||
async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"切换插件状态: {plugin_id}")
|
||||
|
||||
try:
|
||||
plugin_path = find_plugin_path_by_id(plugin_id)
|
||||
if plugin_path is None:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
config_path = plugin_path / "config.toml"
|
||||
config = tomlkit.document()
|
||||
if config_path.exists():
|
||||
with open(config_path, "r", encoding="utf-8") as file_obj:
|
||||
config = tomlkit.load(file_obj)
|
||||
|
||||
if "plugin" not in config:
|
||||
config["plugin"] = tomlkit.table()
|
||||
|
||||
plugin_config = cast(Any, config["plugin"])
|
||||
current_enabled = bool(plugin_config.get("enabled", True))
|
||||
new_enabled = not current_enabled
|
||||
plugin_config["enabled"] = new_enabled
|
||||
save_toml_with_format(config, str(config_path))
|
||||
|
||||
status = "启用" if new_enabled else "禁用"
|
||||
logger.info(f"已{status}插件: {plugin_id}")
|
||||
return {"success": True, "enabled": new_enabled, "message": f"插件已{status}", "note": "状态更改将在下次加载插件时生效"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"切换插件状态失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
302
src/webui/routers/plugin/management.py
Normal file
302
src/webui/routers/plugin/management.py
Normal file
@@ -0,0 +1,302 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, Cookie, HTTPException
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.services.git_mirror_service import get_git_mirror_service
|
||||
|
||||
from .progress import update_progress
|
||||
from .schemas import InstallPluginRequest, UninstallPluginRequest, UpdatePluginRequest
|
||||
from .support import (
|
||||
find_plugin_path_by_id,
|
||||
get_plugin_candidate_paths,
|
||||
get_plugins_dir,
|
||||
load_manifest_json,
|
||||
parse_repository_url,
|
||||
remove_tree,
|
||||
require_plugin_token,
|
||||
resolve_installed_plugin_path,
|
||||
validate_plugin_id,
|
||||
)
|
||||
|
||||
logger = get_logger("webui.plugin_routes")
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _infer_plugin_id(folder_name: str, manifest: dict[str, Any], manifest_path: Path) -> str:
|
||||
if "id" in manifest:
|
||||
return str(manifest["id"])
|
||||
|
||||
author_name: Optional[str] = None
|
||||
repo_name: Optional[str] = None
|
||||
if "author" in manifest:
|
||||
author_data = manifest["author"]
|
||||
if isinstance(author_data, dict) and "name" in author_data:
|
||||
author_name = str(author_data["name"])
|
||||
elif isinstance(author_data, str):
|
||||
author_name = author_data
|
||||
|
||||
if "repository_url" in manifest:
|
||||
repo_url = str(manifest["repository_url"]).rstrip("/").removesuffix(".git")
|
||||
repo_name = repo_url.split("/")[-1]
|
||||
|
||||
if author_name and repo_name:
|
||||
plugin_id = f"{author_name}.{repo_name}"
|
||||
elif author_name:
|
||||
plugin_id = f"{author_name}.{folder_name}"
|
||||
elif "_" in folder_name and "." not in folder_name:
|
||||
plugin_id = folder_name.replace("_", ".", 1)
|
||||
else:
|
||||
plugin_id = folder_name
|
||||
|
||||
logger.info(f"为插件 {folder_name} 自动生成 ID: {plugin_id}")
|
||||
manifest["id"] = plugin_id
|
||||
try:
|
||||
with open(manifest_path, "w", encoding="utf-8") as file_obj:
|
||||
json.dump(manifest, file_obj, ensure_ascii=False, indent=2)
|
||||
except Exception as write_error:
|
||||
logger.warning(f"无法写入 ID 到 manifest: {write_error}")
|
||||
return plugin_id
|
||||
|
||||
|
||||
@router.post("/install")
|
||||
async def install_plugin(request: InstallPluginRequest, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"收到安装插件请求: {request.plugin_id}")
|
||||
plugin_id = request.plugin_id
|
||||
|
||||
try:
|
||||
plugin_id = validate_plugin_id(request.plugin_id)
|
||||
await update_progress(stage="loading", progress=5, message=f"开始安装插件: {plugin_id}", operation="install", plugin_id=plugin_id)
|
||||
|
||||
repo_url, owner, repo = parse_repository_url(request.repository_url)
|
||||
await update_progress(stage="loading", progress=10, message=f"解析仓库信息: {owner}/{repo}", operation="install", plugin_id=plugin_id)
|
||||
|
||||
target_path, old_format_path = get_plugin_candidate_paths(plugin_id)
|
||||
if target_path.exists() or old_format_path.exists():
|
||||
await update_progress(stage="error", progress=0, message="插件已存在", operation="install", plugin_id=plugin_id, error="插件已安装,请先卸载")
|
||||
raise HTTPException(status_code=400, detail="插件已安装")
|
||||
|
||||
await update_progress(stage="loading", progress=15, message=f"准备克隆到: {target_path}", operation="install", plugin_id=plugin_id)
|
||||
service = get_git_mirror_service()
|
||||
if "github.com" in repo_url:
|
||||
result = await service.clone_repository(owner=owner, repo=repo, target_path=target_path, branch=request.branch, mirror_id=request.mirror_id, depth=1)
|
||||
else:
|
||||
result = await service.clone_repository(owner=owner, repo=repo, target_path=target_path, branch=request.branch, custom_url=repo_url, depth=1)
|
||||
|
||||
if not result.get("success"):
|
||||
error_msg = str(result.get("error", "克隆失败"))
|
||||
await update_progress(stage="error", progress=0, message="克隆仓库失败", operation="install", plugin_id=plugin_id, error=error_msg)
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
|
||||
await update_progress(stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id)
|
||||
manifest_path = target_path / "_manifest.json"
|
||||
if not manifest_path.exists():
|
||||
remove_tree(target_path)
|
||||
await update_progress(stage="error", progress=0, message="插件缺少 _manifest.json", operation="install", plugin_id=plugin_id, error="无效的插件格式")
|
||||
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
||||
|
||||
await update_progress(stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=plugin_id)
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as file_obj:
|
||||
manifest = json.load(file_obj)
|
||||
for field in ["manifest_version", "name", "version", "author"]:
|
||||
if field not in manifest:
|
||||
raise ValueError(f"缺少必需字段: {field}")
|
||||
manifest["id"] = plugin_id
|
||||
with open(manifest_path, "w", encoding="utf-8") as file_obj:
|
||||
json.dump(manifest, file_obj, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
remove_tree(target_path)
|
||||
await update_progress(stage="error", progress=0, message="_manifest.json 无效", operation="install", plugin_id=plugin_id, error=str(e))
|
||||
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
||||
|
||||
await update_progress(stage="success", progress=100, message=f"成功安装插件: {manifest['name']} v{manifest['version']}", operation="install", plugin_id=plugin_id)
|
||||
return {"success": True, "message": "插件安装成功", "plugin_id": plugin_id, "plugin_name": manifest["name"], "version": manifest["version"], "path": str(target_path)}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"安装插件失败: {e}", exc_info=True)
|
||||
await update_progress(stage="error", progress=0, message="安装失败", operation="install", plugin_id=plugin_id, error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/uninstall")
|
||||
async def uninstall_plugin(request: UninstallPluginRequest, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"收到卸载插件请求: {request.plugin_id}")
|
||||
plugin_id = request.plugin_id
|
||||
|
||||
try:
|
||||
plugin_id = validate_plugin_id(request.plugin_id)
|
||||
await update_progress(stage="loading", progress=10, message=f"开始卸载插件: {plugin_id}", operation="uninstall", plugin_id=plugin_id)
|
||||
plugin_path = resolve_installed_plugin_path(plugin_id)
|
||||
if plugin_path is None:
|
||||
await update_progress(stage="error", progress=0, message="插件不存在", operation="uninstall", plugin_id=plugin_id, error="插件未安装或已被删除")
|
||||
raise HTTPException(status_code=404, detail="插件未安装")
|
||||
|
||||
await update_progress(stage="loading", progress=30, message=f"正在删除插件文件: {plugin_path}", operation="uninstall", plugin_id=plugin_id)
|
||||
manifest = load_manifest_json(plugin_path / "_manifest.json")
|
||||
plugin_name = str(manifest.get("name", plugin_id)) if manifest is not None else plugin_id
|
||||
await update_progress(stage="loading", progress=50, message=f"正在删除 {plugin_name}...", operation="uninstall", plugin_id=plugin_id)
|
||||
remove_tree(plugin_path)
|
||||
logger.info(f"成功卸载插件: {plugin_id} ({plugin_name})")
|
||||
await update_progress(stage="success", progress=100, message=f"成功卸载插件: {plugin_name}", operation="uninstall", plugin_id=plugin_id)
|
||||
return {"success": True, "message": "插件卸载成功", "plugin_id": plugin_id, "plugin_name": plugin_name}
|
||||
except HTTPException:
|
||||
raise
|
||||
except PermissionError as e:
|
||||
logger.error(f"卸载插件失败(权限错误): {e}")
|
||||
await update_progress(stage="error", progress=0, message="卸载失败", operation="uninstall", plugin_id=plugin_id, error="权限不足,无法删除插件文件")
|
||||
raise HTTPException(status_code=500, detail="权限不足,无法删除插件文件") from e
|
||||
except Exception as e:
|
||||
logger.error(f"卸载插件失败: {e}", exc_info=True)
|
||||
await update_progress(stage="error", progress=0, message="卸载失败", operation="uninstall", plugin_id=plugin_id, error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/update")
|
||||
async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"收到更新插件请求: {request.plugin_id}")
|
||||
plugin_id = request.plugin_id
|
||||
|
||||
try:
|
||||
plugin_id = validate_plugin_id(request.plugin_id)
|
||||
await update_progress(stage="loading", progress=5, message=f"开始更新插件: {plugin_id}", operation="update", plugin_id=plugin_id)
|
||||
plugin_path = resolve_installed_plugin_path(plugin_id)
|
||||
if plugin_path is None:
|
||||
await update_progress(stage="error", progress=0, message="插件不存在", operation="update", plugin_id=plugin_id, error="插件未安装,请先安装")
|
||||
raise HTTPException(status_code=404, detail="插件未安装")
|
||||
|
||||
manifest = load_manifest_json(plugin_path / "_manifest.json")
|
||||
old_version = str(manifest.get("version", "unknown")) if manifest is not None else "unknown"
|
||||
await update_progress(stage="loading", progress=10, message=f"当前版本: {old_version},准备更新...", operation="update", plugin_id=plugin_id)
|
||||
await update_progress(stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id)
|
||||
remove_tree(plugin_path)
|
||||
|
||||
await update_progress(stage="loading", progress=30, message="正在准备下载新版本...", operation="update", plugin_id=plugin_id)
|
||||
repo_url, owner, repo = parse_repository_url(request.repository_url)
|
||||
service = get_git_mirror_service()
|
||||
if "github.com" in repo_url:
|
||||
result = await service.clone_repository(owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, mirror_id=request.mirror_id, depth=1)
|
||||
else:
|
||||
result = await service.clone_repository(owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, custom_url=repo_url, depth=1)
|
||||
|
||||
if not result.get("success"):
|
||||
error_msg = str(result.get("error", "克隆失败"))
|
||||
await update_progress(stage="error", progress=0, message="下载新版本失败", operation="update", plugin_id=plugin_id, error=error_msg)
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
|
||||
await update_progress(stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id)
|
||||
new_manifest_path = plugin_path / "_manifest.json"
|
||||
if not new_manifest_path.exists():
|
||||
remove_tree(plugin_path)
|
||||
await update_progress(stage="error", progress=0, message="新版本缺少 _manifest.json", operation="update", plugin_id=plugin_id, error="无效的插件格式")
|
||||
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
||||
|
||||
try:
|
||||
with open(new_manifest_path, "r", encoding="utf-8") as file_obj:
|
||||
new_manifest = json.load(file_obj)
|
||||
new_version = str(new_manifest.get("version", "unknown"))
|
||||
new_name = str(new_manifest.get("name", plugin_id))
|
||||
logger.info(f"成功更新插件: {plugin_id} {old_version} → {new_version}")
|
||||
await update_progress(stage="success", progress=100, message=f"成功更新 {new_name}: {old_version} → {new_version}", operation="update", plugin_id=plugin_id)
|
||||
return {"success": True, "message": "插件更新成功", "plugin_id": plugin_id, "plugin_name": new_name, "old_version": old_version, "new_version": new_version}
|
||||
except Exception as e:
|
||||
remove_tree(plugin_path)
|
||||
await update_progress(stage="error", progress=0, message="_manifest.json 无效", operation="update", plugin_id=plugin_id, error=str(e))
|
||||
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新插件失败: {e}", exc_info=True)
|
||||
await update_progress(stage="error", progress=0, message="更新失败", operation="update", plugin_id=plugin_id, error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/installed")
|
||||
async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info("收到获取已安装插件列表请求")
|
||||
|
||||
try:
|
||||
plugins_dir = get_plugins_dir()
|
||||
installed_plugins: list[dict[str, Any]] = []
|
||||
for plugin_path in plugins_dir.iterdir():
|
||||
if not plugin_path.is_dir():
|
||||
continue
|
||||
folder_name = plugin_path.name
|
||||
if folder_name.startswith(".") or folder_name.startswith("__"):
|
||||
continue
|
||||
|
||||
manifest_path = plugin_path / "_manifest.json"
|
||||
if not manifest_path.exists():
|
||||
logger.warning(f"插件文件夹 {folder_name} 缺少 _manifest.json,跳过")
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as file_obj:
|
||||
manifest = json.load(file_obj)
|
||||
if "name" not in manifest or "version" not in manifest:
|
||||
logger.warning(f"插件文件夹 {folder_name} 的 _manifest.json 格式无效,跳过")
|
||||
continue
|
||||
plugin_id = _infer_plugin_id(folder_name, manifest, manifest_path)
|
||||
installed_plugins.append({"id": plugin_id, "manifest": manifest, "path": str(plugin_path.absolute())})
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"插件 {folder_name} 的 _manifest.json 解析失败: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"读取插件 {folder_name} 信息时出错: {e}")
|
||||
|
||||
seen_ids: dict[str, str] = {}
|
||||
unique_plugins: list[dict[str, Any]] = []
|
||||
duplicates: list[dict[str, Any]] = []
|
||||
for plugin in installed_plugins:
|
||||
plugin_id = str(plugin["id"])
|
||||
plugin_path = str(plugin["path"])
|
||||
if plugin_id not in seen_ids:
|
||||
seen_ids[plugin_id] = plugin_path
|
||||
unique_plugins.append(plugin)
|
||||
else:
|
||||
duplicates.append(plugin)
|
||||
logger.warning(f"重复插件 {plugin_id}: 保留 {seen_ids[plugin_id]}, 跳过 {plugin_path}")
|
||||
|
||||
if duplicates:
|
||||
logger.warning(f"共检测到 {len(duplicates)} 个重复插件已去重")
|
||||
|
||||
logger.info(f"找到 {len(unique_plugins)} 个已安装插件")
|
||||
return {"success": True, "plugins": unique_plugins, "total": len(unique_plugins)}
|
||||
except Exception as e:
|
||||
logger.error(f"获取已安装插件列表失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/local-readme/{plugin_id}")
|
||||
async def get_local_plugin_readme(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"获取本地插件 README: {plugin_id}")
|
||||
|
||||
try:
|
||||
plugin_path = find_plugin_path_by_id(plugin_id)
|
||||
if plugin_path is None:
|
||||
return {"success": False, "error": "插件未安装"}
|
||||
|
||||
for readme_name in ["README.md", "readme.md", "Readme.md", "README.MD"]:
|
||||
readme_path = plugin_path / readme_name
|
||||
if readme_path.exists():
|
||||
try:
|
||||
with open(readme_path, "r", encoding="utf-8") as file_obj:
|
||||
readme_content = file_obj.read()
|
||||
logger.info(f"成功读取本地 README: {readme_path}")
|
||||
return {"success": True, "data": readme_content}
|
||||
except Exception as e:
|
||||
logger.warning(f"读取 {readme_path} 失败: {e}")
|
||||
|
||||
return {"success": False, "error": "本地未找到 README 文件"}
|
||||
except Exception as e:
|
||||
logger.error(f"获取本地 README 失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
@@ -1,36 +1,32 @@
|
||||
"""WebSocket 插件加载进度推送模块"""
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
from typing import Set, Dict, Any, Optional
|
||||
import json
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from typing import Any, Optional, Set
|
||||
|
||||
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.core import get_token_manager
|
||||
from src.webui.routers.websocket.auth import verify_ws_token
|
||||
|
||||
logger = get_logger("webui.plugin_progress")
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter()
|
||||
|
||||
# 全局 WebSocket 连接池
|
||||
active_connections: Set[WebSocket] = set()
|
||||
|
||||
# 当前加载进度状态
|
||||
current_progress: Dict[str, Any] = {
|
||||
"operation": "idle", # idle, fetch, install, uninstall, update
|
||||
"stage": "idle", # idle, loading, success, error
|
||||
"progress": 0, # 0-100
|
||||
current_progress: dict[str, Any] = {
|
||||
"operation": "idle",
|
||||
"stage": "idle",
|
||||
"progress": 0,
|
||||
"message": "",
|
||||
"error": None,
|
||||
"plugin_id": None, # 当前操作的插件 ID
|
||||
"plugin_id": None,
|
||||
"total_plugins": 0,
|
||||
"loaded_plugins": 0,
|
||||
}
|
||||
|
||||
|
||||
async def broadcast_progress(progress_data: Dict[str, Any]):
|
||||
"""广播进度更新到所有连接的客户端"""
|
||||
async def broadcast_progress(progress_data: dict[str, Any]) -> None:
|
||||
global current_progress
|
||||
current_progress = progress_data.copy()
|
||||
|
||||
@@ -38,7 +34,7 @@ async def broadcast_progress(progress_data: Dict[str, Any]):
|
||||
return
|
||||
|
||||
message = json.dumps(progress_data, ensure_ascii=False)
|
||||
disconnected = set()
|
||||
disconnected: set[WebSocket] = set()
|
||||
|
||||
for websocket in active_connections:
|
||||
try:
|
||||
@@ -47,7 +43,6 @@ async def broadcast_progress(progress_data: Dict[str, Any]):
|
||||
logger.error(f"发送进度更新失败: {e}")
|
||||
disconnected.add(websocket)
|
||||
|
||||
# 移除断开的连接
|
||||
for websocket in disconnected:
|
||||
active_connections.discard(websocket)
|
||||
|
||||
@@ -57,23 +52,11 @@ async def update_progress(
|
||||
progress: int,
|
||||
message: str,
|
||||
operation: str = "fetch",
|
||||
error: str = None,
|
||||
plugin_id: str = None,
|
||||
error: Optional[str] = None,
|
||||
plugin_id: Optional[str] = None,
|
||||
total_plugins: int = 0,
|
||||
loaded_plugins: int = 0,
|
||||
):
|
||||
"""更新并广播进度
|
||||
|
||||
Args:
|
||||
stage: 阶段 (idle, loading, success, error)
|
||||
progress: 进度百分比 (0-100)
|
||||
message: 当前消息
|
||||
operation: 操作类型 (fetch, install, uninstall, update)
|
||||
error: 错误信息(可选)
|
||||
plugin_id: 当前操作的插件 ID
|
||||
total_plugins: 总插件数
|
||||
loaded_plugins: 已加载插件数
|
||||
"""
|
||||
) -> None:
|
||||
progress_data = {
|
||||
"operation": operation,
|
||||
"stage": stage,
|
||||
@@ -91,25 +74,13 @@ async def update_progress(
|
||||
|
||||
|
||||
@router.websocket("/ws/plugin-progress")
|
||||
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)):
|
||||
"""WebSocket 插件加载进度推送端点
|
||||
|
||||
客户端连接后会立即收到当前进度状态
|
||||
支持三种认证方式(按优先级):
|
||||
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||
2. Cookie 中的 maibot_session
|
||||
3. 直接使用 session token(兼容)
|
||||
|
||||
示例:ws://host/ws/plugin-progress?token=xxx
|
||||
"""
|
||||
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)) -> None:
|
||||
is_authenticated = False
|
||||
|
||||
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||
if token and verify_ws_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("插件进度 WebSocket 使用临时 token 认证成功")
|
||||
|
||||
# 方式 2: 尝试从 Cookie 获取 session token
|
||||
if not is_authenticated:
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
if cookie_token:
|
||||
@@ -118,7 +89,6 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
|
||||
is_authenticated = True
|
||||
logger.debug("插件进度 WebSocket 使用 Cookie 认证成功")
|
||||
|
||||
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||
if not is_authenticated and token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
@@ -135,18 +105,13 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
|
||||
logger.info(f"📡 插件进度 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
|
||||
|
||||
try:
|
||||
# 发送当前进度状态
|
||||
await websocket.send_text(json.dumps(current_progress, ensure_ascii=False))
|
||||
|
||||
# 保持连接并处理客户端消息
|
||||
while True:
|
||||
try:
|
||||
data = await websocket.receive_text()
|
||||
|
||||
# 处理客户端心跳
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理客户端消息时出错: {e}")
|
||||
break
|
||||
@@ -160,5 +125,4 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
|
||||
|
||||
|
||||
def get_progress_router() -> APIRouter:
|
||||
"""获取插件进度 WebSocket 路由器"""
|
||||
return router
|
||||
return router
|
||||
113
src/webui/routers/plugin/schemas.py
Normal file
113
src/webui/routers/plugin/schemas.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FetchRawFileRequest(BaseModel):
|
||||
owner: str = Field(..., description="仓库所有者", examples=["MaiM-with-u"])
|
||||
repo: str = Field(..., description="仓库名称", examples=["plugin-repo"])
|
||||
branch: str = Field(..., description="分支名称", examples=["main"])
|
||||
file_path: str = Field(..., description="文件路径", examples=["plugin_details.json"])
|
||||
mirror_id: Optional[str] = Field(None, description="指定镜像源 ID")
|
||||
custom_url: Optional[str] = Field(None, description="自定义完整 URL")
|
||||
|
||||
|
||||
class FetchRawFileResponse(BaseModel):
|
||||
success: bool = Field(..., description="是否成功")
|
||||
data: Optional[str] = Field(None, description="文件内容")
|
||||
error: Optional[str] = Field(None, description="错误信息")
|
||||
mirror_used: Optional[str] = Field(None, description="使用的镜像源")
|
||||
attempts: int = Field(..., description="尝试次数")
|
||||
url: Optional[str] = Field(None, description="实际请求的 URL")
|
||||
|
||||
|
||||
class CloneRepositoryRequest(BaseModel):
|
||||
owner: str = Field(..., description="仓库所有者", examples=["MaiM-with-u"])
|
||||
repo: str = Field(..., description="仓库名称", examples=["plugin-repo"])
|
||||
target_path: str = Field(..., description="目标路径(相对于插件目录)")
|
||||
branch: Optional[str] = Field(None, description="分支名称", examples=["main"])
|
||||
mirror_id: Optional[str] = Field(None, description="指定镜像源 ID")
|
||||
custom_url: Optional[str] = Field(None, description="自定义克隆 URL")
|
||||
depth: Optional[int] = Field(None, description="克隆深度(浅克隆)", ge=1)
|
||||
|
||||
|
||||
class CloneRepositoryResponse(BaseModel):
|
||||
success: bool = Field(..., description="是否成功")
|
||||
path: Optional[str] = Field(None, description="克隆路径")
|
||||
error: Optional[str] = Field(None, description="错误信息")
|
||||
mirror_used: Optional[str] = Field(None, description="使用的镜像源")
|
||||
attempts: int = Field(..., description="尝试次数")
|
||||
url: Optional[str] = Field(None, description="实际克隆的 URL")
|
||||
message: Optional[str] = Field(None, description="附加信息")
|
||||
|
||||
|
||||
class MirrorConfigResponse(BaseModel):
|
||||
id: str = Field(..., description="镜像源 ID")
|
||||
name: str = Field(..., description="镜像源名称")
|
||||
raw_prefix: str = Field(..., description="Raw 文件前缀")
|
||||
clone_prefix: str = Field(..., description="克隆前缀")
|
||||
enabled: bool = Field(..., description="是否启用")
|
||||
priority: int = Field(..., description="优先级(数字越小优先级越高)")
|
||||
|
||||
|
||||
class AvailableMirrorsResponse(BaseModel):
|
||||
mirrors: list[MirrorConfigResponse] = Field(..., description="镜像源列表")
|
||||
default_priority: list[str] = Field(..., description="默认优先级顺序(ID 列表)")
|
||||
|
||||
|
||||
class AddMirrorRequest(BaseModel):
|
||||
id: str = Field(..., description="镜像源 ID", examples=["custom-mirror"])
|
||||
name: str = Field(..., description="镜像源名称", examples=["自定义镜像源"])
|
||||
raw_prefix: str = Field(..., description="Raw 文件前缀", examples=["https://example.com/raw"])
|
||||
clone_prefix: str = Field(..., description="克隆前缀", examples=["https://example.com/clone"])
|
||||
enabled: bool = Field(True, description="是否启用")
|
||||
priority: Optional[int] = Field(None, description="优先级")
|
||||
|
||||
|
||||
class UpdateMirrorRequest(BaseModel):
|
||||
name: Optional[str] = Field(None, description="镜像源名称")
|
||||
raw_prefix: Optional[str] = Field(None, description="Raw 文件前缀")
|
||||
clone_prefix: Optional[str] = Field(None, description="克隆前缀")
|
||||
enabled: Optional[bool] = Field(None, description="是否启用")
|
||||
priority: Optional[int] = Field(None, description="优先级")
|
||||
|
||||
|
||||
class GitStatusResponse(BaseModel):
|
||||
installed: bool = Field(..., description="是否已安装 Git")
|
||||
version: Optional[str] = Field(None, description="Git 版本号")
|
||||
path: Optional[str] = Field(None, description="Git 可执行文件路径")
|
||||
error: Optional[str] = Field(None, description="错误信息")
|
||||
|
||||
|
||||
class InstallPluginRequest(BaseModel):
|
||||
plugin_id: str = Field(..., description="插件 ID")
|
||||
repository_url: str = Field(..., description="插件仓库 URL")
|
||||
branch: Optional[str] = Field("main", description="分支名称")
|
||||
mirror_id: Optional[str] = Field(None, description="指定镜像源 ID")
|
||||
|
||||
|
||||
class VersionResponse(BaseModel):
|
||||
version: str = Field(..., description="麦麦版本号")
|
||||
version_major: int = Field(..., description="主版本号")
|
||||
version_minor: int = Field(..., description="次版本号")
|
||||
version_patch: int = Field(..., description="补丁版本号")
|
||||
|
||||
|
||||
class UninstallPluginRequest(BaseModel):
|
||||
plugin_id: str = Field(..., description="插件 ID")
|
||||
|
||||
|
||||
class UpdatePluginRequest(BaseModel):
|
||||
plugin_id: str = Field(..., description="插件 ID")
|
||||
repository_url: str = Field(..., description="插件仓库 URL")
|
||||
branch: Optional[str] = Field("main", description="分支名称")
|
||||
mirror_id: Optional[str] = Field(None, description="指定镜像源 ID")
|
||||
|
||||
|
||||
class UpdatePluginConfigRequest(BaseModel):
|
||||
enabled: Optional[bool] = None
|
||||
config: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class UpdatePluginRawConfigRequest(BaseModel):
|
||||
config: str = Field(..., description="原始 TOML 配置内容")
|
||||
221
src/webui/routers/plugin/support.py
Normal file
221
src/webui/routers/plugin/support.py
Normal file
@@ -0,0 +1,221 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, cast, get_origin
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import stat
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.core.config_types import ConfigField
|
||||
from src.webui.core import get_token_manager
|
||||
|
||||
logger = get_logger("webui.plugin_routes")
|
||||
|
||||
|
||||
def require_plugin_token(maibot_session: Optional[str]) -> str:
|
||||
token_manager = get_token_manager()
|
||||
if not maibot_session or not token_manager.verify_token(maibot_session):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
return maibot_session
|
||||
|
||||
|
||||
def validate_safe_path(user_path: str, base_path: Path) -> Path:
|
||||
base_resolved = base_path.resolve()
|
||||
if any(pattern in user_path for pattern in ["..", "\x00"]):
|
||||
logger.warning(f"检测到可疑路径: {user_path}")
|
||||
raise HTTPException(status_code=400, detail="路径包含非法字符")
|
||||
|
||||
if user_path.startswith("/") or user_path.startswith("\\") or (len(user_path) > 1 and user_path[1] == ":"):
|
||||
logger.warning(f"检测到绝对路径: {user_path}")
|
||||
raise HTTPException(status_code=400, detail="不允许使用绝对路径")
|
||||
|
||||
target_path = (base_path / user_path).resolve()
|
||||
try:
|
||||
target_path.relative_to(base_resolved)
|
||||
except ValueError as e:
|
||||
logger.warning(f"路径遍历攻击检测: {user_path} -> {target_path}")
|
||||
raise HTTPException(status_code=400, detail="路径超出允许范围") from e
|
||||
|
||||
return target_path
|
||||
|
||||
|
||||
def validate_plugin_id(plugin_id: str) -> str:
|
||||
if not plugin_id or not plugin_id.strip():
|
||||
logger.warning("非法插件 ID: 空字符串")
|
||||
raise HTTPException(status_code=400, detail="插件 ID 不能为空")
|
||||
|
||||
for pattern in ["/", "\\", "\x00", "..", "\n", "\r", "\t"]:
|
||||
if pattern in plugin_id:
|
||||
logger.warning(f"非法插件 ID 格式: {plugin_id} (包含危险字符)")
|
||||
raise HTTPException(status_code=400, detail="插件 ID 包含非法字符")
|
||||
|
||||
if plugin_id.startswith(".") or plugin_id.endswith("."):
|
||||
logger.warning(f"非法插件 ID: {plugin_id}")
|
||||
raise HTTPException(status_code=400, detail="插件 ID 不能以点开头或结尾")
|
||||
|
||||
if plugin_id in {".", ".."}:
|
||||
logger.warning(f"非法插件 ID: {plugin_id}")
|
||||
raise HTTPException(status_code=400, detail="插件 ID 不能为特殊目录名")
|
||||
|
||||
return plugin_id
|
||||
|
||||
|
||||
def parse_version(version_str: str) -> tuple[int, int, int]:
|
||||
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version_str, flags=re.IGNORECASE)[0]
|
||||
parts = base_version.split(".")
|
||||
if len(parts) < 3:
|
||||
parts.extend(["0"] * (3 - len(parts)))
|
||||
|
||||
try:
|
||||
return int(parts[0]), int(parts[1]), int(parts[2])
|
||||
except (ValueError, IndexError):
|
||||
logger.warning(f"无法解析版本号: {version_str},返回默认值 (0, 0, 0)")
|
||||
return 0, 0, 0
|
||||
|
||||
|
||||
def deep_merge(dst: dict[str, Any], src: dict[str, Any]) -> None:
|
||||
for key, value in src.items():
|
||||
if key in dst and isinstance(dst[key], dict) and isinstance(value, dict):
|
||||
deep_merge(dst[key], value)
|
||||
else:
|
||||
dst[key] = value
|
||||
|
||||
|
||||
def normalize_dotted_keys(obj: dict[str, Any]) -> dict[str, Any]:
|
||||
result: dict[str, Any] = {}
|
||||
dotted_items: list[tuple[str, Any]] = []
|
||||
|
||||
for key, value in obj.items():
|
||||
if "." in key:
|
||||
dotted_items.append((key, value))
|
||||
else:
|
||||
result[key] = normalize_dotted_keys(value) if isinstance(value, dict) else value
|
||||
|
||||
for dotted_key, value in dotted_items:
|
||||
normalized_value = normalize_dotted_keys(value) if isinstance(value, dict) else value
|
||||
parts = dotted_key.split(".")
|
||||
if "" in parts:
|
||||
logger.warning(f"键路径包含空段: '{dotted_key}'")
|
||||
parts = [part for part in parts if part]
|
||||
if not parts:
|
||||
logger.warning(f"忽略空键路径: '{dotted_key}'")
|
||||
continue
|
||||
|
||||
current = result
|
||||
for index, part in enumerate(parts[:-1]):
|
||||
if part in current and not isinstance(current[part], dict):
|
||||
path_ctx = ".".join(parts[: index + 1])
|
||||
logger.warning(f"键冲突:{part} 已存在且非字典,覆盖为字典以展开 {dotted_key} (路径 {path_ctx})")
|
||||
current[part] = {}
|
||||
current = current.setdefault(part, {})
|
||||
|
||||
last_part = parts[-1]
|
||||
if last_part in current and isinstance(current[last_part], dict) and isinstance(normalized_value, dict):
|
||||
deep_merge(current[last_part], normalized_value)
|
||||
else:
|
||||
current[last_part] = normalized_value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def coerce_types(schema_part: dict[str, Any], config_part: dict[str, Any]) -> None:
|
||||
def is_list_type(tp: Any) -> bool:
|
||||
origin = get_origin(tp)
|
||||
return tp is list or origin is list
|
||||
|
||||
for key, schema_val in schema_part.items():
|
||||
if key not in config_part:
|
||||
continue
|
||||
value = config_part[key]
|
||||
if isinstance(schema_val, ConfigField):
|
||||
if is_list_type(schema_val.type) and isinstance(value, str):
|
||||
config_part[key] = [item.strip() for item in value.split(",") if item.strip()]
|
||||
elif isinstance(schema_val, dict) and isinstance(value, dict):
|
||||
coerce_types(schema_val, value)
|
||||
|
||||
|
||||
def find_plugin_instance(plugin_id: str) -> Optional[Any]:
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
manager = get_plugin_runtime_manager()
|
||||
for supervisor in manager.supervisors:
|
||||
registered = supervisor._registered_plugins.get(plugin_id)
|
||||
if registered is not None:
|
||||
return registered
|
||||
return None
|
||||
|
||||
|
||||
def get_plugins_dir() -> Path:
|
||||
plugins_dir = Path("plugins").resolve()
|
||||
plugins_dir.mkdir(exist_ok=True)
|
||||
return plugins_dir
|
||||
|
||||
|
||||
def get_plugin_candidate_paths(plugin_id: str) -> tuple[Path, Path]:
|
||||
plugins_dir = get_plugins_dir()
|
||||
folder_name = plugin_id.replace(".", "_")
|
||||
return validate_safe_path(folder_name, plugins_dir), validate_safe_path(plugin_id, plugins_dir)
|
||||
|
||||
|
||||
def resolve_installed_plugin_path(plugin_id: str) -> Optional[Path]:
|
||||
new_format_path, old_format_path = get_plugin_candidate_paths(plugin_id)
|
||||
if new_format_path.exists():
|
||||
return new_format_path
|
||||
return old_format_path if old_format_path.exists() else None
|
||||
|
||||
|
||||
def parse_repository_url(repository_url: str) -> tuple[str, str, str]:
|
||||
repo_url = repository_url.rstrip("/").removesuffix(".git")
|
||||
parts = repo_url.split("/")
|
||||
if len(parts) < 2:
|
||||
raise HTTPException(status_code=400, detail="无效的仓库 URL")
|
||||
return repo_url, parts[-2], parts[-1]
|
||||
|
||||
|
||||
def load_manifest_json(manifest_path: Path) -> Optional[dict[str, Any]]:
|
||||
if not manifest_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as file_obj:
|
||||
return cast(dict[str, Any], json.load(file_obj))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def iter_plugin_directories() -> list[Path]:
|
||||
return [path for path in get_plugins_dir().iterdir() if path.is_dir()]
|
||||
|
||||
|
||||
def find_plugin_path_by_id(plugin_id: str) -> Optional[Path]:
|
||||
for plugin_path in iter_plugin_directories():
|
||||
manifest = load_manifest_json(plugin_path / "_manifest.json")
|
||||
if manifest is not None and (manifest.get("id") == plugin_id or plugin_path.name == plugin_id):
|
||||
return plugin_path
|
||||
return None
|
||||
|
||||
|
||||
def backup_file(file_path: Path, action: str, move_file: bool = False) -> Optional[Path]:
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
backup_name = f"{file_path.name}.{action}.{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
backup_path = file_path.parent / backup_name
|
||||
if move_file:
|
||||
shutil.move(file_path, backup_path)
|
||||
else:
|
||||
shutil.copy(file_path, backup_path)
|
||||
return backup_path
|
||||
|
||||
|
||||
def remove_tree(path: Path) -> None:
|
||||
def remove_readonly(func: Any, target_path: str, _: Any) -> None:
|
||||
os.chmod(target_path, stat.S_IWRITE)
|
||||
func(target_path)
|
||||
|
||||
shutil.rmtree(path, onerror=remove_readonly)
|
||||
@@ -1,29 +1,22 @@
|
||||
"""统计数据 API 路由"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Cookie, Depends, Header, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import desc, func, or_
|
||||
from sqlmodel import col, select
|
||||
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Messages, ModelUsage, OnlineTime
|
||||
from src.common.database.database_model import ModelUsage, OnlineTime
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.core import verify_auth_token_from_cookie_or_header
|
||||
from src.common.message_repository import count_messages
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
logger = get_logger("webui.statistics")
|
||||
|
||||
router = APIRouter(prefix="/statistics", tags=["statistics"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
router = APIRouter(prefix="/statistics", tags=["statistics"], dependencies=[Depends(require_auth)])
|
||||
|
||||
|
||||
class StatisticsSummary(BaseModel):
|
||||
@@ -70,7 +63,7 @@ class DashboardData(BaseModel):
|
||||
|
||||
|
||||
@router.get("/dashboard", response_model=DashboardData)
|
||||
async def get_dashboard_data(hours: int = 24, _auth: bool = Depends(require_auth)):
|
||||
async def get_dashboard_data(hours: int = 24):
|
||||
"""
|
||||
获取仪表盘统计数据
|
||||
|
||||
@@ -159,24 +152,12 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S
|
||||
if end > start:
|
||||
summary.online_time += (end - start).total_seconds()
|
||||
|
||||
# 查询消息数量 - 使用聚合优化
|
||||
with get_db_session() as session:
|
||||
statement = select(func.count()).where(
|
||||
col(Messages.timestamp) >= start_time,
|
||||
col(Messages.timestamp) <= end_time,
|
||||
)
|
||||
total_messages = session.exec(statement).one()
|
||||
summary.total_messages = int(total_messages or 0)
|
||||
|
||||
# 统计回复数量
|
||||
with get_db_session() as session:
|
||||
statement = select(func.count()).where(
|
||||
col(Messages.timestamp) >= start_time,
|
||||
col(Messages.timestamp) <= end_time,
|
||||
col(Messages.reply_to).is_not(None),
|
||||
)
|
||||
total_replies = session.exec(statement).one()
|
||||
summary.total_replies = int(total_replies or 0)
|
||||
summary.total_messages = count_messages(start_time=start_time.timestamp(), end_time=end_time.timestamp())
|
||||
summary.total_replies = count_messages(
|
||||
start_time=start_time.timestamp(),
|
||||
end_time=end_time.timestamp(),
|
||||
has_reply_to=True,
|
||||
)
|
||||
|
||||
# 计算派生指标
|
||||
if summary.online_time > 0:
|
||||
@@ -351,7 +332,7 @@ async def _get_recent_activity(limit: int = 10) -> list[dict[str, Any]]:
|
||||
|
||||
|
||||
@router.get("/summary")
|
||||
async def get_summary(hours: int = 24, _auth: bool = Depends(require_auth)):
|
||||
async def get_summary(hours: int = 24):
|
||||
"""
|
||||
获取统计摘要
|
||||
|
||||
@@ -369,7 +350,7 @@ async def get_summary(hours: int = 24, _auth: bool = Depends(require_auth)):
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
async def get_model_stats(hours: int = 24, _auth: bool = Depends(require_auth)):
|
||||
async def get_model_stats(hours: int = 24):
|
||||
"""
|
||||
获取模型统计
|
||||
|
||||
|
||||
@@ -7,28 +7,20 @@
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from src.config.config import MMC_VERSION
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.core import verify_auth_token_from_cookie_or_header
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
router = APIRouter(prefix="/system", tags=["system"])
|
||||
router = APIRouter(prefix="/system", tags=["system"], dependencies=[Depends(require_auth)])
|
||||
logger = get_logger("webui_system")
|
||||
|
||||
# 记录启动时间
|
||||
_start_time = time.time()
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
class RestartResponse(BaseModel):
|
||||
"""重启响应"""
|
||||
|
||||
@@ -46,7 +38,7 @@ class StatusResponse(BaseModel):
|
||||
|
||||
|
||||
@router.post("/restart", response_model=RestartResponse)
|
||||
async def restart_maibot(_auth: bool = Depends(require_auth)):
|
||||
async def restart_maibot():
|
||||
"""
|
||||
重启麦麦主程序
|
||||
|
||||
@@ -77,7 +69,7 @@ async def restart_maibot(_auth: bool = Depends(require_auth)):
|
||||
|
||||
|
||||
@router.get("/status", response_model=StatusResponse)
|
||||
async def get_maibot_status(_auth: bool = Depends(require_auth)):
|
||||
async def get_maibot_status():
|
||||
"""
|
||||
获取麦麦运行状态
|
||||
|
||||
@@ -100,7 +92,7 @@ async def get_maibot_status(_auth: bool = Depends(require_auth)):
|
||||
|
||||
|
||||
@router.post("/reload-config")
|
||||
async def reload_config(_auth: bool = Depends(require_auth)):
|
||||
async def reload_config():
|
||||
"""
|
||||
热重载配置(不重启进程)
|
||||
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from .logs import router as logs_router
|
||||
from .plugin_progress import get_progress_router
|
||||
from .auth import router as ws_auth_router
|
||||
|
||||
__all__ = [
|
||||
"logs_router",
|
||||
"get_progress_router",
|
||||
"ws_auth_router",
|
||||
]
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
"""WebSocket 认证模块
|
||||
"""WebSocket 认证模块。"""
|
||||
|
||||
提供所有 WebSocket 端点统一使用的临时 token 认证机制。
|
||||
临时 token 有效期 60 秒,且只能使用一次,用于解决 WebSocket 握手时 Cookie 不可用的问题。
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Cookie, Header
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Cookie
|
||||
import secrets
|
||||
import time
|
||||
from src.common.logger import get_logger
|
||||
@@ -77,25 +74,17 @@ def verify_ws_token(temp_token: str) -> bool:
|
||||
@router.get("/ws-token")
|
||||
async def get_ws_token(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取 WebSocket 连接用的临时 token
|
||||
|
||||
此端点验证当前会话的 Cookie 或 Authorization header,
|
||||
此端点验证当前会话 Cookie,
|
||||
然后返回一个临时 token 用于 WebSocket 握手认证。
|
||||
临时 token 有效期 60 秒,且只能使用一次。
|
||||
|
||||
注意:在未认证时返回 200 状态码但 success=False,避免前端因 401 刷新页面。
|
||||
"""
|
||||
# 获取当前 session token
|
||||
session_token = None
|
||||
if maibot_session:
|
||||
session_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
session_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not session_token:
|
||||
if not maibot_session:
|
||||
# 返回 200 但 success=False,避免前端因 401 刷新页面
|
||||
# 这在登录页面是正常情况,不应该触发错误处理
|
||||
logger.debug("ws-token 请求:未提供认证信息(可能在登录页面)")
|
||||
@@ -103,12 +92,12 @@ async def get_ws_token(
|
||||
|
||||
# 验证 session token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(session_token):
|
||||
if not token_manager.verify_token(maibot_session):
|
||||
# 同样返回 200 但 success=False,避免前端刷新
|
||||
logger.debug("ws-token 请求:认证已过期")
|
||||
return {"success": False, "message": "认证已过期,请重新登录", "token": None, "expires_in": 0}
|
||||
|
||||
# 生成临时 WebSocket token
|
||||
ws_token = generate_ws_token(session_token)
|
||||
ws_token = generate_ws_token(maibot_session)
|
||||
|
||||
return {"success": True, "token": ws_token, "expires_in": _WS_TOKEN_EXPIRE_SECONDS}
|
||||
|
||||
@@ -1,177 +1,11 @@
|
||||
"""WebSocket 日志推送模块"""
|
||||
"""WebSocket 日志推送路由兼容导出。"""
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
from typing import Set, Optional
|
||||
import json
|
||||
from pathlib import Path
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.core import get_token_manager
|
||||
from src.webui.routers.websocket.auth import verify_ws_token
|
||||
from src.webui.logs_ws import active_connections, broadcast_log, load_recent_logs, router, websocket_logs
|
||||
|
||||
logger = get_logger("webui.logs_ws")
|
||||
router = APIRouter()
|
||||
|
||||
# 全局 WebSocket 连接池
|
||||
active_connections: Set[WebSocket] = set()
|
||||
|
||||
|
||||
def load_recent_logs(limit: int = 100) -> list[dict]:
|
||||
"""从日志文件中加载最近的日志
|
||||
|
||||
Args:
|
||||
limit: 返回的最大日志条数
|
||||
|
||||
Returns:
|
||||
日志列表
|
||||
"""
|
||||
logs = []
|
||||
log_dir = Path("logs")
|
||||
|
||||
if not log_dir.exists():
|
||||
return logs
|
||||
|
||||
# 获取所有日志文件,按修改时间排序
|
||||
log_files = sorted(log_dir.glob("app_*.log.jsonl"), key=lambda f: f.stat().st_mtime, reverse=True)
|
||||
|
||||
# 用于生成唯一 ID 的计数器
|
||||
log_counter = 0
|
||||
|
||||
# 从最新的文件开始读取
|
||||
for log_file in log_files:
|
||||
if len(logs) >= limit:
|
||||
break
|
||||
|
||||
try:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
# 从文件末尾开始读取
|
||||
for line in reversed(lines):
|
||||
if len(logs) >= limit:
|
||||
break
|
||||
try:
|
||||
log_entry = json.loads(line.strip())
|
||||
# 转换为前端期望的格式
|
||||
# 使用时间戳 + 计数器生成唯一 ID
|
||||
timestamp_id = (
|
||||
log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
|
||||
)
|
||||
formatted_log = {
|
||||
"id": f"{timestamp_id}_{log_counter}",
|
||||
"timestamp": log_entry.get("timestamp", ""),
|
||||
"level": log_entry.get("level", "INFO").upper(),
|
||||
"module": log_entry.get("logger_name", ""),
|
||||
"message": log_entry.get("event", ""),
|
||||
}
|
||||
logs.append(formatted_log)
|
||||
log_counter += 1
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"读取日志文件失败 {log_file}: {e}")
|
||||
continue
|
||||
|
||||
# 反转列表,使其按时间顺序排列(旧到新)
|
||||
return list(reversed(logs))
|
||||
|
||||
|
||||
@router.websocket("/ws/logs")
|
||||
async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None)):
|
||||
"""WebSocket 日志推送端点
|
||||
|
||||
客户端连接后会持续接收服务器端的日志消息
|
||||
支持三种认证方式(按优先级):
|
||||
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||
2. Cookie 中的 maibot_session
|
||||
3. 直接使用 session token(兼容)
|
||||
|
||||
示例:ws://host/ws/logs?token=xxx
|
||||
"""
|
||||
is_authenticated = False
|
||||
|
||||
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||
if token and verify_ws_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("WebSocket 使用临时 token 认证成功")
|
||||
|
||||
# 方式 2: 尝试从 Cookie 获取 session token
|
||||
if not is_authenticated:
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
if cookie_token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
is_authenticated = True
|
||||
logger.debug("WebSocket 使用 Cookie 认证成功")
|
||||
|
||||
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||
if not is_authenticated and token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("WebSocket 使用 session token 认证成功")
|
||||
|
||||
if not is_authenticated:
|
||||
logger.warning("WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
active_connections.add(websocket)
|
||||
logger.info(f"📡 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
|
||||
|
||||
# 连接建立后,立即发送历史日志
|
||||
try:
|
||||
recent_logs = load_recent_logs(limit=100)
|
||||
logger.info(f"发送 {len(recent_logs)} 条历史日志到客户端")
|
||||
|
||||
for log_entry in recent_logs:
|
||||
await websocket.send_text(json.dumps(log_entry, ensure_ascii=False))
|
||||
except Exception as e:
|
||||
logger.error(f"发送历史日志失败: {e}")
|
||||
|
||||
try:
|
||||
# 保持连接,等待客户端消息或断开
|
||||
while True:
|
||||
# 接收客户端消息(用于心跳或控制指令)
|
||||
data = await websocket.receive_text()
|
||||
|
||||
# 可以处理客户端的控制消息,例如:
|
||||
# - "ping" -> 心跳检测
|
||||
# - {"filter": "ERROR"} -> 设置日志级别过滤
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
|
||||
except WebSocketDisconnect:
|
||||
active_connections.discard(websocket)
|
||||
logger.info(f"📡 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ WebSocket 错误: {e}")
|
||||
active_connections.discard(websocket)
|
||||
|
||||
|
||||
async def broadcast_log(log_data: dict):
|
||||
"""广播日志到所有连接的 WebSocket 客户端
|
||||
|
||||
Args:
|
||||
log_data: 日志数据字典
|
||||
"""
|
||||
if not active_connections:
|
||||
return
|
||||
|
||||
# 格式化为 JSON
|
||||
message = json.dumps(log_data, ensure_ascii=False)
|
||||
|
||||
# 记录需要断开的连接
|
||||
disconnected = set()
|
||||
|
||||
# 广播到所有客户端
|
||||
for connection in active_connections:
|
||||
try:
|
||||
await connection.send_text(message)
|
||||
except Exception:
|
||||
# 发送失败,标记为断开
|
||||
disconnected.add(connection)
|
||||
|
||||
# 清理断开的连接
|
||||
if disconnected:
|
||||
active_connections.difference_update(disconnected)
|
||||
logger.debug(f"清理了 {len(disconnected)} 个断开的 WebSocket 连接")
|
||||
__all__ = [
|
||||
"active_connections",
|
||||
"broadcast_log",
|
||||
"load_recent_logs",
|
||||
"router",
|
||||
"websocket_logs",
|
||||
]
|
||||
|
||||
@@ -1,28 +1,26 @@
|
||||
"""WebUI API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie, Depends
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.core import (
|
||||
clear_auth_cookie,
|
||||
check_auth_rate_limit,
|
||||
get_rate_limiter,
|
||||
get_token_manager,
|
||||
set_auth_cookie,
|
||||
clear_auth_cookie,
|
||||
get_rate_limiter,
|
||||
check_auth_rate_limit,
|
||||
)
|
||||
from src.webui.dependencies import require_auth, verify_token_optional
|
||||
from src.webui.routers.config import router as config_router
|
||||
from src.webui.routers.statistics import router as statistics_router
|
||||
from src.webui.routers.person import router as person_router
|
||||
from src.webui.routers.expression import router as expression_router
|
||||
from src.webui.routers.jargon import router as jargon_router
|
||||
from src.webui.routers.emoji import router as emoji_router
|
||||
from src.webui.routers.plugin import router as plugin_router
|
||||
from src.webui.routers.websocket.plugin_progress import get_progress_router
|
||||
from src.webui.routers.plugin import get_progress_router, router as plugin_router
|
||||
from src.webui.routers.system import router as system_router
|
||||
from src.webui.routers.model import router as model_router
|
||||
from src.webui.routers.websocket.auth import router as ws_auth_router
|
||||
from src.webui.routers.annual_report import router as annual_report_router
|
||||
|
||||
logger = get_logger("webui.api")
|
||||
|
||||
@@ -51,8 +49,6 @@ router.include_router(system_router)
|
||||
router.include_router(model_router)
|
||||
# 注册 WebSocket 认证路由
|
||||
router.include_router(ws_auth_router)
|
||||
# 注册年度报告路由
|
||||
router.include_router(annual_report_router)
|
||||
|
||||
|
||||
class TokenVerifyRequest(BaseModel):
|
||||
@@ -190,9 +186,7 @@ async def logout(response: Response):
|
||||
|
||||
@router.get("/auth/check")
|
||||
async def check_auth_status(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
authenticated: bool = Depends(verify_token_optional),
|
||||
):
|
||||
"""
|
||||
检查当前认证状态(用于前端判断是否已登录)
|
||||
@@ -201,46 +195,17 @@ async def check_auth_status(
|
||||
认证状态
|
||||
"""
|
||||
try:
|
||||
token = None
|
||||
|
||||
# 记录请求信息用于调试
|
||||
logger.debug(
|
||||
f"检查认证状态 - Cookie: {maibot_session[:20] if maibot_session else 'None'}..., Authorization: {'Present' if authorization else 'None'}"
|
||||
)
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
logger.debug("使用 Cookie 中的 token")
|
||||
# 其次从 Header 获取
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
logger.debug("使用 Header 中的 token")
|
||||
|
||||
if not token:
|
||||
logger.debug("未找到 token,返回未认证")
|
||||
return {"authenticated": False}
|
||||
|
||||
token_manager = get_token_manager()
|
||||
is_valid = token_manager.verify_token(token)
|
||||
logger.debug(f"Token 验证结果: {is_valid}")
|
||||
|
||||
if is_valid:
|
||||
return {"authenticated": True}
|
||||
else:
|
||||
return {"authenticated": False}
|
||||
logger.debug(f"检查认证状态,结果: {authenticated}")
|
||||
return {"authenticated": authenticated}
|
||||
except Exception as e:
|
||||
logger.error(f"认证检查失败: {e}", exc_info=True)
|
||||
return {"authenticated": False}
|
||||
|
||||
|
||||
@router.post("/auth/update", response_model=TokenUpdateResponse)
|
||||
@router.post("/auth/update", response_model=TokenUpdateResponse, dependencies=[Depends(require_auth)])
|
||||
async def update_token(
|
||||
request: TokenUpdateRequest,
|
||||
response: Response,
|
||||
req: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
更新访问令牌(需要当前有效的 token)
|
||||
@@ -248,28 +213,13 @@ async def update_token(
|
||||
Args:
|
||||
request: 包含新 token 的更新请求
|
||||
response: FastAPI Response 对象
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
# 验证当前 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="当前 Token 无效")
|
||||
|
||||
# 更新 token
|
||||
success, message = token_manager.update_token(request.new_token)
|
||||
|
||||
@@ -285,40 +235,22 @@ async def update_token(
|
||||
raise HTTPException(status_code=500, detail="Token 更新失败") from e
|
||||
|
||||
|
||||
@router.post("/auth/regenerate", response_model=TokenRegenerateResponse)
|
||||
@router.post("/auth/regenerate", response_model=TokenRegenerateResponse, dependencies=[Depends(require_auth)])
|
||||
async def regenerate_token(
|
||||
response: Response,
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
重新生成访问令牌(需要当前有效的 token)
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
新生成的 token
|
||||
"""
|
||||
try:
|
||||
# 验证当前 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="当前 Token 无效")
|
||||
|
||||
# 重新生成 token
|
||||
new_token = token_manager.regenerate_token()
|
||||
|
||||
@@ -333,38 +265,17 @@ async def regenerate_token(
|
||||
raise HTTPException(status_code=500, detail="Token 重新生成失败") from e
|
||||
|
||||
|
||||
@router.get("/setup/status", response_model=FirstSetupStatusResponse)
|
||||
async def get_setup_status(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
@router.get("/setup/status", response_model=FirstSetupStatusResponse, dependencies=[Depends(require_auth)])
|
||||
async def get_setup_status():
|
||||
"""
|
||||
获取首次配置状态
|
||||
|
||||
Args:
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
首次配置状态
|
||||
"""
|
||||
try:
|
||||
# 验证 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效")
|
||||
|
||||
# 检查是否为首次配置
|
||||
is_first = token_manager.is_first_setup()
|
||||
|
||||
@@ -376,38 +287,17 @@ async def get_setup_status(
|
||||
raise HTTPException(status_code=500, detail="获取配置状态失败") from e
|
||||
|
||||
|
||||
@router.post("/setup/complete", response_model=CompleteSetupResponse)
|
||||
async def complete_setup(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
@router.post("/setup/complete", response_model=CompleteSetupResponse, dependencies=[Depends(require_auth)])
|
||||
async def complete_setup():
|
||||
"""
|
||||
标记首次配置完成
|
||||
|
||||
Args:
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
完成结果
|
||||
"""
|
||||
try:
|
||||
# 验证 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效")
|
||||
|
||||
# 标记配置完成
|
||||
success = token_manager.mark_setup_completed()
|
||||
|
||||
@@ -419,38 +309,17 @@ async def complete_setup(
|
||||
raise HTTPException(status_code=500, detail="标记配置完成失败") from e
|
||||
|
||||
|
||||
@router.post("/setup/reset", response_model=ResetSetupResponse)
|
||||
async def reset_setup(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
@router.post("/setup/reset", response_model=ResetSetupResponse, dependencies=[Depends(require_auth)])
|
||||
async def reset_setup():
|
||||
"""
|
||||
重置首次配置状态,允许重新进入配置向导
|
||||
|
||||
Args:
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
重置结果
|
||||
"""
|
||||
try:
|
||||
# 验证 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效")
|
||||
|
||||
# 重置配置状态
|
||||
success = token_manager.reset_setup_status()
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ class WebUIServer:
|
||||
service_name="WebUI 服务器",
|
||||
logger=logger,
|
||||
config_hint="WEBUI_PORT (.env)",
|
||||
allow_reuse_addr=True,
|
||||
)
|
||||
|
||||
config = Config(
|
||||
|
||||
Reference in New Issue
Block a user