WebUI 前端 & 后端超级大重构

This commit is contained in:
DrSmoothl
2026-03-14 21:06:36 +08:00
parent 6ca5a2939e
commit 172615f18a
69 changed files with 3128 additions and 6581 deletions

View File

@@ -14,13 +14,13 @@ import {
import { Button } from '@/components/ui/button'
import {
Dialog,
DialogBody,
DialogContent,
DialogHeader,
DialogTitle,
} from '@/components/ui/dialog'
import { Input } from '@/components/ui/input'
import { Label } from '@/components/ui/label'
import { ScrollArea } from '@/components/ui/scroll-area'
import { useBackendConnections } from '@/hooks/useBackendConnections'
import { isElectron } from '@/lib/runtime'
import type { BackendConnection } from '@/types/electron'
@@ -78,7 +78,7 @@ export function BackendManager({ open, onOpenChange }: BackendManagerProps) {
return (
<>
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="max-w-md sm:max-w-[425px]">
<DialogContent className="max-w-md sm:max-w-106.25">
<DialogHeader>
<DialogTitle></DialogTitle>
</DialogHeader>
@@ -88,7 +88,7 @@ export function BackendManager({ open, onOpenChange }: BackendManagerProps) {
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
</div>
) : (
<ScrollArea className="max-h-[60vh] pr-4">
<DialogBody className="pr-4">
<div className="flex flex-col gap-3 py-4">
{backends.map((backend) => {
const isActive = backend.id === activeId
@@ -100,7 +100,7 @@ export function BackendManager({ open, onOpenChange }: BackendManagerProps) {
}`}
>
<div className="flex flex-1 items-center gap-3 overflow-hidden">
<div className="flex-shrink-0">
<div className="shrink-0">
{isActive ? (
<Check className="h-5 w-5 text-blue-500" />
) : (
@@ -156,7 +156,7 @@ export function BackendManager({ open, onOpenChange }: BackendManagerProps) {
)
})}
</div>
</ScrollArea>
</DialogBody>
)}
<div className="flex justify-end pt-4 border-t">
@@ -173,7 +173,7 @@ export function BackendManager({ open, onOpenChange }: BackendManagerProps) {
{/* Edit/Add Dialog */}
<Dialog open={!!editConn} onOpenChange={(open) => !open && setEditConn(null)}>
<DialogContent className="sm:max-w-[425px]">
<DialogContent className="sm:max-w-106.25" confirmOnEnter>
<DialogHeader>
<DialogTitle>{editConn?.id ? '编辑连接' : '添加连接'}</DialogTitle>
</DialogHeader>
@@ -212,6 +212,7 @@ export function BackendManager({ open, onOpenChange }: BackendManagerProps) {
!editConn?.url ||
!/^https?:\/\//.test(editConn.url)
}
data-dialog-action="confirm"
>
</Button>

View File

@@ -22,6 +22,7 @@ import {
import { Button } from '@/components/ui/button'
import { Input } from '@/components/ui/input'
import { Badge } from '@/components/ui/badge'
import { ShortcutKbd } from '@/components/ui/kbd'
import { ScrollArea } from '@/components/ui/scroll-area'
import { Checkbox } from '@/components/ui/checkbox'
import { Tabs, TabsList, TabsTrigger } from '@/components/ui/tabs'
@@ -1689,19 +1690,19 @@ if (isCurrent) {
{/* 底部快捷键提示(桌面端) */}
<div className="hidden sm:flex items-center justify-center gap-6 px-6 py-3 border-t text-xs text-muted-foreground">
<div className="flex items-center gap-1">
<kbd className="px-2 py-1 bg-muted rounded text-xs"></kbd>
<ShortcutKbd size="sm" keys={['left']} />
<span></span>
</div>
<div className="flex items-center gap-1">
<kbd className="px-2 py-1 bg-muted rounded text-xs"></kbd>
<ShortcutKbd size="sm" keys={['right']} />
<span></span>
</div>
<div className="flex items-center gap-1">
<kbd className="px-2 py-1 bg-muted rounded text-xs"></kbd>
<ShortcutKbd size="sm" keys={['up']} />
<span></span>
</div>
<div className="flex items-center gap-1">
<kbd className="px-2 py-1 bg-muted rounded text-xs"></kbd>
<ShortcutKbd size="sm" keys={['down']} />
<span></span>
</div>
<span className="text-muted-foreground/50">|</span>

View File

@@ -1,5 +1,4 @@
import { Link } from '@tanstack/react-router'
import { BookOpen, ChevronLeft, Globe, LogOut, Menu, Moon, PieChart, Search, Server, Sun } from 'lucide-react'
import { BookOpen, ChevronLeft, Globe, LogOut, Menu, Moon, Search, Server, Sun } from 'lucide-react'
import { useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
@@ -13,7 +12,7 @@ import {
DropdownMenuItem,
DropdownMenuTrigger,
} from '@/components/ui/dropdown-menu'
import { Kbd } from '@/components/ui/kbd'
import { ShortcutKbd } from '@/components/ui/kbd'
import { toggleThemeWithTransition } from '@/components/use-theme'
import { useBackground } from '@/hooks/use-background'
import { logout } from '@/lib/fetch-with-auth'
@@ -99,7 +98,7 @@ export function Header({
title={t('header.toggleConnection')}
>
<Server className="h-4 w-4" />
<span className="hidden sm:inline text-xs text-muted-foreground truncate max-w-[100px]">
<span className="hidden sm:inline text-xs text-muted-foreground truncate max-w-25">
{activeBackendName}
</span>
</Button>
@@ -107,19 +106,6 @@ export function Header({
<div className="h-6 w-px bg-border" />
</>
)}
{/* 年度总结入口 */}
<Link to="/annual-report">
<Button
variant="ghost"
size="sm"
className="gap-2 bg-gradient-to-r from-pink-500/10 to-purple-500/10 hover:from-pink-500/20 hover:to-purple-500/20 border border-pink-500/20"
title={t('header.viewAnnualSummary')}
>
<PieChart className="h-4 w-4 text-pink-500" />
<span className="hidden sm:inline bg-gradient-to-r from-pink-500 to-purple-500 bg-clip-text text-transparent font-medium">{t('header.annualSummary')}</span>
</Button>
</Link>
{/* 搜索框 */}
<button
onClick={() => onSearchOpenChange(true)}
@@ -128,9 +114,7 @@ export function Header({
>
<Search className="absolute left-3 top-1/2 h-4 w-4 -translate-y-1/2 text-muted-foreground" aria-hidden="true" />
<span className="text-sm text-muted-foreground">{t('header.searchPlaceholder')}</span>
<Kbd size="sm" className="absolute right-2 top-1/2 -translate-y-1/2">
<span className="text-xs"></span>K
</Kbd>
<ShortcutKbd size="sm" className="absolute right-2 top-1/2 -translate-y-1/2" keys={['mod', 'k']} />
</button>
{/* 搜索对话框 */}

View File

@@ -13,6 +13,7 @@ import { useAuthGuard } from '@/hooks/use-auth'
import { useBackground } from '@/hooks/use-background'
import { TitleBar } from '@/components/electron/TitleBar'
import { matchesShortcut } from '@/lib/keyboard'
import { isElectron } from '@/lib/runtime'
import { cn } from '@/lib/utils'
import { menuSections } from './constants'
@@ -49,7 +50,7 @@ export function Layout({ children }: LayoutProps) {
// 搜索快捷键监听Cmd/Ctrl + K
useEffect(() => {
const handleKeyDown = (e: KeyboardEvent) => {
if ((e.metaKey || e.ctrlKey) && e.key === 'k') {
if (matchesShortcut(e, ['mod', 'k'])) {
e.preventDefault()
setSearchOpen(true)
}
@@ -68,9 +69,8 @@ export function Layout({ children }: LayoutProps) {
}
}
const unsubscribe = router.subscribe('onResolved', () => {
const pathname = router.state.location.pathname
const pageTitle = pathToLabel[pathname] ?? 'MaiBot Dashboard'
return router.subscribe('onResolved', () => {
const pageTitle = pathToLabel[router.state.location.pathname] ?? 'MaiBot Dashboard'
const fullTitle = pageTitle === 'MaiBot Dashboard'
? 'MaiBot Dashboard'
: `${pageTitle} — MaiBot Dashboard`
@@ -90,8 +90,6 @@ export function Layout({ children }: LayoutProps) {
})
}
})
return unsubscribe
}, [router, announce, t])
// 获取实际应用的主题(处理 system 情况)

View File

@@ -6,25 +6,25 @@ export const menuSections: MenuSection[] = [
{
title: 'sidebar.groups.overview',
items: [
{ icon: Home, label: 'sidebar.menu.home', path: '/' },
{ icon: Home, label: 'sidebar.menu.home', path: '/', searchDescription: 'search.items.homeDesc' },
],
},
{
title: 'sidebar.groups.botConfig',
items: [
{ icon: FileText, label: 'sidebar.menu.botMainConfig', path: '/config/bot' },
{ icon: Server, label: 'sidebar.menu.aiModelProvider', path: '/config/modelProvider', tourId: 'sidebar-model-provider' },
{ icon: Boxes, label: 'sidebar.menu.modelManagement', path: '/config/model', tourId: 'sidebar-model-management' },
{ icon: FileText, label: 'sidebar.menu.botMainConfig', path: '/config/bot', searchDescription: 'search.items.botConfigDesc' },
{ icon: Server, label: 'sidebar.menu.aiModelProvider', path: '/config/modelProvider', searchDescription: 'search.items.modelProviderDesc', tourId: 'sidebar-model-provider' },
{ icon: Boxes, label: 'sidebar.menu.modelManagement', path: '/config/model', searchDescription: 'search.items.modelDesc', tourId: 'sidebar-model-management' },
{ icon: Sliders, label: 'sidebar.menu.adapterConfig', path: '/config/adapter' },
],
},
{
title: 'sidebar.groups.botResources',
items: [
{ icon: Smile, label: 'sidebar.menu.emojiManagement', path: '/resource/emoji' },
{ icon: MessageSquare, label: 'sidebar.menu.expressionManagement', path: '/resource/expression' },
{ icon: Hash, label: 'sidebar.menu.slangManagement', path: '/resource/jargon' },
{ icon: UserCircle, label: 'sidebar.menu.personInfo', path: '/resource/person' },
{ icon: Smile, label: 'sidebar.menu.emojiManagement', path: '/resource/emoji', searchDescription: 'search.items.emojiDesc' },
{ icon: MessageSquare, label: 'sidebar.menu.expressionManagement', path: '/resource/expression', searchDescription: 'search.items.expressionDesc' },
{ icon: Hash, label: 'sidebar.menu.slangManagement', path: '/resource/jargon', searchDescription: 'search.items.jargonDesc' },
{ icon: UserCircle, label: 'sidebar.menu.personInfo', path: '/resource/person', searchDescription: 'search.items.personDesc' },
{ icon: Network, label: 'sidebar.menu.knowledgeGraph', path: '/resource/knowledge-graph' },
{ icon: Database, label: 'sidebar.menu.knowledgeBase', path: '/resource/knowledge-base' },
],
@@ -32,10 +32,10 @@ export const menuSections: MenuSection[] = [
{
title: 'sidebar.groups.extensionsMonitor',
items: [
{ icon: Package, label: 'sidebar.menu.pluginMarket', path: '/plugins' },
{ icon: Package, label: 'sidebar.menu.pluginMarket', path: '/plugins', searchDescription: 'search.items.pluginsDesc' },
{ icon: LayoutGrid, label: 'sidebar.menu.configTemplate', path: '/config/pack-market' },
{ icon: Sliders, label: 'sidebar.menu.pluginConfig', path: '/plugin-config' },
{ icon: FileSearch, label: 'sidebar.menu.logViewer', path: '/logs' },
{ icon: FileSearch, label: 'sidebar.menu.logViewer', path: '/logs', searchDescription: 'search.items.logsDesc' },
{ icon: Activity, label: 'sidebar.menu.plannerMonitor', path: '/planner-monitor' },
{ icon: MessageSquare, label: 'sidebar.menu.localChat', path: '/chat' },
],
@@ -43,7 +43,7 @@ export const menuSections: MenuSection[] = [
{
title: 'sidebar.groups.system',
items: [
{ icon: Settings, label: 'sidebar.menu.settings', path: '/settings' },
{ icon: Settings, label: 'sidebar.menu.settings', path: '/settings', searchDescription: 'search.items.settingsDesc' },
],
},
]

View File

@@ -9,6 +9,7 @@ export interface MenuItem {
icon: ComponentType<LucideProps>
label: string
path: string
searchDescription?: string
tourId?: string
}

View File

@@ -1,16 +1,20 @@
import { useState, useCallback, useMemo } from 'react'
import { Search, FileText, Server, Boxes, Smile, MessageSquare, UserCircle, FileSearch, BarChart3, Package, Settings, Home, Hash } from 'lucide-react'
import { useState, useCallback, useEffect, useMemo, useRef } from 'react'
import { Search } from 'lucide-react'
import { useNavigate } from '@tanstack/react-router'
import { useTranslation } from 'react-i18next'
import type { LucideProps } from 'lucide-react'
import {
Dialog,
DialogBody,
DialogContent,
DialogHeader,
DialogTitle,
} from '@/components/ui/dialog'
import { Input } from '@/components/ui/input'
import { ScrollArea } from '@/components/ui/scroll-area'
import { ShortcutKbd } from '@/components/ui/kbd'
import { menuSections } from '@/components/layout/constants'
import { registeredRoutePaths } from '@/router'
import { cn } from '@/lib/utils'
interface SearchDialogProps {
@@ -19,7 +23,7 @@ interface SearchDialogProps {
}
interface SearchItem {
icon: React.ComponentType<{ className?: string }>
icon: React.ComponentType<LucideProps>
title: string
description: string
path: string
@@ -29,95 +33,37 @@ interface SearchItem {
export function SearchDialog({ open, onOpenChange }: SearchDialogProps) {
const [searchQuery, setSearchQuery] = useState('')
const [selectedIndex, setSelectedIndex] = useState(0)
const inputRef = useRef<HTMLInputElement>(null)
const navigate = useNavigate()
const { t } = useTranslation()
const searchItems: SearchItem[] = useMemo(() => [
{
icon: Home,
title: t('search.items.home'),
description: t('search.items.homeDesc'),
path: '/',
category: t('search.categories.overview'),
},
{
icon: FileText,
title: t('search.items.botConfig'),
description: t('search.items.botConfigDesc'),
path: '/config/bot',
category: t('search.categories.config'),
},
{
icon: Server,
title: t('search.items.modelProvider'),
description: t('search.items.modelProviderDesc'),
path: '/config/modelProvider',
category: t('search.categories.config'),
},
{
icon: Boxes,
title: t('search.items.model'),
description: t('search.items.modelDesc'),
path: '/config/model',
category: t('search.categories.config'),
},
{
icon: Smile,
title: t('search.items.emoji'),
description: t('search.items.emojiDesc'),
path: '/resource/emoji',
category: t('search.categories.resources'),
},
{
icon: MessageSquare,
title: t('search.items.expression'),
description: t('search.items.expressionDesc'),
path: '/resource/expression',
category: t('search.categories.resources'),
},
{
icon: UserCircle,
title: t('search.items.person'),
description: t('search.items.personDesc'),
path: '/resource/person',
category: t('search.categories.resources'),
},
{
icon: Hash,
title: t('search.items.jargon'),
description: t('search.items.jargonDesc'),
path: '/resource/jargon',
category: t('search.categories.resources'),
},
{
icon: BarChart3,
title: t('search.items.statistics'),
description: t('search.items.statisticsDesc'),
path: '/statistics',
category: t('search.categories.monitor'),
},
{
icon: Package,
title: t('search.items.plugins'),
description: t('search.items.pluginsDesc'),
path: '/plugins',
category: t('search.categories.extensions'),
},
{
icon: FileSearch,
title: t('search.items.logs'),
description: t('search.items.logsDesc'),
path: '/logs',
category: t('search.categories.monitor'),
},
{
icon: Settings,
title: t('search.items.settings'),
description: t('search.items.settingsDesc'),
path: '/settings',
category: t('search.categories.system'),
},
], [t])
useEffect(() => {
if (!open) {
return
}
const frameId = window.requestAnimationFrame(() => {
inputRef.current?.focus()
})
return () => window.cancelAnimationFrame(frameId)
}, [open])
const searchItems: SearchItem[] = useMemo(
() =>
menuSections.flatMap((section) =>
section.items
.filter((item) => registeredRoutePaths.has(item.path))
.map((item) => ({
icon: item.icon,
title: t(item.label),
description: item.searchDescription ? t(item.searchDescription) : item.path,
path: item.path,
category: t(section.title),
}))
),
[t]
)
// 过滤搜索结果
const filteredItems = searchItems.filter(
@@ -155,12 +101,13 @@ export function SearchDialog({ open, onOpenChange }: SearchDialogProps) {
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="max-w-2xl p-0 gap-0">
<DialogContent className="max-w-2xl p-0 gap-0" confirmOnEnter>
<DialogHeader className="px-4 pt-4 pb-0">
<DialogTitle className="sr-only">{t('search.title')}</DialogTitle>
<div className="relative">
<Search className="absolute left-3 top-1/2 h-5 w-5 -translate-y-1/2 text-muted-foreground" />
<Input
ref={inputRef}
value={searchQuery}
onChange={(e) => {
setSearchQuery(e.target.value)
@@ -169,13 +116,12 @@ export function SearchDialog({ open, onOpenChange }: SearchDialogProps) {
onKeyDown={handleKeyDown}
placeholder={t('search.placeholder')}
className="h-12 pl-11 text-base border-0 focus-visible:ring-0 shadow-none"
autoFocus
/>
</div>
</DialogHeader>
<div className="border-t">
<ScrollArea className="h-[400px]">
<DialogBody className="h-100" viewportClassName="px-0">
{filteredItems.length > 0 ? (
<div className="p-2">
{filteredItems.map((item, index) => {
@@ -192,7 +138,7 @@ export function SearchDialog({ open, onOpenChange }: SearchDialogProps) {
: 'hover:bg-accent/50'
)}
>
<Icon className="h-5 w-5 flex-shrink-0" />
<Icon className="h-5 w-5 shrink-0" />
<div className="flex-1 min-w-0">
<div className="font-medium text-sm">{item.title}</div>
<div className="text-xs text-muted-foreground truncate">
@@ -214,22 +160,22 @@ export function SearchDialog({ open, onOpenChange }: SearchDialogProps) {
</p>
</div>
)}
</ScrollArea>
</DialogBody>
</div>
<div className="border-t px-4 py-3 flex items-center justify-between text-xs text-muted-foreground">
<div className="flex items-center gap-4">
<span className="flex items-center gap-1">
<kbd className="px-1.5 py-0.5 bg-muted rounded border"></kbd>
<kbd className="px-1.5 py-0.5 bg-muted rounded border"></kbd>
<ShortcutKbd size="sm" keys={['up']} />
<ShortcutKbd size="sm" keys={['down']} />
{t('search.navigate')}
</span>
<span className="flex items-center gap-1">
<kbd className="px-1.5 py-0.5 bg-muted rounded border">Enter</kbd>
<ShortcutKbd size="sm" keys={['enter']} />
{t('search.select')}
</span>
<span className="flex items-center gap-1">
<kbd className="px-1.5 py-0.5 bg-muted rounded border">Esc</kbd>
<ShortcutKbd size="sm" keys={['esc']} />
{t('search.close')}
</span>
</div>

View File

@@ -24,6 +24,7 @@ import { Checkbox } from '@/components/ui/checkbox'
import { Badge } from '@/components/ui/badge'
import {
Dialog,
DialogBody,
DialogContent,
DialogDescription,
DialogFooter,
@@ -34,7 +35,6 @@ import {
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'
import { Alert, AlertDescription, AlertTitle } from '@/components/ui/alert'
import { Separator } from '@/components/ui/separator'
import { ScrollArea } from '@/components/ui/scroll-area'
import { toast } from '@/hooks/use-toast'
import {
createPack,
@@ -340,7 +340,7 @@ export function SharePackDialog({ trigger }: SharePackDialogProps) {
)}
</DialogTrigger>
<DialogContent className="max-w-2xl max-h-[85vh] flex flex-col">
<DialogContent className="max-w-2xl flex flex-col" confirmOnEnter>
<DialogHeader>
<DialogTitle className="flex items-center gap-2">
<Package className="w-5 h-5" />
@@ -353,7 +353,7 @@ export function SharePackDialog({ trigger }: SharePackDialogProps) {
</DialogDescription>
</DialogHeader>
<ScrollArea className="h-[calc(85vh-220px)] pr-4">
<DialogBody>
{loading ? (
<div className="py-8 text-center">
<Loader2 className="w-8 h-8 mx-auto animate-spin text-primary" />
@@ -639,7 +639,7 @@ export function SharePackDialog({ trigger }: SharePackDialogProps) {
)}
</>
)}
</ScrollArea>
</DialogBody>
<DialogFooter className="flex justify-between pt-4 border-t">
<div>
@@ -662,6 +662,7 @@ export function SharePackDialog({ trigger }: SharePackDialogProps) {
</Button>
{step < totalSteps ? (
<Button
data-dialog-action="confirm"
onClick={() => setStep(step + 1)}
disabled={
loading ||
@@ -671,7 +672,7 @@ export function SharePackDialog({ trigger }: SharePackDialogProps) {
</Button>
) : (
<Button onClick={handleSubmit} disabled={submitting}>
<Button data-dialog-action="confirm" onClick={handleSubmit} disabled={submitting}>
{submitting && <Loader2 className="w-4 h-4 mr-2 animate-spin" />}
</Button>

View File

@@ -1,8 +1,13 @@
import * as React from "react"
import * as DialogPrimitive from "@radix-ui/react-dialog"
import { cn } from "@/lib/utils"
import { X } from "lucide-react"
import { isEditableTarget, matchesShortcut } from "@/lib/keyboard"
import { ScrollArea } from "@/components/ui/scroll-area"
const Dialog = DialogPrimitive.Root
const DialogTrigger = DialogPrimitive.Trigger
@@ -32,22 +37,48 @@ interface DialogContentProps
preventOutsideClose?: boolean
/** 隐藏默认关闭按钮(当使用自定义关闭按钮时) */
hideCloseButton?: boolean
/** 回车触发主操作按钮 */
confirmOnEnter?: boolean
}
interface DialogBodyProps extends React.ComponentPropsWithoutRef<typeof ScrollArea> {
allowHorizontalScroll?: boolean
}
const DialogContent = React.forwardRef<
React.ElementRef<typeof DialogPrimitive.Content>,
DialogContentProps
>(({ className, children, preventOutsideClose = false, hideCloseButton = false, ...props }, ref) => (
>(({ className, children, preventOutsideClose = false, hideCloseButton = false, confirmOnEnter = false, onKeyDownCapture, ...props }, ref) => (
<DialogPortal>
<DialogOverlay />
<DialogPrimitive.Content
ref={ref}
className={cn(
"fixed left-[50%] top-[50%] z-50 grid w-full max-w-lg translate-x-[-50%] translate-y-[-50%] gap-4 border bg-background p-6 shadow-lg duration-200 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[state=closed]:slide-out-to-left-1/2 data-[state=closed]:slide-out-to-top-[48%] data-[state=open]:slide-in-from-left-1/2 data-[state=open]:slide-in-from-top-[48%] sm:rounded-lg",
"fixed left-[50%] top-[50%] z-50 grid w-[min(calc(100vw-2rem),var(--dialog-width,32rem))] max-h-[calc(100vh-2rem)] translate-x-[-50%] translate-y-[-50%] gap-4 overflow-hidden border bg-background p-6 shadow-lg duration-200 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[state=closed]:slide-out-to-left-1/2 data-[state=closed]:slide-out-to-top-[48%] data-[state=open]:slide-in-from-left-1/2 data-[state=open]:slide-in-from-top-[48%] sm:rounded-lg",
className
)}
onPointerDownOutside={preventOutsideClose ? (e) => e.preventDefault() : undefined}
onInteractOutside={preventOutsideClose ? (e) => e.preventDefault() : undefined}
onKeyDownCapture={(event) => {
onKeyDownCapture?.(event)
if (
!confirmOnEnter ||
event.defaultPrevented ||
!matchesShortcut(event, ['enter']) ||
event.nativeEvent.isComposing ||
isEditableTarget(event.target)
) {
return
}
const confirmButton = event.currentTarget.querySelector<HTMLElement>('[data-dialog-action="confirm"]:not([disabled])')
if (!confirmButton) {
return
}
event.preventDefault()
confirmButton.click()
}}
{...props}
>
{children}
@@ -62,6 +93,22 @@ const DialogContent = React.forwardRef<
))
DialogContent.displayName = DialogPrimitive.Content.displayName
const DialogBody = React.forwardRef<HTMLDivElement, DialogBodyProps>(
({ className, children, allowHorizontalScroll = false, contentClassName, scrollbars, viewportClassName, ...props }, ref) => (
<ScrollArea
ref={ref as never}
className={cn("min-h-0 flex-1", className)}
contentClassName={cn(allowHorizontalScroll && "min-w-full w-max", contentClassName)}
scrollbars={scrollbars ?? (allowHorizontalScroll ? "both" : "vertical")}
viewportClassName={cn("pr-4", viewportClassName)}
{...props}
>
{children}
</ScrollArea>
)
)
DialogBody.displayName = "DialogBody"
const DialogHeader = ({
className,
...props
@@ -125,6 +172,7 @@ export {
DialogClose,
DialogContent,
DialogHeader,
DialogBody,
DialogFooter,
DialogTitle,
DialogDescription,

View File

@@ -1,6 +1,7 @@
import * as React from "react"
import { cva, type VariantProps } from "class-variance-authority"
import { getPlatformModifierAriaLabel, getShortcutKeyLabel, type ShortcutKey } from "@/lib/keyboard"
import { cn } from "@/lib/utils"
const kbdVariants = cva(
@@ -25,6 +26,10 @@ export interface KbdProps
abbrTitle?: string
}
interface ShortcutKbdProps extends Omit<KbdProps, "children"> {
keys: ShortcutKey[]
}
const Kbd = React.forwardRef<HTMLElement, KbdProps>(
({ className, size, abbrTitle, children, ...props }, ref) => {
return (
@@ -40,4 +45,20 @@ const Kbd = React.forwardRef<HTMLElement, KbdProps>(
)
Kbd.displayName = "Kbd"
export { Kbd }
function ShortcutKbd({ keys, className, size, ...props }: ShortcutKbdProps) {
return (
<span className={cn("inline-flex items-center gap-1", className)}>
{keys.map((key) => {
const label = getShortcutKeyLabel(key)
const abbrTitle = key === 'mod' ? getPlatformModifierAriaLabel() : undefined
return (
<Kbd key={`${key}-${label}`} size={size} abbrTitle={abbrTitle} {...props}>
{label}
</Kbd>
)
})}
</span>
)
}
export { Kbd, ShortcutKbd }

View File

@@ -5,22 +5,25 @@ import { cn } from "@/lib/utils"
interface ScrollAreaProps extends React.ComponentPropsWithoutRef<typeof ScrollAreaPrimitive.Root> {
viewportRef?: React.RefObject<HTMLDivElement | null>
viewportClassName?: string
contentClassName?: string
scrollbars?: "vertical" | "horizontal" | "both"
}
const ScrollArea = React.forwardRef<
React.ElementRef<typeof ScrollAreaPrimitive.Root>,
ScrollAreaProps
>(({ className, children, viewportRef, ...props }, ref) => (
>(({ className, children, viewportRef, viewportClassName, contentClassName, scrollbars = "both", ...props }, ref) => (
<ScrollAreaPrimitive.Root
ref={ref}
className={cn("relative overflow-hidden", className)}
{...props}
>
<ScrollAreaPrimitive.Viewport ref={viewportRef} className="h-full w-full rounded-[inherit]">
{children}
<ScrollAreaPrimitive.Viewport ref={viewportRef} className={cn("h-full w-full rounded-[inherit]", viewportClassName)}>
<div className={contentClassName}>{children}</div>
</ScrollAreaPrimitive.Viewport>
<ScrollBar />
<ScrollBar orientation="horizontal" />
{scrollbars !== "horizontal" && <ScrollBar />}
{scrollbars !== "vertical" && <ScrollBar orientation="horizontal" />}
<ScrollAreaPrimitive.Corner />
</ScrollAreaPrimitive.Root>
))
@@ -36,9 +39,9 @@ const ScrollBar = React.forwardRef<
className={cn(
"flex touch-none select-none transition-colors",
orientation === "vertical" &&
"h-full w-2.5 border-l border-l-transparent p-[1px]",
"h-full w-2.5 border-l border-l-transparent p-px",
orientation === "horizontal" &&
"h-2.5 flex-col border-t border-t-transparent p-[1px]",
"h-2.5 flex-col border-t border-t-transparent p-px",
className
)}
{...props}

View File

@@ -1,136 +0,0 @@
import { fetchWithAuth } from './fetch-with-auth'
export interface TimeFootprintData {
total_online_hours: number
first_message_time: string | null
first_message_user: string | null
first_message_content: string | null
busiest_day: string | null
busiest_day_count: number
hourly_distribution: number[]
midnight_chat_count: number
is_night_owl: boolean
}
export interface SocialNetworkData {
total_groups: number
top_groups: Array<{
group_id: string
group_name: string
message_count: number
is_webui?: boolean
}>
top_users: Array<{
user_id: string
user_nickname: string
message_count: number
is_webui?: boolean
}>
at_count: number
mentioned_count: number
longest_companion_user: string | null
longest_companion_days: number
}
export interface BrainPowerData {
total_tokens: number
total_cost: number
favorite_model: string | null
favorite_model_count: number
model_distribution: Array<{
model: string
count: number
tokens: number
cost: number
}>
top_reply_models: Array<{
model: string
count: number
}>
most_expensive_cost: number
most_expensive_time: string | null
top_token_consumers: Array<{
user_id: string
cost: number
tokens: number
}>
silence_rate: number
total_actions: number
no_reply_count: number
avg_interest_value: number
max_interest_value: number
max_interest_time: string | null
avg_reasoning_length: number
max_reasoning_length: number
max_reasoning_time: string | null
}
export interface ExpressionVibeData {
top_emoji: {
id: number
path: string
description: string
usage_count: number
hash: string
} | null
top_emojis: Array<{
id: number
path: string
description: string
usage_count: number
hash: string
}>
top_expressions: Array<{
style: string
count: number
}>
rejected_expression_count: number
checked_expression_count: number
total_expressions: number
action_types: Array<{
action: string
count: number
}>
image_processed_count: number
late_night_reply: {
time: string
content: string
} | null
favorite_reply: {
content: string
count: number
} | null
}
export interface AchievementData {
new_jargon_count: number
sample_jargons: Array<{
content: string
meaning: string
count: number
}>
total_messages: number
total_replies: number
}
export interface AnnualReportData {
year: number
bot_name: string
generated_at: string
time_footprint: TimeFootprintData
social_network: SocialNetworkData
brain_power: BrainPowerData
expression_vibe: ExpressionVibeData
achievements: AchievementData
}
export async function getAnnualReport(year: number = 2025): Promise<AnnualReportData> {
const response = await fetchWithAuth(`/api/webui/annual-report/full?year=${year}`)
if (!response.ok) {
const error = await response.json()
throw new Error(error.detail || '获取年度报告失败')
}
return response.json()
}

View File

@@ -0,0 +1,93 @@
export type ShortcutKey =
| 'mod'
| 'shift'
| 'alt'
| 'enter'
| 'esc'
| 'up'
| 'down'
| 'left'
| 'right'
| string
const MAC_PLATFORMS = /(Mac|iPhone|iPod|iPad)/i
export function isMacLikePlatform(): boolean {
if (typeof navigator === 'undefined') {
return false
}
return MAC_PLATFORMS.test(navigator.platform || navigator.userAgent)
}
export function getShortcutKeyLabel(key: ShortcutKey): string {
const isMacLike = isMacLikePlatform()
const normalizedKey = key.toLowerCase()
switch (normalizedKey) {
case 'mod':
return isMacLike ? '⌘' : 'Ctrl'
case 'shift':
return isMacLike ? '⇧' : 'Shift'
case 'alt':
return isMacLike ? '⌥' : 'Alt'
case 'enter':
return isMacLike ? '↵' : 'Enter'
case 'esc':
case 'escape':
return 'Esc'
case 'up':
return '↑'
case 'down':
return '↓'
case 'left':
return '←'
case 'right':
return '→'
default:
return key.length === 1 ? key.toUpperCase() : key
}
}
export function getPlatformModifierAriaLabel(): string {
return isMacLikePlatform() ? 'Command' : 'Control'
}
export function matchesShortcut(event: KeyboardEvent | React.KeyboardEvent, keys: ShortcutKey[]): boolean {
const normalizedKeys = keys.map((key) => key.toLowerCase())
const eventKey = event.key.toLowerCase()
const modifierChecks = {
mod: isMacLikePlatform() ? event.metaKey : event.ctrlKey,
shift: event.shiftKey,
alt: event.altKey,
}
for (const key of normalizedKeys) {
if (key in modifierChecks) {
if (!modifierChecks[key as keyof typeof modifierChecks]) {
return false
}
continue
}
if (eventKey !== key) {
return false
}
}
return true
}
export function isEditableTarget(target: EventTarget | null): boolean {
if (!(target instanceof HTMLElement)) {
return false
}
return (
target.tagName === 'INPUT' ||
target.tagName === 'TEXTAREA' ||
target.isContentEditable ||
target.getAttribute('role') === 'textbox'
)
}

View File

@@ -23,9 +23,6 @@ export const STORAGE_KEYS = {
WS_MAX_RECONNECT_ATTEMPTS: 'maibot-ws-max-reconnect-attempts',
// 用户数据
// 注意ACCESS_TOKEN 已弃用,现在使用 HttpOnly Cookie 存储认证信息
// 保留此常量仅用于向后兼容和清理旧数据
ACCESS_TOKEN: 'access-token',
COMPLETED_TOURS: 'maibot-completed-tours',
CHAT_USER_ID: 'maibot_webui_user_id',
CHAT_USER_NAME: 'maibot_webui_user_name',
@@ -211,10 +208,8 @@ export function clearLocalCache(): { clearedKeys: string[]; preservedKeys: strin
const keysToRemove: string[] = []
for (let i = 0; i < localStorage.length; i++) {
const key = localStorage.key(i)
if (key) {
if (key.startsWith('maibot') || key.startsWith('accent-color') || key === 'access-token') {
keysToRemove.push(key)
}
if (key && (key.startsWith('maibot') || key.startsWith('accent-color'))) {
keysToRemove.push(key)
}
}

View File

@@ -24,7 +24,6 @@ import { PluginMirrorsPage } from './routes/plugin-mirrors'
import { PluginDetailPage } from './routes/plugin-detail'
import { ChatPage } from './routes/chat/index'
import { WebUIFeedbackSurveyPage, MaiBotFeedbackSurveyPage } from './routes/survey'
import { AnnualReportPage } from './routes/annual-report'
import PackMarketPage from './routes/config/pack-market'
import PackDetailPage from './routes/config/pack-detail'
import { Layout } from './components/layout'
@@ -241,13 +240,6 @@ const maibotFeedbackSurveyRoute = createRoute({
component: MaiBotFeedbackSurveyPage,
})
// 年度报告路由
const annualReportRoute = createRoute({
getParentRoute: () => protectedRoute,
path: '/annual-report',
component: AnnualReportPage,
})
// 404 路由
const notFoundRoute = createRoute({
getParentRoute: () => rootRoute,
@@ -284,11 +276,23 @@ const routeTree = rootRoute.addChildren([
packDetailRoute,
webuiFeedbackSurveyRoute,
maibotFeedbackSurveyRoute,
annualReportRoute,
]),
notFoundRoute,
])
type RouteNode = {
fullPath?: string
children?: RouteNode[]
}
function collectRoutePaths(node: RouteNode): string[] {
const currentPath = node.fullPath ? [node.fullPath] : []
const childPaths = node.children?.flatMap(collectRoutePaths) ?? []
return [...currentPath, ...childPaths]
}
export const registeredRoutePaths = new Set(collectRoutePaths(routeTree as RouteNode))
// 创建路由器
export const router = createRouter({
routeTree,

View File

@@ -1,883 +0,0 @@
import { useState, useRef, useEffect, useCallback } from 'react'
import { getAnnualReport, type AnnualReportData } from '@/lib/annual-report-api'
import { Card, CardContent, CardHeader, CardTitle, CardDescription } from '@/components/ui/card'
import { Skeleton } from '@/components/ui/skeleton'
import { Badge } from '@/components/ui/badge'
import { ScrollArea } from '@/components/ui/scroll-area'
import { Button } from '@/components/ui/button'
import { useToast } from '@/hooks/use-toast'
import { toPng } from 'html-to-image'
import {
BarChart,
Bar,
XAxis,
YAxis,
CartesianGrid,
Tooltip,
ResponsiveContainer,
} from 'recharts'
import {
Clock,
Users,
Brain,
Smile,
Trophy,
Calendar,
MessageSquare,
Zap,
Moon,
Sun,
AtSign,
Heart,
Image as ImageIcon,
Bot,
Download,
Loader2,
} from 'lucide-react'
import { cn } from '@/lib/utils'
// 颜色常量
const COLORS = ['#0088FE', '#00C49F', '#FFBB28', '#FF8042', '#8884d8', '#82ca9d']
// 动态比喻生成函数
function getOnlineHoursMetaphor(hours: number): string {
if (hours >= 8760) return "相当于全年无休7x24小时在线"
if (hours >= 5000) return "相当于一位全职员工的年工作时长"
if (hours >= 2000) return "相当于看完了 1000 部电影"
if (hours >= 1000) return "相当于环球飞行 80 次"
if (hours >= 500) return "相当于读完了 100 本书"
if (hours >= 100) return "相当于马拉松跑了 25 次"
return "虽然不多,但每一刻都很珍贵"
}
function getMidnightMetaphor(count: number): string {
if (count >= 1000) return "夜深人静时的知心好友"
if (count >= 500) return "午夜场的常客"
if (count >= 100) return "偶尔熬夜的小伙伴"
if (count >= 50) return "深夜有时也会陪你聊聊"
return "早睡早起,健康作息"
}
function getTokenMetaphor(tokens: number): string {
const millions = tokens / 1000000
if (millions >= 100) return "思考量堪比一座图书馆"
if (millions >= 50) return "相当于写了一部百科全书"
if (millions >= 10) return "脑细胞估计消耗了不少"
if (millions >= 1) return "也算是费了一番脑筋"
return "轻轻松松,游刃有余"
}
function getCostMetaphor(cost: number): string {
if (cost >= 1000) return "这钱够吃一年的泡面了"
if (cost >= 500) return "相当于买了一台游戏机"
if (cost >= 100) return "够请大家喝几杯奶茶"
if (cost >= 50) return "一顿火锅的钱"
if (cost >= 10) return "几杯咖啡的价格"
return "省钱小能手"
}
function getSilenceMetaphor(rate: number): string {
if (rate >= 80) return "沉默是金,惜字如金"
if (rate >= 60) return "话不多但句句到位"
if (rate >= 40) return "该说的时候才开口"
if (rate >= 20) return "能聊的都聊了"
return "话痨本痨,有问必答"
}
function getImageMetaphor(count: number): string {
if (count >= 10000) return "眼睛都快看花了"
if (count >= 5000) return "堪比专业摄影师的阅片量"
if (count >= 1000) return "看图小达人"
if (count >= 500) return "图片鉴赏家"
if (count >= 100) return "偶尔欣赏一下美图"
return "图片?有空再看"
}
function getRejectedMetaphor(count: number): string {
if (count >= 500) return "在不断的纠正中成长"
if (count >= 200) return "学习永无止境"
if (count >= 100) return "虚心接受,积极改正"
if (count >= 50) return "偶尔也会犯错"
if (count >= 10) return "表现还算不错"
return "完美表达,无需纠正"
}
function getExpensiveThinkingMetaphor(cost: number): string {
if (cost >= 1) return "这次思考的价值堪比一顿大餐!"
if (cost >= 0.5) return "为了这个问题,我可是认真思考了!"
if (cost >= 0.1) return "下了点功夫,值得的!"
if (cost >= 0.01) return "花了点小钱,但很值得"
return "小小思考,不足挂齿"
}
function getFavoriteReplyMetaphor(count: number, botName: string): string {
if (count >= 100) return "这句话简直是万能钥匙!"
if (count >= 50) return "百试不爽的经典回复"
if (count >= 20) return `${botName}的口头禅`
if (count >= 10) return "常用语录之一"
return "偶尔用用的小确幸"
}
function getNightOwlMetaphor(isNightOwl: boolean, midnightCount: number): string {
if (isNightOwl) {
if (midnightCount >= 1000) return "深夜的守护者,黑暗中的光芒"
if (midnightCount >= 500) return "月亮是我的好朋友"
if (midnightCount >= 100) return "越夜越精神,夜晚才是主场"
return "偶尔熬夜,享受宁静时光"
} else {
if (midnightCount <= 10) return "作息规律,健康生活的典范"
if (midnightCount <= 50) return "早睡早起,偶尔也会熬个夜"
return "虽然是早起鸟,但也会守候深夜"
}
}
function getBusiestDayMetaphor(count: number): string {
if (count >= 1000) return "忙到飞起,键盘都要冒烟了"
if (count >= 500) return "这天简直是话痨附体"
if (count >= 200) return "社交达人上线"
if (count >= 100) return "比平时活跃不少"
if (count >= 50) return "小忙一下"
return "还算轻松的一天"
}
export function AnnualReportPage() {
const [year] = useState(2025)
const [data, setData] = useState<AnnualReportData | null>(null)
const [isLoading, setIsLoading] = useState(true)
const [isExporting, setIsExporting] = useState(false)
const [error, setError] = useState<Error | null>(null)
const reportRef = useRef<HTMLDivElement>(null)
const { toast } = useToast()
const loadReport = useCallback(async () => {
try {
setIsLoading(true)
setError(null)
const result = await getAnnualReport(year)
setData(result)
} catch (err) {
setError(err instanceof Error ? err : new Error('获取年度报告失败'))
} finally {
setIsLoading(false)
}
}, [year])
// 导出为图片
const handleExport = useCallback(async () => {
if (!reportRef.current || !data) return
setIsExporting(true)
toast({
title: '正在生成图片',
description: '请稍候...',
})
try {
const element = reportRef.current
// 获取当前主题的背景色
const computedStyle = getComputedStyle(document.documentElement)
const backgroundColor = computedStyle.getPropertyValue('--background').trim()
? `hsl(${computedStyle.getPropertyValue('--background').trim()})`
: (document.documentElement.classList.contains('dark') ? '#0a0a0a' : '#ffffff')
// 保存原始样式
const originalWidth = element.style.width
const originalMaxWidth = element.style.maxWidth
// 临时设置固定宽度以去除左右空白
element.style.width = '1024px'
element.style.maxWidth = '1024px'
const dataUrl = await toPng(element, {
quality: 1,
pixelRatio: 2,
backgroundColor,
cacheBust: true,
filter: (node) => {
// 过滤掉导出按钮
if (node instanceof HTMLElement && node.hasAttribute('data-export-btn')) {
return false
}
return true
},
})
// 恢复原始样式
element.style.width = originalWidth
element.style.maxWidth = originalMaxWidth
// 创建下载链接
const link = document.createElement('a')
link.download = `${data.bot_name}_${data.year}_年度总结.png`
link.href = dataUrl
link.click()
toast({
title: '导出成功',
description: '年度报告已保存为图片',
})
} catch (err) {
console.error('导出图片失败:', err)
toast({
title: '导出失败',
description: '请重试',
variant: 'destructive',
})
} finally {
setIsExporting(false)
}
}, [data, toast])
useEffect(() => {
loadReport()
}, [loadReport])
if (isLoading) {
return <LoadingSkeleton />
}
if (error) {
return (
<div className="flex h-screen items-center justify-center text-red-500">
: {error.message}
</div>
)
}
if (!data) return null
return (
<ScrollArea className="h-[calc(100vh-4rem)]">
<div className="min-h-screen bg-gradient-to-b from-background to-muted/50 p-4 md:p-8 print:p-0" ref={reportRef}>
<div className="mx-auto max-w-5xl space-y-8 print:space-y-4">
{/* 头部 Hero */}
<header className="relative overflow-hidden rounded-3xl bg-primary p-8 text-primary-foreground shadow-2xl print:rounded-none print:shadow-none">
{/* 导出按钮 */}
<div className="absolute right-4 top-4 z-20 print:hidden" data-export-btn>
<Button
variant="secondary"
size="sm"
onClick={handleExport}
disabled={isExporting}
className="gap-2 bg-white/20 hover:bg-white/30 text-white border-white/30"
>
{isExporting ? (
<>
<Loader2 className="h-4 w-4 animate-spin" />
...
</>
) : (
<>
<Download className="h-4 w-4" />
</>
)}
</Button>
</div>
<div className="relative z-10 flex flex-col items-center text-center">
<Bot className="mb-4 h-16 w-16 animate-bounce" />
<h1 className="text-4xl font-bold tracking-tighter sm:text-6xl">
{data.bot_name} {data.year}
</h1>
<p className="mt-4 max-w-2xl text-lg opacity-90">
· Connection & Growth
</p>
<div className="mt-6 flex items-center gap-2 text-sm opacity-75">
<Calendar className="h-4 w-4" />
<span>: {data.generated_at}</span>
</div>
</div>
{/* 背景装饰 */}
<div className="absolute -right-20 -top-20 h-64 w-64 rounded-full bg-white/10 blur-3xl" />
<div className="absolute -bottom-20 -left-20 h-64 w-64 rounded-full bg-white/10 blur-3xl" />
</header>
{/* 维度一:时光足迹 */}
<section className="space-y-4 break-inside-avoid">
<div className="flex items-center gap-2 text-2xl font-bold text-primary">
<Clock className="h-8 w-8" />
<h2></h2>
</div>
<div className="grid gap-4 md:grid-cols-2 lg:grid-cols-4">
<StatCard
title="年度在线时长"
value={`${data.time_footprint.total_online_hours} 小时`}
description={getOnlineHoursMetaphor(data.time_footprint.total_online_hours)}
icon={<Clock className="h-4 w-4" />}
/>
<StatCard
title="最忙碌的一天"
value={data.time_footprint.busiest_day || 'N/A'}
description={getBusiestDayMetaphor(data.time_footprint.busiest_day_count)}
icon={<Calendar className="h-4 w-4" />}
/>
<StatCard
title="深夜互动 (0-4点)"
value={`${data.time_footprint.midnight_chat_count}`}
description={getMidnightMetaphor(data.time_footprint.midnight_chat_count)}
icon={<Moon className="h-4 w-4" />}
/>
<StatCard
title="作息属性"
value={data.time_footprint.is_night_owl ? '夜猫子' : '早起鸟'}
description={getNightOwlMetaphor(data.time_footprint.is_night_owl, data.time_footprint.midnight_chat_count)}
icon={data.time_footprint.is_night_owl ? <Moon className="h-4 w-4" /> : <Sun className="h-4 w-4" />}
/>
</div>
<Card className="overflow-hidden">
<CardHeader>
<CardTitle>24</CardTitle>
<CardDescription>{data.bot_name}</CardDescription>
</CardHeader>
<CardContent className="h-[300px]">
<ResponsiveContainer width="100%" height="100%">
<BarChart data={data.time_footprint.hourly_distribution.map((count: number, hour: number) => ({ hour: `${hour}`, count }))}>
<CartesianGrid strokeDasharray="3 3" vertical={false} />
<XAxis dataKey="hour" />
<YAxis />
<Tooltip
contentStyle={{ borderRadius: '8px', border: 'none', boxShadow: '0 4px 12px rgba(0,0,0,0.1)' }}
cursor={{ fill: 'transparent' }}
/>
<Bar dataKey="count" fill="hsl(var(--color-primary))" radius={[4, 4, 0, 0]} />
</BarChart>
</ResponsiveContainer>
</CardContent>
</Card>
{data.time_footprint.first_message_time && (
<Card className="bg-muted/30 border-dashed">
<CardContent className="flex flex-col items-center justify-center p-6 text-center">
<p className="text-muted-foreground mb-2">2025</p>
<div className="text-xl font-bold text-primary mb-1">{data.time_footprint.first_message_time}</div>
<p className="text-lg">
<span className="font-semibold text-foreground">{data.time_footprint.first_message_user}</span>
<span className="italic text-muted-foreground">"{data.time_footprint.first_message_content}"</span>
</p>
</CardContent>
</Card>
)}
</section>
{/* 维度二:社交网络 */}
<section className="space-y-4 break-inside-avoid">
<div className="flex items-center gap-2 text-2xl font-bold text-primary">
<Users className="h-8 w-8" />
<h2></h2>
</div>
<div className="grid gap-4 md:grid-cols-3">
<StatCard
title="社交圈子"
value={`${data.social_network.total_groups} 个群组`}
description={`${data.bot_name}加入的群组总数`}
icon={<Users className="h-4 w-4" />}
/>
<StatCard
title="被呼叫次数"
value={`${data.social_network.at_count + data.social_network.mentioned_count}`}
description="我的名字被大家频繁提起"
icon={<AtSign className="h-4 w-4" />}
/>
<StatCard
title="最长情陪伴"
value={data.social_network.longest_companion_user || 'N/A'}
description={`始终都在,已陪伴 ${data.social_network.longest_companion_days}`}
icon={<Heart className="h-4 w-4 text-red-500" />}
/>
</div>
<div className="grid gap-4 md:grid-cols-2">
<Card>
<CardHeader>
<CardTitle> TOP5</CardTitle>
</CardHeader>
<CardContent>
<div className="space-y-3">
{data.social_network.top_groups.length > 0 ? (
data.social_network.top_groups.map((group: { group_id: string; group_name: string; message_count: number; is_webui?: boolean }, index: number) => (
<div key={group.group_id} className="flex items-center justify-between">
<div className="flex items-center gap-3">
<Badge variant={index === 0 ? "default" : "secondary"} className="h-6 w-6 rounded-full p-0 flex items-center justify-center shrink-0">
{index + 1}
</Badge>
<span className="font-medium truncate max-w-[120px]">{group.group_name}</span>
{group.is_webui && (
<Badge variant="outline" className="text-xs px-1.5 py-0 h-5 bg-blue-50 text-blue-600 border-blue-200">
WebUI
</Badge>
)}
</div>
<span className="text-muted-foreground text-sm shrink-0">{group.message_count} </span>
</div>
))
) : (
<div className="text-center text-muted-foreground py-4"></div>
)}
</div>
</CardContent>
</Card>
<Card>
<CardHeader>
<CardTitle> TOP5</CardTitle>
</CardHeader>
<CardContent>
<div className="space-y-3">
{data.social_network.top_users.length > 0 ? (
data.social_network.top_users.map((user: { user_id: string; user_nickname: string; message_count: number; is_webui?: boolean }, index: number) => (
<div key={user.user_id} className="flex items-center justify-between">
<div className="flex items-center gap-3">
<Badge variant={index === 0 ? "default" : "secondary"} className="h-6 w-6 rounded-full p-0 flex items-center justify-center shrink-0">
{index + 1}
</Badge>
<span className="font-medium truncate max-w-[120px]">{user.user_nickname}</span>
{user.is_webui && (
<Badge variant="outline" className="text-xs px-1.5 py-0 h-5 bg-blue-50 text-blue-600 border-blue-200">
WebUI
</Badge>
)}
</div>
<span className="text-muted-foreground text-sm shrink-0">{user.message_count} </span>
</div>
))
) : (
<div className="text-center text-muted-foreground py-4"></div>
)}
</div>
</CardContent>
</Card>
</div>
</section>
{/* 维度三:最强大脑 */}
<section className="space-y-4 break-inside-avoid">
<div className="flex items-center gap-2 text-2xl font-bold text-primary">
<Brain className="h-8 w-8" />
<h2></h2>
</div>
<div className="grid gap-4 md:grid-cols-2 lg:grid-cols-4">
<StatCard
title="年度 Token 消耗"
value={(data.brain_power.total_tokens / 1000000).toFixed(2) + ' M'}
description={getTokenMetaphor(data.brain_power.total_tokens)}
icon={<Zap className="h-4 w-4" />}
/>
<StatCard
title="年度总花费"
value={`$${data.brain_power.total_cost.toFixed(2)}`}
description={getCostMetaphor(data.brain_power.total_cost)}
icon={<span className="font-bold">$</span>}
/>
<StatCard
title="高冷指数"
value={`${data.brain_power.silence_rate}%`}
description={getSilenceMetaphor(data.brain_power.silence_rate)}
icon={<Moon className="h-4 w-4" />}
/>
<StatCard
title="最高兴趣值"
value={data.brain_power.max_interest_value ?? 'N/A'}
description={data.brain_power.max_interest_time ? `出现在 ${data.brain_power.max_interest_time}` : '暂无数据'}
icon={<Heart className="h-4 w-4" />}
/>
</div>
<div className="grid gap-4 md:grid-cols-2">
<Card>
<CardHeader>
<CardTitle></CardTitle>
</CardHeader>
<CardContent>
<div className="space-y-3">
{data.brain_power.model_distribution.slice(0, 5).map((item: { model: string; count: number }, index: number) => {
const maxCount = data.brain_power.model_distribution[0]?.count || 1
const percentage = Math.round((item.count / maxCount) * 100)
return (
<div key={item.model} className="space-y-1">
<div className="flex justify-between text-sm">
<span className="font-medium truncate max-w-[200px]">{item.model}</span>
<span className="text-muted-foreground">{item.count.toLocaleString()} </span>
</div>
<div className="h-2 w-full overflow-hidden rounded-full bg-secondary">
<div
className="h-full transition-all duration-500"
style={{
width: `${percentage}%`,
backgroundColor: COLORS[index % COLORS.length]
}}
/>
</div>
</div>
)
})}
</div>
</CardContent>
</Card>
{/* 最喜欢的回复模型 TOP5 */}
{data.brain_power.top_reply_models && data.brain_power.top_reply_models.length > 0 && (
<Card>
<CardHeader>
<CardTitle> TOP5</CardTitle>
<CardDescription>{data.bot_name}</CardDescription>
</CardHeader>
<CardContent>
<div className="space-y-3">
{data.brain_power.top_reply_models.map((item: { model: string; count: number }, index: number) => {
const maxCount = data.brain_power.top_reply_models[0]?.count || 1
const percentage = Math.round((item.count / maxCount) * 100)
return (
<div key={item.model} className="space-y-1">
<div className="flex justify-between text-sm">
<span className="font-medium truncate max-w-[200px]">{item.model}</span>
<span className="text-muted-foreground">{item.count.toLocaleString()} </span>
</div>
<div className="h-2 w-full overflow-hidden rounded-full bg-secondary">
<div
className="h-full transition-all duration-500"
style={{
width: `${percentage}%`,
backgroundColor: COLORS[index % COLORS.length]
}}
/>
</div>
</div>
)
})}
</div>
</CardContent>
</Card>
)}
{/* 烧钱大户 - 只有有有效用户数据时才显示 */}
{data.brain_power.top_token_consumers && data.brain_power.top_token_consumers.length > 0 && (
<Card>
<CardHeader>
<CardTitle> TOP3</CardTitle>
<CardDescription> API </CardDescription>
</CardHeader>
<CardContent>
<div className="space-y-6">
{data.brain_power.top_token_consumers.map((consumer: { user_id: string; cost: number; tokens: number }) => (
<div key={consumer.user_id} className="space-y-2">
<div className="flex justify-between text-sm font-medium">
<span> {consumer.user_id}</span>
<span>${consumer.cost.toFixed(2)}</span>
</div>
<div className="h-2 w-full overflow-hidden rounded-full bg-secondary">
<div
className="h-full bg-primary transition-all duration-500"
style={{ width: `${(consumer.cost / (data.brain_power.top_token_consumers[0]?.cost || 1)) * 100}%` }}
/>
</div>
</div>
))}
</div>
</CardContent>
</Card>
)}
</div>
{/* 最昂贵的思考 & 思考深度 */}
<div className="grid gap-4 md:grid-cols-2">
<Card className="bg-gradient-to-br from-amber-50 to-orange-50 dark:from-amber-950/20 dark:to-orange-950/20">
<CardHeader>
<CardTitle className="flex items-center gap-2">
<span className="text-2xl">💰</span>
</CardTitle>
</CardHeader>
<CardContent className="text-center">
<div className="text-4xl font-bold text-amber-600 dark:text-amber-400">
${data.brain_power.most_expensive_cost.toFixed(4)}
</div>
{data.brain_power.most_expensive_time && (
<p className="mt-2 text-sm text-muted-foreground">
{data.brain_power.most_expensive_time}
</p>
)}
<p className="mt-4 text-sm text-muted-foreground">
{getExpensiveThinkingMetaphor(data.brain_power.most_expensive_cost)}
</p>
</CardContent>
</Card>
<Card className="bg-gradient-to-br from-indigo-50 to-blue-50 dark:from-indigo-950/20 dark:to-blue-950/20">
<CardHeader>
<CardTitle className="flex items-center gap-2">
<span className="text-2xl">🧠</span>
</CardTitle>
</CardHeader>
<CardContent>
<div className="grid grid-cols-2 gap-4 text-center">
<div>
<div className="text-2xl font-bold text-indigo-600 dark:text-indigo-400">
{data.brain_power.avg_reasoning_length?.toFixed(0) || 0}
</div>
<div className="text-xs text-muted-foreground"></div>
</div>
<div>
<div className="text-2xl font-bold text-blue-600 dark:text-blue-400">
{data.brain_power.max_reasoning_length?.toLocaleString() || 0}
</div>
<div className="text-xs text-muted-foreground"></div>
</div>
</div>
{data.brain_power.max_reasoning_time && (
<p className="mt-4 text-center text-xs text-muted-foreground">
{data.brain_power.max_reasoning_time}
</p>
)}
</CardContent>
</Card>
</div>
</section>
{/* 维度四:个性与表达 */}
<section className="space-y-4 break-inside-avoid">
<div className="flex items-center gap-2 text-2xl font-bold text-primary">
<Smile className="h-8 w-8" />
<h2></h2>
</div>
{/* 深夜回复 & 最喜欢的回复 */}
{(data.expression_vibe.late_night_reply || data.expression_vibe.favorite_reply) && (
<div className="grid gap-4 md:grid-cols-2">
{data.expression_vibe.late_night_reply && (
<Card className="bg-gradient-to-br from-indigo-50 to-violet-50 dark:from-indigo-950/20 dark:to-violet-950/20">
<CardHeader>
<CardTitle className="flex items-center gap-2">
<span className="text-2xl">🌙</span>
</CardTitle>
<CardDescription> {data.expression_vibe.late_night_reply.time}{data.bot_name}...</CardDescription>
</CardHeader>
<CardContent className="text-center">
<p className="text-lg italic text-muted-foreground">
"{data.expression_vibe.late_night_reply.content}"
</p>
<p className="mt-4 text-sm text-muted-foreground">
</p>
</CardContent>
</Card>
)}
{data.expression_vibe.favorite_reply && (
<Card className="bg-gradient-to-br from-rose-50 to-pink-50 dark:from-rose-950/20 dark:to-pink-950/20">
<CardHeader>
<CardTitle className="flex items-center gap-2">
<span className="text-2xl">💬</span>
</CardTitle>
<CardDescription>使 {data.expression_vibe.favorite_reply.count} </CardDescription>
</CardHeader>
<CardContent className="text-center">
<p className="text-lg font-medium text-primary">
"{data.expression_vibe.favorite_reply.content}"
</p>
<p className="mt-4 text-sm text-muted-foreground">
{getFavoriteReplyMetaphor(data.expression_vibe.favorite_reply.count, data.bot_name)}
</p>
</CardContent>
</Card>
)}
</div>
)}
<div className="grid gap-4 md:grid-cols-2">
{/* 使用最多的表情包 TOP3 */}
<Card className="bg-gradient-to-br from-pink-50 to-purple-50 dark:from-pink-950/20 dark:to-purple-950/20">
<CardHeader>
<CardTitle>使 TOP3</CardTitle>
<CardDescription></CardDescription>
</CardHeader>
<CardContent>
{data.expression_vibe.top_emojis && data.expression_vibe.top_emojis.length > 0 ? (
<div className="flex justify-center gap-4">
{data.expression_vibe.top_emojis.slice(0, 3).map((emoji: { id: number; usage_count: number }, index: number) => (
<div key={emoji.id} className="flex flex-col items-center">
<div className="relative">
<img
src={`/api/webui/emoji/${emoji.id}/thumbnail?original=true`}
alt={`TOP ${index + 1}`}
className="h-24 w-24 rounded-lg object-cover shadow-md transition-transform hover:scale-105"
/>
<Badge
className={cn(
"absolute -top-2 -right-2",
index === 0 ? "bg-yellow-500" : index === 1 ? "bg-gray-400" : "bg-amber-700"
)}
>
{index + 1}
</Badge>
</div>
<p className="mt-2 text-sm text-muted-foreground">{emoji.usage_count} </p>
</div>
))}
</div>
) : (
<div className="flex h-32 items-center justify-center text-muted-foreground"></div>
)}
</CardContent>
</Card>
<div className="space-y-4">
<Card>
<CardHeader>
<CardTitle></CardTitle>
<CardDescription>{data.bot_name}使</CardDescription>
</CardHeader>
<CardContent>
<div className="flex flex-wrap gap-2">
{data.expression_vibe.top_expressions.map((exp: { style: string; count: number }, index: number) => (
<Badge
key={exp.style}
variant="outline"
className={cn(
"px-3 py-1 text-sm",
index === 0 && "border-primary bg-primary/10 text-primary text-base px-4 py-2"
)}
>
{exp.style} ({exp.count})
</Badge>
))}
</div>
</CardContent>
</Card>
<div className="grid grid-cols-2 gap-4">
<StatCard
title="图片鉴赏"
value={`${data.expression_vibe.image_processed_count}`}
description={getImageMetaphor(data.expression_vibe.image_processed_count)}
icon={<ImageIcon className="h-4 w-4" />}
/>
<StatCard
title="成长的足迹"
value={`${data.expression_vibe.rejected_expression_count}`}
description={getRejectedMetaphor(data.expression_vibe.rejected_expression_count)}
icon={<Zap className="h-4 w-4" />}
/>
</div>
</div>
</div>
{/* 行动派 */}
{data.expression_vibe.action_types.length > 0 && (
<Card>
<CardHeader>
<CardTitle className="flex items-center gap-2">
<span className="text-2xl"></span>
</CardTitle>
<CardDescription></CardDescription>
</CardHeader>
<CardContent>
<div className="flex flex-wrap gap-3">
{data.expression_vibe.action_types.map((action: { action: string; count: number }) => (
<div
key={action.action}
className="flex items-center gap-2 rounded-full bg-primary/10 px-4 py-2"
>
<span className="font-medium text-primary">{action.action}</span>
<Badge variant="secondary">{action.count} </Badge>
</div>
))}
</div>
</CardContent>
</Card>
)}
</section>
{/* 维度五:趣味成就 */}
<section className="space-y-4 break-inside-avoid">
<div className="flex items-center gap-2 text-2xl font-bold text-primary">
<Trophy className="h-8 w-8" />
<h2></h2>
</div>
<div className="grid gap-4 md:grid-cols-3">
<Card className="col-span-1 md:col-span-2">
<CardHeader>
<CardTitle>"黑话"</CardTitle>
<CardDescription> {data.achievements.new_jargon_count} </CardDescription>
</CardHeader>
<CardContent>
<div className="flex flex-wrap gap-3">
{data.achievements.sample_jargons.map((jargon: { content: string; meaning: string; count: number }) => (
<div key={jargon.content} className="group relative rounded-lg border bg-card p-3 shadow-sm transition-all hover:shadow-md">
<div className="font-bold text-primary">{jargon.content}</div>
<div className="text-xs text-muted-foreground mt-1 line-clamp-2 max-w-[200px]">
{jargon.meaning || '暂无解释'}
</div>
</div>
))}
</div>
</CardContent>
</Card>
<Card className="flex flex-col justify-center items-center bg-primary text-primary-foreground">
<CardContent className="flex flex-col items-center justify-center p-6 text-center">
<MessageSquare className="h-12 w-12 mb-4 opacity-80" />
<div className="text-4xl font-bold mb-2">{data.achievements.total_messages.toLocaleString()}</div>
<div className="text-sm opacity-80"></div>
<div className="mt-4 text-xs opacity-60">
{data.achievements.total_replies.toLocaleString()}
</div>
</CardContent>
</Card>
</div>
</section>
{/* 底部 */}
<footer className="mt-12 text-center text-muted-foreground">
<p>MaiBot 2025 Annual Report</p>
<p className="text-sm">Generated with by MaiBot Team</p>
</footer>
</div>
</div>
</ScrollArea>
)
}
function StatCard({
title,
value,
description,
icon,
}: {
title: string
value: string | number
description: string
icon: React.ReactNode
}) {
return (
<Card>
<CardHeader className="flex flex-row items-center justify-between space-y-0 pb-2">
<CardTitle className="text-sm font-medium">{title}</CardTitle>
<div className="text-muted-foreground">{icon}</div>
</CardHeader>
<CardContent>
<div className="text-2xl font-bold">{value}</div>
<p className="text-xs text-muted-foreground">{description}</p>
</CardContent>
</Card>
)
}
function LoadingSkeleton() {
return (
<div className="container mx-auto space-y-8 p-8">
<Skeleton className="h-64 w-full rounded-3xl" />
<div className="grid gap-4 md:grid-cols-4">
{[...Array(4)].map((_, i) => (
<Skeleton key={i} className="h-32 w-full" />
))}
</div>
<Skeleton className="h-96 w-full" />
</div>
)
}

View File

@@ -2,6 +2,7 @@ import { Avatar, AvatarFallback } from '@/components/ui/avatar'
import { Button } from '@/components/ui/button'
import {
Dialog,
DialogBody,
DialogContent,
DialogDescription,
DialogFooter,
@@ -54,7 +55,7 @@ export function VirtualIdentityDialog({
}: VirtualIdentityDialogProps) {
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="sm:max-w-[500px] max-h-[85vh] overflow-hidden flex flex-col">
<DialogContent className="sm:max-w-125 max-h-[85vh] overflow-hidden flex flex-col" confirmOnEnter>
<DialogHeader>
<DialogTitle className="flex items-center gap-2">
<UserCircle2 className="h-5 w-5" />
@@ -65,7 +66,7 @@ export function VirtualIdentityDialog({
</DialogDescription>
</DialogHeader>
<div className="space-y-4 flex-1 overflow-hidden flex flex-col">
<DialogBody className="space-y-4 flex-1" viewportClassName="pr-0">
{/* 平台选择 */}
<div className="space-y-2">
<Label className="flex items-center gap-2">
@@ -113,7 +114,7 @@ export function VirtualIdentityDialog({
className="pl-9"
/>
</div>
<ScrollArea className="h-[250px] border rounded-md">
<ScrollArea className="h-62.5 border rounded-md">
<div className="p-2">
{isLoadingPersons ? (
<div className="flex items-center justify-center py-8">
@@ -187,13 +188,14 @@ export function VirtualIdentityDialog({
</p>
</div>
)}
</div>
</DialogBody>
<DialogFooter className="gap-2 sm:gap-0">
<Button variant="outline" onClick={() => onOpenChange(false)}>
</Button>
<Button
<Button
data-dialog-action="confirm"
onClick={onCreateVirtualTab}
disabled={!tempVirtualConfig.platform || !tempVirtualConfig.personId}
>

View File

@@ -18,6 +18,7 @@ import {
} from '@/components/ui/alert-dialog'
import {
Dialog,
DialogBody,
DialogContent,
DialogDescription,
DialogHeader,
@@ -249,7 +250,7 @@ function RegexEditor({
</Button>
</DialogTrigger>
<DialogContent className="max-w-[95vw] sm:max-w-[900px] max-h-[90vh]">
<DialogContent className="max-w-[95vw] sm:max-w-225">
<DialogHeader>
<DialogTitle></DialogTitle>
<DialogDescription className="text-sm">
@@ -257,7 +258,7 @@ function RegexEditor({
</DialogDescription>
</DialogHeader>
<ScrollArea className="max-h-[calc(90vh-120px)]">
<DialogBody>
<Tabs value={activeTab} onValueChange={(v) => setActiveTab(v as 'build' | 'test')} className="w-full">
<TabsList className="grid w-full grid-cols-2">
<TabsTrigger value="build">🔧 </TabsTrigger>
@@ -406,7 +407,7 @@ function RegexEditor({
value={testText}
onChange={(e) => setTestText(e.target.value)}
placeholder="在此输入要测试的文本...&#10;例如:打游戏是这样的"
className="min-h-[100px] text-sm"
className="min-h-25 text-sm"
/>
</div>
@@ -444,7 +445,7 @@ function RegexEditor({
<div className="space-y-2">
<Label className="text-sm font-medium"></Label>
<ScrollArea className="h-40 rounded-md bg-muted p-3">
<div className="text-sm break-words">
<div className="text-sm wrap-break-word">
{renderHighlightedText()}
</div>
</ScrollArea>
@@ -458,7 +459,7 @@ function RegexEditor({
<div className="space-y-2">
{Object.entries(captureGroups).map(([name, value]) => (
<div key={name} className="flex items-start gap-2 text-sm">
<span className="font-mono font-semibold text-primary min-w-[80px]">[{name}]</span>
<span className="font-mono font-semibold text-primary min-w-20">[{name}]</span>
<span className="text-muted-foreground">=</span>
<span className="font-mono bg-muted px-2 py-0.5 rounded">{value}</span>
</div>
@@ -473,7 +474,7 @@ function RegexEditor({
<div className="space-y-2">
<Label className="text-sm font-medium">Reaction </Label>
<ScrollArea className="h-48 rounded-md bg-blue-50 dark:bg-blue-950/30 border border-blue-200 dark:border-blue-800 p-3">
<div className="text-sm break-words">
<div className="text-sm wrap-break-word">
{replacedReaction}
</div>
</ScrollArea>
@@ -497,7 +498,7 @@ function RegexEditor({
</div>
</TabsContent>
</Tabs>
</ScrollArea>
</DialogBody>
</DialogContent>
</Dialog>
)
@@ -628,7 +629,7 @@ export const ProcessingSection = React.memo(function ProcessingSection({
</Button>
</PopoverTrigger>
<PopoverContent className="w-[95vw] sm:w-[500px]">
<PopoverContent className="w-[95vw] sm:w-125">
<div className="space-y-2">
<h4 className="font-medium text-sm"></h4>
<ScrollArea className="h-60 rounded-md bg-muted p-3">
@@ -656,7 +657,7 @@ export const ProcessingSection = React.memo(function ProcessingSection({
</Button>
</PopoverTrigger>
<PopoverContent className="w-[95vw] sm:w-[500px]">
<PopoverContent className="w-[95vw] sm:w-125">
<div className="space-y-2">
<h4 className="font-medium text-sm"></h4>
<ScrollArea className="h-60 rounded-md bg-muted p-3">

View File

@@ -6,6 +6,7 @@ import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'
import { ScrollArea } from '@/components/ui/scroll-area'
import {
Dialog,
DialogBody,
DialogContent,
DialogDescription,
DialogFooter,
@@ -971,9 +972,10 @@ function ModelConfigPageContent() {
{/* 编辑模型对话框 */}
<Dialog open={editDialogOpen} onOpenChange={handleEditDialogClose}>
<DialogContent
className="max-w-[95vw] sm:max-w-2xl max-h-[90vh] overflow-y-auto"
className="max-w-[95vw] sm:max-w-2xl"
data-tour="model-dialog"
preventOutsideClose={tourIsRunning}
confirmOnEnter
>
<DialogHeader>
<DialogTitle>
@@ -982,6 +984,7 @@ function ModelConfigPageContent() {
<DialogDescription></DialogDescription>
</DialogHeader>
<DialogBody>
<div className="grid gap-4 py-4">
<div className="grid gap-2" data-tour="model-name-input">
<Label htmlFor="model_name" className={formErrors.name ? 'text-destructive' : ''}> *</Label>
@@ -1492,12 +1495,13 @@ function ModelConfigPageContent() {
)}
</div>
</div>
</DialogBody>
<DialogFooter>
<Button variant="outline" onClick={() => setEditDialogOpen(false)} data-tour="model-cancel-button">
</Button>
<Button onClick={handleSaveEdit} data-tour="model-save-button"></Button>
<Button data-dialog-action="confirm" onClick={handleSaveEdit} data-tour="model-save-button"></Button>
</DialogFooter>
</DialogContent>
</Dialog>

View File

@@ -3,7 +3,7 @@ import { Check, ChevronsUpDown, Copy, Eye, EyeOff } from 'lucide-react'
import { Button } from '@/components/ui/button'
import { Command, CommandEmpty, CommandGroup, CommandInput, CommandItem, CommandList } from '@/components/ui/command'
import { Dialog, DialogContent, DialogDescription, DialogFooter, DialogHeader, DialogTitle } from '@/components/ui/dialog'
import { Dialog, DialogBody, DialogContent, DialogDescription, DialogFooter, DialogHeader, DialogTitle } from '@/components/ui/dialog'
import { HelpTooltip } from '@/components/ui/help-tooltip'
import { Input } from '@/components/ui/input'
import { Label } from '@/components/ui/label'
@@ -116,9 +116,10 @@ export function ProviderForm({
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent
className="max-w-[95vw] sm:max-w-2xl max-h-[90vh] overflow-y-auto"
className="max-w-[95vw] sm:max-w-2xl"
data-tour="provider-dialog"
preventOutsideClose={tourState.isRunning}
confirmOnEnter
>
<DialogHeader>
<DialogTitle>
@@ -130,6 +131,7 @@ export function ProviderForm({
</DialogHeader>
<form onSubmit={(e) => { e.preventDefault(); handleSaveEdit(); }} autoComplete="off">
<DialogBody>
<div className="grid gap-4 py-4">
<div className="grid gap-2" data-tour="provider-template-select">
<Label htmlFor="template"></Label>
@@ -450,12 +452,13 @@ export function ProviderForm({
</div>
</div>
</div>
</DialogBody>
<DialogFooter>
<Button type="button" variant="outline" onClick={() => onOpenChange(false)} data-tour="provider-cancel-button">
</Button>
<Button type="submit" data-tour="provider-save-button"></Button>
<Button type="submit" data-dialog-action="confirm" data-tour="provider-save-button"></Button>
</DialogFooter>
</form>
</DialogContent>

View File

@@ -40,6 +40,7 @@ import { ScrollArea } from '@/components/ui/scroll-area'
import { Alert, AlertDescription, AlertTitle } from '@/components/ui/alert'
import {
Dialog,
DialogBody,
DialogContent,
DialogDescription,
DialogFooter,
@@ -575,7 +576,7 @@ function ApplyDialog({
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="max-w-2xl max-h-[80vh] overflow-y-auto">
<DialogContent className="max-w-2xl" confirmOnEnter>
<DialogHeader>
<DialogTitle className="flex items-center gap-2">
<Package className="w-5 h-5" />
@@ -589,6 +590,7 @@ function ApplyDialog({
</DialogDescription>
</DialogHeader>
<DialogBody>
{detectingConflicts ? (
<div className="py-8 text-center">
<Loader2 className="w-8 h-8 mx-auto animate-spin text-primary" />
@@ -831,6 +833,7 @@ function ApplyDialog({
)}
</>
)}
</DialogBody>
<DialogFooter className="flex justify-between">
<div>
@@ -845,11 +848,11 @@ function ApplyDialog({
</Button>
{step < totalSteps ? (
<Button onClick={() => setStep(step + 1)} disabled={detectingConflicts}>
<Button data-dialog-action="confirm" onClick={() => setStep(step + 1)} disabled={detectingConflicts}>
</Button>
) : (
<Button onClick={onApply} disabled={applying}>
<Button data-dialog-action="confirm" onClick={onApply} disabled={applying}>
{applying && <Loader2 className="w-4 h-4 mr-2 animate-spin" />}
</Button>

View File

@@ -29,6 +29,7 @@ import { Button } from '@/components/ui/button'
import { Checkbox } from '@/components/ui/checkbox'
import {
Dialog,
DialogBody,
DialogContent,
DialogDescription,
DialogFooter,
@@ -573,7 +574,7 @@ export function PersonManagementPage() {
variant="outline"
size="sm"
onClick={() => handleViewDetail(person)}
className="text-xs px-2 py-1 h-auto flex-shrink-0"
className="text-xs px-2 py-1 h-auto shrink-0"
>
<Eye className="h-3 w-3 mr-1" />
@@ -582,7 +583,7 @@ export function PersonManagementPage() {
variant="outline"
size="sm"
onClick={() => handleEdit(person)}
className="text-xs px-2 py-1 h-auto flex-shrink-0"
className="text-xs px-2 py-1 h-auto shrink-0"
>
<Edit className="h-3 w-3 mr-1" />
@@ -591,7 +592,7 @@ export function PersonManagementPage() {
variant="outline"
size="sm"
onClick={() => setDeleteConfirmPerson(person)}
className="text-xs px-2 py-1 h-auto flex-shrink-0 text-destructive hover:text-destructive"
className="text-xs px-2 py-1 h-auto shrink-0 text-destructive hover:text-destructive"
>
<Trash2 className="h-3 w-3 mr-1" />
@@ -771,7 +772,7 @@ function PersonDetailDialog({
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="max-w-2xl max-h-[80vh] overflow-y-auto">
<DialogContent className="max-w-2xl">
<DialogHeader>
<DialogTitle></DialogTitle>
<DialogDescription>
@@ -779,6 +780,7 @@ function PersonDetailDialog({
</DialogDescription>
</DialogHeader>
<DialogBody>
<div className="space-y-4">
{/* 基本信息 */}
<div className="grid grid-cols-2 gap-4">
@@ -829,6 +831,7 @@ function PersonDetailDialog({
<InfoItem icon={Clock} label="最后更新" value={formatTime(person.last_know)} />
</div>
</div>
</DialogBody>
<DialogFooter>
<Button onClick={() => onOpenChange(false)}></Button>
@@ -919,7 +922,7 @@ function PersonEditDialog({
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="max-w-2xl max-h-[80vh] overflow-y-auto">
<DialogContent className="max-w-2xl" confirmOnEnter>
<DialogHeader>
<DialogTitle></DialogTitle>
<DialogDescription>
@@ -927,6 +930,7 @@ function PersonEditDialog({
</DialogDescription>
</DialogHeader>
<DialogBody>
<div className="space-y-4">
<div className="grid grid-cols-2 gap-4">
<div className="space-y-2">
@@ -974,6 +978,7 @@ function PersonEditDialog({
/>
</div>
</div>
</DialogBody>
<DialogFooter>
<Button variant="outline" onClick={() => onOpenChange(false)}>

View File

@@ -10,6 +10,7 @@ import { Button } from '@/components/ui/button'
import { Checkbox } from '@/components/ui/checkbox'
import {
Dialog,
DialogBody,
DialogContent,
DialogDescription,
DialogFooter,
@@ -59,7 +60,7 @@ export function EmojiDetailDialog({
<DialogHeader>
<DialogTitle></DialogTitle>
</DialogHeader>
<ScrollArea className="max-h-[calc(90vh-8rem)] pr-4">
<DialogBody>
<div className="space-y-4">
{/* 表情包预览图 - 使用原图 */}
<div className="flex justify-center">
@@ -177,7 +178,7 @@ export function EmojiDetailDialog({
</div>
</div>
</div>
</ScrollArea>
</DialogBody>
</DialogContent>
</Dialog>
)
@@ -252,11 +253,12 @@ export function EmojiEditDialog({
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="max-w-2xl">
<DialogContent className="max-w-2xl" confirmOnEnter>
<DialogHeader>
<DialogTitle></DialogTitle>
<DialogDescription></DialogDescription>
</DialogHeader>
<DialogBody>
<div className="space-y-4">
<div>
<Label></Label>
@@ -310,11 +312,12 @@ export function EmojiEditDialog({
</div>
</div>
</div>
</DialogBody>
<DialogFooter>
<Button variant="outline" onClick={() => onOpenChange(false)}>
</Button>
<Button onClick={handleSave} disabled={saving}>
<Button data-dialog-action="confirm" onClick={handleSave} disabled={saving}>
{saving ? '保存中...' : '保存'}
</Button>
</DialogFooter>
@@ -658,7 +661,7 @@ export function EmojiUploadDialog({
<div className="flex gap-6">
{/* 预览图 */}
<div className="flex-shrink-0">
<div className="shrink-0">
<div className="w-32 h-32 rounded-lg border overflow-hidden bg-muted flex items-center justify-center">
<img
src={file.previewUrl}
@@ -764,7 +767,7 @@ export function EmojiUploadDialog({
<div className="grid grid-cols-2 gap-4">
{/* 左侧:文件卡片列表 */}
<ScrollArea className="h-[350px] pr-2">
<ScrollArea className="h-87.5 pr-2">
<div className="space-y-2">
{uploadedFiles.map((file) => {
const complete = isFileComplete(file)
@@ -782,7 +785,7 @@ export function EmojiUploadDialog({
${complete ? 'border-green-500 bg-green-50 dark:bg-green-950/20' : 'border-border hover:border-muted-foreground/50'}
`}
>
<div className="w-12 h-12 rounded border overflow-hidden bg-muted flex-shrink-0 flex items-center justify-center">
<div className="flex h-12 w-12 shrink-0 items-center justify-center overflow-hidden rounded border bg-muted">
<img
src={file.previewUrl}
alt={file.name}
@@ -796,9 +799,9 @@ export function EmojiUploadDialog({
</p>
</div>
{complete ? (
<CheckCircle2 className="h-5 w-5 text-green-500 flex-shrink-0" />
<CheckCircle2 className="h-5 w-5 shrink-0 text-green-500" />
) : (
<div className="h-5 w-5 rounded-full border-2 border-muted-foreground/30 flex-shrink-0" />
<div className="h-5 w-5 shrink-0 rounded-full border-2 border-muted-foreground/30" />
)}
</div>
)
@@ -908,7 +911,7 @@ export function EmojiUploadDialog({
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="max-w-3xl max-h-[90vh] overflow-hidden">
<DialogContent className="max-w-3xl max-h-[90vh] overflow-hidden" confirmOnEnter>
<DialogHeader>
<DialogTitle className="flex items-center gap-2">
<Upload className="h-5 w-5" />
@@ -925,11 +928,11 @@ export function EmojiUploadDialog({
</DialogDescription>
</DialogHeader>
<div className="overflow-y-auto pr-1">
<DialogBody viewportClassName="pr-1">
{step === 'select' && renderSelectStep()}
{step === 'edit-single' && renderEditSingleStep()}
{step === 'edit-multiple' && renderEditMultipleStep()}
</div>
</DialogBody>
</DialogContent>
</Dialog>
)

View File

@@ -15,6 +15,7 @@ import {
import { Button } from '@/components/ui/button'
import {
Dialog,
DialogBody,
DialogContent,
DialogDescription,
DialogFooter,
@@ -65,7 +66,7 @@ export function ExpressionDetailDialog({
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="max-w-2xl max-h-[80vh] overflow-y-auto">
<DialogContent className="max-w-2xl" confirmOnEnter>
<DialogHeader>
<DialogTitle></DialogTitle>
<DialogDescription>
@@ -73,6 +74,7 @@ export function ExpressionDetailDialog({
</DialogDescription>
</DialogHeader>
<DialogBody>
<div className="space-y-4">
<div className="grid grid-cols-2 gap-4">
<InfoItem label="情境" value={expression.situation} />
@@ -131,6 +133,7 @@ export function ExpressionDetailDialog({
</div>
</div>
</div>
</DialogBody>
<DialogFooter>
<Button onClick={() => onOpenChange(false)}></Button>
@@ -233,7 +236,7 @@ export function ExpressionCreateDialog({
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="max-w-2xl max-h-[80vh] overflow-y-auto">
<DialogContent className="max-w-2xl" confirmOnEnter>
<DialogHeader>
<DialogTitle></DialogTitle>
<DialogDescription>
@@ -241,6 +244,7 @@ export function ExpressionCreateDialog({
</DialogDescription>
</DialogHeader>
<DialogBody>
<div className="space-y-4">
<div className="grid grid-cols-2 gap-4">
<div className="space-y-2">
@@ -291,12 +295,13 @@ export function ExpressionCreateDialog({
</Select>
</div>
</div>
</DialogBody>
<DialogFooter>
<Button variant="outline" onClick={() => onOpenChange(false)}>
</Button>
<Button onClick={handleCreate} disabled={saving}>
<Button data-dialog-action="confirm" onClick={handleCreate} disabled={saving}>
{saving ? '创建中...' : '创建'}
</Button>
</DialogFooter>
@@ -371,7 +376,7 @@ export function ExpressionEditDialog({
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="max-w-2xl max-h-[80vh] overflow-y-auto">
<DialogContent className="max-w-2xl" confirmOnEnter>
<DialogHeader>
<DialogTitle></DialogTitle>
<DialogDescription>
@@ -379,6 +384,7 @@ export function ExpressionEditDialog({
</DialogDescription>
</DialogHeader>
<DialogBody>
<div className="space-y-4">
<div className="grid grid-cols-2 gap-4">
<div className="space-y-2">
@@ -474,12 +480,13 @@ export function ExpressionEditDialog({
</div>
</div>
</div>
</DialogBody>
<DialogFooter>
<Button variant="outline" onClick={() => onOpenChange(false)}>
</Button>
<Button onClick={handleSave} disabled={saving}>
<Button data-dialog-action="confirm" onClick={handleSave} disabled={saving}>
{saving ? '保存中...' : '保存'}
</Button>
</DialogFooter>

View File

@@ -14,6 +14,7 @@ import {
import { Button } from '@/components/ui/button'
import {
Dialog,
DialogBody,
DialogContent,
DialogDescription,
DialogFooter,
@@ -24,7 +25,6 @@ import { Badge } from '@/components/ui/badge'
import { Input } from '@/components/ui/input'
import { Label } from '@/components/ui/label'
import { MarkdownRenderer } from '@/components/markdown-renderer'
import { ScrollArea } from '@/components/ui/scroll-area'
import {
Select,
SelectContent,
@@ -92,7 +92,7 @@ export function JargonDetailDialog({
<DialogDescription></DialogDescription>
</DialogHeader>
<ScrollArea className="h-full pr-4">
<DialogBody className="h-full">
<div className="space-y-4 pb-2">
<div className="grid grid-cols-2 gap-4">
<InfoItem icon={Hash} label="记录ID" value={jargon.id.toString()} mono />
@@ -167,7 +167,7 @@ export function JargonDetailDialog({
</div>
)}
</div>
</ScrollArea>
</DialogBody>
<DialogFooter className="flex-shrink-0">
<Button onClick={() => onOpenChange(false)}></Button>
@@ -234,12 +234,13 @@ export function JargonCreateDialog({
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="max-w-2xl max-h-[80vh] overflow-y-auto">
<DialogContent className="max-w-2xl" confirmOnEnter>
<DialogHeader>
<DialogTitle></DialogTitle>
<DialogDescription></DialogDescription>
</DialogHeader>
<DialogBody>
<div className="space-y-4">
<div className="space-y-2">
<Label htmlFor="content">
@@ -294,10 +295,11 @@ export function JargonCreateDialog({
<Label htmlFor="is_global"></Label>
</div>
</div>
</DialogBody>
<DialogFooter>
<Button variant="outline" onClick={() => onOpenChange(false)}></Button>
<Button onClick={handleCreate} disabled={saving}>
<Button data-dialog-action="confirm" onClick={handleCreate} disabled={saving}>
{saving ? '创建中...' : '创建'}
</Button>
</DialogFooter>
@@ -366,12 +368,13 @@ export function JargonEditDialog({
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="max-w-2xl max-h-[80vh] overflow-y-auto">
<DialogContent className="max-w-2xl" confirmOnEnter>
<DialogHeader>
<DialogTitle></DialogTitle>
<DialogDescription></DialogDescription>
</DialogHeader>
<DialogBody>
<div className="space-y-4">
<div className="space-y-2">
<Label htmlFor="edit_content"></Label>
@@ -439,10 +442,11 @@ export function JargonEditDialog({
<Label htmlFor="edit_is_global"></Label>
</div>
</div>
</DialogBody>
<DialogFooter>
<Button variant="outline" onClick={() => onOpenChange(false)}></Button>
<Button onClick={handleSave} disabled={saving}>
<Button data-dialog-action="confirm" onClick={handleSave} disabled={saving}>
{saving ? '保存中...' : '保存'}
</Button>
</DialogFooter>

View File

@@ -2,11 +2,11 @@
import { Badge } from '@/components/ui/badge'
import {
Dialog,
DialogBody,
DialogContent,
DialogHeader,
DialogTitle,
} from '@/components/ui/dialog'
import { ScrollArea } from '@/components/ui/scroll-area'
import type { GraphNode, SelectedEdgeData } from './types'
@@ -24,7 +24,7 @@ export function NodeDetailDialog({ open, onOpenChange, selectedNodeData }: NodeD
<DialogTitle></DialogTitle>
</DialogHeader>
{selectedNodeData && (
<ScrollArea className="h-full pr-4">
<DialogBody className="h-full">
<div className="space-y-4 pb-2">
<div className="grid grid-cols-2 gap-4">
<div>
@@ -62,7 +62,7 @@ export function NodeDetailDialog({ open, onOpenChange, selectedNodeData }: NodeD
)}
</div>
</div>
</ScrollArea>
</DialogBody>
)}
</DialogContent>
</Dialog>
@@ -83,7 +83,7 @@ export function EdgeDetailDialog({ open, onOpenChange, selectedEdgeData }: EdgeD
<DialogTitle></DialogTitle>
</DialogHeader>
{selectedEdgeData && (
<ScrollArea className="flex-1 pr-4">
<DialogBody>
<div className="space-y-4">
<div className="flex items-center gap-4">
<div className="flex-1 min-w-0 p-3 bg-blue-50 dark:bg-blue-950 rounded border-2 border-blue-200 dark:border-blue-800">
@@ -114,7 +114,7 @@ export function EdgeDetailDialog({ open, onOpenChange, selectedEdgeData }: EdgeD
</div>
</div>
</div>
</ScrollArea>
</DialogBody>
)}
</DialogContent>
</Dialog>

View File

@@ -190,7 +190,6 @@ class ChatBot:
@staticmethod
def _mark_command_message(message: SessionMessage, intercept_message_level: int) -> None:
message.is_command = True
message.intercept_message_level = intercept_message_level
message.message_info.additional_config["intercept_message_level"] = intercept_message_level
@staticmethod

View File

@@ -29,7 +29,7 @@ def get_webui_chat_broadcaster():
global _webui_chat_broadcaster
if _webui_chat_broadcaster is None:
try:
from src.webui.chat_routes import chat_manager, WEBUI_CHAT_PLATFORM
from src.webui.routers.chat import WEBUI_CHAT_PLATFORM, chat_manager
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM)
except ImportError:

View File

@@ -1,6 +1,7 @@
import traceback
import time
import asyncio
import importlib
import random
import re
@@ -11,9 +12,9 @@ from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.llm_data_model import LLMGenerationDataModel
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from maim_message import BaseMessageInfo, MessageBase, Seg
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUserInfo
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.common.data_models.mai_message_data_model import MaiMessage
from src.chat.message_receive.message import SessionMessage
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
@@ -786,7 +787,6 @@ class DefaultReplyer:
available_actions = {}
chat_stream = self.chat_stream
chat_id = chat_stream.session_id
_is_group_chat = bool(chat_stream.group_info)
platform = chat_stream.platform
user_id = "用户ID"
@@ -795,11 +795,12 @@ class DefaultReplyer:
target = "消息"
if reply_message:
user_id = reply_message.user_info.user_id
reply_user_info = reply_message.message_info.user_info
user_id = reply_user_info.user_id
person = Person(platform=platform, user_id=user_id)
person_name = person.person_name or user_id
sender = person_name
target = reply_message.processed_plain_text
target = reply_message.processed_plain_text or ""
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
@@ -825,16 +826,17 @@ class DefaultReplyer:
person_list_short: List[Person] = []
for msg in message_list_before_short:
msg_user_info = msg.message_info.user_info
# 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI
if is_bot_self(msg.user_info.platform, msg.user_info.user_id):
if is_bot_self(msg.platform, msg_user_info.user_id):
continue
if (
reply_message
and reply_message.user_info.user_id == msg.user_info.user_id
and reply_message.user_info.platform == msg.user_info.platform
and reply_message.message_info.user_info.user_id == msg_user_info.user_id
and reply_message.platform == msg.platform
):
continue
person = Person(platform=msg.user_info.platform, user_id=msg.user_info.user_id)
person = Person(platform=msg.platform, user_id=msg_user_info.user_id)
if person.is_known:
person_list_short.append(person)
@@ -847,7 +849,6 @@ class DefaultReplyer:
timestamp_mode="relative",
read_mark=0.0,
show_actions=True,
long_time_notice=True,
)
# 统一黑话解释构建:根据配置选择上下文或 Planner 模式
@@ -956,7 +957,6 @@ class DefaultReplyer:
replace_bot_name=True,
timestamp_mode="normal_no_YMD",
truncate=True,
long_time_notice=True,
)
# 获取匹配的额外prompt
@@ -1121,7 +1121,7 @@ class DefaultReplyer:
platform=self.chat_stream.platform,
message_id=message_id,
time=thinking_start_time,
user_info=UserInfo(
user_info=MaimUserInfo(
user_id=str(global_config.bot.qq_account),
user_nickname=global_config.bot.nickname,
),
@@ -1160,7 +1160,16 @@ class DefaultReplyer:
async def get_prompt_info(self, message: str, sender: str, target: str):
related_info = ""
start_time = time.time()
from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool
try:
knowledge_module = importlib.import_module("src.plugins.built_in.knowledge.lpmm_get_knowledge")
except ImportError:
logger.debug("LPMM知识库工具模块不存在跳过获取知识库内容")
return ""
search_knowledge_tool = getattr(knowledge_module, "SearchKnowledgeFromLPMMTool", None)
if search_knowledge_tool is None:
logger.debug("LPMM知识库工具未提供 SearchKnowledgeFromLPMMTool跳过获取知识库内容")
return ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 从LPMM知识库获取知识
@@ -1183,14 +1192,14 @@ class DefaultReplyer:
_, _, _, _, tool_calls = await llm_api.generate_with_model_with_tools(
prompt,
model_config=model_config.model_task_config.tool_use,
tool_options=[SearchKnowledgeFromLPMMTool.get_tool_definition()],
tool_options=[search_knowledge_tool.get_tool_definition()],
)
# logger.info(f"工具调用提示词: {prompt}")
# logger.info(f"工具调用: {tool_calls}")
if tool_calls:
result = await self.tool_executor.execute_tool_call(tool_calls[0], SearchKnowledgeFromLPMMTool())
result = await self.tool_executor.execute_tool_call(tool_calls[0])
end_time = time.time()
if not result or not result.get("content"):
logger.debug("从LPMM知识库获取知识失败返回空知识...")

View File

@@ -11,9 +11,9 @@ from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.llm_data_model import LLMGenerationDataModel
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from maim_message import BaseMessageInfo, MessageBase, Seg
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUserInfo
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.common.data_models.mai_message_data_model import MaiMessage
from src.chat.message_receive.message import SessionMessage
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
@@ -636,11 +636,12 @@ class PrivateReplyer:
target = "消息"
if reply_message:
user_id = reply_message.user_info.user_id
reply_user_info = reply_message.message_info.user_info
user_id = reply_user_info.user_id
person = Person(platform=platform, user_id=user_id)
person_name = person.person_name or user_id
sender = person_name
target = reply_message.processed_plain_text
target = reply_message.processed_plain_text or ""
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
@@ -663,7 +664,6 @@ class PrivateReplyer:
timestamp_mode="relative",
read_mark=0.0,
show_actions=True,
long_time_notice=True,
)
message_list_before_short = get_messages_before_time_in_chat(
@@ -675,16 +675,17 @@ class PrivateReplyer:
person_list_short: List[Person] = []
for msg in message_list_before_short:
msg_user_info = msg.message_info.user_info
# 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI
if is_bot_self(msg.user_info.platform, msg.user_info.user_id):
if is_bot_self(msg.platform, msg_user_info.user_id):
continue
if (
reply_message
and reply_message.user_info.user_id == msg.user_info.user_id
and reply_message.user_info.platform == msg.user_info.platform
and reply_message.message_info.user_info.user_id == msg_user_info.user_id
and reply_message.platform == msg.platform
):
continue
person = Person(platform=msg.user_info.platform, user_id=msg.user_info.user_id)
person = Person(platform=msg.platform, user_id=msg_user_info.user_id)
if person.is_known:
person_list_short.append(person)
@@ -960,7 +961,7 @@ class PrivateReplyer:
platform=self.chat_stream.platform,
message_id=message_id,
time=thinking_start_time,
user_info=UserInfo(
user_info=MaimUserInfo(
user_id=str(global_config.bot.qq_account),
user_nickname=global_config.bot.nickname,
),

View File

@@ -76,7 +76,6 @@ MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = {
"webui.jargon": ("#d7d75f", None, False),
"webui.person": ("#87d787", None, False),
"webui.statistics": ("#af87ff", None, False),
"webui.annual_report": ("#ffaf87", None, False),
"webui.plugin_routes": ("#ffaf00", None, False),
"webui.plugin_progress": ("#ff8700", None, False),
"webui.git_mirror": ("#878787", None, False),
@@ -180,7 +179,6 @@ MODULE_ALIASES = {
"webui.jargon": "WebUI黑话",
"webui.person": "WebUI人物",
"webui.statistics": "WebUI统计",
"webui.annual_report": "WebUI年报",
"webui.plugin_routes": "WebUI插件",
"webui.plugin_progress": "WebUI插件进度",
"webui.git_mirror": "WebUI镜像",

View File

@@ -1,6 +1,5 @@
import traceback
from datetime import datetime
from types import SimpleNamespace
from typing import Any
import json
@@ -48,48 +47,8 @@ def _parse_additional_config(message: Messages) -> dict[str, Any]:
return {}
def _normalize_optional_str(value: object) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
try:
return json.dumps(value, ensure_ascii=False)
except (TypeError, ValueError):
return str(value)
def _message_to_instance(message: Messages) -> SessionMessage:
config = _parse_additional_config(message)
instance = SessionMessage.from_db_instance(message)
instance.interest_value = config.get("interest_value")
instance.key_words = _normalize_optional_str(config.get("key_words"))
instance.key_words_lite = _normalize_optional_str(config.get("key_words_lite"))
instance.reply_probability_boost = config.get("reply_probability_boost")
instance.priority_mode = _normalize_optional_str(config.get("priority_mode"))
instance.priority_info = _normalize_optional_str(config.get("priority_info"))
instance.intercept_message_level = config.get("intercept_message_level", 0)
instance.selected_expressions = _normalize_optional_str(config.get("selected_expressions"))
group_info = instance.message_info.group_info
legacy_group_info = None
if group_info:
legacy_group_info = SimpleNamespace(
group_id=group_info.group_id,
group_name=group_info.group_name,
)
instance.user_info = SimpleNamespace(
user_id=instance.message_info.user_info.user_id,
user_nickname=instance.message_info.user_info.user_nickname,
user_cardname=instance.message_info.user_info.user_cardname,
platform=instance.platform,
)
instance.chat_info = SimpleNamespace(
platform=instance.platform,
stream_id=instance.session_id,
group_info=legacy_group_info,
)
instance.time = instance.timestamp.timestamp()
return instance
return SessionMessage.from_db_instance(message)
def _coerce_datetime(value: Any) -> Any:
@@ -118,6 +77,7 @@ def _build_message_conditions(
end_time: float | None = None,
before_time: float | None = None,
after_time: float | None = None,
has_reply_to: bool | None = None,
) -> list[Any]:
conditions: list[Any] = [Messages.message_id != "notice"]
@@ -141,6 +101,10 @@ def _build_message_conditions(
conditions.append(Messages.timestamp < _coerce_datetime(before_time))
if after_time is not None:
conditions.append(Messages.timestamp > _coerce_datetime(after_time))
if has_reply_to is True:
conditions.append(col(Messages.reply_to).is_not(None))
elif has_reply_to is False:
conditions.append(col(Messages.reply_to).is_(None))
return conditions
@@ -261,6 +225,7 @@ def count_messages(
end_time: float | None = None,
before_time: float | None = None,
after_time: float | None = None,
has_reply_to: bool | None = None,
) -> int:
"""
根据提供的过滤器计算消息数量。
@@ -276,6 +241,7 @@ def count_messages(
end_time: 结束时间,闭区间上界。
before_time: 严格早于该时间。
after_time: 严格晚于该时间。
has_reply_to: 是否要求存在 reply_to 字段。
Returns:
符合条件的消息数量,如果出错则返回 0。
@@ -292,6 +258,7 @@ def count_messages(
end_time=end_time,
before_time=before_time,
after_time=after_time,
has_reply_to=has_reply_to,
)
statement = select(func.count()).select_from(Messages).where(*conditions)
with get_db_session() as session:
@@ -302,7 +269,7 @@ def count_messages(
"使用 SQLModel 计数消息失败 "
f"(session_id={session_id}, user_id={user_id}, group_id={group_id}, platform={platform}, "
f"message_id={message_id}, reply_to={reply_to}, start_time={start_time}, end_time={end_time}, "
f"before_time={before_time}, after_time={after_time}): {e}\n{traceback.format_exc()}"
f"before_time={before_time}, after_time={after_time}, has_reply_to={has_reply_to}): {e}\n{traceback.format_exc()}"
)
logger.error(log_message)
return 0

View File

@@ -25,13 +25,15 @@ def is_port_conflict_error(error: OSError) -> bool:
return "address already in use" in message or "已被占用" in message
def check_port_available(host: str, port: int) -> bool:
def check_port_available(host: str, port: int, *, allow_reuse_addr: bool = False) -> bool:
family = _detect_socket_family(host)
test_host = _normalize_test_host(host)
try:
with socket.socket(family, socket.SOCK_STREAM) as test_socket:
test_socket.settimeout(1)
if allow_reuse_addr:
test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
test_socket.bind((test_host, port))
return True
except OSError:
@@ -65,8 +67,9 @@ def assert_port_available(
service_name: str,
logger,
config_hint: Optional[str] = None,
allow_reuse_addr: bool = False,
) -> None:
if check_port_available(host=host, port=port):
if check_port_available(host=host, port=port, allow_reuse_addr=allow_reuse_addr):
return
log_port_conflict(

View File

@@ -8,13 +8,17 @@
3. 详情按需加载
"""
import json
from pathlib import Path
from typing import List, Dict, Optional
from fastapi import APIRouter, HTTPException, Query
from typing import Dict, List, Optional
import json
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
router = APIRouter(prefix="/api/planner", tags=["planner"])
from src.webui.dependencies import require_auth
router = APIRouter(prefix="/api/planner", tags=["planner"], dependencies=[Depends(require_auth)])
# 规划器日志目录
PLAN_LOG_DIR = Path("logs/plan")

View File

@@ -8,13 +8,17 @@
3. 详情按需加载
"""
import json
from pathlib import Path
from typing import List, Dict, Optional
from fastapi import APIRouter, HTTPException, Query
from typing import Dict, List, Optional
import json
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
router = APIRouter(prefix="/api/replier", tags=["replier"])
from src.webui.dependencies import require_auth
router = APIRouter(prefix="/api/replier", tags=["replier"], dependencies=[Depends(require_auth)])
# 回复器日志目录
REPLY_LOG_DIR = Path("logs/reply")

View File

@@ -52,7 +52,6 @@ def _setup_cors(app: FastAPI, port: int):
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
allow_headers=[
"Content-Type",
"Authorization",
"Accept",
"Origin",
"X-Requested-With",

View File

@@ -9,6 +9,7 @@ from .auth import (
COOKIE_NAME,
COOKIE_MAX_AGE,
get_current_token,
is_token_valid,
set_auth_cookie,
clear_auth_cookie,
verify_auth_token_from_cookie_or_header,
@@ -24,6 +25,7 @@ __all__ = [
"COOKIE_NAME",
"COOKIE_MAX_AGE",
"get_current_token",
"is_token_valid",
"set_auth_cookie",
"clear_auth_cookie",
"verify_auth_token_from_cookie_or_header",

View File

@@ -1,10 +1,8 @@
"""
WebUI 认证模块
提供统一的认证依赖,支持 Cookie 和 Header 两种方式
"""
"""WebUI 认证模块。"""
from typing import Optional
from fastapi import HTTPException, Cookie, Header, Response, Request
from fastapi import Cookie, HTTPException, Request, Response
from src.common.logger import get_logger
from src.config.config import global_config
from .security import get_token_manager
@@ -39,17 +37,13 @@ def _is_secure_environment() -> bool:
def get_current_token(
request: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> str:
"""
获取当前请求的 token优先从 Cookie 获取,其次从 Header 获取
获取当前请求的 token仅从 HttpOnly Cookie 获取。
Args:
request: FastAPI Request 对象
maibot_session: Cookie 中的 token
authorization: Authorization Header (Bearer token)
Returns:
验证通过的 token
@@ -57,24 +51,19 @@ def get_current_token(
Raises:
HTTPException: 认证失败时抛出 401 错误
"""
token = None
# 优先从 Cookie 获取
if maibot_session:
token = maibot_session
# 其次从 Header 获取(兼容旧版本)
elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
if not token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
# 验证 token
token_manager = get_token_manager()
if not token_manager.verify_token(token):
if not is_token_valid(maibot_session):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return token
return maibot_session
def is_token_valid(maibot_session: Optional[str]) -> bool:
"""判断认证 token 是否存在且有效。"""
if not maibot_session:
return False
token_manager = get_token_manager()
return token_manager.verify_token(maibot_session)
def set_auth_cookie(response: Response, token: str, request: Optional[Request] = None) -> None:
@@ -150,14 +139,12 @@ def clear_auth_cookie(response: Response) -> None:
def verify_auth_token_from_cookie_or_header(
maibot_session: Optional[str] = None,
authorization: Optional[str] = None,
) -> bool:
"""
验证认证 Token支持从 Cookie 或 Header 获取
验证认证 Cookie。
Args:
maibot_session: Cookie 中的 token
authorization: Authorization header (Bearer token)
Returns:
验证成功返回 True
@@ -165,21 +152,7 @@ def verify_auth_token_from_cookie_or_header(
Raises:
HTTPException: 认证失败时抛出 401 错误
"""
token = None
# 优先从 Cookie 获取
if maibot_session:
token = maibot_session
# 其次从 Header 获取(兼容旧版本)
elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
if not token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
# 验证 token
token_manager = get_token_manager()
if not token_manager.verify_token(token):
if not is_token_valid(maibot_session):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return True

View File

@@ -1,17 +1,16 @@
from typing import Optional
from fastapi import Depends, Cookie, Header, Request
from .core import get_current_token, get_token_manager, check_auth_rate_limit
from fastapi import Cookie, Depends, Request
from .core import check_auth_rate_limit, get_current_token, is_token_valid
async def require_auth(
request: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> str:
"""
FastAPI 依赖:要求有效认证
用于保护需要认证的路由,自动从 Cookie 或 Header 获取并验证 token
用于保护需要认证的路由,自动从 Cookie 获取并验证 token
Returns:
验证通过的 token
@@ -19,13 +18,12 @@ async def require_auth(
Raises:
HTTPException 401: 认证失败
"""
return get_current_token(request, maibot_session, authorization)
return get_current_token(maibot_session)
async def require_auth_with_rate_limit(
request: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
_rate_limit: None = Depends(check_auth_rate_limit),
) -> str:
"""
@@ -40,12 +38,11 @@ async def require_auth_with_rate_limit(
HTTPException 401: 认证失败
HTTPException 429: 请求过于频繁
"""
return get_current_token(request, maibot_session, authorization)
return get_current_token(maibot_session)
def get_optional_token(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> Optional[str]:
"""
FastAPI 依赖:可选获取 token不验证
@@ -55,16 +52,11 @@ def get_optional_token(
Returns:
token 字符串或 None
"""
if maibot_session:
return maibot_session
if authorization and authorization.startswith("Bearer "):
return authorization.replace("Bearer ", "")
return None
return maibot_session or None
async def verify_token_optional(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""
FastAPI 依赖:可选验证 token
@@ -74,14 +66,4 @@ async def verify_token_optional(
Returns:
True 如果 token 有效,否则 False
"""
token = None
if maibot_session:
token = maibot_session
elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
if not token:
return False
token_manager = get_token_manager()
return token_manager.verify_token(token)
return is_token_valid(maibot_session)

View File

@@ -1,916 +0,0 @@
"""麦麦 2025 年度总结 API 路由"""
from datetime import datetime
from typing import Any, Optional
from fastapi import APIRouter, Cookie, Depends, Header, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy import desc, func
from sqlmodel import col, select
from src.common.database.database import get_db_session
from src.common.database.database_model import (
ActionRecord,
Expression,
Images,
Jargon,
Messages,
ModelUsage,
OnlineTime,
PersonInfo,
)
from src.common.logger import get_logger
from src.webui.core import verify_auth_token_from_cookie_or_header
logger = get_logger("webui.annual_report")
router = APIRouter(prefix="/annual-report", tags=["annual-report"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# ==================== Pydantic 模型定义 ====================
class TimeFootprintData(BaseModel):
"""时光足迹数据"""
total_online_hours: float = Field(0.0, description="年度在线总时长(小时)")
first_message_time: Optional[str] = Field(None, description="初次消息时间")
first_message_user: Optional[str] = Field(None, description="初次消息用户昵称")
first_message_content: Optional[str] = Field(None, description="初次消息内容(截断)")
busiest_day: Optional[str] = Field(None, description="最忙碌的一天")
busiest_day_count: int = Field(0, description="最忙碌那天的消息数")
hourly_distribution: list[int] = Field(default_factory=lambda: [0] * 24, description="24小时活跃分布")
midnight_chat_count: int = Field(0, description="深夜(0-4点)互动次数")
is_night_owl: bool = Field(False, description="是否是夜猫子")
class SocialNetworkData(BaseModel):
"""社交网络数据"""
total_groups: int = Field(0, description="加入的群组总数")
top_groups: list[dict[str, Any]] = Field(default_factory=list, description="话痨群组TOP5")
top_users: list[dict[str, Any]] = Field(default_factory=list, description="互动最多的用户TOP5")
at_count: int = Field(0, description="被@次数")
mentioned_count: int = Field(0, description="被提及次数")
longest_companion_user: Optional[str] = Field(None, description="最长情陪伴的用户")
longest_companion_days: int = Field(0, description="陪伴天数")
class BrainPowerData(BaseModel):
"""最强大脑数据"""
total_tokens: int = Field(0, description="年度消耗Token总量")
total_cost: float = Field(0.0, description="年度总花费")
favorite_model: Optional[str] = Field(None, description="最爱用的模型")
favorite_model_count: int = Field(0, description="最爱模型的调用次数")
model_distribution: list[dict[str, Any]] = Field(default_factory=list, description="模型使用分布")
top_reply_models: list[dict[str, Any]] = Field(default_factory=list, description="最喜欢的回复模型TOP5")
most_expensive_cost: float = Field(0.0, description="最昂贵的一次思考花费")
most_expensive_time: Optional[str] = Field(None, description="最昂贵思考的时间")
top_token_consumers: list[dict[str, Any]] = Field(default_factory=list, description="烧钱大户TOP3")
silence_rate: float = Field(0.0, description="高冷指数(沉默率)")
total_actions: int = Field(0, description="总动作数")
no_reply_count: int = Field(0, description="选择沉默的次数")
avg_interest_value: float = Field(0.0, description="平均兴趣值")
max_interest_value: float = Field(0.0, description="最高兴趣值")
max_interest_time: Optional[str] = Field(None, description="最高兴趣值时间")
avg_reasoning_length: float = Field(0.0, description="平均思考长度")
max_reasoning_length: int = Field(0, description="最长思考长度")
max_reasoning_time: Optional[str] = Field(None, description="最长思考的时间")
class ExpressionVibeData(BaseModel):
"""个性与表达数据"""
top_emoji: Optional[dict[str, Any]] = Field(None, description="表情包之王")
top_emojis: list[dict[str, Any]] = Field(default_factory=list, description="TOP3表情包")
top_expressions: list[dict[str, Any]] = Field(default_factory=list, description="印象最深刻的表达风格")
rejected_expression_count: int = Field(0, description="被拒绝的表达次数")
checked_expression_count: int = Field(0, description="已检查的表达次数")
total_expressions: int = Field(0, description="表达总数")
action_types: list[dict[str, Any]] = Field(default_factory=list, description="动作类型分布")
image_processed_count: int = Field(0, description="处理的图片数量")
late_night_reply: Optional[dict[str, Any]] = Field(None, description="深夜还在回复")
favorite_reply: Optional[dict[str, Any]] = Field(None, description="最喜欢的回复")
class AchievementData(BaseModel):
"""趣味成就数据"""
new_jargon_count: int = Field(0, description="新学到的黑话数量")
sample_jargons: list[dict[str, Any]] = Field(default_factory=list, description="代表性黑话示例")
total_messages: int = Field(0, description="总消息数")
total_replies: int = Field(0, description="总回复数")
class AnnualReportData(BaseModel):
"""年度报告完整数据"""
year: int = Field(2025, description="报告年份")
bot_name: str = Field("麦麦", description="Bot名称")
generated_at: str = Field(..., description="报告生成时间")
time_footprint: TimeFootprintData = Field(default_factory=lambda: TimeFootprintData.model_construct())
social_network: SocialNetworkData = Field(default_factory=lambda: SocialNetworkData.model_construct())
brain_power: BrainPowerData = Field(default_factory=lambda: BrainPowerData.model_construct())
expression_vibe: ExpressionVibeData = Field(default_factory=lambda: ExpressionVibeData.model_construct())
achievements: AchievementData = Field(default_factory=lambda: AchievementData.model_construct())
# ==================== 辅助函数 ====================
def get_year_time_range(year: int = 2025) -> tuple[float, float]:
"""获取指定年份的时间戳范围"""
start = datetime(year, 1, 1, 0, 0, 0).timestamp()
end = datetime(year, 12, 31, 23, 59, 59).timestamp()
return start, end
def get_year_datetime_range(year: int = 2025) -> tuple[datetime, datetime]:
"""获取指定年份的 datetime 范围"""
start = datetime(year, 1, 1, 0, 0, 0)
end = datetime(year, 12, 31, 23, 59, 59)
return start, end
# ==================== 维度一:时光足迹 ====================
async def get_time_footprint(year: int = 2025) -> TimeFootprintData:
"""获取时光足迹数据"""
data = TimeFootprintData.model_construct()
start_ts, end_ts = get_year_time_range(year)
start_dt, end_dt = get_year_datetime_range(year)
try:
# 1. 年度在线时长
with get_db_session() as session:
statement = select(OnlineTime).where(
col(OnlineTime.start_timestamp) >= start_dt,
col(OnlineTime.end_timestamp) <= end_dt,
)
online_records = session.exec(statement).all()
total_seconds = 0
for record in online_records:
try:
start = max(record.start_timestamp, start_dt)
end = min(record.end_timestamp, end_dt)
if end > start:
total_seconds += (end - start).total_seconds()
except Exception:
continue
data.total_online_hours = round(total_seconds / 3600, 2)
# 2. 初次相遇 - 年度第一条消息
with get_db_session() as session:
statement = (
select(Messages)
.where(
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
)
.order_by(col(Messages.timestamp).asc())
.limit(1)
)
first_msg = session.exec(statement).first()
if first_msg:
data.first_message_time = first_msg.timestamp.strftime("%Y-%m-%d %H:%M:%S")
data.first_message_user = first_msg.user_nickname or first_msg.user_id or "未知用户"
content = first_msg.processed_plain_text or first_msg.display_message or ""
data.first_message_content = content[:50] + "..." if len(content) > 50 else content
# 3. 最忙碌的一天
# 使用 SQLite 的 date 函数按日期分组
day_expr = func.date(col(Messages.timestamp))
with get_db_session() as session:
statement = (
select(
day_expr.label("day"),
func.count().label("count"),
)
.where(
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
)
.group_by(day_expr)
.order_by(func.count().desc())
.limit(1)
)
busiest_result = session.exec(statement).all()
if busiest_result:
data.busiest_day = busiest_result[0][0]
data.busiest_day_count = busiest_result[0][1] or 0
# 4. 昼夜节律 - 24小时活跃分布
hour_expr = func.strftime("%H", col(Messages.timestamp))
with get_db_session() as session:
statement = (
select(
hour_expr.label("hour"),
func.count().label("count"),
)
.where(
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
)
.group_by(hour_expr)
)
hourly_rows = session.exec(statement).all()
hourly_distribution = [0] * 24
for row in hourly_rows:
try:
hour = int(row[0] or 0)
if 0 <= hour < 24:
hourly_distribution[hour] = row[1] or 0
except (ValueError, TypeError):
continue
data.hourly_distribution = hourly_distribution
# 5. 深夜食堂 (0-4点)
data.midnight_chat_count = sum(hourly_distribution[0:5])
# 6. 判断是否夜猫子 (22点-4点活跃度 vs 6点-12点)
night_activity = sum(hourly_distribution[22:24]) + sum(hourly_distribution[0:5])
morning_activity = sum(hourly_distribution[6:13])
data.is_night_owl = night_activity > morning_activity
except Exception as e:
logger.error(f"获取时光足迹数据失败: {e}")
return data
# ==================== 维度二:社交网络 ====================
async def get_social_network(year: int = 2025) -> SocialNetworkData:
"""获取社交网络数据"""
from src.config.config import global_config
data = SocialNetworkData.model_construct()
start_ts, end_ts = get_year_time_range(year)
# 获取 bot 自身的 QQ 账号,用于过滤
bot_qq = str(global_config.bot.qq_account or "")
try:
# 1. 加入的群组总数
with get_db_session() as session:
statement = select(func.count(func.distinct(col(Messages.group_id)))).where(
col(Messages.group_id).is_not(None),
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
)
data.total_groups = int(session.exec(statement).first() or 0)
# 2. 话痨群组 TOP3
with get_db_session() as session:
statement = (
select(
col(Messages.group_id),
func.max(col(Messages.group_name)).label("group_name"),
func.count().label("count"),
)
.where(
col(Messages.group_id).is_not(None),
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
)
.group_by(col(Messages.group_id))
.order_by(func.count().desc())
.limit(5)
)
top_groups_rows = session.exec(statement).all()
data.top_groups = [
{
"group_id": row[0],
"group_name": row[1] or "未知群组",
"message_count": row[2] or 0,
"is_webui": str(row[0]).startswith("webui_"),
}
for row in top_groups_rows
]
# 3. 互动最多的用户 TOP5过滤 bot 自身)
with get_db_session() as session:
statement = (
select(
col(Messages.user_id),
func.max(col(Messages.user_nickname)).label("user_nickname"),
func.count().label("count"),
)
.where(
col(Messages.user_id).is_not(None),
col(Messages.user_id) != bot_qq,
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
)
.group_by(col(Messages.user_id))
.order_by(func.count().desc())
.limit(5)
)
top_users_rows = session.exec(statement).all()
data.top_users = [
{
"user_id": row[0],
"user_nickname": row[1] or "未知用户",
"message_count": row[2] or 0,
"is_webui": str(row[0]).startswith("webui_"),
}
for row in top_users_rows
]
# 4. 被@次数
with get_db_session() as session:
statement = select(func.count()).where(
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
col(Messages.is_at),
)
data.at_count = int(session.exec(statement).first() or 0)
# 5. 被提及次数
with get_db_session() as session:
statement = select(func.count()).where(
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
col(Messages.is_mentioned),
)
data.mentioned_count = int(session.exec(statement).first() or 0)
# 6. 最长情陪伴的用户(过滤 bot 自身)
with get_db_session() as session:
statement = select(PersonInfo).where(
col(PersonInfo.user_id) != bot_qq,
col(PersonInfo.first_known_time).is_not(None),
col(PersonInfo.last_known_time).is_not(None),
)
persons = session.exec(statement).all()
if persons:
def _companion_days(person: PersonInfo) -> float:
if not person.first_known_time or not person.last_known_time:
return 0.0
return (person.last_known_time - person.first_known_time).total_seconds()
longest = max(persons, key=_companion_days)
data.longest_companion_user = longest.person_name or longest.user_nickname or longest.user_id
data.longest_companion_days = int(_companion_days(longest) / 86400)
else:
data.longest_companion_user = None
data.longest_companion_days = 0
except Exception as e:
logger.error(f"获取社交网络数据失败: {e}")
return data
# ==================== 维度三:最强大脑 ====================
async def get_brain_power(year: int = 2025) -> BrainPowerData:
"""获取最强大脑数据"""
data = BrainPowerData.model_construct()
start_dt, end_dt = get_year_datetime_range(year)
start_ts, end_ts = get_year_time_range(year)
try:
# 1. 年度消耗 Token 总量和总花费
with get_db_session() as session:
statement = select(
func.sum(col(ModelUsage.total_tokens)).label("total_tokens"),
func.sum(col(ModelUsage.cost)).label("total_cost"),
).where(col(ModelUsage.timestamp) >= start_dt, col(ModelUsage.timestamp) <= end_dt)
result = session.exec(statement).first()
if result:
data.total_tokens = int(result[0] or 0)
data.total_cost = round(float(result[1] or 0), 4)
# 2. 最爱用的模型
with get_db_session() as session:
statement = (
select(ModelUsage)
.where(col(ModelUsage.timestamp) >= start_dt, col(ModelUsage.timestamp) <= end_dt)
.order_by(desc(col(ModelUsage.timestamp)))
)
records = session.exec(statement).all()
model_agg: dict[str, dict[str, float | int]] = {}
for record in records:
model_name = record.model_assign_name or record.model_name or "unknown"
if model_name not in model_agg:
model_agg[model_name] = {"count": 0, "tokens": 0, "cost": 0.0}
bucket = model_agg[model_name]
bucket["count"] = int(bucket["count"]) + 1
bucket["tokens"] = int(bucket["tokens"]) + int(record.total_tokens or 0)
bucket["cost"] = float(bucket["cost"]) + float(record.cost or 0.0)
model_results = sorted(
model_agg.items(),
key=lambda item: float(item[1]["count"]),
reverse=True,
)[:10]
if model_results:
data.favorite_model = model_results[0][0]
data.favorite_model_count = int(model_results[0][1]["count"])
data.model_distribution = [
{
"model": model_name,
"count": int(bucket["count"]),
"tokens": int(bucket["tokens"]),
"cost": round(float(bucket["cost"]), 4),
}
for model_name, bucket in model_results
]
# 3. 最昂贵的一次思考
if records:
expensive_record = max(records, key=lambda record: record.cost or 0.0)
data.most_expensive_cost = round(expensive_record.cost or 0.0, 4)
data.most_expensive_time = expensive_record.timestamp.strftime("%Y-%m-%d %H:%M:%S")
# 4. 烧钱大户 TOP3 (按用户,过滤 system)
consumer_agg: dict[str, dict[str, float | int]] = {}
for record in records:
user_id = record.model_api_provider_name
if not user_id or user_id == "system":
continue
if user_id not in consumer_agg:
consumer_agg[user_id] = {"cost": 0.0, "tokens": 0}
bucket = consumer_agg[user_id]
bucket["cost"] = float(bucket["cost"]) + float(record.cost or 0.0)
bucket["tokens"] = int(bucket["tokens"]) + int(record.total_tokens or 0)
data.top_token_consumers = [
{
"user_id": user_id,
"cost": round(float(bucket["cost"]), 4),
"tokens": int(bucket["tokens"]),
}
for user_id, bucket in sorted(
consumer_agg.items(),
key=lambda item: float(item[1]["cost"]),
reverse=True,
)[:3]
]
# 5. 最喜欢的回复模型 TOP5按模型的回复次数统计只统计 replyer 调用)
# 假设 replyer 调用有特定的 model_assign_name 格式或可以通过某种方式识别
reply_model_agg: dict[str, int] = {}
for record in records:
model_assign_name = record.model_assign_name or ""
if "replyer" not in model_assign_name and "回复" not in model_assign_name:
continue
model_name = model_assign_name or record.model_name or "unknown"
reply_model_agg[model_name] = reply_model_agg.get(model_name, 0) + 1
data.top_reply_models = [
{"model": model_name, "count": count}
for model_name, count in sorted(reply_model_agg.items(), key=lambda item: item[1], reverse=True)[:5]
]
# 6. 高冷指数 (沉默率) - 基于 ActionRecords
with get_db_session() as session:
statement = select(func.count()).where(
col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts),
col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts),
)
total_actions = int(session.exec(statement).first() or 0)
with get_db_session() as session:
statement = select(func.count()).where(
col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts),
col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts),
col(ActionRecord.action_name) == "no_reply",
)
no_reply_count = int(session.exec(statement).first() or 0)
data.total_actions = total_actions
data.no_reply_count = no_reply_count
data.silence_rate = round(no_reply_count / total_actions * 100, 2) if total_actions > 0 else 0
# 6. 情绪波动 (兴趣值)
data.avg_interest_value = 0.0
data.max_interest_value = 0.0
# 找到最高兴趣值的时间
if data.max_interest_value > 0:
data.max_interest_time = None
# 7. 思考深度 (基于 action_reasoning 长度)
with get_db_session() as session:
statement = select(ActionRecord).where(
col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts),
col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts),
col(ActionRecord.action_reasoning).is_not(None),
col(ActionRecord.action_reasoning) != "",
)
reasoning_records = session.exec(statement).all()
reasoning_lengths = []
max_len = 0
max_len_time = None
for record in reasoning_records:
if record.action_reasoning:
length = len(record.action_reasoning)
reasoning_lengths.append(length)
if length > max_len:
max_len = length
max_len_time = record.timestamp
if reasoning_lengths:
data.avg_reasoning_length = round(sum(reasoning_lengths) / len(reasoning_lengths), 1)
data.max_reasoning_length = max_len
if max_len_time:
data.max_reasoning_time = max_len_time.strftime("%Y-%m-%d %H:%M:%S")
except Exception as e:
logger.error(f"获取最强大脑数据失败: {e}")
return data
# ==================== 维度四:个性与表达 ====================
async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
"""获取个性与表达数据"""
from src.config.config import global_config
data = ExpressionVibeData.model_construct()
start_ts, end_ts = get_year_time_range(year)
# 获取 bot 自身的 QQ 账号,用于筛选 bot 发送的消息
bot_qq = str(global_config.bot.qq_account or "")
try:
# 1. 表情包之王 - 使用次数最多的表情包
with get_db_session() as session:
statement = select(Images).where(col(Images.is_registered)).order_by(desc(col(Images.query_count))).limit(5)
top_emojis = session.exec(statement).all()
if top_emojis:
data.top_emoji = {
"id": top_emojis[0].id,
"path": top_emojis[0].full_path,
"description": top_emojis[0].description,
"usage_count": top_emojis[0].query_count,
"hash": top_emojis[0].image_hash,
}
data.top_emojis = [
{
"id": e.id,
"path": e.full_path,
"description": e.description,
"usage_count": e.query_count,
"hash": e.image_hash,
}
for e in top_emojis
]
# 2. 百变麦麦 - 最常用的表达风格
with get_db_session() as session:
statement = (
select(Expression.style, func.sum(col(Expression.count)).label("total_count"))
.where(
col(Expression.last_active_time) >= datetime.fromtimestamp(start_ts),
col(Expression.last_active_time) <= datetime.fromtimestamp(end_ts),
)
.group_by(Expression.style)
.order_by(func.sum(col(Expression.count)).desc())
.limit(5)
)
expression_rows = session.exec(statement).all()
data.top_expressions = [{"style": row[0], "count": row[1] or 0} for row in expression_rows]
# 3. 被拒绝的表达
data.rejected_expression_count = 0
# 4. 已检查的表达
data.checked_expression_count = 0
# 5. 表达总数
with get_db_session() as session:
statement = select(func.count()).where(
col(Expression.last_active_time) >= datetime.fromtimestamp(start_ts),
col(Expression.last_active_time) <= datetime.fromtimestamp(end_ts),
)
data.total_expressions = int(session.exec(statement).first() or 0)
# 6. 动作类型分布 (过滤无意义的动作)
# 过滤掉: no_reply_until_call, make_question, no_action, wait, complete_talk, listening, block_and_ignore
excluded_actions = [
"reply",
"no_reply",
"no_reply_until_call",
"make_question",
"no_action",
"wait",
"complete_talk",
"listening",
"block_and_ignore",
]
with get_db_session() as session:
statement = (
select(ActionRecord.action_name, func.count().label("count"))
.where(
col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts),
col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts),
col(ActionRecord.action_name).not_in(excluded_actions),
)
.group_by(ActionRecord.action_name)
.order_by(func.count().desc())
.limit(10)
)
action_rows = session.exec(statement).all()
data.action_types = [{"action": row[0], "count": row[1]} for row in action_rows]
# 7. 处理的图片数量
with get_db_session() as session:
statement = select(func.count()).where(
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
col(Messages.is_picture),
)
data.image_processed_count = int(session.exec(statement).first() or 0)
# 8. 深夜还在回复 (0-6点最晚的10条消息中随机抽取一条)
import random
import re
def clean_message_content(content: str) -> str:
"""清理消息内容,移除回复引用等标记"""
if not content:
return ""
# 移除 [回复<xxx:xxx> 的消息:...] 格式的引用
content = re.sub(r"\[回复<[^>]+>\s*的消息[:][^\]]*\]", "", content)
# 移除 [图片] [表情] 等标记
content = re.sub(r"\[(图片|表情|语音|视频|文件)\]", "", content)
# 移除多余的空白
content = re.sub(r"\s+", " ", content).strip()
return content
# 使用 user_id 判断是否是 bot 发送的消息
with get_db_session() as session:
statement = (
select(Messages)
.where(
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
col(Messages.user_id) == bot_qq,
)
.order_by(desc(col(Messages.timestamp)))
.limit(200)
)
late_night_messages = session.exec(statement).all()
# 筛选出0-6点的消息
late_night_filtered = []
for msg in late_night_messages:
msg_dt = msg.timestamp
hour = msg_dt.hour
if 0 <= hour < 6: # 0点到6点
raw_content = msg.processed_plain_text or msg.display_message or ""
cleaned_content = clean_message_content(raw_content)
# 只保留有意义的内容
if cleaned_content and len(cleaned_content) > 2:
late_night_filtered.append(
{
"time": msg_dt.timestamp(),
"hour": hour,
"minute": msg_dt.minute,
"content": cleaned_content,
"datetime_str": msg_dt.strftime("%H:%M"),
}
)
if len(late_night_filtered) >= 10:
break
if late_night_filtered:
selected = random.choice(late_night_filtered)
content = selected["content"][:50] + "..." if len(selected["content"]) > 50 else selected["content"]
data.late_night_reply = {
"time": selected["datetime_str"],
"content": content,
}
# 9. 最喜欢的回复(按 action_data 统计回复内容出现次数)
from collections import Counter
import json as json_lib
with get_db_session() as session:
statement = select(ActionRecord).where(
col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts),
col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts),
col(ActionRecord.action_name) == "reply",
col(ActionRecord.action_data).is_not(None),
col(ActionRecord.action_data) != "",
)
reply_records = session.exec(statement).all()
reply_contents = []
for record in reply_records:
try:
action_data = record.action_data
if action_data:
content = None
# 尝试解析 JSON 格式
try:
parsed = json_lib.loads(action_data)
if isinstance(parsed, dict):
# 优先使用 reply_text其次使用 content
content = parsed.get("reply_text") or parsed.get("content")
elif isinstance(parsed, str):
content = parsed
except (json_lib.JSONDecodeError, TypeError):
pass
# 如果 JSON 解析失败,尝试解析 Python 字典字符串格式
# 例如: "{'reply_text': '墨白灵不知道哦'}"
if content is None:
import ast
try:
parsed = ast.literal_eval(action_data)
if isinstance(parsed, dict):
content = parsed.get("reply_text") or parsed.get("content")
elif isinstance(parsed, str):
content = parsed
except (ValueError, SyntaxError):
# 无法解析,使用原始字符串
content = action_data
# 只统计有意义的回复长度大于2
if content and len(content) > 2:
reply_contents.append(content)
except Exception:
continue
if reply_contents:
content_counter = Counter(reply_contents)
most_common = content_counter.most_common(1)
if most_common:
fav_content, fav_count = most_common[0]
# 截断过长的内容
display_content = fav_content[:50] + "..." if len(fav_content) > 50 else fav_content
data.favorite_reply = {
"content": display_content,
"count": fav_count,
}
except Exception as e:
logger.error(f"获取个性与表达数据失败: {e}")
return data
# ==================== 维度五:趣味成就 ====================
async def get_achievements(year: int = 2025) -> AchievementData:
"""获取趣味成就数据"""
data = AchievementData.model_construct()
start_ts, end_ts = get_year_time_range(year)
try:
# 1. 新学到的黑话数量
# Jargon 表没有时间字段,统计全部已确认的黑话
with get_db_session() as session:
statement = select(func.count()).where(col(Jargon.is_jargon))
data.new_jargon_count = int(session.exec(statement).first() or 0)
# 2. 代表性黑话示例
with get_db_session() as session:
statement = select(Jargon).where(col(Jargon.is_jargon)).order_by(desc(col(Jargon.count))).limit(5)
jargon_samples = session.exec(statement).all()
data.sample_jargons = [
{
"content": j.content,
"meaning": j.meaning,
"count": j.count,
}
for j in jargon_samples
]
# 3. 总消息数
with get_db_session() as session:
statement = select(func.count()).where(
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
)
data.total_messages = int(session.exec(statement).first() or 0)
# 4. 总回复数 (有 reply_to 的消息)
with get_db_session() as session:
statement = select(func.count()).where(
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
col(Messages.reply_to).is_not(None),
)
data.total_replies = int(session.exec(statement).first() or 0)
except Exception as e:
logger.error(f"获取趣味成就数据失败: {e}")
return data
# ==================== API 路由 ====================
@router.get("/full", response_model=AnnualReportData)
async def get_full_annual_report(year: int = 2025, _auth: bool = Depends(require_auth)):
"""
获取完整年度报告数据
Args:
year: 报告年份默认2025
Returns:
完整的年度报告数据
"""
try:
from src.config.config import global_config
logger.info(f"开始生成 {year} 年度报告...")
# 获取 bot 名称
bot_name = global_config.bot.nickname or "麦麦"
# 并行获取各维度数据
time_footprint = await get_time_footprint(year)
social_network = await get_social_network(year)
brain_power = await get_brain_power(year)
expression_vibe = await get_expression_vibe(year)
achievements = await get_achievements(year)
report = AnnualReportData(
year=year,
bot_name=bot_name,
generated_at=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
time_footprint=time_footprint,
social_network=social_network,
brain_power=brain_power,
expression_vibe=expression_vibe,
achievements=achievements,
)
logger.info(f"{year} 年度报告生成完成")
return report
except Exception as e:
logger.error(f"生成年度报告失败: {e}")
raise HTTPException(status_code=500, detail=f"生成年度报告失败: {str(e)}") from e
@router.get("/time-footprint", response_model=TimeFootprintData)
async def get_time_footprint_api(year: int = 2025, _auth: bool = Depends(require_auth)):
"""获取时光足迹数据"""
try:
return await get_time_footprint(year)
except Exception as e:
logger.error(f"获取时光足迹数据失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/social-network", response_model=SocialNetworkData)
async def get_social_network_api(year: int = 2025, _auth: bool = Depends(require_auth)):
"""获取社交网络数据"""
try:
return await get_social_network(year)
except Exception as e:
logger.error(f"获取社交网络数据失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/brain-power", response_model=BrainPowerData)
async def get_brain_power_api(year: int = 2025, _auth: bool = Depends(require_auth)):
"""获取最强大脑数据"""
try:
return await get_brain_power(year)
except Exception as e:
logger.error(f"获取最强大脑数据失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/expression-vibe", response_model=ExpressionVibeData)
async def get_expression_vibe_api(year: int = 2025, _auth: bool = Depends(require_auth)):
"""获取个性与表达数据"""
try:
return await get_expression_vibe(year)
except Exception as e:
logger.error(f"获取个性与表达数据失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/achievements", response_model=AchievementData)
async def get_achievements_api(year: int = 2025, _auth: bool = Depends(require_auth)):
"""获取趣味成就数据"""
try:
return await get_achievements(year)
except Exception as e:
logger.error(f"获取趣味成就数据失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e

View File

@@ -1,801 +0,0 @@
"""本地聊天室路由 - WebUI 与麦麦直接对话
支持两种模式:
1. WebUI 模式:使用 WebUI 平台独立身份聊天
2. 虚拟身份模式:使用真实平台用户的身份,在虚拟群聊中与麦麦对话
"""
import time
import uuid
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Cookie, Depends, Header, Query, WebSocket, WebSocketDisconnect
from pydantic import BaseModel
from sqlalchemy import case, desc, func
from sqlmodel import col, select, delete
from src.chat.message_receive.bot import chat_bot
from src.common.database.database import get_db_session
from src.common.database.database_model import Messages, PersonInfo
from src.common.logger import get_logger
from src.config.config import global_config
from src.webui.core import get_token_manager, verify_auth_token_from_cookie_or_header
from src.webui.routers.websocket.auth import verify_ws_token
logger = get_logger("webui.chat")
router = APIRouter(prefix="/api/chat", tags=["LocalChat"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# WebUI 聊天的虚拟群组 ID
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
WEBUI_CHAT_PLATFORM = "webui"
# 虚拟身份模式的群 ID 前缀
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
# 固定的 WebUI 用户 ID 前缀
WEBUI_USER_ID_PREFIX = "webui_user_"
class VirtualIdentityConfig(BaseModel):
"""虚拟身份配置"""
enabled: bool = False # 是否启用虚拟身份模式
platform: Optional[str] = None # 目标平台(如 qq, discord 等)
person_id: Optional[str] = None # PersonInfo 的 person_id
user_id: Optional[str] = None # 原始平台用户 ID
user_nickname: Optional[str] = None # 用户昵称
group_id: Optional[str] = None # 虚拟群 ID自动生成或用户指定
group_name: Optional[str] = None # 虚拟群名(用户自定义)
class ChatHistoryMessage(BaseModel):
"""聊天历史消息"""
id: str
type: str # 'user' | 'bot' | 'system'
content: str
timestamp: float
sender_name: str
sender_id: Optional[str] = None
is_bot: bool = False
class ChatHistoryManager:
"""聊天历史管理器 - 使用 SQLite 数据库存储"""
def __init__(self, max_messages: int = 200):
self.max_messages = max_messages
def _message_to_dict(self, msg: Messages, group_id: Optional[str] = None) -> Dict[str, Any]:
"""将数据库消息转换为前端格式
Args:
msg: 数据库消息对象
group_id: 群 ID用于判断是否是虚拟群
"""
# 判断是否是机器人消息
user_id = msg.user_id or ""
# 对于虚拟群,通过比较机器人 QQ 账号来判断
# 对于普通 WebUI 群,检查 user_id 是否以 webui_ 开头
if group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
# 虚拟群user_id 等于机器人 QQ 账号的是机器人消息
bot_qq = str(global_config.bot.qq_account)
is_bot = user_id == bot_qq
else:
# 普通 WebUI 群:不以 webui_ 开头的是机器人消息
is_bot = not user_id.startswith("webui_") and not user_id.startswith(WEBUI_USER_ID_PREFIX)
return {
"id": msg.message_id,
"type": "bot" if is_bot else "user",
"content": msg.processed_plain_text or msg.display_message or "",
"timestamp": msg.timestamp.timestamp(),
"sender_name": msg.user_nickname or (global_config.bot.nickname if is_bot else "未知用户"),
"sender_id": "bot" if is_bot else user_id,
"is_bot": is_bot,
}
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""从数据库获取最近的历史记录
Args:
limit: 获取的消息数量
group_id: 群 ID默认为 WEBUI_CHAT_GROUP_ID
"""
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
try:
# 查询指定群的消息,按时间排序
with get_db_session() as session:
statement = (
select(Messages)
.where(col(Messages.group_id) == target_group_id)
.order_by(desc(col(Messages.timestamp)))
.limit(limit)
)
messages = session.exec(statement).all()
# 转换为列表并反转(使最旧的消息在前)
# 传递 group_id 以便正确判断虚拟群中的机器人消息
result = [self._message_to_dict(msg, target_group_id) for msg in messages]
result.reverse()
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录 (group_id={target_group_id})")
return result
except Exception as e:
logger.error(f"从数据库加载聊天记录失败: {e}")
return []
def clear_history(self, group_id: Optional[str] = None) -> int:
"""清空聊天历史记录
Args:
group_id: 群 ID默认清空 WebUI 默认聊天室
"""
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
try:
with get_db_session() as session:
statement = delete(Messages).where(col(Messages.group_id) == target_group_id)
result = session.exec(statement)
deleted = result.rowcount or 0
logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
return deleted
except Exception as e:
logger.error(f"清空聊天记录失败: {e}")
return 0
# 全局聊天历史管理器
chat_history = ChatHistoryManager()
# 存储 WebSocket 连接
class ChatConnectionManager:
"""聊天连接管理器"""
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.user_sessions: Dict[str, str] = {} # user_id -> session_id 映射
async def connect(self, websocket: WebSocket, session_id: str, user_id: str):
await websocket.accept()
self.active_connections[session_id] = websocket
self.user_sessions[user_id] = session_id
logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}")
def disconnect(self, session_id: str, user_id: str):
if session_id in self.active_connections:
del self.active_connections[session_id]
if user_id in self.user_sessions and self.user_sessions[user_id] == session_id:
del self.user_sessions[user_id]
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
async def send_message(self, session_id: str, message: dict[str, Any]):
if session_id in self.active_connections:
try:
await self.active_connections[session_id].send_json(message)
except Exception as e:
logger.error(f"发送消息失败: {e}")
async def broadcast(self, message: dict[str, Any]):
"""广播消息给所有连接"""
for session_id in list(self.active_connections.keys()):
await self.send_message(session_id, message)
chat_manager = ChatConnectionManager()
def create_message_data(
content: str,
user_id: str,
user_name: str,
message_id: Optional[str] = None,
is_at_bot: bool = True,
virtual_config: Optional[VirtualIdentityConfig] = None,
) -> Dict[str, Any]:
"""创建符合麦麦消息格式的消息数据
Args:
content: 消息内容
user_id: 用户 ID
user_name: 用户昵称
message_id: 消息 ID可选自动生成
is_at_bot: 是否 @ 机器人
virtual_config: 虚拟身份配置(可选,启用后使用真实平台身份)
"""
if message_id is None:
message_id = str(uuid.uuid4())
# 确定使用的平台、群信息和用户信息
if virtual_config and virtual_config.enabled:
# 虚拟身份模式:使用真实平台身份
platform = virtual_config.platform or WEBUI_CHAT_PLATFORM
group_id = virtual_config.group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{uuid.uuid4().hex[:8]}"
group_name = virtual_config.group_name or "WebUI虚拟群聊"
actual_user_id = virtual_config.user_id or user_id
actual_user_name = virtual_config.user_nickname or user_name
else:
# 标准 WebUI 模式
platform = WEBUI_CHAT_PLATFORM
group_id = WEBUI_CHAT_GROUP_ID
group_name = "WebUI本地聊天室"
actual_user_id = user_id
actual_user_name = user_name
return {
"message_info": {
"platform": platform,
"message_id": message_id,
"time": time.time(),
"group_info": {
"group_id": group_id,
"group_name": group_name,
"platform": platform,
},
"user_info": {
"user_id": actual_user_id,
"user_nickname": actual_user_name,
"user_cardname": actual_user_name,
"platform": platform,
},
"additional_config": {
"at_bot": is_at_bot,
},
},
"message_segment": {
"type": "seglist",
"data": [
{
"type": "text",
"data": content,
},
{
"type": "mention_bot",
"data": "1.0",
},
],
},
"raw_message": content,
"processed_plain_text": content,
}
@router.get("/history")
async def get_chat_history(
limit: int = Query(default=50, ge=1, le=200),
user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
group_id: Optional[str] = Query(default=None), # 可选:指定群 ID 获取历史
_auth: bool = Depends(require_auth),
):
"""获取聊天历史记录
所有 WebUI 用户共享同一个聊天室,因此返回所有历史记录
如果指定了 group_id则获取该虚拟群的历史记录
"""
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
history = chat_history.get_history(limit, target_group_id)
return {
"success": True,
"messages": history,
"total": len(history),
}
@router.get("/platforms")
async def get_available_platforms(_auth: bool = Depends(require_auth)):
"""获取可用平台列表
从 PersonInfo 表中获取所有已知的平台
"""
try:
# 查询所有不同的平台
with get_db_session() as session:
statement = (
select(PersonInfo.platform, func.count().label("count"))
.group_by(PersonInfo.platform)
.order_by(func.count().desc())
)
platforms = session.exec(statement).all()
result = []
for platform, count in platforms:
if platform:
result.append({"platform": platform, "count": count})
return {"success": True, "platforms": result}
except Exception as e:
logger.error(f"获取平台列表失败: {e}")
return {"success": False, "error": str(e), "platforms": []}
@router.get("/persons")
async def get_persons_by_platform(
platform: str = Query(..., description="平台名称"),
search: Optional[str] = Query(default=None, description="搜索关键词"),
limit: int = Query(default=50, ge=1, le=200),
_auth: bool = Depends(require_auth),
):
"""获取指定平台的用户列表
Args:
platform: 平台名称(如 qq, discord 等)
search: 搜索关键词匹配昵称、用户名、user_id
limit: 返回数量限制
"""
try:
# 构建查询
statement = select(PersonInfo).where(col(PersonInfo.platform) == platform)
# 搜索过滤
if search:
statement = statement.where(
(col(PersonInfo.person_name).contains(search))
| (col(PersonInfo.user_nickname).contains(search))
| (col(PersonInfo.user_id).contains(search))
)
# 按最后交互时间排序,优先显示活跃用户
statement = statement.order_by(
case((col(PersonInfo.last_known_time).is_(None), 1), else_=0),
col(PersonInfo.last_known_time).desc(),
).limit(limit)
with get_db_session() as session:
persons = session.exec(statement).all()
result = []
for person in persons:
result.append(
{
"person_id": person.person_id,
"user_id": person.user_id,
"person_name": person.person_name,
"nickname": person.user_nickname,
"is_known": person.is_known,
"platform": person.platform,
"display_name": person.person_name or person.user_nickname or person.user_id,
}
)
return {"success": True, "persons": result, "total": len(result)}
except Exception as e:
logger.error(f"获取用户列表失败: {e}")
return {"success": False, "error": str(e), "persons": []}
@router.delete("/history")
async def clear_chat_history(group_id: Optional[str] = Query(default=None), _auth: bool = Depends(require_auth)):
"""清空聊天历史记录
Args:
group_id: 可选,指定要清空的群 ID默认清空 WebUI 默认聊天室
"""
deleted = chat_history.clear_history(group_id)
return {
"success": True,
"message": f"已清空 {deleted} 条聊天记录",
}
@router.websocket("/ws")
async def websocket_chat(
websocket: WebSocket,
user_id: Optional[str] = Query(default=None),
user_name: Optional[str] = Query(default="WebUI用户"),
platform: Optional[str] = Query(default=None),
person_id: Optional[str] = Query(default=None),
group_name: Optional[str] = Query(default=None),
group_id: Optional[str] = Query(default=None), # 前端传递的稳定 group_id
token: Optional[str] = Query(default=None), # 认证 token
):
"""WebSocket 聊天端点
Args:
user_id: 用户唯一标识(由前端生成并持久化)
user_name: 用户显示昵称(可修改)
platform: 虚拟身份模式的平台(可选)
person_id: 虚拟身份模式的用户 person_id可选
group_name: 虚拟身份模式的群名(可选)
group_id: 虚拟身份模式的群 ID可选由前端生成并持久化
token: 认证 token可选也可从 Cookie 获取)
虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置
支持三种认证方式(按优先级):
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/api/chat/ws?token=xxx
"""
is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
# 生成会话 ID每次连接都是新的
session_id = str(uuid.uuid4())
# 如果没有提供 user_id生成一个新的
if not user_id:
user_id = f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}"
elif not user_id.startswith(WEBUI_USER_ID_PREFIX):
# 确保 user_id 有正确的前缀
user_id = f"{WEBUI_USER_ID_PREFIX}{user_id}"
# 当前会话的虚拟身份配置(可通过消息动态更新)
current_virtual_config: Optional[VirtualIdentityConfig] = None
# 如果 URL 参数中提供了虚拟身份信息,自动配置
if platform and person_id:
try:
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
person = session.exec(statement).first()
if person:
# 使用前端传递的 group_id如果没有则生成一个稳定的
virtual_group_id = group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{platform}_{person.user_id}"
current_virtual_config = VirtualIdentityConfig(
enabled=True,
platform=person.platform,
person_id=person.person_id,
user_id=person.user_id,
user_nickname=person.person_name or person.user_nickname or person.user_id,
group_id=virtual_group_id,
group_name=group_name or "WebUI虚拟群聊",
)
logger.info(
f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}"
)
except Exception as e:
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
await chat_manager.connect(websocket, session_id, user_id)
try:
# 构建会话信息
session_info_data: dict[str, Any] = {
"type": "session_info",
"session_id": session_id,
"user_id": user_id,
"user_name": user_name,
"bot_name": global_config.bot.nickname,
}
# 如果有虚拟身份配置,添加到会话信息中
if current_virtual_config and current_virtual_config.enabled:
session_info_data["virtual_mode"] = True
session_info_data["group_id"] = current_virtual_config.group_id
session_info_data["virtual_identity"] = {
"platform": current_virtual_config.platform,
"user_id": current_virtual_config.user_id,
"user_nickname": current_virtual_config.user_nickname,
"group_name": current_virtual_config.group_name,
}
# 发送会话信息(包含用户 ID前端需要保存
await chat_manager.send_message(session_id, session_info_data)
# 发送历史记录(根据模式选择不同的群)
if current_virtual_config and current_virtual_config.enabled:
history = chat_history.get_history(50, current_virtual_config.group_id)
else:
history = chat_history.get_history(50)
if history:
await chat_manager.send_message(
session_id,
{
"type": "history",
"messages": history,
},
)
# 发送欢迎消息(不保存到历史)
if current_virtual_config and current_virtual_config.enabled:
welcome_msg = f"已以 {current_virtual_config.user_nickname} 的身份连接到「{current_virtual_config.group_name}」,开始与 {global_config.bot.nickname} 对话吧!"
else:
welcome_msg = f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!"
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": welcome_msg,
"timestamp": time.time(),
},
)
while True:
data = await websocket.receive_json()
if data.get("type") == "message":
content = data.get("content", "").strip()
if not content:
continue
# 用户可以更新昵称
current_user_name = data.get("user_name", user_name)
message_id = str(uuid.uuid4())
timestamp = time.time()
# 确定发送者信息(根据是否使用虚拟身份)
if current_virtual_config and current_virtual_config.enabled:
sender_name = current_virtual_config.user_nickname or current_user_name
sender_user_id = current_virtual_config.user_id or user_id
else:
sender_name = current_user_name
sender_user_id = user_id
# 广播用户消息给所有连接(包括发送者)
# 注意:用户消息会在 chat_bot.message_process 中自动保存到数据库
await chat_manager.broadcast(
{
"type": "user_message",
"content": content,
"message_id": message_id,
"timestamp": timestamp,
"sender": {
"name": sender_name,
"user_id": sender_user_id,
"is_bot": False,
},
"virtual_mode": current_virtual_config.enabled if current_virtual_config else False,
}
)
# 创建麦麦消息格式
message_data = create_message_data(
content=content,
user_id=user_id,
user_name=current_user_name,
message_id=message_id,
is_at_bot=True,
virtual_config=current_virtual_config,
)
try:
# 显示正在输入状态
await chat_manager.broadcast(
{
"type": "typing",
"is_typing": True,
}
)
# 调用麦麦的消息处理
await chat_bot.message_process(message_data)
except Exception as e:
logger.error(f"处理消息时出错: {e}")
await chat_manager.send_message(
session_id,
{
"type": "error",
"content": f"处理消息时出错: {str(e)}",
"timestamp": time.time(),
},
)
finally:
await chat_manager.broadcast(
{
"type": "typing",
"is_typing": False,
}
)
elif data.get("type") == "ping":
await chat_manager.send_message(
session_id,
{
"type": "pong",
"timestamp": time.time(),
},
)
elif data.get("type") == "update_nickname":
# 允许用户更新昵称
if new_name := data.get("user_name", "").strip():
current_user_name = new_name
await chat_manager.send_message(
session_id,
{
"type": "nickname_updated",
"user_name": current_user_name,
"timestamp": time.time(),
},
)
elif data.get("type") == "set_virtual_identity":
# 设置或更新虚拟身份配置
virtual_data = data.get("config", {})
if virtual_data.get("enabled"):
# 验证必要字段
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
await chat_manager.send_message(
session_id,
{
"type": "error",
"content": "虚拟身份配置缺少必要字段: platform 和 person_id",
"timestamp": time.time(),
},
)
continue
# 获取用户信息
try:
with get_db_session() as session:
statement = (
select(PersonInfo)
.where(col(PersonInfo.person_id) == virtual_data.get("person_id"))
.limit(1)
)
person = session.exec(statement).first()
if not person:
await chat_manager.send_message(
session_id,
{
"type": "error",
"content": f"找不到用户: {virtual_data.get('person_id')}",
"timestamp": time.time(),
},
)
continue
# 生成虚拟群 ID
custom_group_id = virtual_data.get("group_id")
if custom_group_id:
group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{custom_group_id}"
else:
group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{session_id[:8]}"
current_virtual_config = VirtualIdentityConfig(
enabled=True,
platform=person.platform,
person_id=person.person_id,
user_id=person.user_id,
user_nickname=person.person_name or person.user_nickname or person.user_id,
group_id=group_id,
group_name=virtual_data.get("group_name", "WebUI虚拟群聊"),
)
# 发送虚拟身份已激活的消息
await chat_manager.send_message(
session_id,
{
"type": "virtual_identity_set",
"config": {
"enabled": True,
"platform": current_virtual_config.platform,
"user_id": current_virtual_config.user_id,
"user_nickname": current_virtual_config.user_nickname,
"group_id": current_virtual_config.group_id,
"group_name": current_virtual_config.group_name,
},
"timestamp": time.time(),
},
)
# 加载虚拟群的历史记录
virtual_history = chat_history.get_history(50, current_virtual_config.group_id)
await chat_manager.send_message(
session_id,
{
"type": "history",
"messages": virtual_history,
"group_id": current_virtual_config.group_id,
},
)
# 发送系统消息
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": f"已切换到虚拟身份模式:以 {current_virtual_config.user_nickname} 的身份在「{current_virtual_config.group_name}」与 {global_config.bot.nickname} 对话",
"timestamp": time.time(),
},
)
except Exception as e:
logger.error(f"设置虚拟身份失败: {e}")
await chat_manager.send_message(
session_id,
{
"type": "error",
"content": f"设置虚拟身份失败: {str(e)}",
"timestamp": time.time(),
},
)
else:
# 禁用虚拟身份模式
current_virtual_config = None
await chat_manager.send_message(
session_id,
{
"type": "virtual_identity_set",
"config": {"enabled": False},
"timestamp": time.time(),
},
)
# 重新加载默认聊天室历史
default_history = chat_history.get_history(50, WEBUI_CHAT_GROUP_ID)
await chat_manager.send_message(
session_id,
{
"type": "history",
"messages": default_history,
"group_id": WEBUI_CHAT_GROUP_ID,
},
)
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": "已切换回 WebUI 独立用户模式",
"timestamp": time.time(),
},
)
except WebSocketDisconnect:
logger.info(f"WebSocket 断开: session={session_id}, user={user_id}")
except Exception as e:
logger.error(f"WebSocket 错误: {e}")
finally:
chat_manager.disconnect(session_id, user_id)
@router.get("/info")
async def get_chat_info(_auth: bool = Depends(require_auth)):
"""获取聊天室信息"""
return {
"bot_name": global_config.bot.nickname,
"platform": WEBUI_CHAT_PLATFORM,
"group_id": WEBUI_CHAT_GROUP_ID,
"active_sessions": len(chat_manager.active_connections),
}
def get_webui_chat_broadcaster() -> tuple[ChatConnectionManager, str]:
"""获取 WebUI 聊天广播器,供外部模块使用
Returns:
(chat_manager, WEBUI_CHAT_PLATFORM) 元组
"""
return (chat_manager, WEBUI_CHAT_PLATFORM)

View File

@@ -0,0 +1,18 @@
from fastapi import APIRouter
from .routes import router
from .support import ChatConnectionManager, WEBUI_CHAT_PLATFORM, chat_manager
def get_webui_chat_broadcaster() -> tuple[ChatConnectionManager, str]:
"""获取 WebUI 聊天广播器,供外部模块使用。"""
return chat_manager, WEBUI_CHAT_PLATFORM
__all__ = [
"ChatConnectionManager",
"WEBUI_CHAT_PLATFORM",
"chat_manager",
"get_webui_chat_broadcaster",
"router",
]

View File

@@ -0,0 +1,174 @@
"""本地聊天室路由 - WebUI 与麦麦直接对话。"""
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect
from sqlalchemy import case, func
from sqlmodel import col, select
from src.common.database.database import get_db_session
from src.common.database.database_model import PersonInfo
from src.common.logger import get_logger
from src.config.config import global_config
from src.webui.dependencies import require_auth
from .support import (
WEBUI_CHAT_GROUP_ID,
WEBUI_CHAT_PLATFORM,
authenticate_chat_websocket,
chat_history,
chat_manager,
dispatch_chat_event,
normalize_webui_user_id,
resolve_initial_virtual_identity,
send_initial_chat_state,
)
logger = get_logger("webui.chat")
router = APIRouter(prefix="/api/chat", tags=["LocalChat"], dependencies=[Depends(require_auth)])
@router.get("/history")
async def get_chat_history(
limit: int = Query(default=50, ge=1, le=200),
user_id: Optional[str] = Query(default=None),
group_id: Optional[str] = Query(default=None),
) -> dict[str, object]:
"""获取聊天历史记录。"""
del user_id
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
history = chat_history.get_history(limit, target_group_id)
return {"success": True, "messages": history, "total": len(history)}
@router.get("/platforms")
async def get_available_platforms() -> dict[str, object]:
"""获取可用平台列表。"""
try:
with get_db_session() as session:
statement = (
select(PersonInfo.platform, func.count().label("count"))
.group_by(PersonInfo.platform)
.order_by(func.count().desc())
)
platforms = session.exec(statement).all()
result = [{"platform": platform, "count": count} for platform, count in platforms if platform]
return {"success": True, "platforms": result}
except Exception as e:
logger.error(f"获取平台列表失败: {e}")
return {"success": False, "error": str(e), "platforms": []}
@router.get("/persons")
async def get_persons_by_platform(
platform: str = Query(..., description="平台名称"),
search: Optional[str] = Query(default=None, description="搜索关键词"),
limit: int = Query(default=50, ge=1, le=200),
) -> dict[str, object]:
"""获取指定平台的用户列表。"""
try:
statement = select(PersonInfo).where(col(PersonInfo.platform) == platform)
if search:
statement = statement.where(
(col(PersonInfo.person_name).contains(search))
| (col(PersonInfo.user_nickname).contains(search))
| (col(PersonInfo.user_id).contains(search))
)
statement = statement.order_by(
case((col(PersonInfo.last_known_time).is_(None), 1), else_=0),
col(PersonInfo.last_known_time).desc(),
).limit(limit)
with get_db_session() as session:
persons = session.exec(statement).all()
result = [
{
"person_id": person.person_id,
"user_id": person.user_id,
"person_name": person.person_name,
"nickname": person.user_nickname,
"is_known": person.is_known,
"platform": person.platform,
"display_name": person.person_name or person.user_nickname or person.user_id,
}
for person in persons
]
return {"success": True, "persons": result, "total": len(result)}
except Exception as e:
logger.error(f"获取用户列表失败: {e}")
return {"success": False, "error": str(e), "persons": []}
@router.delete("/history")
async def clear_chat_history(
group_id: Optional[str] = Query(default=None),
) -> dict[str, object]:
"""清空聊天历史记录。"""
deleted = chat_history.clear_history(group_id)
return {"success": True, "message": f"已清空 {deleted} 条聊天记录"}
@router.websocket("/ws")
async def websocket_chat(
websocket: WebSocket,
user_id: Optional[str] = Query(default=None),
user_name: Optional[str] = Query(default="WebUI用户"),
platform: Optional[str] = Query(default=None),
person_id: Optional[str] = Query(default=None),
group_name: Optional[str] = Query(default=None),
group_id: Optional[str] = Query(default=None),
token: Optional[str] = Query(default=None),
) -> None:
"""WebSocket 聊天端点。"""
if not await authenticate_chat_websocket(websocket, token):
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
session_id = str(uuid.uuid4())
normalized_user_id = normalize_webui_user_id(user_id)
current_user_name = user_name or "WebUI用户"
current_virtual_config = resolve_initial_virtual_identity(platform, person_id, group_name, group_id)
await chat_manager.connect(websocket, session_id, normalized_user_id)
try:
await send_initial_chat_state(
session_id=session_id,
user_id=normalized_user_id,
user_name=current_user_name,
virtual_config=current_virtual_config,
)
while True:
data = await websocket.receive_json()
current_user_name, current_virtual_config = await dispatch_chat_event(
session_id=session_id,
session_id_prefix=session_id[:8],
data=data,
current_user_name=current_user_name,
normalized_user_id=normalized_user_id,
current_virtual_config=current_virtual_config,
)
except WebSocketDisconnect:
logger.info(f"WebSocket 断开: session={session_id}, user={normalized_user_id}")
except Exception as e:
logger.error(f"WebSocket 错误: {e}")
finally:
chat_manager.disconnect(session_id, normalized_user_id)
@router.get("/info")
async def get_chat_info() -> dict[str, object]:
"""获取聊天室信息。"""
return {
"bot_name": global_config.bot.nickname,
"platform": WEBUI_CHAT_PLATFORM,
"group_id": WEBUI_CHAT_GROUP_ID,
"active_sessions": len(chat_manager.active_connections),
}

View File

@@ -0,0 +1,614 @@
"""WebUI 聊天路由支持逻辑。"""
from typing import Any, Optional, cast
import time
import uuid
from fastapi import WebSocket
from pydantic import BaseModel
from sqlmodel import col, delete, select
from src.chat.message_receive.bot import chat_bot
from src.chat.message_receive.message import SessionMessage
from src.common.database.database import get_db_session
from src.common.database.database_model import Messages, PersonInfo
from src.common.logger import get_logger
from src.common.message_repository import find_messages
from src.common.utils.system_utils import is_bot_self
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
logger = get_logger("webui.chat")
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
WEBUI_CHAT_PLATFORM = "webui"
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
WEBUI_USER_ID_PREFIX = "webui_user_"
class VirtualIdentityConfig(BaseModel):
"""虚拟身份配置。"""
enabled: bool = False
platform: Optional[str] = None
person_id: Optional[str] = None
user_id: Optional[str] = None
user_nickname: Optional[str] = None
group_id: Optional[str] = None
group_name: Optional[str] = None
class ChatHistoryMessage(BaseModel):
"""聊天历史消息。"""
id: str
type: str
content: str
timestamp: float
sender_name: str
sender_id: Optional[str] = None
is_bot: bool = False
class ChatHistoryManager:
"""聊天历史管理器。"""
def __init__(self, max_messages: int = 200) -> None:
self.max_messages = max_messages
def _message_to_dict(self, msg: SessionMessage, group_id: Optional[str] = None) -> dict[str, Any]:
user_info = msg.message_info.user_info
user_id = user_info.user_id or ""
is_bot = is_bot_self(user_id, msg.platform)
if not is_bot and group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
is_bot = user_id == str(global_config.bot.qq_account)
elif not is_bot:
is_bot = not user_id.startswith(WEBUI_USER_ID_PREFIX)
return {
"id": msg.message_id,
"type": "bot" if is_bot else "user",
"content": msg.processed_plain_text or msg.display_message or "",
"timestamp": msg.timestamp.timestamp(),
"sender_name": user_info.user_nickname or (global_config.bot.nickname if is_bot else "未知用户"),
"sender_id": "bot" if is_bot else user_id,
"is_bot": is_bot,
}
def _resolve_session_id(self, group_id: Optional[str]) -> str:
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
return SessionUtils.calculate_session_id(WEBUI_CHAT_PLATFORM, group_id=target_group_id)
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> list[dict[str, Any]]:
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
session_id = self._resolve_session_id(target_group_id)
try:
messages = find_messages(
session_id=session_id,
limit=limit,
limit_mode="latest",
filter_command=False,
)
result = [self._message_to_dict(msg, target_group_id) for msg in messages]
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录 (group_id={target_group_id})")
return result
except Exception as e:
logger.error(f"从数据库加载聊天记录失败: {e}")
return []
def clear_history(self, group_id: Optional[str] = None) -> int:
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
session_id = self._resolve_session_id(target_group_id)
try:
with get_db_session() as session:
statement = delete(Messages).where(col(Messages.session_id) == session_id)
result = session.exec(statement)
deleted = result.rowcount or 0
logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
return deleted
except Exception as e:
logger.error(f"清空聊天记录失败: {e}")
return 0
class ChatConnectionManager:
"""聊天连接管理器。"""
def __init__(self) -> None:
self.active_connections: dict[str, WebSocket] = {}
self.user_sessions: dict[str, str] = {}
async def connect(self, websocket: WebSocket, session_id: str, user_id: str) -> None:
await websocket.accept()
self.active_connections[session_id] = websocket
self.user_sessions[user_id] = session_id
logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}")
def disconnect(self, session_id: str, user_id: str) -> None:
if session_id in self.active_connections:
del self.active_connections[session_id]
if user_id in self.user_sessions and self.user_sessions[user_id] == session_id:
del self.user_sessions[user_id]
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
async def send_message(self, session_id: str, message: dict[str, Any]) -> None:
if session_id in self.active_connections:
try:
await self.active_connections[session_id].send_json(message)
except Exception as e:
logger.error(f"发送消息失败: {e}")
async def broadcast(self, message: dict[str, Any]) -> None:
for session_id in list(self.active_connections.keys()):
await self.send_message(session_id, message)
chat_history = ChatHistoryManager()
chat_manager = ChatConnectionManager()
def is_virtual_mode_enabled(virtual_config: Optional[VirtualIdentityConfig]) -> bool:
return bool(virtual_config and virtual_config.enabled)
async def authenticate_chat_websocket(websocket: WebSocket, token: Optional[str]) -> bool:
if token and verify_ws_token(token):
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
return True
if cookie_token := websocket.cookies.get("maibot_session"):
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
return True
return False
def normalize_webui_user_id(user_id: Optional[str]) -> str:
if not user_id:
return f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}"
if user_id.startswith(WEBUI_USER_ID_PREFIX):
return user_id
return f"{WEBUI_USER_ID_PREFIX}{user_id}"
def get_person_by_person_id(person_id: str) -> Optional[PersonInfo]:
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
return session.exec(statement).first()
def build_virtual_identity_config(person: PersonInfo, group_id: str, group_name: str) -> VirtualIdentityConfig:
return VirtualIdentityConfig(
enabled=True,
platform=person.platform,
person_id=person.person_id,
user_id=person.user_id,
user_nickname=person.person_name or person.user_nickname or person.user_id,
group_id=group_id,
group_name=group_name,
)
def resolve_initial_virtual_identity(
platform: Optional[str],
person_id: Optional[str],
group_name: Optional[str],
group_id: Optional[str],
) -> Optional[VirtualIdentityConfig]:
if not (platform and person_id):
return None
try:
person = get_person_by_person_id(person_id)
if person is None:
return None
virtual_group_id = group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{platform}_{person.user_id}"
virtual_config = build_virtual_identity_config(
person=person,
group_id=virtual_group_id,
group_name=group_name or "WebUI虚拟群聊",
)
logger.info(
f"虚拟身份模式已通过 URL 参数激活: {virtual_config.user_nickname} @ {virtual_config.platform}, group_id={virtual_group_id}"
)
return virtual_config
except Exception as e:
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
return None
def build_session_info_message(
session_id: str,
user_id: str,
user_name: str,
virtual_config: Optional[VirtualIdentityConfig],
) -> dict[str, Any]:
session_info_data: dict[str, Any] = {
"type": "session_info",
"session_id": session_id,
"user_id": user_id,
"user_name": user_name,
"bot_name": global_config.bot.nickname,
}
if is_virtual_mode_enabled(virtual_config):
assert virtual_config is not None
session_info_data["virtual_mode"] = True
session_info_data["group_id"] = virtual_config.group_id
session_info_data["virtual_identity"] = {
"platform": virtual_config.platform,
"user_id": virtual_config.user_id,
"user_nickname": virtual_config.user_nickname,
"group_name": virtual_config.group_name,
}
return session_info_data
def get_active_history_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> Optional[str]:
if is_virtual_mode_enabled(virtual_config):
assert virtual_config is not None
return virtual_config.group_id
return None
def build_welcome_message(virtual_config: Optional[VirtualIdentityConfig]) -> str:
if is_virtual_mode_enabled(virtual_config):
assert virtual_config is not None
return (
f"已以 {virtual_config.user_nickname} 的身份连接到「{virtual_config.group_name}」,"
f"开始与 {global_config.bot.nickname} 对话吧!"
)
return f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!"
async def send_chat_error(session_id: str, content: str) -> None:
await chat_manager.send_message(
session_id,
{
"type": "error",
"content": content,
"timestamp": time.time(),
},
)
async def send_initial_chat_state(
session_id: str,
user_id: str,
user_name: str,
virtual_config: Optional[VirtualIdentityConfig],
) -> None:
await chat_manager.send_message(
session_id,
build_session_info_message(
session_id=session_id,
user_id=user_id,
user_name=user_name,
virtual_config=virtual_config,
),
)
if history := chat_history.get_history(50, get_active_history_group_id(virtual_config)):
await chat_manager.send_message(
session_id,
{
"type": "history",
"messages": history,
},
)
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": build_welcome_message(virtual_config),
"timestamp": time.time(),
},
)
def resolve_sender_identity(
current_user_name: str,
normalized_user_id: str,
virtual_config: Optional[VirtualIdentityConfig],
) -> tuple[str, str]:
if is_virtual_mode_enabled(virtual_config):
assert virtual_config is not None
return virtual_config.user_nickname or current_user_name, virtual_config.user_id or normalized_user_id
return current_user_name, normalized_user_id
def create_message_data(
content: str,
user_id: str,
user_name: str,
message_id: Optional[str] = None,
is_at_bot: bool = True,
virtual_config: Optional[VirtualIdentityConfig] = None,
) -> dict[str, Any]:
if message_id is None:
message_id = str(uuid.uuid4())
if virtual_config and virtual_config.enabled:
platform = virtual_config.platform or WEBUI_CHAT_PLATFORM
group_id = virtual_config.group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{uuid.uuid4().hex[:8]}"
group_name = virtual_config.group_name or "WebUI虚拟群聊"
actual_user_id = virtual_config.user_id or user_id
actual_user_name = virtual_config.user_nickname or user_name
else:
platform = WEBUI_CHAT_PLATFORM
group_id = WEBUI_CHAT_GROUP_ID
group_name = "WebUI本地聊天室"
actual_user_id = user_id
actual_user_name = user_name
return {
"message_info": {
"platform": platform,
"message_id": message_id,
"time": time.time(),
"group_info": {
"group_id": group_id,
"group_name": group_name,
"platform": platform,
},
"user_info": {
"user_id": actual_user_id,
"user_nickname": actual_user_name,
"user_cardname": actual_user_name,
"platform": platform,
},
"additional_config": {
"at_bot": is_at_bot,
},
},
"message_segment": {
"type": "seglist",
"data": [
{
"type": "text",
"data": content,
},
{
"type": "mention_bot",
"data": "1.0",
},
],
},
"raw_message": content,
"processed_plain_text": content,
}
async def handle_chat_message(
session_id: str,
data: dict[str, Any],
current_user_name: str,
normalized_user_id: str,
current_virtual_config: Optional[VirtualIdentityConfig],
) -> str:
content = str(data.get("content", "")).strip()
if not content:
return current_user_name
next_user_name = str(data.get("user_name", current_user_name))
message_id = str(uuid.uuid4())
timestamp = time.time()
sender_name, sender_user_id = resolve_sender_identity(
current_user_name=next_user_name,
normalized_user_id=normalized_user_id,
virtual_config=current_virtual_config,
)
await chat_manager.broadcast(
{
"type": "user_message",
"content": content,
"message_id": message_id,
"timestamp": timestamp,
"sender": {
"name": sender_name,
"user_id": sender_user_id,
"is_bot": False,
},
"virtual_mode": is_virtual_mode_enabled(current_virtual_config),
}
)
message_data = create_message_data(
content=content,
user_id=normalized_user_id,
user_name=next_user_name,
message_id=message_id,
is_at_bot=True,
virtual_config=current_virtual_config,
)
try:
await chat_manager.broadcast({"type": "typing", "is_typing": True})
await chat_bot.message_process(message_data)
except Exception as e:
logger.error(f"处理消息时出错: {e}")
await send_chat_error(session_id, f"处理消息时出错: {str(e)}")
finally:
await chat_manager.broadcast({"type": "typing", "is_typing": False})
return next_user_name
async def handle_chat_ping(session_id: str) -> None:
await chat_manager.send_message(session_id, {"type": "pong", "timestamp": time.time()})
async def handle_nickname_update(session_id: str, data: dict[str, Any], current_user_name: str) -> str:
new_name = str(data.get("user_name", "")).strip()
if not new_name:
return current_user_name
await chat_manager.send_message(
session_id,
{
"type": "nickname_updated",
"user_name": new_name,
"timestamp": time.time(),
},
)
return new_name
async def enable_virtual_identity(
session_id: str,
session_prefix: str,
virtual_data: dict[str, Any],
) -> Optional[VirtualIdentityConfig]:
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
await send_chat_error(session_id, "虚拟身份配置缺少必要字段: platform 和 person_id")
return None
person_id_value = str(virtual_data.get("person_id"))
try:
person = get_person_by_person_id(person_id_value)
if not person:
await send_chat_error(session_id, f"找不到用户: {person_id_value}")
return None
custom_group_id = virtual_data.get("group_id")
current_group_id = (
f"{VIRTUAL_GROUP_ID_PREFIX}{custom_group_id}"
if custom_group_id
else f"{VIRTUAL_GROUP_ID_PREFIX}{session_prefix}"
)
current_virtual_config = build_virtual_identity_config(
person=person,
group_id=current_group_id,
group_name=str(virtual_data.get("group_name", "WebUI虚拟群聊")),
)
await chat_manager.send_message(
session_id,
{
"type": "virtual_identity_set",
"config": {
"enabled": True,
"platform": current_virtual_config.platform,
"user_id": current_virtual_config.user_id,
"user_nickname": current_virtual_config.user_nickname,
"group_id": current_virtual_config.group_id,
"group_name": current_virtual_config.group_name,
},
"timestamp": time.time(),
},
)
await chat_manager.send_message(
session_id,
{
"type": "history",
"messages": chat_history.get_history(50, current_virtual_config.group_id),
"group_id": current_virtual_config.group_id,
},
)
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": (
f"已切换到虚拟身份模式:以 {current_virtual_config.user_nickname} 的身份在"
f"{current_virtual_config.group_name}」与 {global_config.bot.nickname} 对话"
),
"timestamp": time.time(),
},
)
return current_virtual_config
except Exception as e:
logger.error(f"设置虚拟身份失败: {e}")
await send_chat_error(session_id, f"设置虚拟身份失败: {str(e)}")
return None
async def disable_virtual_identity(session_id: str) -> None:
await chat_manager.send_message(
session_id,
{
"type": "virtual_identity_set",
"config": {"enabled": False},
"timestamp": time.time(),
},
)
await chat_manager.send_message(
session_id,
{
"type": "history",
"messages": chat_history.get_history(50, WEBUI_CHAT_GROUP_ID),
"group_id": WEBUI_CHAT_GROUP_ID,
},
)
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": "已切换回 WebUI 独立用户模式",
"timestamp": time.time(),
},
)
async def handle_virtual_identity_update(
session_id: str,
session_id_prefix: str,
data: dict[str, Any],
current_virtual_config: Optional[VirtualIdentityConfig],
) -> Optional[VirtualIdentityConfig]:
virtual_data = cast(dict[str, Any], data.get("config", {}))
if virtual_data.get("enabled"):
next_config = await enable_virtual_identity(session_id, session_id_prefix, virtual_data)
return next_config if next_config is not None else current_virtual_config
await disable_virtual_identity(session_id)
return None
async def dispatch_chat_event(
session_id: str,
session_id_prefix: str,
data: dict[str, Any],
current_user_name: str,
normalized_user_id: str,
current_virtual_config: Optional[VirtualIdentityConfig],
) -> tuple[str, Optional[VirtualIdentityConfig]]:
event_type = data.get("type")
if event_type == "message":
next_user_name = await handle_chat_message(
session_id=session_id,
data=data,
current_user_name=current_user_name,
normalized_user_id=normalized_user_id,
current_virtual_config=current_virtual_config,
)
return next_user_name, current_virtual_config
if event_type == "ping":
await handle_chat_ping(session_id)
return current_user_name, current_virtual_config
if event_type == "update_nickname":
next_user_name = await handle_nickname_update(session_id, data, current_user_name)
return next_user_name, current_virtual_config
if event_type == "set_virtual_identity":
next_virtual_config = await handle_virtual_identity_update(
session_id=session_id,
session_id_prefix=session_id_prefix,
data=data,
current_virtual_config=current_virtual_config,
)
return current_user_name, next_virtual_config
return current_user_name, current_virtual_config

View File

@@ -4,12 +4,13 @@
import copy
import os
import tomlkit
from fastapi import APIRouter, HTTPException, Body, Depends, Cookie, Header
from typing import Any, Annotated, Optional
import tomlkit
from fastapi import APIRouter, Body, Depends, HTTPException
from src.common.logger import get_logger
from src.webui.core import verify_auth_token_from_cookie_or_header
from src.webui.dependencies import require_auth
from src.webui.utils.toml_utils import save_toml_with_format, _update_toml_doc
from src.config.config import Config, ModelConfig, CONFIG_DIR, PROJECT_ROOT
from src.config.config_base import AttributeData
@@ -49,7 +50,7 @@ SectionBody = Annotated[Any, Body()]
RawContentBody = Annotated[str, Body(embed=True)]
PathBody = Annotated[dict[str, str], Body()]
router = APIRouter(prefix="/config", tags=["config"])
router = APIRouter(prefix="/config", tags=["config"], dependencies=[Depends(require_auth)])
def _toml_to_plain_dict(obj: Any) -> Any:
@@ -59,21 +60,11 @@ def _toml_to_plain_dict(obj: Any) -> Any:
if isinstance(obj, list):
return [_toml_to_plain_dict(v) for v in obj]
return obj
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# ===== 架构获取接口 =====
@router.get("/schema/bot")
async def get_bot_config_schema(_auth: bool = Depends(require_auth)):
async def get_bot_config_schema():
"""获取麦麦主程序配置架构"""
try:
# Config 类包含所有子配置
@@ -85,7 +76,7 @@ async def get_bot_config_schema(_auth: bool = Depends(require_auth)):
@router.get("/schema/model")
async def get_model_config_schema(_auth: bool = Depends(require_auth)):
async def get_model_config_schema():
"""获取模型配置架构(包含提供商和模型任务配置)"""
try:
schema = ConfigSchemaGenerator.generate_config_schema(ModelConfig)
@@ -99,7 +90,7 @@ async def get_model_config_schema(_auth: bool = Depends(require_auth)):
@router.get("/schema/section/{section_name}")
async def get_config_section_schema(section_name: str, _auth: bool = Depends(require_auth)):
async def get_config_section_schema(section_name: str):
"""
获取指定配置节的架构
@@ -169,7 +160,7 @@ async def get_config_section_schema(section_name: str, _auth: bool = Depends(req
@router.get("/bot")
async def get_bot_config(_auth: bool = Depends(require_auth)):
async def get_bot_config():
"""获取麦麦主程序配置"""
try:
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
@@ -188,7 +179,7 @@ async def get_bot_config(_auth: bool = Depends(require_auth)):
@router.get("/model")
async def get_model_config(_auth: bool = Depends(require_auth)):
async def get_model_config():
"""获取模型配置(包含提供商和模型任务配置)"""
try:
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
@@ -210,7 +201,7 @@ async def get_model_config(_auth: bool = Depends(require_auth)):
@router.post("/bot")
async def update_bot_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
async def update_bot_config(config_data: ConfigBody):
"""更新麦麦主程序配置"""
try:
# 验证配置数据
@@ -233,7 +224,7 @@ async def update_bot_config(config_data: ConfigBody, _auth: bool = Depends(requi
@router.post("/model")
async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
async def update_model_config(config_data: ConfigBody):
"""更新模型配置"""
try:
# 验证配置数据
@@ -259,7 +250,7 @@ async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(req
@router.post("/bot/section/{section_name}")
async def update_bot_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)):
async def update_bot_config_section(section_name: str, section_data: SectionBody):
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
try:
# 读取现有配置
@@ -308,7 +299,7 @@ async def update_bot_config_section(section_name: str, section_data: SectionBody
@router.get("/bot/raw")
async def get_bot_config_raw(_auth: bool = Depends(require_auth)):
async def get_bot_config_raw():
"""获取麦麦主程序配置的原始 TOML 内容"""
try:
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
@@ -327,7 +318,7 @@ async def get_bot_config_raw(_auth: bool = Depends(require_auth)):
@router.post("/bot/raw")
async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depends(require_auth)):
async def update_bot_config_raw(raw_content: RawContentBody):
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
try:
# 验证 TOML 格式
@@ -357,9 +348,7 @@ async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depen
@router.post("/model/section/{section_name}")
async def update_model_config_section(
section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)
):
async def update_model_config_section(section_name: str, section_data: SectionBody):
"""更新模型配置的指定节(保留注释和格式)"""
try:
# 读取现有配置
@@ -451,7 +440,7 @@ def _to_relative_path(path: str) -> str:
@router.get("/adapter-config/path")
async def get_adapter_config_path(_auth: bool = Depends(require_auth)):
async def get_adapter_config_path():
"""获取保存的适配器配置文件路径"""
try:
# 从 data/webui.json 读取路径偏好
@@ -490,7 +479,7 @@ async def get_adapter_config_path(_auth: bool = Depends(require_auth)):
@router.post("/adapter-config/path")
async def save_adapter_config_path(data: PathBody, _auth: bool = Depends(require_auth)):
async def save_adapter_config_path(data: PathBody):
"""保存适配器配置文件路径偏好"""
try:
path = data.get("path")
@@ -533,7 +522,7 @@ async def save_adapter_config_path(data: PathBody, _auth: bool = Depends(require
@router.get("/adapter-config")
async def get_adapter_config(path: str, _auth: bool = Depends(require_auth)):
async def get_adapter_config(path: str):
"""从指定路径读取适配器配置文件"""
try:
if not path:
@@ -565,7 +554,7 @@ async def get_adapter_config(path: str, _auth: bool = Depends(require_auth)):
@router.post("/adapter-config")
async def save_adapter_config(data: PathBody, _auth: bool = Depends(require_auth)):
async def save_adapter_config(data: PathBody):
"""保存适配器配置到指定路径"""
try:
path = data.get("path")

View File

@@ -0,0 +1,3 @@
from .routes import router
__all__ = ["router"]

View File

@@ -0,0 +1,140 @@
from typing import Annotated, List, Optional
from fastapi import File, Form, UploadFile
from pydantic import BaseModel
from src.common.database.database_model import Images
EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")]
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")]
DescriptionForm = Annotated[str, Form(description="表情包描述")]
EmotionForm = Annotated[str, Form(description="情感标签,多个用逗号分隔")]
IsRegisteredForm = Annotated[bool, Form(description="是否直接注册")]
class EmojiResponse(BaseModel):
"""表情包响应"""
id: int
full_path: str
emoji_hash: str
description: str
query_count: int
is_registered: bool
is_banned: bool
emotion: Optional[str]
record_time: float
register_time: Optional[float]
last_used_time: Optional[float]
class EmojiListResponse(BaseModel):
"""表情包列表响应"""
success: bool
total: int
page: int
page_size: int
data: List[EmojiResponse]
class EmojiDetailResponse(BaseModel):
"""表情包详情响应"""
success: bool
data: EmojiResponse
class EmojiUpdateRequest(BaseModel):
"""表情包更新请求"""
description: Optional[str] = None
is_registered: Optional[bool] = None
is_banned: Optional[bool] = None
emotion: Optional[str] = None
class EmojiUpdateResponse(BaseModel):
"""表情包更新响应"""
success: bool
message: str
data: Optional[EmojiResponse] = None
class EmojiDeleteResponse(BaseModel):
"""表情包删除响应"""
success: bool
message: str
class BatchDeleteRequest(BaseModel):
"""批量删除请求"""
emoji_ids: List[int]
class BatchDeleteResponse(BaseModel):
"""批量删除响应"""
success: bool
message: str
deleted_count: int
failed_count: int
failed_ids: List[int] = []
class EmojiUploadResponse(BaseModel):
"""表情包上传响应"""
success: bool
message: str
data: Optional[EmojiResponse] = None
class ThumbnailCacheStatsResponse(BaseModel):
"""缩略图缓存统计响应"""
success: bool
cache_dir: str
total_count: int
total_size_mb: float
emoji_count: int
coverage_percent: float
class ThumbnailCleanupResponse(BaseModel):
"""缩略图清理响应"""
success: bool
message: str
cleaned_count: int
kept_count: int
class ThumbnailPreheatResponse(BaseModel):
"""缩略图预热响应"""
success: bool
message: str
generated_count: int
skipped_count: int
failed_count: int
def emoji_to_response(image: Images) -> EmojiResponse:
"""将数据库表情包模型转换为响应对象。"""
return EmojiResponse(
id=image.id if image.id is not None else 0,
full_path=image.full_path,
emoji_hash=image.image_hash,
description=image.description,
query_count=image.query_count,
is_registered=image.is_registered,
is_banned=image.is_banned,
emotion=image.emotion,
record_time=image.record_time.timestamp() if image.record_time else 0.0,
register_time=image.register_time.timestamp() if image.register_time else None,
last_used_time=image.last_used_time.timestamp() if image.last_used_time else None,
)

View File

@@ -0,0 +1,142 @@
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import os
import threading
from PIL import Image
from sqlmodel import col, select
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType
from src.common.logger import get_logger
logger = get_logger("webui.emoji")
THUMBNAIL_CACHE_DIR = Path("data/emoji_thumbnails")
THUMBNAIL_SIZE = (200, 200)
THUMBNAIL_QUALITY = 80
EMOJI_REGISTERED_DIR = os.path.join("data", "emoji_registed")
_thumbnail_locks: dict[str, threading.Lock] = {}
_locks_lock = threading.Lock()
_thumbnail_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="thumbnail")
_generating_thumbnails: set[str] = set()
_generating_lock = threading.Lock()
def get_thumbnail_executor() -> ThreadPoolExecutor:
"""获取缩略图生成线程池。"""
return _thumbnail_executor
def get_generating_lock() -> threading.Lock:
"""获取缩略图生成状态锁。"""
return _generating_lock
def get_generating_thumbnails() -> set[str]:
"""获取正在生成的缩略图哈希集合。"""
return _generating_thumbnails
def _get_thumbnail_lock(file_hash: str) -> threading.Lock:
"""获取指定文件哈希的锁,用于防止并发生成同一缩略图。"""
with _locks_lock:
if file_hash not in _thumbnail_locks:
_thumbnail_locks[file_hash] = threading.Lock()
return _thumbnail_locks[file_hash]
def _background_generate_thumbnail(source_path: str, file_hash: str) -> None:
"""在线程池中后台生成缩略图。"""
try:
_generate_thumbnail(source_path, file_hash)
except Exception as e:
logger.warning(f"后台生成缩略图失败 {file_hash}: {e}")
finally:
with _generating_lock:
_generating_thumbnails.discard(file_hash)
def ensure_thumbnail_cache_dir() -> Path:
"""确保缩略图缓存目录存在。"""
_ = THUMBNAIL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
return THUMBNAIL_CACHE_DIR
def get_thumbnail_cache_path(file_hash: str) -> Path:
"""获取缩略图缓存路径。"""
return THUMBNAIL_CACHE_DIR / f"{file_hash}.webp"
def _generate_thumbnail(source_path: str, file_hash: str) -> Path:
"""生成缩略图并保存到缓存目录。"""
ensure_thumbnail_cache_dir()
cache_path = get_thumbnail_cache_path(file_hash)
lock = _get_thumbnail_lock(file_hash)
with lock:
if cache_path.exists():
return cache_path
try:
with Image.open(source_path) as img:
if getattr(img, "n_frames", 1) > 1:
img.seek(0)
if img.mode in ("P", "PA"):
img = img.convert("RGBA")
elif img.mode == "LA":
img = img.convert("RGBA")
elif img.mode not in ("RGB", "RGBA"):
img = img.convert("RGB")
img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS)
img.save(cache_path, "WEBP", quality=THUMBNAIL_QUALITY, method=6)
logger.debug(f"生成缩略图: {file_hash} -> {cache_path}")
except Exception as e:
logger.warning(f"生成缩略图失败 {file_hash}: {e},将返回原图")
raise
return cache_path
def generate_thumbnail(source_path: str, file_hash: str) -> Path:
"""暴露给路由层的缩略图生成函数。"""
return _generate_thumbnail(source_path, file_hash)
def background_generate_thumbnail(source_path: str, file_hash: str) -> None:
"""暴露给路由层的后台缩略图生成函数。"""
_background_generate_thumbnail(source_path, file_hash)
def cleanup_orphaned_thumbnails() -> tuple[int, int]:
"""清理孤立的缩略图缓存。"""
if not THUMBNAIL_CACHE_DIR.exists():
return 0, 0
with get_db_session() as session:
statement = select(Images.image_hash).where(col(Images.image_type) == ImageType.EMOJI)
valid_hashes = set(session.exec(statement).all())
cleaned = 0
kept = 0
for cache_file in THUMBNAIL_CACHE_DIR.glob("*.webp"):
file_hash = cache_file.stem
if file_hash not in valid_hashes:
try:
cache_file.unlink()
cleaned += 1
logger.debug(f"清理孤立缩略图: {cache_file.name}")
except Exception as e:
logger.warning(f"清理缩略图失败 {cache_file.name}: {e}")
else:
kept += 1
if cleaned > 0:
logger.info(f"清理孤立缩略图: 删除 {cleaned} 个,保留 {kept}")
return cleaned, kept

View File

@@ -1,9 +1,10 @@
"""表达方式管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
from pydantic import BaseModel
from typing import Optional, List, Dict
from datetime import datetime, timedelta
from typing import Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy import case, func
from sqlmodel import col, select, delete
@@ -12,12 +13,12 @@ from src.common.logger import get_logger
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.webui.core import verify_auth_token_from_cookie_or_header
from src.webui.dependencies import require_auth
logger = get_logger("webui.expression")
# 创建路由器
router = APIRouter(prefix="/expression", tags=["Expression"])
router = APIRouter(prefix="/expression", tags=["Expression"], dependencies=[Depends(require_auth)])
class ExpressionResponse(BaseModel):
@@ -90,14 +91,6 @@ class ExpressionCreateResponse(BaseModel):
data: ExpressionResponse
def verify_auth_token(
maibot_session: Optional[str] = None,
authorization: Optional[str] = None,
) -> bool:
"""验证认证 Token支持 Cookie 和 Header"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
def expression_to_response(expression: Expression) -> ExpressionResponse:
"""将 Expression 模型转换为响应对象"""
last_active_time = expression.last_active_time.timestamp() if expression.last_active_time else 0.0
@@ -156,19 +149,14 @@ class ChatListResponse(BaseModel):
@router.get("/chats", response_model=ChatListResponse)
async def get_chat_list(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def get_chat_list():
"""
获取所有聊天列表(用于下拉选择)
Args:
authorization: Authorization header
Returns:
聊天列表
"""
try:
verify_auth_token(maibot_session, authorization)
chat_list = []
for session_id, session in _chat_manager.sessions.items():
chat_name = _chat_manager.get_session_name(session_id) or session_id
@@ -199,8 +187,6 @@ async def get_expression_list(
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
获取表达方式列表
@@ -210,14 +196,11 @@ async def get_expression_list(
page_size: 每页数量 (1-100)
search: 搜索关键词 (匹配 situation, style)
chat_id: 聊天ID筛选
authorization: Authorization header
Returns:
表达方式列表
"""
try:
verify_auth_token(maibot_session, authorization)
# 构建查询
statement = select(Expression)
@@ -264,22 +247,17 @@ async def get_expression_list(
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
async def get_expression_detail(
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
async def get_expression_detail(expression_id: int):
"""
获取表达方式详细信息
Args:
expression_id: 表达方式ID
authorization: Authorization header
Returns:
表达方式详细信息
"""
try:
verify_auth_token(maibot_session, authorization)
with get_db_session() as session:
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
expression = session.exec(statement).first()
@@ -299,22 +277,17 @@ async def get_expression_detail(
@router.post("/", response_model=ExpressionCreateResponse)
async def create_expression(
request: ExpressionCreateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
创建新的表达方式
Args:
request: 创建请求
authorization: Authorization header
Returns:
创建结果
"""
try:
verify_auth_token(maibot_session, authorization)
current_time = datetime.now()
# 创建表达方式
@@ -349,8 +322,6 @@ async def create_expression(
async def update_expression(
expression_id: int,
request: ExpressionUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
增量更新表达方式(只更新提供的字段)
@@ -358,14 +329,11 @@ async def update_expression(
Args:
expression_id: 表达方式ID
request: 更新请求(只包含需要更新的字段)
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(maibot_session, authorization)
with get_db_session() as session:
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
expression = session.exec(statement).first()
@@ -411,22 +379,17 @@ async def update_expression(
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
async def delete_expression(
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
async def delete_expression(expression_id: int):
"""
删除表达方式
Args:
expression_id: 表达方式ID
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(maibot_session, authorization)
with get_db_session() as session:
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
expression = session.exec(statement).first()
@@ -461,22 +424,17 @@ class BatchDeleteRequest(BaseModel):
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
async def batch_delete_expressions(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量删除表达方式
Args:
request: 包含要删除的ID列表的请求
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(maibot_session, authorization)
if not request.ids:
raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID")
@@ -506,21 +464,14 @@ async def batch_delete_expressions(
@router.get("/stats/summary")
async def get_expression_stats(
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
async def get_expression_stats():
"""
获取表达方式统计数据
Args:
authorization: Authorization header
Returns:
统计数据
"""
try:
verify_auth_token(maibot_session, authorization)
with get_db_session() as session:
total = len(session.exec(select(Expression.id)).all())
@@ -569,7 +520,7 @@ class ReviewStatsResponse(BaseModel):
@router.get("/review/stats", response_model=ReviewStatsResponse)
async def get_review_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def get_review_stats():
"""
获取审核统计数据
@@ -577,8 +528,6 @@ async def get_review_stats(maibot_session: Optional[str] = Cookie(None), authori
审核统计数据
"""
try:
verify_auth_token(maibot_session, authorization)
with get_db_session() as session:
total = len(session.exec(select(Expression.id)).all())
unchecked = 0
@@ -620,8 +569,6 @@ async def get_review_list(
filter_type: str = Query("unchecked", description="筛选类型: unchecked/passed/rejected/all"),
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
获取待审核/已审核的表达方式列表
@@ -637,8 +584,6 @@ async def get_review_list(
表达方式列表
"""
try:
verify_auth_token(maibot_session, authorization)
statement = select(Expression)
if filter_type in {"unchecked", "passed", "rejected"}:
@@ -728,8 +673,6 @@ class BatchReviewResponse(BaseModel):
@router.post("/review/batch", response_model=BatchReviewResponse)
async def batch_review_expressions(
request: BatchReviewRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量审核表达方式
@@ -741,8 +684,6 @@ async def batch_review_expressions(
批量审核结果
"""
try:
verify_auth_token(maibot_session, authorization)
if not request.items:
raise HTTPException(status_code=400, detail="未提供要审核的表达方式")

View File

@@ -1,20 +1,21 @@
"""黑话(俚语)管理路由"""
from typing import Annotated, Any, List, Optional
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel, Field
from sqlalchemy import func as fn
from sqlmodel import Session, col, delete, select
import json
from typing import Annotated, Any, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from sqlmodel import Session, col, delete, select
from src.common.database.database import get_db_session
from src.common.database.database_model import ChatSession, Jargon
from src.common.logger import get_logger
from src.webui.dependencies import require_auth
logger = get_logger("webui.jargon")
router = APIRouter(prefix="/jargon", tags=["Jargon"])
router = APIRouter(prefix="/jargon", tags=["Jargon"], dependencies=[Depends(require_auth)])
# ==================== 辅助函数 ====================
@@ -33,14 +34,10 @@ def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
parsed = json.loads(chat_id_str)
if isinstance(parsed, list):
# 格式: [["stream_id", user_id], ...]
stream_ids = []
for item in parsed:
if isinstance(item, list) and len(item) >= 1:
stream_ids.append(str(item[0]))
return stream_ids
else:
# 其他格式,返回原始字符串
return [chat_id_str]
return [str(item[0]) for item in parsed if isinstance(item, list) and len(item) >= 1]
# 其他格式,返回原始字符串
return [chat_id_str]
except (json.JSONDecodeError, TypeError):
# 不是有效的 JSON可能是直接的 stream_id
return [chat_id_str]
@@ -57,9 +54,7 @@ def get_display_name_for_chat_id(chat_id_str: str, session: Session) -> str:
return chat_id_str[:20]
stream_id = stream_ids[0]
chat_session = session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first()
if not chat_session:
if not (chat_session := session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first()):
return stream_id[:20]
if chat_session.group_id:
@@ -180,9 +175,59 @@ class ChatListResponse(BaseModel):
# ==================== 工具函数 ====================
def parse_session_id_dict(session_id_dict_str: Optional[str]) -> dict[str, int]:
"""解析会话计数字典。"""
if not session_id_dict_str:
return {}
try:
parsed = json.loads(session_id_dict_str)
except (json.JSONDecodeError, TypeError):
return {}
if not isinstance(parsed, dict):
return {}
session_counts: dict[str, int] = {}
for session_id, count in parsed.items():
if not isinstance(session_id, str):
continue
if isinstance(count, int):
session_counts[session_id] = count
else:
try:
session_counts[session_id] = int(count)
except (TypeError, ValueError):
session_counts[session_id] = 0
return session_counts
def dump_session_id_dict(session_counts: dict[str, int]) -> str:
"""序列化会话计数字典。"""
return json.dumps(session_counts, ensure_ascii=False)
def get_primary_chat_id(session_id_dict_str: Optional[str]) -> str:
"""从会话计数字典中选出主聊天 ID。"""
if not (session_counts := parse_session_id_dict(session_id_dict_str)):
return ""
return max(session_counts.items(), key=lambda item: item[1])[0]
def has_chat_id(session_id_dict_str: Optional[str], chat_id: str) -> bool:
"""判断记录是否包含指定聊天 ID。"""
return chat_id in parse_session_id_dict(session_id_dict_str)
def build_session_id_dict_for_chat(chat_id: str, count: int = 1) -> str:
"""为单个聊天 ID 构建会话计数字典。"""
return dump_session_id_dict({chat_id: count})
def jargon_to_dict(jargon: Jargon, session: Session) -> dict[str, Any]:
"""将 Jargon ORM 对象转换为字典"""
chat_id = jargon.session_id or ""
chat_id = get_primary_chat_id(jargon.session_id_dict)
chat_name = get_display_name_for_chat_id(chat_id, session) if chat_id else None
return {
@@ -191,7 +236,7 @@ def jargon_to_dict(jargon: Jargon, session: Session) -> dict[str, Any]:
"raw_content": jargon.raw_content,
"meaning": jargon.meaning,
"chat_id": chat_id,
"stream_id": jargon.session_id,
"stream_id": chat_id or None,
"chat_name": chat_name,
"count": jargon.count,
"is_jargon": jargon.is_jargon,
@@ -215,7 +260,6 @@ async def get_jargon_list(
"""获取黑话列表"""
try:
statement = select(Jargon)
count_statement = select(fn.count()).select_from(Jargon)
if search:
search_filter = (
@@ -224,28 +268,28 @@ async def get_jargon_list(
| (col(Jargon.raw_content).contains(search))
)
statement = statement.where(search_filter)
count_statement = count_statement.where(search_filter)
if chat_id:
stream_ids = parse_chat_id_to_stream_ids(chat_id)
if stream_ids:
chat_filter = col(Jargon.session_id).contains(stream_ids[0])
else:
chat_filter = col(Jargon.session_id) == chat_id
statement = statement.where(chat_filter)
count_statement = count_statement.where(chat_filter)
if is_jargon is not None:
statement = statement.where(col(Jargon.is_jargon) == is_jargon)
count_statement = count_statement.where(col(Jargon.is_jargon) == is_jargon)
statement = statement.order_by(col(Jargon.count).desc(), col(Jargon.id).desc())
statement = statement.offset((page - 1) * page_size).limit(page_size)
with get_db_session() as session:
total = session.exec(count_statement).one()
jargons = session.exec(statement).all()
data = [jargon_to_dict(jargon, session) for jargon in jargons]
if chat_id:
stream_ids = parse_chat_id_to_stream_ids(chat_id)
chat_ids = stream_ids or [chat_id]
jargons = [
jargon
for jargon in jargons
if any(has_chat_id(jargon.session_id_dict, current_chat_id) for current_chat_id in chat_ids)
]
total = len(jargons)
offset = (page - 1) * page_size
page_jargons = jargons[offset : offset + page_size]
data = [jargon_to_dict(jargon, session) for jargon in page_jargons]
return JargonListResponse(
success=True,
@@ -265,22 +309,16 @@ async def get_chat_list():
"""获取所有有黑话记录的聊天列表"""
try:
with get_db_session() as session:
statement = select(Jargon.session_id).distinct().where(col(Jargon.session_id).is_not(None))
chat_id_list = [chat_id for chat_id in session.exec(statement).all() if chat_id]
jargons = session.exec(select(Jargon)).all()
# 用于按 stream_id 去重
seen_stream_ids: set[str] = set()
for chat_id in chat_id_list:
stream_ids = parse_chat_id_to_stream_ids(chat_id)
if stream_ids:
seen_stream_ids.add(stream_ids[0])
for jargon in jargons:
seen_stream_ids.update(parse_session_id_dict(jargon.session_id_dict).keys())
result = []
with get_db_session() as session:
for stream_id in seen_stream_ids:
chat_session = session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first()
if chat_session:
if chat_session := session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first():
chat_name = str(chat_session.group_id) if chat_session.group_id else stream_id[:20]
result.append(
ChatInfoResponse(
@@ -312,30 +350,21 @@ async def get_jargon_stats():
"""获取黑话统计数据"""
try:
with get_db_session() as session:
total = session.exec(select(fn.count()).select_from(Jargon)).one()
jargons = session.exec(select(Jargon)).all()
confirmed_jargon = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon))).one()
confirmed_not_jargon = session.exec(
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(False))
).one()
pending = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(None))).one()
total = len(jargons)
confirmed_jargon = sum(jargon.is_jargon is True for jargon in jargons)
confirmed_not_jargon = sum(jargon.is_jargon is False for jargon in jargons)
pending = sum(jargon.is_jargon is None for jargon in jargons)
complete_count = sum(jargon.is_complete for jargon in jargons)
complete_count = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_complete))).one()
top_chats_counter: dict[str, int] = {}
for jargon in jargons:
for session_id in parse_session_id_dict(jargon.session_id_dict):
top_chats_counter[session_id] = top_chats_counter.get(session_id, 0) + 1
chat_count = session.exec(
select(fn.count()).select_from(
select(col(Jargon.session_id)).distinct().where(col(Jargon.session_id).is_not(None)).subquery()
)
).one()
top_chats = session.exec(
select(col(Jargon.session_id), fn.count().label("count"))
.where(col(Jargon.session_id).is_not(None))
.group_by(col(Jargon.session_id))
.order_by(fn.count().desc())
.limit(5)
).all()
top_chats_dict = {session_id: count for session_id, count in top_chats if session_id}
top_chats_dict = dict(sorted(top_chats_counter.items(), key=lambda item: item[1], reverse=True)[:5])
chat_count = len(top_chats_counter)
return JargonStatsResponse(
success=True,
@@ -360,8 +389,7 @@ async def get_jargon_detail(jargon_id: int):
"""获取黑话详情"""
try:
with get_db_session() as session:
jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()
if not jargon:
if not (jargon := session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()):
raise HTTPException(status_code=404, detail="黑话不存在")
data = JargonResponse(**jargon_to_dict(jargon, session))
@@ -379,19 +407,19 @@ async def create_jargon(request: JargonCreateRequest):
"""创建黑话"""
try:
with get_db_session() as session:
existing = session.exec(
select(Jargon).where(
(col(Jargon.content) == request.content) & (col(Jargon.session_id) == request.chat_id)
)
).first()
if existing:
same_content_jargons = session.exec(select(Jargon).where(col(Jargon.content) == request.content)).all()
existing = next(
(jargon for jargon in same_content_jargons if has_chat_id(jargon.session_id_dict, request.chat_id)),
None,
)
if existing is not None:
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
jargon = Jargon(
content=request.content,
raw_content=request.raw_content,
meaning=request.meaning or "",
session_id=request.chat_id,
session_id_dict=build_session_id_dict_for_chat(request.chat_id),
count=0,
is_jargon=None,
is_complete=False,
@@ -420,13 +448,12 @@ async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
update_data = request.model_dump(exclude_unset=True)
if update_data:
if update_data := request.model_dump(exclude_unset=True):
for field, value in update_data.items():
if field == "is_global":
continue
if field == "chat_id":
jargon.session_id = value
jargon.session_id_dict = build_session_id_dict_for_chat(value, max(jargon.count, 1))
continue
if value is not None or field in ["meaning", "raw_content", "is_jargon"]:
setattr(jargon, field, value)

View File

@@ -1,18 +1,28 @@
"""知识库图谱可视化 API 路由"""
from typing import List, Optional
from fastapi import APIRouter, Query, Depends, Cookie, Header
from typing import Any, List, Optional
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel
import logging
from src.webui.core import verify_auth_token_from_cookie_or_header
from src.config.config import global_config
from src.webui.dependencies import require_auth
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"], dependencies=[Depends(require_auth)])
# 延迟初始化的轻量级 embedding store只读仅用于获取段落完整文本
_paragraph_store_cache = None
_paragraph_store_cache: Any = None
def _get_embedding_dir() -> str:
"""获取 embedding 数据目录。"""
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
root_path = os.path.abspath(os.path.join(current_dir, "..", ".."))
return os.path.join(root_path, "data/embedding")
def _get_paragraph_store():
@@ -31,17 +41,11 @@ def _get_paragraph_store():
try:
from src.chat.knowledge.embedding_store import EmbeddingStore
import os
# 获取数据路径
current_dir = os.path.dirname(os.path.abspath(__file__))
root_path = os.path.abspath(os.path.join(current_dir, "..", ".."))
embedding_dir = os.path.join(root_path, "data/embedding")
# 只加载段落 embedding store轻量级
paragraph_store = EmbeddingStore(
namespace="paragraph",
dir_path=embedding_dir,
dir_path=_get_embedding_dir(),
max_workers=1, # 只读不需要多线程
chunk_size=100,
)
@@ -74,8 +78,7 @@ def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]:
paragraph_item = paragraph_store.store.get(node_id)
if paragraph_item is not None:
# paragraph_item 是 EmbeddingStoreItem其 str 属性包含完整文本
content: str = getattr(paragraph_item, "str", "")
if content:
if content := getattr(paragraph_item, "str", ""):
return content, True
return None, True
except Exception as e:
@@ -83,14 +86,6 @@ def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]:
return None, True
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
class KnowledgeNode(BaseModel):
"""知识节点"""
@@ -205,7 +200,6 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
async def get_knowledge_graph(
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
_auth: bool = Depends(require_auth),
):
"""获取知识图谱(限制节点数量)
@@ -303,7 +297,7 @@ async def get_knowledge_graph(
@router.get("/stats", response_model=KnowledgeStats)
async def get_knowledge_stats(_auth: bool = Depends(require_auth)):
async def get_knowledge_stats():
"""获取知识库统计信息
Returns:
@@ -352,7 +346,7 @@ async def get_knowledge_stats(_auth: bool = Depends(require_auth)):
@router.get("/search", response_model=List[KnowledgeNode])
async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bool = Depends(require_auth)):
async def search_knowledge_node(query: str = Query(..., min_length=1)):
"""搜索知识节点
Args:

View File

@@ -6,27 +6,18 @@
import os
import httpx
from fastapi import APIRouter, HTTPException, Query, Depends, Cookie, Header
from typing import Optional
import tomlkit
from fastapi import APIRouter, Depends, HTTPException, Query
from src.common.logger import get_logger
from src.config.config import CONFIG_DIR
from src.webui.core import verify_auth_token_from_cookie_or_header
from src.webui.dependencies import require_auth
logger = get_logger("webui")
router = APIRouter(prefix="/models", tags=["models"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
router = APIRouter(prefix="/models", tags=["models"], dependencies=[Depends(require_auth)])
# 模型获取器配置
MODEL_FETCHER_CONFIG = {
# OpenAI 兼容格式的提供商
@@ -44,9 +35,7 @@ MODEL_FETCHER_CONFIG = {
def _normalize_url(url: str) -> str:
"""规范化 URL去掉尾部斜杠"""
if not url:
return ""
return url.rstrip("/")
return url.rstrip("/") if url else ""
def _parse_openai_response(data: dict) -> list[dict]:
@@ -55,18 +44,18 @@ def _parse_openai_response(data: dict) -> list[dict]:
格式: { "data": [{ "id": "gpt-4", "object": "model", ... }] }
"""
models = []
if "data" in data and isinstance(data["data"], list):
for model in data["data"]:
if isinstance(model, dict) and "id" in model:
models.append(
{
"id": model["id"],
"name": model.get("name") or model["id"],
"owned_by": model.get("owned_by", ""),
}
)
return models
if "data" not in data or not isinstance(data["data"], list):
return []
return [
{
"id": model["id"],
"name": model.get("name") or model["id"],
"owned_by": model.get("owned_by", ""),
}
for model in data["data"]
if isinstance(model, dict) and "id" in model
]
def _parse_gemini_response(data: dict) -> list[dict]:
@@ -178,11 +167,8 @@ def _get_provider_config(provider_name: str) -> Optional[dict]:
config_data = tomlkit.load(f)
providers = config_data.get("api_providers", [])
for provider in providers:
if provider.get("name") == provider_name:
return dict(provider)
return None
provider = next((provider for provider in providers if provider.get("name") == provider_name), None)
return dict(provider) if provider is not None else None
except Exception as e:
logger.error(f"读取提供商配置失败: {e}")
return None
@@ -193,7 +179,6 @@ async def get_provider_models(
provider_name: str = Query(..., description="提供商名称"),
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
endpoint: str = Query("/models", description="获取模型列表的端点"),
_auth: bool = Depends(require_auth),
):
"""
获取指定提供商的可用模型列表
@@ -238,7 +223,6 @@ async def get_models_by_url(
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
endpoint: str = Query("/models", description="获取模型列表的端点"),
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
_auth: bool = Depends(require_auth),
):
"""
通过 URL 直接获取模型列表(用于自定义提供商)
@@ -262,7 +246,6 @@ async def get_models_by_url(
async def test_provider_connection(
base_url: str = Query(..., description="提供商的基础 URL"),
api_key: Optional[str] = Query(None, description="API Key可选用于验证 Key 有效性)"),
_auth: bool = Depends(require_auth),
):
"""
测试提供商连接状态
@@ -349,7 +332,6 @@ async def test_provider_connection(
@router.post("/test-connection-by-name")
async def test_provider_connection_by_name(
provider_name: str = Query(..., description="提供商名称"),
_auth: bool = Depends(require_auth),
):
"""
通过提供商名称测试连接(从配置文件读取信息)
@@ -364,11 +346,7 @@ async def test_provider_connection_by_name(
# 查找提供商
providers = config.get("api_providers", [])
provider = None
for p in providers:
if p.get("name") == provider_name:
provider = p
break
provider = next((item for item in providers if item.get("name") == provider_name), None)
if not provider:
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
@@ -380,4 +358,4 @@ async def test_provider_connection_by_name(
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
# 调用测试接口
return await test_provider_connection(base_url=base_url, api_key=api_key if api_key else None)
return await test_provider_connection(base_url=base_url, api_key=api_key or None)

View File

@@ -1,23 +1,24 @@
"""人物信息管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
from pydantic import BaseModel
from typing import Optional, List, Dict
from datetime import datetime
import json
from typing import Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy import case
from sqlmodel import col, select, delete
from src.common.logger import get_logger
from src.common.database.database import get_db_session
from src.common.database.database_model import PersonInfo
from src.webui.core import verify_auth_token_from_cookie_or_header
import json
from src.common.logger import get_logger
from src.webui.dependencies import require_auth
logger = get_logger("webui.person")
# 创建路由器
router = APIRouter(prefix="/person", tags=["Person"])
router = APIRouter(prefix="/person", tags=["Person"], dependencies=[Depends(require_auth)])
class PersonInfoResponse(BaseModel):
@@ -96,14 +97,6 @@ class BatchDeleteResponse(BaseModel):
failed_ids: List[str] = []
def verify_auth_token(
maibot_session: Optional[str] = None,
authorization: Optional[str] = None,
) -> bool:
"""验证认证 Token支持 Cookie 和 Header"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
def parse_group_nick_name(group_nick_name_str: Optional[str]) -> Optional[List[Dict[str, str]]]:
"""解析群昵称 JSON 字符串"""
if not group_nick_name_str:
@@ -127,7 +120,7 @@ def person_to_response(person: PersonInfo) -> PersonInfoResponse:
platform=person.platform,
user_id=person.user_id,
nickname=person.user_nickname,
group_nick_name=parse_group_nick_name(person.group_nickname),
group_nick_name=parse_group_nick_name(person.group_cardname),
memory_points=person.memory_points,
know_times=person.know_counts,
know_since=know_since,
@@ -142,8 +135,6 @@ async def get_person_list(
search: Optional[str] = Query(None, description="搜索关键词"),
is_known: Optional[bool] = Query(None, description="是否已认识筛选"),
platform: Optional[str] = Query(None, description="平台筛选"),
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
获取人物信息列表
@@ -154,14 +145,11 @@ async def get_person_list(
search: 搜索关键词 (匹配 person_name, nickname, user_id)
is_known: 是否已认识筛选
platform: 平台筛选
authorization: Authorization header
Returns:
人物信息列表
"""
try:
verify_auth_token(maibot_session, authorization)
# 构建查询
statement = select(PersonInfo)
@@ -219,22 +207,17 @@ async def get_person_list(
@router.get("/{person_id}", response_model=PersonDetailResponse)
async def get_person_detail(
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
async def get_person_detail(person_id: str):
"""
获取人物详细信息
Args:
person_id: 人物唯一 ID
authorization: Authorization header
Returns:
人物详细信息
"""
try:
verify_auth_token(maibot_session, authorization)
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
person = session.exec(statement).first()
@@ -255,8 +238,6 @@ async def get_person_detail(
async def update_person(
person_id: str,
request: PersonUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
增量更新人物信息(只更新提供的字段)
@@ -264,14 +245,11 @@ async def update_person(
Args:
person_id: 人物唯一 ID
request: 更新请求(只包含需要更新的字段)
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(maibot_session, authorization)
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
person = session.exec(statement).first()
@@ -313,22 +291,17 @@ async def update_person(
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
async def delete_person(
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
async def delete_person(person_id: str):
"""
删除人物信息
Args:
person_id: 人物唯一 ID
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(maibot_session, authorization)
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
person = session.exec(statement).first()
@@ -355,19 +328,14 @@ async def delete_person(
@router.get("/stats/summary")
async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def get_person_stats():
"""
获取人物信息统计数据
Args:
authorization: Authorization header
Returns:
统计数据
"""
try:
verify_auth_token(maibot_session, authorization)
with get_db_session() as session:
total = len(session.exec(select(PersonInfo.id)).all())
known = len(session.exec(select(PersonInfo.id).where(col(PersonInfo.is_known))).all())
@@ -392,22 +360,17 @@ async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authori
@router.post("/batch/delete", response_model=BatchDeleteResponse)
async def batch_delete_persons(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量删除人物信息
Args:
request: 包含person_ids列表的请求
authorization: Authorization header
Returns:
批量删除结果
"""
try:
verify_auth_token(maibot_session, authorization)
if not request.person_ids:
raise HTTPException(status_code=400, detail="未提供要删除的人物ID")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,17 @@
from fastapi import APIRouter
from src.webui.services.git_mirror_service import set_update_progress_callback
from .catalog import router as catalog_router
from .config_routes import router as config_router
from .management import router as management_router
from .progress import get_progress_router, update_progress
router = APIRouter(prefix="/plugins", tags=["插件管理"])
router.include_router(catalog_router)
router.include_router(management_router)
router.include_router(config_router)
set_update_progress_callback(update_progress)
__all__ = ["get_progress_router", "router"]

View File

@@ -0,0 +1,205 @@
from typing import Any, Optional
import json
from fastapi import APIRouter, Cookie, HTTPException
from src.common.logger import get_logger
from src.config.config import MMC_VERSION
from src.webui.services.git_mirror_service import get_git_mirror_service
from .progress import update_progress
from .schemas import (
AddMirrorRequest,
AvailableMirrorsResponse,
CloneRepositoryRequest,
CloneRepositoryResponse,
FetchRawFileRequest,
FetchRawFileResponse,
GitStatusResponse,
MirrorConfigResponse,
UpdateMirrorRequest,
VersionResponse,
)
from .support import get_plugins_dir, parse_version, require_plugin_token, validate_safe_path
logger = get_logger("webui.plugin_routes")
router = APIRouter()
def _mirror_to_response(mirror: dict[str, Any]) -> MirrorConfigResponse:
return MirrorConfigResponse(
id=mirror["id"],
name=mirror["name"],
raw_prefix=mirror["raw_prefix"],
clone_prefix=mirror["clone_prefix"],
enabled=mirror["enabled"],
priority=mirror["priority"],
)
@router.get("/version", response_model=VersionResponse)
async def get_maimai_version() -> VersionResponse:
major, minor, patch = parse_version(MMC_VERSION)
return VersionResponse(version=MMC_VERSION, version_major=major, version_minor=minor, version_patch=patch)
@router.get("/git-status", response_model=GitStatusResponse)
async def check_git_status() -> GitStatusResponse:
service = get_git_mirror_service()
return GitStatusResponse(**service.check_git_installed())
@router.get("/mirrors", response_model=AvailableMirrorsResponse)
async def get_available_mirrors(maibot_session: Optional[str] = Cookie(None)) -> AvailableMirrorsResponse:
require_plugin_token(maibot_session)
service = get_git_mirror_service()
config = service.get_mirror_config()
mirrors = [_mirror_to_response(mirror) for mirror in config.get_all_mirrors()]
return AvailableMirrorsResponse(mirrors=mirrors, default_priority=config.get_default_priority_list())
@router.post("/mirrors", response_model=MirrorConfigResponse)
async def add_mirror(request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None)) -> MirrorConfigResponse:
require_plugin_token(maibot_session)
try:
service = get_git_mirror_service()
config = service.get_mirror_config()
mirror = config.add_mirror(
mirror_id=request.id,
name=request.name,
raw_prefix=request.raw_prefix,
clone_prefix=request.clone_prefix,
enabled=request.enabled,
priority=request.priority,
)
return _mirror_to_response(mirror)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
except Exception as e:
logger.error(f"添加镜像源失败: {e}")
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
@router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse)
async def update_mirror(
mirror_id: str,
request: UpdateMirrorRequest,
maibot_session: Optional[str] = Cookie(None),
) -> MirrorConfigResponse:
require_plugin_token(maibot_session)
try:
service = get_git_mirror_service()
config = service.get_mirror_config()
mirror = config.update_mirror(
mirror_id=mirror_id,
name=request.name,
raw_prefix=request.raw_prefix,
clone_prefix=request.clone_prefix,
enabled=request.enabled,
priority=request.priority,
)
if mirror is None:
raise HTTPException(status_code=404, detail=f"未找到镜像源: {mirror_id}")
return _mirror_to_response(mirror)
except HTTPException:
raise
except Exception as e:
logger.error(f"更新镜像源失败: {e}")
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
@router.delete("/mirrors/{mirror_id}")
async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
require_plugin_token(maibot_session)
service = get_git_mirror_service()
config = service.get_mirror_config()
if not config.delete_mirror(mirror_id):
raise HTTPException(status_code=404, detail=f"未找到镜像源: {mirror_id}")
return {"success": True, "message": f"已删除镜像源: {mirror_id}"}
@router.post("/fetch-raw", response_model=FetchRawFileResponse)
async def fetch_raw_file(
request: FetchRawFileRequest,
maibot_session: Optional[str] = Cookie(None),
) -> FetchRawFileResponse:
require_plugin_token(maibot_session)
logger.info(f"收到获取 Raw 文件请求: {request.owner}/{request.repo}/{request.branch}/{request.file_path}")
await update_progress(
stage="loading",
progress=10,
message=f"正在获取插件列表: {request.file_path}",
total_plugins=0,
loaded_plugins=0,
)
try:
service = get_git_mirror_service()
result = await service.fetch_raw_file(
owner=request.owner,
repo=request.repo,
branch=request.branch,
file_path=request.file_path,
mirror_id=request.mirror_id,
custom_url=request.custom_url,
)
if result.get("success"):
await update_progress(
stage="loading",
progress=70,
message="正在解析插件数据...",
total_plugins=0,
loaded_plugins=0,
)
try:
data = json.loads(result.get("data", "[]"))
total = len(data) if isinstance(data, list) else 0
await update_progress(
stage="success",
progress=100,
message=f"成功加载 {total} 个插件",
total_plugins=total,
loaded_plugins=total,
)
except Exception:
await update_progress(stage="success", progress=100, message="加载完成", total_plugins=0, loaded_plugins=0)
return FetchRawFileResponse(**result)
except Exception as e:
logger.error(f"获取 Raw 文件失败: {e}")
await update_progress(stage="error", progress=0, message="加载失败", error=str(e), total_plugins=0, loaded_plugins=0)
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
@router.post("/clone", response_model=CloneRepositoryResponse)
async def clone_repository(
request: CloneRepositoryRequest,
maibot_session: Optional[str] = Cookie(None),
) -> CloneRepositoryResponse:
require_plugin_token(maibot_session)
logger.info(f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}")
try:
target_path = validate_safe_path(request.target_path, get_plugins_dir())
service = get_git_mirror_service()
result = await service.clone_repository(
owner=request.owner,
repo=request.repo,
target_path=target_path,
branch=request.branch,
mirror_id=request.mirror_id,
custom_url=request.custom_url,
depth=request.depth,
)
return CloneRepositoryResponse(**result)
except Exception as e:
logger.error(f"克隆仓库失败: {e}")
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e

View File

@@ -0,0 +1,333 @@
from typing import Any, Optional, cast
import json
import tomlkit
from fastapi import APIRouter, Cookie, HTTPException
from src.common.logger import get_logger
from src.webui.utils.toml_utils import save_toml_with_format
from .schemas import UpdatePluginConfigRequest, UpdatePluginRawConfigRequest
from .support import (
backup_file,
coerce_types,
find_plugin_instance,
find_plugin_path_by_id,
normalize_dotted_keys,
require_plugin_token,
)
logger = get_logger("webui.plugin_routes")
router = APIRouter()
def _build_schema_from_current_config(plugin_id: str, current_config: Any) -> dict[str, Any]:
schema: dict[str, Any] = {
"plugin_id": plugin_id,
"plugin_info": {
"name": plugin_id,
"version": "",
"description": "",
"author": "",
},
"sections": {},
"layout": {"type": "auto", "tabs": []},
"_note": "插件未加载,仅返回当前配置结构",
}
for section_name, section_data in current_config.items():
if not isinstance(section_data, dict):
continue
section_fields: dict[str, Any] = {}
for field_name, field_value in section_data.items():
field_type = type(field_value).__name__
ui_type = "text"
item_type = None
item_fields = None
if isinstance(field_value, bool):
ui_type = "switch"
elif isinstance(field_value, (int, float)):
ui_type = "number"
elif isinstance(field_value, list):
ui_type = "list"
if field_value:
first_item = field_value[0]
if isinstance(first_item, dict):
item_type = "object"
item_fields = {
key: {
"type": "number" if isinstance(value, (int, float)) else "string",
"label": key,
"default": "" if isinstance(value, str) else 0,
}
for key, value in first_item.items()
}
elif isinstance(first_item, (int, float)):
item_type = "number"
else:
item_type = "string"
else:
item_type = "string"
elif isinstance(field_value, dict):
ui_type = "json"
section_fields[field_name] = {
"name": field_name,
"type": field_type,
"default": field_value,
"description": field_name,
"label": field_name,
"ui_type": ui_type,
"required": False,
"hidden": False,
"disabled": False,
"order": 0,
"item_type": item_type,
"item_fields": item_fields,
"min_items": None,
"max_items": None,
"placeholder": None,
"hint": None,
"icon": None,
"example": None,
"choices": None,
"min": None,
"max": None,
"step": None,
"pattern": None,
"max_length": None,
"input_type": None,
"rows": 3,
"group": None,
"depends_on": None,
"depends_value": None,
}
schema["sections"][section_name] = {
"name": section_name,
"title": section_name,
"description": None,
"icon": None,
"collapsed": False,
"order": 0,
"fields": section_fields,
}
return schema
@router.get("/config/{plugin_id}/schema")
async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
require_plugin_token(maibot_session)
logger.info(f"获取插件配置 Schema: {plugin_id}")
try:
plugin_instance = find_plugin_instance(plugin_id)
if plugin_instance and hasattr(plugin_instance, "get_webui_config_schema"):
return {"success": True, "schema": plugin_instance.get_webui_config_schema()}
plugin_path = find_plugin_path_by_id(plugin_id)
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
schema_json_path = plugin_path / "config_schema.json"
if schema_json_path.exists():
try:
with open(schema_json_path, "r", encoding="utf-8") as file_obj:
return {"success": True, "schema": json.load(file_obj)}
except Exception as e:
logger.warning(f"读取 config_schema.json 失败,回退到自动推断: {e}")
current_config: Any = {}
config_path = plugin_path / "config.toml"
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as file_obj:
current_config = tomlkit.load(file_obj)
return {"success": True, "schema": _build_schema_from_current_config(plugin_id, current_config)}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取插件配置 Schema 失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
@router.get("/config/{plugin_id}/raw")
async def get_plugin_config_raw(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
require_plugin_token(maibot_session)
logger.info(f"获取插件原始配置: {plugin_id}")
try:
plugin_path = find_plugin_path_by_id(plugin_id)
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = plugin_path / "config.toml"
if not config_path.exists():
return {"success": True, "config": "", "message": "配置文件不存在"}
with open(config_path, "r", encoding="utf-8") as file_obj:
return {"success": True, "config": file_obj.read()}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取插件原始配置失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
@router.put("/config/{plugin_id}/raw")
async def update_plugin_config_raw(
plugin_id: str,
request: UpdatePluginRawConfigRequest,
maibot_session: Optional[str] = Cookie(None),
) -> dict[str, Any]:
require_plugin_token(maibot_session)
logger.info(f"更新插件原始配置: {plugin_id}")
try:
plugin_path = find_plugin_path_by_id(plugin_id)
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = plugin_path / "config.toml"
try:
tomlkit.loads(request.config)
except Exception as e:
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
backup_path = backup_file(config_path, "backup")
if backup_path is not None:
logger.info(f"已备份配置文件: {backup_path}")
with open(config_path, "w", encoding="utf-8") as file_obj:
file_obj.write(request.config)
logger.info(f"已更新插件原始配置: {plugin_id}")
return {"success": True, "message": "配置已保存", "note": "配置更改将在插件重新加载后生效"}
except HTTPException:
raise
except Exception as e:
logger.error(f"更新插件原始配置失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
@router.get("/config/{plugin_id}")
async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
require_plugin_token(maibot_session)
logger.info(f"获取插件配置: {plugin_id}")
try:
plugin_path = find_plugin_path_by_id(plugin_id)
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = plugin_path / "config.toml"
if not config_path.exists():
return {"success": True, "config": {}, "message": "配置文件不存在"}
with open(config_path, "r", encoding="utf-8") as file_obj:
config = tomlkit.load(file_obj)
return {"success": True, "config": dict(config)}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取插件配置失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
@router.put("/config/{plugin_id}")
async def update_plugin_config(
plugin_id: str,
request: UpdatePluginConfigRequest,
maibot_session: Optional[str] = Cookie(None),
) -> dict[str, Any]:
require_plugin_token(maibot_session)
logger.info(f"更新插件配置: {plugin_id}")
try:
plugin_instance = find_plugin_instance(plugin_id)
config_data = request.config or {}
if plugin_instance and isinstance(config_data, dict):
config_data = normalize_dotted_keys(config_data)
if isinstance(plugin_instance.config_schema, dict):
coerce_types(plugin_instance.config_schema, config_data)
plugin_path = find_plugin_path_by_id(plugin_id)
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = plugin_path / "config.toml"
backup_path = backup_file(config_path, "backup")
if backup_path is not None:
logger.info(f"已备份配置文件: {backup_path}")
save_toml_with_format(config_data, str(config_path))
logger.info(f"已更新插件配置: {plugin_id}")
return {"success": True, "message": "配置已保存", "note": "配置更改将在插件重新加载后生效"}
except HTTPException:
raise
except Exception as e:
logger.error(f"更新插件配置失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
@router.post("/config/{plugin_id}/reset")
async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
require_plugin_token(maibot_session)
logger.info(f"重置插件配置: {plugin_id}")
try:
plugin_path = find_plugin_path_by_id(plugin_id)
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = plugin_path / "config.toml"
if not config_path.exists():
return {"success": True, "message": "配置文件不存在,无需重置"}
backup_path = backup_file(config_path, "reset", move_file=True)
logger.info(f"已重置插件配置: {plugin_id},备份: {backup_path}")
return {"success": True, "message": "配置已重置,下次加载插件时将使用默认配置", "backup": str(backup_path)}
except HTTPException:
raise
except Exception as e:
logger.error(f"重置插件配置失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
@router.post("/config/{plugin_id}/toggle")
async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
require_plugin_token(maibot_session)
logger.info(f"切换插件状态: {plugin_id}")
try:
plugin_path = find_plugin_path_by_id(plugin_id)
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = plugin_path / "config.toml"
config = tomlkit.document()
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as file_obj:
config = tomlkit.load(file_obj)
if "plugin" not in config:
config["plugin"] = tomlkit.table()
plugin_config = cast(Any, config["plugin"])
current_enabled = bool(plugin_config.get("enabled", True))
new_enabled = not current_enabled
plugin_config["enabled"] = new_enabled
save_toml_with_format(config, str(config_path))
status = "启用" if new_enabled else "禁用"
logger.info(f"{status}插件: {plugin_id}")
return {"success": True, "enabled": new_enabled, "message": f"插件已{status}", "note": "状态更改将在下次加载插件时生效"}
except HTTPException:
raise
except Exception as e:
logger.error(f"切换插件状态失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e

View File

@@ -0,0 +1,302 @@
from pathlib import Path
from typing import Any, Optional
import json
from fastapi import APIRouter, Cookie, HTTPException
from src.common.logger import get_logger
from src.webui.services.git_mirror_service import get_git_mirror_service
from .progress import update_progress
from .schemas import InstallPluginRequest, UninstallPluginRequest, UpdatePluginRequest
from .support import (
find_plugin_path_by_id,
get_plugin_candidate_paths,
get_plugins_dir,
load_manifest_json,
parse_repository_url,
remove_tree,
require_plugin_token,
resolve_installed_plugin_path,
validate_plugin_id,
)
logger = get_logger("webui.plugin_routes")
router = APIRouter()
def _infer_plugin_id(folder_name: str, manifest: dict[str, Any], manifest_path: Path) -> str:
if "id" in manifest:
return str(manifest["id"])
author_name: Optional[str] = None
repo_name: Optional[str] = None
if "author" in manifest:
author_data = manifest["author"]
if isinstance(author_data, dict) and "name" in author_data:
author_name = str(author_data["name"])
elif isinstance(author_data, str):
author_name = author_data
if "repository_url" in manifest:
repo_url = str(manifest["repository_url"]).rstrip("/").removesuffix(".git")
repo_name = repo_url.split("/")[-1]
if author_name and repo_name:
plugin_id = f"{author_name}.{repo_name}"
elif author_name:
plugin_id = f"{author_name}.{folder_name}"
elif "_" in folder_name and "." not in folder_name:
plugin_id = folder_name.replace("_", ".", 1)
else:
plugin_id = folder_name
logger.info(f"为插件 {folder_name} 自动生成 ID: {plugin_id}")
manifest["id"] = plugin_id
try:
with open(manifest_path, "w", encoding="utf-8") as file_obj:
json.dump(manifest, file_obj, ensure_ascii=False, indent=2)
except Exception as write_error:
logger.warning(f"无法写入 ID 到 manifest: {write_error}")
return plugin_id
@router.post("/install")
async def install_plugin(request: InstallPluginRequest, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
require_plugin_token(maibot_session)
logger.info(f"收到安装插件请求: {request.plugin_id}")
plugin_id = request.plugin_id
try:
plugin_id = validate_plugin_id(request.plugin_id)
await update_progress(stage="loading", progress=5, message=f"开始安装插件: {plugin_id}", operation="install", plugin_id=plugin_id)
repo_url, owner, repo = parse_repository_url(request.repository_url)
await update_progress(stage="loading", progress=10, message=f"解析仓库信息: {owner}/{repo}", operation="install", plugin_id=plugin_id)
target_path, old_format_path = get_plugin_candidate_paths(plugin_id)
if target_path.exists() or old_format_path.exists():
await update_progress(stage="error", progress=0, message="插件已存在", operation="install", plugin_id=plugin_id, error="插件已安装,请先卸载")
raise HTTPException(status_code=400, detail="插件已安装")
await update_progress(stage="loading", progress=15, message=f"准备克隆到: {target_path}", operation="install", plugin_id=plugin_id)
service = get_git_mirror_service()
if "github.com" in repo_url:
result = await service.clone_repository(owner=owner, repo=repo, target_path=target_path, branch=request.branch, mirror_id=request.mirror_id, depth=1)
else:
result = await service.clone_repository(owner=owner, repo=repo, target_path=target_path, branch=request.branch, custom_url=repo_url, depth=1)
if not result.get("success"):
error_msg = str(result.get("error", "克隆失败"))
await update_progress(stage="error", progress=0, message="克隆仓库失败", operation="install", plugin_id=plugin_id, error=error_msg)
raise HTTPException(status_code=500, detail=error_msg)
await update_progress(stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id)
manifest_path = target_path / "_manifest.json"
if not manifest_path.exists():
remove_tree(target_path)
await update_progress(stage="error", progress=0, message="插件缺少 _manifest.json", operation="install", plugin_id=plugin_id, error="无效的插件格式")
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
await update_progress(stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=plugin_id)
try:
with open(manifest_path, "r", encoding="utf-8") as file_obj:
manifest = json.load(file_obj)
for field in ["manifest_version", "name", "version", "author"]:
if field not in manifest:
raise ValueError(f"缺少必需字段: {field}")
manifest["id"] = plugin_id
with open(manifest_path, "w", encoding="utf-8") as file_obj:
json.dump(manifest, file_obj, ensure_ascii=False, indent=2)
except Exception as e:
remove_tree(target_path)
await update_progress(stage="error", progress=0, message="_manifest.json 无效", operation="install", plugin_id=plugin_id, error=str(e))
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
await update_progress(stage="success", progress=100, message=f"成功安装插件: {manifest['name']} v{manifest['version']}", operation="install", plugin_id=plugin_id)
return {"success": True, "message": "插件安装成功", "plugin_id": plugin_id, "plugin_name": manifest["name"], "version": manifest["version"], "path": str(target_path)}
except HTTPException:
raise
except Exception as e:
logger.error(f"安装插件失败: {e}", exc_info=True)
await update_progress(stage="error", progress=0, message="安装失败", operation="install", plugin_id=plugin_id, error=str(e))
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
@router.post("/uninstall")
async def uninstall_plugin(request: UninstallPluginRequest, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
require_plugin_token(maibot_session)
logger.info(f"收到卸载插件请求: {request.plugin_id}")
plugin_id = request.plugin_id
try:
plugin_id = validate_plugin_id(request.plugin_id)
await update_progress(stage="loading", progress=10, message=f"开始卸载插件: {plugin_id}", operation="uninstall", plugin_id=plugin_id)
plugin_path = resolve_installed_plugin_path(plugin_id)
if plugin_path is None:
await update_progress(stage="error", progress=0, message="插件不存在", operation="uninstall", plugin_id=plugin_id, error="插件未安装或已被删除")
raise HTTPException(status_code=404, detail="插件未安装")
await update_progress(stage="loading", progress=30, message=f"正在删除插件文件: {plugin_path}", operation="uninstall", plugin_id=plugin_id)
manifest = load_manifest_json(plugin_path / "_manifest.json")
plugin_name = str(manifest.get("name", plugin_id)) if manifest is not None else plugin_id
await update_progress(stage="loading", progress=50, message=f"正在删除 {plugin_name}...", operation="uninstall", plugin_id=plugin_id)
remove_tree(plugin_path)
logger.info(f"成功卸载插件: {plugin_id} ({plugin_name})")
await update_progress(stage="success", progress=100, message=f"成功卸载插件: {plugin_name}", operation="uninstall", plugin_id=plugin_id)
return {"success": True, "message": "插件卸载成功", "plugin_id": plugin_id, "plugin_name": plugin_name}
except HTTPException:
raise
except PermissionError as e:
logger.error(f"卸载插件失败(权限错误): {e}")
await update_progress(stage="error", progress=0, message="卸载失败", operation="uninstall", plugin_id=plugin_id, error="权限不足,无法删除插件文件")
raise HTTPException(status_code=500, detail="权限不足,无法删除插件文件") from e
except Exception as e:
logger.error(f"卸载插件失败: {e}", exc_info=True)
await update_progress(stage="error", progress=0, message="卸载失败", operation="uninstall", plugin_id=plugin_id, error=str(e))
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
@router.post("/update")
async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
require_plugin_token(maibot_session)
logger.info(f"收到更新插件请求: {request.plugin_id}")
plugin_id = request.plugin_id
try:
plugin_id = validate_plugin_id(request.plugin_id)
await update_progress(stage="loading", progress=5, message=f"开始更新插件: {plugin_id}", operation="update", plugin_id=plugin_id)
plugin_path = resolve_installed_plugin_path(plugin_id)
if plugin_path is None:
await update_progress(stage="error", progress=0, message="插件不存在", operation="update", plugin_id=plugin_id, error="插件未安装,请先安装")
raise HTTPException(status_code=404, detail="插件未安装")
manifest = load_manifest_json(plugin_path / "_manifest.json")
old_version = str(manifest.get("version", "unknown")) if manifest is not None else "unknown"
await update_progress(stage="loading", progress=10, message=f"当前版本: {old_version},准备更新...", operation="update", plugin_id=plugin_id)
await update_progress(stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id)
remove_tree(plugin_path)
await update_progress(stage="loading", progress=30, message="正在准备下载新版本...", operation="update", plugin_id=plugin_id)
repo_url, owner, repo = parse_repository_url(request.repository_url)
service = get_git_mirror_service()
if "github.com" in repo_url:
result = await service.clone_repository(owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, mirror_id=request.mirror_id, depth=1)
else:
result = await service.clone_repository(owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, custom_url=repo_url, depth=1)
if not result.get("success"):
error_msg = str(result.get("error", "克隆失败"))
await update_progress(stage="error", progress=0, message="下载新版本失败", operation="update", plugin_id=plugin_id, error=error_msg)
raise HTTPException(status_code=500, detail=error_msg)
await update_progress(stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id)
new_manifest_path = plugin_path / "_manifest.json"
if not new_manifest_path.exists():
remove_tree(plugin_path)
await update_progress(stage="error", progress=0, message="新版本缺少 _manifest.json", operation="update", plugin_id=plugin_id, error="无效的插件格式")
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
try:
with open(new_manifest_path, "r", encoding="utf-8") as file_obj:
new_manifest = json.load(file_obj)
new_version = str(new_manifest.get("version", "unknown"))
new_name = str(new_manifest.get("name", plugin_id))
logger.info(f"成功更新插件: {plugin_id} {old_version}{new_version}")
await update_progress(stage="success", progress=100, message=f"成功更新 {new_name}: {old_version}{new_version}", operation="update", plugin_id=plugin_id)
return {"success": True, "message": "插件更新成功", "plugin_id": plugin_id, "plugin_name": new_name, "old_version": old_version, "new_version": new_version}
except Exception as e:
remove_tree(plugin_path)
await update_progress(stage="error", progress=0, message="_manifest.json 无效", operation="update", plugin_id=plugin_id, error=str(e))
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
except HTTPException:
raise
except Exception as e:
logger.error(f"更新插件失败: {e}", exc_info=True)
await update_progress(stage="error", progress=0, message="更新失败", operation="update", plugin_id=plugin_id, error=str(e))
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
@router.get("/installed")
async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
require_plugin_token(maibot_session)
logger.info("收到获取已安装插件列表请求")
try:
plugins_dir = get_plugins_dir()
installed_plugins: list[dict[str, Any]] = []
for plugin_path in plugins_dir.iterdir():
if not plugin_path.is_dir():
continue
folder_name = plugin_path.name
if folder_name.startswith(".") or folder_name.startswith("__"):
continue
manifest_path = plugin_path / "_manifest.json"
if not manifest_path.exists():
logger.warning(f"插件文件夹 {folder_name} 缺少 _manifest.json跳过")
continue
try:
with open(manifest_path, "r", encoding="utf-8") as file_obj:
manifest = json.load(file_obj)
if "name" not in manifest or "version" not in manifest:
logger.warning(f"插件文件夹 {folder_name} 的 _manifest.json 格式无效,跳过")
continue
plugin_id = _infer_plugin_id(folder_name, manifest, manifest_path)
installed_plugins.append({"id": plugin_id, "manifest": manifest, "path": str(plugin_path.absolute())})
except json.JSONDecodeError as e:
logger.warning(f"插件 {folder_name} 的 _manifest.json 解析失败: {e}")
except Exception as e:
logger.error(f"读取插件 {folder_name} 信息时出错: {e}")
seen_ids: dict[str, str] = {}
unique_plugins: list[dict[str, Any]] = []
duplicates: list[dict[str, Any]] = []
for plugin in installed_plugins:
plugin_id = str(plugin["id"])
plugin_path = str(plugin["path"])
if plugin_id not in seen_ids:
seen_ids[plugin_id] = plugin_path
unique_plugins.append(plugin)
else:
duplicates.append(plugin)
logger.warning(f"重复插件 {plugin_id}: 保留 {seen_ids[plugin_id]}, 跳过 {plugin_path}")
if duplicates:
logger.warning(f"共检测到 {len(duplicates)} 个重复插件已去重")
logger.info(f"找到 {len(unique_plugins)} 个已安装插件")
return {"success": True, "plugins": unique_plugins, "total": len(unique_plugins)}
except Exception as e:
logger.error(f"获取已安装插件列表失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
@router.get("/local-readme/{plugin_id}")
async def get_local_plugin_readme(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
require_plugin_token(maibot_session)
logger.info(f"获取本地插件 README: {plugin_id}")
try:
plugin_path = find_plugin_path_by_id(plugin_id)
if plugin_path is None:
return {"success": False, "error": "插件未安装"}
for readme_name in ["README.md", "readme.md", "Readme.md", "README.MD"]:
readme_path = plugin_path / readme_name
if readme_path.exists():
try:
with open(readme_path, "r", encoding="utf-8") as file_obj:
readme_content = file_obj.read()
logger.info(f"成功读取本地 README: {readme_path}")
return {"success": True, "data": readme_content}
except Exception as e:
logger.warning(f"读取 {readme_path} 失败: {e}")
return {"success": False, "error": "本地未找到 README 文件"}
except Exception as e:
logger.error(f"获取本地 README 失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}

View File

@@ -1,36 +1,32 @@
"""WebSocket 插件加载进度推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from typing import Set, Dict, Any, Optional
import json
import asyncio
import json
from typing import Any, Optional, Set
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
from src.common.logger import get_logger
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
logger = get_logger("webui.plugin_progress")
# 创建路由器
router = APIRouter()
# 全局 WebSocket 连接池
active_connections: Set[WebSocket] = set()
# 当前加载进度状态
current_progress: Dict[str, Any] = {
"operation": "idle", # idle, fetch, install, uninstall, update
"stage": "idle", # idle, loading, success, error
"progress": 0, # 0-100
current_progress: dict[str, Any] = {
"operation": "idle",
"stage": "idle",
"progress": 0,
"message": "",
"error": None,
"plugin_id": None, # 当前操作的插件 ID
"plugin_id": None,
"total_plugins": 0,
"loaded_plugins": 0,
}
async def broadcast_progress(progress_data: Dict[str, Any]):
"""广播进度更新到所有连接的客户端"""
async def broadcast_progress(progress_data: dict[str, Any]) -> None:
global current_progress
current_progress = progress_data.copy()
@@ -38,7 +34,7 @@ async def broadcast_progress(progress_data: Dict[str, Any]):
return
message = json.dumps(progress_data, ensure_ascii=False)
disconnected = set()
disconnected: set[WebSocket] = set()
for websocket in active_connections:
try:
@@ -47,7 +43,6 @@ async def broadcast_progress(progress_data: Dict[str, Any]):
logger.error(f"发送进度更新失败: {e}")
disconnected.add(websocket)
# 移除断开的连接
for websocket in disconnected:
active_connections.discard(websocket)
@@ -57,23 +52,11 @@ async def update_progress(
progress: int,
message: str,
operation: str = "fetch",
error: str = None,
plugin_id: str = None,
error: Optional[str] = None,
plugin_id: Optional[str] = None,
total_plugins: int = 0,
loaded_plugins: int = 0,
):
"""更新并广播进度
Args:
stage: 阶段 (idle, loading, success, error)
progress: 进度百分比 (0-100)
message: 当前消息
operation: 操作类型 (fetch, install, uninstall, update)
error: 错误信息可选
plugin_id: 当前操作的插件 ID
total_plugins: 总插件数
loaded_plugins: 已加载插件数
"""
) -> None:
progress_data = {
"operation": operation,
"stage": stage,
@@ -91,25 +74,13 @@ async def update_progress(
@router.websocket("/ws/plugin-progress")
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)):
"""WebSocket 插件加载进度推送端点
客户端连接后会立即收到当前进度状态
支持三种认证方式按优先级
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/ws/plugin-progress?token=xxx
"""
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)) -> None:
is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
@@ -118,7 +89,6 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
is_authenticated = True
logger.debug("插件进度 WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
@@ -135,18 +105,13 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
logger.info(f"📡 插件进度 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
try:
# 发送当前进度状态
await websocket.send_text(json.dumps(current_progress, ensure_ascii=False))
# 保持连接并处理客户端消息
while True:
try:
data = await websocket.receive_text()
# 处理客户端心跳
if data == "ping":
await websocket.send_text("pong")
except Exception as e:
logger.error(f"处理客户端消息时出错: {e}")
break
@@ -160,5 +125,4 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
def get_progress_router() -> APIRouter:
"""获取插件进度 WebSocket 路由器"""
return router
return router

View File

@@ -0,0 +1,113 @@
from typing import Any, Optional
from pydantic import BaseModel, Field
class FetchRawFileRequest(BaseModel):
owner: str = Field(..., description="仓库所有者", examples=["MaiM-with-u"])
repo: str = Field(..., description="仓库名称", examples=["plugin-repo"])
branch: str = Field(..., description="分支名称", examples=["main"])
file_path: str = Field(..., description="文件路径", examples=["plugin_details.json"])
mirror_id: Optional[str] = Field(None, description="指定镜像源 ID")
custom_url: Optional[str] = Field(None, description="自定义完整 URL")
class FetchRawFileResponse(BaseModel):
success: bool = Field(..., description="是否成功")
data: Optional[str] = Field(None, description="文件内容")
error: Optional[str] = Field(None, description="错误信息")
mirror_used: Optional[str] = Field(None, description="使用的镜像源")
attempts: int = Field(..., description="尝试次数")
url: Optional[str] = Field(None, description="实际请求的 URL")
class CloneRepositoryRequest(BaseModel):
owner: str = Field(..., description="仓库所有者", examples=["MaiM-with-u"])
repo: str = Field(..., description="仓库名称", examples=["plugin-repo"])
target_path: str = Field(..., description="目标路径(相对于插件目录)")
branch: Optional[str] = Field(None, description="分支名称", examples=["main"])
mirror_id: Optional[str] = Field(None, description="指定镜像源 ID")
custom_url: Optional[str] = Field(None, description="自定义克隆 URL")
depth: Optional[int] = Field(None, description="克隆深度(浅克隆)", ge=1)
class CloneRepositoryResponse(BaseModel):
success: bool = Field(..., description="是否成功")
path: Optional[str] = Field(None, description="克隆路径")
error: Optional[str] = Field(None, description="错误信息")
mirror_used: Optional[str] = Field(None, description="使用的镜像源")
attempts: int = Field(..., description="尝试次数")
url: Optional[str] = Field(None, description="实际克隆的 URL")
message: Optional[str] = Field(None, description="附加信息")
class MirrorConfigResponse(BaseModel):
id: str = Field(..., description="镜像源 ID")
name: str = Field(..., description="镜像源名称")
raw_prefix: str = Field(..., description="Raw 文件前缀")
clone_prefix: str = Field(..., description="克隆前缀")
enabled: bool = Field(..., description="是否启用")
priority: int = Field(..., description="优先级(数字越小优先级越高)")
class AvailableMirrorsResponse(BaseModel):
mirrors: list[MirrorConfigResponse] = Field(..., description="镜像源列表")
default_priority: list[str] = Field(..., description="默认优先级顺序ID 列表)")
class AddMirrorRequest(BaseModel):
id: str = Field(..., description="镜像源 ID", examples=["custom-mirror"])
name: str = Field(..., description="镜像源名称", examples=["自定义镜像源"])
raw_prefix: str = Field(..., description="Raw 文件前缀", examples=["https://example.com/raw"])
clone_prefix: str = Field(..., description="克隆前缀", examples=["https://example.com/clone"])
enabled: bool = Field(True, description="是否启用")
priority: Optional[int] = Field(None, description="优先级")
class UpdateMirrorRequest(BaseModel):
name: Optional[str] = Field(None, description="镜像源名称")
raw_prefix: Optional[str] = Field(None, description="Raw 文件前缀")
clone_prefix: Optional[str] = Field(None, description="克隆前缀")
enabled: Optional[bool] = Field(None, description="是否启用")
priority: Optional[int] = Field(None, description="优先级")
class GitStatusResponse(BaseModel):
installed: bool = Field(..., description="是否已安装 Git")
version: Optional[str] = Field(None, description="Git 版本号")
path: Optional[str] = Field(None, description="Git 可执行文件路径")
error: Optional[str] = Field(None, description="错误信息")
class InstallPluginRequest(BaseModel):
plugin_id: str = Field(..., description="插件 ID")
repository_url: str = Field(..., description="插件仓库 URL")
branch: Optional[str] = Field("main", description="分支名称")
mirror_id: Optional[str] = Field(None, description="指定镜像源 ID")
class VersionResponse(BaseModel):
version: str = Field(..., description="麦麦版本号")
version_major: int = Field(..., description="主版本号")
version_minor: int = Field(..., description="次版本号")
version_patch: int = Field(..., description="补丁版本号")
class UninstallPluginRequest(BaseModel):
plugin_id: str = Field(..., description="插件 ID")
class UpdatePluginRequest(BaseModel):
plugin_id: str = Field(..., description="插件 ID")
repository_url: str = Field(..., description="插件仓库 URL")
branch: Optional[str] = Field("main", description="分支名称")
mirror_id: Optional[str] = Field(None, description="指定镜像源 ID")
class UpdatePluginConfigRequest(BaseModel):
enabled: Optional[bool] = None
config: Optional[dict[str, Any]] = None
class UpdatePluginRawConfigRequest(BaseModel):
config: str = Field(..., description="原始 TOML 配置内容")

View File

@@ -0,0 +1,221 @@
from datetime import datetime
from pathlib import Path
from typing import Any, Optional, cast, get_origin
import json
import os
import re
import shutil
import stat
from fastapi import HTTPException
from src.common.logger import get_logger
from src.core.config_types import ConfigField
from src.webui.core import get_token_manager
logger = get_logger("webui.plugin_routes")
def require_plugin_token(maibot_session: Optional[str]) -> str:
token_manager = get_token_manager()
if not maibot_session or not token_manager.verify_token(maibot_session):
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
return maibot_session
def validate_safe_path(user_path: str, base_path: Path) -> Path:
base_resolved = base_path.resolve()
if any(pattern in user_path for pattern in ["..", "\x00"]):
logger.warning(f"检测到可疑路径: {user_path}")
raise HTTPException(status_code=400, detail="路径包含非法字符")
if user_path.startswith("/") or user_path.startswith("\\") or (len(user_path) > 1 and user_path[1] == ":"):
logger.warning(f"检测到绝对路径: {user_path}")
raise HTTPException(status_code=400, detail="不允许使用绝对路径")
target_path = (base_path / user_path).resolve()
try:
target_path.relative_to(base_resolved)
except ValueError as e:
logger.warning(f"路径遍历攻击检测: {user_path} -> {target_path}")
raise HTTPException(status_code=400, detail="路径超出允许范围") from e
return target_path
def validate_plugin_id(plugin_id: str) -> str:
if not plugin_id or not plugin_id.strip():
logger.warning("非法插件 ID: 空字符串")
raise HTTPException(status_code=400, detail="插件 ID 不能为空")
for pattern in ["/", "\\", "\x00", "..", "\n", "\r", "\t"]:
if pattern in plugin_id:
logger.warning(f"非法插件 ID 格式: {plugin_id} (包含危险字符)")
raise HTTPException(status_code=400, detail="插件 ID 包含非法字符")
if plugin_id.startswith(".") or plugin_id.endswith("."):
logger.warning(f"非法插件 ID: {plugin_id}")
raise HTTPException(status_code=400, detail="插件 ID 不能以点开头或结尾")
if plugin_id in {".", ".."}:
logger.warning(f"非法插件 ID: {plugin_id}")
raise HTTPException(status_code=400, detail="插件 ID 不能为特殊目录名")
return plugin_id
def parse_version(version_str: str) -> tuple[int, int, int]:
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version_str, flags=re.IGNORECASE)[0]
parts = base_version.split(".")
if len(parts) < 3:
parts.extend(["0"] * (3 - len(parts)))
try:
return int(parts[0]), int(parts[1]), int(parts[2])
except (ValueError, IndexError):
logger.warning(f"无法解析版本号: {version_str},返回默认值 (0, 0, 0)")
return 0, 0, 0
def deep_merge(dst: dict[str, Any], src: dict[str, Any]) -> None:
for key, value in src.items():
if key in dst and isinstance(dst[key], dict) and isinstance(value, dict):
deep_merge(dst[key], value)
else:
dst[key] = value
def normalize_dotted_keys(obj: dict[str, Any]) -> dict[str, Any]:
result: dict[str, Any] = {}
dotted_items: list[tuple[str, Any]] = []
for key, value in obj.items():
if "." in key:
dotted_items.append((key, value))
else:
result[key] = normalize_dotted_keys(value) if isinstance(value, dict) else value
for dotted_key, value in dotted_items:
normalized_value = normalize_dotted_keys(value) if isinstance(value, dict) else value
parts = dotted_key.split(".")
if "" in parts:
logger.warning(f"键路径包含空段: '{dotted_key}'")
parts = [part for part in parts if part]
if not parts:
logger.warning(f"忽略空键路径: '{dotted_key}'")
continue
current = result
for index, part in enumerate(parts[:-1]):
if part in current and not isinstance(current[part], dict):
path_ctx = ".".join(parts[: index + 1])
logger.warning(f"键冲突:{part} 已存在且非字典,覆盖为字典以展开 {dotted_key} (路径 {path_ctx})")
current[part] = {}
current = current.setdefault(part, {})
last_part = parts[-1]
if last_part in current and isinstance(current[last_part], dict) and isinstance(normalized_value, dict):
deep_merge(current[last_part], normalized_value)
else:
current[last_part] = normalized_value
return result
def coerce_types(schema_part: dict[str, Any], config_part: dict[str, Any]) -> None:
def is_list_type(tp: Any) -> bool:
origin = get_origin(tp)
return tp is list or origin is list
for key, schema_val in schema_part.items():
if key not in config_part:
continue
value = config_part[key]
if isinstance(schema_val, ConfigField):
if is_list_type(schema_val.type) and isinstance(value, str):
config_part[key] = [item.strip() for item in value.split(",") if item.strip()]
elif isinstance(schema_val, dict) and isinstance(value, dict):
coerce_types(schema_val, value)
def find_plugin_instance(plugin_id: str) -> Optional[Any]:
from src.plugin_runtime.integration import get_plugin_runtime_manager
manager = get_plugin_runtime_manager()
for supervisor in manager.supervisors:
registered = supervisor._registered_plugins.get(plugin_id)
if registered is not None:
return registered
return None
def get_plugins_dir() -> Path:
plugins_dir = Path("plugins").resolve()
plugins_dir.mkdir(exist_ok=True)
return plugins_dir
def get_plugin_candidate_paths(plugin_id: str) -> tuple[Path, Path]:
plugins_dir = get_plugins_dir()
folder_name = plugin_id.replace(".", "_")
return validate_safe_path(folder_name, plugins_dir), validate_safe_path(plugin_id, plugins_dir)
def resolve_installed_plugin_path(plugin_id: str) -> Optional[Path]:
new_format_path, old_format_path = get_plugin_candidate_paths(plugin_id)
if new_format_path.exists():
return new_format_path
return old_format_path if old_format_path.exists() else None
def parse_repository_url(repository_url: str) -> tuple[str, str, str]:
repo_url = repository_url.rstrip("/").removesuffix(".git")
parts = repo_url.split("/")
if len(parts) < 2:
raise HTTPException(status_code=400, detail="无效的仓库 URL")
return repo_url, parts[-2], parts[-1]
def load_manifest_json(manifest_path: Path) -> Optional[dict[str, Any]]:
if not manifest_path.exists():
return None
try:
with open(manifest_path, "r", encoding="utf-8") as file_obj:
return cast(dict[str, Any], json.load(file_obj))
except Exception:
return None
def iter_plugin_directories() -> list[Path]:
return [path for path in get_plugins_dir().iterdir() if path.is_dir()]
def find_plugin_path_by_id(plugin_id: str) -> Optional[Path]:
for plugin_path in iter_plugin_directories():
manifest = load_manifest_json(plugin_path / "_manifest.json")
if manifest is not None and (manifest.get("id") == plugin_id or plugin_path.name == plugin_id):
return plugin_path
return None
def backup_file(file_path: Path, action: str, move_file: bool = False) -> Optional[Path]:
if not file_path.exists():
return None
backup_name = f"{file_path.name}.{action}.{datetime.now().strftime('%Y%m%d%H%M%S')}"
backup_path = file_path.parent / backup_name
if move_file:
shutil.move(file_path, backup_path)
else:
shutil.copy(file_path, backup_path)
return backup_path
def remove_tree(path: Path) -> None:
def remove_readonly(func: Any, target_path: str, _: Any) -> None:
os.chmod(target_path, stat.S_IWRITE)
func(target_path)
shutil.rmtree(path, onerror=remove_readonly)

View File

@@ -1,29 +1,22 @@
"""统计数据 API 路由"""
from datetime import datetime, timedelta
from typing import Any, Optional
from typing import Any
from fastapi import APIRouter, Cookie, Depends, Header, HTTPException
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy import desc, func, or_
from sqlmodel import col, select
from src.common.database.database import get_db_session
from src.common.database.database_model import Messages, ModelUsage, OnlineTime
from src.common.database.database_model import ModelUsage, OnlineTime
from src.common.logger import get_logger
from src.webui.core import verify_auth_token_from_cookie_or_header
from src.common.message_repository import count_messages
from src.webui.dependencies import require_auth
logger = get_logger("webui.statistics")
router = APIRouter(prefix="/statistics", tags=["statistics"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
router = APIRouter(prefix="/statistics", tags=["statistics"], dependencies=[Depends(require_auth)])
class StatisticsSummary(BaseModel):
@@ -70,7 +63,7 @@ class DashboardData(BaseModel):
@router.get("/dashboard", response_model=DashboardData)
async def get_dashboard_data(hours: int = 24, _auth: bool = Depends(require_auth)):
async def get_dashboard_data(hours: int = 24):
"""
获取仪表盘统计数据
@@ -159,24 +152,12 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S
if end > start:
summary.online_time += (end - start).total_seconds()
# 查询消息数量 - 使用聚合优化
with get_db_session() as session:
statement = select(func.count()).where(
col(Messages.timestamp) >= start_time,
col(Messages.timestamp) <= end_time,
)
total_messages = session.exec(statement).one()
summary.total_messages = int(total_messages or 0)
# 统计回复数量
with get_db_session() as session:
statement = select(func.count()).where(
col(Messages.timestamp) >= start_time,
col(Messages.timestamp) <= end_time,
col(Messages.reply_to).is_not(None),
)
total_replies = session.exec(statement).one()
summary.total_replies = int(total_replies or 0)
summary.total_messages = count_messages(start_time=start_time.timestamp(), end_time=end_time.timestamp())
summary.total_replies = count_messages(
start_time=start_time.timestamp(),
end_time=end_time.timestamp(),
has_reply_to=True,
)
# 计算派生指标
if summary.online_time > 0:
@@ -351,7 +332,7 @@ async def _get_recent_activity(limit: int = 10) -> list[dict[str, Any]]:
@router.get("/summary")
async def get_summary(hours: int = 24, _auth: bool = Depends(require_auth)):
async def get_summary(hours: int = 24):
"""
获取统计摘要
@@ -369,7 +350,7 @@ async def get_summary(hours: int = 24, _auth: bool = Depends(require_auth)):
@router.get("/models")
async def get_model_stats(hours: int = 24, _auth: bool = Depends(require_auth)):
async def get_model_stats(hours: int = 24):
"""
获取模型统计

View File

@@ -7,28 +7,20 @@
import os
import time
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from src.config.config import MMC_VERSION
from src.common.logger import get_logger
from src.webui.core import verify_auth_token_from_cookie_or_header
from src.webui.dependencies import require_auth
router = APIRouter(prefix="/system", tags=["system"])
router = APIRouter(prefix="/system", tags=["system"], dependencies=[Depends(require_auth)])
logger = get_logger("webui_system")
# 记录启动时间
_start_time = time.time()
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
class RestartResponse(BaseModel):
"""重启响应"""
@@ -46,7 +38,7 @@ class StatusResponse(BaseModel):
@router.post("/restart", response_model=RestartResponse)
async def restart_maibot(_auth: bool = Depends(require_auth)):
async def restart_maibot():
"""
重启麦麦主程序
@@ -77,7 +69,7 @@ async def restart_maibot(_auth: bool = Depends(require_auth)):
@router.get("/status", response_model=StatusResponse)
async def get_maibot_status(_auth: bool = Depends(require_auth)):
async def get_maibot_status():
"""
获取麦麦运行状态
@@ -100,7 +92,7 @@ async def get_maibot_status(_auth: bool = Depends(require_auth)):
@router.post("/reload-config")
async def reload_config(_auth: bool = Depends(require_auth)):
async def reload_config():
"""
热重载配置(不重启进程)

View File

@@ -1,9 +1,7 @@
from .logs import router as logs_router
from .plugin_progress import get_progress_router
from .auth import router as ws_auth_router
__all__ = [
"logs_router",
"get_progress_router",
"ws_auth_router",
]

View File

@@ -1,11 +1,8 @@
"""WebSocket 认证模块
"""WebSocket 认证模块"""
提供所有 WebSocket 端点统一使用的临时 token 认证机制。
临时 token 有效期 60 秒,且只能使用一次,用于解决 WebSocket 握手时 Cookie 不可用的问题。
"""
from fastapi import APIRouter, Cookie, Header
from typing import Optional
from fastapi import APIRouter, Cookie
import secrets
import time
from src.common.logger import get_logger
@@ -77,25 +74,17 @@ def verify_ws_token(temp_token: str) -> bool:
@router.get("/ws-token")
async def get_ws_token(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
获取 WebSocket 连接用的临时 token
此端点验证当前会话 Cookie 或 Authorization header
此端点验证当前会话 Cookie
然后返回一个临时 token 用于 WebSocket 握手认证。
临时 token 有效期 60 秒,且只能使用一次。
注意:在未认证时返回 200 状态码但 success=False避免前端因 401 刷新页面。
"""
# 获取当前 session token
session_token = None
if maibot_session:
session_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
session_token = authorization.replace("Bearer ", "")
if not session_token:
if not maibot_session:
# 返回 200 但 success=False避免前端因 401 刷新页面
# 这在登录页面是正常情况,不应该触发错误处理
logger.debug("ws-token 请求:未提供认证信息(可能在登录页面)")
@@ -103,12 +92,12 @@ async def get_ws_token(
# 验证 session token
token_manager = get_token_manager()
if not token_manager.verify_token(session_token):
if not token_manager.verify_token(maibot_session):
# 同样返回 200 但 success=False避免前端刷新
logger.debug("ws-token 请求:认证已过期")
return {"success": False, "message": "认证已过期,请重新登录", "token": None, "expires_in": 0}
# 生成临时 WebSocket token
ws_token = generate_ws_token(session_token)
ws_token = generate_ws_token(maibot_session)
return {"success": True, "token": ws_token, "expires_in": _WS_TOKEN_EXPIRE_SECONDS}

View File

@@ -1,177 +1,11 @@
"""WebSocket 日志推送模块"""
"""WebSocket 日志推送路由兼容导出。"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from typing import Set, Optional
import json
from pathlib import Path
from src.common.logger import get_logger
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
from src.webui.logs_ws import active_connections, broadcast_log, load_recent_logs, router, websocket_logs
logger = get_logger("webui.logs_ws")
router = APIRouter()
# 全局 WebSocket 连接池
active_connections: Set[WebSocket] = set()
def load_recent_logs(limit: int = 100) -> list[dict]:
"""从日志文件中加载最近的日志
Args:
limit: 返回的最大日志条数
Returns:
日志列表
"""
logs = []
log_dir = Path("logs")
if not log_dir.exists():
return logs
# 获取所有日志文件,按修改时间排序
log_files = sorted(log_dir.glob("app_*.log.jsonl"), key=lambda f: f.stat().st_mtime, reverse=True)
# 用于生成唯一 ID 的计数器
log_counter = 0
# 从最新的文件开始读取
for log_file in log_files:
if len(logs) >= limit:
break
try:
with open(log_file, "r", encoding="utf-8") as f:
lines = f.readlines()
# 从文件末尾开始读取
for line in reversed(lines):
if len(logs) >= limit:
break
try:
log_entry = json.loads(line.strip())
# 转换为前端期望的格式
# 使用时间戳 + 计数器生成唯一 ID
timestamp_id = (
log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
)
formatted_log = {
"id": f"{timestamp_id}_{log_counter}",
"timestamp": log_entry.get("timestamp", ""),
"level": log_entry.get("level", "INFO").upper(),
"module": log_entry.get("logger_name", ""),
"message": log_entry.get("event", ""),
}
logs.append(formatted_log)
log_counter += 1
except (json.JSONDecodeError, KeyError):
continue
except Exception as e:
logger.error(f"读取日志文件失败 {log_file}: {e}")
continue
# 反转列表,使其按时间顺序排列(旧到新)
return list(reversed(logs))
@router.websocket("/ws/logs")
async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None)):
"""WebSocket 日志推送端点
客户端连接后会持续接收服务器端的日志消息
支持三种认证方式(按优先级):
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/ws/logs?token=xxx
"""
is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
await websocket.accept()
active_connections.add(websocket)
logger.info(f"📡 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
# 连接建立后,立即发送历史日志
try:
recent_logs = load_recent_logs(limit=100)
logger.info(f"发送 {len(recent_logs)} 条历史日志到客户端")
for log_entry in recent_logs:
await websocket.send_text(json.dumps(log_entry, ensure_ascii=False))
except Exception as e:
logger.error(f"发送历史日志失败: {e}")
try:
# 保持连接,等待客户端消息或断开
while True:
# 接收客户端消息(用于心跳或控制指令)
data = await websocket.receive_text()
# 可以处理客户端的控制消息,例如:
# - "ping" -> 心跳检测
# - {"filter": "ERROR"} -> 设置日志级别过滤
if data == "ping":
await websocket.send_text("pong")
except WebSocketDisconnect:
active_connections.discard(websocket)
logger.info(f"📡 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
except Exception as e:
logger.error(f"❌ WebSocket 错误: {e}")
active_connections.discard(websocket)
async def broadcast_log(log_data: dict):
"""广播日志到所有连接的 WebSocket 客户端
Args:
log_data: 日志数据字典
"""
if not active_connections:
return
# 格式化为 JSON
message = json.dumps(log_data, ensure_ascii=False)
# 记录需要断开的连接
disconnected = set()
# 广播到所有客户端
for connection in active_connections:
try:
await connection.send_text(message)
except Exception:
# 发送失败,标记为断开
disconnected.add(connection)
# 清理断开的连接
if disconnected:
active_connections.difference_update(disconnected)
logger.debug(f"清理了 {len(disconnected)} 个断开的 WebSocket 连接")
__all__ = [
"active_connections",
"broadcast_log",
"load_recent_logs",
"router",
"websocket_logs",
]

View File

@@ -1,28 +1,26 @@
"""WebUI API 路由"""
from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie, Depends
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from pydantic import BaseModel, Field
from typing import Optional
from src.common.logger import get_logger
from src.webui.core import (
clear_auth_cookie,
check_auth_rate_limit,
get_rate_limiter,
get_token_manager,
set_auth_cookie,
clear_auth_cookie,
get_rate_limiter,
check_auth_rate_limit,
)
from src.webui.dependencies import require_auth, verify_token_optional
from src.webui.routers.config import router as config_router
from src.webui.routers.statistics import router as statistics_router
from src.webui.routers.person import router as person_router
from src.webui.routers.expression import router as expression_router
from src.webui.routers.jargon import router as jargon_router
from src.webui.routers.emoji import router as emoji_router
from src.webui.routers.plugin import router as plugin_router
from src.webui.routers.websocket.plugin_progress import get_progress_router
from src.webui.routers.plugin import get_progress_router, router as plugin_router
from src.webui.routers.system import router as system_router
from src.webui.routers.model import router as model_router
from src.webui.routers.websocket.auth import router as ws_auth_router
from src.webui.routers.annual_report import router as annual_report_router
logger = get_logger("webui.api")
@@ -51,8 +49,6 @@ router.include_router(system_router)
router.include_router(model_router)
# 注册 WebSocket 认证路由
router.include_router(ws_auth_router)
# 注册年度报告路由
router.include_router(annual_report_router)
class TokenVerifyRequest(BaseModel):
@@ -190,9 +186,7 @@ async def logout(response: Response):
@router.get("/auth/check")
async def check_auth_status(
request: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
authenticated: bool = Depends(verify_token_optional),
):
"""
检查当前认证状态(用于前端判断是否已登录)
@@ -201,46 +195,17 @@ async def check_auth_status(
认证状态
"""
try:
token = None
# 记录请求信息用于调试
logger.debug(
f"检查认证状态 - Cookie: {maibot_session[:20] if maibot_session else 'None'}..., Authorization: {'Present' if authorization else 'None'}"
)
# 优先从 Cookie 获取
if maibot_session:
token = maibot_session
logger.debug("使用 Cookie 中的 token")
# 其次从 Header 获取
elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
logger.debug("使用 Header 中的 token")
if not token:
logger.debug("未找到 token返回未认证")
return {"authenticated": False}
token_manager = get_token_manager()
is_valid = token_manager.verify_token(token)
logger.debug(f"Token 验证结果: {is_valid}")
if is_valid:
return {"authenticated": True}
else:
return {"authenticated": False}
logger.debug(f"检查认证状态,结果: {authenticated}")
return {"authenticated": authenticated}
except Exception as e:
logger.error(f"认证检查失败: {e}", exc_info=True)
return {"authenticated": False}
@router.post("/auth/update", response_model=TokenUpdateResponse)
@router.post("/auth/update", response_model=TokenUpdateResponse, dependencies=[Depends(require_auth)])
async def update_token(
request: TokenUpdateRequest,
response: Response,
req: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
更新访问令牌(需要当前有效的 token
@@ -248,28 +213,13 @@ async def update_token(
Args:
request: 包含新 token 的更新请求
response: FastAPI Response 对象
maibot_session: Cookie 中的 token
authorization: Authorization header (Bearer token)
Returns:
更新结果
"""
try:
# 验证当前 token优先 Cookie其次 Header
current_token = None
if maibot_session:
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="当前 Token 无效")
# 更新 token
success, message = token_manager.update_token(request.new_token)
@@ -285,40 +235,22 @@ async def update_token(
raise HTTPException(status_code=500, detail="Token 更新失败") from e
@router.post("/auth/regenerate", response_model=TokenRegenerateResponse)
@router.post("/auth/regenerate", response_model=TokenRegenerateResponse, dependencies=[Depends(require_auth)])
async def regenerate_token(
response: Response,
request: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
重新生成访问令牌(需要当前有效的 token
Args:
response: FastAPI Response 对象
maibot_session: Cookie 中的 token
authorization: Authorization header (Bearer token)
Returns:
新生成的 token
"""
try:
# 验证当前 token优先 Cookie其次 Header
current_token = None
if maibot_session:
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="当前 Token 无效")
# 重新生成 token
new_token = token_manager.regenerate_token()
@@ -333,38 +265,17 @@ async def regenerate_token(
raise HTTPException(status_code=500, detail="Token 重新生成失败") from e
@router.get("/setup/status", response_model=FirstSetupStatusResponse)
async def get_setup_status(
request: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
@router.get("/setup/status", response_model=FirstSetupStatusResponse, dependencies=[Depends(require_auth)])
async def get_setup_status():
"""
获取首次配置状态
Args:
maibot_session: Cookie 中的 token
authorization: Authorization header (Bearer token)
Returns:
首次配置状态
"""
try:
# 验证 token优先 Cookie其次 Header
current_token = None
if maibot_session:
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="Token 无效")
# 检查是否为首次配置
is_first = token_manager.is_first_setup()
@@ -376,38 +287,17 @@ async def get_setup_status(
raise HTTPException(status_code=500, detail="获取配置状态失败") from e
@router.post("/setup/complete", response_model=CompleteSetupResponse)
async def complete_setup(
request: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
@router.post("/setup/complete", response_model=CompleteSetupResponse, dependencies=[Depends(require_auth)])
async def complete_setup():
"""
标记首次配置完成
Args:
maibot_session: Cookie 中的 token
authorization: Authorization header (Bearer token)
Returns:
完成结果
"""
try:
# 验证 token优先 Cookie其次 Header
current_token = None
if maibot_session:
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="Token 无效")
# 标记配置完成
success = token_manager.mark_setup_completed()
@@ -419,38 +309,17 @@ async def complete_setup(
raise HTTPException(status_code=500, detail="标记配置完成失败") from e
@router.post("/setup/reset", response_model=ResetSetupResponse)
async def reset_setup(
request: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
@router.post("/setup/reset", response_model=ResetSetupResponse, dependencies=[Depends(require_auth)])
async def reset_setup():
"""
重置首次配置状态,允许重新进入配置向导
Args:
maibot_session: Cookie 中的 token
authorization: Authorization header (Bearer token)
Returns:
重置结果
"""
try:
# 验证 token优先 Cookie其次 Header
current_token = None
if maibot_session:
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="Token 无效")
# 重置配置状态
success = token_manager.reset_setup_status()

View File

@@ -49,6 +49,7 @@ class WebUIServer:
service_name="WebUI 服务器",
logger=logger,
config_hint="WEBUI_PORT (.env)",
allow_reuse_addr=True,
)
config = Config(