Merge pull request #1587 from A-Dawn/r-dev

feat:记忆系统新增反馈学习功能&修复聊天内容导入问题
This commit is contained in:
Dawn ARC
2026-04-16 20:32:57 +08:00
committed by GitHub
37 changed files with 7156 additions and 146 deletions

1
.gitignore vendored
View File

@@ -371,3 +371,4 @@ packages/
.claude/ .claude/
.omc/ .omc/
/.venv312 /.venv312
/src/A_memorix/algorithm_redesign

View File

@@ -496,6 +496,77 @@ export interface MemoryDeleteOperationDetailPayload {
error?: string error?: string
} }
export interface MemoryFeedbackAffectedCountsPayload {
relations?: number
stale_paragraphs?: number
episode_sources?: number
profile_person_ids?: number
correction_paragraphs?: number
corrected_relations?: number
}
export interface MemoryFeedbackActionLogPayload {
id: number
task_id: number
query_tool_id: string
action_type: string
target_hash: string
reason?: string
before_payload?: Record<string, unknown>
after_payload?: Record<string, unknown>
created_at?: number
}
export interface MemoryFeedbackCorrectionSummaryPayload {
task_id: number
query_tool_id: string
session_id: string
query_text: string
query_timestamp?: number
task_status: string
decision: string
decision_confidence: number
feedback_message_count: number
rollback_status: string
affected_counts: MemoryFeedbackAffectedCountsPayload
created_at?: number
updated_at?: number
}
export interface MemoryFeedbackCorrectionDetailTaskPayload extends MemoryFeedbackCorrectionSummaryPayload {
query_snapshot?: Record<string, unknown>
decision_payload?: Record<string, unknown>
rollback_plan_summary?: Record<string, unknown>
rollback_result?: Record<string, unknown>
rollback_error?: string
rollback_requested_by?: string
rollback_reason?: string
rollback_requested_at?: number
rolled_back_at?: number
action_logs?: MemoryFeedbackActionLogPayload[]
}
export interface MemoryFeedbackCorrectionListPayload {
success: boolean
items: MemoryFeedbackCorrectionSummaryPayload[]
count?: number
error?: string
}
export interface MemoryFeedbackCorrectionDetailPayload {
success: boolean
task?: MemoryFeedbackCorrectionDetailTaskPayload | null
error?: string
}
export interface MemoryFeedbackCorrectionRollbackPayload {
success: boolean
already_rolled_back?: boolean
result?: Record<string, unknown>
task?: MemoryFeedbackCorrectionDetailTaskPayload | null
error?: string
}
export interface MemorySourceItemPayload { export interface MemorySourceItemPayload {
source: string source: string
paragraph_count?: number paragraph_count?: number
@@ -610,6 +681,49 @@ export async function getMemoryDeleteOperation(
return requestJson<MemoryDeleteOperationDetailPayload>(`/delete/operations/${encodeURIComponent(operationId)}`) return requestJson<MemoryDeleteOperationDetailPayload>(`/delete/operations/${encodeURIComponent(operationId)}`)
} }
export async function getMemoryFeedbackCorrections(
options?: {
limit?: number
status?: string
rollbackStatus?: string
query?: string
},
): Promise<MemoryFeedbackCorrectionListPayload> {
const params = new URLSearchParams({
limit: String(options?.limit ?? 50),
})
if (options?.status?.trim()) {
params.set('status', options.status.trim())
}
if (options?.rollbackStatus?.trim()) {
params.set('rollback_status', options.rollbackStatus.trim())
}
if (options?.query?.trim()) {
params.set('query', options.query.trim())
}
return requestJson<MemoryFeedbackCorrectionListPayload>(`/feedback-corrections?${params.toString()}`)
}
export async function getMemoryFeedbackCorrection(
taskId: number,
): Promise<MemoryFeedbackCorrectionDetailPayload> {
return requestJson<MemoryFeedbackCorrectionDetailPayload>(`/feedback-corrections/${taskId}`)
}
export async function rollbackMemoryFeedbackCorrection(
taskId: number,
payload: {
requested_by?: string
reason?: string
},
): Promise<MemoryFeedbackCorrectionRollbackPayload> {
return requestJson<MemoryFeedbackCorrectionRollbackPayload>(`/feedback-corrections/${taskId}/rollback`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(payload),
})
}
export async function getMemorySources(): Promise<MemorySourceListPayload> { export async function getMemorySources(): Promise<MemorySourceListPayload> {
return requestJson<MemorySourceListPayload>('/sources') return requestJson<MemorySourceListPayload>('/sources')
} }

View File

@@ -81,9 +81,12 @@ vi.mock('@/lib/memory-api', () => ({
getMemorySources: vi.fn(), getMemorySources: vi.fn(),
getMemoryDeleteOperations: vi.fn(), getMemoryDeleteOperations: vi.fn(),
getMemoryDeleteOperation: vi.fn(), getMemoryDeleteOperation: vi.fn(),
getMemoryFeedbackCorrections: vi.fn(),
getMemoryFeedbackCorrection: vi.fn(),
previewMemoryDelete: vi.fn(), previewMemoryDelete: vi.fn(),
executeMemoryDelete: vi.fn(), executeMemoryDelete: vi.fn(),
restoreMemoryDelete: vi.fn(), restoreMemoryDelete: vi.fn(),
rollbackMemoryFeedbackCorrection: vi.fn(),
})) }))
function mockImportTask(taskId: string, status: string = 'running'): memoryApi.MemoryImportTaskPayload { function mockImportTask(taskId: string, status: string = 'running'): memoryApi.MemoryImportTaskPayload {
@@ -357,6 +360,82 @@ describe('KnowledgeBasePage import workflow', () => {
items: [], items: [],
}, },
}) })
vi.mocked(memoryApi.getMemoryFeedbackCorrections).mockResolvedValue({
success: true,
items: [
{
task_id: 11,
query_tool_id: 'tool-query-11',
session_id: 'session-1',
query_text: '测试用户最喜欢的颜色是什么',
query_timestamp: 1_710_000_010,
task_status: 'applied',
decision: 'correct',
decision_confidence: 0.97,
feedback_message_count: 1,
rollback_status: 'none',
affected_counts: {
relations: 1,
stale_paragraphs: 1,
episode_sources: 2,
profile_person_ids: 1,
correction_paragraphs: 1,
corrected_relations: 1,
},
created_at: 1_710_000_011,
updated_at: 1_710_000_012,
},
],
count: 1,
})
vi.mocked(memoryApi.getMemoryFeedbackCorrection).mockResolvedValue({
success: true,
task: {
task_id: 11,
query_tool_id: 'tool-query-11',
session_id: 'session-1',
query_text: '测试用户最喜欢的颜色是什么',
query_timestamp: 1_710_000_010,
task_status: 'applied',
decision: 'correct',
decision_confidence: 0.97,
feedback_message_count: 1,
rollback_status: 'none',
affected_counts: {
relations: 1,
stale_paragraphs: 1,
episode_sources: 2,
profile_person_ids: 1,
correction_paragraphs: 1,
corrected_relations: 1,
},
query_snapshot: { query: '测试用户最喜欢的颜色是什么', hits: [{ hash: 'paragraph-1' }] },
decision_payload: { decision: 'correct', confidence: 0.97 },
rollback_plan_summary: {
forgotten_relations: [{ hash: 'rel-old', subject: '测试用户', predicate: '最喜欢的颜色是', object: '蓝色' }],
corrected_write: {
paragraph_hashes: ['paragraph-new'],
corrected_relations: [{ hash: 'rel-new', subject: '测试用户', predicate: '最喜欢的颜色是', object: '绿色' }],
},
},
rollback_result: {},
action_logs: [
{
id: 1,
task_id: 11,
query_tool_id: 'tool-query-11',
action_type: 'forget_relation',
target_hash: 'rel-old',
reason: '用户明确纠正为绿色',
before_payload: { hash: 'rel-old', subject: '测试用户', predicate: '最喜欢的颜色是', object: '蓝色' },
after_payload: { is_inactive: true },
created_at: 1_710_000_013,
},
],
created_at: 1_710_000_011,
updated_at: 1_710_000_012,
},
})
vi.mocked(memoryApi.previewMemoryDelete).mockResolvedValue({ vi.mocked(memoryApi.previewMemoryDelete).mockResolvedValue({
success: true, success: true,
mode: 'source', mode: 'source',
@@ -380,6 +459,37 @@ describe('KnowledgeBasePage import workflow', () => {
deleted_source_count: 1, deleted_source_count: 1,
} as never) } as never)
vi.mocked(memoryApi.restoreMemoryDelete).mockResolvedValue({ success: true } as never) vi.mocked(memoryApi.restoreMemoryDelete).mockResolvedValue({ success: true } as never)
vi.mocked(memoryApi.rollbackMemoryFeedbackCorrection).mockResolvedValue({
success: true,
result: { restored_relation_hashes: ['rel-old'] },
task: {
task_id: 11,
query_tool_id: 'tool-query-11',
session_id: 'session-1',
query_text: '测试用户最喜欢的颜色是什么',
query_timestamp: 1_710_000_010,
task_status: 'applied',
decision: 'correct',
decision_confidence: 0.97,
feedback_message_count: 1,
rollback_status: 'rolled_back',
affected_counts: {
relations: 1,
stale_paragraphs: 1,
episode_sources: 2,
profile_person_ids: 1,
correction_paragraphs: 1,
corrected_relations: 1,
},
query_snapshot: { query: '测试用户最喜欢的颜色是什么', hits: [{ hash: 'paragraph-1' }] },
decision_payload: { decision: 'correct', confidence: 0.97 },
rollback_plan_summary: {},
rollback_result: { restored_relation_hashes: ['rel-old'] },
action_logs: [],
created_at: 1_710_000_011,
updated_at: 1_710_000_012,
},
})
vi.mocked(memoryApi.refreshMemoryRuntimeSelfCheck).mockResolvedValue({ vi.mocked(memoryApi.refreshMemoryRuntimeSelfCheck).mockResolvedValue({
success: true, success: true,
report: { ok: true }, report: { ok: true },
@@ -619,4 +729,27 @@ describe('KnowledgeBasePage import workflow', () => {
}), }),
) )
}, 20_000) }, 20_000)
it('shows feedback correction history and supports rollback', async () => {
const user = userEvent.setup()
render(<KnowledgeBasePage />)
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
await user.click(screen.getByRole('tab', { name: '纠错历史' }))
await screen.findByText('反馈纠错历史')
await screen.findByText('测试用户最喜欢的颜色是什么')
await waitFor(() => expect(memoryApi.getMemoryFeedbackCorrection).toHaveBeenCalledWith(11))
await user.click(screen.getByRole('button', { name: '回退本次纠错' }))
const rollbackReason = await screen.findByLabelText('回退原因')
await user.type(rollbackReason, '人工确认回退')
await user.click(screen.getByRole('button', { name: '确认回退' }))
await waitFor(() =>
expect(memoryApi.rollbackMemoryFeedbackCorrection).toHaveBeenCalledWith(11, {
requested_by: 'knowledge_base',
reason: '人工确认回退',
}),
)
}, 20_000)
}) })

View File

@@ -24,6 +24,14 @@ import { Badge } from '@/components/ui/badge'
import { Button } from '@/components/ui/button' import { Button } from '@/components/ui/button'
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'
import { Checkbox } from '@/components/ui/checkbox' import { Checkbox } from '@/components/ui/checkbox'
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
} from '@/components/ui/dialog'
import { Input } from '@/components/ui/input' import { Input } from '@/components/ui/input'
import { Label } from '@/components/ui/label' import { Label } from '@/components/ui/label'
import { Progress } from '@/components/ui/progress' import { Progress } from '@/components/ui/progress'
@@ -48,6 +56,8 @@ import {
createMemoryRawScanImport, createMemoryRawScanImport,
createMemoryTemporalBackfillImport, createMemoryTemporalBackfillImport,
executeMemoryDelete, executeMemoryDelete,
getMemoryFeedbackCorrection,
getMemoryFeedbackCorrections,
getMemoryImportPathAliases, getMemoryImportPathAliases,
getMemoryImportSettings, getMemoryImportSettings,
getMemoryImportTask, getMemoryImportTask,
@@ -74,6 +84,7 @@ import {
type MemoryImportTaskPayload, type MemoryImportTaskPayload,
previewMemoryDelete, previewMemoryDelete,
refreshMemoryRuntimeSelfCheck, refreshMemoryRuntimeSelfCheck,
rollbackMemoryFeedbackCorrection,
resolveMemoryImportPath, resolveMemoryImportPath,
retryMemoryImportTask, retryMemoryImportTask,
restoreMemoryDelete, restoreMemoryDelete,
@@ -82,6 +93,9 @@ import {
type MemoryConfigSchemaPayload, type MemoryConfigSchemaPayload,
type MemoryDeleteExecutePayload, type MemoryDeleteExecutePayload,
type MemoryDeleteOperationPayload, type MemoryDeleteOperationPayload,
type MemoryFeedbackActionLogPayload,
type MemoryFeedbackCorrectionDetailTaskPayload,
type MemoryFeedbackCorrectionSummaryPayload,
type MemorySourceItemPayload, type MemorySourceItemPayload,
type MemoryRuntimeConfigPayload, type MemoryRuntimeConfigPayload,
type MemoryTaskPayload, type MemoryTaskPayload,
@@ -90,6 +104,9 @@ import {
const DELETE_OPERATION_FETCH_LIMIT = 100 const DELETE_OPERATION_FETCH_LIMIT = 100
const DELETE_OPERATION_PAGE_SIZE = 6 const DELETE_OPERATION_PAGE_SIZE = 6
const DELETE_OPERATION_ITEM_PAGE_SIZE = 8 const DELETE_OPERATION_ITEM_PAGE_SIZE = 8
const FEEDBACK_CORRECTION_FETCH_LIMIT = 100
const FEEDBACK_CORRECTION_PAGE_SIZE = 6
const FEEDBACK_ACTION_LOG_PAGE_SIZE = 8
const IMPORT_CHUNK_PAGE_SIZE = 50 const IMPORT_CHUNK_PAGE_SIZE = 50
const RUNNING_IMPORT_STATUS = new Set(['preparing', 'running', 'cancel_requested']) const RUNNING_IMPORT_STATUS = new Set(['preparing', 'running', 'cancel_requested'])
@@ -270,6 +287,90 @@ function formatDeleteOperationTime(timestamp?: number | null): string {
}) })
} }
function formatFeedbackDecision(decision: string): string {
switch (decision) {
case 'correct':
return '纠正'
case 'reject':
return '否定'
case 'confirm':
return '确认'
case 'supplement':
return '补充'
case 'none':
return '无动作'
default:
return decision || '未知'
}
}
function formatFeedbackTaskStatus(status: string): string {
switch (status) {
case 'pending':
return '待处理'
case 'running':
return '处理中'
case 'applied':
return '已应用'
case 'skipped':
return '已跳过'
case 'error':
return '失败'
default:
return status || '未知'
}
}
function formatFeedbackRollbackStatus(status: string): string {
switch (status) {
case 'none':
return '未回退'
case 'running':
return '回退中'
case 'rolled_back':
return '已回退'
case 'error':
return '回退失败'
default:
return status || '未知'
}
}
function getFeedbackStatusVariant(
status: string,
): 'default' | 'secondary' | 'destructive' | 'outline' {
if (status === 'applied' || status === 'rolled_back') {
return 'default'
}
if (status === 'error') {
return 'destructive'
}
if (status === 'running' || status === 'pending') {
return 'outline'
}
return 'secondary'
}
function summarizeFeedbackActionPayload(value: Record<string, unknown> | undefined): string {
if (!value) {
return ''
}
const hash = String(value.hash ?? '').trim()
const subject = String(value.subject ?? '').trim()
const predicate = String(value.predicate ?? '').trim()
const object = String(value.object ?? '').trim()
if (subject && predicate && object) {
return formatDeleteRelationText(subject, predicate, object)
}
if (hash) {
return hash
}
if (Array.isArray(value.target_hashes) && value.target_hashes.length > 0) {
return `targets ${value.target_hashes.length}`
}
return trimDeleteItemText(JSON.stringify(value, null, 2), 120)
}
type DeleteOperationItem = NonNullable<MemoryDeleteOperationPayload['items']>[number] type DeleteOperationItem = NonNullable<MemoryDeleteOperationPayload['items']>[number]
function trimDeleteItemText(value: string, maxLength: number = 140): string { function trimDeleteItemText(value: string, maxLength: number = 140): string {
@@ -471,6 +572,20 @@ export function KnowledgeBasePage() {
const [deleteRestoring, setDeleteRestoring] = useState(false) const [deleteRestoring, setDeleteRestoring] = useState(false)
const [deleteResult, setDeleteResult] = useState<MemoryDeleteExecutePayload | null>(null) const [deleteResult, setDeleteResult] = useState<MemoryDeleteExecutePayload | null>(null)
const [pendingDeleteRequest, setPendingDeleteRequest] = useState<MemoryDeleteRequestPayload | null>(null) const [pendingDeleteRequest, setPendingDeleteRequest] = useState<MemoryDeleteRequestPayload | null>(null)
const [feedbackCorrections, setFeedbackCorrections] = useState<MemoryFeedbackCorrectionSummaryPayload[]>([])
const [feedbackSearch, setFeedbackSearch] = useState('')
const [feedbackStatusFilter, setFeedbackStatusFilter] = useState('all')
const [feedbackRollbackFilter, setFeedbackRollbackFilter] = useState('all')
const [feedbackPage, setFeedbackPage] = useState(1)
const [selectedFeedbackTaskId, setSelectedFeedbackTaskId] = useState(0)
const [selectedFeedbackTaskDetail, setSelectedFeedbackTaskDetail] = useState<MemoryFeedbackCorrectionDetailTaskPayload | null>(null)
const [selectedFeedbackTaskLoading, setSelectedFeedbackTaskLoading] = useState(false)
const [selectedFeedbackTaskError, setSelectedFeedbackTaskError] = useState('')
const [feedbackActionLogSearch, setFeedbackActionLogSearch] = useState('')
const [feedbackActionLogPage, setFeedbackActionLogPage] = useState(1)
const [feedbackRollbackDialogOpen, setFeedbackRollbackDialogOpen] = useState(false)
const [feedbackRollbackReason, setFeedbackRollbackReason] = useState('')
const [feedbackRollingBack, setFeedbackRollingBack] = useState(false)
const [tuningObjective, setTuningObjective] = useState('precision_priority') const [tuningObjective, setTuningObjective] = useState('precision_priority')
const [tuningIntensity, setTuningIntensity] = useState('standard') const [tuningIntensity, setTuningIntensity] = useState('standard')
const [tuningSampleSize, setTuningSampleSize] = useState('24') const [tuningSampleSize, setTuningSampleSize] = useState('24')
@@ -491,6 +606,7 @@ export function KnowledgeBasePage() {
tuningTaskPayload, tuningTaskPayload,
sourcePayload, sourcePayload,
deleteOperationPayload, deleteOperationPayload,
feedbackCorrectionPayload,
] = await Promise.all([ ] = await Promise.all([
getMemoryConfigSchema(), getMemoryConfigSchema(),
getMemoryConfig(), getMemoryConfig(),
@@ -503,6 +619,7 @@ export function KnowledgeBasePage() {
getMemoryTuningTasks(20), getMemoryTuningTasks(20),
getMemorySources(), getMemorySources(),
getMemoryDeleteOperations(DELETE_OPERATION_FETCH_LIMIT), getMemoryDeleteOperations(DELETE_OPERATION_FETCH_LIMIT),
getMemoryFeedbackCorrections({ limit: FEEDBACK_CORRECTION_FETCH_LIMIT }),
]) ])
setSchemaPayload(schema) setSchemaPayload(schema)
@@ -519,6 +636,7 @@ export function KnowledgeBasePage() {
setTuningTasks(tuningTaskPayload.items ?? []) setTuningTasks(tuningTaskPayload.items ?? [])
setMemorySources(sourcePayload.items ?? []) setMemorySources(sourcePayload.items ?? [])
setDeleteOperations(deleteOperationPayload.items ?? []) setDeleteOperations(deleteOperationPayload.items ?? [])
setFeedbackCorrections(feedbackCorrectionPayload.items ?? [])
if (!selectedImportTaskId && (importTaskPayload.items ?? []).length > 0) { if (!selectedImportTaskId && (importTaskPayload.items ?? []).length > 0) {
const initialTaskId = String(importTaskPayload.items?.[0]?.task_id ?? '') const initialTaskId = String(importTaskPayload.items?.[0]?.task_id ?? '')
if (initialTaskId) { if (initialTaskId) {
@@ -1494,6 +1612,212 @@ export function KnowledgeBasePage() {
setDeleteDialogOpen(true) setDeleteDialogOpen(true)
}, []) }, [])
const filteredFeedbackCorrections = useMemo(() => {
const keyword = feedbackSearch.trim().toLowerCase()
return feedbackCorrections.filter((item) => {
const taskStatus = String(item.task_status ?? '').trim().toLowerCase()
const rollbackStatus = String(item.rollback_status ?? '').trim().toLowerCase()
if (feedbackStatusFilter !== 'all' && taskStatus !== feedbackStatusFilter) {
return false
}
if (feedbackRollbackFilter !== 'all' && rollbackStatus !== feedbackRollbackFilter) {
return false
}
if (!keyword) {
return true
}
return [
item.query_tool_id,
item.session_id,
item.query_text,
item.decision,
item.task_status,
item.rollback_status,
]
.map((value) => String(value ?? '').toLowerCase())
.some((value) => value.includes(keyword))
})
}, [feedbackCorrections, feedbackRollbackFilter, feedbackSearch, feedbackStatusFilter])
const feedbackPageCount = Math.max(1, Math.ceil(filteredFeedbackCorrections.length / FEEDBACK_CORRECTION_PAGE_SIZE))
const pagedFeedbackCorrections = useMemo(() => {
const start = (feedbackPage - 1) * FEEDBACK_CORRECTION_PAGE_SIZE
return filteredFeedbackCorrections.slice(start, start + FEEDBACK_CORRECTION_PAGE_SIZE)
}, [feedbackPage, filteredFeedbackCorrections])
const selectedFeedbackCorrection = useMemo(
() =>
filteredFeedbackCorrections.find((item) => item.task_id === selectedFeedbackTaskId)
?? pagedFeedbackCorrections[0]
?? null,
[filteredFeedbackCorrections, pagedFeedbackCorrections, selectedFeedbackTaskId],
)
useEffect(() => {
setFeedbackPage(1)
}, [feedbackSearch, feedbackStatusFilter, feedbackRollbackFilter])
useEffect(() => {
if (feedbackPage > feedbackPageCount) {
setFeedbackPage(feedbackPageCount)
}
}, [feedbackPage, feedbackPageCount])
useEffect(() => {
if (!selectedFeedbackCorrection) {
if (selectedFeedbackTaskId) {
setSelectedFeedbackTaskId(0)
}
setSelectedFeedbackTaskDetail(null)
setSelectedFeedbackTaskError('')
return
}
if (selectedFeedbackCorrection.task_id !== selectedFeedbackTaskId) {
setSelectedFeedbackTaskId(selectedFeedbackCorrection.task_id)
}
}, [selectedFeedbackCorrection, selectedFeedbackTaskId])
useEffect(() => {
const taskId = selectedFeedbackCorrection?.task_id
if (!taskId) {
setSelectedFeedbackTaskDetail(null)
setSelectedFeedbackTaskError('')
return
}
let cancelled = false
setSelectedFeedbackTaskLoading(true)
setSelectedFeedbackTaskError('')
void getMemoryFeedbackCorrection(taskId)
.then((payload) => {
if (cancelled) {
return
}
if (!payload.success || !payload.task) {
setSelectedFeedbackTaskDetail(null)
setSelectedFeedbackTaskError(payload.error || '未能加载纠错任务详情')
return
}
setSelectedFeedbackTaskDetail(payload.task)
})
.catch((error) => {
if (cancelled) {
return
}
setSelectedFeedbackTaskDetail(null)
setSelectedFeedbackTaskError(error instanceof Error ? error.message : '未能加载纠错任务详情')
})
.finally(() => {
if (!cancelled) {
setSelectedFeedbackTaskLoading(false)
}
})
return () => {
cancelled = true
}
}, [selectedFeedbackCorrection?.task_id])
const selectedFeedbackResolved = useMemo(() => {
if (!selectedFeedbackCorrection) {
return null
}
if (selectedFeedbackTaskDetail?.task_id === selectedFeedbackCorrection.task_id) {
return {
...selectedFeedbackCorrection,
...selectedFeedbackTaskDetail,
} satisfies MemoryFeedbackCorrectionDetailTaskPayload
}
return selectedFeedbackTaskDetail ?? selectedFeedbackCorrection
}, [selectedFeedbackCorrection, selectedFeedbackTaskDetail])
const selectedFeedbackActionLogs = Array.isArray(selectedFeedbackResolved?.action_logs)
? selectedFeedbackResolved.action_logs
: []
const filteredFeedbackActionLogs = useMemo(() => {
const keyword = feedbackActionLogSearch.trim().toLowerCase()
if (!keyword) {
return selectedFeedbackActionLogs
}
return selectedFeedbackActionLogs.filter((item) =>
[
item.action_type,
item.target_hash,
item.reason,
summarizeFeedbackActionPayload(item.before_payload),
summarizeFeedbackActionPayload(item.after_payload),
]
.map((value) => String(value ?? '').toLowerCase())
.some((value) => value.includes(keyword)),
)
}, [feedbackActionLogSearch, selectedFeedbackActionLogs])
const feedbackActionLogPageCount = Math.max(
1,
Math.ceil(filteredFeedbackActionLogs.length / FEEDBACK_ACTION_LOG_PAGE_SIZE),
)
const pagedFeedbackActionLogs = useMemo(() => {
const start = (feedbackActionLogPage - 1) * FEEDBACK_ACTION_LOG_PAGE_SIZE
return filteredFeedbackActionLogs.slice(start, start + FEEDBACK_ACTION_LOG_PAGE_SIZE)
}, [feedbackActionLogPage, filteredFeedbackActionLogs])
useEffect(() => {
setFeedbackActionLogPage(1)
}, [selectedFeedbackTaskId, feedbackActionLogSearch])
useEffect(() => {
if (feedbackActionLogPage > feedbackActionLogPageCount) {
setFeedbackActionLogPage(feedbackActionLogPageCount)
}
}, [feedbackActionLogPage, feedbackActionLogPageCount])
const openFeedbackRollbackDialog = useCallback(() => {
setFeedbackRollbackReason('')
setFeedbackRollbackDialogOpen(true)
}, [])
const executeFeedbackRollback = useCallback(async () => {
const taskId = selectedFeedbackResolved?.task_id
if (!taskId) {
return
}
try {
setFeedbackRollingBack(true)
const payload = await rollbackMemoryFeedbackCorrection(taskId, {
requested_by: 'knowledge_base',
reason: feedbackRollbackReason.trim(),
})
if (!payload.success) {
throw new Error(payload.error || '回退失败')
}
toast({
title: payload.already_rolled_back ? '该纠错已回退' : '纠错回退成功',
description: `任务 ${taskId} 的回退结果已写入日志`,
})
setFeedbackRollbackDialogOpen(false)
const [listPayload, detailPayload] = await Promise.all([
getMemoryFeedbackCorrections({ limit: FEEDBACK_CORRECTION_FETCH_LIMIT }),
getMemoryFeedbackCorrection(taskId),
])
setFeedbackCorrections(listPayload.items ?? [])
setSelectedFeedbackTaskDetail(detailPayload.task ?? null)
const [sourcePayload, runtimePayload] = await Promise.all([
getMemorySources(),
getMemoryRuntimeConfig(),
])
setMemorySources(sourcePayload.items ?? [])
setRuntimeConfig(runtimePayload)
} catch (error) {
toast({
title: '纠错回退失败',
description: error instanceof Error ? error.message : '未知错误',
variant: 'destructive',
})
} finally {
setFeedbackRollingBack(false)
}
}, [feedbackRollbackReason, selectedFeedbackResolved?.task_id, toast])
const selectedOperationResolved = useMemo(() => { const selectedOperationResolved = useMemo(() => {
if (!selectedDeleteOperation) { if (!selectedDeleteOperation) {
return null return null
@@ -1776,6 +2100,9 @@ export function KnowledgeBasePage() {
<TabsTrigger value="delete" className="rounded-lg px-4 py-1.5"> <TabsTrigger value="delete" className="rounded-lg px-4 py-1.5">
</TabsTrigger> </TabsTrigger>
<TabsTrigger value="feedback" className="rounded-lg px-4 py-1.5">
</TabsTrigger>
</TabsList> </TabsList>
<TabsContent value="overview" className="space-y-4"> <TabsContent value="overview" className="space-y-4">
@@ -3314,6 +3641,327 @@ export function KnowledgeBasePage() {
</Card> </Card>
</div> </div>
</TabsContent> </TabsContent>
<TabsContent value="feedback" className="space-y-4">
<div className="space-y-4">
<Card>
<CardHeader>
<CardTitle className="flex items-center gap-2">
<RotateCcw className="h-4 w-4" />
</CardTitle>
<CardDescription>
feedback correction 退
</CardDescription>
</CardHeader>
<CardContent className="space-y-4">
<div className="grid gap-3 lg:grid-cols-[minmax(0,1fr)_180px_180px]">
<Input
value={feedbackSearch}
onChange={(event) => setFeedbackSearch(event.target.value)}
placeholder="搜索 query_tool_id / session / query / reason"
/>
<Select value={feedbackStatusFilter} onValueChange={setFeedbackStatusFilter}>
<SelectTrigger>
<SelectValue placeholder="按任务状态筛选" />
</SelectTrigger>
<SelectContent>
<SelectItem value="all"></SelectItem>
<SelectItem value="applied"></SelectItem>
<SelectItem value="skipped"></SelectItem>
<SelectItem value="error"></SelectItem>
<SelectItem value="running"></SelectItem>
<SelectItem value="pending"></SelectItem>
</SelectContent>
</Select>
<Select value={feedbackRollbackFilter} onValueChange={setFeedbackRollbackFilter}>
<SelectTrigger>
<SelectValue placeholder="按回退状态筛选" />
</SelectTrigger>
<SelectContent>
<SelectItem value="all">退</SelectItem>
<SelectItem value="none">退</SelectItem>
<SelectItem value="rolled_back">退</SelectItem>
<SelectItem value="error">退</SelectItem>
<SelectItem value="running">退</SelectItem>
</SelectContent>
</Select>
</div>
<div className="flex flex-wrap items-center justify-between gap-2 text-sm text-muted-foreground">
<span> {filteredFeedbackCorrections.length} {feedbackCorrections.length} </span>
<span> {feedbackPage} / {feedbackPageCount} {FEEDBACK_CORRECTION_PAGE_SIZE} </span>
</div>
<div className="grid gap-4 xl:grid-cols-[minmax(0,0.92fr)_minmax(0,1.08fr)]">
<ScrollArea className="h-[720px] rounded-lg border">
<div className="space-y-3 p-3">
{pagedFeedbackCorrections.length > 0 ? pagedFeedbackCorrections.map((item) => {
const isSelected = selectedFeedbackCorrection?.task_id === item.task_id
return (
<button
key={item.task_id}
type="button"
onClick={() => setSelectedFeedbackTaskId(item.task_id)}
className={cn(
'w-full rounded-xl border p-4 text-left transition-colors',
isSelected
? 'border-primary bg-primary/5 shadow-sm'
: 'bg-muted/20 hover:border-primary/40 hover:bg-muted/40',
)}
>
<div className="flex flex-col gap-3">
<div className="flex flex-wrap items-center gap-2">
<Badge variant={getFeedbackStatusVariant(item.task_status)}>
{formatFeedbackTaskStatus(item.task_status)}
</Badge>
<Badge variant={getFeedbackStatusVariant(item.rollback_status)}>
{formatFeedbackRollbackStatus(item.rollback_status)}
</Badge>
<Badge variant="outline">
{formatFeedbackDecision(item.decision)}
</Badge>
</div>
<div className="text-sm font-medium break-words">
{item.query_text || '无查询文本'}
</div>
<div className="font-mono text-[11px] break-all text-muted-foreground">
{item.query_tool_id}
</div>
<div className="flex flex-wrap gap-3 text-xs text-muted-foreground">
<span> {Number(item.affected_counts?.relations ?? 0)}</span>
<span> {Number(item.affected_counts?.stale_paragraphs ?? 0)}</span>
<span>Episode {Number(item.affected_counts?.episode_sources ?? 0)}</span>
<span>Profile {Number(item.affected_counts?.profile_person_ids ?? 0)}</span>
</div>
<div className="text-xs text-muted-foreground">
{formatDeleteOperationTime(item.query_timestamp ?? item.created_at)}
</div>
</div>
</button>
)
}) : (
<div className="rounded-lg border border-dashed bg-muted/20 p-6 text-center text-sm text-muted-foreground">
</div>
)}
</div>
</ScrollArea>
<div className="rounded-xl border bg-muted/20 p-4">
{selectedFeedbackCorrection ? (
<div className="space-y-4">
<div className="flex flex-col gap-3 lg:flex-row lg:items-start lg:justify-between">
<div className="space-y-2">
<div className="flex flex-wrap items-center gap-2">
<Badge variant={getFeedbackStatusVariant(String(selectedFeedbackResolved?.task_status ?? ''))}>
{formatFeedbackTaskStatus(String(selectedFeedbackResolved?.task_status ?? ''))}
</Badge>
<Badge variant={getFeedbackStatusVariant(String(selectedFeedbackResolved?.rollback_status ?? 'none'))}>
{formatFeedbackRollbackStatus(String(selectedFeedbackResolved?.rollback_status ?? 'none'))}
</Badge>
<Badge variant="outline">
{formatFeedbackDecision(String(selectedFeedbackResolved?.decision ?? ''))}
</Badge>
</div>
<div className="text-sm font-medium break-words">
{selectedFeedbackResolved?.query_text || '无查询文本'}
</div>
<div className="font-mono text-xs break-all">
{selectedFeedbackResolved?.query_tool_id}
</div>
</div>
<Button
size="sm"
variant="outline"
onClick={openFeedbackRollbackDialog}
disabled={
String(selectedFeedbackResolved?.task_status ?? '') !== 'applied'
|| String(selectedFeedbackResolved?.rollback_status ?? 'none') === 'rolled_back'
|| feedbackRollingBack
}
>
<RotateCcw className="mr-2 h-4 w-4" />
{String(selectedFeedbackResolved?.rollback_status ?? 'none') === 'rolled_back'
? '已回退'
: '回退本次纠错'}
</Button>
</div>
<div className="grid gap-3 lg:grid-cols-4">
<div className="rounded-lg border bg-background/60 p-3">
<div className="text-xs text-muted-foreground"></div>
<div className="mt-1 text-sm break-all">{selectedFeedbackResolved?.session_id || '-'}</div>
</div>
<div className="rounded-lg border bg-background/60 p-3">
<div className="text-xs text-muted-foreground"></div>
<div className="mt-1 text-sm">{Number(selectedFeedbackResolved?.feedback_message_count ?? 0)}</div>
</div>
<div className="rounded-lg border bg-background/60 p-3">
<div className="text-xs text-muted-foreground"></div>
<div className="mt-1 text-sm">{Number(selectedFeedbackResolved?.decision_confidence ?? 0).toFixed(2)}</div>
</div>
<div className="rounded-lg border bg-background/60 p-3">
<div className="text-xs text-muted-foreground">退</div>
<div className="mt-1 text-sm">{formatDeleteOperationTime(selectedFeedbackResolved?.rolled_back_at)}</div>
</div>
</div>
{selectedFeedbackTaskLoading ? (
<div className="rounded-lg border bg-background/60 p-4 text-sm text-muted-foreground">
...
</div>
) : null}
{selectedFeedbackTaskError ? (
<Alert variant="destructive">
<AlertDescription>{selectedFeedbackTaskError}</AlertDescription>
</Alert>
) : null}
{selectedFeedbackResolved?.rollback_error ? (
<Alert variant="destructive">
<AlertDescription>{selectedFeedbackResolved.rollback_error}</AlertDescription>
</Alert>
) : null}
<div className="grid gap-4 xl:grid-cols-[minmax(0,0.95fr)_minmax(0,1.05fr)]">
<div className="space-y-2">
<div className="text-sm font-semibold"></div>
<pre className="max-h-56 overflow-auto rounded-lg border bg-background/70 p-3 text-xs break-words whitespace-pre-wrap">
{JSON.stringify(selectedFeedbackResolved?.query_snapshot ?? {}, null, 2)}
</pre>
</div>
<div className="space-y-2">
<div className="text-sm font-semibold"></div>
<pre className="max-h-56 overflow-auto rounded-lg border bg-background/70 p-3 text-xs break-words whitespace-pre-wrap">
{JSON.stringify(selectedFeedbackResolved?.decision_payload ?? {}, null, 2)}
</pre>
</div>
</div>
<div className="grid gap-4 xl:grid-cols-[minmax(0,0.95fr)_minmax(0,1.05fr)]">
<div className="space-y-2">
<div className="text-sm font-semibold">退</div>
<pre className="max-h-64 overflow-auto rounded-lg border bg-background/70 p-3 text-xs break-words whitespace-pre-wrap">
{JSON.stringify(selectedFeedbackResolved?.rollback_plan_summary ?? {}, null, 2)}
</pre>
</div>
<div className="space-y-2">
<div className="text-sm font-semibold">退</div>
<pre className="max-h-64 overflow-auto rounded-lg border bg-background/70 p-3 text-xs break-words whitespace-pre-wrap">
{JSON.stringify(selectedFeedbackResolved?.rollback_result ?? {}, null, 2)}
</pre>
</div>
</div>
<div className="space-y-2">
<div className="flex flex-col gap-3 lg:flex-row lg:items-center lg:justify-between">
<div className="text-sm font-semibold">线</div>
<div className="flex flex-col gap-3 lg:flex-row lg:items-center lg:justify-end">
<Input
value={feedbackActionLogSearch}
onChange={(event) => setFeedbackActionLogSearch(event.target.value)}
placeholder="搜索动作 / hash / 预览内容"
className="lg:w-80"
/>
<div className="text-xs text-muted-foreground">
{feedbackActionLogPage} / {feedbackActionLogPageCount} {FEEDBACK_ACTION_LOG_PAGE_SIZE}
</div>
</div>
</div>
<ScrollArea className="h-[280px] rounded-lg border bg-background/60">
<div className="space-y-2 p-3">
{pagedFeedbackActionLogs.length > 0 ? pagedFeedbackActionLogs.map((item: MemoryFeedbackActionLogPayload) => (
<div key={`${item.id}:${item.action_type}`} className="rounded-lg border bg-muted/20 p-3">
<div className="flex flex-wrap items-center gap-2">
<Badge variant="outline">{item.action_type}</Badge>
{item.target_hash ? (
<span className="font-mono text-[11px] break-all text-muted-foreground">{item.target_hash}</span>
) : null}
</div>
{item.reason ? (
<div className="mt-2 text-xs text-muted-foreground break-words">
{item.reason}
</div>
) : null}
{item.before_payload && Object.keys(item.before_payload).length > 0 ? (
<div className="mt-2 text-xs break-words">
<span className="font-medium">Before</span>
<span className="text-muted-foreground">{summarizeFeedbackActionPayload(item.before_payload)}</span>
</div>
) : null}
{item.after_payload && Object.keys(item.after_payload).length > 0 ? (
<div className="mt-1 text-xs break-words">
<span className="font-medium">After</span>
<span className="text-muted-foreground">{summarizeFeedbackActionPayload(item.after_payload)}</span>
</div>
) : null}
<div className="mt-2 text-[11px] text-muted-foreground">
{formatDeleteOperationTime(item.created_at)}
</div>
</div>
)) : (
<div className="rounded-lg border border-dashed bg-muted/20 p-6 text-center text-sm text-muted-foreground">
{selectedFeedbackActionLogs.length > 0 ? '当前筛选条件下没有动作日志' : '当前任务没有动作日志'}
</div>
)}
</div>
</ScrollArea>
<div className="flex items-center justify-between gap-2">
<Button
variant="outline"
size="sm"
onClick={() => setFeedbackActionLogPage((current) => Math.max(1, current - 1))}
disabled={feedbackActionLogPage <= 1}
>
</Button>
<div className="text-xs text-muted-foreground">hash </div>
<Button
variant="outline"
size="sm"
onClick={() => setFeedbackActionLogPage((current) => Math.min(feedbackActionLogPageCount, current + 1))}
disabled={feedbackActionLogPage >= feedbackActionLogPageCount}
>
</Button>
</div>
</div>
</div>
) : (
<div className="flex min-h-[360px] items-center justify-center rounded-lg border border-dashed bg-background/40 p-6 text-center text-sm text-muted-foreground">
</div>
)}
</div>
</div>
<div className="flex items-center justify-between gap-2">
<Button
variant="outline"
size="sm"
onClick={() => setFeedbackPage((current) => Math.max(1, current - 1))}
disabled={feedbackPage <= 1}
>
</Button>
<div className="text-xs text-muted-foreground">
query退
</div>
<Button
variant="outline"
size="sm"
onClick={() => setFeedbackPage((current) => Math.min(feedbackPageCount, current + 1))}
disabled={feedbackPage >= feedbackPageCount}
>
</Button>
</div>
</CardContent>
</Card>
</div>
</TabsContent>
</Tabs> </Tabs>
</div> </div>
</div> </div>
@@ -3332,6 +3980,52 @@ export function KnowledgeBasePage() {
onExecute={() => void executePendingDelete()} onExecute={() => void executePendingDelete()}
onRestore={() => void (deleteResult?.operation_id ? restoreDeleteOperation(deleteResult.operation_id) : Promise.resolve())} onRestore={() => void (deleteResult?.operation_id ? restoreDeleteOperation(deleteResult.operation_id) : Promise.resolve())}
/> />
<Dialog open={feedbackRollbackDialogOpen} onOpenChange={setFeedbackRollbackDialogOpen}>
<DialogContent className="max-w-lg" confirmOnEnter>
<DialogHeader>
<DialogTitle>退</DialogTitle>
<DialogDescription>
relation episode/profile
</DialogDescription>
</DialogHeader>
<div className="space-y-3">
<div className="rounded-lg border bg-muted/20 p-3 text-sm">
<div className="font-medium break-words">{selectedFeedbackResolved?.query_text || '无查询文本'}</div>
<div className="mt-1 font-mono text-[11px] break-all text-muted-foreground">
{selectedFeedbackResolved?.query_tool_id}
</div>
</div>
<div className="space-y-2">
<Label htmlFor="feedback-rollback-reason">退</Label>
<Textarea
id="feedback-rollback-reason"
value={feedbackRollbackReason}
onChange={(event) => setFeedbackRollbackReason(event.target.value)}
placeholder="可选,建议填写本次人工回退原因"
/>
</div>
</div>
<DialogFooter>
<Button variant="outline" onClick={() => setFeedbackRollbackDialogOpen(false)} disabled={feedbackRollingBack}>
</Button>
<Button onClick={() => void executeFeedbackRollback()} disabled={feedbackRollingBack}>
{feedbackRollingBack ? (
<>
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
退
</>
) : (
<>
<RotateCcw className="mr-2 h-4 w-4" />
退
</>
)}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</div> </div>
) )
} }

View File

@@ -0,0 +1,418 @@
from __future__ import annotations
from datetime import datetime, timedelta
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Callable, Dict, List
import asyncio
import inspect
import json
import pickle
from sqlalchemy.orm import sessionmaker
from sqlmodel import Session, create_engine
import numpy as np
import pytest
IMPORT_ERROR: str | None = None
try:
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
from src.A_memorix.core.utils import summary_importer as summary_importer_module
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
from src.common.database import database as database_module
from src.common.database.migrations import create_database_migration_bootstrapper
from src.common.message_repository import count_messages
from src.config.model_configs import TaskConfig
from src.services import memory_flow_service as memory_flow_service_module
from src.services import memory_service as memory_service_module
from src.services import send_service
except SystemExit as exc:
IMPORT_ERROR = f"config initialization exited during import: {exc}"
kernel_module = None # type: ignore[assignment]
SDKMemoryKernel = None # type: ignore[assignment]
summary_importer_module = None # type: ignore[assignment]
BotChatSession = None # type: ignore[assignment]
SessionMessage = None # type: ignore[assignment]
MessageInfo = None # type: ignore[assignment]
UserInfo = None # type: ignore[assignment]
MessageSequence = None # type: ignore[assignment]
TextComponent = None # type: ignore[assignment]
database_module = None # type: ignore[assignment]
create_database_migration_bootstrapper = None # type: ignore[assignment]
count_messages = None # type: ignore[assignment]
TaskConfig = None # type: ignore[assignment]
memory_flow_service_module = None # type: ignore[assignment]
memory_service_module = None # type: ignore[assignment]
send_service = None # type: ignore[assignment]
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
class _FakeEmbeddingManager:
def __init__(self, dimension: int = 8) -> None:
self.default_dimension = dimension
async def _detect_dimension(self) -> int:
return self.default_dimension
async def encode(self, text: Any) -> np.ndarray:
def _encode_one(raw: Any) -> np.ndarray:
content = str(raw or "")
vector = np.zeros(self.default_dimension, dtype=np.float32)
for index, byte in enumerate(content.encode("utf-8")):
vector[index % self.default_dimension] += float((byte % 17) + 1)
norm = float(np.linalg.norm(vector))
if norm > 0:
vector /= norm
return vector
if isinstance(text, (list, tuple)):
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
return _encode_one(text).astype(np.float32)
class _KernelBackedRuntimeManager:
def __init__(self, kernel: SDKMemoryKernel) -> None:
self.kernel = kernel
async def invoke(
self,
component_name: str,
args: Dict[str, Any] | None,
*,
timeout_ms: int = 30000,
) -> Any:
del timeout_ms
payload = args or {}
handler = getattr(self.kernel, component_name)
result = handler(**payload)
return await result if inspect.isawaitable(result) else result
class _NoopRuntimeManager:
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> Any:
del hook_name
return SimpleNamespace(aborted=False, kwargs=kwargs)
class _FakePlatformIOManager:
def __init__(self) -> None:
self.ensure_calls = 0
async def ensure_send_pipeline_ready(self) -> None:
self.ensure_calls += 1
def build_route_key_from_message(self, message: Any) -> Any:
del message
return SimpleNamespace(platform="qq")
async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any:
del message, metadata
return SimpleNamespace(
has_success=True,
sent_receipts=[
SimpleNamespace(
driver_id="plugin.qq.sender",
external_message_id="real-message-id",
metadata={},
)
],
failed_receipts=[],
route_key=route_key,
)
def _install_temp_main_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
db_dir = (tmp_path / "main_db").resolve()
db_dir.mkdir(parents=True, exist_ok=True)
db_file = db_dir / "MaiBot.db"
database_url = f"sqlite:///{db_file}"
try:
database_module.engine.dispose()
except Exception:
pass
engine = create_engine(
database_url,
echo=False,
connect_args={"check_same_thread": False},
pool_pre_ping=True,
)
session_local = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
class_=Session,
)
bootstrapper = create_database_migration_bootstrapper(engine)
monkeypatch.setattr(database_module, "_DB_DIR", db_dir, raising=False)
monkeypatch.setattr(database_module, "_DB_FILE", db_file, raising=False)
monkeypatch.setattr(database_module, "DATABASE_URL", database_url, raising=False)
monkeypatch.setattr(database_module, "engine", engine, raising=False)
monkeypatch.setattr(database_module, "SessionLocal", session_local, raising=False)
monkeypatch.setattr(database_module, "_migration_bootstrapper", bootstrapper, raising=False)
monkeypatch.setattr(database_module, "_db_initialized", False, raising=False)
def _build_incoming_message(
*,
session_id: str,
user_id: str,
text: str,
timestamp: datetime | None = None,
) -> SessionMessage:
message = SessionMessage(
message_id="incoming-message-id",
timestamp=timestamp or datetime.now(),
platform="qq",
)
message.message_info = MessageInfo(
user_info=UserInfo(
user_id=user_id,
user_nickname="测试用户",
user_cardname="测试用户",
),
additional_config={},
)
message.raw_message = MessageSequence(components=[TextComponent(text=text)])
message.session_id = session_id
message.reply_to = None
message.is_mentioned = False
message.is_at = False
message.is_emoji = False
message.is_picture = False
message.is_command = False
message.is_notify = False
message.processed_plain_text = text
message.display_message = text
message.initialized = True
return message
async def _wait_until(
predicate: Callable[[], Any],
*,
timeout_seconds: float = 10.0,
interval_seconds: float = 0.05,
description: str,
) -> Any:
deadline = asyncio.get_running_loop().time() + max(0.5, float(timeout_seconds))
while asyncio.get_running_loop().time() < deadline:
value = predicate()
if inspect.isawaitable(value):
value = await value
if value:
return value
await asyncio.sleep(interval_seconds)
raise AssertionError(f"等待超时: {description}")
@pytest.mark.asyncio
async def test_text_to_stream_triggers_real_chat_summary_writeback(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
_install_temp_main_database(monkeypatch, tmp_path)
fake_embedding_manager = _FakeEmbeddingManager()
captured_prompts: List[str] = []
fixed_send_timestamp = 1_777_000_000.0
async def _fake_runtime_self_check(**kwargs: Any) -> Dict[str, Any]:
del kwargs
return {
"ok": True,
"message": "ok",
"configured_dimension": fake_embedding_manager.default_dimension,
"requested_dimension": fake_embedding_manager.default_dimension,
"vector_store_dimension": fake_embedding_manager.default_dimension,
"detected_dimension": fake_embedding_manager.default_dimension,
"encoded_dimension": fake_embedding_manager.default_dimension,
"elapsed_ms": 0.0,
"sample_text": "test",
"checked_at": datetime.now().timestamp(),
}
async def _fake_generate(request: Any) -> Any:
captured_prompts.append(str(getattr(request, "prompt", "") or ""))
return SimpleNamespace(
success=True,
completion=SimpleNamespace(
response=json.dumps(
{
"summary": "这段对话记录了用户提到自己买了绿色围巾,机器人表示会记住这件事。",
"entities": ["绿色围巾"],
"relations": [],
},
ensure_ascii=False,
)
),
)
monkeypatch.setattr(
kernel_module,
"create_embedding_api_adapter",
lambda **kwargs: fake_embedding_manager,
)
monkeypatch.setattr(
kernel_module,
"run_embedding_runtime_self_check",
_fake_runtime_self_check,
)
monkeypatch.setattr(
summary_importer_module,
"run_embedding_runtime_self_check",
_fake_runtime_self_check,
)
monkeypatch.setattr(
summary_importer_module.llm_api,
"get_available_models",
lambda: {"utils": TaskConfig(model_list=["fake-summary-model"])},
)
monkeypatch.setattr(
summary_importer_module.llm_api,
"resolve_task_name_from_model_config",
lambda model_config: "utils",
)
monkeypatch.setattr(
summary_importer_module.llm_api,
"generate",
_fake_generate,
)
monkeypatch.setattr(send_service.time, "time", lambda: fixed_send_timestamp)
monkeypatch.setattr(summary_importer_module.time, "time", lambda: fixed_send_timestamp)
kernel = SDKMemoryKernel(
plugin_root=tmp_path / "plugin_root",
config={
"storage": {"data_dir": str((tmp_path / "a_memorix_data").resolve())},
"advanced": {"enable_auto_save": False},
"embedding": {"dimension": fake_embedding_manager.default_dimension},
"memory": {"base_decay_interval_hours": 24},
"person_profile": {"refresh_interval_minutes": 5},
"summarization": {"model_name": ["utils"]},
},
)
service = memory_flow_service_module.MemoryAutomationService()
fake_platform_io_manager = _FakePlatformIOManager()
async def _fake_rebuild_episodes_for_sources(sources: List[str]) -> Dict[str, Any]:
return {
"rebuilt": 0,
"items": [],
"failures": [],
"sources": list(sources),
}
monkeypatch.setattr(kernel, "rebuild_episodes_for_sources", _fake_rebuild_episodes_for_sources)
monkeypatch.setattr(
memory_service_module,
"a_memorix_host_service",
_KernelBackedRuntimeManager(kernel),
)
monkeypatch.setattr(memory_flow_service_module, "memory_automation_service", service)
monkeypatch.setattr(send_service, "_get_runtime_manager", lambda: _NoopRuntimeManager())
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_platform_io_manager)
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
monkeypatch.setattr(
send_service._chat_manager,
"get_session_by_session_id",
lambda stream_id: (
BotChatSession(
session_id="test-session",
platform="qq",
user_id="target-user",
group_id=None,
)
if stream_id == "test-session"
else None
),
)
monkeypatch.setattr(
memory_flow_service_module.global_config.memory,
"chat_summary_writeback_enabled",
True,
raising=False,
)
monkeypatch.setattr(
memory_flow_service_module.global_config.memory,
"chat_summary_writeback_message_threshold",
2,
raising=False,
)
monkeypatch.setattr(
memory_flow_service_module.global_config.memory,
"chat_summary_writeback_context_length",
10,
raising=False,
)
monkeypatch.setattr(
memory_flow_service_module.global_config.memory,
"person_fact_writeback_enabled",
False,
raising=False,
)
await kernel.initialize()
try:
incoming_message = _build_incoming_message(
session_id="test-session",
user_id="target-user",
text="我最近买了一条绿色围巾。",
timestamp=datetime.fromtimestamp(fixed_send_timestamp) - timedelta(seconds=1),
)
with database_module.get_db_session() as session:
session.add(incoming_message.to_db_instance())
sent_message = await send_service.text_to_stream_with_message(
text="好的,我会记住你最近买了绿色围巾。",
stream_id="test-session",
storage_message=True,
)
assert sent_message is not None
assert sent_message.message_id == "real-message-id"
assert fake_platform_io_manager.ensure_calls == 1
assert count_messages(session_id="test-session") == 2
paragraphs = await _wait_until(
lambda: kernel.metadata_store.get_paragraphs_by_source("chat_summary:test-session"),
description="等待聊天摘要写回到 A_memorix",
)
assert captured_prompts
assert "我最近买了一条绿色围巾。" in captured_prompts[-1]
assert "好的,我会记住你最近买了绿色围巾。" in captured_prompts[-1]
assert any("绿色围巾" in str(item.get("content", "") or "") for item in paragraphs)
assert any(
int(
(
pickle.loads(item.get("metadata"))
if isinstance(item.get("metadata"), (bytes, bytearray))
else item.get("metadata")
or {}
).get("trigger_message_count", 0)
or 0
)
== 2
for item in paragraphs
)
assert service.chat_summary_writeback._states["test-session"].last_trigger_message_count == 2
finally:
await service.shutdown()
await kernel.shutdown()
try:
database_module.engine.dispose()
except Exception:
pass

View File

@@ -164,3 +164,28 @@ async def test_runtime_self_check_reports_requested_dimension_without_explicit_o
assert report["detected_dimension"] == 384 assert report["detected_dimension"] == 384
assert report["encoded_dimension"] == 384 assert report["encoded_dimension"] == 384
assert manager.encode_calls == ["A_Memorix runtime self check"] assert manager.encode_calls == ["A_Memorix runtime self check"]
@pytest.mark.asyncio
async def test_encode_batch_keeps_batch_local_indexes_when_cache_hits_previous_batch(monkeypatch):
adapter = EmbeddingAPIAdapter(default_dimension=4, enable_cache=True)
adapter._dimension = 4
adapter._dimension_detected = True
async def fake_detect_dimension() -> int:
return 4
async def fake_get_embedding_direct(text: str, dimensions: int | None = None):
del dimensions
base = float(ord(str(text)[0]))
return [base, base + 1.0, base + 2.0, base + 3.0]
monkeypatch.setattr(adapter, "_detect_dimension", fake_detect_dimension)
monkeypatch.setattr(adapter, "_get_embedding_direct", fake_get_embedding_direct)
embeddings = await adapter.encode(["A", "B", "A", "C"], batch_size=2)
assert embeddings.shape == (4, 4)
assert np.array_equal(embeddings[0], embeddings[2])
assert embeddings[1][0] == float(ord("B"))
assert embeddings[3][0] == float(ord("C"))

View File

@@ -0,0 +1,745 @@
from __future__ import annotations
import asyncio
import inspect
import json
import time
import uuid
from datetime import datetime
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Callable, Dict
import numpy as np
import pytest
import pytest_asyncio
from sqlalchemy.orm import sessionmaker
from sqlmodel import Session, create_engine, select
IMPORT_ERROR: str | None = None
try:
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
from src.A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
from src.chat.heart_flow.heartflow_manager import heartflow_manager
from src.chat.message_receive import bot as bot_module
from src.chat.message_receive.chat_manager import chat_manager
from src.chat.message_receive.bot import chat_bot
from src.common.database import database as database_module
from src.common.database.database_model import PersonInfo, ToolRecord
from src.common.database.migrations import create_database_migration_bootstrapper
from src.common.utils.utils_session import SessionUtils
from src.llm_models.payload_content.tool_option import ToolCall
from src.maisaka import reasoning_engine as reasoning_engine_module
from src.maisaka import runtime as runtime_module
from src.maisaka.chat_loop_service import ChatResponse
from src.maisaka.context_messages import AssistantMessage
from src.plugin_runtime import component_query as component_query_module
from src.services import memory_flow_service as memory_flow_service_module
from src.services import memory_service as memory_service_module
from src.services.memory_service import memory_service
except SystemExit as exc:
IMPORT_ERROR = f"config initialization exited during import: {exc}"
kernel_module = None # type: ignore[assignment]
KernelSearchRequest = None # type: ignore[assignment]
SDKMemoryKernel = None # type: ignore[assignment]
heartflow_manager = None # type: ignore[assignment]
bot_module = None # type: ignore[assignment]
chat_manager = None # type: ignore[assignment]
chat_bot = None # type: ignore[assignment]
database_module = None # type: ignore[assignment]
ToolRecord = None # type: ignore[assignment]
PersonInfo = None # type: ignore[assignment]
create_database_migration_bootstrapper = None # type: ignore[assignment]
SessionUtils = None # type: ignore[assignment]
ToolCall = None # type: ignore[assignment]
reasoning_engine_module = None # type: ignore[assignment]
runtime_module = None # type: ignore[assignment]
ChatResponse = None # type: ignore[assignment]
AssistantMessage = None # type: ignore[assignment]
component_query_module = None # type: ignore[assignment]
memory_flow_service_module = None # type: ignore[assignment]
memory_service_module = None # type: ignore[assignment]
memory_service = None # type: ignore[assignment]
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
RELATION_QUERY = "测试用户 和 最喜欢的颜色 有什么关系"
class _FakeEmbeddingManager:
def __init__(self, dimension: int = 8) -> None:
self.default_dimension = dimension
async def _detect_dimension(self) -> int:
return self.default_dimension
async def encode(self, text: Any) -> np.ndarray:
def _encode_one(raw: Any) -> np.ndarray:
content = str(raw or "")
vector = np.zeros(self.default_dimension, dtype=np.float32)
for index, byte in enumerate(content.encode("utf-8")):
vector[index % self.default_dimension] += float((byte % 17) + 1)
norm = float(np.linalg.norm(vector))
if norm > 0:
vector /= norm
return vector
if isinstance(text, (list, tuple)):
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
return _encode_one(text).astype(np.float32)
class _KernelBackedRuntimeManager:
def __init__(self, kernel: SDKMemoryKernel) -> None:
self.kernel = kernel
async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000):
del timeout_ms
payload = args or {}
if component_name == "search_memory":
return await self.kernel.search_memory(
KernelSearchRequest(
query=str(payload.get("query", "") or ""),
limit=int(payload.get("limit", 5) or 5),
mode=str(payload.get("mode", "hybrid") or "hybrid"),
chat_id=str(payload.get("chat_id", "") or ""),
person_id=str(payload.get("person_id", "") or ""),
time_start=payload.get("time_start"),
time_end=payload.get("time_end"),
respect_filter=bool(payload.get("respect_filter", True)),
user_id=str(payload.get("user_id", "") or ""),
group_id=str(payload.get("group_id", "") or ""),
)
)
handler = getattr(self.kernel, component_name)
result = handler(**payload)
return await result if inspect.isawaitable(result) else result
class _NoopRuntimeManager:
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> Any:
del hook_name
return SimpleNamespace(aborted=False, kwargs=kwargs)
def _install_temp_main_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
db_dir = (tmp_path / "main_db").resolve()
db_dir.mkdir(parents=True, exist_ok=True)
db_file = db_dir / "MaiBot.db"
database_url = f"sqlite:///{db_file}"
try:
database_module.engine.dispose()
except Exception:
pass
engine = create_engine(
database_url,
echo=False,
connect_args={"check_same_thread": False},
pool_pre_ping=True,
)
session_local = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
class_=Session,
)
bootstrapper = create_database_migration_bootstrapper(engine)
monkeypatch.setattr(database_module, "_DB_DIR", db_dir, raising=False)
monkeypatch.setattr(database_module, "_DB_FILE", db_file, raising=False)
monkeypatch.setattr(database_module, "DATABASE_URL", database_url, raising=False)
monkeypatch.setattr(database_module, "engine", engine, raising=False)
monkeypatch.setattr(database_module, "SessionLocal", session_local, raising=False)
monkeypatch.setattr(database_module, "_migration_bootstrapper", bootstrapper, raising=False)
monkeypatch.setattr(database_module, "_db_initialized", False, raising=False)
def _build_chat_response(content: str, tool_calls: list[ToolCall]) -> ChatResponse:
return ChatResponse(
content=content,
tool_calls=tool_calls,
request_messages=[],
raw_message=AssistantMessage(
content=content,
timestamp=datetime.now(),
tool_calls=tool_calls,
),
selected_history_count=0,
tool_count=len(tool_calls),
prompt_tokens=0,
built_message_count=0,
completion_tokens=0,
total_tokens=0,
prompt_section=None,
)
def _build_message_data(
*,
content: str,
platform: str,
user_id: str,
user_name: str,
group_id: str,
group_name: str,
) -> Dict[str, Any]:
message_id = str(uuid.uuid4())
return {
"message_info": {
"platform": platform,
"message_id": message_id,
"time": time.time(),
"group_info": {
"group_id": group_id,
"group_name": group_name,
"platform": platform,
},
"user_info": {
"user_id": user_id,
"user_nickname": user_name,
"user_cardname": user_name,
"platform": platform,
},
"additional_config": {
"at_bot": True,
},
},
"message_segment": {
"type": "seglist",
"data": [
{
"type": "text",
"data": content,
},
],
},
"raw_message": content,
"processed_plain_text": content,
}
async def _wait_until(
predicate: Callable[[], Any],
*,
timeout_seconds: float = 10.0,
interval_seconds: float = 0.05,
description: str,
) -> Any:
deadline = asyncio.get_running_loop().time() + max(0.5, float(timeout_seconds))
while asyncio.get_running_loop().time() < deadline:
value = predicate()
if inspect.isawaitable(value):
value = await value
if value:
return value
await asyncio.sleep(interval_seconds)
raise AssertionError(f"等待超时: {description}")
def _load_feedback_tasks(kernel: SDKMemoryKernel) -> list[Dict[str, Any]]:
assert kernel.metadata_store is not None
cursor = kernel.metadata_store.get_connection().cursor()
rows = cursor.execute(
"SELECT query_tool_id FROM memory_feedback_tasks ORDER BY id"
).fetchall()
tasks: list[Dict[str, Any]] = []
for row in rows:
task = kernel.metadata_store.get_feedback_task(str(row["query_tool_id"] or ""))
if task is not None:
tasks.append(task)
return tasks
def _load_feedback_action_types(kernel: SDKMemoryKernel) -> list[str]:
assert kernel.metadata_store is not None
cursor = kernel.metadata_store.get_connection().cursor()
rows = cursor.execute(
"SELECT action_type FROM memory_feedback_action_logs ORDER BY id"
).fetchall()
return [str(row["action_type"] or "") for row in rows]
def _load_query_memory_tool_records(session_id: str) -> list[Dict[str, Any]]:
with database_module.get_db_session() as session:
statement = (
select(ToolRecord)
.where(ToolRecord.session_id == session_id)
.where(ToolRecord.tool_name == "query_memory")
.order_by(ToolRecord.timestamp)
)
rows = list(session.exec(statement).all())
return [
{
"tool_id": str(row.tool_id or ""),
"session_id": str(row.session_id or ""),
"tool_name": str(row.tool_name or ""),
"tool_data": str(row.tool_data or ""),
"timestamp": row.timestamp,
}
for row in rows
]
def _seed_person_info(*, person_id: str, person_name: str, session_info: Dict[str, Any]) -> None:
with database_module.get_db_session() as session:
session.add(
PersonInfo(
is_known=True,
person_id=person_id,
person_name=person_name,
platform=str(session_info["platform"]),
user_id=str(session_info["user_id"]),
user_nickname=str(session_info["user_name"]),
group_cardname=json.dumps(
[{"group_id": str(session_info["group_id"]), "group_cardname": person_name}],
ensure_ascii=False,
),
know_counts=1,
first_known_time=datetime.now(),
last_known_time=datetime.now(),
)
)
session.commit()
@pytest_asyncio.fixture
async def chat_feedback_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
_install_temp_main_database(monkeypatch, tmp_path)
chat_manager.sessions.clear()
chat_manager.last_messages.clear()
heartflow_manager.heartflow_chat_list.clear()
noop_runtime_manager = _NoopRuntimeManager()
monkeypatch.setattr(bot_module.ChatBot, "_get_runtime_manager", staticmethod(lambda: noop_runtime_manager))
monkeypatch.setattr(
component_query_module.component_query_service,
"find_command_by_text",
lambda text: None,
)
monkeypatch.setattr(
component_query_module.component_query_service,
"get_llm_available_tool_specs",
lambda: {},
)
monkeypatch.setattr(runtime_module.global_config.mcp, "enable", False, raising=False)
monkeypatch.setattr(
runtime_module.MaisakaHeartFlowChatting,
"_get_message_trigger_threshold",
lambda self: 1,
)
async def _noop_on_incoming_message(message: Any) -> None:
del message
monkeypatch.setattr(
memory_flow_service_module.memory_automation_service,
"on_incoming_message",
_noop_on_incoming_message,
)
fake_embedding_manager = _FakeEmbeddingManager(dimension=8)
async def _fake_runtime_self_check(
*,
config: Any,
sample_text: str,
vector_store: Any,
embedding_manager: Any,
) -> Dict[str, Any]:
del config, sample_text, vector_store, embedding_manager
return {
"ok": True,
"message": "ok",
"checked_at": time.time(),
"encoded_dimension": fake_embedding_manager.default_dimension,
}
monkeypatch.setattr(
kernel_module,
"create_embedding_api_adapter",
lambda **kwargs: fake_embedding_manager,
)
monkeypatch.setattr(
kernel_module,
"run_embedding_runtime_self_check",
_fake_runtime_self_check,
)
kernel = SDKMemoryKernel(
plugin_root=tmp_path / "plugin_root",
config={
"storage": {"data_dir": str((tmp_path / "a_memorix_data").resolve())},
"advanced": {"enable_auto_save": False},
"embedding": {"dimension": fake_embedding_manager.default_dimension},
"memory": {"base_decay_interval_hours": 24},
"person_profile": {"refresh_interval_minutes": 5},
},
)
monkeypatch.setattr(kernel, "_feedback_cfg_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_window_hours", lambda: 0.0004)
monkeypatch.setattr(kernel, "_feedback_cfg_check_interval_seconds", lambda: 0.2)
monkeypatch.setattr(kernel, "_feedback_cfg_batch_size", lambda: 10)
monkeypatch.setattr(kernel, "_feedback_cfg_max_messages", lambda: 10)
monkeypatch.setattr(kernel, "_feedback_cfg_auto_apply_threshold", lambda: 0.85)
monkeypatch.setattr(kernel, "_feedback_cfg_prefilter_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_paragraph_mark_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_paragraph_hard_filter_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_profile_refresh_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_profile_force_refresh_on_read", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_episode_rebuild_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_episode_query_block_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_reconcile_interval_seconds", lambda: 0.2)
monkeypatch.setattr(kernel, "_feedback_cfg_reconcile_batch_size", lambda: 10)
monkeypatch.setattr(kernel_module.global_config.memory, "feedback_correction_paragraph_hard_filter_enabled", True, raising=False)
monkeypatch.setattr(kernel_module.global_config.memory, "feedback_correction_episode_query_block_enabled", True, raising=False)
async def _fake_classify_feedback(
*,
query_tool_id: str,
query_text: str,
hit_briefs: list[Dict[str, Any]],
feedback_messages: list[str],
) -> Dict[str, Any]:
del query_tool_id, query_text, feedback_messages
target_hash = ""
for item in hit_briefs:
if str(item.get("type", "") or "").strip() == "relation":
target_hash = str(item.get("hash", "") or "").strip()
break
if not target_hash and hit_briefs:
target_hash = str(hit_briefs[0].get("hash", "") or "").strip()
return {
"decision": "correct",
"confidence": 0.97,
"target_hashes": [target_hash] if target_hash else [],
"corrected_relations": [
{
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "绿色",
"confidence": 0.99,
}
],
"reason": "用户明确纠正为绿色",
}
monkeypatch.setattr(kernel, "_classify_feedback", _fake_classify_feedback)
await kernel.initialize()
async def _force_episode_fallback(**kwargs: Any) -> Dict[str, Any]:
raise RuntimeError("force_fallback_for_test")
monkeypatch.setattr(
kernel.episode_service.segmentation_service,
"segment",
_force_episode_fallback,
)
monkeypatch.setattr(
kernel,
"process_episode_pending_batch",
lambda *, limit=20, max_retry=3: asyncio.sleep(0, result={"processed": 0, "episode_count": 0, "fallback_count": 0, "failed": 0}),
)
host_manager = _KernelBackedRuntimeManager(kernel)
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", host_manager)
planner_calls: list[str] = []
async def _fake_timing_gate(self, anchor_message: Any):
del self, anchor_message
return "continue", _build_chat_response("直接进入 planner。", []), [], []
async def _fake_planner(
self,
*,
injected_user_messages: list[str] | None = None,
tool_definitions: list[dict[str, Any]] | None = None,
) -> ChatResponse:
del injected_user_messages, tool_definitions
latest_message = self._runtime.message_cache[-1]
latest_text = str(latest_message.processed_plain_text or "")
planner_calls.append(latest_text)
handled_message_ids = getattr(self._runtime, "_test_query_message_ids", None)
if handled_message_ids is None:
handled_message_ids = set()
setattr(self._runtime, "_test_query_message_ids", handled_message_ids)
if latest_message.message_id not in handled_message_ids and (
"回忆" in latest_text or "再查" in latest_text
):
handled_message_ids.add(latest_message.message_id)
tool_call = ToolCall(
call_id=f"query-{uuid.uuid4().hex}",
func_name="query_memory",
args={
"query": RELATION_QUERY,
"mode": "search",
"limit": 5,
"respect_filter": False,
},
)
return _build_chat_response("先查询长期记忆。", [tool_call])
stop_call = ToolCall(
call_id=f"stop-{uuid.uuid4().hex}",
func_name="no_reply",
args={},
)
return _build_chat_response("当前轮次结束。", [stop_call])
monkeypatch.setattr(
reasoning_engine_module.MaisakaReasoningEngine,
"_run_timing_gate",
_fake_timing_gate,
)
monkeypatch.setattr(
reasoning_engine_module.MaisakaReasoningEngine,
"_run_interruptible_planner",
_fake_planner,
)
session_info = {
"platform": "unit_test_chat",
"user_id": "user_feedback_flow",
"user_name": "反馈测试用户",
"group_id": "group_feedback_flow",
"group_name": "反馈纠错测试群",
}
person_id = "person_feedback_flow"
session_id = SessionUtils.calculate_session_id(
session_info["platform"],
user_id=session_info["user_id"],
group_id=session_info["group_id"],
)
_seed_person_info(person_id=person_id, person_name="测试用户", session_info=session_info)
try:
yield {
"kernel": kernel,
"session_id": session_id,
"session_info": session_info,
"person_id": person_id,
"planner_calls": planner_calls,
}
finally:
for key, chat in list(heartflow_manager.heartflow_chat_list.items()):
try:
await chat.stop()
except Exception:
pass
heartflow_manager.heartflow_chat_list.pop(key, None)
chat_manager.sessions.clear()
chat_manager.last_messages.clear()
await kernel.shutdown()
try:
database_module.engine.dispose()
except Exception:
pass
@pytest.mark.asyncio
async def test_feedback_correction_real_chat_flow(chat_feedback_env) -> None:
kernel = chat_feedback_env["kernel"]
session_id = chat_feedback_env["session_id"]
session_info = chat_feedback_env["session_info"]
person_id = chat_feedback_env["person_id"]
write_result = await memory_service.ingest_text(
external_id=f"test:feedback-seed:{uuid.uuid4().hex}",
source_type="chat_summary",
text="测试用户 最喜欢的颜色是 蓝色",
chat_id=session_id,
relations=[
{
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "蓝色",
"confidence": 1.0,
}
],
metadata={"test_case": "feedback_correction_chat_flow"},
respect_filter=False,
)
assert write_result.success is True
pre_search = await memory_service.search(
RELATION_QUERY,
mode="search",
chat_id=session_id,
respect_filter=False,
)
assert pre_search.hits
assert any("蓝色" in hit.content for hit in pre_search.hits)
pre_profile = await memory_service.get_person_profile(person_id, chat_id=session_id, limit=10)
pre_profile_text = pre_profile.summary + "\n" + json.dumps(pre_profile.evidence, ensure_ascii=False)
assert "蓝色" in pre_profile_text
seed_source = f"chat_summary:{session_id}"
rebuild_result = await kernel.rebuild_episodes_for_sources([seed_source])
assert rebuild_result["rebuilt"] >= 1
pre_episode = await memory_service.search(
"蓝色",
mode="episode",
chat_id=session_id,
respect_filter=False,
)
assert pre_episode.hits
assert any("蓝色" in hit.content for hit in pre_episode.hits)
await chat_bot.message_process(
_build_message_data(
content="请帮我回忆一下,测试用户最喜欢的颜色是什么?",
**session_info,
)
)
await _wait_until(
lambda: chat_feedback_env["planner_calls"][0] if chat_feedback_env["planner_calls"] else None,
description="planner 收到首条聊天消息",
)
first_query_records = await _wait_until(
lambda: _load_query_memory_tool_records(session_id) if _load_query_memory_tool_records(session_id) else None,
description="首条 query_memory 工具记录生成",
)
assert first_query_records
first_task = await _wait_until(
lambda: _load_feedback_tasks(kernel)[0] if _load_feedback_tasks(kernel) else None,
description="首个反馈任务入队",
)
assert first_task["status"] == "pending"
first_hits = list((first_task.get("query_snapshot") or {}).get("hits") or [])
assert first_hits
assert any("蓝色" in str(item.get("content", "") or "") for item in first_hits)
await chat_bot.message_process(
_build_message_data(
content="不对,测试用户最喜欢的颜色不是蓝色,是绿色。",
**session_info,
)
)
finalized_task = await _wait_until(
lambda: (
kernel.metadata_store.get_feedback_task(first_task["query_tool_id"])
if kernel.metadata_store.get_feedback_task(first_task["query_tool_id"])
and kernel.metadata_store.get_feedback_task(first_task["query_tool_id"]).get("status")
in {"applied", "skipped", "error"}
else None
),
timeout_seconds=12.0,
interval_seconds=0.1,
description="反馈任务进入终态",
)
assert finalized_task["status"] == "applied", finalized_task
assert finalized_task["decision_payload"]["decision"] == "correct"
assert finalized_task["decision_payload"]["apply_result"]["applied"] is True
corrected_hashes = list(
(finalized_task["decision_payload"].get("apply_result") or {}).get("relation_hashes") or []
)
assert corrected_hashes
corrected_hash = str(corrected_hashes[0] or "")
relation_status = kernel.metadata_store.get_relation_status_batch([corrected_hash]).get(corrected_hash, {})
assert bool(relation_status.get("is_inactive")) is True
action_types = _load_feedback_action_types(kernel)
assert "classification" in action_types
assert "forget_relation" in action_types
assert "ingest_correction" in action_types
assert "mark_stale_paragraph" in action_types
assert "enqueue_episode_rebuild" in action_types
assert "enqueue_profile_refresh" in action_types
direct_post_search = await memory_service.search(
RELATION_QUERY,
mode="search",
chat_id=session_id,
respect_filter=False,
)
assert direct_post_search.hits
post_contents = "\n".join(hit.content for hit in direct_post_search.hits)
assert "绿色" in post_contents
assert "蓝色" not in post_contents
profile_refresh_request = await _wait_until(
lambda: (
kernel.metadata_store.get_person_profile_refresh_request(person_id)
if kernel.metadata_store.get_person_profile_refresh_request(person_id)
and kernel.metadata_store.get_person_profile_refresh_request(person_id).get("status") == "done"
else None
),
timeout_seconds=12.0,
interval_seconds=0.1,
description="人物画像刷新完成",
)
assert profile_refresh_request["status"] == "done"
post_profile = await memory_service.get_person_profile(person_id, chat_id=session_id, limit=10)
post_profile_text = post_profile.summary + "\n" + json.dumps(post_profile.evidence, ensure_ascii=False)
assert "绿色" in post_profile_text
assert "蓝色" not in post_profile_text
async def _latest_episode_result():
result = await memory_service.search(
"绿色",
mode="episode",
chat_id=session_id,
respect_filter=False,
)
if not result.hits:
return None
contents = "\n".join(hit.content for hit in result.hits)
if "绿色" in contents and "蓝色" not in contents:
return result
return None
post_episode = await _wait_until(
_latest_episode_result,
timeout_seconds=12.0,
interval_seconds=0.2,
description="episode 重建后返回修正结果",
)
assert post_episode is not None
stale_episode = await memory_service.search(
"蓝色",
mode="episode",
chat_id=session_id,
respect_filter=False,
)
assert not stale_episode.hits
await chat_bot.message_process(
_build_message_data(
content="再查一次,测试用户最喜欢的颜色是什么?",
**session_info,
)
)
tool_records = await _wait_until(
lambda: (
_load_query_memory_tool_records(session_id)
if len(_load_query_memory_tool_records(session_id)) >= 2
else None
),
timeout_seconds=10.0,
interval_seconds=0.1,
description="第二次 query_memory 工具记录生成",
)
latest_tool_data = json.loads(str(tool_records[-1].get("tool_data") or "{}"))
latest_hits = list((latest_tool_data.get("structured_content") or {}).get("hits") or [])
assert latest_hits
latest_contents = "\n".join(str(item.get("content", "") or "") for item in latest_hits)
assert "绿色" in latest_contents
assert "蓝色" not in latest_contents

View File

@@ -0,0 +1,396 @@
from datetime import datetime
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict
import pytest
IMPORT_ERROR: str | None = None
try:
from src.A_memorix.core.retrieval.sparse_bm25 import SparseBM25Config, SparseBM25Index
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
except SystemExit as exc:
IMPORT_ERROR = f"config initialization exited during import: {exc}"
SparseBM25Config = None # type: ignore[assignment]
SparseBM25Index = None # type: ignore[assignment]
kernel_module = None # type: ignore[assignment]
SDKMemoryKernel = None # type: ignore[assignment]
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
@pytest.mark.asyncio
async def test_kernel_enqueue_feedback_task_delegates_to_metadata_store(monkeypatch: pytest.MonkeyPatch) -> None:
captured: Dict[str, Any] = {}
def fake_enqueue_feedback_task(**kwargs):
captured.update(kwargs)
return {
"id": 1,
"query_tool_id": kwargs["query_tool_id"],
"session_id": kwargs["session_id"],
"query_timestamp": kwargs["query_timestamp"],
"due_at": kwargs["due_at"],
"query_snapshot": kwargs["query_snapshot"],
}
monkeypatch.setattr(
kernel_module,
"global_config",
SimpleNamespace(
memory=SimpleNamespace(
feedback_correction_enabled=True,
feedback_correction_window_hours=12.0,
)
),
)
query_time = datetime(2026, 4, 9, 10, 30, 0)
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel.metadata_store = SimpleNamespace(enqueue_feedback_task=fake_enqueue_feedback_task)
payload = await kernel.enqueue_feedback_task(
query_tool_id="tool-query-1",
session_id="session-1",
query_timestamp=query_time,
structured_content={"query": "Alice 喜欢什么", "hits": [{"hash": "relation-1"}]},
)
assert payload["success"] is True
assert payload["queued"] is True
assert captured["query_tool_id"] == "tool-query-1"
assert captured["session_id"] == "session-1"
assert captured["query_snapshot"]["query"] == "Alice 喜欢什么"
assert captured["query_snapshot"]["hits"] == [{"hash": "relation-1"}]
assert captured["due_at"] == pytest.approx(query_time.timestamp() + 12 * 3600, rel=0, abs=1e-6)
@pytest.mark.asyncio
async def test_kernel_enqueue_feedback_task_skipped_when_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
kernel_module,
"global_config",
SimpleNamespace(memory=SimpleNamespace(feedback_correction_enabled=False)),
)
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel.metadata_store = SimpleNamespace(enqueue_feedback_task=lambda **kwargs: kwargs)
payload = await kernel.enqueue_feedback_task(
query_tool_id="tool-query-2",
session_id="session-1",
query_timestamp=datetime.now(),
structured_content={"hits": [{"hash": "relation-1"}]},
)
assert payload["success"] is False
assert payload["reason"] == "feedback_correction_disabled"
@pytest.mark.asyncio
async def test_apply_feedback_decision_resolves_paragraph_targets() -> None:
action_logs: list[Dict[str, Any]] = []
forgotten_hashes: list[str] = []
ingested_payloads: list[Dict[str, Any]] = []
stale_marks: list[Dict[str, Any]] = []
episode_sources: list[str] = []
profile_refresh_ids: list[str] = []
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel.metadata_store = SimpleNamespace(
get_paragraph_relations=lambda paragraph_hash: [
{
"hash": "relation-1",
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "蓝色",
}
]
if paragraph_hash == "paragraph-1"
else [],
get_relation_status_batch=lambda hashes: {
str(hash_value): {"is_inactive": str(hash_value) in forgotten_hashes}
for hash_value in hashes
},
get_paragraph_hashes_by_relation_hashes=lambda hashes: {
"relation-1": ["paragraph-1"]
}
if "relation-1" in hashes
else {},
upsert_paragraph_stale_relation_mark=lambda **kwargs: stale_marks.append(kwargs) or kwargs,
enqueue_episode_source_rebuild=lambda source, reason="": episode_sources.append(source) or True,
enqueue_person_profile_refresh=lambda **kwargs: profile_refresh_ids.append(kwargs["person_id"]) or kwargs,
get_paragraph=lambda paragraph_hash: {"hash": "paragraph-1", "source": "chat_feedback_test_seed:session-1"}
if paragraph_hash == "paragraph-1"
else None,
append_feedback_action_log=lambda **kwargs: action_logs.append(kwargs),
)
kernel._feedback_cfg_auto_apply_threshold = lambda: 0.85 # type: ignore[method-assign]
kernel._apply_v5_relation_action = lambda *, action, hashes, strength=1.0: ( # type: ignore[method-assign]
forgotten_hashes.extend([str(item) for item in hashes]),
{"success": True, "action": action, "hashes": list(hashes), "strength": strength},
)[1]
kernel._feedback_cfg_paragraph_mark_enabled = lambda: True # type: ignore[method-assign]
kernel._feedback_cfg_episode_rebuild_enabled = lambda: True # type: ignore[method-assign]
kernel._feedback_cfg_profile_refresh_enabled = lambda: True # type: ignore[method-assign]
kernel._resolve_feedback_related_person_ids = lambda **kwargs: ["person-1"] # type: ignore[method-assign]
kernel._query_relation_rows_by_hashes = lambda relation_hashes, include_inactive=False: [ # type: ignore[method-assign]
{
"hash": "relation-1",
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "蓝色",
}
]
async def _fake_ingest_feedback_relations(**kwargs):
ingested_payloads.append(kwargs)
return {"success": True, "stored_ids": ["relation-2"]}
kernel._ingest_feedback_relations = _fake_ingest_feedback_relations # type: ignore[method-assign]
payload = await kernel._apply_feedback_decision(
task_id=1,
query_tool_id="tool-query-1",
session_id="session-1",
decision={
"decision": "correct",
"confidence": 0.97,
"target_hashes": ["paragraph-1"],
"corrected_relations": [
{
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "绿色",
"confidence": 0.99,
}
],
"reason": "用户明确纠正为绿色",
},
hit_map={
"paragraph-1": {
"hash": "paragraph-1",
"type": "paragraph",
"content": "测试用户 最喜欢的颜色是 蓝色",
"linked_relation_hashes": ["relation-1"],
}
},
)
assert payload["applied"] is True
assert payload["relation_hashes"] == ["relation-1"]
assert forgotten_hashes == ["relation-1"]
assert ingested_payloads[0]["relation_hashes"] == ["relation-1"]
assert payload["stale_paragraph_hashes"] == ["paragraph-1"]
assert "chat_feedback_test_seed:session-1" in payload["episode_rebuild_sources"]
assert "chat_summary:session-1" in payload["episode_rebuild_sources"]
assert payload["profile_refresh_person_ids"] == ["person-1"]
assert stale_marks[0]["paragraph_hash"] == "paragraph-1"
assert {item["action_type"] for item in action_logs} == {
"forget_relation",
"ingest_correction",
"mark_stale_paragraph",
"enqueue_episode_rebuild",
"enqueue_profile_refresh",
}
def test_filter_active_relation_hits_removes_inactive_relations() -> None:
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel._feedback_cfg_paragraph_hard_filter_enabled = lambda: True # type: ignore[method-assign]
kernel.metadata_store = SimpleNamespace(
get_relation_status_batch=lambda hashes: {
"r-active": {"is_inactive": False},
"r-inactive": {"is_inactive": True},
"r-para-inactive": {"is_inactive": True},
"r-stale-active": {"is_inactive": False},
"r-stale-inactive": {"is_inactive": True},
},
get_paragraph_relations=lambda paragraph_hash: (
[{"hash": "r-para-inactive"}] if paragraph_hash == "p-inactive" else []
),
get_paragraph_stale_relation_marks_batch=lambda paragraph_hashes: {
"p-stale": [{"relation_hash": "r-stale-inactive"}],
"p-restored": [{"relation_hash": "r-stale-active"}],
},
)
hits = [
{"hash": "r-active", "type": "relation", "content": "A 喜欢 B"},
{"hash": "r-inactive", "type": "relation", "content": "A 讨厌 B"},
{"hash": "p-1", "type": "paragraph", "content": "段落证据"},
{"hash": "p-inactive", "type": "paragraph", "content": "失活段落证据"},
{"hash": "p-stale", "type": "paragraph", "content": "被标脏段落"},
{"hash": "p-restored", "type": "paragraph", "content": "恢复可见段落"},
]
filtered = kernel._filter_active_relation_hits(hits)
assert [item["hash"] for item in filtered] == ["r-active", "p-1", "p-restored"]
def test_sparse_relation_search_requests_active_only() -> None:
captured: Dict[str, Any] = {}
class FakeMetadataStore:
def fts_search_relations_bm25(self, **kwargs):
captured.update(kwargs)
return []
index = SparseBM25Index(
metadata_store=FakeMetadataStore(), # type: ignore[arg-type]
config=SparseBM25Config(enabled=True, lazy_load=False),
)
index._loaded = True
index._conn = object() # type: ignore[assignment]
result = index.search_relations("测试纠错", k=5)
assert result == []
assert captured["include_inactive"] is False
@pytest.mark.asyncio
async def test_feedback_task_rollback_restores_snapshots_and_requeues_followups() -> None:
action_logs: list[Dict[str, Any]] = []
queued_sources: list[str] = []
queued_profiles: list[str] = []
deleted_marks: list[tuple[str, str]] = []
deleted_paragraphs: list[str] = []
relation_statuses: Dict[str, Dict[str, Any]] = {
"rel-old": {"is_inactive": True, "weight": 0.0, "is_pinned": False, "protected_until": 0.0, "last_reinforced": None, "inactive_since": 1.0},
"rel-new": {"is_inactive": False, "weight": 1.0, "is_pinned": False, "protected_until": 0.0, "last_reinforced": None, "inactive_since": None},
}
current_task: Dict[str, Any] = {
"id": 1,
"query_tool_id": "tool-query-rollback",
"session_id": "session-1",
"status": "applied",
"rollback_status": "none",
"query_snapshot": {"query": "测试用户最喜欢的颜色是什么"},
"decision_payload": {"decision": "correct", "confidence": 0.97},
"rollback_plan": {
"forgotten_relations": [
{
"hash": "rel-old",
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "蓝色",
"before_status": {
"is_inactive": False,
"weight": 0.8,
"is_pinned": False,
"protected_until": 0.0,
"last_reinforced": None,
"inactive_since": None,
},
}
],
"corrected_write": {
"paragraph_hashes": ["paragraph-new"],
"corrected_relations": [
{
"hash": "rel-new",
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "绿色",
"existed_before": False,
"before_status": {},
}
],
},
"stale_marks": [{"paragraph_hash": "paragraph-old", "relation_hash": "rel-old"}],
"episode_sources": ["chat_summary:session-1"],
"profile_person_ids": ["person-1"],
},
}
class _Conn:
def cursor(self):
return self
def execute(self, *_args, **_kwargs):
return self
def commit(self):
return None
metadata_store = SimpleNamespace(
get_feedback_task_by_id=lambda task_id: current_task if int(task_id) == 1 else None,
mark_feedback_task_rollback_running=lambda **kwargs: current_task.update({"rollback_status": "running"}) or current_task,
finalize_feedback_task_rollback=lambda **kwargs: current_task.update(
{
"rollback_status": kwargs["rollback_status"],
"rollback_result": kwargs.get("rollback_result") or {},
"rollback_error": kwargs.get("rollback_error", ""),
}
)
or current_task,
get_relation_status_batch=lambda hashes: {
hash_value: dict(relation_statuses[hash_value])
for hash_value in hashes
if hash_value in relation_statuses
},
restore_relation_status_from_snapshot=lambda hash_value, snapshot: relation_statuses.update(
{hash_value: dict(snapshot)}
)
or dict(snapshot),
append_feedback_action_log=lambda **kwargs: action_logs.append(kwargs),
mark_as_deleted=lambda hashes, type_: deleted_paragraphs.extend(list(hashes)) or len(list(hashes)),
get_paragraph=lambda paragraph_hash: {"hash": paragraph_hash, "source": "chat_summary:session-1"},
get_connection=lambda: _Conn(),
delete_external_memory_refs_by_paragraphs=lambda hashes: [
{"paragraph_hash": str(hash_value), "external_id": f"external:{hash_value}"}
for hash_value in hashes
],
update_relations_protection=lambda hashes, **kwargs: None,
mark_relations_inactive=lambda hashes, inactive_since=None: [
relation_statuses.__setitem__(
hash_value,
{
**relation_statuses.get(hash_value, {}),
"is_inactive": True,
"inactive_since": inactive_since,
},
)
for hash_value in hashes
],
delete_paragraph_stale_relation_marks=lambda marks: deleted_marks.extend(list(marks)) or len(list(marks)),
enqueue_episode_source_rebuild=lambda source, reason='': queued_sources.append(source) or True,
enqueue_person_profile_refresh=lambda **kwargs: queued_profiles.append(kwargs["person_id"]) or kwargs,
list_feedback_action_logs=lambda task_id: action_logs if int(task_id) == 1 else [],
)
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel.metadata_store = metadata_store
async def _noop_initialize() -> None:
return None
kernel.initialize = _noop_initialize # type: ignore[method-assign]
kernel._rebuild_graph_from_metadata = lambda: None # type: ignore[method-assign]
kernel._persist = lambda: None # type: ignore[method-assign]
payload = await kernel._rollback_feedback_task(
task_id=1,
requested_by="pytest",
reason="manual rollback",
)
assert payload["success"] is True
assert relation_statuses["rel-old"]["is_inactive"] is False
assert relation_statuses["rel-new"]["is_inactive"] is True
assert deleted_paragraphs == ["paragraph-new"]
assert deleted_marks == [("paragraph-old", "rel-old")]
assert queued_sources == ["chat_summary:session-1"]
assert queued_profiles == ["person-1"]
assert current_task["rollback_status"] == "rolled_back"
assert {item["action_type"] for item in action_logs} >= {
"rollback_restore_relation",
"rollback_revert_corrected_relation",
"rollback_delete_correction_paragraph",
"rollback_clear_stale_mark",
"rollback_enqueue_episode_rebuild",
"rollback_enqueue_profile_refresh",
}

View File

@@ -0,0 +1,82 @@
from __future__ import annotations
import pickle
from pathlib import Path
import pytest
try:
from src.A_memorix.core.storage.graph_store import GraphStore
except SystemExit as exc:
GraphStore = None # type: ignore[assignment]
IMPORT_ERROR = f"config initialization exited during import: {exc}"
else:
IMPORT_ERROR = None
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
def _build_empty_graph_metadata() -> dict:
return {
"nodes": [],
"node_to_idx": {},
"node_attrs": {},
"matrix_format": "csr",
"total_nodes_added": 0,
"total_edges_added": 0,
"total_nodes_deleted": 0,
"total_edges_deleted": 0,
"edge_hash_map": {},
}
def test_graph_store_clear_save_removes_stale_adjacency(tmp_path: Path) -> None:
data_dir = tmp_path / "graph_data"
store = GraphStore(data_dir=data_dir)
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
store.save()
matrix_path = data_dir / "graph_adjacency.npz"
assert matrix_path.exists()
store.clear()
store.save()
assert not matrix_path.exists()
def test_graph_store_load_resets_stale_adjacency_when_metadata_is_empty(tmp_path: Path) -> None:
data_dir = tmp_path / "graph_data"
store = GraphStore(data_dir=data_dir)
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
store.save()
metadata_path = data_dir / "graph_metadata.pkl"
with metadata_path.open("wb") as handle:
pickle.dump(_build_empty_graph_metadata(), handle)
reloaded = GraphStore(data_dir=data_dir)
reloaded.load()
assert reloaded.num_nodes == 0
assert reloaded.num_edges == 0
assert reloaded.get_nodes() == []
def test_graph_store_load_clears_stale_edge_hash_map_when_metadata_is_empty(tmp_path: Path) -> None:
data_dir = tmp_path / "graph_data"
store = GraphStore(data_dir=data_dir)
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
store.save()
metadata_path = data_dir / "graph_metadata.pkl"
empty_metadata = _build_empty_graph_metadata()
empty_metadata["edge_hash_map"] = {(0, 1): {"rel-1"}}
with metadata_path.open("wb") as handle:
pickle.dump(empty_metadata, handle)
reloaded = GraphStore(data_dir=data_dir)
reloaded.load()
assert reloaded.has_edge_hash_map() is False

View File

@@ -35,12 +35,30 @@ def test_legacy_learning_list_with_numeric_fourth_column_is_migrated():
] ]
def test_visual_multimodal_replyer_is_migrated_to_replyer_mode() -> None: def test_visual_multimodal_planner_is_migrated_to_planner_mode():
payload = { payload = {"visual": {"multimodal_planner": True}}
"visual": {
"multimodal_replyer": True, result = try_migrate_legacy_bot_config_dict(payload)
}
} assert result.migrated is True
assert "visual.multimodal_planner_moved_to_visual.planner_mode" in result.reason
assert result.data["visual"]["planner_mode"] == "multimodal"
assert "multimodal_planner" not in result.data["visual"]
def test_chat_multimodal_planner_is_migrated_to_visual_planner_mode():
payload = {"chat": {"multimodal_planner": True}}
result = try_migrate_legacy_bot_config_dict(payload)
assert result.migrated is True
assert "chat.multimodal_planner_moved_to_visual.planner_mode" in result.reason
assert result.data["visual"]["planner_mode"] == "multimodal"
assert "multimodal_planner" not in result.data["chat"]
def test_visual_multimodal_replyer_is_migrated_to_replyer_mode():
payload = {"visual": {"multimodal_replyer": True}}
result = try_migrate_legacy_bot_config_dict(payload) result = try_migrate_legacy_bot_config_dict(payload)
@@ -50,13 +68,8 @@ def test_visual_multimodal_replyer_is_migrated_to_replyer_mode() -> None:
assert "multimodal_replyer" not in result.data["visual"] assert "multimodal_replyer" not in result.data["visual"]
def test_chat_replyer_generator_type_is_migrated_to_replyer_mode() -> None: def test_chat_replyer_generator_type_is_migrated_to_replyer_mode():
payload = { payload = {"chat": {"replyer_generator_type": "legacy"}}
"chat": {
"replyer_generator_type": "legacy",
},
"visual": {},
}
result = try_migrate_legacy_bot_config_dict(payload) result = try_migrate_legacy_bot_config_dict(payload)

View File

@@ -38,6 +38,206 @@ def test_person_fact_resolve_target_person_for_private_chat(monkeypatch):
assert person.person_id == "qq:123" assert person.person_id == "qq:123"
@pytest.mark.asyncio
async def test_chat_summary_writeback_service_triggers_when_threshold_reached(monkeypatch):
events: list[tuple[str, object]] = []
monkeypatch.setattr(
memory_flow_module,
"global_config",
SimpleNamespace(
memory=SimpleNamespace(
chat_summary_writeback_enabled=True,
chat_summary_writeback_message_threshold=3,
chat_summary_writeback_context_length=7,
)
),
)
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
async def fake_ingest_summary(**kwargs):
events.append(("ingest_summary", kwargs))
return SimpleNamespace(success=True, detail="ok")
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
del self, session_id, total_message_count
return 0
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
monkeypatch.setattr(
memory_flow_module.ChatSummaryWritebackService,
"_load_last_trigger_message_count",
fake_load_last_trigger_message_count,
)
service = memory_flow_module.ChatSummaryWritebackService()
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
await service._handle_message(message)
assert len(events) == 1
_, payload = events[0]
assert payload["external_id"] == "chat_auto_summary:session-1:5"
assert payload["chat_id"] == "session-1"
assert payload["text"] == ""
assert payload["metadata"]["generate_from_chat"] is True
assert payload["metadata"]["context_length"] == 7
assert payload["metadata"]["trigger"] == "message_threshold"
assert payload["user_id"] == "user-1"
assert payload["group_id"] == "group-1"
@pytest.mark.asyncio
async def test_chat_summary_writeback_service_skips_when_threshold_not_reached(monkeypatch):
called = False
monkeypatch.setattr(
memory_flow_module,
"global_config",
SimpleNamespace(
memory=SimpleNamespace(
chat_summary_writeback_enabled=True,
chat_summary_writeback_message_threshold=6,
chat_summary_writeback_context_length=9,
)
),
)
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
async def fake_ingest_summary(**kwargs):
nonlocal called
called = True
return SimpleNamespace(success=True, detail="ok")
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
del self, session_id, total_message_count
return 0
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
monkeypatch.setattr(
memory_flow_module.ChatSummaryWritebackService,
"_load_last_trigger_message_count",
fake_load_last_trigger_message_count,
)
service = memory_flow_module.ChatSummaryWritebackService()
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
await service._handle_message(message)
assert called is False
@pytest.mark.asyncio
async def test_chat_summary_writeback_service_restores_previous_trigger_count(monkeypatch):
events: list[tuple[str, object]] = []
monkeypatch.setattr(
memory_flow_module,
"global_config",
SimpleNamespace(
memory=SimpleNamespace(
chat_summary_writeback_enabled=True,
chat_summary_writeback_message_threshold=3,
chat_summary_writeback_context_length=7,
)
),
)
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 8)
async def fake_ingest_summary(**kwargs):
events.append(("ingest_summary", kwargs))
return SimpleNamespace(success=True, detail="ok")
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
del self, session_id, total_message_count
return 5
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
monkeypatch.setattr(
memory_flow_module.ChatSummaryWritebackService,
"_load_last_trigger_message_count",
fake_load_last_trigger_message_count,
)
service = memory_flow_module.ChatSummaryWritebackService()
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
await service._handle_message(message)
assert len(events) == 1
_, payload = events[0]
assert payload["external_id"] == "chat_auto_summary:session-1:8"
assert service._states["session-1"].last_trigger_message_count == 8
@pytest.mark.asyncio
async def test_chat_summary_writeback_service_falls_back_to_current_count_for_legacy_summary(monkeypatch):
called = False
monkeypatch.setattr(
memory_flow_module,
"global_config",
SimpleNamespace(
memory=SimpleNamespace(
chat_summary_writeback_enabled=True,
chat_summary_writeback_message_threshold=3,
chat_summary_writeback_context_length=7,
)
),
)
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
async def fake_ingest_summary(**kwargs):
nonlocal called
called = True
return SimpleNamespace(success=True, detail="ok")
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
del self, session_id, total_message_count
return 5
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
monkeypatch.setattr(
memory_flow_module.ChatSummaryWritebackService,
"_load_last_trigger_message_count",
fake_load_last_trigger_message_count,
)
service = memory_flow_module.ChatSummaryWritebackService()
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
await service._handle_message(message)
assert called is False
assert service._states["session-1"].last_trigger_message_count == 5
@pytest.mark.asyncio
async def test_chat_summary_writeback_service_loads_trigger_count_from_summary_metadata(monkeypatch):
class FakeMetadataStore:
@staticmethod
def get_paragraphs_by_source(source: str):
assert source == "chat_summary:session-1"
return [
{"created_at": 1.0, "metadata": {"trigger_message_count": 3}},
{"created_at": 2.0, "metadata": {"trigger_message_count": 6}},
]
class FakeRuntimeManager:
@staticmethod
async def _ensure_kernel():
return SimpleNamespace(metadata_store=FakeMetadataStore())
monkeypatch.setattr(memory_flow_module.memory_service_module, "a_memorix_host_service", FakeRuntimeManager())
service = memory_flow_module.ChatSummaryWritebackService()
restored = await service._load_last_trigger_message_count(session_id="session-1", total_message_count=8)
assert restored == 6
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_memory_automation_service_auto_starts_and_delegates(): async def test_memory_automation_service_auto_starts_and_delegates():
events: list[tuple[str, str]] = [] events: list[tuple[str, str]] = []
@@ -52,15 +252,67 @@ async def test_memory_automation_service_auto_starts_and_delegates():
async def shutdown(self): async def shutdown(self):
events.append(("shutdown", "fact")) events.append(("shutdown", "fact"))
class FakeChatSummaryWriteback:
async def start(self):
events.append(("start", "summary"))
async def enqueue(self, message):
events.append(("summary", message.session_id))
async def shutdown(self):
events.append(("shutdown", "summary"))
service = memory_flow_module.MemoryAutomationService() service = memory_flow_module.MemoryAutomationService()
service.fact_writeback = FakeFactWriteback() service.fact_writeback = FakeFactWriteback()
service.chat_summary_writeback = FakeChatSummaryWriteback()
await service.on_incoming_message(SimpleNamespace(session_id="session-1"))
await service.on_message_sent(SimpleNamespace(session_id="session-1")) await service.on_message_sent(SimpleNamespace(session_id="session-1"))
await service.shutdown() await service.shutdown()
assert events == [ assert events == [
("start", "fact"), ("start", "fact"),
("start", "summary"),
("sent", "session-1"), ("sent", "session-1"),
("summary", "session-1"),
("shutdown", "summary"),
("shutdown", "fact"),
]
@pytest.mark.asyncio
async def test_memory_automation_service_on_incoming_message_auto_starts_only():
events: list[tuple[str, str]] = []
class FakeFactWriteback:
async def start(self):
events.append(("start", "fact"))
async def enqueue(self, message):
events.append(("sent", message.session_id))
async def shutdown(self):
events.append(("shutdown", "fact"))
class FakeChatSummaryWriteback:
async def start(self):
events.append(("start", "summary"))
async def enqueue(self, message):
events.append(("summary", message.session_id))
async def shutdown(self):
events.append(("shutdown", "summary"))
service = memory_flow_module.MemoryAutomationService()
service.fact_writeback = FakeFactWriteback()
service.chat_summary_writeback = FakeChatSummaryWriteback()
await service.on_incoming_message(SimpleNamespace(session_id="session-1"))
await service.shutdown()
assert events == [
("start", "fact"),
("start", "summary"),
("shutdown", "summary"),
("shutdown", "fact"), ("shutdown", "fact"),
] ]

View File

@@ -1,5 +1,6 @@
"""发送服务回归测试。""" """发送服务回归测试。"""
import sys
from types import SimpleNamespace from types import SimpleNamespace
from typing import Any, Dict, List from typing import Any, Dict, List
@@ -182,6 +183,75 @@ async def test_text_to_stream_with_message_returns_sent_message(monkeypatch: pyt
assert stored_messages[0].message_id == "real-message-id" assert stored_messages[0].message_id == "real-message-id"
@pytest.mark.asyncio
async def test_text_to_stream_with_message_triggers_memory_and_syncs_maisaka_history(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_manager = _FakePlatformIOManager(
delivery_batch=SimpleNamespace(
has_success=True,
sent_receipts=[
SimpleNamespace(
driver_id="plugin.qq.sender",
external_message_id="real-message-id",
metadata={},
)
],
failed_receipts=[],
route_key=SimpleNamespace(platform="qq"),
)
)
stored_messages: List[Any] = []
memory_events: List[str] = []
history_events: List[tuple[str, str]] = []
class FakeMemoryAutomationService:
async def on_message_sent(self, message: Any) -> None:
memory_events.append(str(message.message_id))
class FakeRuntime:
def append_sent_message_to_chat_history(self, message: Any, *, source_kind: str = "guided_reply") -> None:
history_events.append((str(message.message_id), source_kind))
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager)
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
monkeypatch.setattr(
send_service._chat_manager,
"get_session_by_session_id",
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
)
monkeypatch.setattr(
send_service.MessageUtils,
"store_message_to_db",
lambda message: stored_messages.append(message),
)
monkeypatch.setitem(
sys.modules,
"src.services.memory_flow_service",
SimpleNamespace(memory_automation_service=FakeMemoryAutomationService()),
)
monkeypatch.setitem(
sys.modules,
"src.chat.heart_flow.heartflow_manager",
SimpleNamespace(heartflow_manager=SimpleNamespace(heartflow_chat_list={"test-session": FakeRuntime()})),
)
sent_message = await send_service.text_to_stream_with_message(
text="你好",
stream_id="test-session",
sync_to_maisaka_history=True,
maisaka_source_kind="guided_reply",
)
assert sent_message is not None
assert sent_message.message_id == "real-message-id"
assert fake_manager.ensure_calls == 1
assert len(stored_messages) == 1
assert stored_messages[0].message_id == "real-message-id"
assert memory_events == ["real-message-id"]
assert history_events == [("real-message-id", "guided_reply")]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: pytest.MonkeyPatch) -> None: async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: pytest.MonkeyPatch) -> None:
fake_manager = _FakePlatformIOManager( fake_manager = _FakePlatformIOManager(

View File

@@ -82,6 +82,7 @@ def test_resolve_static_path_prefers_installed_dashboard_package(monkeypatch, tm
def test_resolve_static_path_uses_dashboard_dist(monkeypatch, tmp_path) -> None: def test_resolve_static_path_uses_dashboard_dist(monkeypatch, tmp_path) -> None:
dashboard_dist = tmp_path / "dashboard" / "dist" dashboard_dist = tmp_path / "dashboard" / "dist"
dashboard_dist.mkdir(parents=True) dashboard_dist.mkdir(parents=True)
(dashboard_dist / "index.html").write_text("<html></html>", encoding="utf-8")
monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path) monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path)
@@ -91,6 +92,26 @@ def test_resolve_static_path_uses_dashboard_dist(monkeypatch, tmp_path) -> None:
assert resolved_path == dashboard_dist assert resolved_path == dashboard_dist
def test_resolve_static_path_falls_back_to_package_when_dashboard_dist_has_no_index(monkeypatch, tmp_path) -> None:
dashboard_dist = tmp_path / "dashboard" / "dist"
dashboard_dist.mkdir(parents=True)
package_dist = tmp_path / "site-packages" / "maibot_dashboard" / "dist"
package_dist.mkdir(parents=True)
class _DashboardModule:
@staticmethod
def get_dist_path() -> Path:
return package_dist
monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path)
with patch.object(webui_app, "import_module", return_value=_DashboardModule()):
resolved_path = webui_app._resolve_static_path()
assert resolved_path == package_dist
def test_resolve_safe_static_file_path_allows_regular_static_file(tmp_path) -> None: def test_resolve_safe_static_file_path_allows_regular_static_file(tmp_path) -> None:
static_path = tmp_path / "dist" static_path = tmp_path / "dist"
asset_path = static_path / "assets" / "app.js" asset_path = static_path / "assets" / "app.js"

View File

@@ -99,13 +99,15 @@ def test_maisaka_is_host_tab_and_mcp_is_attached_to_it():
assert mcp_schema.get("uiParent") == "maisaka" assert mcp_schema.get("uiParent") == "maisaka"
def test_maisaka_memory_query_config_fields_are_exposed(): def test_memory_query_config_fields_are_exposed():
"""MaiSaka 长期记忆检索开关和默认条数应出现在配置 schema 中。""" """query_memory 开关和默认条数应出现在记忆配置 schema 中。"""
schema = ConfigSchemaGenerator.generate_schema(Config) schema = ConfigSchemaGenerator.generate_schema(Config)
maisaka_schema = schema["nested"]["maisaka"] memory_schema = schema["nested"]["memory"]
enable_field = next(field for field in maisaka_schema["fields"] if field["name"] == "enable_memory_query_tool") assert memory_schema.get("uiParent") == "emoji"
limit_field = next(field for field in maisaka_schema["fields"] if field["name"] == "memory_query_default_limit")
enable_field = next(field for field in memory_schema["fields"] if field["name"] == "enable_memory_query_tool")
limit_field = next(field for field in memory_schema["fields"] if field["name"] == "memory_query_default_limit")
assert enable_field["type"] == "boolean" assert enable_field["type"] == "boolean"
assert enable_field.get("x-widget") == "switch" assert enable_field.get("x-widget") == "switch"

View File

@@ -638,3 +638,41 @@ def test_delete_operation_routes(client: TestClient, monkeypatch):
assert list_response.json()["count"] == 1 assert list_response.json()["count"] == 1
assert get_response.status_code == 200 assert get_response.status_code == 200
assert get_response.json()["operation"]["operation_id"] == "del-1" assert get_response.json()["operation"]["operation_id"] == "del-1"
def test_feedback_correction_routes(client: TestClient, monkeypatch):
async def fake_feedback_admin(*, action: str, **kwargs):
if action == "list":
assert kwargs == {
"limit": 7,
"statuses": ["applied"],
"rollback_statuses": ["none"],
"query": "green",
}
return {"success": True, "items": [{"task_id": 11, "query_text": "what color"}], "count": 1}
if action == "get":
assert kwargs == {"task_id": 11}
return {"success": True, "task": {"task_id": 11, "query_text": "what color", "action_logs": []}}
if action == "rollback":
assert kwargs == {"task_id": 11, "requested_by": "tester", "reason": "manual revert"}
return {"success": True, "result": {"restored_relation_hashes": ["rel-1"]}}
raise AssertionError(action)
monkeypatch.setattr(memory_router_module.memory_service, "feedback_admin", fake_feedback_admin)
list_response = client.get(
"/api/webui/memory/feedback-corrections",
params={"limit": 7, "status": "applied", "rollback_status": "none", "query": "green"},
)
get_response = client.get("/api/webui/memory/feedback-corrections/11")
rollback_response = client.post(
"/api/webui/memory/feedback-corrections/11/rollback",
json={"requested_by": "tester", "reason": "manual revert"},
)
assert list_response.status_code == 200
assert list_response.json()["items"][0]["task_id"] == 11
assert get_response.status_code == 200
assert get_response.json()["task"]["task_id"] == 11
assert rollback_response.status_code == 200
assert rollback_response.json()["result"]["restored_relation_hashes"] == ["rel-1"]

View File

@@ -8,7 +8,7 @@
- 运行时主目录由 `storage.data_dir` 决定(当前模板默认 `data/a-memorix` - 运行时主目录由 `storage.data_dir` 决定(当前模板默认 `data/a-memorix`
- 部分离线脚本仍以 `data/plugins/a-dawn.a-memorix` 作为默认处理目录。 - 部分离线脚本仍以 `data/plugins/a-dawn.a-memorix` 作为默认处理目录。
- 修正文档中的导入示例参数,`memory_import_admin.create_paste``input_mode` 示例统一为 `text`/`json` - 修正文档中的导入示例参数,`memory_import_admin.create_paste``input_mode` 示例统一为 `text`/`json`
- 更新 `README.md` 关于元数据 schema 的描述,和当前代码 `SCHEMA_VERSION = 9` 保持一致。 - 更新 `README.md` 关于元数据 schema 的描述,和当前代码 `SCHEMA_VERSION = 10` 保持一致。
## [2.0.0] - 2026-03-18 ## [2.0.0] - 2026-03-18

View File

@@ -1,6 +1,6 @@
# A_Memorix 配置参考 (v2.0.0) # A_Memorix 配置参考 (v2.0.0)
本文档对应当前仓库代码(`__version__ = 2.0.0``SCHEMA_VERSION = 9`)。 本文档对应当前仓库代码(`__version__ = 2.0.0``SCHEMA_VERSION = 10`)。
说明: 说明:

View File

@@ -11,7 +11,7 @@ from __future__ import annotations
import asyncio import asyncio
import time import time
from typing import Any, List, Optional, Union from typing import Any, Dict, List, Optional, Tuple, Union
import aiohttp import aiohttp
import numpy as np import numpy as np
@@ -29,6 +29,9 @@ logger = get_logger("A_Memorix.EmbeddingAPIAdapter")
class EmbeddingAPIAdapter: class EmbeddingAPIAdapter:
"""适配宿主 embedding 请求接口。""" """适配宿主 embedding 请求接口。"""
_GLOBAL_DIMENSION_CACHE: Dict[str, int] = {}
_GLOBAL_TEXT_EMBEDDING_CACHE: Dict[Tuple[str, int, str], np.ndarray] = {}
def __init__( def __init__(
self, self,
batch_size: int = 32, batch_size: int = 32,
@@ -232,10 +235,32 @@ class EmbeddingAPIAdapter:
logger.error(f"通过直接 Client 获取 Embedding 失败: {last_exc}") logger.error(f"通过直接 Client 获取 Embedding 失败: {last_exc}")
return None return None
def _dimension_cache_key(self) -> str:
candidate_names = self._resolve_candidate_model_names()
return "|".join(
[
str(self.model_name or "auto"),
str(self.default_dimension),
",".join(candidate_names),
]
)
def _embedding_cache_key(self, text: str, dimensions: Optional[int]) -> Tuple[str, int, str]:
requested_dimension = self._resolve_canonical_dimension(dimensions)
return (self._dimension_cache_key(), int(requested_dimension), str(text or ""))
async def _detect_dimension(self) -> int: async def _detect_dimension(self) -> int:
if self._dimension_detected and self._dimension is not None: if self._dimension_detected and self._dimension is not None:
return self._dimension return self._dimension
cache_key = self._dimension_cache_key()
cached_dimension = self._GLOBAL_DIMENSION_CACHE.get(cache_key)
if cached_dimension is not None:
self._dimension = int(cached_dimension)
self._dimension_detected = True
logger.info(f"嵌入维度命中进程缓存: {self._dimension}")
return self._dimension
logger.info("正在检测嵌入模型维度...") logger.info("正在检测嵌入模型维度...")
try: try:
target_dim = self.default_dimension target_dim = self.default_dimension
@@ -251,6 +276,7 @@ class EmbeddingAPIAdapter:
) )
self._dimension = detected_dim self._dimension = detected_dim
self._dimension_detected = True self._dimension_detected = True
self._GLOBAL_DIMENSION_CACHE[cache_key] = int(detected_dim)
return detected_dim return detected_dim
except Exception as exc: except Exception as exc:
logger.debug(f"带维度参数探测失败: {exc},尝试不带维度参数探测") logger.debug(f"带维度参数探测失败: {exc},尝试不带维度参数探测")
@@ -261,6 +287,7 @@ class EmbeddingAPIAdapter:
detected_dim = len(test_embedding) detected_dim = len(test_embedding)
self._dimension = detected_dim self._dimension = detected_dim
self._dimension_detected = True self._dimension_detected = True
self._GLOBAL_DIMENSION_CACHE[cache_key] = int(detected_dim)
logger.info(f"嵌入维度检测成功 (自然维度): {detected_dim}") logger.info(f"嵌入维度检测成功 (自然维度): {detected_dim}")
return detected_dim return detected_dim
logger.warning(f"嵌入维度检测失败,使用 configured_dimension: {self.default_dimension}") logger.warning(f"嵌入维度检测失败,使用 configured_dimension: {self.default_dimension}")
@@ -269,6 +296,7 @@ class EmbeddingAPIAdapter:
self._dimension = self.default_dimension self._dimension = self.default_dimension
self._dimension_detected = True self._dimension_detected = True
self._GLOBAL_DIMENSION_CACHE[cache_key] = int(self.default_dimension)
return self.default_dimension return self.default_dimension
async def encode( async def encode(
@@ -336,26 +364,54 @@ class EmbeddingAPIAdapter:
all_embeddings: List[np.ndarray] = [] all_embeddings: List[np.ndarray] = []
for offset in range(0, len(texts), batch_size): for offset in range(0, len(texts), batch_size):
batch = texts[offset : offset + batch_size] batch = texts[offset : offset + batch_size]
batch_results: List[Tuple[int, np.ndarray]] = []
uncached_items: List[Tuple[int, str]] = []
if self.enable_cache:
for index, text in enumerate(batch):
cache_key = self._embedding_cache_key(text, dimensions)
cached_vector = self._GLOBAL_TEXT_EMBEDDING_CACHE.get(cache_key)
if cached_vector is None:
uncached_items.append((index, text))
else:
batch_results.append((index, cached_vector.copy()))
else:
uncached_items = list(enumerate(batch))
if not uncached_items:
batch_results.sort(key=lambda item: item[0])
all_embeddings.extend(emb for _, emb in batch_results)
continue
semaphore = asyncio.Semaphore(self.max_concurrent) semaphore = asyncio.Semaphore(self.max_concurrent)
async def encode_with_semaphore(text: str, index: int): async def encode_with_semaphore(text: str, batch_index: int, absolute_index: int):
async with semaphore: async with semaphore:
embedding = await self._get_embedding_direct(text, dimensions=dimensions) embedding = await self._get_embedding_direct(text, dimensions=dimensions)
if embedding is None: if embedding is None:
raise RuntimeError(f"文本 {index} 编码失败embedding 返回为空") raise RuntimeError(f"文本 {absolute_index} 编码失败embedding 返回为空")
vector = self._validate_embedding_vector( vector = self._validate_embedding_vector(
embedding, embedding,
source=f"文本 {index}", source=f"文本 {absolute_index}",
) )
return index, vector return batch_index, vector
tasks = [ tasks = [
encode_with_semaphore(text, offset + index) encode_with_semaphore(text, index, offset + index)
for index, text in enumerate(batch) for index, text in uncached_items
] ]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
results.sort(key=lambda item: item[0]) normalized_results: List[Tuple[int, np.ndarray]] = []
all_embeddings.extend(emb for _, emb in results) for batch_index, vector in results:
normalized_results.append((batch_index, vector))
if self.enable_cache:
text = batch[batch_index]
cache_key = self._embedding_cache_key(text, dimensions)
self._GLOBAL_TEXT_EMBEDDING_CACHE[cache_key] = vector.copy()
batch_results.extend(normalized_results)
batch_results.sort(key=lambda item: item[0])
all_embeddings.extend(emb for _, emb in batch_results)
return np.array(all_embeddings, dtype=np.float32) return np.array(all_embeddings, dtype=np.float32)

View File

@@ -632,7 +632,7 @@ class DualPathRetriever:
results: List[RetrievalResult] = [] results: List[RetrievalResult] = []
for row in rows: for row in rows:
hash_value = row["hash"] hash_value = row["hash"]
relation = self.metadata_store.get_relation(hash_value) relation = self.metadata_store.get_relation(hash_value, include_inactive=False)
if relation is None: if relation is None:
continue continue
@@ -888,8 +888,8 @@ class DualPathRetriever:
entity_name = entity["name"] entity_name = entity["name"]
related_rels = [] related_rels = []
related_rels.extend(self.metadata_store.get_relations(subject=entity_name)) related_rels.extend(self.metadata_store.get_relations(subject=entity_name, include_inactive=False))
related_rels.extend(self.metadata_store.get_relations(object=entity_name)) related_rels.extend(self.metadata_store.get_relations(object=entity_name, include_inactive=False))
for rel in related_rels: for rel in related_rels:
if rel["hash"] in seen_relations: if rel["hash"] in seen_relations:
@@ -1280,7 +1280,7 @@ class DualPathRetriever:
results = [] results = []
for hash_value, score in zip(rel_ids, rel_scores): for hash_value, score in zip(rel_ids, rel_scores):
relation = self.metadata_store.get_relation(hash_value) relation = self.metadata_store.get_relation(hash_value, include_inactive=False)
if relation is None: if relation is None:
continue continue
@@ -1378,7 +1378,7 @@ class DualPathRetriever:
deduplicated_results.append(result) deduplicated_results.append(result)
continue continue
# 检查关系关联的段落是否已存在 # 检查关系关联的段落是否已存在
relation = self.metadata_store.get_relation(result.hash_value) relation = self.metadata_store.get_relation(result.hash_value, include_inactive=False)
if relation: if relation:
# 获取关联的段落 # 获取关联的段落
para_rels = self.metadata_store.query(""" para_rels = self.metadata_store.query("""

View File

@@ -255,7 +255,7 @@ class GraphRelationRecallService:
graph_hops: int, graph_hops: int,
graph_seed_entities: Sequence[str], graph_seed_entities: Sequence[str],
) -> Optional[GraphRelationCandidate]: ) -> Optional[GraphRelationCandidate]:
relation = self.metadata_store.get_relation(relation_hash) relation = self.metadata_store.get_relation(relation_hash, include_inactive=False)
if relation is None: if relation is None:
return None return None
supporting_paragraphs = self.metadata_store.get_paragraphs_by_relation(relation_hash) supporting_paragraphs = self.metadata_store.get_paragraphs_by_relation(relation_hash)

View File

@@ -338,6 +338,7 @@ class SparseBM25Index:
match_query=match_query, match_query=match_query,
limit=max(1, int(k)), limit=max(1, int(k)),
max_doc_len=self.config.relation_max_doc_len, max_doc_len=self.config.relation_max_doc_len,
include_inactive=False,
conn=self._conn, conn=self._conn,
) )
out: List[Dict[str, Any]] = [] out: List[Dict[str, Any]] = []

File diff suppressed because it is too large Load Diff

View File

@@ -1190,11 +1190,14 @@ class GraphStore:
data_dir.mkdir(parents=True, exist_ok=True) data_dir.mkdir(parents=True, exist_ok=True)
# 保存邻接矩阵 # 保存邻接矩阵
matrix_path = data_dir / "graph_adjacency.npz"
if self._adjacency is not None: if self._adjacency is not None:
matrix_path = data_dir / "graph_adjacency.npz"
with atomic_write(matrix_path, "wb") as f: with atomic_write(matrix_path, "wb") as f:
save_npz(f, self._adjacency) save_npz(f, self._adjacency)
logger.debug(f"保存邻接矩阵: {matrix_path}") logger.debug(f"保存邻接矩阵: {matrix_path}")
elif matrix_path.exists():
matrix_path.unlink()
logger.debug(f"删除陈旧邻接矩阵: {matrix_path}")
# 保存元数据 # 保存元数据
metadata = { metadata = {
@@ -1288,9 +1291,29 @@ class GraphStore:
if self._adjacency is not None: if self._adjacency is not None:
adj_n = self._adjacency.shape[0] adj_n = self._adjacency.shape[0]
current_n = len(self._nodes) current_n = len(self._nodes)
if current_n > adj_n: if current_n == 0:
logger.warning("检测到空图元数据但邻接矩阵仍然存在,已重置为空图。")
self._adjacency = None
self._edge_hash_map = defaultdict(set)
elif current_n > adj_n:
logger.warning(f"检测到图存储维度不匹配: 节点数={current_n}, 矩阵大小={adj_n}. 正在自动修复...") logger.warning(f"检测到图存储维度不匹配: 节点数={current_n}, 矩阵大小={adj_n}. 正在自动修复...")
self._expand_adjacency_matrix(current_n - adj_n) self._expand_adjacency_matrix(current_n - adj_n)
elif current_n < adj_n:
logger.warning(
f"检测到过期邻接矩阵: 节点数={current_n}, 矩阵大小={adj_n}. 正在重置邻接矩阵..."
)
if self.matrix_format == "csc":
self._adjacency = csc_matrix((current_n, current_n), dtype=np.float32)
else:
self._adjacency = csr_matrix((current_n, current_n), dtype=np.float32)
self._edge_hash_map = defaultdict(
set,
{
(src_idx, dst_idx): set(hashes)
for (src_idx, dst_idx), hashes in self._edge_hash_map.items()
if src_idx < current_n and dst_idx < current_n
},
)
self._adjacency_dirty = True self._adjacency_dirty = True
logger.info( logger.info(
@@ -1445,4 +1468,3 @@ class GraphStore:
self._adjacency_dirty = True self._adjacency_dirty = True
logger.info(f"已从 {count} 条哈希重建边哈希映射,覆盖 {len(self._edge_hash_map)} 条边") logger.info(f"已从 {count} 条哈希重建边哈希映射,覆盖 {len(self._edge_hash_map)} 条边")
return count return count

File diff suppressed because it is too large Load Diff

View File

@@ -18,6 +18,7 @@ from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config
from .episode_segmentation_service import EpisodeSegmentationService from .episode_segmentation_service import EpisodeSegmentationService
from .hash import compute_hash from .hash import compute_hash
@@ -528,7 +529,11 @@ class EpisodeService:
"paragraph_count": 0, "paragraph_count": 0,
} }
paragraphs = self.metadata_store.get_live_paragraphs_by_source(token) memory_cfg = getattr(global_config, "memory", None)
paragraphs = self.metadata_store.get_live_paragraphs_by_source(
token,
exclude_stale=bool(getattr(memory_cfg, "feedback_correction_paragraph_hard_filter_enabled", True)),
)
if not paragraphs: if not paragraphs:
replace_result = self.metadata_store.replace_episodes_for_source(token, []) replace_result = self.metadata_store.replace_episodes_for_source(token, [])
return { return {

View File

@@ -90,9 +90,9 @@ def find_paths_between_entities(
else: else:
pred = "related" pred = "related"
direction = "->" direction = "->"
rels = metadata_store.get_relations(subject=u, object=v) rels = metadata_store.get_relations(subject=u, object=v, include_inactive=False)
if not rels: if not rels:
rels = metadata_store.get_relations(subject=v, object=u) rels = metadata_store.get_relations(subject=v, object=u, include_inactive=False)
direction = "<-" direction = "<-"
if rels: if rels:
best_rel = max(rels, key=lambda x: x.get("confidence", 1.0)) best_rel = max(rels, key=lambda x: x.get("confidence", 1.0))
@@ -162,4 +162,3 @@ def to_retrieval_results(paths: Sequence[Dict[str, Any]]) -> List[RetrievalResul
) )
) )
return converted return converted

View File

@@ -15,6 +15,7 @@ from sqlmodel import select
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database import get_db_session from src.common.database.database import get_db_session
from src.common.database.database_model import PersonInfo from src.common.database.database_model import PersonInfo
from src.config.config import global_config
from ..embedding import EmbeddingAPIAdapter from ..embedding import EmbeddingAPIAdapter
from ..retrieval import ( from ..retrieval import (
@@ -285,11 +286,11 @@ class PersonProfileService:
def _collect_relation_evidence(self, aliases: List[str], limit: int = 30) -> List[Dict[str, Any]]: def _collect_relation_evidence(self, aliases: List[str], limit: int = 30) -> List[Dict[str, Any]]:
relation_by_hash: Dict[str, Dict[str, Any]] = {} relation_by_hash: Dict[str, Dict[str, Any]] = {}
for alias in aliases: for alias in aliases:
for rel in self.metadata_store.get_relations(subject=alias): for rel in self.metadata_store.get_relations(subject=alias, include_inactive=False):
h = str(rel.get("hash", "")) h = str(rel.get("hash", ""))
if h: if h:
relation_by_hash[h] = rel relation_by_hash[h] = rel
for rel in self.metadata_store.get_relations(object=alias): for rel in self.metadata_store.get_relations(object=alias, include_inactive=False):
h = str(rel.get("hash", "")) h = str(rel.get("hash", ""))
if h: if h:
relation_by_hash[h] = rel relation_by_hash[h] = rel
@@ -342,7 +343,53 @@ class PersonProfileService:
"metadata": {}, "metadata": {},
} }
) )
return evidence return self._filter_stale_paragraph_evidence(evidence)
def _filter_stale_paragraph_evidence(
self,
evidence: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
memory_cfg = getattr(global_config, "memory", None)
if not bool(getattr(memory_cfg, "feedback_correction_paragraph_hard_filter_enabled", True)):
return evidence
paragraph_hashes = [
str(item.get("hash", "") or "").strip()
for item in evidence
if str(item.get("type", "") or "").strip() == "paragraph" and str(item.get("hash", "") or "").strip()
]
if not paragraph_hashes:
return evidence
marks_by_paragraph = self.metadata_store.get_paragraph_stale_relation_marks_batch(paragraph_hashes)
relation_hashes: List[str] = []
seen = set()
for marks in marks_by_paragraph.values():
for mark in marks:
relation_hash = str(mark.get("relation_hash", "") or "").strip()
if not relation_hash or relation_hash in seen:
continue
seen.add(relation_hash)
relation_hashes.append(relation_hash)
status_map = self.metadata_store.get_relation_status_batch(relation_hashes) if relation_hashes else {}
filtered: List[Dict[str, Any]] = []
for item in evidence:
item_type = str(item.get("type", "") or "").strip()
item_hash = str(item.get("hash", "") or "").strip()
if item_type != "paragraph" or not item_hash:
filtered.append(item)
continue
marks = marks_by_paragraph.get(item_hash, [])
should_hide = any(
status_map.get(str(mark.get("relation_hash", "") or "").strip()) is None
or bool((status_map.get(str(mark.get("relation_hash", "") or "").strip()) or {}).get("is_inactive"))
for mark in marks
if str(mark.get("relation_hash", "") or "").strip()
)
if should_hide:
continue
filtered.append(item)
return filtered
async def _collect_vector_evidence( async def _collect_vector_evidence(
self, self,
@@ -373,7 +420,7 @@ class PersonProfileService:
"metadata": {}, "metadata": {},
} }
) )
return fallback[:top_k] return self._filter_stale_paragraph_evidence(fallback[:top_k])
per_alias_top_k = max(2, int(top_k / max(1, len(alias_queries)))) per_alias_top_k = max(2, int(top_k / max(1, len(alias_queries))))
seen_hash = set() seen_hash = set()
@@ -406,7 +453,7 @@ class PersonProfileService:
} }
) )
evidence.sort(key=lambda x: x.get("score", 0.0), reverse=True) evidence.sort(key=lambda x: x.get("score", 0.0), reverse=True)
return evidence[:top_k] return self._filter_stale_paragraph_evidence(evidence[:top_k])
def _build_profile_text( def _build_profile_text(
self, self,

View File

@@ -5,12 +5,13 @@
导入到 A_memorix 的存储组件中。 导入到 A_memorix 的存储组件中。
""" """
import time from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import json import json
import re import re
import time
import traceback import traceback
from typing import List, Dict, Any, Tuple, Optional
from pathlib import Path
from src.common.logger import get_logger from src.common.logger import get_logger
from src.services import llm_service as llm_api from src.services import llm_service as llm_api
@@ -222,7 +223,9 @@ class SummaryImporter:
self, self,
stream_id: str, stream_id: str,
context_length: Optional[int] = None, context_length: Optional[int] = None,
include_personality: Optional[bool] = None include_personality: Optional[bool] = None,
time_end: Optional[float] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Tuple[bool, str]: ) -> Tuple[bool, str]:
""" """
从指定的聊天流中提取记录并执行总结导入 从指定的聊天流中提取记录并执行总结导入
@@ -231,6 +234,7 @@ class SummaryImporter:
stream_id: 聊天流 ID stream_id: 聊天流 ID
context_length: 总结的历史消息条数 context_length: 总结的历史消息条数
include_personality: 是否包含人设 include_personality: 是否包含人设
time_end: 用于截取聊天记录的时间上界(闭区间)
Returns: Returns:
Tuple[bool, str]: (是否成功, 结果消息) Tuple[bool, str]: (是否成功, 结果消息)
@@ -248,12 +252,13 @@ class SummaryImporter:
include_personality = self.plugin_config.get("summarization", {}).get("include_personality", True) include_personality = self.plugin_config.get("summarization", {}).get("include_personality", True)
# 2. 获取历史消息 # 2. 获取历史消息
# 获取当前时间之前的消息 query_time_end = time.time() if time_end is None else float(time_end)
now = time.time() messages = message_api.get_messages_by_time_in_chat(
messages = message_api.get_messages_before_time_in_chat(
chat_id=stream_id, chat_id=stream_id,
timestamp=now, start_time=0.0,
limit=context_length end_time=query_time_end,
limit=context_length,
limit_mode="latest",
) )
if not messages: if not messages:
@@ -323,7 +328,14 @@ class SummaryImporter:
} }
# 6. 执行导入 # 6. 执行导入
await self._execute_import(summary_text, entities, relations, stream_id, time_meta=time_meta) await self._execute_import(
summary_text,
entities,
relations,
stream_id,
time_meta=time_meta,
metadata=metadata,
)
# 7. 持久化 # 7. 持久化
self.vector_store.save() self.vector_store.save()
@@ -389,6 +401,7 @@ class SummaryImporter:
relations: List[Dict[str, str]], relations: List[Dict[str, str]],
stream_id: str, stream_id: str,
time_meta: Optional[Dict[str, Any]] = None, time_meta: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
): ):
"""将数据写入存储""" """将数据写入存储"""
# 获取默认知识类型 # 获取默认知识类型
@@ -403,6 +416,7 @@ class SummaryImporter:
hash_value = self.metadata_store.add_paragraph( hash_value = self.metadata_store.add_paragraph(
content=summary, content=summary,
source=f"chat_summary:{stream_id}", source=f"chat_summary:{stream_id}",
metadata=metadata,
knowledge_type=knowledge_type.value, knowledge_type=knowledge_type.value,
time_meta=time_meta, time_meta=time_meta,
) )

View File

@@ -190,6 +190,16 @@ class AMemorixHostService:
) )
) )
if component_name == "enqueue_feedback_task":
return await kernel.enqueue_feedback_task(
query_tool_id=str(payload.get("query_tool_id", "") or ""),
session_id=str(payload.get("session_id", "") or ""),
query_timestamp=payload.get("query_timestamp"),
structured_content=payload.get("structured_content")
if isinstance(payload.get("structured_content"), dict)
else {},
)
if component_name == "ingest_summary": if component_name == "ingest_summary":
return await kernel.ingest_summary( return await kernel.ingest_summary(
external_id=str(payload.get("external_id", "") or ""), external_id=str(payload.get("external_id", "") or ""),
@@ -251,6 +261,7 @@ class AMemorixHostService:
"memory_source_admin": kernel.memory_source_admin, "memory_source_admin": kernel.memory_source_admin,
"memory_episode_admin": kernel.memory_episode_admin, "memory_episode_admin": kernel.memory_episode_admin,
"memory_profile_admin": kernel.memory_profile_admin, "memory_profile_admin": kernel.memory_profile_admin,
"memory_feedback_admin": kernel.memory_feedback_admin,
"memory_runtime_admin": kernel.memory_runtime_admin, "memory_runtime_admin": kernel.memory_runtime_admin,
"memory_import_admin": kernel.memory_import_admin, "memory_import_admin": kernel.memory_import_admin,
"memory_tuning_admin": kernel.memory_tuning_admin, "memory_tuning_admin": kernel.memory_tuning_admin,

View File

@@ -62,7 +62,10 @@ if any(arg in {"-h", "--help"} for arg in sys.argv[1:]):
try: try:
from A_memorix.core.storage import GraphStore, KnowledgeType, MetadataStore, QuantizationType, VectorStore from A_memorix.core.storage import GraphStore, KnowledgeType, MetadataStore, QuantizationType, VectorStore
from A_memorix.core.storage.metadata_store import SCHEMA_VERSION from A_memorix.core.storage.metadata_store import (
RUNTIME_AUTO_MIGRATION_MIN_SCHEMA_VERSION,
SCHEMA_VERSION,
)
except Exception as e: # pragma: no cover except Exception as e: # pragma: no cover
print(f"❌ failed to import storage modules: {e}") print(f"❌ failed to import storage modules: {e}")
raise SystemExit(2) raise SystemExit(2)
@@ -125,6 +128,14 @@ def _sqlite_table_exists(conn: sqlite3.Connection, table: str) -> bool:
return row is not None return row is not None
def _sqlite_column_exists(conn: sqlite3.Connection, table: str, column: str) -> bool:
try:
rows = conn.execute(f"PRAGMA table_info({table})").fetchall()
except Exception:
return False
return any(str(row[1] or "") == str(column or "") for row in rows)
def _collect_hash_alias_conflicts(conn: sqlite3.Connection) -> Dict[str, List[str]]: def _collect_hash_alias_conflicts(conn: sqlite3.Connection) -> Dict[str, List[str]]:
hashes: List[str] = [] hashes: List[str] = []
if _sqlite_table_exists(conn, "relations"): if _sqlite_table_exists(conn, "relations"):
@@ -152,6 +163,8 @@ def _collect_hash_alias_conflicts(conn: sqlite3.Connection) -> Dict[str, List[st
def _collect_invalid_knowledge_types(conn: sqlite3.Connection) -> List[str]: def _collect_invalid_knowledge_types(conn: sqlite3.Connection) -> List[str]:
if not _sqlite_table_exists(conn, "paragraphs"): if not _sqlite_table_exists(conn, "paragraphs"):
return [] return []
if not _sqlite_column_exists(conn, "paragraphs", "knowledge_type"):
return []
allowed = {item.value for item in KnowledgeType} allowed = {item.value for item in KnowledgeType}
rows = conn.execute("SELECT DISTINCT knowledge_type FROM paragraphs").fetchall() rows = conn.execute("SELECT DISTINCT knowledge_type FROM paragraphs").fetchall()
@@ -288,6 +301,14 @@ def _preflight_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]:
facts["schema_migrations_exists"] = has_schema_table facts["schema_migrations_exists"] = has_schema_table
has_paragraph_backfill = _sqlite_table_exists(conn, "paragraph_vector_backfill") has_paragraph_backfill = _sqlite_table_exists(conn, "paragraph_vector_backfill")
facts["paragraph_vector_backfill_exists"] = has_paragraph_backfill facts["paragraph_vector_backfill_exists"] = has_paragraph_backfill
has_stale_marks = _sqlite_table_exists(conn, "paragraph_stale_relation_marks")
facts["paragraph_stale_relation_marks_exists"] = has_stale_marks
has_profile_refresh_queue = _sqlite_table_exists(conn, "person_profile_refresh_queue")
facts["person_profile_refresh_queue_exists"] = has_profile_refresh_queue
has_feedback_rollback_status = _sqlite_column_exists(conn, "memory_feedback_tasks", "rollback_status")
facts["memory_feedback_tasks_rollback_status_exists"] = has_feedback_rollback_status
has_feedback_rollback_plan = _sqlite_column_exists(conn, "memory_feedback_tasks", "rollback_plan_json")
facts["memory_feedback_tasks_rollback_plan_exists"] = has_feedback_rollback_plan
if not has_schema_table: if not has_schema_table:
checks.append( checks.append(
CheckItem( CheckItem(
@@ -300,14 +321,28 @@ def _preflight_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]:
row = conn.execute("SELECT MAX(version) FROM schema_migrations").fetchone() row = conn.execute("SELECT MAX(version) FROM schema_migrations").fetchone()
version = int(row[0]) if row and row[0] is not None else 0 version = int(row[0]) if row and row[0] is not None else 0
facts["schema_version"] = version facts["schema_version"] = version
runtime_auto_migratable = (
version < SCHEMA_VERSION
and version >= RUNTIME_AUTO_MIGRATION_MIN_SCHEMA_VERSION
)
facts["schema_runtime_auto_migratable"] = runtime_auto_migratable
if version != SCHEMA_VERSION: if version != SCHEMA_VERSION:
checks.append( if runtime_auto_migratable:
CheckItem( checks.append(
"CP-08", CheckItem(
"error", "CP-18",
f"schema version mismatch: current={version}, expected={SCHEMA_VERSION}", "warning",
f"schema version behind runtime target: current={version}, expected={SCHEMA_VERSION}; runtime auto migration will handle this update",
)
)
else:
checks.append(
CheckItem(
"CP-08",
"error",
f"schema version mismatch: current={version}, expected={SCHEMA_VERSION}",
)
) )
)
elif not has_paragraph_backfill: elif not has_paragraph_backfill:
checks.append( checks.append(
CheckItem( CheckItem(
@@ -316,6 +351,30 @@ def _preflight_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]:
"paragraph_vector_backfill table missing under current schema version", "paragraph_vector_backfill table missing under current schema version",
) )
) )
elif not has_stale_marks:
checks.append(
CheckItem(
"CP-15",
"error",
"paragraph_stale_relation_marks table missing under current schema version",
)
)
elif not has_profile_refresh_queue:
checks.append(
CheckItem(
"CP-16",
"error",
"person_profile_refresh_queue table missing under current schema version",
)
)
elif not has_feedback_rollback_status or not has_feedback_rollback_plan:
checks.append(
CheckItem(
"CP-17",
"error",
"memory_feedback_tasks rollback columns missing under current schema version",
)
)
if _sqlite_table_exists(conn, "relations"): if _sqlite_table_exists(conn, "relations"):
row = conn.execute("SELECT COUNT(*) FROM relations").fetchone() row = conn.execute("SELECT COUNT(*) FROM relations").fetchone()
@@ -616,6 +675,46 @@ def _verify_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]:
"paragraph_vector_backfill table missing after migration", "paragraph_vector_backfill table missing after migration",
) )
) )
has_feedback_tasks = _sqlite_table_exists(conn, "memory_feedback_tasks")
facts["memory_feedback_tasks_exists"] = bool(has_feedback_tasks)
if not has_feedback_tasks:
checks.append(
CheckItem(
"CP-15",
"error",
"memory_feedback_tasks table missing after migration",
)
)
has_feedback_logs = _sqlite_table_exists(conn, "memory_feedback_action_logs")
facts["memory_feedback_action_logs_exists"] = bool(has_feedback_logs)
if not has_feedback_logs:
checks.append(
CheckItem(
"CP-16",
"error",
"memory_feedback_action_logs table missing after migration",
)
)
has_feedback_rollback_status = _sqlite_column_exists(conn, "memory_feedback_tasks", "rollback_status")
facts["memory_feedback_tasks_rollback_status_exists"] = bool(has_feedback_rollback_status)
if not has_feedback_rollback_status:
checks.append(
CheckItem(
"CP-17",
"error",
"memory_feedback_tasks.rollback_status missing after migration",
)
)
has_feedback_rollback_plan = _sqlite_column_exists(conn, "memory_feedback_tasks", "rollback_plan_json")
facts["memory_feedback_tasks_rollback_plan_exists"] = bool(has_feedback_rollback_plan)
if not has_feedback_rollback_plan:
checks.append(
CheckItem(
"CP-18",
"error",
"memory_feedback_tasks.rollback_plan_json missing after migration",
)
)
conflicts = _collect_hash_alias_conflicts(conn) conflicts = _collect_hash_alias_conflicts(conn)
invalid_knowledge_types = _collect_invalid_knowledge_types(conn) invalid_knowledge_types = _collect_invalid_knowledge_types(conn)
finally: finally:

View File

@@ -206,6 +206,40 @@ def _migrate_target_item_list(parent: dict[str, Any], key: str) -> bool:
return True return True
def _parse_planner_mode(value: Any) -> Optional[str]:
"""
兼容旧 planner 配置到当前 visual.planner_mode。
"""
if isinstance(value, bool):
return "multimodal" if value else "text"
if not isinstance(value, str):
return None
normalized_value = value.strip().lower()
if normalized_value in {"text", "multimodal", "auto"}:
return normalized_value
return None
def _parse_replyer_mode(value: Any) -> Optional[str]:
"""
兼容旧 replyer 配置到当前 visual.replyer_mode。
"""
if isinstance(value, bool):
return "multimodal" if value else "text"
if not isinstance(value, str):
return None
normalized_value = value.strip().lower()
if normalized_value in {"text", "multimodal", "auto"}:
return normalized_value
if normalized_value == "legacy":
return "text"
return None
def migrate_legacy_bind_env_to_bot_config_dict(data: dict[str, Any]) -> MigrationResult: def migrate_legacy_bind_env_to_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
"""将旧版 `.env` 中的绑定地址迁移到主配置结构。""" """将旧版 `.env` 中的绑定地址迁移到主配置结构。"""
@@ -280,8 +314,16 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
migrated_any = True migrated_any = True
reasons.append("expression.manual_reflect_operator_id_empty") reasons.append("expression.manual_reflect_operator_id_empty")
chat = _as_dict(data.get("chat"))
personality = _as_dict(data.get("personality")) personality = _as_dict(data.get("personality"))
visual = _as_dict(data.get("visual")) visual = _as_dict(data.get("visual"))
if visual is None and (
(personality is not None and "visual_style" in personality)
or (chat is not None and ("multimodal_planner" in chat or "replyer_generator_type" in chat))
):
visual = {}
data["visual"] = visual
if visual is not None and personality is not None and "visual_style" in personality: if visual is not None and personality is not None and "visual_style" in personality:
if "visual_style" not in visual: if "visual_style" not in visual:
visual["visual_style"] = personality["visual_style"] visual["visual_style"] = personality["visual_style"]
@@ -289,14 +331,41 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
migrated_any = True migrated_any = True
reasons.append("personality.visual_style_moved_to_visual.visual_style") reasons.append("personality.visual_style_moved_to_visual.visual_style")
if visual is not None and "multimodal_planner" in visual and "planner_mode" not in visual: if visual is not None and "multimodal_planner" in visual:
multimodal_planner = visual.pop("multimodal_planner") planner_mode = _parse_planner_mode(visual.get("multimodal_planner"))
if isinstance(multimodal_planner, bool): if "planner_mode" not in visual and planner_mode is not None:
visual["planner_mode"] = "multimodal" if multimodal_planner else "text" visual["planner_mode"] = planner_mode
if "planner_mode" in visual:
visual.pop("multimodal_planner", None)
migrated_any = True migrated_any = True
reasons.append("visual.multimodal_planner_moved_to_visual.planner_mode") reasons.append("visual.multimodal_planner_moved_to_visual.planner_mode")
else:
visual["multimodal_planner"] = multimodal_planner if visual is not None and chat is not None and "multimodal_planner" in chat:
planner_mode = _parse_planner_mode(chat.get("multimodal_planner"))
if "planner_mode" not in visual and planner_mode is not None:
visual["planner_mode"] = planner_mode
if "planner_mode" in visual:
chat.pop("multimodal_planner", None)
migrated_any = True
reasons.append("chat.multimodal_planner_moved_to_visual.planner_mode")
if visual is not None and "multimodal_replyer" in visual:
replyer_mode = _parse_replyer_mode(visual.get("multimodal_replyer"))
if "replyer_mode" not in visual and replyer_mode is not None:
visual["replyer_mode"] = replyer_mode
if "replyer_mode" in visual:
visual.pop("multimodal_replyer", None)
migrated_any = True
reasons.append("visual.multimodal_replyer_moved_to_visual.replyer_mode")
if visual is not None and chat is not None and "replyer_generator_type" in chat:
replyer_mode = _parse_replyer_mode(chat.get("replyer_generator_type"))
if "replyer_mode" not in visual and replyer_mode is not None:
visual["replyer_mode"] = replyer_mode
if "replyer_mode" in visual:
chat.pop("replyer_generator_type", None)
migrated_any = True
reasons.append("chat.replyer_generator_type_moved_to_visual.replyer_mode")
memory = _as_dict(data.get("memory")) memory = _as_dict(data.get("memory"))
if memory is not None and _migrate_target_item_list(memory, "global_memory_blacklist"): if memory is not None and _migrate_target_item_list(memory, "global_memory_blacklist"):

View File

@@ -149,10 +149,10 @@ class VisualConfig(ConfigBase):
default="auto", default="auto",
json_schema_extra={ json_schema_extra={
"x-widget": "select", "x-widget": "select",
"x-icon": "git-branch", "x-icon": "image",
}, },
) )
"""规划器模式auto根据模型信息自动选择text为纯文本模式multimodal为多模态模式""" """Planner 视觉模式text 仅文本multimodal 强制多模态auto 按模型能力自动选择"""
replyer_mode: Literal["text", "multimodal", "auto"] = Field( replyer_mode: Literal["text", "multimodal", "auto"] = Field(
default="auto", default="auto",
@@ -161,7 +161,7 @@ class VisualConfig(ConfigBase):
"x-icon": "git-branch", "x-icon": "git-branch",
}, },
) )
"""回复器模式auto根据模型信息自动选择text为纯文本模式multimodal为多模态模式""" """Replyer 视觉模式text 仅文本multimodal 强制多模态auto 按模型能力自动选择"""
visual_style: str = Field( visual_style: str = Field(
default="请用中文描述这张图片的内容。如果有文字请把文字描述概括出来请留意其主题直观感受输出为一段平文本最多30字请注意不要分点就输出一段文本", default="请用中文描述这张图片的内容。如果有文字请把文字描述概括出来请留意其主题直观感受输出为一段平文本最多30字请注意不要分点就输出一段文本",
@@ -239,12 +239,17 @@ class ChatConfig(ConfigBase):
) )
"""Planner 连续被新消息打断的最大次数0 表示不启用打断""" """Planner 连续被新消息打断的最大次数0 表示不启用打断"""
plan_reply_log_max_per_chat: int = Field(
default=1024,
json_schema_extra={
"x-widget": "input",
"x-icon": "file-text",
},
)
"""每个聊天流最大保存的Plan/Reply日志数量超过此数量时会自动删除最老的日志"""
group_chat_prompt: str = Field( group_chat_prompt: str = Field(
default=""" default="你需要控制自己发言的频率,如果是一对一聊天,可以以较均匀的频率发言;如果用户较多,不要每句都回复,控制回复频率,不要回复的太频繁!控制回复的频率,不要每个人的消息都回复。",
你正在qq群里聊天下面是群里正在聊的内容其中包含聊天记录和聊天中的图片。
回复尽量简短一些。最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。请注意把握聊天内容。
不要回复的太频繁!控制回复的频率,不要每个人的消息都回复,只回复你感兴趣的或者主动提及你的。
""",
json_schema_extra={ json_schema_extra={
"x-widget": "textarea", "x-widget": "textarea",
"x-icon": "users", "x-icon": "users",
@@ -253,11 +258,7 @@ class ChatConfig(ConfigBase):
"""_wrap_群聊通用注意事项""" """_wrap_群聊通用注意事项"""
private_chat_prompts: str = Field( private_chat_prompts: str = Field(
default=""" default="你需要控制自己发言的频率,可以以较均匀的频率发言。",
你正在聊天,下面是正在聊的内容,其中包含聊天记录和聊天中的图片。
回复尽量简短一些。请注意把握聊天内容。
请考虑对方的发言频率,想法,思考自己何时回复以及回复内容。
""",
json_schema_extra={ json_schema_extra={
"x-widget": "textarea", "x-widget": "textarea",
"x-icon": "user", "x-icon": "user",
@@ -414,6 +415,228 @@ class MemoryConfig(ConfigBase):
) )
"""Maisaka 内置长期记忆检索工具 query_memory 的默认返回条数""" """Maisaka 内置长期记忆检索工具 query_memory 的默认返回条数"""
person_fact_writeback_enabled: bool = Field(
default=True,
json_schema_extra={
"x-widget": "switch",
"x-icon": "user-round-pen",
},
)
"""是否在发送回复后自动提取并写回人物事实到长期记忆"""
chat_summary_writeback_enabled: bool = Field(
default=True,
json_schema_extra={
"x-widget": "switch",
"x-icon": "scroll-text",
},
)
"""是否在 Maisaka 聊天过程中按消息窗口自动写回聊天摘要到长期记忆"""
chat_summary_writeback_message_threshold: int = Field(
default=12,
ge=1,
json_schema_extra={
"x-widget": "input",
"x-icon": "messages-square",
},
)
"""自动写回聊天摘要的消息窗口阈值"""
chat_summary_writeback_context_length: int = Field(
default=50,
ge=1,
le=500,
json_schema_extra={
"x-widget": "input",
"x-icon": "rows-3",
},
)
"""自动写回聊天摘要时,从聊天流中回看的消息条数"""
feedback_correction_enabled: bool = Field(
default=False,
json_schema_extra={
"x-widget": "switch",
"x-icon": "message-circle-warning",
},
)
"""是否启用反馈驱动的延迟记忆纠错任务"""
feedback_correction_window_hours: float = Field(
default=12.0,
ge=0.1,
json_schema_extra={
"x-widget": "input",
"x-icon": "clock-4",
},
)
"""反馈窗口时长(小时),以 query_memory 执行时间为起点"""
feedback_correction_check_interval_minutes: int = Field(
default=30,
ge=1,
json_schema_extra={
"x-widget": "input",
"x-icon": "timer",
},
)
"""反馈纠错定时任务轮询间隔(分钟)"""
feedback_correction_batch_size: int = Field(
default=20,
ge=1,
le=200,
json_schema_extra={
"x-widget": "input",
"x-icon": "list-ordered",
},
)
"""反馈纠错每轮最大处理任务数"""
feedback_correction_auto_apply_threshold: float = Field(
default=0.85,
ge=0.0,
le=1.0,
json_schema_extra={
"x-widget": "slider",
"x-icon": "gauge",
"step": 0.01,
},
)
"""自动应用纠错动作的最低置信度阈值"""
feedback_correction_max_feedback_messages: int = Field(
default=30,
ge=1,
le=200,
json_schema_extra={
"x-widget": "input",
"x-icon": "messages-square",
},
)
"""每个纠错任务最多使用的窗口内用户反馈消息数"""
feedback_correction_prefilter_enabled: bool = Field(
default=True,
json_schema_extra={
"x-widget": "switch",
"x-icon": "filter",
},
)
"""是否启用纠错前置预筛(用于减少不必要的模型调用)"""
feedback_correction_paragraph_mark_enabled: bool = Field(
default=True,
json_schema_extra={
"x-widget": "switch",
"x-icon": "sticky-note",
},
)
"""是否为受影响 paragraph 写入已纠正旧事实标记"""
feedback_correction_paragraph_hard_filter_enabled: bool = Field(
default=True,
json_schema_extra={
"x-widget": "switch",
"x-icon": "eye-off",
},
)
"""是否在用户侧查询中硬过滤带有 stale 标记的 paragraph"""
feedback_correction_profile_refresh_enabled: bool = Field(
default=True,
json_schema_extra={
"x-widget": "switch",
"x-icon": "user-round-search",
},
)
"""是否在反馈纠错后将受影响人物画像加入刷新队列"""
feedback_correction_profile_force_refresh_on_read: bool = Field(
default=True,
json_schema_extra={
"x-widget": "switch",
"x-icon": "refresh-ccw",
},
)
"""人物画像处于脏队列时,读取是否强制刷新而不直接复用旧快照"""
feedback_correction_episode_rebuild_enabled: bool = Field(
default=True,
json_schema_extra={
"x-widget": "switch",
"x-icon": "clapperboard",
},
)
"""是否在反馈纠错后将受影响 source 加入 episode 重建队列"""
feedback_correction_episode_query_block_enabled: bool = Field(
default=True,
json_schema_extra={
"x-widget": "switch",
"x-icon": "ban",
},
)
"""episode source 处于重建队列时,是否对用户侧查询做屏蔽"""
feedback_correction_reconcile_interval_minutes: int = Field(
default=5,
ge=1,
json_schema_extra={
"x-widget": "input",
"x-icon": "repeat",
},
)
"""反馈纠错二阶段一致性后台协调任务轮询间隔(分钟)"""
feedback_correction_reconcile_batch_size: int = Field(
default=20,
ge=1,
le=200,
json_schema_extra={
"x-widget": "input",
"x-icon": "list-restart",
},
)
"""反馈纠错二阶段一致性每轮处理 profile/episode 队列的批大小"""
def model_post_init(self, context: Optional[dict] = None) -> None:
"""验证配置值"""
if self.feedback_correction_window_hours <= 0:
raise ValueError(
f"feedback_correction_window_hours 必须大于0当前值: {self.feedback_correction_window_hours}"
)
if self.feedback_correction_check_interval_minutes < 1:
raise ValueError(
"feedback_correction_check_interval_minutes 必须至少为1"
f"当前值: {self.feedback_correction_check_interval_minutes}"
)
if self.feedback_correction_batch_size < 1:
raise ValueError(
f"feedback_correction_batch_size 必须至少为1当前值: {self.feedback_correction_batch_size}"
)
if not 0 <= self.feedback_correction_auto_apply_threshold <= 1:
raise ValueError(
"feedback_correction_auto_apply_threshold 必须在 [0, 1] 之间,"
f"当前值: {self.feedback_correction_auto_apply_threshold}"
)
if self.feedback_correction_max_feedback_messages < 1:
raise ValueError(
"feedback_correction_max_feedback_messages 必须至少为1"
f"当前值: {self.feedback_correction_max_feedback_messages}"
)
if self.feedback_correction_reconcile_interval_minutes < 1:
raise ValueError(
"feedback_correction_reconcile_interval_minutes 必须至少为1"
f"当前值: {self.feedback_correction_reconcile_interval_minutes}"
)
if self.feedback_correction_reconcile_batch_size < 1:
raise ValueError(
"feedback_correction_reconcile_batch_size 必须至少为1"
f"当前值: {self.feedback_correction_reconcile_batch_size}"
)
return super().model_post_init(context)
class LearningItem(ConfigBase): class LearningItem(ConfigBase):
@@ -471,15 +694,6 @@ class LearningItem(ConfigBase):
) )
"""是否启用jargon学习""" """是否启用jargon学习"""
advanced_chosen: bool = Field(
default=False,
json_schema_extra={
"x-widget": "switch",
"x-icon": "sparkles",
},
)
"""是否启用基于子代理的二次表达方式选择"""
class ExpressionGroup(ConfigBase): class ExpressionGroup(ConfigBase):
"""表达互通组配置类,若列表为空代表全局共享""" """表达互通组配置类,若列表为空代表全局共享"""
@@ -509,7 +723,6 @@ class ExpressionConfig(ConfigBase):
use_expression=True, use_expression=True,
enable_learning=True, enable_learning=True,
enable_jargon_learning=True, enable_jargon_learning=True,
advanced_chosen=False,
) )
], ],
json_schema_extra={ json_schema_extra={
@@ -1381,6 +1594,35 @@ class MaiSakaConfig(ConfigBase):
) )
"""MaiSaka 使用的用户名称""" """MaiSaka 使用的用户名称"""
tool_filter_task_name: str = Field(
default="utils",
json_schema_extra={
"x-widget": "input",
"x-icon": "sparkles",
},
)
"""工具筛选预判使用的模型任务名"""
tool_filter_threshold: int = Field(
default=20,
ge=1,
json_schema_extra={
"x-widget": "input",
"x-icon": "filter",
},
)
"""当可用工具总数超过该阈值时,先进行一轮工具筛选"""
tool_filter_max_keep: int = Field(
default=5,
ge=1,
json_schema_extra={
"x-widget": "input",
"x-icon": "list-filter",
},
)
"""工具筛选阶段最多保留的非内置工具数量"""
show_image_path: bool = Field( show_image_path: bool = Field(
default=True, default=True,
json_schema_extra={ json_schema_extra={

View File

@@ -19,6 +19,7 @@ from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvo
from src.llm_models.exceptions import ReqAbortException from src.llm_models.exceptions import ReqAbortException
from src.llm_models.payload_content.tool_option import ToolCall from src.llm_models.payload_content.tool_option import ToolCall
from src.services import database_service as database_api from src.services import database_service as database_api
from src.services.memory_service import memory_service
from .builtin_tool import get_action_tool_specs from .builtin_tool import get_action_tool_specs
from .builtin_tool import build_builtin_tool_handlers as build_split_builtin_tool_handlers from .builtin_tool import build_builtin_tool_handlers as build_split_builtin_tool_handlers
@@ -1123,12 +1124,13 @@ class MaisakaReasoningEngine:
builtin_prompt = tool_spec.build_llm_description() builtin_prompt = tool_spec.build_llm_description()
try: try:
await database_api.store_tool_info( tool_record_payload = self._build_tool_record_payload(invocation, result, tool_spec)
saved_record = await database_api.store_tool_info(
chat_stream=self._runtime.chat_stream, chat_stream=self._runtime.chat_stream,
builtin_prompt=builtin_prompt, builtin_prompt=builtin_prompt,
display_prompt=self._build_tool_display_prompt(invocation, result, tool_spec), display_prompt=self._build_tool_display_prompt(invocation, result, tool_spec),
tool_id=invocation.call_id, tool_id=invocation.call_id,
tool_data=self._build_tool_record_payload(invocation, result, tool_spec), tool_data=tool_record_payload,
tool_name=invocation.tool_name, tool_name=invocation.tool_name,
tool_reasoning=invocation.reasoning, tool_reasoning=invocation.reasoning,
) )
@@ -1136,6 +1138,28 @@ class MaisakaReasoningEngine:
logger.exception( logger.exception(
f"{self._runtime.log_prefix} 写入工具记录失败: 工具={invocation.tool_name} 调用编号={invocation.call_id}" f"{self._runtime.log_prefix} 写入工具记录失败: 工具={invocation.tool_name} 调用编号={invocation.call_id}"
) )
return
if invocation.tool_name == "query_memory" and isinstance(saved_record, dict):
try:
enqueue_payload = await memory_service.enqueue_feedback_task(
query_tool_id=str(saved_record.get("tool_id") or invocation.call_id or "").strip(),
session_id=str(saved_record.get("session_id") or self._runtime.chat_stream.session_id or "").strip(),
query_timestamp=saved_record.get("timestamp"),
structured_content=tool_record_payload.get("structured_content")
if isinstance(tool_record_payload.get("structured_content"), dict)
else {},
)
except Exception:
logger.exception(
f"{self._runtime.log_prefix} 反馈纠错任务入队失败: tool_call_id={invocation.call_id}"
)
else:
if not bool(enqueue_payload.get("success")):
logger.debug(
f"{self._runtime.log_prefix} 反馈纠错任务未入队: "
f"tool_call_id={invocation.call_id} reason={enqueue_payload.get('reason', '')}"
)
def _append_tool_execution_result(self, tool_call: ToolCall, result: ToolExecutionResult) -> None: def _append_tool_execution_result(self, tool_call: ToolCall, result: ToolExecutionResult) -> None:
"""将统一工具执行结果写回 Maisaka 历史。 """将统一工具执行结果写回 Maisaka 历史。
@@ -1316,4 +1340,3 @@ class MaisakaReasoningEngine:
return True, tool_result_summaries, tool_monitor_results return True, tool_result_summaries, tool_monitor_results
return False, tool_result_summaries, tool_monitor_results return False, tool_result_summaries, tool_monitor_results

View File

@@ -1,16 +1,23 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime
from typing import Any, List, Optional
import asyncio import asyncio
import json import json
from typing import Any, List, Optional import pickle
import time
from json_repair import repair_json from json_repair import repair_json
from src.services import memory_service as memory_service_module
from src.chat.utils.utils import is_bot_self from src.chat.utils.utils import is_bot_self
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.message_repository import find_messages from src.common.message_repository import count_messages, find_messages
from src.config.config import global_config from src.config.config import global_config
from src.person_info.person_info import Person, get_person_id, store_person_memory_from_answer from src.person_info.person_info import Person, get_person_id, store_person_memory_from_answer
from src.services.memory_service import memory_service
from src.services.llm_service import LLMServiceClient from src.services.llm_service import LLMServiceClient
logger = get_logger("memory_flow_service") logger = get_logger("memory_flow_service")
@@ -210,27 +217,260 @@ class PersonFactWritebackService:
return False return False
@dataclass
class ChatSummaryWritebackState:
last_trigger_message_count: int = 0
last_trigger_time: float = 0.0
class ChatSummaryWritebackService:
def __init__(self) -> None:
self._queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=256)
self._worker_task: Optional[asyncio.Task] = None
self._stopping = False
self._states: dict[str, ChatSummaryWritebackState] = {}
async def start(self) -> None:
if self._worker_task is not None and not self._worker_task.done():
return
self._stopping = False
self._worker_task = asyncio.create_task(self._worker_loop(), name="memory_chat_summary_writeback")
async def shutdown(self) -> None:
self._stopping = True
worker = self._worker_task
self._worker_task = None
if worker is None:
return
worker.cancel()
try:
await worker
except asyncio.CancelledError:
pass
except Exception as exc:
logger.warning("关闭聊天摘要写回 worker 失败: %s", exc)
async def enqueue(self, message: Any) -> None:
if not bool(getattr(global_config.memory, "chat_summary_writeback_enabled", True)):
return
if self._stopping:
return
try:
self._queue.put_nowait(message)
except asyncio.QueueFull:
logger.warning("聊天摘要写回队列已满,跳过本次触发")
async def _worker_loop(self) -> None:
try:
while not self._stopping:
message = await self._queue.get()
try:
await self._handle_message(message)
except Exception as exc:
logger.warning("聊天摘要写回处理失败: %s", exc, exc_info=True)
finally:
self._queue.task_done()
except asyncio.CancelledError:
raise
async def _handle_message(self, message: Any) -> None:
session_id = self._resolve_session_id(message)
if not session_id:
return
total_message_count = count_messages(session_id=session_id)
if total_message_count <= 0:
return
threshold = self._message_threshold()
state = self._states.get(session_id)
if state is None:
restored_count = await self._load_last_trigger_message_count(
session_id=session_id,
total_message_count=total_message_count,
)
state = ChatSummaryWritebackState(
last_trigger_message_count=restored_count,
last_trigger_time=time.time() if restored_count > 0 else 0.0,
)
self._states[session_id] = state
pending_message_count = max(0, total_message_count - state.last_trigger_message_count)
if pending_message_count < threshold:
return
context_length = self._context_length()
message_time = self._extract_message_timestamp(message)
result = await memory_service.ingest_summary(
external_id=f"chat_auto_summary:{session_id}:{total_message_count}",
chat_id=session_id,
text="",
participants=[],
time_end=message_time,
metadata={
"generate_from_chat": True,
"context_length": context_length,
"writeback_source": "memory_flow_service",
"trigger": "message_threshold",
"trigger_message_count": total_message_count,
},
respect_filter=True,
user_id=self._extract_session_user_id(message),
group_id=self._extract_session_group_id(message),
)
if not getattr(result, "success", False):
logger.warning(
"聊天摘要自动写回失败: session_id=%s detail=%s",
session_id,
getattr(result, "detail", ""),
)
return
state.last_trigger_message_count = total_message_count
state.last_trigger_time = time.time()
logger.info(
"聊天摘要自动写回成功: session_id=%s trigger=%s total_messages=%s context_length=%s detail=%s",
session_id,
"message_threshold",
total_message_count,
context_length,
getattr(result, "detail", ""),
)
async def _load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
"""从已落库的聊天摘要恢复触发游标,避免服务重启后重复摘要。"""
try:
runtime_manager = getattr(memory_service_module, "a_memorix_host_service", None)
ensure_kernel = getattr(runtime_manager, "_ensure_kernel", None)
if not callable(ensure_kernel):
return 0
kernel = await ensure_kernel()
metadata_store = getattr(kernel, "metadata_store", None)
if metadata_store is None:
return 0
paragraphs = metadata_store.get_paragraphs_by_source(f"chat_summary:{session_id}")
if not paragraphs:
return 0
latest_paragraph = max(paragraphs, key=self._paragraph_created_at)
metadata = self._paragraph_metadata(latest_paragraph)
trigger_message_count = self._coerce_positive_int(metadata.get("trigger_message_count"))
if trigger_message_count > 0:
return min(total_message_count, trigger_message_count)
# 兼容旧摘要数据:没有触发计数时,只能退化为对齐当前计数,
# 至少避免重启后立刻重复写入一条相近摘要。
return total_message_count
except Exception as exc:
logger.debug("恢复聊天摘要写回游标失败: session_id=%s error=%s", session_id, exc)
return 0
@staticmethod
def _paragraph_created_at(paragraph: dict[str, Any]) -> float:
try:
return float(paragraph.get("created_at") or 0.0)
except Exception:
return 0.0
@staticmethod
def _paragraph_metadata(paragraph: dict[str, Any]) -> dict[str, Any]:
metadata = paragraph.get("metadata")
if isinstance(metadata, dict):
return metadata
if isinstance(metadata, (bytes, bytearray)):
try:
parsed = pickle.loads(metadata)
except Exception:
return {}
return parsed if isinstance(parsed, dict) else {}
return {}
@staticmethod
def _coerce_positive_int(value: Any) -> int:
try:
number = int(value or 0)
except Exception:
return 0
return max(0, number)
@staticmethod
def _resolve_session_id(message: Any) -> str:
return str(
getattr(message, "session_id", "")
or getattr(getattr(message, "session", None), "session_id", "")
or ""
).strip()
@staticmethod
def _extract_session_user_id(message: Any) -> str:
return str(
getattr(getattr(message, "session", None), "user_id", "")
or getattr(message, "user_id", "")
or ""
).strip()
@staticmethod
def _extract_session_group_id(message: Any) -> str:
return str(
getattr(getattr(message, "session", None), "group_id", "")
or getattr(message, "group_id", "")
or ""
).strip()
@staticmethod
def _extract_message_timestamp(message: Any) -> float | None:
raw_timestamp = getattr(message, "timestamp", None)
if isinstance(raw_timestamp, datetime):
return raw_timestamp.timestamp()
if hasattr(raw_timestamp, "timestamp") and callable(raw_timestamp.timestamp):
try:
return float(raw_timestamp.timestamp())
except Exception:
return None
if isinstance(raw_timestamp, (int, float)):
return float(raw_timestamp)
return None
@staticmethod
def _message_threshold() -> int:
return max(1, int(getattr(global_config.memory, "chat_summary_writeback_message_threshold", 12) or 12))
@staticmethod
def _context_length() -> int:
return max(1, int(getattr(global_config.memory, "chat_summary_writeback_context_length", 50) or 50))
class MemoryAutomationService: class MemoryAutomationService:
def __init__(self) -> None: def __init__(self) -> None:
self.fact_writeback = PersonFactWritebackService() self.fact_writeback = PersonFactWritebackService()
self.chat_summary_writeback = ChatSummaryWritebackService()
self._started = False self._started = False
async def start(self) -> None: async def start(self) -> None:
if self._started: if self._started:
return return
await self.fact_writeback.start() await self.fact_writeback.start()
await self.chat_summary_writeback.start()
self._started = True self._started = True
async def shutdown(self) -> None: async def shutdown(self) -> None:
if not self._started: if not self._started:
return return
await self.chat_summary_writeback.shutdown()
await self.fact_writeback.shutdown() await self.fact_writeback.shutdown()
self._started = False self._started = False
async def on_incoming_message(self, message: Any) -> None:
del message
if not self._started:
await self.start()
async def on_message_sent(self, message: Any) -> None: async def on_message_sent(self, message: Any) -> None:
if not self._started: if not self._started:
await self.start() await self.start()
await self.fact_writeback.enqueue(message) await self.fact_writeback.enqueue(message)
await self.chat_summary_writeback.enqueue(message)
memory_automation_service = MemoryAutomationService() memory_automation_service = MemoryAutomationService()

View File

@@ -233,6 +233,30 @@ class MemoryService:
logger.warning("长期记忆搜索失败: %s", exc) logger.warning("长期记忆搜索失败: %s", exc)
return MemorySearchResult(success=False, error=str(exc)) return MemorySearchResult(success=False, error=str(exc))
async def enqueue_feedback_task(
self,
*,
query_tool_id: str,
session_id: str,
query_timestamp: Any = None,
structured_content: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
try:
payload = await self._invoke(
"enqueue_feedback_task",
{
"query_tool_id": str(query_tool_id or "").strip(),
"session_id": str(session_id or "").strip(),
"query_timestamp": query_timestamp,
"structured_content": structured_content if isinstance(structured_content, dict) else {},
},
timeout_ms=10000,
)
except Exception as exc:
logger.warning("反馈纠错任务入队失败: %s", exc)
return {"success": False, "queued": False, "reason": str(exc)}
return payload if isinstance(payload, dict) else {"success": False, "queued": False, "reason": "invalid_payload"}
async def ingest_summary( async def ingest_summary(
self, self,
*, *,
@@ -388,6 +412,13 @@ class MemoryService:
logger.warning("画像管理调用失败: %s", exc) logger.warning("画像管理调用失败: %s", exc)
return {"success": False, "error": str(exc)} return {"success": False, "error": str(exc)}
async def feedback_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
try:
return await self._invoke_admin("memory_feedback_admin", action=action, **kwargs)
except Exception as exc:
logger.warning("反馈纠错管理调用失败: %s", exc)
return {"success": False, "error": str(exc)}
async def runtime_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: async def runtime_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
try: try:
return await self._invoke_admin("memory_runtime_admin", action=action, **kwargs) return await self._invoke_admin("memory_runtime_admin", action=action, **kwargs)

View File

@@ -205,6 +205,12 @@ def _setup_static_files(app: FastAPI):
def _resolve_static_path() -> Path | None: def _resolve_static_path() -> Path | None:
# 开发环境优先允许复用仓库里的现成 dist
base_dir = _get_project_root()
static_path = base_dir / "dashboard" / "dist"
if static_path.is_dir() and (static_path / "index.html").exists():
return static_path
try: try:
module = import_module("maibot_dashboard") module = import_module("maibot_dashboard")
get_dist_path = getattr(module, "get_dist_path", None) get_dist_path = getattr(module, "get_dist_path", None)
@@ -215,11 +221,6 @@ def _resolve_static_path() -> Path | None:
except Exception: except Exception:
pass pass
# 开发环境允许复用仓库里的现成 dist但不再在用户机器上触发任何前端自愈构建。
base_dir = _get_project_root()
static_path = base_dir / "dashboard" / "dist"
if static_path.exists():
return static_path
return None return None

View File

@@ -124,6 +124,11 @@ class DeletePurgeRequest(BaseModel):
limit: int = Field(1000, ge=1, le=5000) limit: int = Field(1000, ge=1, le=5000)
class FeedbackRollbackRequest(BaseModel):
requested_by: str = "webui"
reason: str = ""
def _build_import_guide_markdown(settings: dict[str, Any]) -> str: def _build_import_guide_markdown(settings: dict[str, Any]) -> str:
path_aliases_raw = settings.get("path_aliases") path_aliases_raw = settings.get("path_aliases")
path_aliases = path_aliases_raw if isinstance(path_aliases_raw, dict) else {} path_aliases = path_aliases_raw if isinstance(path_aliases_raw, dict) else {}
@@ -359,6 +364,31 @@ async def _profile_delete_override(person_id: str) -> dict:
return await memory_service.profile_admin(action="delete_override", person_id=person_id) return await memory_service.profile_admin(action="delete_override", person_id=person_id)
async def _feedback_list(limit: int, status: str, rollback_status: str, query: str) -> dict:
statuses = [item.strip() for item in str(status or "").split(",") if item.strip()]
rollback_statuses = [item.strip() for item in str(rollback_status or "").split(",") if item.strip()]
return await memory_service.feedback_admin(
action="list",
limit=limit,
statuses=statuses,
rollback_statuses=rollback_statuses,
query=query,
)
async def _feedback_get(task_id: int) -> dict:
return await memory_service.feedback_admin(action="get", task_id=task_id)
async def _feedback_rollback(task_id: int, payload: FeedbackRollbackRequest) -> dict:
return await memory_service.feedback_admin(
action="rollback",
task_id=task_id,
requested_by=payload.requested_by,
reason=payload.reason,
)
async def _runtime_save() -> dict: async def _runtime_save() -> dict:
return await memory_service.runtime_admin(action="save") return await memory_service.runtime_admin(action="save")
@@ -830,6 +860,26 @@ async def delete_memory_profile_override(person_id: str):
return await _profile_delete_override(person_id) return await _profile_delete_override(person_id)
@router.get("/feedback-corrections")
async def list_memory_feedback_corrections(
limit: int = Query(50, ge=1, le=200),
status: str = Query(""),
rollback_status: str = Query(""),
query: str = Query(""),
):
return await _feedback_list(limit, status, rollback_status, query)
@router.get("/feedback-corrections/{task_id}")
async def get_memory_feedback_correction(task_id: int):
return await _feedback_get(task_id)
@router.post("/feedback-corrections/{task_id}/rollback")
async def rollback_memory_feedback_correction(task_id: int, payload: FeedbackRollbackRequest):
return await _feedback_rollback(task_id, payload)
@router.post("/runtime/save") @router.post("/runtime/save")
async def save_memory_runtime(): async def save_memory_runtime():
return await _runtime_save() return await _runtime_save()