Merge pull request #1587 from A-Dawn/r-dev
feat:记忆系统新增反馈学习功能&修复聊天内容导入问题
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -371,3 +371,4 @@ packages/
|
||||
.claude/
|
||||
.omc/
|
||||
/.venv312
|
||||
/src/A_memorix/algorithm_redesign
|
||||
|
||||
@@ -496,6 +496,77 @@ export interface MemoryDeleteOperationDetailPayload {
|
||||
error?: string
|
||||
}
|
||||
|
||||
export interface MemoryFeedbackAffectedCountsPayload {
|
||||
relations?: number
|
||||
stale_paragraphs?: number
|
||||
episode_sources?: number
|
||||
profile_person_ids?: number
|
||||
correction_paragraphs?: number
|
||||
corrected_relations?: number
|
||||
}
|
||||
|
||||
export interface MemoryFeedbackActionLogPayload {
|
||||
id: number
|
||||
task_id: number
|
||||
query_tool_id: string
|
||||
action_type: string
|
||||
target_hash: string
|
||||
reason?: string
|
||||
before_payload?: Record<string, unknown>
|
||||
after_payload?: Record<string, unknown>
|
||||
created_at?: number
|
||||
}
|
||||
|
||||
export interface MemoryFeedbackCorrectionSummaryPayload {
|
||||
task_id: number
|
||||
query_tool_id: string
|
||||
session_id: string
|
||||
query_text: string
|
||||
query_timestamp?: number
|
||||
task_status: string
|
||||
decision: string
|
||||
decision_confidence: number
|
||||
feedback_message_count: number
|
||||
rollback_status: string
|
||||
affected_counts: MemoryFeedbackAffectedCountsPayload
|
||||
created_at?: number
|
||||
updated_at?: number
|
||||
}
|
||||
|
||||
export interface MemoryFeedbackCorrectionDetailTaskPayload extends MemoryFeedbackCorrectionSummaryPayload {
|
||||
query_snapshot?: Record<string, unknown>
|
||||
decision_payload?: Record<string, unknown>
|
||||
rollback_plan_summary?: Record<string, unknown>
|
||||
rollback_result?: Record<string, unknown>
|
||||
rollback_error?: string
|
||||
rollback_requested_by?: string
|
||||
rollback_reason?: string
|
||||
rollback_requested_at?: number
|
||||
rolled_back_at?: number
|
||||
action_logs?: MemoryFeedbackActionLogPayload[]
|
||||
}
|
||||
|
||||
export interface MemoryFeedbackCorrectionListPayload {
|
||||
success: boolean
|
||||
items: MemoryFeedbackCorrectionSummaryPayload[]
|
||||
count?: number
|
||||
error?: string
|
||||
}
|
||||
|
||||
export interface MemoryFeedbackCorrectionDetailPayload {
|
||||
success: boolean
|
||||
task?: MemoryFeedbackCorrectionDetailTaskPayload | null
|
||||
error?: string
|
||||
}
|
||||
|
||||
export interface MemoryFeedbackCorrectionRollbackPayload {
|
||||
success: boolean
|
||||
already_rolled_back?: boolean
|
||||
result?: Record<string, unknown>
|
||||
task?: MemoryFeedbackCorrectionDetailTaskPayload | null
|
||||
error?: string
|
||||
}
|
||||
|
||||
export interface MemorySourceItemPayload {
|
||||
source: string
|
||||
paragraph_count?: number
|
||||
@@ -610,6 +681,49 @@ export async function getMemoryDeleteOperation(
|
||||
return requestJson<MemoryDeleteOperationDetailPayload>(`/delete/operations/${encodeURIComponent(operationId)}`)
|
||||
}
|
||||
|
||||
export async function getMemoryFeedbackCorrections(
|
||||
options?: {
|
||||
limit?: number
|
||||
status?: string
|
||||
rollbackStatus?: string
|
||||
query?: string
|
||||
},
|
||||
): Promise<MemoryFeedbackCorrectionListPayload> {
|
||||
const params = new URLSearchParams({
|
||||
limit: String(options?.limit ?? 50),
|
||||
})
|
||||
if (options?.status?.trim()) {
|
||||
params.set('status', options.status.trim())
|
||||
}
|
||||
if (options?.rollbackStatus?.trim()) {
|
||||
params.set('rollback_status', options.rollbackStatus.trim())
|
||||
}
|
||||
if (options?.query?.trim()) {
|
||||
params.set('query', options.query.trim())
|
||||
}
|
||||
return requestJson<MemoryFeedbackCorrectionListPayload>(`/feedback-corrections?${params.toString()}`)
|
||||
}
|
||||
|
||||
export async function getMemoryFeedbackCorrection(
|
||||
taskId: number,
|
||||
): Promise<MemoryFeedbackCorrectionDetailPayload> {
|
||||
return requestJson<MemoryFeedbackCorrectionDetailPayload>(`/feedback-corrections/${taskId}`)
|
||||
}
|
||||
|
||||
export async function rollbackMemoryFeedbackCorrection(
|
||||
taskId: number,
|
||||
payload: {
|
||||
requested_by?: string
|
||||
reason?: string
|
||||
},
|
||||
): Promise<MemoryFeedbackCorrectionRollbackPayload> {
|
||||
return requestJson<MemoryFeedbackCorrectionRollbackPayload>(`/feedback-corrections/${taskId}/rollback`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(payload),
|
||||
})
|
||||
}
|
||||
|
||||
export async function getMemorySources(): Promise<MemorySourceListPayload> {
|
||||
return requestJson<MemorySourceListPayload>('/sources')
|
||||
}
|
||||
|
||||
@@ -81,9 +81,12 @@ vi.mock('@/lib/memory-api', () => ({
|
||||
getMemorySources: vi.fn(),
|
||||
getMemoryDeleteOperations: vi.fn(),
|
||||
getMemoryDeleteOperation: vi.fn(),
|
||||
getMemoryFeedbackCorrections: vi.fn(),
|
||||
getMemoryFeedbackCorrection: vi.fn(),
|
||||
previewMemoryDelete: vi.fn(),
|
||||
executeMemoryDelete: vi.fn(),
|
||||
restoreMemoryDelete: vi.fn(),
|
||||
rollbackMemoryFeedbackCorrection: vi.fn(),
|
||||
}))
|
||||
|
||||
function mockImportTask(taskId: string, status: string = 'running'): memoryApi.MemoryImportTaskPayload {
|
||||
@@ -357,6 +360,82 @@ describe('KnowledgeBasePage import workflow', () => {
|
||||
items: [],
|
||||
},
|
||||
})
|
||||
vi.mocked(memoryApi.getMemoryFeedbackCorrections).mockResolvedValue({
|
||||
success: true,
|
||||
items: [
|
||||
{
|
||||
task_id: 11,
|
||||
query_tool_id: 'tool-query-11',
|
||||
session_id: 'session-1',
|
||||
query_text: '测试用户最喜欢的颜色是什么',
|
||||
query_timestamp: 1_710_000_010,
|
||||
task_status: 'applied',
|
||||
decision: 'correct',
|
||||
decision_confidence: 0.97,
|
||||
feedback_message_count: 1,
|
||||
rollback_status: 'none',
|
||||
affected_counts: {
|
||||
relations: 1,
|
||||
stale_paragraphs: 1,
|
||||
episode_sources: 2,
|
||||
profile_person_ids: 1,
|
||||
correction_paragraphs: 1,
|
||||
corrected_relations: 1,
|
||||
},
|
||||
created_at: 1_710_000_011,
|
||||
updated_at: 1_710_000_012,
|
||||
},
|
||||
],
|
||||
count: 1,
|
||||
})
|
||||
vi.mocked(memoryApi.getMemoryFeedbackCorrection).mockResolvedValue({
|
||||
success: true,
|
||||
task: {
|
||||
task_id: 11,
|
||||
query_tool_id: 'tool-query-11',
|
||||
session_id: 'session-1',
|
||||
query_text: '测试用户最喜欢的颜色是什么',
|
||||
query_timestamp: 1_710_000_010,
|
||||
task_status: 'applied',
|
||||
decision: 'correct',
|
||||
decision_confidence: 0.97,
|
||||
feedback_message_count: 1,
|
||||
rollback_status: 'none',
|
||||
affected_counts: {
|
||||
relations: 1,
|
||||
stale_paragraphs: 1,
|
||||
episode_sources: 2,
|
||||
profile_person_ids: 1,
|
||||
correction_paragraphs: 1,
|
||||
corrected_relations: 1,
|
||||
},
|
||||
query_snapshot: { query: '测试用户最喜欢的颜色是什么', hits: [{ hash: 'paragraph-1' }] },
|
||||
decision_payload: { decision: 'correct', confidence: 0.97 },
|
||||
rollback_plan_summary: {
|
||||
forgotten_relations: [{ hash: 'rel-old', subject: '测试用户', predicate: '最喜欢的颜色是', object: '蓝色' }],
|
||||
corrected_write: {
|
||||
paragraph_hashes: ['paragraph-new'],
|
||||
corrected_relations: [{ hash: 'rel-new', subject: '测试用户', predicate: '最喜欢的颜色是', object: '绿色' }],
|
||||
},
|
||||
},
|
||||
rollback_result: {},
|
||||
action_logs: [
|
||||
{
|
||||
id: 1,
|
||||
task_id: 11,
|
||||
query_tool_id: 'tool-query-11',
|
||||
action_type: 'forget_relation',
|
||||
target_hash: 'rel-old',
|
||||
reason: '用户明确纠正为绿色',
|
||||
before_payload: { hash: 'rel-old', subject: '测试用户', predicate: '最喜欢的颜色是', object: '蓝色' },
|
||||
after_payload: { is_inactive: true },
|
||||
created_at: 1_710_000_013,
|
||||
},
|
||||
],
|
||||
created_at: 1_710_000_011,
|
||||
updated_at: 1_710_000_012,
|
||||
},
|
||||
})
|
||||
vi.mocked(memoryApi.previewMemoryDelete).mockResolvedValue({
|
||||
success: true,
|
||||
mode: 'source',
|
||||
@@ -380,6 +459,37 @@ describe('KnowledgeBasePage import workflow', () => {
|
||||
deleted_source_count: 1,
|
||||
} as never)
|
||||
vi.mocked(memoryApi.restoreMemoryDelete).mockResolvedValue({ success: true } as never)
|
||||
vi.mocked(memoryApi.rollbackMemoryFeedbackCorrection).mockResolvedValue({
|
||||
success: true,
|
||||
result: { restored_relation_hashes: ['rel-old'] },
|
||||
task: {
|
||||
task_id: 11,
|
||||
query_tool_id: 'tool-query-11',
|
||||
session_id: 'session-1',
|
||||
query_text: '测试用户最喜欢的颜色是什么',
|
||||
query_timestamp: 1_710_000_010,
|
||||
task_status: 'applied',
|
||||
decision: 'correct',
|
||||
decision_confidence: 0.97,
|
||||
feedback_message_count: 1,
|
||||
rollback_status: 'rolled_back',
|
||||
affected_counts: {
|
||||
relations: 1,
|
||||
stale_paragraphs: 1,
|
||||
episode_sources: 2,
|
||||
profile_person_ids: 1,
|
||||
correction_paragraphs: 1,
|
||||
corrected_relations: 1,
|
||||
},
|
||||
query_snapshot: { query: '测试用户最喜欢的颜色是什么', hits: [{ hash: 'paragraph-1' }] },
|
||||
decision_payload: { decision: 'correct', confidence: 0.97 },
|
||||
rollback_plan_summary: {},
|
||||
rollback_result: { restored_relation_hashes: ['rel-old'] },
|
||||
action_logs: [],
|
||||
created_at: 1_710_000_011,
|
||||
updated_at: 1_710_000_012,
|
||||
},
|
||||
})
|
||||
vi.mocked(memoryApi.refreshMemoryRuntimeSelfCheck).mockResolvedValue({
|
||||
success: true,
|
||||
report: { ok: true },
|
||||
@@ -619,4 +729,27 @@ describe('KnowledgeBasePage import workflow', () => {
|
||||
}),
|
||||
)
|
||||
}, 20_000)
|
||||
|
||||
it('shows feedback correction history and supports rollback', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<KnowledgeBasePage />)
|
||||
|
||||
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
|
||||
await user.click(screen.getByRole('tab', { name: '纠错历史' }))
|
||||
await screen.findByText('反馈纠错历史')
|
||||
await screen.findByText('测试用户最喜欢的颜色是什么')
|
||||
await waitFor(() => expect(memoryApi.getMemoryFeedbackCorrection).toHaveBeenCalledWith(11))
|
||||
|
||||
await user.click(screen.getByRole('button', { name: '回退本次纠错' }))
|
||||
const rollbackReason = await screen.findByLabelText('回退原因')
|
||||
await user.type(rollbackReason, '人工确认回退')
|
||||
await user.click(screen.getByRole('button', { name: '确认回退' }))
|
||||
|
||||
await waitFor(() =>
|
||||
expect(memoryApi.rollbackMemoryFeedbackCorrection).toHaveBeenCalledWith(11, {
|
||||
requested_by: 'knowledge_base',
|
||||
reason: '人工确认回退',
|
||||
}),
|
||||
)
|
||||
}, 20_000)
|
||||
})
|
||||
|
||||
@@ -24,6 +24,14 @@ import { Badge } from '@/components/ui/badge'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'
|
||||
import { Checkbox } from '@/components/ui/checkbox'
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from '@/components/ui/dialog'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { Label } from '@/components/ui/label'
|
||||
import { Progress } from '@/components/ui/progress'
|
||||
@@ -48,6 +56,8 @@ import {
|
||||
createMemoryRawScanImport,
|
||||
createMemoryTemporalBackfillImport,
|
||||
executeMemoryDelete,
|
||||
getMemoryFeedbackCorrection,
|
||||
getMemoryFeedbackCorrections,
|
||||
getMemoryImportPathAliases,
|
||||
getMemoryImportSettings,
|
||||
getMemoryImportTask,
|
||||
@@ -74,6 +84,7 @@ import {
|
||||
type MemoryImportTaskPayload,
|
||||
previewMemoryDelete,
|
||||
refreshMemoryRuntimeSelfCheck,
|
||||
rollbackMemoryFeedbackCorrection,
|
||||
resolveMemoryImportPath,
|
||||
retryMemoryImportTask,
|
||||
restoreMemoryDelete,
|
||||
@@ -82,6 +93,9 @@ import {
|
||||
type MemoryConfigSchemaPayload,
|
||||
type MemoryDeleteExecutePayload,
|
||||
type MemoryDeleteOperationPayload,
|
||||
type MemoryFeedbackActionLogPayload,
|
||||
type MemoryFeedbackCorrectionDetailTaskPayload,
|
||||
type MemoryFeedbackCorrectionSummaryPayload,
|
||||
type MemorySourceItemPayload,
|
||||
type MemoryRuntimeConfigPayload,
|
||||
type MemoryTaskPayload,
|
||||
@@ -90,6 +104,9 @@ import {
|
||||
const DELETE_OPERATION_FETCH_LIMIT = 100
|
||||
const DELETE_OPERATION_PAGE_SIZE = 6
|
||||
const DELETE_OPERATION_ITEM_PAGE_SIZE = 8
|
||||
const FEEDBACK_CORRECTION_FETCH_LIMIT = 100
|
||||
const FEEDBACK_CORRECTION_PAGE_SIZE = 6
|
||||
const FEEDBACK_ACTION_LOG_PAGE_SIZE = 8
|
||||
const IMPORT_CHUNK_PAGE_SIZE = 50
|
||||
|
||||
const RUNNING_IMPORT_STATUS = new Set(['preparing', 'running', 'cancel_requested'])
|
||||
@@ -270,6 +287,90 @@ function formatDeleteOperationTime(timestamp?: number | null): string {
|
||||
})
|
||||
}
|
||||
|
||||
function formatFeedbackDecision(decision: string): string {
|
||||
switch (decision) {
|
||||
case 'correct':
|
||||
return '纠正'
|
||||
case 'reject':
|
||||
return '否定'
|
||||
case 'confirm':
|
||||
return '确认'
|
||||
case 'supplement':
|
||||
return '补充'
|
||||
case 'none':
|
||||
return '无动作'
|
||||
default:
|
||||
return decision || '未知'
|
||||
}
|
||||
}
|
||||
|
||||
function formatFeedbackTaskStatus(status: string): string {
|
||||
switch (status) {
|
||||
case 'pending':
|
||||
return '待处理'
|
||||
case 'running':
|
||||
return '处理中'
|
||||
case 'applied':
|
||||
return '已应用'
|
||||
case 'skipped':
|
||||
return '已跳过'
|
||||
case 'error':
|
||||
return '失败'
|
||||
default:
|
||||
return status || '未知'
|
||||
}
|
||||
}
|
||||
|
||||
function formatFeedbackRollbackStatus(status: string): string {
|
||||
switch (status) {
|
||||
case 'none':
|
||||
return '未回退'
|
||||
case 'running':
|
||||
return '回退中'
|
||||
case 'rolled_back':
|
||||
return '已回退'
|
||||
case 'error':
|
||||
return '回退失败'
|
||||
default:
|
||||
return status || '未知'
|
||||
}
|
||||
}
|
||||
|
||||
function getFeedbackStatusVariant(
|
||||
status: string,
|
||||
): 'default' | 'secondary' | 'destructive' | 'outline' {
|
||||
if (status === 'applied' || status === 'rolled_back') {
|
||||
return 'default'
|
||||
}
|
||||
if (status === 'error') {
|
||||
return 'destructive'
|
||||
}
|
||||
if (status === 'running' || status === 'pending') {
|
||||
return 'outline'
|
||||
}
|
||||
return 'secondary'
|
||||
}
|
||||
|
||||
function summarizeFeedbackActionPayload(value: Record<string, unknown> | undefined): string {
|
||||
if (!value) {
|
||||
return ''
|
||||
}
|
||||
const hash = String(value.hash ?? '').trim()
|
||||
const subject = String(value.subject ?? '').trim()
|
||||
const predicate = String(value.predicate ?? '').trim()
|
||||
const object = String(value.object ?? '').trim()
|
||||
if (subject && predicate && object) {
|
||||
return formatDeleteRelationText(subject, predicate, object)
|
||||
}
|
||||
if (hash) {
|
||||
return hash
|
||||
}
|
||||
if (Array.isArray(value.target_hashes) && value.target_hashes.length > 0) {
|
||||
return `targets ${value.target_hashes.length}`
|
||||
}
|
||||
return trimDeleteItemText(JSON.stringify(value, null, 2), 120)
|
||||
}
|
||||
|
||||
type DeleteOperationItem = NonNullable<MemoryDeleteOperationPayload['items']>[number]
|
||||
|
||||
function trimDeleteItemText(value: string, maxLength: number = 140): string {
|
||||
@@ -471,6 +572,20 @@ export function KnowledgeBasePage() {
|
||||
const [deleteRestoring, setDeleteRestoring] = useState(false)
|
||||
const [deleteResult, setDeleteResult] = useState<MemoryDeleteExecutePayload | null>(null)
|
||||
const [pendingDeleteRequest, setPendingDeleteRequest] = useState<MemoryDeleteRequestPayload | null>(null)
|
||||
const [feedbackCorrections, setFeedbackCorrections] = useState<MemoryFeedbackCorrectionSummaryPayload[]>([])
|
||||
const [feedbackSearch, setFeedbackSearch] = useState('')
|
||||
const [feedbackStatusFilter, setFeedbackStatusFilter] = useState('all')
|
||||
const [feedbackRollbackFilter, setFeedbackRollbackFilter] = useState('all')
|
||||
const [feedbackPage, setFeedbackPage] = useState(1)
|
||||
const [selectedFeedbackTaskId, setSelectedFeedbackTaskId] = useState(0)
|
||||
const [selectedFeedbackTaskDetail, setSelectedFeedbackTaskDetail] = useState<MemoryFeedbackCorrectionDetailTaskPayload | null>(null)
|
||||
const [selectedFeedbackTaskLoading, setSelectedFeedbackTaskLoading] = useState(false)
|
||||
const [selectedFeedbackTaskError, setSelectedFeedbackTaskError] = useState('')
|
||||
const [feedbackActionLogSearch, setFeedbackActionLogSearch] = useState('')
|
||||
const [feedbackActionLogPage, setFeedbackActionLogPage] = useState(1)
|
||||
const [feedbackRollbackDialogOpen, setFeedbackRollbackDialogOpen] = useState(false)
|
||||
const [feedbackRollbackReason, setFeedbackRollbackReason] = useState('')
|
||||
const [feedbackRollingBack, setFeedbackRollingBack] = useState(false)
|
||||
const [tuningObjective, setTuningObjective] = useState('precision_priority')
|
||||
const [tuningIntensity, setTuningIntensity] = useState('standard')
|
||||
const [tuningSampleSize, setTuningSampleSize] = useState('24')
|
||||
@@ -491,6 +606,7 @@ export function KnowledgeBasePage() {
|
||||
tuningTaskPayload,
|
||||
sourcePayload,
|
||||
deleteOperationPayload,
|
||||
feedbackCorrectionPayload,
|
||||
] = await Promise.all([
|
||||
getMemoryConfigSchema(),
|
||||
getMemoryConfig(),
|
||||
@@ -503,6 +619,7 @@ export function KnowledgeBasePage() {
|
||||
getMemoryTuningTasks(20),
|
||||
getMemorySources(),
|
||||
getMemoryDeleteOperations(DELETE_OPERATION_FETCH_LIMIT),
|
||||
getMemoryFeedbackCorrections({ limit: FEEDBACK_CORRECTION_FETCH_LIMIT }),
|
||||
])
|
||||
|
||||
setSchemaPayload(schema)
|
||||
@@ -519,6 +636,7 @@ export function KnowledgeBasePage() {
|
||||
setTuningTasks(tuningTaskPayload.items ?? [])
|
||||
setMemorySources(sourcePayload.items ?? [])
|
||||
setDeleteOperations(deleteOperationPayload.items ?? [])
|
||||
setFeedbackCorrections(feedbackCorrectionPayload.items ?? [])
|
||||
if (!selectedImportTaskId && (importTaskPayload.items ?? []).length > 0) {
|
||||
const initialTaskId = String(importTaskPayload.items?.[0]?.task_id ?? '')
|
||||
if (initialTaskId) {
|
||||
@@ -1494,6 +1612,212 @@ export function KnowledgeBasePage() {
|
||||
setDeleteDialogOpen(true)
|
||||
}, [])
|
||||
|
||||
const filteredFeedbackCorrections = useMemo(() => {
|
||||
const keyword = feedbackSearch.trim().toLowerCase()
|
||||
return feedbackCorrections.filter((item) => {
|
||||
const taskStatus = String(item.task_status ?? '').trim().toLowerCase()
|
||||
const rollbackStatus = String(item.rollback_status ?? '').trim().toLowerCase()
|
||||
if (feedbackStatusFilter !== 'all' && taskStatus !== feedbackStatusFilter) {
|
||||
return false
|
||||
}
|
||||
if (feedbackRollbackFilter !== 'all' && rollbackStatus !== feedbackRollbackFilter) {
|
||||
return false
|
||||
}
|
||||
if (!keyword) {
|
||||
return true
|
||||
}
|
||||
return [
|
||||
item.query_tool_id,
|
||||
item.session_id,
|
||||
item.query_text,
|
||||
item.decision,
|
||||
item.task_status,
|
||||
item.rollback_status,
|
||||
]
|
||||
.map((value) => String(value ?? '').toLowerCase())
|
||||
.some((value) => value.includes(keyword))
|
||||
})
|
||||
}, [feedbackCorrections, feedbackRollbackFilter, feedbackSearch, feedbackStatusFilter])
|
||||
|
||||
const feedbackPageCount = Math.max(1, Math.ceil(filteredFeedbackCorrections.length / FEEDBACK_CORRECTION_PAGE_SIZE))
|
||||
const pagedFeedbackCorrections = useMemo(() => {
|
||||
const start = (feedbackPage - 1) * FEEDBACK_CORRECTION_PAGE_SIZE
|
||||
return filteredFeedbackCorrections.slice(start, start + FEEDBACK_CORRECTION_PAGE_SIZE)
|
||||
}, [feedbackPage, filteredFeedbackCorrections])
|
||||
|
||||
const selectedFeedbackCorrection = useMemo(
|
||||
() =>
|
||||
filteredFeedbackCorrections.find((item) => item.task_id === selectedFeedbackTaskId)
|
||||
?? pagedFeedbackCorrections[0]
|
||||
?? null,
|
||||
[filteredFeedbackCorrections, pagedFeedbackCorrections, selectedFeedbackTaskId],
|
||||
)
|
||||
|
||||
useEffect(() => {
|
||||
setFeedbackPage(1)
|
||||
}, [feedbackSearch, feedbackStatusFilter, feedbackRollbackFilter])
|
||||
|
||||
useEffect(() => {
|
||||
if (feedbackPage > feedbackPageCount) {
|
||||
setFeedbackPage(feedbackPageCount)
|
||||
}
|
||||
}, [feedbackPage, feedbackPageCount])
|
||||
|
||||
useEffect(() => {
|
||||
if (!selectedFeedbackCorrection) {
|
||||
if (selectedFeedbackTaskId) {
|
||||
setSelectedFeedbackTaskId(0)
|
||||
}
|
||||
setSelectedFeedbackTaskDetail(null)
|
||||
setSelectedFeedbackTaskError('')
|
||||
return
|
||||
}
|
||||
if (selectedFeedbackCorrection.task_id !== selectedFeedbackTaskId) {
|
||||
setSelectedFeedbackTaskId(selectedFeedbackCorrection.task_id)
|
||||
}
|
||||
}, [selectedFeedbackCorrection, selectedFeedbackTaskId])
|
||||
|
||||
useEffect(() => {
|
||||
const taskId = selectedFeedbackCorrection?.task_id
|
||||
if (!taskId) {
|
||||
setSelectedFeedbackTaskDetail(null)
|
||||
setSelectedFeedbackTaskError('')
|
||||
return
|
||||
}
|
||||
|
||||
let cancelled = false
|
||||
setSelectedFeedbackTaskLoading(true)
|
||||
setSelectedFeedbackTaskError('')
|
||||
|
||||
void getMemoryFeedbackCorrection(taskId)
|
||||
.then((payload) => {
|
||||
if (cancelled) {
|
||||
return
|
||||
}
|
||||
if (!payload.success || !payload.task) {
|
||||
setSelectedFeedbackTaskDetail(null)
|
||||
setSelectedFeedbackTaskError(payload.error || '未能加载纠错任务详情')
|
||||
return
|
||||
}
|
||||
setSelectedFeedbackTaskDetail(payload.task)
|
||||
})
|
||||
.catch((error) => {
|
||||
if (cancelled) {
|
||||
return
|
||||
}
|
||||
setSelectedFeedbackTaskDetail(null)
|
||||
setSelectedFeedbackTaskError(error instanceof Error ? error.message : '未能加载纠错任务详情')
|
||||
})
|
||||
.finally(() => {
|
||||
if (!cancelled) {
|
||||
setSelectedFeedbackTaskLoading(false)
|
||||
}
|
||||
})
|
||||
|
||||
return () => {
|
||||
cancelled = true
|
||||
}
|
||||
}, [selectedFeedbackCorrection?.task_id])
|
||||
|
||||
const selectedFeedbackResolved = useMemo(() => {
|
||||
if (!selectedFeedbackCorrection) {
|
||||
return null
|
||||
}
|
||||
if (selectedFeedbackTaskDetail?.task_id === selectedFeedbackCorrection.task_id) {
|
||||
return {
|
||||
...selectedFeedbackCorrection,
|
||||
...selectedFeedbackTaskDetail,
|
||||
} satisfies MemoryFeedbackCorrectionDetailTaskPayload
|
||||
}
|
||||
return selectedFeedbackTaskDetail ?? selectedFeedbackCorrection
|
||||
}, [selectedFeedbackCorrection, selectedFeedbackTaskDetail])
|
||||
|
||||
const selectedFeedbackActionLogs = Array.isArray(selectedFeedbackResolved?.action_logs)
|
||||
? selectedFeedbackResolved.action_logs
|
||||
: []
|
||||
const filteredFeedbackActionLogs = useMemo(() => {
|
||||
const keyword = feedbackActionLogSearch.trim().toLowerCase()
|
||||
if (!keyword) {
|
||||
return selectedFeedbackActionLogs
|
||||
}
|
||||
return selectedFeedbackActionLogs.filter((item) =>
|
||||
[
|
||||
item.action_type,
|
||||
item.target_hash,
|
||||
item.reason,
|
||||
summarizeFeedbackActionPayload(item.before_payload),
|
||||
summarizeFeedbackActionPayload(item.after_payload),
|
||||
]
|
||||
.map((value) => String(value ?? '').toLowerCase())
|
||||
.some((value) => value.includes(keyword)),
|
||||
)
|
||||
}, [feedbackActionLogSearch, selectedFeedbackActionLogs])
|
||||
const feedbackActionLogPageCount = Math.max(
|
||||
1,
|
||||
Math.ceil(filteredFeedbackActionLogs.length / FEEDBACK_ACTION_LOG_PAGE_SIZE),
|
||||
)
|
||||
const pagedFeedbackActionLogs = useMemo(() => {
|
||||
const start = (feedbackActionLogPage - 1) * FEEDBACK_ACTION_LOG_PAGE_SIZE
|
||||
return filteredFeedbackActionLogs.slice(start, start + FEEDBACK_ACTION_LOG_PAGE_SIZE)
|
||||
}, [feedbackActionLogPage, filteredFeedbackActionLogs])
|
||||
|
||||
useEffect(() => {
|
||||
setFeedbackActionLogPage(1)
|
||||
}, [selectedFeedbackTaskId, feedbackActionLogSearch])
|
||||
|
||||
useEffect(() => {
|
||||
if (feedbackActionLogPage > feedbackActionLogPageCount) {
|
||||
setFeedbackActionLogPage(feedbackActionLogPageCount)
|
||||
}
|
||||
}, [feedbackActionLogPage, feedbackActionLogPageCount])
|
||||
|
||||
const openFeedbackRollbackDialog = useCallback(() => {
|
||||
setFeedbackRollbackReason('')
|
||||
setFeedbackRollbackDialogOpen(true)
|
||||
}, [])
|
||||
|
||||
const executeFeedbackRollback = useCallback(async () => {
|
||||
const taskId = selectedFeedbackResolved?.task_id
|
||||
if (!taskId) {
|
||||
return
|
||||
}
|
||||
try {
|
||||
setFeedbackRollingBack(true)
|
||||
const payload = await rollbackMemoryFeedbackCorrection(taskId, {
|
||||
requested_by: 'knowledge_base',
|
||||
reason: feedbackRollbackReason.trim(),
|
||||
})
|
||||
if (!payload.success) {
|
||||
throw new Error(payload.error || '回退失败')
|
||||
}
|
||||
toast({
|
||||
title: payload.already_rolled_back ? '该纠错已回退' : '纠错回退成功',
|
||||
description: `任务 ${taskId} 的回退结果已写入日志`,
|
||||
})
|
||||
setFeedbackRollbackDialogOpen(false)
|
||||
const [listPayload, detailPayload] = await Promise.all([
|
||||
getMemoryFeedbackCorrections({ limit: FEEDBACK_CORRECTION_FETCH_LIMIT }),
|
||||
getMemoryFeedbackCorrection(taskId),
|
||||
])
|
||||
setFeedbackCorrections(listPayload.items ?? [])
|
||||
setSelectedFeedbackTaskDetail(detailPayload.task ?? null)
|
||||
const [sourcePayload, runtimePayload] = await Promise.all([
|
||||
getMemorySources(),
|
||||
getMemoryRuntimeConfig(),
|
||||
])
|
||||
setMemorySources(sourcePayload.items ?? [])
|
||||
setRuntimeConfig(runtimePayload)
|
||||
} catch (error) {
|
||||
toast({
|
||||
title: '纠错回退失败',
|
||||
description: error instanceof Error ? error.message : '未知错误',
|
||||
variant: 'destructive',
|
||||
})
|
||||
} finally {
|
||||
setFeedbackRollingBack(false)
|
||||
}
|
||||
}, [feedbackRollbackReason, selectedFeedbackResolved?.task_id, toast])
|
||||
|
||||
const selectedOperationResolved = useMemo(() => {
|
||||
if (!selectedDeleteOperation) {
|
||||
return null
|
||||
@@ -1776,6 +2100,9 @@ export function KnowledgeBasePage() {
|
||||
<TabsTrigger value="delete" className="rounded-lg px-4 py-1.5">
|
||||
删除
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="feedback" className="rounded-lg px-4 py-1.5">
|
||||
纠错历史
|
||||
</TabsTrigger>
|
||||
</TabsList>
|
||||
|
||||
<TabsContent value="overview" className="space-y-4">
|
||||
@@ -3314,6 +3641,327 @@ export function KnowledgeBasePage() {
|
||||
</Card>
|
||||
</div>
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value="feedback" className="space-y-4">
|
||||
<div className="space-y-4">
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle className="flex items-center gap-2">
|
||||
<RotateCcw className="h-4 w-4" />
|
||||
反馈纠错历史
|
||||
</CardTitle>
|
||||
<CardDescription>
|
||||
查看 feedback correction 的判定、修改轨迹与回退结果;本期仅覆盖自动纠错任务
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-4">
|
||||
<div className="grid gap-3 lg:grid-cols-[minmax(0,1fr)_180px_180px]">
|
||||
<Input
|
||||
value={feedbackSearch}
|
||||
onChange={(event) => setFeedbackSearch(event.target.value)}
|
||||
placeholder="搜索 query_tool_id / session / query / reason"
|
||||
/>
|
||||
<Select value={feedbackStatusFilter} onValueChange={setFeedbackStatusFilter}>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder="按任务状态筛选" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="all">全部任务状态</SelectItem>
|
||||
<SelectItem value="applied">已应用</SelectItem>
|
||||
<SelectItem value="skipped">已跳过</SelectItem>
|
||||
<SelectItem value="error">失败</SelectItem>
|
||||
<SelectItem value="running">处理中</SelectItem>
|
||||
<SelectItem value="pending">待处理</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<Select value={feedbackRollbackFilter} onValueChange={setFeedbackRollbackFilter}>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder="按回退状态筛选" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="all">全部回退状态</SelectItem>
|
||||
<SelectItem value="none">未回退</SelectItem>
|
||||
<SelectItem value="rolled_back">已回退</SelectItem>
|
||||
<SelectItem value="error">回退失败</SelectItem>
|
||||
<SelectItem value="running">回退中</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-wrap items-center justify-between gap-2 text-sm text-muted-foreground">
|
||||
<span>当前命中 {filteredFeedbackCorrections.length} 条记录,已加载最近 {feedbackCorrections.length} 条</span>
|
||||
<span>第 {feedbackPage} / {feedbackPageCount} 页,每页显示 {FEEDBACK_CORRECTION_PAGE_SIZE} 条</span>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-4 xl:grid-cols-[minmax(0,0.92fr)_minmax(0,1.08fr)]">
|
||||
<ScrollArea className="h-[720px] rounded-lg border">
|
||||
<div className="space-y-3 p-3">
|
||||
{pagedFeedbackCorrections.length > 0 ? pagedFeedbackCorrections.map((item) => {
|
||||
const isSelected = selectedFeedbackCorrection?.task_id === item.task_id
|
||||
return (
|
||||
<button
|
||||
key={item.task_id}
|
||||
type="button"
|
||||
onClick={() => setSelectedFeedbackTaskId(item.task_id)}
|
||||
className={cn(
|
||||
'w-full rounded-xl border p-4 text-left transition-colors',
|
||||
isSelected
|
||||
? 'border-primary bg-primary/5 shadow-sm'
|
||||
: 'bg-muted/20 hover:border-primary/40 hover:bg-muted/40',
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-col gap-3">
|
||||
<div className="flex flex-wrap items-center gap-2">
|
||||
<Badge variant={getFeedbackStatusVariant(item.task_status)}>
|
||||
{formatFeedbackTaskStatus(item.task_status)}
|
||||
</Badge>
|
||||
<Badge variant={getFeedbackStatusVariant(item.rollback_status)}>
|
||||
{formatFeedbackRollbackStatus(item.rollback_status)}
|
||||
</Badge>
|
||||
<Badge variant="outline">
|
||||
{formatFeedbackDecision(item.decision)}
|
||||
</Badge>
|
||||
</div>
|
||||
<div className="text-sm font-medium break-words">
|
||||
{item.query_text || '无查询文本'}
|
||||
</div>
|
||||
<div className="font-mono text-[11px] break-all text-muted-foreground">
|
||||
{item.query_tool_id}
|
||||
</div>
|
||||
<div className="flex flex-wrap gap-3 text-xs text-muted-foreground">
|
||||
<span>关系 {Number(item.affected_counts?.relations ?? 0)}</span>
|
||||
<span>脏段落 {Number(item.affected_counts?.stale_paragraphs ?? 0)}</span>
|
||||
<span>Episode {Number(item.affected_counts?.episode_sources ?? 0)}</span>
|
||||
<span>Profile {Number(item.affected_counts?.profile_person_ids ?? 0)}</span>
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{formatDeleteOperationTime(item.query_timestamp ?? item.created_at)}
|
||||
</div>
|
||||
</div>
|
||||
</button>
|
||||
)
|
||||
}) : (
|
||||
<div className="rounded-lg border border-dashed bg-muted/20 p-6 text-center text-sm text-muted-foreground">
|
||||
当前筛选条件下没有纠错历史
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
|
||||
<div className="rounded-xl border bg-muted/20 p-4">
|
||||
{selectedFeedbackCorrection ? (
|
||||
<div className="space-y-4">
|
||||
<div className="flex flex-col gap-3 lg:flex-row lg:items-start lg:justify-between">
|
||||
<div className="space-y-2">
|
||||
<div className="flex flex-wrap items-center gap-2">
|
||||
<Badge variant={getFeedbackStatusVariant(String(selectedFeedbackResolved?.task_status ?? ''))}>
|
||||
{formatFeedbackTaskStatus(String(selectedFeedbackResolved?.task_status ?? ''))}
|
||||
</Badge>
|
||||
<Badge variant={getFeedbackStatusVariant(String(selectedFeedbackResolved?.rollback_status ?? 'none'))}>
|
||||
{formatFeedbackRollbackStatus(String(selectedFeedbackResolved?.rollback_status ?? 'none'))}
|
||||
</Badge>
|
||||
<Badge variant="outline">
|
||||
{formatFeedbackDecision(String(selectedFeedbackResolved?.decision ?? ''))}
|
||||
</Badge>
|
||||
</div>
|
||||
<div className="text-sm font-medium break-words">
|
||||
{selectedFeedbackResolved?.query_text || '无查询文本'}
|
||||
</div>
|
||||
<div className="font-mono text-xs break-all">
|
||||
{selectedFeedbackResolved?.query_tool_id}
|
||||
</div>
|
||||
</div>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onClick={openFeedbackRollbackDialog}
|
||||
disabled={
|
||||
String(selectedFeedbackResolved?.task_status ?? '') !== 'applied'
|
||||
|| String(selectedFeedbackResolved?.rollback_status ?? 'none') === 'rolled_back'
|
||||
|| feedbackRollingBack
|
||||
}
|
||||
>
|
||||
<RotateCcw className="mr-2 h-4 w-4" />
|
||||
{String(selectedFeedbackResolved?.rollback_status ?? 'none') === 'rolled_back'
|
||||
? '已回退'
|
||||
: '回退本次纠错'}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-3 lg:grid-cols-4">
|
||||
<div className="rounded-lg border bg-background/60 p-3">
|
||||
<div className="text-xs text-muted-foreground">会话</div>
|
||||
<div className="mt-1 text-sm break-all">{selectedFeedbackResolved?.session_id || '-'}</div>
|
||||
</div>
|
||||
<div className="rounded-lg border bg-background/60 p-3">
|
||||
<div className="text-xs text-muted-foreground">反馈消息数</div>
|
||||
<div className="mt-1 text-sm">{Number(selectedFeedbackResolved?.feedback_message_count ?? 0)}</div>
|
||||
</div>
|
||||
<div className="rounded-lg border bg-background/60 p-3">
|
||||
<div className="text-xs text-muted-foreground">判定置信度</div>
|
||||
<div className="mt-1 text-sm">{Number(selectedFeedbackResolved?.decision_confidence ?? 0).toFixed(2)}</div>
|
||||
</div>
|
||||
<div className="rounded-lg border bg-background/60 p-3">
|
||||
<div className="text-xs text-muted-foreground">回退时间</div>
|
||||
<div className="mt-1 text-sm">{formatDeleteOperationTime(selectedFeedbackResolved?.rolled_back_at)}</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{selectedFeedbackTaskLoading ? (
|
||||
<div className="rounded-lg border bg-background/60 p-4 text-sm text-muted-foreground">
|
||||
正在加载纠错详情...
|
||||
</div>
|
||||
) : null}
|
||||
|
||||
{selectedFeedbackTaskError ? (
|
||||
<Alert variant="destructive">
|
||||
<AlertDescription>{selectedFeedbackTaskError}</AlertDescription>
|
||||
</Alert>
|
||||
) : null}
|
||||
|
||||
{selectedFeedbackResolved?.rollback_error ? (
|
||||
<Alert variant="destructive">
|
||||
<AlertDescription>{selectedFeedbackResolved.rollback_error}</AlertDescription>
|
||||
</Alert>
|
||||
) : null}
|
||||
|
||||
<div className="grid gap-4 xl:grid-cols-[minmax(0,0.95fr)_minmax(0,1.05fr)]">
|
||||
<div className="space-y-2">
|
||||
<div className="text-sm font-semibold">查询快照</div>
|
||||
<pre className="max-h-56 overflow-auto rounded-lg border bg-background/70 p-3 text-xs break-words whitespace-pre-wrap">
|
||||
{JSON.stringify(selectedFeedbackResolved?.query_snapshot ?? {}, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<div className="text-sm font-semibold">判定结果</div>
|
||||
<pre className="max-h-56 overflow-auto rounded-lg border bg-background/70 p-3 text-xs break-words whitespace-pre-wrap">
|
||||
{JSON.stringify(selectedFeedbackResolved?.decision_payload ?? {}, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-4 xl:grid-cols-[minmax(0,0.95fr)_minmax(0,1.05fr)]">
|
||||
<div className="space-y-2">
|
||||
<div className="text-sm font-semibold">回退计划摘要</div>
|
||||
<pre className="max-h-64 overflow-auto rounded-lg border bg-background/70 p-3 text-xs break-words whitespace-pre-wrap">
|
||||
{JSON.stringify(selectedFeedbackResolved?.rollback_plan_summary ?? {}, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<div className="text-sm font-semibold">回退结果</div>
|
||||
<pre className="max-h-64 overflow-auto rounded-lg border bg-background/70 p-3 text-xs break-words whitespace-pre-wrap">
|
||||
{JSON.stringify(selectedFeedbackResolved?.rollback_result ?? {}, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<div className="flex flex-col gap-3 lg:flex-row lg:items-center lg:justify-between">
|
||||
<div className="text-sm font-semibold">动作时间线</div>
|
||||
<div className="flex flex-col gap-3 lg:flex-row lg:items-center lg:justify-end">
|
||||
<Input
|
||||
value={feedbackActionLogSearch}
|
||||
onChange={(event) => setFeedbackActionLogSearch(event.target.value)}
|
||||
placeholder="搜索动作 / hash / 预览内容"
|
||||
className="lg:w-80"
|
||||
/>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
第 {feedbackActionLogPage} / {feedbackActionLogPageCount} 页,每页 {FEEDBACK_ACTION_LOG_PAGE_SIZE} 项
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<ScrollArea className="h-[280px] rounded-lg border bg-background/60">
|
||||
<div className="space-y-2 p-3">
|
||||
{pagedFeedbackActionLogs.length > 0 ? pagedFeedbackActionLogs.map((item: MemoryFeedbackActionLogPayload) => (
|
||||
<div key={`${item.id}:${item.action_type}`} className="rounded-lg border bg-muted/20 p-3">
|
||||
<div className="flex flex-wrap items-center gap-2">
|
||||
<Badge variant="outline">{item.action_type}</Badge>
|
||||
{item.target_hash ? (
|
||||
<span className="font-mono text-[11px] break-all text-muted-foreground">{item.target_hash}</span>
|
||||
) : null}
|
||||
</div>
|
||||
{item.reason ? (
|
||||
<div className="mt-2 text-xs text-muted-foreground break-words">
|
||||
{item.reason}
|
||||
</div>
|
||||
) : null}
|
||||
{item.before_payload && Object.keys(item.before_payload).length > 0 ? (
|
||||
<div className="mt-2 text-xs break-words">
|
||||
<span className="font-medium">Before:</span>
|
||||
<span className="text-muted-foreground">{summarizeFeedbackActionPayload(item.before_payload)}</span>
|
||||
</div>
|
||||
) : null}
|
||||
{item.after_payload && Object.keys(item.after_payload).length > 0 ? (
|
||||
<div className="mt-1 text-xs break-words">
|
||||
<span className="font-medium">After:</span>
|
||||
<span className="text-muted-foreground">{summarizeFeedbackActionPayload(item.after_payload)}</span>
|
||||
</div>
|
||||
) : null}
|
||||
<div className="mt-2 text-[11px] text-muted-foreground">
|
||||
{formatDeleteOperationTime(item.created_at)}
|
||||
</div>
|
||||
</div>
|
||||
)) : (
|
||||
<div className="rounded-lg border border-dashed bg-muted/20 p-6 text-center text-sm text-muted-foreground">
|
||||
{selectedFeedbackActionLogs.length > 0 ? '当前筛选条件下没有动作日志' : '当前任务没有动作日志'}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => setFeedbackActionLogPage((current) => Math.max(1, current - 1))}
|
||||
disabled={feedbackActionLogPage <= 1}
|
||||
>
|
||||
上一页
|
||||
</Button>
|
||||
<div className="text-xs text-muted-foreground">支持按动作类型、hash 和摘要检索</div>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => setFeedbackActionLogPage((current) => Math.min(feedbackActionLogPageCount, current + 1))}
|
||||
disabled={feedbackActionLogPage >= feedbackActionLogPageCount}
|
||||
>
|
||||
下一页
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex min-h-[360px] items-center justify-center rounded-lg border border-dashed bg-background/40 p-6 text-center text-sm text-muted-foreground">
|
||||
当前没有可查看的纠错详情
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => setFeedbackPage((current) => Math.max(1, current - 1))}
|
||||
disabled={feedbackPage <= 1}
|
||||
>
|
||||
上一页
|
||||
</Button>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
支持按 query、任务状态和回退状态检索
|
||||
</div>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => setFeedbackPage((current) => Math.min(feedbackPageCount, current + 1))}
|
||||
disabled={feedbackPage >= feedbackPageCount}
|
||||
>
|
||||
下一页
|
||||
</Button>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
</TabsContent>
|
||||
</Tabs>
|
||||
</div>
|
||||
</div>
|
||||
@@ -3332,6 +3980,52 @@ export function KnowledgeBasePage() {
|
||||
onExecute={() => void executePendingDelete()}
|
||||
onRestore={() => void (deleteResult?.operation_id ? restoreDeleteOperation(deleteResult.operation_id) : Promise.resolve())}
|
||||
/>
|
||||
|
||||
<Dialog open={feedbackRollbackDialogOpen} onOpenChange={setFeedbackRollbackDialogOpen}>
|
||||
<DialogContent className="max-w-lg" confirmOnEnter>
|
||||
<DialogHeader>
|
||||
<DialogTitle>回退本次纠错</DialogTitle>
|
||||
<DialogDescription>
|
||||
这会恢复旧 relation 状态、隐藏纠错写入段落,并重新触发 episode/profile 的异步修复。
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<div className="space-y-3">
|
||||
<div className="rounded-lg border bg-muted/20 p-3 text-sm">
|
||||
<div className="font-medium break-words">{selectedFeedbackResolved?.query_text || '无查询文本'}</div>
|
||||
<div className="mt-1 font-mono text-[11px] break-all text-muted-foreground">
|
||||
{selectedFeedbackResolved?.query_tool_id}
|
||||
</div>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="feedback-rollback-reason">回退原因</Label>
|
||||
<Textarea
|
||||
id="feedback-rollback-reason"
|
||||
value={feedbackRollbackReason}
|
||||
onChange={(event) => setFeedbackRollbackReason(event.target.value)}
|
||||
placeholder="可选,建议填写本次人工回退原因"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<DialogFooter>
|
||||
<Button variant="outline" onClick={() => setFeedbackRollbackDialogOpen(false)} disabled={feedbackRollingBack}>
|
||||
取消
|
||||
</Button>
|
||||
<Button onClick={() => void executeFeedbackRollback()} disabled={feedbackRollingBack}>
|
||||
{feedbackRollingBack ? (
|
||||
<>
|
||||
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
|
||||
回退中
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<RotateCcw className="mr-2 h-4 w-4" />
|
||||
确认回退
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,418 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import pickle
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import Session, create_engine
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
IMPORT_ERROR: str | None = None
|
||||
|
||||
try:
|
||||
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
|
||||
from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
|
||||
from src.A_memorix.core.utils import summary_importer as summary_importer_module
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
from src.common.database import database as database_module
|
||||
from src.common.database.migrations import create_database_migration_bootstrapper
|
||||
from src.common.message_repository import count_messages
|
||||
from src.config.model_configs import TaskConfig
|
||||
from src.services import memory_flow_service as memory_flow_service_module
|
||||
from src.services import memory_service as memory_service_module
|
||||
from src.services import send_service
|
||||
except SystemExit as exc:
|
||||
IMPORT_ERROR = f"config initialization exited during import: {exc}"
|
||||
kernel_module = None # type: ignore[assignment]
|
||||
SDKMemoryKernel = None # type: ignore[assignment]
|
||||
summary_importer_module = None # type: ignore[assignment]
|
||||
BotChatSession = None # type: ignore[assignment]
|
||||
SessionMessage = None # type: ignore[assignment]
|
||||
MessageInfo = None # type: ignore[assignment]
|
||||
UserInfo = None # type: ignore[assignment]
|
||||
MessageSequence = None # type: ignore[assignment]
|
||||
TextComponent = None # type: ignore[assignment]
|
||||
database_module = None # type: ignore[assignment]
|
||||
create_database_migration_bootstrapper = None # type: ignore[assignment]
|
||||
count_messages = None # type: ignore[assignment]
|
||||
TaskConfig = None # type: ignore[assignment]
|
||||
memory_flow_service_module = None # type: ignore[assignment]
|
||||
memory_service_module = None # type: ignore[assignment]
|
||||
send_service = None # type: ignore[assignment]
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
|
||||
|
||||
|
||||
class _FakeEmbeddingManager:
|
||||
def __init__(self, dimension: int = 8) -> None:
|
||||
self.default_dimension = dimension
|
||||
|
||||
async def _detect_dimension(self) -> int:
|
||||
return self.default_dimension
|
||||
|
||||
async def encode(self, text: Any) -> np.ndarray:
|
||||
def _encode_one(raw: Any) -> np.ndarray:
|
||||
content = str(raw or "")
|
||||
vector = np.zeros(self.default_dimension, dtype=np.float32)
|
||||
for index, byte in enumerate(content.encode("utf-8")):
|
||||
vector[index % self.default_dimension] += float((byte % 17) + 1)
|
||||
norm = float(np.linalg.norm(vector))
|
||||
if norm > 0:
|
||||
vector /= norm
|
||||
return vector
|
||||
|
||||
if isinstance(text, (list, tuple)):
|
||||
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
|
||||
return _encode_one(text).astype(np.float32)
|
||||
|
||||
|
||||
class _KernelBackedRuntimeManager:
|
||||
def __init__(self, kernel: SDKMemoryKernel) -> None:
|
||||
self.kernel = kernel
|
||||
|
||||
async def invoke(
|
||||
self,
|
||||
component_name: str,
|
||||
args: Dict[str, Any] | None,
|
||||
*,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Any:
|
||||
del timeout_ms
|
||||
payload = args or {}
|
||||
handler = getattr(self.kernel, component_name)
|
||||
result = handler(**payload)
|
||||
return await result if inspect.isawaitable(result) else result
|
||||
|
||||
|
||||
class _NoopRuntimeManager:
|
||||
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> Any:
|
||||
del hook_name
|
||||
return SimpleNamespace(aborted=False, kwargs=kwargs)
|
||||
|
||||
|
||||
class _FakePlatformIOManager:
|
||||
def __init__(self) -> None:
|
||||
self.ensure_calls = 0
|
||||
|
||||
async def ensure_send_pipeline_ready(self) -> None:
|
||||
self.ensure_calls += 1
|
||||
|
||||
def build_route_key_from_message(self, message: Any) -> Any:
|
||||
del message
|
||||
return SimpleNamespace(platform="qq")
|
||||
|
||||
async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any:
|
||||
del message, metadata
|
||||
return SimpleNamespace(
|
||||
has_success=True,
|
||||
sent_receipts=[
|
||||
SimpleNamespace(
|
||||
driver_id="plugin.qq.sender",
|
||||
external_message_id="real-message-id",
|
||||
metadata={},
|
||||
)
|
||||
],
|
||||
failed_receipts=[],
|
||||
route_key=route_key,
|
||||
)
|
||||
|
||||
|
||||
def _install_temp_main_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
db_dir = (tmp_path / "main_db").resolve()
|
||||
db_dir.mkdir(parents=True, exist_ok=True)
|
||||
db_file = db_dir / "MaiBot.db"
|
||||
database_url = f"sqlite:///{db_file}"
|
||||
|
||||
try:
|
||||
database_module.engine.dispose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
engine = create_engine(
|
||||
database_url,
|
||||
echo=False,
|
||||
connect_args={"check_same_thread": False},
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
session_local = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=engine,
|
||||
class_=Session,
|
||||
)
|
||||
bootstrapper = create_database_migration_bootstrapper(engine)
|
||||
|
||||
monkeypatch.setattr(database_module, "_DB_DIR", db_dir, raising=False)
|
||||
monkeypatch.setattr(database_module, "_DB_FILE", db_file, raising=False)
|
||||
monkeypatch.setattr(database_module, "DATABASE_URL", database_url, raising=False)
|
||||
monkeypatch.setattr(database_module, "engine", engine, raising=False)
|
||||
monkeypatch.setattr(database_module, "SessionLocal", session_local, raising=False)
|
||||
monkeypatch.setattr(database_module, "_migration_bootstrapper", bootstrapper, raising=False)
|
||||
monkeypatch.setattr(database_module, "_db_initialized", False, raising=False)
|
||||
|
||||
|
||||
def _build_incoming_message(
|
||||
*,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
text: str,
|
||||
timestamp: datetime | None = None,
|
||||
) -> SessionMessage:
|
||||
message = SessionMessage(
|
||||
message_id="incoming-message-id",
|
||||
timestamp=timestamp or datetime.now(),
|
||||
platform="qq",
|
||||
)
|
||||
message.message_info = MessageInfo(
|
||||
user_info=UserInfo(
|
||||
user_id=user_id,
|
||||
user_nickname="测试用户",
|
||||
user_cardname="测试用户",
|
||||
),
|
||||
additional_config={},
|
||||
)
|
||||
message.raw_message = MessageSequence(components=[TextComponent(text=text)])
|
||||
message.session_id = session_id
|
||||
message.reply_to = None
|
||||
message.is_mentioned = False
|
||||
message.is_at = False
|
||||
message.is_emoji = False
|
||||
message.is_picture = False
|
||||
message.is_command = False
|
||||
message.is_notify = False
|
||||
message.processed_plain_text = text
|
||||
message.display_message = text
|
||||
message.initialized = True
|
||||
return message
|
||||
|
||||
|
||||
async def _wait_until(
|
||||
predicate: Callable[[], Any],
|
||||
*,
|
||||
timeout_seconds: float = 10.0,
|
||||
interval_seconds: float = 0.05,
|
||||
description: str,
|
||||
) -> Any:
|
||||
deadline = asyncio.get_running_loop().time() + max(0.5, float(timeout_seconds))
|
||||
while asyncio.get_running_loop().time() < deadline:
|
||||
value = predicate()
|
||||
if inspect.isawaitable(value):
|
||||
value = await value
|
||||
if value:
|
||||
return value
|
||||
await asyncio.sleep(interval_seconds)
|
||||
raise AssertionError(f"等待超时: {description}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_stream_triggers_real_chat_summary_writeback(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
_install_temp_main_database(monkeypatch, tmp_path)
|
||||
|
||||
fake_embedding_manager = _FakeEmbeddingManager()
|
||||
captured_prompts: List[str] = []
|
||||
fixed_send_timestamp = 1_777_000_000.0
|
||||
|
||||
async def _fake_runtime_self_check(**kwargs: Any) -> Dict[str, Any]:
|
||||
del kwargs
|
||||
return {
|
||||
"ok": True,
|
||||
"message": "ok",
|
||||
"configured_dimension": fake_embedding_manager.default_dimension,
|
||||
"requested_dimension": fake_embedding_manager.default_dimension,
|
||||
"vector_store_dimension": fake_embedding_manager.default_dimension,
|
||||
"detected_dimension": fake_embedding_manager.default_dimension,
|
||||
"encoded_dimension": fake_embedding_manager.default_dimension,
|
||||
"elapsed_ms": 0.0,
|
||||
"sample_text": "test",
|
||||
"checked_at": datetime.now().timestamp(),
|
||||
}
|
||||
|
||||
async def _fake_generate(request: Any) -> Any:
|
||||
captured_prompts.append(str(getattr(request, "prompt", "") or ""))
|
||||
return SimpleNamespace(
|
||||
success=True,
|
||||
completion=SimpleNamespace(
|
||||
response=json.dumps(
|
||||
{
|
||||
"summary": "这段对话记录了用户提到自己买了绿色围巾,机器人表示会记住这件事。",
|
||||
"entities": ["绿色围巾"],
|
||||
"relations": [],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
kernel_module,
|
||||
"create_embedding_api_adapter",
|
||||
lambda **kwargs: fake_embedding_manager,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
kernel_module,
|
||||
"run_embedding_runtime_self_check",
|
||||
_fake_runtime_self_check,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
summary_importer_module,
|
||||
"run_embedding_runtime_self_check",
|
||||
_fake_runtime_self_check,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
summary_importer_module.llm_api,
|
||||
"get_available_models",
|
||||
lambda: {"utils": TaskConfig(model_list=["fake-summary-model"])},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
summary_importer_module.llm_api,
|
||||
"resolve_task_name_from_model_config",
|
||||
lambda model_config: "utils",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
summary_importer_module.llm_api,
|
||||
"generate",
|
||||
_fake_generate,
|
||||
)
|
||||
monkeypatch.setattr(send_service.time, "time", lambda: fixed_send_timestamp)
|
||||
monkeypatch.setattr(summary_importer_module.time, "time", lambda: fixed_send_timestamp)
|
||||
|
||||
kernel = SDKMemoryKernel(
|
||||
plugin_root=tmp_path / "plugin_root",
|
||||
config={
|
||||
"storage": {"data_dir": str((tmp_path / "a_memorix_data").resolve())},
|
||||
"advanced": {"enable_auto_save": False},
|
||||
"embedding": {"dimension": fake_embedding_manager.default_dimension},
|
||||
"memory": {"base_decay_interval_hours": 24},
|
||||
"person_profile": {"refresh_interval_minutes": 5},
|
||||
"summarization": {"model_name": ["utils"]},
|
||||
},
|
||||
)
|
||||
|
||||
service = memory_flow_service_module.MemoryAutomationService()
|
||||
fake_platform_io_manager = _FakePlatformIOManager()
|
||||
|
||||
async def _fake_rebuild_episodes_for_sources(sources: List[str]) -> Dict[str, Any]:
|
||||
return {
|
||||
"rebuilt": 0,
|
||||
"items": [],
|
||||
"failures": [],
|
||||
"sources": list(sources),
|
||||
}
|
||||
|
||||
monkeypatch.setattr(kernel, "rebuild_episodes_for_sources", _fake_rebuild_episodes_for_sources)
|
||||
monkeypatch.setattr(
|
||||
memory_service_module,
|
||||
"a_memorix_host_service",
|
||||
_KernelBackedRuntimeManager(kernel),
|
||||
)
|
||||
monkeypatch.setattr(memory_flow_service_module, "memory_automation_service", service)
|
||||
monkeypatch.setattr(send_service, "_get_runtime_manager", lambda: _NoopRuntimeManager())
|
||||
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_platform_io_manager)
|
||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
|
||||
monkeypatch.setattr(
|
||||
send_service._chat_manager,
|
||||
"get_session_by_session_id",
|
||||
lambda stream_id: (
|
||||
BotChatSession(
|
||||
session_id="test-session",
|
||||
platform="qq",
|
||||
user_id="target-user",
|
||||
group_id=None,
|
||||
)
|
||||
if stream_id == "test-session"
|
||||
else None
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
memory_flow_service_module.global_config.memory,
|
||||
"chat_summary_writeback_enabled",
|
||||
True,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
memory_flow_service_module.global_config.memory,
|
||||
"chat_summary_writeback_message_threshold",
|
||||
2,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
memory_flow_service_module.global_config.memory,
|
||||
"chat_summary_writeback_context_length",
|
||||
10,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
memory_flow_service_module.global_config.memory,
|
||||
"person_fact_writeback_enabled",
|
||||
False,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
await kernel.initialize()
|
||||
|
||||
try:
|
||||
incoming_message = _build_incoming_message(
|
||||
session_id="test-session",
|
||||
user_id="target-user",
|
||||
text="我最近买了一条绿色围巾。",
|
||||
timestamp=datetime.fromtimestamp(fixed_send_timestamp) - timedelta(seconds=1),
|
||||
)
|
||||
with database_module.get_db_session() as session:
|
||||
session.add(incoming_message.to_db_instance())
|
||||
|
||||
sent_message = await send_service.text_to_stream_with_message(
|
||||
text="好的,我会记住你最近买了绿色围巾。",
|
||||
stream_id="test-session",
|
||||
storage_message=True,
|
||||
)
|
||||
|
||||
assert sent_message is not None
|
||||
assert sent_message.message_id == "real-message-id"
|
||||
assert fake_platform_io_manager.ensure_calls == 1
|
||||
assert count_messages(session_id="test-session") == 2
|
||||
|
||||
paragraphs = await _wait_until(
|
||||
lambda: kernel.metadata_store.get_paragraphs_by_source("chat_summary:test-session"),
|
||||
description="等待聊天摘要写回到 A_memorix",
|
||||
)
|
||||
|
||||
assert captured_prompts
|
||||
assert "我最近买了一条绿色围巾。" in captured_prompts[-1]
|
||||
assert "好的,我会记住你最近买了绿色围巾。" in captured_prompts[-1]
|
||||
assert any("绿色围巾" in str(item.get("content", "") or "") for item in paragraphs)
|
||||
assert any(
|
||||
int(
|
||||
(
|
||||
pickle.loads(item.get("metadata"))
|
||||
if isinstance(item.get("metadata"), (bytes, bytearray))
|
||||
else item.get("metadata")
|
||||
or {}
|
||||
).get("trigger_message_count", 0)
|
||||
or 0
|
||||
)
|
||||
== 2
|
||||
for item in paragraphs
|
||||
)
|
||||
assert service.chat_summary_writeback._states["test-session"].last_trigger_message_count == 2
|
||||
finally:
|
||||
await service.shutdown()
|
||||
await kernel.shutdown()
|
||||
try:
|
||||
database_module.engine.dispose()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -164,3 +164,28 @@ async def test_runtime_self_check_reports_requested_dimension_without_explicit_o
|
||||
assert report["detected_dimension"] == 384
|
||||
assert report["encoded_dimension"] == 384
|
||||
assert manager.encode_calls == ["A_Memorix runtime self check"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_encode_batch_keeps_batch_local_indexes_when_cache_hits_previous_batch(monkeypatch):
|
||||
adapter = EmbeddingAPIAdapter(default_dimension=4, enable_cache=True)
|
||||
adapter._dimension = 4
|
||||
adapter._dimension_detected = True
|
||||
|
||||
async def fake_detect_dimension() -> int:
|
||||
return 4
|
||||
|
||||
async def fake_get_embedding_direct(text: str, dimensions: int | None = None):
|
||||
del dimensions
|
||||
base = float(ord(str(text)[0]))
|
||||
return [base, base + 1.0, base + 2.0, base + 3.0]
|
||||
|
||||
monkeypatch.setattr(adapter, "_detect_dimension", fake_detect_dimension)
|
||||
monkeypatch.setattr(adapter, "_get_embedding_direct", fake_get_embedding_direct)
|
||||
|
||||
embeddings = await adapter.encode(["A", "B", "A", "C"], batch_size=2)
|
||||
|
||||
assert embeddings.shape == (4, 4)
|
||||
assert np.array_equal(embeddings[0], embeddings[2])
|
||||
assert embeddings[1][0] == float(ord("B"))
|
||||
assert embeddings[3][0] == float(ord("C"))
|
||||
|
||||
745
pytests/A_memorix_test/test_feedback_correction_chat_flow.py
Normal file
745
pytests/A_memorix_test/test_feedback_correction_chat_flow.py
Normal file
@@ -0,0 +1,745 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import Session, create_engine, select
|
||||
|
||||
IMPORT_ERROR: str | None = None
|
||||
|
||||
try:
|
||||
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
|
||||
from src.A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
||||
from src.chat.heart_flow.heartflow_manager import heartflow_manager
|
||||
from src.chat.message_receive import bot as bot_module
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
from src.common.database import database as database_module
|
||||
from src.common.database.database_model import PersonInfo, ToolRecord
|
||||
from src.common.database.migrations import create_database_migration_bootstrapper
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
from src.maisaka import reasoning_engine as reasoning_engine_module
|
||||
from src.maisaka import runtime as runtime_module
|
||||
from src.maisaka.chat_loop_service import ChatResponse
|
||||
from src.maisaka.context_messages import AssistantMessage
|
||||
from src.plugin_runtime import component_query as component_query_module
|
||||
from src.services import memory_flow_service as memory_flow_service_module
|
||||
from src.services import memory_service as memory_service_module
|
||||
from src.services.memory_service import memory_service
|
||||
except SystemExit as exc:
|
||||
IMPORT_ERROR = f"config initialization exited during import: {exc}"
|
||||
kernel_module = None # type: ignore[assignment]
|
||||
KernelSearchRequest = None # type: ignore[assignment]
|
||||
SDKMemoryKernel = None # type: ignore[assignment]
|
||||
heartflow_manager = None # type: ignore[assignment]
|
||||
bot_module = None # type: ignore[assignment]
|
||||
chat_manager = None # type: ignore[assignment]
|
||||
chat_bot = None # type: ignore[assignment]
|
||||
database_module = None # type: ignore[assignment]
|
||||
ToolRecord = None # type: ignore[assignment]
|
||||
PersonInfo = None # type: ignore[assignment]
|
||||
create_database_migration_bootstrapper = None # type: ignore[assignment]
|
||||
SessionUtils = None # type: ignore[assignment]
|
||||
ToolCall = None # type: ignore[assignment]
|
||||
reasoning_engine_module = None # type: ignore[assignment]
|
||||
runtime_module = None # type: ignore[assignment]
|
||||
ChatResponse = None # type: ignore[assignment]
|
||||
AssistantMessage = None # type: ignore[assignment]
|
||||
component_query_module = None # type: ignore[assignment]
|
||||
memory_flow_service_module = None # type: ignore[assignment]
|
||||
memory_service_module = None # type: ignore[assignment]
|
||||
memory_service = None # type: ignore[assignment]
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
|
||||
|
||||
RELATION_QUERY = "测试用户 和 最喜欢的颜色 有什么关系"
|
||||
|
||||
|
||||
class _FakeEmbeddingManager:
|
||||
def __init__(self, dimension: int = 8) -> None:
|
||||
self.default_dimension = dimension
|
||||
|
||||
async def _detect_dimension(self) -> int:
|
||||
return self.default_dimension
|
||||
|
||||
async def encode(self, text: Any) -> np.ndarray:
|
||||
def _encode_one(raw: Any) -> np.ndarray:
|
||||
content = str(raw or "")
|
||||
vector = np.zeros(self.default_dimension, dtype=np.float32)
|
||||
for index, byte in enumerate(content.encode("utf-8")):
|
||||
vector[index % self.default_dimension] += float((byte % 17) + 1)
|
||||
norm = float(np.linalg.norm(vector))
|
||||
if norm > 0:
|
||||
vector /= norm
|
||||
return vector
|
||||
|
||||
if isinstance(text, (list, tuple)):
|
||||
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
|
||||
return _encode_one(text).astype(np.float32)
|
||||
|
||||
|
||||
class _KernelBackedRuntimeManager:
|
||||
def __init__(self, kernel: SDKMemoryKernel) -> None:
|
||||
self.kernel = kernel
|
||||
|
||||
async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000):
|
||||
del timeout_ms
|
||||
payload = args or {}
|
||||
if component_name == "search_memory":
|
||||
return await self.kernel.search_memory(
|
||||
KernelSearchRequest(
|
||||
query=str(payload.get("query", "") or ""),
|
||||
limit=int(payload.get("limit", 5) or 5),
|
||||
mode=str(payload.get("mode", "hybrid") or "hybrid"),
|
||||
chat_id=str(payload.get("chat_id", "") or ""),
|
||||
person_id=str(payload.get("person_id", "") or ""),
|
||||
time_start=payload.get("time_start"),
|
||||
time_end=payload.get("time_end"),
|
||||
respect_filter=bool(payload.get("respect_filter", True)),
|
||||
user_id=str(payload.get("user_id", "") or ""),
|
||||
group_id=str(payload.get("group_id", "") or ""),
|
||||
)
|
||||
)
|
||||
|
||||
handler = getattr(self.kernel, component_name)
|
||||
result = handler(**payload)
|
||||
return await result if inspect.isawaitable(result) else result
|
||||
|
||||
|
||||
class _NoopRuntimeManager:
|
||||
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> Any:
|
||||
del hook_name
|
||||
return SimpleNamespace(aborted=False, kwargs=kwargs)
|
||||
|
||||
|
||||
def _install_temp_main_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
db_dir = (tmp_path / "main_db").resolve()
|
||||
db_dir.mkdir(parents=True, exist_ok=True)
|
||||
db_file = db_dir / "MaiBot.db"
|
||||
database_url = f"sqlite:///{db_file}"
|
||||
|
||||
try:
|
||||
database_module.engine.dispose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
engine = create_engine(
|
||||
database_url,
|
||||
echo=False,
|
||||
connect_args={"check_same_thread": False},
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
session_local = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=engine,
|
||||
class_=Session,
|
||||
)
|
||||
bootstrapper = create_database_migration_bootstrapper(engine)
|
||||
|
||||
monkeypatch.setattr(database_module, "_DB_DIR", db_dir, raising=False)
|
||||
monkeypatch.setattr(database_module, "_DB_FILE", db_file, raising=False)
|
||||
monkeypatch.setattr(database_module, "DATABASE_URL", database_url, raising=False)
|
||||
monkeypatch.setattr(database_module, "engine", engine, raising=False)
|
||||
monkeypatch.setattr(database_module, "SessionLocal", session_local, raising=False)
|
||||
monkeypatch.setattr(database_module, "_migration_bootstrapper", bootstrapper, raising=False)
|
||||
monkeypatch.setattr(database_module, "_db_initialized", False, raising=False)
|
||||
|
||||
|
||||
def _build_chat_response(content: str, tool_calls: list[ToolCall]) -> ChatResponse:
|
||||
return ChatResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
request_messages=[],
|
||||
raw_message=AssistantMessage(
|
||||
content=content,
|
||||
timestamp=datetime.now(),
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
selected_history_count=0,
|
||||
tool_count=len(tool_calls),
|
||||
prompt_tokens=0,
|
||||
built_message_count=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
prompt_section=None,
|
||||
)
|
||||
|
||||
|
||||
def _build_message_data(
|
||||
*,
|
||||
content: str,
|
||||
platform: str,
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
group_id: str,
|
||||
group_name: str,
|
||||
) -> Dict[str, Any]:
|
||||
message_id = str(uuid.uuid4())
|
||||
return {
|
||||
"message_info": {
|
||||
"platform": platform,
|
||||
"message_id": message_id,
|
||||
"time": time.time(),
|
||||
"group_info": {
|
||||
"group_id": group_id,
|
||||
"group_name": group_name,
|
||||
"platform": platform,
|
||||
},
|
||||
"user_info": {
|
||||
"user_id": user_id,
|
||||
"user_nickname": user_name,
|
||||
"user_cardname": user_name,
|
||||
"platform": platform,
|
||||
},
|
||||
"additional_config": {
|
||||
"at_bot": True,
|
||||
},
|
||||
},
|
||||
"message_segment": {
|
||||
"type": "seglist",
|
||||
"data": [
|
||||
{
|
||||
"type": "text",
|
||||
"data": content,
|
||||
},
|
||||
],
|
||||
},
|
||||
"raw_message": content,
|
||||
"processed_plain_text": content,
|
||||
}
|
||||
|
||||
|
||||
async def _wait_until(
|
||||
predicate: Callable[[], Any],
|
||||
*,
|
||||
timeout_seconds: float = 10.0,
|
||||
interval_seconds: float = 0.05,
|
||||
description: str,
|
||||
) -> Any:
|
||||
deadline = asyncio.get_running_loop().time() + max(0.5, float(timeout_seconds))
|
||||
while asyncio.get_running_loop().time() < deadline:
|
||||
value = predicate()
|
||||
if inspect.isawaitable(value):
|
||||
value = await value
|
||||
if value:
|
||||
return value
|
||||
await asyncio.sleep(interval_seconds)
|
||||
raise AssertionError(f"等待超时: {description}")
|
||||
|
||||
|
||||
def _load_feedback_tasks(kernel: SDKMemoryKernel) -> list[Dict[str, Any]]:
|
||||
assert kernel.metadata_store is not None
|
||||
cursor = kernel.metadata_store.get_connection().cursor()
|
||||
rows = cursor.execute(
|
||||
"SELECT query_tool_id FROM memory_feedback_tasks ORDER BY id"
|
||||
).fetchall()
|
||||
tasks: list[Dict[str, Any]] = []
|
||||
for row in rows:
|
||||
task = kernel.metadata_store.get_feedback_task(str(row["query_tool_id"] or ""))
|
||||
if task is not None:
|
||||
tasks.append(task)
|
||||
return tasks
|
||||
|
||||
|
||||
def _load_feedback_action_types(kernel: SDKMemoryKernel) -> list[str]:
|
||||
assert kernel.metadata_store is not None
|
||||
cursor = kernel.metadata_store.get_connection().cursor()
|
||||
rows = cursor.execute(
|
||||
"SELECT action_type FROM memory_feedback_action_logs ORDER BY id"
|
||||
).fetchall()
|
||||
return [str(row["action_type"] or "") for row in rows]
|
||||
|
||||
|
||||
def _load_query_memory_tool_records(session_id: str) -> list[Dict[str, Any]]:
|
||||
with database_module.get_db_session() as session:
|
||||
statement = (
|
||||
select(ToolRecord)
|
||||
.where(ToolRecord.session_id == session_id)
|
||||
.where(ToolRecord.tool_name == "query_memory")
|
||||
.order_by(ToolRecord.timestamp)
|
||||
)
|
||||
rows = list(session.exec(statement).all())
|
||||
return [
|
||||
{
|
||||
"tool_id": str(row.tool_id or ""),
|
||||
"session_id": str(row.session_id or ""),
|
||||
"tool_name": str(row.tool_name or ""),
|
||||
"tool_data": str(row.tool_data or ""),
|
||||
"timestamp": row.timestamp,
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
|
||||
def _seed_person_info(*, person_id: str, person_name: str, session_info: Dict[str, Any]) -> None:
|
||||
with database_module.get_db_session() as session:
|
||||
session.add(
|
||||
PersonInfo(
|
||||
is_known=True,
|
||||
person_id=person_id,
|
||||
person_name=person_name,
|
||||
platform=str(session_info["platform"]),
|
||||
user_id=str(session_info["user_id"]),
|
||||
user_nickname=str(session_info["user_name"]),
|
||||
group_cardname=json.dumps(
|
||||
[{"group_id": str(session_info["group_id"]), "group_cardname": person_name}],
|
||||
ensure_ascii=False,
|
||||
),
|
||||
know_counts=1,
|
||||
first_known_time=datetime.now(),
|
||||
last_known_time=datetime.now(),
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def chat_feedback_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
|
||||
_install_temp_main_database(monkeypatch, tmp_path)
|
||||
|
||||
chat_manager.sessions.clear()
|
||||
chat_manager.last_messages.clear()
|
||||
heartflow_manager.heartflow_chat_list.clear()
|
||||
|
||||
noop_runtime_manager = _NoopRuntimeManager()
|
||||
monkeypatch.setattr(bot_module.ChatBot, "_get_runtime_manager", staticmethod(lambda: noop_runtime_manager))
|
||||
monkeypatch.setattr(
|
||||
component_query_module.component_query_service,
|
||||
"find_command_by_text",
|
||||
lambda text: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
component_query_module.component_query_service,
|
||||
"get_llm_available_tool_specs",
|
||||
lambda: {},
|
||||
)
|
||||
monkeypatch.setattr(runtime_module.global_config.mcp, "enable", False, raising=False)
|
||||
monkeypatch.setattr(
|
||||
runtime_module.MaisakaHeartFlowChatting,
|
||||
"_get_message_trigger_threshold",
|
||||
lambda self: 1,
|
||||
)
|
||||
|
||||
async def _noop_on_incoming_message(message: Any) -> None:
|
||||
del message
|
||||
|
||||
monkeypatch.setattr(
|
||||
memory_flow_service_module.memory_automation_service,
|
||||
"on_incoming_message",
|
||||
_noop_on_incoming_message,
|
||||
)
|
||||
|
||||
fake_embedding_manager = _FakeEmbeddingManager(dimension=8)
|
||||
|
||||
async def _fake_runtime_self_check(
|
||||
*,
|
||||
config: Any,
|
||||
sample_text: str,
|
||||
vector_store: Any,
|
||||
embedding_manager: Any,
|
||||
) -> Dict[str, Any]:
|
||||
del config, sample_text, vector_store, embedding_manager
|
||||
return {
|
||||
"ok": True,
|
||||
"message": "ok",
|
||||
"checked_at": time.time(),
|
||||
"encoded_dimension": fake_embedding_manager.default_dimension,
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
kernel_module,
|
||||
"create_embedding_api_adapter",
|
||||
lambda **kwargs: fake_embedding_manager,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
kernel_module,
|
||||
"run_embedding_runtime_self_check",
|
||||
_fake_runtime_self_check,
|
||||
)
|
||||
|
||||
kernel = SDKMemoryKernel(
|
||||
plugin_root=tmp_path / "plugin_root",
|
||||
config={
|
||||
"storage": {"data_dir": str((tmp_path / "a_memorix_data").resolve())},
|
||||
"advanced": {"enable_auto_save": False},
|
||||
"embedding": {"dimension": fake_embedding_manager.default_dimension},
|
||||
"memory": {"base_decay_interval_hours": 24},
|
||||
"person_profile": {"refresh_interval_minutes": 5},
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_enabled", lambda: True)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_window_hours", lambda: 0.0004)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_check_interval_seconds", lambda: 0.2)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_batch_size", lambda: 10)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_max_messages", lambda: 10)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_auto_apply_threshold", lambda: 0.85)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_prefilter_enabled", lambda: True)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_paragraph_mark_enabled", lambda: True)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_paragraph_hard_filter_enabled", lambda: True)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_profile_refresh_enabled", lambda: True)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_profile_force_refresh_on_read", lambda: True)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_episode_rebuild_enabled", lambda: True)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_episode_query_block_enabled", lambda: True)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_reconcile_interval_seconds", lambda: 0.2)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_reconcile_batch_size", lambda: 10)
|
||||
|
||||
monkeypatch.setattr(kernel_module.global_config.memory, "feedback_correction_paragraph_hard_filter_enabled", True, raising=False)
|
||||
monkeypatch.setattr(kernel_module.global_config.memory, "feedback_correction_episode_query_block_enabled", True, raising=False)
|
||||
|
||||
async def _fake_classify_feedback(
|
||||
*,
|
||||
query_tool_id: str,
|
||||
query_text: str,
|
||||
hit_briefs: list[Dict[str, Any]],
|
||||
feedback_messages: list[str],
|
||||
) -> Dict[str, Any]:
|
||||
del query_tool_id, query_text, feedback_messages
|
||||
target_hash = ""
|
||||
for item in hit_briefs:
|
||||
if str(item.get("type", "") or "").strip() == "relation":
|
||||
target_hash = str(item.get("hash", "") or "").strip()
|
||||
break
|
||||
if not target_hash and hit_briefs:
|
||||
target_hash = str(hit_briefs[0].get("hash", "") or "").strip()
|
||||
return {
|
||||
"decision": "correct",
|
||||
"confidence": 0.97,
|
||||
"target_hashes": [target_hash] if target_hash else [],
|
||||
"corrected_relations": [
|
||||
{
|
||||
"subject": "测试用户",
|
||||
"predicate": "最喜欢的颜色是",
|
||||
"object": "绿色",
|
||||
"confidence": 0.99,
|
||||
}
|
||||
],
|
||||
"reason": "用户明确纠正为绿色",
|
||||
}
|
||||
|
||||
monkeypatch.setattr(kernel, "_classify_feedback", _fake_classify_feedback)
|
||||
|
||||
await kernel.initialize()
|
||||
async def _force_episode_fallback(**kwargs: Any) -> Dict[str, Any]:
|
||||
raise RuntimeError("force_fallback_for_test")
|
||||
|
||||
monkeypatch.setattr(
|
||||
kernel.episode_service.segmentation_service,
|
||||
"segment",
|
||||
_force_episode_fallback,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
kernel,
|
||||
"process_episode_pending_batch",
|
||||
lambda *, limit=20, max_retry=3: asyncio.sleep(0, result={"processed": 0, "episode_count": 0, "fallback_count": 0, "failed": 0}),
|
||||
)
|
||||
|
||||
host_manager = _KernelBackedRuntimeManager(kernel)
|
||||
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", host_manager)
|
||||
|
||||
planner_calls: list[str] = []
|
||||
|
||||
async def _fake_timing_gate(self, anchor_message: Any):
|
||||
del self, anchor_message
|
||||
return "continue", _build_chat_response("直接进入 planner。", []), [], []
|
||||
|
||||
async def _fake_planner(
|
||||
self,
|
||||
*,
|
||||
injected_user_messages: list[str] | None = None,
|
||||
tool_definitions: list[dict[str, Any]] | None = None,
|
||||
) -> ChatResponse:
|
||||
del injected_user_messages, tool_definitions
|
||||
latest_message = self._runtime.message_cache[-1]
|
||||
latest_text = str(latest_message.processed_plain_text or "")
|
||||
planner_calls.append(latest_text)
|
||||
handled_message_ids = getattr(self._runtime, "_test_query_message_ids", None)
|
||||
if handled_message_ids is None:
|
||||
handled_message_ids = set()
|
||||
setattr(self._runtime, "_test_query_message_ids", handled_message_ids)
|
||||
|
||||
if latest_message.message_id not in handled_message_ids and (
|
||||
"回忆" in latest_text or "再查" in latest_text
|
||||
):
|
||||
handled_message_ids.add(latest_message.message_id)
|
||||
tool_call = ToolCall(
|
||||
call_id=f"query-{uuid.uuid4().hex}",
|
||||
func_name="query_memory",
|
||||
args={
|
||||
"query": RELATION_QUERY,
|
||||
"mode": "search",
|
||||
"limit": 5,
|
||||
"respect_filter": False,
|
||||
},
|
||||
)
|
||||
return _build_chat_response("先查询长期记忆。", [tool_call])
|
||||
|
||||
stop_call = ToolCall(
|
||||
call_id=f"stop-{uuid.uuid4().hex}",
|
||||
func_name="no_reply",
|
||||
args={},
|
||||
)
|
||||
return _build_chat_response("当前轮次结束。", [stop_call])
|
||||
|
||||
monkeypatch.setattr(
|
||||
reasoning_engine_module.MaisakaReasoningEngine,
|
||||
"_run_timing_gate",
|
||||
_fake_timing_gate,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
reasoning_engine_module.MaisakaReasoningEngine,
|
||||
"_run_interruptible_planner",
|
||||
_fake_planner,
|
||||
)
|
||||
|
||||
session_info = {
|
||||
"platform": "unit_test_chat",
|
||||
"user_id": "user_feedback_flow",
|
||||
"user_name": "反馈测试用户",
|
||||
"group_id": "group_feedback_flow",
|
||||
"group_name": "反馈纠错测试群",
|
||||
}
|
||||
person_id = "person_feedback_flow"
|
||||
session_id = SessionUtils.calculate_session_id(
|
||||
session_info["platform"],
|
||||
user_id=session_info["user_id"],
|
||||
group_id=session_info["group_id"],
|
||||
)
|
||||
_seed_person_info(person_id=person_id, person_name="测试用户", session_info=session_info)
|
||||
|
||||
try:
|
||||
yield {
|
||||
"kernel": kernel,
|
||||
"session_id": session_id,
|
||||
"session_info": session_info,
|
||||
"person_id": person_id,
|
||||
"planner_calls": planner_calls,
|
||||
}
|
||||
finally:
|
||||
for key, chat in list(heartflow_manager.heartflow_chat_list.items()):
|
||||
try:
|
||||
await chat.stop()
|
||||
except Exception:
|
||||
pass
|
||||
heartflow_manager.heartflow_chat_list.pop(key, None)
|
||||
chat_manager.sessions.clear()
|
||||
chat_manager.last_messages.clear()
|
||||
await kernel.shutdown()
|
||||
try:
|
||||
database_module.engine.dispose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_feedback_correction_real_chat_flow(chat_feedback_env) -> None:
|
||||
kernel = chat_feedback_env["kernel"]
|
||||
session_id = chat_feedback_env["session_id"]
|
||||
session_info = chat_feedback_env["session_info"]
|
||||
person_id = chat_feedback_env["person_id"]
|
||||
|
||||
write_result = await memory_service.ingest_text(
|
||||
external_id=f"test:feedback-seed:{uuid.uuid4().hex}",
|
||||
source_type="chat_summary",
|
||||
text="测试用户 最喜欢的颜色是 蓝色",
|
||||
chat_id=session_id,
|
||||
relations=[
|
||||
{
|
||||
"subject": "测试用户",
|
||||
"predicate": "最喜欢的颜色是",
|
||||
"object": "蓝色",
|
||||
"confidence": 1.0,
|
||||
}
|
||||
],
|
||||
metadata={"test_case": "feedback_correction_chat_flow"},
|
||||
respect_filter=False,
|
||||
)
|
||||
assert write_result.success is True
|
||||
|
||||
pre_search = await memory_service.search(
|
||||
RELATION_QUERY,
|
||||
mode="search",
|
||||
chat_id=session_id,
|
||||
respect_filter=False,
|
||||
)
|
||||
assert pre_search.hits
|
||||
assert any("蓝色" in hit.content for hit in pre_search.hits)
|
||||
|
||||
pre_profile = await memory_service.get_person_profile(person_id, chat_id=session_id, limit=10)
|
||||
pre_profile_text = pre_profile.summary + "\n" + json.dumps(pre_profile.evidence, ensure_ascii=False)
|
||||
assert "蓝色" in pre_profile_text
|
||||
|
||||
seed_source = f"chat_summary:{session_id}"
|
||||
rebuild_result = await kernel.rebuild_episodes_for_sources([seed_source])
|
||||
assert rebuild_result["rebuilt"] >= 1
|
||||
|
||||
pre_episode = await memory_service.search(
|
||||
"蓝色",
|
||||
mode="episode",
|
||||
chat_id=session_id,
|
||||
respect_filter=False,
|
||||
)
|
||||
assert pre_episode.hits
|
||||
assert any("蓝色" in hit.content for hit in pre_episode.hits)
|
||||
|
||||
await chat_bot.message_process(
|
||||
_build_message_data(
|
||||
content="请帮我回忆一下,测试用户最喜欢的颜色是什么?",
|
||||
**session_info,
|
||||
)
|
||||
)
|
||||
|
||||
await _wait_until(
|
||||
lambda: chat_feedback_env["planner_calls"][0] if chat_feedback_env["planner_calls"] else None,
|
||||
description="planner 收到首条聊天消息",
|
||||
)
|
||||
first_query_records = await _wait_until(
|
||||
lambda: _load_query_memory_tool_records(session_id) if _load_query_memory_tool_records(session_id) else None,
|
||||
description="首条 query_memory 工具记录生成",
|
||||
)
|
||||
assert first_query_records
|
||||
|
||||
first_task = await _wait_until(
|
||||
lambda: _load_feedback_tasks(kernel)[0] if _load_feedback_tasks(kernel) else None,
|
||||
description="首个反馈任务入队",
|
||||
)
|
||||
assert first_task["status"] == "pending"
|
||||
first_hits = list((first_task.get("query_snapshot") or {}).get("hits") or [])
|
||||
assert first_hits
|
||||
assert any("蓝色" in str(item.get("content", "") or "") for item in first_hits)
|
||||
|
||||
await chat_bot.message_process(
|
||||
_build_message_data(
|
||||
content="不对,测试用户最喜欢的颜色不是蓝色,是绿色。",
|
||||
**session_info,
|
||||
)
|
||||
)
|
||||
|
||||
finalized_task = await _wait_until(
|
||||
lambda: (
|
||||
kernel.metadata_store.get_feedback_task(first_task["query_tool_id"])
|
||||
if kernel.metadata_store.get_feedback_task(first_task["query_tool_id"])
|
||||
and kernel.metadata_store.get_feedback_task(first_task["query_tool_id"]).get("status")
|
||||
in {"applied", "skipped", "error"}
|
||||
else None
|
||||
),
|
||||
timeout_seconds=12.0,
|
||||
interval_seconds=0.1,
|
||||
description="反馈任务进入终态",
|
||||
)
|
||||
assert finalized_task["status"] == "applied", finalized_task
|
||||
assert finalized_task["decision_payload"]["decision"] == "correct"
|
||||
assert finalized_task["decision_payload"]["apply_result"]["applied"] is True
|
||||
|
||||
corrected_hashes = list(
|
||||
(finalized_task["decision_payload"].get("apply_result") or {}).get("relation_hashes") or []
|
||||
)
|
||||
assert corrected_hashes
|
||||
corrected_hash = str(corrected_hashes[0] or "")
|
||||
relation_status = kernel.metadata_store.get_relation_status_batch([corrected_hash]).get(corrected_hash, {})
|
||||
assert bool(relation_status.get("is_inactive")) is True
|
||||
|
||||
action_types = _load_feedback_action_types(kernel)
|
||||
assert "classification" in action_types
|
||||
assert "forget_relation" in action_types
|
||||
assert "ingest_correction" in action_types
|
||||
assert "mark_stale_paragraph" in action_types
|
||||
assert "enqueue_episode_rebuild" in action_types
|
||||
assert "enqueue_profile_refresh" in action_types
|
||||
|
||||
direct_post_search = await memory_service.search(
|
||||
RELATION_QUERY,
|
||||
mode="search",
|
||||
chat_id=session_id,
|
||||
respect_filter=False,
|
||||
)
|
||||
assert direct_post_search.hits
|
||||
post_contents = "\n".join(hit.content for hit in direct_post_search.hits)
|
||||
assert "绿色" in post_contents
|
||||
assert "蓝色" not in post_contents
|
||||
|
||||
profile_refresh_request = await _wait_until(
|
||||
lambda: (
|
||||
kernel.metadata_store.get_person_profile_refresh_request(person_id)
|
||||
if kernel.metadata_store.get_person_profile_refresh_request(person_id)
|
||||
and kernel.metadata_store.get_person_profile_refresh_request(person_id).get("status") == "done"
|
||||
else None
|
||||
),
|
||||
timeout_seconds=12.0,
|
||||
interval_seconds=0.1,
|
||||
description="人物画像刷新完成",
|
||||
)
|
||||
assert profile_refresh_request["status"] == "done"
|
||||
|
||||
post_profile = await memory_service.get_person_profile(person_id, chat_id=session_id, limit=10)
|
||||
post_profile_text = post_profile.summary + "\n" + json.dumps(post_profile.evidence, ensure_ascii=False)
|
||||
assert "绿色" in post_profile_text
|
||||
assert "蓝色" not in post_profile_text
|
||||
|
||||
async def _latest_episode_result():
|
||||
result = await memory_service.search(
|
||||
"绿色",
|
||||
mode="episode",
|
||||
chat_id=session_id,
|
||||
respect_filter=False,
|
||||
)
|
||||
if not result.hits:
|
||||
return None
|
||||
contents = "\n".join(hit.content for hit in result.hits)
|
||||
if "绿色" in contents and "蓝色" not in contents:
|
||||
return result
|
||||
return None
|
||||
|
||||
post_episode = await _wait_until(
|
||||
_latest_episode_result,
|
||||
timeout_seconds=12.0,
|
||||
interval_seconds=0.2,
|
||||
description="episode 重建后返回修正结果",
|
||||
)
|
||||
assert post_episode is not None
|
||||
|
||||
stale_episode = await memory_service.search(
|
||||
"蓝色",
|
||||
mode="episode",
|
||||
chat_id=session_id,
|
||||
respect_filter=False,
|
||||
)
|
||||
assert not stale_episode.hits
|
||||
|
||||
await chat_bot.message_process(
|
||||
_build_message_data(
|
||||
content="再查一次,测试用户最喜欢的颜色是什么?",
|
||||
**session_info,
|
||||
)
|
||||
)
|
||||
|
||||
tool_records = await _wait_until(
|
||||
lambda: (
|
||||
_load_query_memory_tool_records(session_id)
|
||||
if len(_load_query_memory_tool_records(session_id)) >= 2
|
||||
else None
|
||||
),
|
||||
timeout_seconds=10.0,
|
||||
interval_seconds=0.1,
|
||||
description="第二次 query_memory 工具记录生成",
|
||||
)
|
||||
latest_tool_data = json.loads(str(tool_records[-1].get("tool_data") or "{}"))
|
||||
latest_hits = list((latest_tool_data.get("structured_content") or {}).get("hits") or [])
|
||||
assert latest_hits
|
||||
latest_contents = "\n".join(str(item.get("content", "") or "") for item in latest_hits)
|
||||
assert "绿色" in latest_contents
|
||||
assert "蓝色" not in latest_contents
|
||||
396
pytests/A_memorix_test/test_feedback_correction_core.py
Normal file
396
pytests/A_memorix_test/test_feedback_correction_core.py
Normal file
@@ -0,0 +1,396 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
|
||||
IMPORT_ERROR: str | None = None
|
||||
|
||||
try:
|
||||
from src.A_memorix.core.retrieval.sparse_bm25 import SparseBM25Config, SparseBM25Index
|
||||
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
|
||||
from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
|
||||
except SystemExit as exc:
|
||||
IMPORT_ERROR = f"config initialization exited during import: {exc}"
|
||||
SparseBM25Config = None # type: ignore[assignment]
|
||||
SparseBM25Index = None # type: ignore[assignment]
|
||||
kernel_module = None # type: ignore[assignment]
|
||||
SDKMemoryKernel = None # type: ignore[assignment]
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kernel_enqueue_feedback_task_delegates_to_metadata_store(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
def fake_enqueue_feedback_task(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return {
|
||||
"id": 1,
|
||||
"query_tool_id": kwargs["query_tool_id"],
|
||||
"session_id": kwargs["session_id"],
|
||||
"query_timestamp": kwargs["query_timestamp"],
|
||||
"due_at": kwargs["due_at"],
|
||||
"query_snapshot": kwargs["query_snapshot"],
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
kernel_module,
|
||||
"global_config",
|
||||
SimpleNamespace(
|
||||
memory=SimpleNamespace(
|
||||
feedback_correction_enabled=True,
|
||||
feedback_correction_window_hours=12.0,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
query_time = datetime(2026, 4, 9, 10, 30, 0)
|
||||
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
|
||||
kernel.metadata_store = SimpleNamespace(enqueue_feedback_task=fake_enqueue_feedback_task)
|
||||
|
||||
payload = await kernel.enqueue_feedback_task(
|
||||
query_tool_id="tool-query-1",
|
||||
session_id="session-1",
|
||||
query_timestamp=query_time,
|
||||
structured_content={"query": "Alice 喜欢什么", "hits": [{"hash": "relation-1"}]},
|
||||
)
|
||||
|
||||
assert payload["success"] is True
|
||||
assert payload["queued"] is True
|
||||
assert captured["query_tool_id"] == "tool-query-1"
|
||||
assert captured["session_id"] == "session-1"
|
||||
assert captured["query_snapshot"]["query"] == "Alice 喜欢什么"
|
||||
assert captured["query_snapshot"]["hits"] == [{"hash": "relation-1"}]
|
||||
assert captured["due_at"] == pytest.approx(query_time.timestamp() + 12 * 3600, rel=0, abs=1e-6)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kernel_enqueue_feedback_task_skipped_when_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
kernel_module,
|
||||
"global_config",
|
||||
SimpleNamespace(memory=SimpleNamespace(feedback_correction_enabled=False)),
|
||||
)
|
||||
|
||||
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
|
||||
kernel.metadata_store = SimpleNamespace(enqueue_feedback_task=lambda **kwargs: kwargs)
|
||||
|
||||
payload = await kernel.enqueue_feedback_task(
|
||||
query_tool_id="tool-query-2",
|
||||
session_id="session-1",
|
||||
query_timestamp=datetime.now(),
|
||||
structured_content={"hits": [{"hash": "relation-1"}]},
|
||||
)
|
||||
|
||||
assert payload["success"] is False
|
||||
assert payload["reason"] == "feedback_correction_disabled"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_feedback_decision_resolves_paragraph_targets() -> None:
|
||||
action_logs: list[Dict[str, Any]] = []
|
||||
forgotten_hashes: list[str] = []
|
||||
ingested_payloads: list[Dict[str, Any]] = []
|
||||
stale_marks: list[Dict[str, Any]] = []
|
||||
episode_sources: list[str] = []
|
||||
profile_refresh_ids: list[str] = []
|
||||
|
||||
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
|
||||
kernel.metadata_store = SimpleNamespace(
|
||||
get_paragraph_relations=lambda paragraph_hash: [
|
||||
{
|
||||
"hash": "relation-1",
|
||||
"subject": "测试用户",
|
||||
"predicate": "最喜欢的颜色是",
|
||||
"object": "蓝色",
|
||||
}
|
||||
]
|
||||
if paragraph_hash == "paragraph-1"
|
||||
else [],
|
||||
get_relation_status_batch=lambda hashes: {
|
||||
str(hash_value): {"is_inactive": str(hash_value) in forgotten_hashes}
|
||||
for hash_value in hashes
|
||||
},
|
||||
get_paragraph_hashes_by_relation_hashes=lambda hashes: {
|
||||
"relation-1": ["paragraph-1"]
|
||||
}
|
||||
if "relation-1" in hashes
|
||||
else {},
|
||||
upsert_paragraph_stale_relation_mark=lambda **kwargs: stale_marks.append(kwargs) or kwargs,
|
||||
enqueue_episode_source_rebuild=lambda source, reason="": episode_sources.append(source) or True,
|
||||
enqueue_person_profile_refresh=lambda **kwargs: profile_refresh_ids.append(kwargs["person_id"]) or kwargs,
|
||||
get_paragraph=lambda paragraph_hash: {"hash": "paragraph-1", "source": "chat_feedback_test_seed:session-1"}
|
||||
if paragraph_hash == "paragraph-1"
|
||||
else None,
|
||||
append_feedback_action_log=lambda **kwargs: action_logs.append(kwargs),
|
||||
)
|
||||
kernel._feedback_cfg_auto_apply_threshold = lambda: 0.85 # type: ignore[method-assign]
|
||||
kernel._apply_v5_relation_action = lambda *, action, hashes, strength=1.0: ( # type: ignore[method-assign]
|
||||
forgotten_hashes.extend([str(item) for item in hashes]),
|
||||
{"success": True, "action": action, "hashes": list(hashes), "strength": strength},
|
||||
)[1]
|
||||
kernel._feedback_cfg_paragraph_mark_enabled = lambda: True # type: ignore[method-assign]
|
||||
kernel._feedback_cfg_episode_rebuild_enabled = lambda: True # type: ignore[method-assign]
|
||||
kernel._feedback_cfg_profile_refresh_enabled = lambda: True # type: ignore[method-assign]
|
||||
kernel._resolve_feedback_related_person_ids = lambda **kwargs: ["person-1"] # type: ignore[method-assign]
|
||||
kernel._query_relation_rows_by_hashes = lambda relation_hashes, include_inactive=False: [ # type: ignore[method-assign]
|
||||
{
|
||||
"hash": "relation-1",
|
||||
"subject": "测试用户",
|
||||
"predicate": "最喜欢的颜色是",
|
||||
"object": "蓝色",
|
||||
}
|
||||
]
|
||||
|
||||
async def _fake_ingest_feedback_relations(**kwargs):
|
||||
ingested_payloads.append(kwargs)
|
||||
return {"success": True, "stored_ids": ["relation-2"]}
|
||||
|
||||
kernel._ingest_feedback_relations = _fake_ingest_feedback_relations # type: ignore[method-assign]
|
||||
|
||||
payload = await kernel._apply_feedback_decision(
|
||||
task_id=1,
|
||||
query_tool_id="tool-query-1",
|
||||
session_id="session-1",
|
||||
decision={
|
||||
"decision": "correct",
|
||||
"confidence": 0.97,
|
||||
"target_hashes": ["paragraph-1"],
|
||||
"corrected_relations": [
|
||||
{
|
||||
"subject": "测试用户",
|
||||
"predicate": "最喜欢的颜色是",
|
||||
"object": "绿色",
|
||||
"confidence": 0.99,
|
||||
}
|
||||
],
|
||||
"reason": "用户明确纠正为绿色",
|
||||
},
|
||||
hit_map={
|
||||
"paragraph-1": {
|
||||
"hash": "paragraph-1",
|
||||
"type": "paragraph",
|
||||
"content": "测试用户 最喜欢的颜色是 蓝色",
|
||||
"linked_relation_hashes": ["relation-1"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert payload["applied"] is True
|
||||
assert payload["relation_hashes"] == ["relation-1"]
|
||||
assert forgotten_hashes == ["relation-1"]
|
||||
assert ingested_payloads[0]["relation_hashes"] == ["relation-1"]
|
||||
assert payload["stale_paragraph_hashes"] == ["paragraph-1"]
|
||||
assert "chat_feedback_test_seed:session-1" in payload["episode_rebuild_sources"]
|
||||
assert "chat_summary:session-1" in payload["episode_rebuild_sources"]
|
||||
assert payload["profile_refresh_person_ids"] == ["person-1"]
|
||||
assert stale_marks[0]["paragraph_hash"] == "paragraph-1"
|
||||
assert {item["action_type"] for item in action_logs} == {
|
||||
"forget_relation",
|
||||
"ingest_correction",
|
||||
"mark_stale_paragraph",
|
||||
"enqueue_episode_rebuild",
|
||||
"enqueue_profile_refresh",
|
||||
}
|
||||
|
||||
|
||||
def test_filter_active_relation_hits_removes_inactive_relations() -> None:
|
||||
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
|
||||
kernel._feedback_cfg_paragraph_hard_filter_enabled = lambda: True # type: ignore[method-assign]
|
||||
kernel.metadata_store = SimpleNamespace(
|
||||
get_relation_status_batch=lambda hashes: {
|
||||
"r-active": {"is_inactive": False},
|
||||
"r-inactive": {"is_inactive": True},
|
||||
"r-para-inactive": {"is_inactive": True},
|
||||
"r-stale-active": {"is_inactive": False},
|
||||
"r-stale-inactive": {"is_inactive": True},
|
||||
},
|
||||
get_paragraph_relations=lambda paragraph_hash: (
|
||||
[{"hash": "r-para-inactive"}] if paragraph_hash == "p-inactive" else []
|
||||
),
|
||||
get_paragraph_stale_relation_marks_batch=lambda paragraph_hashes: {
|
||||
"p-stale": [{"relation_hash": "r-stale-inactive"}],
|
||||
"p-restored": [{"relation_hash": "r-stale-active"}],
|
||||
},
|
||||
)
|
||||
|
||||
hits = [
|
||||
{"hash": "r-active", "type": "relation", "content": "A 喜欢 B"},
|
||||
{"hash": "r-inactive", "type": "relation", "content": "A 讨厌 B"},
|
||||
{"hash": "p-1", "type": "paragraph", "content": "段落证据"},
|
||||
{"hash": "p-inactive", "type": "paragraph", "content": "失活段落证据"},
|
||||
{"hash": "p-stale", "type": "paragraph", "content": "被标脏段落"},
|
||||
{"hash": "p-restored", "type": "paragraph", "content": "恢复可见段落"},
|
||||
]
|
||||
|
||||
filtered = kernel._filter_active_relation_hits(hits)
|
||||
|
||||
assert [item["hash"] for item in filtered] == ["r-active", "p-1", "p-restored"]
|
||||
|
||||
|
||||
def test_sparse_relation_search_requests_active_only() -> None:
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
class FakeMetadataStore:
|
||||
def fts_search_relations_bm25(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
return []
|
||||
|
||||
index = SparseBM25Index(
|
||||
metadata_store=FakeMetadataStore(), # type: ignore[arg-type]
|
||||
config=SparseBM25Config(enabled=True, lazy_load=False),
|
||||
)
|
||||
index._loaded = True
|
||||
index._conn = object() # type: ignore[assignment]
|
||||
|
||||
result = index.search_relations("测试纠错", k=5)
|
||||
|
||||
assert result == []
|
||||
assert captured["include_inactive"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_feedback_task_rollback_restores_snapshots_and_requeues_followups() -> None:
|
||||
action_logs: list[Dict[str, Any]] = []
|
||||
queued_sources: list[str] = []
|
||||
queued_profiles: list[str] = []
|
||||
deleted_marks: list[tuple[str, str]] = []
|
||||
deleted_paragraphs: list[str] = []
|
||||
relation_statuses: Dict[str, Dict[str, Any]] = {
|
||||
"rel-old": {"is_inactive": True, "weight": 0.0, "is_pinned": False, "protected_until": 0.0, "last_reinforced": None, "inactive_since": 1.0},
|
||||
"rel-new": {"is_inactive": False, "weight": 1.0, "is_pinned": False, "protected_until": 0.0, "last_reinforced": None, "inactive_since": None},
|
||||
}
|
||||
current_task: Dict[str, Any] = {
|
||||
"id": 1,
|
||||
"query_tool_id": "tool-query-rollback",
|
||||
"session_id": "session-1",
|
||||
"status": "applied",
|
||||
"rollback_status": "none",
|
||||
"query_snapshot": {"query": "测试用户最喜欢的颜色是什么"},
|
||||
"decision_payload": {"decision": "correct", "confidence": 0.97},
|
||||
"rollback_plan": {
|
||||
"forgotten_relations": [
|
||||
{
|
||||
"hash": "rel-old",
|
||||
"subject": "测试用户",
|
||||
"predicate": "最喜欢的颜色是",
|
||||
"object": "蓝色",
|
||||
"before_status": {
|
||||
"is_inactive": False,
|
||||
"weight": 0.8,
|
||||
"is_pinned": False,
|
||||
"protected_until": 0.0,
|
||||
"last_reinforced": None,
|
||||
"inactive_since": None,
|
||||
},
|
||||
}
|
||||
],
|
||||
"corrected_write": {
|
||||
"paragraph_hashes": ["paragraph-new"],
|
||||
"corrected_relations": [
|
||||
{
|
||||
"hash": "rel-new",
|
||||
"subject": "测试用户",
|
||||
"predicate": "最喜欢的颜色是",
|
||||
"object": "绿色",
|
||||
"existed_before": False,
|
||||
"before_status": {},
|
||||
}
|
||||
],
|
||||
},
|
||||
"stale_marks": [{"paragraph_hash": "paragraph-old", "relation_hash": "rel-old"}],
|
||||
"episode_sources": ["chat_summary:session-1"],
|
||||
"profile_person_ids": ["person-1"],
|
||||
},
|
||||
}
|
||||
|
||||
class _Conn:
|
||||
def cursor(self):
|
||||
return self
|
||||
|
||||
def execute(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
def commit(self):
|
||||
return None
|
||||
|
||||
metadata_store = SimpleNamespace(
|
||||
get_feedback_task_by_id=lambda task_id: current_task if int(task_id) == 1 else None,
|
||||
mark_feedback_task_rollback_running=lambda **kwargs: current_task.update({"rollback_status": "running"}) or current_task,
|
||||
finalize_feedback_task_rollback=lambda **kwargs: current_task.update(
|
||||
{
|
||||
"rollback_status": kwargs["rollback_status"],
|
||||
"rollback_result": kwargs.get("rollback_result") or {},
|
||||
"rollback_error": kwargs.get("rollback_error", ""),
|
||||
}
|
||||
)
|
||||
or current_task,
|
||||
get_relation_status_batch=lambda hashes: {
|
||||
hash_value: dict(relation_statuses[hash_value])
|
||||
for hash_value in hashes
|
||||
if hash_value in relation_statuses
|
||||
},
|
||||
restore_relation_status_from_snapshot=lambda hash_value, snapshot: relation_statuses.update(
|
||||
{hash_value: dict(snapshot)}
|
||||
)
|
||||
or dict(snapshot),
|
||||
append_feedback_action_log=lambda **kwargs: action_logs.append(kwargs),
|
||||
mark_as_deleted=lambda hashes, type_: deleted_paragraphs.extend(list(hashes)) or len(list(hashes)),
|
||||
get_paragraph=lambda paragraph_hash: {"hash": paragraph_hash, "source": "chat_summary:session-1"},
|
||||
get_connection=lambda: _Conn(),
|
||||
delete_external_memory_refs_by_paragraphs=lambda hashes: [
|
||||
{"paragraph_hash": str(hash_value), "external_id": f"external:{hash_value}"}
|
||||
for hash_value in hashes
|
||||
],
|
||||
update_relations_protection=lambda hashes, **kwargs: None,
|
||||
mark_relations_inactive=lambda hashes, inactive_since=None: [
|
||||
relation_statuses.__setitem__(
|
||||
hash_value,
|
||||
{
|
||||
**relation_statuses.get(hash_value, {}),
|
||||
"is_inactive": True,
|
||||
"inactive_since": inactive_since,
|
||||
},
|
||||
)
|
||||
for hash_value in hashes
|
||||
],
|
||||
delete_paragraph_stale_relation_marks=lambda marks: deleted_marks.extend(list(marks)) or len(list(marks)),
|
||||
enqueue_episode_source_rebuild=lambda source, reason='': queued_sources.append(source) or True,
|
||||
enqueue_person_profile_refresh=lambda **kwargs: queued_profiles.append(kwargs["person_id"]) or kwargs,
|
||||
list_feedback_action_logs=lambda task_id: action_logs if int(task_id) == 1 else [],
|
||||
)
|
||||
|
||||
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
|
||||
kernel.metadata_store = metadata_store
|
||||
async def _noop_initialize() -> None:
|
||||
return None
|
||||
kernel.initialize = _noop_initialize # type: ignore[method-assign]
|
||||
kernel._rebuild_graph_from_metadata = lambda: None # type: ignore[method-assign]
|
||||
kernel._persist = lambda: None # type: ignore[method-assign]
|
||||
|
||||
payload = await kernel._rollback_feedback_task(
|
||||
task_id=1,
|
||||
requested_by="pytest",
|
||||
reason="manual rollback",
|
||||
)
|
||||
|
||||
assert payload["success"] is True
|
||||
assert relation_statuses["rel-old"]["is_inactive"] is False
|
||||
assert relation_statuses["rel-new"]["is_inactive"] is True
|
||||
assert deleted_paragraphs == ["paragraph-new"]
|
||||
assert deleted_marks == [("paragraph-old", "rel-old")]
|
||||
assert queued_sources == ["chat_summary:session-1"]
|
||||
assert queued_profiles == ["person-1"]
|
||||
assert current_task["rollback_status"] == "rolled_back"
|
||||
assert {item["action_type"] for item in action_logs} >= {
|
||||
"rollback_restore_relation",
|
||||
"rollback_revert_corrected_relation",
|
||||
"rollback_delete_correction_paragraph",
|
||||
"rollback_clear_stale_mark",
|
||||
"rollback_enqueue_episode_rebuild",
|
||||
"rollback_enqueue_profile_refresh",
|
||||
}
|
||||
82
pytests/A_memorix_test/test_graph_store_persistence.py
Normal file
82
pytests/A_memorix_test/test_graph_store_persistence.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
from src.A_memorix.core.storage.graph_store import GraphStore
|
||||
except SystemExit as exc:
|
||||
GraphStore = None # type: ignore[assignment]
|
||||
IMPORT_ERROR = f"config initialization exited during import: {exc}"
|
||||
else:
|
||||
IMPORT_ERROR = None
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
|
||||
|
||||
|
||||
def _build_empty_graph_metadata() -> dict:
|
||||
return {
|
||||
"nodes": [],
|
||||
"node_to_idx": {},
|
||||
"node_attrs": {},
|
||||
"matrix_format": "csr",
|
||||
"total_nodes_added": 0,
|
||||
"total_edges_added": 0,
|
||||
"total_nodes_deleted": 0,
|
||||
"total_edges_deleted": 0,
|
||||
"edge_hash_map": {},
|
||||
}
|
||||
|
||||
|
||||
def test_graph_store_clear_save_removes_stale_adjacency(tmp_path: Path) -> None:
|
||||
data_dir = tmp_path / "graph_data"
|
||||
store = GraphStore(data_dir=data_dir)
|
||||
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
|
||||
store.save()
|
||||
|
||||
matrix_path = data_dir / "graph_adjacency.npz"
|
||||
assert matrix_path.exists()
|
||||
|
||||
store.clear()
|
||||
store.save()
|
||||
|
||||
assert not matrix_path.exists()
|
||||
|
||||
|
||||
def test_graph_store_load_resets_stale_adjacency_when_metadata_is_empty(tmp_path: Path) -> None:
|
||||
data_dir = tmp_path / "graph_data"
|
||||
store = GraphStore(data_dir=data_dir)
|
||||
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
|
||||
store.save()
|
||||
|
||||
metadata_path = data_dir / "graph_metadata.pkl"
|
||||
with metadata_path.open("wb") as handle:
|
||||
pickle.dump(_build_empty_graph_metadata(), handle)
|
||||
|
||||
reloaded = GraphStore(data_dir=data_dir)
|
||||
reloaded.load()
|
||||
|
||||
assert reloaded.num_nodes == 0
|
||||
assert reloaded.num_edges == 0
|
||||
assert reloaded.get_nodes() == []
|
||||
|
||||
|
||||
def test_graph_store_load_clears_stale_edge_hash_map_when_metadata_is_empty(tmp_path: Path) -> None:
|
||||
data_dir = tmp_path / "graph_data"
|
||||
store = GraphStore(data_dir=data_dir)
|
||||
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
|
||||
store.save()
|
||||
|
||||
metadata_path = data_dir / "graph_metadata.pkl"
|
||||
empty_metadata = _build_empty_graph_metadata()
|
||||
empty_metadata["edge_hash_map"] = {(0, 1): {"rel-1"}}
|
||||
with metadata_path.open("wb") as handle:
|
||||
pickle.dump(empty_metadata, handle)
|
||||
|
||||
reloaded = GraphStore(data_dir=data_dir)
|
||||
reloaded.load()
|
||||
|
||||
assert reloaded.has_edge_hash_map() is False
|
||||
@@ -35,12 +35,30 @@ def test_legacy_learning_list_with_numeric_fourth_column_is_migrated():
|
||||
]
|
||||
|
||||
|
||||
def test_visual_multimodal_replyer_is_migrated_to_replyer_mode() -> None:
|
||||
payload = {
|
||||
"visual": {
|
||||
"multimodal_replyer": True,
|
||||
}
|
||||
}
|
||||
def test_visual_multimodal_planner_is_migrated_to_planner_mode():
|
||||
payload = {"visual": {"multimodal_planner": True}}
|
||||
|
||||
result = try_migrate_legacy_bot_config_dict(payload)
|
||||
|
||||
assert result.migrated is True
|
||||
assert "visual.multimodal_planner_moved_to_visual.planner_mode" in result.reason
|
||||
assert result.data["visual"]["planner_mode"] == "multimodal"
|
||||
assert "multimodal_planner" not in result.data["visual"]
|
||||
|
||||
|
||||
def test_chat_multimodal_planner_is_migrated_to_visual_planner_mode():
|
||||
payload = {"chat": {"multimodal_planner": True}}
|
||||
|
||||
result = try_migrate_legacy_bot_config_dict(payload)
|
||||
|
||||
assert result.migrated is True
|
||||
assert "chat.multimodal_planner_moved_to_visual.planner_mode" in result.reason
|
||||
assert result.data["visual"]["planner_mode"] == "multimodal"
|
||||
assert "multimodal_planner" not in result.data["chat"]
|
||||
|
||||
|
||||
def test_visual_multimodal_replyer_is_migrated_to_replyer_mode():
|
||||
payload = {"visual": {"multimodal_replyer": True}}
|
||||
|
||||
result = try_migrate_legacy_bot_config_dict(payload)
|
||||
|
||||
@@ -50,13 +68,8 @@ def test_visual_multimodal_replyer_is_migrated_to_replyer_mode() -> None:
|
||||
assert "multimodal_replyer" not in result.data["visual"]
|
||||
|
||||
|
||||
def test_chat_replyer_generator_type_is_migrated_to_replyer_mode() -> None:
|
||||
payload = {
|
||||
"chat": {
|
||||
"replyer_generator_type": "legacy",
|
||||
},
|
||||
"visual": {},
|
||||
}
|
||||
def test_chat_replyer_generator_type_is_migrated_to_replyer_mode():
|
||||
payload = {"chat": {"replyer_generator_type": "legacy"}}
|
||||
|
||||
result = try_migrate_legacy_bot_config_dict(payload)
|
||||
|
||||
|
||||
@@ -38,6 +38,206 @@ def test_person_fact_resolve_target_person_for_private_chat(monkeypatch):
|
||||
assert person.person_id == "qq:123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_summary_writeback_service_triggers_when_threshold_reached(monkeypatch):
|
||||
events: list[tuple[str, object]] = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
memory_flow_module,
|
||||
"global_config",
|
||||
SimpleNamespace(
|
||||
memory=SimpleNamespace(
|
||||
chat_summary_writeback_enabled=True,
|
||||
chat_summary_writeback_message_threshold=3,
|
||||
chat_summary_writeback_context_length=7,
|
||||
)
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
|
||||
|
||||
async def fake_ingest_summary(**kwargs):
|
||||
events.append(("ingest_summary", kwargs))
|
||||
return SimpleNamespace(success=True, detail="ok")
|
||||
|
||||
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
|
||||
del self, session_id, total_message_count
|
||||
return 0
|
||||
|
||||
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
|
||||
monkeypatch.setattr(
|
||||
memory_flow_module.ChatSummaryWritebackService,
|
||||
"_load_last_trigger_message_count",
|
||||
fake_load_last_trigger_message_count,
|
||||
)
|
||||
|
||||
service = memory_flow_module.ChatSummaryWritebackService()
|
||||
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
|
||||
|
||||
await service._handle_message(message)
|
||||
|
||||
assert len(events) == 1
|
||||
_, payload = events[0]
|
||||
assert payload["external_id"] == "chat_auto_summary:session-1:5"
|
||||
assert payload["chat_id"] == "session-1"
|
||||
assert payload["text"] == ""
|
||||
assert payload["metadata"]["generate_from_chat"] is True
|
||||
assert payload["metadata"]["context_length"] == 7
|
||||
assert payload["metadata"]["trigger"] == "message_threshold"
|
||||
assert payload["user_id"] == "user-1"
|
||||
assert payload["group_id"] == "group-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_summary_writeback_service_skips_when_threshold_not_reached(monkeypatch):
|
||||
called = False
|
||||
|
||||
monkeypatch.setattr(
|
||||
memory_flow_module,
|
||||
"global_config",
|
||||
SimpleNamespace(
|
||||
memory=SimpleNamespace(
|
||||
chat_summary_writeback_enabled=True,
|
||||
chat_summary_writeback_message_threshold=6,
|
||||
chat_summary_writeback_context_length=9,
|
||||
)
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
|
||||
|
||||
async def fake_ingest_summary(**kwargs):
|
||||
nonlocal called
|
||||
called = True
|
||||
return SimpleNamespace(success=True, detail="ok")
|
||||
|
||||
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
|
||||
del self, session_id, total_message_count
|
||||
return 0
|
||||
|
||||
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
|
||||
monkeypatch.setattr(
|
||||
memory_flow_module.ChatSummaryWritebackService,
|
||||
"_load_last_trigger_message_count",
|
||||
fake_load_last_trigger_message_count,
|
||||
)
|
||||
|
||||
service = memory_flow_module.ChatSummaryWritebackService()
|
||||
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
|
||||
|
||||
await service._handle_message(message)
|
||||
|
||||
assert called is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_summary_writeback_service_restores_previous_trigger_count(monkeypatch):
|
||||
events: list[tuple[str, object]] = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
memory_flow_module,
|
||||
"global_config",
|
||||
SimpleNamespace(
|
||||
memory=SimpleNamespace(
|
||||
chat_summary_writeback_enabled=True,
|
||||
chat_summary_writeback_message_threshold=3,
|
||||
chat_summary_writeback_context_length=7,
|
||||
)
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 8)
|
||||
|
||||
async def fake_ingest_summary(**kwargs):
|
||||
events.append(("ingest_summary", kwargs))
|
||||
return SimpleNamespace(success=True, detail="ok")
|
||||
|
||||
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
|
||||
del self, session_id, total_message_count
|
||||
return 5
|
||||
|
||||
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
|
||||
monkeypatch.setattr(
|
||||
memory_flow_module.ChatSummaryWritebackService,
|
||||
"_load_last_trigger_message_count",
|
||||
fake_load_last_trigger_message_count,
|
||||
)
|
||||
|
||||
service = memory_flow_module.ChatSummaryWritebackService()
|
||||
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
|
||||
|
||||
await service._handle_message(message)
|
||||
|
||||
assert len(events) == 1
|
||||
_, payload = events[0]
|
||||
assert payload["external_id"] == "chat_auto_summary:session-1:8"
|
||||
assert service._states["session-1"].last_trigger_message_count == 8
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_summary_writeback_service_falls_back_to_current_count_for_legacy_summary(monkeypatch):
|
||||
called = False
|
||||
|
||||
monkeypatch.setattr(
|
||||
memory_flow_module,
|
||||
"global_config",
|
||||
SimpleNamespace(
|
||||
memory=SimpleNamespace(
|
||||
chat_summary_writeback_enabled=True,
|
||||
chat_summary_writeback_message_threshold=3,
|
||||
chat_summary_writeback_context_length=7,
|
||||
)
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
|
||||
|
||||
async def fake_ingest_summary(**kwargs):
|
||||
nonlocal called
|
||||
called = True
|
||||
return SimpleNamespace(success=True, detail="ok")
|
||||
|
||||
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
|
||||
del self, session_id, total_message_count
|
||||
return 5
|
||||
|
||||
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
|
||||
monkeypatch.setattr(
|
||||
memory_flow_module.ChatSummaryWritebackService,
|
||||
"_load_last_trigger_message_count",
|
||||
fake_load_last_trigger_message_count,
|
||||
)
|
||||
|
||||
service = memory_flow_module.ChatSummaryWritebackService()
|
||||
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
|
||||
|
||||
await service._handle_message(message)
|
||||
|
||||
assert called is False
|
||||
assert service._states["session-1"].last_trigger_message_count == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_summary_writeback_service_loads_trigger_count_from_summary_metadata(monkeypatch):
|
||||
class FakeMetadataStore:
|
||||
@staticmethod
|
||||
def get_paragraphs_by_source(source: str):
|
||||
assert source == "chat_summary:session-1"
|
||||
return [
|
||||
{"created_at": 1.0, "metadata": {"trigger_message_count": 3}},
|
||||
{"created_at": 2.0, "metadata": {"trigger_message_count": 6}},
|
||||
]
|
||||
|
||||
class FakeRuntimeManager:
|
||||
@staticmethod
|
||||
async def _ensure_kernel():
|
||||
return SimpleNamespace(metadata_store=FakeMetadataStore())
|
||||
|
||||
monkeypatch.setattr(memory_flow_module.memory_service_module, "a_memorix_host_service", FakeRuntimeManager())
|
||||
|
||||
service = memory_flow_module.ChatSummaryWritebackService()
|
||||
|
||||
restored = await service._load_last_trigger_message_count(session_id="session-1", total_message_count=8)
|
||||
|
||||
assert restored == 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_automation_service_auto_starts_and_delegates():
|
||||
events: list[tuple[str, str]] = []
|
||||
@@ -52,15 +252,67 @@ async def test_memory_automation_service_auto_starts_and_delegates():
|
||||
async def shutdown(self):
|
||||
events.append(("shutdown", "fact"))
|
||||
|
||||
class FakeChatSummaryWriteback:
|
||||
async def start(self):
|
||||
events.append(("start", "summary"))
|
||||
|
||||
async def enqueue(self, message):
|
||||
events.append(("summary", message.session_id))
|
||||
|
||||
async def shutdown(self):
|
||||
events.append(("shutdown", "summary"))
|
||||
|
||||
service = memory_flow_module.MemoryAutomationService()
|
||||
service.fact_writeback = FakeFactWriteback()
|
||||
service.chat_summary_writeback = FakeChatSummaryWriteback()
|
||||
|
||||
await service.on_incoming_message(SimpleNamespace(session_id="session-1"))
|
||||
await service.on_message_sent(SimpleNamespace(session_id="session-1"))
|
||||
await service.shutdown()
|
||||
|
||||
assert events == [
|
||||
("start", "fact"),
|
||||
("start", "summary"),
|
||||
("sent", "session-1"),
|
||||
("summary", "session-1"),
|
||||
("shutdown", "summary"),
|
||||
("shutdown", "fact"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_automation_service_on_incoming_message_auto_starts_only():
|
||||
events: list[tuple[str, str]] = []
|
||||
|
||||
class FakeFactWriteback:
|
||||
async def start(self):
|
||||
events.append(("start", "fact"))
|
||||
|
||||
async def enqueue(self, message):
|
||||
events.append(("sent", message.session_id))
|
||||
|
||||
async def shutdown(self):
|
||||
events.append(("shutdown", "fact"))
|
||||
|
||||
class FakeChatSummaryWriteback:
|
||||
async def start(self):
|
||||
events.append(("start", "summary"))
|
||||
|
||||
async def enqueue(self, message):
|
||||
events.append(("summary", message.session_id))
|
||||
|
||||
async def shutdown(self):
|
||||
events.append(("shutdown", "summary"))
|
||||
|
||||
service = memory_flow_module.MemoryAutomationService()
|
||||
service.fact_writeback = FakeFactWriteback()
|
||||
service.chat_summary_writeback = FakeChatSummaryWriteback()
|
||||
|
||||
await service.on_incoming_message(SimpleNamespace(session_id="session-1"))
|
||||
await service.shutdown()
|
||||
|
||||
assert events == [
|
||||
("start", "fact"),
|
||||
("start", "summary"),
|
||||
("shutdown", "summary"),
|
||||
("shutdown", "fact"),
|
||||
]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""发送服务回归测试。"""
|
||||
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List
|
||||
|
||||
@@ -182,6 +183,75 @@ async def test_text_to_stream_with_message_returns_sent_message(monkeypatch: pyt
|
||||
assert stored_messages[0].message_id == "real-message-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_stream_with_message_triggers_memory_and_syncs_maisaka_history(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
fake_manager = _FakePlatformIOManager(
|
||||
delivery_batch=SimpleNamespace(
|
||||
has_success=True,
|
||||
sent_receipts=[
|
||||
SimpleNamespace(
|
||||
driver_id="plugin.qq.sender",
|
||||
external_message_id="real-message-id",
|
||||
metadata={},
|
||||
)
|
||||
],
|
||||
failed_receipts=[],
|
||||
route_key=SimpleNamespace(platform="qq"),
|
||||
)
|
||||
)
|
||||
stored_messages: List[Any] = []
|
||||
memory_events: List[str] = []
|
||||
history_events: List[tuple[str, str]] = []
|
||||
|
||||
class FakeMemoryAutomationService:
|
||||
async def on_message_sent(self, message: Any) -> None:
|
||||
memory_events.append(str(message.message_id))
|
||||
|
||||
class FakeRuntime:
|
||||
def append_sent_message_to_chat_history(self, message: Any, *, source_kind: str = "guided_reply") -> None:
|
||||
history_events.append((str(message.message_id), source_kind))
|
||||
|
||||
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager)
|
||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
|
||||
monkeypatch.setattr(
|
||||
send_service._chat_manager,
|
||||
"get_session_by_session_id",
|
||||
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
send_service.MessageUtils,
|
||||
"store_message_to_db",
|
||||
lambda message: stored_messages.append(message),
|
||||
)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"src.services.memory_flow_service",
|
||||
SimpleNamespace(memory_automation_service=FakeMemoryAutomationService()),
|
||||
)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"src.chat.heart_flow.heartflow_manager",
|
||||
SimpleNamespace(heartflow_manager=SimpleNamespace(heartflow_chat_list={"test-session": FakeRuntime()})),
|
||||
)
|
||||
|
||||
sent_message = await send_service.text_to_stream_with_message(
|
||||
text="你好",
|
||||
stream_id="test-session",
|
||||
sync_to_maisaka_history=True,
|
||||
maisaka_source_kind="guided_reply",
|
||||
)
|
||||
|
||||
assert sent_message is not None
|
||||
assert sent_message.message_id == "real-message-id"
|
||||
assert fake_manager.ensure_calls == 1
|
||||
assert len(stored_messages) == 1
|
||||
assert stored_messages[0].message_id == "real-message-id"
|
||||
assert memory_events == ["real-message-id"]
|
||||
assert history_events == [("real-message-id", "guided_reply")]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
fake_manager = _FakePlatformIOManager(
|
||||
|
||||
@@ -82,6 +82,7 @@ def test_resolve_static_path_prefers_installed_dashboard_package(monkeypatch, tm
|
||||
def test_resolve_static_path_uses_dashboard_dist(monkeypatch, tmp_path) -> None:
|
||||
dashboard_dist = tmp_path / "dashboard" / "dist"
|
||||
dashboard_dist.mkdir(parents=True)
|
||||
(dashboard_dist / "index.html").write_text("<html></html>", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path)
|
||||
|
||||
@@ -91,6 +92,26 @@ def test_resolve_static_path_uses_dashboard_dist(monkeypatch, tmp_path) -> None:
|
||||
assert resolved_path == dashboard_dist
|
||||
|
||||
|
||||
def test_resolve_static_path_falls_back_to_package_when_dashboard_dist_has_no_index(monkeypatch, tmp_path) -> None:
|
||||
dashboard_dist = tmp_path / "dashboard" / "dist"
|
||||
dashboard_dist.mkdir(parents=True)
|
||||
|
||||
package_dist = tmp_path / "site-packages" / "maibot_dashboard" / "dist"
|
||||
package_dist.mkdir(parents=True)
|
||||
|
||||
class _DashboardModule:
|
||||
@staticmethod
|
||||
def get_dist_path() -> Path:
|
||||
return package_dist
|
||||
|
||||
monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path)
|
||||
|
||||
with patch.object(webui_app, "import_module", return_value=_DashboardModule()):
|
||||
resolved_path = webui_app._resolve_static_path()
|
||||
|
||||
assert resolved_path == package_dist
|
||||
|
||||
|
||||
def test_resolve_safe_static_file_path_allows_regular_static_file(tmp_path) -> None:
|
||||
static_path = tmp_path / "dist"
|
||||
asset_path = static_path / "assets" / "app.js"
|
||||
|
||||
@@ -99,13 +99,15 @@ def test_maisaka_is_host_tab_and_mcp_is_attached_to_it():
|
||||
assert mcp_schema.get("uiParent") == "maisaka"
|
||||
|
||||
|
||||
def test_maisaka_memory_query_config_fields_are_exposed():
|
||||
"""MaiSaka 长期记忆检索开关和默认条数应出现在配置 schema 中。"""
|
||||
def test_memory_query_config_fields_are_exposed():
|
||||
"""query_memory 开关和默认条数应出现在记忆配置 schema 中。"""
|
||||
schema = ConfigSchemaGenerator.generate_schema(Config)
|
||||
maisaka_schema = schema["nested"]["maisaka"]
|
||||
memory_schema = schema["nested"]["memory"]
|
||||
|
||||
enable_field = next(field for field in maisaka_schema["fields"] if field["name"] == "enable_memory_query_tool")
|
||||
limit_field = next(field for field in maisaka_schema["fields"] if field["name"] == "memory_query_default_limit")
|
||||
assert memory_schema.get("uiParent") == "emoji"
|
||||
|
||||
enable_field = next(field for field in memory_schema["fields"] if field["name"] == "enable_memory_query_tool")
|
||||
limit_field = next(field for field in memory_schema["fields"] if field["name"] == "memory_query_default_limit")
|
||||
|
||||
assert enable_field["type"] == "boolean"
|
||||
assert enable_field.get("x-widget") == "switch"
|
||||
|
||||
@@ -638,3 +638,41 @@ def test_delete_operation_routes(client: TestClient, monkeypatch):
|
||||
assert list_response.json()["count"] == 1
|
||||
assert get_response.status_code == 200
|
||||
assert get_response.json()["operation"]["operation_id"] == "del-1"
|
||||
|
||||
|
||||
def test_feedback_correction_routes(client: TestClient, monkeypatch):
|
||||
async def fake_feedback_admin(*, action: str, **kwargs):
|
||||
if action == "list":
|
||||
assert kwargs == {
|
||||
"limit": 7,
|
||||
"statuses": ["applied"],
|
||||
"rollback_statuses": ["none"],
|
||||
"query": "green",
|
||||
}
|
||||
return {"success": True, "items": [{"task_id": 11, "query_text": "what color"}], "count": 1}
|
||||
if action == "get":
|
||||
assert kwargs == {"task_id": 11}
|
||||
return {"success": True, "task": {"task_id": 11, "query_text": "what color", "action_logs": []}}
|
||||
if action == "rollback":
|
||||
assert kwargs == {"task_id": 11, "requested_by": "tester", "reason": "manual revert"}
|
||||
return {"success": True, "result": {"restored_relation_hashes": ["rel-1"]}}
|
||||
raise AssertionError(action)
|
||||
|
||||
monkeypatch.setattr(memory_router_module.memory_service, "feedback_admin", fake_feedback_admin)
|
||||
|
||||
list_response = client.get(
|
||||
"/api/webui/memory/feedback-corrections",
|
||||
params={"limit": 7, "status": "applied", "rollback_status": "none", "query": "green"},
|
||||
)
|
||||
get_response = client.get("/api/webui/memory/feedback-corrections/11")
|
||||
rollback_response = client.post(
|
||||
"/api/webui/memory/feedback-corrections/11/rollback",
|
||||
json={"requested_by": "tester", "reason": "manual revert"},
|
||||
)
|
||||
|
||||
assert list_response.status_code == 200
|
||||
assert list_response.json()["items"][0]["task_id"] == 11
|
||||
assert get_response.status_code == 200
|
||||
assert get_response.json()["task"]["task_id"] == 11
|
||||
assert rollback_response.status_code == 200
|
||||
assert rollback_response.json()["result"]["restored_relation_hashes"] == ["rel-1"]
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
- 运行时主目录由 `storage.data_dir` 决定(当前模板默认 `data/a-memorix`);
|
||||
- 部分离线脚本仍以 `data/plugins/a-dawn.a-memorix` 作为默认处理目录。
|
||||
- 修正文档中的导入示例参数,`memory_import_admin.create_paste` 的 `input_mode` 示例统一为 `text`/`json`。
|
||||
- 更新 `README.md` 关于元数据 schema 的描述,和当前代码 `SCHEMA_VERSION = 9` 保持一致。
|
||||
- 更新 `README.md` 关于元数据 schema 的描述,和当前代码 `SCHEMA_VERSION = 10` 保持一致。
|
||||
|
||||
## [2.0.0] - 2026-03-18
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# A_Memorix 配置参考 (v2.0.0)
|
||||
|
||||
本文档对应当前仓库代码(`__version__ = 2.0.0`、`SCHEMA_VERSION = 9`)。
|
||||
本文档对应当前仓库代码(`__version__ = 2.0.0`、`SCHEMA_VERSION = 10`)。
|
||||
|
||||
说明:
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
@@ -29,6 +29,9 @@ logger = get_logger("A_Memorix.EmbeddingAPIAdapter")
|
||||
class EmbeddingAPIAdapter:
|
||||
"""适配宿主 embedding 请求接口。"""
|
||||
|
||||
_GLOBAL_DIMENSION_CACHE: Dict[str, int] = {}
|
||||
_GLOBAL_TEXT_EMBEDDING_CACHE: Dict[Tuple[str, int, str], np.ndarray] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = 32,
|
||||
@@ -232,10 +235,32 @@ class EmbeddingAPIAdapter:
|
||||
logger.error(f"通过直接 Client 获取 Embedding 失败: {last_exc}")
|
||||
return None
|
||||
|
||||
def _dimension_cache_key(self) -> str:
|
||||
candidate_names = self._resolve_candidate_model_names()
|
||||
return "|".join(
|
||||
[
|
||||
str(self.model_name or "auto"),
|
||||
str(self.default_dimension),
|
||||
",".join(candidate_names),
|
||||
]
|
||||
)
|
||||
|
||||
def _embedding_cache_key(self, text: str, dimensions: Optional[int]) -> Tuple[str, int, str]:
|
||||
requested_dimension = self._resolve_canonical_dimension(dimensions)
|
||||
return (self._dimension_cache_key(), int(requested_dimension), str(text or ""))
|
||||
|
||||
async def _detect_dimension(self) -> int:
|
||||
if self._dimension_detected and self._dimension is not None:
|
||||
return self._dimension
|
||||
|
||||
cache_key = self._dimension_cache_key()
|
||||
cached_dimension = self._GLOBAL_DIMENSION_CACHE.get(cache_key)
|
||||
if cached_dimension is not None:
|
||||
self._dimension = int(cached_dimension)
|
||||
self._dimension_detected = True
|
||||
logger.info(f"嵌入维度命中进程缓存: {self._dimension}")
|
||||
return self._dimension
|
||||
|
||||
logger.info("正在检测嵌入模型维度...")
|
||||
try:
|
||||
target_dim = self.default_dimension
|
||||
@@ -251,6 +276,7 @@ class EmbeddingAPIAdapter:
|
||||
)
|
||||
self._dimension = detected_dim
|
||||
self._dimension_detected = True
|
||||
self._GLOBAL_DIMENSION_CACHE[cache_key] = int(detected_dim)
|
||||
return detected_dim
|
||||
except Exception as exc:
|
||||
logger.debug(f"带维度参数探测失败: {exc},尝试不带维度参数探测")
|
||||
@@ -261,6 +287,7 @@ class EmbeddingAPIAdapter:
|
||||
detected_dim = len(test_embedding)
|
||||
self._dimension = detected_dim
|
||||
self._dimension_detected = True
|
||||
self._GLOBAL_DIMENSION_CACHE[cache_key] = int(detected_dim)
|
||||
logger.info(f"嵌入维度检测成功 (自然维度): {detected_dim}")
|
||||
return detected_dim
|
||||
logger.warning(f"嵌入维度检测失败,使用 configured_dimension: {self.default_dimension}")
|
||||
@@ -269,6 +296,7 @@ class EmbeddingAPIAdapter:
|
||||
|
||||
self._dimension = self.default_dimension
|
||||
self._dimension_detected = True
|
||||
self._GLOBAL_DIMENSION_CACHE[cache_key] = int(self.default_dimension)
|
||||
return self.default_dimension
|
||||
|
||||
async def encode(
|
||||
@@ -336,26 +364,54 @@ class EmbeddingAPIAdapter:
|
||||
all_embeddings: List[np.ndarray] = []
|
||||
for offset in range(0, len(texts), batch_size):
|
||||
batch = texts[offset : offset + batch_size]
|
||||
batch_results: List[Tuple[int, np.ndarray]] = []
|
||||
uncached_items: List[Tuple[int, str]] = []
|
||||
|
||||
if self.enable_cache:
|
||||
for index, text in enumerate(batch):
|
||||
cache_key = self._embedding_cache_key(text, dimensions)
|
||||
cached_vector = self._GLOBAL_TEXT_EMBEDDING_CACHE.get(cache_key)
|
||||
if cached_vector is None:
|
||||
uncached_items.append((index, text))
|
||||
else:
|
||||
batch_results.append((index, cached_vector.copy()))
|
||||
else:
|
||||
uncached_items = list(enumerate(batch))
|
||||
|
||||
if not uncached_items:
|
||||
batch_results.sort(key=lambda item: item[0])
|
||||
all_embeddings.extend(emb for _, emb in batch_results)
|
||||
continue
|
||||
|
||||
semaphore = asyncio.Semaphore(self.max_concurrent)
|
||||
|
||||
async def encode_with_semaphore(text: str, index: int):
|
||||
async def encode_with_semaphore(text: str, batch_index: int, absolute_index: int):
|
||||
async with semaphore:
|
||||
embedding = await self._get_embedding_direct(text, dimensions=dimensions)
|
||||
if embedding is None:
|
||||
raise RuntimeError(f"文本 {index} 编码失败:embedding 返回为空")
|
||||
raise RuntimeError(f"文本 {absolute_index} 编码失败:embedding 返回为空")
|
||||
vector = self._validate_embedding_vector(
|
||||
embedding,
|
||||
source=f"文本 {index}",
|
||||
source=f"文本 {absolute_index}",
|
||||
)
|
||||
return index, vector
|
||||
return batch_index, vector
|
||||
|
||||
tasks = [
|
||||
encode_with_semaphore(text, offset + index)
|
||||
for index, text in enumerate(batch)
|
||||
encode_with_semaphore(text, index, offset + index)
|
||||
for index, text in uncached_items
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
results.sort(key=lambda item: item[0])
|
||||
all_embeddings.extend(emb for _, emb in results)
|
||||
normalized_results: List[Tuple[int, np.ndarray]] = []
|
||||
for batch_index, vector in results:
|
||||
normalized_results.append((batch_index, vector))
|
||||
if self.enable_cache:
|
||||
text = batch[batch_index]
|
||||
cache_key = self._embedding_cache_key(text, dimensions)
|
||||
self._GLOBAL_TEXT_EMBEDDING_CACHE[cache_key] = vector.copy()
|
||||
|
||||
batch_results.extend(normalized_results)
|
||||
batch_results.sort(key=lambda item: item[0])
|
||||
all_embeddings.extend(emb for _, emb in batch_results)
|
||||
|
||||
return np.array(all_embeddings, dtype=np.float32)
|
||||
|
||||
|
||||
@@ -632,7 +632,7 @@ class DualPathRetriever:
|
||||
results: List[RetrievalResult] = []
|
||||
for row in rows:
|
||||
hash_value = row["hash"]
|
||||
relation = self.metadata_store.get_relation(hash_value)
|
||||
relation = self.metadata_store.get_relation(hash_value, include_inactive=False)
|
||||
if relation is None:
|
||||
continue
|
||||
|
||||
@@ -888,8 +888,8 @@ class DualPathRetriever:
|
||||
entity_name = entity["name"]
|
||||
|
||||
related_rels = []
|
||||
related_rels.extend(self.metadata_store.get_relations(subject=entity_name))
|
||||
related_rels.extend(self.metadata_store.get_relations(object=entity_name))
|
||||
related_rels.extend(self.metadata_store.get_relations(subject=entity_name, include_inactive=False))
|
||||
related_rels.extend(self.metadata_store.get_relations(object=entity_name, include_inactive=False))
|
||||
|
||||
for rel in related_rels:
|
||||
if rel["hash"] in seen_relations:
|
||||
@@ -1280,7 +1280,7 @@ class DualPathRetriever:
|
||||
|
||||
results = []
|
||||
for hash_value, score in zip(rel_ids, rel_scores):
|
||||
relation = self.metadata_store.get_relation(hash_value)
|
||||
relation = self.metadata_store.get_relation(hash_value, include_inactive=False)
|
||||
if relation is None:
|
||||
continue
|
||||
|
||||
@@ -1378,7 +1378,7 @@ class DualPathRetriever:
|
||||
deduplicated_results.append(result)
|
||||
continue
|
||||
# 检查关系关联的段落是否已存在
|
||||
relation = self.metadata_store.get_relation(result.hash_value)
|
||||
relation = self.metadata_store.get_relation(result.hash_value, include_inactive=False)
|
||||
if relation:
|
||||
# 获取关联的段落
|
||||
para_rels = self.metadata_store.query("""
|
||||
|
||||
@@ -255,7 +255,7 @@ class GraphRelationRecallService:
|
||||
graph_hops: int,
|
||||
graph_seed_entities: Sequence[str],
|
||||
) -> Optional[GraphRelationCandidate]:
|
||||
relation = self.metadata_store.get_relation(relation_hash)
|
||||
relation = self.metadata_store.get_relation(relation_hash, include_inactive=False)
|
||||
if relation is None:
|
||||
return None
|
||||
supporting_paragraphs = self.metadata_store.get_paragraphs_by_relation(relation_hash)
|
||||
|
||||
@@ -338,6 +338,7 @@ class SparseBM25Index:
|
||||
match_query=match_query,
|
||||
limit=max(1, int(k)),
|
||||
max_doc_len=self.config.relation_max_doc_len,
|
||||
include_inactive=False,
|
||||
conn=self._conn,
|
||||
)
|
||||
out: List[Dict[str, Any]] = []
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1190,11 +1190,14 @@ class GraphStore:
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 保存邻接矩阵
|
||||
matrix_path = data_dir / "graph_adjacency.npz"
|
||||
if self._adjacency is not None:
|
||||
matrix_path = data_dir / "graph_adjacency.npz"
|
||||
with atomic_write(matrix_path, "wb") as f:
|
||||
save_npz(f, self._adjacency)
|
||||
logger.debug(f"保存邻接矩阵: {matrix_path}")
|
||||
elif matrix_path.exists():
|
||||
matrix_path.unlink()
|
||||
logger.debug(f"删除陈旧邻接矩阵: {matrix_path}")
|
||||
|
||||
# 保存元数据
|
||||
metadata = {
|
||||
@@ -1288,9 +1291,29 @@ class GraphStore:
|
||||
if self._adjacency is not None:
|
||||
adj_n = self._adjacency.shape[0]
|
||||
current_n = len(self._nodes)
|
||||
if current_n > adj_n:
|
||||
if current_n == 0:
|
||||
logger.warning("检测到空图元数据但邻接矩阵仍然存在,已重置为空图。")
|
||||
self._adjacency = None
|
||||
self._edge_hash_map = defaultdict(set)
|
||||
elif current_n > adj_n:
|
||||
logger.warning(f"检测到图存储维度不匹配: 节点数={current_n}, 矩阵大小={adj_n}. 正在自动修复...")
|
||||
self._expand_adjacency_matrix(current_n - adj_n)
|
||||
elif current_n < adj_n:
|
||||
logger.warning(
|
||||
f"检测到过期邻接矩阵: 节点数={current_n}, 矩阵大小={adj_n}. 正在重置邻接矩阵..."
|
||||
)
|
||||
if self.matrix_format == "csc":
|
||||
self._adjacency = csc_matrix((current_n, current_n), dtype=np.float32)
|
||||
else:
|
||||
self._adjacency = csr_matrix((current_n, current_n), dtype=np.float32)
|
||||
self._edge_hash_map = defaultdict(
|
||||
set,
|
||||
{
|
||||
(src_idx, dst_idx): set(hashes)
|
||||
for (src_idx, dst_idx), hashes in self._edge_hash_map.items()
|
||||
if src_idx < current_n and dst_idx < current_n
|
||||
},
|
||||
)
|
||||
|
||||
self._adjacency_dirty = True
|
||||
logger.info(
|
||||
@@ -1445,4 +1468,3 @@ class GraphStore:
|
||||
self._adjacency_dirty = True
|
||||
logger.info(f"已从 {count} 条哈希重建边哈希映射,覆盖 {len(self._edge_hash_map)} 条边")
|
||||
return count
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -18,6 +18,7 @@ from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
from .episode_segmentation_service import EpisodeSegmentationService
|
||||
from .hash import compute_hash
|
||||
@@ -528,7 +529,11 @@ class EpisodeService:
|
||||
"paragraph_count": 0,
|
||||
}
|
||||
|
||||
paragraphs = self.metadata_store.get_live_paragraphs_by_source(token)
|
||||
memory_cfg = getattr(global_config, "memory", None)
|
||||
paragraphs = self.metadata_store.get_live_paragraphs_by_source(
|
||||
token,
|
||||
exclude_stale=bool(getattr(memory_cfg, "feedback_correction_paragraph_hard_filter_enabled", True)),
|
||||
)
|
||||
if not paragraphs:
|
||||
replace_result = self.metadata_store.replace_episodes_for_source(token, [])
|
||||
return {
|
||||
|
||||
@@ -90,9 +90,9 @@ def find_paths_between_entities(
|
||||
else:
|
||||
pred = "related"
|
||||
direction = "->"
|
||||
rels = metadata_store.get_relations(subject=u, object=v)
|
||||
rels = metadata_store.get_relations(subject=u, object=v, include_inactive=False)
|
||||
if not rels:
|
||||
rels = metadata_store.get_relations(subject=v, object=u)
|
||||
rels = metadata_store.get_relations(subject=v, object=u, include_inactive=False)
|
||||
direction = "<-"
|
||||
if rels:
|
||||
best_rel = max(rels, key=lambda x: x.get("confidence", 1.0))
|
||||
@@ -162,4 +162,3 @@ def to_retrieval_results(paths: Sequence[Dict[str, Any]]) -> List[RetrievalResul
|
||||
)
|
||||
)
|
||||
return converted
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from sqlmodel import select
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import PersonInfo
|
||||
from src.config.config import global_config
|
||||
|
||||
from ..embedding import EmbeddingAPIAdapter
|
||||
from ..retrieval import (
|
||||
@@ -285,11 +286,11 @@ class PersonProfileService:
|
||||
def _collect_relation_evidence(self, aliases: List[str], limit: int = 30) -> List[Dict[str, Any]]:
|
||||
relation_by_hash: Dict[str, Dict[str, Any]] = {}
|
||||
for alias in aliases:
|
||||
for rel in self.metadata_store.get_relations(subject=alias):
|
||||
for rel in self.metadata_store.get_relations(subject=alias, include_inactive=False):
|
||||
h = str(rel.get("hash", ""))
|
||||
if h:
|
||||
relation_by_hash[h] = rel
|
||||
for rel in self.metadata_store.get_relations(object=alias):
|
||||
for rel in self.metadata_store.get_relations(object=alias, include_inactive=False):
|
||||
h = str(rel.get("hash", ""))
|
||||
if h:
|
||||
relation_by_hash[h] = rel
|
||||
@@ -342,7 +343,53 @@ class PersonProfileService:
|
||||
"metadata": {},
|
||||
}
|
||||
)
|
||||
return evidence
|
||||
return self._filter_stale_paragraph_evidence(evidence)
|
||||
|
||||
def _filter_stale_paragraph_evidence(
|
||||
self,
|
||||
evidence: List[Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
memory_cfg = getattr(global_config, "memory", None)
|
||||
if not bool(getattr(memory_cfg, "feedback_correction_paragraph_hard_filter_enabled", True)):
|
||||
return evidence
|
||||
paragraph_hashes = [
|
||||
str(item.get("hash", "") or "").strip()
|
||||
for item in evidence
|
||||
if str(item.get("type", "") or "").strip() == "paragraph" and str(item.get("hash", "") or "").strip()
|
||||
]
|
||||
if not paragraph_hashes:
|
||||
return evidence
|
||||
|
||||
marks_by_paragraph = self.metadata_store.get_paragraph_stale_relation_marks_batch(paragraph_hashes)
|
||||
relation_hashes: List[str] = []
|
||||
seen = set()
|
||||
for marks in marks_by_paragraph.values():
|
||||
for mark in marks:
|
||||
relation_hash = str(mark.get("relation_hash", "") or "").strip()
|
||||
if not relation_hash or relation_hash in seen:
|
||||
continue
|
||||
seen.add(relation_hash)
|
||||
relation_hashes.append(relation_hash)
|
||||
status_map = self.metadata_store.get_relation_status_batch(relation_hashes) if relation_hashes else {}
|
||||
|
||||
filtered: List[Dict[str, Any]] = []
|
||||
for item in evidence:
|
||||
item_type = str(item.get("type", "") or "").strip()
|
||||
item_hash = str(item.get("hash", "") or "").strip()
|
||||
if item_type != "paragraph" or not item_hash:
|
||||
filtered.append(item)
|
||||
continue
|
||||
marks = marks_by_paragraph.get(item_hash, [])
|
||||
should_hide = any(
|
||||
status_map.get(str(mark.get("relation_hash", "") or "").strip()) is None
|
||||
or bool((status_map.get(str(mark.get("relation_hash", "") or "").strip()) or {}).get("is_inactive"))
|
||||
for mark in marks
|
||||
if str(mark.get("relation_hash", "") or "").strip()
|
||||
)
|
||||
if should_hide:
|
||||
continue
|
||||
filtered.append(item)
|
||||
return filtered
|
||||
|
||||
async def _collect_vector_evidence(
|
||||
self,
|
||||
@@ -373,7 +420,7 @@ class PersonProfileService:
|
||||
"metadata": {},
|
||||
}
|
||||
)
|
||||
return fallback[:top_k]
|
||||
return self._filter_stale_paragraph_evidence(fallback[:top_k])
|
||||
|
||||
per_alias_top_k = max(2, int(top_k / max(1, len(alias_queries))))
|
||||
seen_hash = set()
|
||||
@@ -406,7 +453,7 @@ class PersonProfileService:
|
||||
}
|
||||
)
|
||||
evidence.sort(key=lambda x: x.get("score", 0.0), reverse=True)
|
||||
return evidence[:top_k]
|
||||
return self._filter_stale_paragraph_evidence(evidence[:top_k])
|
||||
|
||||
def _build_profile_text(
|
||||
self,
|
||||
|
||||
@@ -5,12 +5,13 @@
|
||||
导入到 A_memorix 的存储组件中。
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.services import llm_service as llm_api
|
||||
@@ -222,7 +223,9 @@ class SummaryImporter:
|
||||
self,
|
||||
stream_id: str,
|
||||
context_length: Optional[int] = None,
|
||||
include_personality: Optional[bool] = None
|
||||
include_personality: Optional[bool] = None,
|
||||
time_end: Optional[float] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
从指定的聊天流中提取记录并执行总结导入
|
||||
@@ -231,6 +234,7 @@ class SummaryImporter:
|
||||
stream_id: 聊天流 ID
|
||||
context_length: 总结的历史消息条数
|
||||
include_personality: 是否包含人设
|
||||
time_end: 用于截取聊天记录的时间上界(闭区间)
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否成功, 结果消息)
|
||||
@@ -248,12 +252,13 @@ class SummaryImporter:
|
||||
include_personality = self.plugin_config.get("summarization", {}).get("include_personality", True)
|
||||
|
||||
# 2. 获取历史消息
|
||||
# 获取当前时间之前的消息
|
||||
now = time.time()
|
||||
messages = message_api.get_messages_before_time_in_chat(
|
||||
query_time_end = time.time() if time_end is None else float(time_end)
|
||||
messages = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=stream_id,
|
||||
timestamp=now,
|
||||
limit=context_length
|
||||
start_time=0.0,
|
||||
end_time=query_time_end,
|
||||
limit=context_length,
|
||||
limit_mode="latest",
|
||||
)
|
||||
|
||||
if not messages:
|
||||
@@ -323,7 +328,14 @@ class SummaryImporter:
|
||||
}
|
||||
|
||||
# 6. 执行导入
|
||||
await self._execute_import(summary_text, entities, relations, stream_id, time_meta=time_meta)
|
||||
await self._execute_import(
|
||||
summary_text,
|
||||
entities,
|
||||
relations,
|
||||
stream_id,
|
||||
time_meta=time_meta,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
# 7. 持久化
|
||||
self.vector_store.save()
|
||||
@@ -389,6 +401,7 @@ class SummaryImporter:
|
||||
relations: List[Dict[str, str]],
|
||||
stream_id: str,
|
||||
time_meta: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""将数据写入存储"""
|
||||
# 获取默认知识类型
|
||||
@@ -403,6 +416,7 @@ class SummaryImporter:
|
||||
hash_value = self.metadata_store.add_paragraph(
|
||||
content=summary,
|
||||
source=f"chat_summary:{stream_id}",
|
||||
metadata=metadata,
|
||||
knowledge_type=knowledge_type.value,
|
||||
time_meta=time_meta,
|
||||
)
|
||||
|
||||
@@ -190,6 +190,16 @@ class AMemorixHostService:
|
||||
)
|
||||
)
|
||||
|
||||
if component_name == "enqueue_feedback_task":
|
||||
return await kernel.enqueue_feedback_task(
|
||||
query_tool_id=str(payload.get("query_tool_id", "") or ""),
|
||||
session_id=str(payload.get("session_id", "") or ""),
|
||||
query_timestamp=payload.get("query_timestamp"),
|
||||
structured_content=payload.get("structured_content")
|
||||
if isinstance(payload.get("structured_content"), dict)
|
||||
else {},
|
||||
)
|
||||
|
||||
if component_name == "ingest_summary":
|
||||
return await kernel.ingest_summary(
|
||||
external_id=str(payload.get("external_id", "") or ""),
|
||||
@@ -251,6 +261,7 @@ class AMemorixHostService:
|
||||
"memory_source_admin": kernel.memory_source_admin,
|
||||
"memory_episode_admin": kernel.memory_episode_admin,
|
||||
"memory_profile_admin": kernel.memory_profile_admin,
|
||||
"memory_feedback_admin": kernel.memory_feedback_admin,
|
||||
"memory_runtime_admin": kernel.memory_runtime_admin,
|
||||
"memory_import_admin": kernel.memory_import_admin,
|
||||
"memory_tuning_admin": kernel.memory_tuning_admin,
|
||||
|
||||
@@ -62,7 +62,10 @@ if any(arg in {"-h", "--help"} for arg in sys.argv[1:]):
|
||||
|
||||
try:
|
||||
from A_memorix.core.storage import GraphStore, KnowledgeType, MetadataStore, QuantizationType, VectorStore
|
||||
from A_memorix.core.storage.metadata_store import SCHEMA_VERSION
|
||||
from A_memorix.core.storage.metadata_store import (
|
||||
RUNTIME_AUTO_MIGRATION_MIN_SCHEMA_VERSION,
|
||||
SCHEMA_VERSION,
|
||||
)
|
||||
except Exception as e: # pragma: no cover
|
||||
print(f"❌ failed to import storage modules: {e}")
|
||||
raise SystemExit(2)
|
||||
@@ -125,6 +128,14 @@ def _sqlite_table_exists(conn: sqlite3.Connection, table: str) -> bool:
|
||||
return row is not None
|
||||
|
||||
|
||||
def _sqlite_column_exists(conn: sqlite3.Connection, table: str, column: str) -> bool:
|
||||
try:
|
||||
rows = conn.execute(f"PRAGMA table_info({table})").fetchall()
|
||||
except Exception:
|
||||
return False
|
||||
return any(str(row[1] or "") == str(column or "") for row in rows)
|
||||
|
||||
|
||||
def _collect_hash_alias_conflicts(conn: sqlite3.Connection) -> Dict[str, List[str]]:
|
||||
hashes: List[str] = []
|
||||
if _sqlite_table_exists(conn, "relations"):
|
||||
@@ -152,6 +163,8 @@ def _collect_hash_alias_conflicts(conn: sqlite3.Connection) -> Dict[str, List[st
|
||||
def _collect_invalid_knowledge_types(conn: sqlite3.Connection) -> List[str]:
|
||||
if not _sqlite_table_exists(conn, "paragraphs"):
|
||||
return []
|
||||
if not _sqlite_column_exists(conn, "paragraphs", "knowledge_type"):
|
||||
return []
|
||||
|
||||
allowed = {item.value for item in KnowledgeType}
|
||||
rows = conn.execute("SELECT DISTINCT knowledge_type FROM paragraphs").fetchall()
|
||||
@@ -288,6 +301,14 @@ def _preflight_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]:
|
||||
facts["schema_migrations_exists"] = has_schema_table
|
||||
has_paragraph_backfill = _sqlite_table_exists(conn, "paragraph_vector_backfill")
|
||||
facts["paragraph_vector_backfill_exists"] = has_paragraph_backfill
|
||||
has_stale_marks = _sqlite_table_exists(conn, "paragraph_stale_relation_marks")
|
||||
facts["paragraph_stale_relation_marks_exists"] = has_stale_marks
|
||||
has_profile_refresh_queue = _sqlite_table_exists(conn, "person_profile_refresh_queue")
|
||||
facts["person_profile_refresh_queue_exists"] = has_profile_refresh_queue
|
||||
has_feedback_rollback_status = _sqlite_column_exists(conn, "memory_feedback_tasks", "rollback_status")
|
||||
facts["memory_feedback_tasks_rollback_status_exists"] = has_feedback_rollback_status
|
||||
has_feedback_rollback_plan = _sqlite_column_exists(conn, "memory_feedback_tasks", "rollback_plan_json")
|
||||
facts["memory_feedback_tasks_rollback_plan_exists"] = has_feedback_rollback_plan
|
||||
if not has_schema_table:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
@@ -300,14 +321,28 @@ def _preflight_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]:
|
||||
row = conn.execute("SELECT MAX(version) FROM schema_migrations").fetchone()
|
||||
version = int(row[0]) if row and row[0] is not None else 0
|
||||
facts["schema_version"] = version
|
||||
runtime_auto_migratable = (
|
||||
version < SCHEMA_VERSION
|
||||
and version >= RUNTIME_AUTO_MIGRATION_MIN_SCHEMA_VERSION
|
||||
)
|
||||
facts["schema_runtime_auto_migratable"] = runtime_auto_migratable
|
||||
if version != SCHEMA_VERSION:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-08",
|
||||
"error",
|
||||
f"schema version mismatch: current={version}, expected={SCHEMA_VERSION}",
|
||||
if runtime_auto_migratable:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-18",
|
||||
"warning",
|
||||
f"schema version behind runtime target: current={version}, expected={SCHEMA_VERSION}; runtime auto migration will handle this update",
|
||||
)
|
||||
)
|
||||
else:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-08",
|
||||
"error",
|
||||
f"schema version mismatch: current={version}, expected={SCHEMA_VERSION}",
|
||||
)
|
||||
)
|
||||
)
|
||||
elif not has_paragraph_backfill:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
@@ -316,6 +351,30 @@ def _preflight_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]:
|
||||
"paragraph_vector_backfill table missing under current schema version",
|
||||
)
|
||||
)
|
||||
elif not has_stale_marks:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-15",
|
||||
"error",
|
||||
"paragraph_stale_relation_marks table missing under current schema version",
|
||||
)
|
||||
)
|
||||
elif not has_profile_refresh_queue:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-16",
|
||||
"error",
|
||||
"person_profile_refresh_queue table missing under current schema version",
|
||||
)
|
||||
)
|
||||
elif not has_feedback_rollback_status or not has_feedback_rollback_plan:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-17",
|
||||
"error",
|
||||
"memory_feedback_tasks rollback columns missing under current schema version",
|
||||
)
|
||||
)
|
||||
|
||||
if _sqlite_table_exists(conn, "relations"):
|
||||
row = conn.execute("SELECT COUNT(*) FROM relations").fetchone()
|
||||
@@ -616,6 +675,46 @@ def _verify_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]:
|
||||
"paragraph_vector_backfill table missing after migration",
|
||||
)
|
||||
)
|
||||
has_feedback_tasks = _sqlite_table_exists(conn, "memory_feedback_tasks")
|
||||
facts["memory_feedback_tasks_exists"] = bool(has_feedback_tasks)
|
||||
if not has_feedback_tasks:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-15",
|
||||
"error",
|
||||
"memory_feedback_tasks table missing after migration",
|
||||
)
|
||||
)
|
||||
has_feedback_logs = _sqlite_table_exists(conn, "memory_feedback_action_logs")
|
||||
facts["memory_feedback_action_logs_exists"] = bool(has_feedback_logs)
|
||||
if not has_feedback_logs:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-16",
|
||||
"error",
|
||||
"memory_feedback_action_logs table missing after migration",
|
||||
)
|
||||
)
|
||||
has_feedback_rollback_status = _sqlite_column_exists(conn, "memory_feedback_tasks", "rollback_status")
|
||||
facts["memory_feedback_tasks_rollback_status_exists"] = bool(has_feedback_rollback_status)
|
||||
if not has_feedback_rollback_status:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-17",
|
||||
"error",
|
||||
"memory_feedback_tasks.rollback_status missing after migration",
|
||||
)
|
||||
)
|
||||
has_feedback_rollback_plan = _sqlite_column_exists(conn, "memory_feedback_tasks", "rollback_plan_json")
|
||||
facts["memory_feedback_tasks_rollback_plan_exists"] = bool(has_feedback_rollback_plan)
|
||||
if not has_feedback_rollback_plan:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-18",
|
||||
"error",
|
||||
"memory_feedback_tasks.rollback_plan_json missing after migration",
|
||||
)
|
||||
)
|
||||
conflicts = _collect_hash_alias_conflicts(conn)
|
||||
invalid_knowledge_types = _collect_invalid_knowledge_types(conn)
|
||||
finally:
|
||||
|
||||
@@ -206,6 +206,40 @@ def _migrate_target_item_list(parent: dict[str, Any], key: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _parse_planner_mode(value: Any) -> Optional[str]:
|
||||
"""
|
||||
兼容旧 planner 配置到当前 visual.planner_mode。
|
||||
"""
|
||||
if isinstance(value, bool):
|
||||
return "multimodal" if value else "text"
|
||||
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
|
||||
normalized_value = value.strip().lower()
|
||||
if normalized_value in {"text", "multimodal", "auto"}:
|
||||
return normalized_value
|
||||
return None
|
||||
|
||||
|
||||
def _parse_replyer_mode(value: Any) -> Optional[str]:
|
||||
"""
|
||||
兼容旧 replyer 配置到当前 visual.replyer_mode。
|
||||
"""
|
||||
if isinstance(value, bool):
|
||||
return "multimodal" if value else "text"
|
||||
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
|
||||
normalized_value = value.strip().lower()
|
||||
if normalized_value in {"text", "multimodal", "auto"}:
|
||||
return normalized_value
|
||||
if normalized_value == "legacy":
|
||||
return "text"
|
||||
return None
|
||||
|
||||
|
||||
def migrate_legacy_bind_env_to_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
"""将旧版 `.env` 中的绑定地址迁移到主配置结构。"""
|
||||
|
||||
@@ -280,8 +314,16 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
migrated_any = True
|
||||
reasons.append("expression.manual_reflect_operator_id_empty")
|
||||
|
||||
chat = _as_dict(data.get("chat"))
|
||||
personality = _as_dict(data.get("personality"))
|
||||
visual = _as_dict(data.get("visual"))
|
||||
if visual is None and (
|
||||
(personality is not None and "visual_style" in personality)
|
||||
or (chat is not None and ("multimodal_planner" in chat or "replyer_generator_type" in chat))
|
||||
):
|
||||
visual = {}
|
||||
data["visual"] = visual
|
||||
|
||||
if visual is not None and personality is not None and "visual_style" in personality:
|
||||
if "visual_style" not in visual:
|
||||
visual["visual_style"] = personality["visual_style"]
|
||||
@@ -289,14 +331,41 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
migrated_any = True
|
||||
reasons.append("personality.visual_style_moved_to_visual.visual_style")
|
||||
|
||||
if visual is not None and "multimodal_planner" in visual and "planner_mode" not in visual:
|
||||
multimodal_planner = visual.pop("multimodal_planner")
|
||||
if isinstance(multimodal_planner, bool):
|
||||
visual["planner_mode"] = "multimodal" if multimodal_planner else "text"
|
||||
if visual is not None and "multimodal_planner" in visual:
|
||||
planner_mode = _parse_planner_mode(visual.get("multimodal_planner"))
|
||||
if "planner_mode" not in visual and planner_mode is not None:
|
||||
visual["planner_mode"] = planner_mode
|
||||
if "planner_mode" in visual:
|
||||
visual.pop("multimodal_planner", None)
|
||||
migrated_any = True
|
||||
reasons.append("visual.multimodal_planner_moved_to_visual.planner_mode")
|
||||
else:
|
||||
visual["multimodal_planner"] = multimodal_planner
|
||||
|
||||
if visual is not None and chat is not None and "multimodal_planner" in chat:
|
||||
planner_mode = _parse_planner_mode(chat.get("multimodal_planner"))
|
||||
if "planner_mode" not in visual and planner_mode is not None:
|
||||
visual["planner_mode"] = planner_mode
|
||||
if "planner_mode" in visual:
|
||||
chat.pop("multimodal_planner", None)
|
||||
migrated_any = True
|
||||
reasons.append("chat.multimodal_planner_moved_to_visual.planner_mode")
|
||||
|
||||
if visual is not None and "multimodal_replyer" in visual:
|
||||
replyer_mode = _parse_replyer_mode(visual.get("multimodal_replyer"))
|
||||
if "replyer_mode" not in visual and replyer_mode is not None:
|
||||
visual["replyer_mode"] = replyer_mode
|
||||
if "replyer_mode" in visual:
|
||||
visual.pop("multimodal_replyer", None)
|
||||
migrated_any = True
|
||||
reasons.append("visual.multimodal_replyer_moved_to_visual.replyer_mode")
|
||||
|
||||
if visual is not None and chat is not None and "replyer_generator_type" in chat:
|
||||
replyer_mode = _parse_replyer_mode(chat.get("replyer_generator_type"))
|
||||
if "replyer_mode" not in visual and replyer_mode is not None:
|
||||
visual["replyer_mode"] = replyer_mode
|
||||
if "replyer_mode" in visual:
|
||||
chat.pop("replyer_generator_type", None)
|
||||
migrated_any = True
|
||||
reasons.append("chat.replyer_generator_type_moved_to_visual.replyer_mode")
|
||||
|
||||
memory = _as_dict(data.get("memory"))
|
||||
if memory is not None and _migrate_target_item_list(memory, "global_memory_blacklist"):
|
||||
|
||||
@@ -149,10 +149,10 @@ class VisualConfig(ConfigBase):
|
||||
default="auto",
|
||||
json_schema_extra={
|
||||
"x-widget": "select",
|
||||
"x-icon": "git-branch",
|
||||
"x-icon": "image",
|
||||
},
|
||||
)
|
||||
"""规划器模式,auto根据模型信息自动选择,text为纯文本模式,multimodal为多模态模式"""
|
||||
"""Planner 视觉模式:text 仅文本,multimodal 强制多模态,auto 按模型能力自动选择"""
|
||||
|
||||
replyer_mode: Literal["text", "multimodal", "auto"] = Field(
|
||||
default="auto",
|
||||
@@ -161,7 +161,7 @@ class VisualConfig(ConfigBase):
|
||||
"x-icon": "git-branch",
|
||||
},
|
||||
)
|
||||
"""回复器模式,auto根据模型信息自动选择,text为纯文本模式,multimodal为多模态模式"""
|
||||
"""Replyer 视觉模式:text 仅文本,multimodal 强制多模态,auto 按模型能力自动选择"""
|
||||
|
||||
visual_style: str = Field(
|
||||
default="请用中文描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题,直观感受,输出为一段平文本,最多30字,请注意不要分点,就输出一段文本",
|
||||
@@ -239,12 +239,17 @@ class ChatConfig(ConfigBase):
|
||||
)
|
||||
"""Planner 连续被新消息打断的最大次数,0 表示不启用打断"""
|
||||
|
||||
plan_reply_log_max_per_chat: int = Field(
|
||||
default=1024,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "file-text",
|
||||
},
|
||||
)
|
||||
"""每个聊天流最大保存的Plan/Reply日志数量,超过此数量时会自动删除最老的日志"""
|
||||
|
||||
group_chat_prompt: str = Field(
|
||||
default="""
|
||||
你正在qq群里聊天,下面是群里正在聊的内容,其中包含聊天记录和聊天中的图片。
|
||||
回复尽量简短一些。最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。请注意把握聊天内容。
|
||||
不要回复的太频繁!控制回复的频率,不要每个人的消息都回复,只回复你感兴趣的或者主动提及你的。
|
||||
""",
|
||||
default="你需要控制自己发言的频率,如果是一对一聊天,可以以较均匀的频率发言;如果用户较多,不要每句都回复,控制回复频率,不要回复的太频繁!控制回复的频率,不要每个人的消息都回复。",
|
||||
json_schema_extra={
|
||||
"x-widget": "textarea",
|
||||
"x-icon": "users",
|
||||
@@ -253,11 +258,7 @@ class ChatConfig(ConfigBase):
|
||||
"""_wrap_群聊通用注意事项"""
|
||||
|
||||
private_chat_prompts: str = Field(
|
||||
default="""
|
||||
你正在聊天,下面是正在聊的内容,其中包含聊天记录和聊天中的图片。
|
||||
回复尽量简短一些。请注意把握聊天内容。
|
||||
请考虑对方的发言频率,想法,思考自己何时回复以及回复内容。
|
||||
""",
|
||||
default="你需要控制自己发言的频率,可以以较均匀的频率发言。",
|
||||
json_schema_extra={
|
||||
"x-widget": "textarea",
|
||||
"x-icon": "user",
|
||||
@@ -414,6 +415,228 @@ class MemoryConfig(ConfigBase):
|
||||
)
|
||||
"""Maisaka 内置长期记忆检索工具 query_memory 的默认返回条数"""
|
||||
|
||||
person_fact_writeback_enabled: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "user-round-pen",
|
||||
},
|
||||
)
|
||||
"""是否在发送回复后自动提取并写回人物事实到长期记忆"""
|
||||
|
||||
chat_summary_writeback_enabled: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "scroll-text",
|
||||
},
|
||||
)
|
||||
"""是否在 Maisaka 聊天过程中按消息窗口自动写回聊天摘要到长期记忆"""
|
||||
|
||||
chat_summary_writeback_message_threshold: int = Field(
|
||||
default=12,
|
||||
ge=1,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "messages-square",
|
||||
},
|
||||
)
|
||||
"""自动写回聊天摘要的消息窗口阈值"""
|
||||
|
||||
chat_summary_writeback_context_length: int = Field(
|
||||
default=50,
|
||||
ge=1,
|
||||
le=500,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "rows-3",
|
||||
},
|
||||
)
|
||||
"""自动写回聊天摘要时,从聊天流中回看的消息条数"""
|
||||
|
||||
feedback_correction_enabled: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "message-circle-warning",
|
||||
},
|
||||
)
|
||||
"""是否启用反馈驱动的延迟记忆纠错任务"""
|
||||
|
||||
feedback_correction_window_hours: float = Field(
|
||||
default=12.0,
|
||||
ge=0.1,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "clock-4",
|
||||
},
|
||||
)
|
||||
"""反馈窗口时长(小时),以 query_memory 执行时间为起点"""
|
||||
|
||||
feedback_correction_check_interval_minutes: int = Field(
|
||||
default=30,
|
||||
ge=1,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "timer",
|
||||
},
|
||||
)
|
||||
"""反馈纠错定时任务轮询间隔(分钟)"""
|
||||
|
||||
feedback_correction_batch_size: int = Field(
|
||||
default=20,
|
||||
ge=1,
|
||||
le=200,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "list-ordered",
|
||||
},
|
||||
)
|
||||
"""反馈纠错每轮最大处理任务数"""
|
||||
|
||||
feedback_correction_auto_apply_threshold: float = Field(
|
||||
default=0.85,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
json_schema_extra={
|
||||
"x-widget": "slider",
|
||||
"x-icon": "gauge",
|
||||
"step": 0.01,
|
||||
},
|
||||
)
|
||||
"""自动应用纠错动作的最低置信度阈值"""
|
||||
|
||||
feedback_correction_max_feedback_messages: int = Field(
|
||||
default=30,
|
||||
ge=1,
|
||||
le=200,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "messages-square",
|
||||
},
|
||||
)
|
||||
"""每个纠错任务最多使用的窗口内用户反馈消息数"""
|
||||
|
||||
feedback_correction_prefilter_enabled: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "filter",
|
||||
},
|
||||
)
|
||||
"""是否启用纠错前置预筛(用于减少不必要的模型调用)"""
|
||||
|
||||
feedback_correction_paragraph_mark_enabled: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "sticky-note",
|
||||
},
|
||||
)
|
||||
"""是否为受影响 paragraph 写入已纠正旧事实标记"""
|
||||
|
||||
feedback_correction_paragraph_hard_filter_enabled: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "eye-off",
|
||||
},
|
||||
)
|
||||
"""是否在用户侧查询中硬过滤带有 stale 标记的 paragraph"""
|
||||
|
||||
feedback_correction_profile_refresh_enabled: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "user-round-search",
|
||||
},
|
||||
)
|
||||
"""是否在反馈纠错后将受影响人物画像加入刷新队列"""
|
||||
|
||||
feedback_correction_profile_force_refresh_on_read: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "refresh-ccw",
|
||||
},
|
||||
)
|
||||
"""人物画像处于脏队列时,读取是否强制刷新而不直接复用旧快照"""
|
||||
|
||||
feedback_correction_episode_rebuild_enabled: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "clapperboard",
|
||||
},
|
||||
)
|
||||
"""是否在反馈纠错后将受影响 source 加入 episode 重建队列"""
|
||||
|
||||
feedback_correction_episode_query_block_enabled: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "ban",
|
||||
},
|
||||
)
|
||||
"""episode source 处于重建队列时,是否对用户侧查询做屏蔽"""
|
||||
|
||||
feedback_correction_reconcile_interval_minutes: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "repeat",
|
||||
},
|
||||
)
|
||||
"""反馈纠错二阶段一致性后台协调任务轮询间隔(分钟)"""
|
||||
|
||||
feedback_correction_reconcile_batch_size: int = Field(
|
||||
default=20,
|
||||
ge=1,
|
||||
le=200,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "list-restart",
|
||||
},
|
||||
)
|
||||
"""反馈纠错二阶段一致性每轮处理 profile/episode 队列的批大小"""
|
||||
|
||||
def model_post_init(self, context: Optional[dict] = None) -> None:
|
||||
"""验证配置值"""
|
||||
if self.feedback_correction_window_hours <= 0:
|
||||
raise ValueError(
|
||||
f"feedback_correction_window_hours 必须大于0,当前值: {self.feedback_correction_window_hours}"
|
||||
)
|
||||
if self.feedback_correction_check_interval_minutes < 1:
|
||||
raise ValueError(
|
||||
"feedback_correction_check_interval_minutes 必须至少为1,"
|
||||
f"当前值: {self.feedback_correction_check_interval_minutes}"
|
||||
)
|
||||
if self.feedback_correction_batch_size < 1:
|
||||
raise ValueError(
|
||||
f"feedback_correction_batch_size 必须至少为1,当前值: {self.feedback_correction_batch_size}"
|
||||
)
|
||||
if not 0 <= self.feedback_correction_auto_apply_threshold <= 1:
|
||||
raise ValueError(
|
||||
"feedback_correction_auto_apply_threshold 必须在 [0, 1] 之间,"
|
||||
f"当前值: {self.feedback_correction_auto_apply_threshold}"
|
||||
)
|
||||
if self.feedback_correction_max_feedback_messages < 1:
|
||||
raise ValueError(
|
||||
"feedback_correction_max_feedback_messages 必须至少为1,"
|
||||
f"当前值: {self.feedback_correction_max_feedback_messages}"
|
||||
)
|
||||
if self.feedback_correction_reconcile_interval_minutes < 1:
|
||||
raise ValueError(
|
||||
"feedback_correction_reconcile_interval_minutes 必须至少为1,"
|
||||
f"当前值: {self.feedback_correction_reconcile_interval_minutes}"
|
||||
)
|
||||
if self.feedback_correction_reconcile_batch_size < 1:
|
||||
raise ValueError(
|
||||
"feedback_correction_reconcile_batch_size 必须至少为1,"
|
||||
f"当前值: {self.feedback_correction_reconcile_batch_size}"
|
||||
)
|
||||
return super().model_post_init(context)
|
||||
|
||||
|
||||
class LearningItem(ConfigBase):
|
||||
@@ -471,15 +694,6 @@ class LearningItem(ConfigBase):
|
||||
)
|
||||
"""是否启用jargon学习"""
|
||||
|
||||
advanced_chosen: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "sparkles",
|
||||
},
|
||||
)
|
||||
"""是否启用基于子代理的二次表达方式选择"""
|
||||
|
||||
|
||||
class ExpressionGroup(ConfigBase):
|
||||
"""表达互通组配置类,若列表为空代表全局共享"""
|
||||
@@ -509,7 +723,6 @@ class ExpressionConfig(ConfigBase):
|
||||
use_expression=True,
|
||||
enable_learning=True,
|
||||
enable_jargon_learning=True,
|
||||
advanced_chosen=False,
|
||||
)
|
||||
],
|
||||
json_schema_extra={
|
||||
@@ -1381,6 +1594,35 @@ class MaiSakaConfig(ConfigBase):
|
||||
)
|
||||
"""MaiSaka 使用的用户名称"""
|
||||
|
||||
tool_filter_task_name: str = Field(
|
||||
default="utils",
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "sparkles",
|
||||
},
|
||||
)
|
||||
"""工具筛选预判使用的模型任务名"""
|
||||
|
||||
tool_filter_threshold: int = Field(
|
||||
default=20,
|
||||
ge=1,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "filter",
|
||||
},
|
||||
)
|
||||
"""当可用工具总数超过该阈值时,先进行一轮工具筛选"""
|
||||
|
||||
tool_filter_max_keep: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "list-filter",
|
||||
},
|
||||
)
|
||||
"""工具筛选阶段最多保留的非内置工具数量"""
|
||||
|
||||
show_image_path: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
|
||||
@@ -19,6 +19,7 @@ from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvo
|
||||
from src.llm_models.exceptions import ReqAbortException
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
from src.services import database_service as database_api
|
||||
from src.services.memory_service import memory_service
|
||||
|
||||
from .builtin_tool import get_action_tool_specs
|
||||
from .builtin_tool import build_builtin_tool_handlers as build_split_builtin_tool_handlers
|
||||
@@ -1123,12 +1124,13 @@ class MaisakaReasoningEngine:
|
||||
builtin_prompt = tool_spec.build_llm_description()
|
||||
|
||||
try:
|
||||
await database_api.store_tool_info(
|
||||
tool_record_payload = self._build_tool_record_payload(invocation, result, tool_spec)
|
||||
saved_record = await database_api.store_tool_info(
|
||||
chat_stream=self._runtime.chat_stream,
|
||||
builtin_prompt=builtin_prompt,
|
||||
display_prompt=self._build_tool_display_prompt(invocation, result, tool_spec),
|
||||
tool_id=invocation.call_id,
|
||||
tool_data=self._build_tool_record_payload(invocation, result, tool_spec),
|
||||
tool_data=tool_record_payload,
|
||||
tool_name=invocation.tool_name,
|
||||
tool_reasoning=invocation.reasoning,
|
||||
)
|
||||
@@ -1136,6 +1138,28 @@ class MaisakaReasoningEngine:
|
||||
logger.exception(
|
||||
f"{self._runtime.log_prefix} 写入工具记录失败: 工具={invocation.tool_name} 调用编号={invocation.call_id}"
|
||||
)
|
||||
return
|
||||
|
||||
if invocation.tool_name == "query_memory" and isinstance(saved_record, dict):
|
||||
try:
|
||||
enqueue_payload = await memory_service.enqueue_feedback_task(
|
||||
query_tool_id=str(saved_record.get("tool_id") or invocation.call_id or "").strip(),
|
||||
session_id=str(saved_record.get("session_id") or self._runtime.chat_stream.session_id or "").strip(),
|
||||
query_timestamp=saved_record.get("timestamp"),
|
||||
structured_content=tool_record_payload.get("structured_content")
|
||||
if isinstance(tool_record_payload.get("structured_content"), dict)
|
||||
else {},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"{self._runtime.log_prefix} 反馈纠错任务入队失败: tool_call_id={invocation.call_id}"
|
||||
)
|
||||
else:
|
||||
if not bool(enqueue_payload.get("success")):
|
||||
logger.debug(
|
||||
f"{self._runtime.log_prefix} 反馈纠错任务未入队: "
|
||||
f"tool_call_id={invocation.call_id} reason={enqueue_payload.get('reason', '')}"
|
||||
)
|
||||
|
||||
def _append_tool_execution_result(self, tool_call: ToolCall, result: ToolExecutionResult) -> None:
|
||||
"""将统一工具执行结果写回 Maisaka 历史。
|
||||
@@ -1316,4 +1340,3 @@ class MaisakaReasoningEngine:
|
||||
return True, tool_result_summaries, tool_monitor_results
|
||||
|
||||
return False, tool_result_summaries, tool_monitor_results
|
||||
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, List, Optional
|
||||
import pickle
|
||||
import time
|
||||
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.services import memory_service as memory_service_module
|
||||
from src.chat.utils.utils import is_bot_self
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import find_messages
|
||||
from src.common.message_repository import count_messages, find_messages
|
||||
from src.config.config import global_config
|
||||
from src.person_info.person_info import Person, get_person_id, store_person_memory_from_answer
|
||||
from src.services.memory_service import memory_service
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
logger = get_logger("memory_flow_service")
|
||||
@@ -210,27 +217,260 @@ class PersonFactWritebackService:
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatSummaryWritebackState:
|
||||
last_trigger_message_count: int = 0
|
||||
last_trigger_time: float = 0.0
|
||||
|
||||
|
||||
class ChatSummaryWritebackService:
|
||||
def __init__(self) -> None:
|
||||
self._queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=256)
|
||||
self._worker_task: Optional[asyncio.Task] = None
|
||||
self._stopping = False
|
||||
self._states: dict[str, ChatSummaryWritebackState] = {}
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._worker_task is not None and not self._worker_task.done():
|
||||
return
|
||||
self._stopping = False
|
||||
self._worker_task = asyncio.create_task(self._worker_loop(), name="memory_chat_summary_writeback")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self._stopping = True
|
||||
worker = self._worker_task
|
||||
self._worker_task = None
|
||||
if worker is None:
|
||||
return
|
||||
worker.cancel()
|
||||
try:
|
||||
await worker
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("关闭聊天摘要写回 worker 失败: %s", exc)
|
||||
|
||||
async def enqueue(self, message: Any) -> None:
|
||||
if not bool(getattr(global_config.memory, "chat_summary_writeback_enabled", True)):
|
||||
return
|
||||
if self._stopping:
|
||||
return
|
||||
try:
|
||||
self._queue.put_nowait(message)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning("聊天摘要写回队列已满,跳过本次触发")
|
||||
|
||||
async def _worker_loop(self) -> None:
|
||||
try:
|
||||
while not self._stopping:
|
||||
message = await self._queue.get()
|
||||
try:
|
||||
await self._handle_message(message)
|
||||
except Exception as exc:
|
||||
logger.warning("聊天摘要写回处理失败: %s", exc, exc_info=True)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
async def _handle_message(self, message: Any) -> None:
|
||||
session_id = self._resolve_session_id(message)
|
||||
if not session_id:
|
||||
return
|
||||
|
||||
total_message_count = count_messages(session_id=session_id)
|
||||
if total_message_count <= 0:
|
||||
return
|
||||
|
||||
threshold = self._message_threshold()
|
||||
state = self._states.get(session_id)
|
||||
if state is None:
|
||||
restored_count = await self._load_last_trigger_message_count(
|
||||
session_id=session_id,
|
||||
total_message_count=total_message_count,
|
||||
)
|
||||
state = ChatSummaryWritebackState(
|
||||
last_trigger_message_count=restored_count,
|
||||
last_trigger_time=time.time() if restored_count > 0 else 0.0,
|
||||
)
|
||||
self._states[session_id] = state
|
||||
pending_message_count = max(0, total_message_count - state.last_trigger_message_count)
|
||||
if pending_message_count < threshold:
|
||||
return
|
||||
|
||||
context_length = self._context_length()
|
||||
message_time = self._extract_message_timestamp(message)
|
||||
result = await memory_service.ingest_summary(
|
||||
external_id=f"chat_auto_summary:{session_id}:{total_message_count}",
|
||||
chat_id=session_id,
|
||||
text="",
|
||||
participants=[],
|
||||
time_end=message_time,
|
||||
metadata={
|
||||
"generate_from_chat": True,
|
||||
"context_length": context_length,
|
||||
"writeback_source": "memory_flow_service",
|
||||
"trigger": "message_threshold",
|
||||
"trigger_message_count": total_message_count,
|
||||
},
|
||||
respect_filter=True,
|
||||
user_id=self._extract_session_user_id(message),
|
||||
group_id=self._extract_session_group_id(message),
|
||||
)
|
||||
if not getattr(result, "success", False):
|
||||
logger.warning(
|
||||
"聊天摘要自动写回失败: session_id=%s detail=%s",
|
||||
session_id,
|
||||
getattr(result, "detail", ""),
|
||||
)
|
||||
return
|
||||
|
||||
state.last_trigger_message_count = total_message_count
|
||||
state.last_trigger_time = time.time()
|
||||
logger.info(
|
||||
"聊天摘要自动写回成功: session_id=%s trigger=%s total_messages=%s context_length=%s detail=%s",
|
||||
session_id,
|
||||
"message_threshold",
|
||||
total_message_count,
|
||||
context_length,
|
||||
getattr(result, "detail", ""),
|
||||
)
|
||||
|
||||
async def _load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
|
||||
"""从已落库的聊天摘要恢复触发游标,避免服务重启后重复摘要。"""
|
||||
try:
|
||||
runtime_manager = getattr(memory_service_module, "a_memorix_host_service", None)
|
||||
ensure_kernel = getattr(runtime_manager, "_ensure_kernel", None)
|
||||
if not callable(ensure_kernel):
|
||||
return 0
|
||||
|
||||
kernel = await ensure_kernel()
|
||||
metadata_store = getattr(kernel, "metadata_store", None)
|
||||
if metadata_store is None:
|
||||
return 0
|
||||
|
||||
paragraphs = metadata_store.get_paragraphs_by_source(f"chat_summary:{session_id}")
|
||||
if not paragraphs:
|
||||
return 0
|
||||
|
||||
latest_paragraph = max(paragraphs, key=self._paragraph_created_at)
|
||||
metadata = self._paragraph_metadata(latest_paragraph)
|
||||
trigger_message_count = self._coerce_positive_int(metadata.get("trigger_message_count"))
|
||||
if trigger_message_count > 0:
|
||||
return min(total_message_count, trigger_message_count)
|
||||
|
||||
# 兼容旧摘要数据:没有触发计数时,只能退化为对齐当前计数,
|
||||
# 至少避免重启后立刻重复写入一条相近摘要。
|
||||
return total_message_count
|
||||
except Exception as exc:
|
||||
logger.debug("恢复聊天摘要写回游标失败: session_id=%s error=%s", session_id, exc)
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def _paragraph_created_at(paragraph: dict[str, Any]) -> float:
|
||||
try:
|
||||
return float(paragraph.get("created_at") or 0.0)
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def _paragraph_metadata(paragraph: dict[str, Any]) -> dict[str, Any]:
|
||||
metadata = paragraph.get("metadata")
|
||||
if isinstance(metadata, dict):
|
||||
return metadata
|
||||
if isinstance(metadata, (bytes, bytearray)):
|
||||
try:
|
||||
parsed = pickle.loads(metadata)
|
||||
except Exception:
|
||||
return {}
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _coerce_positive_int(value: Any) -> int:
|
||||
try:
|
||||
number = int(value or 0)
|
||||
except Exception:
|
||||
return 0
|
||||
return max(0, number)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_session_id(message: Any) -> str:
|
||||
return str(
|
||||
getattr(message, "session_id", "")
|
||||
or getattr(getattr(message, "session", None), "session_id", "")
|
||||
or ""
|
||||
).strip()
|
||||
|
||||
@staticmethod
|
||||
def _extract_session_user_id(message: Any) -> str:
|
||||
return str(
|
||||
getattr(getattr(message, "session", None), "user_id", "")
|
||||
or getattr(message, "user_id", "")
|
||||
or ""
|
||||
).strip()
|
||||
|
||||
@staticmethod
|
||||
def _extract_session_group_id(message: Any) -> str:
|
||||
return str(
|
||||
getattr(getattr(message, "session", None), "group_id", "")
|
||||
or getattr(message, "group_id", "")
|
||||
or ""
|
||||
).strip()
|
||||
|
||||
@staticmethod
|
||||
def _extract_message_timestamp(message: Any) -> float | None:
|
||||
raw_timestamp = getattr(message, "timestamp", None)
|
||||
if isinstance(raw_timestamp, datetime):
|
||||
return raw_timestamp.timestamp()
|
||||
if hasattr(raw_timestamp, "timestamp") and callable(raw_timestamp.timestamp):
|
||||
try:
|
||||
return float(raw_timestamp.timestamp())
|
||||
except Exception:
|
||||
return None
|
||||
if isinstance(raw_timestamp, (int, float)):
|
||||
return float(raw_timestamp)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _message_threshold() -> int:
|
||||
return max(1, int(getattr(global_config.memory, "chat_summary_writeback_message_threshold", 12) or 12))
|
||||
|
||||
@staticmethod
|
||||
def _context_length() -> int:
|
||||
return max(1, int(getattr(global_config.memory, "chat_summary_writeback_context_length", 50) or 50))
|
||||
|
||||
|
||||
class MemoryAutomationService:
|
||||
def __init__(self) -> None:
|
||||
self.fact_writeback = PersonFactWritebackService()
|
||||
self.chat_summary_writeback = ChatSummaryWritebackService()
|
||||
self._started = False
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._started:
|
||||
return
|
||||
await self.fact_writeback.start()
|
||||
await self.chat_summary_writeback.start()
|
||||
self._started = True
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if not self._started:
|
||||
return
|
||||
await self.chat_summary_writeback.shutdown()
|
||||
await self.fact_writeback.shutdown()
|
||||
self._started = False
|
||||
|
||||
async def on_incoming_message(self, message: Any) -> None:
|
||||
del message
|
||||
if not self._started:
|
||||
await self.start()
|
||||
|
||||
async def on_message_sent(self, message: Any) -> None:
|
||||
if not self._started:
|
||||
await self.start()
|
||||
await self.fact_writeback.enqueue(message)
|
||||
await self.chat_summary_writeback.enqueue(message)
|
||||
|
||||
|
||||
memory_automation_service = MemoryAutomationService()
|
||||
|
||||
@@ -233,6 +233,30 @@ class MemoryService:
|
||||
logger.warning("长期记忆搜索失败: %s", exc)
|
||||
return MemorySearchResult(success=False, error=str(exc))
|
||||
|
||||
async def enqueue_feedback_task(
|
||||
self,
|
||||
*,
|
||||
query_tool_id: str,
|
||||
session_id: str,
|
||||
query_timestamp: Any = None,
|
||||
structured_content: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
try:
|
||||
payload = await self._invoke(
|
||||
"enqueue_feedback_task",
|
||||
{
|
||||
"query_tool_id": str(query_tool_id or "").strip(),
|
||||
"session_id": str(session_id or "").strip(),
|
||||
"query_timestamp": query_timestamp,
|
||||
"structured_content": structured_content if isinstance(structured_content, dict) else {},
|
||||
},
|
||||
timeout_ms=10000,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("反馈纠错任务入队失败: %s", exc)
|
||||
return {"success": False, "queued": False, "reason": str(exc)}
|
||||
return payload if isinstance(payload, dict) else {"success": False, "queued": False, "reason": "invalid_payload"}
|
||||
|
||||
async def ingest_summary(
|
||||
self,
|
||||
*,
|
||||
@@ -388,6 +412,13 @@ class MemoryService:
|
||||
logger.warning("画像管理调用失败: %s", exc)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def feedback_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
|
||||
try:
|
||||
return await self._invoke_admin("memory_feedback_admin", action=action, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("反馈纠错管理调用失败: %s", exc)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def runtime_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
|
||||
try:
|
||||
return await self._invoke_admin("memory_runtime_admin", action=action, **kwargs)
|
||||
|
||||
@@ -205,6 +205,12 @@ def _setup_static_files(app: FastAPI):
|
||||
|
||||
|
||||
def _resolve_static_path() -> Path | None:
|
||||
# 开发环境优先允许复用仓库里的现成 dist
|
||||
base_dir = _get_project_root()
|
||||
static_path = base_dir / "dashboard" / "dist"
|
||||
if static_path.is_dir() and (static_path / "index.html").exists():
|
||||
return static_path
|
||||
|
||||
try:
|
||||
module = import_module("maibot_dashboard")
|
||||
get_dist_path = getattr(module, "get_dist_path", None)
|
||||
@@ -215,11 +221,6 @@ def _resolve_static_path() -> Path | None:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 开发环境允许复用仓库里的现成 dist,但不再在用户机器上触发任何前端自愈构建。
|
||||
base_dir = _get_project_root()
|
||||
static_path = base_dir / "dashboard" / "dist"
|
||||
if static_path.exists():
|
||||
return static_path
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -124,6 +124,11 @@ class DeletePurgeRequest(BaseModel):
|
||||
limit: int = Field(1000, ge=1, le=5000)
|
||||
|
||||
|
||||
class FeedbackRollbackRequest(BaseModel):
|
||||
requested_by: str = "webui"
|
||||
reason: str = ""
|
||||
|
||||
|
||||
def _build_import_guide_markdown(settings: dict[str, Any]) -> str:
|
||||
path_aliases_raw = settings.get("path_aliases")
|
||||
path_aliases = path_aliases_raw if isinstance(path_aliases_raw, dict) else {}
|
||||
@@ -359,6 +364,31 @@ async def _profile_delete_override(person_id: str) -> dict:
|
||||
return await memory_service.profile_admin(action="delete_override", person_id=person_id)
|
||||
|
||||
|
||||
async def _feedback_list(limit: int, status: str, rollback_status: str, query: str) -> dict:
|
||||
statuses = [item.strip() for item in str(status or "").split(",") if item.strip()]
|
||||
rollback_statuses = [item.strip() for item in str(rollback_status or "").split(",") if item.strip()]
|
||||
return await memory_service.feedback_admin(
|
||||
action="list",
|
||||
limit=limit,
|
||||
statuses=statuses,
|
||||
rollback_statuses=rollback_statuses,
|
||||
query=query,
|
||||
)
|
||||
|
||||
|
||||
async def _feedback_get(task_id: int) -> dict:
|
||||
return await memory_service.feedback_admin(action="get", task_id=task_id)
|
||||
|
||||
|
||||
async def _feedback_rollback(task_id: int, payload: FeedbackRollbackRequest) -> dict:
|
||||
return await memory_service.feedback_admin(
|
||||
action="rollback",
|
||||
task_id=task_id,
|
||||
requested_by=payload.requested_by,
|
||||
reason=payload.reason,
|
||||
)
|
||||
|
||||
|
||||
async def _runtime_save() -> dict:
|
||||
return await memory_service.runtime_admin(action="save")
|
||||
|
||||
@@ -830,6 +860,26 @@ async def delete_memory_profile_override(person_id: str):
|
||||
return await _profile_delete_override(person_id)
|
||||
|
||||
|
||||
@router.get("/feedback-corrections")
|
||||
async def list_memory_feedback_corrections(
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
status: str = Query(""),
|
||||
rollback_status: str = Query(""),
|
||||
query: str = Query(""),
|
||||
):
|
||||
return await _feedback_list(limit, status, rollback_status, query)
|
||||
|
||||
|
||||
@router.get("/feedback-corrections/{task_id}")
|
||||
async def get_memory_feedback_correction(task_id: int):
|
||||
return await _feedback_get(task_id)
|
||||
|
||||
|
||||
@router.post("/feedback-corrections/{task_id}/rollback")
|
||||
async def rollback_memory_feedback_correction(task_id: int, payload: FeedbackRollbackRequest):
|
||||
return await _feedback_rollback(task_id, payload)
|
||||
|
||||
|
||||
@router.post("/runtime/save")
|
||||
async def save_memory_runtime():
|
||||
return await _runtime_save()
|
||||
|
||||
Reference in New Issue
Block a user