feat:同步本地非算法改动到上游基线

保留反馈纠错、WebUI 与运行时增强。\n移除不应提交的 algorithm_redesign 设计目录及其专项测试。
This commit is contained in:
A-Dawn
2026-04-16 13:57:07 +08:00
parent 6c22fdfdf9
commit 21b642d07d
10 changed files with 2244 additions and 34 deletions

View File

@@ -496,6 +496,77 @@ export interface MemoryDeleteOperationDetailPayload {
error?: string
}
export interface MemoryFeedbackAffectedCountsPayload {
relations?: number
stale_paragraphs?: number
episode_sources?: number
profile_person_ids?: number
correction_paragraphs?: number
corrected_relations?: number
}
export interface MemoryFeedbackActionLogPayload {
id: number
task_id: number
query_tool_id: string
action_type: string
target_hash: string
reason?: string
before_payload?: Record<string, unknown>
after_payload?: Record<string, unknown>
created_at?: number
}
export interface MemoryFeedbackCorrectionSummaryPayload {
task_id: number
query_tool_id: string
session_id: string
query_text: string
query_timestamp?: number
task_status: string
decision: string
decision_confidence: number
feedback_message_count: number
rollback_status: string
affected_counts: MemoryFeedbackAffectedCountsPayload
created_at?: number
updated_at?: number
}
export interface MemoryFeedbackCorrectionDetailTaskPayload extends MemoryFeedbackCorrectionSummaryPayload {
query_snapshot?: Record<string, unknown>
decision_payload?: Record<string, unknown>
rollback_plan_summary?: Record<string, unknown>
rollback_result?: Record<string, unknown>
rollback_error?: string
rollback_requested_by?: string
rollback_reason?: string
rollback_requested_at?: number
rolled_back_at?: number
action_logs?: MemoryFeedbackActionLogPayload[]
}
export interface MemoryFeedbackCorrectionListPayload {
success: boolean
items: MemoryFeedbackCorrectionSummaryPayload[]
count?: number
error?: string
}
export interface MemoryFeedbackCorrectionDetailPayload {
success: boolean
task?: MemoryFeedbackCorrectionDetailTaskPayload | null
error?: string
}
export interface MemoryFeedbackCorrectionRollbackPayload {
success: boolean
already_rolled_back?: boolean
result?: Record<string, unknown>
task?: MemoryFeedbackCorrectionDetailTaskPayload | null
error?: string
}
export interface MemorySourceItemPayload {
source: string
paragraph_count?: number
@@ -610,6 +681,49 @@ export async function getMemoryDeleteOperation(
return requestJson<MemoryDeleteOperationDetailPayload>(`/delete/operations/${encodeURIComponent(operationId)}`)
}
export async function getMemoryFeedbackCorrections(
options?: {
limit?: number
status?: string
rollbackStatus?: string
query?: string
},
): Promise<MemoryFeedbackCorrectionListPayload> {
const params = new URLSearchParams({
limit: String(options?.limit ?? 50),
})
if (options?.status?.trim()) {
params.set('status', options.status.trim())
}
if (options?.rollbackStatus?.trim()) {
params.set('rollback_status', options.rollbackStatus.trim())
}
if (options?.query?.trim()) {
params.set('query', options.query.trim())
}
return requestJson<MemoryFeedbackCorrectionListPayload>(`/feedback-corrections?${params.toString()}`)
}
export async function getMemoryFeedbackCorrection(
taskId: number,
): Promise<MemoryFeedbackCorrectionDetailPayload> {
return requestJson<MemoryFeedbackCorrectionDetailPayload>(`/feedback-corrections/${taskId}`)
}
export async function rollbackMemoryFeedbackCorrection(
taskId: number,
payload: {
requested_by?: string
reason?: string
},
): Promise<MemoryFeedbackCorrectionRollbackPayload> {
return requestJson<MemoryFeedbackCorrectionRollbackPayload>(`/feedback-corrections/${taskId}/rollback`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(payload),
})
}
export async function getMemorySources(): Promise<MemorySourceListPayload> {
return requestJson<MemorySourceListPayload>('/sources')
}

View File

@@ -81,9 +81,12 @@ vi.mock('@/lib/memory-api', () => ({
getMemorySources: vi.fn(),
getMemoryDeleteOperations: vi.fn(),
getMemoryDeleteOperation: vi.fn(),
getMemoryFeedbackCorrections: vi.fn(),
getMemoryFeedbackCorrection: vi.fn(),
previewMemoryDelete: vi.fn(),
executeMemoryDelete: vi.fn(),
restoreMemoryDelete: vi.fn(),
rollbackMemoryFeedbackCorrection: vi.fn(),
}))
function mockImportTask(taskId: string, status: string = 'running'): memoryApi.MemoryImportTaskPayload {
@@ -357,6 +360,82 @@ describe('KnowledgeBasePage import workflow', () => {
items: [],
},
})
vi.mocked(memoryApi.getMemoryFeedbackCorrections).mockResolvedValue({
success: true,
items: [
{
task_id: 11,
query_tool_id: 'tool-query-11',
session_id: 'session-1',
query_text: '测试用户最喜欢的颜色是什么',
query_timestamp: 1_710_000_010,
task_status: 'applied',
decision: 'correct',
decision_confidence: 0.97,
feedback_message_count: 1,
rollback_status: 'none',
affected_counts: {
relations: 1,
stale_paragraphs: 1,
episode_sources: 2,
profile_person_ids: 1,
correction_paragraphs: 1,
corrected_relations: 1,
},
created_at: 1_710_000_011,
updated_at: 1_710_000_012,
},
],
count: 1,
})
vi.mocked(memoryApi.getMemoryFeedbackCorrection).mockResolvedValue({
success: true,
task: {
task_id: 11,
query_tool_id: 'tool-query-11',
session_id: 'session-1',
query_text: '测试用户最喜欢的颜色是什么',
query_timestamp: 1_710_000_010,
task_status: 'applied',
decision: 'correct',
decision_confidence: 0.97,
feedback_message_count: 1,
rollback_status: 'none',
affected_counts: {
relations: 1,
stale_paragraphs: 1,
episode_sources: 2,
profile_person_ids: 1,
correction_paragraphs: 1,
corrected_relations: 1,
},
query_snapshot: { query: '测试用户最喜欢的颜色是什么', hits: [{ hash: 'paragraph-1' }] },
decision_payload: { decision: 'correct', confidence: 0.97 },
rollback_plan_summary: {
forgotten_relations: [{ hash: 'rel-old', subject: '测试用户', predicate: '最喜欢的颜色是', object: '蓝色' }],
corrected_write: {
paragraph_hashes: ['paragraph-new'],
corrected_relations: [{ hash: 'rel-new', subject: '测试用户', predicate: '最喜欢的颜色是', object: '绿色' }],
},
},
rollback_result: {},
action_logs: [
{
id: 1,
task_id: 11,
query_tool_id: 'tool-query-11',
action_type: 'forget_relation',
target_hash: 'rel-old',
reason: '用户明确纠正为绿色',
before_payload: { hash: 'rel-old', subject: '测试用户', predicate: '最喜欢的颜色是', object: '蓝色' },
after_payload: { is_inactive: true },
created_at: 1_710_000_013,
},
],
created_at: 1_710_000_011,
updated_at: 1_710_000_012,
},
})
vi.mocked(memoryApi.previewMemoryDelete).mockResolvedValue({
success: true,
mode: 'source',
@@ -380,6 +459,37 @@ describe('KnowledgeBasePage import workflow', () => {
deleted_source_count: 1,
} as never)
vi.mocked(memoryApi.restoreMemoryDelete).mockResolvedValue({ success: true } as never)
vi.mocked(memoryApi.rollbackMemoryFeedbackCorrection).mockResolvedValue({
success: true,
result: { restored_relation_hashes: ['rel-old'] },
task: {
task_id: 11,
query_tool_id: 'tool-query-11',
session_id: 'session-1',
query_text: '测试用户最喜欢的颜色是什么',
query_timestamp: 1_710_000_010,
task_status: 'applied',
decision: 'correct',
decision_confidence: 0.97,
feedback_message_count: 1,
rollback_status: 'rolled_back',
affected_counts: {
relations: 1,
stale_paragraphs: 1,
episode_sources: 2,
profile_person_ids: 1,
correction_paragraphs: 1,
corrected_relations: 1,
},
query_snapshot: { query: '测试用户最喜欢的颜色是什么', hits: [{ hash: 'paragraph-1' }] },
decision_payload: { decision: 'correct', confidence: 0.97 },
rollback_plan_summary: {},
rollback_result: { restored_relation_hashes: ['rel-old'] },
action_logs: [],
created_at: 1_710_000_011,
updated_at: 1_710_000_012,
},
})
vi.mocked(memoryApi.refreshMemoryRuntimeSelfCheck).mockResolvedValue({
success: true,
report: { ok: true },
@@ -619,4 +729,27 @@ describe('KnowledgeBasePage import workflow', () => {
}),
)
}, 20_000)
it('shows feedback correction history and supports rollback', async () => {
const user = userEvent.setup()
render(<KnowledgeBasePage />)
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
await user.click(screen.getByRole('tab', { name: '纠错历史' }))
await screen.findByText('反馈纠错历史')
await screen.findByText('测试用户最喜欢的颜色是什么')
await waitFor(() => expect(memoryApi.getMemoryFeedbackCorrection).toHaveBeenCalledWith(11))
await user.click(screen.getByRole('button', { name: '回退本次纠错' }))
const rollbackReason = await screen.findByLabelText('回退原因')
await user.type(rollbackReason, '人工确认回退')
await user.click(screen.getByRole('button', { name: '确认回退' }))
await waitFor(() =>
expect(memoryApi.rollbackMemoryFeedbackCorrection).toHaveBeenCalledWith(11, {
requested_by: 'knowledge_base',
reason: '人工确认回退',
}),
)
}, 20_000)
})

View File

@@ -24,6 +24,14 @@ import { Badge } from '@/components/ui/badge'
import { Button } from '@/components/ui/button'
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'
import { Checkbox } from '@/components/ui/checkbox'
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
} from '@/components/ui/dialog'
import { Input } from '@/components/ui/input'
import { Label } from '@/components/ui/label'
import { Progress } from '@/components/ui/progress'
@@ -48,6 +56,8 @@ import {
createMemoryRawScanImport,
createMemoryTemporalBackfillImport,
executeMemoryDelete,
getMemoryFeedbackCorrection,
getMemoryFeedbackCorrections,
getMemoryImportPathAliases,
getMemoryImportSettings,
getMemoryImportTask,
@@ -74,6 +84,7 @@ import {
type MemoryImportTaskPayload,
previewMemoryDelete,
refreshMemoryRuntimeSelfCheck,
rollbackMemoryFeedbackCorrection,
resolveMemoryImportPath,
retryMemoryImportTask,
restoreMemoryDelete,
@@ -82,6 +93,9 @@ import {
type MemoryConfigSchemaPayload,
type MemoryDeleteExecutePayload,
type MemoryDeleteOperationPayload,
type MemoryFeedbackActionLogPayload,
type MemoryFeedbackCorrectionDetailTaskPayload,
type MemoryFeedbackCorrectionSummaryPayload,
type MemorySourceItemPayload,
type MemoryRuntimeConfigPayload,
type MemoryTaskPayload,
@@ -90,6 +104,9 @@ import {
const DELETE_OPERATION_FETCH_LIMIT = 100
const DELETE_OPERATION_PAGE_SIZE = 6
const DELETE_OPERATION_ITEM_PAGE_SIZE = 8
const FEEDBACK_CORRECTION_FETCH_LIMIT = 100
const FEEDBACK_CORRECTION_PAGE_SIZE = 6
const FEEDBACK_ACTION_LOG_PAGE_SIZE = 8
const IMPORT_CHUNK_PAGE_SIZE = 50
const RUNNING_IMPORT_STATUS = new Set(['preparing', 'running', 'cancel_requested'])
@@ -270,6 +287,90 @@ function formatDeleteOperationTime(timestamp?: number | null): string {
})
}
function formatFeedbackDecision(decision: string): string {
switch (decision) {
case 'correct':
return '纠正'
case 'reject':
return '否定'
case 'confirm':
return '确认'
case 'supplement':
return '补充'
case 'none':
return '无动作'
default:
return decision || '未知'
}
}
function formatFeedbackTaskStatus(status: string): string {
switch (status) {
case 'pending':
return '待处理'
case 'running':
return '处理中'
case 'applied':
return '已应用'
case 'skipped':
return '已跳过'
case 'error':
return '失败'
default:
return status || '未知'
}
}
function formatFeedbackRollbackStatus(status: string): string {
switch (status) {
case 'none':
return '未回退'
case 'running':
return '回退中'
case 'rolled_back':
return '已回退'
case 'error':
return '回退失败'
default:
return status || '未知'
}
}
function getFeedbackStatusVariant(
status: string,
): 'default' | 'secondary' | 'destructive' | 'outline' {
if (status === 'applied' || status === 'rolled_back') {
return 'default'
}
if (status === 'error') {
return 'destructive'
}
if (status === 'running' || status === 'pending') {
return 'outline'
}
return 'secondary'
}
function summarizeFeedbackActionPayload(value: Record<string, unknown> | undefined): string {
if (!value) {
return ''
}
const hash = String(value.hash ?? '').trim()
const subject = String(value.subject ?? '').trim()
const predicate = String(value.predicate ?? '').trim()
const object = String(value.object ?? '').trim()
if (subject && predicate && object) {
return formatDeleteRelationText(subject, predicate, object)
}
if (hash) {
return hash
}
if (Array.isArray(value.target_hashes) && value.target_hashes.length > 0) {
return `targets ${value.target_hashes.length}`
}
return trimDeleteItemText(JSON.stringify(value, null, 2), 120)
}
type DeleteOperationItem = NonNullable<MemoryDeleteOperationPayload['items']>[number]
function trimDeleteItemText(value: string, maxLength: number = 140): string {
@@ -471,6 +572,20 @@ export function KnowledgeBasePage() {
const [deleteRestoring, setDeleteRestoring] = useState(false)
const [deleteResult, setDeleteResult] = useState<MemoryDeleteExecutePayload | null>(null)
const [pendingDeleteRequest, setPendingDeleteRequest] = useState<MemoryDeleteRequestPayload | null>(null)
const [feedbackCorrections, setFeedbackCorrections] = useState<MemoryFeedbackCorrectionSummaryPayload[]>([])
const [feedbackSearch, setFeedbackSearch] = useState('')
const [feedbackStatusFilter, setFeedbackStatusFilter] = useState('all')
const [feedbackRollbackFilter, setFeedbackRollbackFilter] = useState('all')
const [feedbackPage, setFeedbackPage] = useState(1)
const [selectedFeedbackTaskId, setSelectedFeedbackTaskId] = useState(0)
const [selectedFeedbackTaskDetail, setSelectedFeedbackTaskDetail] = useState<MemoryFeedbackCorrectionDetailTaskPayload | null>(null)
const [selectedFeedbackTaskLoading, setSelectedFeedbackTaskLoading] = useState(false)
const [selectedFeedbackTaskError, setSelectedFeedbackTaskError] = useState('')
const [feedbackActionLogSearch, setFeedbackActionLogSearch] = useState('')
const [feedbackActionLogPage, setFeedbackActionLogPage] = useState(1)
const [feedbackRollbackDialogOpen, setFeedbackRollbackDialogOpen] = useState(false)
const [feedbackRollbackReason, setFeedbackRollbackReason] = useState('')
const [feedbackRollingBack, setFeedbackRollingBack] = useState(false)
const [tuningObjective, setTuningObjective] = useState('precision_priority')
const [tuningIntensity, setTuningIntensity] = useState('standard')
const [tuningSampleSize, setTuningSampleSize] = useState('24')
@@ -491,6 +606,7 @@ export function KnowledgeBasePage() {
tuningTaskPayload,
sourcePayload,
deleteOperationPayload,
feedbackCorrectionPayload,
] = await Promise.all([
getMemoryConfigSchema(),
getMemoryConfig(),
@@ -503,6 +619,7 @@ export function KnowledgeBasePage() {
getMemoryTuningTasks(20),
getMemorySources(),
getMemoryDeleteOperations(DELETE_OPERATION_FETCH_LIMIT),
getMemoryFeedbackCorrections({ limit: FEEDBACK_CORRECTION_FETCH_LIMIT }),
])
setSchemaPayload(schema)
@@ -519,6 +636,7 @@ export function KnowledgeBasePage() {
setTuningTasks(tuningTaskPayload.items ?? [])
setMemorySources(sourcePayload.items ?? [])
setDeleteOperations(deleteOperationPayload.items ?? [])
setFeedbackCorrections(feedbackCorrectionPayload.items ?? [])
if (!selectedImportTaskId && (importTaskPayload.items ?? []).length > 0) {
const initialTaskId = String(importTaskPayload.items?.[0]?.task_id ?? '')
if (initialTaskId) {
@@ -1494,6 +1612,212 @@ export function KnowledgeBasePage() {
setDeleteDialogOpen(true)
}, [])
const filteredFeedbackCorrections = useMemo(() => {
const keyword = feedbackSearch.trim().toLowerCase()
return feedbackCorrections.filter((item) => {
const taskStatus = String(item.task_status ?? '').trim().toLowerCase()
const rollbackStatus = String(item.rollback_status ?? '').trim().toLowerCase()
if (feedbackStatusFilter !== 'all' && taskStatus !== feedbackStatusFilter) {
return false
}
if (feedbackRollbackFilter !== 'all' && rollbackStatus !== feedbackRollbackFilter) {
return false
}
if (!keyword) {
return true
}
return [
item.query_tool_id,
item.session_id,
item.query_text,
item.decision,
item.task_status,
item.rollback_status,
]
.map((value) => String(value ?? '').toLowerCase())
.some((value) => value.includes(keyword))
})
}, [feedbackCorrections, feedbackRollbackFilter, feedbackSearch, feedbackStatusFilter])
const feedbackPageCount = Math.max(1, Math.ceil(filteredFeedbackCorrections.length / FEEDBACK_CORRECTION_PAGE_SIZE))
const pagedFeedbackCorrections = useMemo(() => {
const start = (feedbackPage - 1) * FEEDBACK_CORRECTION_PAGE_SIZE
return filteredFeedbackCorrections.slice(start, start + FEEDBACK_CORRECTION_PAGE_SIZE)
}, [feedbackPage, filteredFeedbackCorrections])
const selectedFeedbackCorrection = useMemo(
() =>
filteredFeedbackCorrections.find((item) => item.task_id === selectedFeedbackTaskId)
?? pagedFeedbackCorrections[0]
?? null,
[filteredFeedbackCorrections, pagedFeedbackCorrections, selectedFeedbackTaskId],
)
useEffect(() => {
setFeedbackPage(1)
}, [feedbackSearch, feedbackStatusFilter, feedbackRollbackFilter])
useEffect(() => {
if (feedbackPage > feedbackPageCount) {
setFeedbackPage(feedbackPageCount)
}
}, [feedbackPage, feedbackPageCount])
useEffect(() => {
if (!selectedFeedbackCorrection) {
if (selectedFeedbackTaskId) {
setSelectedFeedbackTaskId(0)
}
setSelectedFeedbackTaskDetail(null)
setSelectedFeedbackTaskError('')
return
}
if (selectedFeedbackCorrection.task_id !== selectedFeedbackTaskId) {
setSelectedFeedbackTaskId(selectedFeedbackCorrection.task_id)
}
}, [selectedFeedbackCorrection, selectedFeedbackTaskId])
useEffect(() => {
const taskId = selectedFeedbackCorrection?.task_id
if (!taskId) {
setSelectedFeedbackTaskDetail(null)
setSelectedFeedbackTaskError('')
return
}
let cancelled = false
setSelectedFeedbackTaskLoading(true)
setSelectedFeedbackTaskError('')
void getMemoryFeedbackCorrection(taskId)
.then((payload) => {
if (cancelled) {
return
}
if (!payload.success || !payload.task) {
setSelectedFeedbackTaskDetail(null)
setSelectedFeedbackTaskError(payload.error || '未能加载纠错任务详情')
return
}
setSelectedFeedbackTaskDetail(payload.task)
})
.catch((error) => {
if (cancelled) {
return
}
setSelectedFeedbackTaskDetail(null)
setSelectedFeedbackTaskError(error instanceof Error ? error.message : '未能加载纠错任务详情')
})
.finally(() => {
if (!cancelled) {
setSelectedFeedbackTaskLoading(false)
}
})
return () => {
cancelled = true
}
}, [selectedFeedbackCorrection?.task_id])
const selectedFeedbackResolved = useMemo(() => {
if (!selectedFeedbackCorrection) {
return null
}
if (selectedFeedbackTaskDetail?.task_id === selectedFeedbackCorrection.task_id) {
return {
...selectedFeedbackCorrection,
...selectedFeedbackTaskDetail,
} satisfies MemoryFeedbackCorrectionDetailTaskPayload
}
return selectedFeedbackTaskDetail ?? selectedFeedbackCorrection
}, [selectedFeedbackCorrection, selectedFeedbackTaskDetail])
const selectedFeedbackActionLogs = Array.isArray(selectedFeedbackResolved?.action_logs)
? selectedFeedbackResolved.action_logs
: []
const filteredFeedbackActionLogs = useMemo(() => {
const keyword = feedbackActionLogSearch.trim().toLowerCase()
if (!keyword) {
return selectedFeedbackActionLogs
}
return selectedFeedbackActionLogs.filter((item) =>
[
item.action_type,
item.target_hash,
item.reason,
summarizeFeedbackActionPayload(item.before_payload),
summarizeFeedbackActionPayload(item.after_payload),
]
.map((value) => String(value ?? '').toLowerCase())
.some((value) => value.includes(keyword)),
)
}, [feedbackActionLogSearch, selectedFeedbackActionLogs])
const feedbackActionLogPageCount = Math.max(
1,
Math.ceil(filteredFeedbackActionLogs.length / FEEDBACK_ACTION_LOG_PAGE_SIZE),
)
const pagedFeedbackActionLogs = useMemo(() => {
const start = (feedbackActionLogPage - 1) * FEEDBACK_ACTION_LOG_PAGE_SIZE
return filteredFeedbackActionLogs.slice(start, start + FEEDBACK_ACTION_LOG_PAGE_SIZE)
}, [feedbackActionLogPage, filteredFeedbackActionLogs])
useEffect(() => {
setFeedbackActionLogPage(1)
}, [selectedFeedbackTaskId, feedbackActionLogSearch])
useEffect(() => {
if (feedbackActionLogPage > feedbackActionLogPageCount) {
setFeedbackActionLogPage(feedbackActionLogPageCount)
}
}, [feedbackActionLogPage, feedbackActionLogPageCount])
const openFeedbackRollbackDialog = useCallback(() => {
setFeedbackRollbackReason('')
setFeedbackRollbackDialogOpen(true)
}, [])
const executeFeedbackRollback = useCallback(async () => {
const taskId = selectedFeedbackResolved?.task_id
if (!taskId) {
return
}
try {
setFeedbackRollingBack(true)
const payload = await rollbackMemoryFeedbackCorrection(taskId, {
requested_by: 'knowledge_base',
reason: feedbackRollbackReason.trim(),
})
if (!payload.success) {
throw new Error(payload.error || '回退失败')
}
toast({
title: payload.already_rolled_back ? '该纠错已回退' : '纠错回退成功',
description: `任务 ${taskId} 的回退结果已写入日志`,
})
setFeedbackRollbackDialogOpen(false)
const [listPayload, detailPayload] = await Promise.all([
getMemoryFeedbackCorrections({ limit: FEEDBACK_CORRECTION_FETCH_LIMIT }),
getMemoryFeedbackCorrection(taskId),
])
setFeedbackCorrections(listPayload.items ?? [])
setSelectedFeedbackTaskDetail(detailPayload.task ?? null)
const [sourcePayload, runtimePayload] = await Promise.all([
getMemorySources(),
getMemoryRuntimeConfig(),
])
setMemorySources(sourcePayload.items ?? [])
setRuntimeConfig(runtimePayload)
} catch (error) {
toast({
title: '纠错回退失败',
description: error instanceof Error ? error.message : '未知错误',
variant: 'destructive',
})
} finally {
setFeedbackRollingBack(false)
}
}, [feedbackRollbackReason, selectedFeedbackResolved?.task_id, toast])
const selectedOperationResolved = useMemo(() => {
if (!selectedDeleteOperation) {
return null
@@ -1776,6 +2100,9 @@ export function KnowledgeBasePage() {
<TabsTrigger value="delete" className="rounded-lg px-4 py-1.5">
</TabsTrigger>
<TabsTrigger value="feedback" className="rounded-lg px-4 py-1.5">
</TabsTrigger>
</TabsList>
<TabsContent value="overview" className="space-y-4">
@@ -3314,6 +3641,327 @@ export function KnowledgeBasePage() {
</Card>
</div>
</TabsContent>
<TabsContent value="feedback" className="space-y-4">
<div className="space-y-4">
<Card>
<CardHeader>
<CardTitle className="flex items-center gap-2">
<RotateCcw className="h-4 w-4" />
</CardTitle>
<CardDescription>
feedback correction 退
</CardDescription>
</CardHeader>
<CardContent className="space-y-4">
<div className="grid gap-3 lg:grid-cols-[minmax(0,1fr)_180px_180px]">
<Input
value={feedbackSearch}
onChange={(event) => setFeedbackSearch(event.target.value)}
placeholder="搜索 query_tool_id / session / query / reason"
/>
<Select value={feedbackStatusFilter} onValueChange={setFeedbackStatusFilter}>
<SelectTrigger>
<SelectValue placeholder="按任务状态筛选" />
</SelectTrigger>
<SelectContent>
<SelectItem value="all"></SelectItem>
<SelectItem value="applied"></SelectItem>
<SelectItem value="skipped"></SelectItem>
<SelectItem value="error"></SelectItem>
<SelectItem value="running"></SelectItem>
<SelectItem value="pending"></SelectItem>
</SelectContent>
</Select>
<Select value={feedbackRollbackFilter} onValueChange={setFeedbackRollbackFilter}>
<SelectTrigger>
<SelectValue placeholder="按回退状态筛选" />
</SelectTrigger>
<SelectContent>
<SelectItem value="all">退</SelectItem>
<SelectItem value="none">退</SelectItem>
<SelectItem value="rolled_back">退</SelectItem>
<SelectItem value="error">退</SelectItem>
<SelectItem value="running">退</SelectItem>
</SelectContent>
</Select>
</div>
<div className="flex flex-wrap items-center justify-between gap-2 text-sm text-muted-foreground">
<span> {filteredFeedbackCorrections.length} {feedbackCorrections.length} </span>
<span> {feedbackPage} / {feedbackPageCount} {FEEDBACK_CORRECTION_PAGE_SIZE} </span>
</div>
<div className="grid gap-4 xl:grid-cols-[minmax(0,0.92fr)_minmax(0,1.08fr)]">
<ScrollArea className="h-[720px] rounded-lg border">
<div className="space-y-3 p-3">
{pagedFeedbackCorrections.length > 0 ? pagedFeedbackCorrections.map((item) => {
const isSelected = selectedFeedbackCorrection?.task_id === item.task_id
return (
<button
key={item.task_id}
type="button"
onClick={() => setSelectedFeedbackTaskId(item.task_id)}
className={cn(
'w-full rounded-xl border p-4 text-left transition-colors',
isSelected
? 'border-primary bg-primary/5 shadow-sm'
: 'bg-muted/20 hover:border-primary/40 hover:bg-muted/40',
)}
>
<div className="flex flex-col gap-3">
<div className="flex flex-wrap items-center gap-2">
<Badge variant={getFeedbackStatusVariant(item.task_status)}>
{formatFeedbackTaskStatus(item.task_status)}
</Badge>
<Badge variant={getFeedbackStatusVariant(item.rollback_status)}>
{formatFeedbackRollbackStatus(item.rollback_status)}
</Badge>
<Badge variant="outline">
{formatFeedbackDecision(item.decision)}
</Badge>
</div>
<div className="text-sm font-medium break-words">
{item.query_text || '无查询文本'}
</div>
<div className="font-mono text-[11px] break-all text-muted-foreground">
{item.query_tool_id}
</div>
<div className="flex flex-wrap gap-3 text-xs text-muted-foreground">
<span> {Number(item.affected_counts?.relations ?? 0)}</span>
<span> {Number(item.affected_counts?.stale_paragraphs ?? 0)}</span>
<span>Episode {Number(item.affected_counts?.episode_sources ?? 0)}</span>
<span>Profile {Number(item.affected_counts?.profile_person_ids ?? 0)}</span>
</div>
<div className="text-xs text-muted-foreground">
{formatDeleteOperationTime(item.query_timestamp ?? item.created_at)}
</div>
</div>
</button>
)
}) : (
<div className="rounded-lg border border-dashed bg-muted/20 p-6 text-center text-sm text-muted-foreground">
</div>
)}
</div>
</ScrollArea>
<div className="rounded-xl border bg-muted/20 p-4">
{selectedFeedbackCorrection ? (
<div className="space-y-4">
<div className="flex flex-col gap-3 lg:flex-row lg:items-start lg:justify-between">
<div className="space-y-2">
<div className="flex flex-wrap items-center gap-2">
<Badge variant={getFeedbackStatusVariant(String(selectedFeedbackResolved?.task_status ?? ''))}>
{formatFeedbackTaskStatus(String(selectedFeedbackResolved?.task_status ?? ''))}
</Badge>
<Badge variant={getFeedbackStatusVariant(String(selectedFeedbackResolved?.rollback_status ?? 'none'))}>
{formatFeedbackRollbackStatus(String(selectedFeedbackResolved?.rollback_status ?? 'none'))}
</Badge>
<Badge variant="outline">
{formatFeedbackDecision(String(selectedFeedbackResolved?.decision ?? ''))}
</Badge>
</div>
<div className="text-sm font-medium break-words">
{selectedFeedbackResolved?.query_text || '无查询文本'}
</div>
<div className="font-mono text-xs break-all">
{selectedFeedbackResolved?.query_tool_id}
</div>
</div>
<Button
size="sm"
variant="outline"
onClick={openFeedbackRollbackDialog}
disabled={
String(selectedFeedbackResolved?.task_status ?? '') !== 'applied'
|| String(selectedFeedbackResolved?.rollback_status ?? 'none') === 'rolled_back'
|| feedbackRollingBack
}
>
<RotateCcw className="mr-2 h-4 w-4" />
{String(selectedFeedbackResolved?.rollback_status ?? 'none') === 'rolled_back'
? '已回退'
: '回退本次纠错'}
</Button>
</div>
<div className="grid gap-3 lg:grid-cols-4">
<div className="rounded-lg border bg-background/60 p-3">
<div className="text-xs text-muted-foreground"></div>
<div className="mt-1 text-sm break-all">{selectedFeedbackResolved?.session_id || '-'}</div>
</div>
<div className="rounded-lg border bg-background/60 p-3">
<div className="text-xs text-muted-foreground"></div>
<div className="mt-1 text-sm">{Number(selectedFeedbackResolved?.feedback_message_count ?? 0)}</div>
</div>
<div className="rounded-lg border bg-background/60 p-3">
<div className="text-xs text-muted-foreground"></div>
<div className="mt-1 text-sm">{Number(selectedFeedbackResolved?.decision_confidence ?? 0).toFixed(2)}</div>
</div>
<div className="rounded-lg border bg-background/60 p-3">
<div className="text-xs text-muted-foreground">退</div>
<div className="mt-1 text-sm">{formatDeleteOperationTime(selectedFeedbackResolved?.rolled_back_at)}</div>
</div>
</div>
{selectedFeedbackTaskLoading ? (
<div className="rounded-lg border bg-background/60 p-4 text-sm text-muted-foreground">
...
</div>
) : null}
{selectedFeedbackTaskError ? (
<Alert variant="destructive">
<AlertDescription>{selectedFeedbackTaskError}</AlertDescription>
</Alert>
) : null}
{selectedFeedbackResolved?.rollback_error ? (
<Alert variant="destructive">
<AlertDescription>{selectedFeedbackResolved.rollback_error}</AlertDescription>
</Alert>
) : null}
<div className="grid gap-4 xl:grid-cols-[minmax(0,0.95fr)_minmax(0,1.05fr)]">
<div className="space-y-2">
<div className="text-sm font-semibold"></div>
<pre className="max-h-56 overflow-auto rounded-lg border bg-background/70 p-3 text-xs break-words whitespace-pre-wrap">
{JSON.stringify(selectedFeedbackResolved?.query_snapshot ?? {}, null, 2)}
</pre>
</div>
<div className="space-y-2">
<div className="text-sm font-semibold"></div>
<pre className="max-h-56 overflow-auto rounded-lg border bg-background/70 p-3 text-xs break-words whitespace-pre-wrap">
{JSON.stringify(selectedFeedbackResolved?.decision_payload ?? {}, null, 2)}
</pre>
</div>
</div>
<div className="grid gap-4 xl:grid-cols-[minmax(0,0.95fr)_minmax(0,1.05fr)]">
<div className="space-y-2">
<div className="text-sm font-semibold">退</div>
<pre className="max-h-64 overflow-auto rounded-lg border bg-background/70 p-3 text-xs break-words whitespace-pre-wrap">
{JSON.stringify(selectedFeedbackResolved?.rollback_plan_summary ?? {}, null, 2)}
</pre>
</div>
<div className="space-y-2">
<div className="text-sm font-semibold">退</div>
<pre className="max-h-64 overflow-auto rounded-lg border bg-background/70 p-3 text-xs break-words whitespace-pre-wrap">
{JSON.stringify(selectedFeedbackResolved?.rollback_result ?? {}, null, 2)}
</pre>
</div>
</div>
<div className="space-y-2">
<div className="flex flex-col gap-3 lg:flex-row lg:items-center lg:justify-between">
<div className="text-sm font-semibold">线</div>
<div className="flex flex-col gap-3 lg:flex-row lg:items-center lg:justify-end">
<Input
value={feedbackActionLogSearch}
onChange={(event) => setFeedbackActionLogSearch(event.target.value)}
placeholder="搜索动作 / hash / 预览内容"
className="lg:w-80"
/>
<div className="text-xs text-muted-foreground">
{feedbackActionLogPage} / {feedbackActionLogPageCount} {FEEDBACK_ACTION_LOG_PAGE_SIZE}
</div>
</div>
</div>
<ScrollArea className="h-[280px] rounded-lg border bg-background/60">
<div className="space-y-2 p-3">
{pagedFeedbackActionLogs.length > 0 ? pagedFeedbackActionLogs.map((item: MemoryFeedbackActionLogPayload) => (
<div key={`${item.id}:${item.action_type}`} className="rounded-lg border bg-muted/20 p-3">
<div className="flex flex-wrap items-center gap-2">
<Badge variant="outline">{item.action_type}</Badge>
{item.target_hash ? (
<span className="font-mono text-[11px] break-all text-muted-foreground">{item.target_hash}</span>
) : null}
</div>
{item.reason ? (
<div className="mt-2 text-xs text-muted-foreground break-words">
{item.reason}
</div>
) : null}
{item.before_payload && Object.keys(item.before_payload).length > 0 ? (
<div className="mt-2 text-xs break-words">
<span className="font-medium">Before</span>
<span className="text-muted-foreground">{summarizeFeedbackActionPayload(item.before_payload)}</span>
</div>
) : null}
{item.after_payload && Object.keys(item.after_payload).length > 0 ? (
<div className="mt-1 text-xs break-words">
<span className="font-medium">After</span>
<span className="text-muted-foreground">{summarizeFeedbackActionPayload(item.after_payload)}</span>
</div>
) : null}
<div className="mt-2 text-[11px] text-muted-foreground">
{formatDeleteOperationTime(item.created_at)}
</div>
</div>
)) : (
<div className="rounded-lg border border-dashed bg-muted/20 p-6 text-center text-sm text-muted-foreground">
{selectedFeedbackActionLogs.length > 0 ? '当前筛选条件下没有动作日志' : '当前任务没有动作日志'}
</div>
)}
</div>
</ScrollArea>
<div className="flex items-center justify-between gap-2">
<Button
variant="outline"
size="sm"
onClick={() => setFeedbackActionLogPage((current) => Math.max(1, current - 1))}
disabled={feedbackActionLogPage <= 1}
>
</Button>
<div className="text-xs text-muted-foreground">hash </div>
<Button
variant="outline"
size="sm"
onClick={() => setFeedbackActionLogPage((current) => Math.min(feedbackActionLogPageCount, current + 1))}
disabled={feedbackActionLogPage >= feedbackActionLogPageCount}
>
</Button>
</div>
</div>
</div>
) : (
<div className="flex min-h-[360px] items-center justify-center rounded-lg border border-dashed bg-background/40 p-6 text-center text-sm text-muted-foreground">
</div>
)}
</div>
</div>
<div className="flex items-center justify-between gap-2">
<Button
variant="outline"
size="sm"
onClick={() => setFeedbackPage((current) => Math.max(1, current - 1))}
disabled={feedbackPage <= 1}
>
</Button>
<div className="text-xs text-muted-foreground">
query退
</div>
<Button
variant="outline"
size="sm"
onClick={() => setFeedbackPage((current) => Math.min(feedbackPageCount, current + 1))}
disabled={feedbackPage >= feedbackPageCount}
>
</Button>
</div>
</CardContent>
</Card>
</div>
</TabsContent>
</Tabs>
</div>
</div>
@@ -3332,6 +3980,52 @@ export function KnowledgeBasePage() {
onExecute={() => void executePendingDelete()}
onRestore={() => void (deleteResult?.operation_id ? restoreDeleteOperation(deleteResult.operation_id) : Promise.resolve())}
/>
<Dialog open={feedbackRollbackDialogOpen} onOpenChange={setFeedbackRollbackDialogOpen}>
<DialogContent className="max-w-lg" confirmOnEnter>
<DialogHeader>
<DialogTitle>退</DialogTitle>
<DialogDescription>
relation episode/profile
</DialogDescription>
</DialogHeader>
<div className="space-y-3">
<div className="rounded-lg border bg-muted/20 p-3 text-sm">
<div className="font-medium break-words">{selectedFeedbackResolved?.query_text || '无查询文本'}</div>
<div className="mt-1 font-mono text-[11px] break-all text-muted-foreground">
{selectedFeedbackResolved?.query_tool_id}
</div>
</div>
<div className="space-y-2">
<Label htmlFor="feedback-rollback-reason">退</Label>
<Textarea
id="feedback-rollback-reason"
value={feedbackRollbackReason}
onChange={(event) => setFeedbackRollbackReason(event.target.value)}
placeholder="可选,建议填写本次人工回退原因"
/>
</div>
</div>
<DialogFooter>
<Button variant="outline" onClick={() => setFeedbackRollbackDialogOpen(false)} disabled={feedbackRollingBack}>
</Button>
<Button onClick={() => void executeFeedbackRollback()} disabled={feedbackRollingBack}>
{feedbackRollingBack ? (
<>
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
退
</>
) : (
<>
<RotateCcw className="mr-2 h-4 w-4" />
退
</>
)}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</div>
)
}

View File

@@ -0,0 +1,740 @@
from __future__ import annotations
import asyncio
import inspect
import json
import time
import uuid
from datetime import datetime
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Callable, Dict
import numpy as np
import pytest
import pytest_asyncio
from sqlalchemy.orm import sessionmaker
from sqlmodel import Session, create_engine, select
IMPORT_ERROR: str | None = None
try:
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
from src.A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
from src.chat.heart_flow.heartflow_manager import heartflow_manager
from src.chat.message_receive import bot as bot_module
from src.chat.message_receive.chat_manager import chat_manager
from src.chat.message_receive.bot import chat_bot
from src.common.database import database as database_module
from src.common.database.database_model import PersonInfo, ToolRecord
from src.common.database.migrations import create_database_migration_bootstrapper
from src.common.utils.utils_session import SessionUtils
from src.llm_models.payload_content.tool_option import ToolCall
from src.maisaka import reasoning_engine as reasoning_engine_module
from src.maisaka import runtime as runtime_module
from src.maisaka.chat_loop_service import ChatResponse
from src.maisaka.context_messages import AssistantMessage
from src.plugin_runtime import component_query as component_query_module
from src.services import memory_flow_service as memory_flow_service_module
from src.services import memory_service as memory_service_module
from src.services.memory_service import memory_service
except SystemExit as exc:
IMPORT_ERROR = f"config initialization exited during import: {exc}"
kernel_module = None # type: ignore[assignment]
KernelSearchRequest = None # type: ignore[assignment]
SDKMemoryKernel = None # type: ignore[assignment]
heartflow_manager = None # type: ignore[assignment]
bot_module = None # type: ignore[assignment]
chat_manager = None # type: ignore[assignment]
chat_bot = None # type: ignore[assignment]
database_module = None # type: ignore[assignment]
ToolRecord = None # type: ignore[assignment]
PersonInfo = None # type: ignore[assignment]
create_database_migration_bootstrapper = None # type: ignore[assignment]
SessionUtils = None # type: ignore[assignment]
ToolCall = None # type: ignore[assignment]
reasoning_engine_module = None # type: ignore[assignment]
runtime_module = None # type: ignore[assignment]
ChatResponse = None # type: ignore[assignment]
AssistantMessage = None # type: ignore[assignment]
component_query_module = None # type: ignore[assignment]
memory_flow_service_module = None # type: ignore[assignment]
memory_service_module = None # type: ignore[assignment]
memory_service = None # type: ignore[assignment]
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
RELATION_QUERY = "测试用户 和 最喜欢的颜色 有什么关系"
class _FakeEmbeddingManager:
def __init__(self, dimension: int = 8) -> None:
self.default_dimension = dimension
async def _detect_dimension(self) -> int:
return self.default_dimension
async def encode(self, text: Any) -> np.ndarray:
def _encode_one(raw: Any) -> np.ndarray:
content = str(raw or "")
vector = np.zeros(self.default_dimension, dtype=np.float32)
for index, byte in enumerate(content.encode("utf-8")):
vector[index % self.default_dimension] += float((byte % 17) + 1)
norm = float(np.linalg.norm(vector))
if norm > 0:
vector /= norm
return vector
if isinstance(text, (list, tuple)):
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
return _encode_one(text).astype(np.float32)
class _KernelBackedRuntimeManager:
def __init__(self, kernel: SDKMemoryKernel) -> None:
self.kernel = kernel
async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000):
del timeout_ms
payload = args or {}
if component_name == "search_memory":
return await self.kernel.search_memory(
KernelSearchRequest(
query=str(payload.get("query", "") or ""),
limit=int(payload.get("limit", 5) or 5),
mode=str(payload.get("mode", "hybrid") or "hybrid"),
chat_id=str(payload.get("chat_id", "") or ""),
person_id=str(payload.get("person_id", "") or ""),
time_start=payload.get("time_start"),
time_end=payload.get("time_end"),
respect_filter=bool(payload.get("respect_filter", True)),
user_id=str(payload.get("user_id", "") or ""),
group_id=str(payload.get("group_id", "") or ""),
)
)
handler = getattr(self.kernel, component_name)
result = handler(**payload)
return await result if inspect.isawaitable(result) else result
class _NoopRuntimeManager:
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> Any:
del hook_name
return SimpleNamespace(aborted=False, kwargs=kwargs)
def _install_temp_main_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
db_dir = (tmp_path / "main_db").resolve()
db_dir.mkdir(parents=True, exist_ok=True)
db_file = db_dir / "MaiBot.db"
database_url = f"sqlite:///{db_file}"
try:
database_module.engine.dispose()
except Exception:
pass
engine = create_engine(
database_url,
echo=False,
connect_args={"check_same_thread": False},
pool_pre_ping=True,
)
session_local = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
class_=Session,
)
bootstrapper = create_database_migration_bootstrapper(engine)
monkeypatch.setattr(database_module, "_DB_DIR", db_dir, raising=False)
monkeypatch.setattr(database_module, "_DB_FILE", db_file, raising=False)
monkeypatch.setattr(database_module, "DATABASE_URL", database_url, raising=False)
monkeypatch.setattr(database_module, "engine", engine, raising=False)
monkeypatch.setattr(database_module, "SessionLocal", session_local, raising=False)
monkeypatch.setattr(database_module, "_migration_bootstrapper", bootstrapper, raising=False)
monkeypatch.setattr(database_module, "_db_initialized", False, raising=False)
def _build_chat_response(content: str, tool_calls: list[ToolCall]) -> ChatResponse:
return ChatResponse(
content=content,
tool_calls=tool_calls,
request_messages=[],
raw_message=AssistantMessage(
content=content,
timestamp=datetime.now(),
tool_calls=tool_calls,
),
selected_history_count=0,
tool_count=len(tool_calls),
prompt_tokens=0,
built_message_count=0,
completion_tokens=0,
total_tokens=0,
prompt_section=None,
)
def _build_message_data(
*,
content: str,
platform: str,
user_id: str,
user_name: str,
group_id: str,
group_name: str,
) -> Dict[str, Any]:
message_id = str(uuid.uuid4())
return {
"message_info": {
"platform": platform,
"message_id": message_id,
"time": time.time(),
"group_info": {
"group_id": group_id,
"group_name": group_name,
"platform": platform,
},
"user_info": {
"user_id": user_id,
"user_nickname": user_name,
"user_cardname": user_name,
"platform": platform,
},
"additional_config": {
"at_bot": True,
},
},
"message_segment": {
"type": "seglist",
"data": [
{
"type": "text",
"data": content,
},
],
},
"raw_message": content,
"processed_plain_text": content,
}
async def _wait_until(
predicate: Callable[[], Any],
*,
timeout_seconds: float = 10.0,
interval_seconds: float = 0.05,
description: str,
) -> Any:
deadline = asyncio.get_running_loop().time() + max(0.5, float(timeout_seconds))
while asyncio.get_running_loop().time() < deadline:
value = predicate()
if inspect.isawaitable(value):
value = await value
if value:
return value
await asyncio.sleep(interval_seconds)
raise AssertionError(f"等待超时: {description}")
def _load_feedback_tasks(kernel: SDKMemoryKernel) -> list[Dict[str, Any]]:
assert kernel.metadata_store is not None
cursor = kernel.metadata_store.get_connection().cursor()
rows = cursor.execute(
"SELECT query_tool_id FROM memory_feedback_tasks ORDER BY id"
).fetchall()
tasks: list[Dict[str, Any]] = []
for row in rows:
task = kernel.metadata_store.get_feedback_task(str(row["query_tool_id"] or ""))
if task is not None:
tasks.append(task)
return tasks
def _load_feedback_action_types(kernel: SDKMemoryKernel) -> list[str]:
assert kernel.metadata_store is not None
cursor = kernel.metadata_store.get_connection().cursor()
rows = cursor.execute(
"SELECT action_type FROM memory_feedback_action_logs ORDER BY id"
).fetchall()
return [str(row["action_type"] or "") for row in rows]
def _load_query_memory_tool_records(session_id: str) -> list[Dict[str, Any]]:
with database_module.get_db_session() as session:
statement = (
select(ToolRecord)
.where(ToolRecord.session_id == session_id)
.where(ToolRecord.tool_name == "query_memory")
.order_by(ToolRecord.timestamp)
)
rows = list(session.exec(statement).all())
return [
{
"tool_id": str(row.tool_id or ""),
"session_id": str(row.session_id or ""),
"tool_name": str(row.tool_name or ""),
"tool_data": str(row.tool_data or ""),
"timestamp": row.timestamp,
}
for row in rows
]
def _seed_person_info(*, person_id: str, person_name: str, session_info: Dict[str, Any]) -> None:
with database_module.get_db_session() as session:
session.add(
PersonInfo(
is_known=True,
person_id=person_id,
person_name=person_name,
platform=str(session_info["platform"]),
user_id=str(session_info["user_id"]),
user_nickname=str(session_info["user_name"]),
group_cardname=json.dumps(
[{"group_id": str(session_info["group_id"]), "group_cardname": person_name}],
ensure_ascii=False,
),
know_counts=1,
first_known_time=datetime.now(),
last_known_time=datetime.now(),
)
)
session.commit()
@pytest_asyncio.fixture
async def chat_feedback_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
_install_temp_main_database(monkeypatch, tmp_path)
chat_manager.sessions.clear()
chat_manager.last_messages.clear()
heartflow_manager.heartflow_chat_list.clear()
noop_runtime_manager = _NoopRuntimeManager()
monkeypatch.setattr(bot_module.ChatBot, "_get_runtime_manager", staticmethod(lambda: noop_runtime_manager))
monkeypatch.setattr(
component_query_module.component_query_service,
"find_command_by_text",
lambda text: None,
)
monkeypatch.setattr(
component_query_module.component_query_service,
"get_llm_available_tool_specs",
lambda: {},
)
monkeypatch.setattr(runtime_module.global_config.mcp, "enable", False, raising=False)
monkeypatch.setattr(
runtime_module.MaisakaHeartFlowChatting,
"_get_message_trigger_threshold",
lambda self: 1,
)
async def _noop_on_incoming_message(message: Any) -> None:
del message
monkeypatch.setattr(
memory_flow_service_module.memory_automation_service,
"on_incoming_message",
_noop_on_incoming_message,
)
fake_embedding_manager = _FakeEmbeddingManager(dimension=8)
async def _fake_runtime_self_check(
*,
config: Any,
sample_text: str,
vector_store: Any,
embedding_manager: Any,
) -> Dict[str, Any]:
del config, sample_text, vector_store, embedding_manager
return {
"ok": True,
"message": "ok",
"checked_at": time.time(),
"encoded_dimension": fake_embedding_manager.default_dimension,
}
monkeypatch.setattr(
kernel_module,
"create_embedding_api_adapter",
lambda **kwargs: fake_embedding_manager,
)
monkeypatch.setattr(
kernel_module,
"run_embedding_runtime_self_check",
_fake_runtime_self_check,
)
kernel = SDKMemoryKernel(
plugin_root=tmp_path / "plugin_root",
config={
"storage": {"data_dir": str((tmp_path / "a_memorix_data").resolve())},
"advanced": {"enable_auto_save": False},
"embedding": {"dimension": fake_embedding_manager.default_dimension},
"memory": {"base_decay_interval_hours": 24},
"person_profile": {"refresh_interval_minutes": 5},
},
)
monkeypatch.setattr(kernel, "_feedback_cfg_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_window_hours", lambda: 0.0004)
monkeypatch.setattr(kernel, "_feedback_cfg_check_interval_seconds", lambda: 0.2)
monkeypatch.setattr(kernel, "_feedback_cfg_batch_size", lambda: 10)
monkeypatch.setattr(kernel, "_feedback_cfg_max_messages", lambda: 10)
monkeypatch.setattr(kernel, "_feedback_cfg_auto_apply_threshold", lambda: 0.85)
monkeypatch.setattr(kernel, "_feedback_cfg_prefilter_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_paragraph_mark_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_paragraph_hard_filter_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_profile_refresh_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_profile_force_refresh_on_read", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_episode_rebuild_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_episode_query_block_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_reconcile_interval_seconds", lambda: 0.2)
monkeypatch.setattr(kernel, "_feedback_cfg_reconcile_batch_size", lambda: 10)
monkeypatch.setattr(kernel_module.global_config.memory, "feedback_correction_paragraph_hard_filter_enabled", True, raising=False)
monkeypatch.setattr(kernel_module.global_config.memory, "feedback_correction_episode_query_block_enabled", True, raising=False)
async def _fake_classify_feedback(
*,
query_tool_id: str,
query_text: str,
hit_briefs: list[Dict[str, Any]],
feedback_messages: list[str],
) -> Dict[str, Any]:
del query_tool_id, query_text, feedback_messages
target_hash = ""
for item in hit_briefs:
if str(item.get("type", "") or "").strip() == "relation":
target_hash = str(item.get("hash", "") or "").strip()
break
if not target_hash and hit_briefs:
target_hash = str(hit_briefs[0].get("hash", "") or "").strip()
return {
"decision": "correct",
"confidence": 0.97,
"target_hashes": [target_hash] if target_hash else [],
"corrected_relations": [
{
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "绿色",
"confidence": 0.99,
}
],
"reason": "用户明确纠正为绿色",
}
monkeypatch.setattr(kernel, "_classify_feedback", _fake_classify_feedback)
await kernel.initialize()
async def _force_episode_fallback(**kwargs: Any) -> Dict[str, Any]:
raise RuntimeError("force_fallback_for_test")
monkeypatch.setattr(
kernel.episode_service.segmentation_service,
"segment",
_force_episode_fallback,
)
monkeypatch.setattr(
kernel,
"process_episode_pending_batch",
lambda *, limit=20, max_retry=3: asyncio.sleep(0, result={"processed": 0, "episode_count": 0, "fallback_count": 0, "failed": 0}),
)
host_manager = _KernelBackedRuntimeManager(kernel)
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", host_manager)
planner_calls: list[str] = []
async def _fake_timing_gate(self, anchor_message: Any):
del self, anchor_message
return "continue", _build_chat_response("直接进入 planner。", []), []
async def _fake_planner(self, *, tool_definitions: list[dict[str, Any]] | None = None) -> ChatResponse:
del tool_definitions
latest_message = self._runtime.message_cache[-1]
latest_text = str(latest_message.processed_plain_text or "")
planner_calls.append(latest_text)
handled_message_ids = getattr(self._runtime, "_test_query_message_ids", None)
if handled_message_ids is None:
handled_message_ids = set()
setattr(self._runtime, "_test_query_message_ids", handled_message_ids)
if latest_message.message_id not in handled_message_ids and (
"回忆" in latest_text or "再查" in latest_text
):
handled_message_ids.add(latest_message.message_id)
tool_call = ToolCall(
call_id=f"query-{uuid.uuid4().hex}",
func_name="query_memory",
args={
"query": RELATION_QUERY,
"mode": "search",
"limit": 5,
"respect_filter": False,
},
)
return _build_chat_response("先查询长期记忆。", [tool_call])
stop_call = ToolCall(
call_id=f"stop-{uuid.uuid4().hex}",
func_name="no_reply",
args={},
)
return _build_chat_response("当前轮次结束。", [stop_call])
monkeypatch.setattr(
reasoning_engine_module.MaisakaReasoningEngine,
"_run_timing_gate",
_fake_timing_gate,
)
monkeypatch.setattr(
reasoning_engine_module.MaisakaReasoningEngine,
"_run_interruptible_planner",
_fake_planner,
)
session_info = {
"platform": "unit_test_chat",
"user_id": "user_feedback_flow",
"user_name": "反馈测试用户",
"group_id": "group_feedback_flow",
"group_name": "反馈纠错测试群",
}
person_id = "person_feedback_flow"
session_id = SessionUtils.calculate_session_id(
session_info["platform"],
user_id=session_info["user_id"],
group_id=session_info["group_id"],
)
_seed_person_info(person_id=person_id, person_name="测试用户", session_info=session_info)
try:
yield {
"kernel": kernel,
"session_id": session_id,
"session_info": session_info,
"person_id": person_id,
"planner_calls": planner_calls,
}
finally:
for key, chat in list(heartflow_manager.heartflow_chat_list.items()):
try:
await chat.stop()
except Exception:
pass
heartflow_manager.heartflow_chat_list.pop(key, None)
chat_manager.sessions.clear()
chat_manager.last_messages.clear()
await kernel.shutdown()
try:
database_module.engine.dispose()
except Exception:
pass
@pytest.mark.asyncio
async def test_feedback_correction_real_chat_flow(chat_feedback_env) -> None:
kernel = chat_feedback_env["kernel"]
session_id = chat_feedback_env["session_id"]
session_info = chat_feedback_env["session_info"]
person_id = chat_feedback_env["person_id"]
write_result = await memory_service.ingest_text(
external_id=f"test:feedback-seed:{uuid.uuid4().hex}",
source_type="chat_summary",
text="测试用户 最喜欢的颜色是 蓝色",
chat_id=session_id,
relations=[
{
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "蓝色",
"confidence": 1.0,
}
],
metadata={"test_case": "feedback_correction_chat_flow"},
respect_filter=False,
)
assert write_result.success is True
pre_search = await memory_service.search(
RELATION_QUERY,
mode="search",
chat_id=session_id,
respect_filter=False,
)
assert pre_search.hits
assert any("蓝色" in hit.content for hit in pre_search.hits)
pre_profile = await memory_service.get_person_profile(person_id, chat_id=session_id, limit=10)
pre_profile_text = pre_profile.summary + "\n" + json.dumps(pre_profile.evidence, ensure_ascii=False)
assert "蓝色" in pre_profile_text
seed_source = f"chat_summary:{session_id}"
rebuild_result = await kernel.rebuild_episodes_for_sources([seed_source])
assert rebuild_result["rebuilt"] >= 1
pre_episode = await memory_service.search(
"蓝色",
mode="episode",
chat_id=session_id,
respect_filter=False,
)
assert pre_episode.hits
assert any("蓝色" in hit.content for hit in pre_episode.hits)
await chat_bot.message_process(
_build_message_data(
content="请帮我回忆一下,测试用户最喜欢的颜色是什么?",
**session_info,
)
)
await _wait_until(
lambda: chat_feedback_env["planner_calls"][0] if chat_feedback_env["planner_calls"] else None,
description="planner 收到首条聊天消息",
)
first_query_records = await _wait_until(
lambda: _load_query_memory_tool_records(session_id) if _load_query_memory_tool_records(session_id) else None,
description="首条 query_memory 工具记录生成",
)
assert first_query_records
first_task = await _wait_until(
lambda: _load_feedback_tasks(kernel)[0] if _load_feedback_tasks(kernel) else None,
description="首个反馈任务入队",
)
assert first_task["status"] == "pending"
first_hits = list((first_task.get("query_snapshot") or {}).get("hits") or [])
assert first_hits
assert any("蓝色" in str(item.get("content", "") or "") for item in first_hits)
await chat_bot.message_process(
_build_message_data(
content="不对,测试用户最喜欢的颜色不是蓝色,是绿色。",
**session_info,
)
)
finalized_task = await _wait_until(
lambda: (
kernel.metadata_store.get_feedback_task(first_task["query_tool_id"])
if kernel.metadata_store.get_feedback_task(first_task["query_tool_id"])
and kernel.metadata_store.get_feedback_task(first_task["query_tool_id"]).get("status")
in {"applied", "skipped", "error"}
else None
),
timeout_seconds=12.0,
interval_seconds=0.1,
description="反馈任务进入终态",
)
assert finalized_task["status"] == "applied", finalized_task
assert finalized_task["decision_payload"]["decision"] == "correct"
assert finalized_task["decision_payload"]["apply_result"]["applied"] is True
corrected_hashes = list(
(finalized_task["decision_payload"].get("apply_result") or {}).get("relation_hashes") or []
)
assert corrected_hashes
corrected_hash = str(corrected_hashes[0] or "")
relation_status = kernel.metadata_store.get_relation_status_batch([corrected_hash]).get(corrected_hash, {})
assert bool(relation_status.get("is_inactive")) is True
action_types = _load_feedback_action_types(kernel)
assert "classification" in action_types
assert "forget_relation" in action_types
assert "ingest_correction" in action_types
assert "mark_stale_paragraph" in action_types
assert "enqueue_episode_rebuild" in action_types
assert "enqueue_profile_refresh" in action_types
direct_post_search = await memory_service.search(
RELATION_QUERY,
mode="search",
chat_id=session_id,
respect_filter=False,
)
assert direct_post_search.hits
post_contents = "\n".join(hit.content for hit in direct_post_search.hits)
assert "绿色" in post_contents
assert "蓝色" not in post_contents
profile_refresh_request = await _wait_until(
lambda: (
kernel.metadata_store.get_person_profile_refresh_request(person_id)
if kernel.metadata_store.get_person_profile_refresh_request(person_id)
and kernel.metadata_store.get_person_profile_refresh_request(person_id).get("status") == "done"
else None
),
timeout_seconds=12.0,
interval_seconds=0.1,
description="人物画像刷新完成",
)
assert profile_refresh_request["status"] == "done"
post_profile = await memory_service.get_person_profile(person_id, chat_id=session_id, limit=10)
post_profile_text = post_profile.summary + "\n" + json.dumps(post_profile.evidence, ensure_ascii=False)
assert "绿色" in post_profile_text
assert "蓝色" not in post_profile_text
async def _latest_episode_result():
result = await memory_service.search(
"绿色",
mode="episode",
chat_id=session_id,
respect_filter=False,
)
if not result.hits:
return None
contents = "\n".join(hit.content for hit in result.hits)
if "绿色" in contents and "蓝色" not in contents:
return result
return None
post_episode = await _wait_until(
_latest_episode_result,
timeout_seconds=12.0,
interval_seconds=0.2,
description="episode 重建后返回修正结果",
)
assert post_episode is not None
stale_episode = await memory_service.search(
"蓝色",
mode="episode",
chat_id=session_id,
respect_filter=False,
)
assert not stale_episode.hits
await chat_bot.message_process(
_build_message_data(
content="再查一次,测试用户最喜欢的颜色是什么?",
**session_info,
)
)
tool_records = await _wait_until(
lambda: (
_load_query_memory_tool_records(session_id)
if len(_load_query_memory_tool_records(session_id)) >= 2
else None
),
timeout_seconds=10.0,
interval_seconds=0.1,
description="第二次 query_memory 工具记录生成",
)
latest_tool_data = json.loads(str(tool_records[-1].get("tool_data") or "{}"))
latest_hits = list((latest_tool_data.get("structured_content") or {}).get("hits") or [])
assert latest_hits
latest_contents = "\n".join(str(item.get("content", "") or "") for item in latest_hits)
assert "绿色" in latest_contents
assert "蓝色" not in latest_contents

View File

@@ -0,0 +1,396 @@
from datetime import datetime
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict
import pytest
IMPORT_ERROR: str | None = None
try:
from src.A_memorix.core.retrieval.sparse_bm25 import SparseBM25Config, SparseBM25Index
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
except SystemExit as exc:
IMPORT_ERROR = f"config initialization exited during import: {exc}"
SparseBM25Config = None # type: ignore[assignment]
SparseBM25Index = None # type: ignore[assignment]
kernel_module = None # type: ignore[assignment]
SDKMemoryKernel = None # type: ignore[assignment]
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
@pytest.mark.asyncio
async def test_kernel_enqueue_feedback_task_delegates_to_metadata_store(monkeypatch: pytest.MonkeyPatch) -> None:
captured: Dict[str, Any] = {}
def fake_enqueue_feedback_task(**kwargs):
captured.update(kwargs)
return {
"id": 1,
"query_tool_id": kwargs["query_tool_id"],
"session_id": kwargs["session_id"],
"query_timestamp": kwargs["query_timestamp"],
"due_at": kwargs["due_at"],
"query_snapshot": kwargs["query_snapshot"],
}
monkeypatch.setattr(
kernel_module,
"global_config",
SimpleNamespace(
memory=SimpleNamespace(
feedback_correction_enabled=True,
feedback_correction_window_hours=12.0,
)
),
)
query_time = datetime(2026, 4, 9, 10, 30, 0)
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel.metadata_store = SimpleNamespace(enqueue_feedback_task=fake_enqueue_feedback_task)
payload = await kernel.enqueue_feedback_task(
query_tool_id="tool-query-1",
session_id="session-1",
query_timestamp=query_time,
structured_content={"query": "Alice 喜欢什么", "hits": [{"hash": "relation-1"}]},
)
assert payload["success"] is True
assert payload["queued"] is True
assert captured["query_tool_id"] == "tool-query-1"
assert captured["session_id"] == "session-1"
assert captured["query_snapshot"]["query"] == "Alice 喜欢什么"
assert captured["query_snapshot"]["hits"] == [{"hash": "relation-1"}]
assert captured["due_at"] == pytest.approx(query_time.timestamp() + 12 * 3600, rel=0, abs=1e-6)
@pytest.mark.asyncio
async def test_kernel_enqueue_feedback_task_skipped_when_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
kernel_module,
"global_config",
SimpleNamespace(memory=SimpleNamespace(feedback_correction_enabled=False)),
)
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel.metadata_store = SimpleNamespace(enqueue_feedback_task=lambda **kwargs: kwargs)
payload = await kernel.enqueue_feedback_task(
query_tool_id="tool-query-2",
session_id="session-1",
query_timestamp=datetime.now(),
structured_content={"hits": [{"hash": "relation-1"}]},
)
assert payload["success"] is False
assert payload["reason"] == "feedback_correction_disabled"
@pytest.mark.asyncio
async def test_apply_feedback_decision_resolves_paragraph_targets() -> None:
action_logs: list[Dict[str, Any]] = []
forgotten_hashes: list[str] = []
ingested_payloads: list[Dict[str, Any]] = []
stale_marks: list[Dict[str, Any]] = []
episode_sources: list[str] = []
profile_refresh_ids: list[str] = []
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel.metadata_store = SimpleNamespace(
get_paragraph_relations=lambda paragraph_hash: [
{
"hash": "relation-1",
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "蓝色",
}
]
if paragraph_hash == "paragraph-1"
else [],
get_relation_status_batch=lambda hashes: {
str(hash_value): {"is_inactive": str(hash_value) in forgotten_hashes}
for hash_value in hashes
},
get_paragraph_hashes_by_relation_hashes=lambda hashes: {
"relation-1": ["paragraph-1"]
}
if "relation-1" in hashes
else {},
upsert_paragraph_stale_relation_mark=lambda **kwargs: stale_marks.append(kwargs) or kwargs,
enqueue_episode_source_rebuild=lambda source, reason="": episode_sources.append(source) or True,
enqueue_person_profile_refresh=lambda **kwargs: profile_refresh_ids.append(kwargs["person_id"]) or kwargs,
get_paragraph=lambda paragraph_hash: {"hash": "paragraph-1", "source": "chat_feedback_test_seed:session-1"}
if paragraph_hash == "paragraph-1"
else None,
append_feedback_action_log=lambda **kwargs: action_logs.append(kwargs),
)
kernel._feedback_cfg_auto_apply_threshold = lambda: 0.85 # type: ignore[method-assign]
kernel._apply_v5_relation_action = lambda *, action, hashes, strength=1.0: ( # type: ignore[method-assign]
forgotten_hashes.extend([str(item) for item in hashes]),
{"success": True, "action": action, "hashes": list(hashes), "strength": strength},
)[1]
kernel._feedback_cfg_paragraph_mark_enabled = lambda: True # type: ignore[method-assign]
kernel._feedback_cfg_episode_rebuild_enabled = lambda: True # type: ignore[method-assign]
kernel._feedback_cfg_profile_refresh_enabled = lambda: True # type: ignore[method-assign]
kernel._resolve_feedback_related_person_ids = lambda **kwargs: ["person-1"] # type: ignore[method-assign]
kernel._query_relation_rows_by_hashes = lambda relation_hashes, include_inactive=False: [ # type: ignore[method-assign]
{
"hash": "relation-1",
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "蓝色",
}
]
async def _fake_ingest_feedback_relations(**kwargs):
ingested_payloads.append(kwargs)
return {"success": True, "stored_ids": ["relation-2"]}
kernel._ingest_feedback_relations = _fake_ingest_feedback_relations # type: ignore[method-assign]
payload = await kernel._apply_feedback_decision(
task_id=1,
query_tool_id="tool-query-1",
session_id="session-1",
decision={
"decision": "correct",
"confidence": 0.97,
"target_hashes": ["paragraph-1"],
"corrected_relations": [
{
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "绿色",
"confidence": 0.99,
}
],
"reason": "用户明确纠正为绿色",
},
hit_map={
"paragraph-1": {
"hash": "paragraph-1",
"type": "paragraph",
"content": "测试用户 最喜欢的颜色是 蓝色",
"linked_relation_hashes": ["relation-1"],
}
},
)
assert payload["applied"] is True
assert payload["relation_hashes"] == ["relation-1"]
assert forgotten_hashes == ["relation-1"]
assert ingested_payloads[0]["relation_hashes"] == ["relation-1"]
assert payload["stale_paragraph_hashes"] == ["paragraph-1"]
assert "chat_feedback_test_seed:session-1" in payload["episode_rebuild_sources"]
assert "chat_summary:session-1" in payload["episode_rebuild_sources"]
assert payload["profile_refresh_person_ids"] == ["person-1"]
assert stale_marks[0]["paragraph_hash"] == "paragraph-1"
assert {item["action_type"] for item in action_logs} == {
"forget_relation",
"ingest_correction",
"mark_stale_paragraph",
"enqueue_episode_rebuild",
"enqueue_profile_refresh",
}
def test_filter_active_relation_hits_removes_inactive_relations() -> None:
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel._feedback_cfg_paragraph_hard_filter_enabled = lambda: True # type: ignore[method-assign]
kernel.metadata_store = SimpleNamespace(
get_relation_status_batch=lambda hashes: {
"r-active": {"is_inactive": False},
"r-inactive": {"is_inactive": True},
"r-para-inactive": {"is_inactive": True},
"r-stale-active": {"is_inactive": False},
"r-stale-inactive": {"is_inactive": True},
},
get_paragraph_relations=lambda paragraph_hash: (
[{"hash": "r-para-inactive"}] if paragraph_hash == "p-inactive" else []
),
get_paragraph_stale_relation_marks_batch=lambda paragraph_hashes: {
"p-stale": [{"relation_hash": "r-stale-inactive"}],
"p-restored": [{"relation_hash": "r-stale-active"}],
},
)
hits = [
{"hash": "r-active", "type": "relation", "content": "A 喜欢 B"},
{"hash": "r-inactive", "type": "relation", "content": "A 讨厌 B"},
{"hash": "p-1", "type": "paragraph", "content": "段落证据"},
{"hash": "p-inactive", "type": "paragraph", "content": "失活段落证据"},
{"hash": "p-stale", "type": "paragraph", "content": "被标脏段落"},
{"hash": "p-restored", "type": "paragraph", "content": "恢复可见段落"},
]
filtered = kernel._filter_active_relation_hits(hits)
assert [item["hash"] for item in filtered] == ["r-active", "p-1", "p-restored"]
def test_sparse_relation_search_requests_active_only() -> None:
captured: Dict[str, Any] = {}
class FakeMetadataStore:
def fts_search_relations_bm25(self, **kwargs):
captured.update(kwargs)
return []
index = SparseBM25Index(
metadata_store=FakeMetadataStore(), # type: ignore[arg-type]
config=SparseBM25Config(enabled=True, lazy_load=False),
)
index._loaded = True
index._conn = object() # type: ignore[assignment]
result = index.search_relations("测试纠错", k=5)
assert result == []
assert captured["include_inactive"] is False
@pytest.mark.asyncio
async def test_feedback_task_rollback_restores_snapshots_and_requeues_followups() -> None:
action_logs: list[Dict[str, Any]] = []
queued_sources: list[str] = []
queued_profiles: list[str] = []
deleted_marks: list[tuple[str, str]] = []
deleted_paragraphs: list[str] = []
relation_statuses: Dict[str, Dict[str, Any]] = {
"rel-old": {"is_inactive": True, "weight": 0.0, "is_pinned": False, "protected_until": 0.0, "last_reinforced": None, "inactive_since": 1.0},
"rel-new": {"is_inactive": False, "weight": 1.0, "is_pinned": False, "protected_until": 0.0, "last_reinforced": None, "inactive_since": None},
}
current_task: Dict[str, Any] = {
"id": 1,
"query_tool_id": "tool-query-rollback",
"session_id": "session-1",
"status": "applied",
"rollback_status": "none",
"query_snapshot": {"query": "测试用户最喜欢的颜色是什么"},
"decision_payload": {"decision": "correct", "confidence": 0.97},
"rollback_plan": {
"forgotten_relations": [
{
"hash": "rel-old",
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "蓝色",
"before_status": {
"is_inactive": False,
"weight": 0.8,
"is_pinned": False,
"protected_until": 0.0,
"last_reinforced": None,
"inactive_since": None,
},
}
],
"corrected_write": {
"paragraph_hashes": ["paragraph-new"],
"corrected_relations": [
{
"hash": "rel-new",
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "绿色",
"existed_before": False,
"before_status": {},
}
],
},
"stale_marks": [{"paragraph_hash": "paragraph-old", "relation_hash": "rel-old"}],
"episode_sources": ["chat_summary:session-1"],
"profile_person_ids": ["person-1"],
},
}
class _Conn:
def cursor(self):
return self
def execute(self, *_args, **_kwargs):
return self
def commit(self):
return None
metadata_store = SimpleNamespace(
get_feedback_task_by_id=lambda task_id: current_task if int(task_id) == 1 else None,
mark_feedback_task_rollback_running=lambda **kwargs: current_task.update({"rollback_status": "running"}) or current_task,
finalize_feedback_task_rollback=lambda **kwargs: current_task.update(
{
"rollback_status": kwargs["rollback_status"],
"rollback_result": kwargs.get("rollback_result") or {},
"rollback_error": kwargs.get("rollback_error", ""),
}
)
or current_task,
get_relation_status_batch=lambda hashes: {
hash_value: dict(relation_statuses[hash_value])
for hash_value in hashes
if hash_value in relation_statuses
},
restore_relation_status_from_snapshot=lambda hash_value, snapshot: relation_statuses.update(
{hash_value: dict(snapshot)}
)
or dict(snapshot),
append_feedback_action_log=lambda **kwargs: action_logs.append(kwargs),
mark_as_deleted=lambda hashes, type_: deleted_paragraphs.extend(list(hashes)) or len(list(hashes)),
get_paragraph=lambda paragraph_hash: {"hash": paragraph_hash, "source": "chat_summary:session-1"},
get_connection=lambda: _Conn(),
delete_external_memory_refs_by_paragraphs=lambda hashes: [
{"paragraph_hash": str(hash_value), "external_id": f"external:{hash_value}"}
for hash_value in hashes
],
update_relations_protection=lambda hashes, **kwargs: None,
mark_relations_inactive=lambda hashes, inactive_since=None: [
relation_statuses.__setitem__(
hash_value,
{
**relation_statuses.get(hash_value, {}),
"is_inactive": True,
"inactive_since": inactive_since,
},
)
for hash_value in hashes
],
delete_paragraph_stale_relation_marks=lambda marks: deleted_marks.extend(list(marks)) or len(list(marks)),
enqueue_episode_source_rebuild=lambda source, reason='': queued_sources.append(source) or True,
enqueue_person_profile_refresh=lambda **kwargs: queued_profiles.append(kwargs["person_id"]) or kwargs,
list_feedback_action_logs=lambda task_id: action_logs if int(task_id) == 1 else [],
)
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel.metadata_store = metadata_store
async def _noop_initialize() -> None:
return None
kernel.initialize = _noop_initialize # type: ignore[method-assign]
kernel._rebuild_graph_from_metadata = lambda: None # type: ignore[method-assign]
kernel._persist = lambda: None # type: ignore[method-assign]
payload = await kernel._rollback_feedback_task(
task_id=1,
requested_by="pytest",
reason="manual rollback",
)
assert payload["success"] is True
assert relation_statuses["rel-old"]["is_inactive"] is False
assert relation_statuses["rel-new"]["is_inactive"] is True
assert deleted_paragraphs == ["paragraph-new"]
assert deleted_marks == [("paragraph-old", "rel-old")]
assert queued_sources == ["chat_summary:session-1"]
assert queued_profiles == ["person-1"]
assert current_task["rollback_status"] == "rolled_back"
assert {item["action_type"] for item in action_logs} >= {
"rollback_restore_relation",
"rollback_revert_corrected_relation",
"rollback_delete_correction_paragraph",
"rollback_clear_stale_mark",
"rollback_enqueue_episode_rebuild",
"rollback_enqueue_profile_refresh",
}

View File

@@ -638,3 +638,36 @@ def test_delete_operation_routes(client: TestClient, monkeypatch):
assert list_response.json()["count"] == 1
assert get_response.status_code == 200
assert get_response.json()["operation"]["operation_id"] == "del-1"
def test_feedback_correction_routes(client: TestClient, monkeypatch):
async def fake_feedback_admin(*, action: str, **kwargs):
if action == "list":
assert kwargs == {"limit": 7, "status": "applied", "rollback_status": "none", "query": "green"}
return {"success": True, "items": [{"task_id": 11, "query_text": "what color"}], "count": 1}
if action == "get":
assert kwargs == {"task_id": 11}
return {"success": True, "task": {"task_id": 11, "query_text": "what color", "action_logs": []}}
if action == "rollback":
assert kwargs == {"task_id": 11, "requested_by": "tester", "reason": "manual revert"}
return {"success": True, "result": {"restored_relation_hashes": ["rel-1"]}}
raise AssertionError(action)
monkeypatch.setattr(memory_router_module.memory_service, "feedback_admin", fake_feedback_admin)
list_response = client.get(
"/api/webui/memory/feedback-corrections",
params={"limit": 7, "status": "applied", "rollback_status": "none", "query": "green"},
)
get_response = client.get("/api/webui/memory/feedback-corrections/11")
rollback_response = client.post(
"/api/webui/memory/feedback-corrections/11/rollback",
json={"requested_by": "tester", "reason": "manual revert"},
)
assert list_response.status_code == 200
assert list_response.json()["items"][0]["task_id"] == 11
assert get_response.status_code == 200
assert get_response.json()["task"]["task_id"] == 11
assert rollback_response.status_code == 200
assert rollback_response.json()["result"]["restored_relation_hashes"] == ["rel-1"]

View File

@@ -11,7 +11,7 @@ from __future__ import annotations
import asyncio
import time
from typing import Any, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import aiohttp
import numpy as np
@@ -29,6 +29,9 @@ logger = get_logger("A_Memorix.EmbeddingAPIAdapter")
class EmbeddingAPIAdapter:
"""适配宿主 embedding 请求接口。"""
_GLOBAL_DIMENSION_CACHE: Dict[str, int] = {}
_GLOBAL_TEXT_EMBEDDING_CACHE: Dict[Tuple[str, int, str], np.ndarray] = {}
def __init__(
self,
batch_size: int = 32,
@@ -232,10 +235,32 @@ class EmbeddingAPIAdapter:
logger.error(f"通过直接 Client 获取 Embedding 失败: {last_exc}")
return None
def _dimension_cache_key(self) -> str:
candidate_names = self._resolve_candidate_model_names()
return "|".join(
[
str(self.model_name or "auto"),
str(self.default_dimension),
",".join(candidate_names),
]
)
def _embedding_cache_key(self, text: str, dimensions: Optional[int]) -> Tuple[str, int, str]:
requested_dimension = self._resolve_canonical_dimension(dimensions)
return (self._dimension_cache_key(), int(requested_dimension), str(text or ""))
async def _detect_dimension(self) -> int:
if self._dimension_detected and self._dimension is not None:
return self._dimension
cache_key = self._dimension_cache_key()
cached_dimension = self._GLOBAL_DIMENSION_CACHE.get(cache_key)
if cached_dimension is not None:
self._dimension = int(cached_dimension)
self._dimension_detected = True
logger.info(f"嵌入维度命中进程缓存: {self._dimension}")
return self._dimension
logger.info("正在检测嵌入模型维度...")
try:
target_dim = self.default_dimension
@@ -251,6 +276,7 @@ class EmbeddingAPIAdapter:
)
self._dimension = detected_dim
self._dimension_detected = True
self._GLOBAL_DIMENSION_CACHE[cache_key] = int(detected_dim)
return detected_dim
except Exception as exc:
logger.debug(f"带维度参数探测失败: {exc},尝试不带维度参数探测")
@@ -261,6 +287,7 @@ class EmbeddingAPIAdapter:
detected_dim = len(test_embedding)
self._dimension = detected_dim
self._dimension_detected = True
self._GLOBAL_DIMENSION_CACHE[cache_key] = int(detected_dim)
logger.info(f"嵌入维度检测成功 (自然维度): {detected_dim}")
return detected_dim
logger.warning(f"嵌入维度检测失败,使用 configured_dimension: {self.default_dimension}")
@@ -269,6 +296,7 @@ class EmbeddingAPIAdapter:
self._dimension = self.default_dimension
self._dimension_detected = True
self._GLOBAL_DIMENSION_CACHE[cache_key] = int(self.default_dimension)
return self.default_dimension
async def encode(
@@ -336,6 +364,25 @@ class EmbeddingAPIAdapter:
all_embeddings: List[np.ndarray] = []
for offset in range(0, len(texts), batch_size):
batch = texts[offset : offset + batch_size]
batch_results: List[Tuple[int, np.ndarray]] = []
uncached_items: List[Tuple[int, str]] = []
if self.enable_cache:
for index, text in enumerate(batch):
cache_key = self._embedding_cache_key(text, dimensions)
cached_vector = self._GLOBAL_TEXT_EMBEDDING_CACHE.get(cache_key)
if cached_vector is None:
uncached_items.append((index, text))
else:
batch_results.append((index, cached_vector.copy()))
else:
uncached_items = list(enumerate(batch))
if not uncached_items:
batch_results.sort(key=lambda item: item[0])
all_embeddings.extend(emb for _, emb in batch_results)
continue
semaphore = asyncio.Semaphore(self.max_concurrent)
async def encode_with_semaphore(text: str, index: int):
@@ -351,11 +398,20 @@ class EmbeddingAPIAdapter:
tasks = [
encode_with_semaphore(text, offset + index)
for index, text in enumerate(batch)
for index, text in uncached_items
]
results = await asyncio.gather(*tasks)
results.sort(key=lambda item: item[0])
all_embeddings.extend(emb for _, emb in results)
normalized_results: List[Tuple[int, np.ndarray]] = []
for batch_index, vector in results:
normalized_results.append((batch_index, vector))
if self.enable_cache:
text = batch[batch_index]
cache_key = self._embedding_cache_key(text, dimensions)
self._GLOBAL_TEXT_EMBEDDING_CACHE[cache_key] = vector.copy()
batch_results.extend(normalized_results)
batch_results.sort(key=lambda item: item[0])
all_embeddings.extend(emb for _, emb in batch_results)
return np.array(all_embeddings, dtype=np.float32)

View File

@@ -34,7 +34,7 @@ except Exception:
logger = get_logger("A_Memorix.MetadataStore")
SCHEMA_VERSION = 10
SCHEMA_VERSION = 12
RUNTIME_AUTO_MIGRATION_MIN_SCHEMA_VERSION = 9

View File

@@ -375,6 +375,30 @@ def _preflight_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]:
"memory_feedback_tasks rollback columns missing under current schema version",
)
)
elif not has_stale_marks:
checks.append(
CheckItem(
"CP-15",
"error",
"paragraph_stale_relation_marks table missing under current schema version",
)
)
elif not has_profile_refresh_queue:
checks.append(
CheckItem(
"CP-16",
"error",
"person_profile_refresh_queue table missing under current schema version",
)
)
elif not has_feedback_rollback_status or not has_feedback_rollback_plan:
checks.append(
CheckItem(
"CP-17",
"error",
"memory_feedback_tasks rollback columns missing under current schema version",
)
)
if _sqlite_table_exists(conn, "relations"):
row = conn.execute("SELECT COUNT(*) FROM relations").fetchone()

View File

@@ -145,23 +145,23 @@ class VisualConfig(ConfigBase):
__ui_label__ = "视觉"
__ui_icon__ = "image"
planner_mode: Literal["text", "multimodal", "auto"] = Field(
default="auto",
multimodal_planner: bool = Field(
default=True,
json_schema_extra={
"x-widget": "select",
"x-icon": "git-branch",
"x-widget": "switch",
"x-icon": "image",
},
)
"""规划器模式auto根据模型信息自动选择text为纯文本模式multimodal为多模态模式"""
"""是否直接输入图片"""
replyer_mode: Literal["text", "multimodal", "auto"] = Field(
default="auto",
multimodal_replyer: bool = Field(
default=False,
json_schema_extra={
"x-widget": "select",
"x-widget": "switch",
"x-icon": "git-branch",
},
)
"""回复器模式auto根据模型信息自动选择text为纯文本模式multimodal为多模态模式"""
"""是否启用 Maisaka 多模态 replyer 生成器"""
visual_style: str = Field(
default="请用中文描述这张图片的内容。如果有文字请把文字描述概括出来请留意其主题直观感受输出为一段平文本最多30字请注意不要分点就输出一段文本",
@@ -239,12 +239,17 @@ class ChatConfig(ConfigBase):
)
"""Planner 连续被新消息打断的最大次数0 表示不启用打断"""
plan_reply_log_max_per_chat: int = Field(
default=1024,
json_schema_extra={
"x-widget": "input",
"x-icon": "file-text",
},
)
"""每个聊天流最大保存的Plan/Reply日志数量超过此数量时会自动删除最老的日志"""
group_chat_prompt: str = Field(
default="""
你正在qq群里聊天下面是群里正在聊的内容其中包含聊天记录和聊天中的图片。
回复尽量简短一些。最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。请注意把握聊天内容。
不要回复的太频繁!控制回复的频率,不要每个人的消息都回复,只回复你感兴趣的或者主动提及你的。
""",
default="你需要控制自己发言的频率,如果是一对一聊天,可以以较均匀的频率发言;如果用户较多,不要每句都回复,控制回复频率,不要回复的太频繁!控制回复的频率,不要每个人的消息都回复。",
json_schema_extra={
"x-widget": "textarea",
"x-icon": "users",
@@ -253,11 +258,7 @@ class ChatConfig(ConfigBase):
"""_wrap_群聊通用注意事项"""
private_chat_prompts: str = Field(
default="""
你正在聊天,下面是正在聊的内容,其中包含聊天记录和聊天中的图片。
回复尽量简短一些。请注意把握聊天内容。
请考虑对方的发言频率,想法,思考自己何时回复以及回复内容。
""",
default="你需要控制自己发言的频率,可以以较均匀的频率发言。",
json_schema_extra={
"x-widget": "textarea",
"x-icon": "user",
@@ -663,15 +664,6 @@ class LearningItem(ConfigBase):
)
"""是否启用jargon学习"""
advanced_chosen: bool = Field(
default=False,
json_schema_extra={
"x-widget": "switch",
"x-icon": "sparkles",
},
)
"""是否启用基于子代理的二次表达方式选择"""
class ExpressionGroup(ConfigBase):
"""表达互通组配置类,若列表为空代表全局共享"""
@@ -701,7 +693,6 @@ class ExpressionConfig(ConfigBase):
use_expression=True,
enable_learning=True,
enable_jargon_learning=True,
advanced_chosen=False,
)
],
json_schema_extra={
@@ -1573,6 +1564,35 @@ class MaiSakaConfig(ConfigBase):
)
"""MaiSaka 使用的用户名称"""
tool_filter_task_name: str = Field(
default="utils",
json_schema_extra={
"x-widget": "input",
"x-icon": "sparkles",
},
)
"""工具筛选预判使用的模型任务名"""
tool_filter_threshold: int = Field(
default=20,
ge=1,
json_schema_extra={
"x-widget": "input",
"x-icon": "filter",
},
)
"""当可用工具总数超过该阈值时,先进行一轮工具筛选"""
tool_filter_max_keep: int = Field(
default=5,
ge=1,
json_schema_extra={
"x-widget": "input",
"x-icon": "list-filter",
},
)
"""工具筛选阶段最多保留的非内置工具数量"""
show_image_path: bool = Field(
default=True,
json_schema_extra={