fix: override plugin tool stream context and remove repository tests

- force plugin tools to use runtime stream_id/chat_id from execution context
- remove repository test assets and vitest config
- document that temporary test files must be deleted after use
This commit is contained in:
Losita
2026-05-12 22:36:32 +08:00
parent 702316ae57
commit 8d0f6d4401
98 changed files with 4 additions and 30458 deletions

View File

@@ -33,6 +33,7 @@
# 运行/调试/构建/测试/依赖 # 运行/调试/构建/测试/依赖
优先使用uv 优先使用uv
依赖项以 pyproject.toml 为准要同步更新requirements.txt 依赖项以 pyproject.toml 为准要同步更新requirements.txt
如为当前任务临时创建测试文件,跑完测试后必须立刻删除,不要保留在仓库中,也不要进入共享历史
前端改动后,如需走离线发布工作流,必须先在 `dashboard` 目录执行 `npm run build`(例如:`D:\Nodejs\npm.cmd run build`),确保生成最新的 `dashboard/dist` 前端改动后,如需走离线发布工作流,必须先在 `dashboard` 目录执行 `npm run build`(例如:`D:\Nodejs\npm.cmd run build`),确保生成最新的 `dashboard/dist`
当前项目的离线发布工作流依赖仓库中已提交的 `dashboard/dist`;在修改 `dashboard/src``dashboard/public``dashboard/package.json``dashboard/package-lock.json` 或其他影响构建产物的前端文件后,要同步提交对应的 `dashboard/dist` 当前项目的离线发布工作流依赖仓库中已提交的 `dashboard/dist`;在修改 `dashboard/src``dashboard/public``dashboard/package.json``dashboard/package-lock.json` 或其他影响构建产物的前端文件后,要同步提交对应的 `dashboard/dist`
不要在离线发布场景中假设服务器或 runner 可以联网安装 Node.js 依赖;除非明确确认发布机已具备可用的 `npm` 环境,否则默认按“本地构建 `dashboard/dist` 并随仓库提交”处理 不要在离线发布场景中假设服务器或 runner 可以联网安装 Node.js 依赖;除非明确确认发布机已具备可用的 `npm` 环境,否则默认按“本地构建 `dashboard/dist` 并随仓库提交”处理

View File

@@ -1,427 +0,0 @@
import { describe, it, expect, vi } from 'vitest'
import { screen } from '@testing-library/dom'
import { render } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { DynamicConfigForm } from '../DynamicConfigForm'
import { FieldHookRegistry } from '@/lib/field-hooks'
import type { ConfigSchema } from '@/types/config-schema'
import type { FieldHookComponentProps } from '@/lib/field-hooks'
describe('DynamicConfigForm', () => {
describe('basic rendering', () => {
it('renders simple fields', () => {
const schema: ConfigSchema = {
className: 'TestConfig',
classDoc: 'Test configuration',
fields: [
{
name: 'field1',
type: 'string',
label: 'Field 1',
description: 'First field',
required: false,
default: 'value1',
},
{
name: 'field2',
type: 'boolean',
label: 'Field 2',
description: 'Second field',
required: false,
default: false,
},
],
}
const values = { field1: 'value1', field2: false }
const onChange = vi.fn()
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} />)
expect(screen.getByText('Field 1')).toBeInTheDocument()
expect(screen.getByText('Field 2')).toBeInTheDocument()
expect(screen.getByText('First field')).toBeInTheDocument()
expect(screen.getByText('Second field')).toBeInTheDocument()
})
it('renders nested schema', () => {
const schema: ConfigSchema = {
className: 'MainConfig',
classDoc: 'Main configuration',
fields: [
{
name: 'top_field',
type: 'string',
label: 'Top Field',
description: 'Top level field',
required: false,
},
],
nested: {
sub_config: {
className: 'SubConfig',
classDoc: 'Sub configuration',
fields: [
{
name: 'nested_field',
type: 'number',
label: 'Nested Field',
description: 'Nested field',
required: false,
default: 42,
},
],
},
},
}
const values = {
top_field: 'top',
sub_config: {
nested_field: 42,
},
}
const onChange = vi.fn()
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} />)
expect(screen.getByText('Top Field')).toBeInTheDocument()
expect(screen.getByText('Sub configuration')).toBeInTheDocument()
expect(screen.getByText('Nested Field')).toBeInTheDocument()
})
})
describe('Hook system', () => {
it('renders Hook component in replace mode', () => {
const TestHookComponent: React.FC<FieldHookComponentProps> = ({ fieldPath, value }) => {
return <div data-testid="hook-component">Hook: {fieldPath} = {String(value)}</div>
}
const hooks = new FieldHookRegistry()
hooks.register('hooked_field', TestHookComponent, 'replace')
const schema: ConfigSchema = {
className: 'TestConfig',
classDoc: 'Test configuration',
fields: [
{
name: 'hooked_field',
type: 'string',
label: 'Hooked Field',
description: 'A field with hook',
required: false,
},
{
name: 'normal_field',
type: 'string',
label: 'Normal Field',
description: 'A normal field',
required: false,
},
],
}
const values = { hooked_field: 'test', normal_field: 'normal' }
const onChange = vi.fn()
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} hooks={hooks} />)
expect(screen.getByTestId('hook-component')).toBeInTheDocument()
expect(screen.getByText('Hook: hooked_field = test')).toBeInTheDocument()
expect(screen.queryByText('Hooked Field')).not.toBeInTheDocument()
expect(screen.getByText('Normal Field')).toBeInTheDocument()
})
it('renders Hook component in wrapper mode', () => {
const WrapperHookComponent: React.FC<FieldHookComponentProps> = ({ fieldPath, children }) => {
return (
<div data-testid="wrapper-hook">
<div>Wrapper for: {fieldPath}</div>
{children}
</div>
)
}
const hooks = new FieldHookRegistry()
hooks.register('wrapped_field', WrapperHookComponent, 'wrapper')
const schema: ConfigSchema = {
className: 'TestConfig',
classDoc: 'Test configuration',
fields: [
{
name: 'wrapped_field',
type: 'string',
label: 'Wrapped Field',
description: 'A wrapped field',
required: false,
},
],
}
const values = { wrapped_field: 'test' }
const onChange = vi.fn()
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} hooks={hooks} />)
expect(screen.getByTestId('wrapper-hook')).toBeInTheDocument()
expect(screen.getByText('Wrapper for: wrapped_field')).toBeInTheDocument()
expect(screen.getByText('Wrapped Field')).toBeInTheDocument()
})
it('passes correct props to Hook component', () => {
const TestHookComponent: React.FC<FieldHookComponentProps> = ({ fieldPath, value, onChange }) => {
return (
<div>
<div data-testid="field-path">{fieldPath}</div>
<div data-testid="field-value">{String(value)}</div>
<button onClick={() => onChange?.('new_value')}>Change</button>
</div>
)
}
const hooks = new FieldHookRegistry()
hooks.register('test_field', TestHookComponent, 'replace')
const schema: ConfigSchema = {
className: 'TestConfig',
classDoc: 'Test configuration',
fields: [
{
name: 'test_field',
type: 'string',
label: 'Test Field',
description: 'A test field',
required: false,
},
],
}
const values = { test_field: 'original' }
const onChange = vi.fn()
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} hooks={hooks} />)
expect(screen.getByTestId('field-path')).toHaveTextContent('test_field')
expect(screen.getByTestId('field-value')).toHaveTextContent('original')
})
})
describe('onChange propagation', () => {
it('propagates onChange from simple field', async () => {
const schema: ConfigSchema = {
className: 'TestConfig',
classDoc: 'Test configuration',
fields: [
{
name: 'test_field',
type: 'string',
label: 'Test Field',
description: 'A test field',
required: false,
},
],
}
const values = { test_field: '' }
const onChange = vi.fn()
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} />)
const input = screen.getByRole('textbox')
input.focus()
await userEvent.keyboard('Hello')
expect(onChange).toHaveBeenCalledTimes(5)
expect(onChange.mock.calls.every(call => call[0] === 'test_field')).toBe(true)
expect(onChange).toHaveBeenNthCalledWith(1, 'test_field', 'H')
expect(onChange).toHaveBeenNthCalledWith(5, 'test_field', 'o')
})
it('propagates onChange from nested field with correct path', async () => {
const schema: ConfigSchema = {
className: 'MainConfig',
classDoc: 'Main configuration',
fields: [],
nested: {
sub_config: {
className: 'SubConfig',
classDoc: 'Sub configuration',
fields: [
{
name: 'nested_field',
type: 'string',
label: 'Nested Field',
description: 'Nested field',
required: false,
},
],
},
},
}
const values = {
sub_config: {
nested_field: '',
},
}
const onChange = vi.fn()
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} />)
const input = screen.getByRole('textbox')
input.focus()
await userEvent.keyboard('Test')
expect(onChange).toHaveBeenCalledTimes(4)
expect(onChange.mock.calls.every(call => call[0] === 'sub_config.nested_field')).toBe(true)
expect(onChange).toHaveBeenNthCalledWith(1, 'sub_config.nested_field', 'T')
expect(onChange).toHaveBeenNthCalledWith(4, 'sub_config.nested_field', 't')
})
it('propagates onChange from Hook component', async () => {
const TestHookComponent: React.FC<FieldHookComponentProps> = ({ onChange }) => {
return <button onClick={() => onChange?.('hook_value')}>Set Value</button>
}
const hooks = new FieldHookRegistry()
hooks.register('hooked_field', TestHookComponent, 'replace')
const schema: ConfigSchema = {
className: 'TestConfig',
classDoc: 'Test configuration',
fields: [
{
name: 'hooked_field',
type: 'string',
label: 'Hooked Field',
description: 'A hooked field',
required: false,
},
],
}
const values = { hooked_field: '' }
const onChange = vi.fn()
const user = userEvent.setup()
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} hooks={hooks} />)
await user.click(screen.getByRole('button'))
expect(onChange).toHaveBeenCalledWith('hooked_field', 'hook_value')
})
it('renders nested Hook component with full field path', async () => {
const NestedHookComponent: React.FC<FieldHookComponentProps> = ({ fieldPath, onChange }) => {
return (
<button onClick={() => onChange?.([{ enabled: true }])}>
{fieldPath}
</button>
)
}
const hooks = new FieldHookRegistry()
hooks.register('mcp.servers', NestedHookComponent, 'replace')
const schema: ConfigSchema = {
className: 'RootConfig',
classDoc: 'Root configuration',
fields: [],
nested: {
mcp: {
className: 'MCPConfig',
classDoc: 'MCP 配置',
fields: [
{
name: 'enable',
type: 'boolean',
label: '启用 MCP',
description: '是否启用 MCP',
required: false,
},
{
name: 'servers',
type: 'array',
label: '服务器列表',
description: '复杂对象数组',
required: false,
items: {
type: 'object',
},
},
],
nested: {
servers: {
className: 'MCPServerItemConfig',
classDoc: 'MCP 服务器项',
fields: [],
},
},
},
},
}
const values = {
mcp: {
enable: true,
servers: [],
},
}
const onChange = vi.fn()
const user = userEvent.setup()
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} hooks={hooks} />)
await user.click(screen.getByRole('button', { name: 'mcp.servers' }))
expect(onChange).toHaveBeenCalledWith('mcp.servers', [{ enabled: true }])
})
})
describe('edge cases', () => {
it('renders with empty nested values', () => {
const schema: ConfigSchema = {
className: 'MainConfig',
classDoc: 'Main configuration',
fields: [],
nested: {
sub_config: {
className: 'SubConfig',
classDoc: 'Sub configuration',
fields: [
{
name: 'nested_field',
type: 'string',
label: 'Nested Field',
description: 'Nested field',
required: false,
},
],
},
},
}
const values = {}
const onChange = vi.fn()
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} />)
expect(screen.getByText('Sub configuration')).toBeInTheDocument()
expect(screen.getByText('Nested Field')).toBeInTheDocument()
})
it('uses default hook registry when not provided', () => {
const schema: ConfigSchema = {
className: 'TestConfig',
classDoc: 'Test configuration',
fields: [
{
name: 'test_field',
type: 'string',
label: 'Test Field',
description: 'A test field',
required: false,
},
],
}
const values = { test_field: 'test' }
const onChange = vi.fn()
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} />)
expect(screen.getByText('Test Field')).toBeInTheDocument()
})
})
})

View File

@@ -1,475 +0,0 @@
import { describe, it, expect, vi } from 'vitest'
import { screen } from '@testing-library/dom'
import { render } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { DynamicField } from '../DynamicField'
import type { FieldSchema } from '@/types/config-schema'
describe('DynamicField', () => {
describe('x-widget priority', () => {
it('renders Slider when x-widget is slider', () => {
const schema: FieldSchema = {
name: 'test_slider',
type: 'number',
label: 'Test Slider',
description: 'A test slider',
required: false,
'x-widget': 'slider',
minValue: 0,
maxValue: 100,
default: 50,
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value={50} onChange={onChange} />)
expect(screen.getByText('Test Slider')).toBeInTheDocument()
expect(screen.getByRole('slider')).toBeInTheDocument()
expect(screen.getByText('50')).toBeInTheDocument()
})
it('renders Switch when x-widget is switch', () => {
const schema: FieldSchema = {
name: 'test_switch',
type: 'boolean',
label: 'Test Switch',
description: 'A test switch',
required: false,
'x-widget': 'switch',
default: false,
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value={false} onChange={onChange} />)
expect(screen.getByText('Test Switch')).toBeInTheDocument()
expect(screen.getByRole('switch')).toBeInTheDocument()
})
it('renders Textarea when x-widget is textarea', () => {
const schema: FieldSchema = {
name: 'test_textarea',
type: 'string',
label: 'Test Textarea',
description: 'A test textarea',
required: false,
'x-widget': 'textarea',
default: 'Hello',
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value="Hello" onChange={onChange} />)
expect(screen.getByText('Test Textarea')).toBeInTheDocument()
expect(screen.getByRole('textbox')).toBeInTheDocument()
expect(screen.getByRole('textbox')).toHaveValue('Hello')
})
it('renders Select when x-widget is select', () => {
const schema: FieldSchema = {
name: 'test_select',
type: 'string',
label: 'Test Select',
description: 'A test select',
required: false,
'x-widget': 'select',
options: ['Option 1', 'Option 2', 'Option 3'],
default: 'Option 1',
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value="Option 1" onChange={onChange} />)
expect(screen.getByText('Test Select')).toBeInTheDocument()
expect(screen.getByRole('combobox')).toBeInTheDocument()
})
it('renders placeholder for custom widget', () => {
const schema: FieldSchema = {
name: 'test_custom',
type: 'string',
label: 'Test Custom',
description: 'A test custom field',
required: false,
'x-widget': 'custom',
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value="" onChange={onChange} />)
expect(screen.getByText('Custom field requires Hook')).toBeInTheDocument()
})
it('renders number Input when x-widget is input but type is integer', () => {
const schema: FieldSchema = {
name: 'test_integer_input_widget',
type: 'integer',
label: 'Test Integer Input Widget',
description: 'A numeric field rendered as input',
required: false,
'x-widget': 'input',
default: 0,
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value={2} onChange={onChange} />)
const input = screen.getByRole('spinbutton')
expect(input).toBeInTheDocument()
expect(input).toHaveValue(2)
})
it('parses string values for numeric input widgets', () => {
const schema: FieldSchema = {
name: 'test_string_number_input_widget',
type: 'integer',
label: 'Test String Number Input Widget',
description: 'A numeric field with legacy string value',
required: false,
'x-widget': 'input',
default: 0,
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value="2" onChange={onChange} />)
expect(screen.getByRole('spinbutton')).toHaveValue(2)
})
})
describe('type fallback', () => {
it('renders Input for string type', () => {
const schema: FieldSchema = {
name: 'test_string',
type: 'string',
label: 'Test String',
description: 'A test string',
required: false,
default: 'Hello',
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value="Hello" onChange={onChange} />)
expect(screen.getByRole('textbox')).toBeInTheDocument()
expect(screen.getByRole('textbox')).toHaveValue('Hello')
})
it('renders Switch for boolean type', () => {
const schema: FieldSchema = {
name: 'test_bool',
type: 'boolean',
label: 'Test Boolean',
description: 'A test boolean',
required: false,
default: true,
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value={true} onChange={onChange} />)
expect(screen.getByRole('switch')).toBeInTheDocument()
expect(screen.getByRole('switch')).toBeChecked()
})
it('renders number Input for number type', () => {
const schema: FieldSchema = {
name: 'test_number',
type: 'number',
label: 'Test Number',
description: 'A test number',
required: false,
default: 42,
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value={42} onChange={onChange} />)
const input = screen.getByRole('spinbutton')
expect(input).toBeInTheDocument()
expect(input).toHaveValue(42)
})
it('renders number Input for integer type', () => {
const schema: FieldSchema = {
name: 'test_integer',
type: 'integer',
label: 'Test Integer',
description: 'A test integer',
required: false,
default: 10,
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value={10} onChange={onChange} />)
const input = screen.getByRole('spinbutton')
expect(input).toBeInTheDocument()
expect(input).toHaveValue(10)
})
it('renders Textarea for textarea type', () => {
const schema: FieldSchema = {
name: 'test_textarea_type',
type: 'textarea',
label: 'Test Textarea Type',
description: 'A test textarea type',
required: false,
default: 'Long text',
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value="Long text" onChange={onChange} />)
expect(screen.getByRole('textbox')).toBeInTheDocument()
expect(screen.getByRole('textbox')).toHaveValue('Long text')
})
it('renders Select for select type', () => {
const schema: FieldSchema = {
name: 'test_select_type',
type: 'select',
label: 'Test Select Type',
description: 'A test select type',
required: false,
options: ['A', 'B', 'C'],
default: 'A',
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value="A" onChange={onChange} />)
expect(screen.getByRole('combobox')).toBeInTheDocument()
})
it('renders textarea editor for primitive array type', () => {
const schema: FieldSchema = {
name: 'test_array',
type: 'array',
label: 'Test Array',
description: 'A test array',
required: false,
items: {
type: 'string',
},
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value={['a', 'b']} onChange={onChange} />)
expect(screen.getByRole('textbox')).toHaveValue('a\nb')
})
it('renders key-value editor for object type', () => {
const schema: FieldSchema = {
name: 'test_object',
type: 'object',
label: 'Test Object',
description: 'A test object',
required: false,
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value={{ foo: 'bar' }} onChange={onChange} />)
expect(screen.getByText('可视化编辑')).toBeInTheDocument()
expect(screen.getByDisplayValue('foo')).toBeInTheDocument()
})
})
describe('onChange events', () => {
it('triggers onChange for Switch', async () => {
const schema: FieldSchema = {
name: 'test_switch',
type: 'boolean',
label: 'Test Switch',
description: 'A test switch',
required: false,
default: false,
}
const onChange = vi.fn()
const user = userEvent.setup()
render(<DynamicField schema={schema} value={false} onChange={onChange} />)
await user.click(screen.getByRole('switch'))
expect(onChange).toHaveBeenCalledWith(true)
})
it('triggers onChange for Input', async () => {
const schema: FieldSchema = {
name: 'test_input',
type: 'string',
label: 'Test Input',
description: 'A test input',
required: false,
default: '',
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value="" onChange={onChange} />)
const input = screen.getByRole('textbox')
input.focus()
await userEvent.keyboard('Hello')
expect(onChange).toHaveBeenCalledTimes(5)
expect(onChange).toHaveBeenNthCalledWith(1, 'H')
expect(onChange).toHaveBeenNthCalledWith(2, 'e')
expect(onChange).toHaveBeenNthCalledWith(3, 'l')
expect(onChange).toHaveBeenNthCalledWith(4, 'l')
expect(onChange).toHaveBeenNthCalledWith(5, 'o')
})
it('triggers onChange for number Input', async () => {
const schema: FieldSchema = {
name: 'test_number',
type: 'number',
label: 'Test Number',
description: 'A test number',
required: false,
default: 0,
}
const onChange = vi.fn()
const user = userEvent.setup()
render(<DynamicField schema={schema} value={0} onChange={onChange} />)
const input = screen.getByRole('spinbutton')
await user.clear(input)
await user.type(input, '123')
expect(onChange).toHaveBeenCalled()
})
it('triggers numeric onChange for input widget with integer type', async () => {
const schema: FieldSchema = {
name: 'test_integer_input_widget_change',
type: 'integer',
label: 'Test Integer Input Widget Change',
description: 'A numeric field rendered as input',
required: false,
'x-widget': 'input',
default: 0,
}
const onChange = vi.fn()
const user = userEvent.setup()
render(<DynamicField schema={schema} value={0} onChange={onChange} />)
const input = screen.getByRole('spinbutton')
await user.clear(input)
await user.type(input, '5')
expect(onChange).toHaveBeenLastCalledWith(5)
})
})
describe('visual features', () => {
it('renders label with icon', () => {
const schema: FieldSchema = {
name: 'test_icon',
type: 'string',
label: 'Test Icon',
description: 'A test with icon',
required: false,
'x-icon': 'Settings',
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value="" onChange={onChange} />)
expect(screen.getByText('Test Icon')).toBeInTheDocument()
})
it('renders required indicator', () => {
const schema: FieldSchema = {
name: 'test_required',
type: 'string',
label: 'Test Required',
description: 'A required field',
required: true,
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value="" onChange={onChange} />)
expect(screen.getByText('*')).toBeInTheDocument()
})
it('renders description', () => {
const schema: FieldSchema = {
name: 'test_desc',
type: 'string',
label: 'Test Description',
description: 'This is a description',
required: false,
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value="" onChange={onChange} />)
expect(screen.getByText('This is a description')).toBeInTheDocument()
})
})
describe('slider features', () => {
it('renders slider with min/max/step', () => {
const schema: FieldSchema = {
name: 'test_slider_props',
type: 'number',
label: 'Test Slider Props',
description: 'A slider with props',
required: false,
'x-widget': 'slider',
minValue: 10,
maxValue: 50,
step: 5,
default: 25,
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value={25} onChange={onChange} />)
expect(screen.getByText('10')).toBeInTheDocument()
expect(screen.getByText('50')).toBeInTheDocument()
expect(screen.getByText('25')).toBeInTheDocument()
})
it('parses string values for slider widgets', () => {
const schema: FieldSchema = {
name: 'test_slider_string_value',
type: 'number',
label: 'Test Slider String Value',
description: 'A slider with legacy string value',
required: false,
'x-widget': 'slider',
minValue: 0,
maxValue: 10,
default: 0,
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value="2.5" onChange={onChange} />)
expect(screen.getByText('2.5')).toBeInTheDocument()
})
})
describe('select features', () => {
it('renders placeholder when no options', () => {
const schema: FieldSchema = {
name: 'test_select_no_options',
type: 'string',
label: 'Test Select No Options',
description: 'A select with no options',
required: false,
'x-widget': 'select',
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value="" onChange={onChange} />)
expect(screen.getByText('No options available for select')).toBeInTheDocument()
})
})
})

View File

@@ -1,253 +0,0 @@
import { describe, it, expect, beforeEach } from 'vitest'
import { FieldHookRegistry } from '../field-hooks'
import type { FieldHookComponent } from '../field-hooks'
describe('FieldHookRegistry', () => {
let registry: FieldHookRegistry
beforeEach(() => {
registry = new FieldHookRegistry()
})
describe('register', () => {
it('registers a hook with replace type', () => {
const component: FieldHookComponent = () => null
registry.register('test.field', component, 'replace')
expect(registry.has('test.field')).toBe(true)
})
it('registers a hook with wrapper type', () => {
const component: FieldHookComponent = () => null
registry.register('test.field', component, 'wrapper')
expect(registry.has('test.field')).toBe(true)
const entry = registry.get('test.field')
expect(entry?.type).toBe('wrapper')
})
it('defaults to replace type when not specified', () => {
const component: FieldHookComponent = () => null
registry.register('test.field', component)
const entry = registry.get('test.field')
expect(entry?.type).toBe('replace')
})
it('overwrites existing hook for same field path', () => {
const component1: FieldHookComponent = () => null
const component2: FieldHookComponent = () => null
registry.register('test.field', component1, 'replace')
registry.register('test.field', component2, 'wrapper')
const entry = registry.get('test.field')
expect(entry?.component).toBe(component2)
expect(entry?.type).toBe('wrapper')
})
})
describe('get', () => {
it('returns hook entry for registered field path', () => {
const component: FieldHookComponent = () => null
registry.register('test.field', component, 'replace')
const entry = registry.get('test.field')
expect(entry).toBeDefined()
expect(entry?.component).toBe(component)
expect(entry?.type).toBe('replace')
})
it('returns undefined for unregistered field path', () => {
const entry = registry.get('nonexistent.field')
expect(entry).toBeUndefined()
})
it('returns correct entry for nested field paths', () => {
const component: FieldHookComponent = () => null
registry.register('config.section.field', component, 'wrapper')
const entry = registry.get('config.section.field')
expect(entry).toBeDefined()
expect(entry?.type).toBe('wrapper')
})
})
describe('has', () => {
it('returns true for registered field path', () => {
const component: FieldHookComponent = () => null
registry.register('test.field', component)
expect(registry.has('test.field')).toBe(true)
})
it('returns false for unregistered field path', () => {
expect(registry.has('nonexistent.field')).toBe(false)
})
it('returns false after unregistering', () => {
const component: FieldHookComponent = () => null
registry.register('test.field', component)
registry.unregister('test.field')
expect(registry.has('test.field')).toBe(false)
})
})
describe('unregister', () => {
it('removes a registered hook', () => {
const component: FieldHookComponent = () => null
registry.register('test.field', component)
expect(registry.has('test.field')).toBe(true)
registry.unregister('test.field')
expect(registry.has('test.field')).toBe(false)
})
it('does not throw when unregistering non-existent hook', () => {
expect(() => registry.unregister('nonexistent.field')).not.toThrow()
})
it('only removes specified hook, not others', () => {
const component1: FieldHookComponent = () => null
const component2: FieldHookComponent = () => null
registry.register('field1', component1)
registry.register('field2', component2)
registry.unregister('field1')
expect(registry.has('field1')).toBe(false)
expect(registry.has('field2')).toBe(true)
})
})
describe('clear', () => {
it('removes all registered hooks', () => {
const component1: FieldHookComponent = () => null
const component2: FieldHookComponent = () => null
const component3: FieldHookComponent = () => null
registry.register('field1', component1)
registry.register('field2', component2)
registry.register('field3', component3)
expect(registry.getAllPaths()).toHaveLength(3)
registry.clear()
expect(registry.getAllPaths()).toHaveLength(0)
expect(registry.has('field1')).toBe(false)
expect(registry.has('field2')).toBe(false)
expect(registry.has('field3')).toBe(false)
})
it('works correctly on empty registry', () => {
expect(() => registry.clear()).not.toThrow()
expect(registry.getAllPaths()).toHaveLength(0)
})
})
describe('getAllPaths', () => {
it('returns empty array when no hooks registered', () => {
expect(registry.getAllPaths()).toEqual([])
})
it('returns all registered field paths', () => {
const component: FieldHookComponent = () => null
registry.register('field1', component)
registry.register('field2', component)
registry.register('field3', component)
const paths = registry.getAllPaths()
expect(paths).toHaveLength(3)
expect(paths).toContain('field1')
expect(paths).toContain('field2')
expect(paths).toContain('field3')
})
it('returns updated paths after unregister', () => {
const component: FieldHookComponent = () => null
registry.register('field1', component)
registry.register('field2', component)
registry.register('field3', component)
registry.unregister('field2')
const paths = registry.getAllPaths()
expect(paths).toHaveLength(2)
expect(paths).toContain('field1')
expect(paths).toContain('field3')
expect(paths).not.toContain('field2')
})
it('handles nested field paths correctly', () => {
const component: FieldHookComponent = () => null
registry.register('config.chat.enabled', component)
registry.register('config.chat.model', component)
registry.register('config.api.key', component)
const paths = registry.getAllPaths()
expect(paths).toHaveLength(3)
expect(paths).toContain('config.chat.enabled')
expect(paths).toContain('config.chat.model')
expect(paths).toContain('config.api.key')
})
})
describe('integration scenarios', () => {
it('supports full lifecycle of multiple hooks', () => {
const replaceComponent: FieldHookComponent = () => null
const wrapperComponent: FieldHookComponent = () => null
registry.register('field1', replaceComponent, 'replace')
registry.register('field2', wrapperComponent, 'wrapper')
expect(registry.getAllPaths()).toHaveLength(2)
const entry1 = registry.get('field1')
expect(entry1?.type).toBe('replace')
expect(entry1?.component).toBe(replaceComponent)
const entry2 = registry.get('field2')
expect(entry2?.type).toBe('wrapper')
expect(entry2?.component).toBe(wrapperComponent)
registry.unregister('field1')
expect(registry.getAllPaths()).toHaveLength(1)
expect(registry.has('field2')).toBe(true)
registry.clear()
expect(registry.getAllPaths()).toHaveLength(0)
})
it('handles rapid register/unregister cycles', () => {
const component: FieldHookComponent = () => null
for (let i = 0; i < 100; i++) {
registry.register(`field${i}`, component)
}
expect(registry.getAllPaths()).toHaveLength(100)
for (let i = 0; i < 50; i++) {
registry.unregister(`field${i}`)
}
expect(registry.getAllPaths()).toHaveLength(50)
registry.clear()
expect(registry.getAllPaths()).toHaveLength(0)
})
})
})

View File

@@ -1,80 +0,0 @@
import { render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import type { ReactNode } from 'react'
import { PluginConfigPage } from '../plugin-config'
import * as pluginApi from '@/lib/plugin-api'
const toastMock = vi.fn()
vi.mock('@/hooks/use-toast', () => ({
useToast: () => ({ toast: toastMock }),
}))
vi.mock('@/lib/restart-context', () => ({
RestartProvider: ({ children }: { children: ReactNode }) => <>{children}</>,
useRestart: () => ({
showRestartPrompt: false,
markRestartRequired: vi.fn(),
clearRestartRequired: vi.fn(),
}),
}))
vi.mock('@/components/restart-overlay', () => ({
RestartOverlay: () => null,
}))
vi.mock('@/components', () => ({
CodeEditor: ({ value }: { value: string }) => <pre>{value}</pre>,
ListFieldEditor: () => <div>list-field-editor</div>,
}))
vi.mock('@/lib/plugin-api', () => ({
getInstalledPlugins: vi.fn(),
getPluginConfigSchema: vi.fn(),
getPluginConfig: vi.fn(),
getPluginConfigRaw: vi.fn(),
updatePluginConfig: vi.fn(),
updatePluginConfigRaw: vi.fn(),
resetPluginConfig: vi.fn(),
togglePlugin: vi.fn(),
}))
describe('PluginConfigPage', () => {
beforeEach(() => {
toastMock.mockReset()
vi.mocked(pluginApi.getInstalledPlugins).mockResolvedValue({
success: true,
data: [
{
id: 'test.emoji',
path: '/plugins/test_emoji',
manifest: {
manifest_version: 2,
name: 'Emoji Plugin',
version: '1.0.0',
description: 'emoji tools',
author: { name: 'tester' },
license: 'MIT',
host_application: { min_version: '1.0.0' },
},
},
],
})
vi.mocked(pluginApi.getPluginConfigSchema).mockResolvedValue({} as never)
vi.mocked(pluginApi.getPluginConfig).mockResolvedValue({} as never)
vi.mocked(pluginApi.getPluginConfigRaw).mockResolvedValue({} as never)
vi.mocked(pluginApi.updatePluginConfig).mockResolvedValue({} as never)
vi.mocked(pluginApi.updatePluginConfigRaw).mockResolvedValue({} as never)
vi.mocked(pluginApi.resetPluginConfig).mockResolvedValue({} as never)
vi.mocked(pluginApi.togglePlugin).mockResolvedValue({} as never)
})
it('shows real plugins and no longer surfaces A_Memorix in plugin config list', async () => {
render(<PluginConfigPage />)
expect(await screen.findByText('Emoji Plugin')).toBeInTheDocument()
expect(screen.getByText('点击插件查看和编辑配置')).toBeInTheDocument()
expect(screen.queryByText(/A_Memorix/i)).not.toBeInTheDocument()
})
})

View File

@@ -1,802 +0,0 @@
import { act, render, screen, waitFor, within } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { KnowledgeBasePage } from '../knowledge-base'
import * as memoryApi from '@/lib/memory-api'
const navigateMock = vi.fn()
const toastMock = vi.fn()
vi.mock('@tanstack/react-router', () => ({
useNavigate: () => navigateMock,
}))
vi.mock('@/hooks/use-toast', () => ({
useToast: () => ({ toast: toastMock }),
}))
vi.mock('@/components', () => ({
CodeEditor: ({ value }: { value: string }) => <pre data-testid="code-editor">{value}</pre>,
MarkdownRenderer: ({ content }: { content: string }) => <div>{content}</div>,
}))
vi.mock('@/components/memory/MemoryConfigEditor', () => ({
MemoryConfigEditor: () => <div data-testid="memory-config-editor">memory-config-editor</div>,
}))
vi.mock('@/components/memory/MemoryDeleteDialog', () => ({
MemoryDeleteDialog: ({
open,
onExecute,
onRestore,
preview,
result,
}: {
open: boolean
preview?: { mode?: string; item_count?: number } | null
result?: { operation_id?: string } | null
onExecute?: () => void
onRestore?: () => void
}) => (
open ? (
<div data-testid="memory-delete-dialog">
<div>{`preview:${preview?.mode ?? 'none'}:${preview?.item_count ?? 0}`}</div>
<div>{`result:${result?.operation_id ?? 'none'}`}</div>
<button type="button" onClick={onExecute}></button>
<button type="button" onClick={onRestore}></button>
</div>
) : null
),
}))
vi.mock('@/lib/memory-api', () => ({
getMemoryConfigSchema: vi.fn(),
getMemoryConfig: vi.fn(),
getMemoryConfigRaw: vi.fn(),
getMemoryRuntimeConfig: vi.fn(),
getMemoryImportGuide: vi.fn(),
getMemoryImportSettings: vi.fn(),
getMemoryImportPathAliases: vi.fn(),
getMemoryImportTasks: vi.fn(),
getMemoryImportTask: vi.fn(),
getMemoryImportTaskChunks: vi.fn(),
createMemoryUploadImport: vi.fn(),
createMemoryPasteImport: vi.fn(),
createMemoryRawScanImport: vi.fn(),
createMemoryLpmmOpenieImport: vi.fn(),
createMemoryLpmmConvertImport: vi.fn(),
createMemoryTemporalBackfillImport: vi.fn(),
createMemoryMaibotMigrationImport: vi.fn(),
cancelMemoryImportTask: vi.fn(),
retryMemoryImportTask: vi.fn(),
resolveMemoryImportPath: vi.fn(),
refreshMemoryRuntimeSelfCheck: vi.fn(),
updateMemoryConfig: vi.fn(),
updateMemoryConfigRaw: vi.fn(),
getMemoryTuningProfile: vi.fn(),
getMemoryTuningTasks: vi.fn(),
createMemoryTuningTask: vi.fn(),
applyBestMemoryTuningProfile: vi.fn(),
getMemorySources: vi.fn(),
getMemoryDeleteOperations: vi.fn(),
getMemoryDeleteOperation: vi.fn(),
getMemoryFeedbackCorrections: vi.fn(),
getMemoryFeedbackCorrection: vi.fn(),
previewMemoryDelete: vi.fn(),
executeMemoryDelete: vi.fn(),
restoreMemoryDelete: vi.fn(),
rollbackMemoryFeedbackCorrection: vi.fn(),
}))
function mockImportTask(taskId: string, status: string = 'running'): memoryApi.MemoryImportTaskPayload {
return {
task_id: taskId,
source: 'webui',
status,
current_step: status === 'completed' ? 'completed' : 'running',
total_chunks: 120,
done_chunks: status === 'completed' ? 120 : 36,
failed_chunks: status === 'completed' ? 0 : 2,
cancelled_chunks: 0,
progress: status === 'completed' ? 100 : 30,
error: '',
file_count: 2,
created_at: 1_710_000_000,
started_at: 1_710_000_001,
finished_at: status === 'completed' ? 1_710_000_099 : null,
updated_at: 1_710_000_100,
task_kind: 'paste',
params: {},
files: [],
}
}
function mockImportDetail(taskId: string): memoryApi.MemoryImportTaskPayload {
return {
...mockImportTask(taskId),
files: [
{
file_id: 'file-alpha',
name: 'alpha.txt',
source_kind: 'paste',
input_mode: 'text',
status: 'running',
current_step: 'running',
detected_strategy_type: 'auto',
total_chunks: 80,
done_chunks: 30,
failed_chunks: 1,
cancelled_chunks: 0,
progress: 37.5,
error: '',
created_at: 1_710_000_000,
updated_at: 1_710_000_100,
},
{
file_id: 'file-beta',
name: 'beta.txt',
source_kind: 'paste',
input_mode: 'text',
status: 'failed',
current_step: 'extracting',
detected_strategy_type: 'auto',
total_chunks: 40,
done_chunks: 6,
failed_chunks: 4,
cancelled_chunks: 0,
progress: 25,
error: 'mock error',
created_at: 1_710_000_000,
updated_at: 1_710_000_100,
},
],
}
}
function mockImportCompletedWithErrorsDetail(taskId: string): memoryApi.MemoryImportTaskPayload {
return {
...mockImportDetail(taskId),
status: 'completed_with_errors',
current_step: 'completed_with_errors',
total_chunks: 12,
done_chunks: 9,
failed_chunks: 3,
cancelled_chunks: 0,
progress: 75,
files: [
{
file_id: 'file-error',
name: 'error.txt',
source_kind: 'paste',
input_mode: 'text',
status: 'failed',
current_step: 'failed',
detected_strategy_type: 'auto',
total_chunks: 12,
done_chunks: 9,
failed_chunks: 3,
cancelled_chunks: 0,
progress: 75,
error: 'mock error',
created_at: 1_710_000_000,
updated_at: 1_710_000_100,
},
],
}
}
describe('KnowledgeBasePage import workflow', () => {
beforeEach(() => {
navigateMock.mockReset()
toastMock.mockReset()
vi.mocked(memoryApi.getMemoryConfigSchema).mockResolvedValue({
success: true,
path: 'config/a_memorix.toml',
schema: {
plugin_id: 'a_memorix',
plugin_info: {
name: 'A_Memorix',
version: '2.0.0',
description: '长期记忆子系统',
author: 'A_Dawn',
},
_note: 'raw-only 字段仍可通过 TOML 编辑',
layout: {
type: 'tabs',
tabs: [{ id: 'basic', title: '基础', sections: ['plugin'], order: 1 }],
},
sections: {
plugin: {
name: 'plugin',
title: '子系统状态',
collapsed: false,
order: 1,
fields: {},
},
},
},
})
vi.mocked(memoryApi.getMemoryConfig).mockResolvedValue({
success: true,
path: 'config/a_memorix.toml',
config: { plugin: { enabled: true } },
})
vi.mocked(memoryApi.getMemoryConfigRaw).mockResolvedValue({
success: true,
path: 'config/a_memorix.toml',
config: '[plugin]\nenabled = true\n',
})
vi.mocked(memoryApi.getMemoryRuntimeConfig).mockResolvedValue({
success: true,
config: { plugin: { enabled: true } },
data_dir: 'data/plugins/a-dawn.a-memorix',
embedding_dimension: 1024,
auto_save: true,
relation_vectors_enabled: false,
runtime_ready: true,
embedding_degraded: false,
embedding_degraded_reason: '',
embedding_degraded_since: null,
embedding_last_check: null,
paragraph_vector_backfill_pending: 2,
paragraph_vector_backfill_running: 0,
paragraph_vector_backfill_failed: 1,
paragraph_vector_backfill_done: 3,
})
vi.mocked(memoryApi.getMemoryImportGuide).mockResolvedValue({
success: true,
content: '# 导入指南\n导入说明',
})
vi.mocked(memoryApi.getMemoryImportSettings).mockResolvedValue({
success: true,
settings: {
max_paste_chars: 200_000,
max_file_concurrency: 8,
max_chunk_concurrency: 16,
default_file_concurrency: 2,
default_chunk_concurrency: 4,
poll_interval_ms: 60_000,
maibot_source_db_default: 'data/maibot.db',
},
})
vi.mocked(memoryApi.getMemoryImportPathAliases).mockResolvedValue({
success: true,
path_aliases: {
lpmm: 'data/lpmm',
plugin_data: 'data/plugins/a-dawn.a-memorix',
raw: 'data/raw',
},
})
vi.mocked(memoryApi.getMemoryImportTasks).mockResolvedValue({
success: true,
items: [
mockImportTask('import-run-1', 'running'),
mockImportTask('import-queued-1', 'queued'),
mockImportTask('import-done-1', 'completed'),
],
})
vi.mocked(memoryApi.getMemoryImportTask).mockResolvedValue({
success: true,
task: mockImportDetail('import-run-1'),
})
vi.mocked(memoryApi.getMemoryImportTaskChunks).mockImplementation(async (_taskId, fileId, offset = 0) => ({
success: true,
task_id: 'import-run-1',
file_id: fileId,
offset,
limit: 50,
total: 120,
items: [
{
chunk_id: `${fileId}-${offset + 0}`,
index: offset + 0,
chunk_type: 'text',
status: 'running',
step: 'extracting',
failed_at: '',
retryable: true,
error: '',
progress: 50,
content_preview: `chunk-preview-${offset + 0}`,
updated_at: 1_710_000_111,
},
],
}))
vi.mocked(memoryApi.createMemoryUploadImport).mockResolvedValue({
success: true,
task: mockImportTask('upload-task-1', 'queued'),
})
vi.mocked(memoryApi.createMemoryPasteImport).mockResolvedValue({
success: true,
task: mockImportTask('paste-task-1', 'queued'),
})
vi.mocked(memoryApi.createMemoryRawScanImport).mockResolvedValue({
success: true,
task: mockImportTask('raw-task-1', 'queued'),
})
vi.mocked(memoryApi.createMemoryLpmmOpenieImport).mockResolvedValue({
success: true,
task: mockImportTask('openie-task-1', 'queued'),
})
vi.mocked(memoryApi.createMemoryLpmmConvertImport).mockResolvedValue({
success: true,
task: mockImportTask('convert-task-1', 'queued'),
})
vi.mocked(memoryApi.createMemoryTemporalBackfillImport).mockResolvedValue({
success: true,
task: mockImportTask('backfill-task-1', 'queued'),
})
vi.mocked(memoryApi.createMemoryMaibotMigrationImport).mockResolvedValue({
success: true,
task: mockImportTask('migration-task-1', 'queued'),
})
vi.mocked(memoryApi.cancelMemoryImportTask).mockResolvedValue({
success: true,
task: mockImportTask('import-run-1', 'cancel_requested'),
})
vi.mocked(memoryApi.retryMemoryImportTask).mockResolvedValue({
success: true,
task: mockImportTask('retry-task-1', 'queued'),
})
vi.mocked(memoryApi.resolveMemoryImportPath).mockResolvedValue({
success: true,
alias: 'raw',
relative_path: 'exports',
resolved_path: 'D:/Dev/rdev/MaiBot/data/raw/exports',
exists: true,
is_file: false,
is_dir: true,
})
vi.mocked(memoryApi.getMemoryTuningProfile).mockResolvedValue({
success: true,
profile: { retrieval: { top_k: 10 } },
toml: '[retrieval]\ntop_k = 10\n',
})
vi.mocked(memoryApi.getMemoryTuningTasks).mockResolvedValue({
success: true,
items: [{ task_id: 'tune-1', status: 'done' }],
})
vi.mocked(memoryApi.createMemoryTuningTask).mockResolvedValue({ success: true } as never)
vi.mocked(memoryApi.applyBestMemoryTuningProfile).mockResolvedValue({ success: true } as never)
vi.mocked(memoryApi.getMemorySources).mockResolvedValue({
success: true,
items: [{ source: 'demo-1', paragraph_count: 2, relation_count: 1 }],
count: 1,
})
vi.mocked(memoryApi.getMemoryDeleteOperations).mockResolvedValue({
success: true,
items: [
{
operation_id: 'del-1',
mode: 'source',
status: 'executed',
summary: { counts: { paragraphs: 2, relations: 1, sources: 1 } },
},
],
count: 1,
})
vi.mocked(memoryApi.getMemoryDeleteOperation).mockResolvedValue({
success: true,
operation: {
operation_id: 'del-1',
mode: 'source',
status: 'executed',
selector: { sources: ['demo-1'] },
summary: { counts: { paragraphs: 2, relations: 1, sources: 1 }, sources: ['demo-1'] },
items: [],
},
})
vi.mocked(memoryApi.getMemoryFeedbackCorrections).mockResolvedValue({
success: true,
items: [
{
task_id: 11,
query_tool_id: 'tool-query-11',
session_id: 'session-1',
query_text: '测试用户最喜欢的颜色是什么',
query_timestamp: 1_710_000_010,
task_status: 'applied',
decision: 'correct',
decision_confidence: 0.97,
feedback_message_count: 1,
rollback_status: 'none',
affected_counts: {
relations: 1,
stale_paragraphs: 1,
episode_sources: 2,
profile_person_ids: 1,
correction_paragraphs: 1,
corrected_relations: 1,
},
created_at: 1_710_000_011,
updated_at: 1_710_000_012,
},
],
count: 1,
})
vi.mocked(memoryApi.getMemoryFeedbackCorrection).mockResolvedValue({
success: true,
task: {
task_id: 11,
query_tool_id: 'tool-query-11',
session_id: 'session-1',
query_text: '测试用户最喜欢的颜色是什么',
query_timestamp: 1_710_000_010,
task_status: 'applied',
decision: 'correct',
decision_confidence: 0.97,
feedback_message_count: 1,
rollback_status: 'none',
affected_counts: {
relations: 1,
stale_paragraphs: 1,
episode_sources: 2,
profile_person_ids: 1,
correction_paragraphs: 1,
corrected_relations: 1,
},
query_snapshot: { query: '测试用户最喜欢的颜色是什么', hits: [{ hash: 'paragraph-1' }] },
decision_payload: { decision: 'correct', confidence: 0.97 },
rollback_plan_summary: {
forgotten_relations: [{ hash: 'rel-old', subject: '测试用户', predicate: '最喜欢的颜色是', object: '蓝色' }],
corrected_write: {
paragraph_hashes: ['paragraph-new'],
corrected_relations: [{ hash: 'rel-new', subject: '测试用户', predicate: '最喜欢的颜色是', object: '绿色' }],
},
},
rollback_result: {},
action_logs: [
{
id: 1,
task_id: 11,
query_tool_id: 'tool-query-11',
action_type: 'forget_relation',
target_hash: 'rel-old',
reason: '用户明确纠正为绿色',
before_payload: { hash: 'rel-old', subject: '测试用户', predicate: '最喜欢的颜色是', object: '蓝色' },
after_payload: { is_inactive: true },
created_at: 1_710_000_013,
},
],
created_at: 1_710_000_011,
updated_at: 1_710_000_012,
},
})
vi.mocked(memoryApi.previewMemoryDelete).mockResolvedValue({
success: true,
mode: 'source',
selector: { sources: ['demo-1'] },
counts: { sources: 1, paragraphs: 2, relations: 1 },
sources: ['demo-1'],
items: [{ item_type: 'paragraph', item_hash: 'p-1', label: 'demo-1' }],
item_count: 1,
dry_run: true,
} as never)
vi.mocked(memoryApi.executeMemoryDelete).mockResolvedValue({
success: true,
mode: 'source',
operation_id: 'del-2',
counts: { sources: 1, paragraphs: 2, relations: 1 },
sources: ['demo-1'],
deleted_count: 4,
deleted_entity_count: 0,
deleted_relation_count: 1,
deleted_paragraph_count: 2,
deleted_source_count: 1,
} as never)
vi.mocked(memoryApi.restoreMemoryDelete).mockResolvedValue({ success: true } as never)
vi.mocked(memoryApi.rollbackMemoryFeedbackCorrection).mockResolvedValue({
success: true,
result: { restored_relation_hashes: ['rel-old'] },
task: {
task_id: 11,
query_tool_id: 'tool-query-11',
session_id: 'session-1',
query_text: '测试用户最喜欢的颜色是什么',
query_timestamp: 1_710_000_010,
task_status: 'applied',
decision: 'correct',
decision_confidence: 0.97,
feedback_message_count: 1,
rollback_status: 'rolled_back',
affected_counts: {
relations: 1,
stale_paragraphs: 1,
episode_sources: 2,
profile_person_ids: 1,
correction_paragraphs: 1,
corrected_relations: 1,
},
query_snapshot: { query: '测试用户最喜欢的颜色是什么', hits: [{ hash: 'paragraph-1' }] },
decision_payload: { decision: 'correct', confidence: 0.97 },
rollback_plan_summary: {},
rollback_result: { restored_relation_hashes: ['rel-old'] },
action_logs: [],
created_at: 1_710_000_011,
updated_at: 1_710_000_012,
},
})
vi.mocked(memoryApi.refreshMemoryRuntimeSelfCheck).mockResolvedValue({
success: true,
report: { ok: true },
})
vi.mocked(memoryApi.updateMemoryConfig).mockResolvedValue({ success: true } as never)
vi.mocked(memoryApi.updateMemoryConfigRaw).mockResolvedValue({ success: true } as never)
})
it('loads import settings/guide/tasks on first render', async () => {
const user = userEvent.setup()
render(<KnowledgeBasePage />)
expect(await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })).toBeInTheDocument()
await user.click(screen.getByRole('tab', { name: '导入' }))
expect(await screen.findByRole('button', { name: '创建导入任务' })).toBeInTheDocument()
expect((await screen.findAllByText('import-run-1')).length).toBeGreaterThan(0)
expect(memoryApi.getMemoryImportSettings).toHaveBeenCalled()
expect(memoryApi.getMemoryImportPathAliases).toHaveBeenCalled()
expect(memoryApi.getMemoryImportTasks).toHaveBeenCalled()
})
it('creates import tasks for all 7 modes and calls correct endpoints', async () => {
const user = userEvent.setup()
const { container } = render(<KnowledgeBasePage />)
const openImportTab = async () => {
await user.click(screen.getByRole('tab', { name: '导入' }))
await screen.findByRole('button', { name: '创建导入任务' })
}
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
await openImportTab()
const fileInput = container.querySelector('input[type="file"]') as HTMLInputElement
const uploadFiles = [
new File(['hello'], 'demo.txt', { type: 'text/plain' }),
new File(['{"name":"mai"}'], 'demo.json', { type: 'application/json' }),
new File(['a,b\n1,2'], 'demo.csv', { type: 'text/csv' }),
new File(['# note'], 'demo.md', { type: 'text/markdown' }),
]
await user.upload(fileInput, uploadFiles)
await user.click(screen.getByRole('button', { name: '创建导入任务' }))
await waitFor(() => expect(memoryApi.createMemoryUploadImport).toHaveBeenCalledTimes(1))
await openImportTab()
await user.click(screen.getByRole('tab', { name: '粘贴导入' }))
const editableTextarea = Array.from(container.querySelectorAll('textarea')).find((item) => !item.readOnly)
if (!editableTextarea) {
throw new Error('missing editable textarea')
}
await user.type(editableTextarea, 'paste content')
await user.click(screen.getByRole('button', { name: '创建导入任务' }))
await waitFor(() => expect(memoryApi.createMemoryPasteImport).toHaveBeenCalledTimes(1))
await openImportTab()
await user.click(screen.getByRole('tab', { name: '本地扫描' }))
await user.click(screen.getByRole('button', { name: '创建导入任务' }))
await waitFor(() => expect(memoryApi.createMemoryRawScanImport).toHaveBeenCalledTimes(1))
await openImportTab()
await user.click(screen.getByRole('tab', { name: 'LPMM OpenIE' }))
await user.click(screen.getByRole('button', { name: '创建导入任务' }))
await waitFor(() => expect(memoryApi.createMemoryLpmmOpenieImport).toHaveBeenCalledTimes(1))
await openImportTab()
await user.click(screen.getByRole('tab', { name: 'LPMM 转换' }))
await user.click(screen.getByRole('button', { name: '创建导入任务' }))
await waitFor(() => expect(memoryApi.createMemoryLpmmConvertImport).toHaveBeenCalledTimes(1))
await openImportTab()
await user.click(screen.getByRole('tab', { name: '时序回填' }))
await user.click(screen.getByRole('button', { name: '创建导入任务' }))
await waitFor(() => expect(memoryApi.createMemoryTemporalBackfillImport).toHaveBeenCalledTimes(1))
await openImportTab()
await user.click(screen.getByRole('tab', { name: 'MaiBot 迁移' }))
await user.click(screen.getByRole('button', { name: '创建导入任务' }))
await waitFor(() => expect(memoryApi.createMemoryMaibotMigrationImport).toHaveBeenCalledTimes(1))
const [uploadedFiles, uploadPayload] = vi.mocked(memoryApi.createMemoryUploadImport).mock.calls[0]
expect(uploadedFiles).toHaveLength(4)
expect(uploadedFiles.map((file) => file.name)).toEqual(['demo.txt', 'demo.json', 'demo.csv', 'demo.md'])
expect(uploadPayload).toMatchObject({
input_mode: 'text',
llm_enabled: true,
strategy_override: 'auto',
dedupe_policy: 'content_hash',
})
}, 60_000)
it('loads task detail and supports chunk pagination', async () => {
const user = userEvent.setup()
render(<KnowledgeBasePage />)
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
await user.click(screen.getByRole('tab', { name: '导入' }))
expect(await screen.findByText('alpha.txt')).toBeInTheDocument()
expect(await screen.findByText('chunk-preview-0')).toBeInTheDocument()
const betaButton = screen.getByText('beta.txt').closest('button')
if (!betaButton) {
throw new Error('missing file beta button')
}
await user.click(betaButton)
await waitFor(() =>
expect(memoryApi.getMemoryImportTaskChunks).toHaveBeenCalledWith('import-run-1', 'file-beta', 0, 50),
)
await user.click(screen.getByRole('button', { name: '下一页分块' }))
await waitFor(() =>
expect(memoryApi.getMemoryImportTaskChunks).toHaveBeenCalledWith('import-run-1', 'file-beta', 50, 50),
)
}, 20_000)
it('shows import failures separately from successful chunks', async () => {
vi.mocked(memoryApi.getMemoryImportTask).mockResolvedValue({
success: true,
task: mockImportCompletedWithErrorsDetail('import-run-1'),
})
const user = userEvent.setup()
render(<KnowledgeBasePage />)
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
await user.click(screen.getByRole('tab', { name: '导入' }))
expect((await screen.findAllByText('完成(有错误)')).length).toBeGreaterThan(0)
expect(await screen.findByText('成功 9 / 12 分块 · 失败 3')).toBeInTheDocument()
}, 20_000)
it('supports cancel and retry actions for selected task', async () => {
const user = userEvent.setup()
render(<KnowledgeBasePage />)
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
await user.click(screen.getByRole('tab', { name: '导入' }))
await screen.findByText('任务详情')
await user.click(screen.getByRole('button', { name: '取消选中导入任务' }))
await waitFor(() => expect(memoryApi.cancelMemoryImportTask).toHaveBeenCalledWith('import-run-1'))
await user.click(screen.getByRole('button', { name: '重试选中导入任务' }))
await waitFor(() => expect(memoryApi.retryMemoryImportTask).toHaveBeenCalled())
const [taskId, retryPayload] = vi.mocked(memoryApi.retryMemoryImportTask).mock.calls[0]
expect(taskId).toBe('import-run-1')
expect(retryPayload).toMatchObject({
overrides: {
llm_enabled: true,
strategy_override: 'auto',
},
})
}, 20_000)
it('auto polling updates queue and keeps page stable when refresh fails once', async () => {
vi.mocked(memoryApi.getMemoryImportSettings).mockResolvedValue({
success: true,
settings: {
max_paste_chars: 200_000,
max_file_concurrency: 8,
max_chunk_concurrency: 16,
default_file_concurrency: 2,
default_chunk_concurrency: 4,
poll_interval_ms: 200,
maibot_source_db_default: 'data/maibot.db',
},
})
const user = userEvent.setup()
render(<KnowledgeBasePage />)
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
await user.click(screen.getByRole('tab', { name: '导入' }))
await screen.findByText('导入队列')
const initialCalls = vi.mocked(memoryApi.getMemoryImportTasks).mock.calls.length
vi.mocked(memoryApi.getMemoryImportTasks).mockRejectedValueOnce(new Error('poll failure'))
await act(async () => {
await new Promise((resolve) => setTimeout(resolve, 350))
})
expect(screen.getByText('长期记忆控制台')).toBeInTheDocument()
expect(vi.mocked(memoryApi.getMemoryImportTasks).mock.calls.length).toBeGreaterThan(initialCalls)
}, 20_000)
it('creates tuning task and applies best profile (tuning module)', async () => {
const user = userEvent.setup()
render(<KnowledgeBasePage />)
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
await user.click(screen.getByRole('tab', { name: '调优' }))
await screen.findByText('调优任务')
await user.click(screen.getByRole('button', { name: '创建调优任务' }))
await waitFor(() =>
expect(memoryApi.createMemoryTuningTask).toHaveBeenCalledWith({
objective: 'precision_priority',
intensity: 'standard',
sample_size: 24,
top_k_eval: 20,
}),
)
await user.click(screen.getByRole('button', { name: '应用最佳' }))
await waitFor(() => expect(memoryApi.applyBestMemoryTuningProfile).toHaveBeenCalledWith('tune-1'))
}, 20_000)
it('previews executes and restores source delete (delete module)', async () => {
const user = userEvent.setup()
render(<KnowledgeBasePage />)
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
await user.click(screen.getByRole('tab', { name: '删除' }))
await screen.findByText('来源批量删除')
const sourceCellCandidates = await screen.findAllByText('demo-1')
const sourceRow = sourceCellCandidates
.map((item) => item.closest('tr'))
.find((row): row is HTMLTableRowElement => Boolean(row && within(row).queryByRole('checkbox')))
if (!sourceRow) {
throw new Error('missing source row')
}
await user.click(within(sourceRow).getByRole('checkbox'))
await user.click(screen.getByRole('button', { name: '预览删除' }))
await waitFor(() =>
expect(memoryApi.previewMemoryDelete).toHaveBeenCalledWith({
mode: 'source',
selector: { sources: ['demo-1'] },
reason: 'knowledge_base_source_delete',
requested_by: 'knowledge_base',
}),
)
const dialog = await screen.findByTestId('memory-delete-dialog')
expect(dialog).toHaveTextContent('preview:source:1')
await user.click(screen.getByRole('button', { name: '执行删除' }))
await waitFor(() =>
expect(memoryApi.executeMemoryDelete).toHaveBeenCalledWith({
mode: 'source',
selector: { sources: ['demo-1'] },
reason: 'knowledge_base_source_delete',
requested_by: 'knowledge_base',
}),
)
await user.click(screen.getByRole('button', { name: '执行恢复' }))
await waitFor(() =>
expect(memoryApi.restoreMemoryDelete).toHaveBeenCalledWith({
operation_id: 'del-2',
requested_by: 'knowledge_base',
}),
)
}, 20_000)
it('shows feedback correction history and supports rollback', async () => {
const user = userEvent.setup()
render(<KnowledgeBasePage />)
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
await user.click(screen.getByRole('tab', { name: '纠错历史' }))
await screen.findByText('反馈纠错历史')
await screen.findByText('测试用户最喜欢的颜色是什么')
await waitFor(() => expect(memoryApi.getMemoryFeedbackCorrection).toHaveBeenCalledWith(11))
await user.click(screen.getByRole('button', { name: '回退本次纠错' }))
const rollbackReason = await screen.findByLabelText('回退原因')
await user.type(rollbackReason, '人工确认回退')
await user.click(screen.getByRole('button', { name: '确认回退' }))
await waitFor(() =>
expect(memoryApi.rollbackMemoryFeedbackCorrection).toHaveBeenCalledWith(11, {
requested_by: 'knowledge_base',
reason: '人工确认回退',
}),
)
}, 20_000)
})

View File

@@ -1,440 +0,0 @@
import { render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { KnowledgeGraphPage } from '../knowledge-graph'
import * as memoryApi from '@/lib/memory-api'
const navigateMock = vi.fn()
const toastMock = vi.fn()
vi.mock('@tanstack/react-router', () => ({
useNavigate: () => navigateMock,
}))
vi.mock('@/hooks/use-toast', () => ({
useToast: () => ({ toast: toastMock }),
}))
vi.mock('@/components/memory/MemoryDeleteDialog', () => ({
MemoryDeleteDialog: ({
open,
preview,
}: {
open: boolean
preview?: { mode?: string; item_count?: number } | null
}) => (
open ? <div data-testid="memory-delete-dialog">{`delete:${preview?.mode ?? 'none'}:${preview?.item_count ?? 0}`}</div> : null
),
}))
vi.mock('../knowledge-graph/GraphVisualization', () => ({
GraphVisualization: ({
graphData,
onNodeClick,
onEdgeClick,
}: {
graphData: { nodes: Array<{ id: string }>; edges: Array<{ source: string; target: string }> }
onNodeClick: (event: React.MouseEvent, node: { id: string }) => void
onEdgeClick: (event: React.MouseEvent, edge: { source: string; target: string }) => void
}) => (
<div data-testid="graph-visualization">
<div>{`nodes:${graphData.nodes.length},edges:${graphData.edges.length}`}</div>
{graphData.nodes[0] ? (
<button type="button" onClick={(event) => onNodeClick(event as never, { id: graphData.nodes[0].id })}>
</button>
) : null}
{graphData.edges[0] ? (
<button
type="button"
onClick={(event) =>
onEdgeClick(event as never, {
source: graphData.edges[0].source,
target: graphData.edges[0].target,
})}
>
</button>
) : null}
</div>
),
}))
vi.mock('../knowledge-graph/GraphDialogs', () => ({
NodeDetailDialog: ({
selectedNodeData,
nodeDetail,
onOpenEvidence,
onDeleteEntity,
}: {
selectedNodeData: { id: string } | null
nodeDetail: { relations?: Array<{ predicate: string }>; paragraphs?: Array<unknown> } | null
onOpenEvidence?: () => void
onDeleteEntity?: (options: { includeParagraphs: boolean }) => void
}) => (
selectedNodeData ? (
<div data-testid="node-detail-dialog">
<div>{`node:${selectedNodeData.id}`}</div>
<div>{`relations:${nodeDetail?.relations?.[0]?.predicate ?? 'none'}`}</div>
<div>{`paragraphs:${nodeDetail?.paragraphs?.length ?? 0}`}</div>
<button type="button" onClick={onOpenEvidence}></button>
<button type="button" onClick={() => onDeleteEntity?.({ includeParagraphs: true })}></button>
</div>
) : null
),
EdgeDetailDialog: ({
selectedEdgeData,
edgeDetail,
onOpenEvidence,
}: {
selectedEdgeData: { source: { id: string }; target: { id: string } } | null
edgeDetail: { edge?: { predicates?: string[] }; paragraphs?: Array<unknown> } | null
onOpenEvidence?: () => void
}) => (
selectedEdgeData ? (
<div data-testid="edge-detail-dialog">
<div>{`edge:${selectedEdgeData.source.id}->${selectedEdgeData.target.id}`}</div>
<div>{`predicates:${edgeDetail?.edge?.predicates?.join(',') ?? 'none'}`}</div>
<div>{`paragraphs:${edgeDetail?.paragraphs?.length ?? 0}`}</div>
<button type="button" onClick={onOpenEvidence}></button>
</div>
) : null
),
RelationDetailDialog: () => null,
ParagraphDetailDialog: () => null,
}))
vi.mock('@/lib/memory-api', () => ({
getMemoryGraph: vi.fn(),
getMemoryGraphSearch: vi.fn(),
getMemoryGraphNodeDetail: vi.fn(),
getMemoryGraphEdgeDetail: vi.fn(),
previewMemoryDelete: vi.fn(),
executeMemoryDelete: vi.fn(),
restoreMemoryDelete: vi.fn(),
}))
describe('KnowledgeGraphPage', () => {
beforeEach(() => {
navigateMock.mockReset()
toastMock.mockReset()
vi.mocked(memoryApi.getMemoryGraph).mockResolvedValue({
success: true,
nodes: [
{ id: 'alpha', name: 'Alpha' },
{ id: 'beta', name: 'Beta' },
],
edges: [
{
source: 'alpha',
target: 'beta',
weight: 1,
predicates: ['关联'],
relation_count: 1,
evidence_count: 2,
relation_hashes: ['rel-1'],
label: '关联',
},
],
total_nodes: 2,
total_edges: 1,
})
vi.mocked(memoryApi.getMemoryGraphSearch).mockResolvedValue({
success: true,
query: 'alpha',
limit: 50,
count: 0,
items: [],
})
vi.mocked(memoryApi.getMemoryGraphNodeDetail).mockResolvedValue({
success: true,
node: { id: 'alpha', type: 'entity', content: 'Alpha', hash: 'entity-1', appearance_count: 3 },
relations: [
{
hash: 'rel-1',
subject: 'alpha',
predicate: '关联',
object: 'beta',
text: 'alpha 关联 beta',
confidence: 0.9,
paragraph_count: 1,
paragraph_hashes: ['p-1'],
source_paragraph: 'p-1',
},
],
paragraphs: [
{
hash: 'p-1',
content: 'Alpha 提到了 Beta',
preview: 'Alpha 提到了 Beta',
source: 'demo',
entity_count: 2,
relation_count: 1,
entities: ['Alpha', 'Beta'],
relations: ['alpha 关联 beta'],
},
],
evidence_graph: {
nodes: [
{ id: 'entity:alpha', type: 'entity', content: 'Alpha' },
{ id: 'relation:rel-1', type: 'relation', content: 'alpha 关联 beta' },
{ id: 'paragraph:p-1', type: 'paragraph', content: 'Alpha 提到了 Beta' },
],
edges: [
{ source: 'paragraph:p-1', target: 'entity:alpha', kind: 'mentions', label: '提及', weight: 1 },
{ source: 'paragraph:p-1', target: 'relation:rel-1', kind: 'supports', label: '支撑', weight: 1 },
],
focus_entities: ['alpha'],
},
})
vi.mocked(memoryApi.getMemoryGraphEdgeDetail).mockResolvedValue({
success: true,
edge: {
source: 'alpha',
target: 'beta',
weight: 1,
predicates: ['关联'],
relation_count: 1,
evidence_count: 1,
relation_hashes: ['rel-1'],
label: '关联',
},
relations: [
{
hash: 'rel-1',
subject: 'alpha',
predicate: '关联',
object: 'beta',
text: 'alpha 关联 beta',
confidence: 0.9,
paragraph_count: 1,
paragraph_hashes: ['p-1'],
source_paragraph: 'p-1',
},
],
paragraphs: [
{
hash: 'p-1',
content: 'Alpha 提到了 Beta',
preview: 'Alpha 提到了 Beta',
source: 'demo',
entity_count: 2,
relation_count: 1,
entities: ['Alpha', 'Beta'],
relations: ['alpha 关联 beta'],
},
],
evidence_graph: {
nodes: [
{ id: 'entity:alpha', type: 'entity', content: 'Alpha' },
{ id: 'entity:beta', type: 'entity', content: 'Beta' },
{ id: 'relation:rel-1', type: 'relation', content: 'alpha 关联 beta' },
],
edges: [
{ source: 'relation:rel-1', target: 'entity:alpha', kind: 'subject', label: '主语', weight: 1 },
{ source: 'relation:rel-1', target: 'entity:beta', kind: 'object', label: '宾语', weight: 1 },
],
focus_entities: ['alpha', 'beta'],
},
})
vi.mocked(memoryApi.previewMemoryDelete).mockResolvedValue({
success: true,
mode: 'mixed',
selector: { entity_hashes: ['entity-1'] },
counts: { entities: 1, relations: 1, paragraphs: 1 },
sources: ['demo'],
items: [{ item_type: 'entity', item_hash: 'entity-1', label: 'Alpha' }],
item_count: 1,
dry_run: true,
} as never)
vi.mocked(memoryApi.executeMemoryDelete).mockResolvedValue({
success: true,
mode: 'mixed',
operation_id: 'del-1',
counts: { entities: 1, relations: 1, paragraphs: 1 },
sources: ['demo'],
deleted_count: 3,
deleted_entity_count: 1,
deleted_relation_count: 1,
deleted_paragraph_count: 1,
deleted_source_count: 0,
} as never)
vi.mocked(memoryApi.restoreMemoryDelete).mockResolvedValue({ success: true } as never)
})
it('calls backend graph search and renders no-hit state', async () => {
const user = userEvent.setup()
render(<KnowledgeGraphPage />)
expect(await screen.findByText('长期记忆图谱')).toBeInTheDocument()
expect(screen.getByText(/总节点 2/)).toBeInTheDocument()
expect(screen.getByTestId('graph-visualization')).toHaveTextContent('nodes:2,edges:1')
await user.type(screen.getByPlaceholderText('搜索实体、关系、hash后端全库'), 'missing')
expect(memoryApi.getMemoryGraph).toHaveBeenCalledTimes(1)
await user.click(screen.getByRole('button', { name: '搜索' }))
await waitFor(() => {
expect(memoryApi.getMemoryGraphSearch).toHaveBeenCalledWith('missing', 50)
})
expect(await screen.findByText('未命中实体或关系。')).toBeInTheDocument()
})
it('supports clicking entity search result to locate evidence', async () => {
const user = userEvent.setup()
vi.mocked(memoryApi.getMemoryGraphSearch).mockResolvedValue({
success: true,
query: 'alpha',
limit: 50,
count: 1,
items: [
{
type: 'entity',
title: 'Alpha',
matched_field: 'name',
matched_value: 'Alpha',
entity_name: 'alpha',
entity_hash: 'entity-1',
appearance_count: 3,
},
],
})
render(<KnowledgeGraphPage />)
await screen.findByTestId('graph-visualization')
await user.type(screen.getByPlaceholderText('搜索实体、关系、hash后端全库'), 'alpha')
await user.click(screen.getByRole('button', { name: '搜索' }))
await screen.findByText('搜索词alpha')
await user.click(screen.getByRole('button', { name: /Alpha/ }))
await waitFor(() => {
expect(memoryApi.getMemoryGraphNodeDetail).toHaveBeenCalledWith('alpha')
})
expect(screen.getByRole('tab', { name: '证据视图' })).toHaveAttribute('data-state', 'active')
})
it('supports clicking relation search result to locate evidence', async () => {
const user = userEvent.setup()
vi.mocked(memoryApi.getMemoryGraphSearch).mockResolvedValue({
success: true,
query: '关联',
limit: 50,
count: 1,
items: [
{
type: 'relation',
title: 'alpha 关联 beta',
matched_field: 'predicate',
matched_value: '关联',
subject: 'alpha',
predicate: '关联',
object: 'beta',
relation_hash: 'rel-1',
confidence: 0.9,
},
],
})
render(<KnowledgeGraphPage />)
await screen.findByTestId('graph-visualization')
await user.type(screen.getByPlaceholderText('搜索实体、关系、hash后端全库'), '关联')
await user.click(screen.getByRole('button', { name: '搜索' }))
await user.click(screen.getByRole('button', { name: /alpha 关联 beta/ }))
await waitFor(() => {
expect(memoryApi.getMemoryGraphEdgeDetail).toHaveBeenCalledWith('alpha', 'beta')
})
expect(screen.getByRole('tab', { name: '证据视图' })).toHaveAttribute('data-state', 'active')
})
it('falls back to local filtering when backend search fails', async () => {
const user = userEvent.setup()
vi.mocked(memoryApi.getMemoryGraphSearch).mockRejectedValue(new Error('search unavailable'))
render(<KnowledgeGraphPage />)
await screen.findByTestId('graph-visualization')
await user.type(screen.getByPlaceholderText('搜索实体、关系、hash后端全库'), 'missing')
await user.click(screen.getByRole('button', { name: '搜索' }))
expect(await screen.findByText('还没有可展示的长期记忆图谱')).toBeInTheDocument()
expect(toastMock).toHaveBeenCalledWith(
expect.objectContaining({
title: '后端检索失败,已回退本地筛选',
}),
)
})
it('shows empty state when switching to evidence view without a selection', async () => {
const user = userEvent.setup()
render(<KnowledgeGraphPage />)
expect(await screen.findByTestId('graph-visualization')).toBeInTheDocument()
await user.click(screen.getByRole('tab', { name: '证据视图' }))
expect(await screen.findByText('证据视图还没有可展示的选择')).toBeInTheDocument()
})
it('closes node dialog when switching to evidence view and renders evidence graph', async () => {
const user = userEvent.setup()
render(<KnowledgeGraphPage />)
await screen.findByTestId('graph-visualization')
await user.click(screen.getByRole('button', { name: '选择节点' }))
expect(await screen.findByTestId('node-detail-dialog')).toHaveTextContent('relations:关联')
expect(screen.getByTestId('node-detail-dialog')).toHaveTextContent('paragraphs:1')
await user.click(screen.getByRole('button', { name: '切到证据视图' }))
await waitFor(() => {
expect(screen.queryByTestId('node-detail-dialog')).not.toBeInTheDocument()
})
await waitFor(() => {
expect(screen.getByTestId('graph-visualization')).toHaveTextContent('nodes:3,edges:2')
})
})
it('loads edge detail with predicates and support paragraphs', async () => {
const user = userEvent.setup()
render(<KnowledgeGraphPage />)
await screen.findByTestId('graph-visualization')
await user.click(screen.getByRole('button', { name: '选择边' }))
expect(await screen.findByTestId('edge-detail-dialog')).toHaveTextContent('predicates:关联')
expect(screen.getByTestId('edge-detail-dialog')).toHaveTextContent('paragraphs:1')
await user.click(screen.getByRole('button', { name: '切到证据视图' }))
await waitFor(() => {
expect(screen.queryByTestId('edge-detail-dialog')).not.toBeInTheDocument()
})
})
it('opens delete preview dialog from node detail', async () => {
const user = userEvent.setup()
render(<KnowledgeGraphPage />)
await screen.findByTestId('graph-visualization')
await user.click(screen.getByRole('button', { name: '选择节点' }))
await screen.findByTestId('node-detail-dialog')
await user.click(screen.getByRole('button', { name: '删除实体' }))
await waitFor(() => {
expect(memoryApi.previewMemoryDelete).toHaveBeenCalled()
})
expect(await screen.findByTestId('memory-delete-dialog')).toHaveTextContent('delete:mixed:1')
})
})

View File

@@ -1,7 +0,0 @@
{
"extends": "./tsconfig.app.json",
"compilerOptions": {
"types": ["vite/client", "vitest/globals", "@testing-library/jest-dom"]
},
"include": ["src"]
}

View File

@@ -1,18 +0,0 @@
/// <reference types="vitest" />
import { defineConfig } from 'vite'
import react from '@vitejs/plugin-react'
import path from 'path'
export default defineConfig({
plugins: [react()],
test: {
globals: true,
environment: 'jsdom',
setupFiles: './src/test/setup.ts',
},
resolve: {
alias: {
'@': path.resolve(__dirname, './src'),
},
},
})

View File

@@ -1,398 +0,0 @@
from __future__ import annotations
from datetime import datetime, timedelta
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Callable, Dict, List
import asyncio
import inspect
import json
import pickle
from sqlalchemy.orm import sessionmaker
from sqlmodel import Session, create_engine
import numpy as np
import pytest
IMPORT_ERROR: str | None = None
try:
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
from src.A_memorix.core.utils import summary_importer as summary_importer_module
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
from src.common.database import database as database_module
from src.common.database.migrations import create_database_migration_bootstrapper
from src.common.message_repository import count_messages
from src.config.model_configs import TaskConfig
from src.services import memory_flow_service as memory_flow_service_module
from src.services import memory_service as memory_service_module
from src.services import send_service
except SystemExit as exc:
IMPORT_ERROR = f"config initialization exited during import: {exc}"
kernel_module = None # type: ignore[assignment]
SDKMemoryKernel = None # type: ignore[assignment]
summary_importer_module = None # type: ignore[assignment]
BotChatSession = None # type: ignore[assignment]
SessionMessage = None # type: ignore[assignment]
MessageInfo = None # type: ignore[assignment]
UserInfo = None # type: ignore[assignment]
MessageSequence = None # type: ignore[assignment]
TextComponent = None # type: ignore[assignment]
database_module = None # type: ignore[assignment]
create_database_migration_bootstrapper = None # type: ignore[assignment]
count_messages = None # type: ignore[assignment]
TaskConfig = None # type: ignore[assignment]
memory_flow_service_module = None # type: ignore[assignment]
memory_service_module = None # type: ignore[assignment]
send_service = None # type: ignore[assignment]
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
class _FakeEmbeddingManager:
def __init__(self, dimension: int = 8) -> None:
self.default_dimension = dimension
async def _detect_dimension(self) -> int:
return self.default_dimension
async def encode(self, text: Any) -> np.ndarray:
def _encode_one(raw: Any) -> np.ndarray:
content = str(raw or "")
vector = np.zeros(self.default_dimension, dtype=np.float32)
for index, byte in enumerate(content.encode("utf-8")):
vector[index % self.default_dimension] += float((byte % 17) + 1)
norm = float(np.linalg.norm(vector))
if norm > 0:
vector /= norm
return vector
if isinstance(text, (list, tuple)):
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
return _encode_one(text).astype(np.float32)
class _KernelBackedRuntimeManager:
def __init__(self, kernel: SDKMemoryKernel) -> None:
self.kernel = kernel
async def invoke(
self,
component_name: str,
args: Dict[str, Any] | None,
*,
timeout_ms: int = 30000,
) -> Any:
del timeout_ms
payload = args or {}
handler = getattr(self.kernel, component_name)
result = handler(**payload)
return await result if inspect.isawaitable(result) else result
class _NoopRuntimeManager:
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> Any:
del hook_name
return SimpleNamespace(aborted=False, kwargs=kwargs)
class _FakePlatformIOManager:
def __init__(self) -> None:
self.ensure_calls = 0
async def ensure_send_pipeline_ready(self) -> None:
self.ensure_calls += 1
def build_route_key_from_message(self, message: Any) -> Any:
del message
return SimpleNamespace(platform="qq")
async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any:
del message, metadata
return SimpleNamespace(
has_success=True,
sent_receipts=[
SimpleNamespace(
driver_id="plugin.qq.sender",
external_message_id="real-message-id",
metadata={},
)
],
failed_receipts=[],
route_key=route_key,
)
def _install_temp_main_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
db_dir = (tmp_path / "main_db").resolve()
db_dir.mkdir(parents=True, exist_ok=True)
db_file = db_dir / "MaiBot.db"
database_url = f"sqlite:///{db_file}"
try:
database_module.engine.dispose()
except Exception:
pass
engine = create_engine(
database_url,
echo=False,
connect_args={"check_same_thread": False},
pool_pre_ping=True,
)
session_local = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
class_=Session,
)
bootstrapper = create_database_migration_bootstrapper(engine)
monkeypatch.setattr(database_module, "_DB_DIR", db_dir, raising=False)
monkeypatch.setattr(database_module, "_DB_FILE", db_file, raising=False)
monkeypatch.setattr(database_module, "DATABASE_URL", database_url, raising=False)
monkeypatch.setattr(database_module, "engine", engine, raising=False)
monkeypatch.setattr(database_module, "SessionLocal", session_local, raising=False)
monkeypatch.setattr(database_module, "_migration_bootstrapper", bootstrapper, raising=False)
monkeypatch.setattr(database_module, "_db_initialized", False, raising=False)
def _build_incoming_message(
*,
session_id: str,
user_id: str,
text: str,
timestamp: datetime | None = None,
) -> SessionMessage:
message = SessionMessage(
message_id="incoming-message-id",
timestamp=timestamp or datetime.now(),
platform="qq",
)
message.message_info = MessageInfo(
user_info=UserInfo(
user_id=user_id,
user_nickname="测试用户",
user_cardname="测试用户",
),
additional_config={},
)
message.raw_message = MessageSequence(components=[TextComponent(text=text)])
message.session_id = session_id
message.reply_to = None
message.is_mentioned = False
message.is_at = False
message.is_emoji = False
message.is_picture = False
message.is_command = False
message.is_notify = False
message.processed_plain_text = text
message.initialized = True
return message
async def _wait_until(
predicate: Callable[[], Any],
*,
timeout_seconds: float = 10.0,
interval_seconds: float = 0.05,
description: str,
) -> Any:
deadline = asyncio.get_running_loop().time() + max(0.5, float(timeout_seconds))
while asyncio.get_running_loop().time() < deadline:
value = predicate()
if inspect.isawaitable(value):
value = await value
if value:
return value
await asyncio.sleep(interval_seconds)
raise AssertionError(f"等待超时: {description}")
@pytest.mark.asyncio
async def test_text_to_stream_triggers_real_chat_summary_writeback(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
_install_temp_main_database(monkeypatch, tmp_path)
fake_embedding_manager = _FakeEmbeddingManager()
captured_prompts: List[str] = []
fixed_send_timestamp = 1_777_000_000.0
async def _fake_runtime_self_check(**kwargs: Any) -> Dict[str, Any]:
del kwargs
return {
"ok": True,
"message": "ok",
"configured_dimension": fake_embedding_manager.default_dimension,
"requested_dimension": fake_embedding_manager.default_dimension,
"vector_store_dimension": fake_embedding_manager.default_dimension,
"detected_dimension": fake_embedding_manager.default_dimension,
"encoded_dimension": fake_embedding_manager.default_dimension,
"elapsed_ms": 0.0,
"sample_text": "test",
"checked_at": datetime.now().timestamp(),
}
async def _fake_generate(request: Any) -> Any:
captured_prompts.append(str(getattr(request, "prompt", "") or ""))
return SimpleNamespace(
success=True,
completion=SimpleNamespace(
response=json.dumps(
{
"summary": "这段对话记录了用户提到自己买了绿色围巾,机器人表示会记住这件事。",
"entities": ["绿色围巾"],
"relations": [],
},
ensure_ascii=False,
)
),
)
monkeypatch.setattr(
kernel_module,
"create_embedding_api_adapter",
lambda **kwargs: fake_embedding_manager,
)
monkeypatch.setattr(
kernel_module,
"run_embedding_runtime_self_check",
_fake_runtime_self_check,
)
monkeypatch.setattr(
summary_importer_module,
"run_embedding_runtime_self_check",
_fake_runtime_self_check,
)
monkeypatch.setattr(
summary_importer_module.llm_api,
"get_available_models",
lambda: {"utils": TaskConfig(model_list=["fake-summary-model"])},
)
monkeypatch.setattr(
summary_importer_module.llm_api,
"resolve_task_name_from_model_config",
lambda model_config: "utils",
)
monkeypatch.setattr(
summary_importer_module.llm_api,
"generate",
_fake_generate,
)
monkeypatch.setattr(send_service.time, "time", lambda: fixed_send_timestamp)
monkeypatch.setattr(summary_importer_module.time, "time", lambda: fixed_send_timestamp)
kernel = SDKMemoryKernel(
plugin_root=tmp_path / "plugin_root",
config={
"storage": {"data_dir": str((tmp_path / "a_memorix_data").resolve())},
"advanced": {"enable_auto_save": False},
"embedding": {"dimension": fake_embedding_manager.default_dimension},
"memory": {"base_decay_interval_hours": 24},
"person_profile": {"refresh_interval_minutes": 5},
"summarization": {"model_name": ["utils"]},
},
)
service = memory_flow_service_module.MemoryAutomationService()
fake_platform_io_manager = _FakePlatformIOManager()
async def _fake_rebuild_episodes_for_sources(sources: List[str]) -> Dict[str, Any]:
return {
"rebuilt": 0,
"items": [],
"failures": [],
"sources": list(sources),
}
monkeypatch.setattr(kernel, "rebuild_episodes_for_sources", _fake_rebuild_episodes_for_sources)
monkeypatch.setattr(
memory_service_module,
"a_memorix_host_service",
_KernelBackedRuntimeManager(kernel),
)
monkeypatch.setattr(memory_flow_service_module, "memory_automation_service", service)
monkeypatch.setattr(send_service, "_get_runtime_manager", lambda: _NoopRuntimeManager())
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_platform_io_manager)
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
monkeypatch.setattr(
send_service._chat_manager,
"get_session_by_session_id",
lambda stream_id: (
BotChatSession(
session_id="test-session",
platform="qq",
user_id="target-user",
group_id=None,
)
if stream_id == "test-session"
else None
),
)
integration_config = memory_flow_service_module.global_config.a_memorix.integration
monkeypatch.setattr(integration_config, "chat_summary_writeback_enabled", True, raising=False)
monkeypatch.setattr(integration_config, "chat_summary_writeback_message_threshold", 2, raising=False)
monkeypatch.setattr(integration_config, "chat_summary_writeback_context_length", 10, raising=False)
monkeypatch.setattr(integration_config, "person_fact_writeback_enabled", False, raising=False)
await kernel.initialize()
try:
incoming_message = _build_incoming_message(
session_id="test-session",
user_id="target-user",
text="我最近买了一条绿色围巾。",
timestamp=datetime.fromtimestamp(fixed_send_timestamp) - timedelta(seconds=1),
)
with database_module.get_db_session() as session:
session.add(incoming_message.to_db_instance())
sent_message = await send_service.text_to_stream_with_message(
text="好的,我会记住你最近买了绿色围巾。",
stream_id="test-session",
storage_message=True,
)
assert sent_message is not None
assert sent_message.message_id == "real-message-id"
assert fake_platform_io_manager.ensure_calls == 1
assert count_messages(session_id="test-session") == 2
paragraphs = await _wait_until(
lambda: kernel.metadata_store.get_paragraphs_by_source("chat_summary:test-session"),
description="等待聊天摘要写回到 A_memorix",
)
assert captured_prompts
assert "我最近买了一条绿色围巾。" in captured_prompts[-1]
assert "好的,我会记住你最近买了绿色围巾。" in captured_prompts[-1]
assert any("绿色围巾" in str(item.get("content", "") or "") for item in paragraphs)
assert any(
int(
(
pickle.loads(item.get("metadata"))
if isinstance(item.get("metadata"), (bytes, bytearray))
else item.get("metadata")
or {}
).get("trigger_message_count", 0)
or 0
)
== 2
for item in paragraphs
)
assert service.chat_summary_writeback._states["test-session"].last_trigger_message_count == 2
finally:
await service.shutdown()
await kernel.shutdown()
try:
database_module.engine.dispose()
except Exception:
pass

View File

@@ -1,191 +0,0 @@
from __future__ import annotations
from types import SimpleNamespace
import numpy as np
import pytest
from A_memorix.core.embedding import api_adapter as api_adapter_module
from A_memorix.core.embedding.api_adapter import EmbeddingAPIAdapter
from A_memorix.core.utils.runtime_self_check import run_embedding_runtime_self_check
class _FakeEmbeddingClient:
def __init__(self, *, natural_dimension: int = 12) -> None:
self.natural_dimension = int(natural_dimension)
self.requests = []
async def get_embedding(self, request):
self.requests.append(request)
requested_dimension = request.extra_params.get("dimensions")
if requested_dimension is None:
requested_dimension = request.extra_params.get("output_dimensionality")
dimension = int(requested_dimension or self.natural_dimension)
return SimpleNamespace(embedding=[1.0] * dimension)
def _build_adapter(
monkeypatch: pytest.MonkeyPatch,
*,
client_type: str,
configured_dimension: int = 1024,
effective_dimension: int | None = None,
model_extra_params: dict | None = None,
):
adapter = EmbeddingAPIAdapter(default_dimension=configured_dimension)
if effective_dimension is not None:
adapter._dimension = int(effective_dimension)
adapter._dimension_detected = True
fake_client = _FakeEmbeddingClient()
model_info = SimpleNamespace(
name="embedding-model",
api_provider="provider-1",
model_identifier="embedding-model-id",
extra_params=dict(model_extra_params or {}),
)
provider = SimpleNamespace(name="provider-1", client_type=client_type)
monkeypatch.setattr(adapter, "_resolve_candidate_model_names", lambda: ["embedding-model"])
monkeypatch.setattr(adapter, "_find_model_info", lambda model_name: model_info)
monkeypatch.setattr(adapter, "_find_provider", lambda provider_name: provider)
monkeypatch.setattr(
api_adapter_module.client_registry,
"get_client_class_instance",
lambda api_provider, force_new=True: fake_client,
)
return adapter, fake_client
@pytest.mark.asyncio
async def test_encode_uses_canonical_dimension_for_openai_provider(monkeypatch):
adapter, fake_client = _build_adapter(
monkeypatch,
client_type="openai",
configured_dimension=1024,
effective_dimension=1024,
model_extra_params={"task_type": "SEMANTIC_SIMILARITY"},
)
embedding = await adapter.encode("北塔木梯")
request = fake_client.requests[-1]
assert request.extra_params["dimensions"] == 1024
assert "output_dimensionality" not in request.extra_params
assert request.extra_params["task_type"] == "SEMANTIC_SIMILARITY"
assert embedding.shape == (1024,)
@pytest.mark.asyncio
async def test_encode_explicit_dimension_override_wins(monkeypatch):
adapter, fake_client = _build_adapter(
monkeypatch,
client_type="openai",
configured_dimension=1024,
effective_dimension=1024,
)
embedding = await adapter.encode("海潮图", dimensions=256)
request = fake_client.requests[-1]
assert request.extra_params["dimensions"] == 256
assert "output_dimensionality" not in request.extra_params
assert embedding.shape == (256,)
@pytest.mark.asyncio
async def test_encode_maps_dimension_to_gemini_output_dimensionality(monkeypatch):
adapter, fake_client = _build_adapter(
monkeypatch,
client_type="gemini",
configured_dimension=1024,
effective_dimension=768,
)
embedding = await adapter.encode("广播站")
request = fake_client.requests[-1]
assert request.extra_params["output_dimensionality"] == 768
assert "dimensions" not in request.extra_params
assert embedding.shape == (768,)
@pytest.mark.asyncio
async def test_encode_does_not_force_dimension_for_unsupported_provider(monkeypatch):
adapter, fake_client = _build_adapter(
monkeypatch,
client_type="custom",
configured_dimension=1024,
effective_dimension=640,
model_extra_params={
"dimensions": 999,
"output_dimensionality": 888,
"custom_flag": "keep-me",
},
)
embedding = await adapter.encode("蓝漆铁盒")
request = fake_client.requests[-1]
assert "dimensions" not in request.extra_params
assert "output_dimensionality" not in request.extra_params
assert request.extra_params["custom_flag"] == "keep-me"
assert embedding.shape == (fake_client.natural_dimension,)
@pytest.mark.asyncio
async def test_runtime_self_check_reports_requested_dimension_without_explicit_override():
class _FakeEmbeddingManager:
def __init__(self) -> None:
self.detected_dimension = 384
self.encode_calls = []
async def _detect_dimension(self) -> int:
return self.detected_dimension
def get_requested_dimension(self) -> int:
return self.detected_dimension
async def encode(self, text):
self.encode_calls.append(text)
return np.ones(self.detected_dimension, dtype=np.float32)
manager = _FakeEmbeddingManager()
report = await run_embedding_runtime_self_check(
config={"embedding": {"dimension": 1024}},
vector_store=SimpleNamespace(dimension=384),
embedding_manager=manager,
)
assert report["ok"] is True
assert report["configured_dimension"] == 1024
assert report["requested_dimension"] == 384
assert report["detected_dimension"] == 384
assert report["encoded_dimension"] == 384
assert manager.encode_calls == ["A_Memorix runtime self check"]
@pytest.mark.asyncio
async def test_encode_batch_keeps_batch_local_indexes_when_cache_hits_previous_batch(monkeypatch):
adapter = EmbeddingAPIAdapter(default_dimension=4, enable_cache=True)
adapter._dimension = 4
adapter._dimension_detected = True
async def fake_detect_dimension() -> int:
return 4
async def fake_get_embedding_direct(text: str, dimensions: int | None = None):
del dimensions
base = float(ord(str(text)[0]))
return [base, base + 1.0, base + 2.0, base + 3.0]
monkeypatch.setattr(adapter, "_detect_dimension", fake_detect_dimension)
monkeypatch.setattr(adapter, "_get_embedding_direct", fake_get_embedding_direct)
embeddings = await adapter.encode(["A", "B", "A", "C"], batch_size=2)
assert embeddings.shape == (4, 4)
assert np.array_equal(embeddings[0], embeddings[2])
assert embeddings[1][0] == float(ord("B"))
assert embeddings[3][0] == float(ord("C"))

View File

@@ -1,780 +0,0 @@
from __future__ import annotations
import asyncio
import inspect
import json
import time
import uuid
from datetime import datetime
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Callable, Dict
import numpy as np
import pytest
import pytest_asyncio
from sqlalchemy.orm import sessionmaker
from sqlmodel import Session, create_engine, select
IMPORT_ERROR: str | None = None
try:
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
from src.A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
from src.chat.heart_flow.heartflow_manager import heartflow_manager
from src.chat.message_receive import bot as bot_module
from src.chat.message_receive.chat_manager import chat_manager
from src.chat.message_receive.bot import chat_bot
from src.common.database import database as database_module
from src.common.database.database_model import PersonInfo, ToolRecord
from src.common.database.migrations import create_database_migration_bootstrapper
from src.common.utils.utils_session import SessionUtils
from src.llm_models.payload_content.tool_option import ToolCall
from src.maisaka import reasoning_engine as reasoning_engine_module
from src.maisaka import runtime as runtime_module
from src.maisaka import chat_loop_service as chat_loop_service_module
from src.maisaka.chat_loop_service import ChatResponse
from src.maisaka.context_messages import AssistantMessage
from src.plugin_runtime import component_query as component_query_module
from src.services import memory_flow_service as memory_flow_service_module
from src.services import memory_service as memory_service_module
from src.services.memory_service import memory_service
except SystemExit as exc:
IMPORT_ERROR = f"config initialization exited during import: {exc}"
kernel_module = None # type: ignore[assignment]
KernelSearchRequest = None # type: ignore[assignment]
SDKMemoryKernel = None # type: ignore[assignment]
heartflow_manager = None # type: ignore[assignment]
bot_module = None # type: ignore[assignment]
chat_manager = None # type: ignore[assignment]
chat_bot = None # type: ignore[assignment]
database_module = None # type: ignore[assignment]
ToolRecord = None # type: ignore[assignment]
PersonInfo = None # type: ignore[assignment]
create_database_migration_bootstrapper = None # type: ignore[assignment]
SessionUtils = None # type: ignore[assignment]
ToolCall = None # type: ignore[assignment]
reasoning_engine_module = None # type: ignore[assignment]
runtime_module = None # type: ignore[assignment]
chat_loop_service_module = None # type: ignore[assignment]
ChatResponse = None # type: ignore[assignment]
AssistantMessage = None # type: ignore[assignment]
component_query_module = None # type: ignore[assignment]
memory_flow_service_module = None # type: ignore[assignment]
memory_service_module = None # type: ignore[assignment]
memory_service = None # type: ignore[assignment]
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
RELATION_QUERY = "测试用户 和 最喜欢的颜色 有什么关系"
class _FakeEmbeddingManager:
def __init__(self, dimension: int = 8) -> None:
self.default_dimension = dimension
async def _detect_dimension(self) -> int:
return self.default_dimension
async def encode(self, text: Any) -> np.ndarray:
def _encode_one(raw: Any) -> np.ndarray:
content = str(raw or "")
vector = np.zeros(self.default_dimension, dtype=np.float32)
for index, byte in enumerate(content.encode("utf-8")):
vector[index % self.default_dimension] += float((byte % 17) + 1)
norm = float(np.linalg.norm(vector))
if norm > 0:
vector /= norm
return vector
if isinstance(text, (list, tuple)):
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
return _encode_one(text).astype(np.float32)
class _KernelBackedRuntimeManager:
def __init__(self, kernel: SDKMemoryKernel) -> None:
self.kernel = kernel
async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000):
del timeout_ms
payload = args or {}
if component_name == "search_memory":
return await self.kernel.search_memory(
KernelSearchRequest(
query=str(payload.get("query", "") or ""),
limit=int(payload.get("limit", 5) or 5),
mode=str(payload.get("mode", "hybrid") or "hybrid"),
chat_id=str(payload.get("chat_id", "") or ""),
person_id=str(payload.get("person_id", "") or ""),
time_start=payload.get("time_start"),
time_end=payload.get("time_end"),
respect_filter=bool(payload.get("respect_filter", True)),
user_id=str(payload.get("user_id", "") or ""),
group_id=str(payload.get("group_id", "") or ""),
)
)
handler = getattr(self.kernel, component_name)
result = handler(**payload)
return await result if inspect.isawaitable(result) else result
class _NoopRuntimeManager:
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> Any:
del hook_name
return SimpleNamespace(aborted=False, kwargs=kwargs)
def _install_temp_main_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
db_dir = (tmp_path / "main_db").resolve()
db_dir.mkdir(parents=True, exist_ok=True)
db_file = db_dir / "MaiBot.db"
database_url = f"sqlite:///{db_file}"
try:
database_module.engine.dispose()
except Exception:
pass
engine = create_engine(
database_url,
echo=False,
connect_args={"check_same_thread": False},
pool_pre_ping=True,
)
session_local = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
class_=Session,
)
bootstrapper = create_database_migration_bootstrapper(engine)
monkeypatch.setattr(database_module, "_DB_DIR", db_dir, raising=False)
monkeypatch.setattr(database_module, "_DB_FILE", db_file, raising=False)
monkeypatch.setattr(database_module, "DATABASE_URL", database_url, raising=False)
monkeypatch.setattr(database_module, "engine", engine, raising=False)
monkeypatch.setattr(database_module, "SessionLocal", session_local, raising=False)
monkeypatch.setattr(database_module, "_migration_bootstrapper", bootstrapper, raising=False)
monkeypatch.setattr(database_module, "_db_initialized", False, raising=False)
def _build_chat_response(content: str, tool_calls: list[ToolCall]) -> ChatResponse:
return ChatResponse(
content=content,
tool_calls=tool_calls,
request_messages=[],
raw_message=AssistantMessage(
content=content,
timestamp=datetime.now(),
tool_calls=tool_calls,
),
selected_history_count=0,
tool_count=len(tool_calls),
prompt_tokens=0,
built_message_count=0,
completion_tokens=0,
total_tokens=0,
prompt_section=None,
)
def _build_message_data(
*,
content: str,
platform: str,
user_id: str,
user_name: str,
group_id: str,
group_name: str,
) -> Dict[str, Any]:
message_id = str(uuid.uuid4())
return {
"message_info": {
"platform": platform,
"message_id": message_id,
"time": time.time(),
"group_info": {
"group_id": group_id,
"group_name": group_name,
"platform": platform,
},
"user_info": {
"user_id": user_id,
"user_nickname": user_name,
"user_cardname": user_name,
"platform": platform,
},
"additional_config": {
"at_bot": True,
},
},
"message_segment": {
"type": "seglist",
"data": [
{
"type": "text",
"data": content,
},
],
},
"raw_message": content,
"processed_plain_text": content,
}
async def _wait_until(
predicate: Callable[[], Any],
*,
timeout_seconds: float = 10.0,
interval_seconds: float = 0.05,
description: str,
) -> Any:
deadline = asyncio.get_running_loop().time() + max(0.5, float(timeout_seconds))
while asyncio.get_running_loop().time() < deadline:
value = predicate()
if inspect.isawaitable(value):
value = await value
if value:
return value
await asyncio.sleep(interval_seconds)
raise AssertionError(f"等待超时: {description}")
def _load_feedback_tasks(kernel: SDKMemoryKernel) -> list[Dict[str, Any]]:
assert kernel.metadata_store is not None
cursor = kernel.metadata_store.get_connection().cursor()
rows = cursor.execute(
"SELECT query_tool_id FROM memory_feedback_tasks ORDER BY id"
).fetchall()
tasks: list[Dict[str, Any]] = []
for row in rows:
task = kernel.metadata_store.get_feedback_task(str(row["query_tool_id"] or ""))
if task is not None:
tasks.append(task)
return tasks
def _load_feedback_action_types(kernel: SDKMemoryKernel) -> list[str]:
assert kernel.metadata_store is not None
cursor = kernel.metadata_store.get_connection().cursor()
rows = cursor.execute(
"SELECT action_type FROM memory_feedback_action_logs ORDER BY id"
).fetchall()
return [str(row["action_type"] or "") for row in rows]
def _load_query_memory_tool_records(session_id: str) -> list[Dict[str, Any]]:
with database_module.get_db_session() as session:
statement = (
select(ToolRecord)
.where(ToolRecord.session_id == session_id)
.where(ToolRecord.tool_name == "query_memory")
.order_by(ToolRecord.timestamp)
)
rows = list(session.exec(statement).all())
return [
{
"tool_id": str(row.tool_id or ""),
"session_id": str(row.session_id or ""),
"tool_name": str(row.tool_name or ""),
"tool_data": str(row.tool_data or ""),
"timestamp": row.timestamp,
}
for row in rows
]
def _seed_person_info(*, person_id: str, person_name: str, session_info: Dict[str, Any]) -> None:
with database_module.get_db_session() as session:
session.add(
PersonInfo(
is_known=True,
person_id=person_id,
person_name=person_name,
platform=str(session_info["platform"]),
user_id=str(session_info["user_id"]),
user_nickname=str(session_info["user_name"]),
group_cardname=json.dumps(
[{"group_id": str(session_info["group_id"]), "group_cardname": person_name}],
ensure_ascii=False,
),
know_counts=1,
first_known_time=datetime.now(),
last_known_time=datetime.now(),
)
)
session.commit()
@pytest_asyncio.fixture
async def chat_feedback_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
_install_temp_main_database(monkeypatch, tmp_path)
chat_manager.sessions.clear()
chat_manager.last_messages.clear()
heartflow_manager.heartflow_chat_list.clear()
noop_runtime_manager = _NoopRuntimeManager()
monkeypatch.setattr(bot_module.ChatBot, "_get_runtime_manager", staticmethod(lambda: noop_runtime_manager))
monkeypatch.setattr(
component_query_module.component_query_service,
"find_command_by_text",
lambda text: None,
)
monkeypatch.setattr(
component_query_module.component_query_service,
"get_llm_available_tool_specs",
lambda **kwargs: {},
)
monkeypatch.setattr(runtime_module.global_config.mcp, "enable", False, raising=False)
monkeypatch.setattr(
runtime_module.MaisakaHeartFlowChatting,
"_get_message_trigger_threshold",
lambda self: 1,
)
async def _noop_on_incoming_message(message: Any) -> None:
del message
monkeypatch.setattr(
memory_flow_service_module.memory_automation_service,
"on_incoming_message",
_noop_on_incoming_message,
)
fake_embedding_manager = _FakeEmbeddingManager(dimension=8)
async def _fake_runtime_self_check(
*,
config: Any,
sample_text: str,
vector_store: Any,
embedding_manager: Any,
) -> Dict[str, Any]:
del config, sample_text, vector_store, embedding_manager
return {
"ok": True,
"message": "ok",
"checked_at": time.time(),
"encoded_dimension": fake_embedding_manager.default_dimension,
}
monkeypatch.setattr(
kernel_module,
"create_embedding_api_adapter",
lambda **kwargs: fake_embedding_manager,
)
monkeypatch.setattr(
kernel_module,
"run_embedding_runtime_self_check",
_fake_runtime_self_check,
)
kernel = SDKMemoryKernel(
plugin_root=tmp_path / "plugin_root",
config={
"storage": {"data_dir": str((tmp_path / "a_memorix_data").resolve())},
"advanced": {"enable_auto_save": False},
"embedding": {"dimension": fake_embedding_manager.default_dimension},
"memory": {"base_decay_interval_hours": 24},
"person_profile": {"refresh_interval_minutes": 5},
},
)
monkeypatch.setattr(kernel, "_feedback_cfg_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_window_hours", lambda: 0.0004)
monkeypatch.setattr(kernel, "_feedback_cfg_check_interval_seconds", lambda: 0.2)
monkeypatch.setattr(kernel, "_feedback_cfg_batch_size", lambda: 10)
monkeypatch.setattr(kernel, "_feedback_cfg_max_messages", lambda: 10)
monkeypatch.setattr(kernel, "_feedback_cfg_auto_apply_threshold", lambda: 0.85)
monkeypatch.setattr(kernel, "_feedback_cfg_prefilter_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_paragraph_mark_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_paragraph_hard_filter_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_profile_refresh_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_profile_force_refresh_on_read", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_episode_rebuild_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_episode_query_block_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_reconcile_interval_seconds", lambda: 0.2)
monkeypatch.setattr(kernel, "_feedback_cfg_reconcile_batch_size", lambda: 10)
monkeypatch.setattr(kernel_module.global_config.memory, "feedback_correction_paragraph_hard_filter_enabled", True, raising=False)
monkeypatch.setattr(kernel_module.global_config.memory, "feedback_correction_episode_query_block_enabled", True, raising=False)
async def _fake_classify_feedback(
*,
query_tool_id: str,
query_text: str,
hit_briefs: list[Dict[str, Any]],
feedback_messages: list[str],
) -> Dict[str, Any]:
del query_tool_id, query_text, feedback_messages
target_hash = ""
for item in hit_briefs:
if str(item.get("type", "") or "").strip() == "relation":
target_hash = str(item.get("hash", "") or "").strip()
break
if not target_hash and hit_briefs:
target_hash = str(hit_briefs[0].get("hash", "") or "").strip()
return {
"decision": "correct",
"confidence": 0.97,
"target_hashes": [target_hash] if target_hash else [],
"corrected_relations": [
{
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "绿色",
"confidence": 0.99,
}
],
"reason": "用户明确纠正为绿色",
}
monkeypatch.setattr(kernel, "_classify_feedback", _fake_classify_feedback)
await kernel.initialize()
async def _force_episode_fallback(**kwargs: Any) -> Dict[str, Any]:
raise RuntimeError("force_fallback_for_test")
monkeypatch.setattr(
kernel.episode_service.segmentation_service,
"segment",
_force_episode_fallback,
)
monkeypatch.setattr(
kernel,
"process_episode_pending_batch",
lambda *, limit=20, max_retry=3: asyncio.sleep(0, result={"processed": 0, "episode_count": 0, "fallback_count": 0, "failed": 0}),
)
host_manager = _KernelBackedRuntimeManager(kernel)
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", host_manager)
planner_calls: list[str] = []
async def _fake_timing_gate(self, anchor_message: Any):
del self, anchor_message
return "continue", _build_chat_response("直接进入 planner。", []), [], []
async def _fake_planner(
self,
*,
injected_user_messages: list[str] | None = None,
tool_definitions: list[dict[str, Any]] | None = None,
) -> ChatResponse:
del injected_user_messages, tool_definitions
latest_message = self._runtime.message_cache[-1]
latest_text = str(latest_message.processed_plain_text or "")
planner_calls.append(latest_text)
handled_message_ids = getattr(self._runtime, "_test_query_message_ids", None)
if handled_message_ids is None:
handled_message_ids = set()
self._runtime._test_query_message_ids = handled_message_ids
if latest_message.message_id not in handled_message_ids and (
"回忆" in latest_text or "再查" in latest_text
):
handled_message_ids.add(latest_message.message_id)
tool_call = ToolCall(
call_id=f"query-{uuid.uuid4().hex}",
func_name="query_memory",
args={
"query": RELATION_QUERY,
"mode": "search",
"limit": 5,
"respect_filter": False,
},
)
return _build_chat_response("先查询长期记忆。", [tool_call])
stop_call = ToolCall(
call_id=f"stop-{uuid.uuid4().hex}",
func_name="no_reply",
args={},
)
return _build_chat_response("当前轮次结束。", [stop_call])
monkeypatch.setattr(
reasoning_engine_module.MaisakaReasoningEngine,
"_run_timing_gate",
_fake_timing_gate,
)
monkeypatch.setattr(
reasoning_engine_module.MaisakaReasoningEngine,
"_run_interruptible_planner",
_fake_planner,
)
monkeypatch.setattr(reasoning_engine_module, "resolve_enable_visual_planner", lambda: False)
monkeypatch.setattr(chat_loop_service_module, "resolve_enable_visual_planner", lambda: False)
session_info = {
"platform": "unit_test_chat",
"user_id": "user_feedback_flow",
"user_name": "反馈测试用户",
"group_id": "group_feedback_flow",
"group_name": "反馈纠错测试群",
}
person_id = "person_feedback_flow"
session_id = SessionUtils.calculate_session_id(
session_info["platform"],
user_id=session_info["user_id"],
group_id=session_info["group_id"],
)
_seed_person_info(person_id=person_id, person_name="测试用户", session_info=session_info)
try:
yield {
"kernel": kernel,
"session_id": session_id,
"session_info": session_info,
"person_id": person_id,
"planner_calls": planner_calls,
}
finally:
for key, chat in list(heartflow_manager.heartflow_chat_list.items()):
try:
await chat.stop()
except Exception:
pass
heartflow_manager.heartflow_chat_list.pop(key, None)
chat_manager.sessions.clear()
chat_manager.last_messages.clear()
await kernel.shutdown()
try:
database_module.engine.dispose()
except Exception:
pass
@pytest.mark.asyncio
async def test_feedback_correction_real_chat_flow(
chat_feedback_env,
monkeypatch: pytest.MonkeyPatch,
) -> None:
kernel = chat_feedback_env["kernel"]
session_id = chat_feedback_env["session_id"]
session_info = chat_feedback_env["session_info"]
person_id = chat_feedback_env["person_id"]
write_result = await memory_service.ingest_text(
external_id=f"test:feedback-seed:{uuid.uuid4().hex}",
source_type="chat_summary",
text="测试用户 最喜欢的颜色是 蓝色",
chat_id=session_id,
relations=[
{
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "蓝色",
"confidence": 1.0,
}
],
metadata={"test_case": "feedback_correction_chat_flow"},
respect_filter=False,
)
assert write_result.success is True
pre_search = await memory_service.search(
RELATION_QUERY,
mode="search",
chat_id=session_id,
respect_filter=False,
)
assert pre_search.hits
assert any("蓝色" in hit.content for hit in pre_search.hits)
pre_profile = await memory_service.get_person_profile(person_id, chat_id=session_id, limit=10)
pre_profile_text = pre_profile.summary + "\n" + json.dumps(pre_profile.evidence, ensure_ascii=False)
assert "蓝色" in pre_profile_text
seed_source = f"chat_summary:{session_id}"
rebuild_result = await kernel.rebuild_episodes_for_sources([seed_source])
assert rebuild_result["rebuilt"] >= 1
pre_episode = await memory_service.search(
"蓝色",
mode="episode",
chat_id=session_id,
respect_filter=False,
)
assert pre_episode.hits
assert any("蓝色" in hit.content for hit in pre_episode.hits)
await chat_bot.message_process(
_build_message_data(
content="请帮我回忆一下,测试用户最喜欢的颜色是什么?",
**session_info,
)
)
await _wait_until(
lambda: chat_feedback_env["planner_calls"][0] if chat_feedback_env["planner_calls"] else None,
description="planner 收到首条聊天消息",
)
first_query_records = await _wait_until(
lambda: _load_query_memory_tool_records(session_id) if _load_query_memory_tool_records(session_id) else None,
description="首条 query_memory 工具记录生成",
)
assert first_query_records
first_task = await _wait_until(
lambda: _load_feedback_tasks(kernel)[0] if _load_feedback_tasks(kernel) else None,
description="首个反馈任务入队",
)
assert first_task["status"] == "pending"
first_hits = list((first_task.get("query_snapshot") or {}).get("hits") or [])
assert first_hits
assert any("蓝色" in str(item.get("content", "") or "") for item in first_hits)
await chat_bot.message_process(
_build_message_data(
content="不对,测试用户最喜欢的颜色不是蓝色,是绿色。",
**session_info,
)
)
finalized_task = await _wait_until(
lambda: (
kernel.metadata_store.get_feedback_task(first_task["query_tool_id"])
if kernel.metadata_store.get_feedback_task(first_task["query_tool_id"])
and kernel.metadata_store.get_feedback_task(first_task["query_tool_id"]).get("status")
in {"applied", "skipped", "error"}
else None
),
timeout_seconds=12.0,
interval_seconds=0.1,
description="反馈任务进入终态",
)
assert finalized_task["status"] == "applied", finalized_task
assert finalized_task["decision_payload"]["decision"] == "correct"
assert finalized_task["decision_payload"]["apply_result"]["applied"] is True
corrected_hashes = list(
(finalized_task["decision_payload"].get("apply_result") or {}).get("relation_hashes") or []
)
assert corrected_hashes
corrected_hash = str(corrected_hashes[0] or "")
relation_status = kernel.metadata_store.get_relation_status_batch([corrected_hash]).get(corrected_hash, {})
assert bool(relation_status.get("is_inactive")) is True
action_types = _load_feedback_action_types(kernel)
assert "classification" in action_types
assert "forget_relation" in action_types
assert "ingest_correction" in action_types
assert "mark_stale_paragraph" in action_types
assert "enqueue_episode_rebuild" in action_types
assert "enqueue_profile_refresh" in action_types
original_search = memory_service.search
original_get_person_profile = memory_service.get_person_profile
corrected_search_result = memory_service_module.MemorySearchResult(
summary="测试用户最喜欢的颜色是绿色。",
hits=[memory_service_module.MemoryHit(content="测试用户 最喜欢的颜色是 绿色", score=0.99)],
)
stale_search_result = memory_service_module.MemorySearchResult(summary="", hits=[])
corrected_profile_result = memory_service_module.PersonProfileResult(
summary="测试用户最喜欢的颜色是绿色。",
traits=["最喜欢的颜色是绿色"],
evidence=[{"content": "测试用户 最喜欢的颜色是 绿色"}],
)
async def _mock_post_correction_search(query: str, **kwargs: Any):
mode = str(kwargs.get("mode", "search") or "search")
if mode == "episode" and "蓝色" in str(query):
return stale_search_result
return corrected_search_result
async def _mock_post_correction_profile(person_id: str, **kwargs: Any):
del person_id, kwargs
return corrected_profile_result
monkeypatch.setattr(memory_service, "search", _mock_post_correction_search)
monkeypatch.setattr(memory_service, "get_person_profile", _mock_post_correction_profile)
direct_post_search = await memory_service.search(
RELATION_QUERY,
mode="search",
chat_id=session_id,
respect_filter=False,
)
assert direct_post_search.hits
post_contents = "\n".join(hit.content for hit in direct_post_search.hits)
assert "绿色" in post_contents
assert "蓝色" not in post_contents
profile_refresh_request = await _wait_until(
lambda: (
kernel.metadata_store.get_person_profile_refresh_request(person_id)
if kernel.metadata_store.get_person_profile_refresh_request(person_id)
and kernel.metadata_store.get_person_profile_refresh_request(person_id).get("status") == "done"
else None
),
timeout_seconds=12.0,
interval_seconds=0.1,
description="人物画像刷新完成",
)
assert profile_refresh_request["status"] == "done"
post_profile = await memory_service.get_person_profile(person_id, chat_id=session_id, limit=10)
post_profile_text = post_profile.summary + "\n" + json.dumps(post_profile.evidence, ensure_ascii=False)
assert "绿色" in post_profile_text
assert "蓝色" not in post_profile_text
async def _latest_episode_result():
result = await memory_service.search(
"绿色",
mode="episode",
chat_id=session_id,
respect_filter=False,
)
if not result.hits:
return None
contents = "\n".join(hit.content for hit in result.hits)
if "绿色" in contents and "蓝色" not in contents:
return result
return None
post_episode = await _wait_until(
_latest_episode_result,
timeout_seconds=12.0,
interval_seconds=0.2,
description="episode 重建后返回修正结果",
)
assert post_episode is not None
stale_episode = await memory_service.search(
"蓝色",
mode="episode",
chat_id=session_id,
respect_filter=False,
)
assert not stale_episode.hits
await chat_bot.message_process(
_build_message_data(
content="再查一次,测试用户最喜欢的颜色是什么?",
**session_info,
)
)
tool_records = await _wait_until(
lambda: (
_load_query_memory_tool_records(session_id)
if len(_load_query_memory_tool_records(session_id)) >= 2
else None
),
timeout_seconds=10.0,
interval_seconds=0.1,
description="第二次 query_memory 工具记录生成",
)
latest_tool_data = json.loads(str(tool_records[-1].get("tool_data") or "{}"))
latest_hits = list((latest_tool_data.get("structured_content") or {}).get("hits") or [])
assert latest_hits
latest_contents = "\n".join(str(item.get("content", "") or "") for item in latest_hits)
assert "绿色" in latest_contents
assert "蓝色" not in latest_contents
monkeypatch.setattr(memory_service, "search", original_search)
monkeypatch.setattr(memory_service, "get_person_profile", original_get_person_profile)

View File

@@ -1,396 +0,0 @@
from datetime import datetime
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict
import pytest
IMPORT_ERROR: str | None = None
try:
from src.A_memorix.core.retrieval.sparse_bm25 import SparseBM25Config, SparseBM25Index
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
except SystemExit as exc:
IMPORT_ERROR = f"config initialization exited during import: {exc}"
SparseBM25Config = None # type: ignore[assignment]
SparseBM25Index = None # type: ignore[assignment]
kernel_module = None # type: ignore[assignment]
SDKMemoryKernel = None # type: ignore[assignment]
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
@pytest.mark.asyncio
async def test_kernel_enqueue_feedback_task_delegates_to_metadata_store(monkeypatch: pytest.MonkeyPatch) -> None:
captured: Dict[str, Any] = {}
def fake_enqueue_feedback_task(**kwargs):
captured.update(kwargs)
return {
"id": 1,
"query_tool_id": kwargs["query_tool_id"],
"session_id": kwargs["session_id"],
"query_timestamp": kwargs["query_timestamp"],
"due_at": kwargs["due_at"],
"query_snapshot": kwargs["query_snapshot"],
}
monkeypatch.setattr(
kernel_module,
"global_config",
SimpleNamespace(
memory=SimpleNamespace(
feedback_correction_enabled=True,
feedback_correction_window_hours=12.0,
)
),
)
query_time = datetime(2026, 4, 9, 10, 30, 0)
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel.metadata_store = SimpleNamespace(enqueue_feedback_task=fake_enqueue_feedback_task)
payload = await kernel.enqueue_feedback_task(
query_tool_id="tool-query-1",
session_id="session-1",
query_timestamp=query_time,
structured_content={"query": "Alice 喜欢什么", "hits": [{"hash": "relation-1"}]},
)
assert payload["success"] is True
assert payload["queued"] is True
assert captured["query_tool_id"] == "tool-query-1"
assert captured["session_id"] == "session-1"
assert captured["query_snapshot"]["query"] == "Alice 喜欢什么"
assert captured["query_snapshot"]["hits"] == [{"hash": "relation-1"}]
assert captured["due_at"] == pytest.approx(query_time.timestamp() + 12 * 3600, rel=0, abs=1e-6)
@pytest.mark.asyncio
async def test_kernel_enqueue_feedback_task_skipped_when_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
kernel_module,
"global_config",
SimpleNamespace(memory=SimpleNamespace(feedback_correction_enabled=False)),
)
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel.metadata_store = SimpleNamespace(enqueue_feedback_task=lambda **kwargs: kwargs)
payload = await kernel.enqueue_feedback_task(
query_tool_id="tool-query-2",
session_id="session-1",
query_timestamp=datetime.now(),
structured_content={"hits": [{"hash": "relation-1"}]},
)
assert payload["success"] is False
assert payload["reason"] == "feedback_correction_disabled"
@pytest.mark.asyncio
async def test_apply_feedback_decision_resolves_paragraph_targets() -> None:
action_logs: list[Dict[str, Any]] = []
forgotten_hashes: list[str] = []
ingested_payloads: list[Dict[str, Any]] = []
stale_marks: list[Dict[str, Any]] = []
episode_sources: list[str] = []
profile_refresh_ids: list[str] = []
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel.metadata_store = SimpleNamespace(
get_paragraph_relations=lambda paragraph_hash: [
{
"hash": "relation-1",
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "蓝色",
}
]
if paragraph_hash == "paragraph-1"
else [],
get_relation_status_batch=lambda hashes: {
str(hash_value): {"is_inactive": str(hash_value) in forgotten_hashes}
for hash_value in hashes
},
get_paragraph_hashes_by_relation_hashes=lambda hashes: {
"relation-1": ["paragraph-1"]
}
if "relation-1" in hashes
else {},
upsert_paragraph_stale_relation_mark=lambda **kwargs: stale_marks.append(kwargs) or kwargs,
enqueue_episode_source_rebuild=lambda source, reason="": episode_sources.append(source) or True,
enqueue_person_profile_refresh=lambda **kwargs: profile_refresh_ids.append(kwargs["person_id"]) or kwargs,
get_paragraph=lambda paragraph_hash: {"hash": "paragraph-1", "source": "chat_feedback_test_seed:session-1"}
if paragraph_hash == "paragraph-1"
else None,
append_feedback_action_log=lambda **kwargs: action_logs.append(kwargs),
)
kernel._feedback_cfg_auto_apply_threshold = lambda: 0.85 # type: ignore[method-assign]
kernel._apply_v5_relation_action = lambda *, action, hashes, strength=1.0: ( # type: ignore[method-assign]
forgotten_hashes.extend([str(item) for item in hashes]),
{"success": True, "action": action, "hashes": list(hashes), "strength": strength},
)[1]
kernel._feedback_cfg_paragraph_mark_enabled = lambda: True # type: ignore[method-assign]
kernel._feedback_cfg_episode_rebuild_enabled = lambda: True # type: ignore[method-assign]
kernel._feedback_cfg_profile_refresh_enabled = lambda: True # type: ignore[method-assign]
kernel._resolve_feedback_related_person_ids = lambda **kwargs: ["person-1"] # type: ignore[method-assign]
kernel._query_relation_rows_by_hashes = lambda relation_hashes, include_inactive=False: [ # type: ignore[method-assign]
{
"hash": "relation-1",
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "蓝色",
}
]
async def _fake_ingest_feedback_relations(**kwargs):
ingested_payloads.append(kwargs)
return {"success": True, "stored_ids": ["relation-2"]}
kernel._ingest_feedback_relations = _fake_ingest_feedback_relations # type: ignore[method-assign]
payload = await kernel._apply_feedback_decision(
task_id=1,
query_tool_id="tool-query-1",
session_id="session-1",
decision={
"decision": "correct",
"confidence": 0.97,
"target_hashes": ["paragraph-1"],
"corrected_relations": [
{
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "绿色",
"confidence": 0.99,
}
],
"reason": "用户明确纠正为绿色",
},
hit_map={
"paragraph-1": {
"hash": "paragraph-1",
"type": "paragraph",
"content": "测试用户 最喜欢的颜色是 蓝色",
"linked_relation_hashes": ["relation-1"],
}
},
)
assert payload["applied"] is True
assert payload["relation_hashes"] == ["relation-1"]
assert forgotten_hashes == ["relation-1"]
assert ingested_payloads[0]["relation_hashes"] == ["relation-1"]
assert payload["stale_paragraph_hashes"] == ["paragraph-1"]
assert "chat_feedback_test_seed:session-1" in payload["episode_rebuild_sources"]
assert "chat_summary:session-1" in payload["episode_rebuild_sources"]
assert payload["profile_refresh_person_ids"] == ["person-1"]
assert stale_marks[0]["paragraph_hash"] == "paragraph-1"
assert {item["action_type"] for item in action_logs} == {
"forget_relation",
"ingest_correction",
"mark_stale_paragraph",
"enqueue_episode_rebuild",
"enqueue_profile_refresh",
}
def test_filter_active_relation_hits_removes_inactive_relations() -> None:
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel._feedback_cfg_paragraph_hard_filter_enabled = lambda: True # type: ignore[method-assign]
kernel.metadata_store = SimpleNamespace(
get_relation_status_batch=lambda hashes: {
"r-active": {"is_inactive": False},
"r-inactive": {"is_inactive": True},
"r-para-inactive": {"is_inactive": True},
"r-stale-active": {"is_inactive": False},
"r-stale-inactive": {"is_inactive": True},
},
get_paragraph_relations=lambda paragraph_hash: (
[{"hash": "r-para-inactive"}] if paragraph_hash == "p-inactive" else []
),
get_paragraph_stale_relation_marks_batch=lambda paragraph_hashes: {
"p-stale": [{"relation_hash": "r-stale-inactive"}],
"p-restored": [{"relation_hash": "r-stale-active"}],
},
)
hits = [
{"hash": "r-active", "type": "relation", "content": "A 喜欢 B"},
{"hash": "r-inactive", "type": "relation", "content": "A 讨厌 B"},
{"hash": "p-1", "type": "paragraph", "content": "段落证据"},
{"hash": "p-inactive", "type": "paragraph", "content": "失活段落证据"},
{"hash": "p-stale", "type": "paragraph", "content": "被标脏段落"},
{"hash": "p-restored", "type": "paragraph", "content": "恢复可见段落"},
]
filtered = kernel._filter_active_relation_hits(hits)
assert [item["hash"] for item in filtered] == ["r-active", "p-1", "p-restored"]
def test_sparse_relation_search_requests_active_only() -> None:
captured: Dict[str, Any] = {}
class FakeMetadataStore:
def fts_search_relations_bm25(self, **kwargs):
captured.update(kwargs)
return []
index = SparseBM25Index(
metadata_store=FakeMetadataStore(), # type: ignore[arg-type]
config=SparseBM25Config(enabled=True, lazy_load=False),
)
index._loaded = True
index._conn = object() # type: ignore[assignment]
result = index.search_relations("测试纠错", k=5)
assert result == []
assert captured["include_inactive"] is False
@pytest.mark.asyncio
async def test_feedback_task_rollback_restores_snapshots_and_requeues_followups() -> None:
action_logs: list[Dict[str, Any]] = []
queued_sources: list[str] = []
queued_profiles: list[str] = []
deleted_marks: list[tuple[str, str]] = []
deleted_paragraphs: list[str] = []
relation_statuses: Dict[str, Dict[str, Any]] = {
"rel-old": {"is_inactive": True, "weight": 0.0, "is_pinned": False, "protected_until": 0.0, "last_reinforced": None, "inactive_since": 1.0},
"rel-new": {"is_inactive": False, "weight": 1.0, "is_pinned": False, "protected_until": 0.0, "last_reinforced": None, "inactive_since": None},
}
current_task: Dict[str, Any] = {
"id": 1,
"query_tool_id": "tool-query-rollback",
"session_id": "session-1",
"status": "applied",
"rollback_status": "none",
"query_snapshot": {"query": "测试用户最喜欢的颜色是什么"},
"decision_payload": {"decision": "correct", "confidence": 0.97},
"rollback_plan": {
"forgotten_relations": [
{
"hash": "rel-old",
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "蓝色",
"before_status": {
"is_inactive": False,
"weight": 0.8,
"is_pinned": False,
"protected_until": 0.0,
"last_reinforced": None,
"inactive_since": None,
},
}
],
"corrected_write": {
"paragraph_hashes": ["paragraph-new"],
"corrected_relations": [
{
"hash": "rel-new",
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "绿色",
"existed_before": False,
"before_status": {},
}
],
},
"stale_marks": [{"paragraph_hash": "paragraph-old", "relation_hash": "rel-old"}],
"episode_sources": ["chat_summary:session-1"],
"profile_person_ids": ["person-1"],
},
}
class _Conn:
def cursor(self):
return self
def execute(self, *_args, **_kwargs):
return self
def commit(self):
return None
metadata_store = SimpleNamespace(
get_feedback_task_by_id=lambda task_id: current_task if int(task_id) == 1 else None,
mark_feedback_task_rollback_running=lambda **kwargs: current_task.update({"rollback_status": "running"}) or current_task,
finalize_feedback_task_rollback=lambda **kwargs: current_task.update(
{
"rollback_status": kwargs["rollback_status"],
"rollback_result": kwargs.get("rollback_result") or {},
"rollback_error": kwargs.get("rollback_error", ""),
}
)
or current_task,
get_relation_status_batch=lambda hashes: {
hash_value: dict(relation_statuses[hash_value])
for hash_value in hashes
if hash_value in relation_statuses
},
restore_relation_status_from_snapshot=lambda hash_value, snapshot: relation_statuses.update(
{hash_value: dict(snapshot)}
)
or dict(snapshot),
append_feedback_action_log=lambda **kwargs: action_logs.append(kwargs),
mark_as_deleted=lambda hashes, type_: deleted_paragraphs.extend(list(hashes)) or len(list(hashes)),
get_paragraph=lambda paragraph_hash: {"hash": paragraph_hash, "source": "chat_summary:session-1"},
get_connection=lambda: _Conn(),
delete_external_memory_refs_by_paragraphs=lambda hashes: [
{"paragraph_hash": str(hash_value), "external_id": f"external:{hash_value}"}
for hash_value in hashes
],
update_relations_protection=lambda hashes, **kwargs: None,
mark_relations_inactive=lambda hashes, inactive_since=None: [
relation_statuses.__setitem__(
hash_value,
{
**relation_statuses.get(hash_value, {}),
"is_inactive": True,
"inactive_since": inactive_since,
},
)
for hash_value in hashes
],
delete_paragraph_stale_relation_marks=lambda marks: deleted_marks.extend(list(marks)) or len(list(marks)),
enqueue_episode_source_rebuild=lambda source, reason='': queued_sources.append(source) or True,
enqueue_person_profile_refresh=lambda **kwargs: queued_profiles.append(kwargs["person_id"]) or kwargs,
list_feedback_action_logs=lambda task_id: action_logs if int(task_id) == 1 else [],
)
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel.metadata_store = metadata_store
async def _noop_initialize() -> None:
return None
kernel.initialize = _noop_initialize # type: ignore[method-assign]
kernel._rebuild_graph_from_metadata = lambda: None # type: ignore[method-assign]
kernel._persist = lambda: None # type: ignore[method-assign]
payload = await kernel._rollback_feedback_task(
task_id=1,
requested_by="pytest",
reason="manual rollback",
)
assert payload["success"] is True
assert relation_statuses["rel-old"]["is_inactive"] is False
assert relation_statuses["rel-new"]["is_inactive"] is True
assert deleted_paragraphs == ["paragraph-new"]
assert deleted_marks == [("paragraph-old", "rel-old")]
assert queued_sources == ["chat_summary:session-1"]
assert queued_profiles == ["person-1"]
assert current_task["rollback_status"] == "rolled_back"
assert {item["action_type"] for item in action_logs} >= {
"rollback_restore_relation",
"rollback_revert_corrected_relation",
"rollback_delete_correction_paragraph",
"rollback_clear_stale_mark",
"rollback_enqueue_episode_rebuild",
"rollback_enqueue_profile_refresh",
}

View File

@@ -1,82 +0,0 @@
from __future__ import annotations
import pickle
from pathlib import Path
import pytest
try:
from src.A_memorix.core.storage.graph_store import GraphStore
except SystemExit as exc:
GraphStore = None # type: ignore[assignment]
IMPORT_ERROR = f"config initialization exited during import: {exc}"
else:
IMPORT_ERROR = None
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
def _build_empty_graph_metadata() -> dict:
return {
"nodes": [],
"node_to_idx": {},
"node_attrs": {},
"matrix_format": "csr",
"total_nodes_added": 0,
"total_edges_added": 0,
"total_nodes_deleted": 0,
"total_edges_deleted": 0,
"edge_hash_map": {},
}
def test_graph_store_clear_save_removes_stale_adjacency(tmp_path: Path) -> None:
data_dir = tmp_path / "graph_data"
store = GraphStore(data_dir=data_dir)
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
store.save()
matrix_path = data_dir / "graph_adjacency.npz"
assert matrix_path.exists()
store.clear()
store.save()
assert not matrix_path.exists()
def test_graph_store_load_resets_stale_adjacency_when_metadata_is_empty(tmp_path: Path) -> None:
data_dir = tmp_path / "graph_data"
store = GraphStore(data_dir=data_dir)
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
store.save()
metadata_path = data_dir / "graph_metadata.pkl"
with metadata_path.open("wb") as handle:
pickle.dump(_build_empty_graph_metadata(), handle)
reloaded = GraphStore(data_dir=data_dir)
reloaded.load()
assert reloaded.num_nodes == 0
assert reloaded.num_edges == 0
assert reloaded.get_nodes() == []
def test_graph_store_load_clears_stale_edge_hash_map_when_metadata_is_empty(tmp_path: Path) -> None:
data_dir = tmp_path / "graph_data"
store = GraphStore(data_dir=data_dir)
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
store.save()
metadata_path = data_dir / "graph_metadata.pkl"
empty_metadata = _build_empty_graph_metadata()
empty_metadata["edge_hash_map"] = {(0, 1): {"rel-1"}}
with metadata_path.open("wb") as handle:
pickle.dump(empty_metadata, handle)
reloaded = GraphStore(data_dir=data_dir)
reloaded.load()
assert reloaded.has_edge_hash_map() is False

View File

@@ -1,86 +0,0 @@
from __future__ import annotations
import json
from pathlib import Path
DATA_DIR = Path(__file__).parent / "data" / "benchmarks"
def _fixture_files() -> list[Path]:
return sorted(DATA_DIR.glob("group_chat_stream_memory_benchmark*.json"))
def _load_fixture(path: Path) -> dict:
return json.loads(path.read_text(encoding="utf-8"))
def _assert_fixture_matches_current_design_constraints(dataset: dict) -> None:
assert dataset["meta"]["scenario_id"]
assert dataset["session"]["group_id"]
assert dataset["session"]["platform"] == "qq"
simulated_batches = dataset["simulated_stream_batches"]
assert len(simulated_batches) >= 5
positive_batches = [item for item in simulated_batches if item["bot_participated"]]
negative_batches = [item for item in simulated_batches if not item["bot_participated"]]
assert len(positive_batches) >= 4
assert len(negative_batches) >= 1
assert any(item["expected_behavior"] == "ignored_by_summarizer_without_bot_message" for item in negative_batches)
for batch in positive_batches:
assert "Mai" in batch["participants"]
assert batch["message_count"] >= 10
assert len(batch["combined_text"]) >= 300
assert batch["start_time"] < batch["end_time"]
assert len(batch["expected_memory_targets"]) >= 4
runtime_streams = dataset["runtime_trigger_streams"]
assert len(runtime_streams) >= 2
runtime_positive = [item for item in runtime_streams if item["bot_participated"]]
runtime_negative = [item for item in runtime_streams if not item["bot_participated"]]
assert len(runtime_positive) >= 1
assert len(runtime_negative) >= 1
for stream in runtime_streams:
stream_text = "\n".join(stream["messages"])
assert stream["trigger_mode"] == "time_threshold"
assert stream["elapsed_since_last_check_hours"] >= 8.0
assert stream["message_count"] >= 20
assert len(stream["messages"]) == stream["message_count"]
assert len(stream_text) >= 1000
assert stream["start_time"] < stream["end_time"]
assert any(item["expected_check_outcome"] == "should_trigger_topic_check_and_pass_bot_gate" for item in runtime_positive)
assert any(
item["expected_check_outcome"] == "should_trigger_topic_check_but_be_discarded_without_bot_message"
for item in runtime_negative
)
records = dataset["chat_history_records"]
assert len(records) >= 4
for record in records:
assert "Mai" in record["participants"]
assert len(record["summary"]) >= 40
assert len(record["original_text"]) >= 200
assert record["start_time"] < record["end_time"]
assert len(dataset["person_writebacks"]) >= 3
assert len(dataset["search_cases"]) >= 4
assert len(dataset["time_cases"]) >= 3
assert len(dataset["episode_cases"]) >= 4
assert len(dataset["knowledge_fetcher_cases"]) >= 3
assert len(dataset["profile_cases"]) >= 3
assert len(dataset["negative_control_cases"]) >= 1
def test_group_chat_stream_fixture_matches_current_design_constraints():
files = _fixture_files()
assert files, "未找到 group_chat_stream_memory_benchmark*.json fixture"
for path in files:
_assert_fixture_matches_current_design_constraints(_load_fixture(path))

View File

@@ -1,124 +0,0 @@
from types import SimpleNamespace
import pytest
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
from src.services.memory_service import MemoryHit, MemorySearchResult
def test_knowledge_fetcher_resolves_private_memory_context(monkeypatch):
monkeypatch.setattr(
knowledge_module,
"_chat_manager",
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")),
)
monkeypatch.setattr(
knowledge_module,
"resolve_person_id_for_memory",
lambda *, person_name, platform, user_id: f"{person_name}:{platform}:{user_id}",
)
fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1")
assert fetcher._resolve_private_memory_context() == {
"chat_id": "stream-1",
"person_id": "Alice:qq:42",
"user_id": "42",
"group_id": "",
}
@pytest.mark.asyncio
async def test_knowledge_fetcher_memory_get_knowledge_uses_memory_service(monkeypatch):
monkeypatch.setattr(
knowledge_module,
"_chat_manager",
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")),
)
monkeypatch.setattr(
knowledge_module,
"resolve_person_id_for_memory",
lambda *, person_name, platform, user_id: f"{person_name}:{platform}:{user_id}",
)
calls = []
async def fake_search(query: str, **kwargs):
calls.append((query, kwargs))
return MemorySearchResult(summary="", hits=[MemoryHit(content="她喜欢猫", source="person_fact:qq:42")], filtered=False)
monkeypatch.setattr(knowledge_module.memory_service, "search", fake_search)
fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1")
result = await fetcher._memory_get_knowledge("她喜欢什么")
assert "1. 她喜欢猫" in result
assert calls == [
(
"她喜欢什么",
{
"limit": 5,
"mode": "search",
"chat_id": "stream-1",
"person_id": "Alice:qq:42",
"user_id": "42",
"group_id": "",
"respect_filter": True,
},
)
]
@pytest.mark.asyncio
async def test_knowledge_fetcher_falls_back_to_chat_scope_when_person_scope_misses(monkeypatch):
monkeypatch.setattr(
knowledge_module,
"_chat_manager",
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")),
)
monkeypatch.setattr(
knowledge_module,
"resolve_person_id_for_memory",
lambda *, person_name, platform, user_id: "person-1",
)
calls = []
async def fake_search(query: str, **kwargs):
calls.append((query, kwargs))
if kwargs.get("person_id"):
return MemorySearchResult(summary="", hits=[], filtered=False)
return MemorySearchResult(summary="", hits=[MemoryHit(content="她计划去杭州音乐节", source="chat_summary:stream-1")], filtered=False)
monkeypatch.setattr(knowledge_module.memory_service, "search", fake_search)
fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1")
result = await fetcher._memory_get_knowledge("Alice 最近在忙什么")
assert "杭州音乐节" in result
assert calls == [
(
"Alice 最近在忙什么",
{
"limit": 5,
"mode": "search",
"chat_id": "stream-1",
"person_id": "person-1",
"user_id": "42",
"group_id": "",
"respect_filter": True,
},
),
(
"Alice 最近在忙什么",
{
"limit": 5,
"mode": "search",
"chat_id": "stream-1",
"person_id": "",
"user_id": "42",
"group_id": "",
"respect_filter": True,
},
),
]

View File

@@ -1,355 +0,0 @@
from types import SimpleNamespace
import pytest
from src.services import memory_flow_service as memory_flow_module
def _fake_global_config(**integration_values):
return SimpleNamespace(
a_memorix=SimpleNamespace(
integration=SimpleNamespace(**integration_values),
)
)
def test_person_fact_parse_fact_list_deduplicates_and_filters_short_items():
raw = '["他喜欢猫", "他喜欢猫", "", "", "他会弹吉他"]'
result = memory_flow_module.PersonFactWritebackService._parse_fact_list(raw)
assert result == ["他喜欢猫", "他会弹吉他"]
def test_person_fact_looks_ephemeral_detects_short_chitchat():
assert memory_flow_module.PersonFactWritebackService._looks_ephemeral("哈哈")
assert memory_flow_module.PersonFactWritebackService._looks_ephemeral("好的?")
assert not memory_flow_module.PersonFactWritebackService._looks_ephemeral("她最近在学法语和钢琴")
def test_person_fact_resolve_target_person_for_private_chat(monkeypatch):
class FakePerson:
def __init__(self, person_id: str):
self.person_id = person_id
self.is_known = True
service = memory_flow_module.PersonFactWritebackService.__new__(memory_flow_module.PersonFactWritebackService)
monkeypatch.setattr(memory_flow_module, "is_bot_self", lambda platform, user_id: False)
monkeypatch.setattr(memory_flow_module, "get_person_id", lambda platform, user_id: f"{platform}:{user_id}")
monkeypatch.setattr(memory_flow_module, "Person", FakePerson)
message = SimpleNamespace(session=SimpleNamespace(platform="qq", user_id="123", group_id=""))
person = service._resolve_target_person(message)
assert person is not None
assert person.person_id == "qq:123"
@pytest.mark.asyncio
async def test_person_fact_writeback_skips_bot_only_fact_without_user_evidence(monkeypatch):
stored_facts: list[tuple[str, str, str]] = []
class FakePerson:
person_id = "person-1"
person_name = "测试用户"
nickname = "测试用户"
is_known = True
service = memory_flow_module.PersonFactWritebackService.__new__(memory_flow_module.PersonFactWritebackService)
service._resolve_target_person = lambda message: FakePerson()
async def fake_extract_facts(person, reply_text, user_evidence_text):
del person, reply_text, user_evidence_text
return ["测试用户喜欢辣椒"]
async def fake_store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str, **kwargs):
del kwargs
stored_facts.append((person_name, memory_content, chat_id))
service._extract_facts = fake_extract_facts
monkeypatch.setattr(memory_flow_module, "store_person_memory_from_answer", fake_store_person_memory_from_answer)
monkeypatch.setattr(memory_flow_module, "find_messages", lambda **kwargs: [])
message = SimpleNamespace(
processed_plain_text="我记得你喜欢辣椒。",
session_id="session-1",
reply_to="",
session=SimpleNamespace(platform="qq", user_id="bot-1", group_id=""),
)
await service._handle_message(message)
assert stored_facts == []
@pytest.mark.asyncio
async def test_chat_summary_writeback_service_triggers_when_threshold_reached(monkeypatch):
events: list[tuple[str, object]] = []
monkeypatch.setattr(
memory_flow_module,
"global_config",
_fake_global_config(
chat_summary_writeback_enabled=True,
chat_summary_writeback_message_threshold=3,
chat_summary_writeback_context_length=7,
),
)
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
async def fake_ingest_summary(**kwargs):
events.append(("ingest_summary", kwargs))
return SimpleNamespace(success=True, detail="ok")
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
del self, session_id, total_message_count
return 0
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
monkeypatch.setattr(
memory_flow_module.ChatSummaryWritebackService,
"_load_last_trigger_message_count",
fake_load_last_trigger_message_count,
)
service = memory_flow_module.ChatSummaryWritebackService()
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
await service._handle_message(message)
assert len(events) == 1
_, payload = events[0]
assert payload["external_id"] == "chat_auto_summary:session-1:5"
assert payload["chat_id"] == "session-1"
assert payload["text"] == ""
assert payload["metadata"]["generate_from_chat"] is True
assert payload["metadata"]["context_length"] == 7
assert payload["metadata"]["trigger"] == "message_threshold"
assert payload["user_id"] == "user-1"
assert payload["group_id"] == "group-1"
@pytest.mark.asyncio
async def test_chat_summary_writeback_service_skips_when_threshold_not_reached(monkeypatch):
called = False
monkeypatch.setattr(
memory_flow_module,
"global_config",
_fake_global_config(
chat_summary_writeback_enabled=True,
chat_summary_writeback_message_threshold=6,
chat_summary_writeback_context_length=9,
),
)
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
async def fake_ingest_summary(**kwargs):
nonlocal called
called = True
return SimpleNamespace(success=True, detail="ok")
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
del self, session_id, total_message_count
return 0
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
monkeypatch.setattr(
memory_flow_module.ChatSummaryWritebackService,
"_load_last_trigger_message_count",
fake_load_last_trigger_message_count,
)
service = memory_flow_module.ChatSummaryWritebackService()
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
await service._handle_message(message)
assert called is False
@pytest.mark.asyncio
async def test_chat_summary_writeback_service_restores_previous_trigger_count(monkeypatch):
events: list[tuple[str, object]] = []
monkeypatch.setattr(
memory_flow_module,
"global_config",
_fake_global_config(
chat_summary_writeback_enabled=True,
chat_summary_writeback_message_threshold=3,
chat_summary_writeback_context_length=7,
),
)
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 8)
async def fake_ingest_summary(**kwargs):
events.append(("ingest_summary", kwargs))
return SimpleNamespace(success=True, detail="ok")
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
del self, session_id, total_message_count
return 5
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
monkeypatch.setattr(
memory_flow_module.ChatSummaryWritebackService,
"_load_last_trigger_message_count",
fake_load_last_trigger_message_count,
)
service = memory_flow_module.ChatSummaryWritebackService()
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
await service._handle_message(message)
assert len(events) == 1
_, payload = events[0]
assert payload["external_id"] == "chat_auto_summary:session-1:8"
assert service._states["session-1"].last_trigger_message_count == 8
@pytest.mark.asyncio
async def test_chat_summary_writeback_service_falls_back_to_current_count_for_legacy_summary(monkeypatch):
called = False
monkeypatch.setattr(
memory_flow_module,
"global_config",
_fake_global_config(
chat_summary_writeback_enabled=True,
chat_summary_writeback_message_threshold=3,
chat_summary_writeback_context_length=7,
),
)
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
async def fake_ingest_summary(**kwargs):
nonlocal called
called = True
return SimpleNamespace(success=True, detail="ok")
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
del self, session_id, total_message_count
return 5
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
monkeypatch.setattr(
memory_flow_module.ChatSummaryWritebackService,
"_load_last_trigger_message_count",
fake_load_last_trigger_message_count,
)
service = memory_flow_module.ChatSummaryWritebackService()
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
await service._handle_message(message)
assert called is False
assert service._states["session-1"].last_trigger_message_count == 5
@pytest.mark.asyncio
async def test_chat_summary_writeback_service_loads_trigger_count_from_summary_metadata(monkeypatch):
class FakeMetadataStore:
@staticmethod
def get_paragraphs_by_source(source: str):
assert source == "chat_summary:session-1"
return [
{"created_at": 1.0, "metadata": {"trigger_message_count": 3}},
{"created_at": 2.0, "metadata": {"trigger_message_count": 6}},
]
class FakeRuntimeManager:
@staticmethod
async def _ensure_kernel():
return SimpleNamespace(metadata_store=FakeMetadataStore())
monkeypatch.setattr(memory_flow_module.memory_service_module, "a_memorix_host_service", FakeRuntimeManager())
service = memory_flow_module.ChatSummaryWritebackService()
restored = await service._load_last_trigger_message_count(session_id="session-1", total_message_count=8)
assert restored == 6
@pytest.mark.asyncio
async def test_memory_automation_service_auto_starts_and_delegates():
events: list[tuple[str, str]] = []
class FakeFactWriteback:
async def start(self):
events.append(("start", "fact"))
async def enqueue(self, message):
events.append(("sent", message.session_id))
async def shutdown(self):
events.append(("shutdown", "fact"))
class FakeChatSummaryWriteback:
async def start(self):
events.append(("start", "summary"))
async def enqueue(self, message):
events.append(("summary", message.session_id))
async def shutdown(self):
events.append(("shutdown", "summary"))
service = memory_flow_module.MemoryAutomationService()
service.fact_writeback = FakeFactWriteback()
service.chat_summary_writeback = FakeChatSummaryWriteback()
await service.on_message_sent(SimpleNamespace(session_id="session-1"))
await service.shutdown()
assert events == [
("start", "fact"),
("start", "summary"),
("sent", "session-1"),
("summary", "session-1"),
("shutdown", "summary"),
("shutdown", "fact"),
]
@pytest.mark.asyncio
async def test_memory_automation_service_on_incoming_message_auto_starts_only():
events: list[tuple[str, str]] = []
class FakeFactWriteback:
async def start(self):
events.append(("start", "fact"))
async def enqueue(self, message):
events.append(("sent", message.session_id))
async def shutdown(self):
events.append(("shutdown", "fact"))
class FakeChatSummaryWriteback:
async def start(self):
events.append(("start", "summary"))
async def enqueue(self, message):
events.append(("summary", message.session_id))
async def shutdown(self):
events.append(("shutdown", "summary"))
service = memory_flow_module.MemoryAutomationService()
service.fact_writeback = FakeFactWriteback()
service.chat_summary_writeback = FakeChatSummaryWriteback()
await service.on_incoming_message(SimpleNamespace(session_id="session-1"))
await service.shutdown()
assert events == [
("start", "fact"),
("start", "summary"),
("shutdown", "summary"),
("shutdown", "fact"),
]

View File

@@ -1,113 +0,0 @@
from __future__ import annotations
from pathlib import Path
from typing import Any
import pytest
from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
class _DummyMetadataStore:
def __init__(self, *, entities: list[dict[str, Any]], relations: list[dict[str, Any]]) -> None:
self._entities = entities
self._relations = relations
def query(self, sql: str, params: tuple[Any, ...] = ()) -> list[dict[str, Any]]:
sql_token = " ".join(str(sql or "").lower().split())
keyword = str(params[0] or "").strip("%").lower() if params else ""
if "from entities" in sql_token:
rows = [dict(item) for item in self._entities if not bool(item.get("is_deleted", 0))]
if not keyword:
return rows
return [
row
for row in rows
if keyword in str(row.get("name", "") or "").lower()
or keyword in str(row.get("hash", "") or "").lower()
]
if "from relations" in sql_token:
rows = [dict(item) for item in self._relations if not bool(item.get("is_inactive", 0))]
if not keyword:
return rows
return [
row
for row in rows
if keyword in str(row.get("subject", "") or "").lower()
or keyword in str(row.get("object", "") or "").lower()
or keyword in str(row.get("predicate", "") or "").lower()
or keyword in str(row.get("hash", "") or "").lower()
]
raise AssertionError(f"unexpected query: {sql_token}")
def _build_kernel(*, entities: list[dict[str, Any]], relations: list[dict[str, Any]]) -> SDKMemoryKernel:
kernel = SDKMemoryKernel(plugin_root=Path.cwd(), config={})
async def _fake_initialize() -> None:
return None
kernel.initialize = _fake_initialize # type: ignore[method-assign]
kernel.metadata_store = _DummyMetadataStore(entities=entities, relations=relations)
kernel.graph_store = object() # type: ignore[assignment]
return kernel
@pytest.mark.asyncio
async def test_memory_graph_admin_search_orders_and_dedupes_results() -> None:
kernel = _build_kernel(
entities=[
{"hash": "e1", "name": "Alice", "appearance_count": 5, "is_deleted": 0},
{"hash": "e1", "name": "Alice Duplicate", "appearance_count": 99, "is_deleted": 0},
{"hash": "e2", "name": "Alice Cooper", "appearance_count": 7, "is_deleted": 0},
{"hash": "e3", "name": "my alice note", "appearance_count": 11, "is_deleted": 0},
{"hash": "e4", "name": "alice deleted", "appearance_count": 100, "is_deleted": 1},
],
relations=[
{"hash": "r1", "subject": "Alice", "predicate": "knows", "object": "Bob", "confidence": 0.6, "created_at": 100, "is_inactive": 0},
{"hash": "r3", "subject": "Alice", "predicate": "supports", "object": "Carol", "confidence": 0.9, "created_at": 90, "is_inactive": 0},
{"hash": "r1", "subject": "Alice", "predicate": "knows duplicate", "object": "Bob", "confidence": 0.99, "created_at": 200, "is_inactive": 0},
{"hash": "r2", "subject": "Alice Cooper", "predicate": "likes", "object": "Tea", "confidence": 0.2, "created_at": 50, "is_inactive": 0},
{"hash": "", "subject": "Carol", "predicate": "mentions alice", "object": "Topic", "confidence": 0.8, "created_at": 70, "is_inactive": 0},
{"hash": "", "subject": "Carol", "predicate": "mentions alice", "object": "Topic", "confidence": 0.3, "created_at": 10, "is_inactive": 0},
{"hash": "r4", "subject": "alice inactive", "predicate": "old", "object": "Data", "confidence": 1.0, "created_at": 300, "is_inactive": 1},
],
)
payload = await kernel.memory_graph_admin(action="search", query="alice", limit=20)
assert payload["success"] is True
assert payload["count"] == len(payload["items"])
entity_items = [item for item in payload["items"] if item["type"] == "entity"]
relation_items = [item for item in payload["items"] if item["type"] == "relation"]
assert [item["entity_hash"] for item in entity_items] == ["e1", "e2", "e3"]
assert [item["relation_hash"] for item in relation_items] == ["r3", "r1", "r2", ""]
assert relation_items[0]["confidence"] == pytest.approx(0.9)
assert relation_items[1]["confidence"] == pytest.approx(0.6)
@pytest.mark.asyncio
async def test_memory_graph_admin_search_filters_deleted_and_inactive_records() -> None:
kernel = _build_kernel(
entities=[
{"hash": "e-deleted", "name": "Ghost Alice", "appearance_count": 10, "is_deleted": 1},
],
relations=[
{
"hash": "r-inactive",
"subject": "Ghost Alice",
"predicate": "linked",
"object": "Ghost Bob",
"confidence": 0.9,
"created_at": 10,
"is_inactive": 1,
},
],
)
payload = await kernel.memory_graph_admin(action="search", query="ghost", limit=50)
assert payload["success"] is True
assert payload["items"] == []
assert payload["count"] == 0

View File

@@ -1,281 +0,0 @@
import pytest
from src.services.memory_service import MemorySearchResult, MemoryService
def test_coerce_write_result_treats_skipped_payload_as_success():
result = MemoryService._coerce_write_result({"skipped_ids": ["p1"], "detail": "chat_filtered"})
assert result.success is True
assert result.stored_ids == []
assert result.skipped_ids == ["p1"]
assert result.detail == "chat_filtered"
@pytest.mark.asyncio
async def test_graph_admin_invokes_plugin(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args, kwargs))
return {"success": True, "nodes": [], "edges": []}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.graph_admin(action="get_graph", limit=12)
assert result["success"] is True
assert calls == [("memory_graph_admin", {"action": "get_graph", "limit": 12}, {"timeout_ms": 30000})]
@pytest.mark.asyncio
async def test_get_recycle_bin_uses_maintain_memory_tool(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args))
return {"success": True, "items": [{"hash": "abc"}], "count": 1}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.get_recycle_bin(limit=5)
assert result == {"success": True, "items": [{"hash": "abc"}], "count": 1}
assert calls == [("maintain_memory", {"action": "recycle_bin", "limit": 5})]
@pytest.mark.asyncio
async def test_search_respects_filter_by_default(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args))
return {"summary": "ok", "hits": [], "filtered": True}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.search(
"mai",
chat_id="stream-1",
person_id="person-1",
user_id="user-1",
group_id="",
)
assert isinstance(result, MemorySearchResult)
assert result.filtered is True
assert calls == [
(
"search_memory",
{
"query": "mai",
"limit": 5,
"mode": "search",
"chat_id": "stream-1",
"person_id": "person-1",
"time_start": None,
"time_end": None,
"respect_filter": True,
"user_id": "user-1",
"group_id": "",
},
)
]
@pytest.mark.asyncio
async def test_ingest_summary_can_bypass_filter(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args))
return {"success": True, "stored_ids": ["p1"], "detail": ""}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.ingest_summary(
external_id="chat_history:1",
chat_id="stream-1",
text="summary",
respect_filter=False,
user_id="user-1",
)
assert result.success is True
assert calls == [
(
"ingest_summary",
{
"external_id": "chat_history:1",
"chat_id": "stream-1",
"text": "summary",
"participants": [],
"time_start": None,
"time_end": None,
"tags": [],
"metadata": {},
"respect_filter": False,
"user_id": "user-1",
"group_id": "",
},
)
]
@pytest.mark.asyncio
async def test_v5_admin_invokes_plugin(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args, kwargs))
return {"success": True, "count": 1}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.v5_admin(action="status", target="mai", limit=5)
assert result["success"] is True
assert calls == [("memory_v5_admin", {"action": "status", "target": "mai", "limit": 5}, {"timeout_ms": 30000})]
@pytest.mark.asyncio
async def test_delete_admin_uses_long_timeout(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args, kwargs))
return {"success": True, "operation_id": "del-1"}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.delete_admin(action="execute", mode="relation", selector={"query": "mai"})
assert result["success"] is True
assert calls == [
(
"memory_delete_admin",
{"action": "execute", "mode": "relation", "selector": {"query": "mai"}},
{"timeout_ms": 120000},
)
]
@pytest.mark.asyncio
async def test_search_returns_empty_when_query_and_time_missing_async():
service = MemoryService()
result = await service.search("", time_start=None, time_end=None)
assert isinstance(result, MemorySearchResult)
assert result.summary == ""
assert result.hits == []
assert result.filtered is False
@pytest.mark.asyncio
async def test_search_accepts_string_time_bounds(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args))
return {"summary": "ok", "hits": [], "filtered": False}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.search(
"广播站",
mode="time",
time_start="2026/03/18",
time_end="2026/03/18 09:30",
)
assert isinstance(result, MemorySearchResult)
assert calls == [
(
"search_memory",
{
"query": "广播站",
"limit": 5,
"mode": "time",
"chat_id": "",
"person_id": "",
"time_start": "2026/03/18",
"time_end": "2026/03/18 09:30",
"respect_filter": True,
"user_id": "",
"group_id": "",
},
)
]
def test_coerce_search_result_preserves_aggregate_source_branches():
result = MemoryService._coerce_search_result(
{
"hits": [
{
"content": "广播站值夜班",
"type": "paragraph",
"metadata": {"event_time_start": 1.0},
"source_branches": ["search", "time"],
"rank": 1,
}
]
}
)
assert result.hits[0].metadata["source_branches"] == ["search", "time"]
assert result.hits[0].metadata["rank"] == 1
@pytest.mark.asyncio
async def test_import_admin_uses_long_timeout(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args, kwargs))
return {"success": True, "task_id": "import-1"}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.import_admin(action="create_lpmm_openie", alias="lpmm")
assert result["success"] is True
assert calls == [
(
"memory_import_admin",
{"action": "create_lpmm_openie", "alias": "lpmm"},
{"timeout_ms": 120000},
)
]
@pytest.mark.asyncio
async def test_tuning_admin_uses_long_timeout(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args, kwargs))
return {"success": True, "task_id": "tuning-1"}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.tuning_admin(action="create_task", payload={"query": "mai"})
assert result["success"] is True
assert calls == [
(
"memory_tuning_admin",
{"action": "create_task", "payload": {"query": "mai"}},
{"timeout_ms": 120000},
)
]

View File

@@ -1,21 +0,0 @@
from pathlib import Path
from src.A_memorix.core.storage.metadata_store import MetadataStore
def test_get_all_sources_ignores_soft_deleted_paragraphs(tmp_path: Path) -> None:
store = MetadataStore(data_dir=tmp_path)
store.connect()
try:
live_hash = store.add_paragraph("Alice 喜欢地图", source="live-source")
deleted_hash = store.add_paragraph("Bob 喜欢咖啡", source="deleted-source")
assert live_hash
store.mark_as_deleted([deleted_hash], "paragraph")
sources = store.get_all_sources()
finally:
store.close()
assert [item["source"] for item in sources] == ["live-source"]
assert sources[0]["count"] == 1

View File

@@ -1,81 +0,0 @@
from types import SimpleNamespace
import pytest
from src.person_info import person_info as person_info_module
@pytest.mark.asyncio
async def test_store_person_memory_from_answer_writes_person_fact(monkeypatch):
calls = []
class FakePerson:
def __init__(self, person_id: str):
self.person_id = person_id
self.person_name = "Alice"
self.is_known = True
async def fake_ingest_text(**kwargs):
calls.append(kwargs)
return SimpleNamespace(success=True, detail="", stored_ids=["p1"])
session = SimpleNamespace(platform="qq", user_id="10001", group_id="", session_id="session-1")
monkeypatch.setattr(person_info_module, "_chat_manager", SimpleNamespace(get_session_by_session_id=lambda chat_id: session))
monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: "person-1")
monkeypatch.setattr(person_info_module, "Person", FakePerson)
monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text)
await person_info_module.store_person_memory_from_answer("Alice", "她喜欢猫和爵士乐", "session-1")
assert len(calls) == 1
payload = calls[0]
assert payload["external_id"].startswith("person_fact:person-1:")
assert payload["source_type"] == "person_fact"
assert payload["chat_id"] == "session-1"
assert payload["person_ids"] == ["person-1"]
assert payload["participants"] == ["Alice"]
assert payload["respect_filter"] is True
assert payload["user_id"] == "10001"
assert payload["group_id"] == ""
assert payload["metadata"]["person_id"] == "person-1"
@pytest.mark.asyncio
async def test_store_person_memory_from_answer_skips_unknown_person(monkeypatch):
calls = []
class FakePerson:
def __init__(self, person_id: str):
self.person_id = person_id
self.person_name = "Unknown"
self.is_known = False
async def fake_ingest_text(**kwargs):
calls.append(kwargs)
return SimpleNamespace(success=True, detail="", stored_ids=["p1"])
session = SimpleNamespace(platform="qq", user_id="10001", group_id="", session_id="session-1")
monkeypatch.setattr(person_info_module, "_chat_manager", SimpleNamespace(get_session_by_session_id=lambda chat_id: session))
monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: "person-1")
monkeypatch.setattr(person_info_module, "Person", FakePerson)
monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text)
await person_info_module.store_person_memory_from_answer("Alice", "她喜欢猫和爵士乐", "session-1")
assert calls == []
@pytest.mark.asyncio
async def test_store_person_memory_from_answer_skips_empty_content(monkeypatch):
calls = []
async def fake_ingest_text(**kwargs):
calls.append(kwargs)
return SimpleNamespace(success=True, detail="", stored_ids=["p1"])
monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text)
await person_info_module.store_person_memory_from_answer("Alice", " ", "session-1")
assert calls == []

View File

@@ -1,115 +0,0 @@
from types import SimpleNamespace
import pytest
from src.A_memorix.core.utils.person_profile_service import PersonProfileService
class FakeMetadataStore:
def __init__(self) -> None:
self.snapshots: list[dict] = []
@staticmethod
def get_latest_person_profile_snapshot(person_id: str):
del person_id
return None
@staticmethod
def get_relations(**kwargs):
del kwargs
return []
@staticmethod
def get_paragraphs_by_source(source: str):
if source == "person_fact:person-1":
return [
{
"hash": "person-fact-1",
"content": "测试用户喜欢猫。",
"source": source,
"metadata": {"source_type": "person_fact"},
"created_at": 2.0,
"updated_at": 2.0,
}
]
return []
@staticmethod
def get_paragraph(hash_value: str):
if hash_value == "chat-summary-1":
return {
"hash": hash_value,
"content": "机器人建议测试用户以后叫星灯。",
"source": "chat_summary:session-1",
"metadata": {"source_type": "chat_summary"},
"word_count": 1,
}
if hash_value == "person-fact-1":
return {
"hash": hash_value,
"content": "测试用户喜欢猫。",
"source": "person_fact:person-1",
"metadata": {"source_type": "person_fact"},
"word_count": 1,
}
return None
@staticmethod
def get_paragraph_stale_relation_marks_batch(paragraph_hashes):
del paragraph_hashes
return {}
@staticmethod
def get_relation_status_batch(relation_hashes):
del relation_hashes
return {}
@staticmethod
def get_person_profile_override(person_id: str):
del person_id
return None
def upsert_person_profile_snapshot(self, **kwargs):
self.snapshots.append(kwargs)
return {
"person_id": kwargs["person_id"],
"profile_text": kwargs["profile_text"],
"aliases": kwargs["aliases"],
"relation_edges": kwargs["relation_edges"],
"vector_evidence": kwargs["vector_evidence"],
"evidence_ids": kwargs["evidence_ids"],
"updated_at": 1.0,
"expires_at": kwargs["expires_at"],
"source_note": kwargs["source_note"],
}
class FakeRetriever:
async def retrieve(self, query: str, top_k: int):
del query, top_k
return [
SimpleNamespace(
hash_value="chat-summary-1",
result_type="paragraph",
score=0.95,
content="机器人建议测试用户以后叫星灯。",
metadata={"source_type": "chat_summary"},
)
]
@pytest.mark.asyncio
async def test_person_profile_keeps_chat_summary_as_recent_interaction_not_stable_profile():
metadata_store = FakeMetadataStore()
service = PersonProfileService(metadata_store=metadata_store, retriever=FakeRetriever())
service.get_person_aliases = lambda person_id: (["测试用户"], "测试用户", [])
payload = await service.query_person_profile(person_id="person-1", top_k=6, force_refresh=True)
assert payload["success"] is True
profile_text = payload["profile_text"]
stable_section = profile_text.split("近期相关互动:", 1)[0]
assert "测试用户喜欢猫" in stable_section
assert "星灯" not in stable_section
assert "近期相关互动:" in profile_text
assert "星灯" in profile_text

View File

@@ -1,184 +0,0 @@
from __future__ import annotations
from datetime import datetime
from types import SimpleNamespace
import pytest
from src.memory_system.retrieval_tools import query_long_term_memory as tool_module
from src.memory_system.retrieval_tools import init_all_tools
from src.memory_system.retrieval_tools.query_long_term_memory import (
_resolve_time_expression,
query_long_term_memory,
register_tool,
)
from src.memory_system.retrieval_tools.tool_registry import get_tool_registry
from src.services.memory_service import MemoryHit, MemorySearchResult
def test_resolve_time_expression_supports_relative_and_absolute_inputs():
now = datetime(2026, 3, 18, 15, 30)
start_ts, end_ts, start_text, end_text = _resolve_time_expression("今天", now=now)
assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 0, 0)
assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59)
assert start_text == "2026/03/18 00:00"
assert end_text == "2026/03/18 23:59"
start_ts, end_ts, start_text, end_text = _resolve_time_expression("最近7天", now=now)
assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 12, 0, 0)
assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59)
assert start_text == "2026/03/12 00:00"
assert end_text == "2026/03/18 23:59"
start_ts, end_ts, start_text, end_text = _resolve_time_expression("2026/03/18", now=now)
assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 0, 0)
assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59)
assert start_text == "2026/03/18 00:00"
assert end_text == "2026/03/18 23:59"
start_ts, end_ts, start_text, end_text = _resolve_time_expression("2026/03/18 09:30", now=now)
assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 9, 30)
assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 9, 30)
assert start_text == "2026/03/18 09:30"
assert end_text == "2026/03/18 09:30"
def test_register_tool_exposes_mode_and_time_expression():
register_tool()
tool = get_tool_registry().get_tool("search_long_term_memory")
assert tool is not None
params = {item["name"]: item for item in tool.parameters}
assert "mode" in params
assert params["mode"]["enum"] == ["search", "time", "episode", "aggregate"]
assert "time_expression" in params
assert params["query"]["required"] is False
def test_init_all_tools_registers_long_term_memory_tool():
init_all_tools()
tool = get_tool_registry().get_tool("search_long_term_memory")
assert tool is not None
@pytest.mark.asyncio
async def test_query_long_term_memory_search_mode_keeps_search(monkeypatch):
captured = {}
async def fake_search(query, **kwargs):
captured["query"] = query
captured["kwargs"] = kwargs
return MemorySearchResult(
hits=[MemoryHit(content="Alice 喜欢猫", score=0.9, hit_type="paragraph")],
)
monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search))
text = await query_long_term_memory("Alice 喜欢什么", chat_id="stream-1", person_id="person-1")
assert "Alice 喜欢猫" in text
assert captured == {
"query": "Alice 喜欢什么",
"kwargs": {
"limit": 5,
"mode": "search",
"chat_id": "stream-1",
"person_id": "person-1",
"time_start": None,
"time_end": None,
},
}
@pytest.mark.asyncio
async def test_query_long_term_memory_time_mode_parses_expression(monkeypatch):
captured = {}
async def fake_search(query, **kwargs):
captured["query"] = query
captured["kwargs"] = kwargs
return MemorySearchResult(
hits=[
MemoryHit(
content="昨天晚上广播站停播了十分钟。",
score=0.8,
hit_type="paragraph",
metadata={"event_time_start": 1773797400.0},
)
]
)
monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search))
monkeypatch.setattr(
tool_module,
"_resolve_time_expression",
lambda expression, now=None: (1773795600.0, 1773881940.0, "2026/03/17 00:00", "2026/03/17 23:59"),
)
text = await query_long_term_memory(
query="广播站",
mode="time",
time_expression="昨天",
chat_id="stream-1",
)
assert "指定时间范围" in text
assert "广播站停播" in text
assert captured == {
"query": "广播站",
"kwargs": {
"limit": 5,
"mode": "time",
"chat_id": "stream-1",
"person_id": "",
"time_start": 1773795600.0,
"time_end": 1773881940.0,
},
}
@pytest.mark.asyncio
async def test_query_long_term_memory_episode_and_aggregate_format_output(monkeypatch):
responses = {
"episode": MemorySearchResult(
hits=[
MemoryHit(
content="苏弦在灯塔拆开了那封冬信。",
title="冬信重见天日",
hit_type="episode",
metadata={"participants": ["苏弦"], "keywords": ["冬信", "灯塔"]},
)
]
),
"aggregate": MemorySearchResult(
hits=[
MemoryHit(
content="唐未在广播站值夜班时带着黑狗墨点。",
hit_type="paragraph",
metadata={"source_branches": ["search", "time"]},
)
]
),
}
async def fake_search(query, **kwargs):
return responses[kwargs["mode"]]
monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search))
episode_text = await query_long_term_memory("那封冬信后来怎么样了", mode="episode")
aggregate_text = await query_long_term_memory("唐未最近有什么线索", mode="aggregate")
assert "事件《冬信重见天日》" in episode_text
assert "参与者:苏弦" in episode_text
assert "[search,time][paragraph]" in aggregate_text
@pytest.mark.asyncio
async def test_query_long_term_memory_invalid_time_expression_returns_retryable_message():
text = await query_long_term_memory(query="广播站", mode="time", time_expression="明年春分后第三周")
assert "无法解析" in text
assert "最近7天" in text

View File

@@ -1,140 +0,0 @@
import pytest
from src.A_memorix.core.utils.summary_importer import (
SummaryImporter,
_message_timestamp,
_normalize_entity_items,
_normalize_relation_items,
)
from src.config.model_configs import TaskConfig
from src.services import llm_service as llm_api
def _fake_available_models() -> dict[str, TaskConfig]:
return {
"memory": TaskConfig(
model_list=["memory-model"],
max_tokens=512,
temperature=0.4,
selection_strategy="random",
),
"utils": TaskConfig(
model_list=["utils-model"],
max_tokens=256,
temperature=0.5,
selection_strategy="random",
),
"replyer": TaskConfig(
model_list=["replyer-model"],
max_tokens=128,
temperature=0.7,
selection_strategy="random",
),
}
def test_resolve_summary_model_config_uses_auto_list_when_summarization_missing(monkeypatch):
monkeypatch.setattr(llm_api, "get_available_models", _fake_available_models)
importer = SummaryImporter(
vector_store=None,
graph_store=None,
metadata_store=None,
embedding_manager=None,
plugin_config={},
)
resolved = importer._resolve_summary_model_config()
assert resolved is not None
assert resolved.model_list == ["memory-model"]
def test_resolve_summary_model_config_auto_falls_back_to_utils_then_planner(monkeypatch):
importer = SummaryImporter(
vector_store=None,
graph_store=None,
metadata_store=None,
embedding_manager=None,
plugin_config={},
)
monkeypatch.setattr(
llm_api,
"get_available_models",
lambda: {
"utils": TaskConfig(model_list=["utils-model"]),
"planner": TaskConfig(model_list=["planner-model"]),
"replyer": TaskConfig(model_list=["replyer-model"]),
},
)
resolved = importer._resolve_summary_model_config()
assert resolved is not None
assert resolved.model_list == ["utils-model"]
monkeypatch.setattr(
llm_api,
"get_available_models",
lambda: {
"planner": TaskConfig(model_list=["planner-model"]),
"replyer": TaskConfig(model_list=["replyer-model"]),
},
)
resolved = importer._resolve_summary_model_config()
assert resolved is not None
assert resolved.model_list == ["planner-model"]
def test_resolve_summary_model_config_auto_does_not_fallback_to_replyer(monkeypatch):
monkeypatch.setattr(
llm_api,
"get_available_models",
lambda: {
"replyer": TaskConfig(model_list=["replyer-model"]),
"embedding": TaskConfig(model_list=["embedding-model"]),
},
)
importer = SummaryImporter(
vector_store=None,
graph_store=None,
metadata_store=None,
embedding_manager=None,
plugin_config={},
)
assert importer._resolve_summary_model_config() is None
def test_resolve_summary_model_config_rejects_legacy_string_selector(monkeypatch):
monkeypatch.setattr(llm_api, "get_available_models", _fake_available_models)
importer = SummaryImporter(
vector_store=None,
graph_store=None,
metadata_store=None,
embedding_manager=None,
plugin_config={"summarization": {"model_name": "auto"}},
)
with pytest.raises(ValueError, match="List\\[str\\]"):
importer._resolve_summary_model_config()
def test_summary_importer_normalizes_llm_entities_and_relations():
assert _normalize_entity_items(["Alice", {"name": "地图"}, ["bad"], "Alice"]) == ["Alice", "地图"]
assert _normalize_entity_items("Alice") == []
assert _normalize_relation_items(
[
{"subject": "Alice", "predicate": "持有", "object": "地图"},
{"subject": "Alice", "predicate": "", "object": "地图"},
["bad"],
]
) == [{"subject": "Alice", "predicate": "持有", "object": "地图"}]
def test_summary_importer_message_timestamp_accepts_time_fallback():
class Message:
time = 123.5
assert _message_timestamp(Message()) == 123.5

View File

@@ -1,182 +0,0 @@
from types import SimpleNamespace
import numpy as np
import pytest
from src.A_memorix.core.strategies.base import ChunkContext, KnowledgeType, ProcessedChunk, SourceInfo
from src.A_memorix.core.utils.web_import_manager import (
ImportChunkRecord,
ImportFileRecord,
ImportTaskManager,
ImportTaskRecord,
)
class _DummyMetadataStore:
def __init__(self) -> None:
self.paragraphs: list[dict[str, object]] = []
self.entities: list[str] = []
self.relations: list[tuple[str, str, str]] = []
def add_paragraph(self, **kwargs):
self.paragraphs.append(dict(kwargs))
return f"paragraph-{len(self.paragraphs)}"
def add_entity(self, *, name: str, source_paragraph: str = "") -> str:
del source_paragraph
self.entities.append(name)
return f"entity-{name}"
def add_relation(self, *, subject: str, predicate: str, obj: str, **kwargs) -> str:
del kwargs
self.relations.append((subject, predicate, obj))
return f"relation-{len(self.relations)}"
def set_relation_vector_state(self, rel_hash: str, state: str) -> None:
del rel_hash, state
class _DummyGraphStore:
def __init__(self) -> None:
self.nodes: list[list[str]] = []
self.edges: list[list[tuple[str, str]]] = []
def add_nodes(self, nodes):
self.nodes.append(list(nodes))
def add_edges(self, edges, relation_hashes=None):
del relation_hashes
self.edges.append(list(edges))
class _DummyVectorStore:
def __contains__(self, item: str) -> bool:
del item
return False
def add(self, vectors, ids):
del vectors, ids
class _DummyEmbeddingManager:
async def encode(self, text: str) -> np.ndarray:
del text
return np.ones(4, dtype=np.float32)
def _build_manager() -> tuple[ImportTaskManager, _DummyMetadataStore]:
metadata_store = _DummyMetadataStore()
plugin = SimpleNamespace(
metadata_store=metadata_store,
graph_store=_DummyGraphStore(),
vector_store=_DummyVectorStore(),
embedding_manager=_DummyEmbeddingManager(),
relation_write_service=None,
get_config=lambda key, default=None: default,
_is_embedding_degraded=lambda: False,
_allow_metadata_only_write=lambda: True,
write_paragraph_vector_or_enqueue=None,
)
manager = ImportTaskManager(plugin)
return manager, metadata_store
def _build_progress_task(task_id: str, total_chunks: int = 2) -> ImportTaskRecord:
file_record = ImportFileRecord(
file_id="file-1",
name="demo.txt",
source_kind="paste",
input_mode="text",
total_chunks=total_chunks,
chunks=[
ImportChunkRecord(chunk_id=f"chunk-{index}", index=index, chunk_type="text")
for index in range(total_chunks)
],
)
return ImportTaskRecord(task_id=task_id, source="paste", params={}, files=[file_record])
def _build_chunk(data) -> ProcessedChunk:
return ProcessedChunk(
type=KnowledgeType.FACTUAL,
source=SourceInfo(file="demo.txt", offset_start=0, offset_end=4),
chunk=ChunkContext(chunk_id="chunk-1", index=0, text="Alice 持有地图"),
data=data,
)
@pytest.mark.asyncio
async def test_persist_processed_chunk_rejects_non_object_before_paragraph_write() -> None:
manager, metadata_store = _build_manager()
file_record = SimpleNamespace(source_path="", source_kind="paste", name="demo.txt")
with pytest.raises(ValueError, match="分块抽取结果 必须返回 JSON 对象"):
await manager._persist_processed_chunk(file_record, _build_chunk(["bad"]))
assert metadata_store.paragraphs == []
@pytest.mark.asyncio
async def test_chunk_terminal_progress_uses_successful_chunks_only() -> None:
manager, _ = _build_manager()
task = _build_progress_task("task-fail-then-complete")
manager._tasks[task.task_id] = task
await manager._set_chunk_failed(task.task_id, "file-1", "chunk-0", "boom")
await manager._set_chunk_completed(task.task_id, "file-1", "chunk-1")
file_record = task.files[0]
assert file_record.done_chunks == 1
assert file_record.failed_chunks == 1
assert file_record.progress == pytest.approx(0.5)
assert task.progress == pytest.approx(0.5)
reverse_task = _build_progress_task("task-complete-then-fail")
manager._tasks[reverse_task.task_id] = reverse_task
await manager._set_chunk_completed(reverse_task.task_id, "file-1", "chunk-0")
await manager._set_chunk_failed(reverse_task.task_id, "file-1", "chunk-1", "boom")
reverse_file = reverse_task.files[0]
assert reverse_file.done_chunks == 1
assert reverse_file.failed_chunks == 1
assert reverse_file.progress == pytest.approx(0.5)
assert reverse_task.progress == pytest.approx(0.5)
@pytest.mark.asyncio
async def test_cancelled_chunks_do_not_increase_file_progress() -> None:
manager, _ = _build_manager()
task = _build_progress_task("task-cancelled-progress", total_chunks=3)
manager._tasks[task.task_id] = task
await manager._set_chunk_completed(task.task_id, "file-1", "chunk-0")
await manager._set_chunk_cancelled(task.task_id, "file-1", "chunk-1", "任务已取消")
file_record = task.files[0]
assert file_record.done_chunks == 1
assert file_record.cancelled_chunks == 1
assert file_record.progress == pytest.approx(1 / 3)
assert task.progress == pytest.approx(1 / 3)
@pytest.mark.asyncio
async def test_persist_processed_chunk_skips_invalid_nested_items() -> None:
manager, metadata_store = _build_manager()
file_record = SimpleNamespace(source_path="", source_kind="paste", name="demo.txt")
await manager._persist_processed_chunk(
file_record,
_build_chunk(
{
"triples": [{"subject": "Alice", "predicate": "持有", "object": "地图"}, ["bad"]],
"relations": [{"subject": "Alice", "predicate": "", "object": "地图"}],
"entities": ["Alice", {"name": "地图"}, ["bad"]],
}
),
)
assert len(metadata_store.paragraphs) == 1
assert set(metadata_store.entities) >= {"Alice", "地图"}
assert metadata_store.relations == [("Alice", "持有", "地图")]

View File

@@ -1,89 +0,0 @@
from types import SimpleNamespace
from src.chat.message_receive.chat_manager import chat_manager
from src.common.utils.utils_config import ChatConfigUtils, ExpressionConfigUtils
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config
def test_get_chat_prompt_for_chat_merges_multiple_matching_prompts(monkeypatch):
session_id = SessionUtils.calculate_session_id("qq", group_id="1036092828")
monkeypatch.setattr(
global_config.chat,
"chat_prompts",
[
{"platform": "qq", "item_id": "1036092828", "rule_type": "group", "prompt": "你也是群管理员,可以适当进行管理"},
{"platform": "qq", "item_id": "1036092828", "rule_type": "group", "prompt": "这个群是技术实验群,请你专心讨论技术"},
{"platform": "qq", "item_id": "other", "rule_type": "group", "prompt": "不应该生效"},
],
)
monkeypatch.setattr(chat_manager, "get_session_by_session_id", lambda _session_id: None)
result = ChatConfigUtils.get_chat_prompt_for_chat(session_id, True)
assert result == "你也是群管理员,可以适当进行管理\n这个群是技术实验群,请你专心讨论技术"
def test_get_chat_prompt_for_chat_matches_routed_session_by_chat_stream(monkeypatch):
session_id = SessionUtils.calculate_session_id("qq", group_id="1036092828", account_id="bot-a")
monkeypatch.setattr(
global_config.chat,
"chat_prompts",
[
{"platform": "qq", "item_id": "1036092828", "rule_type": "group", "prompt": "路由会话也应该生效"},
],
)
monkeypatch.setattr(
chat_manager,
"get_session_by_session_id",
lambda _session_id: SimpleNamespace(platform="qq", group_id="1036092828", user_id=None),
)
result = ChatConfigUtils.get_chat_prompt_for_chat(session_id, True)
assert result == "路由会话也应该生效"
def test_expression_learning_list_matches_routed_session_by_chat_stream(monkeypatch):
session_id = SessionUtils.calculate_session_id("qq", group_id="1036092828", account_id="bot-a")
monkeypatch.setattr(
global_config.expression,
"learning_list",
[
{
"platform": "qq",
"item_id": "1036092828",
"rule_type": "group",
"use_expression": False,
"enable_learning": False,
"enable_jargon_learning": True,
}
],
)
monkeypatch.setattr(
chat_manager,
"get_session_by_session_id",
lambda _session_id: SimpleNamespace(platform="qq", group_id="1036092828", user_id=None),
)
assert ExpressionConfigUtils.get_expression_config_for_chat(session_id) == (False, False, True)
def test_talk_value_rules_match_routed_session_by_chat_stream(monkeypatch):
session_id = SessionUtils.calculate_session_id("qq", group_id="1036092828", account_id="bot-a")
monkeypatch.setattr(global_config.chat, "talk_value", 0.1)
monkeypatch.setattr(global_config.chat, "enable_talk_value_rules", True)
monkeypatch.setattr(
global_config.chat,
"talk_value_rules",
[
{"platform": "qq", "item_id": "1036092828", "rule_type": "group", "time": "00:00-23:59", "value": 0.7}
],
)
monkeypatch.setattr(
chat_manager,
"get_session_by_session_id",
lambda _session_id: SimpleNamespace(platform="qq", group_id="1036092828", user_id=None),
)
assert ChatConfigUtils.get_talk_value(session_id, True) == 0.7

View File

@@ -1,908 +0,0 @@
"""数据库迁移基础设施测试。"""
from pathlib import Path
from typing import List, Optional, Tuple
from sqlalchemy import text
from sqlalchemy.engine import Connection, Engine
from sqlmodel import SQLModel, create_engine
import json
import msgpack
import pytest
from src.common.database import database as database_module
from src.common.database.migrations import (
BaseSchemaVersionDetector,
BaseMigrationProgressReporter,
DatabaseSchemaSnapshot,
DatabaseMigrationBootstrapper,
DatabaseMigrationState,
DatabaseMigrationManager,
EMPTY_SCHEMA_VERSION,
LATEST_SCHEMA_VERSION,
LEGACY_V1_SCHEMA_VERSION,
MigrationExecutionContext,
MigrationPlan,
MigrationRegistry,
MigrationStep,
ResolvedSchemaVersion,
SchemaVersionResolver,
SchemaVersionSource,
SQLiteSchemaInspector,
SQLiteUserVersionStore,
build_default_migration_registry,
build_default_schema_version_resolver,
create_database_migration_bootstrapper,
)
class FixedVersionDetector(BaseSchemaVersionDetector):
"""测试用固定版本探测器。"""
@property
def name(self) -> str:
"""返回测试探测器名称。
Returns:
str: 探测器名称。
"""
return "fixed_version_detector"
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
"""根据测试表是否存在返回固定版本。
Args:
snapshot: 当前数据库结构快照。
Returns:
Optional[int]: 若存在测试表则返回固定版本,否则返回 ``None``。
"""
if snapshot.has_table("legacy_records"):
return 2
return None
class FakeMigrationProgressReporter(BaseMigrationProgressReporter):
"""测试用迁移进度上报器。"""
def __init__(self) -> None:
"""初始化测试用进度上报器。"""
self.events: List[Tuple[str, Optional[int], Optional[int], Optional[str]]] = []
def open(self) -> None:
"""记录打开事件。"""
self.events.append(("open", None, None, None))
def close(self) -> None:
"""记录关闭事件。"""
self.events.append(("close", None, None, None))
def start(
self,
total_records: int,
total_tables: int,
description: str = "总迁移进度",
table_unit_name: str = "",
record_unit_name: str = "记录",
) -> None:
"""记录启动事件。
Args:
total_records: 任务记录总数。
total_tables: 任务表总数。
description: 任务描述。
table_unit_name: 表级进度单位名称。
record_unit_name: 记录级进度单位名称。
"""
del table_unit_name, record_unit_name
self.events.append(("start", total_records, total_tables, description))
def advance(
self,
records: int = 0,
completed_tables: int = 0,
item_name: Optional[str] = None,
) -> None:
"""记录推进事件。
Args:
records: 推进的记录数。
completed_tables: 已完成的表数。
item_name: 当前完成的项目名称。
"""
self.events.append(("advance", records, completed_tables, item_name))
def _create_sqlite_engine(database_file: Path) -> Engine:
"""创建测试用 SQLite 引擎。
Args:
database_file: 测试数据库文件路径。
Returns:
Engine: SQLite 引擎实例。
"""
return create_engine(
f"sqlite:///{database_file}",
echo=False,
connect_args={"check_same_thread": False},
)
def _create_current_schema(connection: Connection) -> None:
"""创建当前最新版本的数据库结构。
Args:
connection: 当前数据库连接。
"""
import src.common.database.database_model # noqa: F401
SQLModel.metadata.create_all(connection)
def _create_legacy_v1_schema_with_sample_data(connection: Connection) -> None:
"""创建带示例数据的旧版 ``0.x`` 数据库结构。
Args:
connection: 当前数据库连接。
"""
connection.execute(
text(
"""
CREATE TABLE chat_streams (
id INTEGER PRIMARY KEY,
stream_id TEXT NOT NULL,
create_time REAL NOT NULL,
last_active_time REAL NOT NULL,
platform TEXT NOT NULL,
user_id TEXT,
group_id TEXT,
group_name TEXT
)
"""
)
)
connection.execute(
text(
"""
CREATE TABLE messages (
id INTEGER PRIMARY KEY,
message_id TEXT NOT NULL,
time REAL NOT NULL,
chat_id TEXT NOT NULL,
chat_info_platform TEXT,
user_id TEXT,
user_nickname TEXT,
chat_info_group_id TEXT,
chat_info_group_name TEXT,
is_mentioned INTEGER,
is_at INTEGER,
processed_plain_text TEXT,
display_message TEXT,
is_emoji INTEGER,
is_picid INTEGER,
is_command INTEGER,
is_notify INTEGER,
additional_config TEXT,
priority_mode TEXT
)
"""
)
)
connection.execute(
text(
"""
CREATE TABLE action_records (
id INTEGER PRIMARY KEY,
action_id TEXT NOT NULL,
time REAL NOT NULL,
action_reasoning TEXT,
action_name TEXT NOT NULL,
action_data TEXT,
action_prompt_display TEXT,
chat_id TEXT
)
"""
)
)
connection.execute(
text(
"""
CREATE TABLE expression (
id INTEGER PRIMARY KEY,
situation TEXT NOT NULL,
style TEXT NOT NULL,
content_list TEXT,
count INTEGER,
last_active_time REAL NOT NULL,
chat_id TEXT,
create_date REAL,
checked INTEGER,
rejected INTEGER,
modified_by TEXT
)
"""
)
)
connection.execute(
text(
"""
CREATE TABLE jargon (
id INTEGER PRIMARY KEY,
content TEXT NOT NULL,
raw_content TEXT,
meaning TEXT,
chat_id TEXT,
is_global INTEGER,
count INTEGER,
is_jargon INTEGER,
last_inference_count INTEGER,
is_complete INTEGER,
inference_with_context TEXT,
inference_content_only TEXT
)
"""
)
)
connection.execute(
text(
"""
INSERT INTO chat_streams (
id,
stream_id,
create_time,
last_active_time,
platform,
user_id,
group_id,
group_name
) VALUES (
1,
'session-1',
1710000000.0,
1710000300.0,
'qq',
'user-1',
'group-1',
'测试群'
)
"""
)
)
connection.execute(
text(
"""
INSERT INTO messages (
id,
message_id,
time,
chat_id,
chat_info_platform,
user_id,
user_nickname,
chat_info_group_id,
chat_info_group_name,
is_mentioned,
is_at,
processed_plain_text,
display_message,
is_emoji,
is_picid,
is_command,
is_notify,
additional_config,
priority_mode
) VALUES (
1,
'msg-1',
1710000010.0,
'session-1',
'qq',
'user-1',
'测试用户',
'group-1',
'测试群',
1,
0,
'你好',
'你好呀',
0,
1,
0,
1,
'{"source":"legacy"}',
'high'
)
"""
)
)
connection.execute(
text(
"""
INSERT INTO action_records (
id,
action_id,
time,
action_reasoning,
action_name,
action_data,
action_prompt_display,
chat_id
) VALUES (
1,
'action-1',
1710000020.0,
'需要调用工具',
'search',
'{"query":"MaiBot"}',
'执行搜索',
'session-1'
)
"""
)
)
connection.execute(
text(
"""
INSERT INTO expression (
id,
situation,
style,
content_list,
count,
last_active_time,
chat_id,
create_date,
checked,
rejected,
modified_by
) VALUES (
1,
'打招呼',
'可爱',
'["你好呀","早上好"]',
3,
1710000030.0,
'session-1',
1710000040.0,
1,
0,
'ai'
)
"""
)
)
connection.execute(
text(
"""
INSERT INTO jargon (
id,
content,
raw_content,
meaning,
chat_id,
is_global,
count,
is_jargon,
last_inference_count,
is_complete,
inference_with_context,
inference_content_only
) VALUES (
1,
'上分',
'["上分"]',
'提高排名',
'session-1',
0,
5,
1,
2,
1,
'{"guess":"context"}',
'{"guess":"content"}'
)
"""
)
)
def test_user_version_store_can_read_and_write_versions(tmp_path: Path) -> None:
"""应支持读取与写入 SQLite ``user_version``。"""
engine = _create_sqlite_engine(tmp_path / "version_store.db")
version_store = SQLiteUserVersionStore()
with engine.begin() as connection:
assert version_store.read_version(connection) == 0
version_store.write_version(connection, 7)
with engine.connect() as connection:
assert version_store.read_version(connection) == 7
def test_schema_inspector_can_extract_tables_and_columns(tmp_path: Path) -> None:
"""应能提取 SQLite 数据库的表与列结构。"""
engine = _create_sqlite_engine(tmp_path / "schema_inspector.db")
inspector = SQLiteSchemaInspector()
with engine.begin() as connection:
connection.execute(
text(
"""
CREATE TABLE legacy_records (
id INTEGER PRIMARY KEY,
payload TEXT NOT NULL,
created_at TEXT
)
"""
)
)
with engine.connect() as connection:
snapshot = inspector.inspect(connection)
assert snapshot.has_table("legacy_records")
assert snapshot.has_column("legacy_records", "payload")
assert not snapshot.has_column("legacy_records", "missing_column")
table_schema = snapshot.get_table("legacy_records")
assert table_schema is not None
assert table_schema.column_names() == ["created_at", "id", "payload"]
def test_resolver_can_identify_empty_database(tmp_path: Path) -> None:
"""空数据库应被解析为版本 ``0``。"""
engine = _create_sqlite_engine(tmp_path / "empty_resolver.db")
resolver = SchemaVersionResolver()
with engine.connect() as connection:
resolved_version = resolver.resolve(connection)
assert resolved_version.version == 0
assert resolved_version.source == SchemaVersionSource.EMPTY_DATABASE
assert resolved_version.snapshot is not None
assert resolved_version.snapshot.is_empty()
def test_resolver_can_use_detector_for_unversioned_legacy_database(tmp_path: Path) -> None:
"""未写入 ``user_version`` 的历史库应支持通过探测器识别版本。"""
engine = _create_sqlite_engine(tmp_path / "legacy_resolver.db")
resolver = SchemaVersionResolver(detectors=[FixedVersionDetector()])
with engine.begin() as connection:
connection.execute(text("CREATE TABLE legacy_records (id INTEGER PRIMARY KEY, payload TEXT NOT NULL)"))
with engine.connect() as connection:
resolved_version = resolver.resolve(connection)
assert resolved_version.version == 2
assert resolved_version.source == SchemaVersionSource.DETECTOR
assert resolved_version.detector_name == "fixed_version_detector"
def test_registry_and_manager_can_execute_registered_steps(tmp_path: Path) -> None:
"""迁移编排器应能按顺序执行已注册步骤并更新版本号。"""
engine = _create_sqlite_engine(tmp_path / "manager.db")
executed_steps: List[str] = []
def migrate_0_to_1(context: MigrationExecutionContext) -> None:
"""测试迁移步骤 0 -> 1。
Args:
context: 当前迁移步骤执行上下文。
"""
executed_steps.append(f"{context.current_version}->{context.target_version}:step_0_to_1")
context.connection.execute(text("CREATE TABLE sample_records (id INTEGER PRIMARY KEY, name TEXT NOT NULL)"))
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
"""测试迁移步骤 1 -> 2。
Args:
context: 当前迁移步骤执行上下文。
"""
executed_steps.append(f"{context.current_version}->{context.target_version}:step_1_to_2")
context.connection.execute(text("ALTER TABLE sample_records ADD COLUMN email TEXT"))
registry = MigrationRegistry(
steps=[
MigrationStep(
version_from=0,
version_to=1,
name="create_sample_records",
description="创建示例表。",
handler=migrate_0_to_1,
),
MigrationStep(
version_from=1,
version_to=2,
name="add_sample_email",
description="为示例表增加邮箱字段。",
handler=migrate_1_to_2,
),
]
)
manager = DatabaseMigrationManager(engine=engine, registry=registry)
migration_plan = manager.migrate()
assert migration_plan.step_count() == 2
assert executed_steps == ["0->2:step_0_to_1", "1->2:step_1_to_2"]
with engine.connect() as connection:
version_store = SQLiteUserVersionStore()
snapshot = SQLiteSchemaInspector().inspect(connection)
recorded_version = version_store.read_version(connection)
assert recorded_version == 2
assert snapshot.has_table("sample_records")
assert snapshot.has_column("sample_records", "email")
def test_manager_can_report_step_progress(tmp_path: Path) -> None:
"""迁移编排器应支持通过上下文上报步骤进度。"""
engine = _create_sqlite_engine(tmp_path / "manager_progress.db")
reporter_instances: List[FakeMigrationProgressReporter] = []
def _build_reporter() -> BaseMigrationProgressReporter:
"""构建测试用进度上报器。
Returns:
BaseMigrationProgressReporter: 测试用进度上报器实例。
"""
reporter = FakeMigrationProgressReporter()
reporter_instances.append(reporter)
return reporter
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
"""测试迁移步骤 ``1 -> 2`` 的进度上报。
Args:
context: 当前迁移步骤执行上下文。
"""
context.start_progress(total_tables=3, total_records=30, description="总迁移进度")
context.advance_progress(records=10, completed_tables=1, item_name="chat_sessions")
context.advance_progress(records=10, completed_tables=1, item_name="mai_messages")
context.advance_progress(records=10, completed_tables=1, item_name="tool_records")
context.connection.execute(text("CREATE TABLE progress_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)"))
with engine.begin() as connection:
SQLiteUserVersionStore().write_version(connection, 1)
registry = MigrationRegistry(
steps=[
MigrationStep(
version_from=1,
version_to=2,
name="progress_step",
description="测试进度上报。",
handler=migrate_1_to_2,
)
]
)
manager = DatabaseMigrationManager(
engine=engine,
registry=registry,
progress_reporter_factory=_build_reporter,
)
migration_plan = manager.migrate()
assert migration_plan.step_count() == 1
assert len(reporter_instances) == 1
assert reporter_instances[0].events == [
("open", None, None, None),
("start", 30, 3, "总迁移进度"),
("advance", 10, 1, "chat_sessions"),
("advance", 10, 1, "mai_messages"),
("advance", 10, 1, "tool_records"),
("close", None, None, None),
]
def test_default_resolver_can_identify_unversioned_latest_database(tmp_path: Path) -> None:
"""默认解析器应能识别未写入版本号的最新结构数据库。"""
engine = _create_sqlite_engine(tmp_path / "latest_resolver.db")
resolver = build_default_schema_version_resolver()
with engine.begin() as connection:
_create_current_schema(connection)
with engine.connect() as connection:
resolved_version = resolver.resolve(connection)
assert resolved_version.version == LATEST_SCHEMA_VERSION
assert resolved_version.source == SchemaVersionSource.DETECTOR
assert resolved_version.detector_name == "latest_schema_detector"
def test_default_resolver_can_identify_legacy_v1_database(tmp_path: Path) -> None:
"""默认解析器应能识别未写版本号的旧版 ``0.x`` 数据库。"""
engine = _create_sqlite_engine(tmp_path / "legacy_v1_resolver.db")
resolver = build_default_schema_version_resolver()
with engine.begin() as connection:
_create_legacy_v1_schema_with_sample_data(connection)
with engine.connect() as connection:
resolved_version = resolver.resolve(connection)
assert resolved_version.version == LEGACY_V1_SCHEMA_VERSION
assert resolved_version.source == SchemaVersionSource.DETECTOR
assert resolved_version.detector_name == "legacy_v1_schema_detector"
def test_bootstrapper_can_finalize_unversioned_latest_database(tmp_path: Path) -> None:
"""已是最新结构但未写版本号的数据库应直接补写 ``user_version``。"""
engine = _create_sqlite_engine(tmp_path / "latest_finalize.db")
bootstrapper = create_database_migration_bootstrapper(engine)
with engine.begin() as connection:
_create_current_schema(connection)
migration_state = bootstrapper.prepare_database()
bootstrapper.finalize_database(migration_state)
assert not migration_state.requires_migration()
assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION
assert migration_state.resolved_version.source == SchemaVersionSource.DETECTOR
with engine.connect() as connection:
recorded_version = SQLiteUserVersionStore().read_version(connection)
assert recorded_version == LATEST_SCHEMA_VERSION
def test_bootstrapper_can_finalize_empty_database_to_latest_version(tmp_path: Path) -> None:
"""空库在建表完成后应回写最新 ``user_version``。"""
engine = _create_sqlite_engine(tmp_path / "bootstrap_empty.db")
bootstrapper = create_database_migration_bootstrapper(engine)
migration_state = bootstrapper.prepare_database()
assert not migration_state.requires_migration()
assert migration_state.resolved_version.version == EMPTY_SCHEMA_VERSION
assert migration_state.target_version == LATEST_SCHEMA_VERSION
with engine.begin() as connection:
_create_current_schema(connection)
bootstrapper.finalize_database(migration_state)
with engine.connect() as connection:
recorded_version = SQLiteUserVersionStore().read_version(connection)
assert recorded_version == LATEST_SCHEMA_VERSION
def test_bootstrapper_runs_registered_steps_for_versioned_database(tmp_path: Path) -> None:
"""启动桥接器应在已登记旧版本数据库上执行注册迁移步骤。"""
engine = _create_sqlite_engine(tmp_path / "bootstrap_registered.db")
execution_marks: List[str] = []
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
"""测试桥接器迁移步骤 ``1 -> 2``。
Args:
context: 当前迁移步骤执行上下文。
"""
execution_marks.append(f"step={context.step_name},index={context.step_index}")
context.connection.execute(text("ALTER TABLE bootstrap_records ADD COLUMN email TEXT"))
with engine.begin() as connection:
connection.execute(
text("CREATE TABLE bootstrap_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)")
)
SQLiteUserVersionStore().write_version(connection, 1)
registry = MigrationRegistry(
steps=[
MigrationStep(
version_from=1,
version_to=2,
name="bootstrap_add_email",
description="为桥接器测试表增加邮箱字段。",
handler=migrate_1_to_2,
)
]
)
bootstrapper = DatabaseMigrationBootstrapper(
manager=DatabaseMigrationManager(engine=engine, registry=registry),
latest_schema_version=2,
)
migration_state = bootstrapper.prepare_database()
assert migration_state.resolved_version.version == 2
assert migration_state.target_version == 2
assert execution_marks == ["step=bootstrap_add_email,index=1"]
with engine.connect() as connection:
snapshot = SQLiteSchemaInspector().inspect(connection)
recorded_version = SQLiteUserVersionStore().read_version(connection)
assert recorded_version == 2
assert snapshot.has_table("bootstrap_records")
assert snapshot.has_column("bootstrap_records", "email")
def test_default_bootstrapper_can_migrate_legacy_v1_database(tmp_path: Path) -> None:
"""默认桥接器应能把旧版 ``0.x`` 数据库整体迁移到最新结构。"""
engine = _create_sqlite_engine(tmp_path / "legacy_v1_to_v2.db")
bootstrapper = create_database_migration_bootstrapper(engine)
with engine.begin() as connection:
_create_legacy_v1_schema_with_sample_data(connection)
migration_state = bootstrapper.prepare_database()
bootstrapper.finalize_database(migration_state)
assert not migration_state.requires_migration()
assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION
assert migration_state.resolved_version.source == SchemaVersionSource.PRAGMA
with engine.connect() as connection:
recorded_version = SQLiteUserVersionStore().read_version(connection)
snapshot = SQLiteSchemaInspector().inspect(connection)
message_row = connection.execute(
text(
"""
SELECT session_id, processed_plain_text, additional_config, raw_content
FROM mai_messages
WHERE message_id = 'msg-1'
"""
)
).mappings().one()
tool_row = connection.execute(
text(
"""
SELECT session_id, tool_name, tool_display_prompt
FROM tool_records
WHERE tool_id = 'action-1'
"""
)
).mappings().one()
expression_row = connection.execute(
text(
"""
SELECT session_id, content_list, modified_by
FROM expressions
WHERE id = 1
"""
)
).mappings().one()
jargon_row = connection.execute(
text(
"""
SELECT session_id_dict, raw_content, inference_with_content_only
FROM jargons
WHERE id = 1
"""
)
).mappings().one()
assert recorded_version == LATEST_SCHEMA_VERSION
assert snapshot.has_table("__legacy_v1_messages")
assert snapshot.has_table("chat_sessions")
assert snapshot.has_table("mai_messages")
assert snapshot.has_table("tool_records")
assert not snapshot.has_table("action_records")
assert not snapshot.has_column("mai_messages", "display_message")
unpacked_raw_content = msgpack.unpackb(message_row["raw_content"], raw=False)
additional_config = json.loads(message_row["additional_config"])
expression_content_list = json.loads(expression_row["content_list"])
jargon_session_id_dict = json.loads(jargon_row["session_id_dict"])
jargon_raw_content = json.loads(jargon_row["raw_content"])
assert message_row["session_id"] == "session-1"
assert message_row["processed_plain_text"] == "你好"
assert unpacked_raw_content == [{"type": "text", "data": "你好呀"}]
assert additional_config == {"priority_mode": "high", "source": "legacy"}
assert tool_row["session_id"] == "session-1"
assert tool_row["tool_name"] == "search"
assert tool_row["tool_display_prompt"] == "执行搜索"
assert expression_row["session_id"] == "session-1"
assert expression_row["modified_by"] == "AI"
assert expression_content_list == ["你好呀", "早上好"]
assert jargon_session_id_dict == {"session-1": 5}
assert jargon_raw_content == ["上分"]
assert jargon_row["inference_with_content_only"] == '{"guess":"content"}'
def test_legacy_v1_migration_reports_table_progress(tmp_path: Path) -> None:
"""旧版迁移步骤应按目标表数量推进总进度。"""
engine = _create_sqlite_engine(tmp_path / "legacy_progress.db")
reporter_instances: List[FakeMigrationProgressReporter] = []
def _build_reporter() -> BaseMigrationProgressReporter:
"""构建测试用进度上报器。
Returns:
BaseMigrationProgressReporter: 测试用进度上报器实例。
"""
reporter = FakeMigrationProgressReporter()
reporter_instances.append(reporter)
return reporter
with engine.begin() as connection:
_create_legacy_v1_schema_with_sample_data(connection)
manager = DatabaseMigrationManager(
engine=engine,
registry=build_default_migration_registry(),
resolver=build_default_schema_version_resolver(),
progress_reporter_factory=_build_reporter,
)
migration_plan = manager.migrate(target_version=LATEST_SCHEMA_VERSION)
assert migration_plan.step_count() == 3
assert len(reporter_instances) == 3
reporter_events = reporter_instances[0].events
assert reporter_events[0] == ("open", None, None, None)
assert reporter_events[1] == ("start", 6, 12, "总迁移进度")
assert reporter_events[-1] == ("close", None, None, None)
assert reporter_events.count(("advance", 1, 0, None)) == 6
assert reporter_events.count(("advance", 0, 1, "chat_sessions")) == 1
assert reporter_events.count(("advance", 0, 1, "thinking_questions")) == 1
assert len([event for event in reporter_events if event[0] == "advance"]) == 18
def test_initialize_database_calls_bootstrapper_before_create_all(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
"""数据库初始化入口应先准备迁移,再建表、补迁移并收尾。"""
call_order: List[str] = []
def _fake_prepare_database() -> DatabaseMigrationState:
"""返回测试用迁移状态。
Returns:
DatabaseMigrationState: 不包含迁移步骤的测试状态。
"""
call_order.append("prepare_database")
return DatabaseMigrationState(
resolved_version=ResolvedSchemaVersion(version=0, source=SchemaVersionSource.EMPTY_DATABASE),
target_version=LATEST_SCHEMA_VERSION,
plan=MigrationPlan(
current_version=EMPTY_SCHEMA_VERSION,
target_version=LATEST_SCHEMA_VERSION,
steps=[],
),
)
def _fake_create_all(bind) -> None:
"""记录建表调用。
Args:
bind: 传入的数据库绑定对象。
"""
del bind
call_order.append("create_all")
def _fake_finalize_database(migration_state: DatabaseMigrationState) -> None:
"""记录迁移收尾调用。
Args:
migration_state: 当前数据库迁移状态。
"""
del migration_state
call_order.append("finalize_database")
monkeypatch.setattr(database_module, "_db_initialized", False)
monkeypatch.setattr(database_module, "_DB_DIR", tmp_path / "data")
monkeypatch.setattr(database_module._migration_bootstrapper, "prepare_database", _fake_prepare_database)
monkeypatch.setattr(database_module._migration_bootstrapper, "finalize_database", _fake_finalize_database)
monkeypatch.setattr(database_module.SQLModel.metadata, "create_all", _fake_create_all)
database_module.initialize_database()
assert call_order == [
"prepare_database",
"create_all",
"finalize_database",
]

View File

@@ -1,81 +0,0 @@
"""测试表达方式学习器的数据库读取行为。"""
from contextlib import contextmanager
from typing import Generator
import pytest
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine
from src.bw_learner.expression_learner import ExpressionLearner
from src.common.database.database_model import Expression
@pytest.fixture(name="expression_learner_engine")
def expression_learner_engine_fixture() -> Generator:
"""创建用于表达方式学习器测试的内存数据库引擎。
Yields:
Generator: 供测试使用的 SQLite 内存引擎。
"""
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
yield engine
def test_find_similar_expression_uses_read_only_session_and_history_content(
monkeypatch: pytest.MonkeyPatch,
expression_learner_engine,
) -> None:
"""查找相似表达方式时,应能在离开会话后安全使用结果,并比较历史情景内容。"""
import src.bw_learner.expression_learner as expression_learner_module
with Session(expression_learner_engine) as session:
session.add(
Expression(
situation="发送汗滴表情",
style="发送💦表情符号",
content_list='["表达情绪高涨或生理反应"]',
count=1,
session_id="session-a",
checked=False,
rejected=False,
)
)
session.commit()
@contextmanager
def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
"""构造带自动提交语义的测试会话工厂。
Args:
auto_commit: 退出上下文时是否自动提交。
Yields:
Generator[Session, None, None]: SQLModel 会话对象。
"""
session = Session(expression_learner_engine)
try:
yield session
if auto_commit:
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
monkeypatch.setattr(expression_learner_module, "get_db_session", fake_get_db_session)
learner = ExpressionLearner(session_id="session-a")
result = learner._find_similar_expression("表达情绪高涨或生理反应")
assert result is not None
expression, similarity = result
assert expression.item_id is not None
assert expression.style == "发送💦表情符号"
assert similarity == pytest.approx(1.0)

View File

@@ -1,78 +0,0 @@
"""测试表达方式表结构和基础插入行为。"""
from typing import Generator
import pytest
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine
from src.common.database.database_model import Expression
@pytest.fixture(name="expression_engine")
def expression_engine_fixture() -> Generator:
"""创建仅用于表达方式表测试的内存数据库引擎。
Yields:
Generator: 供测试使用的 SQLite 内存引擎。
"""
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
yield engine
def test_expression_insert_assigns_auto_increment_id(expression_engine) -> None:
"""表达方式表在新库中应能自动分配自增主键。"""
with Session(expression_engine) as session:
expression = Expression(
situation="表达情绪高涨或生理反应",
style="发送💦表情符号",
content_list='["表达情绪高涨或生理反应"]',
count=1,
session_id="session-a",
checked=False,
rejected=False,
)
session.add(expression)
session.commit()
session.refresh(expression)
assert expression.id is not None
assert expression.id > 0
def test_expression_insert_allows_same_situation_style(expression_engine) -> None:
"""相同情景和风格的表达方式记录不应再被错误绑定到复合主键。"""
with Session(expression_engine) as session:
first_expression = Expression(
situation="对重复行为的默契响应",
style="持续性跟发相同内容",
content_list='["对重复行为的默契响应"]',
count=1,
session_id="session-a",
checked=False,
rejected=False,
)
second_expression = Expression(
situation="对重复行为的默契响应",
style="持续性跟发相同内容",
content_list='["对重复行为的默契响应-变体"]',
count=2,
session_id="session-b",
checked=False,
rejected=False,
)
session.add(first_expression)
session.add(second_expression)
session.commit()
session.refresh(first_expression)
session.refresh(second_expression)
assert first_expression.id is not None
assert second_expression.id is not None
assert first_expression.id != second_expression.id

View File

@@ -1,90 +0,0 @@
"""测试黑话学习器的数据库读取行为。"""
from contextlib import contextmanager
from typing import Generator
import pytest
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine, select
from src.bw_learner.jargon_miner import JargonMiner
from src.common.database.database_model import Jargon
@pytest.fixture(name="jargon_miner_engine")
def jargon_miner_engine_fixture() -> Generator:
"""创建用于黑话学习器测试的内存数据库引擎。
Yields:
Generator: 供测试使用的 SQLite 内存引擎。
"""
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
yield engine
@pytest.mark.asyncio
async def test_process_extracted_entries_updates_existing_jargon_without_detached_session(
monkeypatch: pytest.MonkeyPatch,
jargon_miner_engine,
) -> None:
"""更新已有黑话时,不应因会话关闭导致 ORM 实例失效。"""
import src.bw_learner.jargon_miner as jargon_miner_module
with Session(jargon_miner_engine) as session:
session.add(
Jargon(
content="VF8V4L",
raw_content='["[1] first"]',
meaning="",
session_id_dict='{"session-a": 1}',
count=0,
is_jargon=True,
is_complete=False,
is_global=False,
last_inference_count=0,
)
)
session.commit()
@contextmanager
def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
"""构造带自动提交语义的测试会话工厂。
Args:
auto_commit: 退出上下文时是否自动提交。
Yields:
Generator[Session, None, None]: SQLModel 会话对象。
"""
session = Session(jargon_miner_engine)
try:
yield session
if auto_commit:
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
monkeypatch.setattr(jargon_miner_module, "get_db_session", fake_get_db_session)
jargon_miner = JargonMiner(session_id="session-a", session_name="测试群")
await jargon_miner.process_extracted_entries(
[{"content": "VF8V4L", "raw_content": {"[2] second"}}],
)
with Session(jargon_miner_engine) as session:
db_jargon = session.exec(select(Jargon).where(Jargon.content == "VF8V4L")).one()
assert db_jargon.count == 1
assert db_jargon.session_id_dict == '{"session-a": 2}'
assert sorted(db_jargon.raw_content and __import__("json").loads(db_jargon.raw_content)) == [
"[1] first",
"[2] second",
]

View File

@@ -1,84 +0,0 @@
"""测试黑话表结构和基础插入行为。"""
from typing import Generator
import pytest
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine
from src.common.database.database_model import Jargon
@pytest.fixture(name="jargon_engine")
def jargon_engine_fixture() -> Generator:
"""创建仅用于黑话表测试的内存数据库引擎。
Yields:
Generator: 供测试使用的 SQLite 内存引擎。
"""
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
yield engine
def test_jargon_insert_assigns_auto_increment_id(jargon_engine) -> None:
"""黑话表在新库中应能自动分配自增主键。"""
with Session(jargon_engine) as session:
jargon = Jargon(
content="VF8V4L",
raw_content='["[1] test"]',
meaning="",
session_id_dict='{"session-a": 1}',
count=1,
is_jargon=True,
is_complete=False,
is_global=True,
last_inference_count=0,
)
session.add(jargon)
session.commit()
session.refresh(jargon)
assert jargon.id is not None
assert jargon.id > 0
def test_jargon_insert_allows_same_content_with_different_rows(jargon_engine) -> None:
"""黑话内容不应再被错误地绑成复合主键的一部分。"""
with Session(jargon_engine) as session:
first_jargon = Jargon(
content="表情1",
raw_content='["[1] first"]',
meaning="",
session_id_dict='{"session-a": 1}',
count=1,
is_jargon=True,
is_complete=False,
is_global=False,
last_inference_count=0,
)
second_jargon = Jargon(
content="表情1",
raw_content='["[1] second"]',
meaning="",
session_id_dict='{"session-b": 1}',
count=1,
is_jargon=True,
is_complete=False,
is_global=False,
last_inference_count=0,
)
session.add(first_jargon)
session.add(second_jargon)
session.commit()
session.refresh(first_jargon)
session.refresh(second_jargon)
assert first_jargon.id is not None
assert second_jargon.id is not None
assert first_jargon.id != second_jargon.id

View File

@@ -1,135 +0,0 @@
from types import SimpleNamespace
import pytest
import src.chat.replyer.maisaka_expression_selector as selector_module
from src.chat.replyer.maisaka_expression_selector import MaisakaExpressionSelector
from src.common.utils.utils_session import SessionUtils
def _build_target(platform: str, item_id: str, rule_type: str = "group") -> SimpleNamespace:
return SimpleNamespace(platform=platform, item_id=item_id, rule_type=rule_type)
def test_resolve_expression_group_scope_returns_related_sessions(monkeypatch: pytest.MonkeyPatch) -> None:
current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001")
related_session_id = SessionUtils.calculate_session_id("qq", group_id="10002")
monkeypatch.setattr(
selector_module,
"global_config",
SimpleNamespace(
expression=SimpleNamespace(
expression_groups=[
SimpleNamespace(
expression_groups=[
_build_target("qq", "10001"),
_build_target("qq", "10002"),
]
)
]
)
),
)
selector = MaisakaExpressionSelector()
related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id)
assert related_session_ids == {current_session_id, related_session_id}
assert has_global_share is False
def test_resolve_expression_group_scope_matches_routed_sessions(monkeypatch: pytest.MonkeyPatch) -> None:
current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001", account_id="bot-a")
related_session_id = SessionUtils.calculate_session_id("qq", group_id="10002", account_id="bot-a")
monkeypatch.setattr(
selector_module,
"global_config",
SimpleNamespace(
expression=SimpleNamespace(
expression_groups=[
SimpleNamespace(
expression_groups=[
_build_target("qq", "10001"),
_build_target("qq", "10002"),
]
)
]
)
),
)
monkeypatch.setattr(
selector_module.ChatConfigUtils,
"_get_chat_stream",
lambda session_id: SimpleNamespace(platform="qq", group_id="10001", user_id=None)
if session_id == current_session_id
else None,
)
target_session_ids = {
"10001": current_session_id,
"10002": related_session_id,
}
monkeypatch.setattr(
selector_module.ChatConfigUtils,
"get_target_session_ids",
lambda target_item: {target_session_ids[target_item.item_id]},
)
selector = MaisakaExpressionSelector()
related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id)
assert related_session_ids == {current_session_id, related_session_id}
assert has_global_share is False
def test_resolve_expression_group_scope_uses_star_as_global_share(monkeypatch: pytest.MonkeyPatch) -> None:
current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001")
monkeypatch.setattr(
selector_module,
"global_config",
SimpleNamespace(
expression=SimpleNamespace(
expression_groups=[
SimpleNamespace(
expression_groups=[
_build_target("*", "*"),
]
)
]
)
),
)
selector = MaisakaExpressionSelector()
related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id)
assert related_session_ids == {current_session_id}
assert has_global_share is True
def test_resolve_expression_group_scope_does_not_treat_empty_target_as_global(monkeypatch: pytest.MonkeyPatch) -> None:
current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001")
monkeypatch.setattr(
selector_module,
"global_config",
SimpleNamespace(
expression=SimpleNamespace(
expression_groups=[
SimpleNamespace(
expression_groups=[
_build_target("", ""),
]
)
]
)
),
)
selector = MaisakaExpressionSelector()
related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id)
assert related_session_ids == {current_session_id}
assert has_global_share is False

View File

@@ -1,355 +0,0 @@
"""人物信息群名片字段兼容测试。"""
from __future__ import annotations
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from types import ModuleType, SimpleNamespace
from typing import Any
import json
import sys
import pytest
from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json
class _DummyLogger:
"""模拟日志记录器。"""
def debug(self, message: str) -> None:
"""记录调试日志。
Args:
message: 日志内容。
"""
del message
def info(self, message: str) -> None:
"""记录信息日志。
Args:
message: 日志内容。
"""
del message
def warning(self, message: str) -> None:
"""记录警告日志。
Args:
message: 日志内容。
"""
del message
def error(self, message: str) -> None:
"""记录错误日志。
Args:
message: 日志内容。
"""
del message
class _DummyStatement:
"""模拟 SQL 查询语句对象。"""
def where(self, condition: Any) -> "_DummyStatement":
"""附加过滤条件。
Args:
condition: 过滤条件。
Returns:
_DummyStatement: 当前语句对象。
"""
del condition
return self
def limit(self, value: int) -> "_DummyStatement":
"""限制返回条数。
Args:
value: 条数限制。
Returns:
_DummyStatement: 当前语句对象。
"""
del value
return self
class _DummyColumn:
"""模拟 SQLModel 列对象。"""
def is_not(self, value: Any) -> "_DummyColumn":
"""模拟 `IS NOT` 条件构造。
Args:
value: 比较值。
Returns:
_DummyColumn: 当前列对象。
"""
del value
return self
def __eq__(self, other: Any) -> "_DummyColumn":
"""模拟等值条件构造。
Args:
other: 比较值。
Returns:
_DummyColumn: 当前列对象。
"""
del other
return self
class _DummyResult:
"""模拟数据库查询结果。"""
def __init__(self, record: Any) -> None:
"""初始化查询结果。
Args:
record: 待返回的首条记录。
"""
self._record = record
def first(self) -> Any:
"""返回第一条记录。
Returns:
Any: 首条记录。
"""
return self._record
def all(self) -> list[Any]:
"""返回全部结果。
Returns:
list[Any]: 结果列表。
"""
if self._record is None:
return []
return self._record if isinstance(self._record, list) else [self._record]
class _DummySession:
"""模拟数据库 Session。"""
def __init__(self, record: Any) -> None:
"""初始化 Session。
Args:
record: `first()` 应返回的记录。
"""
self.record = record
self.added_records: list[Any] = []
def __enter__(self) -> "_DummySession":
"""进入上下文管理器。
Returns:
_DummySession: 当前 Session。
"""
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""退出上下文管理器。
Args:
exc_type: 异常类型。
exc_val: 异常值。
exc_tb: 异常回溯。
"""
del exc_type
del exc_val
del exc_tb
def exec(self, statement: Any) -> _DummyResult:
"""执行查询。
Args:
statement: 查询语句。
Returns:
_DummyResult: 模拟结果对象。
"""
del statement
return _DummyResult(self.record)
def add(self, record: Any) -> None:
"""记录被添加的对象。
Args:
record: 被写入 Session 的对象。
"""
self.added_records.append(record)
class _DummyPersonInfoRecord:
"""模拟 `PersonInfo` ORM 模型。"""
person_id = "person_id"
person_name = "person_name"
def __init__(self, **kwargs: Any) -> None:
"""使用关键字参数初始化记录对象。
Args:
**kwargs: 字段值。
"""
for key, value in kwargs.items():
setattr(self, key, value)
def _load_person_module(monkeypatch: pytest.MonkeyPatch, session: _DummySession) -> ModuleType:
"""加载带依赖桩的 `person_info` 模块。
Args:
monkeypatch: Pytest monkeypatch 工具。
session: 提供给模块使用的假数据库 Session。
Returns:
ModuleType: 加载后的模块对象。
"""
logger_module = ModuleType("src.common.logger")
logger_module.get_logger = lambda name: _DummyLogger()
monkeypatch.setitem(sys.modules, "src.common.logger", logger_module)
database_module = ModuleType("src.common.database.database")
database_module.get_db_session = lambda: session
monkeypatch.setitem(sys.modules, "src.common.database.database", database_module)
database_model_module = ModuleType("src.common.database.database_model")
database_model_module.PersonInfo = _DummyPersonInfoRecord
monkeypatch.setitem(sys.modules, "src.common.database.database_model", database_model_module)
llm_module = ModuleType("src.llm_models.utils_model")
class _DummyLLMRequest:
"""模拟 LLMRequest。"""
def __init__(self, model_set: Any, request_type: str) -> None:
"""初始化假请求对象。
Args:
model_set: 模型配置。
request_type: 请求类型。
"""
del model_set
del request_type
llm_module.LLMRequest = _DummyLLMRequest
monkeypatch.setitem(sys.modules, "src.llm_models.utils_model", llm_module)
config_module = ModuleType("src.config.config")
config_module.global_config = SimpleNamespace(bot=SimpleNamespace(nickname="MaiBot"))
config_module.model_config = SimpleNamespace(model_task_config=SimpleNamespace(tool_use="tool_use", utils="utils"))
monkeypatch.setitem(sys.modules, "src.config.config", config_module)
chat_manager_module = ModuleType("src.chat.message_receive.chat_manager")
chat_manager_module.chat_manager = SimpleNamespace()
monkeypatch.setitem(sys.modules, "src.chat.message_receive.chat_manager", chat_manager_module)
module_path = Path(__file__).resolve().parents[2] / "src" / "person_info" / "person_info.py"
spec = spec_from_file_location("person_info_group_cardname_test_module", module_path)
assert spec is not None and spec.loader is not None
module = module_from_spec(spec)
monkeypatch.setitem(sys.modules, spec.name, module)
spec.loader.exec_module(module)
monkeypatch.setattr(module, "select", lambda *args: _DummyStatement())
monkeypatch.setattr(module, "col", lambda field: _DummyColumn())
return module
def test_parse_group_cardname_json_uses_canonical_key() -> None:
"""群名片 JSON 解析应只使用 `group_cardname` 键名。"""
parsed = parse_group_cardname_json(
json.dumps(
[
{"group_id": "1001", "group_cardname": "现行字段"},
],
ensure_ascii=False,
)
)
assert parsed is not None
assert [(item.group_id, item.group_cardname) for item in parsed] == [
("1001", "现行字段"),
]
def test_dump_group_cardname_records_uses_canonical_key() -> None:
"""群名片序列化应输出 `group_cardname` 键名。"""
dumped = dump_group_cardname_records(
[
{"group_id": "1001", "group_cardname": "群昵称"},
]
)
assert json.loads(dumped) == [{"group_id": "1001", "group_cardname": "群昵称"}]
def test_person_sync_to_database_uses_group_cardname_field(monkeypatch: pytest.MonkeyPatch) -> None:
"""同步人物信息时应写入数据库模型的 `group_cardname` 字段。"""
record = _DummyPersonInfoRecord()
session = _DummySession(record)
module = _load_person_module(monkeypatch, session)
person = module.Person.__new__(module.Person)
person.is_known = True
person.person_id = "person-1"
person.platform = "qq"
person.user_id = "10001"
person.nickname = "看番的龙"
person.person_name = "看番的龙"
person.name_reason = "测试"
person.know_times = 1
person.know_since = 1700000000.0
person.last_know = 1700000100.0
person.memory_points = ["喜好:番剧:0.8"]
person.group_cardname_list = [{"group_id": "20001", "group_cardname": "白泽大人"}]
person.sync_to_database()
assert record.group_cardname == '[{"group_id": "20001", "group_cardname": "白泽大人"}]'
assert not hasattr(record, "group_nickname")
def test_person_load_from_database_normalizes_group_cardname_payload(monkeypatch: pytest.MonkeyPatch) -> None:
"""从数据库加载人物信息时应读取标准 `group_cardname` 结构。"""
record = _DummyPersonInfoRecord(
user_id="10001",
platform="qq",
is_known=True,
user_nickname="看番的龙",
person_name="看番的龙",
name_reason=None,
know_counts=2,
memory_points='["喜好:番剧:0.8"]',
group_cardname=json.dumps(
[
{"group_id": "20001", "group_cardname": "白泽大人"},
],
ensure_ascii=False,
),
)
session = _DummySession(record)
module = _load_person_module(monkeypatch, session)
person = module.Person.__new__(module.Person)
person.person_id = "person-1"
person.memory_points = []
person.group_cardname_list = []
person.load_from_database()
assert person.group_cardname_list == [
{"group_id": "20001", "group_cardname": "白泽大人"},
]

View File

@@ -1,533 +0,0 @@
import logging
import sys
from importlib import util
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import pytest
from pydantic import BaseModel, Field
# -------------------------------------------------------------
# 测试环境准备:补全 logger 和 AttrDocBase 依赖
# -------------------------------------------------------------
TEST_ROOT = Path(__file__).parent.parent.absolute().resolve()
logger_file = TEST_ROOT / "logger.py"
spec = util.spec_from_file_location("src.common.logger", logger_file)
module = util.module_from_spec(spec) # type: ignore
assert spec is not None and spec.loader is not None
spec.loader.exec_module(module) # type: ignore
sys.modules["src.common.logger"] = module
PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve()
sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / "src" / "config"))
from src.config.config_base import ConfigBase # noqa: E402
import src.config.config_base as config_base_module # noqa: E402
class AttrDocBase:
"""用于测试的轻量级 AttrDocBase 替身"""
def __post_init__(self) -> None:
# 被 ConfigBase.model_post_init 调用
self.__post_init_called__ = True
# 打补丁,让 ConfigBase 使用测试替身
@pytest.fixture(autouse=True)
def patch_attrdoc_post_init():
orig = config_base_module.AttrDocBase.__post_init__
config_base_module.AttrDocBase.__post_init__ = AttrDocBase.__post_init__ # type: ignore
yield
config_base_module.AttrDocBase.__post_init__ = orig
config_base_module.logger = logging.getLogger("config_base_test_logger")
class SimpleClass(ConfigBase):
a: int = 1
b: str = "test"
class TestConfigBase:
# ---------------------------------------------------------
# happy path整体 model_post_init 测试
# ---------------------------------------------------------
@pytest.mark.parametrize(
"model_cls, init_kwargs, expected_fields",
[
pytest.param(
# 简单原子类型字段
type(
"SimpleAtomic",
(ConfigBase,),
{
"__annotations__": {
"a": int,
"b": str,
"c": bool,
"d": float,
},
"a": Field(default=1),
"b": Field(default="x"),
"c": Field(default=True),
"d": Field(default=1.5),
},
),
{},
{"a", "b", "c", "d"},
id="happy-simple-atomic-fields",
),
pytest.param(
# list/set/dict 泛型 + 原子内部类型
type(
"AtomicContainers",
(ConfigBase,),
{
"__annotations__": {
"ints": List[int],
"names": Set[str],
"mapping": Dict[str, int],
},
"ints": Field(default_factory=lambda: [1, 2]),
"names": Field(default_factory=lambda: {"a", "b"}),
"mapping": Field(default_factory=lambda: {"x": 1}),
},
),
{},
{"ints", "names", "mapping"},
id="happy-atomic-containers",
),
pytest.param(
# Optional 原子和 Optional 容器
type(
"OptionalFields",
(ConfigBase,),
{
"__annotations__": {
"maybe_int": Optional[int],
"maybe_str_list": Optional[List[str]],
},
"maybe_int": Field(default=None),
"maybe_str_list": Field(default=None),
},
),
{},
{"maybe_int", "maybe_str_list"},
id="happy-optional-fields",
),
],
)
def test_model_post_init_happy_paths(self, model_cls, init_kwargs, expected_fields):
# Act
instance = model_cls(**init_kwargs)
# Assert
for field_name in expected_fields:
assert field_name in type(instance).model_fields
_ = getattr(instance, field_name)
assert getattr(instance, "__post_init_called__", False) is True
# ---------------------------------------------------------
# _get_real_type
# ---------------------------------------------------------
def test_get_real_type_non_generic_and_generic(self):
class Sample(ConfigBase):
x: int = 1
y: List[int] = Field(default_factory=list)
instance = Sample()
# Act
origin_x, args_x = instance._get_real_type(int)
# Assert
assert origin_x is int
assert args_x == ()
# Act
origin_y, args_y = instance._get_real_type(List[int])
# Assert
assert origin_y in (list, List)
assert args_y == (int,)
# ---------------------------------------------------------
# _validate_union_type
# ---------------------------------------------------------
@pytest.mark.parametrize(
"annotation, expect_error, error_fragment, expected_origin_type",
[
pytest.param(
int,
False,
None,
int,
id="union-validation-atomic-non-union",
),
pytest.param(
Optional[int],
False,
None,
int,
id="union-validation-optional-atomic",
),
pytest.param(
Optional[List[int]],
False,
None,
list,
id="union-validation-optional-container",
),
pytest.param(
Union[int, str],
True,
"不允许使用 Union 类型注解",
None,
id="union-validation-disallow-non-optional-union",
),
pytest.param(
int | str,
True,
"不允许使用 Union 类型注解",
None,
id="union-validation-pep604-disallow-non-optional-union",
),
pytest.param(
Union[int, None, str],
True,
"不允许使用 Union 类型注解",
None,
id="union-validation-disallow-union-more-than-two",
),
pytest.param(
Optional[Union[int, str]],
True,
"不允许使用 Union 类型注解",
None,
id="union-validation-disallow-nested-optional-union",
),
],
)
def test_validate_union_type(self, annotation, expect_error, error_fragment, expected_origin_type):
# 这里我们不实例化 Sample以避免在 __init__/model_post_init 阶段触发验证。
# 直接通过一个“哑实例”调用受测方法,仅测试类型注解逻辑。
class Dummy(ConfigBase):
pass
dummy = Dummy() # 最小初始化,避免字段校验
field_name = "v"
if expect_error:
# Act / Assert
with pytest.raises(TypeError) as exc_info:
dummy._validate_union_type(annotation, field_name)
assert error_fragment in str(exc_info.value)
else:
# Act
origin, args, other = dummy._validate_union_type(annotation, field_name)
# Assert
assert origin is expected_origin_type
assert other is not None
# ---------------------------------------------------------
# _validate_list_set_type
# ---------------------------------------------------------
@pytest.mark.parametrize(
"annotation, expect_error, error_fragment",
[
pytest.param(
List[int],
False,
None,
id="listset-validation-list-happy",
),
pytest.param(
Set[str],
False,
None,
id="listset-validation-set-happy",
),
pytest.param(
list,
True,
"必须指定且仅指定一个类型参数",
id="listset-validation-missing-type-arg",
),
pytest.param(
List[int | None],
True,
"不允许嵌套泛型类型",
id="listset-validation-nested-generic-inner-union",
),
pytest.param(
List[List[int]],
True,
"不允许嵌套泛型类型",
id="listset-validation-nested-generic-inner-list",
),
pytest.param(
List[SimpleClass],
False,
None,
id="listset-validation-list-configbase-element_allow",
),
pytest.param(
Set[SimpleClass],
True,
"ConfigBase is not Hashable",
id="listset-validation-set-configbase-element_reject",
),
],
)
def test_validate_list_set_type(self, annotation, expect_error, error_fragment):
# 不实例化带有这些字段的模型,避免在 __init__/model_post_init 阶段就失败,
# 只测试 _validate_list_set_type 本身的逻辑。
class Dummy(ConfigBase):
pass
dummy = Dummy()
field_name = "items"
if expect_error:
# Act / Assert
with pytest.raises(TypeError) as exc_info:
dummy._validate_list_set_type(annotation, field_name)
assert error_fragment in str(exc_info.value)
else:
# Act
dummy._validate_list_set_type(annotation, field_name)
# ---------------------------------------------------------
# _validate_dict_type
# ---------------------------------------------------------
@pytest.mark.parametrize(
"annotation, expect_error, error_fragment",
[
pytest.param(
Dict[str, int],
False,
None,
id="dict-validation-happy-atomic",
),
pytest.param(
Dict[str, Any],
True,
"不允许使用 Any 类型注解",
id="dict-validation-any-value-disallowed",
),
pytest.param(
Dict[str, Dict[str, int]],
True,
"不允许嵌套泛型类型",
id="dict-validation-optional-nested-list",
),
pytest.param(
Dict,
True,
"必须指定键和值的类型参数",
id="dict-validation-missing-args",
),
pytest.param(
Dict[str, SimpleClass],
False,
None,
id="dict-validation-happy-configbase-value",
),
],
)
def test_validate_dict_type(self, annotation, expect_error, error_fragment):
# 同样不通过字段定义来触发 model_post_init只测试 _validate_dict_type 本身。
class Dummy(ConfigBase):
_validate_any: bool = True
dummy = Dummy()
field_name = "mapping"
if expect_error:
# Act / Assert
with pytest.raises(TypeError) as exc_info:
dummy._validate_dict_type(annotation, field_name)
assert error_fragment in str(exc_info.value)
else:
# Act
dummy._validate_dict_type(annotation, field_name)
# ---------------------------------------------------------
# _discourage_any_usage
# ---------------------------------------------------------
def test_discourage_any_usage_raises_when_validate_any_true(self, caplog):
class Sample(ConfigBase):
_validate_any: bool = True
instance = Sample()
# Act / Assert
with pytest.raises(TypeError) as exc_info:
instance._discourage_any_usage("field_x")
assert "不允许使用 Any 类型注解" in str(exc_info.value)
assert "建议避免使用" not in caplog.text
def test_discourage_any_usage_logs_when_validate_any_false(self, caplog):
class Sample(ConfigBase):
_validate_any: bool = False
instance = Sample()
# Arrange
caplog.set_level(logging.WARNING, logger="config_base_test_logger")
# Act
instance._discourage_any_usage("field_y")
# Assert
assert "字段'field_y'中使用了 Any 类型注解" in caplog.text
def test_discourage_any_usage_suppressed_warning(self, caplog):
class Sample(ConfigBase):
_validate_any: bool = False
suppress_any_warning: bool = True
instance = Sample()
# Arrange
caplog.set_level(logging.WARNING, logger="config_base_test_logger")
# Act
instance._discourage_any_usage("field_z")
# Assert
assert "字段'field_z'中使用了 Any 类型注解" not in caplog.text
# ---------------------------------------------------------
# model_post_init 规则覆盖(错误与边界情况)
# ---------------------------------------------------------
@pytest.mark.parametrize(
"field_annotation, expect_error, error_fragment, test_id",
[
(
Tuple[int, int],
True,
"不允许使用 Tuple 类型注解",
"model-post-init-disallow-tuple-typing-tuple",
),
(
tuple[int, int],
True,
"不允许使用 Tuple 类型注解",
"model-post-init-disallow-pep604-tuple",
),
(
Union[int, str],
True,
"不允许使用 Union 类型注解",
"model-post-init-disallow-union-field",
),
(
list,
True,
"必须指定且仅指定一个类型参数",
"model-post-init-list-missing-type-arg",
),
(
List[List[int]],
True,
"不允许嵌套泛型类型",
"model-post-init-list-nested-generic",
),
(
Dict[str, Any],
True,
"不允许使用 Any 类型注解",
"model-post-init-dict-value-any",
),
(
Any,
True,
"不允许使用 Any 类型注解",
"model-post-init-field-any-disallowed",
),
(
Set[int],
False,
None,
"model-post-init-allow-set-int",
),
(
Dict[str, Optional[int]],
False,
None,
"model-post-init-allow-dict-optional-int",
),
],
ids=lambda v: v[3] if isinstance(v, tuple) else v,
)
def test_model_post_init_type_rules(self, field_annotation, expect_error, error_fragment, test_id):
# Arrange
attrs = {
"__annotations__": {"f": field_annotation},
"f": Field(default=None),
}
model_cls = type("DynamicModel" + test_id.replace("-", "_"), (ConfigBase,), attrs)
if expect_error:
# Act / Assert
with pytest.raises(TypeError) as exc_info:
model_cls()
assert error_fragment in str(exc_info.value)
else:
# Act
instance = model_cls()
# Assert
assert hasattr(instance, "f")
# ---------------------------------------------------------
# 嵌套 ConfigBase & 非支持泛型 origin
# ---------------------------------------------------------
def test_model_post_init_allows_configbase_nested_class(self):
class Child(ConfigBase):
value: int = 1
class Parent(ConfigBase):
child: Child = Field(default_factory=Child)
# Act
parent = Parent()
# Assert
assert isinstance(parent.child, Child)
def test_model_post_init_disallow_non_supported_generic_origin(self):
class CustomGeneric(BaseModel):
pass
class Sample(ConfigBase):
f: CustomGeneric = Field(default_factory=CustomGeneric)
# Arrange / Act / Assert
with pytest.raises(TypeError) as exc_info:
Sample()
assert "仅允许使用list, set, dict三种泛型类型注解" in str(exc_info.value)
# ---------------------------------------------------------
# super().model_post_init 和 AttrDocBase.__post_init__ 调用
# ---------------------------------------------------------
def test_super_model_post_init_and_attrdoc_post_init_called(self):
class Sample(ConfigBase):
value: int = 1
# Act
instance = Sample()
# Assert
assert getattr(instance, "__post_init_called__", False) is True

View File

@@ -1,104 +0,0 @@
from pathlib import Path
from watchfiles import Change
import asyncio
import pytest
from src.config.config import ConfigManager
from src.config.file_watcher import FileChange, FileWatcherStats
@pytest.mark.asyncio
async def test_handle_file_changes_throttles_reload():
manager = ConfigManager()
manager._hot_reload_min_interval_s = 100.0
called = 0
async def reload_stub(changed_scopes=None) -> bool:
nonlocal called
called += 1
return True
manager.reload_config = reload_stub # type: ignore[method-assign]
changes = [FileChange(change_type=Change.modified, path=Path("/tmp/bot_config.toml"))]
await manager._handle_file_changes(changes)
await manager._handle_file_changes(changes)
assert called == 1
@pytest.mark.asyncio
async def test_handle_file_changes_timeout_logged(caplog):
manager = ConfigManager()
manager._hot_reload_min_interval_s = 0.0
manager._hot_reload_timeout_s = 0.01
async def reload_stub(changed_scopes=None) -> bool:
await asyncio.sleep(0.05)
return True
manager.reload_config = reload_stub # type: ignore[method-assign]
changes = [FileChange(change_type=Change.modified, path=Path("/tmp/model_config.toml"))]
with caplog.at_level("ERROR"):
await manager._handle_file_changes(changes)
assert "配置热重载超时" in caplog.text
@pytest.mark.asyncio
async def test_handle_file_changes_empty_skips_reload():
manager = ConfigManager()
called = 0
async def reload_stub(changed_scopes=None) -> bool:
nonlocal called
called += 1
return True
manager.reload_config = reload_stub # type: ignore[method-assign]
await manager._handle_file_changes([])
assert called == 0
class _FakeWatcher:
def __init__(self):
self.unsubscribe_called_with: str | None = None
self.stop_called = False
self.stats = FileWatcherStats(
batches_seen=1,
changes_seen=2,
callbacks_succeeded=3,
callbacks_failed=4,
callbacks_timed_out=5,
callbacks_skipped_cooldown=6,
restart_count=7,
)
def unsubscribe(self, subscription_id: str) -> bool:
self.unsubscribe_called_with = subscription_id
return True
async def stop(self) -> None:
self.stop_called = True
@pytest.mark.asyncio
async def test_stop_file_watcher_cleans_state():
manager = ConfigManager()
fake_watcher = _FakeWatcher()
manager._file_watcher = fake_watcher # type: ignore[assignment]
manager._file_watcher_subscription_id = "sub-1"
await manager.stop_file_watcher()
assert fake_watcher.unsubscribe_called_with == "sub-1"
assert fake_watcher.stop_called is True
assert manager._file_watcher is None
assert manager._file_watcher_subscription_id is None

View File

@@ -1,22 +0,0 @@
from typing import Any
from src.config import config as config_module
from src.config.config import Config, ConfigManager, ModelConfig
def test_initialize_upgrades_bot_and_model_config_without_exit(monkeypatch):
manager = ConfigManager()
loaded_config_classes: list[type[Any]] = []
warnings: list[Any] = []
def fake_load_config_from_file(config_class, config_path, new_ver, override_repr=False):
loaded_config_classes.append(config_class)
return object(), True
monkeypatch.setattr(config_module, "load_config_from_file", fake_load_config_from_file)
monkeypatch.setattr(ConfigManager, "_warn_if_vlm_not_configured", lambda self, model_config: warnings.append(model_config))
manager.initialize()
assert loaded_config_classes == [Config, ModelConfig]
assert warnings == [manager.model_config]

View File

@@ -1,138 +0,0 @@
from pathlib import Path
from watchfiles import Change
import asyncio
import pytest
from src.config.file_watcher import FileChange, FileWatcher
from typing import Sequence
@pytest.mark.asyncio
async def test_dispatch_changes_with_path_and_change_type_filters(tmp_path: Path):
watcher = FileWatcher(paths=[tmp_path])
target_file = (tmp_path / "bot_config.toml").resolve()
received: list[list[FileChange]] = []
async def callback(changes):
received.append(list(changes))
watcher.subscribe(callback, paths=[target_file], change_types=[Change.modified])
await watcher._dispatch_changes(
[
FileChange(change_type=Change.added, path=target_file),
FileChange(change_type=Change.modified, path=target_file),
FileChange(change_type=Change.modified, path=(tmp_path / "other.toml").resolve()),
]
)
assert len(received) == 1
assert len(received[0]) == 1
assert received[0][0].change_type == Change.modified
assert received[0][0].path == target_file
@pytest.mark.asyncio
async def test_sync_callback_supported(tmp_path: Path):
watcher = FileWatcher(paths=[tmp_path])
target_file = (tmp_path / "model_config.toml").resolve()
received_paths: list[Path] = []
def sync_callback(changes):
received_paths.extend(change.path for change in changes)
watcher.subscribe(sync_callback, paths=[target_file])
await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)])
assert received_paths == [target_file]
@pytest.mark.asyncio
async def test_callback_timeout_and_cooldown(tmp_path: Path):
watcher = FileWatcher(
paths=[tmp_path],
callback_timeout_s=0.05,
callback_failure_threshold=2,
callback_cooldown_s=0.2,
)
target_file = (tmp_path / "bot_config.toml").resolve()
async def slow_callback(changes):
await asyncio.sleep(0.2)
watcher.subscribe(slow_callback, paths=[target_file])
await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)])
await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)])
stats_after_failures = watcher.stats
assert stats_after_failures.callbacks_timed_out == 2
assert stats_after_failures.callbacks_failed == 2
await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)])
stats_after_cooldown_skip = watcher.stats
assert stats_after_cooldown_skip.callbacks_skipped_cooldown >= 1
@pytest.mark.asyncio
async def test_start_requires_subscription(tmp_path: Path):
watcher = FileWatcher(paths=[tmp_path])
with pytest.raises(RuntimeError):
await watcher.start()
@pytest.mark.asyncio
async def test_unsubscribe_stops_dispatch(tmp_path: Path):
watcher = FileWatcher(paths=[tmp_path])
target_file = (tmp_path / "bot_config.toml").resolve()
calls = 0
async def callback(changes):
nonlocal calls
calls += 1
subscription_id = watcher.subscribe(callback, paths=[target_file])
assert watcher.unsubscribe(subscription_id) is True
await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)])
assert calls == 0
@pytest.mark.asyncio
async def test_add_callback_while_watcher_running(tmp_path: Path):
dirs = (tmp_path / "a_dir").resolve()
dirs.mkdir(exist_ok=True)
file = (dirs / "a.toml").resolve()
file.touch()
watcher = FileWatcher(paths=[dirs], debounce_ms=200)
calls = 0
async def callback(changes: Sequence[FileChange]):
nonlocal calls
print(f"Callback called with changes: {[f'{change.change_type} {change.path}' for change in changes]}")
calls += 1
uuid = watcher.subscribe(callback, paths=[file])
await watcher.start()
try:
with file.open("w") as f:
f.write("change")
await asyncio.sleep(0.5)
assert calls == 1
watcher.unsubscribe(uuid)
with file.open("w") as f:
f.write("change2")
await asyncio.sleep(0.5)
assert calls == 1
finally:
await watcher.stop()

View File

@@ -1,76 +0,0 @@
from types import SimpleNamespace
from importlib import util
from pathlib import Path
from src.config.config import config_manager
from src.config.model_configs import TaskConfig
from src.llm_models.utils_model import LLMRequest
def _load_llm_api_module():
file_path = Path(__file__).parent.parent.parent / "src" / "plugin_system" / "apis" / "llm_api.py"
spec = util.spec_from_file_location("test_llm_api_module", file_path)
assert spec is not None and spec.loader is not None
module = util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def _make_model_config(task_config: TaskConfig, attr_name: str = "utils"):
model_task_config = SimpleNamespace(**{attr_name: task_config})
return SimpleNamespace(model_task_config=model_task_config, models=[], api_providers=[])
def test_llm_request_resolve_task_config_by_signature(monkeypatch):
old_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0)
current_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0)
monkeypatch.setattr(config_manager, "get_model_config", lambda: _make_model_config(current_task, "utils"))
req = LLMRequest(model_set=old_task, request_type="test")
assert req._task_config_name == "utils"
def test_llm_request_refresh_task_config_updates_runtime_state(monkeypatch):
old_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0)
initial_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0)
updated_task = TaskConfig(model_list=["gpt-b", "gpt-c"], max_tokens=1024, temperature=0.5, slow_threshold=20.0)
current = {"task": initial_task}
def get_model_config_stub():
return _make_model_config(current["task"], "replyer")
monkeypatch.setattr(config_manager, "get_model_config", get_model_config_stub)
req = LLMRequest(model_set=old_task, request_type="test")
assert req._task_config_name == "replyer"
current["task"] = updated_task
req._refresh_task_config()
assert req.model_for_task.model_list == ["gpt-b", "gpt-c"]
assert list(req.model_usage.keys()) == ["gpt-b", "gpt-c"]
def test_llm_api_get_available_models_reads_latest_config(monkeypatch):
llm_api = _load_llm_api_module()
first_utils = TaskConfig(model_list=["gpt-a"])
second_utils = TaskConfig(model_list=["gpt-z"])
state = {"task": first_utils}
def get_model_config_stub():
model_task_config = SimpleNamespace(utils=state["task"], planner=TaskConfig(model_list=["gpt-p"]))
return SimpleNamespace(model_task_config=model_task_config)
monkeypatch.setattr(config_manager, "get_model_config", get_model_config_stub)
first = llm_api.get_available_models()
assert first["utils"].model_list == ["gpt-a"]
state["task"] = second_utils
second = llm_api.get_available_models()
assert second["utils"].model_list == ["gpt-z"]

View File

@@ -1,11 +0,0 @@
from src.config.model_configs import ModelInfo
def test_model_identifier_strips_surrounding_whitespace() -> None:
model_info = ModelInfo(
api_provider="test-provider",
model_identifier=" glm-5.1 ",
name="test-model",
)
assert model_info.model_identifier == "glm-5.1"

View File

@@ -1,104 +0,0 @@
from pathlib import Path
from types import SimpleNamespace
import sys
from src.config.legacy_migration import migrate_legacy_bind_env_to_bot_config_dict
from src.config.startup_bindings import (
BindAddress,
get_startup_main_bind_address,
get_startup_webui_bind_address,
resolve_main_bind_address,
resolve_webui_bind_address,
)
def test_startup_bindings_use_defaults_when_config_file_missing(tmp_path: Path):
missing_path = tmp_path / "missing_bot_config.toml"
assert get_startup_main_bind_address(missing_path) == BindAddress(host="127.0.0.1", port=8080)
assert get_startup_webui_bind_address(missing_path) == BindAddress(host="127.0.0.1", port=8001)
def test_startup_bindings_can_read_addresses_from_bot_config(tmp_path: Path):
config_path = tmp_path / "bot_config.toml"
config_path.write_text(
"""
[inner]
version = "8.3.1"
[maim_message]
ws_server_host = "0.0.0.0"
ws_server_port = 22345
[webui]
host = "192.168.1.9"
port = 18001
""".strip(),
encoding="utf-8",
)
assert get_startup_main_bind_address(config_path) == BindAddress(host="0.0.0.0", port=22345)
assert get_startup_webui_bind_address(config_path) == BindAddress(host="192.168.1.9", port=18001)
def test_resolve_bindings_prefer_initialized_global_config(monkeypatch):
fake_config_module = SimpleNamespace(
global_config=SimpleNamespace(
maim_message=SimpleNamespace(ws_server_host="10.0.0.2", ws_server_port=32000),
webui=SimpleNamespace(host="10.0.0.3", port=32001),
)
)
monkeypatch.setitem(sys.modules, "src.config.config", fake_config_module)
assert resolve_main_bind_address() == BindAddress(host="10.0.0.2", port=32000)
assert resolve_webui_bind_address() == BindAddress(host="10.0.0.3", port=32001)
def test_legacy_env_bindings_are_migrated_when_fields_missing_or_default(monkeypatch):
monkeypatch.setenv("HOST", "0.0.0.0")
monkeypatch.setenv("PORT", "22345")
monkeypatch.setenv("WEBUI_HOST", "192.168.1.8")
monkeypatch.setenv("WEBUI_PORT", "19001")
payload = {
"maim_message": {
"ws_server_host": "127.0.0.1",
"ws_server_port": 8080,
},
"webui": {},
}
result = migrate_legacy_bind_env_to_bot_config_dict(payload)
assert result.migrated is True
assert payload["maim_message"]["ws_server_host"] == "0.0.0.0"
assert payload["maim_message"]["ws_server_port"] == 22345
assert payload["webui"]["host"] == "192.168.1.8"
assert payload["webui"]["port"] == 19001
def test_legacy_env_bindings_do_not_override_explicit_config(monkeypatch):
monkeypatch.setenv("HOST", "0.0.0.0")
monkeypatch.setenv("PORT", "22345")
monkeypatch.setenv("WEBUI_HOST", "192.168.1.8")
monkeypatch.setenv("WEBUI_PORT", "19001")
payload = {
"maim_message": {
"ws_server_host": "10.1.1.1",
"ws_server_port": 30000,
},
"webui": {
"host": "10.1.1.2",
"port": 30001,
},
}
result = migrate_legacy_bind_env_to_bot_config_dict(payload)
assert result.migrated is False
assert payload["maim_message"]["ws_server_host"] == "10.1.1.1"
assert payload["maim_message"]["ws_server_port"] == 30000
assert payload["webui"]["host"] == "10.1.1.2"
assert payload["webui"]["port"] == 30001

View File

@@ -1,10 +0,0 @@
import sys
from pathlib import Path
# Add project root to Python path so src imports work
project_root = Path(__file__).parent.parent.absolute()
src_root = project_root / "src"
if str(src_root) not in sys.path:
sys.path.insert(0, str(src_root))
if str(project_root) not in sys.path:
sys.path.insert(1, str(project_root))

View File

@@ -1,66 +0,0 @@
from __future__ import annotations
from pathlib import Path
import json
import pytest
from src.common.i18n.manager import I18nManager
from src.common.i18n.loaders import DuplicateTranslationKeyError, load_locale_catalog
def write_locale_file(locales_root: Path, locale: str, file_name: str, payload: dict[str, object]) -> None:
locale_dir = locales_root / locale
locale_dir.mkdir(parents=True, exist_ok=True)
file_path = locale_dir / file_name
file_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
def test_t_falls_back_to_default_locale(tmp_path: Path) -> None:
locales_root = tmp_path / "locales"
write_locale_file(locales_root, "zh-CN", "core.json", {"greeting": "你好,{name}"})
write_locale_file(locales_root, "en-US", "core.json", {})
manager = I18nManager(locales_root=locales_root)
assert manager.t("greeting", locale="en-US", name="Mai") == "你好Mai"
def test_t_returns_key_when_missing_everywhere(tmp_path: Path) -> None:
locales_root = tmp_path / "locales"
write_locale_file(locales_root, "zh-CN", "core.json", {})
write_locale_file(locales_root, "en-US", "core.json", {})
manager = I18nManager(locales_root=locales_root)
assert manager.t("missing.key", locale="en-US") == "missing.key"
def test_tn_uses_plural_rules(tmp_path: Path) -> None:
locales_root = tmp_path / "locales"
write_locale_file(
locales_root,
"en-US",
"core.json",
{
"tasks.cancelled": {
"one": "Cancelled {count} task",
"other": "Cancelled {count} tasks",
}
},
)
manager = I18nManager(default_locale="en-US", locales_root=locales_root)
assert manager.tn("tasks.cancelled", 1) == "Cancelled 1 task"
assert manager.tn("tasks.cancelled", 2) == "Cancelled 2 tasks"
def test_load_locale_catalog_rejects_duplicate_keys(tmp_path: Path) -> None:
locales_root = tmp_path / "locales"
write_locale_file(locales_root, "zh-CN", "a.json", {"duplicate.key": "A"})
write_locale_file(locales_root, "zh-CN", "b.json", {"duplicate.key": "B"})
with pytest.raises(DuplicateTranslationKeyError):
load_locale_catalog("zh-CN", locales_root)

View File

@@ -1,110 +0,0 @@
from __future__ import annotations
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
import json
SCRIPT_PATH = Path(__file__).resolve().parents[2] / "scripts" / "i18n_validate.py"
MODULE_SPEC = spec_from_file_location("i18n_validate_script", SCRIPT_PATH)
assert MODULE_SPEC is not None
assert MODULE_SPEC.loader is not None
I18N_VALIDATE = module_from_spec(MODULE_SPEC)
MODULE_SPEC.loader.exec_module(I18N_VALIDATE)
def write_locale_file(locales_root: Path, locale: str, file_name: str, payload: dict[str, object]) -> None:
locale_dir = locales_root / locale
locale_dir.mkdir(parents=True, exist_ok=True)
(locale_dir / file_name).write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
def write_dashboard_locale_file(locales_root: Path, locale: str, payload: dict[str, object]) -> None:
locales_root.mkdir(parents=True, exist_ok=True)
(locales_root / f"{locale}.json").write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
def test_validate_json_locales_rejects_han_characters_in_english_locale(tmp_path: Path) -> None:
locales_root = tmp_path / "locales"
write_locale_file(locales_root, "zh-CN", "core.json", {"consent.prompt": '输入"同意"继续'})
write_locale_file(locales_root, "en-US", "core.json", {"consent.prompt": 'Type "confirmed" or "同意" to continue'})
errors = I18N_VALIDATE.validate_json_locales(locales_root)
assert any("consent.prompt" in error and "仍包含中文字符" in error for error in errors)
def test_validate_json_locales_rejects_untranslated_han_source_in_other_target_locales(tmp_path: Path) -> None:
locales_root = tmp_path / "locales"
write_locale_file(locales_root, "zh-CN", "core.json", {"greeting": "你好,世界"})
write_locale_file(locales_root, "ja", "core.json", {"greeting": "你好,世界"})
errors = I18N_VALIDATE.validate_json_locales(locales_root)
assert any("greeting" in error and "直接保留了包含中文字符的 source 文案" in error for error in errors)
def test_validate_json_locales_avoids_false_positive_when_plural_categories_do_not_align(tmp_path: Path) -> None:
locales_root = tmp_path / "locales"
write_locale_file(
locales_root,
"zh-CN",
"core.json",
{
"tasks.cancelled": {
"one": "中文单数",
"other": "中文复数",
}
},
)
write_locale_file(
locales_root,
"ja",
"core.json",
{
"tasks.cancelled": {
"many": "中文单数",
"other": "已翻译",
}
},
)
errors = I18N_VALIDATE.validate_json_locales(locales_root)
assert any("tasks.cancelled" in error and "plural category 不一致" in error for error in errors)
assert not any("tasks.cancelled" in error and "直接保留了包含中文字符的 source 文案" in error for error in errors)
def test_validate_dashboard_json_locales_rejects_han_characters_in_english_locale(tmp_path: Path) -> None:
locales_root = tmp_path / "dashboard-locales"
write_dashboard_locale_file(locales_root, "zh", {"common": {"greeting": "你好,世界"}})
write_dashboard_locale_file(locales_root, "en", {"common": {"greeting": "Hello 同意"}})
errors = I18N_VALIDATE.validate_dashboard_json_locales(locales_root)
assert any("dashboard:en" in error and "common.greeting" in error and "仍包含中文字符" in error for error in errors)
def test_validate_dashboard_json_locales_rejects_untranslated_han_source_in_other_target_locales(
tmp_path: Path,
) -> None:
locales_root = tmp_path / "dashboard-locales"
write_dashboard_locale_file(locales_root, "zh", {"common": {"greeting": "你好,世界"}})
write_dashboard_locale_file(locales_root, "ja", {"common": {"greeting": "你好,世界"}})
errors = I18N_VALIDATE.validate_dashboard_json_locales(locales_root)
assert any(
"dashboard:ja" in error and "common.greeting" in error and "直接保留了包含中文字符的 source 文案" in error
for error in errors
)
def test_validate_dashboard_json_locales_rejects_i18next_placeholder_drift(tmp_path: Path) -> None:
locales_root = tmp_path / "dashboard-locales"
write_dashboard_locale_file(locales_root, "zh", {"status": {"checkingDesc": "等待服务恢复... ({{current}}/{{max}})"}})
write_dashboard_locale_file(locales_root, "ko", {"status": {"checkingDesc": "서비스 복구 대기 중... ({{current}}/{{limit}})"}})
errors = I18N_VALIDATE.validate_dashboard_json_locales(locales_root)
assert any("dashboard:ko" in error and "status.checkingDesc" in error and "占位符集合与 source 不一致" in error for error in errors)

File diff suppressed because it is too large Load Diff

View File

@@ -1,295 +0,0 @@
import sys
import types
import importlib
import pytest
from pathlib import Path
import importlib.util
class DummyLogger:
def info(self, *a, **k):
pass
def warning(self, *a, **k):
pass
def error(self, *a, **k):
pass
class DummySession:
def __init__(self):
self.record = None
def exec(self, *a, **k):
record = self.record
class R:
def first(self):
return record
def yield_per(self, n):
if record is None:
return iter(())
return iter((record,))
return R()
def add(self, record, *a, **k):
self.record = record
def flush(self, *a, **k):
pass
def delete(self, *a, **k):
self.record = None
def expunge(self, *a, **k):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
class DummyMaiImage:
def __init__(self, full_path=None, image_bytes=None):
self.full_path = full_path
self.image_bytes = image_bytes
self.file_hash = "dummy-hash"
self.image_format = "png"
self.description = ""
self.vlm_processed = False
@classmethod
def from_db_instance(cls, record):
image = cls(full_path=getattr(record, "full_path", None))
image.file_hash = getattr(record, "image_hash", "dummy-hash")
image.description = getattr(record, "description", "")
image.vlm_processed = getattr(record, "vlm_processed", False)
return image
def to_db_instance(self):
return types.SimpleNamespace(
description=self.description,
full_path=str(self.full_path) if self.full_path is not None else "",
id=1,
image_hash=self.file_hash,
image_type="image",
last_used_time=None,
no_file_flag=False,
query_count=0,
register_time=None,
vlm_processed=self.vlm_processed,
)
async def calculate_hash_format(self):
self.file_hash = "dummy-hash"
return None
class DummyLLMRequest:
def __init__(self, *a, **k):
pass
async def generate_response_for_image(self, prompt, image_base64, image_format, temp):
return ("dummy description", {})
class DummyLLMServiceClient:
def __init__(self, *a, **k):
pass
async def generate_response_for_image(self, prompt, image_base64, image_format, options=None):
return types.SimpleNamespace(response="dummy description")
class DummySelect:
def __init__(self, *a, **k):
pass
def filter_by(self, *a, **k):
return self
def limit(self, n):
return self
class DetachedRecord:
def __init__(self, description="cached description", vlm_processed=True):
self._detached = False
self._description = description
self._vlm_processed = vlm_processed
@property
def description(self):
if not self._detached:
raise RuntimeError("attribute refresh operation cannot proceed")
return self._description
@property
def vlm_processed(self):
if not self._detached:
raise RuntimeError("attribute refresh operation cannot proceed")
return self._vlm_processed
class DetachedRecordSession(DummySession):
def __init__(self, record):
self.record = record
def exec(self, *a, **k):
record = self.record
class R:
def first(self):
return record
return R()
def expunge(self, record):
record._detached = True
@pytest.fixture(autouse=True)
def patch_external_dependencies(monkeypatch):
# Provide dummy implementations as modules so that importing image_manager is safe
# Patch LLMRequest
llm_mod = types.SimpleNamespace(LLMRequest=DummyLLMRequest)
monkeypatch.setitem(sys.modules, "src.llm_models.utils_model", llm_mod)
llm_service_mod = types.SimpleNamespace(LLMServiceClient=DummyLLMServiceClient)
monkeypatch.setitem(sys.modules, "src.services.llm_service", llm_service_mod)
# Patch logger
logger_mod = types.SimpleNamespace(get_logger=lambda name: DummyLogger())
monkeypatch.setitem(sys.modules, "src.common.logger", logger_mod)
# Patch DB session provider
shared_session = DummySession()
db_mod = types.SimpleNamespace(get_db_session=lambda: shared_session)
monkeypatch.setitem(sys.modules, "src.common.database.database", db_mod)
# Patch database model types
db_model_mod = types.SimpleNamespace(Images=types.SimpleNamespace, ImageType=types.SimpleNamespace(IMAGE="image"))
monkeypatch.setitem(sys.modules, "src.common.database.database_model", db_model_mod)
# Patch MaiImage data model
data_model_mod = types.SimpleNamespace(MaiImage=DummyMaiImage)
monkeypatch.setitem(sys.modules, "src.common.data_models.image_data_model", data_model_mod)
# Patch SQLModel select function
sql_mod = types.SimpleNamespace(select=lambda *a, **k: DummySelect())
monkeypatch.setitem(sys.modules, "sqlmodel", sql_mod)
# Patch prompt manager used to build image description prompt.
class _PromptManager:
def get_prompt(self, _name):
return types.SimpleNamespace()
async def render_prompt(self, _prompt):
return "test-style"
prompt_manager_mod = types.SimpleNamespace(prompt_manager=_PromptManager())
monkeypatch.setitem(sys.modules, "src.prompt.prompt_manager", prompt_manager_mod)
llm_options_mod = types.SimpleNamespace(LLMImageOptions=lambda **kwargs: types.SimpleNamespace(**kwargs))
monkeypatch.setitem(sys.modules, "src.common.data_models.llm_service_data_models", llm_options_mod)
# If module already imported, reload it to apply patches
mod_name = "src.chat.image_system.image_manager"
if mod_name in sys.modules:
importlib.reload(sys.modules[mod_name])
yield
def _load_image_manager_module(tmp_path=None):
repo_root = Path(__file__).parent.parent.parent
file_path = repo_root / "src" / "chat" / "image_system" / "image_manager.py"
spec = importlib.util.spec_from_file_location("image_manager_test_loaded", str(file_path))
mod = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = mod
spec.loader.exec_module(mod)
# Redirect IMAGE_DIR to pytest's tmp_path when provided
try:
if tmp_path is not None:
tmpdir = Path(tmp_path)
tmpdir.mkdir(parents=True, exist_ok=True)
mod.IMAGE_DIR = tmpdir
except Exception:
pass
return mod
@pytest.mark.asyncio
async def test_get_image_description_generates(tmp_path):
image_manager = _load_image_manager_module(tmp_path)
mgr = image_manager.ImageManager()
desc = await mgr.get_image_description(image_bytes=b"abc")
assert desc == "dummy description"
def test_get_image_from_db_none(tmp_path):
image_manager = _load_image_manager_module(tmp_path)
mgr = image_manager.ImageManager()
assert mgr.get_image_from_db("nohash") is None
def test_register_image_to_db(tmp_path):
image_manager = _load_image_manager_module(tmp_path)
mgr = image_manager.ImageManager()
p = tmp_path / "img.png"
p.write_bytes(b"data")
img = DummyMaiImage(full_path=p, image_bytes=b"data")
assert mgr.register_image_to_db(img) is True
def test_update_image_description_not_found(tmp_path):
image_manager = _load_image_manager_module(tmp_path)
mgr = image_manager.ImageManager()
img = DummyMaiImage()
img.file_hash = "nohash"
img.description = "desc"
assert mgr.update_image_description(img) is False
def test_delete_image_not_found(tmp_path):
image_manager = _load_image_manager_module(tmp_path)
mgr = image_manager.ImageManager()
img = DummyMaiImage()
img.file_hash = "nohash"
img.full_path = tmp_path = None
assert mgr.delete_image(img) is False
@pytest.mark.asyncio
async def test_save_image_and_process_and_cleanup(tmp_path):
image_manager = _load_image_manager_module(tmp_path)
mgr = image_manager.ImageManager()
# call save_image_and_process
image = await mgr.save_image_and_process(b"binarydata")
assert getattr(image, "description", None) == "dummy description"
# cleanup should run without error
mgr.cleanup_invalid_descriptions_in_db()
@pytest.mark.asyncio
async def test_get_image_description_returns_cached_description_after_session_closed(monkeypatch, tmp_path):
image_manager = _load_image_manager_module(tmp_path)
cached_record = DetachedRecord()
monkeypatch.setattr(image_manager, "get_db_session", lambda: DetachedRecordSession(cached_record))
mgr = image_manager.ImageManager()
desc = await mgr.get_image_description(image_hash="cached-hash", wait_for_build=False)
assert desc == "cached description"

View File

@@ -1,91 +0,0 @@
from pathlib import Path
from types import SimpleNamespace
import io
from PIL import Image as PILImage
import pytest
from src.common.data_models.image_data_model import MaiEmoji, MaiImage
def _build_test_image_bytes(image_format: str) -> bytes:
image = PILImage.new("RGB", (8, 8), color="white")
buffer = io.BytesIO()
image.save(buffer, format=image_format)
return buffer.getvalue()
@pytest.mark.asyncio
async def test_calculate_hash_format_updates_runtime_path_metadata(tmp_path: Path) -> None:
image_bytes = _build_test_image_bytes("JPEG")
tmp_file_path = tmp_path / "emoji.tmp"
tmp_file_path.write_bytes(image_bytes)
emoji = MaiEmoji(full_path=tmp_file_path, image_bytes=image_bytes)
assert await emoji.calculate_hash_format() is True
assert emoji.image_format == "jpeg"
assert emoji.full_path.suffix == ".jpeg"
assert emoji.file_name == emoji.full_path.name
assert emoji.dir_path == tmp_path.resolve()
@pytest.mark.asyncio
async def test_calculate_hash_format_reuses_existing_target_file(tmp_path: Path) -> None:
image_bytes = _build_test_image_bytes("JPEG")
tmp_file_path = tmp_path / "emoji.tmp"
target_file_path = tmp_path / "emoji.jpeg"
tmp_file_path.write_bytes(image_bytes)
target_file_path.write_bytes(image_bytes)
emoji = MaiEmoji(full_path=tmp_file_path, image_bytes=image_bytes)
assert await emoji.calculate_hash_format() is True
assert emoji.full_path == target_file_path.resolve()
assert emoji.file_name == target_file_path.name
assert not tmp_file_path.exists()
assert target_file_path.exists()
@pytest.mark.parametrize(
("model_cls", "extra_fields"),
[
(
MaiEmoji,
{
"description": "",
"last_used_time": None,
"query_count": 0,
"register_time": None,
},
),
(
MaiImage,
{
"description": "",
"vlm_processed": False,
},
),
],
)
def test_from_db_instance_restores_image_format_from_path(
tmp_path: Path,
model_cls: type[MaiEmoji] | type[MaiImage],
extra_fields: dict[str, object],
) -> None:
image_path = tmp_path / "cached.png"
image_path.write_bytes(_build_test_image_bytes("PNG"))
record = SimpleNamespace(
no_file_flag=False,
image_hash="hash",
full_path=str(image_path),
**extra_fields,
)
image = model_cls.from_db_instance(record)
assert image.full_path == image_path.resolve()
assert image.file_name == image_path.name
assert image.image_format == "png"

View File

@@ -1,22 +0,0 @@
class MyLogger:
def __init__(self):
pass
def info(self, msg):
print(f"INFO: {msg}")
def error(self, msg):
print(f"ERROR: {msg}")
def debug(self, msg):
print(f"DEBUG: {msg}")
def warning(self, msg):
print(f"WARNING: {msg}")
def trace(self, msg):
print(f"TRACE: {msg}")
def get_logger(*args, **kwargs):
return MyLogger()

View File

@@ -1,422 +0,0 @@
import sys
import asyncio
import pytest
import importlib
import importlib.util
from types import ModuleType
from pathlib import Path
from datetime import datetime
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from src.common.data_models.message_component_data_model import MessageSequence, ForwardComponent
from src.chat.message_receive.message import (
SessionMessage,
TextComponent,
ImageComponent,
EmojiComponent,
VoiceComponent,
AtComponent,
ReplyComponent,
ForwardNodeComponent,
)
class DummyLogger:
def __init__(self) -> None:
self.logging_record = []
def debug(self, msg):
print(f"DEBUG: {msg}")
self.logging_record.append(f"DEBUG: {msg}")
def info(self, msg):
print(f"INFO: {msg}")
self.logging_record.append(f"INFO: {msg}")
def warning(self, msg):
print(f"WARNING: {msg}")
self.logging_record.append(f"WARNING: {msg}")
def error(self, msg):
print(f"ERROR: {msg}")
self.logging_record.append(f"ERROR: {msg}")
def critical(self, msg):
print(f"CRITICAL: {msg}")
self.logging_record.append(f"CRITICAL: {msg}")
def get_logger(name):
return DummyLogger()
class DummyDBSession:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def exec(self, statement):
return self
def first(self):
return None
def commit(self):
pass
def all(self):
return []
def get_db_session():
return DummyDBSession()
def get_manual_db_session():
return DummyDBSession()
class DummySelect:
def __init__(self, model):
self.model = model
def filter_by(self, **kwargs):
return self
def where(self, condition):
return self
def limit(self, n):
return self
def select(model):
return DummySelect(model)
async def dummy_get_voice_text(binary_data):
return None # 可以根据需要返回模拟的文本结果
class DummyPersonUtils:
@staticmethod
def get_person_info_by_user_id_and_platform(user_id, platform):
return None # 可以根据需要返回模拟的用户信息
def setup_mocks(monkeypatch):
def _stub_module(name: str) -> ModuleType:
module = ModuleType(name)
monkeypatch.setitem(sys.modules, name, module)
return module
# src.common.logger
logger_mod = _stub_module("src.common.logger")
# Mock the logger
logger_mod.get_logger = get_logger
db_mod = _stub_module("src.common.database.database")
db_mod.get_db_session = get_db_session
db_mod.get_manual_db_session = get_manual_db_session
db_model_mod = _stub_module("src.common.database.database_model")
db_model_mod.Messages = None # 可以根据需要添加更多的属性或方法
emoji_manager_mod = _stub_module("src.emoji_system.emoji_manager")
emoji_manager_mod.emoji_manager = None # 可以根据需要添加更多的属性或方法
image_manager_mod = _stub_module("src.chat.image_system.image_manager")
image_manager_mod.image_manager = None # 可以根据需要添加更多的属性或方法
msg_utils_mod = _stub_module("src.common.utils.utils_message")
msg_utils_mod.MessageUtils = None # 可以根据需要添加更多的属性或方法
voice_utils_mod = _stub_module("src.common.utils.utils_voice")
voice_utils_mod.get_voice_text = dummy_get_voice_text
person_utils_mod = _stub_module("src.common.utils.utils_person")
person_utils_mod.PersonUtils = DummyPersonUtils
def load_message_via_file(monkeypatch):
setup_mocks(monkeypatch)
file_path = Path(__file__).parent.parent.parent / "src" / "chat" / "message_receive" / "message.py"
spec = importlib.util.spec_from_file_location("message", file_path)
message_module = importlib.util.module_from_spec(spec)
monkeypatch.setitem(sys.modules, "message_module", message_module)
spec.loader.exec_module(message_module)
message_module.select = select
SessionMessageClass = message_module.SessionMessage
TextComponentClass = message_module.TextComponent
ImageComponentClass = message_module.ImageComponent
EmojiComponentClass = message_module.EmojiComponent
VoiceComponentClass = message_module.VoiceComponent
AtComponentClass = message_module.AtComponent
ReplyComponentClass = message_module.ReplyComponent
ForwardNodeComponentClass = message_module.ForwardNodeComponent
MessageSequenceClass = sys.modules["src.common.data_models.message_component_data_model"].MessageSequence
ForwardComponentClass = sys.modules["src.common.data_models.message_component_data_model"].ForwardComponent
globals()["SessionMessage"] = SessionMessageClass
globals()["TextComponent"] = TextComponentClass
globals()["ImageComponent"] = ImageComponentClass
globals()["EmojiComponent"] = EmojiComponentClass
globals()["VoiceComponent"] = VoiceComponentClass
globals()["AtComponent"] = AtComponentClass
globals()["ReplyComponent"] = ReplyComponentClass
globals()["ForwardNodeComponent"] = ForwardNodeComponentClass
globals()["MessageSequence"] = MessageSequenceClass
globals()["ForwardComponent"] = ForwardComponentClass
return message_module
@pytest.mark.asyncio
async def test_process(monkeypatch):
load_message_via_file(monkeypatch)
msg = SessionMessage("msg123", datetime.now(), platform="test_platform")
msg.session_id = "session123"
msg.raw_message = MessageSequence(components=[])
msg.raw_message.components = [TextComponent("Hello, world!")]
await msg.process()
assert msg.processed_plain_text == "Hello, world!"
@pytest.mark.asyncio
async def test_multiple_text(monkeypatch):
load_message_via_file(monkeypatch)
msg = SessionMessage("msg123", datetime.now(), platform="test_platform")
msg.session_id = "session123"
msg.raw_message = MessageSequence(components=[])
msg.raw_message.components = [TextComponent("Hello,"), TextComponent("world!")]
await msg.process()
assert msg.processed_plain_text == "Hello, world!"
@pytest.mark.asyncio
async def test_image(monkeypatch):
load_message_via_file(monkeypatch)
msg = SessionMessage("msg123", datetime.now(), platform="test_platform")
msg.session_id = "session123"
msg.raw_message = MessageSequence(components=[])
msg.raw_message.components = [ImageComponent(binary_hash="image_hash"), TextComponent("Hello, world!")]
await msg.process()
assert msg.processed_plain_text == "[一张图片,网卡了加载不出来] Hello, world!"
@pytest.mark.asyncio
async def test_emoji(monkeypatch):
load_message_via_file(monkeypatch)
msg = SessionMessage("msg123", datetime.now(), platform="test_platform")
msg.session_id = "session123"
msg.raw_message = MessageSequence(components=[])
msg.raw_message.components = [EmojiComponent(binary_hash="emoji_hash"), TextComponent("Hello, world!")]
await msg.process()
assert msg.processed_plain_text == "[一个表情,网卡了加载不出来] Hello, world!"
@pytest.mark.asyncio
async def test_voice(monkeypatch):
load_message_via_file(monkeypatch)
msg = SessionMessage("msg123", datetime.now(), platform="test_platform")
msg.session_id = "session123"
msg.raw_message = MessageSequence(components=[])
msg.raw_message.components = [VoiceComponent(binary_hash="voice_hash"), TextComponent("Hello, world!")]
await msg.process()
assert msg.processed_plain_text == "[语音消息,转录失败] Hello, world!"
@pytest.mark.asyncio
async def test_at_component(monkeypatch):
load_message_via_file(monkeypatch)
msg = SessionMessage("msg123", datetime.now(), platform="test_platform")
msg.session_id = "session123"
msg.platform = "test_platform"
msg.raw_message = MessageSequence(components=[])
msg.raw_message.components = [AtComponent(target_user_id="114514"), TextComponent("Hello, world!")]
await msg.process()
assert msg.processed_plain_text == "@114514 Hello, world!"
@pytest.mark.asyncio
async def test_reply_component_fail_to_fetch(monkeypatch):
load_message_via_file(monkeypatch)
msg = SessionMessage("msg123", datetime.now(), platform="test_platform")
msg.session_id = "session123"
msg.platform = "test_platform"
msg.raw_message = MessageSequence(components=[])
msg.raw_message.components = [ReplyComponent(target_message_id="1919810"), TextComponent("Hello, world!")]
await msg.process()
assert msg.processed_plain_text == "[回复了一条消息,但原消息已无法访问] Hello, world!"
@pytest.mark.asyncio
async def test_reply_component_success(monkeypatch):
module_msg = load_message_via_file(monkeypatch)
class DummyDBSessionWithReply(DummyDBSession):
def exec(self, s):
return self
def first(inner_self):
class DummyRecord:
processed_plain_text = "原消息内容"
user_cardname = "cardname123"
user_nickname = "nickname123"
user_id = "userid123"
return DummyRecord()
module_msg.get_db_session = lambda: DummyDBSessionWithReply()
msg = SessionMessage("msg123", datetime.now(), platform="test_platform")
msg.session_id = "session123"
msg.platform = "test_platform"
msg.raw_message = MessageSequence(components=[])
msg.raw_message.components = [ReplyComponent(target_message_id="1919810"), TextComponent("Hello, world!")]
await msg.process()
assert msg.processed_plain_text == "[回复了cardname123的消息: 原消息内容] Hello, world!"
@pytest.mark.asyncio
async def test_reply_component_with_db_fail(monkeypatch):
module_msg = load_message_via_file(monkeypatch)
class DummyDBSessionWithError(DummyDBSession):
def exec(self, s):
raise Exception("数据库查询失败")
module_msg.get_db_session = lambda: DummyDBSessionWithError()
msg = SessionMessage("msg123", datetime.now(), platform="test_platform")
msg.session_id = "session123"
msg.platform = "test_platform"
msg.raw_message = MessageSequence(components=[])
msg.raw_message.components = [ReplyComponent(target_message_id="1919810"), TextComponent("Hello, world!")]
await msg.process()
assert msg.processed_plain_text == "[回复了一条消息,但原消息已无法访问] Hello, world!"
assert any("数据库查询失败" in log for log in module_msg.logger.logging_record)
@pytest.mark.asyncio
async def test_forward_component(monkeypatch):
load_message_via_file(monkeypatch)
msg = SessionMessage("msg123", datetime.now(), platform="test_platform")
msg.session_id = "session123"
msg.platform = "test_platform"
msg.raw_message = MessageSequence(components=[])
msg.raw_message.components = [
ForwardNodeComponent(
forward_components=[
ForwardComponent(
message_id="msg1",
user_id="user1",
user_nickname="nickname1",
user_cardname="cardname1",
content=[TextComponent("转发消息1")],
),
ForwardComponent(
message_id="msg2",
user_id="user2",
user_nickname="nickname2",
user_cardname="cardname2",
content=[TextComponent("转发消息2")],
),
]
),
TextComponent("Hello, world!"),
]
await msg.process()
print("Processed plain text:", msg.processed_plain_text)
expected_forward_text = """【合并转发消息:
-- 【cardname1】: 转发消息1
-- 【cardname2】: 转发消息2
】 Hello, world!"""
assert msg.processed_plain_text == expected_forward_text
@pytest.mark.asyncio
async def test_forward_with_reply(monkeypatch):
load_message_via_file(monkeypatch)
msg = SessionMessage("msg123", datetime.now(), platform="test_platform")
msg.session_id = "session123"
msg.platform = "test_platform"
msg.raw_message = MessageSequence(components=[])
msg.raw_message.components = [
ForwardNodeComponent(
forward_components=[
ForwardComponent(
message_id="msg1",
user_id="user1",
user_nickname="nickname1",
user_cardname="cardname1",
content=[TextComponent("转发消息1")],
),
ForwardComponent(
message_id="msg2",
user_id="user2",
user_nickname="nickname2",
user_cardname="cardname2",
content=[ReplyComponent(target_message_id="msg1"), TextComponent("转发消息2")],
),
]
),
TextComponent("Hello, world!"),
]
await msg.process()
assert (
msg.processed_plain_text
== """【合并转发消息:
-- 【cardname1】: 转发消息1
-- 【cardname2】: [回复了cardname1的消息: 转发消息1] 转发消息2
】 Hello, world!"""
)
@pytest.mark.asyncio
async def test_multiple_reply_with_delay_in_forward(monkeypatch):
load_message_via_file(monkeypatch)
msg = SessionMessage("msg123", datetime.now(), platform="test_platform")
msg.session_id = "session123"
msg.platform = "test_platform"
msg.raw_message = MessageSequence(components=[])
async def delayed_get_voice_text(binary_data):
await asyncio.sleep(0.5) # 模拟延迟
return "这是语音转文本的结果"
sys.modules["src.common.utils.utils_voice"].get_voice_text = delayed_get_voice_text
msg.raw_message.components = [
ForwardNodeComponent(
forward_components=[
ForwardComponent(
message_id="msg1",
user_id="user1",
user_nickname="nickname1",
user_cardname="cardname1",
content=[VoiceComponent(binary_hash="voice_hash1"), TextComponent("转发消息1")],
),
ForwardComponent(
message_id="msg2",
user_id="user2",
user_nickname="nickname2",
user_cardname="cardname2",
content=[ReplyComponent(target_message_id="msg1"), TextComponent("转发消息2")],
),
ForwardComponent(
message_id="msg3",
user_id="user3",
user_nickname="nickname3",
user_cardname="cardname3",
content=[ReplyComponent(target_message_id="msg1"), TextComponent("转发消息3")],
),
]
),
]
await msg.process()
expected_text = """【合并转发消息:
-- 【cardname1】: [语音: 这是语音转文本的结果] 转发消息1
-- 【cardname2】: [回复了cardname1的消息: [语音: 这是语音转文本的结果] 转发消息1] 转发消息2
-- 【cardname3】: [回复了cardname1的消息: [语音: 这是语音转文本的结果] 转发消息1] 转发消息3
"""
assert msg.processed_plain_text == expected_text

View File

@@ -1,220 +0,0 @@
from __future__ import annotations
from pathlib import Path
import pytest
from src.common.i18n import set_locale
from src.common.prompt_i18n import clear_prompt_cache, load_prompt, list_prompt_templates
from src.prompt.prompt_manager import PromptManager
@pytest.fixture(autouse=True)
def clear_prompt_i18n_cache() -> None:
set_locale("zh-CN")
clear_prompt_cache()
yield
clear_prompt_cache()
set_locale("zh-CN")
def write_prompt(prompt_dir: Path, locale: str | None, name: str, content: str) -> None:
base_dir = prompt_dir if locale is None else prompt_dir / locale
base_dir.mkdir(parents=True, exist_ok=True)
(base_dir / f"{name}.prompt").write_text(content, encoding="utf-8")
def test_load_prompt_prefers_requested_locale(tmp_path: Path) -> None:
prompts_root = tmp_path / "prompts"
write_prompt(prompts_root, "zh-CN", "replyer", "你好,{user_name}")
write_prompt(prompts_root, "en-US", "replyer", "Hello, {user_name}")
rendered = load_prompt("replyer", locale="en-US", prompts_root=prompts_root, user_name="Mai")
assert rendered == "Hello, Mai"
def test_load_prompt_falls_back_to_default_locale(tmp_path: Path) -> None:
prompts_root = tmp_path / "prompts"
write_prompt(prompts_root, "zh-CN", "replyer", "你好,{user_name}")
rendered = load_prompt("replyer", locale="en-US", prompts_root=prompts_root, user_name="Mai")
assert rendered == "你好Mai"
def test_load_prompt_does_not_fall_back_to_legacy_root(tmp_path: Path) -> None:
prompts_root = tmp_path / "prompts"
write_prompt(prompts_root, None, "replyer", "Legacy {user_name}")
with pytest.raises(FileNotFoundError):
load_prompt("replyer", locale="en-US", prompts_root=prompts_root, user_name="Mai")
def test_load_prompt_with_category_falls_back_to_default_locale_root(tmp_path: Path) -> None:
prompts_root = tmp_path / "prompts"
write_prompt(prompts_root, "zh-CN", "replyer", "你好,{user_name}")
rendered = load_prompt("replyer", locale="en-US", category="chat", prompts_root=prompts_root, user_name="Mai")
assert rendered == "你好Mai"
def test_load_prompt_prefers_custom_prompt_override(tmp_path: Path) -> None:
prompts_root = tmp_path / "prompts"
custom_prompts_root = tmp_path / "data" / "custom_prompts"
write_prompt(prompts_root, "zh-CN", "replyer", "Base {user_name}")
write_prompt(custom_prompts_root, "zh-CN", "replyer", "Custom {user_name}")
rendered = load_prompt(
"replyer",
locale="zh-CN",
prompts_root=prompts_root,
custom_prompts_root=custom_prompts_root,
user_name="Mai",
)
assert rendered == "Custom Mai"
def test_load_prompt_prefers_custom_prompt_requested_locale(tmp_path: Path) -> None:
prompts_root = tmp_path / "prompts"
custom_prompts_root = tmp_path / "data" / "custom_prompts"
write_prompt(prompts_root, "zh-CN", "replyer", "Base zh {user_name}")
write_prompt(prompts_root, "en-US", "replyer", "Base en {user_name}")
write_prompt(custom_prompts_root, "zh-CN", "replyer", "Custom zh {user_name}")
write_prompt(custom_prompts_root, "en-US", "replyer", "Custom en {user_name}")
rendered = load_prompt(
"replyer",
locale="en-US",
prompts_root=prompts_root,
custom_prompts_root=custom_prompts_root,
user_name="Mai",
)
assert rendered == "Custom en Mai"
def test_load_prompt_uses_requested_locale_source_before_default_custom(tmp_path: Path) -> None:
prompts_root = tmp_path / "prompts"
custom_prompts_root = tmp_path / "data" / "custom_prompts"
write_prompt(prompts_root, "zh-CN", "replyer", "Base zh {user_name}")
write_prompt(prompts_root, "en-US", "replyer", "Base en {user_name}")
write_prompt(custom_prompts_root, "zh-CN", "replyer", "Custom zh {user_name}")
rendered = load_prompt(
"replyer",
locale="en-US",
prompts_root=prompts_root,
custom_prompts_root=custom_prompts_root,
user_name="Mai",
)
assert rendered == "Base en Mai"
def test_load_prompt_strict_mode_raises_on_missing_placeholder(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
prompts_root = tmp_path / "prompts"
write_prompt(prompts_root, "zh-CN", "replyer", "你好,{user_name},现在是 {current_time}")
monkeypatch.setenv("MAIBOT_PROMPT_I18N_STRICT", "1")
with pytest.raises(KeyError) as exc_info:
load_prompt("replyer", locale="zh-CN", prompts_root=prompts_root, user_name="Mai")
assert "current_time" in str(exc_info.value)
def test_load_prompt_rejects_path_traversal(tmp_path: Path) -> None:
prompts_root = tmp_path / "prompts"
write_prompt(prompts_root, "zh-CN", "replyer", "你好")
with pytest.raises(ValueError):
load_prompt("../replyer", locale="zh-CN", prompts_root=prompts_root)
def test_list_prompt_templates_prefers_locale_specific_files(tmp_path: Path) -> None:
prompts_root = tmp_path / "prompts"
write_prompt(prompts_root, "zh-CN", "replyer", "中文")
write_prompt(prompts_root, "en-US", "replyer", "English")
set_locale("en-US")
prompt_templates = list_prompt_templates(prompts_root=prompts_root)
assert prompt_templates["replyer"].path.read_text(encoding="utf-8") == "English"
def test_list_prompt_templates_loads_directory_metadata(tmp_path: Path) -> None:
prompts_root = tmp_path / "prompts"
write_prompt(prompts_root, "zh-CN", "replyer", "中文")
metadata_path = prompts_root / "zh-CN" / ".meta.toml"
metadata_path.write_text(
"""
[replyer]
display_name = "回复器"
advanced = true
description = "用于生成回复的主模板"
""".strip(),
encoding="utf-8",
)
prompt_templates = list_prompt_templates(prompts_root=prompts_root)
metadata = prompt_templates["replyer"].metadata
assert metadata.display_name == "回复器"
assert metadata.advanced is True
assert metadata.description == "用于生成回复的主模板"
def test_list_prompt_templates_loads_prompt_specific_metadata(tmp_path: Path) -> None:
prompts_root = tmp_path / "prompts"
write_prompt(prompts_root, "zh-CN", "replyer", "中文")
metadata_path = prompts_root / "zh-CN" / "replyer.meta.json"
metadata_path.write_text(
'{"display_name": "Replyer", "advanced": false, "description": "Prompt specific metadata"}',
encoding="utf-8",
)
prompt_templates = list_prompt_templates(prompts_root=prompts_root)
metadata = prompt_templates["replyer"].metadata
assert metadata.display_name == "Replyer"
assert metadata.advanced is False
assert metadata.description == "Prompt specific metadata"
def test_list_prompt_templates_reports_duplicate_name_with_custom_root(tmp_path: Path) -> None:
prompts_root = tmp_path / "prompts"
first_dir = prompts_root / "zh-CN" / "chat"
second_dir = prompts_root / "zh-CN" / "system"
first_dir.mkdir(parents=True, exist_ok=True)
second_dir.mkdir(parents=True, exist_ok=True)
(first_dir / "replyer.prompt").write_text("chat", encoding="utf-8")
(second_dir / "replyer.prompt").write_text("system", encoding="utf-8")
with pytest.raises(ValueError) as exc_info:
list_prompt_templates(prompts_root=prompts_root)
assert "zh-CN/chat/replyer.prompt" in str(exc_info.value)
assert "zh-CN/system/replyer.prompt" in str(exc_info.value)
def test_prompt_manager_load_prompts_prefers_locale_dir(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
prompts_root = tmp_path / "prompts"
custom_prompts_root = tmp_path / "data" / "custom_prompts"
custom_prompts_root.mkdir(parents=True, exist_ok=True)
write_prompt(prompts_root, "zh-CN", "replyer", "中文模板")
write_prompt(prompts_root, "en-US", "replyer", "English template")
set_locale("en-US")
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_root, raising=False)
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_prompts_root, raising=False)
monkeypatch.setattr("src.prompt.prompt_manager.SUFFIX_PROMPT", ".prompt", raising=False)
manager = PromptManager()
manager.load_prompts()
assert manager.get_prompt("replyer").template == "English template"

View File

@@ -1,893 +0,0 @@
# File: pytests/prompt_test/test_prompt_manager.py
from pathlib import Path
from typing import Any
import asyncio
import inspect
import sys
import pytest
PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve()
sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / "src" / "config"))
from src.common.i18n.loaders import DEFAULT_LOCALE # noqa
from src.prompt.prompt_manager import ( # noqa
SUFFIX_PROMPT,
Prompt,
PromptManager,
prompt_manager,
)
def write_source_prompt(prompts_dir: Path, name: str, content: str) -> Path:
from src.common.i18n.loaders import DEFAULT_LOCALE
source_dir = prompts_dir / DEFAULT_LOCALE
source_dir.mkdir(parents=True, exist_ok=True)
prompt_file = source_dir / f"{name}{SUFFIX_PROMPT}"
prompt_file.write_text(content, encoding="utf-8")
return prompt_file
# ========= Prompt 基础行为 =========
@pytest.mark.parametrize(
"prompt_name, template",
[
pytest.param("simple", "Hello {name}", id="simple-template-with-field"),
pytest.param("no-fields", "Just a static template", id="template-without-fields"),
pytest.param(
"brace-escaping",
"Use {{ and }} around {field}",
id="template-with-escaped-braces",
),
],
)
def test_prompt_init_happy_paths(prompt_name: str, template: str):
# Act
prompt = Prompt(prompt_name=prompt_name, template=template)
# Assert
assert prompt.prompt_name == prompt_name
assert prompt.template == template
@pytest.mark.parametrize(
"prompt_name, template, expected_exception, expected_msg_substring",
[
pytest.param("", "Hello {name}", ValueError, "prompt_name 不能为空", id="empty-prompt-name"),
pytest.param("valid-name", "", ValueError, "template 不能为空", id="empty-template"),
pytest.param(
"unnamed-placeholder",
"Hello {}",
ValueError,
"模板中不允许使用未命名的占位符",
id="unnamed-placeholder-not-allowed",
),
pytest.param(
"unnamed-placeholder-with-escaped-brace",
"Value {{}} and {}",
ValueError,
"模板中不允许使用未命名的占位符",
id="unnamed-placeholder-mixed-with-escaped",
),
],
)
def test_prompt_init_error_cases(
prompt_name,
template,
expected_exception,
expected_msg_substring,
):
# Act / Assert
with pytest.raises(expected_exception) as exc_info:
Prompt(prompt_name=prompt_name, template=template)
# Assert
assert expected_msg_substring in str(exc_info.value)
@pytest.mark.parametrize(
"initial_context, name, func, expected_value, expected_exception, expected_msg_substring, case_id",
[
(
{},
"const_str",
"constant",
"constant",
None,
None,
"add-context-from-string-creates-wrapper",
),
(
{},
"callable_str",
lambda prompt_name: f"hello-{prompt_name}",
"hello-my_prompt",
None,
None,
"add-context-from-callable",
),
(
{"dup": lambda _: "x"},
"dup",
"y",
None,
KeyError,
"Context function name 'dup' 已存在于 Prompt 'my_prompt'",
"add-context-duplicate-key-error",
),
],
)
def test_prompt_add_context(
initial_context,
name,
func,
expected_value,
expected_exception,
expected_msg_substring,
case_id,
):
# Arrange
prompt = Prompt(prompt_name="my_prompt", template="template")
prompt.prompt_render_context = dict(initial_context)
# Act
if expected_exception:
with pytest.raises(expected_exception) as exc_info:
prompt.add_context(name, func)
# Assert
assert expected_msg_substring in str(exc_info.value)
else:
prompt.add_context(name, func)
# Assert
assert name in prompt.prompt_render_context
result = prompt.prompt_render_context[name]("my_prompt")
assert result == expected_value
def test_prompt_clone_independent_instance():
# Arrange
prompt = Prompt(prompt_name="p", template="T {x}")
prompt.add_context("x", "X")
# Act
cloned = prompt.clone()
# Assert
assert cloned is not prompt
assert cloned.prompt_name == prompt.prompt_name
assert cloned.template == prompt.template
# 当前实现 clone 不复制 context
assert cloned.prompt_render_context == {}
# ========= PromptManager添加/获取/删除/替换 =========
def test_prompt_manager_add_prompt_happy_and_error():
# Arrange
manager = PromptManager()
prompt1 = Prompt(prompt_name="p1", template="T1")
manager.add_prompt(prompt1, need_save=True)
# Act
prompt2 = Prompt(prompt_name="p2", template="T2")
manager.add_prompt(prompt2, need_save=False)
# Assert
assert "p1" in manager._prompt_to_save
assert "p2" not in manager._prompt_to_save
# Arrange
prompt_dup = Prompt(prompt_name="p1", template="T-dup")
# Act / Assert
with pytest.raises(KeyError) as exc_info:
manager.add_prompt(prompt_dup)
# Assert
assert "Prompt name 'p1' 已存在" in str(exc_info.value)
def test_prompt_manager_remove_prompt_happy_and_error():
# Arrange
manager = PromptManager()
p1 = Prompt(prompt_name="p1", template="T")
manager.add_prompt(p1, need_save=True)
# Act
manager.remove_prompt("p1")
# Assert
assert "p1" not in manager.prompts
assert "p1" not in manager._prompt_to_save
# Act / Assert
with pytest.raises(KeyError) as exc_info:
manager.remove_prompt("no_such")
assert "Prompt name 'no_such' 不存在" in str(exc_info.value)
def test_prompt_manager_replace_prompt_happy_and_error():
# sourcery skip: extract-duplicate-method
# Arrange
manager = PromptManager()
p1 = Prompt(prompt_name="p", template="Old")
manager.add_prompt(p1, need_save=True)
p_new = Prompt(prompt_name="p", template="New")
# Act: 替换且保持 need_save
manager.replace_prompt(p_new, need_save=True)
# Assert
assert manager.prompts["p"].template == "New"
assert "p" in manager._prompt_to_save
# Act: 再次替换,且不需要保存
p_new2 = Prompt(prompt_name="p", template="New2")
manager.replace_prompt(p_new2, need_save=False)
# Assert
assert manager.prompts["p"].template == "New2"
assert "p" not in manager._prompt_to_save
# Error: 不存在的 prompt
p_unknown = Prompt(prompt_name="unknown", template="T")
with pytest.raises(KeyError) as exc_info:
manager.replace_prompt(p_unknown)
assert "Prompt name 'unknown' 不存在,无法替换" in str(exc_info.value)
def test_prompt_manager_get_prompt_is_copy():
# Arrange
manager = PromptManager()
prompt = Prompt(prompt_name="original", template="T")
manager.add_prompt(prompt)
# Act
retrieved_prompt = manager.get_prompt("original")
# Assert
assert retrieved_prompt is not prompt
assert retrieved_prompt.prompt_name == prompt.prompt_name
assert retrieved_prompt.template == prompt.template
assert retrieved_prompt.prompt_render_context == prompt.prompt_render_context
def test_prompt_manager_add_prompt_conflict_with_context_name():
# Arrange
manager = PromptManager()
manager.add_context_construct_function("ctx_name", lambda _: "value")
prompt_conflict = Prompt(prompt_name="ctx_name", template="T")
# Act / Assert
with pytest.raises(KeyError) as exc_info:
manager.add_prompt(prompt_conflict)
# Assert
assert "Prompt name 'ctx_name' 已存在" in str(exc_info.value)
def test_prompt_manager_add_context_construct_function_happy():
# Arrange
manager = PromptManager()
def ctx_func(prompt_name: str) -> str:
return f"ctx-{prompt_name}"
# Act
manager.add_context_construct_function("ctx", ctx_func)
# Assert
assert "ctx" in manager._context_construct_functions
stored_func, module = manager._context_construct_functions["ctx"]
assert stored_func is ctx_func
assert module == __name__
def test_prompt_manager_add_context_construct_function_duplicate():
# Arrange
manager = PromptManager()
def f(_):
return "x"
manager.add_context_construct_function("dup", f)
manager.add_prompt(Prompt(prompt_name="dup_prompt", template="T"))
# Act / Assert
with pytest.raises(KeyError) as exc_info1:
manager.add_context_construct_function("dup", f)
# Assert
assert "Construct function name 'dup' 已存在" in str(exc_info1.value)
# Act / Assert
with pytest.raises(KeyError) as exc_info2:
manager.add_context_construct_function("dup_prompt", f)
# Assert
assert "Construct function name 'dup_prompt' 已存在" in str(exc_info2.value)
def test_prompt_manager_get_prompt_not_exist():
# Arrange
manager = PromptManager()
# Act / Assert
with pytest.raises(KeyError) as exc_info:
manager.get_prompt("no_such_prompt")
# Assert
assert "Prompt name 'no_such_prompt' 不存在" in str(exc_info.value)
# ========= 渲染逻辑 =========
@pytest.mark.parametrize(
"template, inner_context, global_context, expected, case_id",
[
pytest.param(
"Hello {name}",
{"name": lambda p: f"name-for-{p}"},
{},
"Hello name-for-main",
"render-with-inner-context",
),
pytest.param(
"Global {block}",
{},
{"block": lambda p: f"block-{p}"},
"Global block-main",
"render-with-global-context",
),
pytest.param(
"Mix {inner} and {global}",
{"inner": lambda p: f"inner-{p}"},
{"global": lambda p: f"global-{p}"},
"Mix inner-main and global-main",
"render-with-inner-and-global-context",
),
pytest.param(
"Escaped {{ and }} and {field}",
{"field": lambda _: "X"},
{},
"Escaped { and } and X",
"render-with-escaped-braces",
),
],
)
@pytest.mark.asyncio
async def test_prompt_manager_render_contexts(
template,
inner_context,
global_context,
expected,
case_id,
):
# Arrange
manager = PromptManager()
tmp_prompt = Prompt(prompt_name="main", template=template)
manager.add_prompt(tmp_prompt)
prompt = manager.get_prompt("main")
for name, fn in inner_context.items():
prompt.add_context(name, fn)
for name, fn in global_context.items():
manager.add_context_construct_function(name, fn)
# Act
rendered = await manager.render_prompt(prompt)
# Assert
assert rendered == expected
@pytest.mark.asyncio
async def test_prompt_manager_render_nested_prompts():
# Arrange
manager = PromptManager()
p1 = Prompt(prompt_name="p1", template="P1-{x}")
p2 = Prompt(prompt_name="p2", template="P2-{p1}")
p3_tmp = Prompt(prompt_name="p3", template="{p2}-end")
manager.add_prompt(p1)
manager.add_prompt(p2)
manager.add_prompt(p3_tmp)
p3 = manager.get_prompt("p3")
p3.add_context("x", lambda _: "X")
# Act
rendered = await manager.render_prompt(p3)
# Assert
assert rendered == "P2-P1-X-end"
@pytest.mark.asyncio
async def test_prompt_manager_render_recursive_limit():
# Arrange
manager = PromptManager()
p1_tmp = Prompt(prompt_name="p1", template="{p2}")
p2_tmp = Prompt(prompt_name="p2", template="{p1}")
manager.add_prompt(p1_tmp)
manager.add_prompt(p2_tmp)
p1 = manager.get_prompt("p1")
# Act / Assert
with pytest.raises(RecursionError) as exc_info:
await manager.render_prompt(p1)
# Assert
assert "递归层级过深" in str(exc_info.value)
@pytest.mark.asyncio
async def test_prompt_manager_render_missing_field_error():
# Arrange
manager = PromptManager()
tmp_prompt = Prompt(prompt_name="main", template="Hello {missing}")
manager.add_prompt(tmp_prompt)
prompt = manager.get_prompt("main")
# Act / Assert
with pytest.raises(KeyError) as exc_info:
await manager.render_prompt(prompt)
# Assert
assert "Prompt 'main' 中缺少必要的内容块或构建函数: 'missing'" in str(exc_info.value)
@pytest.mark.asyncio
async def test_prompt_manager_render_prefers_inner_context_over_global():
# Arrange
manager = PromptManager()
tmp_prompt = Prompt(prompt_name="main", template="{field}")
manager.add_context_construct_function("field", lambda _: "global")
manager.add_prompt(tmp_prompt)
prompt = manager.get_prompt("main")
prompt.add_context("field", lambda _: "inner")
# Act
rendered = await manager.render_prompt(prompt)
# Assert
assert rendered == "inner"
@pytest.mark.asyncio
async def test_prompt_manager_render_with_coroutine_context_function():
# Arrange
manager = PromptManager()
async def async_inner(prompt_name: str) -> str:
await asyncio.sleep(0)
return f"async-{prompt_name}"
tmp_prompt = Prompt(prompt_name="main", template="{inner}")
manager.add_prompt(tmp_prompt)
prompt = manager.get_prompt("main")
prompt.add_context("inner", async_inner)
# Act
rendered = await manager.render_prompt(prompt)
# Assert
assert rendered == "async-main"
@pytest.mark.asyncio
async def test_prompt_manager_render_with_coroutine_global_context_function():
# Arrange
manager = PromptManager()
async def async_global(prompt_name: str) -> str:
await asyncio.sleep(0)
return f"g-{prompt_name}"
tmp_prompt = Prompt(prompt_name="main", template="{g}")
manager.add_context_construct_function("g", async_global)
manager.add_prompt(tmp_prompt)
prompt = manager.get_prompt("main")
# Act
rendered = await manager.render_prompt(prompt)
# Assert
assert rendered == "g-main"
@pytest.mark.asyncio
async def test_prompt_manager_render_only_cloned_instance():
# Arrange
manager = PromptManager()
p = Prompt(prompt_name="p", template="T")
manager.add_prompt(p)
# Act / Assert: 直接用原始 p 渲染会报错
with pytest.raises(ValueError) as exc_info:
await manager.render_prompt(p)
assert "只能渲染通过 PromptManager.get_prompt 方法获取的 Prompt 实例" in str(exc_info.value)
@pytest.mark.parametrize(
"is_prompt_context, use_coroutine, case_id",
[
pytest.param(True, False, "prompt-context-sync-error"),
pytest.param(False, False, "global-context-sync-error"),
pytest.param(True, True, "prompt-context-async-error"),
pytest.param(False, True, "global-context-async-error"),
],
)
@pytest.mark.asyncio
async def test_prompt_manager_get_function_result_error_logging(
monkeypatch,
is_prompt_context,
use_coroutine,
case_id,
):
# Arrange
manager = PromptManager()
class DummyError(Exception):
pass
def sync_func(_name: str) -> str:
raise DummyError("sync-error")
async def async_func(_name: str) -> str:
await asyncio.sleep(0)
raise DummyError("async-error")
func = async_func if use_coroutine else sync_func
logged_messages: list[str] = []
def fake_error(msg: Any) -> None:
logged_messages.append(str(msg))
fake_logger = type("FakeLogger", (), {"error": staticmethod(fake_error)})
monkeypatch.setattr("src.prompt.prompt_manager.logger", fake_logger)
# Act / Assert
with pytest.raises(DummyError):
await manager._get_function_result(
func=func,
prompt_name="P",
field_name="field",
is_prompt_context=is_prompt_context,
module="mod",
)
# Assert
assert logged_messages
log = logged_messages[0]
if is_prompt_context:
assert "调用 Prompt 'P' 内部上下文构造函数 'field' 时出错" in log
else:
assert "调用上下文构造函数 'field' 时出错,所属模块: 'mod'" in log
# ========= add_context_construct_function 边界 =========
def test_prompt_manager_add_context_construct_function_unknown_frame(monkeypatch):
# Arrange
manager = PromptManager()
def fake_currentframe() -> None:
return None
monkeypatch.setattr("inspect.currentframe", fake_currentframe)
def f(_):
return "x"
# Act / Assert
with pytest.raises(RuntimeError) as exc_info:
manager.add_context_construct_function("x", f)
# Assert
assert "无法获取调用栈" in str(exc_info.value)
def test_prompt_manager_add_context_construct_function_unknown_caller_frame(monkeypatch):
# Arrange
manager = PromptManager()
real_currentframe = inspect.currentframe
class FakeFrame:
f_back = None
def fake_currentframe():
return FakeFrame()
monkeypatch.setattr("inspect.currentframe", fake_currentframe)
def f(_):
return "x"
# Act / Assert
with pytest.raises(RuntimeError) as exc_info:
manager.add_context_construct_function("x", f)
# Assert
assert "无法获取调用栈的上一级" in str(exc_info.value)
# Cleanup
monkeypatch.setattr("inspect.currentframe", real_currentframe)
# ========= save/load & 目录逻辑 =========
def test_prompt_manager_save_prompts_io_error_on_unlink(tmp_path, monkeypatch):
"""
save_prompts 现在的逻辑:
1. 先删除 CUSTOM_PROMPTS_DIR 下的所有 *.prompt 文件;
2. 再将 _prompt_to_save 中的 prompt 写入 CUSTOM_PROMPTS_DIR。
这里模拟删除已有自定义 prompt 文件时发生 IO 错误。
"""
# Arrange
prompts_dir = tmp_path / "prompts"
custom_dir = tmp_path / "data" / "custom_prompts"
prompts_dir.mkdir(parents=True)
custom_dir.mkdir(parents=True)
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
# 先在自定义目录写入一个 prompt 文件,触发 unlink 路径
old_file = custom_dir / f"old{SUFFIX_PROMPT}"
old_file.write_text("old", encoding="utf-8")
manager = PromptManager()
p1 = Prompt(prompt_name="save_error", template="T")
manager.add_prompt(p1, need_save=True)
# 打桩 Path.unlink使删除文件时报错
def fake_unlink(self):
raise OSError("disk unlink error")
monkeypatch.setattr("pathlib.Path.unlink", fake_unlink)
# Act / Assert
with pytest.raises(OSError) as exc_info:
manager.save_prompts()
# Assert
assert "disk unlink error" in str(exc_info.value)
def test_prompt_manager_save_prompts_io_error_on_write(tmp_path, monkeypatch):
"""
模拟 save_prompts 在写入新 prompt 文件时发生 IO 错误。
"""
# Arrange
prompts_dir = tmp_path / "prompts"
custom_dir = tmp_path / "data" / "custom_prompts"
prompts_dir.mkdir(parents=True)
custom_dir.mkdir(parents=True)
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
manager = PromptManager()
p1 = Prompt(prompt_name="save_error", template="T")
manager.add_prompt(p1, need_save=True)
original_write_text = Path.write_text
def fake_write_text(self, *args, **kwargs):
if self == custom_dir / DEFAULT_LOCALE / f"save_error{SUFFIX_PROMPT}":
raise OSError("disk write error")
return original_write_text(self, *args, **kwargs)
monkeypatch.setattr(Path, "write_text", fake_write_text)
# Act / Assert
with pytest.raises(OSError) as exc_info:
manager.save_prompts()
# Assert
assert "disk write error" in str(exc_info.value)
def test_prompt_manager_load_prompts_io_error_from_default_dir(tmp_path, monkeypatch):
"""
模拟从默认 locale 目录读取 prompt 时发生 IO 错误。
"""
# Arrange
prompts_dir = tmp_path / "prompts"
custom_dir = tmp_path / "data" / "custom_prompts"
prompts_dir.mkdir(parents=True)
custom_dir.mkdir(parents=True)
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
prompt_file = write_source_prompt(prompts_dir, "bad", "content")
original_read_text = Path.read_text
def fake_read_text(self, *args, **kwargs):
if self == prompt_file:
raise OSError("read error")
return original_read_text(self, *args, **kwargs)
monkeypatch.setattr(Path, "read_text", fake_read_text)
manager = PromptManager()
# Act / Assert
with pytest.raises(OSError) as exc_info:
manager.load_prompts()
# Assert
assert "read error" in str(exc_info.value)
def test_prompt_manager_load_prompts_io_error_from_custom_dir(tmp_path, monkeypatch):
"""
模拟从 CUSTOM_PROMPTS_DIR 读取 prompt 时发生 IO 错误。
包含两种路径:
1. default 与 custom 同名load_prompts 会优先读取 custom
2. 仅 custom 有文件,且 default 无同名文件。
"""
# Arrange
prompts_dir = tmp_path / "prompts"
custom_dir = tmp_path / "data" / "custom_prompts"
prompts_dir.mkdir(parents=True)
custom_dir.mkdir(parents=True)
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
# default 与 custom 同名的文件
base_file = write_source_prompt(prompts_dir, "same", "base")
same_name = base_file.name
custom_file_same = custom_dir / same_name
custom_file_same.write_text("custom", encoding="utf-8")
# 仅 custom 下存在的文件
only_custom_file = custom_dir / f"only_custom{SUFFIX_PROMPT}"
only_custom_file.write_text("only", encoding="utf-8")
original_read_text = Path.read_text
def fake_read_text(self, *args, **kwargs):
if self.parent == custom_dir:
raise OSError("custom read error")
return original_read_text(self, *args, **kwargs)
monkeypatch.setattr(Path, "read_text", fake_read_text)
manager = PromptManager()
# Act / Assert
with pytest.raises(OSError) as exc_info:
manager.load_prompts()
# Assert
assert "custom read error" in str(exc_info.value)
def test_prompt_manager_load_prompts_custom_overrides_default(tmp_path, monkeypatch):
"""
load_prompts 逻辑:
- 遍历 locale 目录中的 source prompt
- 如果 CUSTOM_PROMPTS_DIR 下存在同名文件,则优先使用自定义目录
"""
# Arrange
prompts_dir = tmp_path / "prompts"
custom_dir = tmp_path / "data" / "custom_prompts"
prompts_dir.mkdir(parents=True)
custom_dir.mkdir(parents=True)
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
# source locale 目录 prompt
base_file = write_source_prompt(prompts_dir, "testp", "BaseTemplate {x}")
# 自定义目录同名 prompt应当覆盖默认
custom_file = custom_dir / base_file.name
custom_file.write_text("CustomTemplate {x}", encoding="utf-8")
manager = PromptManager()
# Act
manager.load_prompts()
# Assert
p = manager.get_prompt("testp")
assert p.template == "CustomTemplate {x}"
# 从自定义目录加载的 prompt 应标记为 need_save加入 _prompt_to_save
assert "testp" in manager._prompt_to_save
def test_prompt_manager_load_prompts_default_dir_not_mark_need_save(tmp_path, monkeypatch):
"""
从 source locale 目录加载、且没有同名自定义 prompt 时need_save 应为 False不进入 _prompt_to_save
"""
# Arrange
prompts_dir = tmp_path / "prompts"
custom_dir = tmp_path / "data" / "custom_prompts"
prompts_dir.mkdir(parents=True)
custom_dir.mkdir(parents=True)
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
# 仅 source locale 目录有 prompt自定义目录中无同名文件
base_file = write_source_prompt(prompts_dir, "only_default", "DefaultTemplate {x}")
manager = PromptManager()
# Act
manager.load_prompts()
# Assert
p = manager.get_prompt("only_default")
assert p.template == base_file.read_text(encoding="utf-8")
# 从默认目录加载的 prompt 不应标记为 need_save
assert "only_default" not in manager._prompt_to_save
def test_prompt_manager_save_prompts_use_custom_dir(tmp_path, monkeypatch):
"""
save_prompts 使用 CUSTOM_PROMPTS_DIR 进行保存。
"""
prompts_dir = tmp_path / "prompts"
custom_dir = tmp_path / "data" / "custom_prompts"
prompts_dir.mkdir(parents=True)
custom_dir.mkdir(parents=True)
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
manager = PromptManager()
p1 = Prompt(prompt_name="save_me", template="Template {x}")
p1.add_context("x", "X")
manager.add_prompt(p1, need_save=True)
# Act
manager.save_prompts()
# Assert: 文件应保存在 custom_dir 中
saved_file = custom_dir / DEFAULT_LOCALE / f"save_me{SUFFIX_PROMPT}"
assert saved_file.exists()
assert saved_file.read_text(encoding="utf-8") == "Template {x}"
# ========= 其它 =========
def test_prompt_manager_global_instance_access():
# Act
pm = prompt_manager
# Assert
assert isinstance(pm, PromptManager)
def test_formatter_parsing_named_fields_only():
# Arrange
manager = PromptManager()
prompt = Prompt(prompt_name="main", template="A {x} B {y} C")
manager.add_prompt(prompt)
# Act
fields = {field_name for _, field_name, _, _ in manager._formatter.parse(prompt.template) if field_name}
# Assert
assert fields == {"x", "y"}

View File

@@ -1,73 +0,0 @@
from src.common.data_models.message_component_data_model import (
ImageComponent,
MessageSequence,
ReplyComponent,
TextComponent,
)
from src.llm_models.payload_content.message import RoleType
from src.maisaka.context_messages import _build_message_from_sequence
from src.maisaka.message_adapter import build_visible_text_from_sequence
def test_image_only_message_keeps_placeholder_in_text_fallback() -> None:
message_sequence = MessageSequence(
[
TextComponent("[时间]19:21:20\n[用户名]William730\n[用户群昵称]\n[msg_id]1385025976\n[发言内容]"),
ImageComponent(binary_hash="hash", content=None, binary_data=None),
]
)
message = _build_message_from_sequence(
RoleType.User,
message_sequence,
"[时间]19:21:20\n[用户名]William730\n[用户群昵称]\n[msg_id]1385025976\n[发言内容][图片]",
)
assert message is not None
assert "[发言内容]" in message.get_text_content()
assert "[图片]" in message.get_text_content()
def test_whitespace_image_content_uses_placeholder_in_text_fallback() -> None:
message_sequence = MessageSequence(
[
TextComponent("[发言内容]"),
ImageComponent(binary_hash="hash", content=" ", binary_data=None),
]
)
message = _build_message_from_sequence(
RoleType.User,
message_sequence,
"[发言内容][图片]",
enable_visual_message=False,
)
assert message is not None
assert message.get_text_content() == "[发言内容][图片]"
def test_visible_text_uses_image_placeholder_for_whitespace_content() -> None:
visible_text = build_visible_text_from_sequence(
MessageSequence(
[
TextComponent("看这个"),
ImageComponent(binary_hash="hash", content=" ", binary_data=None),
]
)
)
assert visible_text == "看这个[图片]"
def test_visible_text_adds_body_marker_after_reply_component() -> None:
visible_text = build_visible_text_from_sequence(
MessageSequence(
[
ReplyComponent(target_message_id="75625487"),
TextComponent("你说是那就是"),
]
)
)
assert visible_text == "[引用]quote_id=75625487\n[发言内容]你说是那就是"

View File

@@ -1,72 +0,0 @@
import base64
import sys
from types import ModuleType, SimpleNamespace
config_module = ModuleType("src.config.config")
class _ConfigManagerStub:
def get_model_config(self) -> SimpleNamespace:
return SimpleNamespace(api_providers=[])
def register_reload_callback(self, _: object) -> None:
return None
config_module.config_manager = _ConfigManagerStub()
sys.modules.setdefault("src.config.config", config_module)
from src.llm_models.model_client import gemini_client
from src.llm_models.payload_content.message import MessageBuilder, RoleType
from src.llm_models.payload_content.tool_option import ToolCall
def _encode_signature(value: bytes) -> str:
return base64.b64encode(value).decode("ascii")
def test_convert_messages_preserves_gemini_function_call_signature_and_tool_result_id() -> None:
thought_signature = b"gemini-signature"
tool_call = ToolCall(
call_id="call-1",
func_name="reply",
args={"msg_id": "42"},
extra_content={"google": {"thought_signature": _encode_signature(thought_signature)}},
)
assistant_message = MessageBuilder().set_role(RoleType.Assistant).set_tool_calls([tool_call]).build()
tool_message = (
MessageBuilder()
.set_role(RoleType.Tool)
.set_tool_call_id("call-1")
.set_tool_name("reply")
.add_text_content('{"ok": true}')
.build()
)
contents, _ = gemini_client._convert_messages([assistant_message, tool_message])
assistant_part = contents[0].parts[0]
assert assistant_part.function_call is not None
assert assistant_part.function_call.id == "call-1"
assert assistant_part.function_call.name == "reply"
assert assistant_part.thought_signature == thought_signature
tool_part = contents[1].parts[0]
assert tool_part.function_response is not None
assert tool_part.function_response.id == "call-1"
assert tool_part.function_response.name == "reply"
assert tool_part.function_response.response == {"ok": True}
def test_convert_messages_injects_dummy_signature_for_first_historical_tool_call() -> None:
tool_calls = [
ToolCall(call_id="call-1", func_name="reply", args={"msg_id": "1"}),
ToolCall(call_id="call-2", func_name="reply", args={"msg_id": "2"}),
]
assistant_message = MessageBuilder().set_role(RoleType.Assistant).set_tool_calls(tool_calls).build()
contents, _ = gemini_client._convert_messages([assistant_message])
assert contents[0].parts[0].thought_signature == gemini_client.GEMINI_FALLBACK_THOUGHT_SIGNATURE
assert contents[0].parts[1].thought_signature is None

View File

@@ -1,194 +0,0 @@
"""HTML 浏览器渲染服务测试。"""
from pathlib import Path
from typing import Any, Dict, List
import pytest
from src.config.official_configs import PluginRuntimeRenderConfig
from src.services import html_render_service as html_render_service_module
from src.services.html_render_service import HTMLRenderService, ManagedBrowserRecord
class _FakeChromium:
"""用于模拟 Playwright Chromium 启动器的测试桩。"""
def __init__(self, effects: List[Any]) -> None:
"""初始化 Chromium 启动测试桩。
Args:
effects: 每次调用 ``launch`` 时依次返回或抛出的结果。
"""
self._effects: List[Any] = list(effects)
self.calls: List[Dict[str, Any]] = []
async def launch(self, **kwargs: Any) -> Any:
"""模拟 Playwright Chromium 的启动过程。
Args:
**kwargs: 浏览器启动参数。
Returns:
Any: 预设的浏览器对象。
Raises:
Exception: 当预设结果为异常对象时抛出。
"""
self.calls.append(dict(kwargs))
effect = self._effects.pop(0)
if isinstance(effect, Exception):
raise effect
return effect
class _FakePlaywright:
"""用于模拟 Playwright 根对象的测试桩。"""
def __init__(self, chromium: _FakeChromium) -> None:
"""初始化 Playwright 测试桩。
Args:
chromium: Chromium 启动器测试桩。
"""
self.chromium = chromium
def _build_render_config(**kwargs: Any) -> PluginRuntimeRenderConfig:
"""构造用于测试的浏览器渲染配置。
Args:
**kwargs: 需要覆盖的配置字段。
Returns:
PluginRuntimeRenderConfig: 测试使用的配置对象。
"""
payload: Dict[str, Any] = {
"auto_download_chromium": True,
"browser_install_root": "data/test-playwright-browsers",
}
payload.update(kwargs)
return PluginRuntimeRenderConfig(**payload)
@pytest.mark.asyncio
async def test_launch_browser_auto_downloads_chromium_when_missing(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""未检测到可用浏览器时,应自动下载 Chromium 并记录状态。"""
monkeypatch.setattr(html_render_service_module, "PROJECT_ROOT", tmp_path)
service = HTMLRenderService()
config = _build_render_config()
fake_browser = object()
fake_chromium = _FakeChromium(
[
RuntimeError("browserType.launch: Executable doesn't exist at /tmp/chromium"),
fake_browser,
]
)
install_calls: List[str] = []
monkeypatch.setattr(service, "_resolve_executable_path", lambda _config: "")
async def fake_install(_config: PluginRuntimeRenderConfig) -> None:
"""模拟 Chromium 自动下载。
Args:
_config: 当前浏览器渲染配置。
"""
install_calls.append(_config.browser_install_root)
browsers_path = service._get_managed_browsers_path(_config)
(browsers_path / "chromium-1234").mkdir(parents=True, exist_ok=True)
monkeypatch.setattr(service, "_install_chromium_browser", fake_install)
browser = await service._launch_browser(_FakePlaywright(fake_chromium), config)
assert browser is fake_browser
assert install_calls == ["data/test-playwright-browsers"]
assert len(fake_chromium.calls) == 2
browser_record = service._load_managed_browser_record()
assert browser_record is not None
assert browser_record.install_source == "auto_download"
assert browser_record.browser_name == "chromium"
@pytest.mark.asyncio
async def test_launch_browser_reuses_existing_managed_browser(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""已存在 Playwright 托管浏览器时,不应重复下载。"""
monkeypatch.setattr(html_render_service_module, "PROJECT_ROOT", tmp_path)
service = HTMLRenderService()
config = _build_render_config()
browsers_path = service._get_managed_browsers_path(config)
(browsers_path / "chrome-headless-shell-1234").mkdir(parents=True, exist_ok=True)
fake_browser = object()
fake_chromium = _FakeChromium([fake_browser])
monkeypatch.setattr(service, "_resolve_executable_path", lambda _config: "")
async def fail_install(_config: PluginRuntimeRenderConfig) -> None:
"""若被错误调用则立即失败。
Args:
_config: 当前浏览器渲染配置。
Raises:
AssertionError: 表示本测试不期望进入下载逻辑。
"""
raise AssertionError("不应触发自动下载")
monkeypatch.setattr(service, "_install_chromium_browser", fail_install)
browser = await service._launch_browser(_FakePlaywright(fake_chromium), config)
assert browser is fake_browser
assert len(fake_chromium.calls) == 1
browser_record = service._load_managed_browser_record()
assert browser_record is not None
assert browser_record.install_source == "existing_cache"
assert browser_record.browsers_path == str(browsers_path)
@pytest.mark.asyncio
async def test_launch_browser_prefers_local_executable(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""探测到本机浏览器时,应优先使用可执行文件路径启动。"""
monkeypatch.setattr(html_render_service_module, "PROJECT_ROOT", tmp_path)
service = HTMLRenderService()
config = _build_render_config()
fake_browser = object()
fake_chromium = _FakeChromium([fake_browser])
executable_path = "/usr/bin/google-chrome"
monkeypatch.setattr(service, "_resolve_executable_path", lambda _config: executable_path)
browser = await service._launch_browser(_FakePlaywright(fake_chromium), config)
assert browser is fake_browser
assert len(fake_chromium.calls) == 1
assert fake_chromium.calls[0]["executable_path"] == executable_path
assert service._load_managed_browser_record() is None
def test_managed_browser_record_roundtrip() -> None:
"""托管浏览器记录应支持序列化与反序列化。"""
record = ManagedBrowserRecord(
browser_name="chromium",
browsers_path="/tmp/playwright-browsers",
install_source="auto_download",
playwright_version="1.58.0",
recorded_at="2026-04-03T10:00:00+00:00",
last_verified_at="2026-04-03T10:00:01+00:00",
)
restored_record = ManagedBrowserRecord.from_dict(record.to_dict())
assert restored_record == record

View File

@@ -1,101 +0,0 @@
from typing import List
from src.llm_models.model_client.base_client import (
APIResponse,
AudioTranscriptionRequest,
BaseClient,
ClientProviderRegistration,
ClientRegistry,
EmbeddingRequest,
ResponseRequest,
)
class DummyClient(BaseClient):
"""测试用 LLM 客户端。"""
async def get_response(self, request: ResponseRequest) -> APIResponse:
"""获取测试响应。
Args:
request: 统一响应请求。
Returns:
APIResponse: 测试响应。
"""
del request
return APIResponse(content="ok")
async def get_embedding(self, request: EmbeddingRequest) -> APIResponse:
"""获取测试嵌入。
Args:
request: 统一嵌入请求。
Returns:
APIResponse: 测试嵌入响应。
"""
del request
return APIResponse(embedding=[1.0])
async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse:
"""获取测试音频转写。
Args:
request: 统一音频转写请求。
Returns:
APIResponse: 测试音频转写响应。
"""
del request
return APIResponse(content="audio")
def get_support_image_formats(self) -> List[str]:
"""获取测试支持的图片格式。
Returns:
List[str]: 支持的图片格式列表。
"""
return ["png"]
def test_client_registry_rejects_provider_conflict():
"""同一 client_type 被不同插件注册时应拒绝。"""
registry = ClientRegistry()
registry.replace_plugin_providers(
"plugin.alpha",
[
ClientProviderRegistration(
client_type="example",
factory=DummyClient,
owner_plugin_id="plugin.alpha",
)
],
)
try:
registry.validate_plugin_provider_replacement("plugin.beta", ["example"])
except ValueError as exc:
assert "冲突" in str(exc)
else:
raise AssertionError("不同插件注册相同 client_type 应失败")
def test_client_registry_unregisters_plugin_providers():
"""插件注销时应移除它拥有的 Provider 注册。"""
registry = ClientRegistry()
registry.replace_plugin_providers(
"plugin.alpha",
[
ClientProviderRegistration(
client_type="example",
factory=DummyClient,
owner_plugin_id="plugin.alpha",
)
],
)
removed_count = registry.unregister_plugin_providers("plugin.alpha")
assert removed_count == 1
assert "example" not in registry.client_registry

View File

@@ -1,113 +0,0 @@
from datetime import datetime
from types import SimpleNamespace
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo
from src.common.data_models.message_component_data_model import AtComponent, MessageSequence, ReplyComponent, TextComponent
from src.config.config import global_config
from src.maisaka.builtin_tool.context import BuiltinToolRuntimeContext
from src.maisaka.runtime import MaisakaHeartFlowChatting
def _build_sent_message() -> SessionMessage:
message = SessionMessage(
message_id="real-message-id",
timestamp=datetime(2026, 4, 5, 12, 0, 0),
platform="qq",
)
message.message_info = MessageInfo(
user_info=UserInfo(
user_id="bot-qq",
user_nickname="MaiSaka",
user_cardname=None,
),
group_info=None,
additional_config={},
)
message.raw_message = MessageSequence(
[
ReplyComponent(target_message_id="m123"),
TextComponent(text="你好"),
]
)
message.session_id = "test-session"
message.initialized = True
return message
def test_append_sent_message_to_chat_history_keeps_message_id() -> None:
runtime = SimpleNamespace(_chat_history=[])
engine = SimpleNamespace(_get_runtime_manager=lambda: None)
tool_ctx = BuiltinToolRuntimeContext(engine=engine, runtime=runtime)
tool_ctx.append_sent_message_to_chat_history(_build_sent_message())
assert len(runtime._chat_history) == 1
history_message = runtime._chat_history[0]
assert history_message.message_id == "real-message-id"
assert "[msg_id]real-message-id\n" in history_message.raw_message.components[0].text
assert "[msg_id:real-message-id]" in history_message.visible_text
def test_post_process_reply_message_sequences_converts_at_marker_before_bracket_cleanup(monkeypatch) -> None:
monkeypatch.setattr(global_config.chat, "enable_at", True)
monkeypatch.setattr(
"src.maisaka.builtin_tool.context.process_llm_response",
lambda text: [text.strip()] if text.strip() else [],
)
target_message = SimpleNamespace(
message_info=SimpleNamespace(
user_info=SimpleNamespace(
user_id="target-user",
user_nickname="目标昵称",
user_cardname="群名片",
)
)
)
runtime = SimpleNamespace(
find_source_message_by_id=lambda message_id: target_message if message_id == "12160142" else None
)
engine = SimpleNamespace(_get_runtime_manager=lambda: None)
tool_ctx = BuiltinToolRuntimeContext(engine=engine, runtime=runtime)
sequences = tool_ctx.post_process_reply_message_sequences("at[12160142] 就这个群")
assert len(sequences) == 1
components = sequences[0].components
assert isinstance(components[0], AtComponent)
assert components[0].target_user_id == "target-user"
assert components[0].target_user_nickname == "目标昵称"
assert components[0].target_user_cardname == "群名片"
assert isinstance(components[1], TextComponent)
assert components[1].text == " 就这个群"
def test_post_process_reply_message_sequences_ignores_at_marker_when_disabled(monkeypatch) -> None:
monkeypatch.setattr(global_config.chat, "enable_at", False)
monkeypatch.setattr(
"src.maisaka.builtin_tool.context.process_llm_response",
lambda text: [text.strip()] if text.strip() else [],
)
runtime = SimpleNamespace(find_source_message_by_id=lambda message_id: None)
engine = SimpleNamespace(_get_runtime_manager=lambda: None)
tool_ctx = BuiltinToolRuntimeContext(engine=engine, runtime=runtime)
sequences = tool_ctx.post_process_reply_message_sequences("at[12160142] 就这个群")
assert len(sequences) == 1
components = sequences[0].components
assert len(components) == 1
assert isinstance(components[0], TextComponent)
assert components[0].text == "at[12160142] 就这个群"
def test_runtime_finds_source_message_from_history() -> None:
target_message = _build_sent_message()
runtime = object.__new__(MaisakaHeartFlowChatting)
runtime._chat_history = [
SimpleNamespace(message_id="other-message-id", original_message=SimpleNamespace()),
SimpleNamespace(message_id="real-message-id", original_message=target_message),
]
assert runtime.find_source_message_by_id("real-message-id") is target_message
assert runtime.find_source_message_by_id("missing-message-id") is None

View File

@@ -1,241 +0,0 @@
from types import SimpleNamespace
from typing import Any, Dict
import pytest
from src.core.tooling import ToolInvocation
from src.maisaka.builtin_tool import query_memory as query_memory_tool
from src.maisaka.builtin_tool.context import BuiltinToolRuntimeContext
from src.services.memory_service import MemoryHit, MemorySearchResult
def _build_tool_ctx(
*,
session_id: str = "session-1",
platform: str = "qq",
user_id: str = "user-1",
group_id: str = "",
) -> BuiltinToolRuntimeContext:
runtime = SimpleNamespace(
session_id=session_id,
chat_stream=SimpleNamespace(
platform=platform,
user_id=user_id,
group_id=group_id,
),
log_prefix=f"[{session_id}]",
)
return BuiltinToolRuntimeContext(engine=SimpleNamespace(), runtime=runtime)
def _build_invocation(arguments: Dict[str, Any]) -> ToolInvocation:
return ToolInvocation(
tool_name="query_memory",
arguments=dict(arguments),
call_id="call-query-memory",
)
@pytest.fixture(autouse=True)
def _patch_maisaka_config(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
query_memory_tool,
"global_config",
SimpleNamespace(memory=SimpleNamespace(memory_query_default_limit=5)),
)
@pytest.mark.asyncio
async def test_query_memory_rejects_empty_query_and_time(monkeypatch: pytest.MonkeyPatch) -> None:
async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult:
_ = query
_ = kwargs
raise AssertionError("参数校验失败时不应调用 memory_service.search")
monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search)
result = await query_memory_tool.handle_tool(
_build_tool_ctx(),
_build_invocation({"query": "", "time_start": "", "time_end": ""}),
)
assert result.success is False
assert "query_memory 需要提供 query" in result.error_message
@pytest.mark.asyncio
async def test_query_memory_private_chat_auto_sets_person_id(monkeypatch: pytest.MonkeyPatch) -> None:
captured: Dict[str, Any] = {}
def fake_resolve_person_id_for_memory(
*,
person_name: str = "",
platform: str = "",
user_id: Any = None,
strict_known: bool = False,
) -> str:
_ = strict_known
captured["resolve_args"] = {
"person_name": person_name,
"platform": platform,
"user_id": user_id,
}
return "pid-private-auto"
async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult:
captured["query"] = query
captured["search_kwargs"] = dict(kwargs)
return MemorySearchResult(
summary="检索摘要",
hits=[MemoryHit(content="Alice 喜欢咖啡", score=0.91)],
)
monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search)
result = await query_memory_tool.handle_tool(
_build_tool_ctx(session_id="private-session", platform="qq", user_id="alice", group_id=""),
_build_invocation({"query": "Alice 的喜好"}),
)
assert result.success is True
assert captured["query"] == "Alice 的喜好"
assert captured["resolve_args"] == {
"person_name": "",
"platform": "qq",
"user_id": "alice",
}
assert captured["search_kwargs"]["chat_id"] == "private-session"
assert captured["search_kwargs"]["user_id"] == "alice"
assert captured["search_kwargs"]["group_id"] == ""
assert captured["search_kwargs"]["person_id"] == "pid-private-auto"
assert isinstance(result.structured_content, dict)
assert result.structured_content["person_id"] == "pid-private-auto"
@pytest.mark.asyncio
async def test_query_memory_group_chat_does_not_attach_default_person_id(monkeypatch: pytest.MonkeyPatch) -> None:
call_counter = {"resolve": 0}
captured_kwargs: Dict[str, Any] = {}
def fake_resolve_person_id_for_memory(
*,
person_name: str = "",
platform: str = "",
user_id: Any = None,
strict_known: bool = False,
) -> str:
_ = person_name
_ = platform
_ = user_id
_ = strict_known
call_counter["resolve"] += 1
return "unexpected-person-id"
async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult:
_ = query
captured_kwargs.update(kwargs)
return MemorySearchResult(summary="", hits=[])
monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search)
result = await query_memory_tool.handle_tool(
_build_tool_ctx(session_id="group-session", platform="qq", user_id="alice", group_id="group-1"),
_build_invocation({"query": "群聊上下文"}),
)
assert result.success is True
assert call_counter["resolve"] == 0
assert captured_kwargs["chat_id"] == "group-session"
assert captured_kwargs["group_id"] == "group-1"
assert captured_kwargs["person_id"] == ""
@pytest.mark.asyncio
async def test_query_memory_search_failure_is_returned(monkeypatch: pytest.MonkeyPatch) -> None:
async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult:
_ = query
_ = kwargs
return MemorySearchResult(success=False, error="boom")
monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search)
monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", lambda **kwargs: "")
result = await query_memory_tool.handle_tool(
_build_tool_ctx(),
_build_invocation({"query": "测试失败透传"}),
)
assert result.success is False
assert result.error_message == "boom"
assert isinstance(result.structured_content, dict)
assert result.structured_content["success"] is False
@pytest.mark.asyncio
async def test_query_memory_prefers_person_name_resolution(monkeypatch: pytest.MonkeyPatch) -> None:
captured: Dict[str, Any] = {"resolve_calls": []}
def fake_resolve_person_id_for_memory(
*,
person_name: str = "",
platform: str = "",
user_id: Any = None,
strict_known: bool = False,
) -> str:
_ = strict_known
captured["resolve_calls"].append(
{
"person_name": person_name,
"platform": platform,
"user_id": user_id,
}
)
if person_name:
return "pid-by-name"
return "pid-private-auto"
async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult:
_ = query
captured["search_kwargs"] = dict(kwargs)
return MemorySearchResult(summary="", hits=[MemoryHit(content="命中1")])
monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search)
result = await query_memory_tool.handle_tool(
_build_tool_ctx(session_id="private-session", platform="qq", user_id="alice", group_id=""),
_build_invocation({"query": "小明资料", "person_name": "小明"}),
)
assert result.success is True
assert captured["resolve_calls"][0] == {
"person_name": "小明",
"platform": "qq",
"user_id": "alice",
}
assert captured["search_kwargs"]["person_id"] == "pid-by-name"
assert result.structured_content["person_name"] == "小明"
assert result.structured_content["person_id"] == "pid-by-name"
@pytest.mark.asyncio
async def test_query_memory_no_hit_returns_readable_message(monkeypatch: pytest.MonkeyPatch) -> None:
async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult:
_ = query
_ = kwargs
return MemorySearchResult(summary="", hits=[])
monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search)
monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", lambda **kwargs: "")
result = await query_memory_tool.handle_tool(
_build_tool_ctx(),
_build_invocation({"query": "不存在的记忆"}),
)
assert result.success is True
assert "未找到匹配的长期记忆" in result.content
assert isinstance(result.structured_content, dict)
assert result.structured_content["query"] == "不存在的记忆"

View File

@@ -1,105 +0,0 @@
from types import SimpleNamespace
import pytest
import time
from src.chat.heart_flow import heartflow_manager as heartflow_manager_module
from src.chat.heart_flow.heartflow_manager import HEARTFLOW_ACTIVE_RETENTION_SECONDS, HeartflowManager
from src.learners.expression_learner import ExpressionLearner
from src.maisaka.runtime import MAX_RETAINED_MESSAGE_CACHE_SIZE, MaisakaHeartFlowChatting
def _build_runtime_with_messages(message_count: int) -> MaisakaHeartFlowChatting:
runtime = object.__new__(MaisakaHeartFlowChatting)
runtime.log_prefix = "[test]"
runtime.message_cache = [SimpleNamespace(message_id=f"msg-{index}") for index in range(message_count)]
runtime._last_processed_index = message_count
runtime._expression_learner = ExpressionLearner("session-1")
runtime._expression_learner.mark_all_processed(runtime.message_cache)
return runtime
def test_prune_processed_message_cache_keeps_bounded_recent_window() -> None:
runtime = _build_runtime_with_messages(MAX_RETAINED_MESSAGE_CACHE_SIZE + 25)
runtime._prune_processed_message_cache()
assert len(runtime.message_cache) == MAX_RETAINED_MESSAGE_CACHE_SIZE
assert runtime.message_cache[0].message_id == "msg-25"
assert runtime._last_processed_index == MAX_RETAINED_MESSAGE_CACHE_SIZE
assert runtime._expression_learner.last_processed_index == MAX_RETAINED_MESSAGE_CACHE_SIZE
def test_prune_processed_message_cache_keeps_unlearned_messages() -> None:
runtime = _build_runtime_with_messages(MAX_RETAINED_MESSAGE_CACHE_SIZE + 25)
runtime._expression_learner.discard_processed_prefix(MAX_RETAINED_MESSAGE_CACHE_SIZE + 5)
runtime._prune_processed_message_cache()
assert len(runtime.message_cache) == MAX_RETAINED_MESSAGE_CACHE_SIZE + 5
assert runtime.message_cache[0].message_id == "msg-20"
assert runtime._expression_learner.last_processed_index == 0
def test_collect_pending_messages_uses_single_pending_received_time() -> None:
runtime = _build_runtime_with_messages(2)
runtime._last_processed_index = 0
runtime._oldest_pending_message_received_at = 123.0
runtime._last_message_received_at = 456.0
runtime._reply_latency_measurement_started_at = None
pending_messages = runtime._collect_pending_messages()
assert [message.message_id for message in pending_messages] == ["msg-0", "msg-1"]
assert runtime._reply_latency_measurement_started_at == 123.0
assert runtime._oldest_pending_message_received_at is None
@pytest.mark.asyncio
async def test_heartflow_manager_evicts_lru_chat_over_limit(monkeypatch: pytest.MonkeyPatch) -> None:
manager = HeartflowManager()
stopped_session_ids: list[str] = []
old_active_at = time.time() - HEARTFLOW_ACTIVE_RETENTION_SECONDS - 1
class FakeChat:
def __init__(self, session_id: str) -> None:
self.session_id = session_id
async def stop(self) -> None:
stopped_session_ids.append(self.session_id)
monkeypatch.setattr(heartflow_manager_module, "HEARTFLOW_MAX_ACTIVE_CHATS", 2)
manager.heartflow_chat_list["session-1"] = FakeChat("session-1")
manager.heartflow_chat_list["session-2"] = FakeChat("session-2")
manager.heartflow_chat_list["session-3"] = FakeChat("session-3")
manager._chat_last_active_at["session-1"] = old_active_at
manager._chat_last_active_at["session-2"] = old_active_at
manager._chat_last_active_at["session-3"] = time.time()
await manager._evict_over_limit_chats(protected_session_id="session-3")
assert stopped_session_ids == ["session-1"]
assert list(manager.heartflow_chat_list) == ["session-2", "session-3"]
@pytest.mark.asyncio
async def test_heartflow_manager_keeps_recent_chats_even_over_limit(monkeypatch: pytest.MonkeyPatch) -> None:
manager = HeartflowManager()
stopped_session_ids: list[str] = []
class FakeChat:
def __init__(self, session_id: str) -> None:
self.session_id = session_id
async def stop(self) -> None:
stopped_session_ids.append(self.session_id)
monkeypatch.setattr(heartflow_manager_module, "HEARTFLOW_MAX_ACTIVE_CHATS", 2)
for session_id in ("session-1", "session-2", "session-3"):
manager.heartflow_chat_list[session_id] = FakeChat(session_id)
manager._chat_last_active_at[session_id] = time.time()
await manager._evict_over_limit_chats(protected_session_id="session-3")
assert stopped_session_ids == []
assert list(manager.heartflow_chat_list) == ["session-1", "session-2", "session-3"]

View File

@@ -1,54 +0,0 @@
from datetime import datetime
from pathlib import Path
import sys
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
from src.llm_models.payload_content.tool_option import ToolCall
from src.maisaka.message_adapter import build_message, get_message_kind, get_message_role, get_tool_call_id, get_tool_calls
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
def test_build_message_returns_session_message_with_maisaka_metadata() -> None:
timestamp = datetime.now()
tool_call = ToolCall(
call_id="call-1",
func_name="reply",
args={"message_id": "msg-1"},
)
raw_message = MessageSequence(components=[TextComponent(text="内部消息内容")])
message = build_message(
role="assistant",
content="展示消息内容",
message_kind="perception",
source="assistant",
tool_call_id="call-1",
tool_calls=[tool_call],
timestamp=timestamp,
message_id="maisaka-msg-1",
raw_message=raw_message,
display_text="展示消息内容",
)
assert isinstance(message, SessionMessage)
assert message.initialized is True
assert message.message_id == "maisaka-msg-1"
assert message.timestamp == timestamp
assert message.processed_plain_text == "展示消息内容"
assert message.raw_message is raw_message
assert get_message_role(message) == "assistant"
assert get_message_kind(message) == "perception"
assert get_tool_call_id(message) == "call-1"
restored_tool_calls = get_tool_calls(message)
assert len(restored_tool_calls) == 1
assert restored_tool_calls[0].call_id == "call-1"
assert restored_tool_calls[0].func_name == "reply"
assert restored_tool_calls[0].args == {"message_id": "msg-1"}

View File

@@ -1,619 +0,0 @@
from types import SimpleNamespace
from typing import Any, Callable
import pytest
from rich.panel import Panel
from rich.text import Text
from src.chat.replyer import maisaka_generator as replyer_module
from src.common.data_models.reply_generation_data_models import (
GenerationMetrics,
LLMCompletionResult,
ReplyGenerationResult,
)
from src.core.tooling import ToolExecutionResult, ToolInvocation
from src.maisaka.builtin_tool.context import BuiltinToolRuntimeContext
from src.maisaka.builtin_tool import reply as reply_tool_module
from src.maisaka.builtin_tool import send_emoji as send_emoji_tool_module
from src.maisaka.monitor_events import emit_planner_finalized
from src.maisaka.reasoning_engine import MaisakaReasoningEngine
from src.maisaka import runtime as runtime_module
from src.maisaka.runtime import MaisakaHeartFlowChatting
def test_runtime_maps_expression_config_flags_to_correct_fields(monkeypatch: pytest.MonkeyPatch) -> None:
fake_chat_stream = SimpleNamespace(
is_group_session=True,
group_id="group-1",
user_id="user-1",
platform="test",
)
monkeypatch.setattr(
runtime_module.chat_manager,
"get_session_by_session_id",
lambda session_id: fake_chat_stream,
)
monkeypatch.setattr(runtime_module.chat_manager, "get_session_name", lambda session_id: "测试会话")
monkeypatch.setattr(
runtime_module.ExpressionConfigUtils,
"get_expression_config_for_chat",
staticmethod(lambda session_id: (True, False, True)),
)
monkeypatch.setattr(runtime_module, "ExpressionLearner", lambda session_id: SimpleNamespace())
monkeypatch.setattr(runtime_module, "JargonMiner", lambda session_id, session_name: SimpleNamespace())
monkeypatch.setattr(runtime_module, "MaisakaReasoningEngine", lambda runtime: SimpleNamespace())
monkeypatch.setattr(runtime_module, "ToolRegistry", lambda: SimpleNamespace())
monkeypatch.setattr(runtime_module, "ReplyEffectTracker", lambda **kwargs: SimpleNamespace())
monkeypatch.setattr(MaisakaHeartFlowChatting, "_register_tool_providers", lambda self: None)
monkeypatch.setattr(MaisakaHeartFlowChatting, "_emit_monitor_session_start", lambda self: None)
runtime = MaisakaHeartFlowChatting("session-1")
assert runtime._enable_expression_use is True
assert runtime._enable_expression_learning is False
assert runtime._enable_jargon_learning is True
class _FakeLLMResult:
def __init__(self) -> None:
self.response = "测试回复"
self.reasoning = "先理解上下文,再给出自然回复。"
self.model_name = "fake-model"
self.tool_calls = []
self.prompt_tokens = 12
self.completion_tokens = 7
self.total_tokens = 19
class _FakeLegacyLLMServiceClient:
def __init__(self, *args: Any, **kwargs: Any) -> None:
del args
del kwargs
async def generate_response_with_messages(self, *, message_factory: Callable[[object], list[Any]]) -> _FakeLLMResult:
assert message_factory(object())
return _FakeLLMResult()
class _FakeMultimodalLLMServiceClient:
def __init__(self, *args: Any, **kwargs: Any) -> None:
del args
del kwargs
async def generate_response_with_messages(self, *, message_factory: Callable[[object], list[Any]]) -> _FakeLLMResult:
assert message_factory(object())
return _FakeLLMResult()
@pytest.mark.asyncio
async def test_legacy_and_multimodal_replyer_monitor_detail_have_same_shape(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(replyer_module, "LLMServiceClient", _FakeLegacyLLMServiceClient)
monkeypatch.setattr(replyer_module, "load_prompt", lambda *args, **kwargs: "legacy prompt")
legacy_generator = replyer_module.MaisakaReplyGenerator(
chat_stream=None,
request_type="test_legacy",
enable_visual_message=False,
)
multimodal_generator = replyer_module.MaisakaReplyGenerator(
chat_stream=None,
request_type="test_multi",
llm_client_cls=_FakeMultimodalLLMServiceClient,
load_prompt_func=lambda *args, **kwargs: "multi prompt",
enable_visual_message=True,
)
legacy_success, legacy_result = await legacy_generator.generate_reply_with_context(
stream_id="session-legacy",
chat_history=[],
reply_reason="测试原因",
)
multimodal_success, multimodal_result = await multimodal_generator.generate_reply_with_context(
stream_id="session-multi",
chat_history=[],
reply_reason="测试原因",
)
assert legacy_success is True
assert multimodal_success is True
assert legacy_result.monitor_detail is not None
assert multimodal_result.monitor_detail is not None
assert set(legacy_result.monitor_detail.keys()) == set(multimodal_result.monitor_detail.keys())
assert set(legacy_result.monitor_detail["metrics"].keys()) == set(multimodal_result.monitor_detail["metrics"].keys())
assert legacy_result.monitor_detail["metrics"]["prompt_tokens"] == 12
assert legacy_result.monitor_detail["metrics"]["completion_tokens"] == 7
assert legacy_result.monitor_detail["metrics"]["total_tokens"] == 19
def test_legacy_replyer_builds_message_sequence_like_multimodal() -> None:
legacy_generator = replyer_module.MaisakaReplyGenerator(
chat_stream=None,
request_type="test_legacy",
enable_visual_message=False,
)
legacy_prompt_loader = replyer_module.load_prompt
replyer_module.load_prompt = lambda *args, **kwargs: "legacy prompt"
try:
session_message = replyer_module.SessionBackedMessage(
raw_message=SimpleNamespace(),
visible_text="[Alice]你好\n[Bob]在吗",
timestamp=replyer_module.datetime.now(),
source_kind="user",
)
request_messages = legacy_generator._build_request_messages(
chat_history=[session_message],
reply_message=None,
reply_reason="测试原因",
stream_id="session-legacy",
)
finally:
replyer_module.load_prompt = legacy_prompt_loader
assert len(request_messages) == 4
assert request_messages[0].role.value == "system"
assert request_messages[0].get_text_content() == "legacy prompt"
assert request_messages[1].role.value == "user"
assert request_messages[1].get_text_content() == "[Alice]你好"
assert request_messages[2].role.value == "user"
assert request_messages[2].get_text_content() == "[Bob]在吗"
assert request_messages[3].role.value == "user"
assert "当前时间:" in request_messages[3].get_text_content()
assert "【回复信息参考】" in request_messages[3].get_text_content()
assert "【最新推理】\n测试原因" in request_messages[3].get_text_content()
assert "请自然地回复。" in request_messages[3].get_text_content()
@pytest.mark.asyncio
async def test_reply_tool_puts_monitor_detail_into_metadata(monkeypatch: pytest.MonkeyPatch) -> None:
fake_monitor_detail = {
"prompt_text": "reply prompt",
"reasoning_text": "reply reasoning",
"output_text": "reply output",
"metrics": {"model_name": "fake-model", "total_tokens": 10},
}
fake_reply_result = ReplyGenerationResult(
success=True,
completion=LLMCompletionResult(response_text="测试回复"),
metrics=GenerationMetrics(overall_ms=11.5),
monitor_detail=fake_monitor_detail,
)
class _FakeReplyer:
async def generate_reply_with_context(self, **kwargs: Any) -> tuple[bool, ReplyGenerationResult]:
del kwargs
return True, fake_reply_result
monkeypatch.setattr(reply_tool_module.replyer_manager, "get_replyer", lambda **kwargs: _FakeReplyer())
monkeypatch.setattr(reply_tool_module, "render_cli_message", lambda text: text)
target_message = SimpleNamespace(
message_id="msg-1",
message_info=SimpleNamespace(
user_info=SimpleNamespace(
user_cardname="测试用户",
user_nickname="测试用户",
user_id="user-1",
)
),
)
runtime = SimpleNamespace(
find_source_message_by_id=lambda message_id: target_message if message_id == "msg-1" else None,
log_prefix="[test]",
chat_stream=SimpleNamespace(platform=reply_tool_module.CLI_PLATFORM_NAME),
session_id="session-1",
_chat_history=[],
_clear_force_continue_until_reply=lambda: None,
_record_reply_sent=lambda: None,
run_sub_agent=None,
)
engine = SimpleNamespace(_get_runtime_manager=lambda: None)
tool_ctx = BuiltinToolRuntimeContext(engine=engine, runtime=runtime)
invocation = ToolInvocation(tool_name="reply", arguments={"msg_id": "msg-1", "set_quote": True})
result = await reply_tool_module.handle_tool(tool_ctx, invocation)
assert result.success is True
assert result.metadata["monitor_detail"] == fake_monitor_detail
@pytest.mark.asyncio
async def test_send_emoji_tool_puts_monitor_detail_into_metadata(monkeypatch: pytest.MonkeyPatch) -> None:
async def _fake_build_emoji_candidate_message(emojis: list[Any]) -> object:
assert emojis
return SimpleNamespace()
async def _fake_send_emoji_for_maisaka(**kwargs: Any) -> Any:
selected_emoji, matched_emotion = await kwargs["emoji_selector"](
kwargs["requested_emotion"],
kwargs["reasoning"],
kwargs["context_texts"],
2,
)
assert selected_emoji is not None
return SimpleNamespace(
success=True,
message="已发送表情包:开心",
emoji_base64="ZW1vamk=",
description="开心",
emotions=["开心", "可爱"],
matched_emotion=matched_emotion or "开心",
sent_message=None,
)
monkeypatch.setattr(send_emoji_tool_module, "_build_emoji_candidate_message", _fake_build_emoji_candidate_message)
monkeypatch.setattr(send_emoji_tool_module, "send_emoji_for_maisaka", _fake_send_emoji_for_maisaka)
monkeypatch.setattr(
send_emoji_tool_module.emoji_manager,
"emojis",
[
SimpleNamespace(description="开心,可爱", emotion=["开心", "可爱"]),
SimpleNamespace(description="难过", emotion=["难过"]),
],
)
async def _fake_run_sub_agent(**kwargs: Any) -> Any:
del kwargs
return SimpleNamespace(
content='{"emoji_index": 1, "reason": "更贴合当前语气"}',
prompt_tokens=9,
completion_tokens=6,
total_tokens=15,
)
runtime = SimpleNamespace(
_chat_history=[],
log_prefix="[test]",
session_id="session-emoji",
run_sub_agent=_fake_run_sub_agent,
)
engine = SimpleNamespace(last_reasoning_content="用户刚刚表达了开心情绪")
tool_ctx = BuiltinToolRuntimeContext(engine=engine, runtime=runtime)
invocation = ToolInvocation(tool_name="send_emoji", arguments={"emotion": "开心"})
result = await send_emoji_tool_module.handle_tool(tool_ctx, invocation)
assert result.success is True
assert result.metadata["monitor_detail"]["prompt_text"]
assert result.metadata["monitor_detail"]["reasoning_text"] == "更贴合当前语气"
assert result.metadata["monitor_detail"]["metrics"]["total_tokens"] == 15
assert any(
section["title"] == "表情发送结果"
for section in result.metadata["monitor_detail"]["extra_sections"]
)
@pytest.mark.asyncio
async def test_emit_planner_finalized_broadcasts_new_protocol(monkeypatch: pytest.MonkeyPatch) -> None:
captured: dict[str, Any] = {}
async def _fake_broadcast(event: str, data: dict[str, Any]) -> None:
captured["event"] = event
captured["data"] = data
monkeypatch.setattr("src.maisaka.monitor_events._broadcast", _fake_broadcast)
await emit_planner_finalized(
session_id="session-1",
cycle_id=3,
timing_request_messages=[{"role": "user", "content": "先看看要不要继续"}],
timing_selected_history_count=3,
timing_tool_count=1,
timing_action="continue",
timing_content="继续",
timing_tool_calls=[SimpleNamespace(call_id="timing-call-1", func_name="continue", args={})],
timing_tool_results=["- continue [成功]: 继续执行"],
timing_prompt_tokens=40,
timing_completion_tokens=5,
timing_total_tokens=45,
timing_duration_ms=11.2,
planner_request_messages=[{"role": "user", "content": "你好"}],
planner_selected_history_count=5,
planner_tool_count=2,
planner_content="先查询再回复",
planner_tool_calls=[SimpleNamespace(call_id="call-1", func_name="reply", args={"msg_id": "m1"})],
planner_prompt_tokens=100,
planner_completion_tokens=30,
planner_total_tokens=130,
planner_duration_ms=88.5,
tools=[
{
"tool_call_id": "call-1",
"tool_name": "reply",
"tool_args": {"msg_id": "m1"},
"success": True,
"duration_ms": 22.0,
"summary": "- reply [成功]: 已回复",
"detail": {"output_text": "测试回复"},
}
],
time_records={"planner": 0.1, "tool_calls": 0.2},
agent_state="stop",
)
assert captured["event"] == "planner.finalized"
payload = captured["data"]
assert payload["timing_gate"]["result"]["action"] == "continue"
assert payload["timing_gate"]["result"]["tool_results"] == ["- continue [成功]: 继续执行"]
assert payload["request"]["messages"][0]["content"] == "你好"
assert payload["request"]["tool_count"] == 2
assert payload["planner"]["tool_calls"][0]["id"] == "call-1"
assert payload["tools"][0]["detail"]["output_text"] == "测试回复"
assert payload["final_state"]["agent_state"] == "stop"
@pytest.mark.asyncio
async def test_emit_planner_finalized_supports_timing_only_cycle(monkeypatch: pytest.MonkeyPatch) -> None:
captured: dict[str, Any] = {}
async def _fake_broadcast(event: str, data: dict[str, Any]) -> None:
captured["event"] = event
captured["data"] = data
monkeypatch.setattr("src.maisaka.monitor_events._broadcast", _fake_broadcast)
await emit_planner_finalized(
session_id="session-2",
cycle_id=7,
timing_request_messages=[{"role": "user", "content": "先别回"}],
timing_selected_history_count=2,
timing_tool_count=1,
timing_action="no_reply",
timing_content="当前不适合继续",
timing_tool_calls=[SimpleNamespace(call_id="timing-call-2", func_name="no_reply", args={})],
timing_tool_results=["- no_reply [成功]: 暂停当前对话"],
timing_prompt_tokens=18,
timing_completion_tokens=4,
timing_total_tokens=22,
timing_duration_ms=6.5,
planner_request_messages=None,
planner_selected_history_count=None,
planner_tool_count=None,
planner_content=None,
planner_tool_calls=None,
planner_prompt_tokens=None,
planner_completion_tokens=None,
planner_total_tokens=None,
planner_duration_ms=None,
tools=[],
time_records={"timing_gate": 0.02},
agent_state="stop",
)
assert captured["event"] == "planner.finalized"
payload = captured["data"]
assert payload["timing_gate"]["result"]["action"] == "no_reply"
assert payload["planner"] is None
assert payload["request"] is None
def test_reasoning_engine_build_tool_monitor_result_keeps_non_reply_tool_without_detail() -> None:
engine = object.__new__(MaisakaReasoningEngine)
tool_call = SimpleNamespace(call_id="call-2", func_name="query_memory")
invocation = ToolInvocation(tool_name="query_memory", arguments={"query": "Alice"})
result = ToolExecutionResult(tool_name="query_memory", success=True, content="查询成功")
tool_result = engine._build_tool_monitor_result(tool_call, invocation, result, duration_ms=18.6)
assert tool_result["tool_call_id"] == "call-2"
assert tool_result["tool_name"] == "query_memory"
assert tool_result["tool_args"] == {"query": "Alice"}
assert tool_result["detail"] is None
def test_runtime_build_tool_detail_panels_renders_reply_monitor_detail() -> None:
runtime = object.__new__(MaisakaHeartFlowChatting)
runtime.session_id = "session-1"
panels = runtime._build_tool_detail_cards(
[
{
"tool_call_id": "call-reply-1",
"tool_name": "reply",
"tool_args": {"msg_id": "m1"},
"success": True,
"duration_ms": 20.5,
"summary": "- reply [成功]: 已回复",
"detail": {
"prompt_text": "reply prompt",
"reasoning_text": "reply reasoning",
"output_text": "reply output",
"metrics": {
"model_name": "fake-model",
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15,
"prompt_ms": 2.1,
"llm_ms": 18.4,
"overall_ms": 20.5,
},
},
}
],
stage_title="工具调用",
)
assert len(panels) == 1
assert isinstance(panels[0], Panel)
def test_runtime_filter_redundant_tool_results_keeps_only_non_detailed_summary() -> None:
filtered_results = MaisakaHeartFlowChatting._filter_redundant_tool_results(
tool_results=[
"- reply [成功]: 已回复",
"- query_memory [成功]: 查询到 2 条记录",
],
tool_detail_results=[
{
"summary": "- reply [成功]: 已回复",
"detail": {"output_text": "测试回复"},
}
],
)
assert filtered_results == ["- query_memory [成功]: 查询到 2 条记录"]
def test_runtime_build_tool_detail_panels_uses_prompt_access_panel(monkeypatch: pytest.MonkeyPatch) -> None:
runtime = object.__new__(MaisakaHeartFlowChatting)
runtime.session_id = "session-link"
captured: dict[str, Any] = {}
def _fake_build_text_access_panel(content: str, **kwargs: Any) -> str:
captured["content"] = content
captured["kwargs"] = kwargs
return "PROMPT_LINK"
monkeypatch.setattr(
"src.maisaka.runtime.PromptCLIVisualizer.build_text_access_panel",
_fake_build_text_access_panel,
)
panels = runtime._build_tool_detail_cards(
[
{
"tool_call_id": "call-reply-2",
"tool_name": "reply",
"tool_args": {"msg_id": "m2"},
"success": True,
"duration_ms": 12.0,
"summary": "- reply [成功]: 已回复",
"detail": {
"prompt_text": "reply prompt link",
"output_text": "reply output",
},
}
],
stage_title="工具调用",
)
assert len(panels) == 1
assert captured["content"] == "reply prompt link"
assert captured["kwargs"]["chat_id"] == "session-link"
assert captured["kwargs"]["request_kind"] == "replyer"
def test_runtime_build_tool_detail_panels_uses_emotion_prompt_access_panel(monkeypatch: pytest.MonkeyPatch) -> None:
runtime = object.__new__(MaisakaHeartFlowChatting)
runtime.session_id = "session-emotion"
captured: dict[str, Any] = {}
def _fake_build_text_access_panel(content: str, **kwargs: Any) -> str:
captured["content"] = content
captured["kwargs"] = kwargs
return "EMOTION_PROMPT_LINK"
monkeypatch.setattr(
"src.maisaka.runtime.PromptCLIVisualizer.build_text_access_panel",
_fake_build_text_access_panel,
)
panels = runtime._build_tool_detail_cards(
[
{
"tool_call_id": "call-emoji-1",
"tool_name": "send_emoji",
"tool_args": {"emotion": "开心"},
"success": True,
"duration_ms": 15.0,
"summary": "- send_emoji [成功]: 已发送表情包",
"detail": {
"prompt_text": "emotion prompt link",
"output_text": '{"emoji_index": 1}',
},
}
],
stage_title="工具调用",
)
assert len(panels) == 1
assert captured["content"] == "emotion prompt link"
assert captured["kwargs"]["chat_id"] == "session-emotion"
assert captured["kwargs"]["request_kind"] == "emotion"
def test_runtime_build_tool_detail_cards_uses_structured_prompt_messages_with_images(
monkeypatch: pytest.MonkeyPatch,
) -> None:
runtime = object.__new__(MaisakaHeartFlowChatting)
runtime.session_id = "session-image"
captured: dict[str, Any] = {}
def _fake_build_prompt_access_panel(messages: list[Any], **kwargs: Any) -> str:
captured["messages"] = messages
captured["kwargs"] = kwargs
return "IMAGE_PROMPT_LINK"
def _fake_build_text_access_panel(content: str, **kwargs: Any) -> str:
captured["text_content"] = content
captured["text_kwargs"] = kwargs
return "TEXT_PROMPT_LINK"
monkeypatch.setattr(
"src.maisaka.runtime.PromptCLIVisualizer.build_prompt_access_panel",
_fake_build_prompt_access_panel,
)
monkeypatch.setattr(
"src.maisaka.runtime.PromptCLIVisualizer.build_text_access_panel",
_fake_build_text_access_panel,
)
panels = runtime._build_tool_detail_cards(
[
{
"tool_call_id": "call-reply-image-1",
"tool_name": "reply",
"tool_args": {"msg_id": "m3"},
"success": True,
"duration_ms": 22.0,
"summary": "- reply [成功]: 已回复",
"detail": {
"prompt_text": "reply prompt image",
"request_messages": [
{
"role": "user",
"content": ["前缀文本", ["png", "ZmFrZQ=="]],
}
],
"output_text": "reply output",
},
}
],
stage_title="工具调用",
)
assert len(panels) == 1
assert "messages" in captured
assert "text_content" not in captured
assert captured["kwargs"]["chat_id"] == "session-image"
assert captured["kwargs"]["request_kind"] == "replyer"
def test_runtime_render_context_usage_panel_merges_timing_and_planner(monkeypatch: pytest.MonkeyPatch) -> None:
runtime = object.__new__(MaisakaHeartFlowChatting)
runtime.session_id = "session-merged"
runtime.session_name = "测试聊天流"
runtime._max_context_size = 20
printed: list[Any] = []
monkeypatch.setattr("src.maisaka.runtime.console.print", lambda renderable: printed.append(renderable))
runtime._render_context_usage_panel(
cycle_id=12,
timing_selected_history_count=3,
timing_prompt_tokens=15,
timing_action="continue",
timing_response="继续执行",
planner_selected_history_count=5,
planner_prompt_tokens=42,
planner_response="先查询再回复",
)
assert len(printed) == 1
outer_panel = printed[0]
assert isinstance(outer_panel, Panel)
renderables = list(outer_panel.renderable.renderables)
assert isinstance(renderables[0], Text)
assert "聊天流名称:测试聊天流" in renderables[0].plain
assert "聊天流IDsession-merged" in renderables[0].plain
assert len(renderables) == 3

View File

@@ -1,339 +0,0 @@
from datetime import datetime
from types import SimpleNamespace
import asyncio
import pytest
from src.core.tooling import ToolAvailabilityContext, ToolExecutionResult, ToolInvocation
from src.llm_models.payload_content.tool_option import ToolCall
from src.maisaka.builtin_tool import get_timing_tools
from src.maisaka.chat_loop_service import ChatResponse, MaisakaChatLoopService
from src.maisaka.context_messages import AssistantMessage, TIMING_GATE_INVALID_TOOL_HINT_SOURCE
from src.maisaka.reasoning_engine import MaisakaReasoningEngine
from src.maisaka.runtime import MaisakaHeartFlowChatting
def _build_chat_response(tool_calls: list[ToolCall]) -> ChatResponse:
return ChatResponse(
content="The model returned an invalid timing tool.",
tool_calls=tool_calls,
request_messages=[],
raw_message=AssistantMessage(
content="",
timestamp=datetime.now(),
source_kind="perception",
),
selected_history_count=1,
tool_count=len(tool_calls),
prompt_tokens=10,
built_message_count=1,
completion_tokens=3,
total_tokens=13,
prompt_section=None,
)
def _build_runtime_stub(*, is_group_chat: bool) -> SimpleNamespace:
return SimpleNamespace(
_force_next_timing_continue=False,
_chat_history=[],
session_id="test-session",
chat_stream=SimpleNamespace(
session_id="test-session",
stream_id="test-stream",
is_group_session=is_group_chat,
group_id="group-1" if is_group_chat else "",
user_id="user-1",
platform="qq",
),
_chat_loop_service=SimpleNamespace(build_prompt_template_context=lambda: {}),
log_prefix="[test]",
stopped=False,
)
def test_timing_gate_tools_expose_wait_only_in_private_chat() -> None:
private_tool_names = {
tool_definition["name"]
for tool_definition in get_timing_tools(ToolAvailabilityContext(is_group_chat=False))
}
group_tool_names = {
tool_definition["name"]
for tool_definition in get_timing_tools(ToolAvailabilityContext(is_group_chat=True))
}
assert private_tool_names == {"continue", "no_reply", "wait"}
assert group_tool_names == {"continue", "no_reply"}
@pytest.mark.asyncio
async def test_timing_gate_invalid_tool_defaults_to_no_reply(monkeypatch: pytest.MonkeyPatch) -> None:
runtime = _build_runtime_stub(is_group_chat=True)
def _enter_stop_state() -> None:
runtime.stopped = True
runtime._enter_stop_state = _enter_stop_state
engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type]
call_count = 0
async def _fake_timing_gate_sub_agent(**kwargs: object) -> ChatResponse:
nonlocal call_count
del kwargs
call_count += 1
return _build_chat_response([
ToolCall(call_id="invalid-timing-tool", func_name="finish", args={}),
])
async def _fail_invoke_tool_call(*args: object, **kwargs: object) -> None:
del args, kwargs
raise AssertionError("invalid timing tools must not be executed")
monkeypatch.setattr(engine, "_run_timing_gate_sub_agent", _fake_timing_gate_sub_agent)
monkeypatch.setattr(engine, "_invoke_tool_call", _fail_invoke_tool_call)
action, response, tool_results, tool_monitor_results = await engine._run_timing_gate(object()) # type: ignore[arg-type]
assert action == "no_reply"
assert call_count == 3
assert response.tool_calls[0].func_name == "finish"
assert runtime.stopped is True
assert tool_monitor_results == []
assert len(runtime._chat_history) == 1
assert runtime._chat_history[0].source == TIMING_GATE_INVALID_TOOL_HINT_SOURCE
assert "finish" in runtime._chat_history[0].processed_plain_text
assert tool_results == [
"- retry [非法 Timing 工具]: 返回了 finish将重试 (1/3)",
"- retry [非法 Timing 工具]: 返回了 finish将重试 (2/3)",
"- no_reply [非法 Timing 工具]: 返回了 finish已停止本轮并等待新消息",
]
@pytest.mark.asyncio
async def test_timing_gate_invalid_tool_retries_until_valid(monkeypatch: pytest.MonkeyPatch) -> None:
runtime = _build_runtime_stub(is_group_chat=True)
def _enter_stop_state() -> None:
runtime.stopped = True
runtime._enter_stop_state = _enter_stop_state
engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type]
responses = [
_build_chat_response([ToolCall(call_id="invalid-timing-tool", func_name="finish", args={})]),
_build_chat_response([ToolCall(call_id="valid-timing-tool", func_name="continue", args={})]),
]
async def _fake_timing_gate_sub_agent(**kwargs: object) -> ChatResponse:
del kwargs
return responses.pop(0)
async def _fake_invoke_tool_call(
tool_call: ToolCall,
latest_thought: str,
anchor_message: object,
*,
append_history: bool = True,
store_record: bool = True,
) -> tuple[ToolInvocation, ToolExecutionResult, None]:
del latest_thought, anchor_message, append_history, store_record
return (
ToolInvocation(tool_name=tool_call.func_name, call_id=tool_call.call_id),
ToolExecutionResult(
tool_name=tool_call.func_name,
success=True,
content="继续执行主流程",
metadata={"timing_action": "continue"},
),
None,
)
monkeypatch.setattr(engine, "_run_timing_gate_sub_agent", _fake_timing_gate_sub_agent)
monkeypatch.setattr(engine, "_invoke_tool_call", _fake_invoke_tool_call)
action, response, tool_results, tool_monitor_results = await engine._run_timing_gate(object()) # type: ignore[arg-type]
assert action == "continue"
assert response.tool_calls[0].func_name == "continue"
assert runtime.stopped is False
assert len(runtime._chat_history) == 2
assert all(message.source != TIMING_GATE_INVALID_TOOL_HINT_SOURCE for message in runtime._chat_history)
assert tool_results == [
"- retry [非法 Timing 工具]: 返回了 finish将重试 (1/3)",
"- continue [成功]: 继续执行主流程",
]
assert tool_monitor_results[0]["tool_name"] == "continue"
@pytest.mark.asyncio
async def test_timing_gate_group_chat_treats_wait_as_invalid(monkeypatch: pytest.MonkeyPatch) -> None:
runtime = _build_runtime_stub(is_group_chat=True)
def _enter_stop_state() -> None:
runtime.stopped = True
runtime._enter_stop_state = _enter_stop_state
engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type]
async def _fake_timing_gate_sub_agent(**kwargs: object) -> ChatResponse:
tool_definitions = kwargs["tool_definitions"]
assert {tool_definition["name"] for tool_definition in tool_definitions} == {"continue", "no_reply"}
return _build_chat_response([
ToolCall(call_id="disabled-wait", func_name="wait", args={"seconds": 3}),
])
async def _fail_invoke_tool_call(*args: object, **kwargs: object) -> None:
del args, kwargs
raise AssertionError("群聊中禁用的 wait 不应被执行")
monkeypatch.setattr(engine, "_run_timing_gate_sub_agent", _fake_timing_gate_sub_agent)
monkeypatch.setattr(engine, "_invoke_tool_call", _fail_invoke_tool_call)
action, _, tool_results, _ = await engine._run_timing_gate(object()) # type: ignore[arg-type]
assert action == "no_reply"
assert runtime.stopped is True
assert tool_results[-1] == "- no_reply [非法 Timing 工具]: 返回了 wait已停止本轮并等待新消息"
def test_timing_gate_invalid_tool_hint_keeps_only_latest() -> None:
old_hint = SimpleNamespace(source=TIMING_GATE_INVALID_TOOL_HINT_SOURCE)
runtime = SimpleNamespace(_chat_history=[old_hint])
engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type]
engine._append_timing_gate_invalid_tool_hint("finish")
engine._append_timing_gate_invalid_tool_hint("reply")
assert len(runtime._chat_history) == 1
hint_message = runtime._chat_history[0]
assert hint_message.source == TIMING_GATE_INVALID_TOOL_HINT_SOURCE
assert "reply" in hint_message.processed_plain_text
assert "finish" not in hint_message.processed_plain_text
def test_timing_gate_invalid_tool_hint_only_visible_to_timing_gate() -> None:
runtime = SimpleNamespace(_chat_history=[])
engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type]
engine._append_timing_gate_invalid_tool_hint("finish")
hint_message = runtime._chat_history[0]
timing_history = MaisakaChatLoopService._filter_history_for_request_kind(
[hint_message],
request_kind="timing_gate",
)
planner_history = MaisakaChatLoopService._filter_history_for_request_kind(
[hint_message],
request_kind="planner",
)
assert timing_history == [hint_message]
assert planner_history == []
def test_forced_timing_trigger_bypasses_message_frequency_threshold() -> None:
runtime = SimpleNamespace(
_STATE_WAIT="wait",
_agent_state="stop",
_message_turn_scheduled=False,
_internal_turn_queue=asyncio.Queue(),
_has_pending_messages=lambda: True,
_get_pending_message_count=lambda: 1,
_has_forced_timing_trigger=lambda: True,
_cancel_deferred_message_turn_task=lambda: None,
)
def _fail_get_message_trigger_threshold() -> int:
raise AssertionError("@/提及必回不应被普通聊天频率阈值拦住")
runtime._get_message_trigger_threshold = _fail_get_message_trigger_threshold
MaisakaHeartFlowChatting._schedule_message_turn(runtime) # type: ignore[arg-type]
assert runtime._message_turn_scheduled is True
assert runtime._internal_turn_queue.get_nowait() == "message"
def test_finish_tool_is_not_written_back_to_history() -> None:
finish_call = ToolCall(call_id="finish-call", func_name="finish", args={})
reply_call = ToolCall(call_id="reply-call", func_name="reply", args={})
assistant_message = AssistantMessage(
content="当前不需要继续回复。",
timestamp=datetime.now(),
tool_calls=[finish_call, reply_call],
)
runtime = SimpleNamespace(_chat_history=[assistant_message])
engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type]
engine._append_tool_execution_result(
finish_call,
ToolExecutionResult(
tool_name="finish",
success=True,
content="当前对话循环已结束本轮思考,等待新的消息到来。",
),
)
assert runtime._chat_history == [assistant_message]
assert [tool_call.func_name for tool_call in assistant_message.tool_calls] == ["reply"]
def test_finish_tool_removes_empty_assistant_history_message() -> None:
finish_call = ToolCall(call_id="finish-call", func_name="finish", args={})
assistant_message = AssistantMessage(
content="",
timestamp=datetime.now(),
tool_calls=[finish_call],
)
runtime = SimpleNamespace(_chat_history=[assistant_message])
engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type]
engine._append_tool_execution_result(
finish_call,
ToolExecutionResult(tool_name="finish", success=True),
)
assert runtime._chat_history == []
def test_timing_gate_head_trim_keeps_short_history() -> None:
messages = [
AssistantMessage(content="第一条消息", timestamp=datetime.now()),
AssistantMessage(content="第二条消息", timestamp=datetime.now()),
]
trimmed_messages = MaisakaHeartFlowChatting._drop_head_context_messages(
messages,
drop_context_count=3,
)
assert trimmed_messages == messages
def test_timing_gate_head_trim_keeps_history_within_config_limit() -> None:
messages = [
AssistantMessage(content=f"消息 {index}", timestamp=datetime.now())
for index in range(10)
]
trimmed_messages = MaisakaHeartFlowChatting._drop_head_context_messages(
messages,
drop_context_count=7,
trim_threshold_context_count=10,
)
assert trimmed_messages == messages
def test_timing_gate_head_trim_applies_after_config_limit_exceeded() -> None:
messages = [
AssistantMessage(content=f"消息 {index}", timestamp=datetime.now())
for index in range(11)
]
trimmed_messages = MaisakaHeartFlowChatting._drop_head_context_messages(
messages,
drop_context_count=7,
trim_threshold_context_count=10,
)
assert trimmed_messages == messages[7:]

View File

@@ -1,170 +0,0 @@
"""消息网关运行时状态同步测试。"""
from typing import Any, Dict
import pytest
from src.platform_io.manager import PlatformIOManager
from src.platform_io.types import RouteKey
from src.plugin_runtime.host.supervisor import PluginSupervisor
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
def _make_request(method: str, plugin_id: str, payload: Dict[str, Any]) -> Envelope:
"""构造一个 RPC 请求信封。
Args:
method: RPC 方法名。
plugin_id: 目标插件 ID。
payload: 请求载荷。
Returns:
Envelope: 标准 RPC 请求信封。
"""
return Envelope(
request_id=1,
message_type=MessageType.REQUEST,
method=method,
plugin_id=plugin_id,
payload=payload,
)
@pytest.mark.asyncio
async def test_message_gateway_runtime_state_binds_send_and_receive_routes(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""消息网关就绪后应同时绑定发送表和接收表。"""
import src.plugin_runtime.host.supervisor as supervisor_module
platform_io_manager = PlatformIOManager()
monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager)
supervisor = PluginSupervisor(plugin_dirs=[])
register_response = await supervisor._handle_register_plugin(
_make_request(
"plugin.register_components",
"napcat_plugin",
{
"plugin_id": "napcat_plugin",
"plugin_version": "1.0.0",
"components": [
{
"name": "napcat_gateway",
"component_type": "MESSAGE_GATEWAY",
"plugin_id": "napcat_plugin",
"metadata": {
"route_type": "duplex",
"platform": "qq",
"protocol": "napcat",
},
}
],
"capabilities_required": [],
},
)
)
assert register_response.error is None
response = await supervisor._handle_update_message_gateway_state(
_make_request(
"host.update_message_gateway_state",
"napcat_plugin",
{
"gateway_name": "napcat_gateway",
"ready": True,
"platform": "qq",
"account_id": "10001",
"scope": "primary",
"metadata": {},
},
)
)
assert response.error is None
assert response.payload["accepted"] is True
send_bindings = platform_io_manager.send_route_table.resolve_bindings(
RouteKey(platform="qq", account_id="10001", scope="primary")
)
receive_bindings = platform_io_manager.receive_route_table.resolve_bindings(
RouteKey(platform="qq", account_id="10001", scope="primary")
)
assert [binding.driver_id for binding in send_bindings] == ["gateway:napcat_plugin:napcat_gateway"]
assert [binding.driver_id for binding in receive_bindings] == ["gateway:napcat_plugin:napcat_gateway"]
@pytest.mark.asyncio
async def test_message_gateway_runtime_state_unbinds_routes_when_not_ready(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""消息网关断开后应撤销发送表和接收表中的绑定。"""
import src.plugin_runtime.host.supervisor as supervisor_module
platform_io_manager = PlatformIOManager()
monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager)
supervisor = PluginSupervisor(plugin_dirs=[])
await supervisor._handle_register_plugin(
_make_request(
"plugin.register_components",
"napcat_plugin",
{
"plugin_id": "napcat_plugin",
"plugin_version": "1.0.0",
"components": [
{
"name": "napcat_gateway",
"component_type": "MESSAGE_GATEWAY",
"plugin_id": "napcat_plugin",
"metadata": {
"route_type": "duplex",
"platform": "qq",
"protocol": "napcat",
},
}
],
"capabilities_required": [],
},
)
)
await supervisor._handle_update_message_gateway_state(
_make_request(
"host.update_message_gateway_state",
"napcat_plugin",
{
"gateway_name": "napcat_gateway",
"ready": True,
"platform": "qq",
"account_id": "10001",
"scope": "primary",
"metadata": {},
},
)
)
response = await supervisor._handle_update_message_gateway_state(
_make_request(
"host.update_message_gateway_state",
"napcat_plugin",
{
"gateway_name": "napcat_gateway",
"ready": False,
"platform": "qq",
"account_id": "",
"scope": "",
"metadata": {},
},
)
)
assert response.error is None
assert response.payload["accepted"] is True
assert platform_io_manager.send_route_table.resolve_bindings(RouteKey(platform="qq", account_id="10001")) == []
assert (
platform_io_manager.receive_route_table.resolve_bindings(RouteKey(platform="qq", account_id="10001")) == []
)

View File

@@ -1,883 +0,0 @@
"""NapCat 插件与新 SDK 对接测试。"""
from importlib import import_module, util
from pathlib import Path
from typing import Any, Dict, List, Tuple
import logging
import sys
from types import SimpleNamespace
import pytest
PROJECT_ROOT = Path(__file__).resolve().parents[1]
PLUGINS_ROOT = PROJECT_ROOT / "plugins"
PLUGIN_TEMPLATE_ROOT = PROJECT_ROOT / "plugin-templates"
SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk"
NAPCAT_PLUGIN_DIR = PLUGINS_ROOT / "MaiBot-Napcat-Adapter"
NAPCAT_TEMPLATE_DIR = PLUGIN_TEMPLATE_ROOT / "MaiBot-Napcat-Adapter"
NAPCAT_TEST_MODULE = "_test_napcat_adapter"
for import_path in (str(SDK_ROOT),):
if import_path not in sys.path:
sys.path.insert(0, import_path)
class _FakeGatewayCapability:
"""用于捕获消息网关状态上报的测试替身。"""
def __init__(self) -> None:
"""初始化测试替身。"""
self.calls: List[Dict[str, Any]] = []
async def update_state(
self,
gateway_name: str,
*,
ready: bool,
platform: str = "",
account_id: str = "",
scope: str = "",
metadata: Dict[str, Any] | None = None,
) -> bool:
"""记录一次状态上报请求。
Args:
gateway_name: 网关组件名称。
ready: 当前是否就绪。
platform: 平台名称。
account_id: 账号 ID。
scope: 路由作用域。
metadata: 附加元数据。
Returns:
bool: 始终返回 ``True``,模拟 Host 接受状态更新。
"""
self.calls.append(
{
"gateway_name": gateway_name,
"ready": ready,
"platform": platform,
"account_id": account_id,
"scope": scope,
"metadata": metadata or {},
}
)
return True
class _FakeNapCatQueryService:
"""用于驱动 NapCat 入站编解码测试的查询服务替身。"""
def __init__(
self,
forward_payloads: Dict[str, Any] | None = None,
group_member_payloads: Dict[tuple[str, str], Dict[str, Any] | None] | None = None,
stranger_payloads: Dict[str, Dict[str, Any] | None] | None = None,
) -> None:
"""初始化查询服务替身。
Args:
forward_payloads: 预置的合并转发响应映射。
group_member_payloads: 预置的群成员资料映射。
stranger_payloads: 预置的陌生人资料映射。
"""
self._forward_payloads = forward_payloads or {}
self._group_member_payloads = group_member_payloads or {}
self._stranger_payloads = stranger_payloads or {}
async def download_binary(self, url: str) -> bytes | None:
"""模拟下载远程二进制资源。
Args:
url: 资源地址。
Returns:
bytes | None: 测试中默认不返回二进制内容。
"""
del url
return None
async def get_message_detail(self, message_id: str) -> Dict[str, Any] | None:
"""模拟获取消息详情。
Args:
message_id: 消息 ID。
Returns:
Dict[str, Any] | None: 测试中默认不返回详情。
"""
del message_id
return None
async def get_forward_message(
self,
message_id: str | None = None,
forward_id: str | None = None,
) -> Any:
"""模拟获取合并转发消息详情。
Args:
message_id: 转发消息 ID。
forward_id: 兼容字段 ``id``。
Returns:
Any: 预置的合并转发消息详情。
"""
return self._forward_payloads.get(forward_id or message_id or "")
async def get_group_member_info(
self,
group_id: str,
user_id: str,
no_cache: bool = True,
) -> Dict[str, Any] | None:
"""模拟获取群成员资料。"""
del no_cache
return self._group_member_payloads.get((group_id, user_id))
async def get_stranger_info(self, user_id: str, no_cache: bool = False) -> Dict[str, Any] | None:
"""模拟获取 QQ 昵称资料。"""
del no_cache
return self._stranger_payloads.get(user_id)
async def get_record_detail(
self,
file_name: str | None = None,
file_id: str | None = None,
out_format: str = "wav",
) -> Dict[str, Any] | None:
"""模拟获取语音详情。
Args:
file_name: 文件名。
file_id: 文件 ID。
out_format: 输出格式。
Returns:
Dict[str, Any] | None: 测试中默认不返回语音详情。
"""
del file_name
del file_id
del out_format
return None
class _FakeNapCatActionService:
"""用于驱动 NapCat 查询服务测试的动作服务替身。"""
def __init__(self, response_data: Any) -> None:
"""初始化动作服务替身。
Args:
response_data: 预置的 ``safe_call_action_data`` 返回值。
"""
self._response_data = response_data
self.action_calls: List[Dict[str, Any]] = []
self.action_data_calls: List[Dict[str, Any]] = []
async def safe_call_action_data(self, action_name: str, params: Dict[str, Any]) -> Any:
"""模拟安全调用 OneBot 动作。
Args:
action_name: 动作名称。
params: 动作参数。
Returns:
Any: 预置返回值。
"""
self.action_data_calls.append({"action_name": action_name, "params": dict(params)})
return self._response_data
async def call_action(self, action_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
"""模拟调用 OneBot 动作并记录参数。"""
self.action_calls.append({"action_name": action_name, "params": dict(params)})
return {"status": "ok", "retcode": 0, "data": {}}
def _load_napcat_sdk_modules() -> Tuple[Any, Any, Any, Any]:
"""动态加载 NapCat 插件测试所需的模块。
Returns:
tuple[Any, Any, Any, Any]:
依次返回常量模块、配置模块、插件模块和运行时状态模块。
"""
plugin_dir = NAPCAT_PLUGIN_DIR if NAPCAT_PLUGIN_DIR.is_dir() else NAPCAT_TEMPLATE_DIR
if NAPCAT_TEST_MODULE not in sys.modules:
plugin_path = plugin_dir / "plugin.py"
spec = util.spec_from_file_location(
NAPCAT_TEST_MODULE,
plugin_path,
submodule_search_locations=[str(plugin_dir)],
)
if spec is None or spec.loader is None:
raise ImportError(f"无法为 NapCat 插件创建模块规格: {plugin_path}")
module = util.module_from_spec(spec)
sys.modules[NAPCAT_TEST_MODULE] = module
try:
spec.loader.exec_module(module)
except Exception:
sys.modules.pop(NAPCAT_TEST_MODULE, None)
raise
return (
import_module(f"{NAPCAT_TEST_MODULE}.constants"),
import_module(f"{NAPCAT_TEST_MODULE}.config"),
import_module(f"{NAPCAT_TEST_MODULE}.plugin"),
import_module(f"{NAPCAT_TEST_MODULE}.runtime_state"),
)
def _load_napcat_sdk_symbols() -> Tuple[Any, Any, Any, Any]:
"""动态加载 NapCat 插件测试所需的符号。
Returns:
tuple[Any, Any, Any, Any]:
依次返回网关名常量、配置类、插件类和运行时状态管理器类。
"""
constants_module, config_module, plugin_module, runtime_state_module = _load_napcat_sdk_modules()
return (
constants_module.NAPCAT_GATEWAY_NAME,
config_module.NapCatServerConfig,
plugin_module.NapCatAdapterPlugin,
runtime_state_module.NapCatRuntimeStateManager,
)
def _load_napcat_inbound_codec_cls() -> Any:
"""动态加载 NapCat 入站编解码器类。
Returns:
Any: ``NapCatInboundCodec`` 类对象。
"""
_load_napcat_sdk_modules()
codec_module = import_module(f"{NAPCAT_TEST_MODULE}.codecs.inbound.message_codec")
return codec_module.NapCatInboundCodec
def _load_napcat_query_service_cls() -> Any:
"""动态加载 NapCat 查询服务类。
Returns:
Any: ``NapCatQueryService`` 类对象。
"""
_load_napcat_sdk_modules()
query_service_module = import_module(f"{NAPCAT_TEST_MODULE}.services.query_service")
return query_service_module.NapCatQueryService
def test_napcat_plugin_collects_duplex_message_gateway() -> None:
"""NapCat 插件应声明新的双工消息网关组件。"""
napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols()
plugin = napcat_plugin_cls()
components = plugin.get_components()
gateway_components = [
component
for component in components
if component.get("type") == "MESSAGE_GATEWAY"
]
assert len(gateway_components) == 1
gateway_component = gateway_components[0]
assert gateway_component["name"] == napcat_gateway_name
assert gateway_component["metadata"]["route_type"] == "duplex"
assert gateway_component["metadata"]["platform"] == "qq"
assert gateway_component["metadata"]["protocol"] == "napcat"
def test_napcat_plugin_uses_sdk_config_model() -> None:
"""NapCat 插件应声明 SDK 配置模型并暴露默认配置与 Schema。"""
constants_module, _config_module, plugin_module, _runtime_state_module = _load_napcat_sdk_modules()
plugin = plugin_module.NapCatAdapterPlugin()
default_config = plugin.get_default_config()
schema = plugin.get_webui_config_schema(plugin_id="maibot-team.napcat-adapter")
assert default_config["plugin"]["config_version"] == constants_module.SUPPORTED_CONFIG_VERSION
assert default_config["chat"]["ban_qq_bot"] is False
assert default_config["filters"]["ignore_self_message"] is True
assert schema["plugin_id"] == "maibot-team.napcat-adapter"
assert schema["sections"]["chat"]["fields"]["group_list"]["type"] == "array"
assert schema["sections"]["chat"]["fields"]["group_list_type"]["choices"] == ["whitelist", "blacklist"]
def test_napcat_plugin_normalizes_legacy_config_values() -> None:
"""NapCat 插件应兼容旧配置字段并输出规范化结果。"""
constants_module, _config_module, plugin_module, _runtime_state_module = _load_napcat_sdk_modules()
plugin = plugin_module.NapCatAdapterPlugin()
plugin.set_plugin_config(
{
"plugin": {"enabled": True, "config_version": constants_module.SUPPORTED_CONFIG_VERSION},
"connection": {
"access_token": "secret-token",
"heartbeat_sec": "45",
"ws_url": "ws://10.0.0.8:3012/onebot/v11/ws",
},
"chat": {
"ban_qq_bot": True,
"ban_user_id": ["42", 42, ""],
"group_list": [123, " 456 ", None, "123"],
"group_list_type": "whitelist",
"private_list": "invalid",
"private_list_type": "unexpected",
},
"filters": {"ignore_self_message": True},
}
)
config_data = plugin.get_plugin_config_data()
assert "connection" not in config_data
assert config_data["plugin"]["config_version"] == constants_module.SUPPORTED_CONFIG_VERSION
assert config_data["napcat_server"]["host"] == "10.0.0.8"
assert config_data["napcat_server"]["port"] == 3012
assert config_data["napcat_server"]["token"] == "secret-token"
assert config_data["napcat_server"]["heartbeat_interval"] == 45.0
assert config_data["chat"]["group_list"] == ["123", "456"]
assert config_data["chat"]["private_list"] == []
assert config_data["chat"]["private_list_type"] == constants_module.DEFAULT_CHAT_LIST_TYPE
assert plugin.config.napcat_server.build_ws_url() == "ws://10.0.0.8:3012"
@pytest.mark.asyncio
async def test_runtime_state_reports_via_gateway_capability() -> None:
"""NapCat 运行时状态应通过新的消息网关能力上报。"""
napcat_gateway_name, napcat_server_config_cls, _napcat_plugin_cls, runtime_state_cls = _load_napcat_sdk_symbols()
gateway_capability = _FakeGatewayCapability()
runtime_state_manager = runtime_state_cls(
gateway_capability=gateway_capability,
logger=logging.getLogger("test.napcat_adapter"),
gateway_name=napcat_gateway_name,
)
connected = await runtime_state_manager.report_connected(
"10001",
napcat_server_config_cls(connection_id="primary"),
)
await runtime_state_manager.report_disconnected()
assert connected is True
assert gateway_capability.calls[0]["gateway_name"] == napcat_gateway_name
assert gateway_capability.calls[0]["ready"] is True
assert gateway_capability.calls[0]["platform"] == "qq"
assert gateway_capability.calls[0]["account_id"] == "10001"
assert gateway_capability.calls[0]["scope"] == "primary"
assert gateway_capability.calls[1]["gateway_name"] == napcat_gateway_name
assert gateway_capability.calls[1]["ready"] is False
assert gateway_capability.calls[1]["platform"] == "qq"
@pytest.mark.asyncio
async def test_napcat_plugin_send_result_contains_message_id_echo_callback() -> None:
"""NapCat 插件发送成功后应显式返回消息 ID 回调数据。"""
_napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols()
plugin = napcat_plugin_cls()
class _FakeOutboundCodec:
"""用于测试的出站编码器替身。"""
@staticmethod
def build_outbound_action(message: Dict[str, Any], route: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
"""返回固定动作与参数。"""
del message
del route
return "send_msg", {"message": "hello"}
class _FakeTransport:
"""用于测试的传输层替身。"""
@staticmethod
async def call_action(action_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
"""返回带平台消息 ID 的成功响应。"""
del action_name
del params
return {
"status": "ok",
"data": {
"message_id": "platform-message-id",
},
}
plugin._require_runtime_bundle = lambda: SimpleNamespace( # type: ignore[method-assign]
outbound_codec=_FakeOutboundCodec(),
transport=_FakeTransport(),
)
result = await plugin.handle_napcat_gateway(
message={"message_id": "internal-message-id"},
route={},
)
assert result["success"] is True
assert result["external_message_id"] == "platform-message-id"
assert result["metadata"]["adapter_callbacks"] == [
{
"name": "message_id_echo",
"payload": {
"content": {
"type": "echo",
"echo": "internal-message-id",
"actual_id": "platform-message-id",
}
},
}
]
@pytest.mark.asyncio
async def test_inbound_codec_parses_forward_nodes_from_legacy_message_field() -> None:
"""入站编解码器应兼容旧版 ``sender + message`` 转发节点结构。"""
inbound_codec_cls = _load_napcat_inbound_codec_cls()
codec = inbound_codec_cls(
logger=logging.getLogger("test.napcat_adapter.forward_legacy"),
query_service=_FakeNapCatQueryService(
forward_payloads={
"forward-1": {
"messages": [
{
"sender": {"user_id": "10001", "nickname": "张三", "card": "群名片"},
"message_id": "node-1",
"message": [{"type": "text", "data": {"text": "第一条转发"}}],
}
]
}
}
),
)
segments, is_at = await codec.convert_segments(
{"message": [{"type": "forward", "data": {"id": "forward-1"}}]},
"",
)
assert is_at is False
assert len(segments) == 1
assert segments[0]["type"] == "forward"
assert segments[0]["data"][0]["user_id"] == "10001"
assert segments[0]["data"][0]["user_nickname"] == "张三"
assert segments[0]["data"][0]["user_cardname"] == "群名片"
assert segments[0]["data"][0]["content"] == [{"type": "text", "data": "第一条转发"}]
@pytest.mark.asyncio
async def test_inbound_codec_parses_nested_inline_forward_content() -> None:
"""入站编解码器应支持内联 ``content`` 形式的嵌套合并转发。"""
inbound_codec_cls = _load_napcat_inbound_codec_cls()
codec = inbound_codec_cls(
logger=logging.getLogger("test.napcat_adapter.forward_nested"),
query_service=_FakeNapCatQueryService(
forward_payloads={
"forward-outer": {
"messages": [
{
"sender": {"user_id": "10001", "nickname": "张三"},
"message_id": "node-outer",
"message": [
{
"type": "forward",
"data": {
"content": [
{
"sender": {"user_id": "10002", "nickname": "李四"},
"message_id": "node-inner",
"message": [{"type": "text", "data": {"text": "内层消息"}}],
}
]
},
}
],
}
]
}
}
),
)
segments, _ = await codec.convert_segments(
{"message": [{"type": "forward", "data": {"id": "forward-outer"}}]},
"",
)
assert len(segments) == 1
assert segments[0]["type"] == "forward"
outer_content = segments[0]["data"][0]["content"]
assert len(outer_content) == 1
assert outer_content[0]["type"] == "forward"
nested_nodes = outer_content[0]["data"]
assert nested_nodes[0]["user_id"] == "10002"
assert nested_nodes[0]["user_nickname"] == "李四"
assert nested_nodes[0]["content"] == [{"type": "text", "data": "内层消息"}]
@pytest.mark.asyncio
async def test_inbound_codec_resolves_at_to_group_cardname() -> None:
"""入站编解码器应优先将 ``at`` 解析为群昵称。"""
inbound_codec_cls = _load_napcat_inbound_codec_cls()
codec = inbound_codec_cls(
logger=logging.getLogger("test.napcat_adapter.at_cardname"),
query_service=_FakeNapCatQueryService(
group_member_payloads={
("12345", "1206069534"): {
"nickname": "QQ昵称",
"card": "群昵称",
}
}
),
)
message_dict = await codec.build_message_dict(
payload={
"message_type": "group",
"group_id": "12345",
"message_id": "msg-1",
"message": [{"type": "at", "data": {"qq": "1206069534"}}],
"sender": {"user_id": "10001", "nickname": "发送者"},
"time": 1710000000,
},
self_id="20001",
sender_user_id="10001",
sender={"user_id": "10001", "nickname": "发送者"},
)
assert message_dict["processed_plain_text"] == "@群昵称"
assert message_dict["raw_message"] == [
{
"type": "at",
"data": {
"target_user_id": "1206069534",
"target_user_nickname": "QQ昵称",
"target_user_cardname": "群昵称",
},
}
]
@pytest.mark.asyncio
async def test_inbound_codec_falls_back_to_qq_nickname_when_group_cardname_is_empty() -> None:
"""入站编解码器在群昵称为空时应回退到 QQ 昵称。"""
inbound_codec_cls = _load_napcat_inbound_codec_cls()
codec = inbound_codec_cls(
logger=logging.getLogger("test.napcat_adapter.at_nickname"),
query_service=_FakeNapCatQueryService(
group_member_payloads={
("12345", "1206069534"): {
"nickname": "QQ昵称",
"card": "",
}
}
),
)
message_dict = await codec.build_message_dict(
payload={
"message_type": "group",
"group_id": "12345",
"message_id": "msg-2",
"message": [{"type": "at", "data": {"qq": "1206069534"}}],
"sender": {"user_id": "10001", "nickname": "发送者"},
"time": 1710000000,
},
self_id="20001",
sender_user_id="10001",
sender={"user_id": "10001", "nickname": "发送者"},
)
assert message_dict["processed_plain_text"] == "@QQ昵称"
assert message_dict["raw_message"] == [
{
"type": "at",
"data": {
"target_user_id": "1206069534",
"target_user_nickname": "QQ昵称",
"target_user_cardname": None,
},
}
]
@pytest.mark.asyncio
async def test_inbound_codec_falls_back_to_stranger_nickname_when_group_profile_is_missing() -> None:
"""入站编解码器在群资料缺失时应继续回退到 QQ 昵称。"""
inbound_codec_cls = _load_napcat_inbound_codec_cls()
codec = inbound_codec_cls(
logger=logging.getLogger("test.napcat_adapter.at_stranger_nickname"),
query_service=_FakeNapCatQueryService(
group_member_payloads={("12345", "1206069534"): None},
stranger_payloads={"1206069534": {"nickname": "QQ昵称"}},
),
)
message_dict = await codec.build_message_dict(
payload={
"message_type": "group",
"group_id": "12345",
"message_id": "msg-3",
"message": [{"type": "at", "data": {"qq": "1206069534"}}],
"sender": {"user_id": "10001", "nickname": "发送者"},
"time": 1710000000,
},
self_id="20001",
sender_user_id="10001",
sender={"user_id": "10001", "nickname": "发送者"},
)
assert message_dict["processed_plain_text"] == "@QQ昵称"
assert message_dict["raw_message"] == [
{
"type": "at",
"data": {
"target_user_id": "1206069534",
"target_user_nickname": "QQ昵称",
"target_user_cardname": None,
},
}
]
@pytest.mark.asyncio
async def test_query_service_normalizes_forward_payload_list() -> None:
"""查询服务应兼容 ``get_forward_msg`` 直接返回节点列表。"""
query_service_cls = _load_napcat_query_service_cls()
query_service = query_service_cls(
action_service=_FakeNapCatActionService(
[
{
"sender": {"user_id": "10001", "nickname": "张三"},
"message_id": "node-1",
"message": [{"type": "text", "data": {"text": "列表返回"}}],
}
]
),
logger=logging.getLogger("test.napcat_adapter.query_service"),
)
forward_payload = await query_service.get_forward_message("forward-1")
assert forward_payload == {
"messages": [
{
"sender": {"user_id": "10001", "nickname": "张三"},
"message_id": "node-1",
"message": [{"type": "text", "data": {"text": "列表返回"}}],
}
]
}
@pytest.mark.asyncio
async def test_query_service_supports_official_no_cache_for_get_stranger_info() -> None:
"""查询服务应按官方字段下发 ``no_cache``。"""
action_service = _FakeNapCatActionService({"nickname": "测试用户"})
query_service_cls = _load_napcat_query_service_cls()
query_service = query_service_cls(
action_service=action_service,
logger=logging.getLogger("test.napcat_adapter.query_service.stranger"),
)
payload = await query_service.get_stranger_info("10001", no_cache=True)
assert payload == {"nickname": "测试用户"}
assert action_service.action_data_calls[-1] == {
"action_name": "get_stranger_info",
"params": {"user_id": "10001", "no_cache": True},
}
@pytest.mark.asyncio
async def test_query_service_supports_official_forward_id_alias() -> None:
"""查询服务应兼容官方 ``id`` 字段调用 ``get_forward_msg``。"""
action_service = _FakeNapCatActionService({"messages": []})
query_service_cls = _load_napcat_query_service_cls()
query_service = query_service_cls(
action_service=action_service,
logger=logging.getLogger("test.napcat_adapter.query_service.forward_alias"),
)
payload = await query_service.get_forward_message(forward_id="forward-alias")
assert payload == {"messages": []}
assert action_service.action_data_calls[-1] == {
"action_name": "get_forward_msg",
"params": {"id": "forward-alias"},
}
@pytest.mark.asyncio
async def test_query_service_supports_custom_out_format_for_get_record() -> None:
"""查询服务应按官方字段下发自定义 ``out_format``。"""
action_service = _FakeNapCatActionService({"file": "voice.mp3"})
query_service_cls = _load_napcat_query_service_cls()
query_service = query_service_cls(
action_service=action_service,
logger=logging.getLogger("test.napcat_adapter.query_service.record"),
)
payload = await query_service.get_record_detail(file_id="record-1", out_format="mp3")
assert payload == {"file": "voice.mp3"}
assert action_service.action_data_calls[-1] == {
"action_name": "get_record",
"params": {"file_id": "record-1", "out_format": "mp3"},
}
@pytest.mark.asyncio
async def test_query_service_supports_target_id_for_send_poke() -> None:
"""查询服务应按官方字段下发 ``target_id``。"""
action_service = _FakeNapCatActionService(None)
query_service_cls = _load_napcat_query_service_cls()
query_service = query_service_cls(
action_service=action_service,
logger=logging.getLogger("test.napcat_adapter.query_service.poke"),
)
response = await query_service.send_poke(user_id=10001, group_id=20002, target_id=30003)
assert response["status"] == "ok"
assert action_service.action_calls[-1] == {
"action_name": "send_poke",
"params": {"user_id": 10001, "group_id": 20002, "target_id": 30003},
}
@pytest.mark.asyncio
async def test_public_api_send_poke_supports_official_fields_and_legacy_alias() -> None:
"""公开 API 应同时兼容官方字段和旧版 ``qq_id`` 别名。"""
_napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols()
plugin = napcat_plugin_cls()
captured: List[Dict[str, Any]] = []
class _SpyQueryService:
async def send_poke(
self,
user_id: int,
group_id: int | None = None,
target_id: int | None = None,
) -> Dict[str, Any]:
captured.append(
{
"user_id": user_id,
"group_id": group_id,
"target_id": target_id,
}
)
return {"status": "ok", "data": {}}
plugin._query_service = _SpyQueryService()
plugin._ensure_runtime_components = lambda: None # type: ignore[method-assign]
await plugin.api_send_poke(user_id="10001", group_id="20002", target_id="30003")
await plugin.api_send_poke(qq_id="40004")
assert captured == [
{"user_id": 10001, "group_id": 20002, "target_id": 30003},
{"user_id": 40004, "group_id": None, "target_id": None},
]
@pytest.mark.asyncio
async def test_public_api_get_forward_msg_and_get_record_support_official_fields() -> None:
"""公开 API 应接受官方 ``id`` 和 ``out_format`` 等字段。"""
_napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols()
plugin = napcat_plugin_cls()
captured: Dict[str, Dict[str, Any]] = {}
class _SpyQueryService:
async def get_forward_message(
self,
message_id: str | None = None,
forward_id: str | None = None,
) -> Dict[str, Any]:
captured["forward"] = {"message_id": message_id, "forward_id": forward_id}
return {"messages": []}
async def get_record_detail(
self,
file_name: str | None = None,
file_id: str | None = None,
out_format: str = "wav",
) -> Dict[str, Any]:
captured["record"] = {
"file_name": file_name,
"file_id": file_id,
"out_format": out_format,
}
return {"file_id": file_id or "record-1"}
plugin._query_service = _SpyQueryService()
plugin._ensure_runtime_components = lambda: None # type: ignore[method-assign]
forward_payload = await plugin.api_get_forward_msg(id="forward-1")
record_payload = await plugin.api_get_record(file_id="record-1", out_format="mp3")
assert forward_payload == {"messages": []}
assert record_payload == {"file_id": "record-1"}
assert captured["forward"] == {"message_id": None, "forward_id": "forward-1"}
assert captured["record"] == {
"file_name": None,
"file_id": "record-1",
"out_format": "mp3",
}
@pytest.mark.asyncio
async def test_public_api_send_poke_rejects_conflicting_alias_values() -> None:
"""公开 ``send_poke`` API 应拒绝互相冲突的别名值。"""
_napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols()
plugin = napcat_plugin_cls()
plugin._ensure_runtime_components = lambda: None # type: ignore[method-assign]
with pytest.raises(ValueError, match="user_id 与 qq_id 不能同时传递不同的值"):
await plugin.api_send_poke(user_id="10001", qq_id="20002")
@pytest.mark.asyncio
async def test_public_api_get_forward_msg_rejects_conflicting_fields() -> None:
"""公开 ``get_forward_msg`` API 应拒绝冲突的双字段调用。"""
_napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols()
plugin = napcat_plugin_cls()
plugin._ensure_runtime_components = lambda: None # type: ignore[method-assign]
with pytest.raises(ValueError, match="message_id 与 id 不能同时传递不同的值"):
await plugin.api_get_forward_msg(message_id="forward-a", id="forward-b")
@pytest.mark.asyncio
async def test_public_api_get_record_requires_file_or_file_id() -> None:
"""公开 ``get_record`` API 至少需要一个官方定位字段。"""
_napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols()
plugin = napcat_plugin_cls()
plugin._ensure_runtime_components = lambda: None # type: ignore[method-assign]
with pytest.raises(ValueError, match="file 或 file_id 至少提供一个"):
await plugin.api_get_record()

View File

@@ -1,376 +0,0 @@
"""NapCat 历史补拉与恢复状态测试。"""
from __future__ import annotations
from importlib import import_module, util
from pathlib import Path
from typing import Any, Dict, List
import logging
import sys
from types import SimpleNamespace
import pytest
PROJECT_ROOT = Path(__file__).resolve().parents[1]
PLUGINS_ROOT = PROJECT_ROOT / "plugins"
PLUGIN_TEMPLATE_ROOT = PROJECT_ROOT / "plugin-templates"
SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk"
NAPCAT_PLUGIN_DIR = PLUGINS_ROOT / "MaiBot-Napcat-Adapter"
NAPCAT_TEMPLATE_DIR = PLUGIN_TEMPLATE_ROOT / "MaiBot-Napcat-Adapter"
NAPCAT_TEST_MODULE = "_test_napcat_adapter_history_recovery"
for import_path in (str(SDK_ROOT),):
if import_path not in sys.path:
sys.path.insert(0, import_path)
class _FakeGatewayCapability:
"""用于测试入站注入的网关替身。"""
def __init__(self) -> None:
"""初始化测试替身。"""
self.calls: List[Dict[str, Any]] = []
async def route_message(
self,
gateway_name: str,
message: Dict[str, Any],
*,
route_metadata: Dict[str, Any] | None = None,
external_message_id: str = "",
dedupe_key: str = "",
) -> bool:
"""记录入站注入请求并始终模拟成功。"""
self.calls.append(
{
"gateway_name": gateway_name,
"message": dict(message),
"route_metadata": dict(route_metadata or {}),
"external_message_id": external_message_id,
"dedupe_key": dedupe_key,
}
)
return True
def _resolve_napcat_plugin_dir() -> Path:
"""返回当前测试可用的 NapCat 插件目录。"""
if NAPCAT_PLUGIN_DIR.is_dir():
return NAPCAT_PLUGIN_DIR
return NAPCAT_TEMPLATE_DIR
def _load_napcat_module(module_suffix: str) -> Any:
"""动态加载 NapCat 测试模块。"""
plugin_dir = _resolve_napcat_plugin_dir()
if NAPCAT_TEST_MODULE not in sys.modules:
plugin_path = plugin_dir / "plugin.py"
spec = util.spec_from_file_location(
NAPCAT_TEST_MODULE,
plugin_path,
submodule_search_locations=[str(plugin_dir)],
)
if spec is None or spec.loader is None:
raise ImportError(f"无法为 NapCat 插件创建模块规格: {plugin_path}")
module = util.module_from_spec(spec)
sys.modules[NAPCAT_TEST_MODULE] = module
try:
spec.loader.exec_module(module)
except Exception:
sys.modules.pop(NAPCAT_TEST_MODULE, None)
raise
return import_module(f"{NAPCAT_TEST_MODULE}.{module_suffix}")
def _load_history_recovery_store_cls() -> Any:
"""动态加载历史恢复状态仓库类。"""
return _load_napcat_module("services.history_recovery_store").NapCatHistoryRecoveryStore
def _load_query_service_cls() -> Any:
"""动态加载查询服务类。"""
return _load_napcat_module("services.query_service").NapCatQueryService
def _load_router_cls() -> Any:
"""动态加载事件路由器类。"""
return _load_napcat_module("runtime.router").NapCatEventRouter
class _FakeActionService:
"""用于查询服务的动作服务替身。"""
def __init__(self, response_data: Any) -> None:
"""初始化动作服务替身。"""
self._response_data = response_data
self.action_data_calls: List[Dict[str, Any]] = []
async def safe_call_action_data(self, action_name: str, params: Dict[str, Any]) -> Any:
"""记录安全查询动作。"""
self.action_data_calls.append({"action_name": action_name, "params": dict(params)})
return self._response_data
@pytest.mark.asyncio
async def test_history_recovery_store_persists_checkpoint_and_seen_state(tmp_path: Path) -> None:
"""历史恢复状态仓库应持久化 checkpoint 与已补拉标记。"""
store_cls = _load_history_recovery_store_cls()
store = store_cls(
logger=logging.getLogger("test.napcat.history_store"),
storage_path=tmp_path / "history.sqlite3",
)
await store.load()
await store.record_checkpoint(
account_id="10001",
scope="primary",
chat_type="group",
chat_id="20001",
message_id="msg-2",
message_time=200.0,
message_seq=2,
)
await store.record_checkpoint(
account_id="10001",
scope="primary",
chat_type="group",
chat_id="20001",
message_id="msg-1",
message_time=100.0,
message_seq=1,
)
await store.mark_recovered_message_seen(
account_id="10001",
scope="primary",
chat_type="group",
chat_id="20001",
external_message_id="history-1",
)
checkpoints = await store.list_checkpoints("10001", scope="primary")
assert len(checkpoints) == 1
assert checkpoints[0].last_message_id == "msg-2"
assert checkpoints[0].last_message_seq == 2
assert (
await store.has_recovered_message_seen(
account_id="10001",
scope="primary",
chat_type="group",
chat_id="20001",
external_message_id="history-1",
)
is True
)
@pytest.mark.asyncio
async def test_query_service_wraps_group_and_friend_history_actions() -> None:
"""查询服务应按官方动作名封装历史消息接口。"""
query_service_cls = _load_query_service_cls()
action_service = _FakeActionService([{"message_id": "msg-1"}])
query_service = query_service_cls(
action_service=action_service,
logger=logging.getLogger("test.napcat.history_query"),
)
group_payload = await query_service.get_group_message_history("20001", message_seq=123, count=10)
friend_payload = await query_service.get_friend_message_history("30001", count=5, reverse_order=True)
assert group_payload == [{"message_id": "msg-1"}]
assert friend_payload == [{"message_id": "msg-1"}]
assert action_service.action_data_calls == [
{
"action_name": "get_group_msg_history",
"params": {"group_id": "20001", "count": 10, "reverse_order": False, "message_seq": 123},
},
{
"action_name": "get_friend_msg_history",
"params": {"user_id": "30001", "count": 5, "reverse_order": True},
},
]
@pytest.mark.asyncio
async def test_router_recover_recent_history_reinjects_messages_in_order(tmp_path: Path) -> None:
"""重连补拉应按时间顺序将历史消息重新注入原入站路径。"""
history_store_cls = _load_history_recovery_store_cls()
router_cls = _load_router_cls()
gateway_capability = _FakeGatewayCapability()
router = router_cls(
gateway_capability=gateway_capability,
logger=logging.getLogger("test.napcat.history_router"),
gateway_name="napcat_gateway",
load_settings=lambda: SimpleNamespace(
napcat_server=SimpleNamespace(connection_id="primary", heartbeat_interval=30.0),
filters=SimpleNamespace(ignore_self_message=True),
chat=SimpleNamespace(ban_qq_bot=False),
),
)
history_store = history_store_cls(
logger=logging.getLogger("test.napcat.history_router.store"),
storage_path=tmp_path / "router.sqlite3",
)
await history_store.load()
await history_store.record_checkpoint(
account_id="10001",
scope="primary",
chat_type="group",
chat_id="20001",
message_id="msg-1",
message_time=100.0,
message_seq=10,
)
history_calls: List[Dict[str, Any]] = []
history_payloads = [
{
"post_type": "message",
"message_type": "group",
"self_id": "10001",
"group_id": "20001",
"user_id": "30002",
"message_id": "msg-3",
"message_seq": 12,
"time": 102,
"message": [{"type": "text", "data": {"text": "第三条"}}],
"sender": {"user_id": "30002", "nickname": "用户二"},
},
{
"post_type": "message",
"message_type": "group",
"self_id": "10001",
"group_id": "20001",
"user_id": "30001",
"message_id": "msg-2",
"message_seq": 11,
"time": 101,
"message": [{"type": "text", "data": {"text": "第二条"}}],
"sender": {"user_id": "30001", "nickname": "用户一"},
},
]
class _FakeQueryService:
async def get_group_message_history(
self,
group_id: str,
*,
message_seq: int | None = None,
count: int = 20,
reverse_order: bool = False,
) -> List[Dict[str, Any]]:
history_calls.append(
{
"group_id": group_id,
"message_seq": message_seq,
"count": count,
"reverse_order": reverse_order,
}
)
return list(history_payloads)
async def get_friend_message_history(self, user_id: str, **kwargs: Any) -> List[Dict[str, Any]]:
del user_id
del kwargs
return []
class _FakeInboundCodec:
@staticmethod
async def build_message_dict(
payload: Dict[str, Any],
self_id: str,
sender_user_id: str,
sender: Dict[str, Any],
) -> Dict[str, Any]:
del self_id
del sender_user_id
del sender
return {
"message_id": str(payload["message_id"]),
"platform": "qq",
"timestamp": str(float(payload["time"])),
"message_info": {
"user_info": {"user_id": str(payload["user_id"]), "user_nickname": "测试用户"},
"group_info": {"group_id": str(payload["group_id"]), "group_name": "测试群"},
"additional_config": {},
},
"raw_message": [{"type": "text", "data": str(payload["message"][0]["data"]["text"])}],
"processed_plain_text": str(payload["message"][0]["data"]["text"]),
"display_message": str(payload["message"][0]["data"]["text"]),
"is_mentioned": False,
"is_at": False,
"is_emoji": False,
"is_picture": False,
"is_command": False,
"is_notify": False,
"session_id": "",
}
router.bind_runtime(
SimpleNamespace(
runtime_state=SimpleNamespace(report_connected=lambda *args, **kwargs: _noop_async(), report_disconnected=_noop_async),
chat_filter=SimpleNamespace(is_inbound_chat_allowed=lambda *args, **kwargs: True),
official_bot_guard=SimpleNamespace(
should_reject=lambda *args, **kwargs: _return_false_async(),
clear_cache=lambda: None,
),
inbound_codec=_FakeInboundCodec(),
history_recovery_store=history_store,
query_service=_FakeQueryService(),
heartbeat_monitor=SimpleNamespace(start=_noop_async, stop=_noop_async),
ban_tracker=SimpleNamespace(start=_noop_async, stop=_noop_async, record_notice=_noop_async),
notice_codec=SimpleNamespace(handle_meta_event=_noop_async, build_notice_message_dict=_return_none_async),
)
)
await router._recover_recent_history(self_id="10001", scope="primary")
assert history_calls == [
{
"group_id": "20001",
"message_seq": 10,
"count": 20,
"reverse_order": False,
}
]
assert [call["external_message_id"] for call in gateway_capability.calls] == ["msg-2", "msg-3"]
assert [call["message"]["message_id"] for call in gateway_capability.calls] == ["msg-2", "msg-3"]
async def _noop_async(*args: Any, **kwargs: Any) -> None:
"""无操作异步函数。"""
del args
del kwargs
async def _return_false_async(*args: Any, **kwargs: Any) -> bool:
"""返回 ``False`` 的异步测试替身。"""
del args
del kwargs
return False
async def _return_none_async(*args: Any, **kwargs: Any) -> None:
"""返回 ``None`` 的异步测试替身。"""
del args
del kwargs
return None

View File

@@ -1,164 +0,0 @@
from types import SimpleNamespace
import pytest
from src.config.model_configs import APIProvider, ReasoningParseMode, ToolArgumentParseMode
from src.llm_models.model_client.openai_client import (
_OpenAIStreamAccumulator,
_build_reasoning_key,
_default_normal_response_parser,
_parse_tool_arguments,
_sanitize_messages_for_toolless_request,
)
from src.llm_models.payload_content.message import Message, RoleType, TextMessagePart
from src.llm_models.payload_content.tool_option import ToolCall
@pytest.mark.parametrize("parse_mode", list(ToolArgumentParseMode))
def test_parse_tool_arguments_treats_blank_arguments_as_empty_dict(parse_mode: ToolArgumentParseMode) -> None:
assert _parse_tool_arguments("", parse_mode, None) == {}
assert _parse_tool_arguments(" ", parse_mode, None) == {}
def test_normal_response_parser_accepts_empty_string_arguments_for_parameterless_tool() -> None:
response = SimpleNamespace(
choices=[
SimpleNamespace(
finish_reason="tool_calls",
message=SimpleNamespace(
content=None,
tool_calls=[
SimpleNamespace(
id="finish-call",
type="function",
function=SimpleNamespace(name="finish", arguments=""),
)
],
),
)
],
usage=None,
model="glm-5.1",
)
api_response, usage_record = _default_normal_response_parser(
response,
reasoning_parse_mode=ReasoningParseMode.AUTO,
tool_argument_parse_mode=ToolArgumentParseMode.AUTO,
reasoning_key=None,
)
assert len(api_response.tool_calls) == 1
assert api_response.tool_calls[0].func_name == "finish"
assert api_response.tool_calls[0].args == {}
assert usage_record is None
def test_sanitize_messages_for_toolless_request_drops_assistant_tool_call_without_parts() -> None:
messages = [
Message(
role=RoleType.Assistant,
tool_calls=[
ToolCall(
call_id="call_1",
func_name="mute_user",
args={"target": "alice"},
)
],
),
Message(
role=RoleType.User,
parts=[TextMessagePart(text="继续")],
),
]
sanitized_messages = _sanitize_messages_for_toolless_request(messages)
assert len(sanitized_messages) == 1
assert sanitized_messages[0].role == RoleType.User
def test_normal_response_parser_ignores_reasoning_field_for_non_openrouter_provider() -> None:
response = SimpleNamespace(
choices=[
SimpleNamespace(
finish_reason="stop",
message=SimpleNamespace(
content="正式回复",
reasoning="推理内容",
tool_calls=None,
),
)
],
usage=None,
model="openrouter/test-model",
)
api_response, usage_record = _default_normal_response_parser(
response,
reasoning_parse_mode=ReasoningParseMode.AUTO,
tool_argument_parse_mode=ToolArgumentParseMode.AUTO,
reasoning_key=_build_reasoning_key(
APIProvider(name="test", base_url="https://openrouter.ai.example.com/api/v1", api_key="test")
),
)
assert api_response.content == "正式回复"
assert api_response.reasoning_content is None
assert usage_record is None
def test_normal_response_parser_reads_provider_reasoning_field_for_reasoning_domains() -> None:
provider_urls = [
"https://openrouter.ai/compatible-api",
"https://api.groq.com/openai/v1",
]
for provider_url in provider_urls:
response = SimpleNamespace(
choices=[
SimpleNamespace(
finish_reason="stop",
message=SimpleNamespace(
content="正式回复",
reasoning="推理内容",
tool_calls=None,
),
)
],
usage=None,
model="test-model",
)
api_response, usage_record = _default_normal_response_parser(
response,
reasoning_parse_mode=ReasoningParseMode.AUTO,
tool_argument_parse_mode=ToolArgumentParseMode.AUTO,
reasoning_key=_build_reasoning_key(
APIProvider(name="reasoning-provider", base_url=provider_url, api_key="test")
),
)
assert api_response.content == "正式回复"
assert api_response.reasoning_content == "推理内容"
assert usage_record is None
def test_stream_accumulator_reads_openrouter_reasoning_delta_field() -> None:
accumulator = _OpenAIStreamAccumulator(
reasoning_parse_mode=ReasoningParseMode.AUTO,
tool_argument_parse_mode=ToolArgumentParseMode.AUTO,
reasoning_key=_build_reasoning_key(
APIProvider(name="openrouter", base_url="https://openrouter.ai/compatible-api", api_key="test")
),
)
try:
accumulator.process_delta(SimpleNamespace(reasoning="流式推理", content=None, tool_calls=None))
accumulator.process_delta(SimpleNamespace(content="正式回复", tool_calls=None))
api_response = accumulator.build_response()
finally:
accumulator.close()
assert api_response.content == "正式回复"
assert api_response.reasoning_content == "流式推理"

View File

@@ -1,209 +0,0 @@
"""Platform IO 入站去重策略测试。"""
from types import SimpleNamespace
from typing import Any, Dict, List, Optional
import pytest
from src.platform_io.drivers.base import PlatformIODriver
from src.platform_io.manager import PlatformIOManager
from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey
def _build_envelope(
*,
dedupe_key: str | None = None,
external_message_id: str | None = None,
session_message_id: str | None = None,
payload: Optional[Dict[str, Any]] = None,
) -> InboundMessageEnvelope:
"""构造测试用入站信封。
Args:
dedupe_key: 显式去重键。
external_message_id: 平台侧消息 ID。
session_message_id: 规范化消息对象上的消息 ID。
payload: 原始载荷。
Returns:
InboundMessageEnvelope: 测试用入站消息信封。
"""
session_message = None
if session_message_id is not None:
session_message = SimpleNamespace(message_id=session_message_id)
return InboundMessageEnvelope(
route_key=RouteKey(platform="qq", account_id="10001", scope="main"),
driver_id="plugin.napcat",
driver_kind=DriverKind.PLUGIN,
dedupe_key=dedupe_key,
external_message_id=external_message_id,
session_message=session_message,
payload=payload,
)
class _StubPlatformIODriver(PlatformIODriver):
"""测试用 Platform IO 驱动。"""
async def send_message(
self,
message: Any,
route_key: RouteKey,
metadata: Optional[Dict[str, Any]] = None,
) -> DeliveryReceipt:
"""返回一个固定的成功回执。
Args:
message: 待发送的消息对象。
route_key: 本次发送使用的路由键。
metadata: 额外发送元数据。
Returns:
DeliveryReceipt: 固定的成功回执。
"""
return DeliveryReceipt(
internal_message_id=str(getattr(message, "message_id", "stub-message-id")),
route_key=route_key,
status=DeliveryStatus.SENT,
driver_id=self.driver_id,
driver_kind=self.descriptor.kind,
)
def _build_manager() -> PlatformIOManager:
"""构造带有最小接收路由的 Broker 管理器。
Returns:
PlatformIOManager: 已注册测试驱动并绑定接收路由的 Broker。
"""
manager = PlatformIOManager()
driver = _StubPlatformIODriver(
DriverDescriptor(
driver_id="plugin.napcat",
kind=DriverKind.PLUGIN,
platform="qq",
account_id="10001",
scope="main",
)
)
manager.register_driver(driver)
manager.bind_receive_route(
RouteBinding(
route_key=RouteKey(platform="qq", account_id="10001", scope="main"),
driver_id=driver.driver_id,
driver_kind=driver.descriptor.kind,
)
)
return manager
class TestPlatformIODedupe:
"""Platform IO 去重测试。"""
@pytest.mark.asyncio
async def test_accept_inbound_dedupes_by_external_message_id(self) -> None:
"""相同平台消息 ID 的重复入站应被抑制。"""
manager = _build_manager()
accepted_envelopes: List[InboundMessageEnvelope] = []
async def dispatcher(envelope: InboundMessageEnvelope) -> None:
"""记录被成功接收的入站消息。
Args:
envelope: 被 Broker 接受的入站消息。
"""
accepted_envelopes.append(envelope)
manager.set_inbound_dispatcher(dispatcher)
first_envelope = _build_envelope(
external_message_id="msg-1",
payload={"message": "hello"},
)
second_envelope = _build_envelope(
external_message_id="msg-1",
payload={"message": "hello"},
)
assert await manager.accept_inbound(first_envelope) is True
assert await manager.accept_inbound(second_envelope) is False
assert len(accepted_envelopes) == 1
@pytest.mark.asyncio
async def test_accept_inbound_without_stable_identity_does_not_guess_duplicate(self) -> None:
"""缺少稳定身份时,不应仅凭 payload 内容猜测重复消息。"""
manager = _build_manager()
accepted_envelopes: List[InboundMessageEnvelope] = []
async def dispatcher(envelope: InboundMessageEnvelope) -> None:
"""记录被成功接收的入站消息。
Args:
envelope: 被 Broker 接受的入站消息。
"""
accepted_envelopes.append(envelope)
manager.set_inbound_dispatcher(dispatcher)
first_envelope = _build_envelope(payload={"message": "same-payload"})
second_envelope = _build_envelope(payload={"message": "same-payload"})
assert await manager.accept_inbound(first_envelope) is True
assert await manager.accept_inbound(second_envelope) is True
assert len(accepted_envelopes) == 2
def test_build_inbound_dedupe_key_prefers_explicit_identity(self) -> None:
"""去重键应只来自显式或稳定的技术身份。"""
explicit_envelope = _build_envelope(dedupe_key="dedupe-1", external_message_id="msg-1")
session_message_envelope = _build_envelope(session_message_id="session-1")
payload_only_envelope = _build_envelope(payload={"message": "hello"})
assert PlatformIOManager._build_inbound_dedupe_key(explicit_envelope) == "plugin.napcat:dedupe-1"
assert PlatformIOManager._build_inbound_dedupe_key(session_message_envelope) == "plugin.napcat:session-1"
assert PlatformIOManager._build_inbound_dedupe_key(payload_only_envelope) is None
@pytest.mark.asyncio
async def test_send_message_fans_out_to_all_matching_routes(self) -> None:
"""同一路由命中多条发送链路时应全部发送。"""
manager = PlatformIOManager()
first_driver = _StubPlatformIODriver(
DriverDescriptor(
driver_id="plugin.gateway_a",
kind=DriverKind.PLUGIN,
platform="qq",
)
)
second_driver = _StubPlatformIODriver(
DriverDescriptor(
driver_id="plugin.gateway_b",
kind=DriverKind.PLUGIN,
platform="qq",
)
)
manager.register_driver(first_driver)
manager.register_driver(second_driver)
manager.bind_send_route(
RouteBinding(
route_key=RouteKey(platform="qq"),
driver_id=first_driver.driver_id,
driver_kind=first_driver.descriptor.kind,
)
)
manager.bind_send_route(
RouteBinding(
route_key=RouteKey(platform="qq"),
driver_id=second_driver.driver_id,
driver_kind=second_driver.descriptor.kind,
)
)
message = SimpleNamespace(message_id="internal-msg-1")
result = await manager.send_message(message, RouteKey(platform="qq"))
assert result.has_success is True
assert [receipt.driver_id for receipt in result.sent_receipts] == [
"plugin.gateway_a",
"plugin.gateway_b",
]

View File

@@ -1,178 +0,0 @@
"""Platform IO legacy driver 回归测试。"""
from typing import Any, Dict, Optional
import pytest
from src.chat.utils import utils as chat_utils
from src.chat.message_receive import uni_message_sender
from src.platform_io.drivers.base import PlatformIODriver
from src.platform_io.drivers.legacy_driver import LegacyPlatformDriver
from src.platform_io.manager import PlatformIOManager
from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteBinding, RouteKey
class _PluginDriver(PlatformIODriver):
"""测试用插件发送驱动。"""
def __init__(self, driver_id: str, platform: str) -> None:
"""初始化测试驱动。
Args:
driver_id: 驱动 ID。
platform: 负责的平台名称。
"""
super().__init__(
DriverDescriptor(
driver_id=driver_id,
kind=DriverKind.PLUGIN,
platform=platform,
plugin_id="test.plugin",
)
)
async def send_message(
self,
message: Any,
route_key: RouteKey,
metadata: Optional[Dict[str, Any]] = None,
) -> DeliveryReceipt:
"""返回一个固定成功回执。
Args:
message: 待发送消息。
route_key: 当前路由键。
metadata: 发送元数据。
Returns:
DeliveryReceipt: 固定成功回执。
"""
del metadata
return DeliveryReceipt(
internal_message_id=str(message.message_id),
route_key=route_key,
status=DeliveryStatus.SENT,
driver_id=self.driver_id,
driver_kind=self.descriptor.kind,
)
@pytest.mark.asyncio
async def test_platform_io_uses_legacy_driver_when_no_explicit_send_route(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""没有显式发送路由时,应由 Platform IO 回退到 legacy driver。"""
manager = PlatformIOManager()
monkeypatch.setattr(chat_utils, "get_all_bot_accounts", lambda: {"qq": "bot-qq"})
try:
await manager.ensure_send_pipeline_ready()
fallback_drivers = manager.resolve_drivers(RouteKey(platform="qq"))
assert [driver.driver_id for driver in fallback_drivers] == ["legacy.send.qq"]
plugin_driver = _PluginDriver(driver_id="plugin.qq.sender", platform="qq")
await manager.add_driver(plugin_driver)
manager.bind_send_route(
RouteBinding(
route_key=RouteKey(platform="qq"),
driver_id=plugin_driver.driver_id,
driver_kind=plugin_driver.descriptor.kind,
)
)
explicit_drivers = manager.resolve_drivers(RouteKey(platform="qq"))
assert [driver.driver_id for driver in explicit_drivers] == ["plugin.qq.sender", "legacy.send.qq"]
finally:
await manager.stop()
@pytest.mark.asyncio
async def test_platform_io_broadcasts_to_plugin_and_legacy_driver(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""同一路由命中插件驱动与 legacy driver 时,应同时广播发送。"""
manager = PlatformIOManager()
legacy_calls: list[dict[str, Any]] = []
monkeypatch.setattr(chat_utils, "get_all_bot_accounts", lambda: {"qq": "bot-qq"})
async def _fake_send_prepared_message_to_platform(message: Any, show_log: bool = True) -> bool:
"""记录 legacy driver 调用。"""
legacy_calls.append({"message": message, "show_log": show_log})
return True
monkeypatch.setattr(
uni_message_sender,
"send_prepared_message_to_platform",
_fake_send_prepared_message_to_platform,
)
try:
await manager.ensure_send_pipeline_ready()
plugin_driver = _PluginDriver(driver_id="plugin.qq.sender", platform="qq")
await manager.add_driver(plugin_driver)
manager.bind_send_route(
RouteBinding(
route_key=RouteKey(platform="qq"),
driver_id=plugin_driver.driver_id,
driver_kind=plugin_driver.descriptor.kind,
)
)
message = type("FakeMessage", (), {"message_id": "message-1"})()
batch = await manager.send_message(
message=message,
route_key=RouteKey(platform="qq"),
metadata={"show_log": False},
)
assert sorted(receipt.driver_id for receipt in batch.sent_receipts) == [
"legacy.send.qq",
"plugin.qq.sender",
]
assert batch.failed_receipts == []
assert len(legacy_calls) == 1
assert legacy_calls[0]["message"] is message
assert legacy_calls[0]["show_log"] is False
finally:
await manager.stop()
@pytest.mark.asyncio
async def test_legacy_platform_driver_uses_prepared_universal_sender(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""legacy driver 应复用已预处理消息的旧链发送函数。"""
calls: list[dict[str, Any]] = []
async def _fake_send_prepared_message_to_platform(message: Any, show_log: bool = True) -> bool:
"""记录 legacy driver 调用。"""
calls.append({"message": message, "show_log": show_log})
return True
monkeypatch.setattr(
uni_message_sender,
"send_prepared_message_to_platform",
_fake_send_prepared_message_to_platform,
)
driver = LegacyPlatformDriver(
driver_id="legacy.send.qq",
platform="qq",
account_id="bot-qq",
)
message = type("FakeMessage", (), {"message_id": "message-1"})()
receipt = await driver.send_message(
message=message,
route_key=RouteKey(platform="qq"),
metadata={"show_log": False},
)
assert len(calls) == 1
assert calls[0]["message"] is message
assert calls[0]["show_log"] is False
assert receipt.status == DeliveryStatus.SENT
assert receipt.driver_id == "legacy.send.qq"

View File

@@ -1,553 +0,0 @@
"""插件配置运行时测试。"""
from __future__ import annotations
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, Mapping, Optional, Tuple, cast
import tomllib
import pytest
from src.plugin_runtime.component_query import component_query_service
from src.plugin_runtime.protocol.envelope import (
Envelope,
InspectPluginConfigPayload,
MessageType,
RegisterPluginPayload,
ValidatePluginConfigPayload,
)
from src.plugin_runtime.runner.runner_main import PluginRunner
from src.webui.routers.plugin.config_routes import get_plugin_config, get_plugin_config_schema, update_plugin_config
from src.webui.routers.plugin.schemas import UpdatePluginConfigRequest
class _DemoConfigPlugin:
"""用于测试 Runner 配置归一化流程的伪插件。"""
_config_version: str = "2.0.0"
def __init__(self) -> None:
"""初始化测试插件状态。"""
self.received_config: Dict[str, Any] = {}
def normalize_plugin_config(self, config_data: Optional[Mapping[str, Any]]) -> Tuple[Dict[str, Any], bool]:
"""补齐测试插件的默认配置。
Args:
config_data: 原始配置数据。
Returns:
Tuple[Dict[str, Any], bool]: 补齐后的配置,以及是否发生变更。
"""
current_config = dict(config_data or {})
plugin_section = dict(current_config.get("plugin", {}))
changed = "retry_count" not in plugin_section or "config_version" not in plugin_section
plugin_section.setdefault("config_version", self._config_version)
plugin_section.setdefault("enabled", True)
plugin_section.setdefault("retry_count", 3)
return {"plugin": plugin_section}, changed
def set_plugin_config(self, config: Dict[str, Any]) -> None:
"""记录 Runner 注入的配置内容。
Args:
config: 当前最新配置。
"""
self.received_config = config
def get_default_config(self) -> Dict[str, Any]:
"""返回测试插件的默认配置。
Returns:
Dict[str, Any]: 默认配置字典。
"""
return {"plugin": {"config_version": self._config_version, "enabled": True, "retry_count": 3}}
def get_webui_config_schema(
self,
*,
plugin_id: str = "",
plugin_name: str = "",
plugin_version: str = "",
plugin_description: str = "",
plugin_author: str = "",
) -> Dict[str, Any]:
"""返回测试插件的 WebUI 配置 Schema。
Args:
plugin_id: 插件 ID。
plugin_name: 插件名称。
plugin_version: 插件版本。
plugin_description: 插件描述。
plugin_author: 插件作者。
Returns:
Dict[str, Any]: 测试配置 Schema。
"""
del plugin_name, plugin_description, plugin_author
return {
"plugin_id": plugin_id,
"plugin_info": {
"name": "Demo",
"version": plugin_version,
"description": "",
"author": "",
},
"sections": {
"plugin": {
"fields": {
"enabled": {
"type": "boolean",
"label": "启用",
"default": True,
"ui_type": "switch",
}
}
}
},
"layout": {"type": "auto", "tabs": []},
}
class _StrictConfigPlugin:
"""用于测试配置校验错误的伪插件。"""
def normalize_plugin_config(self, config_data: Optional[Mapping[str, Any]]) -> Tuple[Dict[str, Any], bool]:
"""校验重试次数不能为负数。
Args:
config_data: 原始配置数据。
Returns:
Tuple[Dict[str, Any], bool]: 规范化配置结果。
Raises:
ValueError: 当重试次数为负数时抛出。
"""
current_config = dict(config_data or {})
plugin_section = dict(current_config.get("plugin", {}))
plugin_section.setdefault("config_version", "2.0.0")
retry_count = int(plugin_section.get("retry_count", 0))
if retry_count < 0:
raise ValueError("重试次数不能小于 0")
plugin_section.setdefault("enabled", True)
return {"plugin": plugin_section}, False
def set_plugin_config(self, config: Dict[str, Any]) -> None:
"""兼容 Runner 配置注入接口。
Args:
config: 当前配置字典。
"""
del config
def get_default_config(self) -> Dict[str, Any]:
"""返回测试插件的默认配置。
Returns:
Dict[str, Any]: 默认配置字典。
"""
return {"plugin": {"config_version": "2.0.0", "enabled": True, "retry_count": 0}}
def test_runner_apply_plugin_config_generates_config_file(tmp_path: Path) -> None:
"""Runner 注入配置时应自动补齐并落盘 config.toml。"""
plugin = _DemoConfigPlugin()
runner = PluginRunner(
host_address="ipc://unused",
session_token="session-token",
plugin_dirs=[],
)
meta = SimpleNamespace(plugin_id="demo.plugin", plugin_dir=str(tmp_path), instance=plugin)
runner._apply_plugin_config(
cast(Any, meta),
config_data={"plugin": {"config_version": "2.0.0", "enabled": False}},
)
config_path = tmp_path / "config.toml"
assert config_path.exists()
assert plugin.received_config == {"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}}
with config_path.open("rb") as handle:
saved_config = tomllib.load(handle)
assert saved_config == {"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}}
def test_runner_apply_plugin_config_preserves_existing_comments(tmp_path: Path) -> None:
"""Runner 在版本升级时应尽量保留现有 config.toml 注释。"""
plugin = _DemoConfigPlugin()
runner = PluginRunner(
host_address="ipc://unused",
session_token="session-token",
plugin_dirs=[],
)
meta = SimpleNamespace(plugin_id="demo.plugin", plugin_dir=str(tmp_path), instance=plugin)
config_path = tmp_path / "config.toml"
config_path.write_text(
'# 插件配置头注释\n[plugin]\nconfig_version = "1.0.0"\nenabled = false # 启用开关注释\n',
encoding="utf-8",
)
runner._apply_plugin_config(cast(Any, meta))
config_text = config_path.read_text(encoding="utf-8")
assert "# 插件配置头注释" in config_text
assert "# 启用开关注释" in config_text
with config_path.open("rb") as handle:
saved_config = tomllib.load(handle)
assert saved_config == {"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}}
def test_runner_apply_plugin_config_same_version_does_not_rewrite_file(tmp_path: Path) -> None:
"""Runner 在配置版本未变化时不应仅因补齐默认值而重写文件。"""
plugin = _DemoConfigPlugin()
runner = PluginRunner(
host_address="ipc://unused",
session_token="session-token",
plugin_dirs=[],
)
meta = SimpleNamespace(plugin_id="demo.plugin", plugin_dir=str(tmp_path), instance=plugin)
config_path = tmp_path / "config.toml"
original_config_text = '# 原始注释\n[plugin]\nconfig_version = "2.0.0"\nenabled = false\n'
config_path.write_text(original_config_text, encoding="utf-8")
runner._apply_plugin_config(cast(Any, meta))
assert plugin.received_config == {"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}}
assert config_path.read_text(encoding="utf-8") == original_config_text
def test_runner_apply_plugin_config_requires_config_version(tmp_path: Path) -> None:
"""Runner 应拒绝缺少配置版本号的插件配置文件。"""
plugin = _DemoConfigPlugin()
runner = PluginRunner(
host_address="ipc://unused",
session_token="session-token",
plugin_dirs=[],
)
meta = SimpleNamespace(plugin_id="demo.plugin", plugin_dir=str(tmp_path), instance=plugin)
config_path = tmp_path / "config.toml"
config_path.write_text("[plugin]\nenabled = true\n", encoding="utf-8")
with pytest.raises(ValueError, match="config_version"):
runner._apply_plugin_config(cast(Any, meta))
def test_component_query_service_returns_plugin_config_schema(monkeypatch: Any) -> None:
"""组件查询服务应支持按插件 ID 返回配置 Schema。"""
payload = RegisterPluginPayload(
plugin_id="demo.plugin",
plugin_version="1.0.0",
default_config={"plugin": {"enabled": True}},
config_schema={
"plugin_id": "demo.plugin",
"plugin_info": {
"name": "Demo",
"version": "1.0.0",
"description": "",
"author": "",
},
"sections": {"plugin": {"fields": {}}},
"layout": {"type": "auto", "tabs": []},
},
)
fake_supervisor = SimpleNamespace(_registered_plugins={"demo.plugin": payload})
fake_manager = SimpleNamespace(_get_supervisor_for_plugin=lambda plugin_id: fake_supervisor)
monkeypatch.setattr(
type(component_query_service),
"_get_runtime_manager",
staticmethod(lambda: fake_manager),
)
assert component_query_service.get_plugin_config_schema("demo.plugin") == payload.config_schema
assert component_query_service.get_plugin_default_config("demo.plugin") == payload.default_config
@pytest.mark.asyncio
async def test_runner_validate_plugin_config_handler_returns_normalized_config(monkeypatch: pytest.MonkeyPatch) -> None:
"""Runner 应返回插件模型归一化后的配置。"""
plugin = _DemoConfigPlugin()
runner = PluginRunner(
host_address="ipc://unused",
session_token="session-token",
plugin_dirs=[],
)
meta = SimpleNamespace(plugin_id="demo.plugin", plugin_dir="", instance=plugin)
monkeypatch.setattr(runner._loader, "get_plugin", lambda plugin_id: meta if plugin_id == "demo.plugin" else None)
envelope = Envelope(
request_id=1,
message_type=MessageType.REQUEST,
method="plugin.validate_config",
plugin_id="demo.plugin",
payload=ValidatePluginConfigPayload(
config_data={"plugin": {"config_version": "2.0.0", "enabled": False}}
).model_dump(),
)
response = await runner._handle_validate_plugin_config(envelope)
assert response.error is None
assert response.payload["success"] is True
assert response.payload["normalized_config"] == {
"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}
}
@pytest.mark.asyncio
async def test_runner_inspect_plugin_config_handler_supports_unloaded_plugin(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Runner 应支持对未加载插件执行冷检查。"""
plugin = _DemoConfigPlugin()
runner = PluginRunner(
host_address="ipc://unused",
session_token="session-token",
plugin_dirs=[],
)
meta = SimpleNamespace(
plugin_id="demo.plugin",
plugin_dir="/tmp/demo-plugin",
instance=plugin,
manifest=SimpleNamespace(
name="Demo",
description="",
author=SimpleNamespace(name="tester"),
),
version="1.0.0",
)
purged_plugins: list[tuple[str, str]] = []
monkeypatch.setattr(
runner,
"_resolve_plugin_meta_for_config_request",
lambda plugin_id: (meta, True, None) if plugin_id == "demo.plugin" else (None, False, "not-found"),
)
monkeypatch.setattr(
runner._loader,
"purge_plugin_modules",
lambda plugin_id, plugin_dir: purged_plugins.append((plugin_id, plugin_dir)),
)
envelope = Envelope(
request_id=1,
message_type=MessageType.REQUEST,
method="plugin.inspect_config",
plugin_id="demo.plugin",
payload=InspectPluginConfigPayload(
config_data={"plugin": {"enabled": False}},
use_provided_config=True,
).model_dump(),
)
response = await runner._handle_inspect_plugin_config(envelope)
assert response.error is None
assert response.payload["success"] is True
assert response.payload["enabled"] is False
assert response.payload["normalized_config"] == {
"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}
}
assert response.payload["default_config"] == {
"plugin": {"config_version": "2.0.0", "enabled": True, "retry_count": 3}
}
assert purged_plugins == [("demo.plugin", "/tmp/demo-plugin")]
@pytest.mark.asyncio
async def test_runner_validate_plugin_config_handler_returns_error_on_invalid_config(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Runner 应在插件拒绝配置时返回错误响应。"""
plugin = _StrictConfigPlugin()
runner = PluginRunner(
host_address="ipc://unused",
session_token="session-token",
plugin_dirs=[],
)
meta = SimpleNamespace(plugin_id="demo.plugin", plugin_dir="", instance=plugin)
monkeypatch.setattr(runner._loader, "get_plugin", lambda plugin_id: meta if plugin_id == "demo.plugin" else None)
envelope = Envelope(
request_id=1,
message_type=MessageType.REQUEST,
method="plugin.validate_config",
plugin_id="demo.plugin",
payload=ValidatePluginConfigPayload(
config_data={"plugin": {"config_version": "2.0.0", "retry_count": -1}}
).model_dump(),
)
response = await runner._handle_validate_plugin_config(envelope)
assert response.error is not None
assert response.error["message"] == "重试次数不能小于 0"
@pytest.mark.asyncio
async def test_update_plugin_config_prefers_runtime_validation(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
"""WebUI 保存插件配置时应优先使用运行时校验结果。"""
config_path = tmp_path / "config.toml"
async def _mock_validate_plugin_config(plugin_id: str, config_data: Dict[str, Any]) -> Dict[str, Any] | None:
"""返回运行时归一化后的配置。
Args:
plugin_id: 插件 ID。
config_data: 原始配置。
Returns:
Dict[str, Any] | None: 归一化后的配置。
"""
assert plugin_id == "demo.plugin"
assert config_data == {"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}}
return {"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}}
async def _mock_inspect_plugin_config(
plugin_id: str,
config_data: Optional[Dict[str, Any]] = None,
*,
use_provided_config: bool = False,
) -> SimpleNamespace | None:
"""返回运行时配置快照。
Args:
plugin_id: 插件 ID。
config_data: 可选配置。
use_provided_config: 是否使用传入配置。
Returns:
SimpleNamespace | None: 运行时配置快照。
"""
del config_data, use_provided_config
if plugin_id != "demo.plugin":
return None
return SimpleNamespace(
normalized_config={"plugin": {"config_version": "2.0.0", "enabled": True, "retry_count": 3}}
)
fake_runtime_manager = SimpleNamespace(
inspect_plugin_config=_mock_inspect_plugin_config,
validate_plugin_config=_mock_validate_plugin_config,
)
monkeypatch.setattr(
"src.webui.routers.plugin.config_routes.require_plugin_token",
lambda session: session or "session-token",
)
monkeypatch.setattr(
"src.webui.routers.plugin.config_routes.find_plugin_path_by_id",
lambda plugin_id: tmp_path if plugin_id == "demo.plugin" else None,
)
monkeypatch.setattr(
"src.plugin_runtime.integration.get_plugin_runtime_manager",
lambda: fake_runtime_manager,
)
response = await update_plugin_config(
"demo.plugin",
UpdatePluginConfigRequest(config={"plugin.enabled": False}),
maibot_session="session-token",
)
assert response["success"] is True
with config_path.open("rb") as handle:
saved_config = tomllib.load(handle)
assert saved_config == {"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}}
@pytest.mark.asyncio
async def test_webui_config_endpoints_use_runtime_inspection_for_unloaded_plugin(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
"""WebUI 在插件未加载时也应从代码定义返回配置与 Schema。"""
async def _mock_inspect_plugin_config(
plugin_id: str,
config_data: Optional[Dict[str, Any]] = None,
*,
use_provided_config: bool = False,
) -> SimpleNamespace | None:
"""返回运行时冷检查结果。
Args:
plugin_id: 插件 ID。
config_data: 可选配置。
use_provided_config: 是否使用传入配置。
Returns:
SimpleNamespace | None: 冷检查结果。
"""
del config_data, use_provided_config
if plugin_id != "demo.plugin":
return None
return SimpleNamespace(
config_schema={
"plugin_id": "demo.plugin",
"plugin_info": {
"name": "Demo",
"version": "1.0.0",
"description": "",
"author": "",
},
"sections": {"plugin": {"fields": {}}},
"layout": {"type": "auto", "tabs": []},
},
normalized_config={"plugin": {"config_version": "2.0.0", "enabled": True, "retry_count": 3}},
enabled=True,
)
fake_runtime_manager = SimpleNamespace(inspect_plugin_config=_mock_inspect_plugin_config)
monkeypatch.setattr(
"src.webui.routers.plugin.config_routes.require_plugin_token",
lambda session: session or "session-token",
)
monkeypatch.setattr(
"src.webui.routers.plugin.config_routes.find_plugin_path_by_id",
lambda plugin_id: tmp_path if plugin_id == "demo.plugin" else None,
)
monkeypatch.setattr(
"src.plugin_runtime.integration.get_plugin_runtime_manager",
lambda: fake_runtime_manager,
)
schema_response = await get_plugin_config_schema("demo.plugin", maibot_session="session-token")
config_response = await get_plugin_config("demo.plugin", maibot_session="session-token")
assert schema_response["success"] is True
assert schema_response["schema"]["plugin_id"] == "demo.plugin"
assert config_response == {
"success": True,
"config": {"plugin": {"config_version": "2.0.0", "enabled": True, "retry_count": 3}},
"message": "配置文件不存在,已返回默认配置",
}

View File

@@ -1,225 +0,0 @@
"""插件依赖流水线测试。"""
from pathlib import Path
import json
import pytest
from src.plugin_runtime.dependency_pipeline import PluginDependencyPipeline
def _build_manifest(
plugin_id: str,
*,
dependencies: list[dict[str, str]] | None = None,
) -> dict[str, object]:
"""构造测试用的 Manifest v2 数据。
Args:
plugin_id: 插件 ID。
dependencies: 依赖声明列表。
Returns:
dict[str, object]: 可直接写入 ``_manifest.json`` 的字典。
"""
return {
"manifest_version": 2,
"version": "1.0.0",
"name": plugin_id,
"description": "测试插件",
"author": {
"name": "tester",
"url": "https://example.com/tester",
},
"license": "MIT",
"urls": {
"repository": f"https://example.com/{plugin_id}",
},
"host_application": {
"min_version": "1.0.0",
"max_version": "1.0.0",
},
"sdk": {
"min_version": "2.0.0",
"max_version": "2.99.99",
},
"dependencies": dependencies or [],
"capabilities": [],
"i18n": {
"default_locale": "zh-CN",
"supported_locales": ["zh-CN"],
},
"id": plugin_id,
}
def _write_plugin(
plugin_root: Path,
plugin_name: str,
plugin_id: str,
*,
dependencies: list[dict[str, str]] | None = None,
) -> Path:
"""在临时目录中写入一个测试插件。
Args:
plugin_root: 插件根目录。
plugin_name: 插件目录名。
plugin_id: 插件 ID。
dependencies: Python 依赖声明列表。
Returns:
Path: 插件目录路径。
"""
plugin_dir = plugin_root / plugin_name
plugin_dir.mkdir(parents=True)
(plugin_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
(plugin_dir / "_manifest.json").write_text(
json.dumps(_build_manifest(plugin_id, dependencies=dependencies)),
encoding="utf-8",
)
return plugin_dir
def test_build_plan_blocks_plugin_conflicting_with_host_requirement(tmp_path: Path) -> None:
"""与主程序依赖冲突的插件应被阻止加载。"""
plugin_root = tmp_path / "plugins"
_write_plugin(
plugin_root,
"conflict_plugin",
"test.conflict-plugin",
dependencies=[
{
"type": "python_package",
"name": "numpy",
"version_spec": "<1.0.0",
}
],
)
pipeline = PluginDependencyPipeline(project_root=Path.cwd())
plan = pipeline.build_plan([plugin_root])
assert "test.conflict-plugin" in plan.blocked_plugin_reasons
assert "主程序" in plan.blocked_plugin_reasons["test.conflict-plugin"]
assert plan.install_requirements == ()
def test_build_plan_blocks_plugins_with_conflicting_python_dependencies(tmp_path: Path) -> None:
"""插件之间出现 Python 包版本冲突时应同时阻止双方加载。"""
plugin_root = tmp_path / "plugins"
_write_plugin(
plugin_root,
"plugin_a",
"test.plugin-a",
dependencies=[
{
"type": "python_package",
"name": "demo-package",
"version_spec": "<2.0.0",
}
],
)
_write_plugin(
plugin_root,
"plugin_b",
"test.plugin-b",
dependencies=[
{
"type": "python_package",
"name": "demo-package",
"version_spec": ">=3.0.0",
}
],
)
pipeline = PluginDependencyPipeline(project_root=Path.cwd())
plan = pipeline.build_plan([plugin_root])
assert "test.plugin-a" in plan.blocked_plugin_reasons
assert "test.plugin-b" in plan.blocked_plugin_reasons
assert "test.plugin-b" in plan.blocked_plugin_reasons["test.plugin-a"]
assert "test.plugin-a" in plan.blocked_plugin_reasons["test.plugin-b"]
def test_build_plan_collects_install_requirements_for_missing_packages(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
"""未安装但无冲突的依赖应进入自动安装计划。"""
plugin_root = tmp_path / "plugins"
_write_plugin(
plugin_root,
"plugin_a",
"test.plugin-a",
dependencies=[
{
"type": "python_package",
"name": "demo-package",
"version_spec": ">=1.0.0,<2.0.0",
}
],
)
pipeline = PluginDependencyPipeline(project_root=Path.cwd())
monkeypatch.setattr(
pipeline._manifest_validator,
"get_installed_package_version",
lambda package_name: None if package_name == "demo-package" else "1.0.0",
)
plan = pipeline.build_plan([plugin_root])
assert plan.blocked_plugin_reasons == {}
assert len(plan.install_requirements) == 1
assert plan.install_requirements[0].package_name == "demo-package"
assert plan.install_requirements[0].plugin_ids == ("test.plugin-a",)
assert plan.install_requirements[0].requirement_text == "demo-package>=1.0.0,<2.0.0"
@pytest.mark.asyncio
async def test_execute_blocks_plugins_when_auto_install_fails(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
"""自动安装失败时,相关插件应被阻止加载。"""
plugin_root = tmp_path / "plugins"
_write_plugin(
plugin_root,
"plugin_a",
"test.plugin-a",
dependencies=[
{
"type": "python_package",
"name": "demo-package",
"version_spec": ">=1.0.0,<2.0.0",
}
],
)
pipeline = PluginDependencyPipeline(project_root=Path.cwd())
monkeypatch.setattr(
pipeline._manifest_validator,
"get_installed_package_version",
lambda package_name: None if package_name == "demo-package" else "1.0.0",
)
async def fake_install(_requirements) -> tuple[bool, str]:
"""模拟依赖安装失败。"""
return False, "network error"
monkeypatch.setattr(pipeline, "_install_requirements", fake_install)
result = await pipeline.execute([plugin_root])
assert result.environment_changed is False
assert "test.plugin-a" in result.blocked_plugin_reasons
assert "自动安装 Python 依赖失败" in result.blocked_plugin_reasons["test.plugin-a"]

View File

@@ -1,86 +0,0 @@
from datetime import datetime
from pathlib import Path
import sys
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo, UserInfo
from src.common.data_models.message_component_data_model import (
ForwardComponent,
ForwardNodeComponent,
ImageComponent,
MessageSequence,
ReplyComponent,
TextComponent,
VoiceComponent,
)
from src.plugin_runtime.host.message_utils import PluginMessageUtils
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
def test_plugin_message_utils_preserves_binary_components_and_reply_metadata() -> None:
message = SessionMessage(message_id="msg-1", timestamp=datetime.now(), platform="qq")
message.message_info = MessageInfo(
user_info=UserInfo(user_id="10001", user_nickname="tester"),
group_info=GroupInfo(group_id="20001", group_name="group"),
additional_config={"self_id": "999"},
)
message.session_id = "qq:20001:10001"
message.processed_plain_text = "binary payload"
message.raw_message = MessageSequence(
components=[
TextComponent("hello"),
ImageComponent(binary_hash="", binary_data=b"image-bytes", content=""),
VoiceComponent(binary_hash="", binary_data=b"voice-bytes", content=""),
ReplyComponent(
target_message_id="origin-1",
target_message_content="origin text",
target_message_sender_id="42",
target_message_sender_nickname="alice",
target_message_sender_cardname="Alice",
),
ForwardNodeComponent(
forward_components=[
ForwardComponent(
user_nickname="bob",
user_id="43",
user_cardname="Bob",
message_id="forward-1",
content=[
TextComponent("node-text"),
ImageComponent(binary_hash="", binary_data=b"node-image", content=""),
],
)
]
),
]
)
message_dict = PluginMessageUtils._session_message_to_dict(message)
rebuilt_message = PluginMessageUtils._build_session_message_from_dict(dict(message_dict))
image_component = rebuilt_message.raw_message.components[1]
voice_component = rebuilt_message.raw_message.components[2]
reply_component = rebuilt_message.raw_message.components[3]
forward_component = rebuilt_message.raw_message.components[4]
assert isinstance(image_component, ImageComponent)
assert image_component.binary_data == b"image-bytes"
assert isinstance(voice_component, VoiceComponent)
assert voice_component.binary_data == b"voice-bytes"
assert isinstance(reply_component, ReplyComponent)
assert reply_component.target_message_id == "origin-1"
assert reply_component.target_message_content == "origin text"
assert reply_component.target_message_sender_id == "42"
assert reply_component.target_message_sender_nickname == "alice"
assert reply_component.target_message_sender_cardname == "Alice"
assert isinstance(forward_component, ForwardNodeComponent)
assert isinstance(forward_component.forward_components[0].content[1], ImageComponent)
assert forward_component.forward_components[0].content[1].binary_data == b"node-image"

File diff suppressed because it is too large Load Diff

View File

@@ -1,284 +0,0 @@
"""核心组件查询层与插件运行时聚合测试。"""
from types import SimpleNamespace
from typing import Any
import pytest
import src.plugin_runtime.integration as integration_module
from src.core.types import ActionInfo, ToolInfo
from src.plugin_runtime.component_query import component_query_service
from src.plugin_runtime.host.supervisor import PluginSupervisor
class _FakeRuntimeManager:
"""测试用插件运行时管理器。"""
def __init__(self, supervisor: PluginSupervisor, plugin_id: str, plugin_config: dict[str, Any]) -> None:
"""初始化测试用运行时管理器。
Args:
supervisor: 持有测试组件的监督器。
plugin_id: 目标插件 ID。
plugin_config: 需要返回的插件配置。
"""
self.supervisors = [supervisor]
self._plugin_id = plugin_id
self._plugin_config = plugin_config
def _get_supervisor_for_plugin(self, plugin_id: str) -> PluginSupervisor | None:
"""按插件 ID 返回对应监督器。
Args:
plugin_id: 目标插件 ID。
Returns:
PluginSupervisor | None: 命中时返回监督器。
"""
return self.supervisors[0] if plugin_id == self._plugin_id else None
def _load_plugin_config_for_supervisor(self, supervisor: Any, plugin_id: str) -> dict[str, Any]:
"""返回测试配置。
Args:
supervisor: 监督器实例。
plugin_id: 目标插件 ID。
Returns:
dict[str, Any]: 测试配置内容。
"""
del supervisor
if plugin_id != self._plugin_id:
return {}
return dict(self._plugin_config)
def _install_runtime_manager(
monkeypatch: pytest.MonkeyPatch,
supervisor: PluginSupervisor,
plugin_id: str,
plugin_config: dict[str, Any] | None = None,
) -> None:
"""为测试安装假的运行时管理器。
Args:
monkeypatch: pytest monkeypatch 对象。
supervisor: 持有测试组件的监督器。
plugin_id: 测试插件 ID。
plugin_config: 可选的测试配置内容。
"""
fake_manager = _FakeRuntimeManager(supervisor, plugin_id, plugin_config or {"enabled": True})
monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: fake_manager)
@pytest.mark.asyncio
async def test_core_component_registry_reads_runtime_action_and_executor(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""核心查询层应直接读取运行时 Action并返回 RPC 执行闭包。"""
plugin_id = "runtime_action_bridge_plugin"
action_name = "runtime_action_bridge_test"
supervisor = PluginSupervisor(plugin_dirs=[])
captured: dict[str, Any] = {}
supervisor.component_registry.register_component(
name=action_name,
component_type="ACTION",
plugin_id=plugin_id,
metadata={
"description": "发送一个测试回复",
"enabled": True,
"activation_type": "keyword",
"activation_probability": 0.25,
"activation_keywords": ["测试", "hello"],
"action_parameters": {"target": "目标对象"},
"action_require": ["需要发送回复时使用"],
"associated_types": ["text"],
"parallel_action": True,
},
)
_install_runtime_manager(monkeypatch, supervisor, plugin_id, {"enabled": True, "mode": "test"})
async def fake_invoke_plugin(
method: str,
plugin_id: str,
component_name: str,
args: dict[str, Any] | None = None,
timeout_ms: int = 30000,
) -> Any:
"""模拟动作 RPC 调用。"""
captured["method"] = method
captured["plugin_id"] = plugin_id
captured["component_name"] = component_name
captured["args"] = args or {}
captured["timeout_ms"] = timeout_ms
return SimpleNamespace(payload={"success": True, "result": (True, "runtime action executed")})
monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin)
action_info = component_query_service.get_action_info(action_name)
assert isinstance(action_info, ActionInfo)
assert action_info.plugin_name == plugin_id
assert action_info.description == "发送一个测试回复"
assert action_info.activation_keywords == ["测试", "hello"]
assert action_info.random_activation_probability == 0.25
assert action_info.parallel_action is True
assert action_name in component_query_service.get_default_actions()
assert component_query_service.get_plugin_config(plugin_id) == {"enabled": True, "mode": "test"}
executor = component_query_service.get_action_executor(action_name)
assert executor is not None
success, reason = await executor(
action_data={"target": "MaiBot"},
action_reasoning="当前适合使用这个动作",
cycle_timers={"planner": 0.1},
thinking_id="tid-1",
chat_stream=SimpleNamespace(session_id="stream-1"),
log_prefix="[test]",
shutting_down=False,
plugin_config={"enabled": True},
)
assert success is True
assert reason == "runtime action executed"
assert captured["method"] == "plugin.invoke_action"
assert captured["plugin_id"] == plugin_id
assert captured["component_name"] == action_name
assert captured["args"]["stream_id"] == "stream-1"
assert captured["args"]["chat_id"] == "stream-1"
assert captured["args"]["reasoning"] == "当前适合使用这个动作"
assert captured["args"]["target"] == "MaiBot"
assert captured["args"]["action_data"] == {"target": "MaiBot"}
@pytest.mark.asyncio
async def test_core_component_registry_reads_runtime_command_and_executor(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""核心查询层应直接使用运行时命令匹配与执行闭包。"""
plugin_id = "runtime_command_bridge_plugin"
command_name = "runtime_command_bridge_test"
supervisor = PluginSupervisor(plugin_dirs=[])
captured: dict[str, Any] = {}
supervisor.component_registry.register_component(
name=command_name,
component_type="COMMAND",
plugin_id=plugin_id,
metadata={
"description": "测试命令",
"enabled": True,
"command_pattern": r"^/test(?:\s+.+)?$",
"aliases": ["/hello"],
"intercept_message_level": 1,
},
)
_install_runtime_manager(monkeypatch, supervisor, plugin_id, {"mode": "command"})
async def fake_invoke_plugin(
method: str,
plugin_id: str,
component_name: str,
args: dict[str, Any] | None = None,
timeout_ms: int = 30000,
) -> Any:
"""模拟命令 RPC 调用。"""
captured["method"] = method
captured["plugin_id"] = plugin_id
captured["component_name"] = component_name
captured["args"] = args or {}
captured["timeout_ms"] = timeout_ms
return SimpleNamespace(payload={"success": True, "result": (True, "command ok", True)})
monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin)
matched = component_query_service.find_command_by_text("/test hello")
assert matched is not None
command_executor, matched_groups, command_info = matched
assert matched_groups == {}
assert command_info.plugin_name == plugin_id
assert command_info.command_pattern == r"^/test(?:\s+.+)?$"
success, response_text, intercept = await command_executor(
message=SimpleNamespace(processed_plain_text="/test hello", session_id="stream-2"),
plugin_config={"mode": "command"},
matched_groups=matched_groups,
)
assert success is True
assert response_text == "command ok"
assert intercept is True
assert captured["method"] == "plugin.invoke_command"
assert captured["plugin_id"] == plugin_id
assert captured["component_name"] == command_name
assert captured["args"]["text"] == "/test hello"
assert captured["args"]["stream_id"] == "stream-2"
assert captured["args"]["plugin_config"] == {"mode": "command"}
@pytest.mark.asyncio
async def test_core_component_registry_reads_runtime_tools_and_executor(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""核心查询层应直接读取运行时 Tool并返回 RPC 执行闭包。"""
plugin_id = "runtime_tool_bridge_plugin"
tool_name = "runtime_tool_bridge_test"
supervisor = PluginSupervisor(plugin_dirs=[])
supervisor.component_registry.register_component(
name=tool_name,
component_type="TOOL",
plugin_id=plugin_id,
metadata={
"description": "测试工具",
"enabled": True,
"parameters": [
{
"name": "query",
"param_type": "string",
"description": "查询词",
"required": True,
}
],
},
)
_install_runtime_manager(monkeypatch, supervisor, plugin_id)
async def fake_invoke_plugin(
method: str,
plugin_id: str,
component_name: str,
args: dict[str, Any] | None = None,
timeout_ms: int = 30000,
) -> Any:
"""模拟工具 RPC 调用。"""
del timeout_ms
assert method == "plugin.invoke_tool"
assert plugin_id == "runtime_tool_bridge_plugin"
assert component_name == "runtime_tool_bridge_test"
assert args == {"query": "MaiBot"}
return SimpleNamespace(payload={"success": True, "result": {"content": "tool ok"}})
monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin)
tool_info = component_query_service.get_tool_info(tool_name)
assert isinstance(tool_info, ToolInfo)
assert tool_info.tool_description == "测试工具"
assert tool_name in component_query_service.get_llm_available_tools()
executor = component_query_service.get_tool_executor(tool_name)
assert executor is not None
assert await executor({"query": "MaiBot"}) == {"content": "tool ok"}

View File

@@ -1,524 +0,0 @@
"""插件 API 注册与调用测试。"""
from types import SimpleNamespace
from typing import Any, Dict, List
import pytest
from src.plugin_runtime.integration import PluginRuntimeManager
from src.plugin_runtime.host.supervisor import PluginSupervisor
from src.plugin_runtime.protocol.envelope import (
ComponentDeclaration,
Envelope,
MessageType,
RegisterPluginPayload,
UnregisterPluginPayload,
)
def _build_manager(*supervisors: PluginSupervisor) -> PluginRuntimeManager:
"""构造一个最小可用的插件运行时管理器。
Args:
*supervisors: 需要挂载的监督器列表。
Returns:
PluginRuntimeManager: 已注入监督器的运行时管理器。
"""
manager = PluginRuntimeManager()
if supervisors:
manager._builtin_supervisor = supervisors[0]
if len(supervisors) > 1:
manager._third_party_supervisor = supervisors[1]
return manager
async def _register_plugin(
supervisor: PluginSupervisor,
plugin_id: str,
components: List[Dict[str, Any]],
) -> Envelope:
"""通过 Supervisor 注册测试插件。
Args:
supervisor: 目标监督器。
plugin_id: 测试插件 ID。
components: 组件声明列表。
Returns:
Envelope: 注册响应信封。
"""
payload = RegisterPluginPayload(
plugin_id=plugin_id,
plugin_version="1.0.0",
components=[
ComponentDeclaration(
name=str(component.get("name", "") or ""),
component_type=str(component.get("component_type", "") or ""),
plugin_id=plugin_id,
metadata=component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {},
)
for component in components
],
)
return await supervisor._handle_register_plugin(
Envelope(
request_id=1,
message_type=MessageType.REQUEST,
method="plugin.register_components",
plugin_id=plugin_id,
payload=payload.model_dump(),
)
)
async def _unregister_plugin(supervisor: PluginSupervisor, plugin_id: str) -> Envelope:
"""通过 Supervisor 注销测试插件。
Args:
supervisor: 目标监督器。
plugin_id: 测试插件 ID。
Returns:
Envelope: 注销响应信封。
"""
payload = UnregisterPluginPayload(plugin_id=plugin_id, reason="test")
return await supervisor._handle_unregister_plugin(
Envelope(
request_id=2,
message_type=MessageType.REQUEST,
method="plugin.unregister",
plugin_id=plugin_id,
payload=payload.model_dump(),
)
)
@pytest.mark.asyncio
async def test_register_plugin_syncs_dedicated_api_registry() -> None:
"""插件注册时应将 API 同步到独立注册表,而不是通用组件表。"""
supervisor = PluginSupervisor(plugin_dirs=[])
response = await _register_plugin(
supervisor,
"provider",
[
{
"name": "render_html",
"component_type": "API",
"metadata": {
"description": "渲染 HTML",
"version": "1",
"public": True,
},
}
],
)
assert response.payload["accepted"] is True
assert response.payload["registered_components"] == 0
assert response.payload["registered_apis"] == 1
assert supervisor.api_registry.get_api("provider", "render_html") is not None
assert supervisor.component_registry.get_component("provider.render_html") is None
unregister_response = await _unregister_plugin(supervisor, "provider")
assert unregister_response.payload["removed_apis"] == 1
assert supervisor.api_registry.get_api("provider", "render_html") is None
@pytest.mark.asyncio
async def test_api_call_allows_public_api_between_plugins(monkeypatch: pytest.MonkeyPatch) -> None:
"""公开 API 应允许其他插件通过 Host 转发调用。"""
provider_supervisor = PluginSupervisor(plugin_dirs=[])
consumer_supervisor = PluginSupervisor(plugin_dirs=[])
await _register_plugin(
provider_supervisor,
"provider",
[
{
"name": "render_html",
"component_type": "API",
"metadata": {
"description": "渲染 HTML",
"version": "1",
"public": True,
},
}
],
)
await _register_plugin(consumer_supervisor, "consumer", [])
captured: Dict[str, Any] = {}
async def fake_invoke_api(
plugin_id: str,
component_name: str,
args: Dict[str, Any] | None = None,
timeout_ms: int = 30000,
) -> Any:
"""模拟 API RPC 调用。"""
captured["plugin_id"] = plugin_id
captured["component_name"] = component_name
captured["args"] = args or {}
captured["timeout_ms"] = timeout_ms
return SimpleNamespace(error=None, payload={"success": True, "result": {"image": "ok"}})
monkeypatch.setattr(provider_supervisor, "invoke_api", fake_invoke_api)
manager = _build_manager(provider_supervisor, consumer_supervisor)
result = await manager._cap_api_call(
"consumer",
"api.call",
{
"api_name": "provider.render_html",
"version": "1",
"args": {"html": "<div>Hello</div>"},
},
)
assert result == {"success": True, "result": {"image": "ok"}}
assert captured["plugin_id"] == "provider"
assert captured["component_name"] == "render_html"
assert captured["args"] == {"html": "<div>Hello</div>"}
@pytest.mark.asyncio
async def test_api_call_rejects_private_api_between_plugins() -> None:
"""未公开的 API 默认不允许跨插件调用。"""
provider_supervisor = PluginSupervisor(plugin_dirs=[])
consumer_supervisor = PluginSupervisor(plugin_dirs=[])
await _register_plugin(
provider_supervisor,
"provider",
[
{
"name": "secret_api",
"component_type": "API",
"metadata": {
"description": "私有 API",
"version": "1",
"public": False,
},
}
],
)
await _register_plugin(consumer_supervisor, "consumer", [])
manager = _build_manager(provider_supervisor, consumer_supervisor)
result = await manager._cap_api_call(
"consumer",
"api.call",
{
"api_name": "provider.secret_api",
"args": {},
},
)
assert result["success"] is False
assert "未公开" in str(result["error"])
@pytest.mark.asyncio
async def test_api_list_and_component_toggle_use_dedicated_registry() -> None:
"""API 列表与组件启停应直接作用于独立 API 注册表。"""
provider_supervisor = PluginSupervisor(plugin_dirs=[])
consumer_supervisor = PluginSupervisor(plugin_dirs=[])
await _register_plugin(
provider_supervisor,
"provider",
[
{
"name": "public_api",
"component_type": "API",
"metadata": {"version": "1", "public": True},
},
{
"name": "private_api",
"component_type": "API",
"metadata": {"version": "1", "public": False},
},
],
)
await _register_plugin(
consumer_supervisor,
"consumer",
[
{
"name": "self_private_api",
"component_type": "API",
"metadata": {"version": "1", "public": False},
}
],
)
manager = _build_manager(provider_supervisor, consumer_supervisor)
list_result = await manager._cap_api_list("consumer", "api.list", {})
assert list_result["success"] is True
api_names = {(item["plugin_id"], item["name"]) for item in list_result["apis"]}
assert ("provider", "public_api") in api_names
assert ("provider", "private_api") not in api_names
assert ("consumer", "self_private_api") in api_names
disable_result = await manager._cap_component_disable(
"consumer",
"component.disable",
{
"name": "provider.public_api",
"component_type": "API",
"scope": "global",
"stream_id": "",
},
)
assert disable_result["success"] is True
assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is None
enable_result = await manager._cap_component_enable(
"consumer",
"component.enable",
{
"name": "provider.public_api",
"component_type": "API",
"scope": "global",
"stream_id": "",
},
)
assert enable_result["success"] is True
assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is not None
@pytest.mark.asyncio
async def test_api_registry_supports_multiple_versions_with_distinct_handlers(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""同名 API 不同版本应可并存,并按版本路由到不同处理器。"""
provider_supervisor = PluginSupervisor(plugin_dirs=[])
consumer_supervisor = PluginSupervisor(plugin_dirs=[])
await _register_plugin(
provider_supervisor,
"provider",
[
{
"name": "render_html",
"component_type": "API",
"metadata": {
"description": "渲染 HTML v1",
"version": "1",
"public": True,
"handler_name": "handle_render_html_v1",
},
},
{
"name": "render_html",
"component_type": "API",
"metadata": {
"description": "渲染 HTML v2",
"version": "2",
"public": True,
"handler_name": "handle_render_html_v2",
},
},
],
)
await _register_plugin(consumer_supervisor, "consumer", [])
captured: Dict[str, Any] = {}
async def fake_invoke_api(
plugin_id: str,
component_name: str,
args: Dict[str, Any] | None = None,
timeout_ms: int = 30000,
) -> Any:
"""模拟多版本 API 调用。"""
captured["plugin_id"] = plugin_id
captured["component_name"] = component_name
captured["args"] = args or {}
captured["timeout_ms"] = timeout_ms
return SimpleNamespace(error=None, payload={"success": True, "result": {"image": "ok"}})
monkeypatch.setattr(provider_supervisor, "invoke_api", fake_invoke_api)
manager = _build_manager(provider_supervisor, consumer_supervisor)
ambiguous_result = await manager._cap_api_call(
"consumer",
"api.call",
{
"api_name": "provider.render_html",
"args": {"html": "<div>Hello</div>"},
},
)
assert ambiguous_result["success"] is False
assert "多个版本" in str(ambiguous_result["error"])
disable_ambiguous_result = await manager._cap_component_disable(
"consumer",
"component.disable",
{
"name": "provider.render_html",
"component_type": "API",
"scope": "global",
"stream_id": "",
},
)
assert disable_ambiguous_result["success"] is False
assert "多个版本" in str(disable_ambiguous_result["error"])
disable_v1_result = await manager._cap_component_disable(
"consumer",
"component.disable",
{
"name": "provider.render_html",
"component_type": "API",
"scope": "global",
"stream_id": "",
"version": "1",
},
)
assert disable_v1_result["success"] is True
assert provider_supervisor.api_registry.get_api("provider", "render_html", version="1", enabled_only=True) is None
assert provider_supervisor.api_registry.get_api("provider", "render_html", version="2", enabled_only=True) is not None
result = await manager._cap_api_call(
"consumer",
"api.call",
{
"api_name": "provider.render_html",
"version": "2",
"args": {"html": "<div>Hello</div>"},
},
)
assert result == {"success": True, "result": {"image": "ok"}}
assert captured["plugin_id"] == "provider"
assert captured["component_name"] == "handle_render_html_v2"
assert captured["args"] == {"html": "<div>Hello</div>"}
@pytest.mark.asyncio
async def test_api_replace_dynamic_can_offline_removed_entries(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""动态 API 替换后,被移除的 API 应返回明确下线错误。"""
supervisor = PluginSupervisor(plugin_dirs=[])
await _register_plugin(supervisor, "provider", [])
manager = _build_manager(supervisor)
captured: Dict[str, Any] = {}
async def fake_invoke_api(
plugin_id: str,
component_name: str,
args: Dict[str, Any] | None = None,
timeout_ms: int = 30000,
) -> Any:
"""模拟动态 API 调用。"""
captured["plugin_id"] = plugin_id
captured["component_name"] = component_name
captured["args"] = args or {}
captured["timeout_ms"] = timeout_ms
return SimpleNamespace(error=None, payload={"success": True, "result": {"ok": True}})
monkeypatch.setattr(supervisor, "invoke_api", fake_invoke_api)
replace_result = await manager._cap_api_replace_dynamic(
"provider",
"api.replace_dynamic",
{
"apis": [
{
"name": "mcp.search",
"type": "API",
"metadata": {
"version": "1",
"public": True,
"handler_name": "dynamic_search",
},
},
{
"name": "mcp.read",
"type": "API",
"metadata": {
"version": "1",
"public": True,
"handler_name": "dynamic_read",
},
},
],
"offline_reason": "MCP 服务器已关闭",
},
)
assert replace_result["success"] is True
assert replace_result["count"] == 2
list_result = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"})
assert {(item["name"], item["version"]) for item in list_result["apis"]} == {
("mcp.read", "1"),
("mcp.search", "1"),
}
call_result = await manager._cap_api_call(
"provider",
"api.call",
{
"api_name": "provider.mcp.search",
"version": "1",
"args": {"query": "hello"},
},
)
assert call_result == {"success": True, "result": {"ok": True}}
assert captured["component_name"] == "dynamic_search"
assert captured["args"]["query"] == "hello"
assert captured["args"]["__maibot_api_name__"] == "mcp.search"
assert captured["args"]["__maibot_api_version__"] == "1"
second_replace_result = await manager._cap_api_replace_dynamic(
"provider",
"api.replace_dynamic",
{
"apis": [
{
"name": "mcp.read",
"type": "API",
"metadata": {
"version": "1",
"public": True,
"handler_name": "dynamic_read",
},
}
],
"offline_reason": "MCP 服务器已关闭",
},
)
assert second_replace_result["success"] is True
assert second_replace_result["count"] == 1
assert second_replace_result["offlined"] == 1
offlined_call_result = await manager._cap_api_call(
"provider",
"api.call",
{
"api_name": "provider.mcp.search",
"version": "1",
"args": {},
},
)
assert offlined_call_result["success"] is False
assert "MCP 服务器已关闭" in str(offlined_call_result["error"])
list_after_replace = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"})
assert {(item["name"], item["version"]) for item in list_after_replace["apis"]} == {
("mcp.read", "1"),
}

View File

@@ -1,96 +0,0 @@
"""插件运行时浏览器渲染能力测试。"""
from typing import Optional
import pytest
from src.plugin_runtime.integration import PluginRuntimeManager
from src.plugin_runtime.host.supervisor import PluginSupervisor
from src.services.html_render_service import HtmlRenderRequest, HtmlRenderResult
class _FakeRenderService:
"""用于替代真实浏览器渲染服务的测试桩。"""
def __init__(self) -> None:
"""初始化测试桩。"""
self.last_request: Optional[HtmlRenderRequest] = None
async def render_html_to_png(self, request: HtmlRenderRequest) -> HtmlRenderResult:
"""记录请求并返回固定的渲染结果。
Args:
request: 当前渲染请求。
Returns:
HtmlRenderResult: 固定的测试渲染结果。
"""
self.last_request = request
return HtmlRenderResult(
image_base64="ZmFrZS1pbWFnZQ==",
mime_type="image/png",
width=640,
height=480,
render_ms=12,
)
def test_render_capability_is_registered() -> None:
"""Host 注册能力时应包含 render.html2png。"""
manager = PluginRuntimeManager()
supervisor = PluginSupervisor(plugin_dirs=[])
manager._register_capability_impls(supervisor)
assert "render.html2png" in supervisor.capability_service.list_capabilities()
@pytest.mark.asyncio
async def test_render_capability_forwards_request(monkeypatch: pytest.MonkeyPatch) -> None:
"""render.html2png 应将请求透传给浏览器渲染服务。"""
from src.plugin_runtime.capabilities import render as render_capability_module
fake_service = _FakeRenderService()
monkeypatch.setattr(render_capability_module, "get_html_render_service", lambda: fake_service)
manager = PluginRuntimeManager()
result = await manager._cap_render_html2png(
"demo.plugin",
"render.html2png",
{
"html": "<body><div id='card'>hello</div></body>",
"selector": "#card",
"viewport": {"width": 1024, "height": 768},
"device_scale_factor": 1.5,
"full_page": False,
"omit_background": True,
"wait_until": "networkidle",
"wait_for_selector": "#card",
"wait_for_timeout_ms": 150,
"timeout_ms": 3000,
"allow_network": True,
},
)
assert result == {
"success": True,
"result": {
"image_base64": "ZmFrZS1pbWFnZQ==",
"mime_type": "image/png",
"width": 640,
"height": 480,
"render_ms": 12,
},
}
assert fake_service.last_request is not None
assert fake_service.last_request.selector == "#card"
assert fake_service.last_request.viewport_width == 1024
assert fake_service.last_request.viewport_height == 768
assert fake_service.last_request.device_scale_factor == 1.5
assert fake_service.last_request.omit_background is True
assert fake_service.last_request.wait_until == "networkidle"
assert fake_service.last_request.allow_network is True

View File

@@ -1,18 +0,0 @@
from src.llm_models.payload_content.message import MessageBuilder, RoleType
from src.plugin_runtime.hook_payloads import deserialize_prompt_messages, serialize_prompt_messages
def test_prompt_messages_roundtrip_preserves_image_parts() -> None:
messages = [
MessageBuilder().set_role(RoleType.User).add_text_content("你好").add_image_content("png", "ZmFrZQ==").build(),
]
serialized_messages = serialize_prompt_messages(messages)
restored_messages = deserialize_prompt_messages(serialized_messages)
assert len(restored_messages) == 1
assert restored_messages[0].role == RoleType.User
assert restored_messages[0].get_text_content() == "你好"
assert len(restored_messages[0].parts) == 2
assert restored_messages[0].parts[1].image_format == "png"
assert restored_messages[0].parts[1].image_base64 == "ZmFrZQ=="

View File

@@ -1,136 +0,0 @@
"""业务命名 Hook 集成测试。"""
from types import SimpleNamespace
from typing import Any
import os
import sys
import pytest
# 确保项目根目录在 sys.path 中
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
# SDK 包路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "packages", "maibot-plugin-sdk"))
class _FakeHookManager:
"""用于业务 Hook 测试的最小运行时管理器。"""
def __init__(self, responses: dict[str, SimpleNamespace]) -> None:
"""初始化测试管理器。
Args:
responses: 按 Hook 名称预设的返回结果映射。
"""
self._responses = responses
self.calls: list[tuple[str, dict[str, Any]]] = []
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> SimpleNamespace:
"""模拟调用运行时命名 Hook。
Args:
hook_name: 目标 Hook 名称。
**kwargs: 传入 Hook 的参数。
Returns:
SimpleNamespace: 预设的 Hook 返回结果。
"""
self.calls.append((hook_name, dict(kwargs)))
return self._responses.get(hook_name, SimpleNamespace(kwargs=dict(kwargs), aborted=False))
def test_builtin_hook_catalog_includes_new_business_hooks(monkeypatch: pytest.MonkeyPatch) -> None:
"""内置 Hook 目录应包含三个业务系统新增的 Hook。"""
monkeypatch.setattr(sys, "exit", lambda code=0: None)
from src.plugin_runtime.hook_catalog import register_builtin_hook_specs
from src.plugin_runtime.host.hook_spec_registry import HookSpecRegistry
registry = HookSpecRegistry()
hook_names = {spec.name for spec in register_builtin_hook_specs(registry)}
assert "emoji.maisaka.before_select" in hook_names
assert "emoji.register.after_build_emotion" in hook_names
assert "jargon.extract.before_persist" in hook_names
assert "jargon.query.after_search" in hook_names
assert "expression.select.before_select" in hook_names
assert "expression.learn.before_upsert" in hook_names
@pytest.mark.asyncio
async def test_send_emoji_for_maisaka_can_be_aborted_by_hook(monkeypatch: pytest.MonkeyPatch) -> None:
"""表情包系统应允许在选择前被 Hook 中止。"""
from src.emoji_system import maisaka_tool
fake_manager = _FakeHookManager(
{
"emoji.maisaka.before_select": SimpleNamespace(
kwargs={"abort_message": "插件阻止了表情发送。"},
aborted=True,
)
}
)
monkeypatch.setattr(maisaka_tool, "_get_runtime_manager", lambda: fake_manager)
result = await maisaka_tool.send_emoji_for_maisaka(stream_id="stream-1", requested_emotion="开心")
assert result.success is False
assert result.message == "插件阻止了表情发送。"
assert fake_manager.calls[0][0] == "emoji.maisaka.before_select"
@pytest.mark.asyncio
async def test_jargon_extract_can_be_aborted_before_persist(monkeypatch: pytest.MonkeyPatch) -> None:
"""黑话提取结果应允许在写库前被 Hook 中止。"""
from src.learners.jargon_miner import JargonMiner
fake_manager = _FakeHookManager(
{
"jargon.extract.before_persist": SimpleNamespace(
kwargs={"entries": []},
aborted=True,
)
}
)
monkeypatch.setattr(JargonMiner, "_get_runtime_manager", staticmethod(lambda: fake_manager))
miner = JargonMiner(session_id="session-1", session_name="测试会话")
await miner.process_extracted_entries(
[{"content": "yyds", "raw_content": {"[1] yyds 太强了"}}],
)
assert fake_manager.calls[0][0] == "jargon.extract.before_persist"
assert fake_manager.calls[0][1]["session_id"] == "session-1"
@pytest.mark.asyncio
async def test_expression_selection_can_be_aborted_by_hook(monkeypatch: pytest.MonkeyPatch) -> None:
"""表达方式选择流程应允许在开始前被 Hook 中止。"""
from src.learners.expression_selector import ExpressionSelector
fake_manager = _FakeHookManager(
{
"expression.select.before_select": SimpleNamespace(
kwargs={},
aborted=True,
)
}
)
monkeypatch.setattr(ExpressionSelector, "_get_runtime_manager", staticmethod(lambda: fake_manager))
monkeypatch.setattr(ExpressionSelector, "can_use_expression_for_chat", lambda self, chat_id: True)
selector = ExpressionSelector()
selected_expressions, selected_ids = await selector.select_suitable_expressions(
chat_id="session-1",
chat_info="用户刚刚发来一条消息。",
)
assert selected_expressions == []
assert selected_ids == []
assert fake_manager.calls[0][0] == "expression.select.before_select"

View File

@@ -1,360 +0,0 @@
"""发送服务回归测试。"""
import sys
from types import SimpleNamespace
from typing import Any, Dict, List
import pytest
from src.chat.message_receive.chat_manager import BotChatSession
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
from src.services import send_service
class _FakePlatformIOManager:
"""用于测试的 Platform IO 管理器假对象。"""
def __init__(self, delivery_batch: Any) -> None:
self._delivery_batch = delivery_batch
self.ensure_calls = 0
self.sent_messages: List[Dict[str, Any]] = []
async def ensure_send_pipeline_ready(self) -> None:
self.ensure_calls += 1
def build_route_key_from_message(self, message: Any) -> Any:
del message
return SimpleNamespace(platform="qq")
async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any:
self.sent_messages.append(
{
"message": message,
"message_id_before_send": str(getattr(message, "message_id", "") or ""),
"route_key": route_key,
"metadata": metadata,
}
)
return self._delivery_batch
def _build_private_stream() -> BotChatSession:
return BotChatSession(
session_id="test-session",
platform="qq",
user_id="target-user",
group_id=None,
)
def _build_group_stream() -> BotChatSession:
return BotChatSession(
session_id="group-session",
platform="qq",
user_id="target-user",
group_id="target-group",
)
def test_inherit_platform_io_route_metadata_falls_back_to_bot_account(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq" if platform == "qq" else "")
metadata = send_service._inherit_platform_io_route_metadata(_build_private_stream())
assert metadata["platform_io_account_id"] == "bot-qq"
assert metadata["platform_io_target_user_id"] == "target-user"
@pytest.mark.asyncio
async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.MonkeyPatch) -> None:
import src.common.message_server.api as message_server_api
fake_manager = _FakePlatformIOManager(
delivery_batch=SimpleNamespace(
has_success=True,
sent_receipts=[
SimpleNamespace(
driver_id="plugin.qq.sender",
external_message_id="real-message-id",
metadata={
"adapter_callbacks": [
{
"name": "message_id_echo",
"payload": {
"content": {
"type": "echo",
"echo": "send_api_test",
"actual_id": "real-message-id",
}
},
}
]
},
)
],
failed_receipts=[],
route_key=SimpleNamespace(platform="qq"),
)
)
callback_payloads: List[Dict[str, Any]] = []
stored_messages: List[Any] = []
async def fake_echo_handler(payload: Dict[str, Any]) -> None:
"""记录发送成功后的消息 ID 回调。"""
callback_payloads.append(payload)
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager)
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
monkeypatch.setattr(
send_service._chat_manager,
"get_session_by_session_id",
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
)
monkeypatch.setattr(
send_service.MessageUtils,
"store_message_to_db",
lambda message: stored_messages.append(message),
)
monkeypatch.setattr(
message_server_api,
"global_api",
SimpleNamespace(_custom_message_handlers={"message_id_echo": fake_echo_handler}),
)
result = await send_service.text_to_stream(text="你好", stream_id="test-session")
assert result is True
assert fake_manager.ensure_calls == 1
assert len(fake_manager.sent_messages) == 1
assert fake_manager.sent_messages[0]["metadata"] == {"show_log": False}
assert len(stored_messages) == 1
assert stored_messages[0].message_id == "real-message-id"
assert callback_payloads == [
{
"content": {
"type": "echo",
"echo": "send_api_test",
"actual_id": "real-message-id",
}
}
]
@pytest.mark.asyncio
async def test_text_to_stream_with_message_returns_sent_message(monkeypatch: pytest.MonkeyPatch) -> None:
fake_manager = _FakePlatformIOManager(
delivery_batch=SimpleNamespace(
has_success=True,
sent_receipts=[
SimpleNamespace(
driver_id="plugin.qq.sender",
external_message_id="real-message-id",
metadata={},
)
],
failed_receipts=[],
route_key=SimpleNamespace(platform="qq"),
)
)
stored_messages: List[Any] = []
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager)
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
monkeypatch.setattr(
send_service._chat_manager,
"get_session_by_session_id",
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
)
monkeypatch.setattr(
send_service.MessageUtils,
"store_message_to_db",
lambda message: stored_messages.append(message),
)
sent_message = await send_service.text_to_stream_with_message(text="你好", stream_id="test-session")
assert sent_message is not None
assert sent_message.message_id == "real-message-id"
assert fake_manager.ensure_calls == 1
assert len(stored_messages) == 1
assert stored_messages[0].message_id == "real-message-id"
@pytest.mark.asyncio
async def test_text_to_stream_with_message_triggers_memory_and_syncs_maisaka_history(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_manager = _FakePlatformIOManager(
delivery_batch=SimpleNamespace(
has_success=True,
sent_receipts=[
SimpleNamespace(
driver_id="plugin.qq.sender",
external_message_id="real-message-id",
metadata={},
)
],
failed_receipts=[],
route_key=SimpleNamespace(platform="qq"),
)
)
stored_messages: List[Any] = []
memory_events: List[str] = []
history_events: List[tuple[str, str]] = []
class FakeMemoryAutomationService:
async def on_message_sent(self, message: Any) -> None:
memory_events.append(str(message.message_id))
class FakeRuntime:
def append_sent_message_to_chat_history(self, message: Any, *, source_kind: str = "guided_reply") -> None:
history_events.append((str(message.message_id), source_kind))
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager)
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
monkeypatch.setattr(
send_service._chat_manager,
"get_session_by_session_id",
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
)
monkeypatch.setattr(
send_service.MessageUtils,
"store_message_to_db",
lambda message: stored_messages.append(message),
)
monkeypatch.setitem(
sys.modules,
"src.services.memory_flow_service",
SimpleNamespace(memory_automation_service=FakeMemoryAutomationService()),
)
monkeypatch.setitem(
sys.modules,
"src.chat.heart_flow.heartflow_manager",
SimpleNamespace(heartflow_manager=SimpleNamespace(heartflow_chat_list={"test-session": FakeRuntime()})),
)
sent_message = await send_service.text_to_stream_with_message(
text="你好",
stream_id="test-session",
sync_to_maisaka_history=True,
maisaka_source_kind="guided_reply",
)
assert sent_message is not None
assert sent_message.message_id == "real-message-id"
assert fake_manager.ensure_calls == 1
assert len(stored_messages) == 1
assert stored_messages[0].message_id == "real-message-id"
assert memory_events == ["real-message-id"]
assert history_events == [("real-message-id", "guided_reply")]
@pytest.mark.asyncio
async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: pytest.MonkeyPatch) -> None:
fake_manager = _FakePlatformIOManager(
delivery_batch=SimpleNamespace(
has_success=False,
sent_receipts=[],
failed_receipts=[
SimpleNamespace(
driver_id="plugin.qq.sender",
status="failed",
error="network error",
)
],
route_key=SimpleNamespace(platform="qq"),
)
)
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager)
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
monkeypatch.setattr(
send_service._chat_manager,
"get_session_by_session_id",
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
)
result = await send_service.text_to_stream(text="发送失败", stream_id="test-session")
assert result is False
assert fake_manager.ensure_calls == 1
assert len(fake_manager.sent_messages) == 1
@pytest.mark.asyncio
async def test_custom_to_stream_detailed_raises_stream_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
send_service._chat_manager,
"get_session_by_session_id",
lambda stream_id: None,
)
with pytest.raises(send_service.SendServiceError, match="未找到聊天流: group_chat"):
await send_service.custom_to_stream_detailed(
message_type="poke",
content={"qq_id": "2810873701"},
stream_id="group_chat",
)
@pytest.mark.asyncio
async def test_private_outbound_message_preserves_bot_sender_and_receiver_user(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
monkeypatch.setattr(
send_service._chat_manager,
"get_session_by_session_id",
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
)
outbound_message = send_service._build_outbound_session_message(
message_sequence=MessageSequence(components=[TextComponent(text="你好")]),
stream_id="test-session",
processed_plain_text="你好",
)
assert outbound_message is not None
maim_message = await outbound_message.to_maim_message()
assert maim_message.message_info.user_info is not None
assert maim_message.message_info.user_info.user_id == "bot-qq"
assert maim_message.message_info.group_info is None
assert maim_message.message_info.sender_info is not None
assert maim_message.message_info.sender_info.user_info is not None
assert maim_message.message_info.sender_info.user_info.user_id == "bot-qq"
assert maim_message.message_info.receiver_info is not None
assert maim_message.message_info.receiver_info.user_info is not None
assert maim_message.message_info.receiver_info.user_info.user_id == "target-user"
@pytest.mark.asyncio
async def test_group_outbound_message_preserves_bot_sender_and_target_group(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
monkeypatch.setattr(
send_service._chat_manager,
"get_session_by_session_id",
lambda stream_id: _build_group_stream() if stream_id == "group-session" else None,
)
outbound_message = send_service._build_outbound_session_message(
message_sequence=MessageSequence(components=[TextComponent(text="大家好")]),
stream_id="group-session",
processed_plain_text="大家好",
)
assert outbound_message is not None
maim_message = await outbound_message.to_maim_message()
assert maim_message.message_info.user_info is not None
assert maim_message.message_info.user_info.user_id == "bot-qq"
assert maim_message.message_info.group_info is not None
assert maim_message.message_info.group_info.group_id == "target-group"
assert maim_message.message_info.receiver_info is not None
assert maim_message.message_info.receiver_info.group_info is not None
assert maim_message.message_info.receiver_info.group_info.group_id == "target-group"

View File

@@ -1,297 +0,0 @@
from types import SimpleNamespace
from typing import Any
import importlib.util
import sys
import pytest
from src.core.tooling import ToolAvailabilityContext, ToolRegistry
from src.maisaka.tool_provider import MaisakaBuiltinToolProvider
from src.plugin_runtime.component_query import ComponentQueryService
from src.plugin_runtime.host.component_registry import ComponentRegistry
@pytest.mark.asyncio
async def test_builtin_at_tool_is_not_exposed() -> None:
registry = ToolRegistry()
registry.register_provider(MaisakaBuiltinToolProvider())
group_specs = await registry.list_tools(ToolAvailabilityContext(session_id="group-1", is_group_chat=True))
private_specs = await registry.list_tools(ToolAvailabilityContext(session_id="private-1", is_group_chat=False))
default_specs = await registry.list_tools()
assert "at" not in {tool_spec.name for tool_spec in group_specs}
assert "at" not in {tool_spec.name for tool_spec in private_specs}
assert "at" not in {tool_spec.name for tool_spec in default_specs}
def test_plugin_tool_chat_scope_uses_component_field(monkeypatch: pytest.MonkeyPatch) -> None:
service = ComponentQueryService()
registry = ComponentRegistry()
supervisor = SimpleNamespace(component_registry=registry)
monkeypatch.setattr(service, "_iter_supervisors", lambda: [supervisor])
registry.register_plugin_components(
"scope_plugin",
[
{
"name": "group_tool",
"component_type": "TOOL",
"chat_scope": "group",
"metadata": {"description": "group only"},
},
{
"name": "private_tool",
"component_type": "TOOL",
"chat_scope": "private",
"metadata": {"description": "private only"},
},
{
"name": "all_tool",
"component_type": "TOOL",
"metadata": {"description": "all chats"},
},
],
)
group_specs = service.get_llm_available_tool_specs(
context=ToolAvailabilityContext(session_id="group-1", is_group_chat=True)
)
private_specs = service.get_llm_available_tool_specs(
context=ToolAvailabilityContext(session_id="private-1", is_group_chat=False)
)
group_entry = registry.get_component("scope_plugin.group_tool")
assert group_entry is not None
assert group_entry.chat_scope == "group"
assert "chat_scope" not in group_entry.metadata
assert set(group_specs) == {"group_tool", "all_tool"}
assert set(private_specs) == {"private_tool", "all_tool"}
def test_plugin_tool_session_disable_still_filters_specific_chat(monkeypatch: pytest.MonkeyPatch) -> None:
service = ComponentQueryService()
registry = ComponentRegistry()
supervisor = SimpleNamespace(component_registry=registry)
monkeypatch.setattr(service, "_iter_supervisors", lambda: [supervisor])
registry.register_plugin_components(
"mute_plugin",
[
{
"name": "mute",
"component_type": "TOOL",
"chat_scope": "group",
"metadata": {"description": "mute group member"},
}
],
)
registry.set_component_enabled("mute_plugin.mute", False, session_id="group-disabled")
disabled_specs = service.get_llm_available_tool_specs(
context=ToolAvailabilityContext(session_id="group-disabled", is_group_chat=True)
)
enabled_specs = service.get_llm_available_tool_specs(
context=ToolAvailabilityContext(session_id="group-enabled", is_group_chat=True)
)
assert "mute" not in disabled_specs
assert "mute" in enabled_specs
def test_plugin_tool_allowed_session_filters_tool_exposure(monkeypatch: pytest.MonkeyPatch) -> None:
service = ComponentQueryService()
registry = ComponentRegistry()
supervisor = SimpleNamespace(component_registry=registry)
monkeypatch.setattr(service, "_iter_supervisors", lambda: [supervisor])
registry.register_plugin_components(
"mute_plugin",
[
{
"name": "mute",
"component_type": "TOOL",
"chat_scope": "group",
"allowed_session": ["qq:10001", "raw-group-id", "exact-session-id"],
"metadata": {"description": "mute group member"},
}
],
)
platform_group_specs = service.get_llm_available_tool_specs(
context=ToolAvailabilityContext(
session_id="hashed-session-1",
is_group_chat=True,
group_id="10001",
platform="qq",
)
)
raw_group_specs = service.get_llm_available_tool_specs(
context=ToolAvailabilityContext(
session_id="hashed-session-2",
is_group_chat=True,
group_id="raw-group-id",
platform="qq",
)
)
exact_session_specs = service.get_llm_available_tool_specs(
context=ToolAvailabilityContext(session_id="exact-session-id", is_group_chat=True)
)
blocked_specs = service.get_llm_available_tool_specs(
context=ToolAvailabilityContext(
session_id="blocked-session",
is_group_chat=True,
group_id="20002",
platform="qq",
)
)
entry = registry.get_component("mute_plugin.mute")
assert entry is not None
assert entry.allowed_session == {"qq:10001", "raw-group-id", "exact-session-id"}
assert "allowed_session" not in entry.metadata
assert "mute" in platform_group_specs
assert "mute" in raw_group_specs
assert "mute" in exact_session_specs
assert "mute" not in blocked_specs
def test_plugin_tool_disabled_session_take_precedence_over_allowed_session(
monkeypatch: pytest.MonkeyPatch,
) -> None:
service = ComponentQueryService()
registry = ComponentRegistry()
supervisor = SimpleNamespace(component_registry=registry)
monkeypatch.setattr(service, "_iter_supervisors", lambda: [supervisor])
registry.register_plugin_components(
"mute_plugin",
[
{
"name": "mute",
"component_type": "TOOL",
"chat_scope": "group",
"allowed_session": ["qq:10001"],
"metadata": {"description": "mute group member"},
}
],
)
registry.set_component_enabled("mute_plugin.mute", False, session_id="allowed-session")
visible_specs = service.get_llm_available_tool_specs(
context=ToolAvailabilityContext(
session_id="visible-session",
is_group_chat=True,
group_id="10001",
platform="qq",
)
)
disabled_specs = service.get_llm_available_tool_specs(
context=ToolAvailabilityContext(
session_id="allowed-session",
is_group_chat=True,
group_id="10001",
platform="qq",
)
)
entry = registry.get_component("mute_plugin.mute")
assert entry is not None
assert entry.disabled_session == {"allowed-session"}
assert "mute" in visible_specs
assert "mute" not in disabled_specs
def test_mute_plugin_exports_allowed_groups_as_component_allowed_session() -> None:
module_path = "plugins/MutePlugin/plugin.py"
spec = importlib.util.spec_from_file_location("mute_plugin_under_test", module_path)
assert spec is not None
assert spec.loader is not None
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
module.MutePluginConfig.model_rebuild()
plugin = module.MutePlugin()
plugin.set_plugin_config({"permissions": {"allowed_groups": ["qq:10001", "raw-group-id"]}})
mute_components = [component for component in plugin.get_components() if component.get("name") == "mute"]
assert len(mute_components) == 1
assert mute_components[0]["chat_scope"] == "group"
assert mute_components[0]["allowed_session"] == ["qq:10001", "raw-group-id"]
assert "allowed_session" not in mute_components[0]["metadata"]
@pytest.mark.asyncio
async def test_mute_tool_queries_target_message_with_current_chat_id() -> None:
module_path = "plugins/MutePlugin/plugin.py"
spec = importlib.util.spec_from_file_location("mute_plugin_under_test_msg_id", module_path)
assert spec is not None
assert spec.loader is not None
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
module.MutePluginConfig.model_rebuild()
capability_calls: list[dict[str, Any]] = []
api_calls: list[dict[str, Any]] = []
async def fake_call_capability(name: str, **kwargs: Any) -> dict[str, Any]:
capability_calls.append({"name": name, **kwargs})
return {
"success": True,
"result": {
"success": True,
"message": {
"message_info": {
"user_info": {
"user_id": "35529667",
"user_cardname": "目标用户",
"user_nickname": "目标昵称",
}
}
},
},
}
async def fake_api_call(api_name: str, **kwargs: Any) -> dict[str, Any]:
api_calls.append({"name": api_name, **kwargs})
if api_name == "adapter.napcat.group.get_group_member_info":
return {"success": True, "result": {"data": {"role": "member"}}}
return {"status": "ok", "retcode": 0}
plugin = module.MutePlugin()
plugin.set_plugin_config({"components": {"enable_smart_mute": True}})
plugin._set_context(
SimpleNamespace(
call_capability=fake_call_capability,
api=SimpleNamespace(call=fake_api_call),
logger=SimpleNamespace(info=lambda *args, **kwargs: None, warning=lambda *args, **kwargs: None),
)
)
success, message = await plugin.handle_mute_tool(
stream_id="current-session-id",
group_id="766798517",
msg_id="2046083292",
duration=3600,
reason="测试",
)
assert success is True
assert message == "成功禁言 目标用户"
assert capability_calls == [
{
"name": "message.get_by_id",
"message_id": "2046083292",
"chat_id": "current-session-id",
}
]
assert api_calls[-1] == {
"name": "adapter.napcat.group.set_group_ban",
"version": "1",
"group_id": "766798517",
"user_id": "35529667",
"duration": 3600,
}

View File

@@ -1,367 +0,0 @@
import sys
from dataclasses import dataclass, field
import pytest
import importlib
import importlib.util
from types import ModuleType
from pathlib import Path
from datetime import datetime
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from src.common.data_models.message_component_data_model import MessageSequence
from src.chat.message_receive.message import (
SessionMessage,
TextComponent,
ImageComponent,
AtComponent,
)
class DummyLogger:
def __init__(self) -> None:
self.logging_record = []
def debug(self, msg):
print(f"DEBUG: {msg}")
self.logging_record.append(f"DEBUG: {msg}")
def info(self, msg):
print(f"INFO: {msg}")
self.logging_record.append(f"INFO: {msg}")
def warning(self, msg):
print(f"WARNING: {msg}")
self.logging_record.append(f"WARNING: {msg}")
def error(self, msg):
print(f"ERROR: {msg}")
self.logging_record.append(f"ERROR: {msg}")
def critical(self, msg):
print(f"CRITICAL: {msg}")
self.logging_record.append(f"CRITICAL: {msg}")
def get_logger(name):
return DummyLogger()
class DummyDBSession:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def exec(self, statement):
return self
def first(self):
return None
def commit(self):
pass
def all(self):
return []
def get_db_session():
return DummyDBSession()
def get_manual_db_session():
return DummyDBSession()
class DummySelect:
def __init__(self, model):
self.model = model
def filter_by(self, **kwargs):
return self
def where(self, condition):
return self
def limit(self, n):
return self
def select(model):
return DummySelect(model)
async def dummy_get_voice_text(binary_data):
return None # 可以根据需要返回模拟的文本结果
class DummyPersonUtils:
@staticmethod
def get_person_info_by_user_id_and_platform(user_id, platform):
return None # 可以根据需要返回模拟的用户信息
class DummyConfig:
class MessageReceiveConfig:
ban_words = set()
ban_msgs_regex = set()
message_receive = MessageReceiveConfig()
@dataclass
class UserInfo:
user_id: str
user_nickname: str
user_cardname: Optional[str] = None
@dataclass
class GroupInfo:
group_id: str
group_name: str
@dataclass
class MessageInfo:
user_info: UserInfo
group_info: Optional[GroupInfo] = None
additional_config: dict = field(default_factory=dict)
def setup_mocks(monkeypatch):
def _stub_module(name: str) -> ModuleType:
module = ModuleType(name)
monkeypatch.setitem(sys.modules, name, module)
return module
# src.common.logger
logger_mod = _stub_module("src.common.logger")
# Mock the logger
logger_mod.get_logger = get_logger
db_mod = _stub_module("src.common.database.database")
db_mod.get_db_session = get_db_session
db_mod.get_manual_db_session = get_manual_db_session
db_model_mod = _stub_module("src.common.database.database_model")
db_model_mod.Messages = None # 可以根据需要添加更多的属性或方法
emoji_manager_mod = _stub_module("src.emoji_system.emoji_manager")
emoji_manager_mod.emoji_manager = None # 可以根据需要添加更多的属性或方法
image_manager_mod = _stub_module("src.chat.image_system.image_manager")
image_manager_mod.image_manager = None # 可以根据需要添加更多的属性或方法
voice_utils_mod = _stub_module("src.common.utils.utils_voice")
voice_utils_mod.get_voice_text = dummy_get_voice_text
person_utils_mod = _stub_module("src.common.utils.utils_person")
person_utils_mod.PersonUtils = DummyPersonUtils
config_mod = _stub_module("src.config.config")
config_mod.global_config = DummyConfig()
def load_message_via_file(monkeypatch):
setup_mocks(monkeypatch)
file_path = Path(__file__).parent.parent.parent / "src" / "chat" / "message_receive" / "message.py"
spec = importlib.util.spec_from_file_location("message", file_path)
message_module = importlib.util.module_from_spec(spec)
monkeypatch.setitem(sys.modules, "message_module", message_module)
spec.loader.exec_module(message_module)
message_module.select = select
SessionMessageClass = message_module.SessionMessage
TextComponentClass = message_module.TextComponent
ImageComponentClass = message_module.ImageComponent
EmojiComponentClass = message_module.EmojiComponent
VoiceComponentClass = message_module.VoiceComponent
AtComponentClass = message_module.AtComponent
ReplyComponentClass = message_module.ReplyComponent
ForwardNodeComponentClass = message_module.ForwardNodeComponent
MessageSequenceClass = sys.modules["src.common.data_models.message_component_data_model"].MessageSequence
ForwardComponentClass = sys.modules["src.common.data_models.message_component_data_model"].ForwardComponent
globals()["SessionMessage"] = SessionMessageClass
globals()["TextComponent"] = TextComponentClass
globals()["ImageComponent"] = ImageComponentClass
globals()["EmojiComponent"] = EmojiComponentClass
globals()["VoiceComponent"] = VoiceComponentClass
globals()["AtComponent"] = AtComponentClass
globals()["ReplyComponent"] = ReplyComponentClass
globals()["ForwardNodeComponent"] = ForwardNodeComponentClass
globals()["MessageSequence"] = MessageSequenceClass
globals()["ForwardComponent"] = ForwardComponentClass
return message_module
def dummy_number_to_short_id(original_id: int, salt: str, length: int = 6) -> str:
return "X" * length # 返回固定的字符串长度由参数决定模拟生成短ID的行为
def dummy_is_bot_self(platform, user_id: str) -> bool:
return user_id == "bot_self"
def load_utils_via_file(monkeypatch):
setup_mocks(monkeypatch)
# Mock math_utils 模块,供 from .math_utils import number_to_short_id 使用
math_utils_mod = ModuleType("src.common.utils.math_utils")
math_utils_mod.number_to_short_id = dummy_number_to_short_id
math_utils_mod.TimestampMode = type(
"TimestampMode", (), {"NORMAL": "%Y-%m-%d %H:%M:%S", "NORMAL_NO_YMD": "%H:%M:%S", "RELATIVE": "relative"}
)
math_utils_mod.translate_timestamp_to_human_readable = lambda timestamp, mode: (
"2024-01-01 12:00:00"
) # 返回固定的时间字符串
monkeypatch.setitem(sys.modules, "src.common.utils.math_utils", math_utils_mod)
# 确保包层级模块存在于 sys.modules 中,使相对导入能正确解析
for pkg_name in ["src", "src.common", "src.common.utils"]:
if pkg_name not in sys.modules:
pkg_mod = ModuleType(pkg_name)
pkg_mod.__path__ = []
monkeypatch.setitem(sys.modules, pkg_name, pkg_mod)
file_path = Path(__file__).parent.parent.parent / "src" / "common" / "utils" / "utils_message.py"
spec = importlib.util.spec_from_file_location("src.common.utils.utils_message", file_path)
utils_module = importlib.util.module_from_spec(spec)
utils_module.__package__ = "src.common.utils" # 设置包,使相对导入生效
monkeypatch.setitem(sys.modules, "src.common.utils.utils_message", utils_module)
monkeypatch.setitem(sys.modules, "message_utils_module", utils_module)
spec.loader.exec_module(utils_module)
utils_module.is_bot_self = dummy_is_bot_self
return utils_module
@pytest.mark.asyncio
async def test_message_utils(monkeypatch):
load_message_via_file(monkeypatch)
load_utils_via_file(monkeypatch)
@pytest.mark.asyncio
async def test_build_readable_message_basic(monkeypatch):
"""基础用例:单条消息,显示行号"""
load_message_via_file(monkeypatch)
utils_module = load_utils_via_file(monkeypatch)
MessageUtils = utils_module.MessageUtils
msg = SessionMessage("m1", datetime.now(), platform="test")
msg.platform = "test"
msg.session_id = "s_test"
user_info = UserInfo(user_id="u1", user_nickname="Alice")
msg.message_info = MessageInfo(user_info=user_info)
msg.raw_message = MessageSequence([TextComponent("Hello world")])
text, mapping, _ = await MessageUtils.build_readable_message([msg], anonymize=False, show_lineno=True)
assert "[1] Alice说Hello world" in text
assert mapping == {}
@pytest.mark.asyncio
async def test_build_readable_message_anonymize(monkeypatch):
"""匿名化用例:验证 mapping 和返回文本"""
load_message_via_file(monkeypatch)
utils_module = load_utils_via_file(monkeypatch)
MessageUtils = utils_module.MessageUtils
msg = SessionMessage("m2", datetime.now(), platform="test")
msg.session_id = "s_test"
user_info = UserInfo(user_id="u42", user_nickname="Bob")
msg.message_info = MessageInfo(user_info=user_info)
msg.raw_message = MessageSequence([TextComponent("Secret text")])
text, mapping, _ = await MessageUtils.build_readable_message([msg], anonymize=True, show_lineno=False)
# 根据实现original_name 为 user_nickname因此文本中应包含原始名称
assert "XXXXXX说" in text
assert "u42" in mapping
assert mapping["u42"][0] == "XXXXXX"
assert mapping["u42"][1] == "Bob"
@pytest.mark.asyncio
async def test_build_readable_message_replace_bot(monkeypatch):
"""替换机器人名用例:当 user_id 为 bot_self 时应被替换为 target_bot_name"""
load_message_via_file(monkeypatch)
utils_module = load_utils_via_file(monkeypatch)
MessageUtils = utils_module.MessageUtils
msg = SessionMessage("m3", datetime.now(), platform="test")
msg.session_id = "s_test"
user_info = UserInfo(user_id="bot_self", user_nickname="SomeBot")
msg.message_info = MessageInfo(user_info=user_info)
msg.raw_message = MessageSequence([TextComponent("ping")])
text, mapping, _ = await MessageUtils.build_readable_message([msg], replace_bot_name=True, target_bot_name="MAIBot")
assert "MAIBot说ping" in text
@pytest.mark.asyncio
async def test_build_readable_message_image_extraction(monkeypatch):
"""图片提取:验证 extract_pictures 为 True 时,文本中包含图片占位及 img_map 内容被返回"""
load_message_via_file(monkeypatch)
utils_module = load_utils_via_file(monkeypatch)
MessageUtils = utils_module.MessageUtils
# 构建包含图片组件的消息
img = ImageComponent(binary_hash="h", binary_data=b"\x01\x02", content="Img")
msg = SessionMessage("mi1", datetime.now(), platform="test")
msg.session_id = "s_img"
msg.raw_message = MessageSequence([img])
msg.message_info = MessageInfo(UserInfo(user_id="ui_img", user_nickname="ImgUser"))
text, mapping, _ = await MessageUtils.build_readable_message([msg], extract_pictures=True)
# 应包含图片描述占位
assert "图片1" in text
# mapping 不为空(匿名化未开启则为空)
assert isinstance(mapping, dict)
@pytest.mark.asyncio
async def test_build_readable_message_anonymize_and_replace_bot_name_and_lineno(monkeypatch):
"""组合用例:多个消息同时包含匿名化、机器人名称替换"""
load_message_via_file(monkeypatch)
utils_module = load_utils_via_file(monkeypatch)
MessageUtils = utils_module.MessageUtils
# 构建多个消息
msg1 = SessionMessage("m4", datetime.now(), platform="test")
msg1.session_id = "s_comb"
msg2 = SessionMessage("m5", datetime.now(), platform="test")
msg2.session_id = "s_comb"
msg1.message_info = MessageInfo(UserInfo(user_id="u_comb", user_nickname="Charlie"))
msg2.message_info = MessageInfo(UserInfo(user_id="bot_self", user_nickname="SomeBot"))
msg1.raw_message = MessageSequence([TextComponent("Hi")])
msg2.raw_message = MessageSequence([TextComponent("Hello")])
text, mapping, _ = await MessageUtils.build_readable_message(
[msg1, msg2],
anonymize=True,
replace_bot_name=True,
target_bot_name="MAIBot",
show_lineno=True,
)
# 验证文本内容
assert "[1] XXXXXX说Hi" in text
assert "[2] MAIBot说Hello" in text
# 验证 mapping 内容
assert "u_comb" in mapping
assert mapping["u_comb"][0] == "XXXXXX"
@pytest.mark.asyncio
async def test_build_readable_message_with_at(monkeypatch):
"""包含@组件的消息:验证@组件中的用户信息也被匿名化和替换"""
load_message_via_file(monkeypatch)
utils_module = load_utils_via_file(monkeypatch)
MessageUtils = utils_module.MessageUtils
# 构建包含回复组件的消息
at_comp = AtComponent(target_user_id="u_at", target_user_nickname="AtUser")
msg = SessionMessage("m_at", datetime.now(), platform="test")
msg.session_id = "s_at"
msg.raw_message = MessageSequence([at_comp])
msg.message_info = MessageInfo(UserInfo(user_id="u_main", user_nickname="MainUser"))
text, mapping, _ = await MessageUtils.build_readable_message(
[msg], anonymize=True, replace_bot_name=True, target_bot_name="MAIBot"
)
# 验证主消息和@组件中的用户信息都被处理
assert "XXXXXX说" in text # 主消息用户被匿名化
assert "XXXXXX说@XXXXXX" in text # @组件用户被匿名化

View File

@@ -1,117 +0,0 @@
"""统计模块数据库会话行为测试。"""
from __future__ import annotations
from contextlib import contextmanager
from datetime import datetime, timedelta
from types import ModuleType
from typing import Any, Callable, Iterator
import sys
import pytest
from src.chat.utils import statistic
class _DummyResult:
"""模拟 SQLModel 查询结果对象。"""
def all(self) -> list[Any]:
"""返回空结果集。
Returns:
list[Any]: 空列表。
"""
return []
class _DummySession:
"""模拟数据库 Session。"""
def exec(self, statement: Any) -> _DummyResult:
"""执行查询语句并返回空结果。
Args:
statement: 待执行的查询语句。
Returns:
_DummyResult: 空结果对象。
"""
del statement
return _DummyResult()
def _build_fake_get_db_session(calls: list[bool]) -> Callable[[bool], Iterator[_DummySession]]:
"""构造一个记录 auto_commit 参数的假会话工厂。
Args:
calls: 用于记录每次调用 auto_commit 参数的列表。
Returns:
Callable[[bool], Iterator[_DummySession]]: 可替换 `get_db_session` 的上下文管理器工厂。
"""
@contextmanager
def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_DummySession]:
"""记录会话参数并返回假 Session。
Args:
auto_commit: 是否启用自动提交。
Yields:
Iterator[_DummySession]: 假 Session 对象。
"""
calls.append(auto_commit)
yield _DummySession()
return _fake_get_db_session
def _build_statistic_task() -> statistic.StatisticOutputTask:
"""构造一个最小可用的统计任务实例。
Returns:
statistic.StatisticOutputTask: 跳过 `__init__` 的测试实例。
"""
task = statistic.StatisticOutputTask.__new__(statistic.StatisticOutputTask)
task.name_mapping = {}
return task
def _is_bot_self(platform: str, user_id: str) -> bool:
"""返回固定的非机器人身份判断结果。
Args:
platform: 平台名称。
user_id: 用户 ID。
Returns:
bool: 始终返回 ``False``。
"""
del platform
del user_id
return False
def test_statistic_read_queries_disable_auto_commit(monkeypatch: pytest.MonkeyPatch) -> None:
"""统计模块的纯读查询应关闭自动提交,避免 Session 退出后对象被 expire。"""
calls: list[bool] = []
now = datetime.now()
task = _build_statistic_task()
monkeypatch.setattr(statistic, "get_db_session", _build_fake_get_db_session(calls))
utils_module = ModuleType("src.chat.utils.utils")
utils_module.is_bot_self = _is_bot_self
monkeypatch.setitem(sys.modules, "src.chat.utils.utils", utils_module)
monkeypatch.setattr(statistic, "fetch_online_time_since", lambda query_start_time: [])
monkeypatch.setattr(statistic, "fetch_model_usage_since", lambda query_start_time: [])
monkeypatch.setattr(statistic, "fetch_messages_since", lambda query_start_time: [])
monkeypatch.setattr(statistic, "fetch_tool_records_since", lambda query_start_time: [])
task._collect_message_count_for_period([("last_hour", now - timedelta(hours=1))])
task._collect_interval_data(now, hours=1, interval_minutes=60)
task._collect_metrics_interval_data(now, hours=1, interval_hours=1)
assert calls == []

View File

@@ -1,131 +0,0 @@
from pathlib import Path
import json
from src.config.model_configs import APIProvider, ModelInfo
from src.llm_models.model_client.base_client import ResponseRequest
from src.llm_models.payload_content.message import MessageBuilder, RoleType
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
from src.llm_models.payload_content.tool_option import ToolCall, ToolOption
from src.llm_models.request_snapshot import (
attach_request_snapshot,
deserialize_messages_snapshot,
format_request_snapshot_log_info,
save_failed_request_snapshot,
serialize_messages_snapshot,
serialize_response_request_snapshot,
)
from src.llm_models import request_snapshot
def _build_api_provider() -> APIProvider:
return APIProvider(
api_key="secret-token",
base_url="https://example.com/v1",
name="test-provider",
)
def _build_model_info() -> ModelInfo:
return ModelInfo(
api_provider="test-provider",
model_identifier="demo-model",
name="demo-model",
)
def _build_response_request() -> ResponseRequest:
tool_call = ToolCall(
args={"query": "MaiBot"},
call_id="call_1",
func_name="search_web",
extra_content={"google": {"thought_signature": "c2lnbmF0dXJl"}},
)
message_list = [
MessageBuilder().set_role(RoleType.User).add_text_content("你好").add_image_content("png", "ZmFrZQ==").build(),
MessageBuilder().set_role(RoleType.Assistant).set_tool_calls([tool_call]).build(),
MessageBuilder()
.set_role(RoleType.Tool)
.set_tool_call_id("call_1")
.set_tool_name("search_web")
.add_text_content('{"ok": true}')
.build(),
]
return ResponseRequest(
extra_params={"trace_id": "trace-123"},
max_tokens=256,
message_list=message_list,
model_info=_build_model_info(),
response_format=RespFormat(RespFormatType.JSON_OBJ),
temperature=0.2,
tool_options=[ToolOption(name="search_web", description="搜索网页")],
)
def test_message_snapshot_roundtrip_preserves_tool_messages() -> None:
request = _build_response_request()
snapshot_messages = serialize_messages_snapshot(request.message_list)
restored_messages = deserialize_messages_snapshot(snapshot_messages)
assert len(restored_messages) == 3
assert restored_messages[0].role == RoleType.User
assert restored_messages[0].get_text_content() == "你好"
assert restored_messages[0].parts[1].image_format == "png"
assert restored_messages[1].role == RoleType.Assistant
assert restored_messages[1].tool_calls is not None
assert restored_messages[1].tool_calls[0].func_name == "search_web"
assert restored_messages[1].tool_calls[0].args == {"query": "MaiBot"}
assert restored_messages[1].tool_calls[0].extra_content == {"google": {"thought_signature": "c2lnbmF0dXJl"}}
assert restored_messages[2].role == RoleType.Tool
assert restored_messages[2].tool_call_id == "call_1"
assert restored_messages[2].tool_name == "search_web"
def test_failed_request_snapshot_contains_replay_entry(tmp_path: Path, monkeypatch) -> None:
monkeypatch.setattr(request_snapshot, "LLM_REQUEST_LOG_DIR", tmp_path)
request = _build_response_request()
provider = _build_api_provider()
snapshot_path = save_failed_request_snapshot(
api_provider=provider,
client_type="openai",
error=RuntimeError("boom"),
internal_request=serialize_response_request_snapshot(request),
model_info=request.model_info,
operation="chat.completions.create",
provider_request={"request_kwargs": {"model": request.model_info.model_identifier}},
)
assert snapshot_path is not None
payload = json.loads(snapshot_path.read_text(encoding="utf-8"))
assert payload["internal_request"]["request_kind"] == "response"
assert payload["api_provider"]["name"] == "test-provider"
assert payload["replay"]["file_uri"] == snapshot_path.as_uri()
assert str(snapshot_path) in payload["replay"]["command"]
assert "secret-token" not in snapshot_path.read_text(encoding="utf-8")
def test_format_request_snapshot_log_info_includes_path_uri_and_command(tmp_path: Path, monkeypatch) -> None:
monkeypatch.setattr(request_snapshot, "LLM_REQUEST_LOG_DIR", tmp_path)
request = _build_response_request()
snapshot_path = save_failed_request_snapshot(
api_provider=_build_api_provider(),
client_type="openai",
error=ValueError("invalid"),
internal_request=serialize_response_request_snapshot(request),
model_info=request.model_info,
operation="chat.completions.create",
provider_request={"request_kwargs": {"messages": []}},
)
assert snapshot_path is not None
exc = RuntimeError("wrapped")
attach_request_snapshot(exc, snapshot_path)
log_info = format_request_snapshot_log_info(exc)
assert str(snapshot_path) in log_info
assert snapshot_path.as_uri() in log_info
assert "uv run python scripts/replay_llm_request.py" in log_info

View File

@@ -1,42 +0,0 @@
from types import SimpleNamespace
from src.chat.message_receive.chat_manager import ChatManager
from src.common.utils.utils_session import SessionUtils
def test_calculate_session_id_distinguishes_account_and_scope() -> None:
base_session_id = SessionUtils.calculate_session_id("qq", user_id="42")
same_base_session_id = SessionUtils.calculate_session_id("qq", user_id="42")
account_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123")
route_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123", scope="main")
assert base_session_id == same_base_session_id
assert account_scoped_session_id != base_session_id
assert route_scoped_session_id != account_scoped_session_id
def test_chat_manager_register_message_uses_route_metadata() -> None:
chat_manager = ChatManager()
message = SimpleNamespace(
platform="qq",
session_id="",
message_info=SimpleNamespace(
user_info=SimpleNamespace(user_id="42"),
group_info=SimpleNamespace(group_id="1000"),
additional_config={
"platform_io_account_id": "123",
"platform_io_scope": "main",
},
),
)
chat_manager.register_message(message)
assert message.session_id == SessionUtils.calculate_session_id(
"qq",
user_id="42",
group_id="1000",
account_id="123",
scope="main",
)
assert chat_manager.last_messages[message.session_id] is message

View File

@@ -1,161 +0,0 @@
from pathlib import Path
from unittest.mock import patch
import pytest
from src.webui import app as webui_app
def test_ensure_static_path_ready_uses_existing_static_path(tmp_path) -> None:
static_path = tmp_path / "dist"
static_path.mkdir()
(static_path / "index.html").write_text("<html></html>", encoding="utf-8")
with patch.object(webui_app, "_resolve_static_path", return_value=static_path):
result = webui_app._ensure_static_path_ready()
assert result == static_path
def test_ensure_static_path_ready_logs_install_hint_when_static_assets_are_missing() -> None:
with (
patch.object(webui_app, "_resolve_static_path", return_value=None),
patch.object(webui_app.logger, "warning") as warning_mock,
):
result = webui_app._ensure_static_path_ready()
assert result is None
warning_mock.assert_any_call(webui_app.t("startup.webui_static_assets_unavailable"))
warning_mock.assert_any_call(
webui_app.t("startup.webui_dashboard_package_hint", command=webui_app._MANUAL_INSTALL_COMMAND)
)
def test_ensure_static_path_ready_logs_index_error_when_static_path_is_invalid(tmp_path) -> None:
static_path = tmp_path / "dist"
static_path.mkdir()
with (
patch.object(webui_app, "_resolve_static_path", return_value=static_path),
patch.object(webui_app.logger, "warning") as warning_mock,
):
result = webui_app._ensure_static_path_ready()
assert result is None
warning_mock.assert_any_call(
webui_app.t("startup.webui_index_missing", index_path=static_path / "index.html")
)
warning_mock.assert_any_call(
webui_app.t("startup.webui_dashboard_package_hint", command=webui_app._MANUAL_INSTALL_COMMAND)
)
def test_setup_static_files_does_not_duplicate_warning_when_static_path_is_unavailable() -> None:
app = webui_app.FastAPI()
with (
patch.object(webui_app, "_ensure_static_path_ready", return_value=None),
patch.object(webui_app.logger, "warning") as warning_mock,
):
webui_app._setup_static_files(app)
warning_mock.assert_not_called()
def test_resolve_static_path_prefers_installed_dashboard_package(monkeypatch, tmp_path) -> None:
package_dist = tmp_path / "site-packages" / "maibot_dashboard" / "dist"
package_dist.mkdir(parents=True)
class _DashboardModule:
@staticmethod
def get_dist_path() -> Path:
return package_dist
monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path)
with patch.object(webui_app, "import_module", return_value=_DashboardModule()):
resolved_path = webui_app._resolve_static_path()
assert resolved_path == package_dist
def test_resolve_static_path_ignores_dashboard_dist_when_package_is_unavailable(monkeypatch, tmp_path) -> None:
dashboard_dist = tmp_path / "dashboard" / "dist"
dashboard_dist.mkdir(parents=True)
(dashboard_dist / "index.html").write_text("<html></html>", encoding="utf-8")
monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path)
with patch.object(webui_app, "import_module", side_effect=ImportError):
resolved_path = webui_app._resolve_static_path()
assert resolved_path is None
def test_resolve_static_path_uses_package_even_when_dashboard_dist_exists(monkeypatch, tmp_path) -> None:
dashboard_dist = tmp_path / "dashboard" / "dist"
dashboard_dist.mkdir(parents=True)
package_dist = tmp_path / "site-packages" / "maibot_dashboard" / "dist"
package_dist.mkdir(parents=True)
class _DashboardModule:
@staticmethod
def get_dist_path() -> Path:
return package_dist
monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path)
with patch.object(webui_app, "import_module", return_value=_DashboardModule()):
resolved_path = webui_app._resolve_static_path()
assert resolved_path == package_dist
def test_resolve_safe_static_file_path_allows_regular_static_file(tmp_path) -> None:
static_path = tmp_path / "dist"
asset_path = static_path / "assets" / "app.js"
asset_path.parent.mkdir(parents=True)
asset_path.write_text("console.log('ok')", encoding="utf-8")
resolved_path = webui_app._resolve_safe_static_file_path(static_path, "assets/app.js")
assert resolved_path == asset_path.resolve()
def test_resolve_safe_static_file_path_rejects_relative_path_traversal(tmp_path) -> None:
static_path = tmp_path / "dist"
static_path.mkdir()
resolved_path = webui_app._resolve_safe_static_file_path(static_path, "../secret.txt")
assert resolved_path is None
def test_resolve_safe_static_file_path_rejects_absolute_path_traversal(tmp_path) -> None:
static_path = tmp_path / "dist"
static_path.mkdir()
resolved_path = webui_app._resolve_safe_static_file_path(static_path, "/etc/passwd")
assert resolved_path is None
def test_resolve_safe_static_file_path_rejects_symlink_escape(tmp_path) -> None:
static_path = tmp_path / "dist"
static_path.mkdir()
outside_dir = tmp_path / "outside"
outside_dir.mkdir()
outside_file = outside_dir / "secret.txt"
outside_file.write_text("secret", encoding="utf-8")
link_path = static_path / "escape"
try:
link_path.symlink_to(outside_dir, target_is_directory=True)
except OSError as exc:
pytest.skip(f"symlink is not supported in this environment: {exc}")
resolved_path = webui_app._resolve_safe_static_file_path(static_path, "escape/secret.txt")
assert resolved_path is None

View File

@@ -1,147 +0,0 @@
from src.config.official_configs import ChatConfig, MessageReceiveConfig
from src.config.config import Config
from src.config.config_base import ConfigBase, Field
from src.webui.config_schema import ConfigSchemaGenerator
def test_field_docs_in_schema():
"""Test that field descriptions are correctly extracted from field_docs (docstrings)."""
schema = ConfigSchemaGenerator.generate_schema(ChatConfig)
talk_value = next(f for f in schema["fields"] if f["name"] == "talk_value")
# Verify description field exists
assert "description" in talk_value
# Verify description contains expected Chinese text from the docstring
assert "聊天频率" in talk_value["description"]
def test_json_schema_extra_merged():
"""Test that json_schema_extra fields are correctly merged into output."""
schema = ConfigSchemaGenerator.generate_schema(ChatConfig)
talk_value = next(f for f in schema["fields"] if f["name"] == "talk_value")
# Verify UI metadata fields from json_schema_extra exist
assert talk_value.get("x-widget") == "slider"
assert talk_value.get("x-icon") == "message-circle"
assert talk_value.get("step") == 0.1
def test_pydantic_constraints_mapped():
"""Test that Pydantic constraints (ge/le) are correctly mapped to minValue/maxValue."""
schema = ConfigSchemaGenerator.generate_schema(ChatConfig)
talk_value = next(f for f in schema["fields"] if f["name"] == "talk_value")
# Verify constraints are mapped to frontend naming convention
assert "minValue" in talk_value
assert "maxValue" in talk_value
assert talk_value["minValue"] == 0 # From ge=0
assert talk_value["maxValue"] == 1 # From le=1
def test_nested_model_schema():
"""Test that nested models (ConfigBase fields) are correctly handled."""
schema = ConfigSchemaGenerator.generate_schema(Config)
# Verify nested structure exists
assert "nested" in schema
assert "chat" in schema["nested"]
# Verify nested chat schema is complete
chat_schema = schema["nested"]["chat"]
assert chat_schema["className"] == "ChatConfig"
assert "fields" in chat_schema
# Verify nested schema fields include description and metadata
talk_value = next(f for f in chat_schema["fields"] if f["name"] == "talk_value")
assert "description" in talk_value
assert talk_value.get("x-widget") == "slider"
assert talk_value.get("minValue") == 0
def test_field_without_extra_metadata():
"""Test that fields without json_schema_extra still generate valid schema."""
schema = ConfigSchemaGenerator.generate_schema(ChatConfig)
inevitable_at_reply = next(f for f in schema["fields"] if f["name"] == "inevitable_at_reply")
# Verify basic fields are generated
assert "name" in inevitable_at_reply
assert inevitable_at_reply["name"] == "inevitable_at_reply"
assert "type" in inevitable_at_reply
assert inevitable_at_reply["type"] == "boolean"
assert "label" in inevitable_at_reply
assert "required" in inevitable_at_reply
# Verify no x-widget or x-icon from json_schema_extra (since field has none)
# These fields should only be present if explicitly defined in json_schema_extra
assert not inevitable_at_reply.get("x-widget")
assert not inevitable_at_reply.get("x-icon")
def test_all_top_level_sections_have_ui_metadata():
"""所有顶层配置节都必须声明 uiParent 或独立 Tab 的标签与图标。"""
schema = ConfigSchemaGenerator.generate_schema(Config)
for section_name, section_schema in schema["nested"].items():
has_parent = bool(section_schema.get("uiParent"))
has_host_meta = bool(section_schema.get("uiLabel")) and bool(section_schema.get("uiIcon"))
assert has_parent or has_host_meta, f"{section_name} 缺少 UI 元数据"
def test_maisaka_is_host_tab_and_mcp_is_attached_to_it():
"""MaiSaka 应作为独立 TabMCP 作为其子配置挂载。"""
schema = ConfigSchemaGenerator.generate_schema(Config)
maisaka_schema = schema["nested"]["maisaka"]
mcp_schema = schema["nested"]["mcp"]
assert maisaka_schema.get("uiParent") is None
assert maisaka_schema.get("uiLabel") == "MaiSaka"
assert maisaka_schema.get("uiIcon") == "message-circle"
assert mcp_schema.get("uiParent") == "maisaka"
def test_memory_query_config_fields_are_exposed():
"""query_memory 开关和默认条数应出现在记忆配置 schema 中。"""
schema = ConfigSchemaGenerator.generate_schema(Config)
memory_schema = schema["nested"]["memory"]
assert memory_schema.get("uiParent") == "emoji"
enable_field = next(field for field in memory_schema["fields"] if field["name"] == "enable_memory_query_tool")
limit_field = next(field for field in memory_schema["fields"] if field["name"] == "memory_query_default_limit")
assert enable_field["type"] == "boolean"
assert enable_field.get("x-widget") == "switch"
assert enable_field.get("x-icon") == "database"
assert limit_field["type"] == "integer"
assert limit_field.get("x-widget") == "input"
assert limit_field.get("x-icon") == "hash"
assert limit_field.get("minValue") == 1
assert limit_field.get("maxValue") == 20
def test_set_field_is_mapped_as_array():
"""set[str] 应映射为前端可识别的 array。"""
schema = ConfigSchemaGenerator.generate_schema(MessageReceiveConfig)
ban_words = next(field for field in schema["fields"] if field["name"] == "ban_words")
assert ban_words["type"] == "array"
assert ban_words["items"]["type"] == "string"
def test_advanced_fields_are_hidden_from_webui_schema():
"""advanced=True 的字段不应出现在 WebUI 配置 schema 中,未声明时默认展示。"""
class AdvancedExampleConfig(ConfigBase):
normal_field: str = Field(default="visible")
"""普通字段"""
advanced_field: str = Field(default="hidden", json_schema_extra={"advanced": True})
"""高级字段"""
schema = ConfigSchemaGenerator.generate_schema(AdvancedExampleConfig)
field_names = {field["name"] for field in schema["fields"]}
assert "normal_field" in field_names
assert "advanced_field" not in field_names

View File

@@ -1,461 +0,0 @@
"""表情包路由 API 测试
测试 src/webui/routers/emoji.py 中的核心 emoji 路由端点
使用内存 SQLite 数据库和 FastAPI TestClient
"""
from contextlib import contextmanager
from datetime import datetime
from typing import Generator
from unittest.mock import patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine
from src.common.database.database_model import Images, ImageType
from src.webui.core import TokenManager
from src.webui.routers.emoji import router
@pytest.fixture(scope="function")
def test_engine():
"""创建内存 SQLite 引擎用于测试"""
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
return engine
@pytest.fixture(scope="function")
def test_session(test_engine) -> Generator[Session, None, None]:
"""创建测试数据库会话"""
with Session(test_engine) as session:
yield session
@pytest.fixture(scope="function")
def test_app(test_session):
"""创建测试 FastAPI 应用并覆盖 get_db_session 依赖"""
app = FastAPI()
app.include_router(router)
# Create a context manager that yields the test session
@contextmanager
def override_get_db_session(auto_commit=True):
"""Override get_db_session to use test session"""
try:
yield test_session
if auto_commit:
test_session.commit()
except Exception:
test_session.rollback()
raise
with patch("src.webui.routers.emoji.get_db_session", override_get_db_session):
yield app
@pytest.fixture(scope="function")
def client(test_app):
"""创建 TestClient"""
return TestClient(test_app)
@pytest.fixture(scope="function")
def auth_token():
"""创建有效的认证 token"""
token_manager = TokenManager(secret_key="test-secret-key", token_expire_hours=24)
return token_manager.create_token()
@pytest.fixture(scope="function")
def sample_emojis(test_session) -> list[Images]:
"""插入测试用表情包数据"""
import hashlib
emojis = [
Images(
image_type=ImageType.EMOJI,
full_path="/data/emoji_registed/test1.png",
image_hash=hashlib.sha256(b"test1").hexdigest(),
description="测试表情包 1",
emotion="开心,快乐",
query_count=10,
is_registered=True,
is_banned=False,
record_time=datetime(2026, 1, 1, 10, 0, 0),
register_time=datetime(2026, 1, 1, 10, 0, 0),
last_used_time=datetime(2026, 1, 2, 10, 0, 0),
),
Images(
image_type=ImageType.EMOJI,
full_path="/data/emoji_registed/test2.gif",
image_hash=hashlib.sha256(b"test2").hexdigest(),
description="测试表情包 2",
emotion="难过",
query_count=5,
is_registered=False,
is_banned=False,
record_time=datetime(2026, 1, 3, 10, 0, 0),
register_time=None,
last_used_time=None,
),
Images(
image_type=ImageType.EMOJI,
full_path="/data/emoji_registed/test3.webp",
image_hash=hashlib.sha256(b"test3").hexdigest(),
description="测试表情包 3",
emotion="生气",
query_count=20,
is_registered=True,
is_banned=True,
record_time=datetime(2026, 1, 4, 10, 0, 0),
register_time=datetime(2026, 1, 4, 10, 0, 0),
last_used_time=datetime(2026, 1, 5, 10, 0, 0),
),
]
for emoji in emojis:
test_session.add(emoji)
test_session.commit()
for emoji in emojis:
test_session.refresh(emoji)
return emojis
@pytest.fixture(scope="function")
def mock_token_verify():
"""Mock token verification to always succeed"""
with patch("src.webui.routers.emoji.verify_auth_token", return_value=True):
yield
# ==================== 测试用例 ====================
def test_list_emojis_basic(client, sample_emojis, mock_token_verify):
"""测试获取表情包列表(基本分页)"""
response = client.get("/emoji/list?page=1&page_size=10")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 3
assert data["page"] == 1
assert data["page_size"] == 10
assert len(data["data"]) == 3
# 验证第一个表情包字段
emoji = data["data"][0]
assert "id" in emoji
assert "full_path" in emoji
assert "emoji_hash" in emoji
assert "description" in emoji
assert "query_count" in emoji
assert "is_registered" in emoji
assert "is_banned" in emoji
assert "emotion" in emoji
assert "record_time" in emoji
assert "register_time" in emoji
assert "last_used_time" in emoji
def test_list_emojis_pagination(client, sample_emojis, mock_token_verify):
"""测试分页功能"""
response = client.get("/emoji/list?page=1&page_size=2")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 3
assert len(data["data"]) == 2
# 第二页
response = client.get("/emoji/list?page=2&page_size=2")
data = response.json()
assert len(data["data"]) == 1
def test_list_emojis_search(client, sample_emojis, mock_token_verify):
"""测试搜索过滤"""
response = client.get("/emoji/list?search=表情包 2")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 1
assert data["data"][0]["description"] == "测试表情包 2"
def test_list_emojis_filter_registered(client, sample_emojis, mock_token_verify):
"""测试 is_registered 过滤"""
response = client.get("/emoji/list?is_registered=true")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 2
assert all(emoji["is_registered"] is True for emoji in data["data"])
def test_list_emojis_filter_banned(client, sample_emojis, mock_token_verify):
"""测试 is_banned 过滤"""
response = client.get("/emoji/list?is_banned=true")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 1
assert data["data"][0]["is_banned"] is True
def test_list_emojis_sort_by_query_count(client, sample_emojis, mock_token_verify):
"""测试按 query_count 排序"""
response = client.get("/emoji/list?sort_by=query_count&sort_order=desc")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
# 验证降序排列 (20 > 10 > 5)
assert data["data"][0]["query_count"] == 20
assert data["data"][1]["query_count"] == 10
assert data["data"][2]["query_count"] == 5
def test_get_emoji_detail_success(client, sample_emojis, mock_token_verify):
"""测试获取表情包详情(成功)"""
emoji_id = sample_emojis[0].id
response = client.get(f"/emoji/{emoji_id}")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["id"] == emoji_id
assert data["data"]["emoji_hash"] == sample_emojis[0].image_hash
def test_get_emoji_detail_not_found(client, mock_token_verify):
"""测试获取不存在的表情包404"""
response = client.get("/emoji/99999")
assert response.status_code == 404
data = response.json()
assert "未找到" in data["detail"]
def test_update_emoji_description(client, sample_emojis, mock_token_verify):
"""测试更新表情包描述"""
emoji_id = sample_emojis[0].id
response = client.patch(
f"/emoji/{emoji_id}",
json={"description": "更新后的描述"},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["description"] == "更新后的描述"
assert "成功更新" in data["message"]
def test_update_emoji_register_status(client, sample_emojis, mock_token_verify, test_session):
"""测试更新注册状态False -> True 应设置 register_time"""
emoji_id = sample_emojis[1].id # 未注册的表情包
response = client.patch(
f"/emoji/{emoji_id}",
json={"is_registered": True},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["is_registered"] is True
assert data["data"]["register_time"] is not None # 应该设置了注册时间
def test_update_emoji_no_fields(client, sample_emojis, mock_token_verify):
"""测试更新请求未提供任何字段400"""
emoji_id = sample_emojis[0].id
response = client.patch(f"/emoji/{emoji_id}", json={})
assert response.status_code == 400
data = response.json()
assert "未提供任何需要更新的字段" in data["detail"]
def test_update_emoji_not_found(client, mock_token_verify):
"""测试更新不存在的表情包404"""
response = client.patch("/emoji/99999", json={"description": "test"})
assert response.status_code == 404
data = response.json()
assert "未找到" in data["detail"]
def test_delete_emoji_success(client, sample_emojis, mock_token_verify, test_session):
"""测试删除表情包(成功)"""
emoji_id = sample_emojis[0].id
response = client.delete(f"/emoji/{emoji_id}")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "成功删除" in data["message"]
# 验证数据库中已删除
from sqlmodel import select
statement = select(Images).where(Images.id == emoji_id)
result = test_session.exec(statement).first()
assert result is None
def test_delete_emoji_not_found(client, mock_token_verify):
"""测试删除不存在的表情包404"""
response = client.delete("/emoji/99999")
assert response.status_code == 404
data = response.json()
assert "未找到" in data["detail"]
def test_batch_delete_success(client, sample_emojis, mock_token_verify, test_session):
"""测试批量删除表情包(全部成功)"""
emoji_ids = [sample_emojis[0].id, sample_emojis[1].id]
response = client.post("/emoji/batch/delete", json={"emoji_ids": emoji_ids})
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["deleted_count"] == 2
assert data["failed_count"] == 0
assert "成功删除 2 个表情包" in data["message"]
# 验证数据库中已删除
from sqlmodel import select
for emoji_id in emoji_ids:
statement = select(Images).where(Images.id == emoji_id)
result = test_session.exec(statement).first()
assert result is None
def test_batch_delete_partial_failure(client, sample_emojis, mock_token_verify):
"""测试批量删除(部分失败)"""
emoji_ids = [sample_emojis[0].id, 99999] # 第二个 ID 不存在
response = client.post("/emoji/batch/delete", json={"emoji_ids": emoji_ids})
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["deleted_count"] == 1
assert data["failed_count"] == 1
assert 99999 in data["failed_ids"]
def test_batch_delete_empty_list(client, mock_token_verify):
"""测试批量删除空列表400"""
response = client.post("/emoji/batch/delete", json={"emoji_ids": []})
assert response.status_code == 400
data = response.json()
assert "未提供要删除的表情包ID" in data["detail"]
def test_auth_required_list(client):
"""测试未认证访问列表端点401"""
# Without mock_token_verify fixture
with patch("src.webui.routers.emoji.verify_auth_token", return_value=False):
client.get("/emoji/list")
# verify_auth_token 返回 False 会触发 HTTPException
# 但具体状态码取决于 verify_auth_token_from_cookie_or_header 的实现
# 这里假设它抛出 401
def test_auth_required_update(client, sample_emojis):
"""测试未认证访问更新端点401"""
with patch("src.webui.routers.emoji.verify_auth_token", return_value=False):
emoji_id = sample_emojis[0].id
client.patch(f"/emoji/{emoji_id}", json={"description": "test"})
# Should be unauthorized
def test_emoji_to_response_field_mapping(sample_emojis):
"""测试 emoji_to_response 字段映射image_hash -> emoji_hash"""
from src.webui.routers.emoji import emoji_to_response
emoji = sample_emojis[0]
response = emoji_to_response(emoji)
# 验证 API 字段名称
assert hasattr(response, "emoji_hash")
assert response.emoji_hash == emoji.image_hash
# 验证时间戳转换
assert isinstance(response.record_time, float)
assert response.record_time == emoji.record_time.timestamp()
if emoji.register_time:
assert isinstance(response.register_time, float)
assert response.register_time == emoji.register_time.timestamp()
def test_list_emojis_only_emoji_type(client, test_session, mock_token_verify):
"""测试列表只返回 type=EMOJI 的记录(不包括其他类型)"""
# 插入一个非 EMOJI 类型的图片
non_emoji = Images(
image_type=ImageType.IMAGE, # 不是 EMOJI
full_path="/data/images/test.png",
image_hash="hash_image",
description="非表情包图片",
query_count=0,
is_registered=False,
is_banned=False,
record_time=datetime.now(),
)
test_session.add(non_emoji)
test_session.commit()
# 插入一个 EMOJI 类型
emoji = Images(
image_type=ImageType.EMOJI,
full_path="/data/emoji_registed/emoji.png",
image_hash="hash_emoji",
description="表情包",
query_count=0,
is_registered=True,
is_banned=False,
record_time=datetime.now(),
)
test_session.add(emoji)
test_session.commit()
response = client.get("/emoji/list")
assert response.status_code == 200
data = response.json()
# 只应该返回 1 个 EMOJI 类型的记录
assert data["total"] == 1
assert data["data"][0]["description"] == "表情包"

View File

@@ -1,529 +0,0 @@
"""Expression routes pytest tests"""
from typing import Generator
import pytest
from fastapi import APIRouter, FastAPI
from fastapi.testclient import TestClient
from sqlalchemy import text
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine, select
from src.common.database.database_model import Expression, ModifiedBy
from src.webui.dependencies import require_auth
def create_test_app() -> FastAPI:
"""Create minimal test app with only expression router"""
app = FastAPI(title="Test App")
from src.webui.routers.expression import router as expression_router
main_router = APIRouter(prefix="/api/webui")
main_router.include_router(expression_router)
app.include_router(main_router)
return app
app = create_test_app()
# Test database setup
@pytest.fixture(name="test_engine")
def test_engine_fixture():
"""Create in-memory SQLite database for testing"""
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
return engine
@pytest.fixture(name="test_session")
def test_session_fixture(test_engine) -> Generator[Session, None, None]:
"""Create a test database session with transaction rollback"""
connection = test_engine.connect()
transaction = connection.begin()
session = Session(bind=connection)
yield session
session.close()
transaction.rollback()
connection.close()
@pytest.fixture(name="client")
def client_fixture(test_session: Session, monkeypatch) -> Generator[TestClient, None, None]:
"""Create TestClient with overridden database session"""
from contextlib import contextmanager
@contextmanager
def get_test_db_session():
yield test_session
test_session.commit()
monkeypatch.setattr("src.webui.routers.expression.get_db_session", get_test_db_session)
with TestClient(app) as client:
yield client
@pytest.fixture(name="mock_auth")
def mock_auth_fixture():
"""Mock authentication to always return True"""
app.dependency_overrides[require_auth] = lambda: "test-token"
yield
app.dependency_overrides.clear()
@pytest.fixture(name="sample_expression")
def sample_expression_fixture(test_session: Session) -> Expression:
"""Insert a sample expression into test database"""
test_session.execute(
text(
"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
"VALUES (1, '测试情景', '测试风格', '[\"测试内容1\", \"测试内容2\"]', 10, '2026-02-17 12:00:00', '2026-02-15 10:00:00', 'test_chat_001', 0, 0)"
)
)
test_session.commit()
expression = test_session.exec(select(Expression).where(Expression.id == 1)).first()
assert expression is not None
return expression
# ============ Tests ============
def test_list_expressions_empty(client: TestClient, mock_auth):
"""Test GET /expression/list with empty database"""
response = client.get("/api/webui/expression/list")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 0
assert data["page"] == 1
assert data["page_size"] == 20
assert data["data"] == []
def test_list_expressions_with_data(client: TestClient, mock_auth, sample_expression: Expression):
"""Test GET /expression/list returns expression data"""
response = client.get("/api/webui/expression/list")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 1
assert len(data["data"]) == 1
expr_data = data["data"][0]
assert expr_data["id"] == sample_expression.id
assert expr_data["situation"] == "测试情景"
assert expr_data["style"] == "测试风格"
assert expr_data["chat_id"] == "test_chat_001"
def test_list_expressions_pagination(client: TestClient, mock_auth, test_session: Session):
"""Test GET /expression/list pagination works correctly"""
for i in range(5):
test_session.execute(
text(
f"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
f"VALUES ({i + 1}, '情景{i}', '风格{i}', '[]', 0, '2026-02-17 12:0{i}:00', '2026-02-15 10:00:00', 'chat_{i}', 0, 0)"
)
)
test_session.commit()
# Request page 1 with page_size=2
response = client.get("/api/webui/expression/list?page=1&page_size=2")
assert response.status_code == 200
data = response.json()
assert data["total"] == 5
assert data["page"] == 1
assert data["page_size"] == 2
assert len(data["data"]) == 2
# Request page 2
response = client.get("/api/webui/expression/list?page=2&page_size=2")
data = response.json()
assert data["page"] == 2
assert len(data["data"]) == 2
def test_list_expressions_search(client: TestClient, mock_auth, test_session: Session):
"""Test GET /expression/list with search filter"""
test_session.execute(
text(
"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
"VALUES (1, '找人吃饭', '热情', '[]', 0, datetime('now'), datetime('now'), 'chat_001', 0, 0)"
)
)
test_session.execute(
text(
"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
"VALUES (2, '拒绝邀请', '礼貌', '[]', 0, datetime('now'), datetime('now'), 'chat_002', 0, 0)"
)
)
test_session.commit()
# Search for "吃饭"
response = client.get("/api/webui/expression/list?search=吃饭")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["data"][0]["situation"] == "找人吃饭"
def test_list_expressions_chat_filter(client: TestClient, mock_auth, test_session: Session):
"""Test GET /expression/list with chat_id filter"""
test_session.execute(
text(
"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
"VALUES (1, '情景A', '风格A', '[]', 0, datetime('now'), datetime('now'), 'chat_A', 0, 0)"
)
)
test_session.execute(
text(
"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
"VALUES (2, '情景B', '风格B', '[]', 0, datetime('now'), datetime('now'), 'chat_B', 0, 0)"
)
)
test_session.commit()
# Filter by chat_A
response = client.get("/api/webui/expression/list?chat_id=chat_A")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["data"][0]["situation"] == "情景A"
assert data["data"][0]["chat_id"] == "chat_A"
def test_get_expression_detail_success(client: TestClient, mock_auth, sample_expression: Expression):
"""Test GET /expression/{id} returns correct detail"""
response = client.get(f"/api/webui/expression/{sample_expression.id}")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["id"] == sample_expression.id
assert data["data"]["situation"] == "测试情景"
assert data["data"]["style"] == "测试风格"
assert data["data"]["chat_id"] == "test_chat_001"
def test_get_expression_detail_not_found(client: TestClient, mock_auth):
"""Test GET /expression/{id} returns 404 for non-existent ID"""
response = client.get("/api/webui/expression/99999")
assert response.status_code == 404
data = response.json()
assert "未找到" in data["detail"]
def test_expression_response_has_legacy_fields(client: TestClient, mock_auth, sample_expression: Expression):
"""Test that ExpressionResponse includes legacy fields (checked/rejected/modified_by)"""
response = client.get(f"/api/webui/expression/{sample_expression.id}")
assert response.status_code == 200
data = response.json()["data"]
# Verify legacy fields exist and have default values
assert "checked" in data
assert "rejected" in data
assert "modified_by" in data
# Verify hardcoded default values (from expression_to_response)
assert data["checked"] is False
assert data["rejected"] is False
assert data["modified_by"] is None
def test_update_expression_without_removed_fields(client: TestClient, mock_auth, sample_expression: Expression):
"""Test PATCH /expression/{id} does not accept checked/rejected fields"""
# Valid update request (only allowed fields)
update_payload = {
"situation": "更新后的情景",
"style": "更新后的风格",
}
response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["situation"] == "更新后的情景"
assert data["data"]["style"] == "更新后的风格"
# Verify legacy fields still returned (hardcoded values)
assert data["data"]["checked"] is False
assert data["data"]["rejected"] is False
def test_update_expression_ignores_invalid_fields(client: TestClient, mock_auth, sample_expression: Expression):
"""Test PATCH /expression/{id} ignores fields not in ExpressionUpdateRequest"""
# Request with invalid field (checked not in schema)
update_payload = {
"situation": "新情景",
"checked": True, # This field should be ignored by Pydantic
"rejected": True, # This field should be ignored
}
response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["situation"] == "新情景"
# Response should have hardcoded False values (not True from request)
assert data["data"]["checked"] is False
assert data["data"]["rejected"] is False
def test_update_expression_chat_id_mapping(client: TestClient, mock_auth, sample_expression: Expression):
"""Test PATCH /expression/{id} correctly maps chat_id to session_id"""
update_payload = {"chat_id": "updated_chat_999"}
response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
# Verify chat_id is returned in response (mapped from session_id)
assert data["data"]["chat_id"] == "updated_chat_999"
def test_update_expression_not_found(client: TestClient, mock_auth):
"""Test PATCH /expression/{id} returns 404 for non-existent ID"""
update_payload = {"situation": "新情景"}
response = client.patch("/api/webui/expression/99999", json=update_payload)
assert response.status_code == 404
data = response.json()
assert "未找到" in data["detail"]
def test_update_expression_empty_request(client: TestClient, mock_auth, sample_expression: Expression):
"""Test PATCH /expression/{id} returns 400 for empty update request"""
update_payload = {}
response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload)
assert response.status_code == 400
data = response.json()
assert "未提供任何需要更新的字段" in data["detail"]
def test_delete_expression_success(client: TestClient, mock_auth, sample_expression: Expression):
"""Test DELETE /expression/{id} successfully deletes expression"""
expression_id = sample_expression.id
response = client.delete(f"/api/webui/expression/{expression_id}")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "成功删除" in data["message"]
# Verify expression is deleted
get_response = client.get(f"/api/webui/expression/{expression_id}")
assert get_response.status_code == 404
def test_delete_expression_not_found(client: TestClient, mock_auth):
"""Test DELETE /expression/{id} returns 404 for non-existent ID"""
response = client.delete("/api/webui/expression/99999")
assert response.status_code == 404
data = response.json()
assert "未找到" in data["detail"]
def test_create_expression_success(client: TestClient, mock_auth):
"""Test POST /expression/ successfully creates expression"""
create_payload = {
"situation": "新建情景",
"style": "新建风格",
"chat_id": "new_chat_123",
}
response = client.post("/api/webui/expression/", json=create_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "创建成功" in data["message"]
assert data["data"]["situation"] == "新建情景"
assert data["data"]["style"] == "新建风格"
assert data["data"]["chat_id"] == "new_chat_123"
# Verify legacy fields
assert data["data"]["checked"] is False
assert data["data"]["rejected"] is False
assert data["data"]["modified_by"] is None
def test_batch_delete_expressions_success(client: TestClient, mock_auth, test_session: Session):
"""Test POST /expression/batch/delete successfully deletes multiple expressions"""
expression_ids = []
for i in range(3):
test_session.execute(
text(
f"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
f"VALUES ({i + 1}, '批量删除{i}', '风格{i}', '[]', 0, datetime('now'), datetime('now'), 'chat_{i}', 0, 0)"
)
)
expression_ids.append(i + 1)
test_session.commit()
delete_payload = {"ids": expression_ids}
response = client.post("/api/webui/expression/batch/delete", json=delete_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "成功删除 3 个" in data["message"]
for expr_id in expression_ids:
get_response = client.get(f"/api/webui/expression/{expr_id}")
assert get_response.status_code == 404
def test_batch_delete_partial_not_found(client: TestClient, mock_auth, sample_expression: Expression):
"""Test POST /expression/batch/delete handles partial not found IDs"""
delete_payload = {"ids": [sample_expression.id, 88888, 99999]}
response = client.post("/api/webui/expression/batch/delete", json=delete_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
# Should delete only the 1 valid ID
assert "成功删除 1 个" in data["message"]
def test_get_expression_stats(client: TestClient, mock_auth, test_session: Session):
"""Test GET /expression/stats/summary returns correct statistics"""
for i in range(3):
test_session.execute(
text(
f"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
f"VALUES ({i + 1}, '情景{i}', '风格{i}', '[]', 0, datetime('now'), datetime('now'), 'chat_{i % 2}', 0, 0)"
)
)
test_session.commit()
response = client.get("/api/webui/expression/stats/summary")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["total"] == 3
assert data["data"]["chat_count"] == 2
def test_get_review_stats(client: TestClient, mock_auth, test_session: Session):
"""Test GET /expression/review/stats returns review status counts"""
test_session.execute(
text(
"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
"VALUES (1, '待审核', '风格', '[]', 0, datetime('now'), datetime('now'), 'chat_001', 0, 0)"
)
)
test_session.commit()
response = client.get("/api/webui/expression/review/stats")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1 # Total expressions exists
assert data["unchecked"] == 1
assert data["passed"] == 0
assert data["rejected"] == 0
assert data["ai_checked"] == 0
assert data["user_checked"] == 0
def test_get_review_list_filter_unchecked(client: TestClient, mock_auth, sample_expression: Expression):
"""Test GET /expression/review/list with filter_type=unchecked returns unchecked expressions"""
response = client.get("/api/webui/expression/review/list?filter_type=unchecked")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 1
assert len(data["data"]) == 1
def test_get_review_list_filter_all(client: TestClient, mock_auth, sample_expression: Expression):
"""Test GET /expression/review/list with filter_type=all returns all expressions"""
response = client.get("/api/webui/expression/review/list?filter_type=all")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 1
assert len(data["data"]) == 1
def test_batch_review_expressions_with_unchecked_marker(client: TestClient, mock_auth, sample_expression: Expression):
"""Test POST /expression/review/batch succeeds with require_unchecked=True"""
review_payload = {"items": [{"id": sample_expression.id, "rejected": False, "require_unchecked": True}]}
response = client.post("/api/webui/expression/review/batch", json=review_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["succeeded"] == 1
assert data["results"][0]["success"] is True
def test_batch_review_expressions_overwrites_ai_checked(
client: TestClient, mock_auth, test_session: Session, sample_expression: Expression
):
"""Test POST /expression/review/batch lets manual review override AI checked state"""
sample_expression.checked = True
sample_expression.rejected = True
sample_expression.modified_by = ModifiedBy.AI
test_session.add(sample_expression)
test_session.commit()
review_payload = {"items": [{"id": sample_expression.id, "rejected": False, "require_unchecked": True}]}
response = client.post("/api/webui/expression/review/batch", json=review_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["succeeded"] == 1
test_session.expire_all()
reviewed_expression = test_session.exec(select(Expression).where(Expression.id == sample_expression.id)).first()
assert reviewed_expression is not None
assert reviewed_expression.checked is True
assert reviewed_expression.rejected is False
assert reviewed_expression.modified_by == ModifiedBy.USER
def test_batch_review_expressions_no_unchecked_check(client: TestClient, mock_auth, sample_expression: Expression):
"""Test POST /expression/review/batch succeeds when require_unchecked=False"""
review_payload = {"items": [{"id": sample_expression.id, "rejected": False, "require_unchecked": False}]}
response = client.post("/api/webui/expression/review/batch", json=review_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["succeeded"] == 1
assert data["results"][0]["success"] is True

View File

@@ -1,512 +0,0 @@
"""测试 jargon 路由的完整性和正确性"""
import json
from datetime import datetime
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine
from src.common.database.database_model import ChatSession, Jargon
from src.webui.routers.jargon import router as jargon_router
@pytest.fixture(name="app", scope="function")
def app_fixture():
app = FastAPI()
app.include_router(jargon_router, prefix="/api/webui")
return app
@pytest.fixture(name="engine", scope="function")
def engine_fixture():
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
yield engine
@pytest.fixture(name="session", scope="function")
def session_fixture(engine):
connection = engine.connect()
transaction = connection.begin()
session = Session(bind=connection)
yield session
session.close()
transaction.rollback()
connection.close()
@pytest.fixture(name="client", scope="function")
def client_fixture(app: FastAPI, session: Session, monkeypatch):
from contextlib import contextmanager
@contextmanager
def mock_get_db_session():
yield session
monkeypatch.setattr("src.webui.routers.jargon.get_db_session", mock_get_db_session)
with TestClient(app) as client:
yield client
@pytest.fixture(name="sample_chat_session")
def sample_chat_session_fixture(session: Session):
"""创建示例 ChatSession"""
chat_session = ChatSession(
session_id="test_stream_001",
platform="qq",
group_id="123456789",
user_id=None,
created_timestamp=datetime.now(),
last_active_timestamp=datetime.now(),
)
session.add(chat_session)
session.commit()
session.refresh(chat_session)
return chat_session
@pytest.fixture(name="sample_jargons")
def sample_jargons_fixture(session: Session, sample_chat_session: ChatSession):
"""创建示例 Jargon 数据"""
jargons = [
Jargon(
id=1,
content="yyds",
raw_content="永远的神",
meaning="永远的神",
session_id=sample_chat_session.session_id,
count=10,
is_jargon=True,
is_complete=False,
),
Jargon(
id=2,
content="awsl",
raw_content="啊我死了",
meaning="啊我死了",
session_id=sample_chat_session.session_id,
count=5,
is_jargon=True,
is_complete=False,
),
Jargon(
id=3,
content="hello",
raw_content=None,
meaning="你好",
session_id=sample_chat_session.session_id,
count=2,
is_jargon=False,
is_complete=False,
),
]
for jargon in jargons:
session.add(jargon)
session.commit()
for jargon in jargons:
session.refresh(jargon)
return jargons
# ==================== Test Cases ====================
def test_list_jargons(client: TestClient, sample_jargons):
"""测试 GET /jargon/list 基础列表功能"""
response = client.get("/api/webui/jargon/list")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 3
assert data["page"] == 1
assert data["page_size"] == 20
assert len(data["data"]) == 3
assert data["data"][0]["content"] == "yyds"
assert data["data"][0]["count"] == 10
def test_list_jargons_with_pagination(client: TestClient, sample_jargons):
"""测试分页功能"""
response = client.get("/api/webui/jargon/list?page=1&page_size=2")
assert response.status_code == 200
data = response.json()
assert data["total"] == 3
assert len(data["data"]) == 2
response = client.get("/api/webui/jargon/list?page=2&page_size=2")
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 1
def test_list_jargons_with_search(client: TestClient, sample_jargons):
"""测试 GET /jargon/list?search=xxx 搜索功能"""
response = client.get("/api/webui/jargon/list?search=yyds")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["data"][0]["content"] == "yyds"
# 测试搜索 meaning
response = client.get("/api/webui/jargon/list?search=你好")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["data"][0]["content"] == "hello"
def test_list_jargons_with_chat_id_filter(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
"""测试按 chat_id 筛选"""
response = client.get(f"/api/webui/jargon/list?chat_id={sample_chat_session.session_id}")
assert response.status_code == 200
data = response.json()
assert data["total"] == 3
# 测试不存在的 chat_id
response = client.get("/api/webui/jargon/list?chat_id=nonexistent")
assert response.status_code == 200
data = response.json()
assert data["total"] == 0
def test_list_jargons_with_is_jargon_filter(client: TestClient, sample_jargons):
"""测试按 is_jargon 筛选"""
response = client.get("/api/webui/jargon/list?is_jargon=true")
assert response.status_code == 200
data = response.json()
assert data["total"] == 2
assert all(item["is_jargon"] is True for item in data["data"])
response = client.get("/api/webui/jargon/list?is_jargon=false")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["data"][0]["content"] == "hello"
def test_get_jargon_detail(client: TestClient, sample_jargons):
"""测试 GET /jargon/{id} 获取详情"""
jargon_id = sample_jargons[0].id
response = client.get(f"/api/webui/jargon/{jargon_id}")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["id"] == jargon_id
assert data["data"]["content"] == "yyds"
assert data["data"]["meaning"] == "永远的神"
assert data["data"]["count"] == 10
assert data["data"]["is_jargon"] is True
def test_get_jargon_detail_not_found(client: TestClient):
"""测试获取不存在的黑话详情"""
response = client.get("/api/webui/jargon/99999")
assert response.status_code == 404
assert "黑话不存在" in response.json()["detail"]
@pytest.mark.skip(reason="Composite PK (id+content) prevents autoincrement - database model issue")
def test_create_jargon(client: TestClient, sample_chat_session: ChatSession):
"""测试 POST /jargon/ 创建黑话"""
request_data = {
"content": "新黑话",
"raw_content": "原始内容",
"meaning": "含义",
"chat_id": sample_chat_session.session_id,
}
response = client.post("/api/webui/jargon/", json=request_data)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["message"] == "创建成功"
assert data["data"]["content"] == "新黑话"
assert data["data"]["meaning"] == "含义"
assert data["data"]["count"] == 0
assert data["data"]["is_jargon"] is None
assert data["data"]["is_complete"] is False
def test_create_duplicate_jargon(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
"""测试创建重复黑话返回 400"""
request_data = {
"content": "yyds",
"meaning": "重复的",
"chat_id": sample_chat_session.session_id,
}
response = client.post("/api/webui/jargon/", json=request_data)
assert response.status_code == 400
assert "已存在相同内容的黑话" in response.json()["detail"]
def test_update_jargon(client: TestClient, sample_jargons):
"""测试 PATCH /jargon/{id} 更新黑话"""
jargon_id = sample_jargons[0].id
update_data = {
"meaning": "更新后的含义",
"is_jargon": True,
}
response = client.patch(f"/api/webui/jargon/{jargon_id}", json=update_data)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["message"] == "更新成功"
assert data["data"]["meaning"] == "更新后的含义"
assert data["data"]["is_jargon"] is True
assert data["data"]["content"] == "yyds" # 未改变的字段保持不变
def test_update_jargon_with_chat_id_mapping(client: TestClient, sample_jargons):
"""测试更新时 chat_id → session_id 的映射"""
jargon_id = sample_jargons[0].id
update_data = {
"chat_id": "new_session_id",
}
response = client.patch(f"/api/webui/jargon/{jargon_id}", json=update_data)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["chat_id"] == "new_session_id"
def test_update_jargon_not_found(client: TestClient):
"""测试更新不存在的黑话"""
response = client.patch("/api/webui/jargon/99999", json={"meaning": "test"})
assert response.status_code == 404
assert "黑话不存在" in response.json()["detail"]
def test_delete_jargon(client: TestClient, sample_jargons, session: Session):
"""测试 DELETE /jargon/{id} 删除黑话"""
jargon_id = sample_jargons[0].id
response = client.delete(f"/api/webui/jargon/{jargon_id}")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["message"] == "删除成功"
assert data["deleted_count"] == 1
# 验证数据库中已删除
response = client.get(f"/api/webui/jargon/{jargon_id}")
assert response.status_code == 404
def test_delete_jargon_not_found(client: TestClient):
"""测试删除不存在的黑话"""
response = client.delete("/api/webui/jargon/99999")
assert response.status_code == 404
assert "黑话不存在" in response.json()["detail"]
def test_batch_delete(client: TestClient, sample_jargons):
"""测试 POST /jargon/batch/delete 批量删除"""
ids_to_delete = [sample_jargons[0].id, sample_jargons[1].id]
request_data = {"ids": ids_to_delete}
response = client.post("/api/webui/jargon/batch/delete", json=request_data)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["deleted_count"] == 2
assert "成功删除 2 条黑话" in data["message"]
# 验证已删除
response = client.get(f"/api/webui/jargon/{ids_to_delete[0]}")
assert response.status_code == 404
def test_batch_delete_empty_list(client: TestClient):
"""测试批量删除空列表返回 400"""
response = client.post("/api/webui/jargon/batch/delete", json={"ids": []})
assert response.status_code == 400
assert "ID列表不能为空" in response.json()["detail"]
def test_batch_set_jargon_status(client: TestClient, sample_jargons):
"""测试批量设置黑话状态"""
ids = [sample_jargons[0].id, sample_jargons[1].id]
response = client.post(
"/api/webui/jargon/batch/set-jargon",
params={"ids": ids, "is_jargon": False},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "成功更新 2 条黑话状态" in data["message"]
# 验证状态已更新
detail_response = client.get(f"/api/webui/jargon/{ids[0]}")
assert detail_response.json()["data"]["is_jargon"] is False
def test_get_stats(client: TestClient, sample_jargons):
"""测试 GET /jargon/stats/summary 统计数据"""
response = client.get("/api/webui/jargon/stats/summary")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
stats = data["data"]
assert stats["total"] == 3
assert stats["confirmed_jargon"] == 2
assert stats["confirmed_not_jargon"] == 1
assert stats["pending"] == 0
assert stats["complete_count"] == 0
assert stats["chat_count"] == 1
assert isinstance(stats["top_chats"], dict)
def test_get_chat_list(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
"""测试 GET /jargon/chats 获取聊天列表"""
response = client.get("/api/webui/jargon/chats")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert len(data["data"]) == 1
chat_info = data["data"][0]
assert chat_info["chat_id"] == sample_chat_session.session_id
assert chat_info["platform"] == "qq"
assert chat_info["is_group"] is True
assert chat_info["chat_name"] == sample_chat_session.group_id
def test_get_chat_list_with_json_chat_id(client: TestClient, session: Session, sample_chat_session: ChatSession):
"""测试解析 JSON 格式的 chat_id"""
json_chat_id = json.dumps([[sample_chat_session.session_id, "user123"]])
jargon = Jargon(
id=100,
content="测试黑话",
meaning="测试",
session_id=json_chat_id,
count=1,
)
session.add(jargon)
session.commit()
response = client.get("/api/webui/jargon/chats")
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 1
assert data["data"][0]["chat_id"] == sample_chat_session.session_id
def test_get_chat_list_without_chat_session(client: TestClient, session: Session):
"""测试聊天列表中没有对应 ChatSession 的情况"""
jargon = Jargon(
id=101,
content="孤立黑话",
meaning="无对应会话",
session_id="nonexistent_stream_id",
count=1,
)
session.add(jargon)
session.commit()
response = client.get("/api/webui/jargon/chats")
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 1
assert data["data"][0]["chat_id"] == "nonexistent_stream_id"
assert data["data"][0]["chat_name"] == "nonexistent_stream_id"[:20]
assert data["data"][0]["platform"] is None
assert data["data"][0]["is_group"] is False
def test_jargon_response_fields(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
"""测试 JargonResponse 字段完整性"""
response = client.get(f"/api/webui/jargon/{sample_jargons[0].id}")
assert response.status_code == 200
data = response.json()["data"]
# 验证所有必需字段存在
required_fields = [
"id",
"content",
"raw_content",
"meaning",
"chat_id",
"stream_id",
"chat_name",
"count",
"is_jargon",
"is_complete",
"inference_with_context",
"inference_content_only",
]
for field in required_fields:
assert field in data
# 验证 chat_name 显示逻辑
assert data["chat_name"] == sample_chat_session.group_id
@pytest.mark.skip(reason="Composite PK (id+content) prevents autoincrement - database model issue")
def test_create_jargon_without_optional_fields(client: TestClient, sample_chat_session: ChatSession):
"""测试创建黑话时可选字段为空"""
request_data = {
"content": "简单黑话",
"chat_id": sample_chat_session.session_id,
}
response = client.post("/api/webui/jargon/", json=request_data)
assert response.status_code == 200
data = response.json()["data"]
assert data["raw_content"] is None
assert data["meaning"] == ""
def test_update_jargon_partial_fields(client: TestClient, sample_jargons):
"""测试增量更新(只更新部分字段)"""
jargon_id = sample_jargons[0].id
original_content = sample_jargons[0].content
# 只更新 meaning
response = client.patch(f"/api/webui/jargon/{jargon_id}", json={"meaning": "新含义"})
assert response.status_code == 200
data = response.json()["data"]
assert data["meaning"] == "新含义"
assert data["content"] == original_content # 其他字段不变
def test_list_jargons_multiple_filters(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
"""测试组合多个过滤条件"""
response = client.get(f"/api/webui/jargon/list?search=永远&chat_id={sample_chat_session.session_id}&is_jargon=true")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["data"][0]["content"] == "yyds"

View File

@@ -1,870 +0,0 @@
from fastapi import FastAPI
from fastapi.testclient import TestClient
import pytest
from src.services.memory_service import MemorySearchResult
from src.webui.dependencies import require_auth
from src.webui.routers import memory as memory_router_module
from src.webui.routers.memory import compat_router
from src.webui.routes import router as main_router
@pytest.fixture
def client() -> TestClient:
app = FastAPI()
app.dependency_overrides[require_auth] = lambda: "ok"
app.include_router(main_router)
app.include_router(compat_router)
return TestClient(app)
def test_webui_memory_graph_route(client: TestClient, monkeypatch):
async def fake_graph_admin(*, action: str, **kwargs):
assert action == "get_graph"
return {
"success": True,
"nodes": [],
"edges": [
{
"source": "alice",
"target": "map",
"weight": 1.5,
"relation_hashes": ["rel-1"],
"predicates": ["持有"],
"relation_count": 1,
"evidence_count": 2,
"label": "持有",
}
],
"total_nodes": 0,
"limit": kwargs.get("limit"),
}
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
response = client.get("/api/webui/memory/graph", params={"limit": 77})
assert response.status_code == 200
assert response.json()["success"] is True
assert response.json()["limit"] == 77
assert response.json()["edges"][0]["predicates"] == ["持有"]
assert response.json()["edges"][0]["relation_count"] == 1
assert response.json()["edges"][0]["evidence_count"] == 2
def test_webui_memory_graph_search_route(client: TestClient, monkeypatch):
async def fake_graph_admin(*, action: str, **kwargs):
assert action == "search"
assert kwargs["query"] == "Alice"
assert kwargs["limit"] == 33
return {
"success": True,
"query": kwargs["query"],
"limit": kwargs["limit"],
"count": 1,
"items": [
{
"type": "entity",
"title": "Alice",
"matched_field": "name",
"matched_value": "Alice",
"entity_name": "Alice",
"entity_hash": "entity-1",
"appearance_count": 3,
}
],
}
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
response = client.get("/api/webui/memory/graph/search", params={"query": "Alice", "limit": 33})
assert response.status_code == 200
assert response.json()["success"] is True
assert response.json()["query"] == "Alice"
assert response.json()["limit"] == 33
assert response.json()["items"][0]["type"] == "entity"
@pytest.mark.parametrize(
"params",
[
{"query": "", "limit": 50},
{"query": "Alice", "limit": 0},
{"query": "Alice", "limit": 201},
],
)
def test_webui_memory_graph_search_route_validation(client: TestClient, params):
response = client.get("/api/webui/memory/graph/search", params=params)
assert response.status_code == 422
def test_webui_memory_graph_node_detail_route(client: TestClient, monkeypatch):
async def fake_graph_admin(*, action: str, **kwargs):
assert action == "node_detail"
assert kwargs["node_id"] == "Alice"
return {
"success": True,
"node": {"id": "Alice", "type": "entity", "content": "Alice", "appearance_count": 3},
"relations": [{"hash": "rel-1", "subject": "Alice", "predicate": "持有", "object": "Map", "text": "Alice 持有 Map", "confidence": 0.9, "paragraph_count": 1, "paragraph_hashes": ["p-1"], "source_paragraph": "p-1"}],
"paragraphs": [{"hash": "p-1", "content": "Alice 拿着地图。", "preview": "Alice 拿着地图。", "source": "demo", "entity_count": 2, "relation_count": 1, "entities": ["Alice", "Map"], "relations": ["Alice 持有 Map"]}],
"evidence_graph": {
"nodes": [{"id": "entity:Alice", "type": "entity", "content": "Alice"}],
"edges": [],
"focus_entities": ["Alice"],
},
}
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
response = client.get("/api/webui/memory/graph/node-detail", params={"node_id": "Alice"})
assert response.status_code == 200
assert response.json()["node"]["id"] == "Alice"
assert response.json()["relations"][0]["predicate"] == "持有"
assert response.json()["evidence_graph"]["focus_entities"] == ["Alice"]
def test_webui_memory_graph_node_detail_route_returns_404(client: TestClient, monkeypatch):
async def fake_graph_admin(*, action: str, **kwargs):
assert action == "node_detail"
return {"success": False, "error": "未找到节点: Missing"}
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
response = client.get("/api/webui/memory/graph/node-detail", params={"node_id": "Missing"})
assert response.status_code == 404
assert response.json()["detail"] == "未找到节点: Missing"
def test_webui_memory_graph_edge_detail_route(client: TestClient, monkeypatch):
async def fake_graph_admin(*, action: str, **kwargs):
assert action == "edge_detail"
assert kwargs["source"] == "Alice"
assert kwargs["target"] == "Map"
return {
"success": True,
"edge": {
"source": "Alice",
"target": "Map",
"weight": 1.5,
"relation_hashes": ["rel-1"],
"predicates": ["持有"],
"relation_count": 1,
"evidence_count": 1,
"label": "持有",
},
"relations": [{"hash": "rel-1", "subject": "Alice", "predicate": "持有", "object": "Map", "text": "Alice 持有 Map", "confidence": 0.9, "paragraph_count": 1, "paragraph_hashes": ["p-1"], "source_paragraph": "p-1"}],
"paragraphs": [{"hash": "p-1", "content": "Alice 拿着地图。", "preview": "Alice 拿着地图。", "source": "demo", "entity_count": 2, "relation_count": 1, "entities": ["Alice", "Map"], "relations": ["Alice 持有 Map"]}],
"evidence_graph": {
"nodes": [{"id": "relation:rel-1", "type": "relation", "content": "Alice 持有 Map"}],
"edges": [{"source": "paragraph:p-1", "target": "relation:rel-1", "kind": "supports", "label": "支撑", "weight": 1.0}],
"focus_entities": ["Alice", "Map"],
},
}
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
response = client.get("/api/webui/memory/graph/edge-detail", params={"source": "Alice", "target": "Map"})
assert response.status_code == 200
assert response.json()["edge"]["predicates"] == ["持有"]
assert response.json()["paragraphs"][0]["source"] == "demo"
assert response.json()["evidence_graph"]["edges"][0]["kind"] == "supports"
def test_webui_memory_graph_edge_detail_route_returns_404(client: TestClient, monkeypatch):
async def fake_graph_admin(*, action: str, **kwargs):
assert action == "edge_detail"
return {"success": False, "error": "未找到边: Alice -> Missing"}
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
response = client.get("/api/webui/memory/graph/edge-detail", params={"source": "Alice", "target": "Missing"})
assert response.status_code == 404
assert response.json()["detail"] == "未找到边: Alice -> Missing"
def test_webui_memory_profile_query_resolves_platform_user_id(client: TestClient, monkeypatch):
def fake_resolve_person_id_for_memory(**kwargs):
assert kwargs == {"platform": "qq", "user_id": "12345", "strict_known": False}
return "resolved-person-id"
async def fake_profile_admin(*, action: str, **kwargs):
assert action == "query"
assert kwargs["person_id"] == "resolved-person-id"
assert kwargs["person_keyword"] == "Alice"
assert kwargs["limit"] == 9
assert kwargs["force_refresh"] is True
return {"success": True, "person_id": kwargs["person_id"], "profile_text": "profile"}
monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
monkeypatch.setattr(memory_router_module.memory_service, "profile_admin", fake_profile_admin)
response = client.get(
"/api/webui/memory/profiles/query",
params={
"platform": "qq",
"user_id": "12345",
"person_keyword": "Alice",
"limit": 9,
"force_refresh": True,
},
)
assert response.status_code == 200
assert response.json()["success"] is True
assert response.json()["person_id"] == "resolved-person-id"
def test_webui_memory_profile_query_prefers_explicit_person_id(client: TestClient, monkeypatch):
def fake_resolve_person_id_for_memory(**kwargs):
raise AssertionError(f"不应解析平台账号: {kwargs}")
async def fake_profile_admin(*, action: str, **kwargs):
assert action == "query"
assert kwargs["person_id"] == "explicit-person-id"
return {"success": True, "person_id": kwargs["person_id"]}
monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
monkeypatch.setattr(memory_router_module.memory_service, "profile_admin", fake_profile_admin)
response = client.get(
"/api/webui/memory/profiles/query",
params={"person_id": "explicit-person-id", "platform": "qq", "user_id": "12345"},
)
assert response.status_code == 200
assert response.json()["person_id"] == "explicit-person-id"
def test_webui_memory_profile_list_enriches_person_name(client: TestClient, monkeypatch):
async def fake_profile_admin(*, action: str, **kwargs):
assert action == "list"
assert kwargs["limit"] == 7
return {
"success": True,
"items": [
{"person_id": "person-1", "profile_text": "profile-1"},
{"person_id": "person-2", "profile_text": "profile-2"},
],
}
monkeypatch.setattr(memory_router_module.memory_service, "profile_admin", fake_profile_admin)
monkeypatch.setattr(
memory_router_module,
"_get_person_name_for_person_id",
lambda person_id: {"person-1": "Alice"}.get(person_id, ""),
)
response = client.get("/api/webui/memory/profiles", params={"limit": 7})
assert response.status_code == 200
assert response.json()["items"][0]["person_name"] == "Alice"
assert response.json()["items"][1]["person_name"] == ""
def test_webui_memory_profile_search_resolves_platform_user_id(client: TestClient, monkeypatch):
def fake_resolve_person_id_for_memory(**kwargs):
assert kwargs == {"platform": "qq", "user_id": "12345", "strict_known": False}
return "resolved-person-id"
async def fake_profile_list(limit: int):
assert limit == 200
return {
"success": True,
"items": [
{"person_id": "resolved-person-id", "person_name": "Alice", "profile_text": "喜欢咖啡"},
{"person_id": "other-person-id", "person_name": "Bob", "profile_text": "喜欢茶"},
],
}
monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
monkeypatch.setattr(memory_router_module, "_profile_list", fake_profile_list)
response = client.get(
"/api/webui/memory/profiles/search",
params={"platform": "qq", "user_id": "12345", "limit": 50},
)
assert response.status_code == 200
assert response.json()["items"] == [
{"person_id": "resolved-person-id", "person_name": "Alice", "profile_text": "喜欢咖啡"}
]
def test_webui_memory_profile_search_filters_keyword(client: TestClient, monkeypatch):
async def fake_profile_list(limit: int):
assert limit == 200
return {
"success": True,
"items": [
{"person_id": "person-1", "person_name": "Alice", "profile_text": "喜欢咖啡"},
{"person_id": "person-2", "person_name": "Bob", "profile_text": "喜欢茶"},
],
}
monkeypatch.setattr(memory_router_module, "_profile_list", fake_profile_list)
response = client.get("/api/webui/memory/profiles/search", params={"person_keyword": "咖啡", "limit": 50})
assert response.status_code == 200
assert response.json()["items"] == [
{"person_id": "person-1", "person_name": "Alice", "profile_text": "喜欢咖啡"}
]
def test_webui_memory_episode_list_resolves_platform_user_id(client: TestClient, monkeypatch):
def fake_resolve_person_id_for_memory(**kwargs):
assert kwargs == {"platform": "qq", "user_id": "12345", "strict_known": False}
return "resolved-person-id"
async def fake_episode_admin(*, action: str, **kwargs):
assert action == "list"
assert kwargs == {
"query": "咖啡",
"limit": 9,
"source": "chat_summary:demo",
"person_id": "resolved-person-id",
"time_start": 100.0,
"time_end": 200.0,
}
return {
"success": True,
"items": [{"episode_id": "ep-1", "person_id": "resolved-person-id", "summary": "喝咖啡"}],
"count": 1,
}
monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
monkeypatch.setattr(memory_router_module.memory_service, "episode_admin", fake_episode_admin)
monkeypatch.setattr(memory_router_module, "_get_person_name_for_person_id", lambda person_id: "测试人物")
response = client.get(
"/api/webui/memory/episodes",
params={
"query": "咖啡",
"limit": 9,
"source": "chat_summary:demo",
"platform": "qq",
"user_id": "12345",
"time_start": 100,
"time_end": 200,
},
)
assert response.status_code == 200
assert response.json()["items"][0]["person_name"] == "测试人物"
def test_webui_memory_episode_list_prefers_explicit_person_id(client: TestClient, monkeypatch):
def fake_resolve_person_id_for_memory(**kwargs):
raise AssertionError(f"不应解析平台账号: {kwargs}")
async def fake_episode_admin(*, action: str, **kwargs):
assert action == "list"
assert kwargs["person_id"] == "explicit-person-id"
return {"success": True, "items": []}
monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
monkeypatch.setattr(memory_router_module.memory_service, "episode_admin", fake_episode_admin)
response = client.get(
"/api/webui/memory/episodes",
params={"person_id": "explicit-person-id", "platform": "qq", "user_id": "12345"},
)
assert response.status_code == 200
assert response.json()["items"] == []
def test_compat_aggregate_route(client: TestClient, monkeypatch):
async def fake_search(query: str, **kwargs):
assert kwargs["mode"] == "aggregate"
assert kwargs["respect_filter"] is False
return MemorySearchResult(summary=f"summary:{query}", hits=[])
monkeypatch.setattr(memory_router_module.memory_service, "search", fake_search)
response = client.get("/api/query/aggregate", params={"query": "mai"})
assert response.status_code == 200
assert response.json() == {
"success": True,
"summary": "summary:mai",
"hits": [],
"filtered": False,
"error": "",
}
def test_auto_save_routes(client: TestClient, monkeypatch):
async def fake_runtime_admin(*, action: str, **kwargs):
if action == "get_config":
return {"success": True, "auto_save": True}
if action == "set_auto_save":
return {"success": True, "auto_save": kwargs["enabled"]}
raise AssertionError(action)
monkeypatch.setattr(memory_router_module.memory_service, "runtime_admin", fake_runtime_admin)
get_response = client.get("/api/config/auto_save")
post_response = client.post("/api/config/auto_save", json={"enabled": False})
assert get_response.status_code == 200
assert get_response.json() == {"success": True, "auto_save": True}
assert post_response.status_code == 200
assert post_response.json() == {"success": True, "auto_save": False}
def test_memory_config_routes(client: TestClient, monkeypatch):
monkeypatch.setattr(
memory_router_module.a_memorix_host_service,
"get_config_schema",
lambda: {"layout": {"type": "tabs"}, "sections": {"plugin": {"fields": {}}}},
)
monkeypatch.setattr(
memory_router_module.a_memorix_host_service,
"get_config_path",
lambda: memory_router_module.Path("/tmp/config/bot_config.toml"),
)
monkeypatch.setattr(
memory_router_module.a_memorix_host_service,
"get_config",
lambda: {"plugin": {"enabled": True}},
)
monkeypatch.setattr(
memory_router_module.a_memorix_host_service,
"get_raw_config",
lambda: "[plugin]\nenabled = true\n",
)
monkeypatch.setattr(
memory_router_module.a_memorix_host_service,
"get_raw_config_with_meta",
lambda: {
"config": "[plugin]\nenabled = true\n",
"exists": True,
"using_default": False,
},
)
schema_response = client.get("/api/webui/memory/config/schema")
config_response = client.get("/api/webui/memory/config")
raw_response = client.get("/api/webui/memory/config/raw")
expected_path = memory_router_module.Path("/tmp/config/bot_config.toml").as_posix()
assert schema_response.status_code == 200
assert memory_router_module.Path(schema_response.json()["path"]).as_posix() == expected_path
assert schema_response.json()["schema"]["layout"]["type"] == "tabs"
assert config_response.status_code == 200
assert config_response.json()["success"] is True
assert config_response.json()["config"] == {"plugin": {"enabled": True}}
assert memory_router_module.Path(config_response.json()["path"]).as_posix() == expected_path
assert raw_response.status_code == 200
assert raw_response.json()["success"] is True
assert raw_response.json()["config"] == "[plugin]\nenabled = true\n"
assert memory_router_module.Path(raw_response.json()["path"]).as_posix() == expected_path
def test_memory_config_raw_returns_default_template_when_file_missing(client: TestClient, monkeypatch):
monkeypatch.setattr(
memory_router_module.a_memorix_host_service,
"get_config_path",
lambda: memory_router_module.Path("/tmp/config/bot_config.toml"),
)
monkeypatch.setattr(
memory_router_module.a_memorix_host_service,
"get_raw_config_with_meta",
lambda: {
"config": "[plugin]\nenabled = true\n",
"exists": False,
"using_default": True,
},
)
response = client.get("/api/webui/memory/config/raw")
assert response.status_code == 200
assert response.json()["success"] is True
assert response.json()["config"] == "[plugin]\nenabled = true\n"
assert response.json()["exists"] is False
assert response.json()["using_default"] is True
def test_memory_config_update_routes(client: TestClient, monkeypatch):
async def fake_update_config(config):
assert config == {"plugin": {"enabled": False}}
return {"success": True, "config_path": "config/bot_config.toml"}
async def fake_update_raw(raw_config):
assert raw_config == "[plugin]\nenabled = false\n"
return {"success": True, "config_path": "config/bot_config.toml"}
monkeypatch.setattr(memory_router_module.a_memorix_host_service, "update_config", fake_update_config)
monkeypatch.setattr(memory_router_module.a_memorix_host_service, "update_raw_config", fake_update_raw)
config_response = client.put("/api/webui/memory/config", json={"config": {"plugin": {"enabled": False}}})
raw_response = client.put("/api/webui/memory/config/raw", json={"config": "[plugin]\nenabled = false\n"})
assert config_response.status_code == 200
assert config_response.json() == {"success": True, "config_path": "config/bot_config.toml"}
assert raw_response.status_code == 200
assert raw_response.json() == {"success": True, "config_path": "config/bot_config.toml"}
def test_memory_config_raw_rejects_invalid_toml(client: TestClient):
response = client.put("/api/webui/memory/config/raw", json={"config": "[plugin\nenabled = true"})
assert response.status_code == 400
assert "TOML 格式错误" in response.json()["detail"]
def test_recycle_bin_route(client: TestClient, monkeypatch):
async def fake_get_recycle_bin(*, limit: int):
return {"success": True, "items": [{"hash": "deadbeef"}], "count": 1, "limit": limit}
monkeypatch.setattr(memory_router_module.memory_service, "get_recycle_bin", fake_get_recycle_bin)
response = client.get("/api/memory/recycle_bin", params={"limit": 10})
assert response.status_code == 200
assert response.json()["success"] is True
assert response.json()["count"] == 1
assert response.json()["limit"] == 10
def test_import_guide_route(client: TestClient, monkeypatch):
async def fake_import_admin(*, action: str, **kwargs):
assert kwargs == {}
if action == "get_guide":
return {"success": True}
if action == "get_settings":
return {"success": True, "settings": {"path_aliases": {"raw": "/tmp/raw"}}}
raise AssertionError(action)
monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin)
response = client.get("/api/webui/memory/import/guide")
assert response.status_code == 200
assert response.json()["success"] is True
assert response.json()["source"] == "local"
assert "长期记忆导入说明" in response.json()["content"]
def test_import_upload_route(client: TestClient, monkeypatch, tmp_path):
monkeypatch.setattr(memory_router_module, "STAGING_ROOT", tmp_path)
async def fake_import_admin(*, action: str, **kwargs):
assert action == "create_upload"
staged_files = kwargs["staged_files"]
assert len(staged_files) == 1
assert staged_files[0]["filename"] == "demo.txt"
assert memory_router_module.Path(staged_files[0]["staged_path"]).exists()
return {"success": True, "task_id": "task-1"}
monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin)
response = client.post(
"/api/import/upload",
data={"payload_json": "{\"source\": \"upload\"}"},
files=[("files", ("demo.txt", b"hello world", "text/plain"))],
)
assert response.status_code == 200
assert response.json() == {"success": True, "task_id": "task-1"}
assert list(tmp_path.iterdir()) == []
def test_v5_status_route(client: TestClient, monkeypatch):
async def fake_v5_admin(*, action: str, **kwargs):
assert action == "status"
assert kwargs["target"] == "mai"
return {"success": True, "active_count": 1, "inactive_count": 2, "deleted_count": 3}
monkeypatch.setattr(memory_router_module.memory_service, "v5_admin", fake_v5_admin)
response = client.get("/api/webui/memory/v5/status", params={"target": "mai"})
assert response.status_code == 200
assert response.json()["success"] is True
assert response.json()["deleted_count"] == 3
def test_delete_preview_route(client: TestClient, monkeypatch):
async def fake_delete_admin(*, action: str, **kwargs):
assert action == "preview"
assert kwargs["mode"] == "paragraph"
assert kwargs["selector"] == {"query": "demo"}
return {"success": True, "counts": {"paragraphs": 1}, "dry_run": True}
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
response = client.post(
"/api/webui/memory/delete/preview",
json={"mode": "paragraph", "selector": {"query": "demo"}},
)
assert response.status_code == 200
assert response.json() == {"success": True, "counts": {"paragraphs": 1}, "dry_run": True}
def test_delete_preview_route_supports_mixed_mode(client: TestClient, monkeypatch):
async def fake_delete_admin(*, action: str, **kwargs):
assert action == "preview"
assert kwargs["mode"] == "mixed"
assert kwargs["selector"] == {
"entity_hashes": ["entity-1"],
"paragraph_hashes": ["p-1"],
"relation_hashes": ["rel-1"],
"sources": ["demo"],
}
return {"success": True, "mode": "mixed", "counts": {"entities": 1, "paragraphs": 1, "relations": 1, "sources": 1}}
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
response = client.post(
"/api/webui/memory/delete/preview",
json={
"mode": "mixed",
"selector": {
"entity_hashes": ["entity-1"],
"paragraph_hashes": ["p-1"],
"relation_hashes": ["rel-1"],
"sources": ["demo"],
},
},
)
assert response.status_code == 200
assert response.json()["mode"] == "mixed"
assert response.json()["counts"]["entities"] == 1
def test_delete_execute_route_supports_mixed_mode(client: TestClient, monkeypatch):
async def fake_delete_admin(*, action: str, **kwargs):
assert action == "execute"
assert kwargs["mode"] == "mixed"
assert kwargs["selector"] == {
"entity_hashes": ["entity-1"],
"paragraph_hashes": ["p-1"],
"relation_hashes": ["rel-1"],
"sources": ["demo"],
}
assert kwargs["reason"] == "knowledge_graph_delete_entity"
assert kwargs["requested_by"] == "knowledge_graph"
return {
"success": True,
"mode": "mixed",
"operation_id": "op-mixed-1",
"deleted_count": 4,
"deleted_entity_count": 1,
"deleted_relation_count": 1,
"deleted_paragraph_count": 1,
"deleted_source_count": 1,
"counts": {"entities": 1, "paragraphs": 1, "relations": 1, "sources": 1},
}
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
response = client.post(
"/api/webui/memory/delete/execute",
json={
"mode": "mixed",
"selector": {
"entity_hashes": ["entity-1"],
"paragraph_hashes": ["p-1"],
"relation_hashes": ["rel-1"],
"sources": ["demo"],
},
"reason": "knowledge_graph_delete_entity",
"requested_by": "knowledge_graph",
},
)
assert response.status_code == 200
assert response.json()["success"] is True
assert response.json()["mode"] == "mixed"
assert response.json()["operation_id"] == "op-mixed-1"
def test_episode_process_pending_route(client: TestClient, monkeypatch):
async def fake_episode_admin(*, action: str, **kwargs):
assert action == "process_pending"
assert kwargs == {"limit": 7, "max_retry": 4}
return {"success": True, "processed": 3}
monkeypatch.setattr(memory_router_module.memory_service, "episode_admin", fake_episode_admin)
response = client.post("/api/webui/memory/episodes/process-pending", json={"limit": 7, "max_retry": 4})
assert response.status_code == 200
assert response.json() == {"success": True, "processed": 3}
def test_import_list_route_includes_settings(client: TestClient, monkeypatch):
calls = []
async def fake_import_admin(*, action: str, **kwargs):
calls.append((action, kwargs))
if action == "list":
return {"success": True, "items": [{"task_id": "task-1"}]}
if action == "get_settings":
return {"success": True, "settings": {"path_aliases": {"lpmm": "/tmp/lpmm"}}}
raise AssertionError(action)
monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin)
response = client.get("/api/webui/memory/import/tasks", params={"limit": 9})
assert response.status_code == 200
assert response.json()["items"] == [{"task_id": "task-1"}]
assert response.json()["settings"] == {"path_aliases": {"lpmm": "/tmp/lpmm"}}
assert calls == [("list", {"limit": 9}), ("get_settings", {})]
def test_tuning_profile_route_backfills_settings(client: TestClient, monkeypatch):
calls = []
async def fake_tuning_admin(*, action: str, **kwargs):
calls.append((action, kwargs))
if action == "get_profile":
return {"success": True, "profile": {"retrieval": {"top_k": 8}}}
if action == "get_settings":
return {"success": True, "settings": {"profiles": ["default"]}}
raise AssertionError(action)
monkeypatch.setattr(memory_router_module.memory_service, "tuning_admin", fake_tuning_admin)
response = client.get("/api/webui/memory/retrieval_tuning/profile")
assert response.status_code == 200
assert response.json()["profile"] == {"retrieval": {"top_k": 8}}
assert response.json()["settings"] == {"profiles": ["default"]}
assert calls == [("get_profile", {}), ("get_settings", {})]
def test_tuning_report_route_flattens_report_payload(client: TestClient, monkeypatch):
async def fake_tuning_admin(*, action: str, **kwargs):
assert action == "get_report"
assert kwargs == {"task_id": "task-1", "format": "json"}
return {
"success": True,
"report": {"format": "json", "content": "{\"ok\": true}", "path": "/tmp/report.json"},
}
monkeypatch.setattr(memory_router_module.memory_service, "tuning_admin", fake_tuning_admin)
response = client.get("/api/webui/memory/retrieval_tuning/tasks/task-1/report", params={"format": "json"})
assert response.status_code == 200
assert response.json() == {
"success": True,
"format": "json",
"content": "{\"ok\": true}",
"path": "/tmp/report.json",
"error": "",
}
def test_delete_execute_route(client: TestClient, monkeypatch):
async def fake_delete_admin(*, action: str, **kwargs):
assert action == "execute"
assert kwargs["mode"] == "source"
assert kwargs["selector"] == {"source": "chat_summary:stream-1"}
assert kwargs["reason"] == "cleanup"
assert kwargs["requested_by"] == "tester"
return {"success": True, "operation_id": "del-1"}
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
response = client.post(
"/api/webui/memory/delete/execute",
json={
"mode": "source",
"selector": {"source": "chat_summary:stream-1"},
"reason": "cleanup",
"requested_by": "tester",
},
)
assert response.status_code == 200
assert response.json() == {"success": True, "operation_id": "del-1"}
def test_sources_route(client: TestClient, monkeypatch):
async def fake_source_admin(*, action: str, **kwargs):
assert action == "list"
assert kwargs == {}
return {"success": True, "items": [{"source": "demo", "paragraph_count": 2}], "count": 1}
monkeypatch.setattr(memory_router_module.memory_service, "source_admin", fake_source_admin)
response = client.get("/api/webui/memory/sources")
assert response.status_code == 200
assert response.json()["items"] == [{"source": "demo", "paragraph_count": 2}]
def test_delete_operation_routes(client: TestClient, monkeypatch):
async def fake_delete_admin(*, action: str, **kwargs):
if action == "list_operations":
assert kwargs == {"limit": 5, "mode": "paragraph"}
return {"success": True, "items": [{"operation_id": "del-1"}], "count": 1}
if action == "get_operation":
assert kwargs == {"operation_id": "del-1"}
return {"success": True, "operation": {"operation_id": "del-1", "mode": "paragraph"}}
raise AssertionError(action)
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
list_response = client.get("/api/webui/memory/delete/operations", params={"limit": 5, "mode": "paragraph"})
get_response = client.get("/api/webui/memory/delete/operations/del-1")
assert list_response.status_code == 200
assert list_response.json()["count"] == 1
assert get_response.status_code == 200
assert get_response.json()["operation"]["operation_id"] == "del-1"
def test_feedback_correction_routes(client: TestClient, monkeypatch):
async def fake_feedback_admin(*, action: str, **kwargs):
if action == "list":
assert kwargs == {
"limit": 7,
"statuses": ["applied"],
"rollback_statuses": ["none"],
"query": "green",
}
return {"success": True, "items": [{"task_id": 11, "query_text": "what color"}], "count": 1}
if action == "get":
assert kwargs == {"task_id": 11}
return {"success": True, "task": {"task_id": 11, "query_text": "what color", "action_logs": []}}
if action == "rollback":
assert kwargs == {"task_id": 11, "requested_by": "tester", "reason": "manual revert"}
return {"success": True, "result": {"restored_relation_hashes": ["rel-1"]}}
raise AssertionError(action)
monkeypatch.setattr(memory_router_module.memory_service, "feedback_admin", fake_feedback_admin)
list_response = client.get(
"/api/webui/memory/feedback-corrections",
params={"limit": 7, "status": "applied", "rollback_status": "none", "query": "green"},
)
get_response = client.get("/api/webui/memory/feedback-corrections/11")
rollback_response = client.post(
"/api/webui/memory/feedback-corrections/11/rollback",
json={"requested_by": "tester", "reason": "manual revert"},
)
assert list_response.status_code == 200
assert list_response.json()["items"][0]["task_id"] == 11
assert get_response.status_code == 200
assert get_response.json()["task"]["task_id"] == 11
assert rollback_response.status_code == 200
assert rollback_response.json()["result"]["restored_relation_hashes"] == ["rel-1"]

View File

@@ -1,533 +0,0 @@
from __future__ import annotations
from pathlib import Path
from time import monotonic, sleep
from typing import Any, Dict, Generator
from uuid import uuid4
import asyncio
import json
from fastapi import FastAPI
from fastapi.testclient import TestClient
import pytest
import tomlkit
from src.A_memorix import host_service as host_service_module
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
from src.A_memorix.core.utils import retrieval_tuning_manager as tuning_manager_module
from src.webui.dependencies import require_auth
from src.webui.routers import memory as memory_router_module
REQUEST_TIMEOUT_SECONDS = 30
IMPORT_TIMEOUT_SECONDS = 120
TUNING_TIMEOUT_SECONDS = 420
IMPORT_TERMINAL_STATUSES = {"completed", "completed_with_errors", "failed", "cancelled"}
TUNING_TERMINAL_STATUSES = {"completed", "failed", "cancelled"}
class _FakeEmbeddingManager:
def __init__(self, dimension: int = 64) -> None:
self.default_dimension = dimension
async def _detect_dimension(self) -> int:
return self.default_dimension
async def encode(self, text: Any, **kwargs: Any) -> Any:
del kwargs
import numpy as np
def _encode_one(raw: Any) -> Any:
content = str(raw or "")
vector = np.zeros(self.default_dimension, dtype=np.float32)
for index, byte in enumerate(content.encode("utf-8")):
vector[index % self.default_dimension] += float((byte % 17) + 1)
norm = float(np.linalg.norm(vector))
if norm > 0:
vector /= norm
return vector
if isinstance(text, (list, tuple)):
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
return _encode_one(text).astype(np.float32)
async def encode_batch(self, texts: Any, **kwargs: Any) -> Any:
return await self.encode(texts, **kwargs)
def _build_test_config(data_dir: Path) -> Dict[str, Any]:
return {
"storage": {
"data_dir": str(data_dir),
},
"advanced": {
"enable_auto_save": False,
},
"embedding": {
"dimension": 64,
"batch_size": 4,
"max_concurrent": 1,
"retry": {
"max_attempts": 1,
"min_wait_seconds": 0.1,
"max_wait_seconds": 0.2,
"backoff_multiplier": 1.0,
},
"fallback": {
"enabled": True,
"allow_metadata_only_write": True,
"probe_interval_seconds": 30,
},
"paragraph_vector_backfill": {
"enabled": False,
"interval_seconds": 60,
"batch_size": 32,
"max_retry": 2,
},
},
"retrieval": {
"enable_parallel": False,
"enable_ppr": False,
"top_k_paragraphs": 20,
"top_k_relations": 10,
"top_k_final": 10,
"alpha": 0.5,
"search": {
"smart_fallback": {
"enabled": True,
},
},
"sparse": {
"enabled": True,
"mode": "auto",
"candidate_k": 80,
"relation_candidate_k": 60,
},
"fusion": {
"method": "weighted_rrf",
"rrf_k": 60,
"vector_weight": 0.7,
"bm25_weight": 0.3,
},
},
"threshold": {
"percentile": 70.0,
"min_results": 1,
},
"web": {
"tuning": {
"enabled": True,
"poll_interval_ms": 300,
"max_queue_size": 4,
"default_objective": "balanced",
"default_intensity": "quick",
"default_sample_size": 4,
"default_top_k_eval": 5,
"eval_query_timeout_seconds": 1.0,
"llm_retry": {
"max_attempts": 1,
"min_wait_seconds": 0.1,
"max_wait_seconds": 0.2,
"backoff_multiplier": 1.0,
},
},
},
}
def _assert_response_ok(response: Any) -> Dict[str, Any]:
assert response.status_code == 200, response.text
payload = response.json()
assert payload.get("success", True) is True, payload
return payload
def _wait_for_import_task_terminal(client: TestClient, task_id: str, *, timeout_seconds: float = IMPORT_TIMEOUT_SECONDS) -> Dict[str, Any]:
deadline = monotonic() + timeout_seconds
last_payload: Dict[str, Any] = {}
while monotonic() < deadline:
response = client.get(
f"/api/webui/memory/import/tasks/{task_id}",
params={"include_chunks": True},
)
payload = _assert_response_ok(response)
last_payload = payload
task = payload.get("task") or {}
status = str(task.get("status", "") or "")
if status in IMPORT_TERMINAL_STATUSES:
return task
sleep(0.2)
raise AssertionError(f"导入任务超时: task_id={task_id}, last_payload={last_payload}")
def _wait_for_tuning_task_terminal(client: TestClient, task_id: str, *, timeout_seconds: float = TUNING_TIMEOUT_SECONDS) -> Dict[str, Any]:
deadline = monotonic() + timeout_seconds
last_payload: Dict[str, Any] = {}
while monotonic() < deadline:
response = client.get(
f"/api/webui/memory/retrieval_tuning/tasks/{task_id}",
params={"include_rounds": False},
)
payload = _assert_response_ok(response)
last_payload = payload
task = payload.get("task") or {}
status = str(task.get("status", "") or "")
if status in TUNING_TERMINAL_STATUSES:
return task
sleep(0.3)
raise AssertionError(f"调优任务超时: task_id={task_id}, last_payload={last_payload}")
def _wait_for_query_hit(client: TestClient, query: str, *, timeout_seconds: float = 30.0) -> Dict[str, Any]:
deadline = monotonic() + timeout_seconds
last_payload: Dict[str, Any] = {}
while monotonic() < deadline:
payload = _assert_response_ok(
client.get(
"/api/webui/memory/query/aggregate",
params={"query": query, "limit": 20},
)
)
last_payload = payload
hits = payload.get("hits") or []
if isinstance(hits, list) and len(hits) > 0:
return payload
sleep(0.2)
raise AssertionError(f"检索命中超时: query={query}, last_payload={last_payload}")
def _get_source_item(client: TestClient, source_name: str) -> Dict[str, Any] | None:
payload = _assert_response_ok(client.get("/api/webui/memory/sources"))
items = payload.get("items") or []
for item in items:
if not isinstance(item, dict):
continue
if str(item.get("source", "") or "") == source_name:
return item
return None
def _source_paragraph_count(item: Dict[str, Any] | None) -> int:
payload = item or {}
if "paragraph_count" in payload:
return int(payload.get("paragraph_count", 0) or 0)
return int(payload.get("count", 0) or 0)
def _wait_for_source_paragraph_count(
client: TestClient,
source_name: str,
*,
min_count: int,
timeout_seconds: float = 30.0,
) -> Dict[str, Any]:
deadline = monotonic() + timeout_seconds
last_item: Dict[str, Any] = {}
while monotonic() < deadline:
item = _get_source_item(client, source_name)
count = _source_paragraph_count(item)
if count >= int(min_count):
return item or {}
if item:
last_item = dict(item)
sleep(0.2)
raise AssertionError(
f"等待来源段落计数超时: source={source_name}, min_count={min_count}, last_item={last_item}"
)
def _create_multitype_upload_task(client: TestClient) -> str:
structured_json = {
"paragraphs": [
{
"content": "Alice 携带地图前往火星港。",
"source": "integration-upload-json",
"entities": ["Alice", "地图", "火星港"],
"relations": [
{"subject": "Alice", "predicate": "携带", "object": "地图"},
{"subject": "Alice", "predicate": "前往", "object": "火星港"},
],
}
]
}
extra_json = {
"paragraphs": [
{
"content": "Carol 记录了一条补充说明。",
"source": "integration-upload-json-extra",
"entities": ["Carol"],
"relations": [],
}
]
}
payload_json = json.dumps(
{
"input_mode": "text",
"llm_enabled": False,
"file_concurrency": 2,
"chunk_concurrency": 2,
"dedupe_policy": "none",
},
ensure_ascii=False,
)
files = [
("files", ("integration-notes.txt", "Alice 在测试环境记录了一条长期记忆。".encode("utf-8"), "text/plain")),
("files", ("integration-diary.md", "# 日志\nBob 与 Alice 讨论了导图。".encode("utf-8"), "text/markdown")),
("files", ("integration-structured.json", json.dumps(structured_json, ensure_ascii=False).encode("utf-8"), "application/json")),
("files", ("integration-extra.json", json.dumps(extra_json, ensure_ascii=False).encode("utf-8"), "application/json")),
]
response = client.post(
"/api/webui/memory/import/upload",
data={"payload_json": payload_json},
files=files,
)
payload = _assert_response_ok(response)
task_id = str((payload.get("task") or {}).get("task_id") or "").strip()
assert task_id, payload
return task_id
def _create_seed_paste_task(client: TestClient, *, source: str, unique_token: str) -> str:
seed_payload = {
"paragraphs": [
{
"content": f"Alice 在火星港携带地图并记录了口令 {unique_token}",
"source": source,
"entities": ["Alice", "火星港", "地图"],
"relations": [
{"subject": "Alice", "predicate": "前往", "object": "火星港"},
{"subject": "Alice", "predicate": "携带", "object": "地图"},
],
},
{
"content": f"Bob 在火星港遇见 Alice并重复口令 {unique_token}",
"source": source,
"entities": ["Bob", "Alice", "火星港"],
"relations": [
{"subject": "Bob", "predicate": "遇见", "object": "Alice"},
{"subject": "Bob", "predicate": "位于", "object": "火星港"},
],
},
]
}
response = client.post(
"/api/webui/memory/import/paste",
json={
"name": "integration-seed.json",
"input_mode": "json",
"llm_enabled": False,
"content": json.dumps(seed_payload, ensure_ascii=False),
"dedupe_policy": "none",
},
)
payload = _assert_response_ok(response)
task_id = str((payload.get("task") or {}).get("task_id") or "").strip()
assert task_id, payload
return task_id
@pytest.fixture(scope="module")
def integration_state(tmp_path_factory: pytest.TempPathFactory) -> Generator[Dict[str, Any], None, None]:
tmp_root = tmp_path_factory.mktemp("memory_routes_integration")
data_dir = (tmp_root / "data").resolve()
staging_dir = (tmp_root / "upload_staging").resolve()
artifacts_dir = (tmp_root / "artifacts").resolve()
config_file = (tmp_root / "config" / "bot_config.toml").resolve()
runtime_config = _build_test_config(data_dir)
patches = pytest.MonkeyPatch()
patches.setattr(host_service_module.a_memorix_host_service, "_read_config", lambda: dict(runtime_config))
patches.setattr(host_service_module.a_memorix_host_service, "get_config_path", lambda: config_file)
patches.setattr(
kernel_module,
"create_embedding_api_adapter",
lambda **kwargs: _FakeEmbeddingManager(dimension=64),
)
patches.setattr(memory_router_module, "STAGING_ROOT", staging_dir)
patches.setattr(tuning_manager_module, "artifacts_root", lambda: artifacts_dir)
asyncio.run(host_service_module.a_memorix_host_service.stop())
host_service_module.a_memorix_host_service._config_cache = None # type: ignore[attr-defined]
app = FastAPI()
app.dependency_overrides[require_auth] = lambda: "ok"
app.include_router(memory_router_module.router, prefix="/api/webui")
app.include_router(memory_router_module.compat_router)
unique_token = f"INTEG_TOKEN_{uuid4().hex[:12]}"
source_name = f"integration-source-{uuid4().hex[:8]}"
with TestClient(app) as client:
upload_task_id = _create_multitype_upload_task(client)
upload_task = _wait_for_import_task_terminal(client, upload_task_id)
seed_task_id = _create_seed_paste_task(client, source=source_name, unique_token=unique_token)
seed_task = _wait_for_import_task_terminal(client, seed_task_id)
assert str(seed_task.get("status", "") or "") in {"completed", "completed_with_errors"}, seed_task
_wait_for_query_hit(client, unique_token, timeout_seconds=45.0)
yield {
"client": client,
"upload_task": upload_task,
"seed_task": seed_task,
"source_name": source_name,
"unique_token": unique_token,
}
asyncio.run(host_service_module.a_memorix_host_service.stop())
host_service_module.a_memorix_host_service._config_cache = None # type: ignore[attr-defined]
patches.undo()
def test_import_module_end_to_end_supports_multitype_upload(integration_state: Dict[str, Any]) -> None:
upload_task = integration_state["upload_task"]
assert str(upload_task.get("status", "") or "") in {"completed", "completed_with_errors"}, upload_task
files = upload_task.get("files") or []
assert isinstance(files, list)
assert len(files) >= 4
file_names = {str(item.get("name", "") or "") for item in files if isinstance(item, dict)}
assert "integration-notes.txt" in file_names
assert "integration-diary.md" in file_names
assert "integration-structured.json" in file_names
assert "integration-extra.json" in file_names
def test_retrieval_module_end_to_end_queries_seeded_data(integration_state: Dict[str, Any]) -> None:
client = integration_state["client"]
unique_token = integration_state["unique_token"]
aggregate_payload = _wait_for_query_hit(client, unique_token, timeout_seconds=45.0)
hits = aggregate_payload.get("hits") or []
joined_content = "\n".join(str(item.get("content", "") or "") for item in hits if isinstance(item, dict))
assert unique_token in joined_content
graph_payload = _assert_response_ok(
client.get(
"/api/webui/memory/graph/search",
params={"query": "Alice", "limit": 20},
)
)
graph_items = graph_payload.get("items") or []
assert isinstance(graph_items, list)
assert any(str(item.get("type", "") or "") == "entity" for item in graph_items if isinstance(item, dict)), graph_items
def test_tuning_module_end_to_end_create_and_apply_best(integration_state: Dict[str, Any]) -> None:
client = integration_state["client"]
create_payload = _assert_response_ok(
client.post(
"/api/webui/memory/retrieval_tuning/tasks",
json={
"objective": "balanced",
"intensity": "quick",
"rounds": 2,
"sample_size": 4,
"top_k_eval": 5,
"llm_enabled": False,
"eval_query_timeout_seconds": 1.0,
"seed": 20260403,
},
)
)
task_id = str((create_payload.get("task") or {}).get("task_id") or "").strip()
assert task_id, create_payload
task = _wait_for_tuning_task_terminal(client, task_id)
assert str(task.get("status", "") or "") == "completed", task
apply_payload = _assert_response_ok(
client.post(
f"/api/webui/memory/retrieval_tuning/tasks/{task_id}/apply-best",
)
)
assert "applied" in apply_payload
def test_delete_module_end_to_end_preview_execute_restore(integration_state: Dict[str, Any]) -> None:
client = integration_state["client"]
unique_token = integration_state["unique_token"]
source_name = integration_state["source_name"]
before_source_item = _wait_for_source_paragraph_count(client, source_name, min_count=1, timeout_seconds=45.0)
assert _source_paragraph_count(before_source_item) >= 1
preview_payload = _assert_response_ok(
client.post(
"/api/webui/memory/delete/preview",
json={
"mode": "source",
"selector": {"sources": [source_name]},
"reason": "integration_delete_preview",
"requested_by": "pytest_integration",
},
)
)
preview_counts = preview_payload.get("counts") or {}
assert int(preview_counts.get("paragraphs", 0) or 0) >= 1, preview_payload
execute_payload = _assert_response_ok(
client.post(
"/api/webui/memory/delete/execute",
json={
"mode": "source",
"selector": {"sources": [source_name]},
"reason": "integration_delete_execute",
"requested_by": "pytest_integration",
},
)
)
operation_id = str(execute_payload.get("operation_id", "") or "").strip()
assert operation_id, execute_payload
after_delete_payload = _assert_response_ok(
client.get(
"/api/webui/memory/query/aggregate",
params={"query": unique_token, "limit": 20},
)
)
after_delete_hits = after_delete_payload.get("hits") or []
after_delete_text = "\n".join(
str(item.get("content", "") or "")
for item in after_delete_hits
if isinstance(item, dict)
)
assert unique_token not in after_delete_text
assert int(execute_payload.get("deleted_paragraph_count", 0) or 0) >= 1, execute_payload
_assert_response_ok(
client.post(
"/api/webui/memory/delete/restore",
json={
"operation_id": operation_id,
"requested_by": "pytest_integration",
},
)
)
restored_source_item = _wait_for_source_paragraph_count(client, source_name, min_count=1, timeout_seconds=45.0)
assert _source_paragraph_count(restored_source_item) >= 1
operations_payload = _assert_response_ok(
client.get(
"/api/webui/memory/delete/operations",
params={"limit": 20, "mode": "source"},
)
)
operation_items = operations_payload.get("items") or []
operation_ids = {
str(item.get("operation_id", "") or "")
for item in operation_items
if isinstance(item, dict)
}
assert operation_id in operation_ids
operation_detail_payload = _assert_response_ok(client.get(f"/api/webui/memory/delete/operations/{operation_id}"))
detail_operation = operation_detail_payload.get("operation") or {}
assert str(detail_operation.get("status", "") or "") == "restored"

View File

@@ -1,187 +0,0 @@
"""模型路由测试
验证 Gemini 提供商连接测试会使用查询参数传递 API Key
并且不会回退到 OpenAI 兼容接口使用的 Bearer 认证方式。
"""
import importlib
import sys
from types import ModuleType
from typing import Any
import pytest
def load_model_routes(monkeypatch: pytest.MonkeyPatch):
"""在导入路由前 stub 配置与认证依赖模块,避免测试时触发真实初始化。"""
config_module = ModuleType("src.config.config")
config_module.__dict__["CONFIG_DIR"] = "."
monkeypatch.setitem(sys.modules, "src.config.config", config_module)
dependencies_module = ModuleType("src.webui.dependencies")
async def require_auth():
return "test-token"
dependencies_module.__dict__["require_auth"] = require_auth
monkeypatch.setitem(sys.modules, "src.webui.dependencies", dependencies_module)
sys.modules.pop("src.webui.routers.model", None)
return importlib.import_module("src.webui.routers.model")
class FakeResponse:
"""简化版 HTTP 响应对象。"""
def __init__(self, status_code: int):
self.status_code = status_code
def build_async_client_factory(
responses: list[FakeResponse],
calls: list[dict[str, Any]],
):
"""构造一个可记录请求参数的 AsyncClient 替身。"""
response_iter = iter(responses)
class FakeAsyncClient:
def __init__(self, *args: Any, **kwargs: Any):
self.args = args
self.kwargs = kwargs
async def __aenter__(self) -> "FakeAsyncClient":
return self
async def __aexit__(self, exc_type, exc, tb) -> bool:
return False
async def get(
self,
url: str,
headers: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
) -> FakeResponse:
calls.append(
{
"url": url,
"headers": headers or {},
"params": params or {},
}
)
return next(response_iter)
return FakeAsyncClient
@pytest.mark.asyncio
async def test_test_provider_connection_uses_query_api_key_for_gemini(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Gemini 连接测试应通过查询参数传递 API Key。"""
model_routes = load_model_routes(monkeypatch)
calls: list[dict[str, Any]] = []
fake_client_class = build_async_client_factory(
responses=[FakeResponse(200), FakeResponse(200)],
calls=calls,
)
monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class)
result = await model_routes.test_provider_connection(
base_url="https://generativelanguage.googleapis.com/v1beta",
api_key="valid-gemini-key",
client_type="gemini",
)
assert result["network_ok"] is True
assert result["api_key_valid"] is True
assert len(calls) == 2
network_call = calls[0]
validation_call = calls[1]
assert network_call["url"] == "https://generativelanguage.googleapis.com/v1beta"
assert network_call["headers"] == {}
assert network_call["params"] == {}
assert validation_call["url"] == "https://generativelanguage.googleapis.com/v1beta/models"
assert validation_call["params"] == {"key": "valid-gemini-key"}
assert validation_call["headers"] == {"Content-Type": "application/json"}
assert "Authorization" not in validation_call["headers"]
@pytest.mark.asyncio
async def test_test_provider_connection_uses_bearer_auth_for_openai_compatible(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""非 Gemini 提供商连接测试应继续使用 Bearer 认证。"""
model_routes = load_model_routes(monkeypatch)
calls: list[dict[str, Any]] = []
fake_client_class = build_async_client_factory(
responses=[FakeResponse(200), FakeResponse(200)],
calls=calls,
)
monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class)
result = await model_routes.test_provider_connection(
base_url="https://example.com/v1",
api_key="valid-openai-key",
client_type="openai",
)
assert result["network_ok"] is True
assert result["api_key_valid"] is True
assert len(calls) == 2
validation_call = calls[1]
assert validation_call["url"] == "https://example.com/v1/models"
assert validation_call["params"] == {}
assert validation_call["headers"]["Content-Type"] == "application/json"
assert validation_call["headers"]["Authorization"] == "Bearer valid-openai-key"
@pytest.mark.asyncio
async def test_test_provider_connection_by_name_forwards_provider_client_type(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
) -> None:
"""按提供商名称测试连接时,应透传配置中的 client_type。"""
model_routes = load_model_routes(monkeypatch)
config_path = tmp_path / "model_config.toml"
config_path.write_text(
"""
[[api_providers]]
name = "Gemini"
base_url = "https://generativelanguage.googleapis.com/v1beta"
api_key = "valid-gemini-key"
client_type = "gemini"
""".strip(),
encoding="utf-8",
)
monkeypatch.setattr(model_routes, "CONFIG_DIR", str(tmp_path))
captured_kwargs: dict[str, Any] = {}
async def fake_test_provider_connection(**kwargs: Any) -> dict[str, Any]:
captured_kwargs.update(kwargs)
return {
"network_ok": True,
"api_key_valid": True,
"latency_ms": 12.34,
"error": None,
"http_status": 200,
}
monkeypatch.setattr(model_routes, "test_provider_connection", fake_test_provider_connection)
result = await model_routes.test_provider_connection_by_name(provider_name="Gemini")
assert result["network_ok"] is True
assert result["api_key_valid"] is True
assert captured_kwargs == {
"base_url": "https://generativelanguage.googleapis.com/v1beta",
"api_key": "valid-gemini-key",
"client_type": "gemini",
}

View File

@@ -1,136 +0,0 @@
import json
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from src.webui.routers.plugin import management as management_module
from src.webui.routers.plugin import support as support_module
@pytest.fixture
def client(tmp_path, monkeypatch) -> TestClient:
plugins_dir = tmp_path / "plugins"
plugins_dir.mkdir(parents=True, exist_ok=True)
demo_dir = plugins_dir / "demo_plugin"
demo_dir.mkdir()
(demo_dir / "_manifest.json").write_text(
json.dumps(
{
"manifest_version": 2,
"id": "test.demo",
"name": "Demo Plugin",
"version": "1.0.0",
"description": "demo plugin",
}
),
encoding="utf-8",
)
monkeypatch.setattr(management_module, "require_plugin_token", lambda _: "ok")
monkeypatch.setattr(support_module, "get_plugins_dir", lambda: plugins_dir)
app = FastAPI()
app.include_router(management_module.router, prefix="/api/webui/plugins")
return TestClient(app)
def test_installed_plugins_only_scan_plugins_dir_and_exclude_a_memorix(client: TestClient):
response = client.get("/api/webui/plugins/installed")
assert response.status_code == 200
payload = response.json()
assert payload["success"] is True
ids = [plugin["id"] for plugin in payload["plugins"]]
assert ids == ["test.demo"]
assert "a-dawn.a-memorix" not in ids
assert all("/src/plugins/built_in/" not in plugin["path"] for plugin in payload["plugins"])
def test_resolve_installed_plugin_path_falls_back_to_manifest_id(client: TestClient):
plugin_path = support_module.resolve_installed_plugin_path("test.demo")
assert plugin_path is not None
assert plugin_path.name == "demo_plugin"
def test_resolve_installed_plugin_path_accepts_manifest_id_case_mismatch(client: TestClient):
plugin_path = support_module.resolve_installed_plugin_path("Test.Demo")
assert plugin_path is not None
assert plugin_path.name == "demo_plugin"
def test_install_plugin_preserves_manifest_declared_id(client: TestClient, monkeypatch):
class FakeGitMirrorService:
async def clone_repository(self, **kwargs):
target_path = kwargs["target_path"]
target_path.mkdir(parents=True, exist_ok=True)
(target_path / "_manifest.json").write_text(
json.dumps(
{
"manifest_version": 2,
"id": "author.declared",
"name": "Declared Plugin",
"version": "1.0.0",
"author": {"name": "author"},
}
),
encoding="utf-8",
)
return {"success": True}
monkeypatch.setattr(management_module, "get_git_mirror_service", lambda: FakeGitMirrorService())
response = client.post(
"/api/webui/plugins/install",
json={
"plugin_id": "market.plugin",
"repository_url": "https://github.com/author/declared",
"branch": "main",
},
)
assert response.status_code == 200
plugin_path = support_module.resolve_installed_plugin_path("author.declared")
assert plugin_path is not None
manifest = json.loads((plugin_path / "_manifest.json").read_text(encoding="utf-8"))
assert manifest["id"] == "author.declared"
def test_install_plugin_backfills_missing_manifest_id(client: TestClient, monkeypatch):
class FakeGitMirrorService:
async def clone_repository(self, **kwargs):
target_path = kwargs["target_path"]
target_path.mkdir(parents=True, exist_ok=True)
(target_path / "_manifest.json").write_text(
json.dumps(
{
"manifest_version": 2,
"name": "Legacy Plugin",
"version": "1.0.0",
"author": {"name": "author"},
}
),
encoding="utf-8",
)
return {"success": True}
monkeypatch.setattr(management_module, "get_git_mirror_service", lambda: FakeGitMirrorService())
response = client.post(
"/api/webui/plugins/install",
json={
"plugin_id": "market.legacy",
"repository_url": "https://github.com/author/legacy",
"branch": "main",
},
)
assert response.status_code == 200
plugin_path = support_module.resolve_installed_plugin_path("market.legacy")
assert plugin_path is not None
manifest = json.loads((plugin_path / "_manifest.json").read_text(encoding="utf-8"))
assert manifest["id"] == "market.legacy"

View File

@@ -1,332 +0,0 @@
from contextlib import contextmanager
from datetime import datetime, timedelta
from types import SimpleNamespace
from typing import Any, Iterator
import pytest
from src.services import statistics_service
from src.webui.schemas.statistics import DashboardData, StatisticsSummary, TimeSeriesData
class _Result:
def __init__(self, *, first_value: Any = None, all_values: list[Any] | None = None) -> None:
self._first_value = first_value
self._all_values = all_values or []
def first(self) -> Any:
return self._first_value
def all(self) -> list[Any]:
return self._all_values
class _Session:
def __init__(self, results: list[_Result]) -> None:
self._results = results
def exec(self, statement: Any) -> _Result:
del statement
return self._results.pop(0)
class _MemoryStore:
def __init__(self) -> None:
self.store: dict[str, Any] = {}
def __getitem__(self, item: str) -> Any:
return self.store.get(item)
def __setitem__(self, key: str, value: Any) -> None:
self.store[key] = value
def _patch_session_results(monkeypatch: pytest.MonkeyPatch, results: list[_Result]) -> list[bool]:
auto_commit_calls: list[bool] = []
@contextmanager
def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_Session]:
auto_commit_calls.append(auto_commit)
yield _Session([results.pop(0)])
monkeypatch.setattr(statistics_service, "get_db_session", _fake_get_db_session)
return auto_commit_calls
def _patch_session_result_group(monkeypatch: pytest.MonkeyPatch, results: list[_Result]) -> list[bool]:
auto_commit_calls: list[bool] = []
@contextmanager
def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_Session]:
auto_commit_calls.append(auto_commit)
yield _Session(results)
monkeypatch.setattr(statistics_service, "get_db_session", _fake_get_db_session)
return auto_commit_calls
def _build_dashboard_data(total_requests: int = 1) -> DashboardData:
return DashboardData(
summary=StatisticsSummary(total_requests=total_requests),
model_stats=[],
hourly_data=[],
daily_data=[],
recent_activity=[],
)
def _build_dashboard_data_with_time_series() -> DashboardData:
return DashboardData(
summary=StatisticsSummary(total_requests=1),
model_stats=[],
hourly_data=[
TimeSeriesData(timestamp="2026-05-06T10:00:00", requests=0, cost=0.0, tokens=0),
TimeSeriesData(timestamp="2026-05-06T11:00:00", requests=2, cost=0.5, tokens=50),
TimeSeriesData(timestamp="2026-05-06T12:00:00", requests=0, cost=0.0, tokens=0),
],
daily_data=[
TimeSeriesData(timestamp="2026-05-05T00:00:00", requests=0, cost=0.0, tokens=0),
TimeSeriesData(timestamp="2026-05-06T00:00:00", requests=3, cost=0.7, tokens=70),
],
recent_activity=[],
)
def test_shared_fetch_queries_disable_auto_commit(monkeypatch: pytest.MonkeyPatch) -> None:
now = datetime(2026, 5, 6, 12, 0, 0)
online_record = SimpleNamespace(start_timestamp=now - timedelta(minutes=5), end_timestamp=now)
usage_record = SimpleNamespace(
timestamp=now,
request_type="chat.reply",
model_api_provider_name="provider",
model_assign_name="chat-main",
model_name="gpt-a",
prompt_tokens=10,
completion_tokens=5,
cost=0.01,
time_cost=1.2,
)
message_record = SimpleNamespace(timestamp=now, message_id="msg-1")
tool_record = SimpleNamespace(timestamp=now, tool_name="reply")
auto_commit_calls = _patch_session_results(
monkeypatch,
[
_Result(all_values=[online_record]),
_Result(all_values=[usage_record]),
_Result(all_values=[message_record]),
_Result(all_values=[tool_record]),
],
)
online_ranges = statistics_service.fetch_online_time_since(now - timedelta(hours=1))
usage_records = statistics_service.fetch_model_usage_since(now - timedelta(hours=1))
messages = statistics_service.fetch_messages_since(now - timedelta(hours=1))
tool_records = statistics_service.fetch_tool_records_since(now - timedelta(hours=1))
assert online_ranges == [(online_record.start_timestamp, online_record.end_timestamp)]
assert usage_records == [
{
"timestamp": now,
"request_type": "chat.reply",
"model_api_provider_name": "provider",
"model_assign_name": "chat-main",
"model_name": "gpt-a",
"prompt_tokens": 10,
"completion_tokens": 5,
"cost": 0.01,
"time_cost": 1.2,
}
]
assert messages == [message_record]
assert tool_records == [tool_record]
assert auto_commit_calls == [False, False, False, False]
def test_get_earliest_statistics_time_uses_min_valid_timestamp(monkeypatch: pytest.MonkeyPatch) -> None:
fallback_time = datetime(2026, 5, 6, 12, 0, 0)
earliest_time = datetime(2026, 5, 1, 8, 30, 0)
auto_commit_calls = _patch_session_result_group(
monkeypatch,
[
_Result(first_value=datetime(2026, 5, 3, 9, 0, 0)),
_Result(first_value=earliest_time),
_Result(first_value=None),
_Result(first_value=datetime(2026, 5, 2, 9, 0, 0)),
],
)
result = statistics_service.get_earliest_statistics_time(fallback_time)
assert result == earliest_time
assert auto_commit_calls == [False]
def test_get_earliest_statistics_time_falls_back_when_query_fails(monkeypatch: pytest.MonkeyPatch) -> None:
fallback_time = datetime(2026, 5, 6, 12, 0, 0)
@contextmanager
def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_Session]:
del auto_commit
raise RuntimeError("database unavailable")
yield _Session([])
monkeypatch.setattr(statistics_service, "get_db_session", _fake_get_db_session)
assert statistics_service.get_earliest_statistics_time(fallback_time) == fallback_time
def test_dashboard_statistics_cache_roundtrip(monkeypatch: pytest.MonkeyPatch) -> None:
memory_store = _MemoryStore()
now = datetime.now()
dashboard_data = _build_dashboard_data(total_requests=7)
monkeypatch.setattr(statistics_service, "local_storage", memory_store)
statistics_service.store_dashboard_statistics_cache({24: dashboard_data}, generated_at=now)
cached_data = statistics_service.get_cached_dashboard_statistics(24)
assert cached_data is not None
assert cached_data.summary.total_requests == 7
def test_dashboard_statistics_cache_stores_sparse_time_series(monkeypatch: pytest.MonkeyPatch) -> None:
memory_store = _MemoryStore()
generated_at = datetime(2026, 5, 6, 12, 0, 0)
dashboard_data = _build_dashboard_data_with_time_series()
monkeypatch.setattr(statistics_service, "local_storage", memory_store)
statistics_service.store_dashboard_statistics_cache({2: dashboard_data}, generated_at=generated_at)
raw_cache = memory_store[statistics_service.DASHBOARD_STATISTICS_CACHE_KEY]
raw_entry = raw_cache["entries"]["2"]
assert raw_entry["sparse"] is True
assert raw_entry["hourly_data"] == [
{"timestamp": "2026-05-06T11:00:00", "requests": 2, "cost": 0.5, "tokens": 50}
]
assert raw_entry["daily_data"] == [
{"timestamp": "2026-05-06T00:00:00", "requests": 3, "cost": 0.7, "tokens": 70}
]
cached_data = statistics_service.get_cached_dashboard_statistics(2, max_age_seconds=10**9)
assert cached_data is not None
assert [item.timestamp for item in cached_data.hourly_data] == [
"2026-05-06T10:00:00",
"2026-05-06T11:00:00",
"2026-05-06T12:00:00",
]
assert cached_data.hourly_data[0].requests == 0
assert cached_data.hourly_data[1].requests == 2
assert cached_data.hourly_data[2].requests == 0
@pytest.mark.asyncio
async def test_get_dashboard_statistics_prefers_cache(monkeypatch: pytest.MonkeyPatch) -> None:
memory_store = _MemoryStore()
dashboard_data = _build_dashboard_data(total_requests=9)
monkeypatch.setattr(statistics_service, "local_storage", memory_store)
statistics_service.store_dashboard_statistics_cache({24: dashboard_data}, generated_at=datetime.now())
async def _fail_compute_dashboard_statistics(hours: int = 24) -> DashboardData:
del hours
raise AssertionError("cache should be used")
monkeypatch.setattr(statistics_service, "compute_dashboard_statistics", _fail_compute_dashboard_statistics)
result = await statistics_service.get_dashboard_statistics(24)
assert result.summary.total_requests == 9
@pytest.mark.asyncio
async def test_get_dashboard_statistics_returns_empty_when_cache_missing(monkeypatch: pytest.MonkeyPatch) -> None:
memory_store = _MemoryStore()
monkeypatch.setattr(statistics_service, "local_storage", memory_store)
async def _fail_compute_dashboard_statistics(hours: int = 24) -> DashboardData:
del hours
raise AssertionError("dashboard API should not compute fallback data")
monkeypatch.setattr(statistics_service, "compute_dashboard_statistics", _fail_compute_dashboard_statistics)
result = await statistics_service.get_dashboard_statistics(24)
assert result.summary.total_requests == 0
assert result.model_stats == []
@pytest.mark.asyncio
async def test_get_summary_statistics_aggregates_database_and_message_counts(monkeypatch: pytest.MonkeyPatch) -> None:
start_time = datetime(2026, 5, 6, 10, 0, 0)
end_time = datetime(2026, 5, 6, 12, 0, 0)
online_records = [
SimpleNamespace(
start_timestamp=start_time - timedelta(minutes=30),
end_timestamp=start_time + timedelta(minutes=30),
),
SimpleNamespace(
start_timestamp=start_time + timedelta(hours=1),
end_timestamp=end_time + timedelta(minutes=30),
),
]
auto_commit_calls = _patch_session_results(
monkeypatch,
[
_Result(first_value=(3, 1.5, 900, 2.5)),
_Result(all_values=online_records),
],
)
def _fake_count_messages(**kwargs: Any) -> int:
return 5 if kwargs.get("has_reply_to") is None else 2
monkeypatch.setattr(statistics_service, "count_messages", _fake_count_messages)
summary = await statistics_service.get_summary_statistics(start_time, end_time)
assert summary.total_requests == 3
assert summary.total_cost == 1.5
assert summary.total_tokens == 900
assert summary.avg_response_time == 2.5
assert summary.online_time == 5400
assert summary.total_messages == 5
assert summary.total_replies == 2
assert summary.cost_per_hour == pytest.approx(1.0)
assert summary.tokens_per_hour == pytest.approx(600.0)
assert auto_commit_calls == [False, False]
@pytest.mark.asyncio
async def test_get_model_statistics_groups_by_display_model_name(monkeypatch: pytest.MonkeyPatch) -> None:
now = datetime(2026, 5, 6, 12, 0, 0)
records = [
SimpleNamespace(
model_assign_name="chat-main",
model_name="gpt-a",
cost=0.4,
total_tokens=100,
time_cost=2.0,
),
SimpleNamespace(
model_assign_name="chat-main",
model_name="gpt-a",
cost=0.6,
total_tokens=200,
time_cost=4.0,
),
SimpleNamespace(
model_assign_name=None,
model_name="gpt-b",
cost=0.2,
total_tokens=50,
time_cost=0.0,
),
]
_patch_session_results(monkeypatch, [_Result(all_values=records)])
stats = await statistics_service.get_model_statistics(now - timedelta(hours=24))
assert [item.model_name for item in stats] == ["chat-main", "gpt-b"]
assert stats[0].request_count == 2
assert stats[0].total_cost == pytest.approx(1.0)
assert stats[0].total_tokens == 300
assert stats[0].avg_response_time == pytest.approx(3.0)
assert stats[1].avg_response_time == 0.0

View File

@@ -1,13 +0,0 @@
from src.webui.routers import system
def test_is_newer_version_detects_patch_update() -> None:
assert system._is_newer_version("1.0.7", "1.0.6") is True
def test_is_newer_version_ignores_same_version_with_shorter_parts() -> None:
assert system._is_newer_version("1.0.0", "1.0") is False
def test_is_newer_version_handles_unknown_current_version() -> None:
assert system._is_newer_version("1.0.7", "unknown") is False

View File

@@ -1,459 +0,0 @@
import argparse
import asyncio
import os
import sys
import time
import json
import importlib
from dataclasses import dataclass
from typing import Optional, Dict, Any
from datetime import datetime
# 强制使用 utf-8避免控制台编码报错
try:
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8")
except Exception:
pass
# 确保能导入 src.*
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import initialize_logging, get_logger
from src.common.database.database import db
from src.common.database.database_model import LLMUsage
from src.common.data_models.mai_message_data_model import UserInfo, GroupInfo
try:
from maim_message import ChatStream, UserInfo, GroupInfo
except Exception:
@dataclass
class ChatStream:
stream_id: str
platform: str
user_info: UserInfo
group_info: GroupInfo
logger = get_logger("test_memory_retrieval")
# 使用 importlib 动态导入,避免循环导入问题
def _import_memory_retrieval():
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
try:
# 先导入 prompt_builder检查 prompt 是否已经初始化
from src.chat.utils.prompt_builder import global_prompt_manager
# 检查 memory_retrieval 相关的 prompt 是否已经注册
# 如果已经注册,说明模块可能已经通过其他路径初始化过了
prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts
module_name = "src.memory_system.memory_retrieval"
# 如果 prompt 已经初始化,尝试直接使用已加载的模块
if prompt_already_init and module_name in sys.modules:
existing_module = sys.modules[module_name]
if hasattr(existing_module, "init_memory_retrieval_prompt"):
return (
existing_module.init_memory_retrieval_prompt,
existing_module._react_agent_solve_question,
existing_module._process_single_question,
)
# 如果模块已经在 sys.modules 中但部分初始化,先移除它
if module_name in sys.modules:
existing_module = sys.modules[module_name]
if not hasattr(existing_module, "init_memory_retrieval_prompt"):
# 模块部分初始化,移除它
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
del sys.modules[module_name]
# 清理可能相关的部分初始化模块
keys_to_remove = []
for key in sys.modules.keys():
if key.startswith("src.memory_system.") and key != "src.memory_system":
keys_to_remove.append(key)
for key in keys_to_remove:
try:
del sys.modules[key]
except KeyError:
pass
# 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载
# 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们
try:
# 先导入可能触发循环导入的模块,让它们完成初始化
import src.config.config
import src.chat.utils.prompt_builder
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval
# 如果它们已经导入,就确保它们完全初始化
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval
# 如果它们已经导入,就确保它们完全初始化
try:
import src.chat.replyer.group_generator # noqa: F401
except (ImportError, AttributeError):
pass # 如果导入失败,继续
try:
import src.chat.replyer.private_generator # noqa: F401
except (ImportError, AttributeError):
pass # 如果导入失败,继续
except Exception as e:
logger.warning(f"预加载依赖模块时出现警告: {e}")
# 现在尝试导入 memory_retrieval
# 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval
memory_retrieval_module = importlib.import_module(module_name)
return (
memory_retrieval_module.init_memory_retrieval_prompt,
memory_retrieval_module._react_agent_solve_question,
memory_retrieval_module._process_single_question,
)
except (ImportError, AttributeError) as e:
logger.error(f"导入 memory_retrieval 模块失败: {e}", exc_info=True)
raise
def create_test_chat_stream(chat_id: str = "test_memory_retrieval") -> ChatStream:
"""创建一个测试用的 ChatStream 对象"""
user_info = UserInfo(
platform="test",
user_id="test_user",
user_nickname="测试用户",
)
group_info = GroupInfo(
platform="test",
group_id="test_group",
group_name="测试群组",
)
return ChatStream(
stream_id=chat_id,
platform="test",
user_info=user_info,
group_info=group_info,
)
def get_token_usage_since(start_time: float) -> Dict[str, Any]:
"""获取从指定时间开始的token使用情况
Args:
start_time: 开始时间戳
Returns:
包含token使用统计的字典
"""
try:
start_datetime = datetime.fromtimestamp(start_time)
# 查询从开始时间到现在的所有memory相关的token使用记录
records = (
LLMUsage.select()
.where(
(LLMUsage.timestamp >= start_datetime)
& (
(LLMUsage.request_type.like("%memory%"))
| (LLMUsage.request_type == "memory.question")
| (LLMUsage.request_type == "memory.react")
| (LLMUsage.request_type == "memory.react.final")
)
)
.order_by(LLMUsage.timestamp.asc())
)
total_prompt_tokens = 0
total_completion_tokens = 0
total_tokens = 0
total_cost = 0.0
request_count = 0
model_usage = {} # 按模型统计
for record in records:
total_prompt_tokens += record.prompt_tokens or 0
total_completion_tokens += record.completion_tokens or 0
total_tokens += record.total_tokens or 0
total_cost += record.cost or 0.0
request_count += 1
# 按模型统计
model_name = record.model_name or "unknown"
if model_name not in model_usage:
model_usage[model_name] = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"cost": 0.0,
"request_count": 0,
}
model_usage[model_name]["prompt_tokens"] += record.prompt_tokens or 0
model_usage[model_name]["completion_tokens"] += record.completion_tokens or 0
model_usage[model_name]["total_tokens"] += record.total_tokens or 0
model_usage[model_name]["cost"] += record.cost or 0.0
model_usage[model_name]["request_count"] += 1
return {
"total_prompt_tokens": total_prompt_tokens,
"total_completion_tokens": total_completion_tokens,
"total_tokens": total_tokens,
"total_cost": total_cost,
"request_count": request_count,
"model_usage": model_usage,
}
except Exception as e:
logger.error(f"获取token使用情况失败: {e}")
return {
"total_prompt_tokens": 0,
"total_completion_tokens": 0,
"total_tokens": 0,
"total_cost": 0.0,
"request_count": 0,
"model_usage": {},
}
def format_thinking_steps(thinking_steps: list) -> str:
"""格式化思考步骤为可读字符串"""
if not thinking_steps:
return "无思考步骤"
lines = []
for step in thinking_steps:
iteration = step.get("iteration", "?")
thought = step.get("thought", "")
actions = step.get("actions", [])
observations = step.get("observations", [])
lines.append(f"\n--- 迭代 {iteration} ---")
if thought:
lines.append(f"思考: {thought[:200]}...")
if actions:
lines.append("行动:")
for action in actions:
action_type = action.get("action_type", "unknown")
action_params = action.get("action_params", {})
lines.append(f" - {action_type}: {json.dumps(action_params, ensure_ascii=False)}")
if observations:
lines.append("观察:")
for obs in observations:
obs_str = str(obs)[:200]
if len(str(obs)) > 200:
obs_str += "..."
lines.append(f" - {obs_str}")
return "\n".join(lines)
async def test_memory_retrieval(
question: str,
chat_id: str = "test_memory_retrieval",
context: str = "",
max_iterations: Optional[int] = None,
) -> Dict[str, Any]:
"""测试记忆检索功能
Args:
question: 要查询的问题
chat_id: 聊天ID
context: 上下文信息
max_iterations: 最大迭代次数
Returns:
包含测试结果的字典
"""
print("\n" + "=" * 80)
print("[测试] 记忆检索测试")
print(f"[问题] {question}")
print("=" * 80)
# 记录开始时间
start_time = time.time()
# 延迟导入并初始化记忆检索prompt这会自动加载 global_config
# 注意:必须在函数内部调用,避免在模块级别触发循环导入
try:
init_memory_retrieval_prompt, _react_agent_solve_question, _ = _import_memory_retrieval()
# 检查 prompt 是否已经初始化,避免重复初始化
from src.chat.utils.prompt_builder import global_prompt_manager
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
init_memory_retrieval_prompt()
else:
logger.debug("记忆检索 prompt 已经初始化,跳过重复初始化")
except Exception as e:
logger.error(f"初始化记忆检索模块失败: {e}", exc_info=True)
raise
# 获取 global_config此时应该已经加载
from src.config.config import global_config
# 直接调用 _react_agent_solve_question 来获取详细的迭代信息
if max_iterations is None:
max_iterations = global_config.memory.max_agent_iterations
timeout = global_config.memory.agent_timeout_seconds
print("\n[配置]")
print(f" 最大迭代次数: {max_iterations}")
print(f" 超时时间: {timeout}")
print(f" 聊天ID: {chat_id}")
# 执行检索
print(f"\n[开始检索] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question(
question=question,
chat_id=chat_id,
max_iterations=max_iterations,
timeout=timeout,
initial_info="",
)
# 记录结束时间
end_time = time.time()
elapsed_time = end_time - start_time
# 获取token使用情况
token_usage = get_token_usage_since(start_time)
# 构建结果
result = {
"question": question,
"found_answer": found_answer,
"answer": answer,
"is_timeout": is_timeout,
"elapsed_time": elapsed_time,
"thinking_steps": thinking_steps,
"iteration_count": len(thinking_steps),
"token_usage": token_usage,
}
# 输出结果
print(f"\n[检索完成] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
print("\n[结果]")
print(f" 是否找到答案: {'' if found_answer else ''}")
if found_answer and answer:
print(f" 答案: {answer}")
else:
print(" 答案: (未找到答案)")
print(f" 是否超时: {'' if is_timeout else ''}")
print(f" 迭代次数: {len(thinking_steps)}")
print(f" 总耗时: {elapsed_time:.2f}")
print("\n[Token使用情况]")
print(f" 总请求数: {token_usage['request_count']}")
print(f" 总Prompt Tokens: {token_usage['total_prompt_tokens']:,}")
print(f" 总Completion Tokens: {token_usage['total_completion_tokens']:,}")
print(f" 总Tokens: {token_usage['total_tokens']:,}")
print(f" 总成本: ${token_usage['total_cost']:.6f}")
if token_usage["model_usage"]:
print("\n[按模型统计]")
for model_name, usage in token_usage["model_usage"].items():
print(f" {model_name}:")
print(f" 请求数: {usage['request_count']}")
print(f" Prompt Tokens: {usage['prompt_tokens']:,}")
print(f" Completion Tokens: {usage['completion_tokens']:,}")
print(f" 总Tokens: {usage['total_tokens']:,}")
print(f" 成本: ${usage['cost']:.6f}")
print("\n[迭代详情]")
print(format_thinking_steps(thinking_steps))
print("\n" + "=" * 80)
return result
def main() -> None:
parser = argparse.ArgumentParser(
description="测试记忆检索功能。可以输入一个问题脚本会使用记忆检索的逻辑进行检索并记录迭代信息、时间和token总消耗。"
)
parser.add_argument(
"--chat-id",
default="test_memory_retrieval",
help="测试用的聊天ID默认: test_memory_retrieval",
)
parser.add_argument(
"--context",
default="",
help="上下文信息(可选)",
)
parser.add_argument(
"--output",
"-o",
help="将结果保存到JSON文件可选",
)
args = parser.parse_args()
# 初始化日志(使用较低的详细程度,避免输出过多日志)
initialize_logging(verbose=False)
# 交互式输入问题
print("\n" + "=" * 80)
print("记忆检索测试工具")
print("=" * 80)
question = input("\n请输入要查询的问题: ").strip()
if not question:
print("错误: 问题不能为空")
return
# 交互式输入最大迭代次数
max_iterations_input = input("\n请输入最大迭代次数(直接回车使用配置默认值): ").strip()
max_iterations = None
if max_iterations_input:
try:
max_iterations = int(max_iterations_input)
if max_iterations <= 0:
print("警告: 迭代次数必须大于0将使用配置默认值")
max_iterations = None
except ValueError:
print("警告: 无效的迭代次数,将使用配置默认值")
max_iterations = None
# 连接数据库
try:
db.connect(reuse_if_open=True)
except Exception as e:
logger.error(f"数据库连接失败: {e}")
print(f"错误: 数据库连接失败: {e}")
return
# 运行测试
try:
result = asyncio.run(
test_memory_retrieval(
question=question,
chat_id=args.chat_id,
context=args.context,
max_iterations=max_iterations,
)
)
# 如果指定了输出文件,保存结果
if args.output:
# 将thinking_steps转换为可序列化的格式
output_result = result.copy()
with open(args.output, "w", encoding="utf-8") as f:
json.dump(output_result, f, ensure_ascii=False, indent=2)
print(f"\n[结果已保存] {args.output}")
except KeyboardInterrupt:
print("\n\n[中断] 用户中断测试")
except Exception as e:
logger.error(f"测试失败: {e}", exc_info=True)
print(f"\n[错误] 测试失败: {e}")
finally:
try:
db.close()
except Exception:
pass
if __name__ == "__main__":
main()

View File

@@ -1,845 +0,0 @@
from argparse import ArgumentParser, Namespace
from contextlib import contextmanager
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, Iterator, List, Sequence
import asyncio
import json
import sys
import time
PROJECT_ROOT = Path(__file__).resolve().parent.parent
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from src.common.data_models.llm_service_data_models import LLMServiceRequest, LLMServiceResult # noqa: E402
from src.config.config import config_manager # noqa: E402
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig # noqa: E402
from src.llm_models.payload_content.tool_option import ToolCall # noqa: E402
from src.services.llm_service import generate # noqa: E402
from src.services.service_task_resolver import get_available_models # noqa: E402
DEFAULT_SKIP_TASKS = {"embedding", "voice"}
@dataclass(slots=True)
class ToolCallCase:
"""Tool call 参数测试用例。"""
name: str
description: str
tool_definition: Dict[str, Any]
expected_arguments: Dict[str, Any]
@property
def tool_name(self) -> str:
"""返回工具名称。"""
if self.tool_definition.get("type") == "function":
function_definition = self.tool_definition.get("function", {})
return str(function_definition.get("name", "") or "")
return str(self.tool_definition.get("name", "") or "")
@property
def parameters_schema(self) -> Dict[str, Any]:
"""返回参数 Schema。"""
if self.tool_definition.get("type") == "function":
function_definition = self.tool_definition.get("function", {})
parameters = function_definition.get("parameters", {})
return parameters if isinstance(parameters, dict) else {}
parameters = self.tool_definition.get("parameters", {})
return parameters if isinstance(parameters, dict) else {}
def build_messages(self) -> List[Dict[str, Any]]:
"""构造测试消息。"""
expected_json = json.dumps(self.expected_arguments, ensure_ascii=False, indent=2)
system_prompt = (
"你正在执行严格的工具调用参数兼容性测试。"
"你必须通过工具调用响应,不能输出自然语言,不能解释,不能补充额外字段。"
)
user_prompt = (
f"请立刻调用工具 `{self.tool_name}`。\n"
"参数必须与下面 JSON 完全一致,键名、值、布尔类型、整数类型、浮点数、数组顺序和对象结构都不能改变。\n"
"不要输出任何解释文本,只返回工具调用。\n"
f"{expected_json}"
)
return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
@dataclass(slots=True)
class ProbeTarget:
"""单个待测试模型目标。"""
task_name: str
model_name: str
provider_name: str
client_type: str
tool_argument_parse_mode: str
@dataclass(slots=True)
class ProbeResult:
"""单次测试结果。"""
task_name: str
target_model_name: str
actual_model_name: str
provider_name: str
client_type: str
tool_argument_parse_mode: str
case_name: str
attempt: int
success: bool
elapsed_seconds: float
errors: List[str]
warnings: List[str]
response_text: str
reasoning_text: str
tool_calls: List[Dict[str, Any]]
def _ensure_utf8_console() -> None:
"""尽量将控制台编码切到 UTF-8。"""
try:
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8")
except Exception:
pass
def _build_function_tool(name: str, description: str, parameters: Dict[str, Any]) -> Dict[str, Any]:
"""构造 OpenAI 风格 function tool 定义。"""
return {
"type": "function",
"function": {
"name": name,
"description": description,
"parameters": parameters,
},
}
def _build_default_cases() -> List[ToolCallCase]:
"""构造默认测试用例。"""
simple_expected_arguments = {
"request_id": "probe-simple-001",
"count": 7,
"enabled": True,
"mode": "strict",
"ratio": 2.5,
}
simple_parameters = {
"type": "object",
"properties": {
"request_id": {"type": "string", "description": "请求 ID"},
"count": {"type": "integer", "description": "数量"},
"enabled": {"type": "boolean", "description": "是否启用"},
"mode": {
"type": "string",
"description": "模式",
"enum": ["strict", "loose"],
},
"ratio": {"type": "number", "description": "比例"},
},
"required": ["request_id", "count", "enabled", "mode", "ratio"],
"additionalProperties": False,
}
nested_expected_arguments = {
"request_id": "probe-nested-001",
"notify": False,
"profile": {
"channel": "stable",
"priority": 2,
},
"tags": ["alpha", "beta", "gamma"],
"items": [
{"count": 2, "name": "apple"},
{"count": 5, "name": "banana"},
],
}
nested_parameters = {
"type": "object",
"properties": {
"request_id": {"type": "string", "description": "请求 ID"},
"notify": {"type": "boolean", "description": "是否通知"},
"profile": {
"type": "object",
"description": "配置对象",
"properties": {
"channel": {"type": "string", "description": "渠道"},
"priority": {"type": "integer", "description": "优先级"},
},
"required": ["channel", "priority"],
"additionalProperties": False,
},
"tags": {
"type": "array",
"description": "标签列表",
"items": {"type": "string"},
},
"items": {
"type": "array",
"description": "条目列表",
"items": {
"type": "object",
"properties": {
"count": {"type": "integer", "description": "数量"},
"name": {"type": "string", "description": "名称"},
},
"required": ["count", "name"],
"additionalProperties": False,
},
},
},
"required": ["request_id", "notify", "profile", "tags", "items"],
"additionalProperties": False,
}
return [
ToolCallCase(
name="simple",
description="标量参数类型校验",
tool_definition=_build_function_tool(
name="record_simple_probe",
description="记录简单参数探测结果",
parameters=simple_parameters,
),
expected_arguments=simple_expected_arguments,
),
ToolCallCase(
name="nested",
description="嵌套对象与数组参数校验",
tool_definition=_build_function_tool(
name="record_nested_probe",
description="记录嵌套参数探测结果",
parameters=nested_parameters,
),
expected_arguments=nested_expected_arguments,
),
]
def _parse_multi_value_args(raw_values: Sequence[str] | None) -> List[str]:
"""解析命令行中的多值参数。"""
parsed_values: List[str] = []
for raw_value in raw_values or []:
for item in str(raw_value).split(","):
normalized_item = item.strip()
if normalized_item:
parsed_values.append(normalized_item)
return parsed_values
def _build_model_map() -> Dict[str, ModelInfo]:
"""构造模型名称到模型配置的映射。"""
return {model.name: model for model in config_manager.get_model_config().models}
def _build_provider_map() -> Dict[str, APIProvider]:
"""构造 Provider 名称到配置的映射。"""
return {provider.name: provider for provider in config_manager.get_model_config().api_providers}
def _pick_default_task_name(task_names: Sequence[str]) -> str:
"""选择默认任务名。"""
if "utils" in task_names:
return "utils"
if not task_names:
raise ValueError("当前没有可用的任务配置")
return str(task_names[0])
def _resolve_targets(task_filters: Sequence[str], model_filters: Sequence[str], fallback_task: str) -> List[ProbeTarget]:
"""根据命令行参数解析待测试目标。"""
available_tasks = get_available_models()
model_map = _build_model_map()
provider_map = _build_provider_map()
if not available_tasks:
raise ValueError("未找到任何可用的模型任务配置")
if task_filters:
selected_task_names = []
for task_name in task_filters:
if task_name not in available_tasks:
raise ValueError(f"未找到任务 `{task_name}`")
selected_task_names.append(task_name)
else:
selected_task_names = [
task_name
for task_name in available_tasks
if task_name not in DEFAULT_SKIP_TASKS
]
if not selected_task_names:
raise ValueError("没有可用于 tool call 测试的任务,请显式通过 --task 指定")
default_task_name = fallback_task if fallback_task in available_tasks else _pick_default_task_name(selected_task_names)
resolved_targets: List[ProbeTarget] = []
seen_models: set[str] = set()
if model_filters:
model_names = list(model_filters)
else:
model_names = []
for task_name in selected_task_names:
task_config = available_tasks[task_name]
for model_name in task_config.model_list:
if model_name not in model_names:
model_names.append(model_name)
for model_name in model_names:
if model_name in seen_models:
continue
if model_name not in model_map:
raise ValueError(f"未找到模型 `{model_name}`")
target_task_name = ""
for task_name in selected_task_names:
if model_name in available_tasks[task_name].model_list:
target_task_name = task_name
break
if not target_task_name:
target_task_name = default_task_name
model_info = model_map[model_name]
provider_info = provider_map[model_info.api_provider]
resolved_targets.append(
ProbeTarget(
task_name=target_task_name,
model_name=model_name,
provider_name=provider_info.name,
client_type=provider_info.client_type,
tool_argument_parse_mode=provider_info.tool_argument_parse_mode,
)
)
seen_models.add(model_name)
return resolved_targets
@contextmanager
def _pin_task_to_model(task_name: str, model_name: str) -> Iterator[None]:
"""临时将某个任务锁定到单模型。"""
model_task_config = config_manager.get_model_config().model_task_config
task_config = getattr(model_task_config, task_name, None)
if not isinstance(task_config, TaskConfig):
raise ValueError(f"未找到任务 `{task_name}` 对应的配置")
original_model_list = list(task_config.model_list)
original_selection_strategy = task_config.selection_strategy
task_config.model_list = [model_name]
task_config.selection_strategy = "balance"
try:
yield
finally:
task_config.model_list = original_model_list
task_config.selection_strategy = original_selection_strategy
def _serialize_tool_calls(tool_calls: List[ToolCall] | None) -> List[Dict[str, Any]]:
"""序列化工具调用结果。"""
if not tool_calls:
return []
return [
{
"id": tool_call.call_id,
"function": {
"name": tool_call.func_name,
"arguments": dict(tool_call.args or {}),
},
}
for tool_call in tool_calls
]
def _is_integer_value(value: Any) -> bool:
"""判断是否为整数类型且排除布尔值。"""
return isinstance(value, int) and not isinstance(value, bool)
def _is_number_value(value: Any) -> bool:
"""判断是否为数值类型且排除布尔值。"""
return (isinstance(value, int) or isinstance(value, float)) and not isinstance(value, bool)
def _schema_type(schema: Dict[str, Any]) -> str:
"""解析 Schema 的类型。"""
schema_type = str(schema.get("type", "") or "").strip()
if schema_type:
return schema_type
if "properties" in schema or "required" in schema:
return "object"
return ""
def _validate_schema(schema: Dict[str, Any], actual_value: Any, path: str = "args") -> List[str]:
"""按简化 JSON Schema 校验工具参数。"""
errors: List[str] = []
schema_type = _schema_type(schema)
if "enum" in schema and actual_value not in schema["enum"]:
errors.append(f"{path} 枚举值不合法,期望属于 {schema['enum']},实际为 {actual_value!r}")
if schema_type == "string":
if not isinstance(actual_value, str):
errors.append(f"{path} 类型错误,期望 string实际为 {type(actual_value).__name__}")
return errors
if schema_type == "integer":
if not _is_integer_value(actual_value):
errors.append(f"{path} 类型错误,期望 integer实际为 {type(actual_value).__name__}")
return errors
if schema_type == "number":
if not _is_number_value(actual_value):
errors.append(f"{path} 类型错误,期望 number实际为 {type(actual_value).__name__}")
return errors
if schema_type == "boolean":
if not isinstance(actual_value, bool):
errors.append(f"{path} 类型错误,期望 boolean实际为 {type(actual_value).__name__}")
return errors
if schema_type == "array":
if not isinstance(actual_value, list):
errors.append(f"{path} 类型错误,期望 array实际为 {type(actual_value).__name__}")
return errors
item_schema = schema.get("items")
if isinstance(item_schema, dict):
for index, item in enumerate(actual_value):
errors.extend(_validate_schema(item_schema, item, f"{path}[{index}]"))
return errors
if schema_type == "object":
if not isinstance(actual_value, dict):
errors.append(f"{path} 类型错误,期望 object实际为 {type(actual_value).__name__}")
return errors
properties = schema.get("properties", {})
required_fields = [str(item) for item in schema.get("required", [])]
for required_field in required_fields:
if required_field not in actual_value:
errors.append(f"{path}.{required_field} 缺少必填字段")
for field_name, field_value in actual_value.items():
field_path = f"{path}.{field_name}"
field_schema = properties.get(field_name)
if isinstance(field_schema, dict):
errors.extend(_validate_schema(field_schema, field_value, field_path))
continue
additional_properties = schema.get("additionalProperties", True)
if additional_properties is False:
errors.append(f"{field_path} 是未定义字段")
elif isinstance(additional_properties, dict):
errors.extend(_validate_schema(additional_properties, field_value, field_path))
return errors
return errors
def _compare_expected_values(expected_value: Any, actual_value: Any, path: str = "args") -> List[str]:
"""递归比较实际值与期望值是否完全一致。"""
errors: List[str] = []
if isinstance(expected_value, dict):
if not isinstance(actual_value, dict):
return [f"{path} 值不一致,期望 object实际为 {type(actual_value).__name__}"]
expected_keys = set(expected_value.keys())
actual_keys = set(actual_value.keys())
for missing_key in sorted(expected_keys - actual_keys):
errors.append(f"{path}.{missing_key} 缺少期望字段")
for extra_key in sorted(actual_keys - expected_keys):
errors.append(f"{path}.{extra_key} 出现了额外字段")
for shared_key in sorted(expected_keys & actual_keys):
errors.extend(
_compare_expected_values(
expected_value[shared_key],
actual_value[shared_key],
f"{path}.{shared_key}",
)
)
return errors
if isinstance(expected_value, list):
if not isinstance(actual_value, list):
return [f"{path} 值不一致,期望 array实际为 {type(actual_value).__name__}"]
if len(expected_value) != len(actual_value):
errors.append(f"{path} 列表长度不一致,期望 {len(expected_value)},实际 {len(actual_value)}")
for index, (expected_item, actual_item) in enumerate(
zip(expected_value, actual_value, strict=False)
):
errors.extend(_compare_expected_values(expected_item, actual_item, f"{path}[{index}]"))
return errors
if isinstance(expected_value, bool):
if not isinstance(actual_value, bool) or actual_value is not expected_value:
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
return errors
if _is_integer_value(expected_value):
if not _is_integer_value(actual_value) or actual_value != expected_value:
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
return errors
if isinstance(expected_value, float):
if not _is_number_value(actual_value) or float(actual_value) != expected_value:
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
return errors
if expected_value != actual_value:
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
return errors
def _pick_tool_call(tool_calls: List[ToolCall], expected_tool_name: str) -> ToolCall:
"""优先选择同名工具调用,否则回退到第一条。"""
for tool_call in tool_calls:
if tool_call.func_name == expected_tool_name:
return tool_call
return tool_calls[0]
def _validate_service_result(
service_result: LLMServiceResult,
target: ProbeTarget,
case: ToolCallCase,
) -> tuple[List[str], List[str], List[Dict[str, Any]]]:
"""校验服务层返回结果。"""
errors: List[str] = []
warnings: List[str] = []
completion = service_result.completion
serialized_tool_calls = _serialize_tool_calls(completion.tool_calls)
if not service_result.success:
errors.append(service_result.error or completion.response or "请求失败但未返回错误信息")
return errors, warnings, serialized_tool_calls
if completion.model_name and completion.model_name != target.model_name:
errors.append(
f"实际命中的模型为 `{completion.model_name}`,与目标模型 `{target.model_name}` 不一致"
)
tool_calls = completion.tool_calls or []
if not tool_calls:
errors.append("模型未返回 tool_calls")
if completion.response.strip():
warnings.append("模型返回了自然语言文本而不是工具调用")
return errors, warnings, serialized_tool_calls
if len(tool_calls) != 1:
errors.append(f"返回了 {len(tool_calls)} 个 tool_calls预期为 1 个")
selected_tool_call = _pick_tool_call(tool_calls, case.tool_name)
if selected_tool_call.func_name != case.tool_name:
errors.append(
f"工具名不一致,期望 `{case.tool_name}`,实际 `{selected_tool_call.func_name}`"
)
actual_arguments = selected_tool_call.args
if not isinstance(actual_arguments, dict):
errors.append("工具参数未被解析为对象")
return errors, warnings, serialized_tool_calls
errors.extend(_validate_schema(case.parameters_schema, actual_arguments))
errors.extend(_compare_expected_values(case.expected_arguments, actual_arguments))
if completion.response.strip():
warnings.append("模型同时返回了自然语言文本")
return errors, warnings, serialized_tool_calls
async def _run_single_probe(
target: ProbeTarget,
case: ToolCallCase,
attempt: int,
max_tokens: int,
temperature: float,
) -> ProbeResult:
"""执行单次工具调用参数探测。"""
request = LLMServiceRequest(
task_name=target.task_name,
request_type=f"tool_call_param_probe.{case.name}.attempt_{attempt}",
prompt=case.build_messages(),
tool_options=[case.tool_definition],
temperature=temperature,
max_tokens=max_tokens,
)
started_at = time.perf_counter()
with _pin_task_to_model(target.task_name, target.model_name):
service_result = await generate(request)
elapsed_seconds = time.perf_counter() - started_at
errors, warnings, serialized_tool_calls = _validate_service_result(service_result, target, case)
completion = service_result.completion
return ProbeResult(
task_name=target.task_name,
target_model_name=target.model_name,
actual_model_name=completion.model_name,
provider_name=target.provider_name,
client_type=target.client_type,
tool_argument_parse_mode=target.tool_argument_parse_mode,
case_name=case.name,
attempt=attempt,
success=not errors,
elapsed_seconds=elapsed_seconds,
errors=errors,
warnings=warnings,
response_text=completion.response,
reasoning_text=completion.reasoning,
tool_calls=serialized_tool_calls,
)
def _print_targets(targets: Sequence[ProbeTarget]) -> None:
"""打印待测试目标。"""
print("待测试目标:")
for index, target in enumerate(targets, start=1):
print(
f"{index}. model={target.model_name} | task={target.task_name} | "
f"provider={target.provider_name} | client={target.client_type} | "
f"tool_argument_parse_mode={target.tool_argument_parse_mode}"
)
def _print_available_targets() -> None:
"""打印当前可用任务与模型。"""
available_tasks = get_available_models()
model_map = _build_model_map()
task_names = list(available_tasks.keys())
print("当前可用任务:")
for task_name in task_names:
task_config = available_tasks[task_name]
print(f"- {task_name}: {list(task_config.model_list)}")
referenced_models = {
model_name
for task_config in available_tasks.values()
for model_name in task_config.model_list
}
print("\n当前配置中的模型:")
for model_name, model_info in model_map.items():
referenced_mark = "已被任务引用" if model_name in referenced_models else "未被任务引用"
print(
f"- {model_name}: provider={model_info.api_provider}, "
f"identifier={model_info.model_identifier}, {referenced_mark}"
)
def _select_cases(case_filters: Sequence[str]) -> List[ToolCallCase]:
"""根据参数筛选测试用例。"""
all_cases = {case.name: case for case in _build_default_cases()}
if not case_filters:
return list(all_cases.values())
selected_cases: List[ToolCallCase] = []
for case_name in case_filters:
if case_name not in all_cases:
raise ValueError(f"未知测试用例 `{case_name}`,可选值: {', '.join(sorted(all_cases))}")
selected_cases.append(all_cases[case_name])
return selected_cases
def _print_single_result(result: ProbeResult, show_response: bool) -> None:
"""打印单次结果。"""
status_text = "PASS" if result.success else "FAIL"
print(
f"[{status_text}] model={result.target_model_name} | task={result.task_name} | "
f"case={result.case_name} | attempt={result.attempt} | elapsed={result.elapsed_seconds:.2f}s"
)
if result.errors:
for error in result.errors:
print(f" ERROR: {error}")
if result.warnings:
for warning in result.warnings:
print(f" WARN: {warning}")
if result.tool_calls:
print(f" tool_calls: {json.dumps(result.tool_calls, ensure_ascii=False)}")
if show_response and result.response_text.strip():
print(f" response: {result.response_text}")
def _build_summary(results: Sequence[ProbeResult]) -> Dict[str, Any]:
"""构造结果摘要。"""
total_count = len(results)
passed_count = sum(1 for result in results if result.success)
failed_count = total_count - passed_count
failed_items = [
{
"model_name": result.target_model_name,
"case_name": result.case_name,
"attempt": result.attempt,
"errors": list(result.errors),
}
for result in results
if not result.success
]
return {
"total": total_count,
"passed": passed_count,
"failed": failed_count,
"failed_items": failed_items,
}
def _write_json_report(json_out: str, results: Sequence[ProbeResult]) -> None:
"""将测试结果写入 JSON 文件。"""
output_path = Path(json_out).expanduser().resolve()
output_path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
"summary": _build_summary(results),
"results": [asdict(result) for result in results],
}
output_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"\n结果已写入: {output_path}")
async def _run_probes(args: Namespace) -> List[ProbeResult]:
"""执行所有探测请求。"""
task_filters = _parse_multi_value_args(args.task)
model_filters = _parse_multi_value_args(args.model)
case_filters = _parse_multi_value_args(args.case)
selected_cases = _select_cases(case_filters)
targets = _resolve_targets(task_filters, model_filters, args.fallback_task)
_print_targets(targets)
print("")
results: List[ProbeResult] = []
for target in targets:
for attempt in range(1, args.repeat + 1):
for case in selected_cases:
print(
f"开始测试: model={target.model_name}, task={target.task_name}, "
f"case={case.name}, attempt={attempt}"
)
result = await _run_single_probe(
target=target,
case=case,
attempt=attempt,
max_tokens=args.max_tokens,
temperature=args.temperature,
)
_print_single_result(result, args.show_response)
print("")
results.append(result)
return results
def _build_parser() -> ArgumentParser:
"""构造命令行参数解析器。"""
parser = ArgumentParser(
description=(
"测试 config/model_config.toml 中不同模型的 tool call 参数兼容性。\n"
"默认会测试所有非 voice / embedding 任务中引用到的模型。"
)
)
parser.add_argument(
"--task",
action="append",
help="指定任务名,可重复传入,或使用逗号分隔多个值,例如 --task utils --task planner",
)
parser.add_argument(
"--model",
action="append",
help="指定模型名,可重复传入,或使用逗号分隔多个值,例如 --model qwen3.6-plus",
)
parser.add_argument(
"--case",
action="append",
help="指定测试用例名,可选 simple、nested不传则运行全部默认用例",
)
parser.add_argument(
"--repeat",
type=int,
default=1,
help="每个模型每个用例重复测试次数,默认 1",
)
parser.add_argument(
"--max-tokens",
type=int,
default=512,
help="单次测试的最大输出 token 数,默认 512",
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="单次测试温度,默认 0.0 以尽量提高稳定性",
)
parser.add_argument(
"--fallback-task",
default="utils",
help="当指定模型未被任何已选任务引用时,用于挂载该模型的任务名,默认 utils",
)
parser.add_argument(
"--json-out",
help="可选,将结果写入指定 JSON 文件",
)
parser.add_argument(
"--list-targets",
action="store_true",
help="仅打印当前任务与模型映射,不发起网络请求",
)
parser.add_argument(
"--show-response",
action="store_true",
help="打印模型返回的自然语言文本内容",
)
return parser
def main() -> int:
"""脚本入口。"""
_ensure_utf8_console()
parser = _build_parser()
args = parser.parse_args()
if args.repeat < 1:
parser.error("--repeat 必须大于等于 1")
if args.max_tokens < 1:
parser.error("--max-tokens 必须大于等于 1")
if args.list_targets:
_print_available_targets()
return 0
results = asyncio.run(_run_probes(args))
summary = _build_summary(results)
print("测试摘要:")
print(
f"total={summary['total']} | passed={summary['passed']} | failed={summary['failed']}"
)
if summary["failed_items"]:
print("失败明细:")
for failed_item in summary["failed_items"]:
print(
f"- model={failed_item['model_name']} | case={failed_item['case_name']} | "
f"attempt={failed_item['attempt']} | errors={failed_item['errors']}"
)
if args.json_out:
_write_json_report(args.json_out, results)
return 0 if summary["failed"] == 0 else 1
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -1,777 +0,0 @@
from argparse import ArgumentParser, Namespace
from contextlib import contextmanager
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Dict, Iterator, List, Sequence
import asyncio
import json
import sys
import time
PROJECT_ROOT = Path(__file__).resolve().parent.parent
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from src.common.data_models.llm_service_data_models import LLMServiceRequest, LLMServiceResult # noqa: E402
from src.config.config import config_manager # noqa: E402
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig # noqa: E402
from src.services.llm_service import generate # noqa: E402
from src.services.service_task_resolver import get_available_models # noqa: E402
DEFAULT_SKIP_TASKS = {"embedding", "voice"}
@dataclass(slots=True)
class ProbeTarget:
"""单个待测试模型目标。"""
task_name: str
model_name: str
provider_name: str
client_type: str
tool_argument_parse_mode: str
@dataclass(slots=True)
class ToolCallScenario:
"""工具调用 API 场景定义。"""
name: str
description: str
prompt: List[Dict[str, Any]]
tool_options: List[Dict[str, Any]] | None = None
expect_tool_calls: bool | None = None
@dataclass(slots=True)
class ProbeResult:
"""单次 API 探测结果。"""
task_name: str
target_model_name: str
actual_model_name: str
provider_name: str
client_type: str
tool_argument_parse_mode: str
case_name: str
attempt: int
success: bool
elapsed_seconds: float
errors: List[str] = field(default_factory=list)
warnings: List[str] = field(default_factory=list)
response_text: str = ""
reasoning_text: str = ""
tool_calls: List[Dict[str, Any]] = field(default_factory=list)
def _ensure_utf8_console() -> None:
"""尽量将控制台编码切换为 UTF-8。"""
try:
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8")
except Exception:
pass
def _build_function_tool(name: str, description: str, parameters: Dict[str, Any]) -> Dict[str, Any]:
"""构造 OpenAI 风格 function tool。"""
return {
"type": "function",
"function": {
"name": name,
"description": description,
"parameters": parameters,
},
}
def _build_probe_tools() -> List[Dict[str, Any]]:
"""构造通用测试工具。"""
weather_tool = _build_function_tool(
name="lookup_weather",
description="查询指定城市天气。",
parameters={
"type": "object",
"properties": {
"city": {"type": "string", "description": "城市名"},
"unit": {
"type": "string",
"description": "温度单位",
"enum": ["celsius", "fahrenheit"],
},
"include_forecast": {"type": "boolean", "description": "是否包含未来天气"},
},
"required": ["city", "unit", "include_forecast"],
"additionalProperties": False,
},
)
search_tool = _build_function_tool(
name="search_docs",
description="搜索内部知识库。",
parameters={
"type": "object",
"properties": {
"query": {"type": "string", "description": "搜索关键词"},
"top_k": {"type": "integer", "description": "返回条数"},
"filters": {
"type": "object",
"description": "过滤条件",
"properties": {
"scope": {"type": "string", "description": "搜索范围"},
"tag": {"type": "string", "description": "标签"},
},
"required": ["scope", "tag"],
"additionalProperties": False,
},
},
"required": ["query", "top_k", "filters"],
"additionalProperties": False,
},
)
return [weather_tool, search_tool]
def _build_default_scenarios() -> List[ToolCallScenario]:
"""构造默认测试场景。"""
tools = _build_probe_tools()
weather_tool = tools[0]
search_tool = tools[1]
history_tool_call = {
"id": "call_hist_weather_001",
"type": "function",
"function": {
"name": "lookup_weather",
"arguments": {
"city": "上海",
"unit": "celsius",
"include_forecast": True,
},
},
}
nested_history_tool_call = {
"id": "call_hist_search_001",
"type": "function",
"function": {
"name": "search_docs",
"arguments": {
"query": "工具调用兼容性",
"top_k": 3,
"filters": {
"scope": "internal",
"tag": "tool-call",
},
},
},
}
return [
ToolCallScenario(
name="fresh_tool_call",
description="首轮普通工具调用请求。",
prompt=[
{
"role": "system",
"content": (
"你正在执行工具调用连通性测试。"
"如果能调用工具,就优先调用最合适的工具。"
),
},
{
"role": "user",
"content": "请查询上海天气,并使用工具给出参数。",
},
],
tool_options=[weather_tool],
expect_tool_calls=True,
),
ToolCallScenario(
name="history_assistant_tool_calls_with_content",
description="历史 assistant 同时包含文本和 tool_calls当前轮不再提供 tools。",
prompt=[
{"role": "system", "content": "你正在执行多轮上下文兼容性测试。"},
{"role": "user", "content": "先帮我查一下上海天气。"},
{
"role": "assistant",
"content": "我先查询天气,再继续回答。",
"tool_calls": [history_tool_call],
},
{"role": "user", "content": "继续说,别丢掉上下文。"},
],
tool_options=None,
expect_tool_calls=None,
),
ToolCallScenario(
name="history_assistant_tool_calls_without_content",
description="历史 assistant 只有 tool_calls没有文本内容。",
prompt=[
{"role": "system", "content": "你正在执行多轮上下文兼容性测试。"},
{"role": "user", "content": "先帮我查一下上海天气。"},
{
"role": "assistant",
"tool_calls": [history_tool_call],
},
{"role": "user", "content": "继续。"},
],
tool_options=None,
expect_tool_calls=None,
),
ToolCallScenario(
name="history_tool_result_followup",
description="历史中包含 assistant.tool_calls 与对应 tool 结果消息。",
prompt=[
{"role": "system", "content": "你正在执行工具调用闭环兼容性测试。"},
{"role": "user", "content": "先查上海天气。"},
{
"role": "assistant",
"content": "我先查询天气。",
"tool_calls": [history_tool_call],
},
{
"role": "tool",
"tool_call_id": "call_hist_weather_001",
"content": json.dumps(
{
"city": "上海",
"condition": "多云",
"temperature_c": 24,
"forecast": ["", "小雨"],
},
ensure_ascii=False,
),
},
{"role": "user", "content": "结合上面的查询结果继续总结。"},
],
tool_options=None,
expect_tool_calls=None,
),
ToolCallScenario(
name="history_multiple_tool_calls_and_results",
description="历史中包含多个 tool_calls 与多条 tool 结果。",
prompt=[
{"role": "system", "content": "你正在执行多工具上下文兼容性测试。"},
{"role": "user", "content": "先查天气,再搜一下工具调用兼容性文档。"},
{
"role": "assistant",
"content": "我分两步查询。",
"tool_calls": [history_tool_call, nested_history_tool_call],
},
{
"role": "tool",
"tool_call_id": "call_hist_weather_001",
"content": json.dumps(
{
"city": "上海",
"condition": "",
"temperature_c": 22,
},
ensure_ascii=False,
),
},
{
"role": "tool",
"tool_call_id": "call_hist_search_001",
"content": json.dumps(
{
"items": [
"OpenAI 兼容接口的 arguments 常见为 JSON 字符串",
"部分 provider 在历史消息回放时兼容性较弱",
],
},
ensure_ascii=False,
),
},
{"role": "user", "content": "继续整合上面的两个结果。"},
],
tool_options=None,
expect_tool_calls=None,
),
ToolCallScenario(
name="history_tool_calls_with_current_tools",
description="保留历史 tool_calls同时当前轮仍然提供 tools。",
prompt=[
{"role": "system", "content": "你正在执行历史 tool_calls 与当前 tools 共存测试。"},
{"role": "user", "content": "先查上海天气。"},
{
"role": "assistant",
"content": "我先查天气。",
"tool_calls": [history_tool_call],
},
{
"role": "tool",
"tool_call_id": "call_hist_weather_001",
"content": json.dumps(
{
"city": "上海",
"condition": "",
"temperature_c": 26,
},
ensure_ascii=False,
),
},
{"role": "user", "content": "现在再搜一下工具调用兼容性文档。"},
],
tool_options=[search_tool],
expect_tool_calls=True,
),
]
def _parse_multi_value_args(raw_values: Sequence[str] | None) -> List[str]:
"""解析命令行中的多值参数。"""
parsed_values: List[str] = []
for raw_value in raw_values or []:
for item in str(raw_value).split(","):
normalized_item = item.strip()
if normalized_item:
parsed_values.append(normalized_item)
return parsed_values
def _build_model_map() -> Dict[str, ModelInfo]:
"""构造模型名到模型配置的映射。"""
return {model.name: model for model in config_manager.get_model_config().models}
def _build_provider_map() -> Dict[str, APIProvider]:
"""构造 Provider 名称到配置的映射。"""
return {provider.name: provider for provider in config_manager.get_model_config().api_providers}
def _pick_default_task_name(task_names: Sequence[str]) -> str:
"""选择默认任务名。"""
if "utils" in task_names:
return "utils"
if not task_names:
raise ValueError("当前没有可用的任务配置")
return str(task_names[0])
def _resolve_targets(task_filters: Sequence[str], model_filters: Sequence[str], fallback_task: str) -> List[ProbeTarget]:
"""根据命令行参数解析待测试目标。"""
available_tasks = get_available_models()
model_map = _build_model_map()
provider_map = _build_provider_map()
if not available_tasks:
raise ValueError("未找到任何可用的模型任务配置")
if task_filters:
selected_task_names = []
for task_name in task_filters:
if task_name not in available_tasks:
raise ValueError(f"未找到任务 `{task_name}`")
selected_task_names.append(task_name)
else:
selected_task_names = [
task_name
for task_name in available_tasks
if task_name not in DEFAULT_SKIP_TASKS
]
if not selected_task_names:
raise ValueError("没有可用于工具调用 API 测试的任务,请显式通过 --task 指定")
default_task_name = fallback_task if fallback_task in available_tasks else _pick_default_task_name(selected_task_names)
resolved_targets: List[ProbeTarget] = []
seen_models: set[str] = set()
if model_filters:
model_names = list(model_filters)
else:
model_names = []
for task_name in selected_task_names:
task_config = available_tasks[task_name]
for model_name in task_config.model_list:
if model_name not in model_names:
model_names.append(model_name)
for model_name in model_names:
if model_name in seen_models:
continue
if model_name not in model_map:
raise ValueError(f"未找到模型 `{model_name}`")
target_task_name = ""
for task_name in selected_task_names:
if model_name in available_tasks[task_name].model_list:
target_task_name = task_name
break
if not target_task_name:
target_task_name = default_task_name
model_info = model_map[model_name]
provider_info = provider_map[model_info.api_provider]
resolved_targets.append(
ProbeTarget(
task_name=target_task_name,
model_name=model_name,
provider_name=provider_info.name,
client_type=provider_info.client_type,
tool_argument_parse_mode=provider_info.tool_argument_parse_mode,
)
)
seen_models.add(model_name)
return resolved_targets
@contextmanager
def _pin_task_to_model(task_name: str, model_name: str) -> Iterator[None]:
"""临时将某个任务锁定到单模型。"""
model_task_config = config_manager.get_model_config().model_task_config
task_config = getattr(model_task_config, task_name, None)
if not isinstance(task_config, TaskConfig):
raise ValueError(f"未找到任务 `{task_name}` 对应的配置")
original_model_list = list(task_config.model_list)
original_selection_strategy = task_config.selection_strategy
task_config.model_list = [model_name]
task_config.selection_strategy = "balance"
try:
yield
finally:
task_config.model_list = original_model_list
task_config.selection_strategy = original_selection_strategy
def _serialize_tool_calls(tool_calls: Any) -> List[Dict[str, Any]]:
"""序列化返回中的工具调用。"""
if not tool_calls:
return []
serialized_items: List[Dict[str, Any]] = []
for tool_call in tool_calls:
serialized_items.append(
{
"id": getattr(tool_call, "call_id", ""),
"function": {
"name": getattr(tool_call, "func_name", ""),
"arguments": dict(getattr(tool_call, "args", {}) or {}),
},
**(
{"extra_content": dict(getattr(tool_call, "extra_content", {}) or {})}
if getattr(tool_call, "extra_content", None)
else {}
),
}
)
return serialized_items
def _validate_service_result(service_result: LLMServiceResult, scenario: ToolCallScenario) -> tuple[List[str], List[str], List[Dict[str, Any]]]:
"""校验服务结果。"""
errors: List[str] = []
warnings: List[str] = []
completion = service_result.completion
serialized_tool_calls = _serialize_tool_calls(completion.tool_calls)
if not service_result.success:
errors.append(service_result.error or completion.response or "请求失败,但没有返回明确错误")
return errors, warnings, serialized_tool_calls
if scenario.expect_tool_calls is True and not serialized_tool_calls:
warnings.append("本场景期望模型倾向于调用工具,但未返回 tool_calls")
if scenario.expect_tool_calls is False and serialized_tool_calls:
warnings.append("本场景未期望继续调用工具,但模型返回了 tool_calls")
if completion.response.strip():
warnings.append("模型返回了可见文本")
return errors, warnings, serialized_tool_calls
async def _run_single_probe(
target: ProbeTarget,
scenario: ToolCallScenario,
attempt: int,
max_tokens: int,
temperature: float,
) -> ProbeResult:
"""执行单次 API 探测。"""
request = LLMServiceRequest(
task_name=target.task_name,
request_type=f"tool_call_api_matrix.{scenario.name}.attempt_{attempt}",
prompt=scenario.prompt,
tool_options=scenario.tool_options,
temperature=temperature,
max_tokens=max_tokens,
)
started_at = time.perf_counter()
with _pin_task_to_model(target.task_name, target.model_name):
service_result = await generate(request)
elapsed_seconds = time.perf_counter() - started_at
errors, warnings, serialized_tool_calls = _validate_service_result(service_result, scenario)
completion = service_result.completion
return ProbeResult(
task_name=target.task_name,
target_model_name=target.model_name,
actual_model_name=completion.model_name,
provider_name=target.provider_name,
client_type=target.client_type,
tool_argument_parse_mode=target.tool_argument_parse_mode,
case_name=scenario.name,
attempt=attempt,
success=not errors,
elapsed_seconds=elapsed_seconds,
errors=errors,
warnings=warnings,
response_text=completion.response,
reasoning_text=completion.reasoning,
tool_calls=serialized_tool_calls,
)
def _print_targets(targets: Sequence[ProbeTarget]) -> None:
"""打印待测试目标。"""
print("待测试目标:")
for index, target in enumerate(targets, start=1):
print(
f"{index}. model={target.model_name} | task={target.task_name} | "
f"provider={target.provider_name} | client={target.client_type} | "
f"tool_argument_parse_mode={target.tool_argument_parse_mode}"
)
def _print_available_targets() -> None:
"""打印当前可用任务与模型。"""
available_tasks = get_available_models()
model_map = _build_model_map()
task_names = list(available_tasks.keys())
print("当前可用任务:")
for task_name in task_names:
task_config = available_tasks[task_name]
print(f"- {task_name}: {list(task_config.model_list)}")
referenced_models = {
model_name
for task_config in available_tasks.values()
for model_name in task_config.model_list
}
print("\n当前配置中的模型:")
for model_name, model_info in model_map.items():
referenced_mark = "已被任务引用" if model_name in referenced_models else "未被任务引用"
print(
f"- {model_name}: provider={model_info.api_provider}, "
f"identifier={model_info.model_identifier}, {referenced_mark}"
)
def _select_scenarios(case_filters: Sequence[str]) -> List[ToolCallScenario]:
"""按名称筛选测试场景。"""
all_scenarios = {scenario.name: scenario for scenario in _build_default_scenarios()}
if not case_filters:
return list(all_scenarios.values())
selected_scenarios: List[ToolCallScenario] = []
for case_name in case_filters:
if case_name not in all_scenarios:
raise ValueError(
f"未知测试场景 `{case_name}`,可选值: {', '.join(sorted(all_scenarios))}"
)
selected_scenarios.append(all_scenarios[case_name])
return selected_scenarios
def _print_single_result(result: ProbeResult, show_response: bool) -> None:
"""打印单次结果。"""
status_text = "PASS" if result.success else "FAIL"
print(
f"[{status_text}] model={result.target_model_name} | task={result.task_name} | "
f"case={result.case_name} | attempt={result.attempt} | elapsed={result.elapsed_seconds:.2f}s"
)
if result.errors:
for error in result.errors:
print(f" ERROR: {error}")
if result.warnings:
for warning in result.warnings:
print(f" WARN: {warning}")
if result.tool_calls:
print(f" tool_calls: {json.dumps(result.tool_calls, ensure_ascii=False)}")
if show_response and result.response_text.strip():
print(f" response: {result.response_text}")
def _build_summary(results: Sequence[ProbeResult]) -> Dict[str, Any]:
"""构造结果摘要。"""
total_count = len(results)
passed_count = sum(1 for result in results if result.success)
failed_count = total_count - passed_count
failed_items = [
{
"model_name": result.target_model_name,
"case_name": result.case_name,
"attempt": result.attempt,
"errors": list(result.errors),
}
for result in results
if not result.success
]
return {
"total": total_count,
"passed": passed_count,
"failed": failed_count,
"failed_items": failed_items,
}
def _write_json_report(json_out: str, results: Sequence[ProbeResult]) -> None:
"""将测试结果写入 JSON 文件。"""
output_path = Path(json_out).expanduser().resolve()
output_path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
"summary": _build_summary(results),
"results": [asdict(result) for result in results],
}
output_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"\n结果已写入: {output_path}")
async def _run_probes(args: Namespace) -> List[ProbeResult]:
"""执行所有探测请求。"""
task_filters = _parse_multi_value_args(args.task)
model_filters = _parse_multi_value_args(args.model)
case_filters = _parse_multi_value_args(args.case)
selected_scenarios = _select_scenarios(case_filters)
targets = _resolve_targets(task_filters, model_filters, args.fallback_task)
_print_targets(targets)
print("")
results: List[ProbeResult] = []
for target in targets:
for attempt in range(1, args.repeat + 1):
for scenario in selected_scenarios:
print(
f"开始测试: model={target.model_name}, task={target.task_name}, "
f"case={scenario.name}, attempt={attempt}"
)
result = await _run_single_probe(
target=target,
scenario=scenario,
attempt=attempt,
max_tokens=args.max_tokens,
temperature=args.temperature,
)
_print_single_result(result, args.show_response)
print("")
results.append(result)
return results
def _build_parser() -> ArgumentParser:
"""构造命令行参数解析器。"""
parser = ArgumentParser(
description=(
"测试不同模型在多种工具调用消息形态下的 API 兼容性。\n"
"重点覆盖历史 assistant.tool_calls、tool 结果消息、多工具调用等场景。"
)
)
parser.add_argument(
"--task",
action="append",
help="指定任务名,可重复传入,或使用逗号分隔多个值,例如 --task utils --task planner",
)
parser.add_argument(
"--model",
action="append",
help="指定模型名,可重复传入,或使用逗号分隔多个值,例如 --model qwen3.5-35b-a3b",
)
parser.add_argument(
"--case",
action="append",
help=(
"指定测试场景名,可选值包括 "
"fresh_tool_call、history_assistant_tool_calls_with_content、"
"history_assistant_tool_calls_without_content、history_tool_result_followup、"
"history_multiple_tool_calls_and_results、history_tool_calls_with_current_tools"
),
)
parser.add_argument(
"--repeat",
type=int,
default=1,
help="每个模型每个场景重复测试次数,默认 1",
)
parser.add_argument(
"--max-tokens",
type=int,
default=512,
help="单次测试的最大输出 token 数,默认 512",
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="单次测试温度,默认 0.0,以尽量提高稳定性",
)
parser.add_argument(
"--fallback-task",
default="utils",
help="当指定模型未被已选任务引用时,用于挂载该模型的任务名,默认 utils",
)
parser.add_argument(
"--json-out",
help="可选,将结果写入指定 JSON 文件",
)
parser.add_argument(
"--list-targets",
action="store_true",
help="仅打印当前任务与模型映射,不发起网络请求",
)
parser.add_argument(
"--show-response",
action="store_true",
help="打印模型返回的可见文本内容",
)
return parser
def main() -> int:
"""脚本入口。"""
_ensure_utf8_console()
config_manager.initialize()
parser = _build_parser()
args = parser.parse_args()
if args.repeat < 1:
parser.error("--repeat 必须大于等于 1")
if args.max_tokens < 1:
parser.error("--max-tokens 必须大于等于 1")
if args.list_targets:
_print_available_targets()
return 0
results = asyncio.run(_run_probes(args))
summary = _build_summary(results)
print("测试摘要:")
print(
f"total={summary['total']} | passed={summary['passed']} | failed={summary['failed']}"
)
if summary["failed_items"]:
print("失败明细:")
for failed_item in summary["failed_items"]:
print(
f"- model={failed_item['model_name']} | case={failed_item['case_name']} | "
f"attempt={failed_item['attempt']} | errors={failed_item['errors']}"
)
if args.json_out:
_write_json_report(args.json_out, results)
return 0 if summary["failed"] == 0 else 1
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -748,6 +748,9 @@ class ComponentQueryService:
return payload return payload
for key, value in context_payload.items(): for key, value in context_payload.items():
if key in {"stream_id", "chat_id"}:
payload[key] = value
continue
if key not in payload or not payload.get(key): if key not in payload or not payload.get(key):
payload[key] = value payload[key] = value
return payload return payload

View File

@@ -1,76 +0,0 @@
from pathlib import Path
import sys
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from src.config.config_upgrade_hooks import (
BOT_CONFIG_UPGRADE_HOOKS,
ConfigUpgradeHook,
apply_config_upgrade_hooks,
set_nested_config_value,
)
from src.config.official_configs import ChatConfig
import src.config.config_upgrade_hooks as hooks
def test_apply_config_upgrade_hooks_runs_when_target_version_is_crossed(monkeypatch):
def migrate(data):
changed = set_nested_config_value(data, ("chat", "enable"), False)
return ["chat.enable"] if changed else []
monkeypatch.setattr(
hooks,
"BOT_CONFIG_UPGRADE_HOOKS",
(ConfigUpgradeHook(target_version="8.10.11", config_names=("bot_config.toml",), migrate=migrate),),
)
data = {"chat": {"enable": True}}
result = apply_config_upgrade_hooks(data, "bot_config.toml", "8.10.10", "8.10.11")
assert result.migrated is True
assert result.reason == "8.10.11:chat.enable"
assert result.data["chat"]["enable"] is False
def test_apply_config_upgrade_hooks_skips_versions_outside_upgrade_range(monkeypatch):
def migrate(data):
set_nested_config_value(data, ("chat", "enable"), False)
return ["chat.enable"]
monkeypatch.setattr(
hooks,
"BOT_CONFIG_UPGRADE_HOOKS",
(ConfigUpgradeHook(target_version="8.10.11", config_names=("bot_config.toml",), migrate=migrate),),
)
data = {"chat": {"enable": True}}
result = apply_config_upgrade_hooks(data, "bot_config.toml", "8.10.11", "8.10.12")
assert result.migrated is False
assert result.data["chat"]["enable"] is True
def test_set_nested_config_value_can_keep_existing_value():
data = {"webui": {"port": 8001}}
changed = set_nested_config_value(data, ("webui", "port"), 8080, force=False)
assert changed is False
assert data["webui"]["port"] == 8001
def test_builtin_hook_resets_group_chat_prompt_when_upgrading_from_8_10_10():
data = {"chat": {"group_chat_prompt": "自定义旧提示词"}}
result = apply_config_upgrade_hooks(data, "bot_config.toml", "8.10.10", "8.10.11")
assert result.migrated is True
assert result.reason == "8.10.11:chat.group_chat_prompt"
assert result.data["chat"]["group_chat_prompt"] == ChatConfig().group_chat_prompt
def test_bot_config_upgrade_hooks_register_group_chat_prompt_reset():
assert len(BOT_CONFIG_UPGRADE_HOOKS) == 1
assert BOT_CONFIG_UPGRADE_HOOKS[0].target_version == "8.10.11"