feat:同步本地非算法改动到上游基线
保留反馈纠错、WebUI 与运行时增强。\n移除不应提交的 algorithm_redesign 设计目录及其专项测试。
This commit is contained in:
@@ -496,6 +496,77 @@ export interface MemoryDeleteOperationDetailPayload {
|
||||
error?: string
|
||||
}
|
||||
|
||||
export interface MemoryFeedbackAffectedCountsPayload {
|
||||
relations?: number
|
||||
stale_paragraphs?: number
|
||||
episode_sources?: number
|
||||
profile_person_ids?: number
|
||||
correction_paragraphs?: number
|
||||
corrected_relations?: number
|
||||
}
|
||||
|
||||
export interface MemoryFeedbackActionLogPayload {
|
||||
id: number
|
||||
task_id: number
|
||||
query_tool_id: string
|
||||
action_type: string
|
||||
target_hash: string
|
||||
reason?: string
|
||||
before_payload?: Record<string, unknown>
|
||||
after_payload?: Record<string, unknown>
|
||||
created_at?: number
|
||||
}
|
||||
|
||||
export interface MemoryFeedbackCorrectionSummaryPayload {
|
||||
task_id: number
|
||||
query_tool_id: string
|
||||
session_id: string
|
||||
query_text: string
|
||||
query_timestamp?: number
|
||||
task_status: string
|
||||
decision: string
|
||||
decision_confidence: number
|
||||
feedback_message_count: number
|
||||
rollback_status: string
|
||||
affected_counts: MemoryFeedbackAffectedCountsPayload
|
||||
created_at?: number
|
||||
updated_at?: number
|
||||
}
|
||||
|
||||
export interface MemoryFeedbackCorrectionDetailTaskPayload extends MemoryFeedbackCorrectionSummaryPayload {
|
||||
query_snapshot?: Record<string, unknown>
|
||||
decision_payload?: Record<string, unknown>
|
||||
rollback_plan_summary?: Record<string, unknown>
|
||||
rollback_result?: Record<string, unknown>
|
||||
rollback_error?: string
|
||||
rollback_requested_by?: string
|
||||
rollback_reason?: string
|
||||
rollback_requested_at?: number
|
||||
rolled_back_at?: number
|
||||
action_logs?: MemoryFeedbackActionLogPayload[]
|
||||
}
|
||||
|
||||
export interface MemoryFeedbackCorrectionListPayload {
|
||||
success: boolean
|
||||
items: MemoryFeedbackCorrectionSummaryPayload[]
|
||||
count?: number
|
||||
error?: string
|
||||
}
|
||||
|
||||
export interface MemoryFeedbackCorrectionDetailPayload {
|
||||
success: boolean
|
||||
task?: MemoryFeedbackCorrectionDetailTaskPayload | null
|
||||
error?: string
|
||||
}
|
||||
|
||||
export interface MemoryFeedbackCorrectionRollbackPayload {
|
||||
success: boolean
|
||||
already_rolled_back?: boolean
|
||||
result?: Record<string, unknown>
|
||||
task?: MemoryFeedbackCorrectionDetailTaskPayload | null
|
||||
error?: string
|
||||
}
|
||||
|
||||
export interface MemorySourceItemPayload {
|
||||
source: string
|
||||
paragraph_count?: number
|
||||
@@ -610,6 +681,49 @@ export async function getMemoryDeleteOperation(
|
||||
return requestJson<MemoryDeleteOperationDetailPayload>(`/delete/operations/${encodeURIComponent(operationId)}`)
|
||||
}
|
||||
|
||||
export async function getMemoryFeedbackCorrections(
|
||||
options?: {
|
||||
limit?: number
|
||||
status?: string
|
||||
rollbackStatus?: string
|
||||
query?: string
|
||||
},
|
||||
): Promise<MemoryFeedbackCorrectionListPayload> {
|
||||
const params = new URLSearchParams({
|
||||
limit: String(options?.limit ?? 50),
|
||||
})
|
||||
if (options?.status?.trim()) {
|
||||
params.set('status', options.status.trim())
|
||||
}
|
||||
if (options?.rollbackStatus?.trim()) {
|
||||
params.set('rollback_status', options.rollbackStatus.trim())
|
||||
}
|
||||
if (options?.query?.trim()) {
|
||||
params.set('query', options.query.trim())
|
||||
}
|
||||
return requestJson<MemoryFeedbackCorrectionListPayload>(`/feedback-corrections?${params.toString()}`)
|
||||
}
|
||||
|
||||
export async function getMemoryFeedbackCorrection(
|
||||
taskId: number,
|
||||
): Promise<MemoryFeedbackCorrectionDetailPayload> {
|
||||
return requestJson<MemoryFeedbackCorrectionDetailPayload>(`/feedback-corrections/${taskId}`)
|
||||
}
|
||||
|
||||
export async function rollbackMemoryFeedbackCorrection(
|
||||
taskId: number,
|
||||
payload: {
|
||||
requested_by?: string
|
||||
reason?: string
|
||||
},
|
||||
): Promise<MemoryFeedbackCorrectionRollbackPayload> {
|
||||
return requestJson<MemoryFeedbackCorrectionRollbackPayload>(`/feedback-corrections/${taskId}/rollback`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(payload),
|
||||
})
|
||||
}
|
||||
|
||||
export async function getMemorySources(): Promise<MemorySourceListPayload> {
|
||||
return requestJson<MemorySourceListPayload>('/sources')
|
||||
}
|
||||
|
||||
@@ -81,9 +81,12 @@ vi.mock('@/lib/memory-api', () => ({
|
||||
getMemorySources: vi.fn(),
|
||||
getMemoryDeleteOperations: vi.fn(),
|
||||
getMemoryDeleteOperation: vi.fn(),
|
||||
getMemoryFeedbackCorrections: vi.fn(),
|
||||
getMemoryFeedbackCorrection: vi.fn(),
|
||||
previewMemoryDelete: vi.fn(),
|
||||
executeMemoryDelete: vi.fn(),
|
||||
restoreMemoryDelete: vi.fn(),
|
||||
rollbackMemoryFeedbackCorrection: vi.fn(),
|
||||
}))
|
||||
|
||||
function mockImportTask(taskId: string, status: string = 'running'): memoryApi.MemoryImportTaskPayload {
|
||||
@@ -357,6 +360,82 @@ describe('KnowledgeBasePage import workflow', () => {
|
||||
items: [],
|
||||
},
|
||||
})
|
||||
vi.mocked(memoryApi.getMemoryFeedbackCorrections).mockResolvedValue({
|
||||
success: true,
|
||||
items: [
|
||||
{
|
||||
task_id: 11,
|
||||
query_tool_id: 'tool-query-11',
|
||||
session_id: 'session-1',
|
||||
query_text: '测试用户最喜欢的颜色是什么',
|
||||
query_timestamp: 1_710_000_010,
|
||||
task_status: 'applied',
|
||||
decision: 'correct',
|
||||
decision_confidence: 0.97,
|
||||
feedback_message_count: 1,
|
||||
rollback_status: 'none',
|
||||
affected_counts: {
|
||||
relations: 1,
|
||||
stale_paragraphs: 1,
|
||||
episode_sources: 2,
|
||||
profile_person_ids: 1,
|
||||
correction_paragraphs: 1,
|
||||
corrected_relations: 1,
|
||||
},
|
||||
created_at: 1_710_000_011,
|
||||
updated_at: 1_710_000_012,
|
||||
},
|
||||
],
|
||||
count: 1,
|
||||
})
|
||||
vi.mocked(memoryApi.getMemoryFeedbackCorrection).mockResolvedValue({
|
||||
success: true,
|
||||
task: {
|
||||
task_id: 11,
|
||||
query_tool_id: 'tool-query-11',
|
||||
session_id: 'session-1',
|
||||
query_text: '测试用户最喜欢的颜色是什么',
|
||||
query_timestamp: 1_710_000_010,
|
||||
task_status: 'applied',
|
||||
decision: 'correct',
|
||||
decision_confidence: 0.97,
|
||||
feedback_message_count: 1,
|
||||
rollback_status: 'none',
|
||||
affected_counts: {
|
||||
relations: 1,
|
||||
stale_paragraphs: 1,
|
||||
episode_sources: 2,
|
||||
profile_person_ids: 1,
|
||||
correction_paragraphs: 1,
|
||||
corrected_relations: 1,
|
||||
},
|
||||
query_snapshot: { query: '测试用户最喜欢的颜色是什么', hits: [{ hash: 'paragraph-1' }] },
|
||||
decision_payload: { decision: 'correct', confidence: 0.97 },
|
||||
rollback_plan_summary: {
|
||||
forgotten_relations: [{ hash: 'rel-old', subject: '测试用户', predicate: '最喜欢的颜色是', object: '蓝色' }],
|
||||
corrected_write: {
|
||||
paragraph_hashes: ['paragraph-new'],
|
||||
corrected_relations: [{ hash: 'rel-new', subject: '测试用户', predicate: '最喜欢的颜色是', object: '绿色' }],
|
||||
},
|
||||
},
|
||||
rollback_result: {},
|
||||
action_logs: [
|
||||
{
|
||||
id: 1,
|
||||
task_id: 11,
|
||||
query_tool_id: 'tool-query-11',
|
||||
action_type: 'forget_relation',
|
||||
target_hash: 'rel-old',
|
||||
reason: '用户明确纠正为绿色',
|
||||
before_payload: { hash: 'rel-old', subject: '测试用户', predicate: '最喜欢的颜色是', object: '蓝色' },
|
||||
after_payload: { is_inactive: true },
|
||||
created_at: 1_710_000_013,
|
||||
},
|
||||
],
|
||||
created_at: 1_710_000_011,
|
||||
updated_at: 1_710_000_012,
|
||||
},
|
||||
})
|
||||
vi.mocked(memoryApi.previewMemoryDelete).mockResolvedValue({
|
||||
success: true,
|
||||
mode: 'source',
|
||||
@@ -380,6 +459,37 @@ describe('KnowledgeBasePage import workflow', () => {
|
||||
deleted_source_count: 1,
|
||||
} as never)
|
||||
vi.mocked(memoryApi.restoreMemoryDelete).mockResolvedValue({ success: true } as never)
|
||||
vi.mocked(memoryApi.rollbackMemoryFeedbackCorrection).mockResolvedValue({
|
||||
success: true,
|
||||
result: { restored_relation_hashes: ['rel-old'] },
|
||||
task: {
|
||||
task_id: 11,
|
||||
query_tool_id: 'tool-query-11',
|
||||
session_id: 'session-1',
|
||||
query_text: '测试用户最喜欢的颜色是什么',
|
||||
query_timestamp: 1_710_000_010,
|
||||
task_status: 'applied',
|
||||
decision: 'correct',
|
||||
decision_confidence: 0.97,
|
||||
feedback_message_count: 1,
|
||||
rollback_status: 'rolled_back',
|
||||
affected_counts: {
|
||||
relations: 1,
|
||||
stale_paragraphs: 1,
|
||||
episode_sources: 2,
|
||||
profile_person_ids: 1,
|
||||
correction_paragraphs: 1,
|
||||
corrected_relations: 1,
|
||||
},
|
||||
query_snapshot: { query: '测试用户最喜欢的颜色是什么', hits: [{ hash: 'paragraph-1' }] },
|
||||
decision_payload: { decision: 'correct', confidence: 0.97 },
|
||||
rollback_plan_summary: {},
|
||||
rollback_result: { restored_relation_hashes: ['rel-old'] },
|
||||
action_logs: [],
|
||||
created_at: 1_710_000_011,
|
||||
updated_at: 1_710_000_012,
|
||||
},
|
||||
})
|
||||
vi.mocked(memoryApi.refreshMemoryRuntimeSelfCheck).mockResolvedValue({
|
||||
success: true,
|
||||
report: { ok: true },
|
||||
@@ -619,4 +729,27 @@ describe('KnowledgeBasePage import workflow', () => {
|
||||
}),
|
||||
)
|
||||
}, 20_000)
|
||||
|
||||
it('shows feedback correction history and supports rollback', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<KnowledgeBasePage />)
|
||||
|
||||
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
|
||||
await user.click(screen.getByRole('tab', { name: '纠错历史' }))
|
||||
await screen.findByText('反馈纠错历史')
|
||||
await screen.findByText('测试用户最喜欢的颜色是什么')
|
||||
await waitFor(() => expect(memoryApi.getMemoryFeedbackCorrection).toHaveBeenCalledWith(11))
|
||||
|
||||
await user.click(screen.getByRole('button', { name: '回退本次纠错' }))
|
||||
const rollbackReason = await screen.findByLabelText('回退原因')
|
||||
await user.type(rollbackReason, '人工确认回退')
|
||||
await user.click(screen.getByRole('button', { name: '确认回退' }))
|
||||
|
||||
await waitFor(() =>
|
||||
expect(memoryApi.rollbackMemoryFeedbackCorrection).toHaveBeenCalledWith(11, {
|
||||
requested_by: 'knowledge_base',
|
||||
reason: '人工确认回退',
|
||||
}),
|
||||
)
|
||||
}, 20_000)
|
||||
})
|
||||
|
||||
@@ -24,6 +24,14 @@ import { Badge } from '@/components/ui/badge'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'
|
||||
import { Checkbox } from '@/components/ui/checkbox'
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from '@/components/ui/dialog'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { Label } from '@/components/ui/label'
|
||||
import { Progress } from '@/components/ui/progress'
|
||||
@@ -48,6 +56,8 @@ import {
|
||||
createMemoryRawScanImport,
|
||||
createMemoryTemporalBackfillImport,
|
||||
executeMemoryDelete,
|
||||
getMemoryFeedbackCorrection,
|
||||
getMemoryFeedbackCorrections,
|
||||
getMemoryImportPathAliases,
|
||||
getMemoryImportSettings,
|
||||
getMemoryImportTask,
|
||||
@@ -74,6 +84,7 @@ import {
|
||||
type MemoryImportTaskPayload,
|
||||
previewMemoryDelete,
|
||||
refreshMemoryRuntimeSelfCheck,
|
||||
rollbackMemoryFeedbackCorrection,
|
||||
resolveMemoryImportPath,
|
||||
retryMemoryImportTask,
|
||||
restoreMemoryDelete,
|
||||
@@ -82,6 +93,9 @@ import {
|
||||
type MemoryConfigSchemaPayload,
|
||||
type MemoryDeleteExecutePayload,
|
||||
type MemoryDeleteOperationPayload,
|
||||
type MemoryFeedbackActionLogPayload,
|
||||
type MemoryFeedbackCorrectionDetailTaskPayload,
|
||||
type MemoryFeedbackCorrectionSummaryPayload,
|
||||
type MemorySourceItemPayload,
|
||||
type MemoryRuntimeConfigPayload,
|
||||
type MemoryTaskPayload,
|
||||
@@ -90,6 +104,9 @@ import {
|
||||
const DELETE_OPERATION_FETCH_LIMIT = 100
|
||||
const DELETE_OPERATION_PAGE_SIZE = 6
|
||||
const DELETE_OPERATION_ITEM_PAGE_SIZE = 8
|
||||
const FEEDBACK_CORRECTION_FETCH_LIMIT = 100
|
||||
const FEEDBACK_CORRECTION_PAGE_SIZE = 6
|
||||
const FEEDBACK_ACTION_LOG_PAGE_SIZE = 8
|
||||
const IMPORT_CHUNK_PAGE_SIZE = 50
|
||||
|
||||
const RUNNING_IMPORT_STATUS = new Set(['preparing', 'running', 'cancel_requested'])
|
||||
@@ -270,6 +287,90 @@ function formatDeleteOperationTime(timestamp?: number | null): string {
|
||||
})
|
||||
}
|
||||
|
||||
function formatFeedbackDecision(decision: string): string {
|
||||
switch (decision) {
|
||||
case 'correct':
|
||||
return '纠正'
|
||||
case 'reject':
|
||||
return '否定'
|
||||
case 'confirm':
|
||||
return '确认'
|
||||
case 'supplement':
|
||||
return '补充'
|
||||
case 'none':
|
||||
return '无动作'
|
||||
default:
|
||||
return decision || '未知'
|
||||
}
|
||||
}
|
||||
|
||||
function formatFeedbackTaskStatus(status: string): string {
|
||||
switch (status) {
|
||||
case 'pending':
|
||||
return '待处理'
|
||||
case 'running':
|
||||
return '处理中'
|
||||
case 'applied':
|
||||
return '已应用'
|
||||
case 'skipped':
|
||||
return '已跳过'
|
||||
case 'error':
|
||||
return '失败'
|
||||
default:
|
||||
return status || '未知'
|
||||
}
|
||||
}
|
||||
|
||||
function formatFeedbackRollbackStatus(status: string): string {
|
||||
switch (status) {
|
||||
case 'none':
|
||||
return '未回退'
|
||||
case 'running':
|
||||
return '回退中'
|
||||
case 'rolled_back':
|
||||
return '已回退'
|
||||
case 'error':
|
||||
return '回退失败'
|
||||
default:
|
||||
return status || '未知'
|
||||
}
|
||||
}
|
||||
|
||||
function getFeedbackStatusVariant(
|
||||
status: string,
|
||||
): 'default' | 'secondary' | 'destructive' | 'outline' {
|
||||
if (status === 'applied' || status === 'rolled_back') {
|
||||
return 'default'
|
||||
}
|
||||
if (status === 'error') {
|
||||
return 'destructive'
|
||||
}
|
||||
if (status === 'running' || status === 'pending') {
|
||||
return 'outline'
|
||||
}
|
||||
return 'secondary'
|
||||
}
|
||||
|
||||
function summarizeFeedbackActionPayload(value: Record<string, unknown> | undefined): string {
|
||||
if (!value) {
|
||||
return ''
|
||||
}
|
||||
const hash = String(value.hash ?? '').trim()
|
||||
const subject = String(value.subject ?? '').trim()
|
||||
const predicate = String(value.predicate ?? '').trim()
|
||||
const object = String(value.object ?? '').trim()
|
||||
if (subject && predicate && object) {
|
||||
return formatDeleteRelationText(subject, predicate, object)
|
||||
}
|
||||
if (hash) {
|
||||
return hash
|
||||
}
|
||||
if (Array.isArray(value.target_hashes) && value.target_hashes.length > 0) {
|
||||
return `targets ${value.target_hashes.length}`
|
||||
}
|
||||
return trimDeleteItemText(JSON.stringify(value, null, 2), 120)
|
||||
}
|
||||
|
||||
type DeleteOperationItem = NonNullable<MemoryDeleteOperationPayload['items']>[number]
|
||||
|
||||
function trimDeleteItemText(value: string, maxLength: number = 140): string {
|
||||
@@ -471,6 +572,20 @@ export function KnowledgeBasePage() {
|
||||
const [deleteRestoring, setDeleteRestoring] = useState(false)
|
||||
const [deleteResult, setDeleteResult] = useState<MemoryDeleteExecutePayload | null>(null)
|
||||
const [pendingDeleteRequest, setPendingDeleteRequest] = useState<MemoryDeleteRequestPayload | null>(null)
|
||||
const [feedbackCorrections, setFeedbackCorrections] = useState<MemoryFeedbackCorrectionSummaryPayload[]>([])
|
||||
const [feedbackSearch, setFeedbackSearch] = useState('')
|
||||
const [feedbackStatusFilter, setFeedbackStatusFilter] = useState('all')
|
||||
const [feedbackRollbackFilter, setFeedbackRollbackFilter] = useState('all')
|
||||
const [feedbackPage, setFeedbackPage] = useState(1)
|
||||
const [selectedFeedbackTaskId, setSelectedFeedbackTaskId] = useState(0)
|
||||
const [selectedFeedbackTaskDetail, setSelectedFeedbackTaskDetail] = useState<MemoryFeedbackCorrectionDetailTaskPayload | null>(null)
|
||||
const [selectedFeedbackTaskLoading, setSelectedFeedbackTaskLoading] = useState(false)
|
||||
const [selectedFeedbackTaskError, setSelectedFeedbackTaskError] = useState('')
|
||||
const [feedbackActionLogSearch, setFeedbackActionLogSearch] = useState('')
|
||||
const [feedbackActionLogPage, setFeedbackActionLogPage] = useState(1)
|
||||
const [feedbackRollbackDialogOpen, setFeedbackRollbackDialogOpen] = useState(false)
|
||||
const [feedbackRollbackReason, setFeedbackRollbackReason] = useState('')
|
||||
const [feedbackRollingBack, setFeedbackRollingBack] = useState(false)
|
||||
const [tuningObjective, setTuningObjective] = useState('precision_priority')
|
||||
const [tuningIntensity, setTuningIntensity] = useState('standard')
|
||||
const [tuningSampleSize, setTuningSampleSize] = useState('24')
|
||||
@@ -491,6 +606,7 @@ export function KnowledgeBasePage() {
|
||||
tuningTaskPayload,
|
||||
sourcePayload,
|
||||
deleteOperationPayload,
|
||||
feedbackCorrectionPayload,
|
||||
] = await Promise.all([
|
||||
getMemoryConfigSchema(),
|
||||
getMemoryConfig(),
|
||||
@@ -503,6 +619,7 @@ export function KnowledgeBasePage() {
|
||||
getMemoryTuningTasks(20),
|
||||
getMemorySources(),
|
||||
getMemoryDeleteOperations(DELETE_OPERATION_FETCH_LIMIT),
|
||||
getMemoryFeedbackCorrections({ limit: FEEDBACK_CORRECTION_FETCH_LIMIT }),
|
||||
])
|
||||
|
||||
setSchemaPayload(schema)
|
||||
@@ -519,6 +636,7 @@ export function KnowledgeBasePage() {
|
||||
setTuningTasks(tuningTaskPayload.items ?? [])
|
||||
setMemorySources(sourcePayload.items ?? [])
|
||||
setDeleteOperations(deleteOperationPayload.items ?? [])
|
||||
setFeedbackCorrections(feedbackCorrectionPayload.items ?? [])
|
||||
if (!selectedImportTaskId && (importTaskPayload.items ?? []).length > 0) {
|
||||
const initialTaskId = String(importTaskPayload.items?.[0]?.task_id ?? '')
|
||||
if (initialTaskId) {
|
||||
@@ -1494,6 +1612,212 @@ export function KnowledgeBasePage() {
|
||||
setDeleteDialogOpen(true)
|
||||
}, [])
|
||||
|
||||
const filteredFeedbackCorrections = useMemo(() => {
|
||||
const keyword = feedbackSearch.trim().toLowerCase()
|
||||
return feedbackCorrections.filter((item) => {
|
||||
const taskStatus = String(item.task_status ?? '').trim().toLowerCase()
|
||||
const rollbackStatus = String(item.rollback_status ?? '').trim().toLowerCase()
|
||||
if (feedbackStatusFilter !== 'all' && taskStatus !== feedbackStatusFilter) {
|
||||
return false
|
||||
}
|
||||
if (feedbackRollbackFilter !== 'all' && rollbackStatus !== feedbackRollbackFilter) {
|
||||
return false
|
||||
}
|
||||
if (!keyword) {
|
||||
return true
|
||||
}
|
||||
return [
|
||||
item.query_tool_id,
|
||||
item.session_id,
|
||||
item.query_text,
|
||||
item.decision,
|
||||
item.task_status,
|
||||
item.rollback_status,
|
||||
]
|
||||
.map((value) => String(value ?? '').toLowerCase())
|
||||
.some((value) => value.includes(keyword))
|
||||
})
|
||||
}, [feedbackCorrections, feedbackRollbackFilter, feedbackSearch, feedbackStatusFilter])
|
||||
|
||||
const feedbackPageCount = Math.max(1, Math.ceil(filteredFeedbackCorrections.length / FEEDBACK_CORRECTION_PAGE_SIZE))
|
||||
const pagedFeedbackCorrections = useMemo(() => {
|
||||
const start = (feedbackPage - 1) * FEEDBACK_CORRECTION_PAGE_SIZE
|
||||
return filteredFeedbackCorrections.slice(start, start + FEEDBACK_CORRECTION_PAGE_SIZE)
|
||||
}, [feedbackPage, filteredFeedbackCorrections])
|
||||
|
||||
const selectedFeedbackCorrection = useMemo(
|
||||
() =>
|
||||
filteredFeedbackCorrections.find((item) => item.task_id === selectedFeedbackTaskId)
|
||||
?? pagedFeedbackCorrections[0]
|
||||
?? null,
|
||||
[filteredFeedbackCorrections, pagedFeedbackCorrections, selectedFeedbackTaskId],
|
||||
)
|
||||
|
||||
useEffect(() => {
|
||||
setFeedbackPage(1)
|
||||
}, [feedbackSearch, feedbackStatusFilter, feedbackRollbackFilter])
|
||||
|
||||
useEffect(() => {
|
||||
if (feedbackPage > feedbackPageCount) {
|
||||
setFeedbackPage(feedbackPageCount)
|
||||
}
|
||||
}, [feedbackPage, feedbackPageCount])
|
||||
|
||||
useEffect(() => {
|
||||
if (!selectedFeedbackCorrection) {
|
||||
if (selectedFeedbackTaskId) {
|
||||
setSelectedFeedbackTaskId(0)
|
||||
}
|
||||
setSelectedFeedbackTaskDetail(null)
|
||||
setSelectedFeedbackTaskError('')
|
||||
return
|
||||
}
|
||||
if (selectedFeedbackCorrection.task_id !== selectedFeedbackTaskId) {
|
||||
setSelectedFeedbackTaskId(selectedFeedbackCorrection.task_id)
|
||||
}
|
||||
}, [selectedFeedbackCorrection, selectedFeedbackTaskId])
|
||||
|
||||
useEffect(() => {
|
||||
const taskId = selectedFeedbackCorrection?.task_id
|
||||
if (!taskId) {
|
||||
setSelectedFeedbackTaskDetail(null)
|
||||
setSelectedFeedbackTaskError('')
|
||||
return
|
||||
}
|
||||
|
||||
let cancelled = false
|
||||
setSelectedFeedbackTaskLoading(true)
|
||||
setSelectedFeedbackTaskError('')
|
||||
|
||||
void getMemoryFeedbackCorrection(taskId)
|
||||
.then((payload) => {
|
||||
if (cancelled) {
|
||||
return
|
||||
}
|
||||
if (!payload.success || !payload.task) {
|
||||
setSelectedFeedbackTaskDetail(null)
|
||||
setSelectedFeedbackTaskError(payload.error || '未能加载纠错任务详情')
|
||||
return
|
||||
}
|
||||
setSelectedFeedbackTaskDetail(payload.task)
|
||||
})
|
||||
.catch((error) => {
|
||||
if (cancelled) {
|
||||
return
|
||||
}
|
||||
setSelectedFeedbackTaskDetail(null)
|
||||
setSelectedFeedbackTaskError(error instanceof Error ? error.message : '未能加载纠错任务详情')
|
||||
})
|
||||
.finally(() => {
|
||||
if (!cancelled) {
|
||||
setSelectedFeedbackTaskLoading(false)
|
||||
}
|
||||
})
|
||||
|
||||
return () => {
|
||||
cancelled = true
|
||||
}
|
||||
}, [selectedFeedbackCorrection?.task_id])
|
||||
|
||||
const selectedFeedbackResolved = useMemo(() => {
|
||||
if (!selectedFeedbackCorrection) {
|
||||
return null
|
||||
}
|
||||
if (selectedFeedbackTaskDetail?.task_id === selectedFeedbackCorrection.task_id) {
|
||||
return {
|
||||
...selectedFeedbackCorrection,
|
||||
...selectedFeedbackTaskDetail,
|
||||
} satisfies MemoryFeedbackCorrectionDetailTaskPayload
|
||||
}
|
||||
return selectedFeedbackTaskDetail ?? selectedFeedbackCorrection
|
||||
}, [selectedFeedbackCorrection, selectedFeedbackTaskDetail])
|
||||
|
||||
const selectedFeedbackActionLogs = Array.isArray(selectedFeedbackResolved?.action_logs)
|
||||
? selectedFeedbackResolved.action_logs
|
||||
: []
|
||||
const filteredFeedbackActionLogs = useMemo(() => {
|
||||
const keyword = feedbackActionLogSearch.trim().toLowerCase()
|
||||
if (!keyword) {
|
||||
return selectedFeedbackActionLogs
|
||||
}
|
||||
return selectedFeedbackActionLogs.filter((item) =>
|
||||
[
|
||||
item.action_type,
|
||||
item.target_hash,
|
||||
item.reason,
|
||||
summarizeFeedbackActionPayload(item.before_payload),
|
||||
summarizeFeedbackActionPayload(item.after_payload),
|
||||
]
|
||||
.map((value) => String(value ?? '').toLowerCase())
|
||||
.some((value) => value.includes(keyword)),
|
||||
)
|
||||
}, [feedbackActionLogSearch, selectedFeedbackActionLogs])
|
||||
const feedbackActionLogPageCount = Math.max(
|
||||
1,
|
||||
Math.ceil(filteredFeedbackActionLogs.length / FEEDBACK_ACTION_LOG_PAGE_SIZE),
|
||||
)
|
||||
const pagedFeedbackActionLogs = useMemo(() => {
|
||||
const start = (feedbackActionLogPage - 1) * FEEDBACK_ACTION_LOG_PAGE_SIZE
|
||||
return filteredFeedbackActionLogs.slice(start, start + FEEDBACK_ACTION_LOG_PAGE_SIZE)
|
||||
}, [feedbackActionLogPage, filteredFeedbackActionLogs])
|
||||
|
||||
useEffect(() => {
|
||||
setFeedbackActionLogPage(1)
|
||||
}, [selectedFeedbackTaskId, feedbackActionLogSearch])
|
||||
|
||||
useEffect(() => {
|
||||
if (feedbackActionLogPage > feedbackActionLogPageCount) {
|
||||
setFeedbackActionLogPage(feedbackActionLogPageCount)
|
||||
}
|
||||
}, [feedbackActionLogPage, feedbackActionLogPageCount])
|
||||
|
||||
const openFeedbackRollbackDialog = useCallback(() => {
|
||||
setFeedbackRollbackReason('')
|
||||
setFeedbackRollbackDialogOpen(true)
|
||||
}, [])
|
||||
|
||||
const executeFeedbackRollback = useCallback(async () => {
|
||||
const taskId = selectedFeedbackResolved?.task_id
|
||||
if (!taskId) {
|
||||
return
|
||||
}
|
||||
try {
|
||||
setFeedbackRollingBack(true)
|
||||
const payload = await rollbackMemoryFeedbackCorrection(taskId, {
|
||||
requested_by: 'knowledge_base',
|
||||
reason: feedbackRollbackReason.trim(),
|
||||
})
|
||||
if (!payload.success) {
|
||||
throw new Error(payload.error || '回退失败')
|
||||
}
|
||||
toast({
|
||||
title: payload.already_rolled_back ? '该纠错已回退' : '纠错回退成功',
|
||||
description: `任务 ${taskId} 的回退结果已写入日志`,
|
||||
})
|
||||
setFeedbackRollbackDialogOpen(false)
|
||||
const [listPayload, detailPayload] = await Promise.all([
|
||||
getMemoryFeedbackCorrections({ limit: FEEDBACK_CORRECTION_FETCH_LIMIT }),
|
||||
getMemoryFeedbackCorrection(taskId),
|
||||
])
|
||||
setFeedbackCorrections(listPayload.items ?? [])
|
||||
setSelectedFeedbackTaskDetail(detailPayload.task ?? null)
|
||||
const [sourcePayload, runtimePayload] = await Promise.all([
|
||||
getMemorySources(),
|
||||
getMemoryRuntimeConfig(),
|
||||
])
|
||||
setMemorySources(sourcePayload.items ?? [])
|
||||
setRuntimeConfig(runtimePayload)
|
||||
} catch (error) {
|
||||
toast({
|
||||
title: '纠错回退失败',
|
||||
description: error instanceof Error ? error.message : '未知错误',
|
||||
variant: 'destructive',
|
||||
})
|
||||
} finally {
|
||||
setFeedbackRollingBack(false)
|
||||
}
|
||||
}, [feedbackRollbackReason, selectedFeedbackResolved?.task_id, toast])
|
||||
|
||||
const selectedOperationResolved = useMemo(() => {
|
||||
if (!selectedDeleteOperation) {
|
||||
return null
|
||||
@@ -1776,6 +2100,9 @@ export function KnowledgeBasePage() {
|
||||
<TabsTrigger value="delete" className="rounded-lg px-4 py-1.5">
|
||||
删除
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="feedback" className="rounded-lg px-4 py-1.5">
|
||||
纠错历史
|
||||
</TabsTrigger>
|
||||
</TabsList>
|
||||
|
||||
<TabsContent value="overview" className="space-y-4">
|
||||
@@ -3314,6 +3641,327 @@ export function KnowledgeBasePage() {
|
||||
</Card>
|
||||
</div>
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value="feedback" className="space-y-4">
|
||||
<div className="space-y-4">
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle className="flex items-center gap-2">
|
||||
<RotateCcw className="h-4 w-4" />
|
||||
反馈纠错历史
|
||||
</CardTitle>
|
||||
<CardDescription>
|
||||
查看 feedback correction 的判定、修改轨迹与回退结果;本期仅覆盖自动纠错任务
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-4">
|
||||
<div className="grid gap-3 lg:grid-cols-[minmax(0,1fr)_180px_180px]">
|
||||
<Input
|
||||
value={feedbackSearch}
|
||||
onChange={(event) => setFeedbackSearch(event.target.value)}
|
||||
placeholder="搜索 query_tool_id / session / query / reason"
|
||||
/>
|
||||
<Select value={feedbackStatusFilter} onValueChange={setFeedbackStatusFilter}>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder="按任务状态筛选" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="all">全部任务状态</SelectItem>
|
||||
<SelectItem value="applied">已应用</SelectItem>
|
||||
<SelectItem value="skipped">已跳过</SelectItem>
|
||||
<SelectItem value="error">失败</SelectItem>
|
||||
<SelectItem value="running">处理中</SelectItem>
|
||||
<SelectItem value="pending">待处理</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<Select value={feedbackRollbackFilter} onValueChange={setFeedbackRollbackFilter}>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder="按回退状态筛选" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="all">全部回退状态</SelectItem>
|
||||
<SelectItem value="none">未回退</SelectItem>
|
||||
<SelectItem value="rolled_back">已回退</SelectItem>
|
||||
<SelectItem value="error">回退失败</SelectItem>
|
||||
<SelectItem value="running">回退中</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-wrap items-center justify-between gap-2 text-sm text-muted-foreground">
|
||||
<span>当前命中 {filteredFeedbackCorrections.length} 条记录,已加载最近 {feedbackCorrections.length} 条</span>
|
||||
<span>第 {feedbackPage} / {feedbackPageCount} 页,每页显示 {FEEDBACK_CORRECTION_PAGE_SIZE} 条</span>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-4 xl:grid-cols-[minmax(0,0.92fr)_minmax(0,1.08fr)]">
|
||||
<ScrollArea className="h-[720px] rounded-lg border">
|
||||
<div className="space-y-3 p-3">
|
||||
{pagedFeedbackCorrections.length > 0 ? pagedFeedbackCorrections.map((item) => {
|
||||
const isSelected = selectedFeedbackCorrection?.task_id === item.task_id
|
||||
return (
|
||||
<button
|
||||
key={item.task_id}
|
||||
type="button"
|
||||
onClick={() => setSelectedFeedbackTaskId(item.task_id)}
|
||||
className={cn(
|
||||
'w-full rounded-xl border p-4 text-left transition-colors',
|
||||
isSelected
|
||||
? 'border-primary bg-primary/5 shadow-sm'
|
||||
: 'bg-muted/20 hover:border-primary/40 hover:bg-muted/40',
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-col gap-3">
|
||||
<div className="flex flex-wrap items-center gap-2">
|
||||
<Badge variant={getFeedbackStatusVariant(item.task_status)}>
|
||||
{formatFeedbackTaskStatus(item.task_status)}
|
||||
</Badge>
|
||||
<Badge variant={getFeedbackStatusVariant(item.rollback_status)}>
|
||||
{formatFeedbackRollbackStatus(item.rollback_status)}
|
||||
</Badge>
|
||||
<Badge variant="outline">
|
||||
{formatFeedbackDecision(item.decision)}
|
||||
</Badge>
|
||||
</div>
|
||||
<div className="text-sm font-medium break-words">
|
||||
{item.query_text || '无查询文本'}
|
||||
</div>
|
||||
<div className="font-mono text-[11px] break-all text-muted-foreground">
|
||||
{item.query_tool_id}
|
||||
</div>
|
||||
<div className="flex flex-wrap gap-3 text-xs text-muted-foreground">
|
||||
<span>关系 {Number(item.affected_counts?.relations ?? 0)}</span>
|
||||
<span>脏段落 {Number(item.affected_counts?.stale_paragraphs ?? 0)}</span>
|
||||
<span>Episode {Number(item.affected_counts?.episode_sources ?? 0)}</span>
|
||||
<span>Profile {Number(item.affected_counts?.profile_person_ids ?? 0)}</span>
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{formatDeleteOperationTime(item.query_timestamp ?? item.created_at)}
|
||||
</div>
|
||||
</div>
|
||||
</button>
|
||||
)
|
||||
}) : (
|
||||
<div className="rounded-lg border border-dashed bg-muted/20 p-6 text-center text-sm text-muted-foreground">
|
||||
当前筛选条件下没有纠错历史
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
|
||||
<div className="rounded-xl border bg-muted/20 p-4">
|
||||
{selectedFeedbackCorrection ? (
|
||||
<div className="space-y-4">
|
||||
<div className="flex flex-col gap-3 lg:flex-row lg:items-start lg:justify-between">
|
||||
<div className="space-y-2">
|
||||
<div className="flex flex-wrap items-center gap-2">
|
||||
<Badge variant={getFeedbackStatusVariant(String(selectedFeedbackResolved?.task_status ?? ''))}>
|
||||
{formatFeedbackTaskStatus(String(selectedFeedbackResolved?.task_status ?? ''))}
|
||||
</Badge>
|
||||
<Badge variant={getFeedbackStatusVariant(String(selectedFeedbackResolved?.rollback_status ?? 'none'))}>
|
||||
{formatFeedbackRollbackStatus(String(selectedFeedbackResolved?.rollback_status ?? 'none'))}
|
||||
</Badge>
|
||||
<Badge variant="outline">
|
||||
{formatFeedbackDecision(String(selectedFeedbackResolved?.decision ?? ''))}
|
||||
</Badge>
|
||||
</div>
|
||||
<div className="text-sm font-medium break-words">
|
||||
{selectedFeedbackResolved?.query_text || '无查询文本'}
|
||||
</div>
|
||||
<div className="font-mono text-xs break-all">
|
||||
{selectedFeedbackResolved?.query_tool_id}
|
||||
</div>
|
||||
</div>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onClick={openFeedbackRollbackDialog}
|
||||
disabled={
|
||||
String(selectedFeedbackResolved?.task_status ?? '') !== 'applied'
|
||||
|| String(selectedFeedbackResolved?.rollback_status ?? 'none') === 'rolled_back'
|
||||
|| feedbackRollingBack
|
||||
}
|
||||
>
|
||||
<RotateCcw className="mr-2 h-4 w-4" />
|
||||
{String(selectedFeedbackResolved?.rollback_status ?? 'none') === 'rolled_back'
|
||||
? '已回退'
|
||||
: '回退本次纠错'}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-3 lg:grid-cols-4">
|
||||
<div className="rounded-lg border bg-background/60 p-3">
|
||||
<div className="text-xs text-muted-foreground">会话</div>
|
||||
<div className="mt-1 text-sm break-all">{selectedFeedbackResolved?.session_id || '-'}</div>
|
||||
</div>
|
||||
<div className="rounded-lg border bg-background/60 p-3">
|
||||
<div className="text-xs text-muted-foreground">反馈消息数</div>
|
||||
<div className="mt-1 text-sm">{Number(selectedFeedbackResolved?.feedback_message_count ?? 0)}</div>
|
||||
</div>
|
||||
<div className="rounded-lg border bg-background/60 p-3">
|
||||
<div className="text-xs text-muted-foreground">判定置信度</div>
|
||||
<div className="mt-1 text-sm">{Number(selectedFeedbackResolved?.decision_confidence ?? 0).toFixed(2)}</div>
|
||||
</div>
|
||||
<div className="rounded-lg border bg-background/60 p-3">
|
||||
<div className="text-xs text-muted-foreground">回退时间</div>
|
||||
<div className="mt-1 text-sm">{formatDeleteOperationTime(selectedFeedbackResolved?.rolled_back_at)}</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{selectedFeedbackTaskLoading ? (
|
||||
<div className="rounded-lg border bg-background/60 p-4 text-sm text-muted-foreground">
|
||||
正在加载纠错详情...
|
||||
</div>
|
||||
) : null}
|
||||
|
||||
{selectedFeedbackTaskError ? (
|
||||
<Alert variant="destructive">
|
||||
<AlertDescription>{selectedFeedbackTaskError}</AlertDescription>
|
||||
</Alert>
|
||||
) : null}
|
||||
|
||||
{selectedFeedbackResolved?.rollback_error ? (
|
||||
<Alert variant="destructive">
|
||||
<AlertDescription>{selectedFeedbackResolved.rollback_error}</AlertDescription>
|
||||
</Alert>
|
||||
) : null}
|
||||
|
||||
<div className="grid gap-4 xl:grid-cols-[minmax(0,0.95fr)_minmax(0,1.05fr)]">
|
||||
<div className="space-y-2">
|
||||
<div className="text-sm font-semibold">查询快照</div>
|
||||
<pre className="max-h-56 overflow-auto rounded-lg border bg-background/70 p-3 text-xs break-words whitespace-pre-wrap">
|
||||
{JSON.stringify(selectedFeedbackResolved?.query_snapshot ?? {}, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<div className="text-sm font-semibold">判定结果</div>
|
||||
<pre className="max-h-56 overflow-auto rounded-lg border bg-background/70 p-3 text-xs break-words whitespace-pre-wrap">
|
||||
{JSON.stringify(selectedFeedbackResolved?.decision_payload ?? {}, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-4 xl:grid-cols-[minmax(0,0.95fr)_minmax(0,1.05fr)]">
|
||||
<div className="space-y-2">
|
||||
<div className="text-sm font-semibold">回退计划摘要</div>
|
||||
<pre className="max-h-64 overflow-auto rounded-lg border bg-background/70 p-3 text-xs break-words whitespace-pre-wrap">
|
||||
{JSON.stringify(selectedFeedbackResolved?.rollback_plan_summary ?? {}, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<div className="text-sm font-semibold">回退结果</div>
|
||||
<pre className="max-h-64 overflow-auto rounded-lg border bg-background/70 p-3 text-xs break-words whitespace-pre-wrap">
|
||||
{JSON.stringify(selectedFeedbackResolved?.rollback_result ?? {}, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<div className="flex flex-col gap-3 lg:flex-row lg:items-center lg:justify-between">
|
||||
<div className="text-sm font-semibold">动作时间线</div>
|
||||
<div className="flex flex-col gap-3 lg:flex-row lg:items-center lg:justify-end">
|
||||
<Input
|
||||
value={feedbackActionLogSearch}
|
||||
onChange={(event) => setFeedbackActionLogSearch(event.target.value)}
|
||||
placeholder="搜索动作 / hash / 预览内容"
|
||||
className="lg:w-80"
|
||||
/>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
第 {feedbackActionLogPage} / {feedbackActionLogPageCount} 页,每页 {FEEDBACK_ACTION_LOG_PAGE_SIZE} 项
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<ScrollArea className="h-[280px] rounded-lg border bg-background/60">
|
||||
<div className="space-y-2 p-3">
|
||||
{pagedFeedbackActionLogs.length > 0 ? pagedFeedbackActionLogs.map((item: MemoryFeedbackActionLogPayload) => (
|
||||
<div key={`${item.id}:${item.action_type}`} className="rounded-lg border bg-muted/20 p-3">
|
||||
<div className="flex flex-wrap items-center gap-2">
|
||||
<Badge variant="outline">{item.action_type}</Badge>
|
||||
{item.target_hash ? (
|
||||
<span className="font-mono text-[11px] break-all text-muted-foreground">{item.target_hash}</span>
|
||||
) : null}
|
||||
</div>
|
||||
{item.reason ? (
|
||||
<div className="mt-2 text-xs text-muted-foreground break-words">
|
||||
{item.reason}
|
||||
</div>
|
||||
) : null}
|
||||
{item.before_payload && Object.keys(item.before_payload).length > 0 ? (
|
||||
<div className="mt-2 text-xs break-words">
|
||||
<span className="font-medium">Before:</span>
|
||||
<span className="text-muted-foreground">{summarizeFeedbackActionPayload(item.before_payload)}</span>
|
||||
</div>
|
||||
) : null}
|
||||
{item.after_payload && Object.keys(item.after_payload).length > 0 ? (
|
||||
<div className="mt-1 text-xs break-words">
|
||||
<span className="font-medium">After:</span>
|
||||
<span className="text-muted-foreground">{summarizeFeedbackActionPayload(item.after_payload)}</span>
|
||||
</div>
|
||||
) : null}
|
||||
<div className="mt-2 text-[11px] text-muted-foreground">
|
||||
{formatDeleteOperationTime(item.created_at)}
|
||||
</div>
|
||||
</div>
|
||||
)) : (
|
||||
<div className="rounded-lg border border-dashed bg-muted/20 p-6 text-center text-sm text-muted-foreground">
|
||||
{selectedFeedbackActionLogs.length > 0 ? '当前筛选条件下没有动作日志' : '当前任务没有动作日志'}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => setFeedbackActionLogPage((current) => Math.max(1, current - 1))}
|
||||
disabled={feedbackActionLogPage <= 1}
|
||||
>
|
||||
上一页
|
||||
</Button>
|
||||
<div className="text-xs text-muted-foreground">支持按动作类型、hash 和摘要检索</div>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => setFeedbackActionLogPage((current) => Math.min(feedbackActionLogPageCount, current + 1))}
|
||||
disabled={feedbackActionLogPage >= feedbackActionLogPageCount}
|
||||
>
|
||||
下一页
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex min-h-[360px] items-center justify-center rounded-lg border border-dashed bg-background/40 p-6 text-center text-sm text-muted-foreground">
|
||||
当前没有可查看的纠错详情
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => setFeedbackPage((current) => Math.max(1, current - 1))}
|
||||
disabled={feedbackPage <= 1}
|
||||
>
|
||||
上一页
|
||||
</Button>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
支持按 query、任务状态和回退状态检索
|
||||
</div>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => setFeedbackPage((current) => Math.min(feedbackPageCount, current + 1))}
|
||||
disabled={feedbackPage >= feedbackPageCount}
|
||||
>
|
||||
下一页
|
||||
</Button>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
</TabsContent>
|
||||
</Tabs>
|
||||
</div>
|
||||
</div>
|
||||
@@ -3332,6 +3980,52 @@ export function KnowledgeBasePage() {
|
||||
onExecute={() => void executePendingDelete()}
|
||||
onRestore={() => void (deleteResult?.operation_id ? restoreDeleteOperation(deleteResult.operation_id) : Promise.resolve())}
|
||||
/>
|
||||
|
||||
<Dialog open={feedbackRollbackDialogOpen} onOpenChange={setFeedbackRollbackDialogOpen}>
|
||||
<DialogContent className="max-w-lg" confirmOnEnter>
|
||||
<DialogHeader>
|
||||
<DialogTitle>回退本次纠错</DialogTitle>
|
||||
<DialogDescription>
|
||||
这会恢复旧 relation 状态、隐藏纠错写入段落,并重新触发 episode/profile 的异步修复。
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<div className="space-y-3">
|
||||
<div className="rounded-lg border bg-muted/20 p-3 text-sm">
|
||||
<div className="font-medium break-words">{selectedFeedbackResolved?.query_text || '无查询文本'}</div>
|
||||
<div className="mt-1 font-mono text-[11px] break-all text-muted-foreground">
|
||||
{selectedFeedbackResolved?.query_tool_id}
|
||||
</div>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="feedback-rollback-reason">回退原因</Label>
|
||||
<Textarea
|
||||
id="feedback-rollback-reason"
|
||||
value={feedbackRollbackReason}
|
||||
onChange={(event) => setFeedbackRollbackReason(event.target.value)}
|
||||
placeholder="可选,建议填写本次人工回退原因"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<DialogFooter>
|
||||
<Button variant="outline" onClick={() => setFeedbackRollbackDialogOpen(false)} disabled={feedbackRollingBack}>
|
||||
取消
|
||||
</Button>
|
||||
<Button onClick={() => void executeFeedbackRollback()} disabled={feedbackRollingBack}>
|
||||
{feedbackRollingBack ? (
|
||||
<>
|
||||
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
|
||||
回退中
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<RotateCcw className="mr-2 h-4 w-4" />
|
||||
确认回退
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
740
pytests/A_memorix_test/test_feedback_correction_chat_flow.py
Normal file
740
pytests/A_memorix_test/test_feedback_correction_chat_flow.py
Normal file
@@ -0,0 +1,740 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import Session, create_engine, select
|
||||
|
||||
IMPORT_ERROR: str | None = None
|
||||
|
||||
try:
|
||||
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
|
||||
from src.A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
||||
from src.chat.heart_flow.heartflow_manager import heartflow_manager
|
||||
from src.chat.message_receive import bot as bot_module
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
from src.common.database import database as database_module
|
||||
from src.common.database.database_model import PersonInfo, ToolRecord
|
||||
from src.common.database.migrations import create_database_migration_bootstrapper
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
from src.maisaka import reasoning_engine as reasoning_engine_module
|
||||
from src.maisaka import runtime as runtime_module
|
||||
from src.maisaka.chat_loop_service import ChatResponse
|
||||
from src.maisaka.context_messages import AssistantMessage
|
||||
from src.plugin_runtime import component_query as component_query_module
|
||||
from src.services import memory_flow_service as memory_flow_service_module
|
||||
from src.services import memory_service as memory_service_module
|
||||
from src.services.memory_service import memory_service
|
||||
except SystemExit as exc:
|
||||
IMPORT_ERROR = f"config initialization exited during import: {exc}"
|
||||
kernel_module = None # type: ignore[assignment]
|
||||
KernelSearchRequest = None # type: ignore[assignment]
|
||||
SDKMemoryKernel = None # type: ignore[assignment]
|
||||
heartflow_manager = None # type: ignore[assignment]
|
||||
bot_module = None # type: ignore[assignment]
|
||||
chat_manager = None # type: ignore[assignment]
|
||||
chat_bot = None # type: ignore[assignment]
|
||||
database_module = None # type: ignore[assignment]
|
||||
ToolRecord = None # type: ignore[assignment]
|
||||
PersonInfo = None # type: ignore[assignment]
|
||||
create_database_migration_bootstrapper = None # type: ignore[assignment]
|
||||
SessionUtils = None # type: ignore[assignment]
|
||||
ToolCall = None # type: ignore[assignment]
|
||||
reasoning_engine_module = None # type: ignore[assignment]
|
||||
runtime_module = None # type: ignore[assignment]
|
||||
ChatResponse = None # type: ignore[assignment]
|
||||
AssistantMessage = None # type: ignore[assignment]
|
||||
component_query_module = None # type: ignore[assignment]
|
||||
memory_flow_service_module = None # type: ignore[assignment]
|
||||
memory_service_module = None # type: ignore[assignment]
|
||||
memory_service = None # type: ignore[assignment]
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
|
||||
|
||||
RELATION_QUERY = "测试用户 和 最喜欢的颜色 有什么关系"
|
||||
|
||||
|
||||
class _FakeEmbeddingManager:
|
||||
def __init__(self, dimension: int = 8) -> None:
|
||||
self.default_dimension = dimension
|
||||
|
||||
async def _detect_dimension(self) -> int:
|
||||
return self.default_dimension
|
||||
|
||||
async def encode(self, text: Any) -> np.ndarray:
|
||||
def _encode_one(raw: Any) -> np.ndarray:
|
||||
content = str(raw or "")
|
||||
vector = np.zeros(self.default_dimension, dtype=np.float32)
|
||||
for index, byte in enumerate(content.encode("utf-8")):
|
||||
vector[index % self.default_dimension] += float((byte % 17) + 1)
|
||||
norm = float(np.linalg.norm(vector))
|
||||
if norm > 0:
|
||||
vector /= norm
|
||||
return vector
|
||||
|
||||
if isinstance(text, (list, tuple)):
|
||||
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
|
||||
return _encode_one(text).astype(np.float32)
|
||||
|
||||
|
||||
class _KernelBackedRuntimeManager:
|
||||
def __init__(self, kernel: SDKMemoryKernel) -> None:
|
||||
self.kernel = kernel
|
||||
|
||||
async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000):
|
||||
del timeout_ms
|
||||
payload = args or {}
|
||||
if component_name == "search_memory":
|
||||
return await self.kernel.search_memory(
|
||||
KernelSearchRequest(
|
||||
query=str(payload.get("query", "") or ""),
|
||||
limit=int(payload.get("limit", 5) or 5),
|
||||
mode=str(payload.get("mode", "hybrid") or "hybrid"),
|
||||
chat_id=str(payload.get("chat_id", "") or ""),
|
||||
person_id=str(payload.get("person_id", "") or ""),
|
||||
time_start=payload.get("time_start"),
|
||||
time_end=payload.get("time_end"),
|
||||
respect_filter=bool(payload.get("respect_filter", True)),
|
||||
user_id=str(payload.get("user_id", "") or ""),
|
||||
group_id=str(payload.get("group_id", "") or ""),
|
||||
)
|
||||
)
|
||||
|
||||
handler = getattr(self.kernel, component_name)
|
||||
result = handler(**payload)
|
||||
return await result if inspect.isawaitable(result) else result
|
||||
|
||||
|
||||
class _NoopRuntimeManager:
|
||||
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> Any:
|
||||
del hook_name
|
||||
return SimpleNamespace(aborted=False, kwargs=kwargs)
|
||||
|
||||
|
||||
def _install_temp_main_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
db_dir = (tmp_path / "main_db").resolve()
|
||||
db_dir.mkdir(parents=True, exist_ok=True)
|
||||
db_file = db_dir / "MaiBot.db"
|
||||
database_url = f"sqlite:///{db_file}"
|
||||
|
||||
try:
|
||||
database_module.engine.dispose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
engine = create_engine(
|
||||
database_url,
|
||||
echo=False,
|
||||
connect_args={"check_same_thread": False},
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
session_local = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=engine,
|
||||
class_=Session,
|
||||
)
|
||||
bootstrapper = create_database_migration_bootstrapper(engine)
|
||||
|
||||
monkeypatch.setattr(database_module, "_DB_DIR", db_dir, raising=False)
|
||||
monkeypatch.setattr(database_module, "_DB_FILE", db_file, raising=False)
|
||||
monkeypatch.setattr(database_module, "DATABASE_URL", database_url, raising=False)
|
||||
monkeypatch.setattr(database_module, "engine", engine, raising=False)
|
||||
monkeypatch.setattr(database_module, "SessionLocal", session_local, raising=False)
|
||||
monkeypatch.setattr(database_module, "_migration_bootstrapper", bootstrapper, raising=False)
|
||||
monkeypatch.setattr(database_module, "_db_initialized", False, raising=False)
|
||||
|
||||
|
||||
def _build_chat_response(content: str, tool_calls: list[ToolCall]) -> ChatResponse:
|
||||
return ChatResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
request_messages=[],
|
||||
raw_message=AssistantMessage(
|
||||
content=content,
|
||||
timestamp=datetime.now(),
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
selected_history_count=0,
|
||||
tool_count=len(tool_calls),
|
||||
prompt_tokens=0,
|
||||
built_message_count=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
prompt_section=None,
|
||||
)
|
||||
|
||||
|
||||
def _build_message_data(
|
||||
*,
|
||||
content: str,
|
||||
platform: str,
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
group_id: str,
|
||||
group_name: str,
|
||||
) -> Dict[str, Any]:
|
||||
message_id = str(uuid.uuid4())
|
||||
return {
|
||||
"message_info": {
|
||||
"platform": platform,
|
||||
"message_id": message_id,
|
||||
"time": time.time(),
|
||||
"group_info": {
|
||||
"group_id": group_id,
|
||||
"group_name": group_name,
|
||||
"platform": platform,
|
||||
},
|
||||
"user_info": {
|
||||
"user_id": user_id,
|
||||
"user_nickname": user_name,
|
||||
"user_cardname": user_name,
|
||||
"platform": platform,
|
||||
},
|
||||
"additional_config": {
|
||||
"at_bot": True,
|
||||
},
|
||||
},
|
||||
"message_segment": {
|
||||
"type": "seglist",
|
||||
"data": [
|
||||
{
|
||||
"type": "text",
|
||||
"data": content,
|
||||
},
|
||||
],
|
||||
},
|
||||
"raw_message": content,
|
||||
"processed_plain_text": content,
|
||||
}
|
||||
|
||||
|
||||
async def _wait_until(
|
||||
predicate: Callable[[], Any],
|
||||
*,
|
||||
timeout_seconds: float = 10.0,
|
||||
interval_seconds: float = 0.05,
|
||||
description: str,
|
||||
) -> Any:
|
||||
deadline = asyncio.get_running_loop().time() + max(0.5, float(timeout_seconds))
|
||||
while asyncio.get_running_loop().time() < deadline:
|
||||
value = predicate()
|
||||
if inspect.isawaitable(value):
|
||||
value = await value
|
||||
if value:
|
||||
return value
|
||||
await asyncio.sleep(interval_seconds)
|
||||
raise AssertionError(f"等待超时: {description}")
|
||||
|
||||
|
||||
def _load_feedback_tasks(kernel: SDKMemoryKernel) -> list[Dict[str, Any]]:
|
||||
assert kernel.metadata_store is not None
|
||||
cursor = kernel.metadata_store.get_connection().cursor()
|
||||
rows = cursor.execute(
|
||||
"SELECT query_tool_id FROM memory_feedback_tasks ORDER BY id"
|
||||
).fetchall()
|
||||
tasks: list[Dict[str, Any]] = []
|
||||
for row in rows:
|
||||
task = kernel.metadata_store.get_feedback_task(str(row["query_tool_id"] or ""))
|
||||
if task is not None:
|
||||
tasks.append(task)
|
||||
return tasks
|
||||
|
||||
|
||||
def _load_feedback_action_types(kernel: SDKMemoryKernel) -> list[str]:
|
||||
assert kernel.metadata_store is not None
|
||||
cursor = kernel.metadata_store.get_connection().cursor()
|
||||
rows = cursor.execute(
|
||||
"SELECT action_type FROM memory_feedback_action_logs ORDER BY id"
|
||||
).fetchall()
|
||||
return [str(row["action_type"] or "") for row in rows]
|
||||
|
||||
|
||||
def _load_query_memory_tool_records(session_id: str) -> list[Dict[str, Any]]:
|
||||
with database_module.get_db_session() as session:
|
||||
statement = (
|
||||
select(ToolRecord)
|
||||
.where(ToolRecord.session_id == session_id)
|
||||
.where(ToolRecord.tool_name == "query_memory")
|
||||
.order_by(ToolRecord.timestamp)
|
||||
)
|
||||
rows = list(session.exec(statement).all())
|
||||
return [
|
||||
{
|
||||
"tool_id": str(row.tool_id or ""),
|
||||
"session_id": str(row.session_id or ""),
|
||||
"tool_name": str(row.tool_name or ""),
|
||||
"tool_data": str(row.tool_data or ""),
|
||||
"timestamp": row.timestamp,
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
|
||||
def _seed_person_info(*, person_id: str, person_name: str, session_info: Dict[str, Any]) -> None:
|
||||
with database_module.get_db_session() as session:
|
||||
session.add(
|
||||
PersonInfo(
|
||||
is_known=True,
|
||||
person_id=person_id,
|
||||
person_name=person_name,
|
||||
platform=str(session_info["platform"]),
|
||||
user_id=str(session_info["user_id"]),
|
||||
user_nickname=str(session_info["user_name"]),
|
||||
group_cardname=json.dumps(
|
||||
[{"group_id": str(session_info["group_id"]), "group_cardname": person_name}],
|
||||
ensure_ascii=False,
|
||||
),
|
||||
know_counts=1,
|
||||
first_known_time=datetime.now(),
|
||||
last_known_time=datetime.now(),
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def chat_feedback_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
|
||||
_install_temp_main_database(monkeypatch, tmp_path)
|
||||
|
||||
chat_manager.sessions.clear()
|
||||
chat_manager.last_messages.clear()
|
||||
heartflow_manager.heartflow_chat_list.clear()
|
||||
|
||||
noop_runtime_manager = _NoopRuntimeManager()
|
||||
monkeypatch.setattr(bot_module.ChatBot, "_get_runtime_manager", staticmethod(lambda: noop_runtime_manager))
|
||||
monkeypatch.setattr(
|
||||
component_query_module.component_query_service,
|
||||
"find_command_by_text",
|
||||
lambda text: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
component_query_module.component_query_service,
|
||||
"get_llm_available_tool_specs",
|
||||
lambda: {},
|
||||
)
|
||||
monkeypatch.setattr(runtime_module.global_config.mcp, "enable", False, raising=False)
|
||||
monkeypatch.setattr(
|
||||
runtime_module.MaisakaHeartFlowChatting,
|
||||
"_get_message_trigger_threshold",
|
||||
lambda self: 1,
|
||||
)
|
||||
|
||||
async def _noop_on_incoming_message(message: Any) -> None:
|
||||
del message
|
||||
|
||||
monkeypatch.setattr(
|
||||
memory_flow_service_module.memory_automation_service,
|
||||
"on_incoming_message",
|
||||
_noop_on_incoming_message,
|
||||
)
|
||||
|
||||
fake_embedding_manager = _FakeEmbeddingManager(dimension=8)
|
||||
|
||||
async def _fake_runtime_self_check(
|
||||
*,
|
||||
config: Any,
|
||||
sample_text: str,
|
||||
vector_store: Any,
|
||||
embedding_manager: Any,
|
||||
) -> Dict[str, Any]:
|
||||
del config, sample_text, vector_store, embedding_manager
|
||||
return {
|
||||
"ok": True,
|
||||
"message": "ok",
|
||||
"checked_at": time.time(),
|
||||
"encoded_dimension": fake_embedding_manager.default_dimension,
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
kernel_module,
|
||||
"create_embedding_api_adapter",
|
||||
lambda **kwargs: fake_embedding_manager,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
kernel_module,
|
||||
"run_embedding_runtime_self_check",
|
||||
_fake_runtime_self_check,
|
||||
)
|
||||
|
||||
kernel = SDKMemoryKernel(
|
||||
plugin_root=tmp_path / "plugin_root",
|
||||
config={
|
||||
"storage": {"data_dir": str((tmp_path / "a_memorix_data").resolve())},
|
||||
"advanced": {"enable_auto_save": False},
|
||||
"embedding": {"dimension": fake_embedding_manager.default_dimension},
|
||||
"memory": {"base_decay_interval_hours": 24},
|
||||
"person_profile": {"refresh_interval_minutes": 5},
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_enabled", lambda: True)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_window_hours", lambda: 0.0004)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_check_interval_seconds", lambda: 0.2)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_batch_size", lambda: 10)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_max_messages", lambda: 10)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_auto_apply_threshold", lambda: 0.85)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_prefilter_enabled", lambda: True)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_paragraph_mark_enabled", lambda: True)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_paragraph_hard_filter_enabled", lambda: True)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_profile_refresh_enabled", lambda: True)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_profile_force_refresh_on_read", lambda: True)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_episode_rebuild_enabled", lambda: True)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_episode_query_block_enabled", lambda: True)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_reconcile_interval_seconds", lambda: 0.2)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_reconcile_batch_size", lambda: 10)
|
||||
|
||||
monkeypatch.setattr(kernel_module.global_config.memory, "feedback_correction_paragraph_hard_filter_enabled", True, raising=False)
|
||||
monkeypatch.setattr(kernel_module.global_config.memory, "feedback_correction_episode_query_block_enabled", True, raising=False)
|
||||
|
||||
async def _fake_classify_feedback(
|
||||
*,
|
||||
query_tool_id: str,
|
||||
query_text: str,
|
||||
hit_briefs: list[Dict[str, Any]],
|
||||
feedback_messages: list[str],
|
||||
) -> Dict[str, Any]:
|
||||
del query_tool_id, query_text, feedback_messages
|
||||
target_hash = ""
|
||||
for item in hit_briefs:
|
||||
if str(item.get("type", "") or "").strip() == "relation":
|
||||
target_hash = str(item.get("hash", "") or "").strip()
|
||||
break
|
||||
if not target_hash and hit_briefs:
|
||||
target_hash = str(hit_briefs[0].get("hash", "") or "").strip()
|
||||
return {
|
||||
"decision": "correct",
|
||||
"confidence": 0.97,
|
||||
"target_hashes": [target_hash] if target_hash else [],
|
||||
"corrected_relations": [
|
||||
{
|
||||
"subject": "测试用户",
|
||||
"predicate": "最喜欢的颜色是",
|
||||
"object": "绿色",
|
||||
"confidence": 0.99,
|
||||
}
|
||||
],
|
||||
"reason": "用户明确纠正为绿色",
|
||||
}
|
||||
|
||||
monkeypatch.setattr(kernel, "_classify_feedback", _fake_classify_feedback)
|
||||
|
||||
await kernel.initialize()
|
||||
async def _force_episode_fallback(**kwargs: Any) -> Dict[str, Any]:
|
||||
raise RuntimeError("force_fallback_for_test")
|
||||
|
||||
monkeypatch.setattr(
|
||||
kernel.episode_service.segmentation_service,
|
||||
"segment",
|
||||
_force_episode_fallback,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
kernel,
|
||||
"process_episode_pending_batch",
|
||||
lambda *, limit=20, max_retry=3: asyncio.sleep(0, result={"processed": 0, "episode_count": 0, "fallback_count": 0, "failed": 0}),
|
||||
)
|
||||
|
||||
host_manager = _KernelBackedRuntimeManager(kernel)
|
||||
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", host_manager)
|
||||
|
||||
planner_calls: list[str] = []
|
||||
|
||||
async def _fake_timing_gate(self, anchor_message: Any):
|
||||
del self, anchor_message
|
||||
return "continue", _build_chat_response("直接进入 planner。", []), []
|
||||
|
||||
async def _fake_planner(self, *, tool_definitions: list[dict[str, Any]] | None = None) -> ChatResponse:
|
||||
del tool_definitions
|
||||
latest_message = self._runtime.message_cache[-1]
|
||||
latest_text = str(latest_message.processed_plain_text or "")
|
||||
planner_calls.append(latest_text)
|
||||
handled_message_ids = getattr(self._runtime, "_test_query_message_ids", None)
|
||||
if handled_message_ids is None:
|
||||
handled_message_ids = set()
|
||||
setattr(self._runtime, "_test_query_message_ids", handled_message_ids)
|
||||
|
||||
if latest_message.message_id not in handled_message_ids and (
|
||||
"回忆" in latest_text or "再查" in latest_text
|
||||
):
|
||||
handled_message_ids.add(latest_message.message_id)
|
||||
tool_call = ToolCall(
|
||||
call_id=f"query-{uuid.uuid4().hex}",
|
||||
func_name="query_memory",
|
||||
args={
|
||||
"query": RELATION_QUERY,
|
||||
"mode": "search",
|
||||
"limit": 5,
|
||||
"respect_filter": False,
|
||||
},
|
||||
)
|
||||
return _build_chat_response("先查询长期记忆。", [tool_call])
|
||||
|
||||
stop_call = ToolCall(
|
||||
call_id=f"stop-{uuid.uuid4().hex}",
|
||||
func_name="no_reply",
|
||||
args={},
|
||||
)
|
||||
return _build_chat_response("当前轮次结束。", [stop_call])
|
||||
|
||||
monkeypatch.setattr(
|
||||
reasoning_engine_module.MaisakaReasoningEngine,
|
||||
"_run_timing_gate",
|
||||
_fake_timing_gate,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
reasoning_engine_module.MaisakaReasoningEngine,
|
||||
"_run_interruptible_planner",
|
||||
_fake_planner,
|
||||
)
|
||||
|
||||
session_info = {
|
||||
"platform": "unit_test_chat",
|
||||
"user_id": "user_feedback_flow",
|
||||
"user_name": "反馈测试用户",
|
||||
"group_id": "group_feedback_flow",
|
||||
"group_name": "反馈纠错测试群",
|
||||
}
|
||||
person_id = "person_feedback_flow"
|
||||
session_id = SessionUtils.calculate_session_id(
|
||||
session_info["platform"],
|
||||
user_id=session_info["user_id"],
|
||||
group_id=session_info["group_id"],
|
||||
)
|
||||
_seed_person_info(person_id=person_id, person_name="测试用户", session_info=session_info)
|
||||
|
||||
try:
|
||||
yield {
|
||||
"kernel": kernel,
|
||||
"session_id": session_id,
|
||||
"session_info": session_info,
|
||||
"person_id": person_id,
|
||||
"planner_calls": planner_calls,
|
||||
}
|
||||
finally:
|
||||
for key, chat in list(heartflow_manager.heartflow_chat_list.items()):
|
||||
try:
|
||||
await chat.stop()
|
||||
except Exception:
|
||||
pass
|
||||
heartflow_manager.heartflow_chat_list.pop(key, None)
|
||||
chat_manager.sessions.clear()
|
||||
chat_manager.last_messages.clear()
|
||||
await kernel.shutdown()
|
||||
try:
|
||||
database_module.engine.dispose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_feedback_correction_real_chat_flow(chat_feedback_env) -> None:
|
||||
kernel = chat_feedback_env["kernel"]
|
||||
session_id = chat_feedback_env["session_id"]
|
||||
session_info = chat_feedback_env["session_info"]
|
||||
person_id = chat_feedback_env["person_id"]
|
||||
|
||||
write_result = await memory_service.ingest_text(
|
||||
external_id=f"test:feedback-seed:{uuid.uuid4().hex}",
|
||||
source_type="chat_summary",
|
||||
text="测试用户 最喜欢的颜色是 蓝色",
|
||||
chat_id=session_id,
|
||||
relations=[
|
||||
{
|
||||
"subject": "测试用户",
|
||||
"predicate": "最喜欢的颜色是",
|
||||
"object": "蓝色",
|
||||
"confidence": 1.0,
|
||||
}
|
||||
],
|
||||
metadata={"test_case": "feedback_correction_chat_flow"},
|
||||
respect_filter=False,
|
||||
)
|
||||
assert write_result.success is True
|
||||
|
||||
pre_search = await memory_service.search(
|
||||
RELATION_QUERY,
|
||||
mode="search",
|
||||
chat_id=session_id,
|
||||
respect_filter=False,
|
||||
)
|
||||
assert pre_search.hits
|
||||
assert any("蓝色" in hit.content for hit in pre_search.hits)
|
||||
|
||||
pre_profile = await memory_service.get_person_profile(person_id, chat_id=session_id, limit=10)
|
||||
pre_profile_text = pre_profile.summary + "\n" + json.dumps(pre_profile.evidence, ensure_ascii=False)
|
||||
assert "蓝色" in pre_profile_text
|
||||
|
||||
seed_source = f"chat_summary:{session_id}"
|
||||
rebuild_result = await kernel.rebuild_episodes_for_sources([seed_source])
|
||||
assert rebuild_result["rebuilt"] >= 1
|
||||
|
||||
pre_episode = await memory_service.search(
|
||||
"蓝色",
|
||||
mode="episode",
|
||||
chat_id=session_id,
|
||||
respect_filter=False,
|
||||
)
|
||||
assert pre_episode.hits
|
||||
assert any("蓝色" in hit.content for hit in pre_episode.hits)
|
||||
|
||||
await chat_bot.message_process(
|
||||
_build_message_data(
|
||||
content="请帮我回忆一下,测试用户最喜欢的颜色是什么?",
|
||||
**session_info,
|
||||
)
|
||||
)
|
||||
|
||||
await _wait_until(
|
||||
lambda: chat_feedback_env["planner_calls"][0] if chat_feedback_env["planner_calls"] else None,
|
||||
description="planner 收到首条聊天消息",
|
||||
)
|
||||
first_query_records = await _wait_until(
|
||||
lambda: _load_query_memory_tool_records(session_id) if _load_query_memory_tool_records(session_id) else None,
|
||||
description="首条 query_memory 工具记录生成",
|
||||
)
|
||||
assert first_query_records
|
||||
|
||||
first_task = await _wait_until(
|
||||
lambda: _load_feedback_tasks(kernel)[0] if _load_feedback_tasks(kernel) else None,
|
||||
description="首个反馈任务入队",
|
||||
)
|
||||
assert first_task["status"] == "pending"
|
||||
first_hits = list((first_task.get("query_snapshot") or {}).get("hits") or [])
|
||||
assert first_hits
|
||||
assert any("蓝色" in str(item.get("content", "") or "") for item in first_hits)
|
||||
|
||||
await chat_bot.message_process(
|
||||
_build_message_data(
|
||||
content="不对,测试用户最喜欢的颜色不是蓝色,是绿色。",
|
||||
**session_info,
|
||||
)
|
||||
)
|
||||
|
||||
finalized_task = await _wait_until(
|
||||
lambda: (
|
||||
kernel.metadata_store.get_feedback_task(first_task["query_tool_id"])
|
||||
if kernel.metadata_store.get_feedback_task(first_task["query_tool_id"])
|
||||
and kernel.metadata_store.get_feedback_task(first_task["query_tool_id"]).get("status")
|
||||
in {"applied", "skipped", "error"}
|
||||
else None
|
||||
),
|
||||
timeout_seconds=12.0,
|
||||
interval_seconds=0.1,
|
||||
description="反馈任务进入终态",
|
||||
)
|
||||
assert finalized_task["status"] == "applied", finalized_task
|
||||
assert finalized_task["decision_payload"]["decision"] == "correct"
|
||||
assert finalized_task["decision_payload"]["apply_result"]["applied"] is True
|
||||
|
||||
corrected_hashes = list(
|
||||
(finalized_task["decision_payload"].get("apply_result") or {}).get("relation_hashes") or []
|
||||
)
|
||||
assert corrected_hashes
|
||||
corrected_hash = str(corrected_hashes[0] or "")
|
||||
relation_status = kernel.metadata_store.get_relation_status_batch([corrected_hash]).get(corrected_hash, {})
|
||||
assert bool(relation_status.get("is_inactive")) is True
|
||||
|
||||
action_types = _load_feedback_action_types(kernel)
|
||||
assert "classification" in action_types
|
||||
assert "forget_relation" in action_types
|
||||
assert "ingest_correction" in action_types
|
||||
assert "mark_stale_paragraph" in action_types
|
||||
assert "enqueue_episode_rebuild" in action_types
|
||||
assert "enqueue_profile_refresh" in action_types
|
||||
|
||||
direct_post_search = await memory_service.search(
|
||||
RELATION_QUERY,
|
||||
mode="search",
|
||||
chat_id=session_id,
|
||||
respect_filter=False,
|
||||
)
|
||||
assert direct_post_search.hits
|
||||
post_contents = "\n".join(hit.content for hit in direct_post_search.hits)
|
||||
assert "绿色" in post_contents
|
||||
assert "蓝色" not in post_contents
|
||||
|
||||
profile_refresh_request = await _wait_until(
|
||||
lambda: (
|
||||
kernel.metadata_store.get_person_profile_refresh_request(person_id)
|
||||
if kernel.metadata_store.get_person_profile_refresh_request(person_id)
|
||||
and kernel.metadata_store.get_person_profile_refresh_request(person_id).get("status") == "done"
|
||||
else None
|
||||
),
|
||||
timeout_seconds=12.0,
|
||||
interval_seconds=0.1,
|
||||
description="人物画像刷新完成",
|
||||
)
|
||||
assert profile_refresh_request["status"] == "done"
|
||||
|
||||
post_profile = await memory_service.get_person_profile(person_id, chat_id=session_id, limit=10)
|
||||
post_profile_text = post_profile.summary + "\n" + json.dumps(post_profile.evidence, ensure_ascii=False)
|
||||
assert "绿色" in post_profile_text
|
||||
assert "蓝色" not in post_profile_text
|
||||
|
||||
async def _latest_episode_result():
|
||||
result = await memory_service.search(
|
||||
"绿色",
|
||||
mode="episode",
|
||||
chat_id=session_id,
|
||||
respect_filter=False,
|
||||
)
|
||||
if not result.hits:
|
||||
return None
|
||||
contents = "\n".join(hit.content for hit in result.hits)
|
||||
if "绿色" in contents and "蓝色" not in contents:
|
||||
return result
|
||||
return None
|
||||
|
||||
post_episode = await _wait_until(
|
||||
_latest_episode_result,
|
||||
timeout_seconds=12.0,
|
||||
interval_seconds=0.2,
|
||||
description="episode 重建后返回修正结果",
|
||||
)
|
||||
assert post_episode is not None
|
||||
|
||||
stale_episode = await memory_service.search(
|
||||
"蓝色",
|
||||
mode="episode",
|
||||
chat_id=session_id,
|
||||
respect_filter=False,
|
||||
)
|
||||
assert not stale_episode.hits
|
||||
|
||||
await chat_bot.message_process(
|
||||
_build_message_data(
|
||||
content="再查一次,测试用户最喜欢的颜色是什么?",
|
||||
**session_info,
|
||||
)
|
||||
)
|
||||
|
||||
tool_records = await _wait_until(
|
||||
lambda: (
|
||||
_load_query_memory_tool_records(session_id)
|
||||
if len(_load_query_memory_tool_records(session_id)) >= 2
|
||||
else None
|
||||
),
|
||||
timeout_seconds=10.0,
|
||||
interval_seconds=0.1,
|
||||
description="第二次 query_memory 工具记录生成",
|
||||
)
|
||||
latest_tool_data = json.loads(str(tool_records[-1].get("tool_data") or "{}"))
|
||||
latest_hits = list((latest_tool_data.get("structured_content") or {}).get("hits") or [])
|
||||
assert latest_hits
|
||||
latest_contents = "\n".join(str(item.get("content", "") or "") for item in latest_hits)
|
||||
assert "绿色" in latest_contents
|
||||
assert "蓝色" not in latest_contents
|
||||
396
pytests/A_memorix_test/test_feedback_correction_core.py
Normal file
396
pytests/A_memorix_test/test_feedback_correction_core.py
Normal file
@@ -0,0 +1,396 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
|
||||
IMPORT_ERROR: str | None = None
|
||||
|
||||
try:
|
||||
from src.A_memorix.core.retrieval.sparse_bm25 import SparseBM25Config, SparseBM25Index
|
||||
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
|
||||
from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
|
||||
except SystemExit as exc:
|
||||
IMPORT_ERROR = f"config initialization exited during import: {exc}"
|
||||
SparseBM25Config = None # type: ignore[assignment]
|
||||
SparseBM25Index = None # type: ignore[assignment]
|
||||
kernel_module = None # type: ignore[assignment]
|
||||
SDKMemoryKernel = None # type: ignore[assignment]
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kernel_enqueue_feedback_task_delegates_to_metadata_store(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
def fake_enqueue_feedback_task(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return {
|
||||
"id": 1,
|
||||
"query_tool_id": kwargs["query_tool_id"],
|
||||
"session_id": kwargs["session_id"],
|
||||
"query_timestamp": kwargs["query_timestamp"],
|
||||
"due_at": kwargs["due_at"],
|
||||
"query_snapshot": kwargs["query_snapshot"],
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
kernel_module,
|
||||
"global_config",
|
||||
SimpleNamespace(
|
||||
memory=SimpleNamespace(
|
||||
feedback_correction_enabled=True,
|
||||
feedback_correction_window_hours=12.0,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
query_time = datetime(2026, 4, 9, 10, 30, 0)
|
||||
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
|
||||
kernel.metadata_store = SimpleNamespace(enqueue_feedback_task=fake_enqueue_feedback_task)
|
||||
|
||||
payload = await kernel.enqueue_feedback_task(
|
||||
query_tool_id="tool-query-1",
|
||||
session_id="session-1",
|
||||
query_timestamp=query_time,
|
||||
structured_content={"query": "Alice 喜欢什么", "hits": [{"hash": "relation-1"}]},
|
||||
)
|
||||
|
||||
assert payload["success"] is True
|
||||
assert payload["queued"] is True
|
||||
assert captured["query_tool_id"] == "tool-query-1"
|
||||
assert captured["session_id"] == "session-1"
|
||||
assert captured["query_snapshot"]["query"] == "Alice 喜欢什么"
|
||||
assert captured["query_snapshot"]["hits"] == [{"hash": "relation-1"}]
|
||||
assert captured["due_at"] == pytest.approx(query_time.timestamp() + 12 * 3600, rel=0, abs=1e-6)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kernel_enqueue_feedback_task_skipped_when_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
kernel_module,
|
||||
"global_config",
|
||||
SimpleNamespace(memory=SimpleNamespace(feedback_correction_enabled=False)),
|
||||
)
|
||||
|
||||
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
|
||||
kernel.metadata_store = SimpleNamespace(enqueue_feedback_task=lambda **kwargs: kwargs)
|
||||
|
||||
payload = await kernel.enqueue_feedback_task(
|
||||
query_tool_id="tool-query-2",
|
||||
session_id="session-1",
|
||||
query_timestamp=datetime.now(),
|
||||
structured_content={"hits": [{"hash": "relation-1"}]},
|
||||
)
|
||||
|
||||
assert payload["success"] is False
|
||||
assert payload["reason"] == "feedback_correction_disabled"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_feedback_decision_resolves_paragraph_targets() -> None:
|
||||
action_logs: list[Dict[str, Any]] = []
|
||||
forgotten_hashes: list[str] = []
|
||||
ingested_payloads: list[Dict[str, Any]] = []
|
||||
stale_marks: list[Dict[str, Any]] = []
|
||||
episode_sources: list[str] = []
|
||||
profile_refresh_ids: list[str] = []
|
||||
|
||||
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
|
||||
kernel.metadata_store = SimpleNamespace(
|
||||
get_paragraph_relations=lambda paragraph_hash: [
|
||||
{
|
||||
"hash": "relation-1",
|
||||
"subject": "测试用户",
|
||||
"predicate": "最喜欢的颜色是",
|
||||
"object": "蓝色",
|
||||
}
|
||||
]
|
||||
if paragraph_hash == "paragraph-1"
|
||||
else [],
|
||||
get_relation_status_batch=lambda hashes: {
|
||||
str(hash_value): {"is_inactive": str(hash_value) in forgotten_hashes}
|
||||
for hash_value in hashes
|
||||
},
|
||||
get_paragraph_hashes_by_relation_hashes=lambda hashes: {
|
||||
"relation-1": ["paragraph-1"]
|
||||
}
|
||||
if "relation-1" in hashes
|
||||
else {},
|
||||
upsert_paragraph_stale_relation_mark=lambda **kwargs: stale_marks.append(kwargs) or kwargs,
|
||||
enqueue_episode_source_rebuild=lambda source, reason="": episode_sources.append(source) or True,
|
||||
enqueue_person_profile_refresh=lambda **kwargs: profile_refresh_ids.append(kwargs["person_id"]) or kwargs,
|
||||
get_paragraph=lambda paragraph_hash: {"hash": "paragraph-1", "source": "chat_feedback_test_seed:session-1"}
|
||||
if paragraph_hash == "paragraph-1"
|
||||
else None,
|
||||
append_feedback_action_log=lambda **kwargs: action_logs.append(kwargs),
|
||||
)
|
||||
kernel._feedback_cfg_auto_apply_threshold = lambda: 0.85 # type: ignore[method-assign]
|
||||
kernel._apply_v5_relation_action = lambda *, action, hashes, strength=1.0: ( # type: ignore[method-assign]
|
||||
forgotten_hashes.extend([str(item) for item in hashes]),
|
||||
{"success": True, "action": action, "hashes": list(hashes), "strength": strength},
|
||||
)[1]
|
||||
kernel._feedback_cfg_paragraph_mark_enabled = lambda: True # type: ignore[method-assign]
|
||||
kernel._feedback_cfg_episode_rebuild_enabled = lambda: True # type: ignore[method-assign]
|
||||
kernel._feedback_cfg_profile_refresh_enabled = lambda: True # type: ignore[method-assign]
|
||||
kernel._resolve_feedback_related_person_ids = lambda **kwargs: ["person-1"] # type: ignore[method-assign]
|
||||
kernel._query_relation_rows_by_hashes = lambda relation_hashes, include_inactive=False: [ # type: ignore[method-assign]
|
||||
{
|
||||
"hash": "relation-1",
|
||||
"subject": "测试用户",
|
||||
"predicate": "最喜欢的颜色是",
|
||||
"object": "蓝色",
|
||||
}
|
||||
]
|
||||
|
||||
async def _fake_ingest_feedback_relations(**kwargs):
|
||||
ingested_payloads.append(kwargs)
|
||||
return {"success": True, "stored_ids": ["relation-2"]}
|
||||
|
||||
kernel._ingest_feedback_relations = _fake_ingest_feedback_relations # type: ignore[method-assign]
|
||||
|
||||
payload = await kernel._apply_feedback_decision(
|
||||
task_id=1,
|
||||
query_tool_id="tool-query-1",
|
||||
session_id="session-1",
|
||||
decision={
|
||||
"decision": "correct",
|
||||
"confidence": 0.97,
|
||||
"target_hashes": ["paragraph-1"],
|
||||
"corrected_relations": [
|
||||
{
|
||||
"subject": "测试用户",
|
||||
"predicate": "最喜欢的颜色是",
|
||||
"object": "绿色",
|
||||
"confidence": 0.99,
|
||||
}
|
||||
],
|
||||
"reason": "用户明确纠正为绿色",
|
||||
},
|
||||
hit_map={
|
||||
"paragraph-1": {
|
||||
"hash": "paragraph-1",
|
||||
"type": "paragraph",
|
||||
"content": "测试用户 最喜欢的颜色是 蓝色",
|
||||
"linked_relation_hashes": ["relation-1"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert payload["applied"] is True
|
||||
assert payload["relation_hashes"] == ["relation-1"]
|
||||
assert forgotten_hashes == ["relation-1"]
|
||||
assert ingested_payloads[0]["relation_hashes"] == ["relation-1"]
|
||||
assert payload["stale_paragraph_hashes"] == ["paragraph-1"]
|
||||
assert "chat_feedback_test_seed:session-1" in payload["episode_rebuild_sources"]
|
||||
assert "chat_summary:session-1" in payload["episode_rebuild_sources"]
|
||||
assert payload["profile_refresh_person_ids"] == ["person-1"]
|
||||
assert stale_marks[0]["paragraph_hash"] == "paragraph-1"
|
||||
assert {item["action_type"] for item in action_logs} == {
|
||||
"forget_relation",
|
||||
"ingest_correction",
|
||||
"mark_stale_paragraph",
|
||||
"enqueue_episode_rebuild",
|
||||
"enqueue_profile_refresh",
|
||||
}
|
||||
|
||||
|
||||
def test_filter_active_relation_hits_removes_inactive_relations() -> None:
|
||||
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
|
||||
kernel._feedback_cfg_paragraph_hard_filter_enabled = lambda: True # type: ignore[method-assign]
|
||||
kernel.metadata_store = SimpleNamespace(
|
||||
get_relation_status_batch=lambda hashes: {
|
||||
"r-active": {"is_inactive": False},
|
||||
"r-inactive": {"is_inactive": True},
|
||||
"r-para-inactive": {"is_inactive": True},
|
||||
"r-stale-active": {"is_inactive": False},
|
||||
"r-stale-inactive": {"is_inactive": True},
|
||||
},
|
||||
get_paragraph_relations=lambda paragraph_hash: (
|
||||
[{"hash": "r-para-inactive"}] if paragraph_hash == "p-inactive" else []
|
||||
),
|
||||
get_paragraph_stale_relation_marks_batch=lambda paragraph_hashes: {
|
||||
"p-stale": [{"relation_hash": "r-stale-inactive"}],
|
||||
"p-restored": [{"relation_hash": "r-stale-active"}],
|
||||
},
|
||||
)
|
||||
|
||||
hits = [
|
||||
{"hash": "r-active", "type": "relation", "content": "A 喜欢 B"},
|
||||
{"hash": "r-inactive", "type": "relation", "content": "A 讨厌 B"},
|
||||
{"hash": "p-1", "type": "paragraph", "content": "段落证据"},
|
||||
{"hash": "p-inactive", "type": "paragraph", "content": "失活段落证据"},
|
||||
{"hash": "p-stale", "type": "paragraph", "content": "被标脏段落"},
|
||||
{"hash": "p-restored", "type": "paragraph", "content": "恢复可见段落"},
|
||||
]
|
||||
|
||||
filtered = kernel._filter_active_relation_hits(hits)
|
||||
|
||||
assert [item["hash"] for item in filtered] == ["r-active", "p-1", "p-restored"]
|
||||
|
||||
|
||||
def test_sparse_relation_search_requests_active_only() -> None:
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
class FakeMetadataStore:
|
||||
def fts_search_relations_bm25(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
return []
|
||||
|
||||
index = SparseBM25Index(
|
||||
metadata_store=FakeMetadataStore(), # type: ignore[arg-type]
|
||||
config=SparseBM25Config(enabled=True, lazy_load=False),
|
||||
)
|
||||
index._loaded = True
|
||||
index._conn = object() # type: ignore[assignment]
|
||||
|
||||
result = index.search_relations("测试纠错", k=5)
|
||||
|
||||
assert result == []
|
||||
assert captured["include_inactive"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_feedback_task_rollback_restores_snapshots_and_requeues_followups() -> None:
|
||||
action_logs: list[Dict[str, Any]] = []
|
||||
queued_sources: list[str] = []
|
||||
queued_profiles: list[str] = []
|
||||
deleted_marks: list[tuple[str, str]] = []
|
||||
deleted_paragraphs: list[str] = []
|
||||
relation_statuses: Dict[str, Dict[str, Any]] = {
|
||||
"rel-old": {"is_inactive": True, "weight": 0.0, "is_pinned": False, "protected_until": 0.0, "last_reinforced": None, "inactive_since": 1.0},
|
||||
"rel-new": {"is_inactive": False, "weight": 1.0, "is_pinned": False, "protected_until": 0.0, "last_reinforced": None, "inactive_since": None},
|
||||
}
|
||||
current_task: Dict[str, Any] = {
|
||||
"id": 1,
|
||||
"query_tool_id": "tool-query-rollback",
|
||||
"session_id": "session-1",
|
||||
"status": "applied",
|
||||
"rollback_status": "none",
|
||||
"query_snapshot": {"query": "测试用户最喜欢的颜色是什么"},
|
||||
"decision_payload": {"decision": "correct", "confidence": 0.97},
|
||||
"rollback_plan": {
|
||||
"forgotten_relations": [
|
||||
{
|
||||
"hash": "rel-old",
|
||||
"subject": "测试用户",
|
||||
"predicate": "最喜欢的颜色是",
|
||||
"object": "蓝色",
|
||||
"before_status": {
|
||||
"is_inactive": False,
|
||||
"weight": 0.8,
|
||||
"is_pinned": False,
|
||||
"protected_until": 0.0,
|
||||
"last_reinforced": None,
|
||||
"inactive_since": None,
|
||||
},
|
||||
}
|
||||
],
|
||||
"corrected_write": {
|
||||
"paragraph_hashes": ["paragraph-new"],
|
||||
"corrected_relations": [
|
||||
{
|
||||
"hash": "rel-new",
|
||||
"subject": "测试用户",
|
||||
"predicate": "最喜欢的颜色是",
|
||||
"object": "绿色",
|
||||
"existed_before": False,
|
||||
"before_status": {},
|
||||
}
|
||||
],
|
||||
},
|
||||
"stale_marks": [{"paragraph_hash": "paragraph-old", "relation_hash": "rel-old"}],
|
||||
"episode_sources": ["chat_summary:session-1"],
|
||||
"profile_person_ids": ["person-1"],
|
||||
},
|
||||
}
|
||||
|
||||
class _Conn:
|
||||
def cursor(self):
|
||||
return self
|
||||
|
||||
def execute(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
def commit(self):
|
||||
return None
|
||||
|
||||
metadata_store = SimpleNamespace(
|
||||
get_feedback_task_by_id=lambda task_id: current_task if int(task_id) == 1 else None,
|
||||
mark_feedback_task_rollback_running=lambda **kwargs: current_task.update({"rollback_status": "running"}) or current_task,
|
||||
finalize_feedback_task_rollback=lambda **kwargs: current_task.update(
|
||||
{
|
||||
"rollback_status": kwargs["rollback_status"],
|
||||
"rollback_result": kwargs.get("rollback_result") or {},
|
||||
"rollback_error": kwargs.get("rollback_error", ""),
|
||||
}
|
||||
)
|
||||
or current_task,
|
||||
get_relation_status_batch=lambda hashes: {
|
||||
hash_value: dict(relation_statuses[hash_value])
|
||||
for hash_value in hashes
|
||||
if hash_value in relation_statuses
|
||||
},
|
||||
restore_relation_status_from_snapshot=lambda hash_value, snapshot: relation_statuses.update(
|
||||
{hash_value: dict(snapshot)}
|
||||
)
|
||||
or dict(snapshot),
|
||||
append_feedback_action_log=lambda **kwargs: action_logs.append(kwargs),
|
||||
mark_as_deleted=lambda hashes, type_: deleted_paragraphs.extend(list(hashes)) or len(list(hashes)),
|
||||
get_paragraph=lambda paragraph_hash: {"hash": paragraph_hash, "source": "chat_summary:session-1"},
|
||||
get_connection=lambda: _Conn(),
|
||||
delete_external_memory_refs_by_paragraphs=lambda hashes: [
|
||||
{"paragraph_hash": str(hash_value), "external_id": f"external:{hash_value}"}
|
||||
for hash_value in hashes
|
||||
],
|
||||
update_relations_protection=lambda hashes, **kwargs: None,
|
||||
mark_relations_inactive=lambda hashes, inactive_since=None: [
|
||||
relation_statuses.__setitem__(
|
||||
hash_value,
|
||||
{
|
||||
**relation_statuses.get(hash_value, {}),
|
||||
"is_inactive": True,
|
||||
"inactive_since": inactive_since,
|
||||
},
|
||||
)
|
||||
for hash_value in hashes
|
||||
],
|
||||
delete_paragraph_stale_relation_marks=lambda marks: deleted_marks.extend(list(marks)) or len(list(marks)),
|
||||
enqueue_episode_source_rebuild=lambda source, reason='': queued_sources.append(source) or True,
|
||||
enqueue_person_profile_refresh=lambda **kwargs: queued_profiles.append(kwargs["person_id"]) or kwargs,
|
||||
list_feedback_action_logs=lambda task_id: action_logs if int(task_id) == 1 else [],
|
||||
)
|
||||
|
||||
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
|
||||
kernel.metadata_store = metadata_store
|
||||
async def _noop_initialize() -> None:
|
||||
return None
|
||||
kernel.initialize = _noop_initialize # type: ignore[method-assign]
|
||||
kernel._rebuild_graph_from_metadata = lambda: None # type: ignore[method-assign]
|
||||
kernel._persist = lambda: None # type: ignore[method-assign]
|
||||
|
||||
payload = await kernel._rollback_feedback_task(
|
||||
task_id=1,
|
||||
requested_by="pytest",
|
||||
reason="manual rollback",
|
||||
)
|
||||
|
||||
assert payload["success"] is True
|
||||
assert relation_statuses["rel-old"]["is_inactive"] is False
|
||||
assert relation_statuses["rel-new"]["is_inactive"] is True
|
||||
assert deleted_paragraphs == ["paragraph-new"]
|
||||
assert deleted_marks == [("paragraph-old", "rel-old")]
|
||||
assert queued_sources == ["chat_summary:session-1"]
|
||||
assert queued_profiles == ["person-1"]
|
||||
assert current_task["rollback_status"] == "rolled_back"
|
||||
assert {item["action_type"] for item in action_logs} >= {
|
||||
"rollback_restore_relation",
|
||||
"rollback_revert_corrected_relation",
|
||||
"rollback_delete_correction_paragraph",
|
||||
"rollback_clear_stale_mark",
|
||||
"rollback_enqueue_episode_rebuild",
|
||||
"rollback_enqueue_profile_refresh",
|
||||
}
|
||||
@@ -638,3 +638,36 @@ def test_delete_operation_routes(client: TestClient, monkeypatch):
|
||||
assert list_response.json()["count"] == 1
|
||||
assert get_response.status_code == 200
|
||||
assert get_response.json()["operation"]["operation_id"] == "del-1"
|
||||
|
||||
|
||||
def test_feedback_correction_routes(client: TestClient, monkeypatch):
|
||||
async def fake_feedback_admin(*, action: str, **kwargs):
|
||||
if action == "list":
|
||||
assert kwargs == {"limit": 7, "status": "applied", "rollback_status": "none", "query": "green"}
|
||||
return {"success": True, "items": [{"task_id": 11, "query_text": "what color"}], "count": 1}
|
||||
if action == "get":
|
||||
assert kwargs == {"task_id": 11}
|
||||
return {"success": True, "task": {"task_id": 11, "query_text": "what color", "action_logs": []}}
|
||||
if action == "rollback":
|
||||
assert kwargs == {"task_id": 11, "requested_by": "tester", "reason": "manual revert"}
|
||||
return {"success": True, "result": {"restored_relation_hashes": ["rel-1"]}}
|
||||
raise AssertionError(action)
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "feedback_admin", fake_feedback_admin)
|
||||
|
||||
list_response = client.get(
|
||||
"/api/webui/memory/feedback-corrections",
|
||||
params={"limit": 7, "status": "applied", "rollback_status": "none", "query": "green"},
|
||||
)
|
||||
get_response = client.get("/api/webui/memory/feedback-corrections/11")
|
||||
rollback_response = client.post(
|
||||
"/api/webui/memory/feedback-corrections/11/rollback",
|
||||
json={"requested_by": "tester", "reason": "manual revert"},
|
||||
)
|
||||
|
||||
assert list_response.status_code == 200
|
||||
assert list_response.json()["items"][0]["task_id"] == 11
|
||||
assert get_response.status_code == 200
|
||||
assert get_response.json()["task"]["task_id"] == 11
|
||||
assert rollback_response.status_code == 200
|
||||
assert rollback_response.json()["result"]["restored_relation_hashes"] == ["rel-1"]
|
||||
|
||||
@@ -11,7 +11,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
@@ -29,6 +29,9 @@ logger = get_logger("A_Memorix.EmbeddingAPIAdapter")
|
||||
class EmbeddingAPIAdapter:
|
||||
"""适配宿主 embedding 请求接口。"""
|
||||
|
||||
_GLOBAL_DIMENSION_CACHE: Dict[str, int] = {}
|
||||
_GLOBAL_TEXT_EMBEDDING_CACHE: Dict[Tuple[str, int, str], np.ndarray] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = 32,
|
||||
@@ -232,10 +235,32 @@ class EmbeddingAPIAdapter:
|
||||
logger.error(f"通过直接 Client 获取 Embedding 失败: {last_exc}")
|
||||
return None
|
||||
|
||||
def _dimension_cache_key(self) -> str:
|
||||
candidate_names = self._resolve_candidate_model_names()
|
||||
return "|".join(
|
||||
[
|
||||
str(self.model_name or "auto"),
|
||||
str(self.default_dimension),
|
||||
",".join(candidate_names),
|
||||
]
|
||||
)
|
||||
|
||||
def _embedding_cache_key(self, text: str, dimensions: Optional[int]) -> Tuple[str, int, str]:
|
||||
requested_dimension = self._resolve_canonical_dimension(dimensions)
|
||||
return (self._dimension_cache_key(), int(requested_dimension), str(text or ""))
|
||||
|
||||
async def _detect_dimension(self) -> int:
|
||||
if self._dimension_detected and self._dimension is not None:
|
||||
return self._dimension
|
||||
|
||||
cache_key = self._dimension_cache_key()
|
||||
cached_dimension = self._GLOBAL_DIMENSION_CACHE.get(cache_key)
|
||||
if cached_dimension is not None:
|
||||
self._dimension = int(cached_dimension)
|
||||
self._dimension_detected = True
|
||||
logger.info(f"嵌入维度命中进程缓存: {self._dimension}")
|
||||
return self._dimension
|
||||
|
||||
logger.info("正在检测嵌入模型维度...")
|
||||
try:
|
||||
target_dim = self.default_dimension
|
||||
@@ -251,6 +276,7 @@ class EmbeddingAPIAdapter:
|
||||
)
|
||||
self._dimension = detected_dim
|
||||
self._dimension_detected = True
|
||||
self._GLOBAL_DIMENSION_CACHE[cache_key] = int(detected_dim)
|
||||
return detected_dim
|
||||
except Exception as exc:
|
||||
logger.debug(f"带维度参数探测失败: {exc},尝试不带维度参数探测")
|
||||
@@ -261,6 +287,7 @@ class EmbeddingAPIAdapter:
|
||||
detected_dim = len(test_embedding)
|
||||
self._dimension = detected_dim
|
||||
self._dimension_detected = True
|
||||
self._GLOBAL_DIMENSION_CACHE[cache_key] = int(detected_dim)
|
||||
logger.info(f"嵌入维度检测成功 (自然维度): {detected_dim}")
|
||||
return detected_dim
|
||||
logger.warning(f"嵌入维度检测失败,使用 configured_dimension: {self.default_dimension}")
|
||||
@@ -269,6 +296,7 @@ class EmbeddingAPIAdapter:
|
||||
|
||||
self._dimension = self.default_dimension
|
||||
self._dimension_detected = True
|
||||
self._GLOBAL_DIMENSION_CACHE[cache_key] = int(self.default_dimension)
|
||||
return self.default_dimension
|
||||
|
||||
async def encode(
|
||||
@@ -336,6 +364,25 @@ class EmbeddingAPIAdapter:
|
||||
all_embeddings: List[np.ndarray] = []
|
||||
for offset in range(0, len(texts), batch_size):
|
||||
batch = texts[offset : offset + batch_size]
|
||||
batch_results: List[Tuple[int, np.ndarray]] = []
|
||||
uncached_items: List[Tuple[int, str]] = []
|
||||
|
||||
if self.enable_cache:
|
||||
for index, text in enumerate(batch):
|
||||
cache_key = self._embedding_cache_key(text, dimensions)
|
||||
cached_vector = self._GLOBAL_TEXT_EMBEDDING_CACHE.get(cache_key)
|
||||
if cached_vector is None:
|
||||
uncached_items.append((index, text))
|
||||
else:
|
||||
batch_results.append((index, cached_vector.copy()))
|
||||
else:
|
||||
uncached_items = list(enumerate(batch))
|
||||
|
||||
if not uncached_items:
|
||||
batch_results.sort(key=lambda item: item[0])
|
||||
all_embeddings.extend(emb for _, emb in batch_results)
|
||||
continue
|
||||
|
||||
semaphore = asyncio.Semaphore(self.max_concurrent)
|
||||
|
||||
async def encode_with_semaphore(text: str, index: int):
|
||||
@@ -351,11 +398,20 @@ class EmbeddingAPIAdapter:
|
||||
|
||||
tasks = [
|
||||
encode_with_semaphore(text, offset + index)
|
||||
for index, text in enumerate(batch)
|
||||
for index, text in uncached_items
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
results.sort(key=lambda item: item[0])
|
||||
all_embeddings.extend(emb for _, emb in results)
|
||||
normalized_results: List[Tuple[int, np.ndarray]] = []
|
||||
for batch_index, vector in results:
|
||||
normalized_results.append((batch_index, vector))
|
||||
if self.enable_cache:
|
||||
text = batch[batch_index]
|
||||
cache_key = self._embedding_cache_key(text, dimensions)
|
||||
self._GLOBAL_TEXT_EMBEDDING_CACHE[cache_key] = vector.copy()
|
||||
|
||||
batch_results.extend(normalized_results)
|
||||
batch_results.sort(key=lambda item: item[0])
|
||||
all_embeddings.extend(emb for _, emb in batch_results)
|
||||
|
||||
return np.array(all_embeddings, dtype=np.float32)
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ except Exception:
|
||||
logger = get_logger("A_Memorix.MetadataStore")
|
||||
|
||||
|
||||
SCHEMA_VERSION = 10
|
||||
SCHEMA_VERSION = 12
|
||||
RUNTIME_AUTO_MIGRATION_MIN_SCHEMA_VERSION = 9
|
||||
|
||||
|
||||
|
||||
@@ -375,6 +375,30 @@ def _preflight_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]:
|
||||
"memory_feedback_tasks rollback columns missing under current schema version",
|
||||
)
|
||||
)
|
||||
elif not has_stale_marks:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-15",
|
||||
"error",
|
||||
"paragraph_stale_relation_marks table missing under current schema version",
|
||||
)
|
||||
)
|
||||
elif not has_profile_refresh_queue:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-16",
|
||||
"error",
|
||||
"person_profile_refresh_queue table missing under current schema version",
|
||||
)
|
||||
)
|
||||
elif not has_feedback_rollback_status or not has_feedback_rollback_plan:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-17",
|
||||
"error",
|
||||
"memory_feedback_tasks rollback columns missing under current schema version",
|
||||
)
|
||||
)
|
||||
|
||||
if _sqlite_table_exists(conn, "relations"):
|
||||
row = conn.execute("SELECT COUNT(*) FROM relations").fetchone()
|
||||
|
||||
@@ -145,23 +145,23 @@ class VisualConfig(ConfigBase):
|
||||
__ui_label__ = "视觉"
|
||||
__ui_icon__ = "image"
|
||||
|
||||
planner_mode: Literal["text", "multimodal", "auto"] = Field(
|
||||
default="auto",
|
||||
multimodal_planner: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "select",
|
||||
"x-icon": "git-branch",
|
||||
"x-widget": "switch",
|
||||
"x-icon": "image",
|
||||
},
|
||||
)
|
||||
"""规划器模式,auto根据模型信息自动选择,text为纯文本模式,multimodal为多模态模式"""
|
||||
"""是否直接输入图片"""
|
||||
|
||||
replyer_mode: Literal["text", "multimodal", "auto"] = Field(
|
||||
default="auto",
|
||||
multimodal_replyer: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"x-widget": "select",
|
||||
"x-widget": "switch",
|
||||
"x-icon": "git-branch",
|
||||
},
|
||||
)
|
||||
"""回复器模式,auto根据模型信息自动选择,text为纯文本模式,multimodal为多模态模式"""
|
||||
"""是否启用 Maisaka 多模态 replyer 生成器"""
|
||||
|
||||
visual_style: str = Field(
|
||||
default="请用中文描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题,直观感受,输出为一段平文本,最多30字,请注意不要分点,就输出一段文本",
|
||||
@@ -239,12 +239,17 @@ class ChatConfig(ConfigBase):
|
||||
)
|
||||
"""Planner 连续被新消息打断的最大次数,0 表示不启用打断"""
|
||||
|
||||
plan_reply_log_max_per_chat: int = Field(
|
||||
default=1024,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "file-text",
|
||||
},
|
||||
)
|
||||
"""每个聊天流最大保存的Plan/Reply日志数量,超过此数量时会自动删除最老的日志"""
|
||||
|
||||
group_chat_prompt: str = Field(
|
||||
default="""
|
||||
你正在qq群里聊天,下面是群里正在聊的内容,其中包含聊天记录和聊天中的图片。
|
||||
回复尽量简短一些。最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。请注意把握聊天内容。
|
||||
不要回复的太频繁!控制回复的频率,不要每个人的消息都回复,只回复你感兴趣的或者主动提及你的。
|
||||
""",
|
||||
default="你需要控制自己发言的频率,如果是一对一聊天,可以以较均匀的频率发言;如果用户较多,不要每句都回复,控制回复频率,不要回复的太频繁!控制回复的频率,不要每个人的消息都回复。",
|
||||
json_schema_extra={
|
||||
"x-widget": "textarea",
|
||||
"x-icon": "users",
|
||||
@@ -253,11 +258,7 @@ class ChatConfig(ConfigBase):
|
||||
"""_wrap_群聊通用注意事项"""
|
||||
|
||||
private_chat_prompts: str = Field(
|
||||
default="""
|
||||
你正在聊天,下面是正在聊的内容,其中包含聊天记录和聊天中的图片。
|
||||
回复尽量简短一些。请注意把握聊天内容。
|
||||
请考虑对方的发言频率,想法,思考自己何时回复以及回复内容。
|
||||
""",
|
||||
default="你需要控制自己发言的频率,可以以较均匀的频率发言。",
|
||||
json_schema_extra={
|
||||
"x-widget": "textarea",
|
||||
"x-icon": "user",
|
||||
@@ -663,15 +664,6 @@ class LearningItem(ConfigBase):
|
||||
)
|
||||
"""是否启用jargon学习"""
|
||||
|
||||
advanced_chosen: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "sparkles",
|
||||
},
|
||||
)
|
||||
"""是否启用基于子代理的二次表达方式选择"""
|
||||
|
||||
|
||||
class ExpressionGroup(ConfigBase):
|
||||
"""表达互通组配置类,若列表为空代表全局共享"""
|
||||
@@ -701,7 +693,6 @@ class ExpressionConfig(ConfigBase):
|
||||
use_expression=True,
|
||||
enable_learning=True,
|
||||
enable_jargon_learning=True,
|
||||
advanced_chosen=False,
|
||||
)
|
||||
],
|
||||
json_schema_extra={
|
||||
@@ -1573,6 +1564,35 @@ class MaiSakaConfig(ConfigBase):
|
||||
)
|
||||
"""MaiSaka 使用的用户名称"""
|
||||
|
||||
tool_filter_task_name: str = Field(
|
||||
default="utils",
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "sparkles",
|
||||
},
|
||||
)
|
||||
"""工具筛选预判使用的模型任务名"""
|
||||
|
||||
tool_filter_threshold: int = Field(
|
||||
default=20,
|
||||
ge=1,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "filter",
|
||||
},
|
||||
)
|
||||
"""当可用工具总数超过该阈值时,先进行一轮工具筛选"""
|
||||
|
||||
tool_filter_max_keep: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "list-filter",
|
||||
},
|
||||
)
|
||||
"""工具筛选阶段最多保留的非内置工具数量"""
|
||||
|
||||
show_image_path: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
|
||||
Reference in New Issue
Block a user