From 172615f18ae2e86bfe5fa54cbc006145814a33bb Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Sat, 14 Mar 2026 21:06:36 +0800 Subject: [PATCH] =?UTF-8?q?WebUI=20=E5=89=8D=E7=AB=AF=20&=20=E5=90=8E?= =?UTF-8?q?=E7=AB=AF=E8=B6=85=E7=BA=A7=E5=A4=A7=E9=87=8D=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../components/electron/BackendManager.tsx | 13 +- .../src/components/expression-reviewer.tsx | 9 +- dashboard/src/components/layout/Header.tsx | 24 +- dashboard/src/components/layout/Layout.tsx | 10 +- dashboard/src/components/layout/constants.ts | 22 +- dashboard/src/components/layout/types.ts | 1 + dashboard/src/components/search-dialog.tsx | 144 +- .../src/components/share-pack-dialog.tsx | 11 +- dashboard/src/components/ui/dialog.tsx | 52 +- dashboard/src/components/ui/kbd.tsx | 23 +- dashboard/src/components/ui/scroll-area.tsx | 17 +- dashboard/src/lib/annual-report-api.ts | 136 -- dashboard/src/lib/keyboard.ts | 93 + dashboard/src/lib/settings-manager.ts | 9 +- dashboard/src/router.tsx | 22 +- dashboard/src/routes/annual-report.tsx | 883 ------- .../src/routes/chat/VirtualIdentityDialog.tsx | 12 +- .../config/bot/sections/ProcessingSection.tsx | 19 +- dashboard/src/routes/config/model.tsx | 8 +- .../config/modelProvider/ProviderForm.tsx | 9 +- dashboard/src/routes/config/pack-detail.tsx | 9 +- dashboard/src/routes/person.tsx | 15 +- .../routes/resource/emoji/EmojiDialogs.tsx | 27 +- .../resource/expression/ExpressionDialogs.tsx | 17 +- .../routes/resource/jargon/JargonDialogs.tsx | 18 +- .../resource/knowledge-graph/GraphDialogs.tsx | 10 +- src/chat/message_receive/bot.py | 1 - .../message_receive/uni_message_sender.py | 2 +- src/chat/replyer/group_generator.py | 39 +- src/chat/replyer/private_generator.py | 21 +- src/common/logger_color_and_mapping.py | 2 - src/common/message_repository.py | 53 +- src/common/utils/port_checker.py | 7 +- src/webui/api/planner.py | 12 +- src/webui/api/replier.py | 12 +- src/webui/app.py | 1 - src/webui/core/__init__.py | 2 + src/webui/core/auth.py | 61 +- src/webui/dependencies.py | 34 +- src/webui/routers/annual_report.py | 916 -------- src/webui/routers/chat.py | 801 ------- src/webui/routers/chat/__init__.py | 18 + src/webui/routers/chat/routes.py | 174 ++ src/webui/routers/chat/support.py | 614 +++++ src/webui/routers/config.py | 51 +- src/webui/routers/emoji/__init__.py | 3 + .../routers/{emoji.py => emoji/routes.py} | 792 ++----- src/webui/routers/emoji/schemas.py | 140 ++ src/webui/routers/emoji/support.py | 142 ++ src/webui/routers/expression.py | 81 +- src/webui/routers/jargon.py | 183 +- src/webui/routers/knowledge.py | 44 +- src/webui/routers/model.py | 64 +- src/webui/routers/person.py | 61 +- src/webui/routers/plugin.py | 2054 ----------------- src/webui/routers/plugin/__init__.py | 17 + src/webui/routers/plugin/catalog.py | 205 ++ src/webui/routers/plugin/config_routes.py | 333 +++ src/webui/routers/plugin/management.py | 302 +++ .../plugin_progress.py => plugin/progress.py} | 72 +- src/webui/routers/plugin/schemas.py | 113 + src/webui/routers/plugin/support.py | 221 ++ src/webui/routers/statistics.py | 49 +- src/webui/routers/system.py | 22 +- src/webui/routers/websocket/__init__.py | 2 - src/webui/routers/websocket/auth.py | 25 +- src/webui/routers/websocket/logs.py | 184 +- src/webui/routes.py | 165 +- src/webui/webui_server.py | 1 + 69 files changed, 3128 insertions(+), 6581 deletions(-) delete mode 100644 dashboard/src/lib/annual-report-api.ts create mode 100644 dashboard/src/lib/keyboard.ts delete mode 100644 dashboard/src/routes/annual-report.tsx delete mode 100644 src/webui/routers/annual_report.py delete mode 100644 src/webui/routers/chat.py create mode 100644 src/webui/routers/chat/__init__.py create mode 100644 src/webui/routers/chat/routes.py create mode 100644 src/webui/routers/chat/support.py create mode 100644 src/webui/routers/emoji/__init__.py rename src/webui/routers/{emoji.py => emoji/routes.py} (57%) create mode 100644 src/webui/routers/emoji/schemas.py create mode 100644 src/webui/routers/emoji/support.py delete mode 100644 src/webui/routers/plugin.py create mode 100644 src/webui/routers/plugin/__init__.py create mode 100644 src/webui/routers/plugin/catalog.py create mode 100644 src/webui/routers/plugin/config_routes.py create mode 100644 src/webui/routers/plugin/management.py rename src/webui/routers/{websocket/plugin_progress.py => plugin/progress.py} (64%) create mode 100644 src/webui/routers/plugin/schemas.py create mode 100644 src/webui/routers/plugin/support.py diff --git a/dashboard/src/components/electron/BackendManager.tsx b/dashboard/src/components/electron/BackendManager.tsx index caef5137..d8e6cc42 100644 --- a/dashboard/src/components/electron/BackendManager.tsx +++ b/dashboard/src/components/electron/BackendManager.tsx @@ -14,13 +14,13 @@ import { import { Button } from '@/components/ui/button' import { Dialog, + DialogBody, DialogContent, DialogHeader, DialogTitle, } from '@/components/ui/dialog' import { Input } from '@/components/ui/input' import { Label } from '@/components/ui/label' -import { ScrollArea } from '@/components/ui/scroll-area' import { useBackendConnections } from '@/hooks/useBackendConnections' import { isElectron } from '@/lib/runtime' import type { BackendConnection } from '@/types/electron' @@ -78,7 +78,7 @@ export function BackendManager({ open, onOpenChange }: BackendManagerProps) { return ( <> - + 后端连接管理 @@ -88,7 +88,7 @@ export function BackendManager({ open, onOpenChange }: BackendManagerProps) { ) : ( - +
{backends.map((backend) => { const isActive = backend.id === activeId @@ -100,7 +100,7 @@ export function BackendManager({ open, onOpenChange }: BackendManagerProps) { }`} >
-
+
{isActive ? ( ) : ( @@ -156,7 +156,7 @@ export function BackendManager({ open, onOpenChange }: BackendManagerProps) { ) })}
- + )}
@@ -173,7 +173,7 @@ export function BackendManager({ open, onOpenChange }: BackendManagerProps) { {/* Edit/Add Dialog */} !open && setEditConn(null)}> - + {editConn?.id ? '编辑连接' : '添加连接'} @@ -212,6 +212,7 @@ export function BackendManager({ open, onOpenChange }: BackendManagerProps) { !editConn?.url || !/^https?:\/\//.test(editConn.url) } + data-dialog-action="confirm" > 保存 diff --git a/dashboard/src/components/expression-reviewer.tsx b/dashboard/src/components/expression-reviewer.tsx index 8373c840..9f3e8b27 100644 --- a/dashboard/src/components/expression-reviewer.tsx +++ b/dashboard/src/components/expression-reviewer.tsx @@ -22,6 +22,7 @@ import { import { Button } from '@/components/ui/button' import { Input } from '@/components/ui/input' import { Badge } from '@/components/ui/badge' +import { ShortcutKbd } from '@/components/ui/kbd' import { ScrollArea } from '@/components/ui/scroll-area' import { Checkbox } from '@/components/ui/checkbox' import { Tabs, TabsList, TabsTrigger } from '@/components/ui/tabs' @@ -1689,19 +1690,19 @@ if (isCurrent) { {/* 底部快捷键提示(桌面端) */}
- + 拒绝
- + 通过
- + 上一条
- + 下一条
| diff --git a/dashboard/src/components/layout/Header.tsx b/dashboard/src/components/layout/Header.tsx index d17d080e..eb99d5fc 100644 --- a/dashboard/src/components/layout/Header.tsx +++ b/dashboard/src/components/layout/Header.tsx @@ -1,5 +1,4 @@ -import { Link } from '@tanstack/react-router' -import { BookOpen, ChevronLeft, Globe, LogOut, Menu, Moon, PieChart, Search, Server, Sun } from 'lucide-react' +import { BookOpen, ChevronLeft, Globe, LogOut, Menu, Moon, Search, Server, Sun } from 'lucide-react' import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -13,7 +12,7 @@ import { DropdownMenuItem, DropdownMenuTrigger, } from '@/components/ui/dropdown-menu' -import { Kbd } from '@/components/ui/kbd' +import { ShortcutKbd } from '@/components/ui/kbd' import { toggleThemeWithTransition } from '@/components/use-theme' import { useBackground } from '@/hooks/use-background' import { logout } from '@/lib/fetch-with-auth' @@ -99,7 +98,7 @@ export function Header({ title={t('header.toggleConnection')} > - + {activeBackendName} @@ -107,19 +106,6 @@ export function Header({
)} - {/* 年度总结入口 */} - - - - {/* 搜索框 */} {/* 搜索对话框 */} diff --git a/dashboard/src/components/layout/Layout.tsx b/dashboard/src/components/layout/Layout.tsx index 058bfca7..d09f6753 100644 --- a/dashboard/src/components/layout/Layout.tsx +++ b/dashboard/src/components/layout/Layout.tsx @@ -13,6 +13,7 @@ import { useAuthGuard } from '@/hooks/use-auth' import { useBackground } from '@/hooks/use-background' import { TitleBar } from '@/components/electron/TitleBar' +import { matchesShortcut } from '@/lib/keyboard' import { isElectron } from '@/lib/runtime' import { cn } from '@/lib/utils' import { menuSections } from './constants' @@ -49,7 +50,7 @@ export function Layout({ children }: LayoutProps) { // 搜索快捷键监听(Cmd/Ctrl + K) useEffect(() => { const handleKeyDown = (e: KeyboardEvent) => { - if ((e.metaKey || e.ctrlKey) && e.key === 'k') { + if (matchesShortcut(e, ['mod', 'k'])) { e.preventDefault() setSearchOpen(true) } @@ -68,9 +69,8 @@ export function Layout({ children }: LayoutProps) { } } - const unsubscribe = router.subscribe('onResolved', () => { - const pathname = router.state.location.pathname - const pageTitle = pathToLabel[pathname] ?? 'MaiBot Dashboard' + return router.subscribe('onResolved', () => { + const pageTitle = pathToLabel[router.state.location.pathname] ?? 'MaiBot Dashboard' const fullTitle = pageTitle === 'MaiBot Dashboard' ? 'MaiBot Dashboard' : `${pageTitle} — MaiBot Dashboard` @@ -90,8 +90,6 @@ export function Layout({ children }: LayoutProps) { }) } }) - - return unsubscribe }, [router, announce, t]) // 获取实际应用的主题(处理 system 情况) diff --git a/dashboard/src/components/layout/constants.ts b/dashboard/src/components/layout/constants.ts index 1d487586..6bdcac21 100644 --- a/dashboard/src/components/layout/constants.ts +++ b/dashboard/src/components/layout/constants.ts @@ -6,25 +6,25 @@ export const menuSections: MenuSection[] = [ { title: 'sidebar.groups.overview', items: [ - { icon: Home, label: 'sidebar.menu.home', path: '/' }, + { icon: Home, label: 'sidebar.menu.home', path: '/', searchDescription: 'search.items.homeDesc' }, ], }, { title: 'sidebar.groups.botConfig', items: [ - { icon: FileText, label: 'sidebar.menu.botMainConfig', path: '/config/bot' }, - { icon: Server, label: 'sidebar.menu.aiModelProvider', path: '/config/modelProvider', tourId: 'sidebar-model-provider' }, - { icon: Boxes, label: 'sidebar.menu.modelManagement', path: '/config/model', tourId: 'sidebar-model-management' }, + { icon: FileText, label: 'sidebar.menu.botMainConfig', path: '/config/bot', searchDescription: 'search.items.botConfigDesc' }, + { icon: Server, label: 'sidebar.menu.aiModelProvider', path: '/config/modelProvider', searchDescription: 'search.items.modelProviderDesc', tourId: 'sidebar-model-provider' }, + { icon: Boxes, label: 'sidebar.menu.modelManagement', path: '/config/model', searchDescription: 'search.items.modelDesc', tourId: 'sidebar-model-management' }, { icon: Sliders, label: 'sidebar.menu.adapterConfig', path: '/config/adapter' }, ], }, { title: 'sidebar.groups.botResources', items: [ - { icon: Smile, label: 'sidebar.menu.emojiManagement', path: '/resource/emoji' }, - { icon: MessageSquare, label: 'sidebar.menu.expressionManagement', path: '/resource/expression' }, - { icon: Hash, label: 'sidebar.menu.slangManagement', path: '/resource/jargon' }, - { icon: UserCircle, label: 'sidebar.menu.personInfo', path: '/resource/person' }, + { icon: Smile, label: 'sidebar.menu.emojiManagement', path: '/resource/emoji', searchDescription: 'search.items.emojiDesc' }, + { icon: MessageSquare, label: 'sidebar.menu.expressionManagement', path: '/resource/expression', searchDescription: 'search.items.expressionDesc' }, + { icon: Hash, label: 'sidebar.menu.slangManagement', path: '/resource/jargon', searchDescription: 'search.items.jargonDesc' }, + { icon: UserCircle, label: 'sidebar.menu.personInfo', path: '/resource/person', searchDescription: 'search.items.personDesc' }, { icon: Network, label: 'sidebar.menu.knowledgeGraph', path: '/resource/knowledge-graph' }, { icon: Database, label: 'sidebar.menu.knowledgeBase', path: '/resource/knowledge-base' }, ], @@ -32,10 +32,10 @@ export const menuSections: MenuSection[] = [ { title: 'sidebar.groups.extensionsMonitor', items: [ - { icon: Package, label: 'sidebar.menu.pluginMarket', path: '/plugins' }, + { icon: Package, label: 'sidebar.menu.pluginMarket', path: '/plugins', searchDescription: 'search.items.pluginsDesc' }, { icon: LayoutGrid, label: 'sidebar.menu.configTemplate', path: '/config/pack-market' }, { icon: Sliders, label: 'sidebar.menu.pluginConfig', path: '/plugin-config' }, - { icon: FileSearch, label: 'sidebar.menu.logViewer', path: '/logs' }, + { icon: FileSearch, label: 'sidebar.menu.logViewer', path: '/logs', searchDescription: 'search.items.logsDesc' }, { icon: Activity, label: 'sidebar.menu.plannerMonitor', path: '/planner-monitor' }, { icon: MessageSquare, label: 'sidebar.menu.localChat', path: '/chat' }, ], @@ -43,7 +43,7 @@ export const menuSections: MenuSection[] = [ { title: 'sidebar.groups.system', items: [ - { icon: Settings, label: 'sidebar.menu.settings', path: '/settings' }, + { icon: Settings, label: 'sidebar.menu.settings', path: '/settings', searchDescription: 'search.items.settingsDesc' }, ], }, ] diff --git a/dashboard/src/components/layout/types.ts b/dashboard/src/components/layout/types.ts index 0be17225..33bf579b 100644 --- a/dashboard/src/components/layout/types.ts +++ b/dashboard/src/components/layout/types.ts @@ -9,6 +9,7 @@ export interface MenuItem { icon: ComponentType label: string path: string + searchDescription?: string tourId?: string } diff --git a/dashboard/src/components/search-dialog.tsx b/dashboard/src/components/search-dialog.tsx index 4e5dabbe..54f9fa12 100644 --- a/dashboard/src/components/search-dialog.tsx +++ b/dashboard/src/components/search-dialog.tsx @@ -1,16 +1,20 @@ -import { useState, useCallback, useMemo } from 'react' -import { Search, FileText, Server, Boxes, Smile, MessageSquare, UserCircle, FileSearch, BarChart3, Package, Settings, Home, Hash } from 'lucide-react' +import { useState, useCallback, useEffect, useMemo, useRef } from 'react' +import { Search } from 'lucide-react' import { useNavigate } from '@tanstack/react-router' import { useTranslation } from 'react-i18next' +import type { LucideProps } from 'lucide-react' import { Dialog, + DialogBody, DialogContent, DialogHeader, DialogTitle, } from '@/components/ui/dialog' import { Input } from '@/components/ui/input' -import { ScrollArea } from '@/components/ui/scroll-area' +import { ShortcutKbd } from '@/components/ui/kbd' +import { menuSections } from '@/components/layout/constants' +import { registeredRoutePaths } from '@/router' import { cn } from '@/lib/utils' interface SearchDialogProps { @@ -19,7 +23,7 @@ interface SearchDialogProps { } interface SearchItem { - icon: React.ComponentType<{ className?: string }> + icon: React.ComponentType title: string description: string path: string @@ -29,95 +33,37 @@ interface SearchItem { export function SearchDialog({ open, onOpenChange }: SearchDialogProps) { const [searchQuery, setSearchQuery] = useState('') const [selectedIndex, setSelectedIndex] = useState(0) + const inputRef = useRef(null) const navigate = useNavigate() const { t } = useTranslation() - const searchItems: SearchItem[] = useMemo(() => [ - { - icon: Home, - title: t('search.items.home'), - description: t('search.items.homeDesc'), - path: '/', - category: t('search.categories.overview'), - }, - { - icon: FileText, - title: t('search.items.botConfig'), - description: t('search.items.botConfigDesc'), - path: '/config/bot', - category: t('search.categories.config'), - }, - { - icon: Server, - title: t('search.items.modelProvider'), - description: t('search.items.modelProviderDesc'), - path: '/config/modelProvider', - category: t('search.categories.config'), - }, - { - icon: Boxes, - title: t('search.items.model'), - description: t('search.items.modelDesc'), - path: '/config/model', - category: t('search.categories.config'), - }, - { - icon: Smile, - title: t('search.items.emoji'), - description: t('search.items.emojiDesc'), - path: '/resource/emoji', - category: t('search.categories.resources'), - }, - { - icon: MessageSquare, - title: t('search.items.expression'), - description: t('search.items.expressionDesc'), - path: '/resource/expression', - category: t('search.categories.resources'), - }, - { - icon: UserCircle, - title: t('search.items.person'), - description: t('search.items.personDesc'), - path: '/resource/person', - category: t('search.categories.resources'), - }, - { - icon: Hash, - title: t('search.items.jargon'), - description: t('search.items.jargonDesc'), - path: '/resource/jargon', - category: t('search.categories.resources'), - }, - { - icon: BarChart3, - title: t('search.items.statistics'), - description: t('search.items.statisticsDesc'), - path: '/statistics', - category: t('search.categories.monitor'), - }, - { - icon: Package, - title: t('search.items.plugins'), - description: t('search.items.pluginsDesc'), - path: '/plugins', - category: t('search.categories.extensions'), - }, - { - icon: FileSearch, - title: t('search.items.logs'), - description: t('search.items.logsDesc'), - path: '/logs', - category: t('search.categories.monitor'), - }, - { - icon: Settings, - title: t('search.items.settings'), - description: t('search.items.settingsDesc'), - path: '/settings', - category: t('search.categories.system'), - }, - ], [t]) + useEffect(() => { + if (!open) { + return + } + + const frameId = window.requestAnimationFrame(() => { + inputRef.current?.focus() + }) + + return () => window.cancelAnimationFrame(frameId) + }, [open]) + + const searchItems: SearchItem[] = useMemo( + () => + menuSections.flatMap((section) => + section.items + .filter((item) => registeredRoutePaths.has(item.path)) + .map((item) => ({ + icon: item.icon, + title: t(item.label), + description: item.searchDescription ? t(item.searchDescription) : item.path, + path: item.path, + category: t(section.title), + })) + ), + [t] + ) // 过滤搜索结果 const filteredItems = searchItems.filter( @@ -155,12 +101,13 @@ export function SearchDialog({ open, onOpenChange }: SearchDialogProps) { return ( - + {t('search.title')}
{ setSearchQuery(e.target.value) @@ -169,13 +116,12 @@ export function SearchDialog({ open, onOpenChange }: SearchDialogProps) { onKeyDown={handleKeyDown} placeholder={t('search.placeholder')} className="h-12 pl-11 text-base border-0 focus-visible:ring-0 shadow-none" - autoFocus />
- + {filteredItems.length > 0 ? (
{filteredItems.map((item, index) => { @@ -192,7 +138,7 @@ export function SearchDialog({ open, onOpenChange }: SearchDialogProps) { : 'hover:bg-accent/50' )} > - +
{item.title}
@@ -214,22 +160,22 @@ export function SearchDialog({ open, onOpenChange }: SearchDialogProps) {

)} - +
- - + + {t('search.navigate')} - Enter + {t('search.select')} - Esc + {t('search.close')}
diff --git a/dashboard/src/components/share-pack-dialog.tsx b/dashboard/src/components/share-pack-dialog.tsx index 2ec6171a..dd20a97c 100644 --- a/dashboard/src/components/share-pack-dialog.tsx +++ b/dashboard/src/components/share-pack-dialog.tsx @@ -24,6 +24,7 @@ import { Checkbox } from '@/components/ui/checkbox' import { Badge } from '@/components/ui/badge' import { Dialog, + DialogBody, DialogContent, DialogDescription, DialogFooter, @@ -34,7 +35,6 @@ import { import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs' import { Alert, AlertDescription, AlertTitle } from '@/components/ui/alert' import { Separator } from '@/components/ui/separator' -import { ScrollArea } from '@/components/ui/scroll-area' import { toast } from '@/hooks/use-toast' import { createPack, @@ -340,7 +340,7 @@ export function SharePackDialog({ trigger }: SharePackDialogProps) { )} - + @@ -353,7 +353,7 @@ export function SharePackDialog({ trigger }: SharePackDialogProps) { - + {loading ? (
@@ -639,7 +639,7 @@ export function SharePackDialog({ trigger }: SharePackDialogProps) { )} )} - +
@@ -662,6 +662,7 @@ export function SharePackDialog({ trigger }: SharePackDialogProps) { {step < totalSteps ? ( ) : ( - diff --git a/dashboard/src/components/ui/dialog.tsx b/dashboard/src/components/ui/dialog.tsx index cd10f20d..8509d085 100644 --- a/dashboard/src/components/ui/dialog.tsx +++ b/dashboard/src/components/ui/dialog.tsx @@ -1,8 +1,13 @@ import * as React from "react" import * as DialogPrimitive from "@radix-ui/react-dialog" + import { cn } from "@/lib/utils" import { X } from "lucide-react" +import { isEditableTarget, matchesShortcut } from "@/lib/keyboard" + +import { ScrollArea } from "@/components/ui/scroll-area" + const Dialog = DialogPrimitive.Root const DialogTrigger = DialogPrimitive.Trigger @@ -32,22 +37,48 @@ interface DialogContentProps preventOutsideClose?: boolean /** 隐藏默认关闭按钮(当使用自定义关闭按钮时) */ hideCloseButton?: boolean + /** 回车触发主操作按钮 */ + confirmOnEnter?: boolean +} + +interface DialogBodyProps extends React.ComponentPropsWithoutRef { + allowHorizontalScroll?: boolean } const DialogContent = React.forwardRef< React.ElementRef, DialogContentProps ->(({ className, children, preventOutsideClose = false, hideCloseButton = false, ...props }, ref) => ( +>(({ className, children, preventOutsideClose = false, hideCloseButton = false, confirmOnEnter = false, onKeyDownCapture, ...props }, ref) => ( e.preventDefault() : undefined} onInteractOutside={preventOutsideClose ? (e) => e.preventDefault() : undefined} + onKeyDownCapture={(event) => { + onKeyDownCapture?.(event) + if ( + !confirmOnEnter || + event.defaultPrevented || + !matchesShortcut(event, ['enter']) || + event.nativeEvent.isComposing || + isEditableTarget(event.target) + ) { + return + } + + const confirmButton = event.currentTarget.querySelector('[data-dialog-action="confirm"]:not([disabled])') + if (!confirmButton) { + return + } + + event.preventDefault() + confirmButton.click() + }} {...props} > {children} @@ -62,6 +93,22 @@ const DialogContent = React.forwardRef< )) DialogContent.displayName = DialogPrimitive.Content.displayName +const DialogBody = React.forwardRef( + ({ className, children, allowHorizontalScroll = false, contentClassName, scrollbars, viewportClassName, ...props }, ref) => ( + + {children} + + ) +) +DialogBody.displayName = "DialogBody" + const DialogHeader = ({ className, ...props @@ -125,6 +172,7 @@ export { DialogClose, DialogContent, DialogHeader, + DialogBody, DialogFooter, DialogTitle, DialogDescription, diff --git a/dashboard/src/components/ui/kbd.tsx b/dashboard/src/components/ui/kbd.tsx index c8ae2d51..9ccd0bd7 100644 --- a/dashboard/src/components/ui/kbd.tsx +++ b/dashboard/src/components/ui/kbd.tsx @@ -1,6 +1,7 @@ import * as React from "react" import { cva, type VariantProps } from "class-variance-authority" +import { getPlatformModifierAriaLabel, getShortcutKeyLabel, type ShortcutKey } from "@/lib/keyboard" import { cn } from "@/lib/utils" const kbdVariants = cva( @@ -25,6 +26,10 @@ export interface KbdProps abbrTitle?: string } +interface ShortcutKbdProps extends Omit { + keys: ShortcutKey[] +} + const Kbd = React.forwardRef( ({ className, size, abbrTitle, children, ...props }, ref) => { return ( @@ -40,4 +45,20 @@ const Kbd = React.forwardRef( ) Kbd.displayName = "Kbd" -export { Kbd } +function ShortcutKbd({ keys, className, size, ...props }: ShortcutKbdProps) { + return ( + + {keys.map((key) => { + const label = getShortcutKeyLabel(key) + const abbrTitle = key === 'mod' ? getPlatformModifierAriaLabel() : undefined + return ( + + {label} + + ) + })} + + ) +} + +export { Kbd, ShortcutKbd } diff --git a/dashboard/src/components/ui/scroll-area.tsx b/dashboard/src/components/ui/scroll-area.tsx index c3e868be..34f463eb 100644 --- a/dashboard/src/components/ui/scroll-area.tsx +++ b/dashboard/src/components/ui/scroll-area.tsx @@ -5,22 +5,25 @@ import { cn } from "@/lib/utils" interface ScrollAreaProps extends React.ComponentPropsWithoutRef { viewportRef?: React.RefObject + viewportClassName?: string + contentClassName?: string + scrollbars?: "vertical" | "horizontal" | "both" } const ScrollArea = React.forwardRef< React.ElementRef, ScrollAreaProps ->(({ className, children, viewportRef, ...props }, ref) => ( +>(({ className, children, viewportRef, viewportClassName, contentClassName, scrollbars = "both", ...props }, ref) => ( - - {children} + +
{children}
- - + {scrollbars !== "horizontal" && } + {scrollbars !== "vertical" && }
)) @@ -36,9 +39,9 @@ const ScrollBar = React.forwardRef< className={cn( "flex touch-none select-none transition-colors", orientation === "vertical" && - "h-full w-2.5 border-l border-l-transparent p-[1px]", + "h-full w-2.5 border-l border-l-transparent p-px", orientation === "horizontal" && - "h-2.5 flex-col border-t border-t-transparent p-[1px]", + "h-2.5 flex-col border-t border-t-transparent p-px", className )} {...props} diff --git a/dashboard/src/lib/annual-report-api.ts b/dashboard/src/lib/annual-report-api.ts deleted file mode 100644 index c19ce2de..00000000 --- a/dashboard/src/lib/annual-report-api.ts +++ /dev/null @@ -1,136 +0,0 @@ -import { fetchWithAuth } from './fetch-with-auth' - -export interface TimeFootprintData { - total_online_hours: number - first_message_time: string | null - first_message_user: string | null - first_message_content: string | null - busiest_day: string | null - busiest_day_count: number - hourly_distribution: number[] - midnight_chat_count: number - is_night_owl: boolean -} - -export interface SocialNetworkData { - total_groups: number - top_groups: Array<{ - group_id: string - group_name: string - message_count: number - is_webui?: boolean - }> - top_users: Array<{ - user_id: string - user_nickname: string - message_count: number - is_webui?: boolean - }> - at_count: number - mentioned_count: number - longest_companion_user: string | null - longest_companion_days: number -} - -export interface BrainPowerData { - total_tokens: number - total_cost: number - favorite_model: string | null - favorite_model_count: number - model_distribution: Array<{ - model: string - count: number - tokens: number - cost: number - }> - top_reply_models: Array<{ - model: string - count: number - }> - most_expensive_cost: number - most_expensive_time: string | null - top_token_consumers: Array<{ - user_id: string - cost: number - tokens: number - }> - silence_rate: number - total_actions: number - no_reply_count: number - avg_interest_value: number - max_interest_value: number - max_interest_time: string | null - avg_reasoning_length: number - max_reasoning_length: number - max_reasoning_time: string | null -} - -export interface ExpressionVibeData { - top_emoji: { - id: number - path: string - description: string - usage_count: number - hash: string - } | null - top_emojis: Array<{ - id: number - path: string - description: string - usage_count: number - hash: string - }> - top_expressions: Array<{ - style: string - count: number - }> - rejected_expression_count: number - checked_expression_count: number - total_expressions: number - action_types: Array<{ - action: string - count: number - }> - image_processed_count: number - late_night_reply: { - time: string - content: string - } | null - favorite_reply: { - content: string - count: number - } | null -} - -export interface AchievementData { - new_jargon_count: number - sample_jargons: Array<{ - content: string - meaning: string - count: number - }> - total_messages: number - total_replies: number -} - -export interface AnnualReportData { - year: number - bot_name: string - generated_at: string - time_footprint: TimeFootprintData - social_network: SocialNetworkData - brain_power: BrainPowerData - expression_vibe: ExpressionVibeData - achievements: AchievementData -} - -export async function getAnnualReport(year: number = 2025): Promise { - const response = await fetchWithAuth(`/api/webui/annual-report/full?year=${year}`) - - if (!response.ok) { - const error = await response.json() - throw new Error(error.detail || '获取年度报告失败') - } - - return response.json() -} diff --git a/dashboard/src/lib/keyboard.ts b/dashboard/src/lib/keyboard.ts new file mode 100644 index 00000000..872a0b92 --- /dev/null +++ b/dashboard/src/lib/keyboard.ts @@ -0,0 +1,93 @@ +export type ShortcutKey = + | 'mod' + | 'shift' + | 'alt' + | 'enter' + | 'esc' + | 'up' + | 'down' + | 'left' + | 'right' + | string + +const MAC_PLATFORMS = /(Mac|iPhone|iPod|iPad)/i + +export function isMacLikePlatform(): boolean { + if (typeof navigator === 'undefined') { + return false + } + + return MAC_PLATFORMS.test(navigator.platform || navigator.userAgent) +} + +export function getShortcutKeyLabel(key: ShortcutKey): string { + const isMacLike = isMacLikePlatform() + const normalizedKey = key.toLowerCase() + + switch (normalizedKey) { + case 'mod': + return isMacLike ? '⌘' : 'Ctrl' + case 'shift': + return isMacLike ? '⇧' : 'Shift' + case 'alt': + return isMacLike ? '⌥' : 'Alt' + case 'enter': + return isMacLike ? '↵' : 'Enter' + case 'esc': + case 'escape': + return 'Esc' + case 'up': + return '↑' + case 'down': + return '↓' + case 'left': + return '←' + case 'right': + return '→' + default: + return key.length === 1 ? key.toUpperCase() : key + } +} + +export function getPlatformModifierAriaLabel(): string { + return isMacLikePlatform() ? 'Command' : 'Control' +} + +export function matchesShortcut(event: KeyboardEvent | React.KeyboardEvent, keys: ShortcutKey[]): boolean { + const normalizedKeys = keys.map((key) => key.toLowerCase()) + const eventKey = event.key.toLowerCase() + + const modifierChecks = { + mod: isMacLikePlatform() ? event.metaKey : event.ctrlKey, + shift: event.shiftKey, + alt: event.altKey, + } + + for (const key of normalizedKeys) { + if (key in modifierChecks) { + if (!modifierChecks[key as keyof typeof modifierChecks]) { + return false + } + continue + } + + if (eventKey !== key) { + return false + } + } + + return true +} + +export function isEditableTarget(target: EventTarget | null): boolean { + if (!(target instanceof HTMLElement)) { + return false + } + + return ( + target.tagName === 'INPUT' || + target.tagName === 'TEXTAREA' || + target.isContentEditable || + target.getAttribute('role') === 'textbox' + ) +} \ No newline at end of file diff --git a/dashboard/src/lib/settings-manager.ts b/dashboard/src/lib/settings-manager.ts index 3ffbb8f2..de173657 100644 --- a/dashboard/src/lib/settings-manager.ts +++ b/dashboard/src/lib/settings-manager.ts @@ -23,9 +23,6 @@ export const STORAGE_KEYS = { WS_MAX_RECONNECT_ATTEMPTS: 'maibot-ws-max-reconnect-attempts', // 用户数据 - // 注意:ACCESS_TOKEN 已弃用,现在使用 HttpOnly Cookie 存储认证信息 - // 保留此常量仅用于向后兼容和清理旧数据 - ACCESS_TOKEN: 'access-token', COMPLETED_TOURS: 'maibot-completed-tours', CHAT_USER_ID: 'maibot_webui_user_id', CHAT_USER_NAME: 'maibot_webui_user_name', @@ -211,10 +208,8 @@ export function clearLocalCache(): { clearedKeys: string[]; preservedKeys: strin const keysToRemove: string[] = [] for (let i = 0; i < localStorage.length; i++) { const key = localStorage.key(i) - if (key) { - if (key.startsWith('maibot') || key.startsWith('accent-color') || key === 'access-token') { - keysToRemove.push(key) - } + if (key && (key.startsWith('maibot') || key.startsWith('accent-color'))) { + keysToRemove.push(key) } } diff --git a/dashboard/src/router.tsx b/dashboard/src/router.tsx index 4ef072b2..26b70d8e 100644 --- a/dashboard/src/router.tsx +++ b/dashboard/src/router.tsx @@ -24,7 +24,6 @@ import { PluginMirrorsPage } from './routes/plugin-mirrors' import { PluginDetailPage } from './routes/plugin-detail' import { ChatPage } from './routes/chat/index' import { WebUIFeedbackSurveyPage, MaiBotFeedbackSurveyPage } from './routes/survey' -import { AnnualReportPage } from './routes/annual-report' import PackMarketPage from './routes/config/pack-market' import PackDetailPage from './routes/config/pack-detail' import { Layout } from './components/layout' @@ -241,13 +240,6 @@ const maibotFeedbackSurveyRoute = createRoute({ component: MaiBotFeedbackSurveyPage, }) -// 年度报告路由 -const annualReportRoute = createRoute({ - getParentRoute: () => protectedRoute, - path: '/annual-report', - component: AnnualReportPage, -}) - // 404 路由 const notFoundRoute = createRoute({ getParentRoute: () => rootRoute, @@ -284,11 +276,23 @@ const routeTree = rootRoute.addChildren([ packDetailRoute, webuiFeedbackSurveyRoute, maibotFeedbackSurveyRoute, - annualReportRoute, ]), notFoundRoute, ]) +type RouteNode = { + fullPath?: string + children?: RouteNode[] +} + +function collectRoutePaths(node: RouteNode): string[] { + const currentPath = node.fullPath ? [node.fullPath] : [] + const childPaths = node.children?.flatMap(collectRoutePaths) ?? [] + return [...currentPath, ...childPaths] +} + +export const registeredRoutePaths = new Set(collectRoutePaths(routeTree as RouteNode)) + // 创建路由器 export const router = createRouter({ routeTree, diff --git a/dashboard/src/routes/annual-report.tsx b/dashboard/src/routes/annual-report.tsx deleted file mode 100644 index 528e9496..00000000 --- a/dashboard/src/routes/annual-report.tsx +++ /dev/null @@ -1,883 +0,0 @@ -import { useState, useRef, useEffect, useCallback } from 'react' -import { getAnnualReport, type AnnualReportData } from '@/lib/annual-report-api' -import { Card, CardContent, CardHeader, CardTitle, CardDescription } from '@/components/ui/card' -import { Skeleton } from '@/components/ui/skeleton' -import { Badge } from '@/components/ui/badge' -import { ScrollArea } from '@/components/ui/scroll-area' -import { Button } from '@/components/ui/button' -import { useToast } from '@/hooks/use-toast' -import { toPng } from 'html-to-image' -import { - BarChart, - Bar, - XAxis, - YAxis, - CartesianGrid, - Tooltip, - ResponsiveContainer, -} from 'recharts' -import { - Clock, - Users, - Brain, - Smile, - Trophy, - Calendar, - MessageSquare, - Zap, - Moon, - Sun, - AtSign, - Heart, - Image as ImageIcon, - Bot, - Download, - Loader2, -} from 'lucide-react' -import { cn } from '@/lib/utils' - -// 颜色常量 -const COLORS = ['#0088FE', '#00C49F', '#FFBB28', '#FF8042', '#8884d8', '#82ca9d'] - -// 动态比喻生成函数 -function getOnlineHoursMetaphor(hours: number): string { - if (hours >= 8760) return "相当于全年无休,7x24小时在线!" - if (hours >= 5000) return "相当于一位全职员工的年工作时长" - if (hours >= 2000) return "相当于看完了 1000 部电影" - if (hours >= 1000) return "相当于环球飞行 80 次" - if (hours >= 500) return "相当于读完了 100 本书" - if (hours >= 100) return "相当于马拉松跑了 25 次" - return "虽然不多,但每一刻都很珍贵" -} - -function getMidnightMetaphor(count: number): string { - if (count >= 1000) return "夜深人静时的知心好友" - if (count >= 500) return "午夜场的常客" - if (count >= 100) return "偶尔熬夜的小伙伴" - if (count >= 50) return "深夜有时也会陪你聊聊" - return "早睡早起,健康作息" -} - -function getTokenMetaphor(tokens: number): string { - const millions = tokens / 1000000 - if (millions >= 100) return "思考量堪比一座图书馆" - if (millions >= 50) return "相当于写了一部百科全书" - if (millions >= 10) return "脑细胞估计消耗了不少" - if (millions >= 1) return "也算是费了一番脑筋" - return "轻轻松松,游刃有余" -} - -function getCostMetaphor(cost: number): string { - if (cost >= 1000) return "这钱够吃一年的泡面了" - if (cost >= 500) return "相当于买了一台游戏机" - if (cost >= 100) return "够请大家喝几杯奶茶" - if (cost >= 50) return "一顿火锅的钱" - if (cost >= 10) return "几杯咖啡的价格" - return "省钱小能手" -} - -function getSilenceMetaphor(rate: number): string { - if (rate >= 80) return "沉默是金,惜字如金" - if (rate >= 60) return "话不多但句句到位" - if (rate >= 40) return "该说的时候才开口" - if (rate >= 20) return "能聊的都聊了" - return "话痨本痨,有问必答" -} - -function getImageMetaphor(count: number): string { - if (count >= 10000) return "眼睛都快看花了" - if (count >= 5000) return "堪比专业摄影师的阅片量" - if (count >= 1000) return "看图小达人" - if (count >= 500) return "图片鉴赏家" - if (count >= 100) return "偶尔欣赏一下美图" - return "图片?有空再看" -} - -function getRejectedMetaphor(count: number): string { - if (count >= 500) return "在不断的纠正中成长" - if (count >= 200) return "学习永无止境" - if (count >= 100) return "虚心接受,积极改正" - if (count >= 50) return "偶尔也会犯错" - if (count >= 10) return "表现还算不错" - return "完美表达,无需纠正" -} - -function getExpensiveThinkingMetaphor(cost: number): string { - if (cost >= 1) return "这次思考的价值堪比一顿大餐!" - if (cost >= 0.5) return "为了这个问题,我可是认真思考了!" - if (cost >= 0.1) return "下了点功夫,值得的!" - if (cost >= 0.01) return "花了点小钱,但很值得" - return "小小思考,不足挂齿" -} - -function getFavoriteReplyMetaphor(count: number, botName: string): string { - if (count >= 100) return "这句话简直是万能钥匙!" - if (count >= 50) return "百试不爽的经典回复" - if (count >= 20) return `${botName}的口头禅` - if (count >= 10) return "常用语录之一" - return "偶尔用用的小确幸" -} - -function getNightOwlMetaphor(isNightOwl: boolean, midnightCount: number): string { - if (isNightOwl) { - if (midnightCount >= 1000) return "深夜的守护者,黑暗中的光芒" - if (midnightCount >= 500) return "月亮是我的好朋友" - if (midnightCount >= 100) return "越夜越精神,夜晚才是主场" - return "偶尔熬夜,享受宁静时光" - } else { - if (midnightCount <= 10) return "作息规律,健康生活的典范" - if (midnightCount <= 50) return "早睡早起,偶尔也会熬个夜" - return "虽然是早起鸟,但也会守候深夜" - } -} - -function getBusiestDayMetaphor(count: number): string { - if (count >= 1000) return "忙到飞起,键盘都要冒烟了" - if (count >= 500) return "这天简直是话痨附体" - if (count >= 200) return "社交达人上线" - if (count >= 100) return "比平时活跃不少" - if (count >= 50) return "小忙一下" - return "还算轻松的一天" -} - -export function AnnualReportPage() { - const [year] = useState(2025) - const [data, setData] = useState(null) - const [isLoading, setIsLoading] = useState(true) - const [isExporting, setIsExporting] = useState(false) - const [error, setError] = useState(null) - const reportRef = useRef(null) - const { toast } = useToast() - - const loadReport = useCallback(async () => { - try { - setIsLoading(true) - setError(null) - const result = await getAnnualReport(year) - setData(result) - } catch (err) { - setError(err instanceof Error ? err : new Error('获取年度报告失败')) - } finally { - setIsLoading(false) - } - }, [year]) - - // 导出为图片 - const handleExport = useCallback(async () => { - if (!reportRef.current || !data) return - - setIsExporting(true) - toast({ - title: '正在生成图片', - description: '请稍候...', - }) - - try { - const element = reportRef.current - - // 获取当前主题的背景色 - const computedStyle = getComputedStyle(document.documentElement) - const backgroundColor = computedStyle.getPropertyValue('--background').trim() - ? `hsl(${computedStyle.getPropertyValue('--background').trim()})` - : (document.documentElement.classList.contains('dark') ? '#0a0a0a' : '#ffffff') - - // 保存原始样式 - const originalWidth = element.style.width - const originalMaxWidth = element.style.maxWidth - - // 临时设置固定宽度以去除左右空白 - element.style.width = '1024px' - element.style.maxWidth = '1024px' - - const dataUrl = await toPng(element, { - quality: 1, - pixelRatio: 2, - backgroundColor, - cacheBust: true, - filter: (node) => { - // 过滤掉导出按钮 - if (node instanceof HTMLElement && node.hasAttribute('data-export-btn')) { - return false - } - return true - }, - }) - - // 恢复原始样式 - element.style.width = originalWidth - element.style.maxWidth = originalMaxWidth - - // 创建下载链接 - const link = document.createElement('a') - link.download = `${data.bot_name}_${data.year}_年度总结.png` - link.href = dataUrl - link.click() - - toast({ - title: '导出成功', - description: '年度报告已保存为图片', - }) - } catch (err) { - console.error('导出图片失败:', err) - toast({ - title: '导出失败', - description: '请重试', - variant: 'destructive', - }) - } finally { - setIsExporting(false) - } - }, [data, toast]) - - useEffect(() => { - loadReport() - }, [loadReport]) - - if (isLoading) { - return - } - - if (error) { - return ( -
- 获取年度报告失败: {error.message} -
- ) - } - - if (!data) return null - - return ( - -
-
- {/* 头部 Hero */} -
- {/* 导出按钮 */} -
- -
-
- -

- {data.bot_name} {data.year} 年度总结 -

-

- 连接与成长 · Connection & Growth -

-
- - 生成时间: {data.generated_at} -
-
- {/* 背景装饰 */} -
-
-
- - {/* 维度一:时光足迹 */} -
-
- -

时光足迹

-
-
- } - /> - } - /> - } - /> - : } - /> -
- - - - 24小时活跃时钟 - {data.bot_name}在一天中各个时段的活跃程度 - - - - ({ hour: `${hour}点`, count }))}> - - - - - - - - - - - {data.time_footprint.first_message_time && ( - - -

2025年的故事开始于

-
{data.time_footprint.first_message_time}
-

- {data.time_footprint.first_message_user} 说: - "{data.time_footprint.first_message_content}" -

-
-
- )} -
- - {/* 维度二:社交网络 */} -
-
- -

社交网络

-
-
- } - /> - } - /> - } - /> -
- -
- - - 话痨群组 TOP5 - - -
- {data.social_network.top_groups.length > 0 ? ( - data.social_network.top_groups.map((group: { group_id: string; group_name: string; message_count: number; is_webui?: boolean }, index: number) => ( -
-
- - {index + 1} - - {group.group_name} - {group.is_webui && ( - - WebUI - - )} -
- {group.message_count} 条消息 -
- )) - ) : ( -
暂无数据
- )} -
-
-
- - - 年度最佳损友 TOP5 - - -
- {data.social_network.top_users.length > 0 ? ( - data.social_network.top_users.map((user: { user_id: string; user_nickname: string; message_count: number; is_webui?: boolean }, index: number) => ( -
-
- - {index + 1} - - {user.user_nickname} - {user.is_webui && ( - - WebUI - - )} -
- {user.message_count} 次互动 -
- )) - ) : ( -
暂无数据
- )} -
-
-
-
-
- - {/* 维度三:最强大脑 */} -
-
- -

最强大脑

-
-
- } - /> - $} - /> - } - /> - } - /> -
- -
- - - 模型偏好分布 - - -
- {data.brain_power.model_distribution.slice(0, 5).map((item: { model: string; count: number }, index: number) => { - const maxCount = data.brain_power.model_distribution[0]?.count || 1 - const percentage = Math.round((item.count / maxCount) * 100) - return ( -
-
- {item.model} - {item.count.toLocaleString()} 次 -
-
-
-
-
- ) - })} -
- - - - {/* 最喜欢的回复模型 TOP5 */} - {data.brain_power.top_reply_models && data.brain_power.top_reply_models.length > 0 && ( - - - 最喜欢的回复模型 TOP5 - {data.bot_name}用来回复消息的模型偏好 - - -
- {data.brain_power.top_reply_models.map((item: { model: string; count: number }, index: number) => { - const maxCount = data.brain_power.top_reply_models[0]?.count || 1 - const percentage = Math.round((item.count / maxCount) * 100) - return ( -
-
- {item.model} - {item.count.toLocaleString()} 次 -
-
-
-
-
- ) - })} -
- - - )} - - {/* 烧钱大户 - 只有有有效用户数据时才显示 */} - {data.brain_power.top_token_consumers && data.brain_power.top_token_consumers.length > 0 && ( - - - 烧钱大户 TOP3 - 谁消耗了最多的 API 额度 - - -
- {data.brain_power.top_token_consumers.map((consumer: { user_id: string; cost: number; tokens: number }) => ( -
-
- 用户 {consumer.user_id} - ${consumer.cost.toFixed(2)} -
-
-
-
-
- ))} -
- - - )} -
- - {/* 最昂贵的思考 & 思考深度 */} -
- - - - 💰 - 最昂贵的一次思考 - - - -
- ${data.brain_power.most_expensive_cost.toFixed(4)} -
- {data.brain_power.most_expensive_time && ( -

- 发生在 {data.brain_power.most_expensive_time} -

- )} -

- {getExpensiveThinkingMetaphor(data.brain_power.most_expensive_cost)} -

-
-
- - - - - 🧠 - 思考深度 - - - -
-
-
- {data.brain_power.avg_reasoning_length?.toFixed(0) || 0} -
-
平均思考字数
-
-
-
- {data.brain_power.max_reasoning_length?.toLocaleString() || 0} -
-
最长思考字数
-
-
- {data.brain_power.max_reasoning_time && ( -

- 最深沉的思考发生在 {data.brain_power.max_reasoning_time} -

- )} -
-
-
-
- - {/* 维度四:个性与表达 */} -
-
- -

个性与表达

-
- - {/* 深夜回复 & 最喜欢的回复 */} - {(data.expression_vibe.late_night_reply || data.expression_vibe.favorite_reply) && ( -
- {data.expression_vibe.late_night_reply && ( - - - - 🌙 - 深夜还在回复 - - 凌晨 {data.expression_vibe.late_night_reply.time},{data.bot_name}还在回复... - - -

- "{data.expression_vibe.late_night_reply.content}" -

-

- 是有什么心事吗? -

-
-
- )} - - {data.expression_vibe.favorite_reply && ( - - - - 💬 - 最喜欢的回复 - - 使用了 {data.expression_vibe.favorite_reply.count} 次 - - -

- "{data.expression_vibe.favorite_reply.content}" -

-

- {getFavoriteReplyMetaphor(data.expression_vibe.favorite_reply.count, data.bot_name)} -

-
-
- )} -
- )} - -
- {/* 使用最多的表情包 TOP3 */} - - - 使用最多的表情包 TOP3 - 年度最爱的表情包们 - - - {data.expression_vibe.top_emojis && data.expression_vibe.top_emojis.length > 0 ? ( -
- {data.expression_vibe.top_emojis.slice(0, 3).map((emoji: { id: number; usage_count: number }, index: number) => ( -
-
- {`TOP - - {index + 1} - -
-

{emoji.usage_count} 次

-
- ))} -
- ) : ( -
暂无数据
- )} -
-
- -
- - - 印象最深刻的表达风格 - {data.bot_name}最常使用的表达方式 - - -
- {data.expression_vibe.top_expressions.map((exp: { style: string; count: number }, index: number) => ( - - {exp.style} ({exp.count}) - - ))} -
-
-
- -
- } - /> - } - /> -
-
-
- - {/* 行动派 */} - {data.expression_vibe.action_types.length > 0 && ( - - - - - 行动派 - - 除了聊天,我还帮大家做了这些事 - - -
- {data.expression_vibe.action_types.map((action: { action: string; count: number }) => ( -
- {action.action} - {action.count} 次 -
- ))} -
-
-
- )} -
- - {/* 维度五:趣味成就 */} -
-
- -

趣味成就

-
- -
- - - 新学到的"黑话" - 今年我学会了 {data.achievements.new_jargon_count} 个新词 - - -
- {data.achievements.sample_jargons.map((jargon: { content: string; meaning: string; count: number }) => ( -
-
{jargon.content}
-
- {jargon.meaning || '暂无解释'} -
-
- ))} -
-
-
- - - - -
{data.achievements.total_messages.toLocaleString()}
-
年度总消息数
-
- 其中回复了 {data.achievements.total_replies.toLocaleString()} 次 -
-
-
-
-
- - {/* 底部 */} -
-

MaiBot 2025 Annual Report

-

Generated with ❤️ by MaiBot Team

-
-
-
-
- ) -} - -function StatCard({ - title, - value, - description, - icon, -}: { - title: string - value: string | number - description: string - icon: React.ReactNode -}) { - return ( - - - {title} -
{icon}
-
- -
{value}
-

{description}

-
-
- ) -} - -function LoadingSkeleton() { - return ( -
- -
- {[...Array(4)].map((_, i) => ( - - ))} -
- -
- ) -} diff --git a/dashboard/src/routes/chat/VirtualIdentityDialog.tsx b/dashboard/src/routes/chat/VirtualIdentityDialog.tsx index a150a746..ab5993d8 100644 --- a/dashboard/src/routes/chat/VirtualIdentityDialog.tsx +++ b/dashboard/src/routes/chat/VirtualIdentityDialog.tsx @@ -2,6 +2,7 @@ import { Avatar, AvatarFallback } from '@/components/ui/avatar' import { Button } from '@/components/ui/button' import { Dialog, + DialogBody, DialogContent, DialogDescription, DialogFooter, @@ -54,7 +55,7 @@ export function VirtualIdentityDialog({ }: VirtualIdentityDialogProps) { return ( - + @@ -65,7 +66,7 @@ export function VirtualIdentityDialog({ -
+ {/* 平台选择 */}
- +
{isLoadingPersons ? (
@@ -187,13 +188,14 @@ export function VirtualIdentityDialog({

)} -
+
- - + 正则表达式编辑器 @@ -257,7 +258,7 @@ function RegexEditor({ - + setActiveTab(v as 'build' | 'test')} className="w-full"> 🔧 构建器 @@ -406,7 +407,7 @@ function RegexEditor({ value={testText} onChange={(e) => setTestText(e.target.value)} placeholder="在此输入要测试的文本... 例如:打游戏是这样的" - className="min-h-[100px] text-sm" + className="min-h-25 text-sm" />
@@ -444,7 +445,7 @@ function RegexEditor({
-
+
{renderHighlightedText()}
@@ -458,7 +459,7 @@ function RegexEditor({
{Object.entries(captureGroups).map(([name, value]) => (
- [{name}] + [{name}] = {value}
@@ -473,7 +474,7 @@ function RegexEditor({
-
+
{replacedReaction}
@@ -497,7 +498,7 @@ function RegexEditor({
-
+
) @@ -628,7 +629,7 @@ export const ProcessingSection = React.memo(function ProcessingSection({ 预览 - +

配置预览

@@ -656,7 +657,7 @@ export const ProcessingSection = React.memo(function ProcessingSection({ 预览 - +

配置预览

diff --git a/dashboard/src/routes/config/model.tsx b/dashboard/src/routes/config/model.tsx index 535d97de..d7a3d0cf 100644 --- a/dashboard/src/routes/config/model.tsx +++ b/dashboard/src/routes/config/model.tsx @@ -6,6 +6,7 @@ import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs' import { ScrollArea } from '@/components/ui/scroll-area' import { Dialog, + DialogBody, DialogContent, DialogDescription, DialogFooter, @@ -971,9 +972,10 @@ function ModelConfigPageContent() { {/* 编辑模型对话框 */} @@ -982,6 +984,7 @@ function ModelConfigPageContent() { 配置模型的基本信息和参数 +
@@ -1492,12 +1495,13 @@ function ModelConfigPageContent() { )}
+
- +
diff --git a/dashboard/src/routes/config/modelProvider/ProviderForm.tsx b/dashboard/src/routes/config/modelProvider/ProviderForm.tsx index 87b9204c..4e28bd17 100644 --- a/dashboard/src/routes/config/modelProvider/ProviderForm.tsx +++ b/dashboard/src/routes/config/modelProvider/ProviderForm.tsx @@ -3,7 +3,7 @@ import { Check, ChevronsUpDown, Copy, Eye, EyeOff } from 'lucide-react' import { Button } from '@/components/ui/button' import { Command, CommandEmpty, CommandGroup, CommandInput, CommandItem, CommandList } from '@/components/ui/command' -import { Dialog, DialogContent, DialogDescription, DialogFooter, DialogHeader, DialogTitle } from '@/components/ui/dialog' +import { Dialog, DialogBody, DialogContent, DialogDescription, DialogFooter, DialogHeader, DialogTitle } from '@/components/ui/dialog' import { HelpTooltip } from '@/components/ui/help-tooltip' import { Input } from '@/components/ui/input' import { Label } from '@/components/ui/label' @@ -116,9 +116,10 @@ export function ProviderForm({ return ( @@ -130,6 +131,7 @@ export function ProviderForm({
{ e.preventDefault(); handleSaveEdit(); }} autoComplete="off"> +
@@ -450,12 +452,13 @@ export function ProviderForm({
+ - + diff --git a/dashboard/src/routes/config/pack-detail.tsx b/dashboard/src/routes/config/pack-detail.tsx index a1d7b027..2fcaef4a 100644 --- a/dashboard/src/routes/config/pack-detail.tsx +++ b/dashboard/src/routes/config/pack-detail.tsx @@ -40,6 +40,7 @@ import { ScrollArea } from '@/components/ui/scroll-area' import { Alert, AlertDescription, AlertTitle } from '@/components/ui/alert' import { Dialog, + DialogBody, DialogContent, DialogDescription, DialogFooter, @@ -575,7 +576,7 @@ function ApplyDialog({ return ( - + @@ -589,6 +590,7 @@ function ApplyDialog({ + {detectingConflicts ? (
@@ -831,6 +833,7 @@ function ApplyDialog({ )} )} +
@@ -845,11 +848,11 @@ function ApplyDialog({ 取消 {step < totalSteps ? ( - ) : ( - diff --git a/dashboard/src/routes/person.tsx b/dashboard/src/routes/person.tsx index 8a1d8430..8685aa8b 100644 --- a/dashboard/src/routes/person.tsx +++ b/dashboard/src/routes/person.tsx @@ -29,6 +29,7 @@ import { Button } from '@/components/ui/button' import { Checkbox } from '@/components/ui/checkbox' import { Dialog, + DialogBody, DialogContent, DialogDescription, DialogFooter, @@ -573,7 +574,7 @@ export function PersonManagementPage() { variant="outline" size="sm" onClick={() => handleViewDetail(person)} - className="text-xs px-2 py-1 h-auto flex-shrink-0" + className="text-xs px-2 py-1 h-auto shrink-0" > 查看 @@ -582,7 +583,7 @@ export function PersonManagementPage() { variant="outline" size="sm" onClick={() => handleEdit(person)} - className="text-xs px-2 py-1 h-auto flex-shrink-0" + className="text-xs px-2 py-1 h-auto shrink-0" > 编辑 @@ -591,7 +592,7 @@ export function PersonManagementPage() { variant="outline" size="sm" onClick={() => setDeleteConfirmPerson(person)} - className="text-xs px-2 py-1 h-auto flex-shrink-0 text-destructive hover:text-destructive" + className="text-xs px-2 py-1 h-auto shrink-0 text-destructive hover:text-destructive" > 删除 @@ -771,7 +772,7 @@ function PersonDetailDialog({ return ( - + 人物详情 @@ -779,6 +780,7 @@ function PersonDetailDialog({ +
{/* 基本信息 */}
@@ -829,6 +831,7 @@ function PersonDetailDialog({
+
@@ -919,7 +922,7 @@ function PersonEditDialog({ return ( - + 编辑人物信息 @@ -927,6 +930,7 @@ function PersonEditDialog({ +
@@ -974,6 +978,7 @@ function PersonEditDialog({ />
+
- +
) @@ -252,11 +253,12 @@ export function EmojiEditDialog({ return ( - + 编辑表情包 修改表情包的情绪和状态信息 +
@@ -310,11 +312,12 @@ export function EmojiEditDialog({
+ - @@ -658,7 +661,7 @@ export function EmojiUploadDialog({
{/* 预览图 */} -
+
{/* 左侧:文件卡片列表 */} - +
{uploadedFiles.map((file) => { const complete = isFileComplete(file) @@ -782,7 +785,7 @@ export function EmojiUploadDialog({ ${complete ? 'border-green-500 bg-green-50 dark:bg-green-950/20' : 'border-border hover:border-muted-foreground/50'} `} > -
+
{file.name}
{complete ? ( - + ) : ( -
+
)}
) @@ -908,7 +911,7 @@ export function EmojiUploadDialog({ return ( - + @@ -925,11 +928,11 @@ export function EmojiUploadDialog({ -
+ {step === 'select' && renderSelectStep()} {step === 'edit-single' && renderEditSingleStep()} {step === 'edit-multiple' && renderEditMultipleStep()} -
+
) diff --git a/dashboard/src/routes/resource/expression/ExpressionDialogs.tsx b/dashboard/src/routes/resource/expression/ExpressionDialogs.tsx index 15f3c955..d6729184 100644 --- a/dashboard/src/routes/resource/expression/ExpressionDialogs.tsx +++ b/dashboard/src/routes/resource/expression/ExpressionDialogs.tsx @@ -15,6 +15,7 @@ import { import { Button } from '@/components/ui/button' import { Dialog, + DialogBody, DialogContent, DialogDescription, DialogFooter, @@ -65,7 +66,7 @@ export function ExpressionDetailDialog({ return ( - + 表达方式详情 @@ -73,6 +74,7 @@ export function ExpressionDetailDialog({ +
@@ -131,6 +133,7 @@ export function ExpressionDetailDialog({
+ @@ -233,7 +236,7 @@ export function ExpressionCreateDialog({ return ( - + 新增表达方式 @@ -241,6 +244,7 @@ export function ExpressionCreateDialog({ +
@@ -291,12 +295,13 @@ export function ExpressionCreateDialog({
+ - @@ -371,7 +376,7 @@ export function ExpressionEditDialog({ return ( - + 编辑表达方式 @@ -379,6 +384,7 @@ export function ExpressionEditDialog({ +
@@ -474,12 +480,13 @@ export function ExpressionEditDialog({
+
- diff --git a/dashboard/src/routes/resource/jargon/JargonDialogs.tsx b/dashboard/src/routes/resource/jargon/JargonDialogs.tsx index 83d8e308..92ede436 100644 --- a/dashboard/src/routes/resource/jargon/JargonDialogs.tsx +++ b/dashboard/src/routes/resource/jargon/JargonDialogs.tsx @@ -14,6 +14,7 @@ import { import { Button } from '@/components/ui/button' import { Dialog, + DialogBody, DialogContent, DialogDescription, DialogFooter, @@ -24,7 +25,6 @@ import { Badge } from '@/components/ui/badge' import { Input } from '@/components/ui/input' import { Label } from '@/components/ui/label' import { MarkdownRenderer } from '@/components/markdown-renderer' -import { ScrollArea } from '@/components/ui/scroll-area' import { Select, SelectContent, @@ -92,7 +92,7 @@ export function JargonDetailDialog({ 查看黑话的完整信息 - +
@@ -167,7 +167,7 @@ export function JargonDetailDialog({
)}
-
+ @@ -234,12 +234,13 @@ export function JargonCreateDialog({ return ( - + 新增黑话 创建新的黑话记录 +
+
- @@ -366,12 +368,13 @@ export function JargonEditDialog({ return ( - + 编辑黑话 修改黑话的信息 +
@@ -439,10 +442,11 @@ export function JargonEditDialog({
+
- diff --git a/dashboard/src/routes/resource/knowledge-graph/GraphDialogs.tsx b/dashboard/src/routes/resource/knowledge-graph/GraphDialogs.tsx index 93cf7dca..4a189cf9 100644 --- a/dashboard/src/routes/resource/knowledge-graph/GraphDialogs.tsx +++ b/dashboard/src/routes/resource/knowledge-graph/GraphDialogs.tsx @@ -2,11 +2,11 @@ import { Badge } from '@/components/ui/badge' import { Dialog, + DialogBody, DialogContent, DialogHeader, DialogTitle, } from '@/components/ui/dialog' -import { ScrollArea } from '@/components/ui/scroll-area' import type { GraphNode, SelectedEdgeData } from './types' @@ -24,7 +24,7 @@ export function NodeDetailDialog({ open, onOpenChange, selectedNodeData }: NodeD 节点详情 {selectedNodeData && ( - +
@@ -62,7 +62,7 @@ export function NodeDetailDialog({ open, onOpenChange, selectedNodeData }: NodeD )}
- + )}
@@ -83,7 +83,7 @@ export function EdgeDetailDialog({ open, onOpenChange, selectedEdgeData }: EdgeD 边详情 {selectedEdgeData && ( - +
@@ -114,7 +114,7 @@ export function EdgeDetailDialog({ open, onOpenChange, selectedEdgeData }: EdgeD
-
+ )}
diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 494bcc56..60586406 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -190,7 +190,6 @@ class ChatBot: @staticmethod def _mark_command_message(message: SessionMessage, intercept_message_level: int) -> None: message.is_command = True - message.intercept_message_level = intercept_message_level message.message_info.additional_config["intercept_message_level"] = intercept_message_level @staticmethod diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index 075a695d..894af238 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -29,7 +29,7 @@ def get_webui_chat_broadcaster(): global _webui_chat_broadcaster if _webui_chat_broadcaster is None: try: - from src.webui.chat_routes import chat_manager, WEBUI_CHAT_PLATFORM + from src.webui.routers.chat import WEBUI_CHAT_PLATFORM, chat_manager _webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM) except ImportError: diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index 5982e0c3..22ea9228 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -1,6 +1,7 @@ import traceback import time import asyncio +import importlib import random import re @@ -11,9 +12,9 @@ from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.data_models.llm_data_model import LLMGenerationDataModel from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest -from maim_message import BaseMessageInfo, MessageBase, Seg +from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUserInfo -from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo +from src.common.data_models.mai_message_data_model import MaiMessage from src.chat.message_receive.message import SessionMessage from src.chat.message_receive.chat_manager import BotChatSession from src.chat.message_receive.uni_message_sender import UniversalMessageSender @@ -786,7 +787,6 @@ class DefaultReplyer: available_actions = {} chat_stream = self.chat_stream chat_id = chat_stream.session_id - _is_group_chat = bool(chat_stream.group_info) platform = chat_stream.platform user_id = "用户ID" @@ -795,11 +795,12 @@ class DefaultReplyer: target = "消息" if reply_message: - user_id = reply_message.user_info.user_id + reply_user_info = reply_message.message_info.user_info + user_id = reply_user_info.user_id person = Person(platform=platform, user_id=user_id) person_name = person.person_name or user_id sender = person_name - target = reply_message.processed_plain_text + target = reply_message.processed_plain_text or "" target = replace_user_references(target, chat_stream.platform, replace_bot_name=True) @@ -825,16 +826,17 @@ class DefaultReplyer: person_list_short: List[Person] = [] for msg in message_list_before_short: + msg_user_info = msg.message_info.user_info # 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI) - if is_bot_self(msg.user_info.platform, msg.user_info.user_id): + if is_bot_self(msg.platform, msg_user_info.user_id): continue if ( reply_message - and reply_message.user_info.user_id == msg.user_info.user_id - and reply_message.user_info.platform == msg.user_info.platform + and reply_message.message_info.user_info.user_id == msg_user_info.user_id + and reply_message.platform == msg.platform ): continue - person = Person(platform=msg.user_info.platform, user_id=msg.user_info.user_id) + person = Person(platform=msg.platform, user_id=msg_user_info.user_id) if person.is_known: person_list_short.append(person) @@ -847,7 +849,6 @@ class DefaultReplyer: timestamp_mode="relative", read_mark=0.0, show_actions=True, - long_time_notice=True, ) # 统一黑话解释构建:根据配置选择上下文或 Planner 模式 @@ -956,7 +957,6 @@ class DefaultReplyer: replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=True, - long_time_notice=True, ) # 获取匹配的额外prompt @@ -1121,7 +1121,7 @@ class DefaultReplyer: platform=self.chat_stream.platform, message_id=message_id, time=thinking_start_time, - user_info=UserInfo( + user_info=MaimUserInfo( user_id=str(global_config.bot.qq_account), user_nickname=global_config.bot.nickname, ), @@ -1160,7 +1160,16 @@ class DefaultReplyer: async def get_prompt_info(self, message: str, sender: str, target: str): related_info = "" start_time = time.time() - from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool + try: + knowledge_module = importlib.import_module("src.plugins.built_in.knowledge.lpmm_get_knowledge") + except ImportError: + logger.debug("LPMM知识库工具模块不存在,跳过获取知识库内容") + return "" + + search_knowledge_tool = getattr(knowledge_module, "SearchKnowledgeFromLPMMTool", None) + if search_knowledge_tool is None: + logger.debug("LPMM知识库工具未提供 SearchKnowledgeFromLPMMTool,跳过获取知识库内容") + return "" logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") # 从LPMM知识库获取知识 @@ -1183,14 +1192,14 @@ class DefaultReplyer: _, _, _, _, tool_calls = await llm_api.generate_with_model_with_tools( prompt, model_config=model_config.model_task_config.tool_use, - tool_options=[SearchKnowledgeFromLPMMTool.get_tool_definition()], + tool_options=[search_knowledge_tool.get_tool_definition()], ) # logger.info(f"工具调用提示词: {prompt}") # logger.info(f"工具调用: {tool_calls}") if tool_calls: - result = await self.tool_executor.execute_tool_call(tool_calls[0], SearchKnowledgeFromLPMMTool()) + result = await self.tool_executor.execute_tool_call(tool_calls[0]) end_time = time.time() if not result or not result.get("content"): logger.debug("从LPMM知识库获取知识失败,返回空知识...") diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py index cadb734d..3769dfd2 100644 --- a/src/chat/replyer/private_generator.py +++ b/src/chat/replyer/private_generator.py @@ -11,9 +11,9 @@ from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.data_models.llm_data_model import LLMGenerationDataModel from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest -from maim_message import BaseMessageInfo, MessageBase, Seg +from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUserInfo -from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo +from src.common.data_models.mai_message_data_model import MaiMessage from src.chat.message_receive.message import SessionMessage from src.chat.message_receive.chat_manager import BotChatSession from src.chat.message_receive.uni_message_sender import UniversalMessageSender @@ -636,11 +636,12 @@ class PrivateReplyer: target = "消息" if reply_message: - user_id = reply_message.user_info.user_id + reply_user_info = reply_message.message_info.user_info + user_id = reply_user_info.user_id person = Person(platform=platform, user_id=user_id) person_name = person.person_name or user_id sender = person_name - target = reply_message.processed_plain_text + target = reply_message.processed_plain_text or "" target = replace_user_references(target, chat_stream.platform, replace_bot_name=True) @@ -663,7 +664,6 @@ class PrivateReplyer: timestamp_mode="relative", read_mark=0.0, show_actions=True, - long_time_notice=True, ) message_list_before_short = get_messages_before_time_in_chat( @@ -675,16 +675,17 @@ class PrivateReplyer: person_list_short: List[Person] = [] for msg in message_list_before_short: + msg_user_info = msg.message_info.user_info # 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI) - if is_bot_self(msg.user_info.platform, msg.user_info.user_id): + if is_bot_self(msg.platform, msg_user_info.user_id): continue if ( reply_message - and reply_message.user_info.user_id == msg.user_info.user_id - and reply_message.user_info.platform == msg.user_info.platform + and reply_message.message_info.user_info.user_id == msg_user_info.user_id + and reply_message.platform == msg.platform ): continue - person = Person(platform=msg.user_info.platform, user_id=msg.user_info.user_id) + person = Person(platform=msg.platform, user_id=msg_user_info.user_id) if person.is_known: person_list_short.append(person) @@ -960,7 +961,7 @@ class PrivateReplyer: platform=self.chat_stream.platform, message_id=message_id, time=thinking_start_time, - user_info=UserInfo( + user_info=MaimUserInfo( user_id=str(global_config.bot.qq_account), user_nickname=global_config.bot.nickname, ), diff --git a/src/common/logger_color_and_mapping.py b/src/common/logger_color_and_mapping.py index 803c0371..1aabafbc 100644 --- a/src/common/logger_color_and_mapping.py +++ b/src/common/logger_color_and_mapping.py @@ -76,7 +76,6 @@ MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = { "webui.jargon": ("#d7d75f", None, False), "webui.person": ("#87d787", None, False), "webui.statistics": ("#af87ff", None, False), - "webui.annual_report": ("#ffaf87", None, False), "webui.plugin_routes": ("#ffaf00", None, False), "webui.plugin_progress": ("#ff8700", None, False), "webui.git_mirror": ("#878787", None, False), @@ -180,7 +179,6 @@ MODULE_ALIASES = { "webui.jargon": "WebUI黑话", "webui.person": "WebUI人物", "webui.statistics": "WebUI统计", - "webui.annual_report": "WebUI年报", "webui.plugin_routes": "WebUI插件", "webui.plugin_progress": "WebUI插件进度", "webui.git_mirror": "WebUI镜像", diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 7b35ae07..a89b6f49 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -1,6 +1,5 @@ import traceback from datetime import datetime -from types import SimpleNamespace from typing import Any import json @@ -48,48 +47,8 @@ def _parse_additional_config(message: Messages) -> dict[str, Any]: return {} -def _normalize_optional_str(value: object) -> str | None: - if value is None: - return None - if isinstance(value, str): - return value - try: - return json.dumps(value, ensure_ascii=False) - except (TypeError, ValueError): - return str(value) - - def _message_to_instance(message: Messages) -> SessionMessage: - config = _parse_additional_config(message) - instance = SessionMessage.from_db_instance(message) - instance.interest_value = config.get("interest_value") - instance.key_words = _normalize_optional_str(config.get("key_words")) - instance.key_words_lite = _normalize_optional_str(config.get("key_words_lite")) - instance.reply_probability_boost = config.get("reply_probability_boost") - instance.priority_mode = _normalize_optional_str(config.get("priority_mode")) - instance.priority_info = _normalize_optional_str(config.get("priority_info")) - instance.intercept_message_level = config.get("intercept_message_level", 0) - instance.selected_expressions = _normalize_optional_str(config.get("selected_expressions")) - group_info = instance.message_info.group_info - legacy_group_info = None - if group_info: - legacy_group_info = SimpleNamespace( - group_id=group_info.group_id, - group_name=group_info.group_name, - ) - instance.user_info = SimpleNamespace( - user_id=instance.message_info.user_info.user_id, - user_nickname=instance.message_info.user_info.user_nickname, - user_cardname=instance.message_info.user_info.user_cardname, - platform=instance.platform, - ) - instance.chat_info = SimpleNamespace( - platform=instance.platform, - stream_id=instance.session_id, - group_info=legacy_group_info, - ) - instance.time = instance.timestamp.timestamp() - return instance + return SessionMessage.from_db_instance(message) def _coerce_datetime(value: Any) -> Any: @@ -118,6 +77,7 @@ def _build_message_conditions( end_time: float | None = None, before_time: float | None = None, after_time: float | None = None, + has_reply_to: bool | None = None, ) -> list[Any]: conditions: list[Any] = [Messages.message_id != "notice"] @@ -141,6 +101,10 @@ def _build_message_conditions( conditions.append(Messages.timestamp < _coerce_datetime(before_time)) if after_time is not None: conditions.append(Messages.timestamp > _coerce_datetime(after_time)) + if has_reply_to is True: + conditions.append(col(Messages.reply_to).is_not(None)) + elif has_reply_to is False: + conditions.append(col(Messages.reply_to).is_(None)) return conditions @@ -261,6 +225,7 @@ def count_messages( end_time: float | None = None, before_time: float | None = None, after_time: float | None = None, + has_reply_to: bool | None = None, ) -> int: """ 根据提供的过滤器计算消息数量。 @@ -276,6 +241,7 @@ def count_messages( end_time: 结束时间,闭区间上界。 before_time: 严格早于该时间。 after_time: 严格晚于该时间。 + has_reply_to: 是否要求存在 reply_to 字段。 Returns: 符合条件的消息数量,如果出错则返回 0。 @@ -292,6 +258,7 @@ def count_messages( end_time=end_time, before_time=before_time, after_time=after_time, + has_reply_to=has_reply_to, ) statement = select(func.count()).select_from(Messages).where(*conditions) with get_db_session() as session: @@ -302,7 +269,7 @@ def count_messages( "使用 SQLModel 计数消息失败 " f"(session_id={session_id}, user_id={user_id}, group_id={group_id}, platform={platform}, " f"message_id={message_id}, reply_to={reply_to}, start_time={start_time}, end_time={end_time}, " - f"before_time={before_time}, after_time={after_time}): {e}\n{traceback.format_exc()}" + f"before_time={before_time}, after_time={after_time}, has_reply_to={has_reply_to}): {e}\n{traceback.format_exc()}" ) logger.error(log_message) return 0 diff --git a/src/common/utils/port_checker.py b/src/common/utils/port_checker.py index fc5acd64..60f783ac 100644 --- a/src/common/utils/port_checker.py +++ b/src/common/utils/port_checker.py @@ -25,13 +25,15 @@ def is_port_conflict_error(error: OSError) -> bool: return "address already in use" in message or "已被占用" in message -def check_port_available(host: str, port: int) -> bool: +def check_port_available(host: str, port: int, *, allow_reuse_addr: bool = False) -> bool: family = _detect_socket_family(host) test_host = _normalize_test_host(host) try: with socket.socket(family, socket.SOCK_STREAM) as test_socket: test_socket.settimeout(1) + if allow_reuse_addr: + test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) test_socket.bind((test_host, port)) return True except OSError: @@ -65,8 +67,9 @@ def assert_port_available( service_name: str, logger, config_hint: Optional[str] = None, + allow_reuse_addr: bool = False, ) -> None: - if check_port_available(host=host, port=port): + if check_port_available(host=host, port=port, allow_reuse_addr=allow_reuse_addr): return log_port_conflict( diff --git a/src/webui/api/planner.py b/src/webui/api/planner.py index 981cc9d4..cea892dc 100644 --- a/src/webui/api/planner.py +++ b/src/webui/api/planner.py @@ -8,13 +8,17 @@ 3. 详情按需加载 """ -import json from pathlib import Path -from typing import List, Dict, Optional -from fastapi import APIRouter, HTTPException, Query +from typing import Dict, List, Optional + +import json + +from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel -router = APIRouter(prefix="/api/planner", tags=["planner"]) +from src.webui.dependencies import require_auth + +router = APIRouter(prefix="/api/planner", tags=["planner"], dependencies=[Depends(require_auth)]) # 规划器日志目录 PLAN_LOG_DIR = Path("logs/plan") diff --git a/src/webui/api/replier.py b/src/webui/api/replier.py index 0643ceb4..fe25459f 100644 --- a/src/webui/api/replier.py +++ b/src/webui/api/replier.py @@ -8,13 +8,17 @@ 3. 详情按需加载 """ -import json from pathlib import Path -from typing import List, Dict, Optional -from fastapi import APIRouter, HTTPException, Query +from typing import Dict, List, Optional + +import json + +from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel -router = APIRouter(prefix="/api/replier", tags=["replier"]) +from src.webui.dependencies import require_auth + +router = APIRouter(prefix="/api/replier", tags=["replier"], dependencies=[Depends(require_auth)]) # 回复器日志目录 REPLY_LOG_DIR = Path("logs/reply") diff --git a/src/webui/app.py b/src/webui/app.py index 6f3f7e74..feab298b 100644 --- a/src/webui/app.py +++ b/src/webui/app.py @@ -52,7 +52,6 @@ def _setup_cors(app: FastAPI, port: int): allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], allow_headers=[ "Content-Type", - "Authorization", "Accept", "Origin", "X-Requested-With", diff --git a/src/webui/core/__init__.py b/src/webui/core/__init__.py index 3124e897..d0b7c146 100644 --- a/src/webui/core/__init__.py +++ b/src/webui/core/__init__.py @@ -9,6 +9,7 @@ from .auth import ( COOKIE_NAME, COOKIE_MAX_AGE, get_current_token, + is_token_valid, set_auth_cookie, clear_auth_cookie, verify_auth_token_from_cookie_or_header, @@ -24,6 +25,7 @@ __all__ = [ "COOKIE_NAME", "COOKIE_MAX_AGE", "get_current_token", + "is_token_valid", "set_auth_cookie", "clear_auth_cookie", "verify_auth_token_from_cookie_or_header", diff --git a/src/webui/core/auth.py b/src/webui/core/auth.py index ff02b789..73693355 100644 --- a/src/webui/core/auth.py +++ b/src/webui/core/auth.py @@ -1,10 +1,8 @@ -""" -WebUI 认证模块 -提供统一的认证依赖,支持 Cookie 和 Header 两种方式 -""" +"""WebUI 认证模块。""" from typing import Optional -from fastapi import HTTPException, Cookie, Header, Response, Request + +from fastapi import Cookie, HTTPException, Request, Response from src.common.logger import get_logger from src.config.config import global_config from .security import get_token_manager @@ -39,17 +37,13 @@ def _is_secure_environment() -> bool: def get_current_token( - request: Request, maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), ) -> str: """ - 获取当前请求的 token,优先从 Cookie 获取,其次从 Header 获取 + 获取当前请求的 token,仅从 HttpOnly Cookie 获取。 Args: - request: FastAPI Request 对象 maibot_session: Cookie 中的 token - authorization: Authorization Header (Bearer token) Returns: 验证通过的 token @@ -57,24 +51,19 @@ def get_current_token( Raises: HTTPException: 认证失败时抛出 401 错误 """ - token = None - - # 优先从 Cookie 获取 - if maibot_session: - token = maibot_session - # 其次从 Header 获取(兼容旧版本) - elif authorization and authorization.startswith("Bearer "): - token = authorization.replace("Bearer ", "") - - if not token: - raise HTTPException(status_code=401, detail="未提供有效的认证信息") - - # 验证 token - token_manager = get_token_manager() - if not token_manager.verify_token(token): + if not is_token_valid(maibot_session): raise HTTPException(status_code=401, detail="Token 无效或已过期") - return token + return maibot_session + + +def is_token_valid(maibot_session: Optional[str]) -> bool: + """判断认证 token 是否存在且有效。""" + if not maibot_session: + return False + + token_manager = get_token_manager() + return token_manager.verify_token(maibot_session) def set_auth_cookie(response: Response, token: str, request: Optional[Request] = None) -> None: @@ -150,14 +139,12 @@ def clear_auth_cookie(response: Response) -> None: def verify_auth_token_from_cookie_or_header( maibot_session: Optional[str] = None, - authorization: Optional[str] = None, ) -> bool: """ - 验证认证 Token,支持从 Cookie 或 Header 获取 + 验证认证 Cookie。 Args: maibot_session: Cookie 中的 token - authorization: Authorization header (Bearer token) Returns: 验证成功返回 True @@ -165,21 +152,7 @@ def verify_auth_token_from_cookie_or_header( Raises: HTTPException: 认证失败时抛出 401 错误 """ - token = None - - # 优先从 Cookie 获取 - if maibot_session: - token = maibot_session - # 其次从 Header 获取(兼容旧版本) - elif authorization and authorization.startswith("Bearer "): - token = authorization.replace("Bearer ", "") - - if not token: - raise HTTPException(status_code=401, detail="未提供有效的认证信息") - - # 验证 token - token_manager = get_token_manager() - if not token_manager.verify_token(token): + if not is_token_valid(maibot_session): raise HTTPException(status_code=401, detail="Token 无效或已过期") return True diff --git a/src/webui/dependencies.py b/src/webui/dependencies.py index b7f14348..d29663fd 100644 --- a/src/webui/dependencies.py +++ b/src/webui/dependencies.py @@ -1,17 +1,16 @@ from typing import Optional -from fastapi import Depends, Cookie, Header, Request -from .core import get_current_token, get_token_manager, check_auth_rate_limit + +from fastapi import Cookie, Depends, Request +from .core import check_auth_rate_limit, get_current_token, is_token_valid async def require_auth( - request: Request, maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), ) -> str: """ FastAPI 依赖:要求有效认证 - 用于保护需要认证的路由,自动从 Cookie 或 Header 获取并验证 token + 用于保护需要认证的路由,自动从 Cookie 获取并验证 token Returns: 验证通过的 token @@ -19,13 +18,12 @@ async def require_auth( Raises: HTTPException 401: 认证失败 """ - return get_current_token(request, maibot_session, authorization) + return get_current_token(maibot_session) async def require_auth_with_rate_limit( request: Request, maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), _rate_limit: None = Depends(check_auth_rate_limit), ) -> str: """ @@ -40,12 +38,11 @@ async def require_auth_with_rate_limit( HTTPException 401: 认证失败 HTTPException 429: 请求过于频繁 """ - return get_current_token(request, maibot_session, authorization) + return get_current_token(maibot_session) def get_optional_token( maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), ) -> Optional[str]: """ FastAPI 依赖:可选获取 token(不验证) @@ -55,16 +52,11 @@ def get_optional_token( Returns: token 字符串或 None """ - if maibot_session: - return maibot_session - if authorization and authorization.startswith("Bearer "): - return authorization.replace("Bearer ", "") - return None + return maibot_session or None async def verify_token_optional( maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), ) -> bool: """ FastAPI 依赖:可选验证 token @@ -74,14 +66,4 @@ async def verify_token_optional( Returns: True 如果 token 有效,否则 False """ - token = None - if maibot_session: - token = maibot_session - elif authorization and authorization.startswith("Bearer "): - token = authorization.replace("Bearer ", "") - - if not token: - return False - - token_manager = get_token_manager() - return token_manager.verify_token(token) + return is_token_valid(maibot_session) diff --git a/src/webui/routers/annual_report.py b/src/webui/routers/annual_report.py deleted file mode 100644 index 2eeb8165..00000000 --- a/src/webui/routers/annual_report.py +++ /dev/null @@ -1,916 +0,0 @@ -"""麦麦 2025 年度总结 API 路由""" - -from datetime import datetime -from typing import Any, Optional - -from fastapi import APIRouter, Cookie, Depends, Header, HTTPException -from pydantic import BaseModel, Field -from sqlalchemy import desc, func -from sqlmodel import col, select - -from src.common.database.database import get_db_session -from src.common.database.database_model import ( - ActionRecord, - Expression, - Images, - Jargon, - Messages, - ModelUsage, - OnlineTime, - PersonInfo, -) -from src.common.logger import get_logger -from src.webui.core import verify_auth_token_from_cookie_or_header - -logger = get_logger("webui.annual_report") - -router = APIRouter(prefix="/annual-report", tags=["annual-report"]) - - -def require_auth( - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -) -> bool: - """认证依赖:验证用户是否已登录""" - return verify_auth_token_from_cookie_or_header(maibot_session, authorization) - - -# ==================== Pydantic 模型定义 ==================== - - -class TimeFootprintData(BaseModel): - """时光足迹数据""" - - total_online_hours: float = Field(0.0, description="年度在线总时长(小时)") - first_message_time: Optional[str] = Field(None, description="初次消息时间") - first_message_user: Optional[str] = Field(None, description="初次消息用户昵称") - first_message_content: Optional[str] = Field(None, description="初次消息内容(截断)") - busiest_day: Optional[str] = Field(None, description="最忙碌的一天") - busiest_day_count: int = Field(0, description="最忙碌那天的消息数") - hourly_distribution: list[int] = Field(default_factory=lambda: [0] * 24, description="24小时活跃分布") - midnight_chat_count: int = Field(0, description="深夜(0-4点)互动次数") - is_night_owl: bool = Field(False, description="是否是夜猫子") - - -class SocialNetworkData(BaseModel): - """社交网络数据""" - - total_groups: int = Field(0, description="加入的群组总数") - top_groups: list[dict[str, Any]] = Field(default_factory=list, description="话痨群组TOP5") - top_users: list[dict[str, Any]] = Field(default_factory=list, description="互动最多的用户TOP5") - at_count: int = Field(0, description="被@次数") - mentioned_count: int = Field(0, description="被提及次数") - longest_companion_user: Optional[str] = Field(None, description="最长情陪伴的用户") - longest_companion_days: int = Field(0, description="陪伴天数") - - -class BrainPowerData(BaseModel): - """最强大脑数据""" - - total_tokens: int = Field(0, description="年度消耗Token总量") - total_cost: float = Field(0.0, description="年度总花费") - favorite_model: Optional[str] = Field(None, description="最爱用的模型") - favorite_model_count: int = Field(0, description="最爱模型的调用次数") - model_distribution: list[dict[str, Any]] = Field(default_factory=list, description="模型使用分布") - top_reply_models: list[dict[str, Any]] = Field(default_factory=list, description="最喜欢的回复模型TOP5") - most_expensive_cost: float = Field(0.0, description="最昂贵的一次思考花费") - most_expensive_time: Optional[str] = Field(None, description="最昂贵思考的时间") - top_token_consumers: list[dict[str, Any]] = Field(default_factory=list, description="烧钱大户TOP3") - silence_rate: float = Field(0.0, description="高冷指数(沉默率)") - total_actions: int = Field(0, description="总动作数") - no_reply_count: int = Field(0, description="选择沉默的次数") - avg_interest_value: float = Field(0.0, description="平均兴趣值") - max_interest_value: float = Field(0.0, description="最高兴趣值") - max_interest_time: Optional[str] = Field(None, description="最高兴趣值时间") - avg_reasoning_length: float = Field(0.0, description="平均思考长度") - max_reasoning_length: int = Field(0, description="最长思考长度") - max_reasoning_time: Optional[str] = Field(None, description="最长思考的时间") - - -class ExpressionVibeData(BaseModel): - """个性与表达数据""" - - top_emoji: Optional[dict[str, Any]] = Field(None, description="表情包之王") - top_emojis: list[dict[str, Any]] = Field(default_factory=list, description="TOP3表情包") - top_expressions: list[dict[str, Any]] = Field(default_factory=list, description="印象最深刻的表达风格") - rejected_expression_count: int = Field(0, description="被拒绝的表达次数") - checked_expression_count: int = Field(0, description="已检查的表达次数") - total_expressions: int = Field(0, description="表达总数") - action_types: list[dict[str, Any]] = Field(default_factory=list, description="动作类型分布") - image_processed_count: int = Field(0, description="处理的图片数量") - late_night_reply: Optional[dict[str, Any]] = Field(None, description="深夜还在回复") - favorite_reply: Optional[dict[str, Any]] = Field(None, description="最喜欢的回复") - - -class AchievementData(BaseModel): - """趣味成就数据""" - - new_jargon_count: int = Field(0, description="新学到的黑话数量") - sample_jargons: list[dict[str, Any]] = Field(default_factory=list, description="代表性黑话示例") - total_messages: int = Field(0, description="总消息数") - total_replies: int = Field(0, description="总回复数") - - -class AnnualReportData(BaseModel): - """年度报告完整数据""" - - year: int = Field(2025, description="报告年份") - bot_name: str = Field("麦麦", description="Bot名称") - generated_at: str = Field(..., description="报告生成时间") - time_footprint: TimeFootprintData = Field(default_factory=lambda: TimeFootprintData.model_construct()) - social_network: SocialNetworkData = Field(default_factory=lambda: SocialNetworkData.model_construct()) - brain_power: BrainPowerData = Field(default_factory=lambda: BrainPowerData.model_construct()) - expression_vibe: ExpressionVibeData = Field(default_factory=lambda: ExpressionVibeData.model_construct()) - achievements: AchievementData = Field(default_factory=lambda: AchievementData.model_construct()) - - -# ==================== 辅助函数 ==================== - - -def get_year_time_range(year: int = 2025) -> tuple[float, float]: - """获取指定年份的时间戳范围""" - start = datetime(year, 1, 1, 0, 0, 0).timestamp() - end = datetime(year, 12, 31, 23, 59, 59).timestamp() - return start, end - - -def get_year_datetime_range(year: int = 2025) -> tuple[datetime, datetime]: - """获取指定年份的 datetime 范围""" - start = datetime(year, 1, 1, 0, 0, 0) - end = datetime(year, 12, 31, 23, 59, 59) - return start, end - - -# ==================== 维度一:时光足迹 ==================== - - -async def get_time_footprint(year: int = 2025) -> TimeFootprintData: - """获取时光足迹数据""" - data = TimeFootprintData.model_construct() - start_ts, end_ts = get_year_time_range(year) - start_dt, end_dt = get_year_datetime_range(year) - - try: - # 1. 年度在线时长 - with get_db_session() as session: - statement = select(OnlineTime).where( - col(OnlineTime.start_timestamp) >= start_dt, - col(OnlineTime.end_timestamp) <= end_dt, - ) - online_records = session.exec(statement).all() - total_seconds = 0 - for record in online_records: - try: - start = max(record.start_timestamp, start_dt) - end = min(record.end_timestamp, end_dt) - if end > start: - total_seconds += (end - start).total_seconds() - except Exception: - continue - data.total_online_hours = round(total_seconds / 3600, 2) - - # 2. 初次相遇 - 年度第一条消息 - with get_db_session() as session: - statement = ( - select(Messages) - .where( - col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), - col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), - ) - .order_by(col(Messages.timestamp).asc()) - .limit(1) - ) - first_msg = session.exec(statement).first() - if first_msg: - data.first_message_time = first_msg.timestamp.strftime("%Y-%m-%d %H:%M:%S") - data.first_message_user = first_msg.user_nickname or first_msg.user_id or "未知用户" - content = first_msg.processed_plain_text or first_msg.display_message or "" - data.first_message_content = content[:50] + "..." if len(content) > 50 else content - - # 3. 最忙碌的一天 - # 使用 SQLite 的 date 函数按日期分组 - day_expr = func.date(col(Messages.timestamp)) - with get_db_session() as session: - statement = ( - select( - day_expr.label("day"), - func.count().label("count"), - ) - .where( - col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), - col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), - ) - .group_by(day_expr) - .order_by(func.count().desc()) - .limit(1) - ) - busiest_result = session.exec(statement).all() - if busiest_result: - data.busiest_day = busiest_result[0][0] - data.busiest_day_count = busiest_result[0][1] or 0 - - # 4. 昼夜节律 - 24小时活跃分布 - hour_expr = func.strftime("%H", col(Messages.timestamp)) - with get_db_session() as session: - statement = ( - select( - hour_expr.label("hour"), - func.count().label("count"), - ) - .where( - col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), - col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), - ) - .group_by(hour_expr) - ) - hourly_rows = session.exec(statement).all() - hourly_distribution = [0] * 24 - for row in hourly_rows: - try: - hour = int(row[0] or 0) - if 0 <= hour < 24: - hourly_distribution[hour] = row[1] or 0 - except (ValueError, TypeError): - continue - data.hourly_distribution = hourly_distribution - - # 5. 深夜食堂 (0-4点) - data.midnight_chat_count = sum(hourly_distribution[0:5]) - - # 6. 判断是否夜猫子 (22点-4点活跃度 vs 6点-12点) - night_activity = sum(hourly_distribution[22:24]) + sum(hourly_distribution[0:5]) - morning_activity = sum(hourly_distribution[6:13]) - data.is_night_owl = night_activity > morning_activity - - except Exception as e: - logger.error(f"获取时光足迹数据失败: {e}") - - return data - - -# ==================== 维度二:社交网络 ==================== - - -async def get_social_network(year: int = 2025) -> SocialNetworkData: - """获取社交网络数据""" - from src.config.config import global_config - - data = SocialNetworkData.model_construct() - start_ts, end_ts = get_year_time_range(year) - - # 获取 bot 自身的 QQ 账号,用于过滤 - bot_qq = str(global_config.bot.qq_account or "") - - try: - # 1. 加入的群组总数 - with get_db_session() as session: - statement = select(func.count(func.distinct(col(Messages.group_id)))).where( - col(Messages.group_id).is_not(None), - col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), - col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), - ) - data.total_groups = int(session.exec(statement).first() or 0) - - # 2. 话痨群组 TOP3 - with get_db_session() as session: - statement = ( - select( - col(Messages.group_id), - func.max(col(Messages.group_name)).label("group_name"), - func.count().label("count"), - ) - .where( - col(Messages.group_id).is_not(None), - col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), - col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), - ) - .group_by(col(Messages.group_id)) - .order_by(func.count().desc()) - .limit(5) - ) - top_groups_rows = session.exec(statement).all() - data.top_groups = [ - { - "group_id": row[0], - "group_name": row[1] or "未知群组", - "message_count": row[2] or 0, - "is_webui": str(row[0]).startswith("webui_"), - } - for row in top_groups_rows - ] - - # 3. 互动最多的用户 TOP5(过滤 bot 自身) - with get_db_session() as session: - statement = ( - select( - col(Messages.user_id), - func.max(col(Messages.user_nickname)).label("user_nickname"), - func.count().label("count"), - ) - .where( - col(Messages.user_id).is_not(None), - col(Messages.user_id) != bot_qq, - col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), - col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), - ) - .group_by(col(Messages.user_id)) - .order_by(func.count().desc()) - .limit(5) - ) - top_users_rows = session.exec(statement).all() - data.top_users = [ - { - "user_id": row[0], - "user_nickname": row[1] or "未知用户", - "message_count": row[2] or 0, - "is_webui": str(row[0]).startswith("webui_"), - } - for row in top_users_rows - ] - - # 4. 被@次数 - with get_db_session() as session: - statement = select(func.count()).where( - col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), - col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), - col(Messages.is_at), - ) - data.at_count = int(session.exec(statement).first() or 0) - - # 5. 被提及次数 - with get_db_session() as session: - statement = select(func.count()).where( - col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), - col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), - col(Messages.is_mentioned), - ) - data.mentioned_count = int(session.exec(statement).first() or 0) - - # 6. 最长情陪伴的用户(过滤 bot 自身) - with get_db_session() as session: - statement = select(PersonInfo).where( - col(PersonInfo.user_id) != bot_qq, - col(PersonInfo.first_known_time).is_not(None), - col(PersonInfo.last_known_time).is_not(None), - ) - persons = session.exec(statement).all() - if persons: - - def _companion_days(person: PersonInfo) -> float: - if not person.first_known_time or not person.last_known_time: - return 0.0 - return (person.last_known_time - person.first_known_time).total_seconds() - - longest = max(persons, key=_companion_days) - data.longest_companion_user = longest.person_name or longest.user_nickname or longest.user_id - data.longest_companion_days = int(_companion_days(longest) / 86400) - else: - data.longest_companion_user = None - data.longest_companion_days = 0 - - except Exception as e: - logger.error(f"获取社交网络数据失败: {e}") - - return data - - -# ==================== 维度三:最强大脑 ==================== - - -async def get_brain_power(year: int = 2025) -> BrainPowerData: - """获取最强大脑数据""" - data = BrainPowerData.model_construct() - start_dt, end_dt = get_year_datetime_range(year) - start_ts, end_ts = get_year_time_range(year) - - try: - # 1. 年度消耗 Token 总量和总花费 - with get_db_session() as session: - statement = select( - func.sum(col(ModelUsage.total_tokens)).label("total_tokens"), - func.sum(col(ModelUsage.cost)).label("total_cost"), - ).where(col(ModelUsage.timestamp) >= start_dt, col(ModelUsage.timestamp) <= end_dt) - result = session.exec(statement).first() - if result: - data.total_tokens = int(result[0] or 0) - data.total_cost = round(float(result[1] or 0), 4) - - # 2. 最爱用的模型 - with get_db_session() as session: - statement = ( - select(ModelUsage) - .where(col(ModelUsage.timestamp) >= start_dt, col(ModelUsage.timestamp) <= end_dt) - .order_by(desc(col(ModelUsage.timestamp))) - ) - records = session.exec(statement).all() - - model_agg: dict[str, dict[str, float | int]] = {} - for record in records: - model_name = record.model_assign_name or record.model_name or "unknown" - if model_name not in model_agg: - model_agg[model_name] = {"count": 0, "tokens": 0, "cost": 0.0} - bucket = model_agg[model_name] - bucket["count"] = int(bucket["count"]) + 1 - bucket["tokens"] = int(bucket["tokens"]) + int(record.total_tokens or 0) - bucket["cost"] = float(bucket["cost"]) + float(record.cost or 0.0) - - model_results = sorted( - model_agg.items(), - key=lambda item: float(item[1]["count"]), - reverse=True, - )[:10] - if model_results: - data.favorite_model = model_results[0][0] - data.favorite_model_count = int(model_results[0][1]["count"]) - data.model_distribution = [ - { - "model": model_name, - "count": int(bucket["count"]), - "tokens": int(bucket["tokens"]), - "cost": round(float(bucket["cost"]), 4), - } - for model_name, bucket in model_results - ] - - # 3. 最昂贵的一次思考 - if records: - expensive_record = max(records, key=lambda record: record.cost or 0.0) - data.most_expensive_cost = round(expensive_record.cost or 0.0, 4) - data.most_expensive_time = expensive_record.timestamp.strftime("%Y-%m-%d %H:%M:%S") - - # 4. 烧钱大户 TOP3 (按用户,过滤 system) - consumer_agg: dict[str, dict[str, float | int]] = {} - for record in records: - user_id = record.model_api_provider_name - if not user_id or user_id == "system": - continue - if user_id not in consumer_agg: - consumer_agg[user_id] = {"cost": 0.0, "tokens": 0} - bucket = consumer_agg[user_id] - bucket["cost"] = float(bucket["cost"]) + float(record.cost or 0.0) - bucket["tokens"] = int(bucket["tokens"]) + int(record.total_tokens or 0) - - data.top_token_consumers = [ - { - "user_id": user_id, - "cost": round(float(bucket["cost"]), 4), - "tokens": int(bucket["tokens"]), - } - for user_id, bucket in sorted( - consumer_agg.items(), - key=lambda item: float(item[1]["cost"]), - reverse=True, - )[:3] - ] - - # 5. 最喜欢的回复模型 TOP5(按模型的回复次数统计,只统计 replyer 调用) - # 假设 replyer 调用有特定的 model_assign_name 格式或可以通过某种方式识别 - reply_model_agg: dict[str, int] = {} - for record in records: - model_assign_name = record.model_assign_name or "" - if "replyer" not in model_assign_name and "回复" not in model_assign_name: - continue - model_name = model_assign_name or record.model_name or "unknown" - reply_model_agg[model_name] = reply_model_agg.get(model_name, 0) + 1 - data.top_reply_models = [ - {"model": model_name, "count": count} - for model_name, count in sorted(reply_model_agg.items(), key=lambda item: item[1], reverse=True)[:5] - ] - - # 6. 高冷指数 (沉默率) - 基于 ActionRecords - with get_db_session() as session: - statement = select(func.count()).where( - col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts), - col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts), - ) - total_actions = int(session.exec(statement).first() or 0) - with get_db_session() as session: - statement = select(func.count()).where( - col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts), - col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts), - col(ActionRecord.action_name) == "no_reply", - ) - no_reply_count = int(session.exec(statement).first() or 0) - data.total_actions = total_actions - data.no_reply_count = no_reply_count - data.silence_rate = round(no_reply_count / total_actions * 100, 2) if total_actions > 0 else 0 - - # 6. 情绪波动 (兴趣值) - data.avg_interest_value = 0.0 - data.max_interest_value = 0.0 - - # 找到最高兴趣值的时间 - if data.max_interest_value > 0: - data.max_interest_time = None - - # 7. 思考深度 (基于 action_reasoning 长度) - with get_db_session() as session: - statement = select(ActionRecord).where( - col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts), - col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts), - col(ActionRecord.action_reasoning).is_not(None), - col(ActionRecord.action_reasoning) != "", - ) - reasoning_records = session.exec(statement).all() - reasoning_lengths = [] - max_len = 0 - max_len_time = None - for record in reasoning_records: - if record.action_reasoning: - length = len(record.action_reasoning) - reasoning_lengths.append(length) - if length > max_len: - max_len = length - max_len_time = record.timestamp - - if reasoning_lengths: - data.avg_reasoning_length = round(sum(reasoning_lengths) / len(reasoning_lengths), 1) - data.max_reasoning_length = max_len - if max_len_time: - data.max_reasoning_time = max_len_time.strftime("%Y-%m-%d %H:%M:%S") - - except Exception as e: - logger.error(f"获取最强大脑数据失败: {e}") - - return data - - -# ==================== 维度四:个性与表达 ==================== - - -async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData: - """获取个性与表达数据""" - from src.config.config import global_config - - data = ExpressionVibeData.model_construct() - start_ts, end_ts = get_year_time_range(year) - - # 获取 bot 自身的 QQ 账号,用于筛选 bot 发送的消息 - bot_qq = str(global_config.bot.qq_account or "") - - try: - # 1. 表情包之王 - 使用次数最多的表情包 - with get_db_session() as session: - statement = select(Images).where(col(Images.is_registered)).order_by(desc(col(Images.query_count))).limit(5) - top_emojis = session.exec(statement).all() - if top_emojis: - data.top_emoji = { - "id": top_emojis[0].id, - "path": top_emojis[0].full_path, - "description": top_emojis[0].description, - "usage_count": top_emojis[0].query_count, - "hash": top_emojis[0].image_hash, - } - data.top_emojis = [ - { - "id": e.id, - "path": e.full_path, - "description": e.description, - "usage_count": e.query_count, - "hash": e.image_hash, - } - for e in top_emojis - ] - - # 2. 百变麦麦 - 最常用的表达风格 - with get_db_session() as session: - statement = ( - select(Expression.style, func.sum(col(Expression.count)).label("total_count")) - .where( - col(Expression.last_active_time) >= datetime.fromtimestamp(start_ts), - col(Expression.last_active_time) <= datetime.fromtimestamp(end_ts), - ) - .group_by(Expression.style) - .order_by(func.sum(col(Expression.count)).desc()) - .limit(5) - ) - expression_rows = session.exec(statement).all() - data.top_expressions = [{"style": row[0], "count": row[1] or 0} for row in expression_rows] - - # 3. 被拒绝的表达 - data.rejected_expression_count = 0 - - # 4. 已检查的表达 - data.checked_expression_count = 0 - - # 5. 表达总数 - with get_db_session() as session: - statement = select(func.count()).where( - col(Expression.last_active_time) >= datetime.fromtimestamp(start_ts), - col(Expression.last_active_time) <= datetime.fromtimestamp(end_ts), - ) - data.total_expressions = int(session.exec(statement).first() or 0) - - # 6. 动作类型分布 (过滤无意义的动作) - # 过滤掉: no_reply_until_call, make_question, no_action, wait, complete_talk, listening, block_and_ignore - excluded_actions = [ - "reply", - "no_reply", - "no_reply_until_call", - "make_question", - "no_action", - "wait", - "complete_talk", - "listening", - "block_and_ignore", - ] - with get_db_session() as session: - statement = ( - select(ActionRecord.action_name, func.count().label("count")) - .where( - col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts), - col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts), - col(ActionRecord.action_name).not_in(excluded_actions), - ) - .group_by(ActionRecord.action_name) - .order_by(func.count().desc()) - .limit(10) - ) - action_rows = session.exec(statement).all() - data.action_types = [{"action": row[0], "count": row[1]} for row in action_rows] - - # 7. 处理的图片数量 - with get_db_session() as session: - statement = select(func.count()).where( - col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), - col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), - col(Messages.is_picture), - ) - data.image_processed_count = int(session.exec(statement).first() or 0) - - # 8. 深夜还在回复 (0-6点最晚的10条消息中随机抽取一条) - import random - import re - - def clean_message_content(content: str) -> str: - """清理消息内容,移除回复引用等标记""" - if not content: - return "" - # 移除 [回复 的消息:...] 格式的引用 - content = re.sub(r"\[回复<[^>]+>\s*的消息[::][^\]]*\]", "", content) - # 移除 [图片] [表情] 等标记 - content = re.sub(r"\[(图片|表情|语音|视频|文件)\]", "", content) - # 移除多余的空白 - content = re.sub(r"\s+", " ", content).strip() - return content - - # 使用 user_id 判断是否是 bot 发送的消息 - with get_db_session() as session: - statement = ( - select(Messages) - .where( - col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), - col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), - col(Messages.user_id) == bot_qq, - ) - .order_by(desc(col(Messages.timestamp))) - .limit(200) - ) - late_night_messages = session.exec(statement).all() - # 筛选出0-6点的消息 - late_night_filtered = [] - for msg in late_night_messages: - msg_dt = msg.timestamp - hour = msg_dt.hour - if 0 <= hour < 6: # 0点到6点 - raw_content = msg.processed_plain_text or msg.display_message or "" - cleaned_content = clean_message_content(raw_content) - # 只保留有意义的内容 - if cleaned_content and len(cleaned_content) > 2: - late_night_filtered.append( - { - "time": msg_dt.timestamp(), - "hour": hour, - "minute": msg_dt.minute, - "content": cleaned_content, - "datetime_str": msg_dt.strftime("%H:%M"), - } - ) - if len(late_night_filtered) >= 10: - break - - if late_night_filtered: - selected = random.choice(late_night_filtered) - content = selected["content"][:50] + "..." if len(selected["content"]) > 50 else selected["content"] - data.late_night_reply = { - "time": selected["datetime_str"], - "content": content, - } - - # 9. 最喜欢的回复(按 action_data 统计回复内容出现次数) - from collections import Counter - import json as json_lib - - with get_db_session() as session: - statement = select(ActionRecord).where( - col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts), - col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts), - col(ActionRecord.action_name) == "reply", - col(ActionRecord.action_data).is_not(None), - col(ActionRecord.action_data) != "", - ) - reply_records = session.exec(statement).all() - - reply_contents = [] - for record in reply_records: - try: - action_data = record.action_data - if action_data: - content = None - # 尝试解析 JSON 格式 - try: - parsed = json_lib.loads(action_data) - if isinstance(parsed, dict): - # 优先使用 reply_text,其次使用 content - content = parsed.get("reply_text") or parsed.get("content") - elif isinstance(parsed, str): - content = parsed - except (json_lib.JSONDecodeError, TypeError): - pass - - # 如果 JSON 解析失败,尝试解析 Python 字典字符串格式 - # 例如: "{'reply_text': '墨白灵不知道哦'}" - if content is None: - import ast - - try: - parsed = ast.literal_eval(action_data) - if isinstance(parsed, dict): - content = parsed.get("reply_text") or parsed.get("content") - elif isinstance(parsed, str): - content = parsed - except (ValueError, SyntaxError): - # 无法解析,使用原始字符串 - content = action_data - - # 只统计有意义的回复(长度大于2) - if content and len(content) > 2: - reply_contents.append(content) - except Exception: - continue - - if reply_contents: - content_counter = Counter(reply_contents) - most_common = content_counter.most_common(1) - if most_common: - fav_content, fav_count = most_common[0] - # 截断过长的内容 - display_content = fav_content[:50] + "..." if len(fav_content) > 50 else fav_content - data.favorite_reply = { - "content": display_content, - "count": fav_count, - } - - except Exception as e: - logger.error(f"获取个性与表达数据失败: {e}") - - return data - - -# ==================== 维度五:趣味成就 ==================== - - -async def get_achievements(year: int = 2025) -> AchievementData: - """获取趣味成就数据""" - data = AchievementData.model_construct() - start_ts, end_ts = get_year_time_range(year) - - try: - # 1. 新学到的黑话数量 - # Jargon 表没有时间字段,统计全部已确认的黑话 - with get_db_session() as session: - statement = select(func.count()).where(col(Jargon.is_jargon)) - data.new_jargon_count = int(session.exec(statement).first() or 0) - - # 2. 代表性黑话示例 - with get_db_session() as session: - statement = select(Jargon).where(col(Jargon.is_jargon)).order_by(desc(col(Jargon.count))).limit(5) - jargon_samples = session.exec(statement).all() - data.sample_jargons = [ - { - "content": j.content, - "meaning": j.meaning, - "count": j.count, - } - for j in jargon_samples - ] - - # 3. 总消息数 - with get_db_session() as session: - statement = select(func.count()).where( - col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), - col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), - ) - data.total_messages = int(session.exec(statement).first() or 0) - - # 4. 总回复数 (有 reply_to 的消息) - with get_db_session() as session: - statement = select(func.count()).where( - col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), - col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), - col(Messages.reply_to).is_not(None), - ) - data.total_replies = int(session.exec(statement).first() or 0) - - except Exception as e: - logger.error(f"获取趣味成就数据失败: {e}") - - return data - - -# ==================== API 路由 ==================== - - -@router.get("/full", response_model=AnnualReportData) -async def get_full_annual_report(year: int = 2025, _auth: bool = Depends(require_auth)): - """ - 获取完整年度报告数据 - - Args: - year: 报告年份,默认2025 - - Returns: - 完整的年度报告数据 - """ - try: - from src.config.config import global_config - - logger.info(f"开始生成 {year} 年度报告...") - - # 获取 bot 名称 - bot_name = global_config.bot.nickname or "麦麦" - - # 并行获取各维度数据 - time_footprint = await get_time_footprint(year) - social_network = await get_social_network(year) - brain_power = await get_brain_power(year) - expression_vibe = await get_expression_vibe(year) - achievements = await get_achievements(year) - - report = AnnualReportData( - year=year, - bot_name=bot_name, - generated_at=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - time_footprint=time_footprint, - social_network=social_network, - brain_power=brain_power, - expression_vibe=expression_vibe, - achievements=achievements, - ) - - logger.info(f"{year} 年度报告生成完成") - return report - - except Exception as e: - logger.error(f"生成年度报告失败: {e}") - raise HTTPException(status_code=500, detail=f"生成年度报告失败: {str(e)}") from e - - -@router.get("/time-footprint", response_model=TimeFootprintData) -async def get_time_footprint_api(year: int = 2025, _auth: bool = Depends(require_auth)): - """获取时光足迹数据""" - try: - return await get_time_footprint(year) - except Exception as e: - logger.error(f"获取时光足迹数据失败: {e}") - raise HTTPException(status_code=500, detail=str(e)) from e - - -@router.get("/social-network", response_model=SocialNetworkData) -async def get_social_network_api(year: int = 2025, _auth: bool = Depends(require_auth)): - """获取社交网络数据""" - try: - return await get_social_network(year) - except Exception as e: - logger.error(f"获取社交网络数据失败: {e}") - raise HTTPException(status_code=500, detail=str(e)) from e - - -@router.get("/brain-power", response_model=BrainPowerData) -async def get_brain_power_api(year: int = 2025, _auth: bool = Depends(require_auth)): - """获取最强大脑数据""" - try: - return await get_brain_power(year) - except Exception as e: - logger.error(f"获取最强大脑数据失败: {e}") - raise HTTPException(status_code=500, detail=str(e)) from e - - -@router.get("/expression-vibe", response_model=ExpressionVibeData) -async def get_expression_vibe_api(year: int = 2025, _auth: bool = Depends(require_auth)): - """获取个性与表达数据""" - try: - return await get_expression_vibe(year) - except Exception as e: - logger.error(f"获取个性与表达数据失败: {e}") - raise HTTPException(status_code=500, detail=str(e)) from e - - -@router.get("/achievements", response_model=AchievementData) -async def get_achievements_api(year: int = 2025, _auth: bool = Depends(require_auth)): - """获取趣味成就数据""" - try: - return await get_achievements(year) - except Exception as e: - logger.error(f"获取趣味成就数据失败: {e}") - raise HTTPException(status_code=500, detail=str(e)) from e diff --git a/src/webui/routers/chat.py b/src/webui/routers/chat.py deleted file mode 100644 index e1e71780..00000000 --- a/src/webui/routers/chat.py +++ /dev/null @@ -1,801 +0,0 @@ -"""本地聊天室路由 - WebUI 与麦麦直接对话 - -支持两种模式: -1. WebUI 模式:使用 WebUI 平台独立身份聊天 -2. 虚拟身份模式:使用真实平台用户的身份,在虚拟群聊中与麦麦对话 -""" - -import time -import uuid -from typing import Any, Dict, List, Optional - -from fastapi import APIRouter, Cookie, Depends, Header, Query, WebSocket, WebSocketDisconnect -from pydantic import BaseModel -from sqlalchemy import case, desc, func -from sqlmodel import col, select, delete - -from src.chat.message_receive.bot import chat_bot -from src.common.database.database import get_db_session -from src.common.database.database_model import Messages, PersonInfo -from src.common.logger import get_logger -from src.config.config import global_config -from src.webui.core import get_token_manager, verify_auth_token_from_cookie_or_header -from src.webui.routers.websocket.auth import verify_ws_token - -logger = get_logger("webui.chat") - -router = APIRouter(prefix="/api/chat", tags=["LocalChat"]) - - -def require_auth( - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -) -> bool: - """认证依赖:验证用户是否已登录""" - return verify_auth_token_from_cookie_or_header(maibot_session, authorization) - - -# WebUI 聊天的虚拟群组 ID -WEBUI_CHAT_GROUP_ID = "webui_local_chat" -WEBUI_CHAT_PLATFORM = "webui" - -# 虚拟身份模式的群 ID 前缀 -VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_" - -# 固定的 WebUI 用户 ID 前缀 -WEBUI_USER_ID_PREFIX = "webui_user_" - - -class VirtualIdentityConfig(BaseModel): - """虚拟身份配置""" - - enabled: bool = False # 是否启用虚拟身份模式 - platform: Optional[str] = None # 目标平台(如 qq, discord 等) - person_id: Optional[str] = None # PersonInfo 的 person_id - user_id: Optional[str] = None # 原始平台用户 ID - user_nickname: Optional[str] = None # 用户昵称 - group_id: Optional[str] = None # 虚拟群 ID(自动生成或用户指定) - group_name: Optional[str] = None # 虚拟群名(用户自定义) - - -class ChatHistoryMessage(BaseModel): - """聊天历史消息""" - - id: str - type: str # 'user' | 'bot' | 'system' - content: str - timestamp: float - sender_name: str - sender_id: Optional[str] = None - is_bot: bool = False - - -class ChatHistoryManager: - """聊天历史管理器 - 使用 SQLite 数据库存储""" - - def __init__(self, max_messages: int = 200): - self.max_messages = max_messages - - def _message_to_dict(self, msg: Messages, group_id: Optional[str] = None) -> Dict[str, Any]: - """将数据库消息转换为前端格式 - - Args: - msg: 数据库消息对象 - group_id: 群 ID,用于判断是否是虚拟群 - """ - # 判断是否是机器人消息 - user_id = msg.user_id or "" - - # 对于虚拟群,通过比较机器人 QQ 账号来判断 - # 对于普通 WebUI 群,检查 user_id 是否以 webui_ 开头 - if group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX): - # 虚拟群:user_id 等于机器人 QQ 账号的是机器人消息 - bot_qq = str(global_config.bot.qq_account) - is_bot = user_id == bot_qq - else: - # 普通 WebUI 群:不以 webui_ 开头的是机器人消息 - is_bot = not user_id.startswith("webui_") and not user_id.startswith(WEBUI_USER_ID_PREFIX) - - return { - "id": msg.message_id, - "type": "bot" if is_bot else "user", - "content": msg.processed_plain_text or msg.display_message or "", - "timestamp": msg.timestamp.timestamp(), - "sender_name": msg.user_nickname or (global_config.bot.nickname if is_bot else "未知用户"), - "sender_id": "bot" if is_bot else user_id, - "is_bot": is_bot, - } - - def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]: - """从数据库获取最近的历史记录 - - Args: - limit: 获取的消息数量 - group_id: 群 ID,默认为 WEBUI_CHAT_GROUP_ID - """ - target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID - try: - # 查询指定群的消息,按时间排序 - with get_db_session() as session: - statement = ( - select(Messages) - .where(col(Messages.group_id) == target_group_id) - .order_by(desc(col(Messages.timestamp))) - .limit(limit) - ) - messages = session.exec(statement).all() - - # 转换为列表并反转(使最旧的消息在前) - # 传递 group_id 以便正确判断虚拟群中的机器人消息 - result = [self._message_to_dict(msg, target_group_id) for msg in messages] - result.reverse() - - logger.debug(f"从数据库加载了 {len(result)} 条聊天记录 (group_id={target_group_id})") - return result - except Exception as e: - logger.error(f"从数据库加载聊天记录失败: {e}") - return [] - - def clear_history(self, group_id: Optional[str] = None) -> int: - """清空聊天历史记录 - - Args: - group_id: 群 ID,默认清空 WebUI 默认聊天室 - """ - target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID - try: - with get_db_session() as session: - statement = delete(Messages).where(col(Messages.group_id) == target_group_id) - result = session.exec(statement) - deleted = result.rowcount or 0 - logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})") - return deleted - except Exception as e: - logger.error(f"清空聊天记录失败: {e}") - return 0 - - -# 全局聊天历史管理器 -chat_history = ChatHistoryManager() - - -# 存储 WebSocket 连接 -class ChatConnectionManager: - """聊天连接管理器""" - - def __init__(self): - self.active_connections: Dict[str, WebSocket] = {} - self.user_sessions: Dict[str, str] = {} # user_id -> session_id 映射 - - async def connect(self, websocket: WebSocket, session_id: str, user_id: str): - await websocket.accept() - self.active_connections[session_id] = websocket - self.user_sessions[user_id] = session_id - logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}") - - def disconnect(self, session_id: str, user_id: str): - if session_id in self.active_connections: - del self.active_connections[session_id] - if user_id in self.user_sessions and self.user_sessions[user_id] == session_id: - del self.user_sessions[user_id] - logger.info(f"WebUI 聊天会话已断开: session={session_id}") - - async def send_message(self, session_id: str, message: dict[str, Any]): - if session_id in self.active_connections: - try: - await self.active_connections[session_id].send_json(message) - except Exception as e: - logger.error(f"发送消息失败: {e}") - - async def broadcast(self, message: dict[str, Any]): - """广播消息给所有连接""" - for session_id in list(self.active_connections.keys()): - await self.send_message(session_id, message) - - -chat_manager = ChatConnectionManager() - - -def create_message_data( - content: str, - user_id: str, - user_name: str, - message_id: Optional[str] = None, - is_at_bot: bool = True, - virtual_config: Optional[VirtualIdentityConfig] = None, -) -> Dict[str, Any]: - """创建符合麦麦消息格式的消息数据 - - Args: - content: 消息内容 - user_id: 用户 ID - user_name: 用户昵称 - message_id: 消息 ID(可选,自动生成) - is_at_bot: 是否 @ 机器人 - virtual_config: 虚拟身份配置(可选,启用后使用真实平台身份) - """ - if message_id is None: - message_id = str(uuid.uuid4()) - - # 确定使用的平台、群信息和用户信息 - if virtual_config and virtual_config.enabled: - # 虚拟身份模式:使用真实平台身份 - platform = virtual_config.platform or WEBUI_CHAT_PLATFORM - group_id = virtual_config.group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{uuid.uuid4().hex[:8]}" - group_name = virtual_config.group_name or "WebUI虚拟群聊" - actual_user_id = virtual_config.user_id or user_id - actual_user_name = virtual_config.user_nickname or user_name - else: - # 标准 WebUI 模式 - platform = WEBUI_CHAT_PLATFORM - group_id = WEBUI_CHAT_GROUP_ID - group_name = "WebUI本地聊天室" - actual_user_id = user_id - actual_user_name = user_name - - return { - "message_info": { - "platform": platform, - "message_id": message_id, - "time": time.time(), - "group_info": { - "group_id": group_id, - "group_name": group_name, - "platform": platform, - }, - "user_info": { - "user_id": actual_user_id, - "user_nickname": actual_user_name, - "user_cardname": actual_user_name, - "platform": platform, - }, - "additional_config": { - "at_bot": is_at_bot, - }, - }, - "message_segment": { - "type": "seglist", - "data": [ - { - "type": "text", - "data": content, - }, - { - "type": "mention_bot", - "data": "1.0", - }, - ], - }, - "raw_message": content, - "processed_plain_text": content, - } - - -@router.get("/history") -async def get_chat_history( - limit: int = Query(default=50, ge=1, le=200), - user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤 - group_id: Optional[str] = Query(default=None), # 可选:指定群 ID 获取历史 - _auth: bool = Depends(require_auth), -): - """获取聊天历史记录 - - 所有 WebUI 用户共享同一个聊天室,因此返回所有历史记录 - 如果指定了 group_id,则获取该虚拟群的历史记录 - """ - target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID - history = chat_history.get_history(limit, target_group_id) - return { - "success": True, - "messages": history, - "total": len(history), - } - - -@router.get("/platforms") -async def get_available_platforms(_auth: bool = Depends(require_auth)): - """获取可用平台列表 - - 从 PersonInfo 表中获取所有已知的平台 - """ - try: - # 查询所有不同的平台 - with get_db_session() as session: - statement = ( - select(PersonInfo.platform, func.count().label("count")) - .group_by(PersonInfo.platform) - .order_by(func.count().desc()) - ) - platforms = session.exec(statement).all() - - result = [] - for platform, count in platforms: - if platform: - result.append({"platform": platform, "count": count}) - - return {"success": True, "platforms": result} - except Exception as e: - logger.error(f"获取平台列表失败: {e}") - return {"success": False, "error": str(e), "platforms": []} - - -@router.get("/persons") -async def get_persons_by_platform( - platform: str = Query(..., description="平台名称"), - search: Optional[str] = Query(default=None, description="搜索关键词"), - limit: int = Query(default=50, ge=1, le=200), - _auth: bool = Depends(require_auth), -): - """获取指定平台的用户列表 - - Args: - platform: 平台名称(如 qq, discord 等) - search: 搜索关键词(匹配昵称、用户名、user_id) - limit: 返回数量限制 - """ - try: - # 构建查询 - statement = select(PersonInfo).where(col(PersonInfo.platform) == platform) - - # 搜索过滤 - if search: - statement = statement.where( - (col(PersonInfo.person_name).contains(search)) - | (col(PersonInfo.user_nickname).contains(search)) - | (col(PersonInfo.user_id).contains(search)) - ) - - # 按最后交互时间排序,优先显示活跃用户 - statement = statement.order_by( - case((col(PersonInfo.last_known_time).is_(None), 1), else_=0), - col(PersonInfo.last_known_time).desc(), - ).limit(limit) - - with get_db_session() as session: - persons = session.exec(statement).all() - - result = [] - for person in persons: - result.append( - { - "person_id": person.person_id, - "user_id": person.user_id, - "person_name": person.person_name, - "nickname": person.user_nickname, - "is_known": person.is_known, - "platform": person.platform, - "display_name": person.person_name or person.user_nickname or person.user_id, - } - ) - - return {"success": True, "persons": result, "total": len(result)} - except Exception as e: - logger.error(f"获取用户列表失败: {e}") - return {"success": False, "error": str(e), "persons": []} - - -@router.delete("/history") -async def clear_chat_history(group_id: Optional[str] = Query(default=None), _auth: bool = Depends(require_auth)): - """清空聊天历史记录 - - Args: - group_id: 可选,指定要清空的群 ID,默认清空 WebUI 默认聊天室 - """ - deleted = chat_history.clear_history(group_id) - return { - "success": True, - "message": f"已清空 {deleted} 条聊天记录", - } - - -@router.websocket("/ws") -async def websocket_chat( - websocket: WebSocket, - user_id: Optional[str] = Query(default=None), - user_name: Optional[str] = Query(default="WebUI用户"), - platform: Optional[str] = Query(default=None), - person_id: Optional[str] = Query(default=None), - group_name: Optional[str] = Query(default=None), - group_id: Optional[str] = Query(default=None), # 前端传递的稳定 group_id - token: Optional[str] = Query(default=None), # 认证 token -): - """WebSocket 聊天端点 - - Args: - user_id: 用户唯一标识(由前端生成并持久化) - user_name: 用户显示昵称(可修改) - platform: 虚拟身份模式的平台(可选) - person_id: 虚拟身份模式的用户 person_id(可选) - group_name: 虚拟身份模式的群名(可选) - group_id: 虚拟身份模式的群 ID(可选,由前端生成并持久化) - token: 认证 token(可选,也可从 Cookie 获取) - - 虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置 - - 支持三种认证方式(按优先级): - 1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token) - 2. Cookie 中的 maibot_session - 3. 直接使用 session token(兼容) - - 示例:ws://host/api/chat/ws?token=xxx - """ - is_authenticated = False - - # 方式 1: 尝试验证临时 WebSocket token(推荐方式) - if token and verify_ws_token(token): - is_authenticated = True - logger.debug("聊天 WebSocket 使用临时 token 认证成功") - - # 方式 2: 尝试从 Cookie 获取 session token - if not is_authenticated: - cookie_token = websocket.cookies.get("maibot_session") - if cookie_token: - token_manager = get_token_manager() - if token_manager.verify_token(cookie_token): - is_authenticated = True - logger.debug("聊天 WebSocket 使用 Cookie 认证成功") - - # 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式) - if not is_authenticated and token: - token_manager = get_token_manager() - if token_manager.verify_token(token): - is_authenticated = True - logger.debug("聊天 WebSocket 使用 session token 认证成功") - - if not is_authenticated: - logger.warning("聊天 WebSocket 连接被拒绝:认证失败") - await websocket.close(code=4001, reason="认证失败,请重新登录") - return - - # 生成会话 ID(每次连接都是新的) - session_id = str(uuid.uuid4()) - - # 如果没有提供 user_id,生成一个新的 - if not user_id: - user_id = f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}" - elif not user_id.startswith(WEBUI_USER_ID_PREFIX): - # 确保 user_id 有正确的前缀 - user_id = f"{WEBUI_USER_ID_PREFIX}{user_id}" - - # 当前会话的虚拟身份配置(可通过消息动态更新) - current_virtual_config: Optional[VirtualIdentityConfig] = None - - # 如果 URL 参数中提供了虚拟身份信息,自动配置 - if platform and person_id: - try: - with get_db_session() as session: - statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1) - person = session.exec(statement).first() - if person: - # 使用前端传递的 group_id,如果没有则生成一个稳定的 - virtual_group_id = group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{platform}_{person.user_id}" - current_virtual_config = VirtualIdentityConfig( - enabled=True, - platform=person.platform, - person_id=person.person_id, - user_id=person.user_id, - user_nickname=person.person_name or person.user_nickname or person.user_id, - group_id=virtual_group_id, - group_name=group_name or "WebUI虚拟群聊", - ) - logger.info( - f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}" - ) - except Exception as e: - logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}") - - await chat_manager.connect(websocket, session_id, user_id) - - try: - # 构建会话信息 - session_info_data: dict[str, Any] = { - "type": "session_info", - "session_id": session_id, - "user_id": user_id, - "user_name": user_name, - "bot_name": global_config.bot.nickname, - } - - # 如果有虚拟身份配置,添加到会话信息中 - if current_virtual_config and current_virtual_config.enabled: - session_info_data["virtual_mode"] = True - session_info_data["group_id"] = current_virtual_config.group_id - session_info_data["virtual_identity"] = { - "platform": current_virtual_config.platform, - "user_id": current_virtual_config.user_id, - "user_nickname": current_virtual_config.user_nickname, - "group_name": current_virtual_config.group_name, - } - - # 发送会话信息(包含用户 ID,前端需要保存) - await chat_manager.send_message(session_id, session_info_data) - - # 发送历史记录(根据模式选择不同的群) - if current_virtual_config and current_virtual_config.enabled: - history = chat_history.get_history(50, current_virtual_config.group_id) - else: - history = chat_history.get_history(50) - if history: - await chat_manager.send_message( - session_id, - { - "type": "history", - "messages": history, - }, - ) - - # 发送欢迎消息(不保存到历史) - if current_virtual_config and current_virtual_config.enabled: - welcome_msg = f"已以 {current_virtual_config.user_nickname} 的身份连接到「{current_virtual_config.group_name}」,开始与 {global_config.bot.nickname} 对话吧!" - else: - welcome_msg = f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!" - - await chat_manager.send_message( - session_id, - { - "type": "system", - "content": welcome_msg, - "timestamp": time.time(), - }, - ) - - while True: - data = await websocket.receive_json() - - if data.get("type") == "message": - content = data.get("content", "").strip() - if not content: - continue - - # 用户可以更新昵称 - current_user_name = data.get("user_name", user_name) - - message_id = str(uuid.uuid4()) - timestamp = time.time() - - # 确定发送者信息(根据是否使用虚拟身份) - if current_virtual_config and current_virtual_config.enabled: - sender_name = current_virtual_config.user_nickname or current_user_name - sender_user_id = current_virtual_config.user_id or user_id - else: - sender_name = current_user_name - sender_user_id = user_id - - # 广播用户消息给所有连接(包括发送者) - # 注意:用户消息会在 chat_bot.message_process 中自动保存到数据库 - await chat_manager.broadcast( - { - "type": "user_message", - "content": content, - "message_id": message_id, - "timestamp": timestamp, - "sender": { - "name": sender_name, - "user_id": sender_user_id, - "is_bot": False, - }, - "virtual_mode": current_virtual_config.enabled if current_virtual_config else False, - } - ) - - # 创建麦麦消息格式 - message_data = create_message_data( - content=content, - user_id=user_id, - user_name=current_user_name, - message_id=message_id, - is_at_bot=True, - virtual_config=current_virtual_config, - ) - - try: - # 显示正在输入状态 - await chat_manager.broadcast( - { - "type": "typing", - "is_typing": True, - } - ) - - # 调用麦麦的消息处理 - await chat_bot.message_process(message_data) - - except Exception as e: - logger.error(f"处理消息时出错: {e}") - await chat_manager.send_message( - session_id, - { - "type": "error", - "content": f"处理消息时出错: {str(e)}", - "timestamp": time.time(), - }, - ) - finally: - await chat_manager.broadcast( - { - "type": "typing", - "is_typing": False, - } - ) - - elif data.get("type") == "ping": - await chat_manager.send_message( - session_id, - { - "type": "pong", - "timestamp": time.time(), - }, - ) - - elif data.get("type") == "update_nickname": - # 允许用户更新昵称 - if new_name := data.get("user_name", "").strip(): - current_user_name = new_name - await chat_manager.send_message( - session_id, - { - "type": "nickname_updated", - "user_name": current_user_name, - "timestamp": time.time(), - }, - ) - - elif data.get("type") == "set_virtual_identity": - # 设置或更新虚拟身份配置 - virtual_data = data.get("config", {}) - if virtual_data.get("enabled"): - # 验证必要字段 - if not virtual_data.get("platform") or not virtual_data.get("person_id"): - await chat_manager.send_message( - session_id, - { - "type": "error", - "content": "虚拟身份配置缺少必要字段: platform 和 person_id", - "timestamp": time.time(), - }, - ) - continue - - # 获取用户信息 - try: - with get_db_session() as session: - statement = ( - select(PersonInfo) - .where(col(PersonInfo.person_id) == virtual_data.get("person_id")) - .limit(1) - ) - person = session.exec(statement).first() - if not person: - await chat_manager.send_message( - session_id, - { - "type": "error", - "content": f"找不到用户: {virtual_data.get('person_id')}", - "timestamp": time.time(), - }, - ) - continue - - # 生成虚拟群 ID - custom_group_id = virtual_data.get("group_id") - if custom_group_id: - group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{custom_group_id}" - else: - group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{session_id[:8]}" - - current_virtual_config = VirtualIdentityConfig( - enabled=True, - platform=person.platform, - person_id=person.person_id, - user_id=person.user_id, - user_nickname=person.person_name or person.user_nickname or person.user_id, - group_id=group_id, - group_name=virtual_data.get("group_name", "WebUI虚拟群聊"), - ) - - # 发送虚拟身份已激活的消息 - await chat_manager.send_message( - session_id, - { - "type": "virtual_identity_set", - "config": { - "enabled": True, - "platform": current_virtual_config.platform, - "user_id": current_virtual_config.user_id, - "user_nickname": current_virtual_config.user_nickname, - "group_id": current_virtual_config.group_id, - "group_name": current_virtual_config.group_name, - }, - "timestamp": time.time(), - }, - ) - - # 加载虚拟群的历史记录 - virtual_history = chat_history.get_history(50, current_virtual_config.group_id) - await chat_manager.send_message( - session_id, - { - "type": "history", - "messages": virtual_history, - "group_id": current_virtual_config.group_id, - }, - ) - - # 发送系统消息 - await chat_manager.send_message( - session_id, - { - "type": "system", - "content": f"已切换到虚拟身份模式:以 {current_virtual_config.user_nickname} 的身份在「{current_virtual_config.group_name}」与 {global_config.bot.nickname} 对话", - "timestamp": time.time(), - }, - ) - - except Exception as e: - logger.error(f"设置虚拟身份失败: {e}") - await chat_manager.send_message( - session_id, - { - "type": "error", - "content": f"设置虚拟身份失败: {str(e)}", - "timestamp": time.time(), - }, - ) - else: - # 禁用虚拟身份模式 - current_virtual_config = None - await chat_manager.send_message( - session_id, - { - "type": "virtual_identity_set", - "config": {"enabled": False}, - "timestamp": time.time(), - }, - ) - - # 重新加载默认聊天室历史 - default_history = chat_history.get_history(50, WEBUI_CHAT_GROUP_ID) - await chat_manager.send_message( - session_id, - { - "type": "history", - "messages": default_history, - "group_id": WEBUI_CHAT_GROUP_ID, - }, - ) - - await chat_manager.send_message( - session_id, - { - "type": "system", - "content": "已切换回 WebUI 独立用户模式", - "timestamp": time.time(), - }, - ) - - except WebSocketDisconnect: - logger.info(f"WebSocket 断开: session={session_id}, user={user_id}") - except Exception as e: - logger.error(f"WebSocket 错误: {e}") - finally: - chat_manager.disconnect(session_id, user_id) - - -@router.get("/info") -async def get_chat_info(_auth: bool = Depends(require_auth)): - """获取聊天室信息""" - return { - "bot_name": global_config.bot.nickname, - "platform": WEBUI_CHAT_PLATFORM, - "group_id": WEBUI_CHAT_GROUP_ID, - "active_sessions": len(chat_manager.active_connections), - } - - -def get_webui_chat_broadcaster() -> tuple[ChatConnectionManager, str]: - """获取 WebUI 聊天广播器,供外部模块使用 - - Returns: - (chat_manager, WEBUI_CHAT_PLATFORM) 元组 - """ - return (chat_manager, WEBUI_CHAT_PLATFORM) diff --git a/src/webui/routers/chat/__init__.py b/src/webui/routers/chat/__init__.py new file mode 100644 index 00000000..1ee04fcf --- /dev/null +++ b/src/webui/routers/chat/__init__.py @@ -0,0 +1,18 @@ +from fastapi import APIRouter + +from .routes import router +from .support import ChatConnectionManager, WEBUI_CHAT_PLATFORM, chat_manager + + +def get_webui_chat_broadcaster() -> tuple[ChatConnectionManager, str]: + """获取 WebUI 聊天广播器,供外部模块使用。""" + return chat_manager, WEBUI_CHAT_PLATFORM + + +__all__ = [ + "ChatConnectionManager", + "WEBUI_CHAT_PLATFORM", + "chat_manager", + "get_webui_chat_broadcaster", + "router", +] \ No newline at end of file diff --git a/src/webui/routers/chat/routes.py b/src/webui/routers/chat/routes.py new file mode 100644 index 00000000..805ab3b3 --- /dev/null +++ b/src/webui/routers/chat/routes.py @@ -0,0 +1,174 @@ +"""本地聊天室路由 - WebUI 与麦麦直接对话。""" + +import uuid + +from typing import Optional + +from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect +from sqlalchemy import case, func +from sqlmodel import col, select + +from src.common.database.database import get_db_session +from src.common.database.database_model import PersonInfo +from src.common.logger import get_logger +from src.config.config import global_config +from src.webui.dependencies import require_auth + +from .support import ( + WEBUI_CHAT_GROUP_ID, + WEBUI_CHAT_PLATFORM, + authenticate_chat_websocket, + chat_history, + chat_manager, + dispatch_chat_event, + normalize_webui_user_id, + resolve_initial_virtual_identity, + send_initial_chat_state, +) + +logger = get_logger("webui.chat") + +router = APIRouter(prefix="/api/chat", tags=["LocalChat"], dependencies=[Depends(require_auth)]) + + +@router.get("/history") +async def get_chat_history( + limit: int = Query(default=50, ge=1, le=200), + user_id: Optional[str] = Query(default=None), + group_id: Optional[str] = Query(default=None), +) -> dict[str, object]: + """获取聊天历史记录。""" + del user_id + target_group_id = group_id or WEBUI_CHAT_GROUP_ID + history = chat_history.get_history(limit, target_group_id) + return {"success": True, "messages": history, "total": len(history)} + + +@router.get("/platforms") +async def get_available_platforms() -> dict[str, object]: + """获取可用平台列表。""" + try: + with get_db_session() as session: + statement = ( + select(PersonInfo.platform, func.count().label("count")) + .group_by(PersonInfo.platform) + .order_by(func.count().desc()) + ) + platforms = session.exec(statement).all() + + result = [{"platform": platform, "count": count} for platform, count in platforms if platform] + return {"success": True, "platforms": result} + except Exception as e: + logger.error(f"获取平台列表失败: {e}") + return {"success": False, "error": str(e), "platforms": []} + + +@router.get("/persons") +async def get_persons_by_platform( + platform: str = Query(..., description="平台名称"), + search: Optional[str] = Query(default=None, description="搜索关键词"), + limit: int = Query(default=50, ge=1, le=200), +) -> dict[str, object]: + """获取指定平台的用户列表。""" + try: + statement = select(PersonInfo).where(col(PersonInfo.platform) == platform) + if search: + statement = statement.where( + (col(PersonInfo.person_name).contains(search)) + | (col(PersonInfo.user_nickname).contains(search)) + | (col(PersonInfo.user_id).contains(search)) + ) + + statement = statement.order_by( + case((col(PersonInfo.last_known_time).is_(None), 1), else_=0), + col(PersonInfo.last_known_time).desc(), + ).limit(limit) + + with get_db_session() as session: + persons = session.exec(statement).all() + + result = [ + { + "person_id": person.person_id, + "user_id": person.user_id, + "person_name": person.person_name, + "nickname": person.user_nickname, + "is_known": person.is_known, + "platform": person.platform, + "display_name": person.person_name or person.user_nickname or person.user_id, + } + for person in persons + ] + return {"success": True, "persons": result, "total": len(result)} + except Exception as e: + logger.error(f"获取用户列表失败: {e}") + return {"success": False, "error": str(e), "persons": []} + + +@router.delete("/history") +async def clear_chat_history( + group_id: Optional[str] = Query(default=None), +) -> dict[str, object]: + """清空聊天历史记录。""" + deleted = chat_history.clear_history(group_id) + return {"success": True, "message": f"已清空 {deleted} 条聊天记录"} + + +@router.websocket("/ws") +async def websocket_chat( + websocket: WebSocket, + user_id: Optional[str] = Query(default=None), + user_name: Optional[str] = Query(default="WebUI用户"), + platform: Optional[str] = Query(default=None), + person_id: Optional[str] = Query(default=None), + group_name: Optional[str] = Query(default=None), + group_id: Optional[str] = Query(default=None), + token: Optional[str] = Query(default=None), +) -> None: + """WebSocket 聊天端点。""" + if not await authenticate_chat_websocket(websocket, token): + logger.warning("聊天 WebSocket 连接被拒绝:认证失败") + await websocket.close(code=4001, reason="认证失败,请重新登录") + return + + session_id = str(uuid.uuid4()) + normalized_user_id = normalize_webui_user_id(user_id) + current_user_name = user_name or "WebUI用户" + current_virtual_config = resolve_initial_virtual_identity(platform, person_id, group_name, group_id) + + await chat_manager.connect(websocket, session_id, normalized_user_id) + try: + await send_initial_chat_state( + session_id=session_id, + user_id=normalized_user_id, + user_name=current_user_name, + virtual_config=current_virtual_config, + ) + + while True: + data = await websocket.receive_json() + current_user_name, current_virtual_config = await dispatch_chat_event( + session_id=session_id, + session_id_prefix=session_id[:8], + data=data, + current_user_name=current_user_name, + normalized_user_id=normalized_user_id, + current_virtual_config=current_virtual_config, + ) + except WebSocketDisconnect: + logger.info(f"WebSocket 断开: session={session_id}, user={normalized_user_id}") + except Exception as e: + logger.error(f"WebSocket 错误: {e}") + finally: + chat_manager.disconnect(session_id, normalized_user_id) + + +@router.get("/info") +async def get_chat_info() -> dict[str, object]: + """获取聊天室信息。""" + return { + "bot_name": global_config.bot.nickname, + "platform": WEBUI_CHAT_PLATFORM, + "group_id": WEBUI_CHAT_GROUP_ID, + "active_sessions": len(chat_manager.active_connections), + } \ No newline at end of file diff --git a/src/webui/routers/chat/support.py b/src/webui/routers/chat/support.py new file mode 100644 index 00000000..507abcf5 --- /dev/null +++ b/src/webui/routers/chat/support.py @@ -0,0 +1,614 @@ +"""WebUI 聊天路由支持逻辑。""" + +from typing import Any, Optional, cast + +import time +import uuid + +from fastapi import WebSocket +from pydantic import BaseModel +from sqlmodel import col, delete, select + +from src.chat.message_receive.bot import chat_bot +from src.chat.message_receive.message import SessionMessage +from src.common.database.database import get_db_session +from src.common.database.database_model import Messages, PersonInfo +from src.common.logger import get_logger +from src.common.message_repository import find_messages +from src.common.utils.system_utils import is_bot_self +from src.common.utils.utils_session import SessionUtils +from src.config.config import global_config +from src.webui.core import get_token_manager +from src.webui.routers.websocket.auth import verify_ws_token + +logger = get_logger("webui.chat") + +WEBUI_CHAT_GROUP_ID = "webui_local_chat" +WEBUI_CHAT_PLATFORM = "webui" +VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_" +WEBUI_USER_ID_PREFIX = "webui_user_" + + +class VirtualIdentityConfig(BaseModel): + """虚拟身份配置。""" + + enabled: bool = False + platform: Optional[str] = None + person_id: Optional[str] = None + user_id: Optional[str] = None + user_nickname: Optional[str] = None + group_id: Optional[str] = None + group_name: Optional[str] = None + + +class ChatHistoryMessage(BaseModel): + """聊天历史消息。""" + + id: str + type: str + content: str + timestamp: float + sender_name: str + sender_id: Optional[str] = None + is_bot: bool = False + + +class ChatHistoryManager: + """聊天历史管理器。""" + + def __init__(self, max_messages: int = 200) -> None: + self.max_messages = max_messages + + def _message_to_dict(self, msg: SessionMessage, group_id: Optional[str] = None) -> dict[str, Any]: + user_info = msg.message_info.user_info + user_id = user_info.user_id or "" + is_bot = is_bot_self(user_id, msg.platform) + + if not is_bot and group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX): + is_bot = user_id == str(global_config.bot.qq_account) + elif not is_bot: + is_bot = not user_id.startswith(WEBUI_USER_ID_PREFIX) + + return { + "id": msg.message_id, + "type": "bot" if is_bot else "user", + "content": msg.processed_plain_text or msg.display_message or "", + "timestamp": msg.timestamp.timestamp(), + "sender_name": user_info.user_nickname or (global_config.bot.nickname if is_bot else "未知用户"), + "sender_id": "bot" if is_bot else user_id, + "is_bot": is_bot, + } + + def _resolve_session_id(self, group_id: Optional[str]) -> str: + target_group_id = group_id or WEBUI_CHAT_GROUP_ID + return SessionUtils.calculate_session_id(WEBUI_CHAT_PLATFORM, group_id=target_group_id) + + def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> list[dict[str, Any]]: + target_group_id = group_id or WEBUI_CHAT_GROUP_ID + session_id = self._resolve_session_id(target_group_id) + try: + messages = find_messages( + session_id=session_id, + limit=limit, + limit_mode="latest", + filter_command=False, + ) + result = [self._message_to_dict(msg, target_group_id) for msg in messages] + logger.debug(f"从数据库加载了 {len(result)} 条聊天记录 (group_id={target_group_id})") + return result + except Exception as e: + logger.error(f"从数据库加载聊天记录失败: {e}") + return [] + + def clear_history(self, group_id: Optional[str] = None) -> int: + target_group_id = group_id or WEBUI_CHAT_GROUP_ID + session_id = self._resolve_session_id(target_group_id) + try: + with get_db_session() as session: + statement = delete(Messages).where(col(Messages.session_id) == session_id) + result = session.exec(statement) + deleted = result.rowcount or 0 + logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})") + return deleted + except Exception as e: + logger.error(f"清空聊天记录失败: {e}") + return 0 + + +class ChatConnectionManager: + """聊天连接管理器。""" + + def __init__(self) -> None: + self.active_connections: dict[str, WebSocket] = {} + self.user_sessions: dict[str, str] = {} + + async def connect(self, websocket: WebSocket, session_id: str, user_id: str) -> None: + await websocket.accept() + self.active_connections[session_id] = websocket + self.user_sessions[user_id] = session_id + logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}") + + def disconnect(self, session_id: str, user_id: str) -> None: + if session_id in self.active_connections: + del self.active_connections[session_id] + if user_id in self.user_sessions and self.user_sessions[user_id] == session_id: + del self.user_sessions[user_id] + logger.info(f"WebUI 聊天会话已断开: session={session_id}") + + async def send_message(self, session_id: str, message: dict[str, Any]) -> None: + if session_id in self.active_connections: + try: + await self.active_connections[session_id].send_json(message) + except Exception as e: + logger.error(f"发送消息失败: {e}") + + async def broadcast(self, message: dict[str, Any]) -> None: + for session_id in list(self.active_connections.keys()): + await self.send_message(session_id, message) + + +chat_history = ChatHistoryManager() +chat_manager = ChatConnectionManager() + + +def is_virtual_mode_enabled(virtual_config: Optional[VirtualIdentityConfig]) -> bool: + return bool(virtual_config and virtual_config.enabled) + + +async def authenticate_chat_websocket(websocket: WebSocket, token: Optional[str]) -> bool: + if token and verify_ws_token(token): + logger.debug("聊天 WebSocket 使用临时 token 认证成功") + return True + + if cookie_token := websocket.cookies.get("maibot_session"): + token_manager = get_token_manager() + if token_manager.verify_token(cookie_token): + logger.debug("聊天 WebSocket 使用 Cookie 认证成功") + return True + + return False + + +def normalize_webui_user_id(user_id: Optional[str]) -> str: + if not user_id: + return f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}" + if user_id.startswith(WEBUI_USER_ID_PREFIX): + return user_id + return f"{WEBUI_USER_ID_PREFIX}{user_id}" + + +def get_person_by_person_id(person_id: str) -> Optional[PersonInfo]: + with get_db_session() as session: + statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1) + return session.exec(statement).first() + + +def build_virtual_identity_config(person: PersonInfo, group_id: str, group_name: str) -> VirtualIdentityConfig: + return VirtualIdentityConfig( + enabled=True, + platform=person.platform, + person_id=person.person_id, + user_id=person.user_id, + user_nickname=person.person_name or person.user_nickname or person.user_id, + group_id=group_id, + group_name=group_name, + ) + + +def resolve_initial_virtual_identity( + platform: Optional[str], + person_id: Optional[str], + group_name: Optional[str], + group_id: Optional[str], +) -> Optional[VirtualIdentityConfig]: + if not (platform and person_id): + return None + + try: + person = get_person_by_person_id(person_id) + if person is None: + return None + + virtual_group_id = group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{platform}_{person.user_id}" + virtual_config = build_virtual_identity_config( + person=person, + group_id=virtual_group_id, + group_name=group_name or "WebUI虚拟群聊", + ) + logger.info( + f"虚拟身份模式已通过 URL 参数激活: {virtual_config.user_nickname} @ {virtual_config.platform}, group_id={virtual_group_id}" + ) + return virtual_config + except Exception as e: + logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}") + return None + + +def build_session_info_message( + session_id: str, + user_id: str, + user_name: str, + virtual_config: Optional[VirtualIdentityConfig], +) -> dict[str, Any]: + session_info_data: dict[str, Any] = { + "type": "session_info", + "session_id": session_id, + "user_id": user_id, + "user_name": user_name, + "bot_name": global_config.bot.nickname, + } + + if is_virtual_mode_enabled(virtual_config): + assert virtual_config is not None + session_info_data["virtual_mode"] = True + session_info_data["group_id"] = virtual_config.group_id + session_info_data["virtual_identity"] = { + "platform": virtual_config.platform, + "user_id": virtual_config.user_id, + "user_nickname": virtual_config.user_nickname, + "group_name": virtual_config.group_name, + } + + return session_info_data + + +def get_active_history_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> Optional[str]: + if is_virtual_mode_enabled(virtual_config): + assert virtual_config is not None + return virtual_config.group_id + return None + + +def build_welcome_message(virtual_config: Optional[VirtualIdentityConfig]) -> str: + if is_virtual_mode_enabled(virtual_config): + assert virtual_config is not None + return ( + f"已以 {virtual_config.user_nickname} 的身份连接到「{virtual_config.group_name}」," + f"开始与 {global_config.bot.nickname} 对话吧!" + ) + return f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!" + + +async def send_chat_error(session_id: str, content: str) -> None: + await chat_manager.send_message( + session_id, + { + "type": "error", + "content": content, + "timestamp": time.time(), + }, + ) + + +async def send_initial_chat_state( + session_id: str, + user_id: str, + user_name: str, + virtual_config: Optional[VirtualIdentityConfig], +) -> None: + await chat_manager.send_message( + session_id, + build_session_info_message( + session_id=session_id, + user_id=user_id, + user_name=user_name, + virtual_config=virtual_config, + ), + ) + + if history := chat_history.get_history(50, get_active_history_group_id(virtual_config)): + await chat_manager.send_message( + session_id, + { + "type": "history", + "messages": history, + }, + ) + + await chat_manager.send_message( + session_id, + { + "type": "system", + "content": build_welcome_message(virtual_config), + "timestamp": time.time(), + }, + ) + + +def resolve_sender_identity( + current_user_name: str, + normalized_user_id: str, + virtual_config: Optional[VirtualIdentityConfig], +) -> tuple[str, str]: + if is_virtual_mode_enabled(virtual_config): + assert virtual_config is not None + return virtual_config.user_nickname or current_user_name, virtual_config.user_id or normalized_user_id + return current_user_name, normalized_user_id + + +def create_message_data( + content: str, + user_id: str, + user_name: str, + message_id: Optional[str] = None, + is_at_bot: bool = True, + virtual_config: Optional[VirtualIdentityConfig] = None, +) -> dict[str, Any]: + if message_id is None: + message_id = str(uuid.uuid4()) + + if virtual_config and virtual_config.enabled: + platform = virtual_config.platform or WEBUI_CHAT_PLATFORM + group_id = virtual_config.group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{uuid.uuid4().hex[:8]}" + group_name = virtual_config.group_name or "WebUI虚拟群聊" + actual_user_id = virtual_config.user_id or user_id + actual_user_name = virtual_config.user_nickname or user_name + else: + platform = WEBUI_CHAT_PLATFORM + group_id = WEBUI_CHAT_GROUP_ID + group_name = "WebUI本地聊天室" + actual_user_id = user_id + actual_user_name = user_name + + return { + "message_info": { + "platform": platform, + "message_id": message_id, + "time": time.time(), + "group_info": { + "group_id": group_id, + "group_name": group_name, + "platform": platform, + }, + "user_info": { + "user_id": actual_user_id, + "user_nickname": actual_user_name, + "user_cardname": actual_user_name, + "platform": platform, + }, + "additional_config": { + "at_bot": is_at_bot, + }, + }, + "message_segment": { + "type": "seglist", + "data": [ + { + "type": "text", + "data": content, + }, + { + "type": "mention_bot", + "data": "1.0", + }, + ], + }, + "raw_message": content, + "processed_plain_text": content, + } + + +async def handle_chat_message( + session_id: str, + data: dict[str, Any], + current_user_name: str, + normalized_user_id: str, + current_virtual_config: Optional[VirtualIdentityConfig], +) -> str: + content = str(data.get("content", "")).strip() + if not content: + return current_user_name + + next_user_name = str(data.get("user_name", current_user_name)) + message_id = str(uuid.uuid4()) + timestamp = time.time() + sender_name, sender_user_id = resolve_sender_identity( + current_user_name=next_user_name, + normalized_user_id=normalized_user_id, + virtual_config=current_virtual_config, + ) + + await chat_manager.broadcast( + { + "type": "user_message", + "content": content, + "message_id": message_id, + "timestamp": timestamp, + "sender": { + "name": sender_name, + "user_id": sender_user_id, + "is_bot": False, + }, + "virtual_mode": is_virtual_mode_enabled(current_virtual_config), + } + ) + + message_data = create_message_data( + content=content, + user_id=normalized_user_id, + user_name=next_user_name, + message_id=message_id, + is_at_bot=True, + virtual_config=current_virtual_config, + ) + + try: + await chat_manager.broadcast({"type": "typing", "is_typing": True}) + await chat_bot.message_process(message_data) + except Exception as e: + logger.error(f"处理消息时出错: {e}") + await send_chat_error(session_id, f"处理消息时出错: {str(e)}") + finally: + await chat_manager.broadcast({"type": "typing", "is_typing": False}) + + return next_user_name + + +async def handle_chat_ping(session_id: str) -> None: + await chat_manager.send_message(session_id, {"type": "pong", "timestamp": time.time()}) + + +async def handle_nickname_update(session_id: str, data: dict[str, Any], current_user_name: str) -> str: + new_name = str(data.get("user_name", "")).strip() + if not new_name: + return current_user_name + + await chat_manager.send_message( + session_id, + { + "type": "nickname_updated", + "user_name": new_name, + "timestamp": time.time(), + }, + ) + return new_name + + +async def enable_virtual_identity( + session_id: str, + session_prefix: str, + virtual_data: dict[str, Any], +) -> Optional[VirtualIdentityConfig]: + if not virtual_data.get("platform") or not virtual_data.get("person_id"): + await send_chat_error(session_id, "虚拟身份配置缺少必要字段: platform 和 person_id") + return None + + person_id_value = str(virtual_data.get("person_id")) + try: + person = get_person_by_person_id(person_id_value) + if not person: + await send_chat_error(session_id, f"找不到用户: {person_id_value}") + return None + + custom_group_id = virtual_data.get("group_id") + current_group_id = ( + f"{VIRTUAL_GROUP_ID_PREFIX}{custom_group_id}" + if custom_group_id + else f"{VIRTUAL_GROUP_ID_PREFIX}{session_prefix}" + ) + current_virtual_config = build_virtual_identity_config( + person=person, + group_id=current_group_id, + group_name=str(virtual_data.get("group_name", "WebUI虚拟群聊")), + ) + + await chat_manager.send_message( + session_id, + { + "type": "virtual_identity_set", + "config": { + "enabled": True, + "platform": current_virtual_config.platform, + "user_id": current_virtual_config.user_id, + "user_nickname": current_virtual_config.user_nickname, + "group_id": current_virtual_config.group_id, + "group_name": current_virtual_config.group_name, + }, + "timestamp": time.time(), + }, + ) + await chat_manager.send_message( + session_id, + { + "type": "history", + "messages": chat_history.get_history(50, current_virtual_config.group_id), + "group_id": current_virtual_config.group_id, + }, + ) + await chat_manager.send_message( + session_id, + { + "type": "system", + "content": ( + f"已切换到虚拟身份模式:以 {current_virtual_config.user_nickname} 的身份在" + f"「{current_virtual_config.group_name}」与 {global_config.bot.nickname} 对话" + ), + "timestamp": time.time(), + }, + ) + return current_virtual_config + except Exception as e: + logger.error(f"设置虚拟身份失败: {e}") + await send_chat_error(session_id, f"设置虚拟身份失败: {str(e)}") + return None + + +async def disable_virtual_identity(session_id: str) -> None: + await chat_manager.send_message( + session_id, + { + "type": "virtual_identity_set", + "config": {"enabled": False}, + "timestamp": time.time(), + }, + ) + await chat_manager.send_message( + session_id, + { + "type": "history", + "messages": chat_history.get_history(50, WEBUI_CHAT_GROUP_ID), + "group_id": WEBUI_CHAT_GROUP_ID, + }, + ) + await chat_manager.send_message( + session_id, + { + "type": "system", + "content": "已切换回 WebUI 独立用户模式", + "timestamp": time.time(), + }, + ) + + +async def handle_virtual_identity_update( + session_id: str, + session_id_prefix: str, + data: dict[str, Any], + current_virtual_config: Optional[VirtualIdentityConfig], +) -> Optional[VirtualIdentityConfig]: + virtual_data = cast(dict[str, Any], data.get("config", {})) + if virtual_data.get("enabled"): + next_config = await enable_virtual_identity(session_id, session_id_prefix, virtual_data) + return next_config if next_config is not None else current_virtual_config + + await disable_virtual_identity(session_id) + return None + + +async def dispatch_chat_event( + session_id: str, + session_id_prefix: str, + data: dict[str, Any], + current_user_name: str, + normalized_user_id: str, + current_virtual_config: Optional[VirtualIdentityConfig], +) -> tuple[str, Optional[VirtualIdentityConfig]]: + event_type = data.get("type") + if event_type == "message": + next_user_name = await handle_chat_message( + session_id=session_id, + data=data, + current_user_name=current_user_name, + normalized_user_id=normalized_user_id, + current_virtual_config=current_virtual_config, + ) + return next_user_name, current_virtual_config + + if event_type == "ping": + await handle_chat_ping(session_id) + return current_user_name, current_virtual_config + + if event_type == "update_nickname": + next_user_name = await handle_nickname_update(session_id, data, current_user_name) + return next_user_name, current_virtual_config + + if event_type == "set_virtual_identity": + next_virtual_config = await handle_virtual_identity_update( + session_id=session_id, + session_id_prefix=session_id_prefix, + data=data, + current_virtual_config=current_virtual_config, + ) + return current_user_name, next_virtual_config + + return current_user_name, current_virtual_config \ No newline at end of file diff --git a/src/webui/routers/config.py b/src/webui/routers/config.py index bfc47703..fa196af3 100644 --- a/src/webui/routers/config.py +++ b/src/webui/routers/config.py @@ -4,12 +4,13 @@ import copy import os -import tomlkit -from fastapi import APIRouter, HTTPException, Body, Depends, Cookie, Header from typing import Any, Annotated, Optional +import tomlkit +from fastapi import APIRouter, Body, Depends, HTTPException + from src.common.logger import get_logger -from src.webui.core import verify_auth_token_from_cookie_or_header +from src.webui.dependencies import require_auth from src.webui.utils.toml_utils import save_toml_with_format, _update_toml_doc from src.config.config import Config, ModelConfig, CONFIG_DIR, PROJECT_ROOT from src.config.config_base import AttributeData @@ -49,7 +50,7 @@ SectionBody = Annotated[Any, Body()] RawContentBody = Annotated[str, Body(embed=True)] PathBody = Annotated[dict[str, str], Body()] -router = APIRouter(prefix="/config", tags=["config"]) +router = APIRouter(prefix="/config", tags=["config"], dependencies=[Depends(require_auth)]) def _toml_to_plain_dict(obj: Any) -> Any: @@ -59,21 +60,11 @@ def _toml_to_plain_dict(obj: Any) -> Any: if isinstance(obj, list): return [_toml_to_plain_dict(v) for v in obj] return obj - - -def require_auth( - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -) -> bool: - """认证依赖:验证用户是否已登录""" - return verify_auth_token_from_cookie_or_header(maibot_session, authorization) - - # ===== 架构获取接口 ===== @router.get("/schema/bot") -async def get_bot_config_schema(_auth: bool = Depends(require_auth)): +async def get_bot_config_schema(): """获取麦麦主程序配置架构""" try: # Config 类包含所有子配置 @@ -85,7 +76,7 @@ async def get_bot_config_schema(_auth: bool = Depends(require_auth)): @router.get("/schema/model") -async def get_model_config_schema(_auth: bool = Depends(require_auth)): +async def get_model_config_schema(): """获取模型配置架构(包含提供商和模型任务配置)""" try: schema = ConfigSchemaGenerator.generate_config_schema(ModelConfig) @@ -99,7 +90,7 @@ async def get_model_config_schema(_auth: bool = Depends(require_auth)): @router.get("/schema/section/{section_name}") -async def get_config_section_schema(section_name: str, _auth: bool = Depends(require_auth)): +async def get_config_section_schema(section_name: str): """ 获取指定配置节的架构 @@ -169,7 +160,7 @@ async def get_config_section_schema(section_name: str, _auth: bool = Depends(req @router.get("/bot") -async def get_bot_config(_auth: bool = Depends(require_auth)): +async def get_bot_config(): """获取麦麦主程序配置""" try: config_path = os.path.join(CONFIG_DIR, "bot_config.toml") @@ -188,7 +179,7 @@ async def get_bot_config(_auth: bool = Depends(require_auth)): @router.get("/model") -async def get_model_config(_auth: bool = Depends(require_auth)): +async def get_model_config(): """获取模型配置(包含提供商和模型任务配置)""" try: config_path = os.path.join(CONFIG_DIR, "model_config.toml") @@ -210,7 +201,7 @@ async def get_model_config(_auth: bool = Depends(require_auth)): @router.post("/bot") -async def update_bot_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)): +async def update_bot_config(config_data: ConfigBody): """更新麦麦主程序配置""" try: # 验证配置数据 @@ -233,7 +224,7 @@ async def update_bot_config(config_data: ConfigBody, _auth: bool = Depends(requi @router.post("/model") -async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)): +async def update_model_config(config_data: ConfigBody): """更新模型配置""" try: # 验证配置数据 @@ -259,7 +250,7 @@ async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(req @router.post("/bot/section/{section_name}") -async def update_bot_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)): +async def update_bot_config_section(section_name: str, section_data: SectionBody): """更新麦麦主程序配置的指定节(保留注释和格式)""" try: # 读取现有配置 @@ -308,7 +299,7 @@ async def update_bot_config_section(section_name: str, section_data: SectionBody @router.get("/bot/raw") -async def get_bot_config_raw(_auth: bool = Depends(require_auth)): +async def get_bot_config_raw(): """获取麦麦主程序配置的原始 TOML 内容""" try: config_path = os.path.join(CONFIG_DIR, "bot_config.toml") @@ -327,7 +318,7 @@ async def get_bot_config_raw(_auth: bool = Depends(require_auth)): @router.post("/bot/raw") -async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depends(require_auth)): +async def update_bot_config_raw(raw_content: RawContentBody): """更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)""" try: # 验证 TOML 格式 @@ -357,9 +348,7 @@ async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depen @router.post("/model/section/{section_name}") -async def update_model_config_section( - section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth) -): +async def update_model_config_section(section_name: str, section_data: SectionBody): """更新模型配置的指定节(保留注释和格式)""" try: # 读取现有配置 @@ -451,7 +440,7 @@ def _to_relative_path(path: str) -> str: @router.get("/adapter-config/path") -async def get_adapter_config_path(_auth: bool = Depends(require_auth)): +async def get_adapter_config_path(): """获取保存的适配器配置文件路径""" try: # 从 data/webui.json 读取路径偏好 @@ -490,7 +479,7 @@ async def get_adapter_config_path(_auth: bool = Depends(require_auth)): @router.post("/adapter-config/path") -async def save_adapter_config_path(data: PathBody, _auth: bool = Depends(require_auth)): +async def save_adapter_config_path(data: PathBody): """保存适配器配置文件路径偏好""" try: path = data.get("path") @@ -533,7 +522,7 @@ async def save_adapter_config_path(data: PathBody, _auth: bool = Depends(require @router.get("/adapter-config") -async def get_adapter_config(path: str, _auth: bool = Depends(require_auth)): +async def get_adapter_config(path: str): """从指定路径读取适配器配置文件""" try: if not path: @@ -565,7 +554,7 @@ async def get_adapter_config(path: str, _auth: bool = Depends(require_auth)): @router.post("/adapter-config") -async def save_adapter_config(data: PathBody, _auth: bool = Depends(require_auth)): +async def save_adapter_config(data: PathBody): """保存适配器配置到指定路径""" try: path = data.get("path") diff --git a/src/webui/routers/emoji/__init__.py b/src/webui/routers/emoji/__init__.py new file mode 100644 index 00000000..f8cdb40b --- /dev/null +++ b/src/webui/routers/emoji/__init__.py @@ -0,0 +1,3 @@ +from .routes import router + +__all__ = ["router"] \ No newline at end of file diff --git a/src/webui/routers/emoji.py b/src/webui/routers/emoji/routes.py similarity index 57% rename from src/webui/routers/emoji.py rename to src/webui/routers/emoji/routes.py index 4f882d09..6081dd14 100644 --- a/src/webui/routers/emoji.py +++ b/src/webui/routers/emoji/routes.py @@ -1,282 +1,60 @@ """表情包管理 API 路由""" -from concurrent.futures import ThreadPoolExecutor from datetime import datetime from pathlib import Path -from typing import Annotated, Any, List, Optional +from typing import Any, Optional import asyncio import hashlib import io import os -import threading -from fastapi import APIRouter, Cookie, File, Form, Header, HTTPException, Query, UploadFile +from fastapi import APIRouter, Cookie, HTTPException, Query from fastapi.responses import FileResponse, JSONResponse -from pydantic import BaseModel from PIL import Image from sqlalchemy import func from sqlmodel import col, select from src.common.database.database import get_db_session from src.common.database.database_model import Images, ImageType -from src.common.logger import get_logger -from src.webui.core import get_token_manager, verify_auth_token_from_cookie_or_header +from src.webui.core import get_token_manager, verify_auth_token_from_cookie_or_header as verify_auth_token -logger = get_logger("webui.emoji") +from .schemas import ( + BatchDeleteRequest, + BatchDeleteResponse, + DescriptionForm, + EmojiDeleteResponse, + EmojiDetailResponse, + EmojiFile, + EmojiFiles, + EmojiListResponse, + EmojiUpdateRequest, + EmojiUpdateResponse, + EmojiUploadResponse, + EmotionForm, + IsRegisteredForm, + ThumbnailCacheStatsResponse, + ThumbnailCleanupResponse, + ThumbnailPreheatResponse, + emoji_to_response, +) +from .support import ( + EMOJI_REGISTERED_DIR, + THUMBNAIL_CACHE_DIR, + background_generate_thumbnail, + cleanup_orphaned_thumbnails, + ensure_thumbnail_cache_dir, + generate_thumbnail, + get_generating_lock, + get_generating_thumbnails, + get_thumbnail_cache_path, + get_thumbnail_executor, + logger, +) -# ==================== 缩略图缓存配置 ==================== -# 缩略图缓存目录 -THUMBNAIL_CACHE_DIR = Path("data/emoji_thumbnails") -# 缩略图尺寸 (宽, 高) -THUMBNAIL_SIZE = (200, 200) -# 缩略图质量 (WebP 格式, 1-100) -THUMBNAIL_QUALITY = 80 -# 缓存锁,防止并发生成同一缩略图 -_thumbnail_locks: dict[str, threading.Lock] = {} -_locks_lock = threading.Lock() -# 缩略图生成专用线程池(避免阻塞事件循环) -_thumbnail_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="thumbnail") -# 正在生成中的缩略图哈希集合(防止重复提交任务) -_generating_thumbnails: set[str] = set() -_generating_lock = threading.Lock() - - -def _get_thumbnail_lock(file_hash: str) -> threading.Lock: - """获取指定文件哈希的锁,用于防止并发生成同一缩略图""" - with _locks_lock: - if file_hash not in _thumbnail_locks: - _thumbnail_locks[file_hash] = threading.Lock() - return _thumbnail_locks[file_hash] - - -def _background_generate_thumbnail(source_path: str, file_hash: str) -> None: - """ - 后台生成缩略图(在线程池中执行) - - 生成完成后自动从 generating 集合中移除 - """ - try: - _generate_thumbnail(source_path, file_hash) - except Exception as e: - logger.warning(f"后台生成缩略图失败 {file_hash}: {e}") - finally: - with _generating_lock: - _generating_thumbnails.discard(file_hash) - - -def _ensure_thumbnail_cache_dir() -> Path: - """确保缩略图缓存目录存在""" - _ = THUMBNAIL_CACHE_DIR.mkdir(parents=True, exist_ok=True) - return THUMBNAIL_CACHE_DIR - - -def _get_thumbnail_cache_path(file_hash: str) -> Path: - """获取缩略图缓存路径""" - return THUMBNAIL_CACHE_DIR / f"{file_hash}.webp" - - -def _generate_thumbnail(source_path: str, file_hash: str) -> Path: - """ - 生成缩略图并保存到缓存目录 - - Args: - source_path: 原图路径 - file_hash: 文件哈希值,用作缓存文件名 - - Returns: - 缩略图路径 - - Features: - - GIF: 提取第一帧作为缩略图 - - 所有格式统一转为 WebP - - 保持宽高比缩放 - """ - _ensure_thumbnail_cache_dir() - cache_path = _get_thumbnail_cache_path(file_hash) - - # 使用锁防止并发生成同一缩略图 - lock = _get_thumbnail_lock(file_hash) - with lock: - # 双重检查,可能在等待锁时已被其他线程生成 - if cache_path.exists(): - return cache_path - - try: - with Image.open(source_path) as img: - # GIF 处理:提取第一帧 - if getattr(img, "n_frames", 1) > 1: - img.seek(0) # 确保在第一帧 - - # 转换为 RGB/RGBA(WebP 支持透明度) - if img.mode in ("P", "PA"): - # 调色板模式转换为 RGBA 以保留透明度 - img = img.convert("RGBA") - elif img.mode == "LA": - img = img.convert("RGBA") - elif img.mode not in ("RGB", "RGBA"): - img = img.convert("RGB") - - # 创建缩略图(保持宽高比) - img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS) - - # 保存为 WebP 格式 - img.save(cache_path, "WEBP", quality=THUMBNAIL_QUALITY, method=6) - - logger.debug(f"生成缩略图: {file_hash} -> {cache_path}") - - except Exception as e: - logger.warning(f"生成缩略图失败 {file_hash}: {e},将返回原图") - # 生成失败时不创建缓存文件,下次会重试 - raise - - return cache_path - - -def cleanup_orphaned_thumbnails() -> tuple[int, int]: - """ - 清理孤立的缩略图缓存(原图已不存在的缩略图) - - Returns: - (清理数量, 保留数量) - """ - if not THUMBNAIL_CACHE_DIR.exists(): - return 0, 0 - - # 获取所有表情包的哈希值 - with get_db_session() as session: - statement = select(Images.image_hash).where(col(Images.image_type) == ImageType.EMOJI) - valid_hashes = set(session.exec(statement).all()) - - cleaned = 0 - kept = 0 - - for cache_file in THUMBNAIL_CACHE_DIR.glob("*.webp"): - file_hash = cache_file.stem - if file_hash not in valid_hashes: - try: - cache_file.unlink() - cleaned += 1 - logger.debug(f"清理孤立缩略图: {cache_file.name}") - except Exception as e: - logger.warning(f"清理缩略图失败 {cache_file.name}: {e}") - else: - kept += 1 - - if cleaned > 0: - logger.info(f"清理孤立缩略图: 删除 {cleaned} 个,保留 {kept} 个") - - return cleaned, kept - - -# 模块级别的类型别名(解决 B008 ruff 错误) -EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")] -EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")] -DescriptionForm = Annotated[str, Form(description="表情包描述")] -EmotionForm = Annotated[str, Form(description="情感标签,多个用逗号分隔")] -IsRegisteredForm = Annotated[bool, Form(description="是否直接注册")] - -# 创建路由器 router = APIRouter(prefix="/emoji", tags=["Emoji"]) -class EmojiResponse(BaseModel): - """表情包响应""" - - id: int - full_path: str - emoji_hash: str - description: str - query_count: int - is_registered: bool - is_banned: bool - emotion: Optional[str] # 直接返回字符串 - record_time: float - register_time: Optional[float] - last_used_time: Optional[float] - - -class EmojiListResponse(BaseModel): - """表情包列表响应""" - - success: bool - total: int - page: int - page_size: int - data: List[EmojiResponse] - - -class EmojiDetailResponse(BaseModel): - """表情包详情响应""" - - success: bool - data: EmojiResponse - - -class EmojiUpdateRequest(BaseModel): - """表情包更新请求""" - - description: Optional[str] = None - is_registered: Optional[bool] = None - is_banned: Optional[bool] = None - emotion: Optional[str] = None - - -class EmojiUpdateResponse(BaseModel): - """表情包更新响应""" - - success: bool - message: str - data: Optional[EmojiResponse] = None - - -class EmojiDeleteResponse(BaseModel): - """表情包删除响应""" - - success: bool - message: str - - -class BatchDeleteRequest(BaseModel): - """批量删除请求""" - - emoji_ids: List[int] - - -class BatchDeleteResponse(BaseModel): - """批量删除响应""" - - success: bool - message: str - deleted_count: int - failed_count: int - failed_ids: List[int] = [] - - -def verify_auth_token( - maibot_session: Optional[str] = None, - authorization: Optional[str] = None, -) -> bool: - """验证认证 Token,支持 Cookie 和 Header""" - return verify_auth_token_from_cookie_or_header(maibot_session, authorization) - - -def emoji_to_response(image: Images) -> EmojiResponse: - return EmojiResponse( - id=image.id if image.id is not None else 0, - full_path=image.full_path, - emoji_hash=image.image_hash, - description=image.description, - query_count=image.query_count, - is_registered=image.is_registered, - is_banned=image.is_banned, - emotion=image.emotion, - record_time=image.record_time.timestamp() if image.record_time else 0.0, - register_time=image.register_time.timestamp() if image.register_time else None, - last_used_time=image.last_used_time.timestamp() if image.last_used_time else None, - ) - - @router.get("/list", response_model=EmojiListResponse) async def get_emoji_list( page: int = Query(1, ge=1, description="页码"), @@ -287,45 +65,24 @@ async def get_emoji_list( sort_by: Optional[str] = Query("query_count", description="排序字段"), sort_order: Optional[str] = Query("desc", description="排序方向"), maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -): - """ - 获取表情包列表 - - Args: - page: 页码 (从 1 开始) - page_size: 每页数量 (1-100) - search: 搜索关键词 (匹配 description, emoji_hash) - is_registered: 是否已注册筛选 - is_banned: 是否被禁用筛选 - sort_by: 排序字段 (query_count, register_time, record_time, last_used_time) - sort_order: 排序方向 (asc, desc) - authorization: Authorization header - - Returns: - 表情包列表 - """ +) -> EmojiListResponse: + """获取表情包列表。""" try: - verify_auth_token(maibot_session, authorization) + verify_auth_token(maibot_session) - # 构建查询 statement = select(Images).where(col(Images.image_type) == ImageType.EMOJI) - # 搜索过滤 if search: statement = statement.where( (col(Images.description).contains(search)) | (col(Images.image_hash).contains(search)) ) - # 注册状态过滤 if is_registered is not None: statement = statement.where(col(Images.is_registered) == is_registered) - # 禁用状态过滤 if is_banned is not None: statement = statement.where(col(Images.is_banned) == is_banned) - # 排序字段映射 sort_field_map = { "usage_count": col(Images.query_count), "query_count": col(Images.query_count), @@ -333,18 +90,9 @@ async def get_emoji_list( "record_time": col(Images.record_time), "last_used_time": col(Images.last_used_time), } + sort_field = sort_field_map.get(sort_by or "query_count", col(Images.query_count)) + statement = statement.order_by(sort_field.asc() if sort_order == "asc" else sort_field.desc()) - # 获取排序字段,默认使用 usage_count - sort_key = sort_by or "query_count" - sort_field = sort_field_map.get(sort_key, col(Images.query_count)) - - # 应用排序 - if sort_order == "asc": - statement = statement.order_by(sort_field.asc()) - else: - statement = statement.order_by(sort_field.desc()) - - # 分页 offset = (page - 1) * page_size statement = statement.offset(offset).limit(page_size) @@ -362,11 +110,13 @@ async def get_emoji_list( count_statement = count_statement.where(col(Images.is_banned) == is_banned) total = session.exec(count_statement).one() - # 转换为响应对象 - data = [emoji_to_response(emoji) for emoji in emojis] - - return EmojiListResponse(success=True, total=total, page=page, page_size=page_size, data=data) - + return EmojiListResponse( + success=True, + total=total, + page=page, + page_size=page_size, + data=[emoji_to_response(emoji) for emoji in emojis], + ) except HTTPException: raise except Exception as e: @@ -375,34 +125,20 @@ async def get_emoji_list( @router.get("/{emoji_id}", response_model=EmojiDetailResponse) -async def get_emoji_detail( - emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) -): - """ - 获取表情包详细信息 - - Args: - emoji_id: 表情包ID - authorization: Authorization header - - Returns: - 表情包详细信息 - """ +async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie(None)) -> EmojiDetailResponse: + """获取表情包详细信息。""" try: - verify_auth_token(maibot_session, authorization) + verify_auth_token(maibot_session) with get_db_session() as session: statement = select(Images).where( col(Images.id) == emoji_id, col(Images.image_type) == ImageType.EMOJI, ) - emoji = session.exec(statement).first() - - if not emoji: + if not (emoji := session.exec(statement).first()): raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包") return EmojiDetailResponse(success=True, data=emoji_to_response(emoji)) - except HTTPException: raise except Exception as e: @@ -415,21 +151,10 @@ async def update_emoji( emoji_id: int, request: EmojiUpdateRequest, maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -): - """ - 增量更新表情包(只更新提供的字段) - - Args: - emoji_id: 表情包ID - request: 更新请求(只包含需要更新的字段) - authorization: Authorization header - - Returns: - 更新结果 - """ +) -> EmojiUpdateResponse: + """增量更新表情包。""" try: - verify_auth_token(maibot_session, authorization) + verify_auth_token(maibot_session) with get_db_session() as session: statement = select(Images).where( @@ -441,28 +166,24 @@ async def update_emoji( if not emoji: raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包") - # 只更新提供的字段 update_data = request.model_dump(exclude_unset=True) - if not update_data: raise HTTPException(status_code=400, detail="未提供任何需要更新的字段") - # 如果注册状态从 False 变为 True,记录注册时间 if "is_registered" in update_data and update_data["is_registered"] and not emoji.is_registered: update_data["register_time"] = datetime.now() - # 执行更新 for field, value in update_data.items(): setattr(emoji, field, value) session.add(emoji) - logger.info(f"表情包已更新: ID={emoji_id}, 字段: {list(update_data.keys())}") return EmojiUpdateResponse( - success=True, message=f"成功更新 {len(update_data)} 个字段", data=emoji_to_response(emoji) + success=True, + message=f"成功更新 {len(update_data)} 个字段", + data=emoji_to_response(emoji), ) - except HTTPException: raise except Exception as e: @@ -471,21 +192,10 @@ async def update_emoji( @router.delete("/{emoji_id}", response_model=EmojiDeleteResponse) -async def delete_emoji( - emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) -): - """ - 删除表情包 - - Args: - emoji_id: 表情包ID - authorization: Authorization header - - Returns: - 删除结果 - """ +async def delete_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None)) -> EmojiDeleteResponse: + """删除表情包。""" try: - verify_auth_token(maibot_session, authorization) + verify_auth_token(maibot_session) with get_db_session() as session: statement = select(Images).where( @@ -499,11 +209,8 @@ async def delete_emoji( emoji_hash = emoji.image_hash session.delete(emoji) - logger.info(f"表情包已删除: ID={emoji_id}, hash={emoji_hash}") - return EmojiDeleteResponse(success=True, message=f"成功删除表情包: {emoji_hash}") - except HTTPException: raise except Exception as e: @@ -512,18 +219,10 @@ async def delete_emoji( @router.get("/stats/summary") -async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): - """ - 获取表情包统计数据 - - Args: - authorization: Authorization header - - Returns: - 统计数据 - """ +async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]: + """获取表情包统计数据。""" try: - verify_auth_token(maibot_session, authorization) + verify_auth_token(maibot_session) with get_db_session() as session: total_statement = select(func.count()).select_from(Images).where(col(Images.image_type) == ImageType.EMOJI) @@ -582,7 +281,6 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authoriz "top_used": top_used_list, }, } - except HTTPException: raise except Exception as e: @@ -591,21 +289,10 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authoriz @router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse) -async def register_emoji( - emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) -): - """ - 注册表情包(快捷操作) - - Args: - emoji_id: 表情包ID - authorization: Authorization header - - Returns: - 更新结果 - """ +async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None)) -> EmojiUpdateResponse: + """注册表情包。""" try: - verify_auth_token(maibot_session, authorization) + verify_auth_token(maibot_session) with get_db_session() as session: statement = select(Images).where( @@ -616,7 +303,6 @@ async def register_emoji( if not emoji: raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包") - if emoji.is_registered: raise HTTPException(status_code=400, detail="该表情包已经注册") @@ -626,9 +312,7 @@ async def register_emoji( session.add(emoji) logger.info(f"表情包已注册: ID={emoji_id}") - return EmojiUpdateResponse(success=True, message="表情包注册成功", data=emoji_to_response(emoji)) - except HTTPException: raise except Exception as e: @@ -637,21 +321,10 @@ async def register_emoji( @router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse) -async def ban_emoji( - emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) -): - """ - 禁用表情包(快捷操作) - - Args: - emoji_id: 表情包ID - authorization: Authorization header - - Returns: - 更新结果 - """ +async def ban_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None)) -> EmojiUpdateResponse: + """禁用表情包。""" try: - verify_auth_token(maibot_session, authorization) + verify_auth_token(maibot_session) with get_db_session() as session: statement = select(Images).where( @@ -668,9 +341,7 @@ async def ban_emoji( session.add(emoji) logger.info(f"表情包已禁用: ID={emoji_id}") - return EmojiUpdateResponse(success=True, message="表情包禁用成功", data=emoji_to_response(emoji)) - except HTTPException: raise except Exception as e: @@ -678,48 +349,22 @@ async def ban_emoji( raise HTTPException(status_code=500, detail=f"禁用表情包失败: {str(e)}") from e -@router.get("/{emoji_id}/thumbnail") +@router.get("/{emoji_id}/thumbnail", response_model=None) async def get_emoji_thumbnail( emoji_id: int, token: Optional[str] = Query(None, description="访问令牌"), maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), original: bool = Query(False, description="是否返回原图"), -): - """ - 获取表情包缩略图(懒加载生成 + 缓存) - - Args: - emoji_id: 表情包ID - token: 访问令牌(通过 query parameter,用于向后兼容) - maibot_session: Cookie 中的 token - authorization: Authorization header - original: 是否返回原图(用于详情页查看原图) - - Returns: - 表情包缩略图(WebP 格式)或原图 - - Features: - - 懒加载:首次请求时生成缩略图 - - 缓存:后续请求直接返回缓存 - - GIF 支持:提取第一帧作为缩略图 - - 格式统一:所有缩略图统一为 WebP 格式 - """ +) -> FileResponse | JSONResponse: + """获取表情包缩略图。""" try: token_manager = get_token_manager() is_valid = False - # 1. 优先使用 Cookie if maibot_session and token_manager.verify_token(maibot_session): is_valid = True - # 2. 其次使用 query parameter(用于向后兼容 img 标签) elif token and token_manager.verify_token(token): is_valid = True - # 3. 最后使用 Authorization header - elif authorization and authorization.startswith("Bearer "): - auth_token = authorization.replace("Bearer ", "") - if token_manager.verify_token(auth_token): - is_valid = True if not is_valid: raise HTTPException(status_code=401, detail="Token 无效或已过期") @@ -733,7 +378,6 @@ async def get_emoji_thumbnail( if not emoji: raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包") - if not os.path.exists(emoji.full_path): raise HTTPException(status_code=404, detail="表情包文件不存在") @@ -749,20 +393,25 @@ async def get_emoji_thumbnail( suffix = Path(emoji.full_path).suffix.lower().lstrip(".") media_type = mime_types.get(suffix, "application/octet-stream") return FileResponse( - path=emoji.full_path, media_type=media_type, filename=f"{emoji.image_hash}.{suffix}" + path=emoji.full_path, + media_type=media_type, + filename=f"{emoji.image_hash}.{suffix}", ) - cache_path = _get_thumbnail_cache_path(emoji.image_hash) - + cache_path = get_thumbnail_cache_path(emoji.image_hash) if cache_path.exists(): return FileResponse( - path=str(cache_path), media_type="image/webp", filename=f"{emoji.image_hash}_thumb.webp" + path=str(cache_path), + media_type="image/webp", + filename=f"{emoji.image_hash}_thumb.webp", ) - with _generating_lock: - if emoji.image_hash not in _generating_thumbnails: - _generating_thumbnails.add(emoji.image_hash) - _thumbnail_executor.submit(_background_generate_thumbnail, emoji.full_path, emoji.image_hash) + generating_lock = get_generating_lock() + generating_thumbnails = get_generating_thumbnails() + with generating_lock: + if emoji.image_hash not in generating_thumbnails: + generating_thumbnails.add(emoji.image_hash) + get_thumbnail_executor().submit(background_generate_thumbnail, emoji.full_path, emoji.image_hash) return JSONResponse( status_code=202, @@ -771,11 +420,8 @@ async def get_emoji_thumbnail( "message": "缩略图正在生成中,请稍后重试", "emoji_id": emoji_id, }, - headers={ - "Retry-After": "1", - }, + headers={"Retry-After": "1"}, ) - except HTTPException: raise except Exception as e: @@ -787,27 +433,17 @@ async def get_emoji_thumbnail( async def batch_delete_emojis( request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -): - """ - 批量删除表情包 - - Args: - request: 包含emoji_ids列表的请求 - authorization: Authorization header - - Returns: - 批量删除结果 - """ +) -> BatchDeleteResponse: + """批量删除表情包。""" try: - verify_auth_token(maibot_session, authorization) + verify_auth_token(maibot_session) if not request.emoji_ids: raise HTTPException(status_code=400, detail="未提供要删除的表情包ID") deleted_count = 0 failed_count = 0 - failed_ids = [] + failed_ids: list[int] = [] for emoji_id in request.emoji_ids: try: @@ -816,8 +452,7 @@ async def batch_delete_emojis( col(Images.id) == emoji_id, col(Images.image_type) == ImageType.EMOJI, ) - emoji = session.exec(statement).first() - if emoji: + if emoji := session.exec(statement).first(): session.delete(emoji) deleted_count += 1 logger.info(f"批量删除表情包: {emoji_id}") @@ -840,7 +475,6 @@ async def batch_delete_emojis( failed_count=failed_count, failed_ids=failed_ids, ) - except HTTPException: raise except Exception as e: @@ -848,18 +482,6 @@ async def batch_delete_emojis( raise HTTPException(status_code=500, detail=f"批量删除失败: {str(e)}") from e -# 表情包存储目录 -EMOJI_REGISTERED_DIR = os.path.join("data", "emoji_registed") - - -class EmojiUploadResponse(BaseModel): - """表情包上传响应""" - - success: bool - message: str - data: Optional[EmojiResponse] = None - - @router.post("/upload", response_model=EmojiUploadResponse) async def upload_emoji( file: EmojiFile, @@ -867,25 +489,11 @@ async def upload_emoji( emotion: EmotionForm = "", is_registered: IsRegisteredForm = True, maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -): - """ - 上传并注册表情包 - - Args: - file: 表情包图片文件 (支持 jpg, jpeg, png, gif, webp) - description: 表情包描述 - emotion: 情感标签,多个用逗号分隔 - is_registered: 是否直接注册,默认为 True - authorization: Authorization header - - Returns: - 上传结果和表情包信息 - """ +) -> EmojiUploadResponse: + """上传并注册表情包。""" try: - verify_auth_token(maibot_session, authorization) + verify_auth_token(maibot_session) - # 验证文件类型 if not file.content_type: raise HTTPException(status_code=400, detail="无法识别文件类型") @@ -896,26 +504,19 @@ async def upload_emoji( detail=f"不支持的文件类型: {file.content_type},支持: {', '.join(allowed_types)}", ) - # 读取文件内容 file_content = await file.read() - if not file_content: raise HTTPException(status_code=400, detail="文件内容为空") - # 验证图片并获取格式 try: with Image.open(io.BytesIO(file_content)) as img: - img_format = img.format.lower() if img.format else "png" - # 验证图片可以正常打开 img.verify() except Exception as e: raise HTTPException(status_code=400, detail=f"无效的图片文件: {str(e)}") from e - # 重新打开图片(verify后需要重新打开) with Image.open(io.BytesIO(file_content)) as img: img_format = img.format.lower() if img.format else "png" - # 计算文件哈希 emoji_hash = hashlib.md5(file_content).hexdigest() with get_db_session() as session: @@ -923,36 +524,26 @@ async def upload_emoji( col(Images.image_hash) == emoji_hash, col(Images.image_type) == ImageType.EMOJI, ) - existing_emoji = session.exec(existing_statement).first() - if existing_emoji: - raise HTTPException( - status_code=409, - detail=f"已存在相同的表情包 (ID: {existing_emoji.id})", - ) + if existing_emoji := session.exec(existing_statement).first(): + raise HTTPException(status_code=409, detail=f"已存在相同的表情包 (ID: {existing_emoji.id})") - # 确保目录存在 os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True) - # 生成文件名 timestamp = int(datetime.now().timestamp()) filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}" full_path = os.path.join(EMOJI_REGISTERED_DIR, filename) - # 如果文件已存在,添加随机后缀 counter = 1 while os.path.exists(full_path): filename = f"emoji_{timestamp}_{emoji_hash[:8]}_{counter}.{img_format}" full_path = os.path.join(EMOJI_REGISTERED_DIR, filename) counter += 1 - # 保存文件 - with open(full_path, "wb") as f: - _ = f.write(file_content) + with open(full_path, "wb") as output_file: + _ = output_file.write(file_content) logger.info(f"表情包文件已保存: {full_path}") - - # 处理情感标签 - emotion_str = ",".join(e.strip() for e in emotion.split(",") if e.strip()) if emotion else "" + emotion_str = ",".join(item.strip() for item in emotion.split(",") if item.strip()) if emotion else "" current_time = datetime.now() with get_db_session() as session: @@ -973,13 +564,11 @@ async def upload_emoji( session.flush() logger.info(f"表情包已上传并注册: ID={emoji.id}, hash={emoji_hash}") - return EmojiUploadResponse( success=True, message="表情包上传成功" + ("并已注册" if is_registered else ""), data=emoji_to_response(emoji), ) - except HTTPException: raise except Exception as e: @@ -993,22 +582,10 @@ async def batch_upload_emoji( emotion: EmotionForm = "", is_registered: IsRegisteredForm = True, maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -): - """ - 批量上传表情包 - - Args: - files: 多个表情包图片文件 - emotion: 共用的情感标签 - is_registered: 是否直接注册 - authorization: Authorization header - - Returns: - 批量上传结果 - """ +) -> dict[str, Any]: + """批量上传表情包。""" try: - verify_auth_token(maibot_session, authorization) + verify_auth_token(maibot_session) results: dict[str, Any] = { "success": True, @@ -1023,7 +600,6 @@ async def batch_upload_emoji( for file in files: try: - # 验证文件类型 if file.content_type not in allowed_types: results["failed"] += 1 results["details"].append( @@ -1035,36 +611,24 @@ async def batch_upload_emoji( ) continue - # 读取文件内容 file_content = await file.read() - if not file_content: results["failed"] += 1 results["details"].append( - { - "filename": file.filename, - "success": False, - "error": "文件内容为空", - } + {"filename": file.filename, "success": False, "error": "文件内容为空"} ) continue - # 验证图片 try: with Image.open(io.BytesIO(file_content)) as img: img_format = img.format.lower() if img.format else "png" except Exception as e: results["failed"] += 1 results["details"].append( - { - "filename": file.filename, - "success": False, - "error": f"无效的图片: {str(e)}", - } + {"filename": file.filename, "success": False, "error": f"无效的图片: {str(e)}"} ) continue - # 计算哈希 emoji_hash = hashlib.md5(file_content).hexdigest() with get_db_session() as session: @@ -1075,15 +639,10 @@ async def batch_upload_emoji( if session.exec(existing_statement).first(): results["failed"] += 1 results["details"].append( - { - "filename": file.filename, - "success": False, - "error": "已存在相同的表情包", - } + {"filename": file.filename, "success": False, "error": "已存在相同的表情包"} ) continue - # 生成文件名并保存 timestamp = int(datetime.now().timestamp()) filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}" full_path = os.path.join(EMOJI_REGISTERED_DIR, filename) @@ -1094,13 +653,12 @@ async def batch_upload_emoji( full_path = os.path.join(EMOJI_REGISTERED_DIR, filename) counter += 1 - with open(full_path, "wb") as f: - _ = f.write(file_content) - - # 处理情感标签 - emotion_str = ",".join(e.strip() for e in emotion.split(",") if e.strip()) if emotion else "" + with open(full_path, "wb") as output_file: + _ = output_file.write(file_content) + emotion_str = ",".join(item.strip() for item in emotion.split(",") if item.strip()) if emotion else "" current_time = datetime.now() + with get_db_session() as session: emoji = Images( image_type=ImageType.EMOJI, @@ -1120,26 +678,16 @@ async def batch_upload_emoji( results["uploaded"] += 1 results["details"].append( - { - "filename": file.filename, - "success": True, - "id": emoji.id, - } + {"filename": file.filename, "success": True, "id": emoji.id} ) - except Exception as e: results["failed"] += 1 results["details"].append( - { - "filename": file.filename, - "success": False, - "error": str(e), - } + {"filename": file.filename, "success": False, "error": str(e)} ) results["message"] = f"成功上传 {results['uploaded']} 个,失败 {results['failed']} 个" return results - except HTTPException: raise except Exception as e: @@ -1147,68 +695,22 @@ async def batch_upload_emoji( raise HTTPException(status_code=500, detail=f"批量上传失败: {str(e)}") from e -# ==================== 缩略图缓存管理 API ==================== - - -class ThumbnailCacheStatsResponse(BaseModel): - """缩略图缓存统计响应""" - - success: bool - cache_dir: str - total_count: int - total_size_mb: float - emoji_count: int - coverage_percent: float - - -class ThumbnailCleanupResponse(BaseModel): - """缩略图清理响应""" - - success: bool - message: str - cleaned_count: int - kept_count: int - - -class ThumbnailPreheatResponse(BaseModel): - """缩略图预热响应""" - - success: bool - message: str - generated_count: int - skipped_count: int - failed_count: int - - @router.get("/thumbnail-cache/stats", response_model=ThumbnailCacheStatsResponse) -async def get_thumbnail_cache_stats( - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -): - """ - 获取缩略图缓存统计信息 - - Returns: - 缓存目录、缓存数量、总大小、覆盖率等统计信息 - """ +async def get_thumbnail_cache_stats(maibot_session: Optional[str] = Cookie(None)) -> ThumbnailCacheStatsResponse: + """获取缩略图缓存统计信息。""" try: - verify_auth_token(maibot_session, authorization) + verify_auth_token(maibot_session) - _ensure_thumbnail_cache_dir() - - # 统计缓存文件 + ensure_thumbnail_cache_dir() cache_files = list(THUMBNAIL_CACHE_DIR.glob("*.webp")) total_count = len(cache_files) - total_size = sum(f.stat().st_size for f in cache_files) - total_size_mb = round(total_size / (1024 * 1024), 2) + total_size_mb = round(sum(item.stat().st_size for item in cache_files) / (1024 * 1024), 2) with get_db_session() as session: count_statement = select(func.count()).select_from(Images).where(col(Images.image_type) == ImageType.EMOJI) emoji_count = session.exec(count_statement).one() - # 计算覆盖率 coverage_percent = round((total_count / emoji_count * 100) if emoji_count > 0 else 0, 1) - return ThumbnailCacheStatsResponse( success=True, cache_dir=str(THUMBNAIL_CACHE_DIR.absolute()), @@ -1217,7 +719,6 @@ async def get_thumbnail_cache_stats( emoji_count=emoji_count, coverage_percent=coverage_percent, ) - except HTTPException: raise except Exception as e: @@ -1226,28 +727,18 @@ async def get_thumbnail_cache_stats( @router.post("/thumbnail-cache/cleanup", response_model=ThumbnailCleanupResponse) -async def cleanup_thumbnail_cache( - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -): - """ - 清理孤立的缩略图缓存(原图已删除的表情包对应的缩略图) - - Returns: - 清理结果 - """ +async def cleanup_thumbnail_cache(maibot_session: Optional[str] = Cookie(None)) -> ThumbnailCleanupResponse: + """清理孤立的缩略图缓存。""" try: - verify_auth_token(maibot_session, authorization) + verify_auth_token(maibot_session) cleaned, kept = cleanup_orphaned_thumbnails() - return ThumbnailCleanupResponse( success=True, message=f"清理完成:删除 {cleaned} 个孤立缓存,保留 {kept} 个有效缓存", cleaned_count=cleaned, kept_count=kept, ) - except HTTPException: raise except Exception as e: @@ -1259,25 +750,13 @@ async def cleanup_thumbnail_cache( async def preheat_thumbnail_cache( limit: int = Query(100, ge=1, le=1000, description="最多预热数量"), maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -): - """ - 预热缩略图缓存(提前生成未缓存的缩略图) - - 优先处理使用次数高的表情包 - - Args: - limit: 最多预热数量 (1-1000) - - Returns: - 预热结果 - """ +) -> ThumbnailPreheatResponse: + """预热缩略图缓存。""" try: - verify_auth_token(maibot_session, authorization) + verify_auth_token(maibot_session) - _ensure_thumbnail_cache_dir() + ensure_thumbnail_cache_dir() - # 获取使用次数最高的表情包(未缓存的优先) with get_db_session() as session: statement = ( select(Images) @@ -1298,19 +777,17 @@ async def preheat_thumbnail_cache( if generated >= limit: break - cache_path = _get_thumbnail_cache_path(emoji.image_hash) - + cache_path = get_thumbnail_cache_path(emoji.image_hash) if cache_path.exists(): skipped += 1 continue - if not os.path.exists(emoji.full_path): failed += 1 continue try: loop = asyncio.get_event_loop() - await loop.run_in_executor(_thumbnail_executor, _generate_thumbnail, emoji.full_path, emoji.image_hash) + await loop.run_in_executor(get_thumbnail_executor(), generate_thumbnail, emoji.full_path, emoji.image_hash) generated += 1 except Exception as e: logger.warning(f"预热缩略图失败 {emoji.image_hash}: {e}") @@ -1323,7 +800,6 @@ async def preheat_thumbnail_cache( skipped_count=skipped, failed_count=failed, ) - except HTTPException: raise except Exception as e: @@ -1332,18 +808,10 @@ async def preheat_thumbnail_cache( @router.delete("/thumbnail-cache/clear", response_model=ThumbnailCleanupResponse) -async def clear_all_thumbnail_cache( - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -): - """ - 清空所有缩略图缓存(下次访问时会重新生成) - - Returns: - 清理结果 - """ +async def clear_all_thumbnail_cache(maibot_session: Optional[str] = Cookie(None)) -> ThumbnailCleanupResponse: + """清空所有缩略图缓存。""" try: - verify_auth_token(maibot_session, authorization) + verify_auth_token(maibot_session) if not THUMBNAIL_CACHE_DIR.exists(): return ThumbnailCleanupResponse( @@ -1362,16 +830,14 @@ async def clear_all_thumbnail_cache( logger.warning(f"删除缓存文件失败 {cache_file.name}: {e}") logger.info(f"已清空缩略图缓存: 删除 {cleaned} 个文件") - return ThumbnailCleanupResponse( success=True, message=f"已清空所有缩略图缓存:删除 {cleaned} 个文件", cleaned_count=cleaned, kept_count=0, ) - except HTTPException: raise except Exception as e: logger.exception(f"清空缩略图缓存失败: {e}") - raise HTTPException(status_code=500, detail=f"清空失败: {str(e)}") from e + raise HTTPException(status_code=500, detail=f"清空失败: {str(e)}") from e \ No newline at end of file diff --git a/src/webui/routers/emoji/schemas.py b/src/webui/routers/emoji/schemas.py new file mode 100644 index 00000000..4eea67c6 --- /dev/null +++ b/src/webui/routers/emoji/schemas.py @@ -0,0 +1,140 @@ +from typing import Annotated, List, Optional + +from fastapi import File, Form, UploadFile +from pydantic import BaseModel + +from src.common.database.database_model import Images + +EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")] +EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")] +DescriptionForm = Annotated[str, Form(description="表情包描述")] +EmotionForm = Annotated[str, Form(description="情感标签,多个用逗号分隔")] +IsRegisteredForm = Annotated[bool, Form(description="是否直接注册")] + + +class EmojiResponse(BaseModel): + """表情包响应""" + + id: int + full_path: str + emoji_hash: str + description: str + query_count: int + is_registered: bool + is_banned: bool + emotion: Optional[str] + record_time: float + register_time: Optional[float] + last_used_time: Optional[float] + + +class EmojiListResponse(BaseModel): + """表情包列表响应""" + + success: bool + total: int + page: int + page_size: int + data: List[EmojiResponse] + + +class EmojiDetailResponse(BaseModel): + """表情包详情响应""" + + success: bool + data: EmojiResponse + + +class EmojiUpdateRequest(BaseModel): + """表情包更新请求""" + + description: Optional[str] = None + is_registered: Optional[bool] = None + is_banned: Optional[bool] = None + emotion: Optional[str] = None + + +class EmojiUpdateResponse(BaseModel): + """表情包更新响应""" + + success: bool + message: str + data: Optional[EmojiResponse] = None + + +class EmojiDeleteResponse(BaseModel): + """表情包删除响应""" + + success: bool + message: str + + +class BatchDeleteRequest(BaseModel): + """批量删除请求""" + + emoji_ids: List[int] + + +class BatchDeleteResponse(BaseModel): + """批量删除响应""" + + success: bool + message: str + deleted_count: int + failed_count: int + failed_ids: List[int] = [] + + +class EmojiUploadResponse(BaseModel): + """表情包上传响应""" + + success: bool + message: str + data: Optional[EmojiResponse] = None + + +class ThumbnailCacheStatsResponse(BaseModel): + """缩略图缓存统计响应""" + + success: bool + cache_dir: str + total_count: int + total_size_mb: float + emoji_count: int + coverage_percent: float + + +class ThumbnailCleanupResponse(BaseModel): + """缩略图清理响应""" + + success: bool + message: str + cleaned_count: int + kept_count: int + + +class ThumbnailPreheatResponse(BaseModel): + """缩略图预热响应""" + + success: bool + message: str + generated_count: int + skipped_count: int + failed_count: int + + +def emoji_to_response(image: Images) -> EmojiResponse: + """将数据库表情包模型转换为响应对象。""" + return EmojiResponse( + id=image.id if image.id is not None else 0, + full_path=image.full_path, + emoji_hash=image.image_hash, + description=image.description, + query_count=image.query_count, + is_registered=image.is_registered, + is_banned=image.is_banned, + emotion=image.emotion, + record_time=image.record_time.timestamp() if image.record_time else 0.0, + register_time=image.register_time.timestamp() if image.register_time else None, + last_used_time=image.last_used_time.timestamp() if image.last_used_time else None, + ) \ No newline at end of file diff --git a/src/webui/routers/emoji/support.py b/src/webui/routers/emoji/support.py new file mode 100644 index 00000000..51790cd4 --- /dev/null +++ b/src/webui/routers/emoji/support.py @@ -0,0 +1,142 @@ +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +import os +import threading + +from PIL import Image +from sqlmodel import col, select + +from src.common.database.database import get_db_session +from src.common.database.database_model import Images, ImageType +from src.common.logger import get_logger + +logger = get_logger("webui.emoji") + +THUMBNAIL_CACHE_DIR = Path("data/emoji_thumbnails") +THUMBNAIL_SIZE = (200, 200) +THUMBNAIL_QUALITY = 80 +EMOJI_REGISTERED_DIR = os.path.join("data", "emoji_registed") + +_thumbnail_locks: dict[str, threading.Lock] = {} +_locks_lock = threading.Lock() +_thumbnail_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="thumbnail") +_generating_thumbnails: set[str] = set() +_generating_lock = threading.Lock() + + +def get_thumbnail_executor() -> ThreadPoolExecutor: + """获取缩略图生成线程池。""" + return _thumbnail_executor + + +def get_generating_lock() -> threading.Lock: + """获取缩略图生成状态锁。""" + return _generating_lock + + +def get_generating_thumbnails() -> set[str]: + """获取正在生成的缩略图哈希集合。""" + return _generating_thumbnails + + +def _get_thumbnail_lock(file_hash: str) -> threading.Lock: + """获取指定文件哈希的锁,用于防止并发生成同一缩略图。""" + with _locks_lock: + if file_hash not in _thumbnail_locks: + _thumbnail_locks[file_hash] = threading.Lock() + return _thumbnail_locks[file_hash] + + +def _background_generate_thumbnail(source_path: str, file_hash: str) -> None: + """在线程池中后台生成缩略图。""" + try: + _generate_thumbnail(source_path, file_hash) + except Exception as e: + logger.warning(f"后台生成缩略图失败 {file_hash}: {e}") + finally: + with _generating_lock: + _generating_thumbnails.discard(file_hash) + + +def ensure_thumbnail_cache_dir() -> Path: + """确保缩略图缓存目录存在。""" + _ = THUMBNAIL_CACHE_DIR.mkdir(parents=True, exist_ok=True) + return THUMBNAIL_CACHE_DIR + + +def get_thumbnail_cache_path(file_hash: str) -> Path: + """获取缩略图缓存路径。""" + return THUMBNAIL_CACHE_DIR / f"{file_hash}.webp" + + +def _generate_thumbnail(source_path: str, file_hash: str) -> Path: + """生成缩略图并保存到缓存目录。""" + ensure_thumbnail_cache_dir() + cache_path = get_thumbnail_cache_path(file_hash) + + lock = _get_thumbnail_lock(file_hash) + with lock: + if cache_path.exists(): + return cache_path + + try: + with Image.open(source_path) as img: + if getattr(img, "n_frames", 1) > 1: + img.seek(0) + + if img.mode in ("P", "PA"): + img = img.convert("RGBA") + elif img.mode == "LA": + img = img.convert("RGBA") + elif img.mode not in ("RGB", "RGBA"): + img = img.convert("RGB") + + img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS) + img.save(cache_path, "WEBP", quality=THUMBNAIL_QUALITY, method=6) + logger.debug(f"生成缩略图: {file_hash} -> {cache_path}") + except Exception as e: + logger.warning(f"生成缩略图失败 {file_hash}: {e},将返回原图") + raise + + return cache_path + + +def generate_thumbnail(source_path: str, file_hash: str) -> Path: + """暴露给路由层的缩略图生成函数。""" + return _generate_thumbnail(source_path, file_hash) + + +def background_generate_thumbnail(source_path: str, file_hash: str) -> None: + """暴露给路由层的后台缩略图生成函数。""" + _background_generate_thumbnail(source_path, file_hash) + + +def cleanup_orphaned_thumbnails() -> tuple[int, int]: + """清理孤立的缩略图缓存。""" + if not THUMBNAIL_CACHE_DIR.exists(): + return 0, 0 + + with get_db_session() as session: + statement = select(Images.image_hash).where(col(Images.image_type) == ImageType.EMOJI) + valid_hashes = set(session.exec(statement).all()) + + cleaned = 0 + kept = 0 + + for cache_file in THUMBNAIL_CACHE_DIR.glob("*.webp"): + file_hash = cache_file.stem + if file_hash not in valid_hashes: + try: + cache_file.unlink() + cleaned += 1 + logger.debug(f"清理孤立缩略图: {cache_file.name}") + except Exception as e: + logger.warning(f"清理缩略图失败 {cache_file.name}: {e}") + else: + kept += 1 + + if cleaned > 0: + logger.info(f"清理孤立缩略图: 删除 {cleaned} 个,保留 {kept} 个") + + return cleaned, kept \ No newline at end of file diff --git a/src/webui/routers/expression.py b/src/webui/routers/expression.py index 3e1fc187..f814751c 100644 --- a/src/webui/routers/expression.py +++ b/src/webui/routers/expression.py @@ -1,9 +1,10 @@ """表达方式管理 API 路由""" -from fastapi import APIRouter, HTTPException, Header, Query, Cookie -from pydantic import BaseModel -from typing import Optional, List, Dict from datetime import datetime, timedelta +from typing import Dict, List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel from sqlalchemy import case, func from sqlmodel import col, select, delete @@ -12,12 +13,12 @@ from src.common.logger import get_logger from src.common.database.database import get_db_session from src.common.database.database_model import Expression from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.webui.core import verify_auth_token_from_cookie_or_header +from src.webui.dependencies import require_auth logger = get_logger("webui.expression") # 创建路由器 -router = APIRouter(prefix="/expression", tags=["Expression"]) +router = APIRouter(prefix="/expression", tags=["Expression"], dependencies=[Depends(require_auth)]) class ExpressionResponse(BaseModel): @@ -90,14 +91,6 @@ class ExpressionCreateResponse(BaseModel): data: ExpressionResponse -def verify_auth_token( - maibot_session: Optional[str] = None, - authorization: Optional[str] = None, -) -> bool: - """验证认证 Token,支持 Cookie 和 Header""" - return verify_auth_token_from_cookie_or_header(maibot_session, authorization) - - def expression_to_response(expression: Expression) -> ExpressionResponse: """将 Expression 模型转换为响应对象""" last_active_time = expression.last_active_time.timestamp() if expression.last_active_time else 0.0 @@ -156,19 +149,14 @@ class ChatListResponse(BaseModel): @router.get("/chats", response_model=ChatListResponse) -async def get_chat_list(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): +async def get_chat_list(): """ 获取所有聊天列表(用于下拉选择) - Args: - authorization: Authorization header - Returns: 聊天列表 """ try: - verify_auth_token(maibot_session, authorization) - chat_list = [] for session_id, session in _chat_manager.sessions.items(): chat_name = _chat_manager.get_session_name(session_id) or session_id @@ -199,8 +187,6 @@ async def get_expression_list( page_size: int = Query(20, ge=1, le=100, description="每页数量"), search: Optional[str] = Query(None, description="搜索关键词"), chat_id: Optional[str] = Query(None, description="聊天ID筛选"), - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), ): """ 获取表达方式列表 @@ -210,14 +196,11 @@ async def get_expression_list( page_size: 每页数量 (1-100) search: 搜索关键词 (匹配 situation, style) chat_id: 聊天ID筛选 - authorization: Authorization header Returns: 表达方式列表 """ try: - verify_auth_token(maibot_session, authorization) - # 构建查询 statement = select(Expression) @@ -264,22 +247,17 @@ async def get_expression_list( @router.get("/{expression_id}", response_model=ExpressionDetailResponse) -async def get_expression_detail( - expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) -): +async def get_expression_detail(expression_id: int): """ 获取表达方式详细信息 Args: expression_id: 表达方式ID - authorization: Authorization header Returns: 表达方式详细信息 """ try: - verify_auth_token(maibot_session, authorization) - with get_db_session() as session: statement = select(Expression).where(col(Expression.id) == expression_id).limit(1) expression = session.exec(statement).first() @@ -299,22 +277,17 @@ async def get_expression_detail( @router.post("/", response_model=ExpressionCreateResponse) async def create_expression( request: ExpressionCreateRequest, - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), ): """ 创建新的表达方式 Args: request: 创建请求 - authorization: Authorization header Returns: 创建结果 """ try: - verify_auth_token(maibot_session, authorization) - current_time = datetime.now() # 创建表达方式 @@ -349,8 +322,6 @@ async def create_expression( async def update_expression( expression_id: int, request: ExpressionUpdateRequest, - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), ): """ 增量更新表达方式(只更新提供的字段) @@ -358,14 +329,11 @@ async def update_expression( Args: expression_id: 表达方式ID request: 更新请求(只包含需要更新的字段) - authorization: Authorization header Returns: 更新结果 """ try: - verify_auth_token(maibot_session, authorization) - with get_db_session() as session: statement = select(Expression).where(col(Expression.id) == expression_id).limit(1) expression = session.exec(statement).first() @@ -411,22 +379,17 @@ async def update_expression( @router.delete("/{expression_id}", response_model=ExpressionDeleteResponse) -async def delete_expression( - expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) -): +async def delete_expression(expression_id: int): """ 删除表达方式 Args: expression_id: 表达方式ID - authorization: Authorization header Returns: 删除结果 """ try: - verify_auth_token(maibot_session, authorization) - with get_db_session() as session: statement = select(Expression).where(col(Expression.id) == expression_id).limit(1) expression = session.exec(statement).first() @@ -461,22 +424,17 @@ class BatchDeleteRequest(BaseModel): @router.post("/batch/delete", response_model=ExpressionDeleteResponse) async def batch_delete_expressions( request: BatchDeleteRequest, - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), ): """ 批量删除表达方式 Args: request: 包含要删除的ID列表的请求 - authorization: Authorization header Returns: 删除结果 """ try: - verify_auth_token(maibot_session, authorization) - if not request.ids: raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID") @@ -506,21 +464,14 @@ async def batch_delete_expressions( @router.get("/stats/summary") -async def get_expression_stats( - maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) -): +async def get_expression_stats(): """ 获取表达方式统计数据 - Args: - authorization: Authorization header - Returns: 统计数据 """ try: - verify_auth_token(maibot_session, authorization) - with get_db_session() as session: total = len(session.exec(select(Expression.id)).all()) @@ -569,7 +520,7 @@ class ReviewStatsResponse(BaseModel): @router.get("/review/stats", response_model=ReviewStatsResponse) -async def get_review_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): +async def get_review_stats(): """ 获取审核统计数据 @@ -577,8 +528,6 @@ async def get_review_stats(maibot_session: Optional[str] = Cookie(None), authori 审核统计数据 """ try: - verify_auth_token(maibot_session, authorization) - with get_db_session() as session: total = len(session.exec(select(Expression.id)).all()) unchecked = 0 @@ -620,8 +569,6 @@ async def get_review_list( filter_type: str = Query("unchecked", description="筛选类型: unchecked/passed/rejected/all"), search: Optional[str] = Query(None, description="搜索关键词"), chat_id: Optional[str] = Query(None, description="聊天ID筛选"), - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), ): """ 获取待审核/已审核的表达方式列表 @@ -637,8 +584,6 @@ async def get_review_list( 表达方式列表 """ try: - verify_auth_token(maibot_session, authorization) - statement = select(Expression) if filter_type in {"unchecked", "passed", "rejected"}: @@ -728,8 +673,6 @@ class BatchReviewResponse(BaseModel): @router.post("/review/batch", response_model=BatchReviewResponse) async def batch_review_expressions( request: BatchReviewRequest, - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), ): """ 批量审核表达方式 @@ -741,8 +684,6 @@ async def batch_review_expressions( 批量审核结果 """ try: - verify_auth_token(maibot_session, authorization) - if not request.items: raise HTTPException(status_code=400, detail="未提供要审核的表达方式") diff --git a/src/webui/routers/jargon.py b/src/webui/routers/jargon.py index 02414109..95568d48 100644 --- a/src/webui/routers/jargon.py +++ b/src/webui/routers/jargon.py @@ -1,20 +1,21 @@ """黑话(俚语)管理路由""" -from typing import Annotated, Any, List, Optional -from fastapi import APIRouter, HTTPException, Query -from pydantic import BaseModel, Field -from sqlalchemy import func as fn -from sqlmodel import Session, col, delete, select - import json +from typing import Annotated, Any, List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel, Field +from sqlmodel import Session, col, delete, select + from src.common.database.database import get_db_session from src.common.database.database_model import ChatSession, Jargon from src.common.logger import get_logger +from src.webui.dependencies import require_auth logger = get_logger("webui.jargon") -router = APIRouter(prefix="/jargon", tags=["Jargon"]) +router = APIRouter(prefix="/jargon", tags=["Jargon"], dependencies=[Depends(require_auth)]) # ==================== 辅助函数 ==================== @@ -33,14 +34,10 @@ def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]: parsed = json.loads(chat_id_str) if isinstance(parsed, list): # 格式: [["stream_id", user_id], ...] - stream_ids = [] - for item in parsed: - if isinstance(item, list) and len(item) >= 1: - stream_ids.append(str(item[0])) - return stream_ids - else: - # 其他格式,返回原始字符串 - return [chat_id_str] + return [str(item[0]) for item in parsed if isinstance(item, list) and len(item) >= 1] + + # 其他格式,返回原始字符串 + return [chat_id_str] except (json.JSONDecodeError, TypeError): # 不是有效的 JSON,可能是直接的 stream_id return [chat_id_str] @@ -57,9 +54,7 @@ def get_display_name_for_chat_id(chat_id_str: str, session: Session) -> str: return chat_id_str[:20] stream_id = stream_ids[0] - chat_session = session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first() - - if not chat_session: + if not (chat_session := session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first()): return stream_id[:20] if chat_session.group_id: @@ -180,9 +175,59 @@ class ChatListResponse(BaseModel): # ==================== 工具函数 ==================== +def parse_session_id_dict(session_id_dict_str: Optional[str]) -> dict[str, int]: + """解析会话计数字典。""" + if not session_id_dict_str: + return {} + + try: + parsed = json.loads(session_id_dict_str) + except (json.JSONDecodeError, TypeError): + return {} + + if not isinstance(parsed, dict): + return {} + + session_counts: dict[str, int] = {} + for session_id, count in parsed.items(): + if not isinstance(session_id, str): + continue + if isinstance(count, int): + session_counts[session_id] = count + else: + try: + session_counts[session_id] = int(count) + except (TypeError, ValueError): + session_counts[session_id] = 0 + return session_counts + + +def dump_session_id_dict(session_counts: dict[str, int]) -> str: + """序列化会话计数字典。""" + return json.dumps(session_counts, ensure_ascii=False) + + +def get_primary_chat_id(session_id_dict_str: Optional[str]) -> str: + """从会话计数字典中选出主聊天 ID。""" + if not (session_counts := parse_session_id_dict(session_id_dict_str)): + return "" + + return max(session_counts.items(), key=lambda item: item[1])[0] + + +def has_chat_id(session_id_dict_str: Optional[str], chat_id: str) -> bool: + """判断记录是否包含指定聊天 ID。""" + return chat_id in parse_session_id_dict(session_id_dict_str) + + +def build_session_id_dict_for_chat(chat_id: str, count: int = 1) -> str: + """为单个聊天 ID 构建会话计数字典。""" + return dump_session_id_dict({chat_id: count}) + + def jargon_to_dict(jargon: Jargon, session: Session) -> dict[str, Any]: """将 Jargon ORM 对象转换为字典""" - chat_id = jargon.session_id or "" + chat_id = get_primary_chat_id(jargon.session_id_dict) chat_name = get_display_name_for_chat_id(chat_id, session) if chat_id else None return { @@ -191,7 +236,7 @@ def jargon_to_dict(jargon: Jargon, session: Session) -> dict[str, Any]: "raw_content": jargon.raw_content, "meaning": jargon.meaning, "chat_id": chat_id, - "stream_id": jargon.session_id, + "stream_id": chat_id or None, "chat_name": chat_name, "count": jargon.count, "is_jargon": jargon.is_jargon, @@ -215,7 +260,6 @@ async def get_jargon_list( """获取黑话列表""" try: statement = select(Jargon) - count_statement = select(fn.count()).select_from(Jargon) if search: search_filter = ( @@ -224,28 +268,28 @@ async def get_jargon_list( | (col(Jargon.raw_content).contains(search)) ) statement = statement.where(search_filter) - count_statement = count_statement.where(search_filter) - - if chat_id: - stream_ids = parse_chat_id_to_stream_ids(chat_id) - if stream_ids: - chat_filter = col(Jargon.session_id).contains(stream_ids[0]) - else: - chat_filter = col(Jargon.session_id) == chat_id - statement = statement.where(chat_filter) - count_statement = count_statement.where(chat_filter) if is_jargon is not None: statement = statement.where(col(Jargon.is_jargon) == is_jargon) - count_statement = count_statement.where(col(Jargon.is_jargon) == is_jargon) statement = statement.order_by(col(Jargon.count).desc(), col(Jargon.id).desc()) - statement = statement.offset((page - 1) * page_size).limit(page_size) with get_db_session() as session: - total = session.exec(count_statement).one() jargons = session.exec(statement).all() - data = [jargon_to_dict(jargon, session) for jargon in jargons] + + if chat_id: + stream_ids = parse_chat_id_to_stream_ids(chat_id) + chat_ids = stream_ids or [chat_id] + jargons = [ + jargon + for jargon in jargons + if any(has_chat_id(jargon.session_id_dict, current_chat_id) for current_chat_id in chat_ids) + ] + + total = len(jargons) + offset = (page - 1) * page_size + page_jargons = jargons[offset : offset + page_size] + data = [jargon_to_dict(jargon, session) for jargon in page_jargons] return JargonListResponse( success=True, @@ -265,22 +309,16 @@ async def get_chat_list(): """获取所有有黑话记录的聊天列表""" try: with get_db_session() as session: - statement = select(Jargon.session_id).distinct().where(col(Jargon.session_id).is_not(None)) - chat_id_list = [chat_id for chat_id in session.exec(statement).all() if chat_id] + jargons = session.exec(select(Jargon)).all() - # 用于按 stream_id 去重 seen_stream_ids: set[str] = set() - - for chat_id in chat_id_list: - stream_ids = parse_chat_id_to_stream_ids(chat_id) - if stream_ids: - seen_stream_ids.add(stream_ids[0]) + for jargon in jargons: + seen_stream_ids.update(parse_session_id_dict(jargon.session_id_dict).keys()) result = [] with get_db_session() as session: for stream_id in seen_stream_ids: - chat_session = session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first() - if chat_session: + if chat_session := session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first(): chat_name = str(chat_session.group_id) if chat_session.group_id else stream_id[:20] result.append( ChatInfoResponse( @@ -312,30 +350,21 @@ async def get_jargon_stats(): """获取黑话统计数据""" try: with get_db_session() as session: - total = session.exec(select(fn.count()).select_from(Jargon)).one() + jargons = session.exec(select(Jargon)).all() - confirmed_jargon = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon))).one() - confirmed_not_jargon = session.exec( - select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(False)) - ).one() - pending = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(None))).one() + total = len(jargons) + confirmed_jargon = sum(jargon.is_jargon is True for jargon in jargons) + confirmed_not_jargon = sum(jargon.is_jargon is False for jargon in jargons) + pending = sum(jargon.is_jargon is None for jargon in jargons) + complete_count = sum(jargon.is_complete for jargon in jargons) - complete_count = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_complete))).one() + top_chats_counter: dict[str, int] = {} + for jargon in jargons: + for session_id in parse_session_id_dict(jargon.session_id_dict): + top_chats_counter[session_id] = top_chats_counter.get(session_id, 0) + 1 - chat_count = session.exec( - select(fn.count()).select_from( - select(col(Jargon.session_id)).distinct().where(col(Jargon.session_id).is_not(None)).subquery() - ) - ).one() - - top_chats = session.exec( - select(col(Jargon.session_id), fn.count().label("count")) - .where(col(Jargon.session_id).is_not(None)) - .group_by(col(Jargon.session_id)) - .order_by(fn.count().desc()) - .limit(5) - ).all() - top_chats_dict = {session_id: count for session_id, count in top_chats if session_id} + top_chats_dict = dict(sorted(top_chats_counter.items(), key=lambda item: item[1], reverse=True)[:5]) + chat_count = len(top_chats_counter) return JargonStatsResponse( success=True, @@ -360,8 +389,7 @@ async def get_jargon_detail(jargon_id: int): """获取黑话详情""" try: with get_db_session() as session: - jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first() - if not jargon: + if not (jargon := session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()): raise HTTPException(status_code=404, detail="黑话不存在") data = JargonResponse(**jargon_to_dict(jargon, session)) @@ -379,19 +407,19 @@ async def create_jargon(request: JargonCreateRequest): """创建黑话""" try: with get_db_session() as session: - existing = session.exec( - select(Jargon).where( - (col(Jargon.content) == request.content) & (col(Jargon.session_id) == request.chat_id) - ) - ).first() - if existing: + same_content_jargons = session.exec(select(Jargon).where(col(Jargon.content) == request.content)).all() + existing = next( + (jargon for jargon in same_content_jargons if has_chat_id(jargon.session_id_dict, request.chat_id)), + None, + ) + if existing is not None: raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话") jargon = Jargon( content=request.content, raw_content=request.raw_content, meaning=request.meaning or "", - session_id=request.chat_id, + session_id_dict=build_session_id_dict_for_chat(request.chat_id), count=0, is_jargon=None, is_complete=False, @@ -420,13 +448,12 @@ async def update_jargon(jargon_id: int, request: JargonUpdateRequest): if not jargon: raise HTTPException(status_code=404, detail="黑话不存在") - update_data = request.model_dump(exclude_unset=True) - if update_data: + if update_data := request.model_dump(exclude_unset=True): for field, value in update_data.items(): if field == "is_global": continue if field == "chat_id": - jargon.session_id = value + jargon.session_id_dict = build_session_id_dict_for_chat(value, max(jargon.count, 1)) continue if value is not None or field in ["meaning", "raw_content", "is_jargon"]: setattr(jargon, field, value) diff --git a/src/webui/routers/knowledge.py b/src/webui/routers/knowledge.py index e43bdb02..cb7f58d9 100644 --- a/src/webui/routers/knowledge.py +++ b/src/webui/routers/knowledge.py @@ -1,18 +1,28 @@ """知识库图谱可视化 API 路由""" -from typing import List, Optional -from fastapi import APIRouter, Query, Depends, Cookie, Header +from typing import Any, List, Optional + +from fastapi import APIRouter, Depends, Query from pydantic import BaseModel import logging -from src.webui.core import verify_auth_token_from_cookie_or_header from src.config.config import global_config +from src.webui.dependencies import require_auth logger = logging.getLogger(__name__) -router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"]) +router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"], dependencies=[Depends(require_auth)]) # 延迟初始化的轻量级 embedding store(只读,仅用于获取段落完整文本) -_paragraph_store_cache = None +_paragraph_store_cache: Any = None + + +def _get_embedding_dir() -> str: + """获取 embedding 数据目录。""" + import os + + current_dir = os.path.dirname(os.path.abspath(__file__)) + root_path = os.path.abspath(os.path.join(current_dir, "..", "..")) + return os.path.join(root_path, "data/embedding") def _get_paragraph_store(): @@ -31,17 +41,11 @@ def _get_paragraph_store(): try: from src.chat.knowledge.embedding_store import EmbeddingStore - import os - - # 获取数据路径 - current_dir = os.path.dirname(os.path.abspath(__file__)) - root_path = os.path.abspath(os.path.join(current_dir, "..", "..")) - embedding_dir = os.path.join(root_path, "data/embedding") # 只加载段落 embedding store(轻量级) paragraph_store = EmbeddingStore( namespace="paragraph", - dir_path=embedding_dir, + dir_path=_get_embedding_dir(), max_workers=1, # 只读不需要多线程 chunk_size=100, ) @@ -74,8 +78,7 @@ def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]: paragraph_item = paragraph_store.store.get(node_id) if paragraph_item is not None: # paragraph_item 是 EmbeddingStoreItem,其 str 属性包含完整文本 - content: str = getattr(paragraph_item, "str", "") - if content: + if content := getattr(paragraph_item, "str", ""): return content, True return None, True except Exception as e: @@ -83,14 +86,6 @@ def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]: return None, True -def require_auth( - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -) -> bool: - """认证依赖:验证用户是否已登录""" - return verify_auth_token_from_cookie_or_header(maibot_session, authorization) - - class KnowledgeNode(BaseModel): """知识节点""" @@ -205,7 +200,6 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph: async def get_knowledge_graph( limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"), node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"), - _auth: bool = Depends(require_auth), ): """获取知识图谱(限制节点数量) @@ -303,7 +297,7 @@ async def get_knowledge_graph( @router.get("/stats", response_model=KnowledgeStats) -async def get_knowledge_stats(_auth: bool = Depends(require_auth)): +async def get_knowledge_stats(): """获取知识库统计信息 Returns: @@ -352,7 +346,7 @@ async def get_knowledge_stats(_auth: bool = Depends(require_auth)): @router.get("/search", response_model=List[KnowledgeNode]) -async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bool = Depends(require_auth)): +async def search_knowledge_node(query: str = Query(..., min_length=1)): """搜索知识节点 Args: diff --git a/src/webui/routers/model.py b/src/webui/routers/model.py index b5ca4128..a1382a3a 100644 --- a/src/webui/routers/model.py +++ b/src/webui/routers/model.py @@ -6,27 +6,18 @@ import os import httpx -from fastapi import APIRouter, HTTPException, Query, Depends, Cookie, Header from typing import Optional + import tomlkit +from fastapi import APIRouter, Depends, HTTPException, Query from src.common.logger import get_logger from src.config.config import CONFIG_DIR -from src.webui.core import verify_auth_token_from_cookie_or_header +from src.webui.dependencies import require_auth logger = get_logger("webui") -router = APIRouter(prefix="/models", tags=["models"]) - - -def require_auth( - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -) -> bool: - """认证依赖:验证用户是否已登录""" - return verify_auth_token_from_cookie_or_header(maibot_session, authorization) - - +router = APIRouter(prefix="/models", tags=["models"], dependencies=[Depends(require_auth)]) # 模型获取器配置 MODEL_FETCHER_CONFIG = { # OpenAI 兼容格式的提供商 @@ -44,9 +35,7 @@ MODEL_FETCHER_CONFIG = { def _normalize_url(url: str) -> str: """规范化 URL(去掉尾部斜杠)""" - if not url: - return "" - return url.rstrip("/") + return url.rstrip("/") if url else "" def _parse_openai_response(data: dict) -> list[dict]: @@ -55,18 +44,18 @@ def _parse_openai_response(data: dict) -> list[dict]: 格式: { "data": [{ "id": "gpt-4", "object": "model", ... }] } """ - models = [] - if "data" in data and isinstance(data["data"], list): - for model in data["data"]: - if isinstance(model, dict) and "id" in model: - models.append( - { - "id": model["id"], - "name": model.get("name") or model["id"], - "owned_by": model.get("owned_by", ""), - } - ) - return models + if "data" not in data or not isinstance(data["data"], list): + return [] + + return [ + { + "id": model["id"], + "name": model.get("name") or model["id"], + "owned_by": model.get("owned_by", ""), + } + for model in data["data"] + if isinstance(model, dict) and "id" in model + ] def _parse_gemini_response(data: dict) -> list[dict]: @@ -178,11 +167,8 @@ def _get_provider_config(provider_name: str) -> Optional[dict]: config_data = tomlkit.load(f) providers = config_data.get("api_providers", []) - for provider in providers: - if provider.get("name") == provider_name: - return dict(provider) - - return None + provider = next((provider for provider in providers if provider.get("name") == provider_name), None) + return dict(provider) if provider is not None else None except Exception as e: logger.error(f"读取提供商配置失败: {e}") return None @@ -193,7 +179,6 @@ async def get_provider_models( provider_name: str = Query(..., description="提供商名称"), parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"), endpoint: str = Query("/models", description="获取模型列表的端点"), - _auth: bool = Depends(require_auth), ): """ 获取指定提供商的可用模型列表 @@ -238,7 +223,6 @@ async def get_models_by_url( parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"), endpoint: str = Query("/models", description="获取模型列表的端点"), client_type: str = Query("openai", description="客户端类型 (openai | gemini)"), - _auth: bool = Depends(require_auth), ): """ 通过 URL 直接获取模型列表(用于自定义提供商) @@ -262,7 +246,6 @@ async def get_models_by_url( async def test_provider_connection( base_url: str = Query(..., description="提供商的基础 URL"), api_key: Optional[str] = Query(None, description="API Key(可选,用于验证 Key 有效性)"), - _auth: bool = Depends(require_auth), ): """ 测试提供商连接状态 @@ -349,7 +332,6 @@ async def test_provider_connection( @router.post("/test-connection-by-name") async def test_provider_connection_by_name( provider_name: str = Query(..., description="提供商名称"), - _auth: bool = Depends(require_auth), ): """ 通过提供商名称测试连接(从配置文件读取信息) @@ -364,11 +346,7 @@ async def test_provider_connection_by_name( # 查找提供商 providers = config.get("api_providers", []) - provider = None - for p in providers: - if p.get("name") == provider_name: - provider = p - break + provider = next((item for item in providers if item.get("name") == provider_name), None) if not provider: raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}") @@ -380,4 +358,4 @@ async def test_provider_connection_by_name( raise HTTPException(status_code=400, detail="提供商配置缺少 base_url") # 调用测试接口 - return await test_provider_connection(base_url=base_url, api_key=api_key if api_key else None) + return await test_provider_connection(base_url=base_url, api_key=api_key or None) diff --git a/src/webui/routers/person.py b/src/webui/routers/person.py index 0b896f69..86927b0e 100644 --- a/src/webui/routers/person.py +++ b/src/webui/routers/person.py @@ -1,23 +1,24 @@ """人物信息管理 API 路由""" -from fastapi import APIRouter, HTTPException, Header, Query, Cookie -from pydantic import BaseModel -from typing import Optional, List, Dict from datetime import datetime +import json +from typing import Dict, List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel from sqlalchemy import case from sqlmodel import col, select, delete -from src.common.logger import get_logger from src.common.database.database import get_db_session from src.common.database.database_model import PersonInfo -from src.webui.core import verify_auth_token_from_cookie_or_header -import json +from src.common.logger import get_logger +from src.webui.dependencies import require_auth logger = get_logger("webui.person") # 创建路由器 -router = APIRouter(prefix="/person", tags=["Person"]) +router = APIRouter(prefix="/person", tags=["Person"], dependencies=[Depends(require_auth)]) class PersonInfoResponse(BaseModel): @@ -96,14 +97,6 @@ class BatchDeleteResponse(BaseModel): failed_ids: List[str] = [] -def verify_auth_token( - maibot_session: Optional[str] = None, - authorization: Optional[str] = None, -) -> bool: - """验证认证 Token,支持 Cookie 和 Header""" - return verify_auth_token_from_cookie_or_header(maibot_session, authorization) - - def parse_group_nick_name(group_nick_name_str: Optional[str]) -> Optional[List[Dict[str, str]]]: """解析群昵称 JSON 字符串""" if not group_nick_name_str: @@ -127,7 +120,7 @@ def person_to_response(person: PersonInfo) -> PersonInfoResponse: platform=person.platform, user_id=person.user_id, nickname=person.user_nickname, - group_nick_name=parse_group_nick_name(person.group_nickname), + group_nick_name=parse_group_nick_name(person.group_cardname), memory_points=person.memory_points, know_times=person.know_counts, know_since=know_since, @@ -142,8 +135,6 @@ async def get_person_list( search: Optional[str] = Query(None, description="搜索关键词"), is_known: Optional[bool] = Query(None, description="是否已认识筛选"), platform: Optional[str] = Query(None, description="平台筛选"), - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), ): """ 获取人物信息列表 @@ -154,14 +145,11 @@ async def get_person_list( search: 搜索关键词 (匹配 person_name, nickname, user_id) is_known: 是否已认识筛选 platform: 平台筛选 - authorization: Authorization header Returns: 人物信息列表 """ try: - verify_auth_token(maibot_session, authorization) - # 构建查询 statement = select(PersonInfo) @@ -219,22 +207,17 @@ async def get_person_list( @router.get("/{person_id}", response_model=PersonDetailResponse) -async def get_person_detail( - person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) -): +async def get_person_detail(person_id: str): """ 获取人物详细信息 Args: person_id: 人物唯一 ID - authorization: Authorization header Returns: 人物详细信息 """ try: - verify_auth_token(maibot_session, authorization) - with get_db_session() as session: statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1) person = session.exec(statement).first() @@ -255,8 +238,6 @@ async def get_person_detail( async def update_person( person_id: str, request: PersonUpdateRequest, - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), ): """ 增量更新人物信息(只更新提供的字段) @@ -264,14 +245,11 @@ async def update_person( Args: person_id: 人物唯一 ID request: 更新请求(只包含需要更新的字段) - authorization: Authorization header Returns: 更新结果 """ try: - verify_auth_token(maibot_session, authorization) - with get_db_session() as session: statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1) person = session.exec(statement).first() @@ -313,22 +291,17 @@ async def update_person( @router.delete("/{person_id}", response_model=PersonDeleteResponse) -async def delete_person( - person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) -): +async def delete_person(person_id: str): """ 删除人物信息 Args: person_id: 人物唯一 ID - authorization: Authorization header Returns: 删除结果 """ try: - verify_auth_token(maibot_session, authorization) - with get_db_session() as session: statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1) person = session.exec(statement).first() @@ -355,19 +328,14 @@ async def delete_person( @router.get("/stats/summary") -async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): +async def get_person_stats(): """ 获取人物信息统计数据 - Args: - authorization: Authorization header - Returns: 统计数据 """ try: - verify_auth_token(maibot_session, authorization) - with get_db_session() as session: total = len(session.exec(select(PersonInfo.id)).all()) known = len(session.exec(select(PersonInfo.id).where(col(PersonInfo.is_known))).all()) @@ -392,22 +360,17 @@ async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authori @router.post("/batch/delete", response_model=BatchDeleteResponse) async def batch_delete_persons( request: BatchDeleteRequest, - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), ): """ 批量删除人物信息 Args: request: 包含person_ids列表的请求 - authorization: Authorization header Returns: 批量删除结果 """ try: - verify_auth_token(maibot_session, authorization) - if not request.person_ids: raise HTTPException(status_code=400, detail="未提供要删除的人物ID") diff --git a/src/webui/routers/plugin.py b/src/webui/routers/plugin.py deleted file mode 100644 index 3a7c65ed..00000000 --- a/src/webui/routers/plugin.py +++ /dev/null @@ -1,2054 +0,0 @@ -from fastapi import APIRouter, HTTPException, Header, Cookie -from pydantic import BaseModel, Field -from typing import Optional, List, Dict, Any, get_origin -from pathlib import Path -import json -from src.common.logger import get_logger -from src.webui.utils.toml_utils import save_toml_with_format -from src.config.config import MMC_VERSION -from src.core.config_types import ConfigField -from src.webui.services.git_mirror_service import get_git_mirror_service, set_update_progress_callback -from src.webui.core import get_token_manager -from src.webui.routers.websocket.plugin_progress import update_progress - -logger = get_logger("webui.plugin_routes") - -# 创建路由器 -router = APIRouter(prefix="/plugins", tags=["插件管理"]) - -# 设置进度更新回调 -set_update_progress_callback(update_progress) - - -def get_token_from_cookie_or_header( - maibot_session: Optional[str] = None, - authorization: Optional[str] = None, -) -> Optional[str]: - """从 Cookie 或 Header 获取 token""" - # 优先从 Cookie 获取 - if maibot_session: - return maibot_session - # 其次从 Header 获取 - if authorization and authorization.startswith("Bearer "): - return authorization.replace("Bearer ", "") - return None - - -def validate_safe_path(user_path: str, base_path: Path) -> Path: - """ - 验证用户提供的路径是否安全,防止路径遍历攻击 - - Args: - user_path: 用户输入的路径(相对路径) - base_path: 允许的基础目录 - - Returns: - 安全的绝对路径 - - Raises: - HTTPException: 如果检测到路径遍历攻击 - """ - # 规范化基础路径 - base_resolved = base_path.resolve() - - # 检查用户路径是否包含可疑字符 - # 禁止: .., 绝对路径开头, 空字节等 - if any(pattern in user_path for pattern in ["..", "\x00"]): - logger.warning(f"检测到可疑路径: {user_path}") - raise HTTPException(status_code=400, detail="路径包含非法字符") - - # 检查是否为绝对路径(Windows 和 Unix) - if user_path.startswith("/") or user_path.startswith("\\") or (len(user_path) > 1 and user_path[1] == ":"): - logger.warning(f"检测到绝对路径: {user_path}") - raise HTTPException(status_code=400, detail="不允许使用绝对路径") - - # 构建目标路径并解析 - target_path = (base_path / user_path).resolve() - - # 验证解析后的路径仍在基础目录内 - try: - target_path.relative_to(base_resolved) - except ValueError as e: - logger.warning(f"路径遍历攻击检测: {user_path} -> {target_path}") - raise HTTPException(status_code=400, detail="路径超出允许范围") from e - - return target_path - - -def validate_plugin_id(plugin_id: str) -> str: - """ - 验证插件 ID 格式是否安全 - - Args: - plugin_id: 插件 ID (支持 author.name 格式,允许中文) - - Returns: - 验证通过的插件 ID - - Raises: - HTTPException: 如果插件 ID 格式不安全 - """ - # 禁止空字符串 - if not plugin_id or not plugin_id.strip(): - logger.warning("非法插件 ID: 空字符串") - raise HTTPException(status_code=400, detail="插件 ID 不能为空") - - # 禁止危险字符: 路径分隔符、空字节、控制字符等 - dangerous_patterns = ["/", "\\", "\x00", "..", "\n", "\r", "\t"] - for pattern in dangerous_patterns: - if pattern in plugin_id: - logger.warning(f"非法插件 ID 格式: {plugin_id} (包含危险字符)") - raise HTTPException(status_code=400, detail="插件 ID 包含非法字符") - - # 禁止以点开头或结尾(防止隐藏文件和路径问题) - if plugin_id.startswith(".") or plugin_id.endswith("."): - logger.warning(f"非法插件 ID: {plugin_id}") - raise HTTPException(status_code=400, detail="插件 ID 不能以点开头或结尾") - - # 禁止特殊名称 - if plugin_id in (".", ".."): - logger.warning(f"非法插件 ID: {plugin_id}") - raise HTTPException(status_code=400, detail="插件 ID 不能为特殊目录名") - - return plugin_id - - -def parse_version(version_str: str) -> tuple[int, int, int]: - """ - 解析版本号字符串 - - 支持格式: - - 0.11.2 -> (0, 11, 2) - - 0.11.2.snapshot.2 -> (0, 11, 2) - - Returns: - (major, minor, patch) 三元组 - """ - # 移除 snapshot、dev、alpha、beta 等后缀(支持 - 和 . 分隔符) - import re - - # 匹配 -snapshot.X, .snapshot, -dev, .dev, -alpha, .alpha, -beta, .beta 等后缀 - base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version_str, flags=re.IGNORECASE)[0] - - parts = base_version.split(".") - if len(parts) < 3: - # 补齐到 3 位 - parts.extend(["0"] * (3 - len(parts))) - - try: - major = int(parts[0]) - minor = int(parts[1]) - patch = int(parts[2]) - return (major, minor, patch) - except (ValueError, IndexError): - logger.warning(f"无法解析版本号: {version_str},返回默认值 (0, 0, 0)") - return (0, 0, 0) - - -# ============ 工具函数(避免在请求内重复定义) ============ - - -def _deep_merge(dst: Dict[str, Any], src: Dict[str, Any]) -> None: - """深度合并两个字典,src 的值会覆盖或合并到 dst 中。""" - for k, v in src.items(): - if k in dst and isinstance(dst[k], dict) and isinstance(v, dict): - _deep_merge(dst[k], v) - else: - dst[k] = v - - -def normalize_dotted_keys(obj: Dict[str, Any]) -> Dict[str, Any]: - """ - 将形如 {'a.b': 1} 的键展开为嵌套结构 {'a': {'b': 1}}。 - 若遇到中间节点已存在且非字典,记录日志并覆盖为字典。 - """ - result: Dict[str, Any] = {} - dotted_items = [] - - # 先处理非点号键,避免后续展开覆盖已有结构 - for k, v in obj.items(): - if "." in k: - dotted_items.append((k, v)) - else: - result[k] = normalize_dotted_keys(v) if isinstance(v, dict) else v - - # 再处理点号键 - for dotted_key, v in dotted_items: - value = normalize_dotted_keys(v) if isinstance(v, dict) else v - parts = dotted_key.split(".") - if "" in parts: - logger.warning(f"键路径包含空段: '{dotted_key}'") - parts = [p for p in parts if p] - if not parts: - logger.warning(f"忽略空键路径: '{dotted_key}'") - continue - current = result - # 中间层 - for idx, part in enumerate(parts[:-1]): - if part in current and not isinstance(current[part], dict): - path_ctx = ".".join(parts[: idx + 1]) - logger.warning(f"键冲突:{part} 已存在且非字典,覆盖为字典以展开 {dotted_key} (路径 {path_ctx})") - current[part] = {} - current = current.setdefault(part, {}) - # 最后一层 - last_part = parts[-1] - if last_part in current and isinstance(current[last_part], dict) and isinstance(value, dict): - _deep_merge(current[last_part], value) - else: - current[last_part] = value - - return result - - -def coerce_types(schema_part: Dict[str, Any], config_part: Dict[str, Any]) -> None: - """ - 根据 schema 将配置中的类型纠正(目前只纠正 list-from-str)。 - """ - - def _is_list_type(tp: Any) -> bool: - origin = get_origin(tp) - return tp is list or origin is list - - for key, schema_val in schema_part.items(): - if key not in config_part: - continue - value = config_part[key] - if isinstance(schema_val, ConfigField): - if _is_list_type(schema_val.type) and isinstance(value, str): - config_part[key] = [item.strip() for item in value.split(",") if item.strip()] - elif isinstance(schema_val, dict) and isinstance(value, dict): - coerce_types(schema_val, value) - - -def find_plugin_instance(plugin_id: str) -> Optional[Any]: - """ - 按 plugin_id 查找已加载的插件信息。 - 新运行时中插件运行在子进程,无法获取实例,返回注册信息。 - """ - from src.plugin_runtime.integration import get_plugin_runtime_manager - - mgr = get_plugin_runtime_manager() - for sv in mgr.supervisors: - reg = sv._registered_plugins.get(plugin_id) - if reg is not None: - return reg - return None - - -# ============ 请求/响应模型 ============ - - -class FetchRawFileRequest(BaseModel): - """获取 Raw 文件请求""" - - owner: str = Field(..., description="仓库所有者", example="MaiM-with-u") - repo: str = Field(..., description="仓库名称", example="plugin-repo") - branch: str = Field(..., description="分支名称", example="main") - file_path: str = Field(..., description="文件路径", example="plugin_details.json") - mirror_id: Optional[str] = Field(None, description="指定镜像源 ID") - custom_url: Optional[str] = Field(None, description="自定义完整 URL") - - -class FetchRawFileResponse(BaseModel): - """获取 Raw 文件响应""" - - success: bool = Field(..., description="是否成功") - data: Optional[str] = Field(None, description="文件内容") - error: Optional[str] = Field(None, description="错误信息") - mirror_used: Optional[str] = Field(None, description="使用的镜像源") - attempts: int = Field(..., description="尝试次数") - url: Optional[str] = Field(None, description="实际请求的 URL") - - -class CloneRepositoryRequest(BaseModel): - """克隆仓库请求""" - - owner: str = Field(..., description="仓库所有者", example="MaiM-with-u") - repo: str = Field(..., description="仓库名称", example="plugin-repo") - target_path: str = Field(..., description="目标路径(相对于插件目录)") - branch: Optional[str] = Field(None, description="分支名称", example="main") - mirror_id: Optional[str] = Field(None, description="指定镜像源 ID") - custom_url: Optional[str] = Field(None, description="自定义克隆 URL") - depth: Optional[int] = Field(None, description="克隆深度(浅克隆)", ge=1) - - -class CloneRepositoryResponse(BaseModel): - """克隆仓库响应""" - - success: bool = Field(..., description="是否成功") - path: Optional[str] = Field(None, description="克隆路径") - error: Optional[str] = Field(None, description="错误信息") - mirror_used: Optional[str] = Field(None, description="使用的镜像源") - attempts: int = Field(..., description="尝试次数") - url: Optional[str] = Field(None, description="实际克隆的 URL") - message: Optional[str] = Field(None, description="附加信息") - - -class MirrorConfigResponse(BaseModel): - """镜像源配置响应""" - - id: str = Field(..., description="镜像源 ID") - name: str = Field(..., description="镜像源名称") - raw_prefix: str = Field(..., description="Raw 文件前缀") - clone_prefix: str = Field(..., description="克隆前缀") - enabled: bool = Field(..., description="是否启用") - priority: int = Field(..., description="优先级(数字越小优先级越高)") - - -class AvailableMirrorsResponse(BaseModel): - """可用镜像源列表响应""" - - mirrors: List[MirrorConfigResponse] = Field(..., description="镜像源列表") - default_priority: List[str] = Field(..., description="默认优先级顺序(ID 列表)") - - -class AddMirrorRequest(BaseModel): - """添加镜像源请求""" - - id: str = Field(..., description="镜像源 ID", example="custom-mirror") - name: str = Field(..., description="镜像源名称", example="自定义镜像源") - raw_prefix: str = Field(..., description="Raw 文件前缀", example="https://example.com/raw") - clone_prefix: str = Field(..., description="克隆前缀", example="https://example.com/clone") - enabled: bool = Field(True, description="是否启用") - priority: Optional[int] = Field(None, description="优先级") - - -class UpdateMirrorRequest(BaseModel): - """更新镜像源请求""" - - name: Optional[str] = Field(None, description="镜像源名称") - raw_prefix: Optional[str] = Field(None, description="Raw 文件前缀") - clone_prefix: Optional[str] = Field(None, description="克隆前缀") - enabled: Optional[bool] = Field(None, description="是否启用") - priority: Optional[int] = Field(None, description="优先级") - - -class GitStatusResponse(BaseModel): - """Git 安装状态响应""" - - installed: bool = Field(..., description="是否已安装 Git") - version: Optional[str] = Field(None, description="Git 版本号") - path: Optional[str] = Field(None, description="Git 可执行文件路径") - error: Optional[str] = Field(None, description="错误信息") - - -class InstallPluginRequest(BaseModel): - """安装插件请求""" - - plugin_id: str = Field(..., description="插件 ID") - repository_url: str = Field(..., description="插件仓库 URL") - branch: Optional[str] = Field("main", description="分支名称") - mirror_id: Optional[str] = Field(None, description="指定镜像源 ID") - - -class VersionResponse(BaseModel): - """麦麦版本响应""" - - version: str = Field(..., description="麦麦版本号") - version_major: int = Field(..., description="主版本号") - version_minor: int = Field(..., description="次版本号") - version_patch: int = Field(..., description="补丁版本号") - - -class UninstallPluginRequest(BaseModel): - """卸载插件请求""" - - plugin_id: str = Field(..., description="插件 ID") - - -class UpdatePluginRequest(BaseModel): - """更新插件请求""" - - plugin_id: str = Field(..., description="插件 ID") - repository_url: str = Field(..., description="插件仓库 URL") - branch: Optional[str] = Field("main", description="分支名称") - mirror_id: Optional[str] = Field(None, description="指定镜像源 ID") - - -# ============ API 路由 ============ - - -@router.get("/version", response_model=VersionResponse) -async def get_maimai_version() -> VersionResponse: - """ - 获取麦麦版本信息 - - 此接口无需认证,用于前端检查插件兼容性 - """ - major, minor, patch = parse_version(MMC_VERSION) - - return VersionResponse(version=MMC_VERSION, version_major=major, version_minor=minor, version_patch=patch) - - -@router.get("/git-status", response_model=GitStatusResponse) -async def check_git_status() -> GitStatusResponse: - """ - 检查本机 Git 安装状态 - - 此接口无需认证,用于前端快速检测是否可以使用插件安装功能 - """ - service = get_git_mirror_service() - result = service.check_git_installed() - - return GitStatusResponse(**result) - - -@router.get("/mirrors", response_model=AvailableMirrorsResponse) -async def get_available_mirrors( - maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) -) -> AvailableMirrorsResponse: - """ - 获取所有可用的镜像源配置 - """ - # Token 验证 - token = get_token_from_cookie_or_header(maibot_session, authorization) - token_manager = get_token_manager() - if not token or not token_manager.verify_token(token): - raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - - service = get_git_mirror_service() - config = service.get_mirror_config() - - all_mirrors = config.get_all_mirrors() - mirrors = [ - MirrorConfigResponse( - id=m["id"], - name=m["name"], - raw_prefix=m["raw_prefix"], - clone_prefix=m["clone_prefix"], - enabled=m["enabled"], - priority=m["priority"], - ) - for m in all_mirrors - ] - - return AvailableMirrorsResponse(mirrors=mirrors, default_priority=config.get_default_priority_list()) - - -@router.post("/mirrors", response_model=MirrorConfigResponse) -async def add_mirror( - request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) -) -> MirrorConfigResponse: - """ - 添加新的镜像源 - """ - # Token 验证 - token = get_token_from_cookie_or_header(maibot_session, authorization) - token_manager = get_token_manager() - if not token or not token_manager.verify_token(token): - raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - - try: - service = get_git_mirror_service() - config = service.get_mirror_config() - - mirror = config.add_mirror( - mirror_id=request.id, - name=request.name, - raw_prefix=request.raw_prefix, - clone_prefix=request.clone_prefix, - enabled=request.enabled, - priority=request.priority, - ) - - return MirrorConfigResponse( - id=mirror["id"], - name=mirror["name"], - raw_prefix=mirror["raw_prefix"], - clone_prefix=mirror["clone_prefix"], - enabled=mirror["enabled"], - priority=mirror["priority"], - ) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) from e - except Exception as e: - logger.error(f"添加镜像源失败: {e}") - raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e - - -@router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse) -async def update_mirror( - mirror_id: str, - request: UpdateMirrorRequest, - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -) -> MirrorConfigResponse: - """ - 更新镜像源配置 - """ - # Token 验证 - token = get_token_from_cookie_or_header(maibot_session, authorization) - token_manager = get_token_manager() - if not token or not token_manager.verify_token(token): - raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - - try: - service = get_git_mirror_service() - config = service.get_mirror_config() - - mirror = config.update_mirror( - mirror_id=mirror_id, - name=request.name, - raw_prefix=request.raw_prefix, - clone_prefix=request.clone_prefix, - enabled=request.enabled, - priority=request.priority, - ) - - if not mirror: - raise HTTPException(status_code=404, detail=f"未找到镜像源: {mirror_id}") - - return MirrorConfigResponse( - id=mirror["id"], - name=mirror["name"], - raw_prefix=mirror["raw_prefix"], - clone_prefix=mirror["clone_prefix"], - enabled=mirror["enabled"], - priority=mirror["priority"], - ) - except HTTPException: - raise - except Exception as e: - logger.error(f"更新镜像源失败: {e}") - raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e - - -@router.delete("/mirrors/{mirror_id}") -async def delete_mirror( - mirror_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) -) -> Dict[str, Any]: - """ - 删除镜像源 - """ - # Token 验证 - token = get_token_from_cookie_or_header(maibot_session, authorization) - token_manager = get_token_manager() - if not token or not token_manager.verify_token(token): - raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - - service = get_git_mirror_service() - config = service.get_mirror_config() - - success = config.delete_mirror(mirror_id) - - if not success: - raise HTTPException(status_code=404, detail=f"未找到镜像源: {mirror_id}") - - return {"success": True, "message": f"已删除镜像源: {mirror_id}"} - - -@router.post("/fetch-raw", response_model=FetchRawFileResponse) -async def fetch_raw_file( - request: FetchRawFileRequest, - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -) -> FetchRawFileResponse: - """ - 获取 GitHub 仓库的 Raw 文件内容 - - 支持多镜像源自动切换和错误重试 - - 需要认证才能访问,防止被滥用作为 SSRF 跳板 - """ - # Token 验证(强制) - token = get_token_from_cookie_or_header(maibot_session, authorization) - token_manager = get_token_manager() - if not token or not token_manager.verify_token(token): - raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - - logger.info(f"收到获取 Raw 文件请求: {request.owner}/{request.repo}/{request.branch}/{request.file_path}") - - # 发送开始加载进度 - await update_progress( - stage="loading", - progress=10, - message=f"正在获取插件列表: {request.file_path}", - total_plugins=0, - loaded_plugins=0, - ) - - try: - service = get_git_mirror_service() - - # git_mirror_service 会自动推送 30%-70% 的详细镜像源尝试进度 - result = await service.fetch_raw_file( - owner=request.owner, - repo=request.repo, - branch=request.branch, - file_path=request.file_path, - mirror_id=request.mirror_id, - custom_url=request.custom_url, - ) - - if result.get("success"): - # 更新进度:成功获取 - await update_progress( - stage="loading", progress=70, message="正在解析插件数据...", total_plugins=0, loaded_plugins=0 - ) - - # 尝试解析插件数量 - try: - import json - - data = json.loads(result.get("data", "[]")) - total = len(data) if isinstance(data, list) else 0 - - # 发送成功状态 - await update_progress( - stage="success", - progress=100, - message=f"成功加载 {total} 个插件", - total_plugins=total, - loaded_plugins=total, - ) - except Exception: - # 如果解析失败,仍然发送成功状态 - await update_progress( - stage="success", progress=100, message="加载完成", total_plugins=0, loaded_plugins=0 - ) - - return FetchRawFileResponse(**result) - - except Exception as e: - logger.error(f"获取 Raw 文件失败: {e}") - - # 发送错误进度 - await update_progress( - stage="error", progress=0, message="加载失败", error=str(e), total_plugins=0, loaded_plugins=0 - ) - - raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e - - -@router.post("/clone", response_model=CloneRepositoryResponse) -async def clone_repository( - request: CloneRepositoryRequest, - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -) -> CloneRepositoryResponse: - """ - 克隆 GitHub 仓库到本地 - - 支持多镜像源自动切换和错误重试 - """ - # Token 验证 - token = get_token_from_cookie_or_header(maibot_session, authorization) - token_manager = get_token_manager() - if not token or not token_manager.verify_token(token): - raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - - logger.info(f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}") - - try: - # 验证 target_path 的安全性,防止路径遍历攻击 - base_plugin_path = Path("./plugins").resolve() - base_plugin_path.mkdir(exist_ok=True) - target_path = validate_safe_path(request.target_path, base_plugin_path) - - service = get_git_mirror_service() - result = await service.clone_repository( - owner=request.owner, - repo=request.repo, - target_path=target_path, - branch=request.branch, - mirror_id=request.mirror_id, - custom_url=request.custom_url, - depth=request.depth, - ) - - return CloneRepositoryResponse(**result) - - except Exception as e: - logger.error(f"克隆仓库失败: {e}") - raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e - - -@router.post("/install") -async def install_plugin( - request: InstallPluginRequest, - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -) -> Dict[str, Any]: - """ - 安装插件 - - 从 Git 仓库克隆插件到本地插件目录 - """ - # Token 验证 - token = get_token_from_cookie_or_header(maibot_session, authorization) - token_manager = get_token_manager() - if not token or not token_manager.verify_token(token): - raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - - logger.info(f"收到安装插件请求: {request.plugin_id}") - - try: - # 验证插件 ID 格式安全性 - plugin_id = validate_plugin_id(request.plugin_id) - - # 推送进度:开始安装 - await update_progress( - stage="loading", - progress=5, - message=f"开始安装插件: {plugin_id}", - operation="install", - plugin_id=plugin_id, - ) - - # 1. 解析仓库 URL - # repository_url 格式: https://github.com/owner/repo - repo_url = request.repository_url.rstrip("/") - if repo_url.endswith(".git"): - repo_url = repo_url[:-4] - - parts = repo_url.split("/") - if len(parts) < 2: - raise HTTPException(status_code=400, detail="无效的仓库 URL") - - owner = parts[-2] - repo = parts[-1] - - await update_progress( - stage="loading", - progress=10, - message=f"解析仓库信息: {owner}/{repo}", - operation="install", - plugin_id=plugin_id, - ) - - # 2. 确定插件安装路径 - plugins_dir = Path("plugins").resolve() - plugins_dir.mkdir(exist_ok=True) - - # 将插件 ID 中的点替换为下划线作为文件夹名称(避免文件系统问题) - # 例如: SengokuCola.Mute-Plugin -> SengokuCola_Mute-Plugin - folder_name = plugin_id.replace(".", "_") - # 使用安全路径验证,防止路径遍历 - target_path = validate_safe_path(folder_name, plugins_dir) - - # 检查插件是否已安装(需要检查两种格式:新格式下划线和旧格式点) - old_format_path = plugins_dir / plugin_id - if target_path.exists() or old_format_path.exists(): - await update_progress( - stage="error", - progress=0, - message="插件已存在", - operation="install", - plugin_id=plugin_id, - error="插件已安装,请先卸载", - ) - raise HTTPException(status_code=400, detail="插件已安装") - - await update_progress( - stage="loading", - progress=15, - message=f"准备克隆到: {target_path}", - operation="install", - plugin_id=plugin_id, - ) - - # 3. 克隆仓库(这里会自动推送 20%-80% 的进度) - service = get_git_mirror_service() - - # 如果是 GitHub 仓库,使用镜像源 - if "github.com" in repo_url: - result = await service.clone_repository( - owner=owner, - repo=repo, - target_path=target_path, - branch=request.branch, - mirror_id=request.mirror_id, - depth=1, # 浅克隆,节省时间和空间 - ) - else: - # 自定义仓库,直接使用 URL - result = await service.clone_repository( - owner=owner, repo=repo, target_path=target_path, branch=request.branch, custom_url=repo_url, depth=1 - ) - - if not result.get("success"): - error_msg = result.get("error", "克隆失败") - await update_progress( - stage="error", - progress=0, - message="克隆仓库失败", - operation="install", - plugin_id=plugin_id, - error=error_msg, - ) - raise HTTPException(status_code=500, detail=error_msg) - - # 4. 验证插件完整性 - await update_progress( - stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id - ) - - manifest_path = target_path / "_manifest.json" - if not manifest_path.exists(): - # 清理失败的安装 - import shutil - - shutil.rmtree(target_path, ignore_errors=True) - - await update_progress( - stage="error", - progress=0, - message="插件缺少 _manifest.json", - operation="install", - plugin_id=plugin_id, - error="无效的插件格式", - ) - raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json") - - # 5. 读取并验证 manifest - await update_progress( - stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=plugin_id - ) - - try: - import json as json_module - - with open(manifest_path, "r", encoding="utf-8") as f: - manifest = json_module.load(f) - - # 基本验证 - required_fields = ["manifest_version", "name", "version", "author"] - for field in required_fields: - if field not in manifest: - raise ValueError(f"缺少必需字段: {field}") - - # 将插件 ID 写入 manifest(用于后续准确识别) - # 这样即使文件夹名称改变,也能通过 manifest 准确识别插件 - manifest["id"] = plugin_id - with open(manifest_path, "w", encoding="utf-8") as f: - json_module.dump(manifest, f, ensure_ascii=False, indent=2) - - except Exception as e: - # 清理失败的安装 - import shutil - - shutil.rmtree(target_path, ignore_errors=True) - - await update_progress( - stage="error", - progress=0, - message="_manifest.json 无效", - operation="install", - plugin_id=plugin_id, - error=str(e), - ) - raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e - - # 6. 安装成功 - await update_progress( - stage="success", - progress=100, - message=f"成功安装插件: {manifest['name']} v{manifest['version']}", - operation="install", - plugin_id=plugin_id, - ) - - return { - "success": True, - "message": "插件安装成功", - "plugin_id": plugin_id, - "plugin_name": manifest["name"], - "version": manifest["version"], - "path": str(target_path), - } - - except HTTPException: - raise - except Exception as e: - logger.error(f"安装插件失败: {e}", exc_info=True) - - await update_progress( - stage="error", - progress=0, - message="安装失败", - operation="install", - plugin_id=plugin_id, - error=str(e), - ) - - raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e - - -@router.post("/uninstall") -async def uninstall_plugin( - request: UninstallPluginRequest, - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -) -> Dict[str, Any]: - """ - 卸载插件 - - 删除插件目录及其所有文件 - """ - # Token 验证 - token = get_token_from_cookie_or_header(maibot_session, authorization) - token_manager = get_token_manager() - if not token or not token_manager.verify_token(token): - raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - - logger.info(f"收到卸载插件请求: {request.plugin_id}") - - try: - # 验证插件 ID 格式安全性 - plugin_id = validate_plugin_id(request.plugin_id) - - # 推送进度:开始卸载 - await update_progress( - stage="loading", - progress=10, - message=f"开始卸载插件: {plugin_id}", - operation="uninstall", - plugin_id=plugin_id, - ) - - # 1. 检查插件是否存在(支持新旧两种格式) - plugins_dir = Path("plugins").resolve() - # 新格式:下划线 - folder_name = plugin_id.replace(".", "_") - # 使用安全路径验证 - plugin_path = validate_safe_path(folder_name, plugins_dir) - # 旧格式:点 - old_format_path = validate_safe_path(plugin_id, plugins_dir) - - # 优先使用新格式,如果不存在则尝试旧格式 - if not plugin_path.exists(): - if old_format_path.exists(): - plugin_path = old_format_path - else: - await update_progress( - stage="error", - progress=0, - message="插件不存在", - operation="uninstall", - plugin_id=plugin_id, - error="插件未安装或已被删除", - ) - raise HTTPException(status_code=404, detail="插件未安装") - - await update_progress( - stage="loading", - progress=30, - message=f"正在删除插件文件: {plugin_path}", - operation="uninstall", - plugin_id=plugin_id, - ) - - # 2. 读取插件信息(用于日志) - manifest_path = plugin_path / "_manifest.json" - plugin_name = plugin_id - - if manifest_path.exists(): - try: - import json as json_module - - with open(manifest_path, "r", encoding="utf-8") as f: - manifest = json_module.load(f) - plugin_name = manifest.get("name", plugin_id) - except Exception: - pass # 如果读取失败,使用插件 ID 作为名称 - - await update_progress( - stage="loading", - progress=50, - message=f"正在删除 {plugin_name}...", - operation="uninstall", - plugin_id=plugin_id, - ) - - # 3. 删除插件目录 - import shutil - import stat - - def remove_readonly(func, path, _): - """清除只读属性并删除文件""" - import os - - os.chmod(path, stat.S_IWRITE) - func(path) - - shutil.rmtree(plugin_path, onerror=remove_readonly) - - logger.info(f"成功卸载插件: {plugin_id} ({plugin_name})") - - # 4. 推送成功状态 - await update_progress( - stage="success", - progress=100, - message=f"成功卸载插件: {plugin_name}", - operation="uninstall", - plugin_id=plugin_id, - ) - - return {"success": True, "message": "插件卸载成功", "plugin_id": plugin_id, "plugin_name": plugin_name} - - except HTTPException: - raise - except PermissionError as e: - logger.error(f"卸载插件失败(权限错误): {e}") - - await update_progress( - stage="error", - progress=0, - message="卸载失败", - operation="uninstall", - plugin_id=plugin_id, - error="权限不足,无法删除插件文件", - ) - - raise HTTPException(status_code=500, detail="权限不足,无法删除插件文件") from e - except Exception as e: - logger.error(f"卸载插件失败: {e}", exc_info=True) - - await update_progress( - stage="error", - progress=0, - message="卸载失败", - operation="uninstall", - plugin_id=plugin_id, - error=str(e), - ) - - raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e - - -@router.post("/update") -async def update_plugin( - request: UpdatePluginRequest, - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -) -> Dict[str, Any]: - """ - 更新插件 - - 删除旧版本,重新克隆新版本 - """ - # Token 验证 - token = get_token_from_cookie_or_header(maibot_session, authorization) - token_manager = get_token_manager() - if not token or not token_manager.verify_token(token): - raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - - logger.info(f"收到更新插件请求: {request.plugin_id}") - - try: - # 验证插件 ID 格式安全性 - plugin_id = validate_plugin_id(request.plugin_id) - - # 推送进度:开始更新 - await update_progress( - stage="loading", - progress=5, - message=f"开始更新插件: {plugin_id}", - operation="update", - plugin_id=plugin_id, - ) - - # 1. 检查插件是否已安装(支持新旧两种格式) - plugins_dir = Path("plugins").resolve() - # 新格式:下划线 - folder_name = plugin_id.replace(".", "_") - # 使用安全路径验证 - plugin_path = validate_safe_path(folder_name, plugins_dir) - # 旧格式:点 - old_format_path = validate_safe_path(plugin_id, plugins_dir) - - # 优先使用新格式,如果不存在则尝试旧格式 - if not plugin_path.exists(): - if old_format_path.exists(): - plugin_path = old_format_path - else: - await update_progress( - stage="error", - progress=0, - message="插件不存在", - operation="update", - plugin_id=plugin_id, - error="插件未安装,请先安装", - ) - raise HTTPException(status_code=404, detail="插件未安装") - - # 2. 读取旧版本信息 - manifest_path = plugin_path / "_manifest.json" - old_version = "unknown" - - if manifest_path.exists(): - try: - import json as json_module - - with open(manifest_path, "r", encoding="utf-8") as f: - manifest = json_module.load(f) - old_version = manifest.get("version", "unknown") - except Exception: - pass - - await update_progress( - stage="loading", - progress=10, - message=f"当前版本: {old_version},准备更新...", - operation="update", - plugin_id=plugin_id, - ) - - # 3. 删除旧版本 - await update_progress( - stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id - ) - - import shutil - import stat - - def remove_readonly(func, path, _): - """清除只读属性并删除文件""" - import os - - os.chmod(path, stat.S_IWRITE) - func(path) - - shutil.rmtree(plugin_path, onerror=remove_readonly) - - logger.info(f"已删除旧版本: {plugin_id} v{old_version}") - - # 4. 解析仓库 URL - await update_progress( - stage="loading", - progress=30, - message="正在准备下载新版本...", - operation="update", - plugin_id=plugin_id, - ) - - repo_url = request.repository_url.rstrip("/") - if repo_url.endswith(".git"): - repo_url = repo_url[:-4] - - parts = repo_url.split("/") - if len(parts) < 2: - raise HTTPException(status_code=400, detail="无效的仓库 URL") - - owner = parts[-2] - repo = parts[-1] - - # 5. 克隆新版本(这里会推送 35%-85% 的进度) - service = get_git_mirror_service() - - if "github.com" in repo_url: - result = await service.clone_repository( - owner=owner, - repo=repo, - target_path=plugin_path, - branch=request.branch, - mirror_id=request.mirror_id, - depth=1, - ) - else: - result = await service.clone_repository( - owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, custom_url=repo_url, depth=1 - ) - - if not result.get("success"): - error_msg = result.get("error", "克隆失败") - await update_progress( - stage="error", - progress=0, - message="下载新版本失败", - operation="update", - plugin_id=plugin_id, - error=error_msg, - ) - raise HTTPException(status_code=500, detail=error_msg) - - # 6. 验证新版本 - await update_progress( - stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id - ) - - new_manifest_path = plugin_path / "_manifest.json" - if not new_manifest_path.exists(): - # 清理失败的更新 - def remove_readonly(func, path, _): - """清除只读属性并删除文件""" - import os - - os.chmod(path, stat.S_IWRITE) - func(path) - - shutil.rmtree(plugin_path, onerror=remove_readonly) - - await update_progress( - stage="error", - progress=0, - message="新版本缺少 _manifest.json", - operation="update", - plugin_id=plugin_id, - error="无效的插件格式", - ) - raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json") - - # 7. 读取新版本信息 - try: - with open(new_manifest_path, "r", encoding="utf-8") as f: - new_manifest = json_module.load(f) - - new_version = new_manifest.get("version", "unknown") - new_name = new_manifest.get("name", plugin_id) - - logger.info(f"成功更新插件: {plugin_id} {old_version} → {new_version}") - - # 8. 推送成功状态 - await update_progress( - stage="success", - progress=100, - message=f"成功更新 {new_name}: {old_version} → {new_version}", - operation="update", - plugin_id=plugin_id, - ) - - return { - "success": True, - "message": "插件更新成功", - "plugin_id": plugin_id, - "plugin_name": new_name, - "old_version": old_version, - "new_version": new_version, - } - - except Exception as e: - # 清理失败的更新 - shutil.rmtree(plugin_path, ignore_errors=True) - - await update_progress( - stage="error", - progress=0, - message="_manifest.json 无效", - operation="update", - plugin_id=plugin_id, - error=str(e), - ) - raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e - - except HTTPException: - raise - except Exception as e: - logger.error(f"更新插件失败: {e}", exc_info=True) - - await update_progress( - stage="error", progress=0, message="更新失败", operation="update", plugin_id=plugin_id, error=str(e) - ) - - raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e - - -@router.get("/installed") -async def get_installed_plugins( - maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) -) -> Dict[str, Any]: - """ - 获取已安装的插件列表 - - 扫描 plugins 目录,返回所有已安装插件的 ID 和基本信息 - """ - # Token 验证 - token = get_token_from_cookie_or_header(maibot_session, authorization) - token_manager = get_token_manager() - if not token or not token_manager.verify_token(token): - raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - - logger.info("收到获取已安装插件列表请求") - - try: - plugins_dir = Path("plugins") - - # 如果插件目录不存在,返回空列表 - if not plugins_dir.exists(): - logger.info("插件目录不存在,创建目录") - plugins_dir.mkdir(exist_ok=True) - return {"success": True, "plugins": []} - - installed_plugins = [] - - # 遍历插件目录 - for plugin_path in plugins_dir.iterdir(): - # 只处理目录 - if not plugin_path.is_dir(): - continue - - # 目录名(可能是下划线格式、点格式或其他格式) - folder_name = plugin_path.name - - # 跳过隐藏目录和特殊目录 - if folder_name.startswith(".") or folder_name.startswith("__"): - continue - - # 读取 _manifest.json - manifest_path = plugin_path / "_manifest.json" - - if not manifest_path.exists(): - logger.warning(f"插件文件夹 {folder_name} 缺少 _manifest.json,跳过") - continue - - try: - import json as json_module - - with open(manifest_path, "r", encoding="utf-8") as f: - manifest = json_module.load(f) - - # 基本验证 - if "name" not in manifest or "version" not in manifest: - logger.warning(f"插件文件夹 {folder_name} 的 _manifest.json 格式无效,跳过") - continue - - # 获取插件 ID(优先从 manifest,否则从文件夹名推断) - if "id" in manifest: - # 优先使用 manifest 中的 id(最准确) - plugin_id = manifest["id"] - else: - # 从 manifest 信息构建 ID - # 尝试从 author.name 和 repository_url 构建标准 ID - author_name = None - repo_name = None - - # 获取作者名 - if "author" in manifest: - if isinstance(manifest["author"], dict) and "name" in manifest["author"]: - author_name = manifest["author"]["name"] - elif isinstance(manifest["author"], str): - author_name = manifest["author"] - - # 从 repository_url 获取仓库名 - if "repository_url" in manifest: - repo_url = manifest["repository_url"].rstrip("/") - if repo_url.endswith(".git"): - repo_url = repo_url[:-4] - repo_name = repo_url.split("/")[-1] - - # 构建 ID - if author_name and repo_name: - # 标准格式: Author.RepoName - plugin_id = f"{author_name}.{repo_name}" - elif author_name: - # 如果只有作者,使用 Author.FolderName - plugin_id = f"{author_name}.{folder_name}" - else: - # 从文件夹名推断 - if "_" in folder_name and "." not in folder_name: - # 假设格式为 Author_PluginName,转换为 Author.PluginName - plugin_id = folder_name.replace("_", ".", 1) - else: - # 直接使用文件夹名 - plugin_id = folder_name - - # 将推断的 ID 写入 manifest(方便下次识别) - logger.info(f"为插件 {folder_name} 自动生成 ID: {plugin_id}") - manifest["id"] = plugin_id - try: - with open(manifest_path, "w", encoding="utf-8") as f: - json_module.dump(manifest, f, ensure_ascii=False, indent=2) - except Exception as write_error: - logger.warning(f"无法写入 ID 到 manifest: {write_error}") - - # 添加到已安装列表(返回完整的 manifest 信息) - installed_plugins.append( - { - "id": plugin_id, - "manifest": manifest, # 返回完整的 manifest 对象 - "path": str(plugin_path.absolute()), - } - ) - - except json.JSONDecodeError as e: - logger.warning(f"插件 {folder_name} 的 _manifest.json 解析失败: {e}") - continue - except Exception as e: - logger.error(f"读取插件 {folder_name} 信息时出错: {e}") - continue - - # 去重:如果有重复的 plugin_id,只保留第一个(按路径) - seen_ids = {} # 记录 ID -> 路径的映射 - unique_plugins = [] - duplicates = [] - - for plugin in installed_plugins: - plugin_id = plugin["id"] - plugin_path = plugin["path"] - - if plugin_id not in seen_ids: - seen_ids[plugin_id] = plugin_path - unique_plugins.append(plugin) - else: - duplicates.append(plugin) - first_path = seen_ids[plugin_id] - logger.warning(f"重复插件 {plugin_id}: 保留 {first_path}, 跳过 {plugin_path}") - - if duplicates: - logger.warning(f"共检测到 {len(duplicates)} 个重复插件已去重") - - logger.info(f"找到 {len(unique_plugins)} 个已安装插件") - - return {"success": True, "plugins": unique_plugins, "total": len(unique_plugins)} - - except Exception as e: - logger.error(f"获取已安装插件列表失败: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e - - -@router.get("/local-readme/{plugin_id}") -async def get_local_plugin_readme( - plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) -) -> Dict[str, Any]: - """ - 获取本地已安装插件的 README 文件内容 - - Args: - plugin_id: 插件 ID - - Returns: - 包含 success 和 data(README 内容) 的字典,如果文件不存在则返回 success=False - """ - # Token 验证 - token = get_token_from_cookie_or_header(maibot_session, authorization) - token_manager = get_token_manager() - if not token or not token_manager.verify_token(token): - raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - - logger.info(f"获取本地插件 README: {plugin_id}") - - try: - plugins_dir = Path("plugins") - - # 查找插件目录 - plugin_path = None - for folder in plugins_dir.iterdir(): - if not folder.is_dir(): - continue - - manifest_path = folder / "_manifest.json" - if manifest_path.exists(): - try: - import json as json_module - - with open(manifest_path, "r", encoding="utf-8") as f: - manifest = json_module.load(f) - - # 检查是否匹配 plugin_id - if manifest.get("id") == plugin_id: - plugin_path = folder - break - except Exception: - continue - - if not plugin_path: - return {"success": False, "error": "插件未安装"} - - # 查找 README 文件(支持多种命名) - readme_files = ["README.md", "readme.md", "Readme.md", "README.MD"] - readme_content = None - - for readme_name in readme_files: - readme_path = plugin_path / readme_name - if readme_path.exists(): - try: - with open(readme_path, "r", encoding="utf-8") as f: - readme_content = f.read() - logger.info(f"成功读取本地 README: {readme_path}") - break - except Exception as e: - logger.warning(f"读取 {readme_path} 失败: {e}") - continue - - if readme_content: - return {"success": True, "data": readme_content} - else: - return {"success": False, "error": "本地未找到 README 文件"} - - except Exception as e: - logger.error(f"获取本地 README 失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - -# ============ 插件配置管理 API ============ - - -class UpdatePluginConfigRequest(BaseModel): - """更新插件配置请求""" - - config: Dict[str, Any] = Field(..., description="配置数据") - - -@router.get("/config/{plugin_id}/schema") -async def get_plugin_config_schema( - plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) -) -> Dict[str, Any]: - """ - 获取插件配置 Schema - - 返回插件的完整配置 schema,包含所有 section、字段定义和布局信息。 - 用于前端动态生成配置表单。 - """ - # Token 验证 - token = get_token_from_cookie_or_header(maibot_session, authorization) - token_manager = get_token_manager() - if not token or not token_manager.verify_token(token): - raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - - logger.info(f"获取插件配置 Schema: {plugin_id}") - - try: - # 新运行时中插件运行在子进程,无法直接获取实例的 webui_config_schema - # 尝试从文件系统读取 - plugin_instance = None - - if plugin_instance and hasattr(plugin_instance, "get_webui_config_schema"): - # 从插件实例获取 schema - schema = plugin_instance.get_webui_config_schema() - return {"success": True, "schema": schema} - - # 如果插件未加载,尝试从文件系统读取 - # 查找插件目录 - plugins_dir = Path("plugins") - plugin_path = None - - for p in plugins_dir.iterdir(): - if p.is_dir(): - manifest_path = p / "_manifest.json" - if manifest_path.exists(): - try: - with open(manifest_path, "r", encoding="utf-8") as f: - manifest = json.load(f) - if manifest.get("id") == plugin_id or p.name == plugin_id: - plugin_path = p - break - except Exception: - continue - - if not plugin_path: - raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") - - # 优先读取 config_schema.json(插件可选提供的富 UI 元数据) - schema_json_path = plugin_path / "config_schema.json" - if schema_json_path.exists(): - try: - with open(schema_json_path, "r", encoding="utf-8") as f: - schema = json.load(f) - return {"success": True, "schema": schema} - except Exception as e: - logger.warning(f"读取 config_schema.json 失败,回退到自动推断: {e}") - - # 读取配置文件获取当前配置 - config_path = plugin_path / "config.toml" - current_config = {} - if config_path.exists(): - import tomlkit - - with open(config_path, "r", encoding="utf-8") as f: - current_config = tomlkit.load(f) - - # 构建基础 schema(无法获取完整的 ConfigField 信息) - schema = { - "plugin_id": plugin_id, - "plugin_info": { - "name": plugin_id, - "version": "", - "description": "", - "author": "", - }, - "sections": {}, - "layout": {"type": "auto", "tabs": []}, - "_note": "插件未加载,仅返回当前配置结构", - } - - # 从当前配置推断 schema - for section_name, section_data in current_config.items(): - if isinstance(section_data, dict): - schema["sections"][section_name] = { - "name": section_name, - "title": section_name, - "description": None, - "icon": None, - "collapsed": False, - "order": 0, - "fields": {}, - } - for field_name, field_value in section_data.items(): - # 推断字段类型 - field_type = type(field_value).__name__ - ui_type = "text" - item_type = None - item_fields = None - - if isinstance(field_value, bool): - ui_type = "switch" - elif isinstance(field_value, (int, float)): - ui_type = "number" - elif isinstance(field_value, list): - ui_type = "list" - # 推断数组元素类型 - if field_value: - first_item = field_value[0] - if isinstance(first_item, dict): - item_type = "object" - # 从第一个元素推断字段结构 - item_fields = {} - for k, v in first_item.items(): - item_fields[k] = { - "type": "number" if isinstance(v, (int, float)) else "string", - "label": k, - "default": "" if isinstance(v, str) else 0, - } - elif isinstance(first_item, (int, float)): - item_type = "number" - else: - item_type = "string" - else: - item_type = "string" - elif isinstance(field_value, dict): - ui_type = "json" - - schema["sections"][section_name]["fields"][field_name] = { - "name": field_name, - "type": field_type, - "default": field_value, - "description": field_name, - "label": field_name, - "ui_type": ui_type, - "required": False, - "hidden": False, - "disabled": False, - "order": 0, - "item_type": item_type, - "item_fields": item_fields, - "min_items": None, - "max_items": None, - # 补充缺失的字段 - "placeholder": None, - "hint": None, - "icon": None, - "example": None, - "choices": None, - "min": None, - "max": None, - "step": None, - "pattern": None, - "max_length": None, - "input_type": None, - "rows": 3, - "group": None, - "depends_on": None, - "depends_value": None, - } - - return {"success": True, "schema": schema} - - except HTTPException: - raise - except Exception as e: - logger.error(f"获取插件配置 Schema 失败: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e - - -@router.get("/config/{plugin_id}/raw") -async def get_plugin_config_raw( - plugin_id: str, - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -) -> Dict[str, Any]: - """ - 获取插件原始 TOML 配置文件内容 - """ - # Token 验证 - token = get_token_from_cookie_or_header(maibot_session, authorization) - token_manager = get_token_manager() - if not token or not token_manager.verify_token(token): - raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - - logger.info(f"获取插件原始配置: {plugin_id}") - - try: - # 查找插件目录 - plugins_dir = Path("plugins") - plugin_path = None - - for p in plugins_dir.iterdir(): - if p.is_dir(): - manifest_path = p / "_manifest.json" - if manifest_path.exists(): - try: - with open(manifest_path, "r", encoding="utf-8") as f: - manifest = json.load(f) - if manifest.get("id") == plugin_id or p.name == plugin_id: - plugin_path = p - break - except Exception: - continue - - if not plugin_path: - raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") - - # 读取配置文件 - config_path = plugin_path / "config.toml" - if not config_path.exists(): - return {"success": True, "config": "", "message": "配置文件不存在"} - - with open(config_path, "r", encoding="utf-8") as f: - config_content = f.read() - - return {"success": True, "config": config_content} - - except HTTPException: - raise - except Exception as e: - logger.error(f"获取插件原始配置失败: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e - - -@router.put("/config/{plugin_id}/raw") -async def update_plugin_config_raw( - plugin_id: str, - request: UpdatePluginConfigRequest, - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -) -> Dict[str, Any]: - """ - 更新插件原始 TOML 配置文件 - - 直接保存 TOML 字符串到配置文件。 - """ - # Token 验证 - token = get_token_from_cookie_or_header(maibot_session, authorization) - token_manager = get_token_manager() - if not token or not token_manager.verify_token(token): - raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - - logger.info(f"更新插件原始配置: {plugin_id}") - - try: - # 查找插件目录 - plugins_dir = Path("plugins") - plugin_path = None - - for p in plugins_dir.iterdir(): - if p.is_dir(): - manifest_path = p / "_manifest.json" - if manifest_path.exists(): - try: - with open(manifest_path, "r", encoding="utf-8") as f: - manifest = json.load(f) - if manifest.get("id") == plugin_id or p.name == plugin_id: - plugin_path = p - break - except Exception: - continue - - if not plugin_path: - raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") - - config_path = plugin_path / "config.toml" - - # 验证 TOML 格式 - import tomlkit - - if not isinstance(request.config, str): - raise HTTPException(status_code=400, detail="配置必须是字符串格式的 TOML 内容") - - try: - tomlkit.loads(request.config) - except Exception as e: - raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e - - # 备份旧配置 - import shutil - import datetime - - if config_path.exists(): - backup_name = f"config.toml.backup.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}" - backup_path = plugin_path / backup_name - shutil.copy(config_path, backup_path) - logger.info(f"已备份配置文件: {backup_path}") - - # 写入新配置 - with open(config_path, "w", encoding="utf-8") as f: - f.write(request.config) - - logger.info(f"已更新插件原始配置: {plugin_id}") - - return {"success": True, "message": "配置已保存", "note": "配置更改将在插件重新加载后生效"} - - except HTTPException: - raise - except Exception as e: - logger.error(f"更新插件原始配置失败: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e - - -@router.get("/config/{plugin_id}") -async def get_plugin_config( - plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) -) -> Dict[str, Any]: - """ - 获取插件当前配置值 - - 返回插件的当前配置值。 - """ - # Token 验证 - token = get_token_from_cookie_or_header(maibot_session, authorization) - token_manager = get_token_manager() - if not token or not token_manager.verify_token(token): - raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - - logger.info(f"获取插件配置: {plugin_id}") - - try: - # 查找插件目录 - plugins_dir = Path("plugins") - plugin_path = None - - for p in plugins_dir.iterdir(): - if p.is_dir(): - manifest_path = p / "_manifest.json" - if manifest_path.exists(): - try: - with open(manifest_path, "r", encoding="utf-8") as f: - manifest = json.load(f) - if manifest.get("id") == plugin_id or p.name == plugin_id: - plugin_path = p - break - except Exception: - continue - - if not plugin_path: - raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") - - # 读取配置文件 - config_path = plugin_path / "config.toml" - if not config_path.exists(): - return {"success": True, "config": {}, "message": "配置文件不存在"} - - import tomlkit - - with open(config_path, "r", encoding="utf-8") as f: - config = tomlkit.load(f) - - return {"success": True, "config": dict(config)} - - except HTTPException: - raise - except Exception as e: - logger.error(f"获取插件配置失败: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e - - -@router.put("/config/{plugin_id}") -async def update_plugin_config( - plugin_id: str, - request: UpdatePluginConfigRequest, - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -) -> Dict[str, Any]: - """ - 更新插件配置 - - 保存新的配置值到插件的配置文件。 - """ - # Token 验证 - token = get_token_from_cookie_or_header(maibot_session, authorization) - token_manager = get_token_manager() - if not token or not token_manager.verify_token(token): - raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - - logger.info(f"更新插件配置: {plugin_id}") - - try: - plugin_instance = find_plugin_instance(plugin_id) - - # 纠正 WebUI 提交的数据结构(扁平键与字符串列表) - if plugin_instance and isinstance(request.config, dict): - request.config = normalize_dotted_keys(request.config) - if isinstance(plugin_instance.config_schema, dict): - coerce_types(plugin_instance.config_schema, request.config) - - # 查找插件目录 - plugins_dir = Path("plugins") - plugin_path = None - - for p in plugins_dir.iterdir(): - if p.is_dir(): - manifest_path = p / "_manifest.json" - if manifest_path.exists(): - try: - with open(manifest_path, "r", encoding="utf-8") as f: - manifest = json.load(f) - if manifest.get("id") == plugin_id or p.name == plugin_id: - plugin_path = p - break - except Exception: - continue - - if not plugin_path: - raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") - - config_path = plugin_path / "config.toml" - - # 备份旧配置 - import shutil - import datetime - - if config_path.exists(): - backup_name = f"config.toml.backup.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}" - backup_path = plugin_path / backup_name - shutil.copy(config_path, backup_path) - logger.info(f"已备份配置文件: {backup_path}") - - # 写入新配置(自动保留注释和格式) - save_toml_with_format(request.config, str(config_path)) - - logger.info(f"已更新插件配置: {plugin_id}") - - return {"success": True, "message": "配置已保存", "note": "配置更改将在插件重新加载后生效"} - - except HTTPException: - raise - except Exception as e: - logger.error(f"更新插件配置失败: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e - - -@router.post("/config/{plugin_id}/reset") -async def reset_plugin_config( - plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) -) -> Dict[str, Any]: - """ - 重置插件配置为默认值 - - 删除当前配置文件,下次加载插件时将使用默认配置。 - """ - # Token 验证 - token = get_token_from_cookie_or_header(maibot_session, authorization) - token_manager = get_token_manager() - if not token or not token_manager.verify_token(token): - raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - - logger.info(f"重置插件配置: {plugin_id}") - - try: - # 查找插件目录 - plugins_dir = Path("plugins") - plugin_path = None - - for p in plugins_dir.iterdir(): - if p.is_dir(): - manifest_path = p / "_manifest.json" - if manifest_path.exists(): - try: - with open(manifest_path, "r", encoding="utf-8") as f: - manifest = json.load(f) - if manifest.get("id") == plugin_id or p.name == plugin_id: - plugin_path = p - break - except Exception: - continue - - if not plugin_path: - raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") - - config_path = plugin_path / "config.toml" - - if not config_path.exists(): - return {"success": True, "message": "配置文件不存在,无需重置"} - - # 备份并删除 - import shutil - import datetime - - backup_name = f"config.toml.reset.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}" - backup_path = plugin_path / backup_name - shutil.move(config_path, backup_path) - - logger.info(f"已重置插件配置: {plugin_id},备份: {backup_path}") - - return {"success": True, "message": "配置已重置,下次加载插件时将使用默认配置", "backup": str(backup_path)} - - except HTTPException: - raise - except Exception as e: - logger.error(f"重置插件配置失败: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e - - -@router.post("/config/{plugin_id}/toggle") -async def toggle_plugin( - plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) -) -> Dict[str, Any]: - """ - 切换插件启用状态 - - 切换插件配置中的 enabled 字段。 - """ - # Token 验证 - token = get_token_from_cookie_or_header(maibot_session, authorization) - token_manager = get_token_manager() - if not token or not token_manager.verify_token(token): - raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") - - logger.info(f"切换插件状态: {plugin_id}") - - try: - # 查找插件目录 - plugins_dir = Path("plugins") - plugin_path = None - - for p in plugins_dir.iterdir(): - if p.is_dir(): - manifest_path = p / "_manifest.json" - if manifest_path.exists(): - try: - with open(manifest_path, "r", encoding="utf-8") as f: - manifest = json.load(f) - if manifest.get("id") == plugin_id or p.name == plugin_id: - plugin_path = p - break - except Exception: - continue - - if not plugin_path: - raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") - - config_path = plugin_path / "config.toml" - - import tomlkit - - # 读取当前配置(保留注释和格式) - config = tomlkit.document() - if config_path.exists(): - with open(config_path, "r", encoding="utf-8") as f: - config = tomlkit.load(f) - - # 切换 enabled 状态 - if "plugin" not in config: - config["plugin"] = tomlkit.table() - - current_enabled = config["plugin"].get("enabled", True) - new_enabled = not current_enabled - config["plugin"]["enabled"] = new_enabled - - # 写入配置(保留注释,格式化数组) - save_toml_with_format(config, str(config_path)) - - status = "启用" if new_enabled else "禁用" - logger.info(f"已{status}插件: {plugin_id}") - - return { - "success": True, - "enabled": new_enabled, - "message": f"插件已{status}", - "note": "状态更改将在下次加载插件时生效", - } - - except HTTPException: - raise - except Exception as e: - logger.error(f"切换插件状态失败: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e diff --git a/src/webui/routers/plugin/__init__.py b/src/webui/routers/plugin/__init__.py new file mode 100644 index 00000000..1841f690 --- /dev/null +++ b/src/webui/routers/plugin/__init__.py @@ -0,0 +1,17 @@ +from fastapi import APIRouter + +from src.webui.services.git_mirror_service import set_update_progress_callback + +from .catalog import router as catalog_router +from .config_routes import router as config_router +from .management import router as management_router +from .progress import get_progress_router, update_progress + +router = APIRouter(prefix="/plugins", tags=["插件管理"]) +router.include_router(catalog_router) +router.include_router(management_router) +router.include_router(config_router) + +set_update_progress_callback(update_progress) + +__all__ = ["get_progress_router", "router"] \ No newline at end of file diff --git a/src/webui/routers/plugin/catalog.py b/src/webui/routers/plugin/catalog.py new file mode 100644 index 00000000..124c6fac --- /dev/null +++ b/src/webui/routers/plugin/catalog.py @@ -0,0 +1,205 @@ +from typing import Any, Optional + +import json + +from fastapi import APIRouter, Cookie, HTTPException + +from src.common.logger import get_logger +from src.config.config import MMC_VERSION +from src.webui.services.git_mirror_service import get_git_mirror_service + +from .progress import update_progress +from .schemas import ( + AddMirrorRequest, + AvailableMirrorsResponse, + CloneRepositoryRequest, + CloneRepositoryResponse, + FetchRawFileRequest, + FetchRawFileResponse, + GitStatusResponse, + MirrorConfigResponse, + UpdateMirrorRequest, + VersionResponse, +) +from .support import get_plugins_dir, parse_version, require_plugin_token, validate_safe_path + +logger = get_logger("webui.plugin_routes") + +router = APIRouter() + + +def _mirror_to_response(mirror: dict[str, Any]) -> MirrorConfigResponse: + return MirrorConfigResponse( + id=mirror["id"], + name=mirror["name"], + raw_prefix=mirror["raw_prefix"], + clone_prefix=mirror["clone_prefix"], + enabled=mirror["enabled"], + priority=mirror["priority"], + ) + + +@router.get("/version", response_model=VersionResponse) +async def get_maimai_version() -> VersionResponse: + major, minor, patch = parse_version(MMC_VERSION) + return VersionResponse(version=MMC_VERSION, version_major=major, version_minor=minor, version_patch=patch) + + +@router.get("/git-status", response_model=GitStatusResponse) +async def check_git_status() -> GitStatusResponse: + service = get_git_mirror_service() + return GitStatusResponse(**service.check_git_installed()) + + +@router.get("/mirrors", response_model=AvailableMirrorsResponse) +async def get_available_mirrors(maibot_session: Optional[str] = Cookie(None)) -> AvailableMirrorsResponse: + require_plugin_token(maibot_session) + + service = get_git_mirror_service() + config = service.get_mirror_config() + mirrors = [_mirror_to_response(mirror) for mirror in config.get_all_mirrors()] + return AvailableMirrorsResponse(mirrors=mirrors, default_priority=config.get_default_priority_list()) + + +@router.post("/mirrors", response_model=MirrorConfigResponse) +async def add_mirror(request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None)) -> MirrorConfigResponse: + require_plugin_token(maibot_session) + + try: + service = get_git_mirror_service() + config = service.get_mirror_config() + mirror = config.add_mirror( + mirror_id=request.id, + name=request.name, + raw_prefix=request.raw_prefix, + clone_prefix=request.clone_prefix, + enabled=request.enabled, + priority=request.priority, + ) + return _mirror_to_response(mirror) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + except Exception as e: + logger.error(f"添加镜像源失败: {e}") + raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e + + +@router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse) +async def update_mirror( + mirror_id: str, + request: UpdateMirrorRequest, + maibot_session: Optional[str] = Cookie(None), +) -> MirrorConfigResponse: + require_plugin_token(maibot_session) + + try: + service = get_git_mirror_service() + config = service.get_mirror_config() + mirror = config.update_mirror( + mirror_id=mirror_id, + name=request.name, + raw_prefix=request.raw_prefix, + clone_prefix=request.clone_prefix, + enabled=request.enabled, + priority=request.priority, + ) + if mirror is None: + raise HTTPException(status_code=404, detail=f"未找到镜像源: {mirror_id}") + return _mirror_to_response(mirror) + except HTTPException: + raise + except Exception as e: + logger.error(f"更新镜像源失败: {e}") + raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e + + +@router.delete("/mirrors/{mirror_id}") +async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]: + require_plugin_token(maibot_session) + + service = get_git_mirror_service() + config = service.get_mirror_config() + if not config.delete_mirror(mirror_id): + raise HTTPException(status_code=404, detail=f"未找到镜像源: {mirror_id}") + return {"success": True, "message": f"已删除镜像源: {mirror_id}"} + + +@router.post("/fetch-raw", response_model=FetchRawFileResponse) +async def fetch_raw_file( + request: FetchRawFileRequest, + maibot_session: Optional[str] = Cookie(None), +) -> FetchRawFileResponse: + require_plugin_token(maibot_session) + logger.info(f"收到获取 Raw 文件请求: {request.owner}/{request.repo}/{request.branch}/{request.file_path}") + + await update_progress( + stage="loading", + progress=10, + message=f"正在获取插件列表: {request.file_path}", + total_plugins=0, + loaded_plugins=0, + ) + + try: + service = get_git_mirror_service() + result = await service.fetch_raw_file( + owner=request.owner, + repo=request.repo, + branch=request.branch, + file_path=request.file_path, + mirror_id=request.mirror_id, + custom_url=request.custom_url, + ) + + if result.get("success"): + await update_progress( + stage="loading", + progress=70, + message="正在解析插件数据...", + total_plugins=0, + loaded_plugins=0, + ) + try: + data = json.loads(result.get("data", "[]")) + total = len(data) if isinstance(data, list) else 0 + await update_progress( + stage="success", + progress=100, + message=f"成功加载 {total} 个插件", + total_plugins=total, + loaded_plugins=total, + ) + except Exception: + await update_progress(stage="success", progress=100, message="加载完成", total_plugins=0, loaded_plugins=0) + + return FetchRawFileResponse(**result) + except Exception as e: + logger.error(f"获取 Raw 文件失败: {e}") + await update_progress(stage="error", progress=0, message="加载失败", error=str(e), total_plugins=0, loaded_plugins=0) + raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e + + +@router.post("/clone", response_model=CloneRepositoryResponse) +async def clone_repository( + request: CloneRepositoryRequest, + maibot_session: Optional[str] = Cookie(None), +) -> CloneRepositoryResponse: + require_plugin_token(maibot_session) + logger.info(f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}") + + try: + target_path = validate_safe_path(request.target_path, get_plugins_dir()) + service = get_git_mirror_service() + result = await service.clone_repository( + owner=request.owner, + repo=request.repo, + target_path=target_path, + branch=request.branch, + mirror_id=request.mirror_id, + custom_url=request.custom_url, + depth=request.depth, + ) + return CloneRepositoryResponse(**result) + except Exception as e: + logger.error(f"克隆仓库失败: {e}") + raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e \ No newline at end of file diff --git a/src/webui/routers/plugin/config_routes.py b/src/webui/routers/plugin/config_routes.py new file mode 100644 index 00000000..b2f4e254 --- /dev/null +++ b/src/webui/routers/plugin/config_routes.py @@ -0,0 +1,333 @@ +from typing import Any, Optional, cast + +import json +import tomlkit + +from fastapi import APIRouter, Cookie, HTTPException + +from src.common.logger import get_logger +from src.webui.utils.toml_utils import save_toml_with_format + +from .schemas import UpdatePluginConfigRequest, UpdatePluginRawConfigRequest +from .support import ( + backup_file, + coerce_types, + find_plugin_instance, + find_plugin_path_by_id, + normalize_dotted_keys, + require_plugin_token, +) + +logger = get_logger("webui.plugin_routes") + +router = APIRouter() + + +def _build_schema_from_current_config(plugin_id: str, current_config: Any) -> dict[str, Any]: + schema: dict[str, Any] = { + "plugin_id": plugin_id, + "plugin_info": { + "name": plugin_id, + "version": "", + "description": "", + "author": "", + }, + "sections": {}, + "layout": {"type": "auto", "tabs": []}, + "_note": "插件未加载,仅返回当前配置结构", + } + + for section_name, section_data in current_config.items(): + if not isinstance(section_data, dict): + continue + section_fields: dict[str, Any] = {} + for field_name, field_value in section_data.items(): + field_type = type(field_value).__name__ + ui_type = "text" + item_type = None + item_fields = None + + if isinstance(field_value, bool): + ui_type = "switch" + elif isinstance(field_value, (int, float)): + ui_type = "number" + elif isinstance(field_value, list): + ui_type = "list" + if field_value: + first_item = field_value[0] + if isinstance(first_item, dict): + item_type = "object" + item_fields = { + key: { + "type": "number" if isinstance(value, (int, float)) else "string", + "label": key, + "default": "" if isinstance(value, str) else 0, + } + for key, value in first_item.items() + } + elif isinstance(first_item, (int, float)): + item_type = "number" + else: + item_type = "string" + else: + item_type = "string" + elif isinstance(field_value, dict): + ui_type = "json" + + section_fields[field_name] = { + "name": field_name, + "type": field_type, + "default": field_value, + "description": field_name, + "label": field_name, + "ui_type": ui_type, + "required": False, + "hidden": False, + "disabled": False, + "order": 0, + "item_type": item_type, + "item_fields": item_fields, + "min_items": None, + "max_items": None, + "placeholder": None, + "hint": None, + "icon": None, + "example": None, + "choices": None, + "min": None, + "max": None, + "step": None, + "pattern": None, + "max_length": None, + "input_type": None, + "rows": 3, + "group": None, + "depends_on": None, + "depends_value": None, + } + + schema["sections"][section_name] = { + "name": section_name, + "title": section_name, + "description": None, + "icon": None, + "collapsed": False, + "order": 0, + "fields": section_fields, + } + + return schema + + +@router.get("/config/{plugin_id}/schema") +async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]: + require_plugin_token(maibot_session) + logger.info(f"获取插件配置 Schema: {plugin_id}") + + try: + plugin_instance = find_plugin_instance(plugin_id) + if plugin_instance and hasattr(plugin_instance, "get_webui_config_schema"): + return {"success": True, "schema": plugin_instance.get_webui_config_schema()} + + plugin_path = find_plugin_path_by_id(plugin_id) + if plugin_path is None: + raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") + + schema_json_path = plugin_path / "config_schema.json" + if schema_json_path.exists(): + try: + with open(schema_json_path, "r", encoding="utf-8") as file_obj: + return {"success": True, "schema": json.load(file_obj)} + except Exception as e: + logger.warning(f"读取 config_schema.json 失败,回退到自动推断: {e}") + + current_config: Any = {} + config_path = plugin_path / "config.toml" + if config_path.exists(): + with open(config_path, "r", encoding="utf-8") as file_obj: + current_config = tomlkit.load(file_obj) + + return {"success": True, "schema": _build_schema_from_current_config(plugin_id, current_config)} + except HTTPException: + raise + except Exception as e: + logger.error(f"获取插件配置 Schema 失败: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e + + +@router.get("/config/{plugin_id}/raw") +async def get_plugin_config_raw(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]: + require_plugin_token(maibot_session) + logger.info(f"获取插件原始配置: {plugin_id}") + + try: + plugin_path = find_plugin_path_by_id(plugin_id) + if plugin_path is None: + raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") + + config_path = plugin_path / "config.toml" + if not config_path.exists(): + return {"success": True, "config": "", "message": "配置文件不存在"} + + with open(config_path, "r", encoding="utf-8") as file_obj: + return {"success": True, "config": file_obj.read()} + except HTTPException: + raise + except Exception as e: + logger.error(f"获取插件原始配置失败: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e + + +@router.put("/config/{plugin_id}/raw") +async def update_plugin_config_raw( + plugin_id: str, + request: UpdatePluginRawConfigRequest, + maibot_session: Optional[str] = Cookie(None), +) -> dict[str, Any]: + require_plugin_token(maibot_session) + logger.info(f"更新插件原始配置: {plugin_id}") + + try: + plugin_path = find_plugin_path_by_id(plugin_id) + if plugin_path is None: + raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") + + config_path = plugin_path / "config.toml" + try: + tomlkit.loads(request.config) + except Exception as e: + raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e + + backup_path = backup_file(config_path, "backup") + if backup_path is not None: + logger.info(f"已备份配置文件: {backup_path}") + + with open(config_path, "w", encoding="utf-8") as file_obj: + file_obj.write(request.config) + + logger.info(f"已更新插件原始配置: {plugin_id}") + return {"success": True, "message": "配置已保存", "note": "配置更改将在插件重新加载后生效"} + except HTTPException: + raise + except Exception as e: + logger.error(f"更新插件原始配置失败: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e + + +@router.get("/config/{plugin_id}") +async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]: + require_plugin_token(maibot_session) + logger.info(f"获取插件配置: {plugin_id}") + + try: + plugin_path = find_plugin_path_by_id(plugin_id) + if plugin_path is None: + raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") + + config_path = plugin_path / "config.toml" + if not config_path.exists(): + return {"success": True, "config": {}, "message": "配置文件不存在"} + + with open(config_path, "r", encoding="utf-8") as file_obj: + config = tomlkit.load(file_obj) + return {"success": True, "config": dict(config)} + except HTTPException: + raise + except Exception as e: + logger.error(f"获取插件配置失败: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e + + +@router.put("/config/{plugin_id}") +async def update_plugin_config( + plugin_id: str, + request: UpdatePluginConfigRequest, + maibot_session: Optional[str] = Cookie(None), +) -> dict[str, Any]: + require_plugin_token(maibot_session) + logger.info(f"更新插件配置: {plugin_id}") + + try: + plugin_instance = find_plugin_instance(plugin_id) + config_data = request.config or {} + if plugin_instance and isinstance(config_data, dict): + config_data = normalize_dotted_keys(config_data) + if isinstance(plugin_instance.config_schema, dict): + coerce_types(plugin_instance.config_schema, config_data) + + plugin_path = find_plugin_path_by_id(plugin_id) + if plugin_path is None: + raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") + + config_path = plugin_path / "config.toml" + backup_path = backup_file(config_path, "backup") + if backup_path is not None: + logger.info(f"已备份配置文件: {backup_path}") + + save_toml_with_format(config_data, str(config_path)) + logger.info(f"已更新插件配置: {plugin_id}") + return {"success": True, "message": "配置已保存", "note": "配置更改将在插件重新加载后生效"} + except HTTPException: + raise + except Exception as e: + logger.error(f"更新插件配置失败: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e + + +@router.post("/config/{plugin_id}/reset") +async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]: + require_plugin_token(maibot_session) + logger.info(f"重置插件配置: {plugin_id}") + + try: + plugin_path = find_plugin_path_by_id(plugin_id) + if plugin_path is None: + raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") + + config_path = plugin_path / "config.toml" + if not config_path.exists(): + return {"success": True, "message": "配置文件不存在,无需重置"} + + backup_path = backup_file(config_path, "reset", move_file=True) + logger.info(f"已重置插件配置: {plugin_id},备份: {backup_path}") + return {"success": True, "message": "配置已重置,下次加载插件时将使用默认配置", "backup": str(backup_path)} + except HTTPException: + raise + except Exception as e: + logger.error(f"重置插件配置失败: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e + + +@router.post("/config/{plugin_id}/toggle") +async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]: + require_plugin_token(maibot_session) + logger.info(f"切换插件状态: {plugin_id}") + + try: + plugin_path = find_plugin_path_by_id(plugin_id) + if plugin_path is None: + raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") + + config_path = plugin_path / "config.toml" + config = tomlkit.document() + if config_path.exists(): + with open(config_path, "r", encoding="utf-8") as file_obj: + config = tomlkit.load(file_obj) + + if "plugin" not in config: + config["plugin"] = tomlkit.table() + + plugin_config = cast(Any, config["plugin"]) + current_enabled = bool(plugin_config.get("enabled", True)) + new_enabled = not current_enabled + plugin_config["enabled"] = new_enabled + save_toml_with_format(config, str(config_path)) + + status = "启用" if new_enabled else "禁用" + logger.info(f"已{status}插件: {plugin_id}") + return {"success": True, "enabled": new_enabled, "message": f"插件已{status}", "note": "状态更改将在下次加载插件时生效"} + except HTTPException: + raise + except Exception as e: + logger.error(f"切换插件状态失败: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e \ No newline at end of file diff --git a/src/webui/routers/plugin/management.py b/src/webui/routers/plugin/management.py new file mode 100644 index 00000000..cbf2920b --- /dev/null +++ b/src/webui/routers/plugin/management.py @@ -0,0 +1,302 @@ +from pathlib import Path +from typing import Any, Optional + +import json + +from fastapi import APIRouter, Cookie, HTTPException + +from src.common.logger import get_logger +from src.webui.services.git_mirror_service import get_git_mirror_service + +from .progress import update_progress +from .schemas import InstallPluginRequest, UninstallPluginRequest, UpdatePluginRequest +from .support import ( + find_plugin_path_by_id, + get_plugin_candidate_paths, + get_plugins_dir, + load_manifest_json, + parse_repository_url, + remove_tree, + require_plugin_token, + resolve_installed_plugin_path, + validate_plugin_id, +) + +logger = get_logger("webui.plugin_routes") + +router = APIRouter() + + +def _infer_plugin_id(folder_name: str, manifest: dict[str, Any], manifest_path: Path) -> str: + if "id" in manifest: + return str(manifest["id"]) + + author_name: Optional[str] = None + repo_name: Optional[str] = None + if "author" in manifest: + author_data = manifest["author"] + if isinstance(author_data, dict) and "name" in author_data: + author_name = str(author_data["name"]) + elif isinstance(author_data, str): + author_name = author_data + + if "repository_url" in manifest: + repo_url = str(manifest["repository_url"]).rstrip("/").removesuffix(".git") + repo_name = repo_url.split("/")[-1] + + if author_name and repo_name: + plugin_id = f"{author_name}.{repo_name}" + elif author_name: + plugin_id = f"{author_name}.{folder_name}" + elif "_" in folder_name and "." not in folder_name: + plugin_id = folder_name.replace("_", ".", 1) + else: + plugin_id = folder_name + + logger.info(f"为插件 {folder_name} 自动生成 ID: {plugin_id}") + manifest["id"] = plugin_id + try: + with open(manifest_path, "w", encoding="utf-8") as file_obj: + json.dump(manifest, file_obj, ensure_ascii=False, indent=2) + except Exception as write_error: + logger.warning(f"无法写入 ID 到 manifest: {write_error}") + return plugin_id + + +@router.post("/install") +async def install_plugin(request: InstallPluginRequest, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]: + require_plugin_token(maibot_session) + logger.info(f"收到安装插件请求: {request.plugin_id}") + plugin_id = request.plugin_id + + try: + plugin_id = validate_plugin_id(request.plugin_id) + await update_progress(stage="loading", progress=5, message=f"开始安装插件: {plugin_id}", operation="install", plugin_id=plugin_id) + + repo_url, owner, repo = parse_repository_url(request.repository_url) + await update_progress(stage="loading", progress=10, message=f"解析仓库信息: {owner}/{repo}", operation="install", plugin_id=plugin_id) + + target_path, old_format_path = get_plugin_candidate_paths(plugin_id) + if target_path.exists() or old_format_path.exists(): + await update_progress(stage="error", progress=0, message="插件已存在", operation="install", plugin_id=plugin_id, error="插件已安装,请先卸载") + raise HTTPException(status_code=400, detail="插件已安装") + + await update_progress(stage="loading", progress=15, message=f"准备克隆到: {target_path}", operation="install", plugin_id=plugin_id) + service = get_git_mirror_service() + if "github.com" in repo_url: + result = await service.clone_repository(owner=owner, repo=repo, target_path=target_path, branch=request.branch, mirror_id=request.mirror_id, depth=1) + else: + result = await service.clone_repository(owner=owner, repo=repo, target_path=target_path, branch=request.branch, custom_url=repo_url, depth=1) + + if not result.get("success"): + error_msg = str(result.get("error", "克隆失败")) + await update_progress(stage="error", progress=0, message="克隆仓库失败", operation="install", plugin_id=plugin_id, error=error_msg) + raise HTTPException(status_code=500, detail=error_msg) + + await update_progress(stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id) + manifest_path = target_path / "_manifest.json" + if not manifest_path.exists(): + remove_tree(target_path) + await update_progress(stage="error", progress=0, message="插件缺少 _manifest.json", operation="install", plugin_id=plugin_id, error="无效的插件格式") + raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json") + + await update_progress(stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=plugin_id) + try: + with open(manifest_path, "r", encoding="utf-8") as file_obj: + manifest = json.load(file_obj) + for field in ["manifest_version", "name", "version", "author"]: + if field not in manifest: + raise ValueError(f"缺少必需字段: {field}") + manifest["id"] = plugin_id + with open(manifest_path, "w", encoding="utf-8") as file_obj: + json.dump(manifest, file_obj, ensure_ascii=False, indent=2) + except Exception as e: + remove_tree(target_path) + await update_progress(stage="error", progress=0, message="_manifest.json 无效", operation="install", plugin_id=plugin_id, error=str(e)) + raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e + + await update_progress(stage="success", progress=100, message=f"成功安装插件: {manifest['name']} v{manifest['version']}", operation="install", plugin_id=plugin_id) + return {"success": True, "message": "插件安装成功", "plugin_id": plugin_id, "plugin_name": manifest["name"], "version": manifest["version"], "path": str(target_path)} + except HTTPException: + raise + except Exception as e: + logger.error(f"安装插件失败: {e}", exc_info=True) + await update_progress(stage="error", progress=0, message="安装失败", operation="install", plugin_id=plugin_id, error=str(e)) + raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e + + +@router.post("/uninstall") +async def uninstall_plugin(request: UninstallPluginRequest, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]: + require_plugin_token(maibot_session) + logger.info(f"收到卸载插件请求: {request.plugin_id}") + plugin_id = request.plugin_id + + try: + plugin_id = validate_plugin_id(request.plugin_id) + await update_progress(stage="loading", progress=10, message=f"开始卸载插件: {plugin_id}", operation="uninstall", plugin_id=plugin_id) + plugin_path = resolve_installed_plugin_path(plugin_id) + if plugin_path is None: + await update_progress(stage="error", progress=0, message="插件不存在", operation="uninstall", plugin_id=plugin_id, error="插件未安装或已被删除") + raise HTTPException(status_code=404, detail="插件未安装") + + await update_progress(stage="loading", progress=30, message=f"正在删除插件文件: {plugin_path}", operation="uninstall", plugin_id=plugin_id) + manifest = load_manifest_json(plugin_path / "_manifest.json") + plugin_name = str(manifest.get("name", plugin_id)) if manifest is not None else plugin_id + await update_progress(stage="loading", progress=50, message=f"正在删除 {plugin_name}...", operation="uninstall", plugin_id=plugin_id) + remove_tree(plugin_path) + logger.info(f"成功卸载插件: {plugin_id} ({plugin_name})") + await update_progress(stage="success", progress=100, message=f"成功卸载插件: {plugin_name}", operation="uninstall", plugin_id=plugin_id) + return {"success": True, "message": "插件卸载成功", "plugin_id": plugin_id, "plugin_name": plugin_name} + except HTTPException: + raise + except PermissionError as e: + logger.error(f"卸载插件失败(权限错误): {e}") + await update_progress(stage="error", progress=0, message="卸载失败", operation="uninstall", plugin_id=plugin_id, error="权限不足,无法删除插件文件") + raise HTTPException(status_code=500, detail="权限不足,无法删除插件文件") from e + except Exception as e: + logger.error(f"卸载插件失败: {e}", exc_info=True) + await update_progress(stage="error", progress=0, message="卸载失败", operation="uninstall", plugin_id=plugin_id, error=str(e)) + raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e + + +@router.post("/update") +async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]: + require_plugin_token(maibot_session) + logger.info(f"收到更新插件请求: {request.plugin_id}") + plugin_id = request.plugin_id + + try: + plugin_id = validate_plugin_id(request.plugin_id) + await update_progress(stage="loading", progress=5, message=f"开始更新插件: {plugin_id}", operation="update", plugin_id=plugin_id) + plugin_path = resolve_installed_plugin_path(plugin_id) + if plugin_path is None: + await update_progress(stage="error", progress=0, message="插件不存在", operation="update", plugin_id=plugin_id, error="插件未安装,请先安装") + raise HTTPException(status_code=404, detail="插件未安装") + + manifest = load_manifest_json(plugin_path / "_manifest.json") + old_version = str(manifest.get("version", "unknown")) if manifest is not None else "unknown" + await update_progress(stage="loading", progress=10, message=f"当前版本: {old_version},准备更新...", operation="update", plugin_id=plugin_id) + await update_progress(stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id) + remove_tree(plugin_path) + + await update_progress(stage="loading", progress=30, message="正在准备下载新版本...", operation="update", plugin_id=plugin_id) + repo_url, owner, repo = parse_repository_url(request.repository_url) + service = get_git_mirror_service() + if "github.com" in repo_url: + result = await service.clone_repository(owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, mirror_id=request.mirror_id, depth=1) + else: + result = await service.clone_repository(owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, custom_url=repo_url, depth=1) + + if not result.get("success"): + error_msg = str(result.get("error", "克隆失败")) + await update_progress(stage="error", progress=0, message="下载新版本失败", operation="update", plugin_id=plugin_id, error=error_msg) + raise HTTPException(status_code=500, detail=error_msg) + + await update_progress(stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id) + new_manifest_path = plugin_path / "_manifest.json" + if not new_manifest_path.exists(): + remove_tree(plugin_path) + await update_progress(stage="error", progress=0, message="新版本缺少 _manifest.json", operation="update", plugin_id=plugin_id, error="无效的插件格式") + raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json") + + try: + with open(new_manifest_path, "r", encoding="utf-8") as file_obj: + new_manifest = json.load(file_obj) + new_version = str(new_manifest.get("version", "unknown")) + new_name = str(new_manifest.get("name", plugin_id)) + logger.info(f"成功更新插件: {plugin_id} {old_version} → {new_version}") + await update_progress(stage="success", progress=100, message=f"成功更新 {new_name}: {old_version} → {new_version}", operation="update", plugin_id=plugin_id) + return {"success": True, "message": "插件更新成功", "plugin_id": plugin_id, "plugin_name": new_name, "old_version": old_version, "new_version": new_version} + except Exception as e: + remove_tree(plugin_path) + await update_progress(stage="error", progress=0, message="_manifest.json 无效", operation="update", plugin_id=plugin_id, error=str(e)) + raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e + except HTTPException: + raise + except Exception as e: + logger.error(f"更新插件失败: {e}", exc_info=True) + await update_progress(stage="error", progress=0, message="更新失败", operation="update", plugin_id=plugin_id, error=str(e)) + raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e + + +@router.get("/installed") +async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]: + require_plugin_token(maibot_session) + logger.info("收到获取已安装插件列表请求") + + try: + plugins_dir = get_plugins_dir() + installed_plugins: list[dict[str, Any]] = [] + for plugin_path in plugins_dir.iterdir(): + if not plugin_path.is_dir(): + continue + folder_name = plugin_path.name + if folder_name.startswith(".") or folder_name.startswith("__"): + continue + + manifest_path = plugin_path / "_manifest.json" + if not manifest_path.exists(): + logger.warning(f"插件文件夹 {folder_name} 缺少 _manifest.json,跳过") + continue + + try: + with open(manifest_path, "r", encoding="utf-8") as file_obj: + manifest = json.load(file_obj) + if "name" not in manifest or "version" not in manifest: + logger.warning(f"插件文件夹 {folder_name} 的 _manifest.json 格式无效,跳过") + continue + plugin_id = _infer_plugin_id(folder_name, manifest, manifest_path) + installed_plugins.append({"id": plugin_id, "manifest": manifest, "path": str(plugin_path.absolute())}) + except json.JSONDecodeError as e: + logger.warning(f"插件 {folder_name} 的 _manifest.json 解析失败: {e}") + except Exception as e: + logger.error(f"读取插件 {folder_name} 信息时出错: {e}") + + seen_ids: dict[str, str] = {} + unique_plugins: list[dict[str, Any]] = [] + duplicates: list[dict[str, Any]] = [] + for plugin in installed_plugins: + plugin_id = str(plugin["id"]) + plugin_path = str(plugin["path"]) + if plugin_id not in seen_ids: + seen_ids[plugin_id] = plugin_path + unique_plugins.append(plugin) + else: + duplicates.append(plugin) + logger.warning(f"重复插件 {plugin_id}: 保留 {seen_ids[plugin_id]}, 跳过 {plugin_path}") + + if duplicates: + logger.warning(f"共检测到 {len(duplicates)} 个重复插件已去重") + + logger.info(f"找到 {len(unique_plugins)} 个已安装插件") + return {"success": True, "plugins": unique_plugins, "total": len(unique_plugins)} + except Exception as e: + logger.error(f"获取已安装插件列表失败: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e + + +@router.get("/local-readme/{plugin_id}") +async def get_local_plugin_readme(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]: + require_plugin_token(maibot_session) + logger.info(f"获取本地插件 README: {plugin_id}") + + try: + plugin_path = find_plugin_path_by_id(plugin_id) + if plugin_path is None: + return {"success": False, "error": "插件未安装"} + + for readme_name in ["README.md", "readme.md", "Readme.md", "README.MD"]: + readme_path = plugin_path / readme_name + if readme_path.exists(): + try: + with open(readme_path, "r", encoding="utf-8") as file_obj: + readme_content = file_obj.read() + logger.info(f"成功读取本地 README: {readme_path}") + return {"success": True, "data": readme_content} + except Exception as e: + logger.warning(f"读取 {readme_path} 失败: {e}") + + return {"success": False, "error": "本地未找到 README 文件"} + except Exception as e: + logger.error(f"获取本地 README 失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} \ No newline at end of file diff --git a/src/webui/routers/websocket/plugin_progress.py b/src/webui/routers/plugin/progress.py similarity index 64% rename from src/webui/routers/websocket/plugin_progress.py rename to src/webui/routers/plugin/progress.py index 82ead6a5..2d945df5 100644 --- a/src/webui/routers/websocket/plugin_progress.py +++ b/src/webui/routers/plugin/progress.py @@ -1,36 +1,32 @@ -"""WebSocket 插件加载进度推送模块""" - -from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query -from typing import Set, Dict, Any, Optional -import json import asyncio +import json + +from typing import Any, Optional, Set + +from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect + from src.common.logger import get_logger from src.webui.core import get_token_manager from src.webui.routers.websocket.auth import verify_ws_token logger = get_logger("webui.plugin_progress") -# 创建路由器 router = APIRouter() -# 全局 WebSocket 连接池 active_connections: Set[WebSocket] = set() - -# 当前加载进度状态 -current_progress: Dict[str, Any] = { - "operation": "idle", # idle, fetch, install, uninstall, update - "stage": "idle", # idle, loading, success, error - "progress": 0, # 0-100 +current_progress: dict[str, Any] = { + "operation": "idle", + "stage": "idle", + "progress": 0, "message": "", "error": None, - "plugin_id": None, # 当前操作的插件 ID + "plugin_id": None, "total_plugins": 0, "loaded_plugins": 0, } -async def broadcast_progress(progress_data: Dict[str, Any]): - """广播进度更新到所有连接的客户端""" +async def broadcast_progress(progress_data: dict[str, Any]) -> None: global current_progress current_progress = progress_data.copy() @@ -38,7 +34,7 @@ async def broadcast_progress(progress_data: Dict[str, Any]): return message = json.dumps(progress_data, ensure_ascii=False) - disconnected = set() + disconnected: set[WebSocket] = set() for websocket in active_connections: try: @@ -47,7 +43,6 @@ async def broadcast_progress(progress_data: Dict[str, Any]): logger.error(f"发送进度更新失败: {e}") disconnected.add(websocket) - # 移除断开的连接 for websocket in disconnected: active_connections.discard(websocket) @@ -57,23 +52,11 @@ async def update_progress( progress: int, message: str, operation: str = "fetch", - error: str = None, - plugin_id: str = None, + error: Optional[str] = None, + plugin_id: Optional[str] = None, total_plugins: int = 0, loaded_plugins: int = 0, -): - """更新并广播进度 - - Args: - stage: 阶段 (idle, loading, success, error) - progress: 进度百分比 (0-100) - message: 当前消息 - operation: 操作类型 (fetch, install, uninstall, update) - error: 错误信息(可选) - plugin_id: 当前操作的插件 ID - total_plugins: 总插件数 - loaded_plugins: 已加载插件数 - """ +) -> None: progress_data = { "operation": operation, "stage": stage, @@ -91,25 +74,13 @@ async def update_progress( @router.websocket("/ws/plugin-progress") -async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)): - """WebSocket 插件加载进度推送端点 - - 客户端连接后会立即收到当前进度状态 - 支持三种认证方式(按优先级): - 1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token) - 2. Cookie 中的 maibot_session - 3. 直接使用 session token(兼容) - - 示例:ws://host/ws/plugin-progress?token=xxx - """ +async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)) -> None: is_authenticated = False - # 方式 1: 尝试验证临时 WebSocket token(推荐方式) if token and verify_ws_token(token): is_authenticated = True logger.debug("插件进度 WebSocket 使用临时 token 认证成功") - # 方式 2: 尝试从 Cookie 获取 session token if not is_authenticated: cookie_token = websocket.cookies.get("maibot_session") if cookie_token: @@ -118,7 +89,6 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = is_authenticated = True logger.debug("插件进度 WebSocket 使用 Cookie 认证成功") - # 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式) if not is_authenticated and token: token_manager = get_token_manager() if token_manager.verify_token(token): @@ -135,18 +105,13 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = logger.info(f"📡 插件进度 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}") try: - # 发送当前进度状态 await websocket.send_text(json.dumps(current_progress, ensure_ascii=False)) - # 保持连接并处理客户端消息 while True: try: data = await websocket.receive_text() - - # 处理客户端心跳 if data == "ping": await websocket.send_text("pong") - except Exception as e: logger.error(f"处理客户端消息时出错: {e}") break @@ -160,5 +125,4 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = def get_progress_router() -> APIRouter: - """获取插件进度 WebSocket 路由器""" - return router + return router \ No newline at end of file diff --git a/src/webui/routers/plugin/schemas.py b/src/webui/routers/plugin/schemas.py new file mode 100644 index 00000000..c7f6c252 --- /dev/null +++ b/src/webui/routers/plugin/schemas.py @@ -0,0 +1,113 @@ +from typing import Any, Optional + +from pydantic import BaseModel, Field + + +class FetchRawFileRequest(BaseModel): + owner: str = Field(..., description="仓库所有者", examples=["MaiM-with-u"]) + repo: str = Field(..., description="仓库名称", examples=["plugin-repo"]) + branch: str = Field(..., description="分支名称", examples=["main"]) + file_path: str = Field(..., description="文件路径", examples=["plugin_details.json"]) + mirror_id: Optional[str] = Field(None, description="指定镜像源 ID") + custom_url: Optional[str] = Field(None, description="自定义完整 URL") + + +class FetchRawFileResponse(BaseModel): + success: bool = Field(..., description="是否成功") + data: Optional[str] = Field(None, description="文件内容") + error: Optional[str] = Field(None, description="错误信息") + mirror_used: Optional[str] = Field(None, description="使用的镜像源") + attempts: int = Field(..., description="尝试次数") + url: Optional[str] = Field(None, description="实际请求的 URL") + + +class CloneRepositoryRequest(BaseModel): + owner: str = Field(..., description="仓库所有者", examples=["MaiM-with-u"]) + repo: str = Field(..., description="仓库名称", examples=["plugin-repo"]) + target_path: str = Field(..., description="目标路径(相对于插件目录)") + branch: Optional[str] = Field(None, description="分支名称", examples=["main"]) + mirror_id: Optional[str] = Field(None, description="指定镜像源 ID") + custom_url: Optional[str] = Field(None, description="自定义克隆 URL") + depth: Optional[int] = Field(None, description="克隆深度(浅克隆)", ge=1) + + +class CloneRepositoryResponse(BaseModel): + success: bool = Field(..., description="是否成功") + path: Optional[str] = Field(None, description="克隆路径") + error: Optional[str] = Field(None, description="错误信息") + mirror_used: Optional[str] = Field(None, description="使用的镜像源") + attempts: int = Field(..., description="尝试次数") + url: Optional[str] = Field(None, description="实际克隆的 URL") + message: Optional[str] = Field(None, description="附加信息") + + +class MirrorConfigResponse(BaseModel): + id: str = Field(..., description="镜像源 ID") + name: str = Field(..., description="镜像源名称") + raw_prefix: str = Field(..., description="Raw 文件前缀") + clone_prefix: str = Field(..., description="克隆前缀") + enabled: bool = Field(..., description="是否启用") + priority: int = Field(..., description="优先级(数字越小优先级越高)") + + +class AvailableMirrorsResponse(BaseModel): + mirrors: list[MirrorConfigResponse] = Field(..., description="镜像源列表") + default_priority: list[str] = Field(..., description="默认优先级顺序(ID 列表)") + + +class AddMirrorRequest(BaseModel): + id: str = Field(..., description="镜像源 ID", examples=["custom-mirror"]) + name: str = Field(..., description="镜像源名称", examples=["自定义镜像源"]) + raw_prefix: str = Field(..., description="Raw 文件前缀", examples=["https://example.com/raw"]) + clone_prefix: str = Field(..., description="克隆前缀", examples=["https://example.com/clone"]) + enabled: bool = Field(True, description="是否启用") + priority: Optional[int] = Field(None, description="优先级") + + +class UpdateMirrorRequest(BaseModel): + name: Optional[str] = Field(None, description="镜像源名称") + raw_prefix: Optional[str] = Field(None, description="Raw 文件前缀") + clone_prefix: Optional[str] = Field(None, description="克隆前缀") + enabled: Optional[bool] = Field(None, description="是否启用") + priority: Optional[int] = Field(None, description="优先级") + + +class GitStatusResponse(BaseModel): + installed: bool = Field(..., description="是否已安装 Git") + version: Optional[str] = Field(None, description="Git 版本号") + path: Optional[str] = Field(None, description="Git 可执行文件路径") + error: Optional[str] = Field(None, description="错误信息") + + +class InstallPluginRequest(BaseModel): + plugin_id: str = Field(..., description="插件 ID") + repository_url: str = Field(..., description="插件仓库 URL") + branch: Optional[str] = Field("main", description="分支名称") + mirror_id: Optional[str] = Field(None, description="指定镜像源 ID") + + +class VersionResponse(BaseModel): + version: str = Field(..., description="麦麦版本号") + version_major: int = Field(..., description="主版本号") + version_minor: int = Field(..., description="次版本号") + version_patch: int = Field(..., description="补丁版本号") + + +class UninstallPluginRequest(BaseModel): + plugin_id: str = Field(..., description="插件 ID") + + +class UpdatePluginRequest(BaseModel): + plugin_id: str = Field(..., description="插件 ID") + repository_url: str = Field(..., description="插件仓库 URL") + branch: Optional[str] = Field("main", description="分支名称") + mirror_id: Optional[str] = Field(None, description="指定镜像源 ID") + + +class UpdatePluginConfigRequest(BaseModel): + enabled: Optional[bool] = None + config: Optional[dict[str, Any]] = None + + +class UpdatePluginRawConfigRequest(BaseModel): + config: str = Field(..., description="原始 TOML 配置内容") \ No newline at end of file diff --git a/src/webui/routers/plugin/support.py b/src/webui/routers/plugin/support.py new file mode 100644 index 00000000..f0078736 --- /dev/null +++ b/src/webui/routers/plugin/support.py @@ -0,0 +1,221 @@ +from datetime import datetime +from pathlib import Path +from typing import Any, Optional, cast, get_origin + +import json +import os +import re +import shutil +import stat + +from fastapi import HTTPException + +from src.common.logger import get_logger +from src.core.config_types import ConfigField +from src.webui.core import get_token_manager + +logger = get_logger("webui.plugin_routes") + + +def require_plugin_token(maibot_session: Optional[str]) -> str: + token_manager = get_token_manager() + if not maibot_session or not token_manager.verify_token(maibot_session): + raise HTTPException(status_code=401, detail="未授权:无效的访问令牌") + return maibot_session + + +def validate_safe_path(user_path: str, base_path: Path) -> Path: + base_resolved = base_path.resolve() + if any(pattern in user_path for pattern in ["..", "\x00"]): + logger.warning(f"检测到可疑路径: {user_path}") + raise HTTPException(status_code=400, detail="路径包含非法字符") + + if user_path.startswith("/") or user_path.startswith("\\") or (len(user_path) > 1 and user_path[1] == ":"): + logger.warning(f"检测到绝对路径: {user_path}") + raise HTTPException(status_code=400, detail="不允许使用绝对路径") + + target_path = (base_path / user_path).resolve() + try: + target_path.relative_to(base_resolved) + except ValueError as e: + logger.warning(f"路径遍历攻击检测: {user_path} -> {target_path}") + raise HTTPException(status_code=400, detail="路径超出允许范围") from e + + return target_path + + +def validate_plugin_id(plugin_id: str) -> str: + if not plugin_id or not plugin_id.strip(): + logger.warning("非法插件 ID: 空字符串") + raise HTTPException(status_code=400, detail="插件 ID 不能为空") + + for pattern in ["/", "\\", "\x00", "..", "\n", "\r", "\t"]: + if pattern in plugin_id: + logger.warning(f"非法插件 ID 格式: {plugin_id} (包含危险字符)") + raise HTTPException(status_code=400, detail="插件 ID 包含非法字符") + + if plugin_id.startswith(".") or plugin_id.endswith("."): + logger.warning(f"非法插件 ID: {plugin_id}") + raise HTTPException(status_code=400, detail="插件 ID 不能以点开头或结尾") + + if plugin_id in {".", ".."}: + logger.warning(f"非法插件 ID: {plugin_id}") + raise HTTPException(status_code=400, detail="插件 ID 不能为特殊目录名") + + return plugin_id + + +def parse_version(version_str: str) -> tuple[int, int, int]: + base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version_str, flags=re.IGNORECASE)[0] + parts = base_version.split(".") + if len(parts) < 3: + parts.extend(["0"] * (3 - len(parts))) + + try: + return int(parts[0]), int(parts[1]), int(parts[2]) + except (ValueError, IndexError): + logger.warning(f"无法解析版本号: {version_str},返回默认值 (0, 0, 0)") + return 0, 0, 0 + + +def deep_merge(dst: dict[str, Any], src: dict[str, Any]) -> None: + for key, value in src.items(): + if key in dst and isinstance(dst[key], dict) and isinstance(value, dict): + deep_merge(dst[key], value) + else: + dst[key] = value + + +def normalize_dotted_keys(obj: dict[str, Any]) -> dict[str, Any]: + result: dict[str, Any] = {} + dotted_items: list[tuple[str, Any]] = [] + + for key, value in obj.items(): + if "." in key: + dotted_items.append((key, value)) + else: + result[key] = normalize_dotted_keys(value) if isinstance(value, dict) else value + + for dotted_key, value in dotted_items: + normalized_value = normalize_dotted_keys(value) if isinstance(value, dict) else value + parts = dotted_key.split(".") + if "" in parts: + logger.warning(f"键路径包含空段: '{dotted_key}'") + parts = [part for part in parts if part] + if not parts: + logger.warning(f"忽略空键路径: '{dotted_key}'") + continue + + current = result + for index, part in enumerate(parts[:-1]): + if part in current and not isinstance(current[part], dict): + path_ctx = ".".join(parts[: index + 1]) + logger.warning(f"键冲突:{part} 已存在且非字典,覆盖为字典以展开 {dotted_key} (路径 {path_ctx})") + current[part] = {} + current = current.setdefault(part, {}) + + last_part = parts[-1] + if last_part in current and isinstance(current[last_part], dict) and isinstance(normalized_value, dict): + deep_merge(current[last_part], normalized_value) + else: + current[last_part] = normalized_value + + return result + + +def coerce_types(schema_part: dict[str, Any], config_part: dict[str, Any]) -> None: + def is_list_type(tp: Any) -> bool: + origin = get_origin(tp) + return tp is list or origin is list + + for key, schema_val in schema_part.items(): + if key not in config_part: + continue + value = config_part[key] + if isinstance(schema_val, ConfigField): + if is_list_type(schema_val.type) and isinstance(value, str): + config_part[key] = [item.strip() for item in value.split(",") if item.strip()] + elif isinstance(schema_val, dict) and isinstance(value, dict): + coerce_types(schema_val, value) + + +def find_plugin_instance(plugin_id: str) -> Optional[Any]: + from src.plugin_runtime.integration import get_plugin_runtime_manager + + manager = get_plugin_runtime_manager() + for supervisor in manager.supervisors: + registered = supervisor._registered_plugins.get(plugin_id) + if registered is not None: + return registered + return None + + +def get_plugins_dir() -> Path: + plugins_dir = Path("plugins").resolve() + plugins_dir.mkdir(exist_ok=True) + return plugins_dir + + +def get_plugin_candidate_paths(plugin_id: str) -> tuple[Path, Path]: + plugins_dir = get_plugins_dir() + folder_name = plugin_id.replace(".", "_") + return validate_safe_path(folder_name, plugins_dir), validate_safe_path(plugin_id, plugins_dir) + + +def resolve_installed_plugin_path(plugin_id: str) -> Optional[Path]: + new_format_path, old_format_path = get_plugin_candidate_paths(plugin_id) + if new_format_path.exists(): + return new_format_path + return old_format_path if old_format_path.exists() else None + + +def parse_repository_url(repository_url: str) -> tuple[str, str, str]: + repo_url = repository_url.rstrip("/").removesuffix(".git") + parts = repo_url.split("/") + if len(parts) < 2: + raise HTTPException(status_code=400, detail="无效的仓库 URL") + return repo_url, parts[-2], parts[-1] + + +def load_manifest_json(manifest_path: Path) -> Optional[dict[str, Any]]: + if not manifest_path.exists(): + return None + + try: + with open(manifest_path, "r", encoding="utf-8") as file_obj: + return cast(dict[str, Any], json.load(file_obj)) + except Exception: + return None + + +def iter_plugin_directories() -> list[Path]: + return [path for path in get_plugins_dir().iterdir() if path.is_dir()] + + +def find_plugin_path_by_id(plugin_id: str) -> Optional[Path]: + for plugin_path in iter_plugin_directories(): + manifest = load_manifest_json(plugin_path / "_manifest.json") + if manifest is not None and (manifest.get("id") == plugin_id or plugin_path.name == plugin_id): + return plugin_path + return None + + +def backup_file(file_path: Path, action: str, move_file: bool = False) -> Optional[Path]: + if not file_path.exists(): + return None + + backup_name = f"{file_path.name}.{action}.{datetime.now().strftime('%Y%m%d%H%M%S')}" + backup_path = file_path.parent / backup_name + if move_file: + shutil.move(file_path, backup_path) + else: + shutil.copy(file_path, backup_path) + return backup_path + + +def remove_tree(path: Path) -> None: + def remove_readonly(func: Any, target_path: str, _: Any) -> None: + os.chmod(target_path, stat.S_IWRITE) + func(target_path) + + shutil.rmtree(path, onerror=remove_readonly) \ No newline at end of file diff --git a/src/webui/routers/statistics.py b/src/webui/routers/statistics.py index fdf36f8c..b91e3244 100644 --- a/src/webui/routers/statistics.py +++ b/src/webui/routers/statistics.py @@ -1,29 +1,22 @@ """统计数据 API 路由""" from datetime import datetime, timedelta -from typing import Any, Optional +from typing import Any -from fastapi import APIRouter, Cookie, Depends, Header, HTTPException +from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel, Field from sqlalchemy import desc, func, or_ from sqlmodel import col, select from src.common.database.database import get_db_session -from src.common.database.database_model import Messages, ModelUsage, OnlineTime +from src.common.database.database_model import ModelUsage, OnlineTime from src.common.logger import get_logger -from src.webui.core import verify_auth_token_from_cookie_or_header +from src.common.message_repository import count_messages +from src.webui.dependencies import require_auth logger = get_logger("webui.statistics") -router = APIRouter(prefix="/statistics", tags=["statistics"]) - - -def require_auth( - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -) -> bool: - """认证依赖:验证用户是否已登录""" - return verify_auth_token_from_cookie_or_header(maibot_session, authorization) +router = APIRouter(prefix="/statistics", tags=["statistics"], dependencies=[Depends(require_auth)]) class StatisticsSummary(BaseModel): @@ -70,7 +63,7 @@ class DashboardData(BaseModel): @router.get("/dashboard", response_model=DashboardData) -async def get_dashboard_data(hours: int = 24, _auth: bool = Depends(require_auth)): +async def get_dashboard_data(hours: int = 24): """ 获取仪表盘统计数据 @@ -159,24 +152,12 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S if end > start: summary.online_time += (end - start).total_seconds() - # 查询消息数量 - 使用聚合优化 - with get_db_session() as session: - statement = select(func.count()).where( - col(Messages.timestamp) >= start_time, - col(Messages.timestamp) <= end_time, - ) - total_messages = session.exec(statement).one() - summary.total_messages = int(total_messages or 0) - - # 统计回复数量 - with get_db_session() as session: - statement = select(func.count()).where( - col(Messages.timestamp) >= start_time, - col(Messages.timestamp) <= end_time, - col(Messages.reply_to).is_not(None), - ) - total_replies = session.exec(statement).one() - summary.total_replies = int(total_replies or 0) + summary.total_messages = count_messages(start_time=start_time.timestamp(), end_time=end_time.timestamp()) + summary.total_replies = count_messages( + start_time=start_time.timestamp(), + end_time=end_time.timestamp(), + has_reply_to=True, + ) # 计算派生指标 if summary.online_time > 0: @@ -351,7 +332,7 @@ async def _get_recent_activity(limit: int = 10) -> list[dict[str, Any]]: @router.get("/summary") -async def get_summary(hours: int = 24, _auth: bool = Depends(require_auth)): +async def get_summary(hours: int = 24): """ 获取统计摘要 @@ -369,7 +350,7 @@ async def get_summary(hours: int = 24, _auth: bool = Depends(require_auth)): @router.get("/models") -async def get_model_stats(hours: int = 24, _auth: bool = Depends(require_auth)): +async def get_model_stats(hours: int = 24): """ 获取模型统计 diff --git a/src/webui/routers/system.py b/src/webui/routers/system.py index ac6ab324..7bd27347 100644 --- a/src/webui/routers/system.py +++ b/src/webui/routers/system.py @@ -7,28 +7,20 @@ import os import time from datetime import datetime -from typing import Optional -from fastapi import APIRouter, HTTPException, Depends, Cookie, Header + +from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel from src.config.config import MMC_VERSION from src.common.logger import get_logger -from src.webui.core import verify_auth_token_from_cookie_or_header +from src.webui.dependencies import require_auth -router = APIRouter(prefix="/system", tags=["system"]) +router = APIRouter(prefix="/system", tags=["system"], dependencies=[Depends(require_auth)]) logger = get_logger("webui_system") # 记录启动时间 _start_time = time.time() -def require_auth( - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -) -> bool: - """认证依赖:验证用户是否已登录""" - return verify_auth_token_from_cookie_or_header(maibot_session, authorization) - - class RestartResponse(BaseModel): """重启响应""" @@ -46,7 +38,7 @@ class StatusResponse(BaseModel): @router.post("/restart", response_model=RestartResponse) -async def restart_maibot(_auth: bool = Depends(require_auth)): +async def restart_maibot(): """ 重启麦麦主程序 @@ -77,7 +69,7 @@ async def restart_maibot(_auth: bool = Depends(require_auth)): @router.get("/status", response_model=StatusResponse) -async def get_maibot_status(_auth: bool = Depends(require_auth)): +async def get_maibot_status(): """ 获取麦麦运行状态 @@ -100,7 +92,7 @@ async def get_maibot_status(_auth: bool = Depends(require_auth)): @router.post("/reload-config") -async def reload_config(_auth: bool = Depends(require_auth)): +async def reload_config(): """ 热重载配置(不重启进程) diff --git a/src/webui/routers/websocket/__init__.py b/src/webui/routers/websocket/__init__.py index 0acec62f..9dc576f2 100644 --- a/src/webui/routers/websocket/__init__.py +++ b/src/webui/routers/websocket/__init__.py @@ -1,9 +1,7 @@ from .logs import router as logs_router -from .plugin_progress import get_progress_router from .auth import router as ws_auth_router __all__ = [ "logs_router", - "get_progress_router", "ws_auth_router", ] diff --git a/src/webui/routers/websocket/auth.py b/src/webui/routers/websocket/auth.py index 74246759..a8e33e7f 100644 --- a/src/webui/routers/websocket/auth.py +++ b/src/webui/routers/websocket/auth.py @@ -1,11 +1,8 @@ -"""WebSocket 认证模块 +"""WebSocket 认证模块。""" -提供所有 WebSocket 端点统一使用的临时 token 认证机制。 -临时 token 有效期 60 秒,且只能使用一次,用于解决 WebSocket 握手时 Cookie 不可用的问题。 -""" - -from fastapi import APIRouter, Cookie, Header from typing import Optional + +from fastapi import APIRouter, Cookie import secrets import time from src.common.logger import get_logger @@ -77,25 +74,17 @@ def verify_ws_token(temp_token: str) -> bool: @router.get("/ws-token") async def get_ws_token( maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), ): """ 获取 WebSocket 连接用的临时 token - 此端点验证当前会话的 Cookie 或 Authorization header, + 此端点验证当前会话 Cookie, 然后返回一个临时 token 用于 WebSocket 握手认证。 临时 token 有效期 60 秒,且只能使用一次。 注意:在未认证时返回 200 状态码但 success=False,避免前端因 401 刷新页面。 """ - # 获取当前 session token - session_token = None - if maibot_session: - session_token = maibot_session - elif authorization and authorization.startswith("Bearer "): - session_token = authorization.replace("Bearer ", "") - - if not session_token: + if not maibot_session: # 返回 200 但 success=False,避免前端因 401 刷新页面 # 这在登录页面是正常情况,不应该触发错误处理 logger.debug("ws-token 请求:未提供认证信息(可能在登录页面)") @@ -103,12 +92,12 @@ async def get_ws_token( # 验证 session token token_manager = get_token_manager() - if not token_manager.verify_token(session_token): + if not token_manager.verify_token(maibot_session): # 同样返回 200 但 success=False,避免前端刷新 logger.debug("ws-token 请求:认证已过期") return {"success": False, "message": "认证已过期,请重新登录", "token": None, "expires_in": 0} # 生成临时 WebSocket token - ws_token = generate_ws_token(session_token) + ws_token = generate_ws_token(maibot_session) return {"success": True, "token": ws_token, "expires_in": _WS_TOKEN_EXPIRE_SECONDS} diff --git a/src/webui/routers/websocket/logs.py b/src/webui/routers/websocket/logs.py index 1d43f306..d8341263 100644 --- a/src/webui/routers/websocket/logs.py +++ b/src/webui/routers/websocket/logs.py @@ -1,177 +1,11 @@ -"""WebSocket 日志推送模块""" +"""WebSocket 日志推送路由兼容导出。""" -from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query -from typing import Set, Optional -import json -from pathlib import Path -from src.common.logger import get_logger -from src.webui.core import get_token_manager -from src.webui.routers.websocket.auth import verify_ws_token +from src.webui.logs_ws import active_connections, broadcast_log, load_recent_logs, router, websocket_logs -logger = get_logger("webui.logs_ws") -router = APIRouter() - -# 全局 WebSocket 连接池 -active_connections: Set[WebSocket] = set() - - -def load_recent_logs(limit: int = 100) -> list[dict]: - """从日志文件中加载最近的日志 - - Args: - limit: 返回的最大日志条数 - - Returns: - 日志列表 - """ - logs = [] - log_dir = Path("logs") - - if not log_dir.exists(): - return logs - - # 获取所有日志文件,按修改时间排序 - log_files = sorted(log_dir.glob("app_*.log.jsonl"), key=lambda f: f.stat().st_mtime, reverse=True) - - # 用于生成唯一 ID 的计数器 - log_counter = 0 - - # 从最新的文件开始读取 - for log_file in log_files: - if len(logs) >= limit: - break - - try: - with open(log_file, "r", encoding="utf-8") as f: - lines = f.readlines() - # 从文件末尾开始读取 - for line in reversed(lines): - if len(logs) >= limit: - break - try: - log_entry = json.loads(line.strip()) - # 转换为前端期望的格式 - # 使用时间戳 + 计数器生成唯一 ID - timestamp_id = ( - log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "") - ) - formatted_log = { - "id": f"{timestamp_id}_{log_counter}", - "timestamp": log_entry.get("timestamp", ""), - "level": log_entry.get("level", "INFO").upper(), - "module": log_entry.get("logger_name", ""), - "message": log_entry.get("event", ""), - } - logs.append(formatted_log) - log_counter += 1 - except (json.JSONDecodeError, KeyError): - continue - except Exception as e: - logger.error(f"读取日志文件失败 {log_file}: {e}") - continue - - # 反转列表,使其按时间顺序排列(旧到新) - return list(reversed(logs)) - - -@router.websocket("/ws/logs") -async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None)): - """WebSocket 日志推送端点 - - 客户端连接后会持续接收服务器端的日志消息 - 支持三种认证方式(按优先级): - 1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token) - 2. Cookie 中的 maibot_session - 3. 直接使用 session token(兼容) - - 示例:ws://host/ws/logs?token=xxx - """ - is_authenticated = False - - # 方式 1: 尝试验证临时 WebSocket token(推荐方式) - if token and verify_ws_token(token): - is_authenticated = True - logger.debug("WebSocket 使用临时 token 认证成功") - - # 方式 2: 尝试从 Cookie 获取 session token - if not is_authenticated: - cookie_token = websocket.cookies.get("maibot_session") - if cookie_token: - token_manager = get_token_manager() - if token_manager.verify_token(cookie_token): - is_authenticated = True - logger.debug("WebSocket 使用 Cookie 认证成功") - - # 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式) - if not is_authenticated and token: - token_manager = get_token_manager() - if token_manager.verify_token(token): - is_authenticated = True - logger.debug("WebSocket 使用 session token 认证成功") - - if not is_authenticated: - logger.warning("WebSocket 连接被拒绝:认证失败") - await websocket.close(code=4001, reason="认证失败,请重新登录") - return - - await websocket.accept() - active_connections.add(websocket) - logger.info(f"📡 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}") - - # 连接建立后,立即发送历史日志 - try: - recent_logs = load_recent_logs(limit=100) - logger.info(f"发送 {len(recent_logs)} 条历史日志到客户端") - - for log_entry in recent_logs: - await websocket.send_text(json.dumps(log_entry, ensure_ascii=False)) - except Exception as e: - logger.error(f"发送历史日志失败: {e}") - - try: - # 保持连接,等待客户端消息或断开 - while True: - # 接收客户端消息(用于心跳或控制指令) - data = await websocket.receive_text() - - # 可以处理客户端的控制消息,例如: - # - "ping" -> 心跳检测 - # - {"filter": "ERROR"} -> 设置日志级别过滤 - if data == "ping": - await websocket.send_text("pong") - - except WebSocketDisconnect: - active_connections.discard(websocket) - logger.info(f"📡 WebSocket 客户端已断开,当前连接数: {len(active_connections)}") - except Exception as e: - logger.error(f"❌ WebSocket 错误: {e}") - active_connections.discard(websocket) - - -async def broadcast_log(log_data: dict): - """广播日志到所有连接的 WebSocket 客户端 - - Args: - log_data: 日志数据字典 - """ - if not active_connections: - return - - # 格式化为 JSON - message = json.dumps(log_data, ensure_ascii=False) - - # 记录需要断开的连接 - disconnected = set() - - # 广播到所有客户端 - for connection in active_connections: - try: - await connection.send_text(message) - except Exception: - # 发送失败,标记为断开 - disconnected.add(connection) - - # 清理断开的连接 - if disconnected: - active_connections.difference_update(disconnected) - logger.debug(f"清理了 {len(disconnected)} 个断开的 WebSocket 连接") +__all__ = [ + "active_connections", + "broadcast_log", + "load_recent_logs", + "router", + "websocket_logs", +] diff --git a/src/webui/routes.py b/src/webui/routes.py index da45cb06..c3e4c1fc 100644 --- a/src/webui/routes.py +++ b/src/webui/routes.py @@ -1,28 +1,26 @@ """WebUI API 路由""" -from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie, Depends +from fastapi import APIRouter, Depends, HTTPException, Request, Response from pydantic import BaseModel, Field -from typing import Optional from src.common.logger import get_logger from src.webui.core import ( + clear_auth_cookie, + check_auth_rate_limit, + get_rate_limiter, get_token_manager, set_auth_cookie, - clear_auth_cookie, - get_rate_limiter, - check_auth_rate_limit, ) +from src.webui.dependencies import require_auth, verify_token_optional from src.webui.routers.config import router as config_router from src.webui.routers.statistics import router as statistics_router from src.webui.routers.person import router as person_router from src.webui.routers.expression import router as expression_router from src.webui.routers.jargon import router as jargon_router from src.webui.routers.emoji import router as emoji_router -from src.webui.routers.plugin import router as plugin_router -from src.webui.routers.websocket.plugin_progress import get_progress_router +from src.webui.routers.plugin import get_progress_router, router as plugin_router from src.webui.routers.system import router as system_router from src.webui.routers.model import router as model_router from src.webui.routers.websocket.auth import router as ws_auth_router -from src.webui.routers.annual_report import router as annual_report_router logger = get_logger("webui.api") @@ -51,8 +49,6 @@ router.include_router(system_router) router.include_router(model_router) # 注册 WebSocket 认证路由 router.include_router(ws_auth_router) -# 注册年度报告路由 -router.include_router(annual_report_router) class TokenVerifyRequest(BaseModel): @@ -190,9 +186,7 @@ async def logout(response: Response): @router.get("/auth/check") async def check_auth_status( - request: Request, - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), + authenticated: bool = Depends(verify_token_optional), ): """ 检查当前认证状态(用于前端判断是否已登录) @@ -201,46 +195,17 @@ async def check_auth_status( 认证状态 """ try: - token = None - - # 记录请求信息用于调试 - logger.debug( - f"检查认证状态 - Cookie: {maibot_session[:20] if maibot_session else 'None'}..., Authorization: {'Present' if authorization else 'None'}" - ) - - # 优先从 Cookie 获取 - if maibot_session: - token = maibot_session - logger.debug("使用 Cookie 中的 token") - # 其次从 Header 获取 - elif authorization and authorization.startswith("Bearer "): - token = authorization.replace("Bearer ", "") - logger.debug("使用 Header 中的 token") - - if not token: - logger.debug("未找到 token,返回未认证") - return {"authenticated": False} - - token_manager = get_token_manager() - is_valid = token_manager.verify_token(token) - logger.debug(f"Token 验证结果: {is_valid}") - - if is_valid: - return {"authenticated": True} - else: - return {"authenticated": False} + logger.debug(f"检查认证状态,结果: {authenticated}") + return {"authenticated": authenticated} except Exception as e: logger.error(f"认证检查失败: {e}", exc_info=True) return {"authenticated": False} -@router.post("/auth/update", response_model=TokenUpdateResponse) +@router.post("/auth/update", response_model=TokenUpdateResponse, dependencies=[Depends(require_auth)]) async def update_token( request: TokenUpdateRequest, response: Response, - req: Request, - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), ): """ 更新访问令牌(需要当前有效的 token) @@ -248,28 +213,13 @@ async def update_token( Args: request: 包含新 token 的更新请求 response: FastAPI Response 对象 - maibot_session: Cookie 中的 token - authorization: Authorization header (Bearer token) Returns: 更新结果 """ try: - # 验证当前 token(优先 Cookie,其次 Header) - current_token = None - if maibot_session: - current_token = maibot_session - elif authorization and authorization.startswith("Bearer "): - current_token = authorization.replace("Bearer ", "") - - if not current_token: - raise HTTPException(status_code=401, detail="未提供有效的认证信息") - token_manager = get_token_manager() - if not token_manager.verify_token(current_token): - raise HTTPException(status_code=401, detail="当前 Token 无效") - # 更新 token success, message = token_manager.update_token(request.new_token) @@ -285,40 +235,22 @@ async def update_token( raise HTTPException(status_code=500, detail="Token 更新失败") from e -@router.post("/auth/regenerate", response_model=TokenRegenerateResponse) +@router.post("/auth/regenerate", response_model=TokenRegenerateResponse, dependencies=[Depends(require_auth)]) async def regenerate_token( response: Response, - request: Request, - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), ): """ 重新生成访问令牌(需要当前有效的 token) Args: response: FastAPI Response 对象 - maibot_session: Cookie 中的 token - authorization: Authorization header (Bearer token) Returns: 新生成的 token """ try: - # 验证当前 token(优先 Cookie,其次 Header) - current_token = None - if maibot_session: - current_token = maibot_session - elif authorization and authorization.startswith("Bearer "): - current_token = authorization.replace("Bearer ", "") - - if not current_token: - raise HTTPException(status_code=401, detail="未提供有效的认证信息") - token_manager = get_token_manager() - if not token_manager.verify_token(current_token): - raise HTTPException(status_code=401, detail="当前 Token 无效") - # 重新生成 token new_token = token_manager.regenerate_token() @@ -333,38 +265,17 @@ async def regenerate_token( raise HTTPException(status_code=500, detail="Token 重新生成失败") from e -@router.get("/setup/status", response_model=FirstSetupStatusResponse) -async def get_setup_status( - request: Request, - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -): +@router.get("/setup/status", response_model=FirstSetupStatusResponse, dependencies=[Depends(require_auth)]) +async def get_setup_status(): """ 获取首次配置状态 - Args: - maibot_session: Cookie 中的 token - authorization: Authorization header (Bearer token) - Returns: 首次配置状态 """ try: - # 验证 token(优先 Cookie,其次 Header) - current_token = None - if maibot_session: - current_token = maibot_session - elif authorization and authorization.startswith("Bearer "): - current_token = authorization.replace("Bearer ", "") - - if not current_token: - raise HTTPException(status_code=401, detail="未提供有效的认证信息") - token_manager = get_token_manager() - if not token_manager.verify_token(current_token): - raise HTTPException(status_code=401, detail="Token 无效") - # 检查是否为首次配置 is_first = token_manager.is_first_setup() @@ -376,38 +287,17 @@ async def get_setup_status( raise HTTPException(status_code=500, detail="获取配置状态失败") from e -@router.post("/setup/complete", response_model=CompleteSetupResponse) -async def complete_setup( - request: Request, - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -): +@router.post("/setup/complete", response_model=CompleteSetupResponse, dependencies=[Depends(require_auth)]) +async def complete_setup(): """ 标记首次配置完成 - Args: - maibot_session: Cookie 中的 token - authorization: Authorization header (Bearer token) - Returns: 完成结果 """ try: - # 验证 token(优先 Cookie,其次 Header) - current_token = None - if maibot_session: - current_token = maibot_session - elif authorization and authorization.startswith("Bearer "): - current_token = authorization.replace("Bearer ", "") - - if not current_token: - raise HTTPException(status_code=401, detail="未提供有效的认证信息") - token_manager = get_token_manager() - if not token_manager.verify_token(current_token): - raise HTTPException(status_code=401, detail="Token 无效") - # 标记配置完成 success = token_manager.mark_setup_completed() @@ -419,38 +309,17 @@ async def complete_setup( raise HTTPException(status_code=500, detail="标记配置完成失败") from e -@router.post("/setup/reset", response_model=ResetSetupResponse) -async def reset_setup( - request: Request, - maibot_session: Optional[str] = Cookie(None), - authorization: Optional[str] = Header(None), -): +@router.post("/setup/reset", response_model=ResetSetupResponse, dependencies=[Depends(require_auth)]) +async def reset_setup(): """ 重置首次配置状态,允许重新进入配置向导 - Args: - maibot_session: Cookie 中的 token - authorization: Authorization header (Bearer token) - Returns: 重置结果 """ try: - # 验证 token(优先 Cookie,其次 Header) - current_token = None - if maibot_session: - current_token = maibot_session - elif authorization and authorization.startswith("Bearer "): - current_token = authorization.replace("Bearer ", "") - - if not current_token: - raise HTTPException(status_code=401, detail="未提供有效的认证信息") - token_manager = get_token_manager() - if not token_manager.verify_token(current_token): - raise HTTPException(status_code=401, detail="Token 无效") - # 重置配置状态 success = token_manager.reset_setup_status() diff --git a/src/webui/webui_server.py b/src/webui/webui_server.py index 95eb6546..fbc30cbf 100644 --- a/src/webui/webui_server.py +++ b/src/webui/webui_server.py @@ -49,6 +49,7 @@ class WebUIServer: service_name="WebUI 服务器", logger=logger, config_hint="WEBUI_PORT (.env)", + allow_reuse_addr=True, ) config = Config(