feat:完善长期记忆控制台导入链路与联调测试

summary:\n- 扩展长期记忆控制台导入、调优与删除相关 UI/接口,补充中文化展示与任务细粒度状态管理\n- 强化 memory API 与后端路由能力,补齐导入任务、图谱检索、配置与运行态相关字段\n- 新增与增强前后端测试,覆盖导入多文件类型、检索、调优、删除及图谱查询关键路径

description:\n- dashboard: 重构 knowledge-base 页面与 memory-api,统一任务队列、分块分页、来源删除恢复、调优闭环交互\n- backend: 扩展 webui memory 路由与 A_Memorix 内核检索逻辑,完善服务侧能力与配置 schema\n- tests: 增加 webui 集成测试和 kernel 单测,提升导入/检索/调优/删除全流程回归保障
This commit is contained in:
DawnARC
2026-04-03 19:50:08 +08:00
parent eac5495d00
commit da95b06f96
18 changed files with 4045 additions and 299 deletions

View File

@@ -53,6 +53,31 @@ export interface MemoryGraphPayload {
total_edges: number
}
export interface MemoryGraphSearchItem {
type: 'entity' | 'relation'
title: string
matched_field: string
matched_value: string
entity_name?: string
entity_hash?: string
appearance_count?: number
subject?: string
predicate?: string
object?: string
relation_hash?: string
confidence?: number
created_at?: number
}
export interface MemoryGraphSearchPayload {
success: boolean
query: string
limit: number
count: number
items: MemoryGraphSearchItem[]
error?: string
}
export interface MemoryGraphRelationDetailPayload {
hash: string
subject: string
@@ -185,6 +210,8 @@ export interface MemoryRawConfigPayload {
success: boolean
config: string
path: string
exists?: boolean
using_default?: boolean
}
export interface MemoryConfigSchemaPayload {
@@ -198,7 +225,7 @@ export interface MemoryImportGuidePayload {
content: string
source?: string
path?: string
settings?: Record<string, unknown>
settings?: MemoryImportSettings
}
export interface MemoryTaskPayload {
@@ -217,6 +244,158 @@ export interface MemoryTaskListPayload {
settings?: Record<string, unknown>
}
export type MemoryImportInputMode = 'text' | 'json'
export type MemoryImportTaskKind =
| 'upload'
| 'paste'
| 'raw_scan'
| 'lpmm_openie'
| 'lpmm_convert'
| 'temporal_backfill'
| 'maibot_migration'
export interface MemoryImportSettings {
max_queue_size?: number
max_files_per_task?: number
max_file_size_mb?: number
max_paste_chars?: number
default_file_concurrency?: number
default_chunk_concurrency?: number
max_file_concurrency?: number
max_chunk_concurrency?: number
poll_interval_ms?: number
maibot_source_db_default?: string
maibot_target_data_dir?: string
path_aliases?: Record<string, string>
llm_retry?: Record<string, number>
convert_enable_staging_switch?: boolean
convert_keep_backup_count?: number
}
export interface MemoryImportSettingsPayload {
success: boolean
settings: MemoryImportSettings
}
export interface MemoryImportPathAliasesPayload {
success: boolean
path_aliases: Record<string, string>
}
export interface MemoryImportResolvePathPayload {
success?: boolean
alias: string
relative_path: string
resolved_path: string
exists: boolean
is_file: boolean
is_dir: boolean
error?: string
}
export interface MemoryImportChunkPayload {
chunk_id: string
index: number
chunk_type: string
status: string
step: string
failed_at: string
retryable: boolean
error: string
progress: number
content_preview: string
updated_at: number
}
export interface MemoryImportFilePayload {
file_id: string
name: string
source_kind: string
input_mode: MemoryImportInputMode
status: string
current_step: string
detected_strategy_type: string
total_chunks: number
done_chunks: number
failed_chunks: number
cancelled_chunks: number
progress: number
error: string
created_at: number
updated_at: number
source_path?: string
content_hash?: string
retry_chunk_indexes?: number[]
retry_mode?: string
chunks?: MemoryImportChunkPayload[]
}
export interface MemoryImportRetrySummary {
chunk_retry_files?: number
chunk_retry_chunks?: number
file_fallback_files?: number
skipped_files?: number
parent_task_id?: string
skipped_details?: Array<Record<string, string>>
}
export interface MemoryImportTaskPayload extends MemoryTaskPayload {
task_id: string
source: string
status: string
current_step: string
total_chunks: number
done_chunks: number
failed_chunks: number
cancelled_chunks: number
progress: number
error: string
file_count: number
created_at: number
started_at?: number | null
finished_at?: number | null
updated_at: number
task_kind?: MemoryImportTaskKind | string
schema_detected?: string
artifact_paths?: Record<string, string>
rollback_info?: Record<string, unknown>
retry_parent_task_id?: string
retry_summary?: MemoryImportRetrySummary
params?: Record<string, unknown>
files?: MemoryImportFilePayload[]
}
export interface MemoryImportTaskListPayload {
success: boolean
items: MemoryImportTaskPayload[]
count?: number
settings?: MemoryImportSettings
}
export interface MemoryImportTaskDetailPayload {
success: boolean
task?: MemoryImportTaskPayload
error?: string
}
export interface MemoryImportChunkListPayload {
success: boolean
task_id?: string
file_id?: string
offset?: number
limit?: number
total?: number
items?: MemoryImportChunkPayload[]
error?: string
}
export interface MemoryImportActionPayload {
success: boolean
task?: MemoryImportTaskPayload
error?: string
}
export interface MemoryTuningProfilePayload {
success: boolean
profile?: Record<string, unknown>
@@ -335,6 +514,17 @@ export async function getMemoryGraph(limit: number = 120): Promise<MemoryGraphPa
return requestJson<MemoryGraphPayload>(`/graph?limit=${limit}`)
}
export async function getMemoryGraphSearch(
query: string,
limit: number = 50,
): Promise<MemoryGraphSearchPayload> {
const params = new URLSearchParams({
query,
limit: String(limit),
})
return requestJson<MemoryGraphSearchPayload>(`/graph/search?${params.toString()}`)
}
export async function getMemoryGraphNodeDetail(
nodeId: string,
options?: {
@@ -466,16 +656,120 @@ export async function getMemoryImportGuide(): Promise<MemoryImportGuidePayload>
return requestJson<MemoryImportGuidePayload>('/import/guide')
}
export async function getMemoryImportSettings(): Promise<Record<string, unknown>> {
return requestJson('/import/settings')
export async function getMemoryImportSettings(): Promise<MemoryImportSettingsPayload> {
return requestJson<MemoryImportSettingsPayload>('/import/settings')
}
export async function getMemoryImportTasks(limit: number = 20): Promise<MemoryTaskListPayload> {
return requestJson<MemoryTaskListPayload>(`/import/tasks?limit=${limit}`)
export async function getMemoryImportPathAliases(): Promise<MemoryImportPathAliasesPayload> {
return requestJson<MemoryImportPathAliasesPayload>('/import/path-aliases')
}
export async function createMemoryPasteImport(payload: Record<string, unknown>): Promise<{ success: boolean; task?: MemoryTaskPayload }> {
return requestJson('/import/paste', {
export async function resolveMemoryImportPath(payload: {
alias: string
relative_path?: string
must_exist?: boolean
}): Promise<MemoryImportResolvePathPayload> {
return requestJson<MemoryImportResolvePathPayload>('/import/resolve-path', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(payload),
})
}
export async function getMemoryImportTasks(limit: number = 20): Promise<MemoryImportTaskListPayload> {
return requestJson<MemoryImportTaskListPayload>(`/import/tasks?limit=${limit}`)
}
export async function getMemoryImportTask(taskId: string, includeChunks: boolean = false): Promise<MemoryImportTaskDetailPayload> {
return requestJson<MemoryImportTaskDetailPayload>(
`/import/tasks/${encodeURIComponent(taskId)}?include_chunks=${includeChunks ? 'true' : 'false'}`,
)
}
export async function getMemoryImportTaskChunks(
taskId: string,
fileId: string,
offset: number = 0,
limit: number = 50,
): Promise<MemoryImportChunkListPayload> {
return requestJson<MemoryImportChunkListPayload>(
`/import/tasks/${encodeURIComponent(taskId)}/chunks/${encodeURIComponent(fileId)}?offset=${offset}&limit=${limit}`,
)
}
export async function createMemoryUploadImport(files: File[], payload: Record<string, unknown>): Promise<MemoryImportActionPayload> {
const formData = new FormData()
files.forEach((file) => {
formData.append('files', file)
})
formData.append('payload_json', JSON.stringify(payload))
return requestJson<MemoryImportActionPayload>('/import/upload', {
method: 'POST',
body: formData,
})
}
export async function createMemoryPasteImport(payload: Record<string, unknown>): Promise<MemoryImportActionPayload> {
return requestJson<MemoryImportActionPayload>('/import/paste', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(payload),
})
}
export async function createMemoryRawScanImport(payload: Record<string, unknown>): Promise<MemoryImportActionPayload> {
return requestJson<MemoryImportActionPayload>('/import/raw-scan', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(payload),
})
}
export async function createMemoryLpmmOpenieImport(payload: Record<string, unknown>): Promise<MemoryImportActionPayload> {
return requestJson<MemoryImportActionPayload>('/import/lpmm-openie', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(payload),
})
}
export async function createMemoryLpmmConvertImport(payload: Record<string, unknown>): Promise<MemoryImportActionPayload> {
return requestJson<MemoryImportActionPayload>('/import/lpmm-convert', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(payload),
})
}
export async function createMemoryTemporalBackfillImport(payload: Record<string, unknown>): Promise<MemoryImportActionPayload> {
return requestJson<MemoryImportActionPayload>('/import/temporal-backfill', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(payload),
})
}
export async function createMemoryMaibotMigrationImport(payload: Record<string, unknown>): Promise<MemoryImportActionPayload> {
return requestJson<MemoryImportActionPayload>('/import/maibot-migration', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(payload),
})
}
export async function cancelMemoryImportTask(taskId: string): Promise<MemoryImportActionPayload> {
return requestJson<MemoryImportActionPayload>(`/import/tasks/${encodeURIComponent(taskId)}/cancel`, {
method: 'POST',
})
}
export async function retryMemoryImportTask(
taskId: string,
payload: {
overrides?: Record<string, unknown>
} = {},
): Promise<MemoryImportActionPayload> {
return requestJson<MemoryImportActionPayload>(`/import/tasks/${encodeURIComponent(taskId)}/retry`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(payload),

View File

@@ -1,4 +1,4 @@
import { render, screen } from '@testing-library/react'
import { act, render, screen, waitFor, within } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { beforeEach, describe, expect, it, vi } from 'vitest'
@@ -28,12 +28,25 @@ vi.mock('@/components/memory/MemoryConfigEditor', () => ({
vi.mock('@/components/memory/MemoryDeleteDialog', () => ({
MemoryDeleteDialog: ({
open,
onExecute,
onRestore,
preview,
result,
}: {
open: boolean
preview?: { mode?: string; item_count?: number } | null
result?: { operation_id?: string } | null
onExecute?: () => void
onRestore?: () => void
}) => (
open ? <div data-testid="memory-delete-dialog">{`delete:${preview?.mode ?? 'none'}:${preview?.item_count ?? 0}`}</div> : null
open ? (
<div data-testid="memory-delete-dialog">
<div>{`preview:${preview?.mode ?? 'none'}:${preview?.item_count ?? 0}`}</div>
<div>{`result:${result?.operation_id ?? 'none'}`}</div>
<button type="button" onClick={onExecute}></button>
<button type="button" onClick={onRestore}></button>
</div>
) : null
),
}))
@@ -41,26 +54,104 @@ vi.mock('@/lib/memory-api', () => ({
getMemoryConfigSchema: vi.fn(),
getMemoryConfig: vi.fn(),
getMemoryConfigRaw: vi.fn(),
getMemoryDeleteOperation: vi.fn(),
getMemoryRuntimeConfig: vi.fn(),
getMemoryImportGuide: vi.fn(),
getMemoryImportSettings: vi.fn(),
getMemoryImportPathAliases: vi.fn(),
getMemoryImportTasks: vi.fn(),
getMemoryTuningProfile: vi.fn(),
getMemoryTuningTasks: vi.fn(),
getMemorySources: vi.fn(),
getMemoryDeleteOperations: vi.fn(),
getMemoryImportTask: vi.fn(),
getMemoryImportTaskChunks: vi.fn(),
createMemoryUploadImport: vi.fn(),
createMemoryPasteImport: vi.fn(),
createMemoryRawScanImport: vi.fn(),
createMemoryLpmmOpenieImport: vi.fn(),
createMemoryLpmmConvertImport: vi.fn(),
createMemoryTemporalBackfillImport: vi.fn(),
createMemoryMaibotMigrationImport: vi.fn(),
cancelMemoryImportTask: vi.fn(),
retryMemoryImportTask: vi.fn(),
resolveMemoryImportPath: vi.fn(),
refreshMemoryRuntimeSelfCheck: vi.fn(),
updateMemoryConfig: vi.fn(),
updateMemoryConfigRaw: vi.fn(),
createMemoryPasteImport: vi.fn(),
getMemoryTuningProfile: vi.fn(),
getMemoryTuningTasks: vi.fn(),
createMemoryTuningTask: vi.fn(),
applyBestMemoryTuningProfile: vi.fn(),
getMemorySources: vi.fn(),
getMemoryDeleteOperations: vi.fn(),
getMemoryDeleteOperation: vi.fn(),
previewMemoryDelete: vi.fn(),
executeMemoryDelete: vi.fn(),
restoreMemoryDelete: vi.fn(),
}))
describe('KnowledgeBasePage', () => {
function mockImportTask(taskId: string, status: string = 'running'): memoryApi.MemoryImportTaskPayload {
return {
task_id: taskId,
source: 'webui',
status,
current_step: status === 'completed' ? 'completed' : 'running',
total_chunks: 120,
done_chunks: status === 'completed' ? 120 : 36,
failed_chunks: status === 'completed' ? 0 : 2,
cancelled_chunks: 0,
progress: status === 'completed' ? 100 : 30,
error: '',
file_count: 2,
created_at: 1_710_000_000,
started_at: 1_710_000_001,
finished_at: status === 'completed' ? 1_710_000_099 : null,
updated_at: 1_710_000_100,
task_kind: 'paste',
params: {},
files: [],
}
}
function mockImportDetail(taskId: string): memoryApi.MemoryImportTaskPayload {
return {
...mockImportTask(taskId),
files: [
{
file_id: 'file-alpha',
name: 'alpha.txt',
source_kind: 'paste',
input_mode: 'text',
status: 'running',
current_step: 'running',
detected_strategy_type: 'auto',
total_chunks: 80,
done_chunks: 30,
failed_chunks: 1,
cancelled_chunks: 0,
progress: 37.5,
error: '',
created_at: 1_710_000_000,
updated_at: 1_710_000_100,
},
{
file_id: 'file-beta',
name: 'beta.txt',
source_kind: 'paste',
input_mode: 'text',
status: 'failed',
current_step: 'extracting',
detected_strategy_type: 'auto',
total_chunks: 40,
done_chunks: 6,
failed_chunks: 4,
cancelled_chunks: 0,
progress: 25,
error: 'mock error',
created_at: 1_710_000_000,
updated_at: 1_710_000_100,
},
],
}
}
describe('KnowledgeBasePage import workflow', () => {
beforeEach(() => {
navigateMock.mockReset()
toastMock.mockReset()
@@ -119,14 +210,113 @@ describe('KnowledgeBasePage', () => {
paragraph_vector_backfill_failed: 1,
paragraph_vector_backfill_done: 3,
})
vi.mocked(memoryApi.getMemoryImportGuide).mockResolvedValue({
success: true,
content: '# 导入指南\n导入说明',
})
vi.mocked(memoryApi.getMemoryImportSettings).mockResolvedValue({
success: true,
settings: {
max_paste_chars: 200_000,
max_file_concurrency: 8,
max_chunk_concurrency: 16,
default_file_concurrency: 2,
default_chunk_concurrency: 4,
poll_interval_ms: 60_000,
maibot_source_db_default: 'data/maibot.db',
},
})
vi.mocked(memoryApi.getMemoryImportPathAliases).mockResolvedValue({
success: true,
path_aliases: {
lpmm: 'data/lpmm',
plugin_data: 'data/plugins/a-dawn.a-memorix',
raw: 'data/raw',
},
})
vi.mocked(memoryApi.getMemoryImportTasks).mockResolvedValue({
success: true,
items: [{ task_id: 'import-1', status: 'done', mode: 'text' }],
items: [
mockImportTask('import-run-1', 'running'),
mockImportTask('import-queued-1', 'queued'),
mockImportTask('import-done-1', 'completed'),
],
})
vi.mocked(memoryApi.getMemoryImportTask).mockResolvedValue({
success: true,
task: mockImportDetail('import-run-1'),
})
vi.mocked(memoryApi.getMemoryImportTaskChunks).mockImplementation(async (_taskId, fileId, offset = 0) => ({
success: true,
task_id: 'import-run-1',
file_id: fileId,
offset,
limit: 50,
total: 120,
items: [
{
chunk_id: `${fileId}-${offset + 0}`,
index: offset + 0,
chunk_type: 'text',
status: 'running',
step: 'extracting',
failed_at: '',
retryable: true,
error: '',
progress: 50,
content_preview: `chunk-preview-${offset + 0}`,
updated_at: 1_710_000_111,
},
],
}))
vi.mocked(memoryApi.createMemoryUploadImport).mockResolvedValue({
success: true,
task: mockImportTask('upload-task-1', 'queued'),
})
vi.mocked(memoryApi.createMemoryPasteImport).mockResolvedValue({
success: true,
task: mockImportTask('paste-task-1', 'queued'),
})
vi.mocked(memoryApi.createMemoryRawScanImport).mockResolvedValue({
success: true,
task: mockImportTask('raw-task-1', 'queued'),
})
vi.mocked(memoryApi.createMemoryLpmmOpenieImport).mockResolvedValue({
success: true,
task: mockImportTask('openie-task-1', 'queued'),
})
vi.mocked(memoryApi.createMemoryLpmmConvertImport).mockResolvedValue({
success: true,
task: mockImportTask('convert-task-1', 'queued'),
})
vi.mocked(memoryApi.createMemoryTemporalBackfillImport).mockResolvedValue({
success: true,
task: mockImportTask('backfill-task-1', 'queued'),
})
vi.mocked(memoryApi.createMemoryMaibotMigrationImport).mockResolvedValue({
success: true,
task: mockImportTask('migration-task-1', 'queued'),
})
vi.mocked(memoryApi.cancelMemoryImportTask).mockResolvedValue({
success: true,
task: mockImportTask('import-run-1', 'cancel_requested'),
})
vi.mocked(memoryApi.retryMemoryImportTask).mockResolvedValue({
success: true,
task: mockImportTask('retry-task-1', 'queued'),
})
vi.mocked(memoryApi.resolveMemoryImportPath).mockResolvedValue({
success: true,
alias: 'raw',
relative_path: 'exports',
resolved_path: 'D:/Dev/rdev/MaiBot/data/raw/exports',
exists: true,
is_file: false,
is_dir: true,
})
vi.mocked(memoryApi.getMemoryTuningProfile).mockResolvedValue({
success: true,
profile: { retrieval: { top_k: 10 } },
@@ -136,13 +326,13 @@ describe('KnowledgeBasePage', () => {
success: true,
items: [{ task_id: 'tune-1', status: 'done' }],
})
vi.mocked(memoryApi.createMemoryTuningTask).mockResolvedValue({ success: true } as never)
vi.mocked(memoryApi.applyBestMemoryTuningProfile).mockResolvedValue({ success: true } as never)
vi.mocked(memoryApi.getMemorySources).mockResolvedValue({
success: true,
items: [
{ source: 'demo-1', paragraph_count: 2, relation_count: 1 },
{ source: 'demo-2', paragraph_count: 1, relation_count: 0 },
],
count: 2,
items: [{ source: 'demo-1', paragraph_count: 2, relation_count: 1 }],
count: 1,
})
vi.mocked(memoryApi.getMemoryDeleteOperations).mockResolvedValue({
success: true,
@@ -164,31 +354,9 @@ describe('KnowledgeBasePage', () => {
status: 'executed',
selector: { sources: ['demo-1'] },
summary: { counts: { paragraphs: 2, relations: 1, sources: 1 }, sources: ['demo-1'] },
items: [
{
item_type: 'paragraph',
item_hash: 'p-1',
item_key: 'paragraph:p-1',
payload: { paragraph: { source: 'demo-1', content: '这是用于测试删除详情展示的段落内容。' } },
},
],
items: [],
},
})
vi.mocked(memoryApi.refreshMemoryRuntimeSelfCheck).mockResolvedValue({
success: true,
report: { ok: true },
})
vi.mocked(memoryApi.updateMemoryConfig).mockResolvedValue({
success: true,
config_path: 'config/a_memorix.toml',
} as never)
vi.mocked(memoryApi.updateMemoryConfigRaw).mockResolvedValue({
success: true,
config_path: 'config/a_memorix.toml',
} as never)
vi.mocked(memoryApi.createMemoryPasteImport).mockResolvedValue({ success: true } as never)
vi.mocked(memoryApi.createMemoryTuningTask).mockResolvedValue({ success: true } as never)
vi.mocked(memoryApi.applyBestMemoryTuningProfile).mockResolvedValue({ success: true } as never)
vi.mocked(memoryApi.previewMemoryDelete).mockResolvedValue({
success: true,
mode: 'source',
@@ -212,59 +380,243 @@ describe('KnowledgeBasePage', () => {
deleted_source_count: 1,
} as never)
vi.mocked(memoryApi.restoreMemoryDelete).mockResolvedValue({ success: true } as never)
vi.mocked(memoryApi.refreshMemoryRuntimeSelfCheck).mockResolvedValue({
success: true,
report: { ok: true },
})
vi.mocked(memoryApi.updateMemoryConfig).mockResolvedValue({ success: true } as never)
vi.mocked(memoryApi.updateMemoryConfigRaw).mockResolvedValue({ success: true } as never)
})
it('renders long-term memory console and key tabs', async () => {
it('loads import settings/guide/tasks on first render', async () => {
const user = userEvent.setup()
render(<KnowledgeBasePage />)
expect(await screen.findByText('长期记忆控制台')).toBeInTheDocument()
expect(screen.getByText(/config\/a_memorix\.toml/)).toBeInTheDocument()
expect(screen.getByText('运行就绪')).toBeInTheDocument()
await user.click(screen.getByRole('tab', { name: '配置' }))
expect(await screen.findByTestId('memory-config-editor')).toBeInTheDocument()
expect(await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })).toBeInTheDocument()
await user.click(screen.getByRole('tab', { name: '导入' }))
expect(await screen.findByText(/导入说明/)).toBeInTheDocument()
expect(screen.getByText('import-1')).toBeInTheDocument()
expect(await screen.findByRole('button', { name: '创建导入任务' })).toBeInTheDocument()
expect((await screen.findAllByText('import-run-1')).length).toBeGreaterThan(0)
expect(memoryApi.getMemoryImportSettings).toHaveBeenCalled()
expect(memoryApi.getMemoryImportPathAliases).toHaveBeenCalled()
expect(memoryApi.getMemoryImportTasks).toHaveBeenCalled()
})
it('creates import tasks for all 7 modes and calls correct endpoints', async () => {
const user = userEvent.setup()
const { container } = render(<KnowledgeBasePage />)
const openImportTab = async () => {
await user.click(screen.getByRole('tab', { name: '导入' }))
await screen.findByRole('button', { name: '创建导入任务' })
}
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
await openImportTab()
const fileInput = container.querySelector('input[type="file"]') as HTMLInputElement
const uploadFiles = [
new File(['hello'], 'demo.txt', { type: 'text/plain' }),
new File(['{"name":"mai"}'], 'demo.json', { type: 'application/json' }),
new File(['a,b\n1,2'], 'demo.csv', { type: 'text/csv' }),
new File(['# note'], 'demo.md', { type: 'text/markdown' }),
]
await user.upload(fileInput, uploadFiles)
await user.click(screen.getByRole('button', { name: '创建导入任务' }))
await waitFor(() => expect(memoryApi.createMemoryUploadImport).toHaveBeenCalledTimes(1))
await openImportTab()
await user.click(screen.getByRole('tab', { name: '粘贴导入' }))
const editableTextarea = Array.from(container.querySelectorAll('textarea')).find((item) => !item.readOnly)
if (!editableTextarea) {
throw new Error('missing editable textarea')
}
await user.type(editableTextarea, 'paste content')
await user.click(screen.getByRole('button', { name: '创建导入任务' }))
await waitFor(() => expect(memoryApi.createMemoryPasteImport).toHaveBeenCalledTimes(1))
await openImportTab()
await user.click(screen.getByRole('tab', { name: '本地扫描' }))
await user.click(screen.getByRole('button', { name: '创建导入任务' }))
await waitFor(() => expect(memoryApi.createMemoryRawScanImport).toHaveBeenCalledTimes(1))
await openImportTab()
await user.click(screen.getByRole('tab', { name: 'LPMM OpenIE' }))
await user.click(screen.getByRole('button', { name: '创建导入任务' }))
await waitFor(() => expect(memoryApi.createMemoryLpmmOpenieImport).toHaveBeenCalledTimes(1))
await openImportTab()
await user.click(screen.getByRole('tab', { name: 'LPMM 转换' }))
await user.click(screen.getByRole('button', { name: '创建导入任务' }))
await waitFor(() => expect(memoryApi.createMemoryLpmmConvertImport).toHaveBeenCalledTimes(1))
await openImportTab()
await user.click(screen.getByRole('tab', { name: '时序回填' }))
await user.click(screen.getByRole('button', { name: '创建导入任务' }))
await waitFor(() => expect(memoryApi.createMemoryTemporalBackfillImport).toHaveBeenCalledTimes(1))
await openImportTab()
await user.click(screen.getByRole('tab', { name: 'MaiBot 迁移' }))
await user.click(screen.getByRole('button', { name: '创建导入任务' }))
await waitFor(() => expect(memoryApi.createMemoryMaibotMigrationImport).toHaveBeenCalledTimes(1))
const [uploadedFiles, uploadPayload] = vi.mocked(memoryApi.createMemoryUploadImport).mock.calls[0]
expect(uploadedFiles).toHaveLength(4)
expect(uploadedFiles.map((file) => file.name)).toEqual(['demo.txt', 'demo.json', 'demo.csv', 'demo.md'])
expect(uploadPayload).toMatchObject({
input_mode: 'text',
llm_enabled: true,
strategy_override: 'auto',
dedupe_policy: 'content_hash',
})
}, 60_000)
it('loads task detail and supports chunk pagination', async () => {
const user = userEvent.setup()
render(<KnowledgeBasePage />)
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
await user.click(screen.getByRole('tab', { name: '导入' }))
expect(await screen.findByText('alpha.txt')).toBeInTheDocument()
expect(await screen.findByText('chunk-preview-0')).toBeInTheDocument()
const betaButton = screen.getByText('beta.txt').closest('button')
if (!betaButton) {
throw new Error('missing file beta button')
}
await user.click(betaButton)
await waitFor(() =>
expect(memoryApi.getMemoryImportTaskChunks).toHaveBeenCalledWith('import-run-1', 'file-beta', 0, 50),
)
await user.click(screen.getByRole('button', { name: '下一页分块' }))
await waitFor(() =>
expect(memoryApi.getMemoryImportTaskChunks).toHaveBeenCalledWith('import-run-1', 'file-beta', 50, 50),
)
}, 20_000)
it('supports cancel and retry actions for selected task', async () => {
const user = userEvent.setup()
render(<KnowledgeBasePage />)
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
await user.click(screen.getByRole('tab', { name: '导入' }))
await screen.findByText('任务详情')
await user.click(screen.getByRole('button', { name: '取消选中导入任务' }))
await waitFor(() => expect(memoryApi.cancelMemoryImportTask).toHaveBeenCalledWith('import-run-1'))
await user.click(screen.getByRole('button', { name: '重试选中导入任务' }))
await waitFor(() => expect(memoryApi.retryMemoryImportTask).toHaveBeenCalled())
const [taskId, retryPayload] = vi.mocked(memoryApi.retryMemoryImportTask).mock.calls[0]
expect(taskId).toBe('import-run-1')
expect(retryPayload).toMatchObject({
overrides: {
llm_enabled: true,
strategy_override: 'auto',
},
})
}, 20_000)
it('auto polling updates queue and keeps page stable when refresh fails once', async () => {
vi.mocked(memoryApi.getMemoryImportSettings).mockResolvedValue({
success: true,
settings: {
max_paste_chars: 200_000,
max_file_concurrency: 8,
max_chunk_concurrency: 16,
default_file_concurrency: 2,
default_chunk_concurrency: 4,
poll_interval_ms: 200,
maibot_source_db_default: 'data/maibot.db',
},
})
const user = userEvent.setup()
render(<KnowledgeBasePage />)
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
await user.click(screen.getByRole('tab', { name: '导入' }))
await screen.findByText('导入队列')
const initialCalls = vi.mocked(memoryApi.getMemoryImportTasks).mock.calls.length
vi.mocked(memoryApi.getMemoryImportTasks).mockRejectedValueOnce(new Error('poll failure'))
await act(async () => {
await new Promise((resolve) => setTimeout(resolve, 350))
})
expect(screen.getByText('长期记忆控制台')).toBeInTheDocument()
expect(vi.mocked(memoryApi.getMemoryImportTasks).mock.calls.length).toBeGreaterThan(initialCalls)
}, 20_000)
it('creates tuning task and applies best profile (tuning module)', async () => {
const user = userEvent.setup()
render(<KnowledgeBasePage />)
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
await user.click(screen.getByRole('tab', { name: '调优' }))
expect(await screen.findByText('tune-1')).toBeInTheDocument()
expect(screen.getByRole('button', { name: '应用最佳' })).toBeInTheDocument()
})
await screen.findByText('调优任务')
it('shows delete tab and opens source delete preview', async () => {
await user.click(screen.getByRole('button', { name: '创建调优任务' }))
await waitFor(() =>
expect(memoryApi.createMemoryTuningTask).toHaveBeenCalledWith({
objective: 'precision_priority',
intensity: 'standard',
sample_size: 24,
top_k_eval: 20,
}),
)
await user.click(screen.getByRole('button', { name: '应用最佳' }))
await waitFor(() => expect(memoryApi.applyBestMemoryTuningProfile).toHaveBeenCalledWith('tune-1'))
}, 20_000)
it('previews executes and restores source delete (delete module)', async () => {
const user = userEvent.setup()
render(<KnowledgeBasePage />)
expect(await screen.findByText('长期记忆控制台')).toBeInTheDocument()
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
await user.click(screen.getByRole('tab', { name: '删除' }))
await screen.findByText('来源批量删除')
expect(await screen.findByText('来源批量删除')).toBeInTheDocument()
expect(screen.getAllByText('demo-1').length).toBeGreaterThan(0)
expect(screen.getAllByText('del-1').length).toBeGreaterThan(0)
expect(screen.getByText('恢复这次删除')).toBeInTheDocument()
const sourceCellCandidates = await screen.findAllByText('demo-1')
const sourceRow = sourceCellCandidates
.map((item) => item.closest('tr'))
.find((row): row is HTMLTableRowElement => Boolean(row && within(row).queryByRole('checkbox')))
if (!sourceRow) {
throw new Error('missing source row')
}
await user.click(within(sourceRow).getByRole('checkbox'))
await user.click(screen.getAllByRole('checkbox')[0])
await user.click(screen.getByRole('button', { name: '预览删除' }))
await waitFor(() =>
expect(memoryApi.previewMemoryDelete).toHaveBeenCalledWith({
mode: 'source',
selector: { sources: ['demo-1'] },
reason: 'knowledge_base_source_delete',
requested_by: 'knowledge_base',
}),
)
expect(await screen.findByTestId('memory-delete-dialog')).toHaveTextContent('delete:source:1')
})
const dialog = await screen.findByTestId('memory-delete-dialog')
expect(dialog).toHaveTextContent('preview:source:1')
it('loads selected delete operation detail items from detail endpoint', async () => {
const user = userEvent.setup()
await user.click(screen.getByRole('button', { name: '执行删除' }))
await waitFor(() =>
expect(memoryApi.executeMemoryDelete).toHaveBeenCalledWith({
mode: 'source',
selector: { sources: ['demo-1'] },
reason: 'knowledge_base_source_delete',
requested_by: 'knowledge_base',
}),
)
render(<KnowledgeBasePage />)
expect(await screen.findByText('长期记忆控制台')).toBeInTheDocument()
await user.click(screen.getByRole('tab', { name: '删除' }))
expect(await screen.findByText('删除操作恢复')).toBeInTheDocument()
expect(await screen.findByText('paragraph')).toBeInTheDocument()
expect(screen.getByText('p-1')).toBeInTheDocument()
expect(screen.getByText('这是用于测试删除详情展示的段落内容。')).toBeInTheDocument()
})
await user.click(screen.getByRole('button', { name: '执行恢复' }))
await waitFor(() =>
expect(memoryApi.restoreMemoryDelete).toHaveBeenCalledWith({
operation_id: 'del-2',
requested_by: 'knowledge_base',
}),
)
}, 20_000)
})

View File

@@ -107,6 +107,7 @@ vi.mock('../knowledge-graph/GraphDialogs', () => ({
vi.mock('@/lib/memory-api', () => ({
getMemoryGraph: vi.fn(),
getMemoryGraphSearch: vi.fn(),
getMemoryGraphNodeDetail: vi.fn(),
getMemoryGraphEdgeDetail: vi.fn(),
previewMemoryDelete: vi.fn(),
@@ -139,6 +140,13 @@ describe('KnowledgeGraphPage', () => {
total_nodes: 2,
total_edges: 1,
})
vi.mocked(memoryApi.getMemoryGraphSearch).mockResolvedValue({
success: true,
query: 'alpha',
limit: 50,
count: 0,
items: [],
})
vi.mocked(memoryApi.getMemoryGraphNodeDetail).mockResolvedValue({
success: true,
node: { id: 'alpha', type: 'entity', content: 'Alpha', hash: 'entity-1', appearance_count: 3 },
@@ -255,7 +263,7 @@ describe('KnowledgeGraphPage', () => {
vi.mocked(memoryApi.restoreMemoryDelete).mockResolvedValue({ success: true } as never)
})
it('renders graph summary and supports empty-result filtering', async () => {
it('calls backend graph search and renders no-hit state', async () => {
const user = userEvent.setup()
render(<KnowledgeGraphPage />)
@@ -264,11 +272,102 @@ describe('KnowledgeGraphPage', () => {
expect(screen.getByText(/总节点 2/)).toBeInTheDocument()
expect(screen.getByTestId('graph-visualization')).toHaveTextContent('nodes:2,edges:1')
await user.type(screen.getByPlaceholderText('筛选实体名称、节点 ID 或边标签'), 'missing')
await user.type(screen.getByPlaceholderText('搜索实体、关系、hash后端全库'), 'missing')
expect(memoryApi.getMemoryGraph).toHaveBeenCalledTimes(1)
await user.click(screen.getByRole('button', { name: '筛选' }))
await user.click(screen.getByRole('button', { name: '搜索' }))
await waitFor(() => {
expect(memoryApi.getMemoryGraphSearch).toHaveBeenCalledWith('missing', 50)
})
expect(await screen.findByText('未命中实体或关系。')).toBeInTheDocument()
})
it('supports clicking entity search result to locate evidence', async () => {
const user = userEvent.setup()
vi.mocked(memoryApi.getMemoryGraphSearch).mockResolvedValue({
success: true,
query: 'alpha',
limit: 50,
count: 1,
items: [
{
type: 'entity',
title: 'Alpha',
matched_field: 'name',
matched_value: 'Alpha',
entity_name: 'alpha',
entity_hash: 'entity-1',
appearance_count: 3,
},
],
})
render(<KnowledgeGraphPage />)
await screen.findByTestId('graph-visualization')
await user.type(screen.getByPlaceholderText('搜索实体、关系、hash后端全库'), 'alpha')
await user.click(screen.getByRole('button', { name: '搜索' }))
await screen.findByText('搜索词alpha')
await user.click(screen.getByRole('button', { name: /Alpha/ }))
await waitFor(() => {
expect(memoryApi.getMemoryGraphNodeDetail).toHaveBeenCalledWith('alpha')
})
expect(screen.getByRole('tab', { name: '证据视图' })).toHaveAttribute('data-state', 'active')
})
it('supports clicking relation search result to locate evidence', async () => {
const user = userEvent.setup()
vi.mocked(memoryApi.getMemoryGraphSearch).mockResolvedValue({
success: true,
query: '关联',
limit: 50,
count: 1,
items: [
{
type: 'relation',
title: 'alpha 关联 beta',
matched_field: 'predicate',
matched_value: '关联',
subject: 'alpha',
predicate: '关联',
object: 'beta',
relation_hash: 'rel-1',
confidence: 0.9,
},
],
})
render(<KnowledgeGraphPage />)
await screen.findByTestId('graph-visualization')
await user.type(screen.getByPlaceholderText('搜索实体、关系、hash后端全库'), '关联')
await user.click(screen.getByRole('button', { name: '搜索' }))
await user.click(screen.getByRole('button', { name: /alpha 关联 beta/ }))
await waitFor(() => {
expect(memoryApi.getMemoryGraphEdgeDetail).toHaveBeenCalledWith('alpha', 'beta')
})
expect(screen.getByRole('tab', { name: '证据视图' })).toHaveAttribute('data-state', 'active')
})
it('falls back to local filtering when backend search fails', async () => {
const user = userEvent.setup()
vi.mocked(memoryApi.getMemoryGraphSearch).mockRejectedValue(new Error('search unavailable'))
render(<KnowledgeGraphPage />)
await screen.findByTestId('graph-visualization')
await user.type(screen.getByPlaceholderText('搜索实体、关系、hash后端全库'), 'missing')
await user.click(screen.getByRole('button', { name: '搜索' }))
expect(await screen.findByText('还没有可展示的长期记忆图谱')).toBeInTheDocument()
expect(toastMock).toHaveBeenCalledWith(
expect.objectContaining({
title: '后端检索失败,已回退本地筛选',
}),
)
})
it('shows empty state when switching to evidence view without a selection', async () => {

File diff suppressed because it is too large Load Diff

View File

@@ -22,6 +22,7 @@ import {
getMemoryGraph,
getMemoryGraphEdgeDetail,
getMemoryGraphNodeDetail,
getMemoryGraphSearch,
previewMemoryDelete,
restoreMemoryDelete,
type MemoryDeleteExecutePayload,
@@ -34,6 +35,7 @@ import {
type MemoryGraphParagraphDetailPayload,
type MemoryGraphPayload,
type MemoryGraphRelationDetailPayload,
type MemoryGraphSearchItem,
} from '@/lib/memory-api'
import {
@@ -211,6 +213,9 @@ export function KnowledgeGraphPage() {
const [nodeLimit, setNodeLimit] = useState('120')
const [searchInput, setSearchInput] = useState('')
const [appliedSearchQuery, setAppliedSearchQuery] = useState('')
const [searchLoading, setSearchLoading] = useState(false)
const [searchResults, setSearchResults] = useState<MemoryGraphSearchItem[]>([])
const [searchFallbackMode, setSearchFallbackMode] = useState(false)
const [viewMode, setViewMode] = useState<GraphViewMode>('entity')
const [fullGraph, setFullGraph] = useState<GraphData>({ nodes: [], edges: [] })
const [graphData, setGraphData] = useState<GraphData>({ nodes: [], edges: [] })
@@ -258,9 +263,12 @@ export function KnowledgeGraphPage() {
setLoading(true)
const payload = await getMemoryGraph(Number(nodeLimit))
const nextGraph = toEntityGraphData(payload)
const visibleGraph = searchFallbackMode && appliedSearchQuery
? filterGraphData(nextGraph, appliedSearchQuery)
: nextGraph
setGraphMeta(payload)
setFullGraph(nextGraph)
setGraphData(filterGraphData(nextGraph, appliedSearchQuery))
setGraphData(visibleGraph)
setEvidenceGraph({ nodes: [], edges: [] })
resetDetailSelections()
if (!options?.silent) {
@@ -278,21 +286,54 @@ export function KnowledgeGraphPage() {
} finally {
setLoading(false)
}
}, [appliedSearchQuery, nodeLimit, resetDetailSelections, toast])
}, [appliedSearchQuery, nodeLimit, resetDetailSelections, searchFallbackMode, toast])
useEffect(() => {
void loadGraph({ silent: true })
}, [loadGraph])
const handleSearch = useCallback(() => {
const handleSearch = useCallback(async () => {
const nextQuery = searchInput.trim()
if (!nextQuery) {
setAppliedSearchQuery('')
setSearchFallbackMode(false)
setSearchResults([])
setGraphData(fullGraph)
toast({
title: '已重置筛选',
description: `当前显示 ${fullGraph.nodes.length} 个节点、${fullGraph.edges.length} 条关系`,
})
return
}
setSearchLoading(true)
setAppliedSearchQuery(nextQuery)
const filtered = filterGraphData(fullGraph, nextQuery)
setGraphData(filtered)
toast({
title: nextQuery ? '筛选完成' : '已重置筛选',
description: `当前显示 ${filtered.nodes.length} 个节点、${filtered.edges.length} 条关系`,
})
try {
const payload = await getMemoryGraphSearch(nextQuery, 50)
if (!payload.success) {
throw new Error(payload.error || '图谱检索失败')
}
const items = Array.isArray(payload.items) ? payload.items : []
setSearchResults(items)
setSearchFallbackMode(false)
setGraphData(fullGraph)
toast({
title: '全库检索完成',
description: `命中 ${payload.count ?? items.length} 条结果`,
})
} catch (error) {
const filtered = filterGraphData(fullGraph, nextQuery)
setSearchResults([])
setSearchFallbackMode(true)
setGraphData(filtered)
toast({
title: '后端检索失败,已回退本地筛选',
description: `仅当前已加载范围(${filtered.nodes.length} 个节点、${filtered.edges.length} 条关系)`,
variant: 'destructive',
})
} finally {
setSearchLoading(false)
}
}, [fullGraph, searchInput, toast])
const stats = useMemo(
@@ -397,21 +438,41 @@ export function KnowledgeGraphPage() {
}
}, [closeDeleteDialog, deleteResult?.operation_id, loadGraph, toast])
const handleNodeClick = useCallback(async (_: React.MouseEvent, node: Node) => {
const selected = graphData.nodes.find((item) => item.id === node.id)
setSelectedNodeData(selected ?? null)
const openNodeDetail = useCallback(async (
nodeId: string,
options?: { locateInEvidence?: boolean },
) => {
const nodeToken = String(nodeId || '').trim()
if (!nodeToken) {
return
}
const selected = graphData.nodes.find((item) => item.id === nodeToken)
if (options?.locateInEvidence) {
setSelectedNodeData(null)
} else {
setSelectedNodeData(
selected ?? {
id: nodeToken,
type: 'entity',
content: nodeToken,
metadata: {},
},
)
}
setSelectedEdgeData(null)
setEdgeDetail(null)
setSelectedRelationDetail(null)
setSelectedRelationMetadata(null)
setSelectedParagraphDetail(null)
if (!selected) {
return
}
setSelectedParagraphMetadata(null)
try {
setDetailLoading(true)
const detail = await getMemoryGraphNodeDetail(selected.id)
const detail = await getMemoryGraphNodeDetail(nodeToken)
setNodeDetail(detail)
setEvidenceGraph(toEvidenceGraphData(detail.evidence_graph))
if (options?.locateInEvidence) {
setViewMode('evidence')
}
} catch (error) {
toast({
title: '加载节点详情失败',
@@ -423,27 +484,62 @@ export function KnowledgeGraphPage() {
}
}, [graphData.nodes, toast])
const handleEdgeClick = useCallback(async (_: React.MouseEvent, edge: Edge) => {
const sourceNode = graphData.nodes.find((nodeItem) => nodeItem.id === edge.source)
const targetNode = graphData.nodes.find((nodeItem) => nodeItem.id === edge.target)
const edgeData = graphData.edges.find((item) => item.source === edge.source && item.target === edge.target)
if (!sourceNode || !targetNode || !edgeData) {
const openEdgeDetail = useCallback(async (
source: string,
target: string,
options?: { locateInEvidence?: boolean },
) => {
const sourceToken = String(source || '').trim()
const targetToken = String(target || '').trim()
if (!sourceToken || !targetToken) {
return
}
setSelectedNodeData(null)
setNodeDetail(null)
setSelectedRelationDetail(null)
setSelectedRelationMetadata(null)
setSelectedParagraphDetail(null)
setSelectedEdgeData({
source: sourceNode,
target: targetNode,
edge: edgeData,
})
setSelectedParagraphMetadata(null)
if (options?.locateInEvidence) {
setSelectedEdgeData(null)
} else {
const sourceNode = graphData.nodes.find((nodeItem) => nodeItem.id === sourceToken) ?? {
id: sourceToken,
type: 'entity' as const,
content: sourceToken,
metadata: {},
}
const targetNode = graphData.nodes.find((nodeItem) => nodeItem.id === targetToken) ?? {
id: targetToken,
type: 'entity' as const,
content: targetToken,
metadata: {},
}
const edgeData = graphData.edges.find((item) => item.source === sourceToken && item.target === targetToken) ?? {
source: sourceToken,
target: targetToken,
weight: 1,
kind: 'relation' as const,
label: '',
relationHashes: [],
predicates: [],
relationCount: 0,
evidenceCount: 0,
}
setSelectedEdgeData({
source: sourceNode,
target: targetNode,
edge: edgeData,
})
}
try {
setDetailLoading(true)
const detail = await getMemoryGraphEdgeDetail(edge.source, edge.target)
const detail = await getMemoryGraphEdgeDetail(sourceToken, targetToken)
setEdgeDetail(detail)
setEvidenceGraph(toEvidenceGraphData(detail.evidence_graph))
if (options?.locateInEvidence) {
setViewMode('evidence')
}
} catch (error) {
toast({
title: '加载关系详情失败',
@@ -455,6 +551,36 @@ export function KnowledgeGraphPage() {
}
}, [graphData.edges, graphData.nodes, toast])
const handleNodeClick = useCallback((_: React.MouseEvent, node: Node) => {
void openNodeDetail(node.id)
}, [openNodeDetail])
const handleEdgeClick = useCallback((_: React.MouseEvent, edge: Edge) => {
void openEdgeDetail(edge.source, edge.target)
}, [openEdgeDetail])
const handleSearchResultClick = useCallback((item: MemoryGraphSearchItem) => {
if (item.type === 'entity') {
const entityName = String(item.entity_name ?? item.title ?? '').trim()
if (!entityName) {
return
}
void openNodeDetail(entityName, { locateInEvidence: true })
return
}
const source = String(item.subject ?? '').trim()
const target = String(item.object ?? '').trim()
if (!source || !target) {
toast({
title: '结果缺少定位信息',
description: '该关系记录没有可用的 subject/object无法定位。',
variant: 'destructive',
})
return
}
void openEdgeDetail(source, target, { locateInEvidence: true })
}, [openEdgeDetail, openNodeDetail, toast])
const handleEvidenceNodeClick = useCallback(async (_: React.MouseEvent, node: Node) => {
const selected = evidenceGraph.nodes.find((item) => item.id === node.id)
if (!selected) {
@@ -640,12 +766,12 @@ export function KnowledgeGraphPage() {
<Input
value={searchInput}
onChange={(event) => setSearchInput(event.target.value)}
onKeyDown={(event) => event.key === 'Enter' && handleSearch()}
placeholder="筛选实体名称、节点 ID 或边标签"
onKeyDown={(event) => event.key === 'Enter' && void handleSearch()}
placeholder="搜索实体、关系、hash后端全库"
/>
<Button onClick={handleSearch} variant="secondary">
<Button onClick={() => void handleSearch()} variant="secondary" disabled={searchLoading}>
<Search className="mr-2 h-4 w-4" />
{searchLoading ? '检索中' : '搜索'}
</Button>
</div>
@@ -678,6 +804,48 @@ export function KnowledgeGraphPage() {
<TabsTrigger value="evidence"></TabsTrigger>
</TabsList>
</Tabs>
{appliedSearchQuery ? (
<div className="rounded-lg border bg-background/80 p-3">
<div className="flex flex-wrap items-center justify-between gap-2">
<div className="text-sm font-medium">
{appliedSearchQuery}
</div>
<Badge variant={searchFallbackMode ? 'destructive' : 'secondary'}>
{searchFallbackMode ? '仅当前已加载范围' : `全库命中 ${searchResults.length}`}
</Badge>
</div>
{searchFallbackMode ? (
<p className="mt-2 text-sm text-muted-foreground">
</p>
) : searchResults.length <= 0 ? (
<p className="mt-2 text-sm text-muted-foreground"></p>
) : (
<div className="mt-3 max-h-56 space-y-2 overflow-auto pr-1">
{searchResults.map((item, index) => (
<button
key={`${item.type}-${item.entity_hash ?? item.relation_hash ?? `${item.title}-${index}`}`}
type="button"
className="w-full rounded-md border bg-card px-3 py-2 text-left transition hover:bg-accent/40"
onClick={() => handleSearchResultClick(item)}
>
<div className="flex items-center gap-2">
<Badge variant="outline">{item.type === 'entity' ? '实体' : '关系'}</Badge>
<span className="truncate text-sm font-medium">{item.title || '(无标题结果)'}</span>
</div>
<p className="mt-1 text-xs text-muted-foreground">
{item.matched_field} = {item.matched_value}
{item.type === 'entity'
? ` · appearance=${item.appearance_count ?? 0}`
: ` · confidence=${Number(item.confidence ?? 0).toFixed(2)}`}
</p>
</button>
))}
</div>
)}
</div>
) : null}
</div>
</div>

View File

@@ -0,0 +1,113 @@
from __future__ import annotations
from pathlib import Path
from typing import Any
import pytest
from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
class _DummyMetadataStore:
def __init__(self, *, entities: list[dict[str, Any]], relations: list[dict[str, Any]]) -> None:
self._entities = entities
self._relations = relations
def query(self, sql: str, params: tuple[Any, ...] = ()) -> list[dict[str, Any]]:
sql_token = " ".join(str(sql or "").lower().split())
keyword = str(params[0] or "").strip("%").lower() if params else ""
if "from entities" in sql_token:
rows = [dict(item) for item in self._entities if not bool(item.get("is_deleted", 0))]
if not keyword:
return rows
return [
row
for row in rows
if keyword in str(row.get("name", "") or "").lower()
or keyword in str(row.get("hash", "") or "").lower()
]
if "from relations" in sql_token:
rows = [dict(item) for item in self._relations if not bool(item.get("is_inactive", 0))]
if not keyword:
return rows
return [
row
for row in rows
if keyword in str(row.get("subject", "") or "").lower()
or keyword in str(row.get("object", "") or "").lower()
or keyword in str(row.get("predicate", "") or "").lower()
or keyword in str(row.get("hash", "") or "").lower()
]
raise AssertionError(f"unexpected query: {sql_token}")
def _build_kernel(*, entities: list[dict[str, Any]], relations: list[dict[str, Any]]) -> SDKMemoryKernel:
kernel = SDKMemoryKernel(plugin_root=Path.cwd(), config={})
async def _fake_initialize() -> None:
return None
kernel.initialize = _fake_initialize # type: ignore[method-assign]
kernel.metadata_store = _DummyMetadataStore(entities=entities, relations=relations)
kernel.graph_store = object() # type: ignore[assignment]
return kernel
@pytest.mark.asyncio
async def test_memory_graph_admin_search_orders_and_dedupes_results() -> None:
kernel = _build_kernel(
entities=[
{"hash": "e1", "name": "Alice", "appearance_count": 5, "is_deleted": 0},
{"hash": "e1", "name": "Alice Duplicate", "appearance_count": 99, "is_deleted": 0},
{"hash": "e2", "name": "Alice Cooper", "appearance_count": 7, "is_deleted": 0},
{"hash": "e3", "name": "my alice note", "appearance_count": 11, "is_deleted": 0},
{"hash": "e4", "name": "alice deleted", "appearance_count": 100, "is_deleted": 1},
],
relations=[
{"hash": "r1", "subject": "Alice", "predicate": "knows", "object": "Bob", "confidence": 0.6, "created_at": 100, "is_inactive": 0},
{"hash": "r3", "subject": "Alice", "predicate": "supports", "object": "Carol", "confidence": 0.9, "created_at": 90, "is_inactive": 0},
{"hash": "r1", "subject": "Alice", "predicate": "knows duplicate", "object": "Bob", "confidence": 0.99, "created_at": 200, "is_inactive": 0},
{"hash": "r2", "subject": "Alice Cooper", "predicate": "likes", "object": "Tea", "confidence": 0.2, "created_at": 50, "is_inactive": 0},
{"hash": "", "subject": "Carol", "predicate": "mentions alice", "object": "Topic", "confidence": 0.8, "created_at": 70, "is_inactive": 0},
{"hash": "", "subject": "Carol", "predicate": "mentions alice", "object": "Topic", "confidence": 0.3, "created_at": 10, "is_inactive": 0},
{"hash": "r4", "subject": "alice inactive", "predicate": "old", "object": "Data", "confidence": 1.0, "created_at": 300, "is_inactive": 1},
],
)
payload = await kernel.memory_graph_admin(action="search", query="alice", limit=20)
assert payload["success"] is True
assert payload["count"] == len(payload["items"])
entity_items = [item for item in payload["items"] if item["type"] == "entity"]
relation_items = [item for item in payload["items"] if item["type"] == "relation"]
assert [item["entity_hash"] for item in entity_items] == ["e1", "e2", "e3"]
assert [item["relation_hash"] for item in relation_items] == ["r3", "r1", "r2", ""]
assert relation_items[0]["confidence"] == pytest.approx(0.9)
assert relation_items[1]["confidence"] == pytest.approx(0.6)
@pytest.mark.asyncio
async def test_memory_graph_admin_search_filters_deleted_and_inactive_records() -> None:
kernel = _build_kernel(
entities=[
{"hash": "e-deleted", "name": "Ghost Alice", "appearance_count": 10, "is_deleted": 1},
],
relations=[
{
"hash": "r-inactive",
"subject": "Ghost Alice",
"predicate": "linked",
"object": "Ghost Bob",
"confidence": 0.9,
"created_at": 10,
"is_inactive": 1,
},
],
)
payload = await kernel.memory_graph_admin(action="search", query="ghost", limit=50)
assert payload["success"] is True
assert payload["items"] == []
assert payload["count"] == 0

View File

@@ -52,6 +52,54 @@ def test_webui_memory_graph_route(client: TestClient, monkeypatch):
assert response.json()["edges"][0]["evidence_count"] == 2
def test_webui_memory_graph_search_route(client: TestClient, monkeypatch):
async def fake_graph_admin(*, action: str, **kwargs):
assert action == "search"
assert kwargs["query"] == "Alice"
assert kwargs["limit"] == 33
return {
"success": True,
"query": kwargs["query"],
"limit": kwargs["limit"],
"count": 1,
"items": [
{
"type": "entity",
"title": "Alice",
"matched_field": "name",
"matched_value": "Alice",
"entity_name": "Alice",
"entity_hash": "entity-1",
"appearance_count": 3,
}
],
}
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
response = client.get("/api/webui/memory/graph/search", params={"query": "Alice", "limit": 33})
assert response.status_code == 200
assert response.json()["success"] is True
assert response.json()["query"] == "Alice"
assert response.json()["limit"] == 33
assert response.json()["items"][0]["type"] == "entity"
@pytest.mark.parametrize(
"params",
[
{"query": "", "limit": 50},
{"query": "Alice", "limit": 0},
{"query": "Alice", "limit": 201},
],
)
def test_webui_memory_graph_search_route_validation(client: TestClient, params):
response = client.get("/api/webui/memory/graph/search", params=params)
assert response.status_code == 422
def test_webui_memory_graph_node_detail_route(client: TestClient, monkeypatch):
async def fake_graph_admin(*, action: str, **kwargs):
assert action == "node_detail"
@@ -200,28 +248,59 @@ def test_memory_config_routes(client: TestClient, monkeypatch):
"get_raw_config",
lambda: "[plugin]\nenabled = true\n",
)
monkeypatch.setattr(
memory_router_module.a_memorix_host_service,
"get_raw_config_with_meta",
lambda: {
"config": "[plugin]\nenabled = true\n",
"exists": True,
"using_default": False,
},
)
schema_response = client.get("/api/webui/memory/config/schema")
config_response = client.get("/api/webui/memory/config")
raw_response = client.get("/api/webui/memory/config/raw")
expected_path = memory_router_module.Path("/tmp/config/a_memorix.toml").as_posix()
assert schema_response.status_code == 200
assert schema_response.json()["path"] == "/tmp/config/a_memorix.toml"
assert memory_router_module.Path(schema_response.json()["path"]).as_posix() == expected_path
assert schema_response.json()["schema"]["layout"]["type"] == "tabs"
assert config_response.status_code == 200
assert config_response.json() == {
"success": True,
"config": {"plugin": {"enabled": True}},
"path": "/tmp/config/a_memorix.toml",
}
assert config_response.json()["success"] is True
assert config_response.json()["config"] == {"plugin": {"enabled": True}}
assert memory_router_module.Path(config_response.json()["path"]).as_posix() == expected_path
assert raw_response.status_code == 200
assert raw_response.json() == {
"success": True,
"config": "[plugin]\nenabled = true\n",
"path": "/tmp/config/a_memorix.toml",
}
assert raw_response.json()["success"] is True
assert raw_response.json()["config"] == "[plugin]\nenabled = true\n"
assert memory_router_module.Path(raw_response.json()["path"]).as_posix() == expected_path
def test_memory_config_raw_returns_default_template_when_file_missing(client: TestClient, monkeypatch):
monkeypatch.setattr(
memory_router_module.a_memorix_host_service,
"get_config_path",
lambda: memory_router_module.Path("/tmp/config/a_memorix.toml"),
)
monkeypatch.setattr(
memory_router_module.a_memorix_host_service,
"get_raw_config_with_meta",
lambda: {
"config": "[plugin]\nenabled = true\n",
"exists": False,
"using_default": True,
},
)
response = client.get("/api/webui/memory/config/raw")
assert response.status_code == 200
assert response.json()["success"] is True
assert response.json()["config"] == "[plugin]\nenabled = true\n"
assert response.json()["exists"] is False
assert response.json()["using_default"] is True
def test_memory_config_update_routes(client: TestClient, monkeypatch):

View File

@@ -0,0 +1,499 @@
from __future__ import annotations
from pathlib import Path
from time import monotonic, sleep
from typing import Any, Dict, Generator
from uuid import uuid4
import asyncio
import json
from fastapi import FastAPI
from fastapi.testclient import TestClient
import pytest
import tomlkit
from src.A_memorix import host_service as host_service_module
from src.A_memorix.core.utils import retrieval_tuning_manager as tuning_manager_module
from src.webui.dependencies import require_auth
from src.webui.routers import memory as memory_router_module
REQUEST_TIMEOUT_SECONDS = 30
IMPORT_TIMEOUT_SECONDS = 120
TUNING_TIMEOUT_SECONDS = 420
IMPORT_TERMINAL_STATUSES = {"completed", "completed_with_errors", "failed", "cancelled"}
TUNING_TERMINAL_STATUSES = {"completed", "failed", "cancelled"}
def _build_test_config(data_dir: Path) -> Dict[str, Any]:
return {
"storage": {
"data_dir": str(data_dir),
},
"advanced": {
"enable_auto_save": False,
},
"embedding": {
"dimension": 64,
"batch_size": 4,
"max_concurrent": 1,
"retry": {
"max_attempts": 1,
"min_wait_seconds": 0.1,
"max_wait_seconds": 0.2,
"backoff_multiplier": 1.0,
},
"fallback": {
"enabled": True,
"allow_metadata_only_write": True,
"probe_interval_seconds": 30,
},
"paragraph_vector_backfill": {
"enabled": False,
"interval_seconds": 60,
"batch_size": 32,
"max_retry": 2,
},
},
"retrieval": {
"enable_parallel": False,
"enable_ppr": False,
"top_k_paragraphs": 20,
"top_k_relations": 10,
"top_k_final": 10,
"alpha": 0.5,
"search": {
"smart_fallback": {
"enabled": True,
},
},
"sparse": {
"enabled": True,
"mode": "auto",
"candidate_k": 80,
"relation_candidate_k": 60,
},
"fusion": {
"method": "weighted_rrf",
"rrf_k": 60,
"vector_weight": 0.7,
"bm25_weight": 0.3,
},
},
"threshold": {
"percentile": 70.0,
"min_results": 1,
},
"web": {
"tuning": {
"enabled": True,
"poll_interval_ms": 300,
"max_queue_size": 4,
"default_objective": "balanced",
"default_intensity": "quick",
"default_sample_size": 4,
"default_top_k_eval": 5,
"eval_query_timeout_seconds": 1.0,
"llm_retry": {
"max_attempts": 1,
"min_wait_seconds": 0.1,
"max_wait_seconds": 0.2,
"backoff_multiplier": 1.0,
},
},
},
}
def _assert_response_ok(response: Any) -> Dict[str, Any]:
assert response.status_code == 200, response.text
payload = response.json()
assert payload.get("success", True) is True, payload
return payload
def _wait_for_import_task_terminal(client: TestClient, task_id: str, *, timeout_seconds: float = IMPORT_TIMEOUT_SECONDS) -> Dict[str, Any]:
deadline = monotonic() + timeout_seconds
last_payload: Dict[str, Any] = {}
while monotonic() < deadline:
response = client.get(
f"/api/webui/memory/import/tasks/{task_id}",
params={"include_chunks": True},
)
payload = _assert_response_ok(response)
last_payload = payload
task = payload.get("task") or {}
status = str(task.get("status", "") or "")
if status in IMPORT_TERMINAL_STATUSES:
return task
sleep(0.2)
raise AssertionError(f"导入任务超时: task_id={task_id}, last_payload={last_payload}")
def _wait_for_tuning_task_terminal(client: TestClient, task_id: str, *, timeout_seconds: float = TUNING_TIMEOUT_SECONDS) -> Dict[str, Any]:
deadline = monotonic() + timeout_seconds
last_payload: Dict[str, Any] = {}
while monotonic() < deadline:
response = client.get(
f"/api/webui/memory/retrieval_tuning/tasks/{task_id}",
params={"include_rounds": False},
)
payload = _assert_response_ok(response)
last_payload = payload
task = payload.get("task") or {}
status = str(task.get("status", "") or "")
if status in TUNING_TERMINAL_STATUSES:
return task
sleep(0.3)
raise AssertionError(f"调优任务超时: task_id={task_id}, last_payload={last_payload}")
def _wait_for_query_hit(client: TestClient, query: str, *, timeout_seconds: float = 30.0) -> Dict[str, Any]:
deadline = monotonic() + timeout_seconds
last_payload: Dict[str, Any] = {}
while monotonic() < deadline:
payload = _assert_response_ok(
client.get(
"/api/webui/memory/query/aggregate",
params={"query": query, "limit": 20},
)
)
last_payload = payload
hits = payload.get("hits") or []
if isinstance(hits, list) and len(hits) > 0:
return payload
sleep(0.2)
raise AssertionError(f"检索命中超时: query={query}, last_payload={last_payload}")
def _get_source_item(client: TestClient, source_name: str) -> Dict[str, Any] | None:
payload = _assert_response_ok(client.get("/api/webui/memory/sources"))
items = payload.get("items") or []
for item in items:
if not isinstance(item, dict):
continue
if str(item.get("source", "") or "") == source_name:
return item
return None
def _source_paragraph_count(item: Dict[str, Any] | None) -> int:
payload = item or {}
if "paragraph_count" in payload:
return int(payload.get("paragraph_count", 0) or 0)
return int(payload.get("count", 0) or 0)
def _wait_for_source_paragraph_count(
client: TestClient,
source_name: str,
*,
min_count: int,
timeout_seconds: float = 30.0,
) -> Dict[str, Any]:
deadline = monotonic() + timeout_seconds
last_item: Dict[str, Any] = {}
while monotonic() < deadline:
item = _get_source_item(client, source_name)
count = _source_paragraph_count(item)
if count >= int(min_count):
return item or {}
if item:
last_item = dict(item)
sleep(0.2)
raise AssertionError(
f"等待来源段落计数超时: source={source_name}, min_count={min_count}, last_item={last_item}"
)
def _create_multitype_upload_task(client: TestClient) -> str:
structured_json = {
"paragraphs": [
{
"content": "Alice 携带地图前往火星港。",
"source": "integration-upload-json",
"entities": ["Alice", "地图", "火星港"],
"relations": [
{"subject": "Alice", "predicate": "携带", "object": "地图"},
{"subject": "Alice", "predicate": "前往", "object": "火星港"},
],
}
]
}
extra_json = {
"paragraphs": [
{
"content": "Carol 记录了一条补充说明。",
"source": "integration-upload-json-extra",
"entities": ["Carol"],
"relations": [],
}
]
}
payload_json = json.dumps(
{
"input_mode": "text",
"llm_enabled": False,
"file_concurrency": 2,
"chunk_concurrency": 2,
"dedupe_policy": "none",
},
ensure_ascii=False,
)
files = [
("files", ("integration-notes.txt", "Alice 在测试环境记录了一条长期记忆。".encode("utf-8"), "text/plain")),
("files", ("integration-diary.md", "# 日志\nBob 与 Alice 讨论了导图。".encode("utf-8"), "text/markdown")),
("files", ("integration-structured.json", json.dumps(structured_json, ensure_ascii=False).encode("utf-8"), "application/json")),
("files", ("integration-extra.json", json.dumps(extra_json, ensure_ascii=False).encode("utf-8"), "application/json")),
]
response = client.post(
"/api/webui/memory/import/upload",
data={"payload_json": payload_json},
files=files,
)
payload = _assert_response_ok(response)
task_id = str((payload.get("task") or {}).get("task_id") or "").strip()
assert task_id, payload
return task_id
def _create_seed_paste_task(client: TestClient, *, source: str, unique_token: str) -> str:
seed_payload = {
"paragraphs": [
{
"content": f"Alice 在火星港携带地图并记录了口令 {unique_token}",
"source": source,
"entities": ["Alice", "火星港", "地图"],
"relations": [
{"subject": "Alice", "predicate": "前往", "object": "火星港"},
{"subject": "Alice", "predicate": "携带", "object": "地图"},
],
},
{
"content": f"Bob 在火星港遇见 Alice并重复口令 {unique_token}",
"source": source,
"entities": ["Bob", "Alice", "火星港"],
"relations": [
{"subject": "Bob", "predicate": "遇见", "object": "Alice"},
{"subject": "Bob", "predicate": "位于", "object": "火星港"},
],
},
]
}
response = client.post(
"/api/webui/memory/import/paste",
json={
"name": "integration-seed.json",
"input_mode": "json",
"llm_enabled": False,
"content": json.dumps(seed_payload, ensure_ascii=False),
"dedupe_policy": "none",
},
)
payload = _assert_response_ok(response)
task_id = str((payload.get("task") or {}).get("task_id") or "").strip()
assert task_id, payload
return task_id
@pytest.fixture(scope="module")
def integration_state(tmp_path_factory: pytest.TempPathFactory) -> Generator[Dict[str, Any], None, None]:
tmp_root = tmp_path_factory.mktemp("memory_routes_integration")
data_dir = (tmp_root / "data").resolve()
staging_dir = (tmp_root / "upload_staging").resolve()
artifacts_dir = (tmp_root / "artifacts").resolve()
config_file = (tmp_root / "config" / "a_memorix.toml").resolve()
config_file.parent.mkdir(parents=True, exist_ok=True)
config_file.write_text(tomlkit.dumps(_build_test_config(data_dir)), encoding="utf-8")
patches = pytest.MonkeyPatch()
patches.setattr(host_service_module, "config_path", lambda: config_file)
patches.setattr(memory_router_module, "STAGING_ROOT", staging_dir)
patches.setattr(tuning_manager_module, "artifacts_root", lambda: artifacts_dir)
asyncio.run(host_service_module.a_memorix_host_service.stop())
host_service_module.a_memorix_host_service._config_cache = None # type: ignore[attr-defined]
app = FastAPI()
app.dependency_overrides[require_auth] = lambda: "ok"
app.include_router(memory_router_module.router, prefix="/api/webui")
app.include_router(memory_router_module.compat_router)
unique_token = f"INTEG_TOKEN_{uuid4().hex[:12]}"
source_name = f"integration-source-{uuid4().hex[:8]}"
with TestClient(app) as client:
upload_task_id = _create_multitype_upload_task(client)
upload_task = _wait_for_import_task_terminal(client, upload_task_id)
seed_task_id = _create_seed_paste_task(client, source=source_name, unique_token=unique_token)
seed_task = _wait_for_import_task_terminal(client, seed_task_id)
assert str(seed_task.get("status", "") or "") in {"completed", "completed_with_errors"}, seed_task
_wait_for_query_hit(client, unique_token, timeout_seconds=45.0)
yield {
"client": client,
"upload_task": upload_task,
"seed_task": seed_task,
"source_name": source_name,
"unique_token": unique_token,
}
asyncio.run(host_service_module.a_memorix_host_service.stop())
host_service_module.a_memorix_host_service._config_cache = None # type: ignore[attr-defined]
patches.undo()
def test_import_module_end_to_end_supports_multitype_upload(integration_state: Dict[str, Any]) -> None:
upload_task = integration_state["upload_task"]
assert str(upload_task.get("status", "") or "") in {"completed", "completed_with_errors"}, upload_task
files = upload_task.get("files") or []
assert isinstance(files, list)
assert len(files) >= 4
file_names = {str(item.get("name", "") or "") for item in files if isinstance(item, dict)}
assert "integration-notes.txt" in file_names
assert "integration-diary.md" in file_names
assert "integration-structured.json" in file_names
assert "integration-extra.json" in file_names
def test_retrieval_module_end_to_end_queries_seeded_data(integration_state: Dict[str, Any]) -> None:
client = integration_state["client"]
unique_token = integration_state["unique_token"]
aggregate_payload = _wait_for_query_hit(client, unique_token, timeout_seconds=45.0)
hits = aggregate_payload.get("hits") or []
joined_content = "\n".join(str(item.get("content", "") or "") for item in hits if isinstance(item, dict))
assert unique_token in joined_content
graph_payload = _assert_response_ok(
client.get(
"/api/webui/memory/graph/search",
params={"query": "Alice", "limit": 20},
)
)
graph_items = graph_payload.get("items") or []
assert isinstance(graph_items, list)
assert any(str(item.get("type", "") or "") == "entity" for item in graph_items if isinstance(item, dict)), graph_items
def test_tuning_module_end_to_end_create_and_apply_best(integration_state: Dict[str, Any]) -> None:
client = integration_state["client"]
create_payload = _assert_response_ok(
client.post(
"/api/webui/memory/retrieval_tuning/tasks",
json={
"objective": "balanced",
"intensity": "quick",
"rounds": 2,
"sample_size": 4,
"top_k_eval": 5,
"llm_enabled": False,
"eval_query_timeout_seconds": 1.0,
"seed": 20260403,
},
)
)
task_id = str((create_payload.get("task") or {}).get("task_id") or "").strip()
assert task_id, create_payload
task = _wait_for_tuning_task_terminal(client, task_id)
assert str(task.get("status", "") or "") == "completed", task
apply_payload = _assert_response_ok(
client.post(
f"/api/webui/memory/retrieval_tuning/tasks/{task_id}/apply-best",
)
)
assert "applied" in apply_payload
def test_delete_module_end_to_end_preview_execute_restore(integration_state: Dict[str, Any]) -> None:
client = integration_state["client"]
unique_token = integration_state["unique_token"]
source_name = integration_state["source_name"]
before_source_item = _wait_for_source_paragraph_count(client, source_name, min_count=1, timeout_seconds=45.0)
assert _source_paragraph_count(before_source_item) >= 1
preview_payload = _assert_response_ok(
client.post(
"/api/webui/memory/delete/preview",
json={
"mode": "source",
"selector": {"sources": [source_name]},
"reason": "integration_delete_preview",
"requested_by": "pytest_integration",
},
)
)
preview_counts = preview_payload.get("counts") or {}
assert int(preview_counts.get("paragraphs", 0) or 0) >= 1, preview_payload
execute_payload = _assert_response_ok(
client.post(
"/api/webui/memory/delete/execute",
json={
"mode": "source",
"selector": {"sources": [source_name]},
"reason": "integration_delete_execute",
"requested_by": "pytest_integration",
},
)
)
operation_id = str(execute_payload.get("operation_id", "") or "").strip()
assert operation_id, execute_payload
after_delete_payload = _assert_response_ok(
client.get(
"/api/webui/memory/query/aggregate",
params={"query": unique_token, "limit": 20},
)
)
after_delete_hits = after_delete_payload.get("hits") or []
after_delete_text = "\n".join(
str(item.get("content", "") or "")
for item in after_delete_hits
if isinstance(item, dict)
)
assert unique_token not in after_delete_text
assert int(execute_payload.get("deleted_paragraph_count", 0) or 0) >= 1, execute_payload
_assert_response_ok(
client.post(
"/api/webui/memory/delete/restore",
json={
"operation_id": operation_id,
"requested_by": "pytest_integration",
},
)
)
restored_source_item = _wait_for_source_paragraph_count(client, source_name, min_count=1, timeout_seconds=45.0)
assert _source_paragraph_count(restored_source_item) >= 1
operations_payload = _assert_response_ok(
client.get(
"/api/webui/memory/delete/operations",
params={"limit": 20, "mode": "source"},
)
)
operation_items = operations_payload.get("items") or []
operation_ids = {
str(item.get("operation_id", "") or "")
for item in operation_items
if isinstance(item, dict)
}
assert operation_id in operation_ids
operation_detail_payload = _assert_response_ok(client.get(f"/api/webui/memory/delete/operations/{operation_id}"))
detail_operation = operation_detail_payload.get("operation") or {}
assert str(detail_operation.get("status", "") or "") == "restored"

View File

@@ -98,7 +98,7 @@
"data_dir": {
"name": "data_dir",
"type": "string",
"default": "data/plugins/a-dawn.a-memorix",
"default": "data/a-memorix",
"description": "数据目录",
"label": "数据目录",
"ui_type": "text",
@@ -107,7 +107,7 @@
"disabled": false,
"order": 1,
"hint": "相对路径按 MaiBot 仓库根目录解析,建议保持默认外置目录。",
"placeholder": "data/plugins/a-dawn.a-memorix",
"placeholder": "data/a-memorix",
"choices": null
}
}

View File

@@ -1317,6 +1317,11 @@ class SDKMemoryKernel:
act = str(action or "").strip().lower()
if act == "get_graph":
return {"success": True, **self._serialize_graph(limit=max(1, int(kwargs.get("limit", 200) or 200)))}
if act == "search":
return self._search_graph(
query=str(kwargs.get("query", "") or "").strip(),
limit=max(1, min(200, int(kwargs.get("limit", 50) or 50))),
)
if act == "node_detail":
detail = self._build_graph_node_detail(
node_id=str(kwargs.get("node_id", "") or kwargs.get("node", "") or "").strip(),
@@ -2275,6 +2280,179 @@ class SDKMemoryKernel:
"total_edges": int(self.graph_store.num_edges),
}
@staticmethod
def _graph_search_match_rank(value: str, keyword: str) -> Optional[int]:
token = str(value or "").strip().lower()
if not token or not keyword:
return None
if token == keyword:
return 0
if token.startswith(keyword):
return 1
if keyword in token:
return 2
return None
@classmethod
def _pick_graph_search_match(
cls,
fields: Sequence[tuple[str, str]],
keyword: str,
) -> Optional[tuple[str, str, int]]:
best_match: Optional[tuple[str, str, int]] = None
for field, raw_value in fields:
value = str(raw_value or "").strip()
if not value:
continue
rank = cls._graph_search_match_rank(value, keyword)
if rank is None:
continue
if best_match is None or rank < best_match[2]:
best_match = (field, value, rank)
return best_match
def _search_graph(self, *, query: str, limit: int) -> Dict[str, Any]:
assert self.metadata_store is not None
token = str(query or "").strip()
normalized_query = token.lower()
safe_limit = max(1, int(limit or 50))
if not token:
return {
"success": False,
"query": token,
"limit": safe_limit,
"count": 0,
"items": [],
"error": "query 不能为空",
}
like_keyword = f"%{normalized_query}%"
entity_rows = self.metadata_store.query(
"""
SELECT hash, name, appearance_count, created_at
FROM entities
WHERE (is_deleted IS NULL OR is_deleted = 0)
AND (
LOWER(COALESCE(name, '')) LIKE ?
OR LOWER(COALESCE(hash, '')) LIKE ?
)
""",
(like_keyword, like_keyword),
)
relation_rows = self.metadata_store.query(
"""
SELECT hash, subject, predicate, object, confidence, created_at
FROM relations
WHERE (is_inactive IS NULL OR is_inactive = 0)
AND (
LOWER(COALESCE(subject, '')) LIKE ?
OR LOWER(COALESCE(object, '')) LIKE ?
OR LOWER(COALESCE(predicate, '')) LIKE ?
OR LOWER(COALESCE(hash, '')) LIKE ?
)
""",
(like_keyword, like_keyword, like_keyword, like_keyword),
)
entity_items: List[Dict[str, Any]] = []
seen_entity_keys: set[str] = set()
for row in entity_rows:
name = str(row.get("name", "") or "").strip()
hash_value = str(row.get("hash", "") or "").strip()
match = self._pick_graph_search_match(
[("name", name), ("hash", hash_value)],
normalized_query,
)
if match is None:
continue
dedupe_key = hash_value or f"name:{name.lower()}"
if dedupe_key in seen_entity_keys:
continue
seen_entity_keys.add(dedupe_key)
matched_field, matched_value, rank = match
entity_items.append(
{
"type": "entity",
"title": name or hash_value,
"matched_field": matched_field,
"matched_value": matched_value,
"entity_name": name or hash_value,
"entity_hash": hash_value,
"appearance_count": int(row.get("appearance_count", 0) or 0),
"_rank": rank,
}
)
relation_items: List[Dict[str, Any]] = []
seen_relation_keys: set[str] = set()
for row in relation_rows:
subject = str(row.get("subject", "") or "").strip()
predicate = str(row.get("predicate", "") or "").strip()
obj = str(row.get("object", "") or "").strip()
relation_hash = str(row.get("hash", "") or "").strip()
match = self._pick_graph_search_match(
[
("subject", subject),
("object", obj),
("predicate", predicate),
("hash", relation_hash),
],
normalized_query,
)
if match is None:
continue
dedupe_key = relation_hash or f"{subject.lower()}|{predicate.lower()}|{obj.lower()}"
if dedupe_key in seen_relation_keys:
continue
seen_relation_keys.add(dedupe_key)
matched_field, matched_value, rank = match
relation_items.append(
{
"type": "relation",
"title": self._format_relation_text(subject, predicate, obj),
"matched_field": matched_field,
"matched_value": matched_value,
"subject": subject,
"predicate": predicate,
"object": obj,
"relation_hash": relation_hash,
"confidence": float(row.get("confidence", 0.0) or 0.0),
"created_at": float(row.get("created_at", 0.0) or 0.0),
"_rank": rank,
}
)
items = entity_items + relation_items
items.sort(
key=lambda item: (
int(item["_rank"]) if item.get("_rank") is not None else 99,
0 if str(item.get("type", "") or "") == "entity" else 1,
-int(item.get("appearance_count", 0) or 0)
if str(item.get("type", "") or "") == "entity"
else -float(item.get("confidence", 0.0) or 0.0),
0.0 if str(item.get("type", "") or "") == "entity" else -float(item.get("created_at", 0.0) or 0.0),
str(item.get("entity_name", item.get("subject", "")) or "").lower(),
str(item.get("predicate", "") or "").lower(),
str(item.get("object", "") or "").lower(),
str(item.get("entity_hash", item.get("relation_hash", "")) or "").lower(),
)
)
normalized_items: List[Dict[str, Any]] = []
for item in items[:safe_limit]:
normalized = dict(item)
normalized.pop("_rank", None)
normalized_items.append(normalized)
return {
"success": True,
"query": token,
"limit": safe_limit,
"count": len(normalized_items),
"items": normalized_items,
}
@staticmethod
def _dedupe_strings(values: Iterable[Any]) -> List[str]:
deduped: List[str] = []

View File

@@ -277,6 +277,7 @@ class EpisodeSegmentationService:
model_config, model_label = self._resolve_model_config()
if model_config is None:
raise RuntimeError("episode segmentation model unavailable")
task_name = llm_api.resolve_task_name_from_model_config(model_config, preferred_task_name=model_label)
prompt = self._build_prompt(
source=source,
@@ -284,11 +285,17 @@ class EpisodeSegmentationService:
window_end=window_end,
paragraphs=paragraphs,
)
success, response, _, _ = await llm_api.generate_with_model(
prompt=prompt,
model_config=model_config,
request_type="A_Memorix.EpisodeSegmentation",
result = await llm_api.generate(
llm_api.LLMServiceRequest(
task_name=task_name,
request_type="A_Memorix.EpisodeSegmentation",
prompt=prompt,
temperature=getattr(model_config, "temperature", None),
max_tokens=getattr(model_config, "max_tokens", None),
)
)
success = bool(result.success)
response = str(result.completion.response or "")
if not success or not response:
raise RuntimeError("llm_generate_failed")

View File

@@ -1306,6 +1306,7 @@ class RetrievalTuningManager:
model_cfg = await self._select_llm_model()
if model_cfg is None:
raise RuntimeError("no_llm_model")
task_name = llm_api.resolve_task_name_from_model_config(model_cfg)
retry = self._llm_retry_cfg()
max_attempts = int(retry["max_attempts"])
@@ -1316,11 +1317,17 @@ class RetrievalTuningManager:
last_error: Optional[Exception] = None
for idx in range(max_attempts):
try:
success, response, _, _ = await llm_api.generate_with_model(
prompt=prompt,
model_config=model_cfg,
request_type=request_type,
result = await llm_api.generate(
llm_api.LLMServiceRequest(
task_name=task_name,
request_type=request_type,
prompt=prompt,
temperature=getattr(model_cfg, "temperature", None),
max_tokens=getattr(model_cfg, "max_tokens", None),
)
)
success = bool(result.success)
response = str(result.completion.response or "")
if not success:
raise RuntimeError("llm_generation_failed")
text = str(response or "").strip()

View File

@@ -280,15 +280,22 @@ class SummaryImporter:
model_config_to_use = self._resolve_summary_model_config()
if model_config_to_use is None:
return False, "未找到可用的总结模型配置"
task_name_to_use = llm_api.resolve_task_name_from_model_config(model_config_to_use)
logger.info(f"正在为流 {stream_id} 执行总结,消息条数: {len(messages)}")
logger.info(f"总结模型候选列表: {model_config_to_use.model_list}")
success, response, _, _ = await llm_api.generate_with_model(
prompt=prompt,
model_config=model_config_to_use,
request_type="A_Memorix.ChatSummarization"
result = await llm_api.generate(
llm_api.LLMServiceRequest(
task_name=task_name_to_use,
request_type="A_Memorix.ChatSummarization",
prompt=prompt,
temperature=getattr(model_config_to_use, "temperature", None),
max_tokens=getattr(model_config_to_use, "max_tokens", None),
)
)
success = bool(result.success)
response = str(result.completion.response or "")
if not success or not response:
return False, "LLM 生成总结失败"

View File

@@ -3165,14 +3165,21 @@ class ImportTaskManager:
async def _llm_call(self, prompt: str, model_config: Any) -> Dict[str, Any]:
cfg = self._llm_retry_config()
retries = int(cfg["retries"])
task_name = llm_api.resolve_task_name_from_model_config(model_config)
last_error: Optional[Exception] = None
for attempt in range(retries + 1):
try:
success, response, _, _ = await llm_api.generate_with_model(
prompt=prompt,
model_config=model_config,
request_type="A_Memorix.WebImport",
result = await llm_api.generate(
llm_api.LLMServiceRequest(
task_name=task_name,
request_type="A_Memorix.WebImport",
prompt=prompt,
temperature=getattr(model_config, "temperature", None),
max_tokens=getattr(model_config, "max_tokens", None),
)
)
success = bool(result.success)
response = str(result.completion.response or "")
if not success or not response:
raise RuntimeError("LLM 生成失败")

View File

@@ -88,11 +88,59 @@ class AMemorixHostService:
def get_config(self) -> Dict[str, Any]:
return dict(self._read_config())
def get_raw_config(self) -> str:
def _build_default_config(self) -> Dict[str, Any]:
schema = self.get_config_schema()
sections = schema.get("sections") if isinstance(schema, dict) else None
if not isinstance(sections, dict):
return {}
defaults: Dict[str, Any] = {}
for section_name, section_payload in sections.items():
if not isinstance(section_payload, dict):
continue
fields = section_payload.get("fields")
if not isinstance(fields, dict):
continue
section_parts = [part for part in str(section_name or "").split(".") if part]
if not section_parts:
continue
section_target: Dict[str, Any] = defaults
for part in section_parts:
nested = section_target.get(part)
if not isinstance(nested, dict):
nested = {}
section_target[part] = nested
section_target = nested
for field_name, field_payload in fields.items():
if not isinstance(field_payload, dict) or "default" not in field_payload:
continue
section_target[str(field_name)] = _to_builtin_data(field_payload.get("default"))
return defaults
def get_raw_config_with_meta(self) -> Dict[str, Any]:
path = self.get_config_path()
if not path.exists():
return ""
return path.read_text(encoding="utf-8")
if path.exists():
return {
"config": path.read_text(encoding="utf-8"),
"exists": True,
"using_default": False,
}
default_config = self._build_default_config()
default_raw = tomlkit.dumps(default_config) if default_config else ""
return {
"config": default_raw,
"exists": False,
"using_default": True,
}
def get_raw_config(self) -> str:
payload = self.get_raw_config_with_meta()
return str(payload.get("config", "") or "")
async def update_raw_config(self, raw_config: str) -> Dict[str, Any]:
tomlkit.loads(raw_config)
@@ -231,16 +279,18 @@ class AMemorixHostService:
path = self.get_config_path()
if not path.exists():
self._config_cache = {}
return {}
defaults = self._build_default_config()
self._config_cache = defaults
return dict(defaults)
try:
with path.open("r", encoding="utf-8") as handle:
loaded = tomlkit.load(handle)
except Exception as exc:
logger.warning("读取 A_Memorix 配置失败 %s: %s", path, exc)
self._config_cache = {}
return {}
defaults = self._build_default_config()
self._config_cache = defaults
return dict(defaults)
self._config_cache = _to_builtin_data(loaded) if isinstance(loaded, dict) else {}
return dict(self._config_cache)

View File

@@ -560,11 +560,18 @@ Chat paragraph:
)
async def _llm_call(self, prompt: str, model_config: Any) -> Dict:
"""Generic LLM Caller"""
success, response, _, _ = await llm_api.generate_with_model(
prompt=prompt,
model_config=model_config,
request_type="Script.ProcessKnowledge"
task_name = llm_api.resolve_task_name_from_model_config(model_config)
result = await llm_api.generate(
llm_api.LLMServiceRequest(
task_name=task_name,
request_type="Script.ProcessKnowledge",
prompt=prompt,
temperature=getattr(model_config, "temperature", None),
max_tokens=getattr(model_config, "max_tokens", None),
)
)
success = bool(result.success)
response = str(result.completion.response or "")
if success:
txt = response.strip()
if "```" in txt:

View File

@@ -230,6 +230,61 @@ def resolve_task_name(task_name: str = "") -> str:
return normalized_task_name
def resolve_task_name_from_model_config(model_config: Any, preferred_task_name: str = "") -> str:
"""根据旧版 `TaskConfig` 风格参数解析可用任务名。
该方法用于兼容仍以 `model_config` 传参的调用方:
1. 优先使用显式给出的 `preferred_task_name`
2. 其次匹配对象同一性;
3. 再尝试按 `model_list` 精确匹配;
4. 最后按 `model_list` 中首个命中的模型进行近似映射。
Args:
model_config: 旧调用方持有的任务配置对象。
preferred_task_name: 候选任务名(可选)。
Returns:
str: 可用于 `LLMServiceRequest.task_name` 的任务名。
Raises:
RuntimeError: 当前没有可用模型配置。
ValueError: 无法解析任何可用任务名时抛出。
"""
models = get_available_models()
if not models:
raise RuntimeError("没有可用的模型配置")
normalized_preferred = str(preferred_task_name or "").strip()
if normalized_preferred and normalized_preferred in models:
return normalized_preferred
for task_name, task_cfg in models.items():
if task_cfg is model_config:
return task_name
requested_model_list_raw = getattr(model_config, "model_list", [])
requested_model_list = [str(item).strip() for item in (requested_model_list_raw or []) if str(item).strip()]
if requested_model_list:
for task_name, task_cfg in models.items():
candidate_list = [str(item).strip() for item in getattr(task_cfg, "model_list", []) if str(item).strip()]
if candidate_list == requested_model_list:
return task_name
for requested_model in requested_model_list:
for task_name, task_cfg in models.items():
candidate_list = [str(item).strip() for item in getattr(task_cfg, "model_list", []) if str(item).strip()]
if requested_model in candidate_list:
logger.info(
"[LLMService] 旧版 model_config 未命中任务配置,"
f"按模型 `{requested_model}` 近似映射到任务 `{task_name}`"
)
return task_name
if normalized_preferred:
logger.warning(f"[LLMService] 无法映射旧版 model_config回退默认任务: preferred={normalized_preferred}")
return resolve_task_name("")
def _normalize_role(role_name: str) -> RoleType:
"""将原始角色字符串转换为内部角色枚举。

View File

@@ -168,6 +168,10 @@ async def _graph_get(limit: int) -> dict:
return await memory_service.graph_admin(action="get_graph", limit=limit)
async def _graph_search(query: str, limit: int) -> dict:
return await memory_service.graph_admin(action="search", query=query, limit=limit)
async def _graph_get_node_detail(
node_id: str,
*,
@@ -390,9 +394,20 @@ async def _memory_config_get() -> dict:
async def _memory_config_get_raw() -> dict:
raw_payload_getter = getattr(a_memorix_host_service, "get_raw_config_with_meta", None)
if callable(raw_payload_getter):
raw_payload = raw_payload_getter()
else:
raw_payload = {
"config": a_memorix_host_service.get_raw_config(),
"exists": bool(a_memorix_host_service.get_config_path().exists()),
"using_default": False,
}
return {
"success": True,
"config": a_memorix_host_service.get_raw_config(),
"config": str(raw_payload.get("config", "") or ""),
"exists": bool(raw_payload.get("exists", False)),
"using_default": bool(raw_payload.get("using_default", False)),
"path": str(a_memorix_host_service.get_config_path()),
}
@@ -649,6 +664,14 @@ async def get_memory_graph(limit: int = Query(200, ge=1, le=5000)):
return await _graph_get(limit)
@router.get("/graph/search")
async def search_memory_graph(
query: str = Query(..., min_length=1),
limit: int = Query(50, ge=1, le=200),
):
return await _graph_search(query, limit)
@router.get("/graph/node-detail")
async def get_memory_graph_node_detail(
node_id: str = Query(..., min_length=1),