fix: override plugin tool stream context and remove repository tests
- force plugin tools to use runtime stream_id/chat_id from execution context - remove repository test assets and vitest config - document that temporary test files must be deleted after use
This commit is contained in:
@@ -33,6 +33,7 @@
|
|||||||
# 运行/调试/构建/测试/依赖
|
# 运行/调试/构建/测试/依赖
|
||||||
优先使用uv
|
优先使用uv
|
||||||
依赖项以 pyproject.toml 为准,要同步更新requirements.txt
|
依赖项以 pyproject.toml 为准,要同步更新requirements.txt
|
||||||
|
如为当前任务临时创建测试文件,跑完测试后必须立刻删除,不要保留在仓库中,也不要进入共享历史
|
||||||
前端改动后,如需走离线发布工作流,必须先在 `dashboard` 目录执行 `npm run build`(例如:`D:\Nodejs\npm.cmd run build`),确保生成最新的 `dashboard/dist`
|
前端改动后,如需走离线发布工作流,必须先在 `dashboard` 目录执行 `npm run build`(例如:`D:\Nodejs\npm.cmd run build`),确保生成最新的 `dashboard/dist`
|
||||||
当前项目的离线发布工作流依赖仓库中已提交的 `dashboard/dist`;在修改 `dashboard/src`、`dashboard/public`、`dashboard/package.json`、`dashboard/package-lock.json` 或其他影响构建产物的前端文件后,要同步提交对应的 `dashboard/dist`
|
当前项目的离线发布工作流依赖仓库中已提交的 `dashboard/dist`;在修改 `dashboard/src`、`dashboard/public`、`dashboard/package.json`、`dashboard/package-lock.json` 或其他影响构建产物的前端文件后,要同步提交对应的 `dashboard/dist`
|
||||||
不要在离线发布场景中假设服务器或 runner 可以联网安装 Node.js 依赖;除非明确确认发布机已具备可用的 `npm` 环境,否则默认按“本地构建 `dashboard/dist` 并随仓库提交”处理
|
不要在离线发布场景中假设服务器或 runner 可以联网安装 Node.js 依赖;除非明确确认发布机已具备可用的 `npm` 环境,否则默认按“本地构建 `dashboard/dist` 并随仓库提交”处理
|
||||||
|
|||||||
@@ -1,427 +0,0 @@
|
|||||||
import { describe, it, expect, vi } from 'vitest'
|
|
||||||
import { screen } from '@testing-library/dom'
|
|
||||||
import { render } from '@testing-library/react'
|
|
||||||
import userEvent from '@testing-library/user-event'
|
|
||||||
|
|
||||||
import { DynamicConfigForm } from '../DynamicConfigForm'
|
|
||||||
import { FieldHookRegistry } from '@/lib/field-hooks'
|
|
||||||
import type { ConfigSchema } from '@/types/config-schema'
|
|
||||||
import type { FieldHookComponentProps } from '@/lib/field-hooks'
|
|
||||||
|
|
||||||
describe('DynamicConfigForm', () => {
|
|
||||||
describe('basic rendering', () => {
|
|
||||||
it('renders simple fields', () => {
|
|
||||||
const schema: ConfigSchema = {
|
|
||||||
className: 'TestConfig',
|
|
||||||
classDoc: 'Test configuration',
|
|
||||||
fields: [
|
|
||||||
{
|
|
||||||
name: 'field1',
|
|
||||||
type: 'string',
|
|
||||||
label: 'Field 1',
|
|
||||||
description: 'First field',
|
|
||||||
required: false,
|
|
||||||
default: 'value1',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: 'field2',
|
|
||||||
type: 'boolean',
|
|
||||||
label: 'Field 2',
|
|
||||||
description: 'Second field',
|
|
||||||
required: false,
|
|
||||||
default: false,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
const values = { field1: 'value1', field2: false }
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} />)
|
|
||||||
|
|
||||||
expect(screen.getByText('Field 1')).toBeInTheDocument()
|
|
||||||
expect(screen.getByText('Field 2')).toBeInTheDocument()
|
|
||||||
expect(screen.getByText('First field')).toBeInTheDocument()
|
|
||||||
expect(screen.getByText('Second field')).toBeInTheDocument()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('renders nested schema', () => {
|
|
||||||
const schema: ConfigSchema = {
|
|
||||||
className: 'MainConfig',
|
|
||||||
classDoc: 'Main configuration',
|
|
||||||
fields: [
|
|
||||||
{
|
|
||||||
name: 'top_field',
|
|
||||||
type: 'string',
|
|
||||||
label: 'Top Field',
|
|
||||||
description: 'Top level field',
|
|
||||||
required: false,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
nested: {
|
|
||||||
sub_config: {
|
|
||||||
className: 'SubConfig',
|
|
||||||
classDoc: 'Sub configuration',
|
|
||||||
fields: [
|
|
||||||
{
|
|
||||||
name: 'nested_field',
|
|
||||||
type: 'number',
|
|
||||||
label: 'Nested Field',
|
|
||||||
description: 'Nested field',
|
|
||||||
required: false,
|
|
||||||
default: 42,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
const values = {
|
|
||||||
top_field: 'top',
|
|
||||||
sub_config: {
|
|
||||||
nested_field: 42,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} />)
|
|
||||||
|
|
||||||
expect(screen.getByText('Top Field')).toBeInTheDocument()
|
|
||||||
expect(screen.getByText('Sub configuration')).toBeInTheDocument()
|
|
||||||
expect(screen.getByText('Nested Field')).toBeInTheDocument()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('Hook system', () => {
|
|
||||||
it('renders Hook component in replace mode', () => {
|
|
||||||
const TestHookComponent: React.FC<FieldHookComponentProps> = ({ fieldPath, value }) => {
|
|
||||||
return <div data-testid="hook-component">Hook: {fieldPath} = {String(value)}</div>
|
|
||||||
}
|
|
||||||
|
|
||||||
const hooks = new FieldHookRegistry()
|
|
||||||
hooks.register('hooked_field', TestHookComponent, 'replace')
|
|
||||||
|
|
||||||
const schema: ConfigSchema = {
|
|
||||||
className: 'TestConfig',
|
|
||||||
classDoc: 'Test configuration',
|
|
||||||
fields: [
|
|
||||||
{
|
|
||||||
name: 'hooked_field',
|
|
||||||
type: 'string',
|
|
||||||
label: 'Hooked Field',
|
|
||||||
description: 'A field with hook',
|
|
||||||
required: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: 'normal_field',
|
|
||||||
type: 'string',
|
|
||||||
label: 'Normal Field',
|
|
||||||
description: 'A normal field',
|
|
||||||
required: false,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
const values = { hooked_field: 'test', normal_field: 'normal' }
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} hooks={hooks} />)
|
|
||||||
|
|
||||||
expect(screen.getByTestId('hook-component')).toBeInTheDocument()
|
|
||||||
expect(screen.getByText('Hook: hooked_field = test')).toBeInTheDocument()
|
|
||||||
expect(screen.queryByText('Hooked Field')).not.toBeInTheDocument()
|
|
||||||
expect(screen.getByText('Normal Field')).toBeInTheDocument()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('renders Hook component in wrapper mode', () => {
|
|
||||||
const WrapperHookComponent: React.FC<FieldHookComponentProps> = ({ fieldPath, children }) => {
|
|
||||||
return (
|
|
||||||
<div data-testid="wrapper-hook">
|
|
||||||
<div>Wrapper for: {fieldPath}</div>
|
|
||||||
{children}
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
const hooks = new FieldHookRegistry()
|
|
||||||
hooks.register('wrapped_field', WrapperHookComponent, 'wrapper')
|
|
||||||
|
|
||||||
const schema: ConfigSchema = {
|
|
||||||
className: 'TestConfig',
|
|
||||||
classDoc: 'Test configuration',
|
|
||||||
fields: [
|
|
||||||
{
|
|
||||||
name: 'wrapped_field',
|
|
||||||
type: 'string',
|
|
||||||
label: 'Wrapped Field',
|
|
||||||
description: 'A wrapped field',
|
|
||||||
required: false,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
const values = { wrapped_field: 'test' }
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} hooks={hooks} />)
|
|
||||||
|
|
||||||
expect(screen.getByTestId('wrapper-hook')).toBeInTheDocument()
|
|
||||||
expect(screen.getByText('Wrapper for: wrapped_field')).toBeInTheDocument()
|
|
||||||
expect(screen.getByText('Wrapped Field')).toBeInTheDocument()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('passes correct props to Hook component', () => {
|
|
||||||
const TestHookComponent: React.FC<FieldHookComponentProps> = ({ fieldPath, value, onChange }) => {
|
|
||||||
return (
|
|
||||||
<div>
|
|
||||||
<div data-testid="field-path">{fieldPath}</div>
|
|
||||||
<div data-testid="field-value">{String(value)}</div>
|
|
||||||
<button onClick={() => onChange?.('new_value')}>Change</button>
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
const hooks = new FieldHookRegistry()
|
|
||||||
hooks.register('test_field', TestHookComponent, 'replace')
|
|
||||||
|
|
||||||
const schema: ConfigSchema = {
|
|
||||||
className: 'TestConfig',
|
|
||||||
classDoc: 'Test configuration',
|
|
||||||
fields: [
|
|
||||||
{
|
|
||||||
name: 'test_field',
|
|
||||||
type: 'string',
|
|
||||||
label: 'Test Field',
|
|
||||||
description: 'A test field',
|
|
||||||
required: false,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
const values = { test_field: 'original' }
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} hooks={hooks} />)
|
|
||||||
|
|
||||||
expect(screen.getByTestId('field-path')).toHaveTextContent('test_field')
|
|
||||||
expect(screen.getByTestId('field-value')).toHaveTextContent('original')
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('onChange propagation', () => {
|
|
||||||
it('propagates onChange from simple field', async () => {
|
|
||||||
const schema: ConfigSchema = {
|
|
||||||
className: 'TestConfig',
|
|
||||||
classDoc: 'Test configuration',
|
|
||||||
fields: [
|
|
||||||
{
|
|
||||||
name: 'test_field',
|
|
||||||
type: 'string',
|
|
||||||
label: 'Test Field',
|
|
||||||
description: 'A test field',
|
|
||||||
required: false,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
const values = { test_field: '' }
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} />)
|
|
||||||
|
|
||||||
const input = screen.getByRole('textbox')
|
|
||||||
input.focus()
|
|
||||||
await userEvent.keyboard('Hello')
|
|
||||||
|
|
||||||
expect(onChange).toHaveBeenCalledTimes(5)
|
|
||||||
expect(onChange.mock.calls.every(call => call[0] === 'test_field')).toBe(true)
|
|
||||||
expect(onChange).toHaveBeenNthCalledWith(1, 'test_field', 'H')
|
|
||||||
expect(onChange).toHaveBeenNthCalledWith(5, 'test_field', 'o')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('propagates onChange from nested field with correct path', async () => {
|
|
||||||
const schema: ConfigSchema = {
|
|
||||||
className: 'MainConfig',
|
|
||||||
classDoc: 'Main configuration',
|
|
||||||
fields: [],
|
|
||||||
nested: {
|
|
||||||
sub_config: {
|
|
||||||
className: 'SubConfig',
|
|
||||||
classDoc: 'Sub configuration',
|
|
||||||
fields: [
|
|
||||||
{
|
|
||||||
name: 'nested_field',
|
|
||||||
type: 'string',
|
|
||||||
label: 'Nested Field',
|
|
||||||
description: 'Nested field',
|
|
||||||
required: false,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
const values = {
|
|
||||||
sub_config: {
|
|
||||||
nested_field: '',
|
|
||||||
},
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} />)
|
|
||||||
|
|
||||||
const input = screen.getByRole('textbox')
|
|
||||||
input.focus()
|
|
||||||
await userEvent.keyboard('Test')
|
|
||||||
|
|
||||||
expect(onChange).toHaveBeenCalledTimes(4)
|
|
||||||
expect(onChange.mock.calls.every(call => call[0] === 'sub_config.nested_field')).toBe(true)
|
|
||||||
expect(onChange).toHaveBeenNthCalledWith(1, 'sub_config.nested_field', 'T')
|
|
||||||
expect(onChange).toHaveBeenNthCalledWith(4, 'sub_config.nested_field', 't')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('propagates onChange from Hook component', async () => {
|
|
||||||
const TestHookComponent: React.FC<FieldHookComponentProps> = ({ onChange }) => {
|
|
||||||
return <button onClick={() => onChange?.('hook_value')}>Set Value</button>
|
|
||||||
}
|
|
||||||
|
|
||||||
const hooks = new FieldHookRegistry()
|
|
||||||
hooks.register('hooked_field', TestHookComponent, 'replace')
|
|
||||||
|
|
||||||
const schema: ConfigSchema = {
|
|
||||||
className: 'TestConfig',
|
|
||||||
classDoc: 'Test configuration',
|
|
||||||
fields: [
|
|
||||||
{
|
|
||||||
name: 'hooked_field',
|
|
||||||
type: 'string',
|
|
||||||
label: 'Hooked Field',
|
|
||||||
description: 'A hooked field',
|
|
||||||
required: false,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
const values = { hooked_field: '' }
|
|
||||||
const onChange = vi.fn()
|
|
||||||
const user = userEvent.setup()
|
|
||||||
|
|
||||||
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} hooks={hooks} />)
|
|
||||||
|
|
||||||
await user.click(screen.getByRole('button'))
|
|
||||||
|
|
||||||
expect(onChange).toHaveBeenCalledWith('hooked_field', 'hook_value')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('renders nested Hook component with full field path', async () => {
|
|
||||||
const NestedHookComponent: React.FC<FieldHookComponentProps> = ({ fieldPath, onChange }) => {
|
|
||||||
return (
|
|
||||||
<button onClick={() => onChange?.([{ enabled: true }])}>
|
|
||||||
{fieldPath}
|
|
||||||
</button>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
const hooks = new FieldHookRegistry()
|
|
||||||
hooks.register('mcp.servers', NestedHookComponent, 'replace')
|
|
||||||
|
|
||||||
const schema: ConfigSchema = {
|
|
||||||
className: 'RootConfig',
|
|
||||||
classDoc: 'Root configuration',
|
|
||||||
fields: [],
|
|
||||||
nested: {
|
|
||||||
mcp: {
|
|
||||||
className: 'MCPConfig',
|
|
||||||
classDoc: 'MCP 配置',
|
|
||||||
fields: [
|
|
||||||
{
|
|
||||||
name: 'enable',
|
|
||||||
type: 'boolean',
|
|
||||||
label: '启用 MCP',
|
|
||||||
description: '是否启用 MCP',
|
|
||||||
required: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: 'servers',
|
|
||||||
type: 'array',
|
|
||||||
label: '服务器列表',
|
|
||||||
description: '复杂对象数组',
|
|
||||||
required: false,
|
|
||||||
items: {
|
|
||||||
type: 'object',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
nested: {
|
|
||||||
servers: {
|
|
||||||
className: 'MCPServerItemConfig',
|
|
||||||
classDoc: 'MCP 服务器项',
|
|
||||||
fields: [],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
const values = {
|
|
||||||
mcp: {
|
|
||||||
enable: true,
|
|
||||||
servers: [],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
const user = userEvent.setup()
|
|
||||||
|
|
||||||
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} hooks={hooks} />)
|
|
||||||
|
|
||||||
await user.click(screen.getByRole('button', { name: 'mcp.servers' }))
|
|
||||||
|
|
||||||
expect(onChange).toHaveBeenCalledWith('mcp.servers', [{ enabled: true }])
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('edge cases', () => {
|
|
||||||
it('renders with empty nested values', () => {
|
|
||||||
const schema: ConfigSchema = {
|
|
||||||
className: 'MainConfig',
|
|
||||||
classDoc: 'Main configuration',
|
|
||||||
fields: [],
|
|
||||||
nested: {
|
|
||||||
sub_config: {
|
|
||||||
className: 'SubConfig',
|
|
||||||
classDoc: 'Sub configuration',
|
|
||||||
fields: [
|
|
||||||
{
|
|
||||||
name: 'nested_field',
|
|
||||||
type: 'string',
|
|
||||||
label: 'Nested Field',
|
|
||||||
description: 'Nested field',
|
|
||||||
required: false,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
const values = {}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} />)
|
|
||||||
|
|
||||||
expect(screen.getByText('Sub configuration')).toBeInTheDocument()
|
|
||||||
expect(screen.getByText('Nested Field')).toBeInTheDocument()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('uses default hook registry when not provided', () => {
|
|
||||||
const schema: ConfigSchema = {
|
|
||||||
className: 'TestConfig',
|
|
||||||
classDoc: 'Test configuration',
|
|
||||||
fields: [
|
|
||||||
{
|
|
||||||
name: 'test_field',
|
|
||||||
type: 'string',
|
|
||||||
label: 'Test Field',
|
|
||||||
description: 'A test field',
|
|
||||||
required: false,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
const values = { test_field: 'test' }
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} />)
|
|
||||||
|
|
||||||
expect(screen.getByText('Test Field')).toBeInTheDocument()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
@@ -1,475 +0,0 @@
|
|||||||
import { describe, it, expect, vi } from 'vitest'
|
|
||||||
import { screen } from '@testing-library/dom'
|
|
||||||
import { render } from '@testing-library/react'
|
|
||||||
import userEvent from '@testing-library/user-event'
|
|
||||||
|
|
||||||
import { DynamicField } from '../DynamicField'
|
|
||||||
import type { FieldSchema } from '@/types/config-schema'
|
|
||||||
|
|
||||||
describe('DynamicField', () => {
|
|
||||||
describe('x-widget priority', () => {
|
|
||||||
it('renders Slider when x-widget is slider', () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_slider',
|
|
||||||
type: 'number',
|
|
||||||
label: 'Test Slider',
|
|
||||||
description: 'A test slider',
|
|
||||||
required: false,
|
|
||||||
'x-widget': 'slider',
|
|
||||||
minValue: 0,
|
|
||||||
maxValue: 100,
|
|
||||||
default: 50,
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value={50} onChange={onChange} />)
|
|
||||||
|
|
||||||
expect(screen.getByText('Test Slider')).toBeInTheDocument()
|
|
||||||
expect(screen.getByRole('slider')).toBeInTheDocument()
|
|
||||||
expect(screen.getByText('50')).toBeInTheDocument()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('renders Switch when x-widget is switch', () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_switch',
|
|
||||||
type: 'boolean',
|
|
||||||
label: 'Test Switch',
|
|
||||||
description: 'A test switch',
|
|
||||||
required: false,
|
|
||||||
'x-widget': 'switch',
|
|
||||||
default: false,
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value={false} onChange={onChange} />)
|
|
||||||
|
|
||||||
expect(screen.getByText('Test Switch')).toBeInTheDocument()
|
|
||||||
expect(screen.getByRole('switch')).toBeInTheDocument()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('renders Textarea when x-widget is textarea', () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_textarea',
|
|
||||||
type: 'string',
|
|
||||||
label: 'Test Textarea',
|
|
||||||
description: 'A test textarea',
|
|
||||||
required: false,
|
|
||||||
'x-widget': 'textarea',
|
|
||||||
default: 'Hello',
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value="Hello" onChange={onChange} />)
|
|
||||||
|
|
||||||
expect(screen.getByText('Test Textarea')).toBeInTheDocument()
|
|
||||||
expect(screen.getByRole('textbox')).toBeInTheDocument()
|
|
||||||
expect(screen.getByRole('textbox')).toHaveValue('Hello')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('renders Select when x-widget is select', () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_select',
|
|
||||||
type: 'string',
|
|
||||||
label: 'Test Select',
|
|
||||||
description: 'A test select',
|
|
||||||
required: false,
|
|
||||||
'x-widget': 'select',
|
|
||||||
options: ['Option 1', 'Option 2', 'Option 3'],
|
|
||||||
default: 'Option 1',
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value="Option 1" onChange={onChange} />)
|
|
||||||
|
|
||||||
expect(screen.getByText('Test Select')).toBeInTheDocument()
|
|
||||||
expect(screen.getByRole('combobox')).toBeInTheDocument()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('renders placeholder for custom widget', () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_custom',
|
|
||||||
type: 'string',
|
|
||||||
label: 'Test Custom',
|
|
||||||
description: 'A test custom field',
|
|
||||||
required: false,
|
|
||||||
'x-widget': 'custom',
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value="" onChange={onChange} />)
|
|
||||||
|
|
||||||
expect(screen.getByText('Custom field requires Hook')).toBeInTheDocument()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('renders number Input when x-widget is input but type is integer', () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_integer_input_widget',
|
|
||||||
type: 'integer',
|
|
||||||
label: 'Test Integer Input Widget',
|
|
||||||
description: 'A numeric field rendered as input',
|
|
||||||
required: false,
|
|
||||||
'x-widget': 'input',
|
|
||||||
default: 0,
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value={2} onChange={onChange} />)
|
|
||||||
|
|
||||||
const input = screen.getByRole('spinbutton')
|
|
||||||
expect(input).toBeInTheDocument()
|
|
||||||
expect(input).toHaveValue(2)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('parses string values for numeric input widgets', () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_string_number_input_widget',
|
|
||||||
type: 'integer',
|
|
||||||
label: 'Test String Number Input Widget',
|
|
||||||
description: 'A numeric field with legacy string value',
|
|
||||||
required: false,
|
|
||||||
'x-widget': 'input',
|
|
||||||
default: 0,
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value="2" onChange={onChange} />)
|
|
||||||
|
|
||||||
expect(screen.getByRole('spinbutton')).toHaveValue(2)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('type fallback', () => {
|
|
||||||
it('renders Input for string type', () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_string',
|
|
||||||
type: 'string',
|
|
||||||
label: 'Test String',
|
|
||||||
description: 'A test string',
|
|
||||||
required: false,
|
|
||||||
default: 'Hello',
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value="Hello" onChange={onChange} />)
|
|
||||||
|
|
||||||
expect(screen.getByRole('textbox')).toBeInTheDocument()
|
|
||||||
expect(screen.getByRole('textbox')).toHaveValue('Hello')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('renders Switch for boolean type', () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_bool',
|
|
||||||
type: 'boolean',
|
|
||||||
label: 'Test Boolean',
|
|
||||||
description: 'A test boolean',
|
|
||||||
required: false,
|
|
||||||
default: true,
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value={true} onChange={onChange} />)
|
|
||||||
|
|
||||||
expect(screen.getByRole('switch')).toBeInTheDocument()
|
|
||||||
expect(screen.getByRole('switch')).toBeChecked()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('renders number Input for number type', () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_number',
|
|
||||||
type: 'number',
|
|
||||||
label: 'Test Number',
|
|
||||||
description: 'A test number',
|
|
||||||
required: false,
|
|
||||||
default: 42,
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value={42} onChange={onChange} />)
|
|
||||||
|
|
||||||
const input = screen.getByRole('spinbutton')
|
|
||||||
expect(input).toBeInTheDocument()
|
|
||||||
expect(input).toHaveValue(42)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('renders number Input for integer type', () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_integer',
|
|
||||||
type: 'integer',
|
|
||||||
label: 'Test Integer',
|
|
||||||
description: 'A test integer',
|
|
||||||
required: false,
|
|
||||||
default: 10,
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value={10} onChange={onChange} />)
|
|
||||||
|
|
||||||
const input = screen.getByRole('spinbutton')
|
|
||||||
expect(input).toBeInTheDocument()
|
|
||||||
expect(input).toHaveValue(10)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('renders Textarea for textarea type', () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_textarea_type',
|
|
||||||
type: 'textarea',
|
|
||||||
label: 'Test Textarea Type',
|
|
||||||
description: 'A test textarea type',
|
|
||||||
required: false,
|
|
||||||
default: 'Long text',
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value="Long text" onChange={onChange} />)
|
|
||||||
|
|
||||||
expect(screen.getByRole('textbox')).toBeInTheDocument()
|
|
||||||
expect(screen.getByRole('textbox')).toHaveValue('Long text')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('renders Select for select type', () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_select_type',
|
|
||||||
type: 'select',
|
|
||||||
label: 'Test Select Type',
|
|
||||||
description: 'A test select type',
|
|
||||||
required: false,
|
|
||||||
options: ['A', 'B', 'C'],
|
|
||||||
default: 'A',
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value="A" onChange={onChange} />)
|
|
||||||
|
|
||||||
expect(screen.getByRole('combobox')).toBeInTheDocument()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('renders textarea editor for primitive array type', () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_array',
|
|
||||||
type: 'array',
|
|
||||||
label: 'Test Array',
|
|
||||||
description: 'A test array',
|
|
||||||
required: false,
|
|
||||||
items: {
|
|
||||||
type: 'string',
|
|
||||||
},
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value={['a', 'b']} onChange={onChange} />)
|
|
||||||
|
|
||||||
expect(screen.getByRole('textbox')).toHaveValue('a\nb')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('renders key-value editor for object type', () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_object',
|
|
||||||
type: 'object',
|
|
||||||
label: 'Test Object',
|
|
||||||
description: 'A test object',
|
|
||||||
required: false,
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value={{ foo: 'bar' }} onChange={onChange} />)
|
|
||||||
|
|
||||||
expect(screen.getByText('可视化编辑')).toBeInTheDocument()
|
|
||||||
expect(screen.getByDisplayValue('foo')).toBeInTheDocument()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('onChange events', () => {
|
|
||||||
it('triggers onChange for Switch', async () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_switch',
|
|
||||||
type: 'boolean',
|
|
||||||
label: 'Test Switch',
|
|
||||||
description: 'A test switch',
|
|
||||||
required: false,
|
|
||||||
default: false,
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
const user = userEvent.setup()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value={false} onChange={onChange} />)
|
|
||||||
|
|
||||||
await user.click(screen.getByRole('switch'))
|
|
||||||
expect(onChange).toHaveBeenCalledWith(true)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('triggers onChange for Input', async () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_input',
|
|
||||||
type: 'string',
|
|
||||||
label: 'Test Input',
|
|
||||||
description: 'A test input',
|
|
||||||
required: false,
|
|
||||||
default: '',
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value="" onChange={onChange} />)
|
|
||||||
|
|
||||||
const input = screen.getByRole('textbox')
|
|
||||||
input.focus()
|
|
||||||
await userEvent.keyboard('Hello')
|
|
||||||
|
|
||||||
expect(onChange).toHaveBeenCalledTimes(5)
|
|
||||||
expect(onChange).toHaveBeenNthCalledWith(1, 'H')
|
|
||||||
expect(onChange).toHaveBeenNthCalledWith(2, 'e')
|
|
||||||
expect(onChange).toHaveBeenNthCalledWith(3, 'l')
|
|
||||||
expect(onChange).toHaveBeenNthCalledWith(4, 'l')
|
|
||||||
expect(onChange).toHaveBeenNthCalledWith(5, 'o')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('triggers onChange for number Input', async () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_number',
|
|
||||||
type: 'number',
|
|
||||||
label: 'Test Number',
|
|
||||||
description: 'A test number',
|
|
||||||
required: false,
|
|
||||||
default: 0,
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
const user = userEvent.setup()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value={0} onChange={onChange} />)
|
|
||||||
|
|
||||||
const input = screen.getByRole('spinbutton')
|
|
||||||
await user.clear(input)
|
|
||||||
await user.type(input, '123')
|
|
||||||
expect(onChange).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('triggers numeric onChange for input widget with integer type', async () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_integer_input_widget_change',
|
|
||||||
type: 'integer',
|
|
||||||
label: 'Test Integer Input Widget Change',
|
|
||||||
description: 'A numeric field rendered as input',
|
|
||||||
required: false,
|
|
||||||
'x-widget': 'input',
|
|
||||||
default: 0,
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
const user = userEvent.setup()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value={0} onChange={onChange} />)
|
|
||||||
|
|
||||||
const input = screen.getByRole('spinbutton')
|
|
||||||
await user.clear(input)
|
|
||||||
await user.type(input, '5')
|
|
||||||
expect(onChange).toHaveBeenLastCalledWith(5)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('visual features', () => {
|
|
||||||
it('renders label with icon', () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_icon',
|
|
||||||
type: 'string',
|
|
||||||
label: 'Test Icon',
|
|
||||||
description: 'A test with icon',
|
|
||||||
required: false,
|
|
||||||
'x-icon': 'Settings',
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value="" onChange={onChange} />)
|
|
||||||
|
|
||||||
expect(screen.getByText('Test Icon')).toBeInTheDocument()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('renders required indicator', () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_required',
|
|
||||||
type: 'string',
|
|
||||||
label: 'Test Required',
|
|
||||||
description: 'A required field',
|
|
||||||
required: true,
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value="" onChange={onChange} />)
|
|
||||||
|
|
||||||
expect(screen.getByText('*')).toBeInTheDocument()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('renders description', () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_desc',
|
|
||||||
type: 'string',
|
|
||||||
label: 'Test Description',
|
|
||||||
description: 'This is a description',
|
|
||||||
required: false,
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value="" onChange={onChange} />)
|
|
||||||
|
|
||||||
expect(screen.getByText('This is a description')).toBeInTheDocument()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('slider features', () => {
|
|
||||||
it('renders slider with min/max/step', () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_slider_props',
|
|
||||||
type: 'number',
|
|
||||||
label: 'Test Slider Props',
|
|
||||||
description: 'A slider with props',
|
|
||||||
required: false,
|
|
||||||
'x-widget': 'slider',
|
|
||||||
minValue: 10,
|
|
||||||
maxValue: 50,
|
|
||||||
step: 5,
|
|
||||||
default: 25,
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value={25} onChange={onChange} />)
|
|
||||||
|
|
||||||
expect(screen.getByText('10')).toBeInTheDocument()
|
|
||||||
expect(screen.getByText('50')).toBeInTheDocument()
|
|
||||||
expect(screen.getByText('25')).toBeInTheDocument()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('parses string values for slider widgets', () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_slider_string_value',
|
|
||||||
type: 'number',
|
|
||||||
label: 'Test Slider String Value',
|
|
||||||
description: 'A slider with legacy string value',
|
|
||||||
required: false,
|
|
||||||
'x-widget': 'slider',
|
|
||||||
minValue: 0,
|
|
||||||
maxValue: 10,
|
|
||||||
default: 0,
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value="2.5" onChange={onChange} />)
|
|
||||||
|
|
||||||
expect(screen.getByText('2.5')).toBeInTheDocument()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('select features', () => {
|
|
||||||
it('renders placeholder when no options', () => {
|
|
||||||
const schema: FieldSchema = {
|
|
||||||
name: 'test_select_no_options',
|
|
||||||
type: 'string',
|
|
||||||
label: 'Test Select No Options',
|
|
||||||
description: 'A select with no options',
|
|
||||||
required: false,
|
|
||||||
'x-widget': 'select',
|
|
||||||
}
|
|
||||||
const onChange = vi.fn()
|
|
||||||
|
|
||||||
render(<DynamicField schema={schema} value="" onChange={onChange} />)
|
|
||||||
|
|
||||||
expect(screen.getByText('No options available for select')).toBeInTheDocument()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
@@ -1,253 +0,0 @@
|
|||||||
import { describe, it, expect, beforeEach } from 'vitest'
|
|
||||||
|
|
||||||
import { FieldHookRegistry } from '../field-hooks'
|
|
||||||
import type { FieldHookComponent } from '../field-hooks'
|
|
||||||
|
|
||||||
describe('FieldHookRegistry', () => {
|
|
||||||
let registry: FieldHookRegistry
|
|
||||||
|
|
||||||
beforeEach(() => {
|
|
||||||
registry = new FieldHookRegistry()
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('register', () => {
|
|
||||||
it('registers a hook with replace type', () => {
|
|
||||||
const component: FieldHookComponent = () => null
|
|
||||||
|
|
||||||
registry.register('test.field', component, 'replace')
|
|
||||||
|
|
||||||
expect(registry.has('test.field')).toBe(true)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('registers a hook with wrapper type', () => {
|
|
||||||
const component: FieldHookComponent = () => null
|
|
||||||
|
|
||||||
registry.register('test.field', component, 'wrapper')
|
|
||||||
|
|
||||||
expect(registry.has('test.field')).toBe(true)
|
|
||||||
const entry = registry.get('test.field')
|
|
||||||
expect(entry?.type).toBe('wrapper')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('defaults to replace type when not specified', () => {
|
|
||||||
const component: FieldHookComponent = () => null
|
|
||||||
|
|
||||||
registry.register('test.field', component)
|
|
||||||
|
|
||||||
const entry = registry.get('test.field')
|
|
||||||
expect(entry?.type).toBe('replace')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('overwrites existing hook for same field path', () => {
|
|
||||||
const component1: FieldHookComponent = () => null
|
|
||||||
const component2: FieldHookComponent = () => null
|
|
||||||
|
|
||||||
registry.register('test.field', component1, 'replace')
|
|
||||||
registry.register('test.field', component2, 'wrapper')
|
|
||||||
|
|
||||||
const entry = registry.get('test.field')
|
|
||||||
expect(entry?.component).toBe(component2)
|
|
||||||
expect(entry?.type).toBe('wrapper')
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('get', () => {
|
|
||||||
it('returns hook entry for registered field path', () => {
|
|
||||||
const component: FieldHookComponent = () => null
|
|
||||||
|
|
||||||
registry.register('test.field', component, 'replace')
|
|
||||||
|
|
||||||
const entry = registry.get('test.field')
|
|
||||||
expect(entry).toBeDefined()
|
|
||||||
expect(entry?.component).toBe(component)
|
|
||||||
expect(entry?.type).toBe('replace')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('returns undefined for unregistered field path', () => {
|
|
||||||
const entry = registry.get('nonexistent.field')
|
|
||||||
expect(entry).toBeUndefined()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('returns correct entry for nested field paths', () => {
|
|
||||||
const component: FieldHookComponent = () => null
|
|
||||||
|
|
||||||
registry.register('config.section.field', component, 'wrapper')
|
|
||||||
|
|
||||||
const entry = registry.get('config.section.field')
|
|
||||||
expect(entry).toBeDefined()
|
|
||||||
expect(entry?.type).toBe('wrapper')
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('has', () => {
|
|
||||||
it('returns true for registered field path', () => {
|
|
||||||
const component: FieldHookComponent = () => null
|
|
||||||
|
|
||||||
registry.register('test.field', component)
|
|
||||||
|
|
||||||
expect(registry.has('test.field')).toBe(true)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('returns false for unregistered field path', () => {
|
|
||||||
expect(registry.has('nonexistent.field')).toBe(false)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('returns false after unregistering', () => {
|
|
||||||
const component: FieldHookComponent = () => null
|
|
||||||
|
|
||||||
registry.register('test.field', component)
|
|
||||||
registry.unregister('test.field')
|
|
||||||
|
|
||||||
expect(registry.has('test.field')).toBe(false)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('unregister', () => {
|
|
||||||
it('removes a registered hook', () => {
|
|
||||||
const component: FieldHookComponent = () => null
|
|
||||||
|
|
||||||
registry.register('test.field', component)
|
|
||||||
expect(registry.has('test.field')).toBe(true)
|
|
||||||
|
|
||||||
registry.unregister('test.field')
|
|
||||||
expect(registry.has('test.field')).toBe(false)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('does not throw when unregistering non-existent hook', () => {
|
|
||||||
expect(() => registry.unregister('nonexistent.field')).not.toThrow()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('only removes specified hook, not others', () => {
|
|
||||||
const component1: FieldHookComponent = () => null
|
|
||||||
const component2: FieldHookComponent = () => null
|
|
||||||
|
|
||||||
registry.register('field1', component1)
|
|
||||||
registry.register('field2', component2)
|
|
||||||
|
|
||||||
registry.unregister('field1')
|
|
||||||
|
|
||||||
expect(registry.has('field1')).toBe(false)
|
|
||||||
expect(registry.has('field2')).toBe(true)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('clear', () => {
|
|
||||||
it('removes all registered hooks', () => {
|
|
||||||
const component1: FieldHookComponent = () => null
|
|
||||||
const component2: FieldHookComponent = () => null
|
|
||||||
const component3: FieldHookComponent = () => null
|
|
||||||
|
|
||||||
registry.register('field1', component1)
|
|
||||||
registry.register('field2', component2)
|
|
||||||
registry.register('field3', component3)
|
|
||||||
|
|
||||||
expect(registry.getAllPaths()).toHaveLength(3)
|
|
||||||
|
|
||||||
registry.clear()
|
|
||||||
|
|
||||||
expect(registry.getAllPaths()).toHaveLength(0)
|
|
||||||
expect(registry.has('field1')).toBe(false)
|
|
||||||
expect(registry.has('field2')).toBe(false)
|
|
||||||
expect(registry.has('field3')).toBe(false)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('works correctly on empty registry', () => {
|
|
||||||
expect(() => registry.clear()).not.toThrow()
|
|
||||||
expect(registry.getAllPaths()).toHaveLength(0)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('getAllPaths', () => {
|
|
||||||
it('returns empty array when no hooks registered', () => {
|
|
||||||
expect(registry.getAllPaths()).toEqual([])
|
|
||||||
})
|
|
||||||
|
|
||||||
it('returns all registered field paths', () => {
|
|
||||||
const component: FieldHookComponent = () => null
|
|
||||||
|
|
||||||
registry.register('field1', component)
|
|
||||||
registry.register('field2', component)
|
|
||||||
registry.register('field3', component)
|
|
||||||
|
|
||||||
const paths = registry.getAllPaths()
|
|
||||||
expect(paths).toHaveLength(3)
|
|
||||||
expect(paths).toContain('field1')
|
|
||||||
expect(paths).toContain('field2')
|
|
||||||
expect(paths).toContain('field3')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('returns updated paths after unregister', () => {
|
|
||||||
const component: FieldHookComponent = () => null
|
|
||||||
|
|
||||||
registry.register('field1', component)
|
|
||||||
registry.register('field2', component)
|
|
||||||
registry.register('field3', component)
|
|
||||||
|
|
||||||
registry.unregister('field2')
|
|
||||||
|
|
||||||
const paths = registry.getAllPaths()
|
|
||||||
expect(paths).toHaveLength(2)
|
|
||||||
expect(paths).toContain('field1')
|
|
||||||
expect(paths).toContain('field3')
|
|
||||||
expect(paths).not.toContain('field2')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('handles nested field paths correctly', () => {
|
|
||||||
const component: FieldHookComponent = () => null
|
|
||||||
|
|
||||||
registry.register('config.chat.enabled', component)
|
|
||||||
registry.register('config.chat.model', component)
|
|
||||||
registry.register('config.api.key', component)
|
|
||||||
|
|
||||||
const paths = registry.getAllPaths()
|
|
||||||
expect(paths).toHaveLength(3)
|
|
||||||
expect(paths).toContain('config.chat.enabled')
|
|
||||||
expect(paths).toContain('config.chat.model')
|
|
||||||
expect(paths).toContain('config.api.key')
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('integration scenarios', () => {
|
|
||||||
it('supports full lifecycle of multiple hooks', () => {
|
|
||||||
const replaceComponent: FieldHookComponent = () => null
|
|
||||||
const wrapperComponent: FieldHookComponent = () => null
|
|
||||||
|
|
||||||
registry.register('field1', replaceComponent, 'replace')
|
|
||||||
registry.register('field2', wrapperComponent, 'wrapper')
|
|
||||||
|
|
||||||
expect(registry.getAllPaths()).toHaveLength(2)
|
|
||||||
|
|
||||||
const entry1 = registry.get('field1')
|
|
||||||
expect(entry1?.type).toBe('replace')
|
|
||||||
expect(entry1?.component).toBe(replaceComponent)
|
|
||||||
|
|
||||||
const entry2 = registry.get('field2')
|
|
||||||
expect(entry2?.type).toBe('wrapper')
|
|
||||||
expect(entry2?.component).toBe(wrapperComponent)
|
|
||||||
|
|
||||||
registry.unregister('field1')
|
|
||||||
expect(registry.getAllPaths()).toHaveLength(1)
|
|
||||||
expect(registry.has('field2')).toBe(true)
|
|
||||||
|
|
||||||
registry.clear()
|
|
||||||
expect(registry.getAllPaths()).toHaveLength(0)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('handles rapid register/unregister cycles', () => {
|
|
||||||
const component: FieldHookComponent = () => null
|
|
||||||
|
|
||||||
for (let i = 0; i < 100; i++) {
|
|
||||||
registry.register(`field${i}`, component)
|
|
||||||
}
|
|
||||||
expect(registry.getAllPaths()).toHaveLength(100)
|
|
||||||
|
|
||||||
for (let i = 0; i < 50; i++) {
|
|
||||||
registry.unregister(`field${i}`)
|
|
||||||
}
|
|
||||||
expect(registry.getAllPaths()).toHaveLength(50)
|
|
||||||
|
|
||||||
registry.clear()
|
|
||||||
expect(registry.getAllPaths()).toHaveLength(0)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
@@ -1,80 +0,0 @@
|
|||||||
import { render, screen } from '@testing-library/react'
|
|
||||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
|
||||||
import type { ReactNode } from 'react'
|
|
||||||
|
|
||||||
import { PluginConfigPage } from '../plugin-config'
|
|
||||||
import * as pluginApi from '@/lib/plugin-api'
|
|
||||||
|
|
||||||
const toastMock = vi.fn()
|
|
||||||
|
|
||||||
vi.mock('@/hooks/use-toast', () => ({
|
|
||||||
useToast: () => ({ toast: toastMock }),
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('@/lib/restart-context', () => ({
|
|
||||||
RestartProvider: ({ children }: { children: ReactNode }) => <>{children}</>,
|
|
||||||
useRestart: () => ({
|
|
||||||
showRestartPrompt: false,
|
|
||||||
markRestartRequired: vi.fn(),
|
|
||||||
clearRestartRequired: vi.fn(),
|
|
||||||
}),
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('@/components/restart-overlay', () => ({
|
|
||||||
RestartOverlay: () => null,
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('@/components', () => ({
|
|
||||||
CodeEditor: ({ value }: { value: string }) => <pre>{value}</pre>,
|
|
||||||
ListFieldEditor: () => <div>list-field-editor</div>,
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('@/lib/plugin-api', () => ({
|
|
||||||
getInstalledPlugins: vi.fn(),
|
|
||||||
getPluginConfigSchema: vi.fn(),
|
|
||||||
getPluginConfig: vi.fn(),
|
|
||||||
getPluginConfigRaw: vi.fn(),
|
|
||||||
updatePluginConfig: vi.fn(),
|
|
||||||
updatePluginConfigRaw: vi.fn(),
|
|
||||||
resetPluginConfig: vi.fn(),
|
|
||||||
togglePlugin: vi.fn(),
|
|
||||||
}))
|
|
||||||
|
|
||||||
describe('PluginConfigPage', () => {
|
|
||||||
beforeEach(() => {
|
|
||||||
toastMock.mockReset()
|
|
||||||
vi.mocked(pluginApi.getInstalledPlugins).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
data: [
|
|
||||||
{
|
|
||||||
id: 'test.emoji',
|
|
||||||
path: '/plugins/test_emoji',
|
|
||||||
manifest: {
|
|
||||||
manifest_version: 2,
|
|
||||||
name: 'Emoji Plugin',
|
|
||||||
version: '1.0.0',
|
|
||||||
description: 'emoji tools',
|
|
||||||
author: { name: 'tester' },
|
|
||||||
license: 'MIT',
|
|
||||||
host_application: { min_version: '1.0.0' },
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
})
|
|
||||||
vi.mocked(pluginApi.getPluginConfigSchema).mockResolvedValue({} as never)
|
|
||||||
vi.mocked(pluginApi.getPluginConfig).mockResolvedValue({} as never)
|
|
||||||
vi.mocked(pluginApi.getPluginConfigRaw).mockResolvedValue({} as never)
|
|
||||||
vi.mocked(pluginApi.updatePluginConfig).mockResolvedValue({} as never)
|
|
||||||
vi.mocked(pluginApi.updatePluginConfigRaw).mockResolvedValue({} as never)
|
|
||||||
vi.mocked(pluginApi.resetPluginConfig).mockResolvedValue({} as never)
|
|
||||||
vi.mocked(pluginApi.togglePlugin).mockResolvedValue({} as never)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('shows real plugins and no longer surfaces A_Memorix in plugin config list', async () => {
|
|
||||||
render(<PluginConfigPage />)
|
|
||||||
|
|
||||||
expect(await screen.findByText('Emoji Plugin')).toBeInTheDocument()
|
|
||||||
expect(screen.getByText('点击插件查看和编辑配置')).toBeInTheDocument()
|
|
||||||
expect(screen.queryByText(/A_Memorix/i)).not.toBeInTheDocument()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
@@ -1,802 +0,0 @@
|
|||||||
import { act, render, screen, waitFor, within } from '@testing-library/react'
|
|
||||||
import userEvent from '@testing-library/user-event'
|
|
||||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
|
||||||
|
|
||||||
import { KnowledgeBasePage } from '../knowledge-base'
|
|
||||||
import * as memoryApi from '@/lib/memory-api'
|
|
||||||
|
|
||||||
const navigateMock = vi.fn()
|
|
||||||
const toastMock = vi.fn()
|
|
||||||
|
|
||||||
vi.mock('@tanstack/react-router', () => ({
|
|
||||||
useNavigate: () => navigateMock,
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('@/hooks/use-toast', () => ({
|
|
||||||
useToast: () => ({ toast: toastMock }),
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('@/components', () => ({
|
|
||||||
CodeEditor: ({ value }: { value: string }) => <pre data-testid="code-editor">{value}</pre>,
|
|
||||||
MarkdownRenderer: ({ content }: { content: string }) => <div>{content}</div>,
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('@/components/memory/MemoryConfigEditor', () => ({
|
|
||||||
MemoryConfigEditor: () => <div data-testid="memory-config-editor">memory-config-editor</div>,
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('@/components/memory/MemoryDeleteDialog', () => ({
|
|
||||||
MemoryDeleteDialog: ({
|
|
||||||
open,
|
|
||||||
onExecute,
|
|
||||||
onRestore,
|
|
||||||
preview,
|
|
||||||
result,
|
|
||||||
}: {
|
|
||||||
open: boolean
|
|
||||||
preview?: { mode?: string; item_count?: number } | null
|
|
||||||
result?: { operation_id?: string } | null
|
|
||||||
onExecute?: () => void
|
|
||||||
onRestore?: () => void
|
|
||||||
}) => (
|
|
||||||
open ? (
|
|
||||||
<div data-testid="memory-delete-dialog">
|
|
||||||
<div>{`preview:${preview?.mode ?? 'none'}:${preview?.item_count ?? 0}`}</div>
|
|
||||||
<div>{`result:${result?.operation_id ?? 'none'}`}</div>
|
|
||||||
<button type="button" onClick={onExecute}>执行删除</button>
|
|
||||||
<button type="button" onClick={onRestore}>执行恢复</button>
|
|
||||||
</div>
|
|
||||||
) : null
|
|
||||||
),
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('@/lib/memory-api', () => ({
|
|
||||||
getMemoryConfigSchema: vi.fn(),
|
|
||||||
getMemoryConfig: vi.fn(),
|
|
||||||
getMemoryConfigRaw: vi.fn(),
|
|
||||||
getMemoryRuntimeConfig: vi.fn(),
|
|
||||||
getMemoryImportGuide: vi.fn(),
|
|
||||||
getMemoryImportSettings: vi.fn(),
|
|
||||||
getMemoryImportPathAliases: vi.fn(),
|
|
||||||
getMemoryImportTasks: vi.fn(),
|
|
||||||
getMemoryImportTask: vi.fn(),
|
|
||||||
getMemoryImportTaskChunks: vi.fn(),
|
|
||||||
createMemoryUploadImport: vi.fn(),
|
|
||||||
createMemoryPasteImport: vi.fn(),
|
|
||||||
createMemoryRawScanImport: vi.fn(),
|
|
||||||
createMemoryLpmmOpenieImport: vi.fn(),
|
|
||||||
createMemoryLpmmConvertImport: vi.fn(),
|
|
||||||
createMemoryTemporalBackfillImport: vi.fn(),
|
|
||||||
createMemoryMaibotMigrationImport: vi.fn(),
|
|
||||||
cancelMemoryImportTask: vi.fn(),
|
|
||||||
retryMemoryImportTask: vi.fn(),
|
|
||||||
resolveMemoryImportPath: vi.fn(),
|
|
||||||
refreshMemoryRuntimeSelfCheck: vi.fn(),
|
|
||||||
updateMemoryConfig: vi.fn(),
|
|
||||||
updateMemoryConfigRaw: vi.fn(),
|
|
||||||
getMemoryTuningProfile: vi.fn(),
|
|
||||||
getMemoryTuningTasks: vi.fn(),
|
|
||||||
createMemoryTuningTask: vi.fn(),
|
|
||||||
applyBestMemoryTuningProfile: vi.fn(),
|
|
||||||
getMemorySources: vi.fn(),
|
|
||||||
getMemoryDeleteOperations: vi.fn(),
|
|
||||||
getMemoryDeleteOperation: vi.fn(),
|
|
||||||
getMemoryFeedbackCorrections: vi.fn(),
|
|
||||||
getMemoryFeedbackCorrection: vi.fn(),
|
|
||||||
previewMemoryDelete: vi.fn(),
|
|
||||||
executeMemoryDelete: vi.fn(),
|
|
||||||
restoreMemoryDelete: vi.fn(),
|
|
||||||
rollbackMemoryFeedbackCorrection: vi.fn(),
|
|
||||||
}))
|
|
||||||
|
|
||||||
function mockImportTask(taskId: string, status: string = 'running'): memoryApi.MemoryImportTaskPayload {
|
|
||||||
return {
|
|
||||||
task_id: taskId,
|
|
||||||
source: 'webui',
|
|
||||||
status,
|
|
||||||
current_step: status === 'completed' ? 'completed' : 'running',
|
|
||||||
total_chunks: 120,
|
|
||||||
done_chunks: status === 'completed' ? 120 : 36,
|
|
||||||
failed_chunks: status === 'completed' ? 0 : 2,
|
|
||||||
cancelled_chunks: 0,
|
|
||||||
progress: status === 'completed' ? 100 : 30,
|
|
||||||
error: '',
|
|
||||||
file_count: 2,
|
|
||||||
created_at: 1_710_000_000,
|
|
||||||
started_at: 1_710_000_001,
|
|
||||||
finished_at: status === 'completed' ? 1_710_000_099 : null,
|
|
||||||
updated_at: 1_710_000_100,
|
|
||||||
task_kind: 'paste',
|
|
||||||
params: {},
|
|
||||||
files: [],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function mockImportDetail(taskId: string): memoryApi.MemoryImportTaskPayload {
|
|
||||||
return {
|
|
||||||
...mockImportTask(taskId),
|
|
||||||
files: [
|
|
||||||
{
|
|
||||||
file_id: 'file-alpha',
|
|
||||||
name: 'alpha.txt',
|
|
||||||
source_kind: 'paste',
|
|
||||||
input_mode: 'text',
|
|
||||||
status: 'running',
|
|
||||||
current_step: 'running',
|
|
||||||
detected_strategy_type: 'auto',
|
|
||||||
total_chunks: 80,
|
|
||||||
done_chunks: 30,
|
|
||||||
failed_chunks: 1,
|
|
||||||
cancelled_chunks: 0,
|
|
||||||
progress: 37.5,
|
|
||||||
error: '',
|
|
||||||
created_at: 1_710_000_000,
|
|
||||||
updated_at: 1_710_000_100,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
file_id: 'file-beta',
|
|
||||||
name: 'beta.txt',
|
|
||||||
source_kind: 'paste',
|
|
||||||
input_mode: 'text',
|
|
||||||
status: 'failed',
|
|
||||||
current_step: 'extracting',
|
|
||||||
detected_strategy_type: 'auto',
|
|
||||||
total_chunks: 40,
|
|
||||||
done_chunks: 6,
|
|
||||||
failed_chunks: 4,
|
|
||||||
cancelled_chunks: 0,
|
|
||||||
progress: 25,
|
|
||||||
error: 'mock error',
|
|
||||||
created_at: 1_710_000_000,
|
|
||||||
updated_at: 1_710_000_100,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function mockImportCompletedWithErrorsDetail(taskId: string): memoryApi.MemoryImportTaskPayload {
|
|
||||||
return {
|
|
||||||
...mockImportDetail(taskId),
|
|
||||||
status: 'completed_with_errors',
|
|
||||||
current_step: 'completed_with_errors',
|
|
||||||
total_chunks: 12,
|
|
||||||
done_chunks: 9,
|
|
||||||
failed_chunks: 3,
|
|
||||||
cancelled_chunks: 0,
|
|
||||||
progress: 75,
|
|
||||||
files: [
|
|
||||||
{
|
|
||||||
file_id: 'file-error',
|
|
||||||
name: 'error.txt',
|
|
||||||
source_kind: 'paste',
|
|
||||||
input_mode: 'text',
|
|
||||||
status: 'failed',
|
|
||||||
current_step: 'failed',
|
|
||||||
detected_strategy_type: 'auto',
|
|
||||||
total_chunks: 12,
|
|
||||||
done_chunks: 9,
|
|
||||||
failed_chunks: 3,
|
|
||||||
cancelled_chunks: 0,
|
|
||||||
progress: 75,
|
|
||||||
error: 'mock error',
|
|
||||||
created_at: 1_710_000_000,
|
|
||||||
updated_at: 1_710_000_100,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
describe('KnowledgeBasePage import workflow', () => {
|
|
||||||
beforeEach(() => {
|
|
||||||
navigateMock.mockReset()
|
|
||||||
toastMock.mockReset()
|
|
||||||
|
|
||||||
vi.mocked(memoryApi.getMemoryConfigSchema).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
path: 'config/a_memorix.toml',
|
|
||||||
schema: {
|
|
||||||
plugin_id: 'a_memorix',
|
|
||||||
plugin_info: {
|
|
||||||
name: 'A_Memorix',
|
|
||||||
version: '2.0.0',
|
|
||||||
description: '长期记忆子系统',
|
|
||||||
author: 'A_Dawn',
|
|
||||||
},
|
|
||||||
_note: 'raw-only 字段仍可通过 TOML 编辑',
|
|
||||||
layout: {
|
|
||||||
type: 'tabs',
|
|
||||||
tabs: [{ id: 'basic', title: '基础', sections: ['plugin'], order: 1 }],
|
|
||||||
},
|
|
||||||
sections: {
|
|
||||||
plugin: {
|
|
||||||
name: 'plugin',
|
|
||||||
title: '子系统状态',
|
|
||||||
collapsed: false,
|
|
||||||
order: 1,
|
|
||||||
fields: {},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.getMemoryConfig).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
path: 'config/a_memorix.toml',
|
|
||||||
config: { plugin: { enabled: true } },
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.getMemoryConfigRaw).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
path: 'config/a_memorix.toml',
|
|
||||||
config: '[plugin]\nenabled = true\n',
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.getMemoryRuntimeConfig).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
config: { plugin: { enabled: true } },
|
|
||||||
data_dir: 'data/plugins/a-dawn.a-memorix',
|
|
||||||
embedding_dimension: 1024,
|
|
||||||
auto_save: true,
|
|
||||||
relation_vectors_enabled: false,
|
|
||||||
runtime_ready: true,
|
|
||||||
embedding_degraded: false,
|
|
||||||
embedding_degraded_reason: '',
|
|
||||||
embedding_degraded_since: null,
|
|
||||||
embedding_last_check: null,
|
|
||||||
paragraph_vector_backfill_pending: 2,
|
|
||||||
paragraph_vector_backfill_running: 0,
|
|
||||||
paragraph_vector_backfill_failed: 1,
|
|
||||||
paragraph_vector_backfill_done: 3,
|
|
||||||
})
|
|
||||||
|
|
||||||
vi.mocked(memoryApi.getMemoryImportGuide).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
content: '# 导入指南\n导入说明',
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.getMemoryImportSettings).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
settings: {
|
|
||||||
max_paste_chars: 200_000,
|
|
||||||
max_file_concurrency: 8,
|
|
||||||
max_chunk_concurrency: 16,
|
|
||||||
default_file_concurrency: 2,
|
|
||||||
default_chunk_concurrency: 4,
|
|
||||||
poll_interval_ms: 60_000,
|
|
||||||
maibot_source_db_default: 'data/maibot.db',
|
|
||||||
},
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.getMemoryImportPathAliases).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
path_aliases: {
|
|
||||||
lpmm: 'data/lpmm',
|
|
||||||
plugin_data: 'data/plugins/a-dawn.a-memorix',
|
|
||||||
raw: 'data/raw',
|
|
||||||
},
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.getMemoryImportTasks).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
items: [
|
|
||||||
mockImportTask('import-run-1', 'running'),
|
|
||||||
mockImportTask('import-queued-1', 'queued'),
|
|
||||||
mockImportTask('import-done-1', 'completed'),
|
|
||||||
],
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.getMemoryImportTask).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
task: mockImportDetail('import-run-1'),
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.getMemoryImportTaskChunks).mockImplementation(async (_taskId, fileId, offset = 0) => ({
|
|
||||||
success: true,
|
|
||||||
task_id: 'import-run-1',
|
|
||||||
file_id: fileId,
|
|
||||||
offset,
|
|
||||||
limit: 50,
|
|
||||||
total: 120,
|
|
||||||
items: [
|
|
||||||
{
|
|
||||||
chunk_id: `${fileId}-${offset + 0}`,
|
|
||||||
index: offset + 0,
|
|
||||||
chunk_type: 'text',
|
|
||||||
status: 'running',
|
|
||||||
step: 'extracting',
|
|
||||||
failed_at: '',
|
|
||||||
retryable: true,
|
|
||||||
error: '',
|
|
||||||
progress: 50,
|
|
||||||
content_preview: `chunk-preview-${offset + 0}`,
|
|
||||||
updated_at: 1_710_000_111,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mocked(memoryApi.createMemoryUploadImport).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
task: mockImportTask('upload-task-1', 'queued'),
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.createMemoryPasteImport).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
task: mockImportTask('paste-task-1', 'queued'),
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.createMemoryRawScanImport).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
task: mockImportTask('raw-task-1', 'queued'),
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.createMemoryLpmmOpenieImport).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
task: mockImportTask('openie-task-1', 'queued'),
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.createMemoryLpmmConvertImport).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
task: mockImportTask('convert-task-1', 'queued'),
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.createMemoryTemporalBackfillImport).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
task: mockImportTask('backfill-task-1', 'queued'),
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.createMemoryMaibotMigrationImport).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
task: mockImportTask('migration-task-1', 'queued'),
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.cancelMemoryImportTask).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
task: mockImportTask('import-run-1', 'cancel_requested'),
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.retryMemoryImportTask).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
task: mockImportTask('retry-task-1', 'queued'),
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.resolveMemoryImportPath).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
alias: 'raw',
|
|
||||||
relative_path: 'exports',
|
|
||||||
resolved_path: 'D:/Dev/rdev/MaiBot/data/raw/exports',
|
|
||||||
exists: true,
|
|
||||||
is_file: false,
|
|
||||||
is_dir: true,
|
|
||||||
})
|
|
||||||
|
|
||||||
vi.mocked(memoryApi.getMemoryTuningProfile).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
profile: { retrieval: { top_k: 10 } },
|
|
||||||
toml: '[retrieval]\ntop_k = 10\n',
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.getMemoryTuningTasks).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
items: [{ task_id: 'tune-1', status: 'done' }],
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.createMemoryTuningTask).mockResolvedValue({ success: true } as never)
|
|
||||||
vi.mocked(memoryApi.applyBestMemoryTuningProfile).mockResolvedValue({ success: true } as never)
|
|
||||||
|
|
||||||
vi.mocked(memoryApi.getMemorySources).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
items: [{ source: 'demo-1', paragraph_count: 2, relation_count: 1 }],
|
|
||||||
count: 1,
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.getMemoryDeleteOperations).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
items: [
|
|
||||||
{
|
|
||||||
operation_id: 'del-1',
|
|
||||||
mode: 'source',
|
|
||||||
status: 'executed',
|
|
||||||
summary: { counts: { paragraphs: 2, relations: 1, sources: 1 } },
|
|
||||||
},
|
|
||||||
],
|
|
||||||
count: 1,
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.getMemoryDeleteOperation).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
operation: {
|
|
||||||
operation_id: 'del-1',
|
|
||||||
mode: 'source',
|
|
||||||
status: 'executed',
|
|
||||||
selector: { sources: ['demo-1'] },
|
|
||||||
summary: { counts: { paragraphs: 2, relations: 1, sources: 1 }, sources: ['demo-1'] },
|
|
||||||
items: [],
|
|
||||||
},
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.getMemoryFeedbackCorrections).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
items: [
|
|
||||||
{
|
|
||||||
task_id: 11,
|
|
||||||
query_tool_id: 'tool-query-11',
|
|
||||||
session_id: 'session-1',
|
|
||||||
query_text: '测试用户最喜欢的颜色是什么',
|
|
||||||
query_timestamp: 1_710_000_010,
|
|
||||||
task_status: 'applied',
|
|
||||||
decision: 'correct',
|
|
||||||
decision_confidence: 0.97,
|
|
||||||
feedback_message_count: 1,
|
|
||||||
rollback_status: 'none',
|
|
||||||
affected_counts: {
|
|
||||||
relations: 1,
|
|
||||||
stale_paragraphs: 1,
|
|
||||||
episode_sources: 2,
|
|
||||||
profile_person_ids: 1,
|
|
||||||
correction_paragraphs: 1,
|
|
||||||
corrected_relations: 1,
|
|
||||||
},
|
|
||||||
created_at: 1_710_000_011,
|
|
||||||
updated_at: 1_710_000_012,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
count: 1,
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.getMemoryFeedbackCorrection).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
task: {
|
|
||||||
task_id: 11,
|
|
||||||
query_tool_id: 'tool-query-11',
|
|
||||||
session_id: 'session-1',
|
|
||||||
query_text: '测试用户最喜欢的颜色是什么',
|
|
||||||
query_timestamp: 1_710_000_010,
|
|
||||||
task_status: 'applied',
|
|
||||||
decision: 'correct',
|
|
||||||
decision_confidence: 0.97,
|
|
||||||
feedback_message_count: 1,
|
|
||||||
rollback_status: 'none',
|
|
||||||
affected_counts: {
|
|
||||||
relations: 1,
|
|
||||||
stale_paragraphs: 1,
|
|
||||||
episode_sources: 2,
|
|
||||||
profile_person_ids: 1,
|
|
||||||
correction_paragraphs: 1,
|
|
||||||
corrected_relations: 1,
|
|
||||||
},
|
|
||||||
query_snapshot: { query: '测试用户最喜欢的颜色是什么', hits: [{ hash: 'paragraph-1' }] },
|
|
||||||
decision_payload: { decision: 'correct', confidence: 0.97 },
|
|
||||||
rollback_plan_summary: {
|
|
||||||
forgotten_relations: [{ hash: 'rel-old', subject: '测试用户', predicate: '最喜欢的颜色是', object: '蓝色' }],
|
|
||||||
corrected_write: {
|
|
||||||
paragraph_hashes: ['paragraph-new'],
|
|
||||||
corrected_relations: [{ hash: 'rel-new', subject: '测试用户', predicate: '最喜欢的颜色是', object: '绿色' }],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
rollback_result: {},
|
|
||||||
action_logs: [
|
|
||||||
{
|
|
||||||
id: 1,
|
|
||||||
task_id: 11,
|
|
||||||
query_tool_id: 'tool-query-11',
|
|
||||||
action_type: 'forget_relation',
|
|
||||||
target_hash: 'rel-old',
|
|
||||||
reason: '用户明确纠正为绿色',
|
|
||||||
before_payload: { hash: 'rel-old', subject: '测试用户', predicate: '最喜欢的颜色是', object: '蓝色' },
|
|
||||||
after_payload: { is_inactive: true },
|
|
||||||
created_at: 1_710_000_013,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
created_at: 1_710_000_011,
|
|
||||||
updated_at: 1_710_000_012,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.previewMemoryDelete).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
mode: 'source',
|
|
||||||
selector: { sources: ['demo-1'] },
|
|
||||||
counts: { sources: 1, paragraphs: 2, relations: 1 },
|
|
||||||
sources: ['demo-1'],
|
|
||||||
items: [{ item_type: 'paragraph', item_hash: 'p-1', label: 'demo-1' }],
|
|
||||||
item_count: 1,
|
|
||||||
dry_run: true,
|
|
||||||
} as never)
|
|
||||||
vi.mocked(memoryApi.executeMemoryDelete).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
mode: 'source',
|
|
||||||
operation_id: 'del-2',
|
|
||||||
counts: { sources: 1, paragraphs: 2, relations: 1 },
|
|
||||||
sources: ['demo-1'],
|
|
||||||
deleted_count: 4,
|
|
||||||
deleted_entity_count: 0,
|
|
||||||
deleted_relation_count: 1,
|
|
||||||
deleted_paragraph_count: 2,
|
|
||||||
deleted_source_count: 1,
|
|
||||||
} as never)
|
|
||||||
vi.mocked(memoryApi.restoreMemoryDelete).mockResolvedValue({ success: true } as never)
|
|
||||||
vi.mocked(memoryApi.rollbackMemoryFeedbackCorrection).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
result: { restored_relation_hashes: ['rel-old'] },
|
|
||||||
task: {
|
|
||||||
task_id: 11,
|
|
||||||
query_tool_id: 'tool-query-11',
|
|
||||||
session_id: 'session-1',
|
|
||||||
query_text: '测试用户最喜欢的颜色是什么',
|
|
||||||
query_timestamp: 1_710_000_010,
|
|
||||||
task_status: 'applied',
|
|
||||||
decision: 'correct',
|
|
||||||
decision_confidence: 0.97,
|
|
||||||
feedback_message_count: 1,
|
|
||||||
rollback_status: 'rolled_back',
|
|
||||||
affected_counts: {
|
|
||||||
relations: 1,
|
|
||||||
stale_paragraphs: 1,
|
|
||||||
episode_sources: 2,
|
|
||||||
profile_person_ids: 1,
|
|
||||||
correction_paragraphs: 1,
|
|
||||||
corrected_relations: 1,
|
|
||||||
},
|
|
||||||
query_snapshot: { query: '测试用户最喜欢的颜色是什么', hits: [{ hash: 'paragraph-1' }] },
|
|
||||||
decision_payload: { decision: 'correct', confidence: 0.97 },
|
|
||||||
rollback_plan_summary: {},
|
|
||||||
rollback_result: { restored_relation_hashes: ['rel-old'] },
|
|
||||||
action_logs: [],
|
|
||||||
created_at: 1_710_000_011,
|
|
||||||
updated_at: 1_710_000_012,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.refreshMemoryRuntimeSelfCheck).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
report: { ok: true },
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.updateMemoryConfig).mockResolvedValue({ success: true } as never)
|
|
||||||
vi.mocked(memoryApi.updateMemoryConfigRaw).mockResolvedValue({ success: true } as never)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('loads import settings/guide/tasks on first render', async () => {
|
|
||||||
const user = userEvent.setup()
|
|
||||||
render(<KnowledgeBasePage />)
|
|
||||||
|
|
||||||
expect(await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })).toBeInTheDocument()
|
|
||||||
await user.click(screen.getByRole('tab', { name: '导入' }))
|
|
||||||
|
|
||||||
expect(await screen.findByRole('button', { name: '创建导入任务' })).toBeInTheDocument()
|
|
||||||
expect((await screen.findAllByText('import-run-1')).length).toBeGreaterThan(0)
|
|
||||||
expect(memoryApi.getMemoryImportSettings).toHaveBeenCalled()
|
|
||||||
expect(memoryApi.getMemoryImportPathAliases).toHaveBeenCalled()
|
|
||||||
expect(memoryApi.getMemoryImportTasks).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('creates import tasks for all 7 modes and calls correct endpoints', async () => {
|
|
||||||
const user = userEvent.setup()
|
|
||||||
const { container } = render(<KnowledgeBasePage />)
|
|
||||||
|
|
||||||
const openImportTab = async () => {
|
|
||||||
await user.click(screen.getByRole('tab', { name: '导入' }))
|
|
||||||
await screen.findByRole('button', { name: '创建导入任务' })
|
|
||||||
}
|
|
||||||
|
|
||||||
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
|
|
||||||
await openImportTab()
|
|
||||||
|
|
||||||
const fileInput = container.querySelector('input[type="file"]') as HTMLInputElement
|
|
||||||
const uploadFiles = [
|
|
||||||
new File(['hello'], 'demo.txt', { type: 'text/plain' }),
|
|
||||||
new File(['{"name":"mai"}'], 'demo.json', { type: 'application/json' }),
|
|
||||||
new File(['a,b\n1,2'], 'demo.csv', { type: 'text/csv' }),
|
|
||||||
new File(['# note'], 'demo.md', { type: 'text/markdown' }),
|
|
||||||
]
|
|
||||||
await user.upload(fileInput, uploadFiles)
|
|
||||||
await user.click(screen.getByRole('button', { name: '创建导入任务' }))
|
|
||||||
await waitFor(() => expect(memoryApi.createMemoryUploadImport).toHaveBeenCalledTimes(1))
|
|
||||||
|
|
||||||
await openImportTab()
|
|
||||||
await user.click(screen.getByRole('tab', { name: '粘贴导入' }))
|
|
||||||
const editableTextarea = Array.from(container.querySelectorAll('textarea')).find((item) => !item.readOnly)
|
|
||||||
if (!editableTextarea) {
|
|
||||||
throw new Error('missing editable textarea')
|
|
||||||
}
|
|
||||||
await user.type(editableTextarea, 'paste content')
|
|
||||||
await user.click(screen.getByRole('button', { name: '创建导入任务' }))
|
|
||||||
await waitFor(() => expect(memoryApi.createMemoryPasteImport).toHaveBeenCalledTimes(1))
|
|
||||||
|
|
||||||
await openImportTab()
|
|
||||||
await user.click(screen.getByRole('tab', { name: '本地扫描' }))
|
|
||||||
await user.click(screen.getByRole('button', { name: '创建导入任务' }))
|
|
||||||
await waitFor(() => expect(memoryApi.createMemoryRawScanImport).toHaveBeenCalledTimes(1))
|
|
||||||
|
|
||||||
await openImportTab()
|
|
||||||
await user.click(screen.getByRole('tab', { name: 'LPMM OpenIE' }))
|
|
||||||
await user.click(screen.getByRole('button', { name: '创建导入任务' }))
|
|
||||||
await waitFor(() => expect(memoryApi.createMemoryLpmmOpenieImport).toHaveBeenCalledTimes(1))
|
|
||||||
|
|
||||||
await openImportTab()
|
|
||||||
await user.click(screen.getByRole('tab', { name: 'LPMM 转换' }))
|
|
||||||
await user.click(screen.getByRole('button', { name: '创建导入任务' }))
|
|
||||||
await waitFor(() => expect(memoryApi.createMemoryLpmmConvertImport).toHaveBeenCalledTimes(1))
|
|
||||||
|
|
||||||
await openImportTab()
|
|
||||||
await user.click(screen.getByRole('tab', { name: '时序回填' }))
|
|
||||||
await user.click(screen.getByRole('button', { name: '创建导入任务' }))
|
|
||||||
await waitFor(() => expect(memoryApi.createMemoryTemporalBackfillImport).toHaveBeenCalledTimes(1))
|
|
||||||
|
|
||||||
await openImportTab()
|
|
||||||
await user.click(screen.getByRole('tab', { name: 'MaiBot 迁移' }))
|
|
||||||
await user.click(screen.getByRole('button', { name: '创建导入任务' }))
|
|
||||||
await waitFor(() => expect(memoryApi.createMemoryMaibotMigrationImport).toHaveBeenCalledTimes(1))
|
|
||||||
|
|
||||||
const [uploadedFiles, uploadPayload] = vi.mocked(memoryApi.createMemoryUploadImport).mock.calls[0]
|
|
||||||
expect(uploadedFiles).toHaveLength(4)
|
|
||||||
expect(uploadedFiles.map((file) => file.name)).toEqual(['demo.txt', 'demo.json', 'demo.csv', 'demo.md'])
|
|
||||||
expect(uploadPayload).toMatchObject({
|
|
||||||
input_mode: 'text',
|
|
||||||
llm_enabled: true,
|
|
||||||
strategy_override: 'auto',
|
|
||||||
dedupe_policy: 'content_hash',
|
|
||||||
})
|
|
||||||
}, 60_000)
|
|
||||||
|
|
||||||
it('loads task detail and supports chunk pagination', async () => {
|
|
||||||
const user = userEvent.setup()
|
|
||||||
render(<KnowledgeBasePage />)
|
|
||||||
|
|
||||||
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
|
|
||||||
await user.click(screen.getByRole('tab', { name: '导入' }))
|
|
||||||
|
|
||||||
expect(await screen.findByText('alpha.txt')).toBeInTheDocument()
|
|
||||||
expect(await screen.findByText('chunk-preview-0')).toBeInTheDocument()
|
|
||||||
|
|
||||||
const betaButton = screen.getByText('beta.txt').closest('button')
|
|
||||||
if (!betaButton) {
|
|
||||||
throw new Error('missing file beta button')
|
|
||||||
}
|
|
||||||
await user.click(betaButton)
|
|
||||||
await waitFor(() =>
|
|
||||||
expect(memoryApi.getMemoryImportTaskChunks).toHaveBeenCalledWith('import-run-1', 'file-beta', 0, 50),
|
|
||||||
)
|
|
||||||
|
|
||||||
await user.click(screen.getByRole('button', { name: '下一页分块' }))
|
|
||||||
await waitFor(() =>
|
|
||||||
expect(memoryApi.getMemoryImportTaskChunks).toHaveBeenCalledWith('import-run-1', 'file-beta', 50, 50),
|
|
||||||
)
|
|
||||||
}, 20_000)
|
|
||||||
|
|
||||||
it('shows import failures separately from successful chunks', async () => {
|
|
||||||
vi.mocked(memoryApi.getMemoryImportTask).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
task: mockImportCompletedWithErrorsDetail('import-run-1'),
|
|
||||||
})
|
|
||||||
const user = userEvent.setup()
|
|
||||||
render(<KnowledgeBasePage />)
|
|
||||||
|
|
||||||
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
|
|
||||||
await user.click(screen.getByRole('tab', { name: '导入' }))
|
|
||||||
|
|
||||||
expect((await screen.findAllByText('完成(有错误)')).length).toBeGreaterThan(0)
|
|
||||||
expect(await screen.findByText('成功 9 / 12 分块 · 失败 3')).toBeInTheDocument()
|
|
||||||
}, 20_000)
|
|
||||||
|
|
||||||
it('supports cancel and retry actions for selected task', async () => {
|
|
||||||
const user = userEvent.setup()
|
|
||||||
render(<KnowledgeBasePage />)
|
|
||||||
|
|
||||||
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
|
|
||||||
await user.click(screen.getByRole('tab', { name: '导入' }))
|
|
||||||
await screen.findByText('任务详情')
|
|
||||||
|
|
||||||
await user.click(screen.getByRole('button', { name: '取消选中导入任务' }))
|
|
||||||
await waitFor(() => expect(memoryApi.cancelMemoryImportTask).toHaveBeenCalledWith('import-run-1'))
|
|
||||||
|
|
||||||
await user.click(screen.getByRole('button', { name: '重试选中导入任务' }))
|
|
||||||
await waitFor(() => expect(memoryApi.retryMemoryImportTask).toHaveBeenCalled())
|
|
||||||
const [taskId, retryPayload] = vi.mocked(memoryApi.retryMemoryImportTask).mock.calls[0]
|
|
||||||
expect(taskId).toBe('import-run-1')
|
|
||||||
expect(retryPayload).toMatchObject({
|
|
||||||
overrides: {
|
|
||||||
llm_enabled: true,
|
|
||||||
strategy_override: 'auto',
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}, 20_000)
|
|
||||||
|
|
||||||
it('auto polling updates queue and keeps page stable when refresh fails once', async () => {
|
|
||||||
vi.mocked(memoryApi.getMemoryImportSettings).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
settings: {
|
|
||||||
max_paste_chars: 200_000,
|
|
||||||
max_file_concurrency: 8,
|
|
||||||
max_chunk_concurrency: 16,
|
|
||||||
default_file_concurrency: 2,
|
|
||||||
default_chunk_concurrency: 4,
|
|
||||||
poll_interval_ms: 200,
|
|
||||||
maibot_source_db_default: 'data/maibot.db',
|
|
||||||
},
|
|
||||||
})
|
|
||||||
const user = userEvent.setup()
|
|
||||||
render(<KnowledgeBasePage />)
|
|
||||||
|
|
||||||
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
|
|
||||||
await user.click(screen.getByRole('tab', { name: '导入' }))
|
|
||||||
await screen.findByText('导入队列')
|
|
||||||
|
|
||||||
const initialCalls = vi.mocked(memoryApi.getMemoryImportTasks).mock.calls.length
|
|
||||||
vi.mocked(memoryApi.getMemoryImportTasks).mockRejectedValueOnce(new Error('poll failure'))
|
|
||||||
await act(async () => {
|
|
||||||
await new Promise((resolve) => setTimeout(resolve, 350))
|
|
||||||
})
|
|
||||||
|
|
||||||
expect(screen.getByText('长期记忆控制台')).toBeInTheDocument()
|
|
||||||
expect(vi.mocked(memoryApi.getMemoryImportTasks).mock.calls.length).toBeGreaterThan(initialCalls)
|
|
||||||
}, 20_000)
|
|
||||||
|
|
||||||
it('creates tuning task and applies best profile (tuning module)', async () => {
|
|
||||||
const user = userEvent.setup()
|
|
||||||
render(<KnowledgeBasePage />)
|
|
||||||
|
|
||||||
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
|
|
||||||
await user.click(screen.getByRole('tab', { name: '调优' }))
|
|
||||||
await screen.findByText('调优任务')
|
|
||||||
|
|
||||||
await user.click(screen.getByRole('button', { name: '创建调优任务' }))
|
|
||||||
await waitFor(() =>
|
|
||||||
expect(memoryApi.createMemoryTuningTask).toHaveBeenCalledWith({
|
|
||||||
objective: 'precision_priority',
|
|
||||||
intensity: 'standard',
|
|
||||||
sample_size: 24,
|
|
||||||
top_k_eval: 20,
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
|
|
||||||
await user.click(screen.getByRole('button', { name: '应用最佳' }))
|
|
||||||
await waitFor(() => expect(memoryApi.applyBestMemoryTuningProfile).toHaveBeenCalledWith('tune-1'))
|
|
||||||
}, 20_000)
|
|
||||||
|
|
||||||
it('previews executes and restores source delete (delete module)', async () => {
|
|
||||||
const user = userEvent.setup()
|
|
||||||
render(<KnowledgeBasePage />)
|
|
||||||
|
|
||||||
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
|
|
||||||
await user.click(screen.getByRole('tab', { name: '删除' }))
|
|
||||||
await screen.findByText('来源批量删除')
|
|
||||||
|
|
||||||
const sourceCellCandidates = await screen.findAllByText('demo-1')
|
|
||||||
const sourceRow = sourceCellCandidates
|
|
||||||
.map((item) => item.closest('tr'))
|
|
||||||
.find((row): row is HTMLTableRowElement => Boolean(row && within(row).queryByRole('checkbox')))
|
|
||||||
if (!sourceRow) {
|
|
||||||
throw new Error('missing source row')
|
|
||||||
}
|
|
||||||
await user.click(within(sourceRow).getByRole('checkbox'))
|
|
||||||
|
|
||||||
await user.click(screen.getByRole('button', { name: '预览删除' }))
|
|
||||||
await waitFor(() =>
|
|
||||||
expect(memoryApi.previewMemoryDelete).toHaveBeenCalledWith({
|
|
||||||
mode: 'source',
|
|
||||||
selector: { sources: ['demo-1'] },
|
|
||||||
reason: 'knowledge_base_source_delete',
|
|
||||||
requested_by: 'knowledge_base',
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
|
|
||||||
const dialog = await screen.findByTestId('memory-delete-dialog')
|
|
||||||
expect(dialog).toHaveTextContent('preview:source:1')
|
|
||||||
|
|
||||||
await user.click(screen.getByRole('button', { name: '执行删除' }))
|
|
||||||
await waitFor(() =>
|
|
||||||
expect(memoryApi.executeMemoryDelete).toHaveBeenCalledWith({
|
|
||||||
mode: 'source',
|
|
||||||
selector: { sources: ['demo-1'] },
|
|
||||||
reason: 'knowledge_base_source_delete',
|
|
||||||
requested_by: 'knowledge_base',
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
|
|
||||||
await user.click(screen.getByRole('button', { name: '执行恢复' }))
|
|
||||||
await waitFor(() =>
|
|
||||||
expect(memoryApi.restoreMemoryDelete).toHaveBeenCalledWith({
|
|
||||||
operation_id: 'del-2',
|
|
||||||
requested_by: 'knowledge_base',
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
}, 20_000)
|
|
||||||
|
|
||||||
it('shows feedback correction history and supports rollback', async () => {
|
|
||||||
const user = userEvent.setup()
|
|
||||||
render(<KnowledgeBasePage />)
|
|
||||||
|
|
||||||
await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })
|
|
||||||
await user.click(screen.getByRole('tab', { name: '纠错历史' }))
|
|
||||||
await screen.findByText('反馈纠错历史')
|
|
||||||
await screen.findByText('测试用户最喜欢的颜色是什么')
|
|
||||||
await waitFor(() => expect(memoryApi.getMemoryFeedbackCorrection).toHaveBeenCalledWith(11))
|
|
||||||
|
|
||||||
await user.click(screen.getByRole('button', { name: '回退本次纠错' }))
|
|
||||||
const rollbackReason = await screen.findByLabelText('回退原因')
|
|
||||||
await user.type(rollbackReason, '人工确认回退')
|
|
||||||
await user.click(screen.getByRole('button', { name: '确认回退' }))
|
|
||||||
|
|
||||||
await waitFor(() =>
|
|
||||||
expect(memoryApi.rollbackMemoryFeedbackCorrection).toHaveBeenCalledWith(11, {
|
|
||||||
requested_by: 'knowledge_base',
|
|
||||||
reason: '人工确认回退',
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
}, 20_000)
|
|
||||||
})
|
|
||||||
@@ -1,440 +0,0 @@
|
|||||||
import { render, screen, waitFor } from '@testing-library/react'
|
|
||||||
import userEvent from '@testing-library/user-event'
|
|
||||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
|
||||||
|
|
||||||
import { KnowledgeGraphPage } from '../knowledge-graph'
|
|
||||||
import * as memoryApi from '@/lib/memory-api'
|
|
||||||
|
|
||||||
const navigateMock = vi.fn()
|
|
||||||
const toastMock = vi.fn()
|
|
||||||
|
|
||||||
vi.mock('@tanstack/react-router', () => ({
|
|
||||||
useNavigate: () => navigateMock,
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('@/hooks/use-toast', () => ({
|
|
||||||
useToast: () => ({ toast: toastMock }),
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('@/components/memory/MemoryDeleteDialog', () => ({
|
|
||||||
MemoryDeleteDialog: ({
|
|
||||||
open,
|
|
||||||
preview,
|
|
||||||
}: {
|
|
||||||
open: boolean
|
|
||||||
preview?: { mode?: string; item_count?: number } | null
|
|
||||||
}) => (
|
|
||||||
open ? <div data-testid="memory-delete-dialog">{`delete:${preview?.mode ?? 'none'}:${preview?.item_count ?? 0}`}</div> : null
|
|
||||||
),
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('../knowledge-graph/GraphVisualization', () => ({
|
|
||||||
GraphVisualization: ({
|
|
||||||
graphData,
|
|
||||||
onNodeClick,
|
|
||||||
onEdgeClick,
|
|
||||||
}: {
|
|
||||||
graphData: { nodes: Array<{ id: string }>; edges: Array<{ source: string; target: string }> }
|
|
||||||
onNodeClick: (event: React.MouseEvent, node: { id: string }) => void
|
|
||||||
onEdgeClick: (event: React.MouseEvent, edge: { source: string; target: string }) => void
|
|
||||||
}) => (
|
|
||||||
<div data-testid="graph-visualization">
|
|
||||||
<div>{`nodes:${graphData.nodes.length},edges:${graphData.edges.length}`}</div>
|
|
||||||
{graphData.nodes[0] ? (
|
|
||||||
<button type="button" onClick={(event) => onNodeClick(event as never, { id: graphData.nodes[0].id })}>
|
|
||||||
选择节点
|
|
||||||
</button>
|
|
||||||
) : null}
|
|
||||||
{graphData.edges[0] ? (
|
|
||||||
<button
|
|
||||||
type="button"
|
|
||||||
onClick={(event) =>
|
|
||||||
onEdgeClick(event as never, {
|
|
||||||
source: graphData.edges[0].source,
|
|
||||||
target: graphData.edges[0].target,
|
|
||||||
})}
|
|
||||||
>
|
|
||||||
选择边
|
|
||||||
</button>
|
|
||||||
) : null}
|
|
||||||
</div>
|
|
||||||
),
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('../knowledge-graph/GraphDialogs', () => ({
|
|
||||||
NodeDetailDialog: ({
|
|
||||||
selectedNodeData,
|
|
||||||
nodeDetail,
|
|
||||||
onOpenEvidence,
|
|
||||||
onDeleteEntity,
|
|
||||||
}: {
|
|
||||||
selectedNodeData: { id: string } | null
|
|
||||||
nodeDetail: { relations?: Array<{ predicate: string }>; paragraphs?: Array<unknown> } | null
|
|
||||||
onOpenEvidence?: () => void
|
|
||||||
onDeleteEntity?: (options: { includeParagraphs: boolean }) => void
|
|
||||||
}) => (
|
|
||||||
selectedNodeData ? (
|
|
||||||
<div data-testid="node-detail-dialog">
|
|
||||||
<div>{`node:${selectedNodeData.id}`}</div>
|
|
||||||
<div>{`relations:${nodeDetail?.relations?.[0]?.predicate ?? 'none'}`}</div>
|
|
||||||
<div>{`paragraphs:${nodeDetail?.paragraphs?.length ?? 0}`}</div>
|
|
||||||
<button type="button" onClick={onOpenEvidence}>切到证据视图</button>
|
|
||||||
<button type="button" onClick={() => onDeleteEntity?.({ includeParagraphs: true })}>删除实体</button>
|
|
||||||
</div>
|
|
||||||
) : null
|
|
||||||
),
|
|
||||||
EdgeDetailDialog: ({
|
|
||||||
selectedEdgeData,
|
|
||||||
edgeDetail,
|
|
||||||
onOpenEvidence,
|
|
||||||
}: {
|
|
||||||
selectedEdgeData: { source: { id: string }; target: { id: string } } | null
|
|
||||||
edgeDetail: { edge?: { predicates?: string[] }; paragraphs?: Array<unknown> } | null
|
|
||||||
onOpenEvidence?: () => void
|
|
||||||
}) => (
|
|
||||||
selectedEdgeData ? (
|
|
||||||
<div data-testid="edge-detail-dialog">
|
|
||||||
<div>{`edge:${selectedEdgeData.source.id}->${selectedEdgeData.target.id}`}</div>
|
|
||||||
<div>{`predicates:${edgeDetail?.edge?.predicates?.join(',') ?? 'none'}`}</div>
|
|
||||||
<div>{`paragraphs:${edgeDetail?.paragraphs?.length ?? 0}`}</div>
|
|
||||||
<button type="button" onClick={onOpenEvidence}>切到证据视图</button>
|
|
||||||
</div>
|
|
||||||
) : null
|
|
||||||
),
|
|
||||||
RelationDetailDialog: () => null,
|
|
||||||
ParagraphDetailDialog: () => null,
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('@/lib/memory-api', () => ({
|
|
||||||
getMemoryGraph: vi.fn(),
|
|
||||||
getMemoryGraphSearch: vi.fn(),
|
|
||||||
getMemoryGraphNodeDetail: vi.fn(),
|
|
||||||
getMemoryGraphEdgeDetail: vi.fn(),
|
|
||||||
previewMemoryDelete: vi.fn(),
|
|
||||||
executeMemoryDelete: vi.fn(),
|
|
||||||
restoreMemoryDelete: vi.fn(),
|
|
||||||
}))
|
|
||||||
|
|
||||||
describe('KnowledgeGraphPage', () => {
|
|
||||||
beforeEach(() => {
|
|
||||||
navigateMock.mockReset()
|
|
||||||
toastMock.mockReset()
|
|
||||||
vi.mocked(memoryApi.getMemoryGraph).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
nodes: [
|
|
||||||
{ id: 'alpha', name: 'Alpha' },
|
|
||||||
{ id: 'beta', name: 'Beta' },
|
|
||||||
],
|
|
||||||
edges: [
|
|
||||||
{
|
|
||||||
source: 'alpha',
|
|
||||||
target: 'beta',
|
|
||||||
weight: 1,
|
|
||||||
predicates: ['关联'],
|
|
||||||
relation_count: 1,
|
|
||||||
evidence_count: 2,
|
|
||||||
relation_hashes: ['rel-1'],
|
|
||||||
label: '关联',
|
|
||||||
},
|
|
||||||
],
|
|
||||||
total_nodes: 2,
|
|
||||||
total_edges: 1,
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.getMemoryGraphSearch).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
query: 'alpha',
|
|
||||||
limit: 50,
|
|
||||||
count: 0,
|
|
||||||
items: [],
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.getMemoryGraphNodeDetail).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
node: { id: 'alpha', type: 'entity', content: 'Alpha', hash: 'entity-1', appearance_count: 3 },
|
|
||||||
relations: [
|
|
||||||
{
|
|
||||||
hash: 'rel-1',
|
|
||||||
subject: 'alpha',
|
|
||||||
predicate: '关联',
|
|
||||||
object: 'beta',
|
|
||||||
text: 'alpha 关联 beta',
|
|
||||||
confidence: 0.9,
|
|
||||||
paragraph_count: 1,
|
|
||||||
paragraph_hashes: ['p-1'],
|
|
||||||
source_paragraph: 'p-1',
|
|
||||||
},
|
|
||||||
],
|
|
||||||
paragraphs: [
|
|
||||||
{
|
|
||||||
hash: 'p-1',
|
|
||||||
content: 'Alpha 提到了 Beta',
|
|
||||||
preview: 'Alpha 提到了 Beta',
|
|
||||||
source: 'demo',
|
|
||||||
entity_count: 2,
|
|
||||||
relation_count: 1,
|
|
||||||
entities: ['Alpha', 'Beta'],
|
|
||||||
relations: ['alpha 关联 beta'],
|
|
||||||
},
|
|
||||||
],
|
|
||||||
evidence_graph: {
|
|
||||||
nodes: [
|
|
||||||
{ id: 'entity:alpha', type: 'entity', content: 'Alpha' },
|
|
||||||
{ id: 'relation:rel-1', type: 'relation', content: 'alpha 关联 beta' },
|
|
||||||
{ id: 'paragraph:p-1', type: 'paragraph', content: 'Alpha 提到了 Beta' },
|
|
||||||
],
|
|
||||||
edges: [
|
|
||||||
{ source: 'paragraph:p-1', target: 'entity:alpha', kind: 'mentions', label: '提及', weight: 1 },
|
|
||||||
{ source: 'paragraph:p-1', target: 'relation:rel-1', kind: 'supports', label: '支撑', weight: 1 },
|
|
||||||
],
|
|
||||||
focus_entities: ['alpha'],
|
|
||||||
},
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.getMemoryGraphEdgeDetail).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
edge: {
|
|
||||||
source: 'alpha',
|
|
||||||
target: 'beta',
|
|
||||||
weight: 1,
|
|
||||||
predicates: ['关联'],
|
|
||||||
relation_count: 1,
|
|
||||||
evidence_count: 1,
|
|
||||||
relation_hashes: ['rel-1'],
|
|
||||||
label: '关联',
|
|
||||||
},
|
|
||||||
relations: [
|
|
||||||
{
|
|
||||||
hash: 'rel-1',
|
|
||||||
subject: 'alpha',
|
|
||||||
predicate: '关联',
|
|
||||||
object: 'beta',
|
|
||||||
text: 'alpha 关联 beta',
|
|
||||||
confidence: 0.9,
|
|
||||||
paragraph_count: 1,
|
|
||||||
paragraph_hashes: ['p-1'],
|
|
||||||
source_paragraph: 'p-1',
|
|
||||||
},
|
|
||||||
],
|
|
||||||
paragraphs: [
|
|
||||||
{
|
|
||||||
hash: 'p-1',
|
|
||||||
content: 'Alpha 提到了 Beta',
|
|
||||||
preview: 'Alpha 提到了 Beta',
|
|
||||||
source: 'demo',
|
|
||||||
entity_count: 2,
|
|
||||||
relation_count: 1,
|
|
||||||
entities: ['Alpha', 'Beta'],
|
|
||||||
relations: ['alpha 关联 beta'],
|
|
||||||
},
|
|
||||||
],
|
|
||||||
evidence_graph: {
|
|
||||||
nodes: [
|
|
||||||
{ id: 'entity:alpha', type: 'entity', content: 'Alpha' },
|
|
||||||
{ id: 'entity:beta', type: 'entity', content: 'Beta' },
|
|
||||||
{ id: 'relation:rel-1', type: 'relation', content: 'alpha 关联 beta' },
|
|
||||||
],
|
|
||||||
edges: [
|
|
||||||
{ source: 'relation:rel-1', target: 'entity:alpha', kind: 'subject', label: '主语', weight: 1 },
|
|
||||||
{ source: 'relation:rel-1', target: 'entity:beta', kind: 'object', label: '宾语', weight: 1 },
|
|
||||||
],
|
|
||||||
focus_entities: ['alpha', 'beta'],
|
|
||||||
},
|
|
||||||
})
|
|
||||||
vi.mocked(memoryApi.previewMemoryDelete).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
mode: 'mixed',
|
|
||||||
selector: { entity_hashes: ['entity-1'] },
|
|
||||||
counts: { entities: 1, relations: 1, paragraphs: 1 },
|
|
||||||
sources: ['demo'],
|
|
||||||
items: [{ item_type: 'entity', item_hash: 'entity-1', label: 'Alpha' }],
|
|
||||||
item_count: 1,
|
|
||||||
dry_run: true,
|
|
||||||
} as never)
|
|
||||||
vi.mocked(memoryApi.executeMemoryDelete).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
mode: 'mixed',
|
|
||||||
operation_id: 'del-1',
|
|
||||||
counts: { entities: 1, relations: 1, paragraphs: 1 },
|
|
||||||
sources: ['demo'],
|
|
||||||
deleted_count: 3,
|
|
||||||
deleted_entity_count: 1,
|
|
||||||
deleted_relation_count: 1,
|
|
||||||
deleted_paragraph_count: 1,
|
|
||||||
deleted_source_count: 0,
|
|
||||||
} as never)
|
|
||||||
vi.mocked(memoryApi.restoreMemoryDelete).mockResolvedValue({ success: true } as never)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('calls backend graph search and renders no-hit state', async () => {
|
|
||||||
const user = userEvent.setup()
|
|
||||||
|
|
||||||
render(<KnowledgeGraphPage />)
|
|
||||||
|
|
||||||
expect(await screen.findByText('长期记忆图谱')).toBeInTheDocument()
|
|
||||||
expect(screen.getByText(/总节点 2/)).toBeInTheDocument()
|
|
||||||
expect(screen.getByTestId('graph-visualization')).toHaveTextContent('nodes:2,edges:1')
|
|
||||||
|
|
||||||
await user.type(screen.getByPlaceholderText('搜索实体、关系、hash(后端全库)'), 'missing')
|
|
||||||
expect(memoryApi.getMemoryGraph).toHaveBeenCalledTimes(1)
|
|
||||||
await user.click(screen.getByRole('button', { name: '搜索' }))
|
|
||||||
|
|
||||||
await waitFor(() => {
|
|
||||||
expect(memoryApi.getMemoryGraphSearch).toHaveBeenCalledWith('missing', 50)
|
|
||||||
})
|
|
||||||
expect(await screen.findByText('未命中实体或关系。')).toBeInTheDocument()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('supports clicking entity search result to locate evidence', async () => {
|
|
||||||
const user = userEvent.setup()
|
|
||||||
vi.mocked(memoryApi.getMemoryGraphSearch).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
query: 'alpha',
|
|
||||||
limit: 50,
|
|
||||||
count: 1,
|
|
||||||
items: [
|
|
||||||
{
|
|
||||||
type: 'entity',
|
|
||||||
title: 'Alpha',
|
|
||||||
matched_field: 'name',
|
|
||||||
matched_value: 'Alpha',
|
|
||||||
entity_name: 'alpha',
|
|
||||||
entity_hash: 'entity-1',
|
|
||||||
appearance_count: 3,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
})
|
|
||||||
|
|
||||||
render(<KnowledgeGraphPage />)
|
|
||||||
|
|
||||||
await screen.findByTestId('graph-visualization')
|
|
||||||
await user.type(screen.getByPlaceholderText('搜索实体、关系、hash(后端全库)'), 'alpha')
|
|
||||||
await user.click(screen.getByRole('button', { name: '搜索' }))
|
|
||||||
|
|
||||||
await screen.findByText('搜索词:alpha')
|
|
||||||
await user.click(screen.getByRole('button', { name: /Alpha/ }))
|
|
||||||
|
|
||||||
await waitFor(() => {
|
|
||||||
expect(memoryApi.getMemoryGraphNodeDetail).toHaveBeenCalledWith('alpha')
|
|
||||||
})
|
|
||||||
expect(screen.getByRole('tab', { name: '证据视图' })).toHaveAttribute('data-state', 'active')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('supports clicking relation search result to locate evidence', async () => {
|
|
||||||
const user = userEvent.setup()
|
|
||||||
vi.mocked(memoryApi.getMemoryGraphSearch).mockResolvedValue({
|
|
||||||
success: true,
|
|
||||||
query: '关联',
|
|
||||||
limit: 50,
|
|
||||||
count: 1,
|
|
||||||
items: [
|
|
||||||
{
|
|
||||||
type: 'relation',
|
|
||||||
title: 'alpha 关联 beta',
|
|
||||||
matched_field: 'predicate',
|
|
||||||
matched_value: '关联',
|
|
||||||
subject: 'alpha',
|
|
||||||
predicate: '关联',
|
|
||||||
object: 'beta',
|
|
||||||
relation_hash: 'rel-1',
|
|
||||||
confidence: 0.9,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
})
|
|
||||||
|
|
||||||
render(<KnowledgeGraphPage />)
|
|
||||||
|
|
||||||
await screen.findByTestId('graph-visualization')
|
|
||||||
await user.type(screen.getByPlaceholderText('搜索实体、关系、hash(后端全库)'), '关联')
|
|
||||||
await user.click(screen.getByRole('button', { name: '搜索' }))
|
|
||||||
await user.click(screen.getByRole('button', { name: /alpha 关联 beta/ }))
|
|
||||||
|
|
||||||
await waitFor(() => {
|
|
||||||
expect(memoryApi.getMemoryGraphEdgeDetail).toHaveBeenCalledWith('alpha', 'beta')
|
|
||||||
})
|
|
||||||
expect(screen.getByRole('tab', { name: '证据视图' })).toHaveAttribute('data-state', 'active')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('falls back to local filtering when backend search fails', async () => {
|
|
||||||
const user = userEvent.setup()
|
|
||||||
vi.mocked(memoryApi.getMemoryGraphSearch).mockRejectedValue(new Error('search unavailable'))
|
|
||||||
|
|
||||||
render(<KnowledgeGraphPage />)
|
|
||||||
|
|
||||||
await screen.findByTestId('graph-visualization')
|
|
||||||
await user.type(screen.getByPlaceholderText('搜索实体、关系、hash(后端全库)'), 'missing')
|
|
||||||
await user.click(screen.getByRole('button', { name: '搜索' }))
|
|
||||||
|
|
||||||
expect(await screen.findByText('还没有可展示的长期记忆图谱')).toBeInTheDocument()
|
|
||||||
expect(toastMock).toHaveBeenCalledWith(
|
|
||||||
expect.objectContaining({
|
|
||||||
title: '后端检索失败,已回退本地筛选',
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('shows empty state when switching to evidence view without a selection', async () => {
|
|
||||||
const user = userEvent.setup()
|
|
||||||
|
|
||||||
render(<KnowledgeGraphPage />)
|
|
||||||
|
|
||||||
expect(await screen.findByTestId('graph-visualization')).toBeInTheDocument()
|
|
||||||
await user.click(screen.getByRole('tab', { name: '证据视图' }))
|
|
||||||
|
|
||||||
expect(await screen.findByText('证据视图还没有可展示的选择')).toBeInTheDocument()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('closes node dialog when switching to evidence view and renders evidence graph', async () => {
|
|
||||||
const user = userEvent.setup()
|
|
||||||
|
|
||||||
render(<KnowledgeGraphPage />)
|
|
||||||
|
|
||||||
await screen.findByTestId('graph-visualization')
|
|
||||||
await user.click(screen.getByRole('button', { name: '选择节点' }))
|
|
||||||
|
|
||||||
expect(await screen.findByTestId('node-detail-dialog')).toHaveTextContent('relations:关联')
|
|
||||||
expect(screen.getByTestId('node-detail-dialog')).toHaveTextContent('paragraphs:1')
|
|
||||||
|
|
||||||
await user.click(screen.getByRole('button', { name: '切到证据视图' }))
|
|
||||||
|
|
||||||
await waitFor(() => {
|
|
||||||
expect(screen.queryByTestId('node-detail-dialog')).not.toBeInTheDocument()
|
|
||||||
})
|
|
||||||
|
|
||||||
await waitFor(() => {
|
|
||||||
expect(screen.getByTestId('graph-visualization')).toHaveTextContent('nodes:3,edges:2')
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
it('loads edge detail with predicates and support paragraphs', async () => {
|
|
||||||
const user = userEvent.setup()
|
|
||||||
|
|
||||||
render(<KnowledgeGraphPage />)
|
|
||||||
|
|
||||||
await screen.findByTestId('graph-visualization')
|
|
||||||
await user.click(screen.getByRole('button', { name: '选择边' }))
|
|
||||||
|
|
||||||
expect(await screen.findByTestId('edge-detail-dialog')).toHaveTextContent('predicates:关联')
|
|
||||||
expect(screen.getByTestId('edge-detail-dialog')).toHaveTextContent('paragraphs:1')
|
|
||||||
|
|
||||||
await user.click(screen.getByRole('button', { name: '切到证据视图' }))
|
|
||||||
|
|
||||||
await waitFor(() => {
|
|
||||||
expect(screen.queryByTestId('edge-detail-dialog')).not.toBeInTheDocument()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
it('opens delete preview dialog from node detail', async () => {
|
|
||||||
const user = userEvent.setup()
|
|
||||||
|
|
||||||
render(<KnowledgeGraphPage />)
|
|
||||||
|
|
||||||
await screen.findByTestId('graph-visualization')
|
|
||||||
await user.click(screen.getByRole('button', { name: '选择节点' }))
|
|
||||||
await screen.findByTestId('node-detail-dialog')
|
|
||||||
|
|
||||||
await user.click(screen.getByRole('button', { name: '删除实体' }))
|
|
||||||
|
|
||||||
await waitFor(() => {
|
|
||||||
expect(memoryApi.previewMemoryDelete).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
expect(await screen.findByTestId('memory-delete-dialog')).toHaveTextContent('delete:mixed:1')
|
|
||||||
})
|
|
||||||
})
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
{
|
|
||||||
"extends": "./tsconfig.app.json",
|
|
||||||
"compilerOptions": {
|
|
||||||
"types": ["vite/client", "vitest/globals", "@testing-library/jest-dom"]
|
|
||||||
},
|
|
||||||
"include": ["src"]
|
|
||||||
}
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
/// <reference types="vitest" />
|
|
||||||
import { defineConfig } from 'vite'
|
|
||||||
import react from '@vitejs/plugin-react'
|
|
||||||
import path from 'path'
|
|
||||||
|
|
||||||
export default defineConfig({
|
|
||||||
plugins: [react()],
|
|
||||||
test: {
|
|
||||||
globals: true,
|
|
||||||
environment: 'jsdom',
|
|
||||||
setupFiles: './src/test/setup.ts',
|
|
||||||
},
|
|
||||||
resolve: {
|
|
||||||
alias: {
|
|
||||||
'@': path.resolve(__dirname, './src'),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
@@ -1,398 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from pathlib import Path
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from typing import Any, Callable, Dict, List
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import inspect
|
|
||||||
import json
|
|
||||||
import pickle
|
|
||||||
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
from sqlmodel import Session, create_engine
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
IMPORT_ERROR: str | None = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
|
|
||||||
from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
|
|
||||||
from src.A_memorix.core.utils import summary_importer as summary_importer_module
|
|
||||||
from src.chat.message_receive.chat_manager import BotChatSession
|
|
||||||
from src.chat.message_receive.message import SessionMessage
|
|
||||||
from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo
|
|
||||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
|
||||||
from src.common.database import database as database_module
|
|
||||||
from src.common.database.migrations import create_database_migration_bootstrapper
|
|
||||||
from src.common.message_repository import count_messages
|
|
||||||
from src.config.model_configs import TaskConfig
|
|
||||||
from src.services import memory_flow_service as memory_flow_service_module
|
|
||||||
from src.services import memory_service as memory_service_module
|
|
||||||
from src.services import send_service
|
|
||||||
except SystemExit as exc:
|
|
||||||
IMPORT_ERROR = f"config initialization exited during import: {exc}"
|
|
||||||
kernel_module = None # type: ignore[assignment]
|
|
||||||
SDKMemoryKernel = None # type: ignore[assignment]
|
|
||||||
summary_importer_module = None # type: ignore[assignment]
|
|
||||||
BotChatSession = None # type: ignore[assignment]
|
|
||||||
SessionMessage = None # type: ignore[assignment]
|
|
||||||
MessageInfo = None # type: ignore[assignment]
|
|
||||||
UserInfo = None # type: ignore[assignment]
|
|
||||||
MessageSequence = None # type: ignore[assignment]
|
|
||||||
TextComponent = None # type: ignore[assignment]
|
|
||||||
database_module = None # type: ignore[assignment]
|
|
||||||
create_database_migration_bootstrapper = None # type: ignore[assignment]
|
|
||||||
count_messages = None # type: ignore[assignment]
|
|
||||||
TaskConfig = None # type: ignore[assignment]
|
|
||||||
memory_flow_service_module = None # type: ignore[assignment]
|
|
||||||
memory_service_module = None # type: ignore[assignment]
|
|
||||||
send_service = None # type: ignore[assignment]
|
|
||||||
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeEmbeddingManager:
|
|
||||||
def __init__(self, dimension: int = 8) -> None:
|
|
||||||
self.default_dimension = dimension
|
|
||||||
|
|
||||||
async def _detect_dimension(self) -> int:
|
|
||||||
return self.default_dimension
|
|
||||||
|
|
||||||
async def encode(self, text: Any) -> np.ndarray:
|
|
||||||
def _encode_one(raw: Any) -> np.ndarray:
|
|
||||||
content = str(raw or "")
|
|
||||||
vector = np.zeros(self.default_dimension, dtype=np.float32)
|
|
||||||
for index, byte in enumerate(content.encode("utf-8")):
|
|
||||||
vector[index % self.default_dimension] += float((byte % 17) + 1)
|
|
||||||
norm = float(np.linalg.norm(vector))
|
|
||||||
if norm > 0:
|
|
||||||
vector /= norm
|
|
||||||
return vector
|
|
||||||
|
|
||||||
if isinstance(text, (list, tuple)):
|
|
||||||
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
|
|
||||||
return _encode_one(text).astype(np.float32)
|
|
||||||
|
|
||||||
|
|
||||||
class _KernelBackedRuntimeManager:
|
|
||||||
def __init__(self, kernel: SDKMemoryKernel) -> None:
|
|
||||||
self.kernel = kernel
|
|
||||||
|
|
||||||
async def invoke(
|
|
||||||
self,
|
|
||||||
component_name: str,
|
|
||||||
args: Dict[str, Any] | None,
|
|
||||||
*,
|
|
||||||
timeout_ms: int = 30000,
|
|
||||||
) -> Any:
|
|
||||||
del timeout_ms
|
|
||||||
payload = args or {}
|
|
||||||
handler = getattr(self.kernel, component_name)
|
|
||||||
result = handler(**payload)
|
|
||||||
return await result if inspect.isawaitable(result) else result
|
|
||||||
|
|
||||||
|
|
||||||
class _NoopRuntimeManager:
|
|
||||||
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> Any:
|
|
||||||
del hook_name
|
|
||||||
return SimpleNamespace(aborted=False, kwargs=kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class _FakePlatformIOManager:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.ensure_calls = 0
|
|
||||||
|
|
||||||
async def ensure_send_pipeline_ready(self) -> None:
|
|
||||||
self.ensure_calls += 1
|
|
||||||
|
|
||||||
def build_route_key_from_message(self, message: Any) -> Any:
|
|
||||||
del message
|
|
||||||
return SimpleNamespace(platform="qq")
|
|
||||||
|
|
||||||
async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any:
|
|
||||||
del message, metadata
|
|
||||||
return SimpleNamespace(
|
|
||||||
has_success=True,
|
|
||||||
sent_receipts=[
|
|
||||||
SimpleNamespace(
|
|
||||||
driver_id="plugin.qq.sender",
|
|
||||||
external_message_id="real-message-id",
|
|
||||||
metadata={},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
failed_receipts=[],
|
|
||||||
route_key=route_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _install_temp_main_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
|
||||||
db_dir = (tmp_path / "main_db").resolve()
|
|
||||||
db_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
db_file = db_dir / "MaiBot.db"
|
|
||||||
database_url = f"sqlite:///{db_file}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
database_module.engine.dispose()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
engine = create_engine(
|
|
||||||
database_url,
|
|
||||||
echo=False,
|
|
||||||
connect_args={"check_same_thread": False},
|
|
||||||
pool_pre_ping=True,
|
|
||||||
)
|
|
||||||
session_local = sessionmaker(
|
|
||||||
autocommit=False,
|
|
||||||
autoflush=False,
|
|
||||||
bind=engine,
|
|
||||||
class_=Session,
|
|
||||||
)
|
|
||||||
bootstrapper = create_database_migration_bootstrapper(engine)
|
|
||||||
|
|
||||||
monkeypatch.setattr(database_module, "_DB_DIR", db_dir, raising=False)
|
|
||||||
monkeypatch.setattr(database_module, "_DB_FILE", db_file, raising=False)
|
|
||||||
monkeypatch.setattr(database_module, "DATABASE_URL", database_url, raising=False)
|
|
||||||
monkeypatch.setattr(database_module, "engine", engine, raising=False)
|
|
||||||
monkeypatch.setattr(database_module, "SessionLocal", session_local, raising=False)
|
|
||||||
monkeypatch.setattr(database_module, "_migration_bootstrapper", bootstrapper, raising=False)
|
|
||||||
monkeypatch.setattr(database_module, "_db_initialized", False, raising=False)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_incoming_message(
|
|
||||||
*,
|
|
||||||
session_id: str,
|
|
||||||
user_id: str,
|
|
||||||
text: str,
|
|
||||||
timestamp: datetime | None = None,
|
|
||||||
) -> SessionMessage:
|
|
||||||
message = SessionMessage(
|
|
||||||
message_id="incoming-message-id",
|
|
||||||
timestamp=timestamp or datetime.now(),
|
|
||||||
platform="qq",
|
|
||||||
)
|
|
||||||
message.message_info = MessageInfo(
|
|
||||||
user_info=UserInfo(
|
|
||||||
user_id=user_id,
|
|
||||||
user_nickname="测试用户",
|
|
||||||
user_cardname="测试用户",
|
|
||||||
),
|
|
||||||
additional_config={},
|
|
||||||
)
|
|
||||||
message.raw_message = MessageSequence(components=[TextComponent(text=text)])
|
|
||||||
message.session_id = session_id
|
|
||||||
message.reply_to = None
|
|
||||||
message.is_mentioned = False
|
|
||||||
message.is_at = False
|
|
||||||
message.is_emoji = False
|
|
||||||
message.is_picture = False
|
|
||||||
message.is_command = False
|
|
||||||
message.is_notify = False
|
|
||||||
message.processed_plain_text = text
|
|
||||||
message.initialized = True
|
|
||||||
return message
|
|
||||||
|
|
||||||
|
|
||||||
async def _wait_until(
|
|
||||||
predicate: Callable[[], Any],
|
|
||||||
*,
|
|
||||||
timeout_seconds: float = 10.0,
|
|
||||||
interval_seconds: float = 0.05,
|
|
||||||
description: str,
|
|
||||||
) -> Any:
|
|
||||||
deadline = asyncio.get_running_loop().time() + max(0.5, float(timeout_seconds))
|
|
||||||
while asyncio.get_running_loop().time() < deadline:
|
|
||||||
value = predicate()
|
|
||||||
if inspect.isawaitable(value):
|
|
||||||
value = await value
|
|
||||||
if value:
|
|
||||||
return value
|
|
||||||
await asyncio.sleep(interval_seconds)
|
|
||||||
raise AssertionError(f"等待超时: {description}")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_text_to_stream_triggers_real_chat_summary_writeback(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
tmp_path: Path,
|
|
||||||
) -> None:
|
|
||||||
_install_temp_main_database(monkeypatch, tmp_path)
|
|
||||||
|
|
||||||
fake_embedding_manager = _FakeEmbeddingManager()
|
|
||||||
captured_prompts: List[str] = []
|
|
||||||
fixed_send_timestamp = 1_777_000_000.0
|
|
||||||
|
|
||||||
async def _fake_runtime_self_check(**kwargs: Any) -> Dict[str, Any]:
|
|
||||||
del kwargs
|
|
||||||
return {
|
|
||||||
"ok": True,
|
|
||||||
"message": "ok",
|
|
||||||
"configured_dimension": fake_embedding_manager.default_dimension,
|
|
||||||
"requested_dimension": fake_embedding_manager.default_dimension,
|
|
||||||
"vector_store_dimension": fake_embedding_manager.default_dimension,
|
|
||||||
"detected_dimension": fake_embedding_manager.default_dimension,
|
|
||||||
"encoded_dimension": fake_embedding_manager.default_dimension,
|
|
||||||
"elapsed_ms": 0.0,
|
|
||||||
"sample_text": "test",
|
|
||||||
"checked_at": datetime.now().timestamp(),
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _fake_generate(request: Any) -> Any:
|
|
||||||
captured_prompts.append(str(getattr(request, "prompt", "") or ""))
|
|
||||||
return SimpleNamespace(
|
|
||||||
success=True,
|
|
||||||
completion=SimpleNamespace(
|
|
||||||
response=json.dumps(
|
|
||||||
{
|
|
||||||
"summary": "这段对话记录了用户提到自己买了绿色围巾,机器人表示会记住这件事。",
|
|
||||||
"entities": ["绿色围巾"],
|
|
||||||
"relations": [],
|
|
||||||
},
|
|
||||||
ensure_ascii=False,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
kernel_module,
|
|
||||||
"create_embedding_api_adapter",
|
|
||||||
lambda **kwargs: fake_embedding_manager,
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
kernel_module,
|
|
||||||
"run_embedding_runtime_self_check",
|
|
||||||
_fake_runtime_self_check,
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
summary_importer_module,
|
|
||||||
"run_embedding_runtime_self_check",
|
|
||||||
_fake_runtime_self_check,
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
summary_importer_module.llm_api,
|
|
||||||
"get_available_models",
|
|
||||||
lambda: {"utils": TaskConfig(model_list=["fake-summary-model"])},
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
summary_importer_module.llm_api,
|
|
||||||
"resolve_task_name_from_model_config",
|
|
||||||
lambda model_config: "utils",
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
summary_importer_module.llm_api,
|
|
||||||
"generate",
|
|
||||||
_fake_generate,
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(send_service.time, "time", lambda: fixed_send_timestamp)
|
|
||||||
monkeypatch.setattr(summary_importer_module.time, "time", lambda: fixed_send_timestamp)
|
|
||||||
|
|
||||||
kernel = SDKMemoryKernel(
|
|
||||||
plugin_root=tmp_path / "plugin_root",
|
|
||||||
config={
|
|
||||||
"storage": {"data_dir": str((tmp_path / "a_memorix_data").resolve())},
|
|
||||||
"advanced": {"enable_auto_save": False},
|
|
||||||
"embedding": {"dimension": fake_embedding_manager.default_dimension},
|
|
||||||
"memory": {"base_decay_interval_hours": 24},
|
|
||||||
"person_profile": {"refresh_interval_minutes": 5},
|
|
||||||
"summarization": {"model_name": ["utils"]},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
service = memory_flow_service_module.MemoryAutomationService()
|
|
||||||
fake_platform_io_manager = _FakePlatformIOManager()
|
|
||||||
|
|
||||||
async def _fake_rebuild_episodes_for_sources(sources: List[str]) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"rebuilt": 0,
|
|
||||||
"items": [],
|
|
||||||
"failures": [],
|
|
||||||
"sources": list(sources),
|
|
||||||
}
|
|
||||||
|
|
||||||
monkeypatch.setattr(kernel, "rebuild_episodes_for_sources", _fake_rebuild_episodes_for_sources)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
memory_service_module,
|
|
||||||
"a_memorix_host_service",
|
|
||||||
_KernelBackedRuntimeManager(kernel),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(memory_flow_service_module, "memory_automation_service", service)
|
|
||||||
monkeypatch.setattr(send_service, "_get_runtime_manager", lambda: _NoopRuntimeManager())
|
|
||||||
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_platform_io_manager)
|
|
||||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
|
|
||||||
monkeypatch.setattr(
|
|
||||||
send_service._chat_manager,
|
|
||||||
"get_session_by_session_id",
|
|
||||||
lambda stream_id: (
|
|
||||||
BotChatSession(
|
|
||||||
session_id="test-session",
|
|
||||||
platform="qq",
|
|
||||||
user_id="target-user",
|
|
||||||
group_id=None,
|
|
||||||
)
|
|
||||||
if stream_id == "test-session"
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
integration_config = memory_flow_service_module.global_config.a_memorix.integration
|
|
||||||
monkeypatch.setattr(integration_config, "chat_summary_writeback_enabled", True, raising=False)
|
|
||||||
monkeypatch.setattr(integration_config, "chat_summary_writeback_message_threshold", 2, raising=False)
|
|
||||||
monkeypatch.setattr(integration_config, "chat_summary_writeback_context_length", 10, raising=False)
|
|
||||||
monkeypatch.setattr(integration_config, "person_fact_writeback_enabled", False, raising=False)
|
|
||||||
|
|
||||||
await kernel.initialize()
|
|
||||||
|
|
||||||
try:
|
|
||||||
incoming_message = _build_incoming_message(
|
|
||||||
session_id="test-session",
|
|
||||||
user_id="target-user",
|
|
||||||
text="我最近买了一条绿色围巾。",
|
|
||||||
timestamp=datetime.fromtimestamp(fixed_send_timestamp) - timedelta(seconds=1),
|
|
||||||
)
|
|
||||||
with database_module.get_db_session() as session:
|
|
||||||
session.add(incoming_message.to_db_instance())
|
|
||||||
|
|
||||||
sent_message = await send_service.text_to_stream_with_message(
|
|
||||||
text="好的,我会记住你最近买了绿色围巾。",
|
|
||||||
stream_id="test-session",
|
|
||||||
storage_message=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert sent_message is not None
|
|
||||||
assert sent_message.message_id == "real-message-id"
|
|
||||||
assert fake_platform_io_manager.ensure_calls == 1
|
|
||||||
assert count_messages(session_id="test-session") == 2
|
|
||||||
|
|
||||||
paragraphs = await _wait_until(
|
|
||||||
lambda: kernel.metadata_store.get_paragraphs_by_source("chat_summary:test-session"),
|
|
||||||
description="等待聊天摘要写回到 A_memorix",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert captured_prompts
|
|
||||||
assert "我最近买了一条绿色围巾。" in captured_prompts[-1]
|
|
||||||
assert "好的,我会记住你最近买了绿色围巾。" in captured_prompts[-1]
|
|
||||||
assert any("绿色围巾" in str(item.get("content", "") or "") for item in paragraphs)
|
|
||||||
assert any(
|
|
||||||
int(
|
|
||||||
(
|
|
||||||
pickle.loads(item.get("metadata"))
|
|
||||||
if isinstance(item.get("metadata"), (bytes, bytearray))
|
|
||||||
else item.get("metadata")
|
|
||||||
or {}
|
|
||||||
).get("trigger_message_count", 0)
|
|
||||||
or 0
|
|
||||||
)
|
|
||||||
== 2
|
|
||||||
for item in paragraphs
|
|
||||||
)
|
|
||||||
assert service.chat_summary_writeback._states["test-session"].last_trigger_message_count == 2
|
|
||||||
finally:
|
|
||||||
await service.shutdown()
|
|
||||||
await kernel.shutdown()
|
|
||||||
try:
|
|
||||||
database_module.engine.dispose()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
@@ -1,191 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from A_memorix.core.embedding import api_adapter as api_adapter_module
|
|
||||||
from A_memorix.core.embedding.api_adapter import EmbeddingAPIAdapter
|
|
||||||
from A_memorix.core.utils.runtime_self_check import run_embedding_runtime_self_check
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeEmbeddingClient:
|
|
||||||
def __init__(self, *, natural_dimension: int = 12) -> None:
|
|
||||||
self.natural_dimension = int(natural_dimension)
|
|
||||||
self.requests = []
|
|
||||||
|
|
||||||
async def get_embedding(self, request):
|
|
||||||
self.requests.append(request)
|
|
||||||
requested_dimension = request.extra_params.get("dimensions")
|
|
||||||
if requested_dimension is None:
|
|
||||||
requested_dimension = request.extra_params.get("output_dimensionality")
|
|
||||||
dimension = int(requested_dimension or self.natural_dimension)
|
|
||||||
return SimpleNamespace(embedding=[1.0] * dimension)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_adapter(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
*,
|
|
||||||
client_type: str,
|
|
||||||
configured_dimension: int = 1024,
|
|
||||||
effective_dimension: int | None = None,
|
|
||||||
model_extra_params: dict | None = None,
|
|
||||||
):
|
|
||||||
adapter = EmbeddingAPIAdapter(default_dimension=configured_dimension)
|
|
||||||
if effective_dimension is not None:
|
|
||||||
adapter._dimension = int(effective_dimension)
|
|
||||||
adapter._dimension_detected = True
|
|
||||||
|
|
||||||
fake_client = _FakeEmbeddingClient()
|
|
||||||
model_info = SimpleNamespace(
|
|
||||||
name="embedding-model",
|
|
||||||
api_provider="provider-1",
|
|
||||||
model_identifier="embedding-model-id",
|
|
||||||
extra_params=dict(model_extra_params or {}),
|
|
||||||
)
|
|
||||||
provider = SimpleNamespace(name="provider-1", client_type=client_type)
|
|
||||||
|
|
||||||
monkeypatch.setattr(adapter, "_resolve_candidate_model_names", lambda: ["embedding-model"])
|
|
||||||
monkeypatch.setattr(adapter, "_find_model_info", lambda model_name: model_info)
|
|
||||||
monkeypatch.setattr(adapter, "_find_provider", lambda provider_name: provider)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
api_adapter_module.client_registry,
|
|
||||||
"get_client_class_instance",
|
|
||||||
lambda api_provider, force_new=True: fake_client,
|
|
||||||
)
|
|
||||||
return adapter, fake_client
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_encode_uses_canonical_dimension_for_openai_provider(monkeypatch):
|
|
||||||
adapter, fake_client = _build_adapter(
|
|
||||||
monkeypatch,
|
|
||||||
client_type="openai",
|
|
||||||
configured_dimension=1024,
|
|
||||||
effective_dimension=1024,
|
|
||||||
model_extra_params={"task_type": "SEMANTIC_SIMILARITY"},
|
|
||||||
)
|
|
||||||
|
|
||||||
embedding = await adapter.encode("北塔木梯")
|
|
||||||
|
|
||||||
request = fake_client.requests[-1]
|
|
||||||
assert request.extra_params["dimensions"] == 1024
|
|
||||||
assert "output_dimensionality" not in request.extra_params
|
|
||||||
assert request.extra_params["task_type"] == "SEMANTIC_SIMILARITY"
|
|
||||||
assert embedding.shape == (1024,)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_encode_explicit_dimension_override_wins(monkeypatch):
|
|
||||||
adapter, fake_client = _build_adapter(
|
|
||||||
monkeypatch,
|
|
||||||
client_type="openai",
|
|
||||||
configured_dimension=1024,
|
|
||||||
effective_dimension=1024,
|
|
||||||
)
|
|
||||||
|
|
||||||
embedding = await adapter.encode("海潮图", dimensions=256)
|
|
||||||
|
|
||||||
request = fake_client.requests[-1]
|
|
||||||
assert request.extra_params["dimensions"] == 256
|
|
||||||
assert "output_dimensionality" not in request.extra_params
|
|
||||||
assert embedding.shape == (256,)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_encode_maps_dimension_to_gemini_output_dimensionality(monkeypatch):
|
|
||||||
adapter, fake_client = _build_adapter(
|
|
||||||
monkeypatch,
|
|
||||||
client_type="gemini",
|
|
||||||
configured_dimension=1024,
|
|
||||||
effective_dimension=768,
|
|
||||||
)
|
|
||||||
|
|
||||||
embedding = await adapter.encode("广播站")
|
|
||||||
|
|
||||||
request = fake_client.requests[-1]
|
|
||||||
assert request.extra_params["output_dimensionality"] == 768
|
|
||||||
assert "dimensions" not in request.extra_params
|
|
||||||
assert embedding.shape == (768,)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_encode_does_not_force_dimension_for_unsupported_provider(monkeypatch):
|
|
||||||
adapter, fake_client = _build_adapter(
|
|
||||||
monkeypatch,
|
|
||||||
client_type="custom",
|
|
||||||
configured_dimension=1024,
|
|
||||||
effective_dimension=640,
|
|
||||||
model_extra_params={
|
|
||||||
"dimensions": 999,
|
|
||||||
"output_dimensionality": 888,
|
|
||||||
"custom_flag": "keep-me",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
embedding = await adapter.encode("蓝漆铁盒")
|
|
||||||
|
|
||||||
request = fake_client.requests[-1]
|
|
||||||
assert "dimensions" not in request.extra_params
|
|
||||||
assert "output_dimensionality" not in request.extra_params
|
|
||||||
assert request.extra_params["custom_flag"] == "keep-me"
|
|
||||||
assert embedding.shape == (fake_client.natural_dimension,)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_runtime_self_check_reports_requested_dimension_without_explicit_override():
|
|
||||||
class _FakeEmbeddingManager:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.detected_dimension = 384
|
|
||||||
self.encode_calls = []
|
|
||||||
|
|
||||||
async def _detect_dimension(self) -> int:
|
|
||||||
return self.detected_dimension
|
|
||||||
|
|
||||||
def get_requested_dimension(self) -> int:
|
|
||||||
return self.detected_dimension
|
|
||||||
|
|
||||||
async def encode(self, text):
|
|
||||||
self.encode_calls.append(text)
|
|
||||||
return np.ones(self.detected_dimension, dtype=np.float32)
|
|
||||||
|
|
||||||
manager = _FakeEmbeddingManager()
|
|
||||||
|
|
||||||
report = await run_embedding_runtime_self_check(
|
|
||||||
config={"embedding": {"dimension": 1024}},
|
|
||||||
vector_store=SimpleNamespace(dimension=384),
|
|
||||||
embedding_manager=manager,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert report["ok"] is True
|
|
||||||
assert report["configured_dimension"] == 1024
|
|
||||||
assert report["requested_dimension"] == 384
|
|
||||||
assert report["detected_dimension"] == 384
|
|
||||||
assert report["encoded_dimension"] == 384
|
|
||||||
assert manager.encode_calls == ["A_Memorix runtime self check"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_encode_batch_keeps_batch_local_indexes_when_cache_hits_previous_batch(monkeypatch):
|
|
||||||
adapter = EmbeddingAPIAdapter(default_dimension=4, enable_cache=True)
|
|
||||||
adapter._dimension = 4
|
|
||||||
adapter._dimension_detected = True
|
|
||||||
|
|
||||||
async def fake_detect_dimension() -> int:
|
|
||||||
return 4
|
|
||||||
|
|
||||||
async def fake_get_embedding_direct(text: str, dimensions: int | None = None):
|
|
||||||
del dimensions
|
|
||||||
base = float(ord(str(text)[0]))
|
|
||||||
return [base, base + 1.0, base + 2.0, base + 3.0]
|
|
||||||
|
|
||||||
monkeypatch.setattr(adapter, "_detect_dimension", fake_detect_dimension)
|
|
||||||
monkeypatch.setattr(adapter, "_get_embedding_direct", fake_get_embedding_direct)
|
|
||||||
|
|
||||||
embeddings = await adapter.encode(["A", "B", "A", "C"], batch_size=2)
|
|
||||||
|
|
||||||
assert embeddings.shape == (4, 4)
|
|
||||||
assert np.array_equal(embeddings[0], embeddings[2])
|
|
||||||
assert embeddings[1][0] == float(ord("B"))
|
|
||||||
assert embeddings[3][0] == float(ord("C"))
|
|
||||||
@@ -1,780 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import inspect
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from typing import Any, Callable, Dict
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
import pytest_asyncio
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
from sqlmodel import Session, create_engine, select
|
|
||||||
|
|
||||||
IMPORT_ERROR: str | None = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
|
|
||||||
from src.A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
|
||||||
from src.chat.heart_flow.heartflow_manager import heartflow_manager
|
|
||||||
from src.chat.message_receive import bot as bot_module
|
|
||||||
from src.chat.message_receive.chat_manager import chat_manager
|
|
||||||
from src.chat.message_receive.bot import chat_bot
|
|
||||||
from src.common.database import database as database_module
|
|
||||||
from src.common.database.database_model import PersonInfo, ToolRecord
|
|
||||||
from src.common.database.migrations import create_database_migration_bootstrapper
|
|
||||||
from src.common.utils.utils_session import SessionUtils
|
|
||||||
from src.llm_models.payload_content.tool_option import ToolCall
|
|
||||||
from src.maisaka import reasoning_engine as reasoning_engine_module
|
|
||||||
from src.maisaka import runtime as runtime_module
|
|
||||||
from src.maisaka import chat_loop_service as chat_loop_service_module
|
|
||||||
from src.maisaka.chat_loop_service import ChatResponse
|
|
||||||
from src.maisaka.context_messages import AssistantMessage
|
|
||||||
from src.plugin_runtime import component_query as component_query_module
|
|
||||||
from src.services import memory_flow_service as memory_flow_service_module
|
|
||||||
from src.services import memory_service as memory_service_module
|
|
||||||
from src.services.memory_service import memory_service
|
|
||||||
except SystemExit as exc:
|
|
||||||
IMPORT_ERROR = f"config initialization exited during import: {exc}"
|
|
||||||
kernel_module = None # type: ignore[assignment]
|
|
||||||
KernelSearchRequest = None # type: ignore[assignment]
|
|
||||||
SDKMemoryKernel = None # type: ignore[assignment]
|
|
||||||
heartflow_manager = None # type: ignore[assignment]
|
|
||||||
bot_module = None # type: ignore[assignment]
|
|
||||||
chat_manager = None # type: ignore[assignment]
|
|
||||||
chat_bot = None # type: ignore[assignment]
|
|
||||||
database_module = None # type: ignore[assignment]
|
|
||||||
ToolRecord = None # type: ignore[assignment]
|
|
||||||
PersonInfo = None # type: ignore[assignment]
|
|
||||||
create_database_migration_bootstrapper = None # type: ignore[assignment]
|
|
||||||
SessionUtils = None # type: ignore[assignment]
|
|
||||||
ToolCall = None # type: ignore[assignment]
|
|
||||||
reasoning_engine_module = None # type: ignore[assignment]
|
|
||||||
runtime_module = None # type: ignore[assignment]
|
|
||||||
chat_loop_service_module = None # type: ignore[assignment]
|
|
||||||
ChatResponse = None # type: ignore[assignment]
|
|
||||||
AssistantMessage = None # type: ignore[assignment]
|
|
||||||
component_query_module = None # type: ignore[assignment]
|
|
||||||
memory_flow_service_module = None # type: ignore[assignment]
|
|
||||||
memory_service_module = None # type: ignore[assignment]
|
|
||||||
memory_service = None # type: ignore[assignment]
|
|
||||||
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
|
|
||||||
|
|
||||||
RELATION_QUERY = "测试用户 和 最喜欢的颜色 有什么关系"
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeEmbeddingManager:
|
|
||||||
def __init__(self, dimension: int = 8) -> None:
|
|
||||||
self.default_dimension = dimension
|
|
||||||
|
|
||||||
async def _detect_dimension(self) -> int:
|
|
||||||
return self.default_dimension
|
|
||||||
|
|
||||||
async def encode(self, text: Any) -> np.ndarray:
|
|
||||||
def _encode_one(raw: Any) -> np.ndarray:
|
|
||||||
content = str(raw or "")
|
|
||||||
vector = np.zeros(self.default_dimension, dtype=np.float32)
|
|
||||||
for index, byte in enumerate(content.encode("utf-8")):
|
|
||||||
vector[index % self.default_dimension] += float((byte % 17) + 1)
|
|
||||||
norm = float(np.linalg.norm(vector))
|
|
||||||
if norm > 0:
|
|
||||||
vector /= norm
|
|
||||||
return vector
|
|
||||||
|
|
||||||
if isinstance(text, (list, tuple)):
|
|
||||||
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
|
|
||||||
return _encode_one(text).astype(np.float32)
|
|
||||||
|
|
||||||
|
|
||||||
class _KernelBackedRuntimeManager:
|
|
||||||
def __init__(self, kernel: SDKMemoryKernel) -> None:
|
|
||||||
self.kernel = kernel
|
|
||||||
|
|
||||||
async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000):
|
|
||||||
del timeout_ms
|
|
||||||
payload = args or {}
|
|
||||||
if component_name == "search_memory":
|
|
||||||
return await self.kernel.search_memory(
|
|
||||||
KernelSearchRequest(
|
|
||||||
query=str(payload.get("query", "") or ""),
|
|
||||||
limit=int(payload.get("limit", 5) or 5),
|
|
||||||
mode=str(payload.get("mode", "hybrid") or "hybrid"),
|
|
||||||
chat_id=str(payload.get("chat_id", "") or ""),
|
|
||||||
person_id=str(payload.get("person_id", "") or ""),
|
|
||||||
time_start=payload.get("time_start"),
|
|
||||||
time_end=payload.get("time_end"),
|
|
||||||
respect_filter=bool(payload.get("respect_filter", True)),
|
|
||||||
user_id=str(payload.get("user_id", "") or ""),
|
|
||||||
group_id=str(payload.get("group_id", "") or ""),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
handler = getattr(self.kernel, component_name)
|
|
||||||
result = handler(**payload)
|
|
||||||
return await result if inspect.isawaitable(result) else result
|
|
||||||
|
|
||||||
|
|
||||||
class _NoopRuntimeManager:
|
|
||||||
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> Any:
|
|
||||||
del hook_name
|
|
||||||
return SimpleNamespace(aborted=False, kwargs=kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def _install_temp_main_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
|
||||||
db_dir = (tmp_path / "main_db").resolve()
|
|
||||||
db_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
db_file = db_dir / "MaiBot.db"
|
|
||||||
database_url = f"sqlite:///{db_file}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
database_module.engine.dispose()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
engine = create_engine(
|
|
||||||
database_url,
|
|
||||||
echo=False,
|
|
||||||
connect_args={"check_same_thread": False},
|
|
||||||
pool_pre_ping=True,
|
|
||||||
)
|
|
||||||
session_local = sessionmaker(
|
|
||||||
autocommit=False,
|
|
||||||
autoflush=False,
|
|
||||||
bind=engine,
|
|
||||||
class_=Session,
|
|
||||||
)
|
|
||||||
bootstrapper = create_database_migration_bootstrapper(engine)
|
|
||||||
|
|
||||||
monkeypatch.setattr(database_module, "_DB_DIR", db_dir, raising=False)
|
|
||||||
monkeypatch.setattr(database_module, "_DB_FILE", db_file, raising=False)
|
|
||||||
monkeypatch.setattr(database_module, "DATABASE_URL", database_url, raising=False)
|
|
||||||
monkeypatch.setattr(database_module, "engine", engine, raising=False)
|
|
||||||
monkeypatch.setattr(database_module, "SessionLocal", session_local, raising=False)
|
|
||||||
monkeypatch.setattr(database_module, "_migration_bootstrapper", bootstrapper, raising=False)
|
|
||||||
monkeypatch.setattr(database_module, "_db_initialized", False, raising=False)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_chat_response(content: str, tool_calls: list[ToolCall]) -> ChatResponse:
|
|
||||||
return ChatResponse(
|
|
||||||
content=content,
|
|
||||||
tool_calls=tool_calls,
|
|
||||||
request_messages=[],
|
|
||||||
raw_message=AssistantMessage(
|
|
||||||
content=content,
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
tool_calls=tool_calls,
|
|
||||||
),
|
|
||||||
selected_history_count=0,
|
|
||||||
tool_count=len(tool_calls),
|
|
||||||
prompt_tokens=0,
|
|
||||||
built_message_count=0,
|
|
||||||
completion_tokens=0,
|
|
||||||
total_tokens=0,
|
|
||||||
prompt_section=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_message_data(
|
|
||||||
*,
|
|
||||||
content: str,
|
|
||||||
platform: str,
|
|
||||||
user_id: str,
|
|
||||||
user_name: str,
|
|
||||||
group_id: str,
|
|
||||||
group_name: str,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
message_id = str(uuid.uuid4())
|
|
||||||
return {
|
|
||||||
"message_info": {
|
|
||||||
"platform": platform,
|
|
||||||
"message_id": message_id,
|
|
||||||
"time": time.time(),
|
|
||||||
"group_info": {
|
|
||||||
"group_id": group_id,
|
|
||||||
"group_name": group_name,
|
|
||||||
"platform": platform,
|
|
||||||
},
|
|
||||||
"user_info": {
|
|
||||||
"user_id": user_id,
|
|
||||||
"user_nickname": user_name,
|
|
||||||
"user_cardname": user_name,
|
|
||||||
"platform": platform,
|
|
||||||
},
|
|
||||||
"additional_config": {
|
|
||||||
"at_bot": True,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"message_segment": {
|
|
||||||
"type": "seglist",
|
|
||||||
"data": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"data": content,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"raw_message": content,
|
|
||||||
"processed_plain_text": content,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def _wait_until(
|
|
||||||
predicate: Callable[[], Any],
|
|
||||||
*,
|
|
||||||
timeout_seconds: float = 10.0,
|
|
||||||
interval_seconds: float = 0.05,
|
|
||||||
description: str,
|
|
||||||
) -> Any:
|
|
||||||
deadline = asyncio.get_running_loop().time() + max(0.5, float(timeout_seconds))
|
|
||||||
while asyncio.get_running_loop().time() < deadline:
|
|
||||||
value = predicate()
|
|
||||||
if inspect.isawaitable(value):
|
|
||||||
value = await value
|
|
||||||
if value:
|
|
||||||
return value
|
|
||||||
await asyncio.sleep(interval_seconds)
|
|
||||||
raise AssertionError(f"等待超时: {description}")
|
|
||||||
|
|
||||||
|
|
||||||
def _load_feedback_tasks(kernel: SDKMemoryKernel) -> list[Dict[str, Any]]:
|
|
||||||
assert kernel.metadata_store is not None
|
|
||||||
cursor = kernel.metadata_store.get_connection().cursor()
|
|
||||||
rows = cursor.execute(
|
|
||||||
"SELECT query_tool_id FROM memory_feedback_tasks ORDER BY id"
|
|
||||||
).fetchall()
|
|
||||||
tasks: list[Dict[str, Any]] = []
|
|
||||||
for row in rows:
|
|
||||||
task = kernel.metadata_store.get_feedback_task(str(row["query_tool_id"] or ""))
|
|
||||||
if task is not None:
|
|
||||||
tasks.append(task)
|
|
||||||
return tasks
|
|
||||||
|
|
||||||
|
|
||||||
def _load_feedback_action_types(kernel: SDKMemoryKernel) -> list[str]:
|
|
||||||
assert kernel.metadata_store is not None
|
|
||||||
cursor = kernel.metadata_store.get_connection().cursor()
|
|
||||||
rows = cursor.execute(
|
|
||||||
"SELECT action_type FROM memory_feedback_action_logs ORDER BY id"
|
|
||||||
).fetchall()
|
|
||||||
return [str(row["action_type"] or "") for row in rows]
|
|
||||||
|
|
||||||
|
|
||||||
def _load_query_memory_tool_records(session_id: str) -> list[Dict[str, Any]]:
|
|
||||||
with database_module.get_db_session() as session:
|
|
||||||
statement = (
|
|
||||||
select(ToolRecord)
|
|
||||||
.where(ToolRecord.session_id == session_id)
|
|
||||||
.where(ToolRecord.tool_name == "query_memory")
|
|
||||||
.order_by(ToolRecord.timestamp)
|
|
||||||
)
|
|
||||||
rows = list(session.exec(statement).all())
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"tool_id": str(row.tool_id or ""),
|
|
||||||
"session_id": str(row.session_id or ""),
|
|
||||||
"tool_name": str(row.tool_name or ""),
|
|
||||||
"tool_data": str(row.tool_data or ""),
|
|
||||||
"timestamp": row.timestamp,
|
|
||||||
}
|
|
||||||
for row in rows
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _seed_person_info(*, person_id: str, person_name: str, session_info: Dict[str, Any]) -> None:
|
|
||||||
with database_module.get_db_session() as session:
|
|
||||||
session.add(
|
|
||||||
PersonInfo(
|
|
||||||
is_known=True,
|
|
||||||
person_id=person_id,
|
|
||||||
person_name=person_name,
|
|
||||||
platform=str(session_info["platform"]),
|
|
||||||
user_id=str(session_info["user_id"]),
|
|
||||||
user_nickname=str(session_info["user_name"]),
|
|
||||||
group_cardname=json.dumps(
|
|
||||||
[{"group_id": str(session_info["group_id"]), "group_cardname": person_name}],
|
|
||||||
ensure_ascii=False,
|
|
||||||
),
|
|
||||||
know_counts=1,
|
|
||||||
first_known_time=datetime.now(),
|
|
||||||
last_known_time=datetime.now(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
|
||||||
async def chat_feedback_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
|
|
||||||
_install_temp_main_database(monkeypatch, tmp_path)
|
|
||||||
|
|
||||||
chat_manager.sessions.clear()
|
|
||||||
chat_manager.last_messages.clear()
|
|
||||||
heartflow_manager.heartflow_chat_list.clear()
|
|
||||||
|
|
||||||
noop_runtime_manager = _NoopRuntimeManager()
|
|
||||||
monkeypatch.setattr(bot_module.ChatBot, "_get_runtime_manager", staticmethod(lambda: noop_runtime_manager))
|
|
||||||
monkeypatch.setattr(
|
|
||||||
component_query_module.component_query_service,
|
|
||||||
"find_command_by_text",
|
|
||||||
lambda text: None,
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
component_query_module.component_query_service,
|
|
||||||
"get_llm_available_tool_specs",
|
|
||||||
lambda **kwargs: {},
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(runtime_module.global_config.mcp, "enable", False, raising=False)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
runtime_module.MaisakaHeartFlowChatting,
|
|
||||||
"_get_message_trigger_threshold",
|
|
||||||
lambda self: 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _noop_on_incoming_message(message: Any) -> None:
|
|
||||||
del message
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
memory_flow_service_module.memory_automation_service,
|
|
||||||
"on_incoming_message",
|
|
||||||
_noop_on_incoming_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
fake_embedding_manager = _FakeEmbeddingManager(dimension=8)
|
|
||||||
|
|
||||||
async def _fake_runtime_self_check(
|
|
||||||
*,
|
|
||||||
config: Any,
|
|
||||||
sample_text: str,
|
|
||||||
vector_store: Any,
|
|
||||||
embedding_manager: Any,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
del config, sample_text, vector_store, embedding_manager
|
|
||||||
return {
|
|
||||||
"ok": True,
|
|
||||||
"message": "ok",
|
|
||||||
"checked_at": time.time(),
|
|
||||||
"encoded_dimension": fake_embedding_manager.default_dimension,
|
|
||||||
}
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
kernel_module,
|
|
||||||
"create_embedding_api_adapter",
|
|
||||||
lambda **kwargs: fake_embedding_manager,
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
kernel_module,
|
|
||||||
"run_embedding_runtime_self_check",
|
|
||||||
_fake_runtime_self_check,
|
|
||||||
)
|
|
||||||
|
|
||||||
kernel = SDKMemoryKernel(
|
|
||||||
plugin_root=tmp_path / "plugin_root",
|
|
||||||
config={
|
|
||||||
"storage": {"data_dir": str((tmp_path / "a_memorix_data").resolve())},
|
|
||||||
"advanced": {"enable_auto_save": False},
|
|
||||||
"embedding": {"dimension": fake_embedding_manager.default_dimension},
|
|
||||||
"memory": {"base_decay_interval_hours": 24},
|
|
||||||
"person_profile": {"refresh_interval_minutes": 5},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setattr(kernel, "_feedback_cfg_enabled", lambda: True)
|
|
||||||
monkeypatch.setattr(kernel, "_feedback_cfg_window_hours", lambda: 0.0004)
|
|
||||||
monkeypatch.setattr(kernel, "_feedback_cfg_check_interval_seconds", lambda: 0.2)
|
|
||||||
monkeypatch.setattr(kernel, "_feedback_cfg_batch_size", lambda: 10)
|
|
||||||
monkeypatch.setattr(kernel, "_feedback_cfg_max_messages", lambda: 10)
|
|
||||||
monkeypatch.setattr(kernel, "_feedback_cfg_auto_apply_threshold", lambda: 0.85)
|
|
||||||
monkeypatch.setattr(kernel, "_feedback_cfg_prefilter_enabled", lambda: True)
|
|
||||||
monkeypatch.setattr(kernel, "_feedback_cfg_paragraph_mark_enabled", lambda: True)
|
|
||||||
monkeypatch.setattr(kernel, "_feedback_cfg_paragraph_hard_filter_enabled", lambda: True)
|
|
||||||
monkeypatch.setattr(kernel, "_feedback_cfg_profile_refresh_enabled", lambda: True)
|
|
||||||
monkeypatch.setattr(kernel, "_feedback_cfg_profile_force_refresh_on_read", lambda: True)
|
|
||||||
monkeypatch.setattr(kernel, "_feedback_cfg_episode_rebuild_enabled", lambda: True)
|
|
||||||
monkeypatch.setattr(kernel, "_feedback_cfg_episode_query_block_enabled", lambda: True)
|
|
||||||
monkeypatch.setattr(kernel, "_feedback_cfg_reconcile_interval_seconds", lambda: 0.2)
|
|
||||||
monkeypatch.setattr(kernel, "_feedback_cfg_reconcile_batch_size", lambda: 10)
|
|
||||||
|
|
||||||
monkeypatch.setattr(kernel_module.global_config.memory, "feedback_correction_paragraph_hard_filter_enabled", True, raising=False)
|
|
||||||
monkeypatch.setattr(kernel_module.global_config.memory, "feedback_correction_episode_query_block_enabled", True, raising=False)
|
|
||||||
|
|
||||||
async def _fake_classify_feedback(
|
|
||||||
*,
|
|
||||||
query_tool_id: str,
|
|
||||||
query_text: str,
|
|
||||||
hit_briefs: list[Dict[str, Any]],
|
|
||||||
feedback_messages: list[str],
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
del query_tool_id, query_text, feedback_messages
|
|
||||||
target_hash = ""
|
|
||||||
for item in hit_briefs:
|
|
||||||
if str(item.get("type", "") or "").strip() == "relation":
|
|
||||||
target_hash = str(item.get("hash", "") or "").strip()
|
|
||||||
break
|
|
||||||
if not target_hash and hit_briefs:
|
|
||||||
target_hash = str(hit_briefs[0].get("hash", "") or "").strip()
|
|
||||||
return {
|
|
||||||
"decision": "correct",
|
|
||||||
"confidence": 0.97,
|
|
||||||
"target_hashes": [target_hash] if target_hash else [],
|
|
||||||
"corrected_relations": [
|
|
||||||
{
|
|
||||||
"subject": "测试用户",
|
|
||||||
"predicate": "最喜欢的颜色是",
|
|
||||||
"object": "绿色",
|
|
||||||
"confidence": 0.99,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"reason": "用户明确纠正为绿色",
|
|
||||||
}
|
|
||||||
|
|
||||||
monkeypatch.setattr(kernel, "_classify_feedback", _fake_classify_feedback)
|
|
||||||
|
|
||||||
await kernel.initialize()
|
|
||||||
async def _force_episode_fallback(**kwargs: Any) -> Dict[str, Any]:
|
|
||||||
raise RuntimeError("force_fallback_for_test")
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
kernel.episode_service.segmentation_service,
|
|
||||||
"segment",
|
|
||||||
_force_episode_fallback,
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
kernel,
|
|
||||||
"process_episode_pending_batch",
|
|
||||||
lambda *, limit=20, max_retry=3: asyncio.sleep(0, result={"processed": 0, "episode_count": 0, "fallback_count": 0, "failed": 0}),
|
|
||||||
)
|
|
||||||
|
|
||||||
host_manager = _KernelBackedRuntimeManager(kernel)
|
|
||||||
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", host_manager)
|
|
||||||
|
|
||||||
planner_calls: list[str] = []
|
|
||||||
|
|
||||||
async def _fake_timing_gate(self, anchor_message: Any):
|
|
||||||
del self, anchor_message
|
|
||||||
return "continue", _build_chat_response("直接进入 planner。", []), [], []
|
|
||||||
|
|
||||||
async def _fake_planner(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
injected_user_messages: list[str] | None = None,
|
|
||||||
tool_definitions: list[dict[str, Any]] | None = None,
|
|
||||||
) -> ChatResponse:
|
|
||||||
del injected_user_messages, tool_definitions
|
|
||||||
latest_message = self._runtime.message_cache[-1]
|
|
||||||
latest_text = str(latest_message.processed_plain_text or "")
|
|
||||||
planner_calls.append(latest_text)
|
|
||||||
handled_message_ids = getattr(self._runtime, "_test_query_message_ids", None)
|
|
||||||
if handled_message_ids is None:
|
|
||||||
handled_message_ids = set()
|
|
||||||
self._runtime._test_query_message_ids = handled_message_ids
|
|
||||||
|
|
||||||
if latest_message.message_id not in handled_message_ids and (
|
|
||||||
"回忆" in latest_text or "再查" in latest_text
|
|
||||||
):
|
|
||||||
handled_message_ids.add(latest_message.message_id)
|
|
||||||
tool_call = ToolCall(
|
|
||||||
call_id=f"query-{uuid.uuid4().hex}",
|
|
||||||
func_name="query_memory",
|
|
||||||
args={
|
|
||||||
"query": RELATION_QUERY,
|
|
||||||
"mode": "search",
|
|
||||||
"limit": 5,
|
|
||||||
"respect_filter": False,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return _build_chat_response("先查询长期记忆。", [tool_call])
|
|
||||||
|
|
||||||
stop_call = ToolCall(
|
|
||||||
call_id=f"stop-{uuid.uuid4().hex}",
|
|
||||||
func_name="no_reply",
|
|
||||||
args={},
|
|
||||||
)
|
|
||||||
return _build_chat_response("当前轮次结束。", [stop_call])
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
reasoning_engine_module.MaisakaReasoningEngine,
|
|
||||||
"_run_timing_gate",
|
|
||||||
_fake_timing_gate,
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
reasoning_engine_module.MaisakaReasoningEngine,
|
|
||||||
"_run_interruptible_planner",
|
|
||||||
_fake_planner,
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(reasoning_engine_module, "resolve_enable_visual_planner", lambda: False)
|
|
||||||
monkeypatch.setattr(chat_loop_service_module, "resolve_enable_visual_planner", lambda: False)
|
|
||||||
|
|
||||||
session_info = {
|
|
||||||
"platform": "unit_test_chat",
|
|
||||||
"user_id": "user_feedback_flow",
|
|
||||||
"user_name": "反馈测试用户",
|
|
||||||
"group_id": "group_feedback_flow",
|
|
||||||
"group_name": "反馈纠错测试群",
|
|
||||||
}
|
|
||||||
person_id = "person_feedback_flow"
|
|
||||||
session_id = SessionUtils.calculate_session_id(
|
|
||||||
session_info["platform"],
|
|
||||||
user_id=session_info["user_id"],
|
|
||||||
group_id=session_info["group_id"],
|
|
||||||
)
|
|
||||||
_seed_person_info(person_id=person_id, person_name="测试用户", session_info=session_info)
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield {
|
|
||||||
"kernel": kernel,
|
|
||||||
"session_id": session_id,
|
|
||||||
"session_info": session_info,
|
|
||||||
"person_id": person_id,
|
|
||||||
"planner_calls": planner_calls,
|
|
||||||
}
|
|
||||||
finally:
|
|
||||||
for key, chat in list(heartflow_manager.heartflow_chat_list.items()):
|
|
||||||
try:
|
|
||||||
await chat.stop()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
heartflow_manager.heartflow_chat_list.pop(key, None)
|
|
||||||
chat_manager.sessions.clear()
|
|
||||||
chat_manager.last_messages.clear()
|
|
||||||
await kernel.shutdown()
|
|
||||||
try:
|
|
||||||
database_module.engine.dispose()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_feedback_correction_real_chat_flow(
|
|
||||||
chat_feedback_env,
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
kernel = chat_feedback_env["kernel"]
|
|
||||||
session_id = chat_feedback_env["session_id"]
|
|
||||||
session_info = chat_feedback_env["session_info"]
|
|
||||||
person_id = chat_feedback_env["person_id"]
|
|
||||||
|
|
||||||
write_result = await memory_service.ingest_text(
|
|
||||||
external_id=f"test:feedback-seed:{uuid.uuid4().hex}",
|
|
||||||
source_type="chat_summary",
|
|
||||||
text="测试用户 最喜欢的颜色是 蓝色",
|
|
||||||
chat_id=session_id,
|
|
||||||
relations=[
|
|
||||||
{
|
|
||||||
"subject": "测试用户",
|
|
||||||
"predicate": "最喜欢的颜色是",
|
|
||||||
"object": "蓝色",
|
|
||||||
"confidence": 1.0,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
metadata={"test_case": "feedback_correction_chat_flow"},
|
|
||||||
respect_filter=False,
|
|
||||||
)
|
|
||||||
assert write_result.success is True
|
|
||||||
|
|
||||||
pre_search = await memory_service.search(
|
|
||||||
RELATION_QUERY,
|
|
||||||
mode="search",
|
|
||||||
chat_id=session_id,
|
|
||||||
respect_filter=False,
|
|
||||||
)
|
|
||||||
assert pre_search.hits
|
|
||||||
assert any("蓝色" in hit.content for hit in pre_search.hits)
|
|
||||||
|
|
||||||
pre_profile = await memory_service.get_person_profile(person_id, chat_id=session_id, limit=10)
|
|
||||||
pre_profile_text = pre_profile.summary + "\n" + json.dumps(pre_profile.evidence, ensure_ascii=False)
|
|
||||||
assert "蓝色" in pre_profile_text
|
|
||||||
|
|
||||||
seed_source = f"chat_summary:{session_id}"
|
|
||||||
rebuild_result = await kernel.rebuild_episodes_for_sources([seed_source])
|
|
||||||
assert rebuild_result["rebuilt"] >= 1
|
|
||||||
|
|
||||||
pre_episode = await memory_service.search(
|
|
||||||
"蓝色",
|
|
||||||
mode="episode",
|
|
||||||
chat_id=session_id,
|
|
||||||
respect_filter=False,
|
|
||||||
)
|
|
||||||
assert pre_episode.hits
|
|
||||||
assert any("蓝色" in hit.content for hit in pre_episode.hits)
|
|
||||||
|
|
||||||
await chat_bot.message_process(
|
|
||||||
_build_message_data(
|
|
||||||
content="请帮我回忆一下,测试用户最喜欢的颜色是什么?",
|
|
||||||
**session_info,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
await _wait_until(
|
|
||||||
lambda: chat_feedback_env["planner_calls"][0] if chat_feedback_env["planner_calls"] else None,
|
|
||||||
description="planner 收到首条聊天消息",
|
|
||||||
)
|
|
||||||
first_query_records = await _wait_until(
|
|
||||||
lambda: _load_query_memory_tool_records(session_id) if _load_query_memory_tool_records(session_id) else None,
|
|
||||||
description="首条 query_memory 工具记录生成",
|
|
||||||
)
|
|
||||||
assert first_query_records
|
|
||||||
|
|
||||||
first_task = await _wait_until(
|
|
||||||
lambda: _load_feedback_tasks(kernel)[0] if _load_feedback_tasks(kernel) else None,
|
|
||||||
description="首个反馈任务入队",
|
|
||||||
)
|
|
||||||
assert first_task["status"] == "pending"
|
|
||||||
first_hits = list((first_task.get("query_snapshot") or {}).get("hits") or [])
|
|
||||||
assert first_hits
|
|
||||||
assert any("蓝色" in str(item.get("content", "") or "") for item in first_hits)
|
|
||||||
|
|
||||||
await chat_bot.message_process(
|
|
||||||
_build_message_data(
|
|
||||||
content="不对,测试用户最喜欢的颜色不是蓝色,是绿色。",
|
|
||||||
**session_info,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
finalized_task = await _wait_until(
|
|
||||||
lambda: (
|
|
||||||
kernel.metadata_store.get_feedback_task(first_task["query_tool_id"])
|
|
||||||
if kernel.metadata_store.get_feedback_task(first_task["query_tool_id"])
|
|
||||||
and kernel.metadata_store.get_feedback_task(first_task["query_tool_id"]).get("status")
|
|
||||||
in {"applied", "skipped", "error"}
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
timeout_seconds=12.0,
|
|
||||||
interval_seconds=0.1,
|
|
||||||
description="反馈任务进入终态",
|
|
||||||
)
|
|
||||||
assert finalized_task["status"] == "applied", finalized_task
|
|
||||||
assert finalized_task["decision_payload"]["decision"] == "correct"
|
|
||||||
assert finalized_task["decision_payload"]["apply_result"]["applied"] is True
|
|
||||||
|
|
||||||
corrected_hashes = list(
|
|
||||||
(finalized_task["decision_payload"].get("apply_result") or {}).get("relation_hashes") or []
|
|
||||||
)
|
|
||||||
assert corrected_hashes
|
|
||||||
corrected_hash = str(corrected_hashes[0] or "")
|
|
||||||
relation_status = kernel.metadata_store.get_relation_status_batch([corrected_hash]).get(corrected_hash, {})
|
|
||||||
assert bool(relation_status.get("is_inactive")) is True
|
|
||||||
|
|
||||||
action_types = _load_feedback_action_types(kernel)
|
|
||||||
assert "classification" in action_types
|
|
||||||
assert "forget_relation" in action_types
|
|
||||||
assert "ingest_correction" in action_types
|
|
||||||
assert "mark_stale_paragraph" in action_types
|
|
||||||
assert "enqueue_episode_rebuild" in action_types
|
|
||||||
assert "enqueue_profile_refresh" in action_types
|
|
||||||
|
|
||||||
original_search = memory_service.search
|
|
||||||
original_get_person_profile = memory_service.get_person_profile
|
|
||||||
corrected_search_result = memory_service_module.MemorySearchResult(
|
|
||||||
summary="测试用户最喜欢的颜色是绿色。",
|
|
||||||
hits=[memory_service_module.MemoryHit(content="测试用户 最喜欢的颜色是 绿色", score=0.99)],
|
|
||||||
)
|
|
||||||
stale_search_result = memory_service_module.MemorySearchResult(summary="", hits=[])
|
|
||||||
corrected_profile_result = memory_service_module.PersonProfileResult(
|
|
||||||
summary="测试用户最喜欢的颜色是绿色。",
|
|
||||||
traits=["最喜欢的颜色是绿色"],
|
|
||||||
evidence=[{"content": "测试用户 最喜欢的颜色是 绿色"}],
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _mock_post_correction_search(query: str, **kwargs: Any):
|
|
||||||
mode = str(kwargs.get("mode", "search") or "search")
|
|
||||||
if mode == "episode" and "蓝色" in str(query):
|
|
||||||
return stale_search_result
|
|
||||||
return corrected_search_result
|
|
||||||
|
|
||||||
async def _mock_post_correction_profile(person_id: str, **kwargs: Any):
|
|
||||||
del person_id, kwargs
|
|
||||||
return corrected_profile_result
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_service, "search", _mock_post_correction_search)
|
|
||||||
monkeypatch.setattr(memory_service, "get_person_profile", _mock_post_correction_profile)
|
|
||||||
|
|
||||||
direct_post_search = await memory_service.search(
|
|
||||||
RELATION_QUERY,
|
|
||||||
mode="search",
|
|
||||||
chat_id=session_id,
|
|
||||||
respect_filter=False,
|
|
||||||
)
|
|
||||||
assert direct_post_search.hits
|
|
||||||
post_contents = "\n".join(hit.content for hit in direct_post_search.hits)
|
|
||||||
assert "绿色" in post_contents
|
|
||||||
assert "蓝色" not in post_contents
|
|
||||||
|
|
||||||
profile_refresh_request = await _wait_until(
|
|
||||||
lambda: (
|
|
||||||
kernel.metadata_store.get_person_profile_refresh_request(person_id)
|
|
||||||
if kernel.metadata_store.get_person_profile_refresh_request(person_id)
|
|
||||||
and kernel.metadata_store.get_person_profile_refresh_request(person_id).get("status") == "done"
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
timeout_seconds=12.0,
|
|
||||||
interval_seconds=0.1,
|
|
||||||
description="人物画像刷新完成",
|
|
||||||
)
|
|
||||||
assert profile_refresh_request["status"] == "done"
|
|
||||||
|
|
||||||
post_profile = await memory_service.get_person_profile(person_id, chat_id=session_id, limit=10)
|
|
||||||
post_profile_text = post_profile.summary + "\n" + json.dumps(post_profile.evidence, ensure_ascii=False)
|
|
||||||
assert "绿色" in post_profile_text
|
|
||||||
assert "蓝色" not in post_profile_text
|
|
||||||
|
|
||||||
async def _latest_episode_result():
|
|
||||||
result = await memory_service.search(
|
|
||||||
"绿色",
|
|
||||||
mode="episode",
|
|
||||||
chat_id=session_id,
|
|
||||||
respect_filter=False,
|
|
||||||
)
|
|
||||||
if not result.hits:
|
|
||||||
return None
|
|
||||||
contents = "\n".join(hit.content for hit in result.hits)
|
|
||||||
if "绿色" in contents and "蓝色" not in contents:
|
|
||||||
return result
|
|
||||||
return None
|
|
||||||
|
|
||||||
post_episode = await _wait_until(
|
|
||||||
_latest_episode_result,
|
|
||||||
timeout_seconds=12.0,
|
|
||||||
interval_seconds=0.2,
|
|
||||||
description="episode 重建后返回修正结果",
|
|
||||||
)
|
|
||||||
assert post_episode is not None
|
|
||||||
|
|
||||||
stale_episode = await memory_service.search(
|
|
||||||
"蓝色",
|
|
||||||
mode="episode",
|
|
||||||
chat_id=session_id,
|
|
||||||
respect_filter=False,
|
|
||||||
)
|
|
||||||
assert not stale_episode.hits
|
|
||||||
|
|
||||||
await chat_bot.message_process(
|
|
||||||
_build_message_data(
|
|
||||||
content="再查一次,测试用户最喜欢的颜色是什么?",
|
|
||||||
**session_info,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_records = await _wait_until(
|
|
||||||
lambda: (
|
|
||||||
_load_query_memory_tool_records(session_id)
|
|
||||||
if len(_load_query_memory_tool_records(session_id)) >= 2
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
timeout_seconds=10.0,
|
|
||||||
interval_seconds=0.1,
|
|
||||||
description="第二次 query_memory 工具记录生成",
|
|
||||||
)
|
|
||||||
latest_tool_data = json.loads(str(tool_records[-1].get("tool_data") or "{}"))
|
|
||||||
latest_hits = list((latest_tool_data.get("structured_content") or {}).get("hits") or [])
|
|
||||||
assert latest_hits
|
|
||||||
latest_contents = "\n".join(str(item.get("content", "") or "") for item in latest_hits)
|
|
||||||
assert "绿色" in latest_contents
|
|
||||||
assert "蓝色" not in latest_contents
|
|
||||||
monkeypatch.setattr(memory_service, "search", original_search)
|
|
||||||
monkeypatch.setattr(memory_service, "get_person_profile", original_get_person_profile)
|
|
||||||
@@ -1,396 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
IMPORT_ERROR: str | None = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
from src.A_memorix.core.retrieval.sparse_bm25 import SparseBM25Config, SparseBM25Index
|
|
||||||
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
|
|
||||||
from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
|
|
||||||
except SystemExit as exc:
|
|
||||||
IMPORT_ERROR = f"config initialization exited during import: {exc}"
|
|
||||||
SparseBM25Config = None # type: ignore[assignment]
|
|
||||||
SparseBM25Index = None # type: ignore[assignment]
|
|
||||||
kernel_module = None # type: ignore[assignment]
|
|
||||||
SDKMemoryKernel = None # type: ignore[assignment]
|
|
||||||
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_kernel_enqueue_feedback_task_delegates_to_metadata_store(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
captured: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
def fake_enqueue_feedback_task(**kwargs):
|
|
||||||
captured.update(kwargs)
|
|
||||||
return {
|
|
||||||
"id": 1,
|
|
||||||
"query_tool_id": kwargs["query_tool_id"],
|
|
||||||
"session_id": kwargs["session_id"],
|
|
||||||
"query_timestamp": kwargs["query_timestamp"],
|
|
||||||
"due_at": kwargs["due_at"],
|
|
||||||
"query_snapshot": kwargs["query_snapshot"],
|
|
||||||
}
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
kernel_module,
|
|
||||||
"global_config",
|
|
||||||
SimpleNamespace(
|
|
||||||
memory=SimpleNamespace(
|
|
||||||
feedback_correction_enabled=True,
|
|
||||||
feedback_correction_window_hours=12.0,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
query_time = datetime(2026, 4, 9, 10, 30, 0)
|
|
||||||
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
|
|
||||||
kernel.metadata_store = SimpleNamespace(enqueue_feedback_task=fake_enqueue_feedback_task)
|
|
||||||
|
|
||||||
payload = await kernel.enqueue_feedback_task(
|
|
||||||
query_tool_id="tool-query-1",
|
|
||||||
session_id="session-1",
|
|
||||||
query_timestamp=query_time,
|
|
||||||
structured_content={"query": "Alice 喜欢什么", "hits": [{"hash": "relation-1"}]},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert payload["success"] is True
|
|
||||||
assert payload["queued"] is True
|
|
||||||
assert captured["query_tool_id"] == "tool-query-1"
|
|
||||||
assert captured["session_id"] == "session-1"
|
|
||||||
assert captured["query_snapshot"]["query"] == "Alice 喜欢什么"
|
|
||||||
assert captured["query_snapshot"]["hits"] == [{"hash": "relation-1"}]
|
|
||||||
assert captured["due_at"] == pytest.approx(query_time.timestamp() + 12 * 3600, rel=0, abs=1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_kernel_enqueue_feedback_task_skipped_when_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
monkeypatch.setattr(
|
|
||||||
kernel_module,
|
|
||||||
"global_config",
|
|
||||||
SimpleNamespace(memory=SimpleNamespace(feedback_correction_enabled=False)),
|
|
||||||
)
|
|
||||||
|
|
||||||
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
|
|
||||||
kernel.metadata_store = SimpleNamespace(enqueue_feedback_task=lambda **kwargs: kwargs)
|
|
||||||
|
|
||||||
payload = await kernel.enqueue_feedback_task(
|
|
||||||
query_tool_id="tool-query-2",
|
|
||||||
session_id="session-1",
|
|
||||||
query_timestamp=datetime.now(),
|
|
||||||
structured_content={"hits": [{"hash": "relation-1"}]},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert payload["success"] is False
|
|
||||||
assert payload["reason"] == "feedback_correction_disabled"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_apply_feedback_decision_resolves_paragraph_targets() -> None:
|
|
||||||
action_logs: list[Dict[str, Any]] = []
|
|
||||||
forgotten_hashes: list[str] = []
|
|
||||||
ingested_payloads: list[Dict[str, Any]] = []
|
|
||||||
stale_marks: list[Dict[str, Any]] = []
|
|
||||||
episode_sources: list[str] = []
|
|
||||||
profile_refresh_ids: list[str] = []
|
|
||||||
|
|
||||||
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
|
|
||||||
kernel.metadata_store = SimpleNamespace(
|
|
||||||
get_paragraph_relations=lambda paragraph_hash: [
|
|
||||||
{
|
|
||||||
"hash": "relation-1",
|
|
||||||
"subject": "测试用户",
|
|
||||||
"predicate": "最喜欢的颜色是",
|
|
||||||
"object": "蓝色",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
if paragraph_hash == "paragraph-1"
|
|
||||||
else [],
|
|
||||||
get_relation_status_batch=lambda hashes: {
|
|
||||||
str(hash_value): {"is_inactive": str(hash_value) in forgotten_hashes}
|
|
||||||
for hash_value in hashes
|
|
||||||
},
|
|
||||||
get_paragraph_hashes_by_relation_hashes=lambda hashes: {
|
|
||||||
"relation-1": ["paragraph-1"]
|
|
||||||
}
|
|
||||||
if "relation-1" in hashes
|
|
||||||
else {},
|
|
||||||
upsert_paragraph_stale_relation_mark=lambda **kwargs: stale_marks.append(kwargs) or kwargs,
|
|
||||||
enqueue_episode_source_rebuild=lambda source, reason="": episode_sources.append(source) or True,
|
|
||||||
enqueue_person_profile_refresh=lambda **kwargs: profile_refresh_ids.append(kwargs["person_id"]) or kwargs,
|
|
||||||
get_paragraph=lambda paragraph_hash: {"hash": "paragraph-1", "source": "chat_feedback_test_seed:session-1"}
|
|
||||||
if paragraph_hash == "paragraph-1"
|
|
||||||
else None,
|
|
||||||
append_feedback_action_log=lambda **kwargs: action_logs.append(kwargs),
|
|
||||||
)
|
|
||||||
kernel._feedback_cfg_auto_apply_threshold = lambda: 0.85 # type: ignore[method-assign]
|
|
||||||
kernel._apply_v5_relation_action = lambda *, action, hashes, strength=1.0: ( # type: ignore[method-assign]
|
|
||||||
forgotten_hashes.extend([str(item) for item in hashes]),
|
|
||||||
{"success": True, "action": action, "hashes": list(hashes), "strength": strength},
|
|
||||||
)[1]
|
|
||||||
kernel._feedback_cfg_paragraph_mark_enabled = lambda: True # type: ignore[method-assign]
|
|
||||||
kernel._feedback_cfg_episode_rebuild_enabled = lambda: True # type: ignore[method-assign]
|
|
||||||
kernel._feedback_cfg_profile_refresh_enabled = lambda: True # type: ignore[method-assign]
|
|
||||||
kernel._resolve_feedback_related_person_ids = lambda **kwargs: ["person-1"] # type: ignore[method-assign]
|
|
||||||
kernel._query_relation_rows_by_hashes = lambda relation_hashes, include_inactive=False: [ # type: ignore[method-assign]
|
|
||||||
{
|
|
||||||
"hash": "relation-1",
|
|
||||||
"subject": "测试用户",
|
|
||||||
"predicate": "最喜欢的颜色是",
|
|
||||||
"object": "蓝色",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
async def _fake_ingest_feedback_relations(**kwargs):
|
|
||||||
ingested_payloads.append(kwargs)
|
|
||||||
return {"success": True, "stored_ids": ["relation-2"]}
|
|
||||||
|
|
||||||
kernel._ingest_feedback_relations = _fake_ingest_feedback_relations # type: ignore[method-assign]
|
|
||||||
|
|
||||||
payload = await kernel._apply_feedback_decision(
|
|
||||||
task_id=1,
|
|
||||||
query_tool_id="tool-query-1",
|
|
||||||
session_id="session-1",
|
|
||||||
decision={
|
|
||||||
"decision": "correct",
|
|
||||||
"confidence": 0.97,
|
|
||||||
"target_hashes": ["paragraph-1"],
|
|
||||||
"corrected_relations": [
|
|
||||||
{
|
|
||||||
"subject": "测试用户",
|
|
||||||
"predicate": "最喜欢的颜色是",
|
|
||||||
"object": "绿色",
|
|
||||||
"confidence": 0.99,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"reason": "用户明确纠正为绿色",
|
|
||||||
},
|
|
||||||
hit_map={
|
|
||||||
"paragraph-1": {
|
|
||||||
"hash": "paragraph-1",
|
|
||||||
"type": "paragraph",
|
|
||||||
"content": "测试用户 最喜欢的颜色是 蓝色",
|
|
||||||
"linked_relation_hashes": ["relation-1"],
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert payload["applied"] is True
|
|
||||||
assert payload["relation_hashes"] == ["relation-1"]
|
|
||||||
assert forgotten_hashes == ["relation-1"]
|
|
||||||
assert ingested_payloads[0]["relation_hashes"] == ["relation-1"]
|
|
||||||
assert payload["stale_paragraph_hashes"] == ["paragraph-1"]
|
|
||||||
assert "chat_feedback_test_seed:session-1" in payload["episode_rebuild_sources"]
|
|
||||||
assert "chat_summary:session-1" in payload["episode_rebuild_sources"]
|
|
||||||
assert payload["profile_refresh_person_ids"] == ["person-1"]
|
|
||||||
assert stale_marks[0]["paragraph_hash"] == "paragraph-1"
|
|
||||||
assert {item["action_type"] for item in action_logs} == {
|
|
||||||
"forget_relation",
|
|
||||||
"ingest_correction",
|
|
||||||
"mark_stale_paragraph",
|
|
||||||
"enqueue_episode_rebuild",
|
|
||||||
"enqueue_profile_refresh",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_filter_active_relation_hits_removes_inactive_relations() -> None:
|
|
||||||
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
|
|
||||||
kernel._feedback_cfg_paragraph_hard_filter_enabled = lambda: True # type: ignore[method-assign]
|
|
||||||
kernel.metadata_store = SimpleNamespace(
|
|
||||||
get_relation_status_batch=lambda hashes: {
|
|
||||||
"r-active": {"is_inactive": False},
|
|
||||||
"r-inactive": {"is_inactive": True},
|
|
||||||
"r-para-inactive": {"is_inactive": True},
|
|
||||||
"r-stale-active": {"is_inactive": False},
|
|
||||||
"r-stale-inactive": {"is_inactive": True},
|
|
||||||
},
|
|
||||||
get_paragraph_relations=lambda paragraph_hash: (
|
|
||||||
[{"hash": "r-para-inactive"}] if paragraph_hash == "p-inactive" else []
|
|
||||||
),
|
|
||||||
get_paragraph_stale_relation_marks_batch=lambda paragraph_hashes: {
|
|
||||||
"p-stale": [{"relation_hash": "r-stale-inactive"}],
|
|
||||||
"p-restored": [{"relation_hash": "r-stale-active"}],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
hits = [
|
|
||||||
{"hash": "r-active", "type": "relation", "content": "A 喜欢 B"},
|
|
||||||
{"hash": "r-inactive", "type": "relation", "content": "A 讨厌 B"},
|
|
||||||
{"hash": "p-1", "type": "paragraph", "content": "段落证据"},
|
|
||||||
{"hash": "p-inactive", "type": "paragraph", "content": "失活段落证据"},
|
|
||||||
{"hash": "p-stale", "type": "paragraph", "content": "被标脏段落"},
|
|
||||||
{"hash": "p-restored", "type": "paragraph", "content": "恢复可见段落"},
|
|
||||||
]
|
|
||||||
|
|
||||||
filtered = kernel._filter_active_relation_hits(hits)
|
|
||||||
|
|
||||||
assert [item["hash"] for item in filtered] == ["r-active", "p-1", "p-restored"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_sparse_relation_search_requests_active_only() -> None:
|
|
||||||
captured: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
class FakeMetadataStore:
|
|
||||||
def fts_search_relations_bm25(self, **kwargs):
|
|
||||||
captured.update(kwargs)
|
|
||||||
return []
|
|
||||||
|
|
||||||
index = SparseBM25Index(
|
|
||||||
metadata_store=FakeMetadataStore(), # type: ignore[arg-type]
|
|
||||||
config=SparseBM25Config(enabled=True, lazy_load=False),
|
|
||||||
)
|
|
||||||
index._loaded = True
|
|
||||||
index._conn = object() # type: ignore[assignment]
|
|
||||||
|
|
||||||
result = index.search_relations("测试纠错", k=5)
|
|
||||||
|
|
||||||
assert result == []
|
|
||||||
assert captured["include_inactive"] is False
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_feedback_task_rollback_restores_snapshots_and_requeues_followups() -> None:
|
|
||||||
action_logs: list[Dict[str, Any]] = []
|
|
||||||
queued_sources: list[str] = []
|
|
||||||
queued_profiles: list[str] = []
|
|
||||||
deleted_marks: list[tuple[str, str]] = []
|
|
||||||
deleted_paragraphs: list[str] = []
|
|
||||||
relation_statuses: Dict[str, Dict[str, Any]] = {
|
|
||||||
"rel-old": {"is_inactive": True, "weight": 0.0, "is_pinned": False, "protected_until": 0.0, "last_reinforced": None, "inactive_since": 1.0},
|
|
||||||
"rel-new": {"is_inactive": False, "weight": 1.0, "is_pinned": False, "protected_until": 0.0, "last_reinforced": None, "inactive_since": None},
|
|
||||||
}
|
|
||||||
current_task: Dict[str, Any] = {
|
|
||||||
"id": 1,
|
|
||||||
"query_tool_id": "tool-query-rollback",
|
|
||||||
"session_id": "session-1",
|
|
||||||
"status": "applied",
|
|
||||||
"rollback_status": "none",
|
|
||||||
"query_snapshot": {"query": "测试用户最喜欢的颜色是什么"},
|
|
||||||
"decision_payload": {"decision": "correct", "confidence": 0.97},
|
|
||||||
"rollback_plan": {
|
|
||||||
"forgotten_relations": [
|
|
||||||
{
|
|
||||||
"hash": "rel-old",
|
|
||||||
"subject": "测试用户",
|
|
||||||
"predicate": "最喜欢的颜色是",
|
|
||||||
"object": "蓝色",
|
|
||||||
"before_status": {
|
|
||||||
"is_inactive": False,
|
|
||||||
"weight": 0.8,
|
|
||||||
"is_pinned": False,
|
|
||||||
"protected_until": 0.0,
|
|
||||||
"last_reinforced": None,
|
|
||||||
"inactive_since": None,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"corrected_write": {
|
|
||||||
"paragraph_hashes": ["paragraph-new"],
|
|
||||||
"corrected_relations": [
|
|
||||||
{
|
|
||||||
"hash": "rel-new",
|
|
||||||
"subject": "测试用户",
|
|
||||||
"predicate": "最喜欢的颜色是",
|
|
||||||
"object": "绿色",
|
|
||||||
"existed_before": False,
|
|
||||||
"before_status": {},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"stale_marks": [{"paragraph_hash": "paragraph-old", "relation_hash": "rel-old"}],
|
|
||||||
"episode_sources": ["chat_summary:session-1"],
|
|
||||||
"profile_person_ids": ["person-1"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
class _Conn:
|
|
||||||
def cursor(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def execute(self, *_args, **_kwargs):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def commit(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
metadata_store = SimpleNamespace(
|
|
||||||
get_feedback_task_by_id=lambda task_id: current_task if int(task_id) == 1 else None,
|
|
||||||
mark_feedback_task_rollback_running=lambda **kwargs: current_task.update({"rollback_status": "running"}) or current_task,
|
|
||||||
finalize_feedback_task_rollback=lambda **kwargs: current_task.update(
|
|
||||||
{
|
|
||||||
"rollback_status": kwargs["rollback_status"],
|
|
||||||
"rollback_result": kwargs.get("rollback_result") or {},
|
|
||||||
"rollback_error": kwargs.get("rollback_error", ""),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
or current_task,
|
|
||||||
get_relation_status_batch=lambda hashes: {
|
|
||||||
hash_value: dict(relation_statuses[hash_value])
|
|
||||||
for hash_value in hashes
|
|
||||||
if hash_value in relation_statuses
|
|
||||||
},
|
|
||||||
restore_relation_status_from_snapshot=lambda hash_value, snapshot: relation_statuses.update(
|
|
||||||
{hash_value: dict(snapshot)}
|
|
||||||
)
|
|
||||||
or dict(snapshot),
|
|
||||||
append_feedback_action_log=lambda **kwargs: action_logs.append(kwargs),
|
|
||||||
mark_as_deleted=lambda hashes, type_: deleted_paragraphs.extend(list(hashes)) or len(list(hashes)),
|
|
||||||
get_paragraph=lambda paragraph_hash: {"hash": paragraph_hash, "source": "chat_summary:session-1"},
|
|
||||||
get_connection=lambda: _Conn(),
|
|
||||||
delete_external_memory_refs_by_paragraphs=lambda hashes: [
|
|
||||||
{"paragraph_hash": str(hash_value), "external_id": f"external:{hash_value}"}
|
|
||||||
for hash_value in hashes
|
|
||||||
],
|
|
||||||
update_relations_protection=lambda hashes, **kwargs: None,
|
|
||||||
mark_relations_inactive=lambda hashes, inactive_since=None: [
|
|
||||||
relation_statuses.__setitem__(
|
|
||||||
hash_value,
|
|
||||||
{
|
|
||||||
**relation_statuses.get(hash_value, {}),
|
|
||||||
"is_inactive": True,
|
|
||||||
"inactive_since": inactive_since,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
for hash_value in hashes
|
|
||||||
],
|
|
||||||
delete_paragraph_stale_relation_marks=lambda marks: deleted_marks.extend(list(marks)) or len(list(marks)),
|
|
||||||
enqueue_episode_source_rebuild=lambda source, reason='': queued_sources.append(source) or True,
|
|
||||||
enqueue_person_profile_refresh=lambda **kwargs: queued_profiles.append(kwargs["person_id"]) or kwargs,
|
|
||||||
list_feedback_action_logs=lambda task_id: action_logs if int(task_id) == 1 else [],
|
|
||||||
)
|
|
||||||
|
|
||||||
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
|
|
||||||
kernel.metadata_store = metadata_store
|
|
||||||
async def _noop_initialize() -> None:
|
|
||||||
return None
|
|
||||||
kernel.initialize = _noop_initialize # type: ignore[method-assign]
|
|
||||||
kernel._rebuild_graph_from_metadata = lambda: None # type: ignore[method-assign]
|
|
||||||
kernel._persist = lambda: None # type: ignore[method-assign]
|
|
||||||
|
|
||||||
payload = await kernel._rollback_feedback_task(
|
|
||||||
task_id=1,
|
|
||||||
requested_by="pytest",
|
|
||||||
reason="manual rollback",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert payload["success"] is True
|
|
||||||
assert relation_statuses["rel-old"]["is_inactive"] is False
|
|
||||||
assert relation_statuses["rel-new"]["is_inactive"] is True
|
|
||||||
assert deleted_paragraphs == ["paragraph-new"]
|
|
||||||
assert deleted_marks == [("paragraph-old", "rel-old")]
|
|
||||||
assert queued_sources == ["chat_summary:session-1"]
|
|
||||||
assert queued_profiles == ["person-1"]
|
|
||||||
assert current_task["rollback_status"] == "rolled_back"
|
|
||||||
assert {item["action_type"] for item in action_logs} >= {
|
|
||||||
"rollback_restore_relation",
|
|
||||||
"rollback_revert_corrected_relation",
|
|
||||||
"rollback_delete_correction_paragraph",
|
|
||||||
"rollback_clear_stale_mark",
|
|
||||||
"rollback_enqueue_episode_rebuild",
|
|
||||||
"rollback_enqueue_profile_refresh",
|
|
||||||
}
|
|
||||||
@@ -1,82 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import pickle
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
try:
|
|
||||||
from src.A_memorix.core.storage.graph_store import GraphStore
|
|
||||||
except SystemExit as exc:
|
|
||||||
GraphStore = None # type: ignore[assignment]
|
|
||||||
IMPORT_ERROR = f"config initialization exited during import: {exc}"
|
|
||||||
else:
|
|
||||||
IMPORT_ERROR = None
|
|
||||||
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
|
|
||||||
|
|
||||||
|
|
||||||
def _build_empty_graph_metadata() -> dict:
|
|
||||||
return {
|
|
||||||
"nodes": [],
|
|
||||||
"node_to_idx": {},
|
|
||||||
"node_attrs": {},
|
|
||||||
"matrix_format": "csr",
|
|
||||||
"total_nodes_added": 0,
|
|
||||||
"total_edges_added": 0,
|
|
||||||
"total_nodes_deleted": 0,
|
|
||||||
"total_edges_deleted": 0,
|
|
||||||
"edge_hash_map": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_graph_store_clear_save_removes_stale_adjacency(tmp_path: Path) -> None:
|
|
||||||
data_dir = tmp_path / "graph_data"
|
|
||||||
store = GraphStore(data_dir=data_dir)
|
|
||||||
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
|
|
||||||
store.save()
|
|
||||||
|
|
||||||
matrix_path = data_dir / "graph_adjacency.npz"
|
|
||||||
assert matrix_path.exists()
|
|
||||||
|
|
||||||
store.clear()
|
|
||||||
store.save()
|
|
||||||
|
|
||||||
assert not matrix_path.exists()
|
|
||||||
|
|
||||||
|
|
||||||
def test_graph_store_load_resets_stale_adjacency_when_metadata_is_empty(tmp_path: Path) -> None:
|
|
||||||
data_dir = tmp_path / "graph_data"
|
|
||||||
store = GraphStore(data_dir=data_dir)
|
|
||||||
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
|
|
||||||
store.save()
|
|
||||||
|
|
||||||
metadata_path = data_dir / "graph_metadata.pkl"
|
|
||||||
with metadata_path.open("wb") as handle:
|
|
||||||
pickle.dump(_build_empty_graph_metadata(), handle)
|
|
||||||
|
|
||||||
reloaded = GraphStore(data_dir=data_dir)
|
|
||||||
reloaded.load()
|
|
||||||
|
|
||||||
assert reloaded.num_nodes == 0
|
|
||||||
assert reloaded.num_edges == 0
|
|
||||||
assert reloaded.get_nodes() == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_graph_store_load_clears_stale_edge_hash_map_when_metadata_is_empty(tmp_path: Path) -> None:
|
|
||||||
data_dir = tmp_path / "graph_data"
|
|
||||||
store = GraphStore(data_dir=data_dir)
|
|
||||||
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
|
|
||||||
store.save()
|
|
||||||
|
|
||||||
metadata_path = data_dir / "graph_metadata.pkl"
|
|
||||||
empty_metadata = _build_empty_graph_metadata()
|
|
||||||
empty_metadata["edge_hash_map"] = {(0, 1): {"rel-1"}}
|
|
||||||
with metadata_path.open("wb") as handle:
|
|
||||||
pickle.dump(empty_metadata, handle)
|
|
||||||
|
|
||||||
reloaded = GraphStore(data_dir=data_dir)
|
|
||||||
reloaded.load()
|
|
||||||
|
|
||||||
assert reloaded.has_edge_hash_map() is False
|
|
||||||
@@ -1,86 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
DATA_DIR = Path(__file__).parent / "data" / "benchmarks"
|
|
||||||
|
|
||||||
|
|
||||||
def _fixture_files() -> list[Path]:
|
|
||||||
return sorted(DATA_DIR.glob("group_chat_stream_memory_benchmark*.json"))
|
|
||||||
|
|
||||||
|
|
||||||
def _load_fixture(path: Path) -> dict:
|
|
||||||
return json.loads(path.read_text(encoding="utf-8"))
|
|
||||||
|
|
||||||
|
|
||||||
def _assert_fixture_matches_current_design_constraints(dataset: dict) -> None:
|
|
||||||
assert dataset["meta"]["scenario_id"]
|
|
||||||
|
|
||||||
assert dataset["session"]["group_id"]
|
|
||||||
assert dataset["session"]["platform"] == "qq"
|
|
||||||
|
|
||||||
simulated_batches = dataset["simulated_stream_batches"]
|
|
||||||
assert len(simulated_batches) >= 5
|
|
||||||
|
|
||||||
positive_batches = [item for item in simulated_batches if item["bot_participated"]]
|
|
||||||
negative_batches = [item for item in simulated_batches if not item["bot_participated"]]
|
|
||||||
|
|
||||||
assert len(positive_batches) >= 4
|
|
||||||
assert len(negative_batches) >= 1
|
|
||||||
assert any(item["expected_behavior"] == "ignored_by_summarizer_without_bot_message" for item in negative_batches)
|
|
||||||
|
|
||||||
for batch in positive_batches:
|
|
||||||
assert "Mai" in batch["participants"]
|
|
||||||
assert batch["message_count"] >= 10
|
|
||||||
assert len(batch["combined_text"]) >= 300
|
|
||||||
assert batch["start_time"] < batch["end_time"]
|
|
||||||
assert len(batch["expected_memory_targets"]) >= 4
|
|
||||||
|
|
||||||
runtime_streams = dataset["runtime_trigger_streams"]
|
|
||||||
assert len(runtime_streams) >= 2
|
|
||||||
|
|
||||||
runtime_positive = [item for item in runtime_streams if item["bot_participated"]]
|
|
||||||
runtime_negative = [item for item in runtime_streams if not item["bot_participated"]]
|
|
||||||
|
|
||||||
assert len(runtime_positive) >= 1
|
|
||||||
assert len(runtime_negative) >= 1
|
|
||||||
|
|
||||||
for stream in runtime_streams:
|
|
||||||
stream_text = "\n".join(stream["messages"])
|
|
||||||
assert stream["trigger_mode"] == "time_threshold"
|
|
||||||
assert stream["elapsed_since_last_check_hours"] >= 8.0
|
|
||||||
assert stream["message_count"] >= 20
|
|
||||||
assert len(stream["messages"]) == stream["message_count"]
|
|
||||||
assert len(stream_text) >= 1000
|
|
||||||
assert stream["start_time"] < stream["end_time"]
|
|
||||||
|
|
||||||
assert any(item["expected_check_outcome"] == "should_trigger_topic_check_and_pass_bot_gate" for item in runtime_positive)
|
|
||||||
assert any(
|
|
||||||
item["expected_check_outcome"] == "should_trigger_topic_check_but_be_discarded_without_bot_message"
|
|
||||||
for item in runtime_negative
|
|
||||||
)
|
|
||||||
|
|
||||||
records = dataset["chat_history_records"]
|
|
||||||
assert len(records) >= 4
|
|
||||||
for record in records:
|
|
||||||
assert "Mai" in record["participants"]
|
|
||||||
assert len(record["summary"]) >= 40
|
|
||||||
assert len(record["original_text"]) >= 200
|
|
||||||
assert record["start_time"] < record["end_time"]
|
|
||||||
|
|
||||||
assert len(dataset["person_writebacks"]) >= 3
|
|
||||||
assert len(dataset["search_cases"]) >= 4
|
|
||||||
assert len(dataset["time_cases"]) >= 3
|
|
||||||
assert len(dataset["episode_cases"]) >= 4
|
|
||||||
assert len(dataset["knowledge_fetcher_cases"]) >= 3
|
|
||||||
assert len(dataset["profile_cases"]) >= 3
|
|
||||||
assert len(dataset["negative_control_cases"]) >= 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_group_chat_stream_fixture_matches_current_design_constraints():
|
|
||||||
files = _fixture_files()
|
|
||||||
assert files, "未找到 group_chat_stream_memory_benchmark*.json fixture"
|
|
||||||
for path in files:
|
|
||||||
_assert_fixture_matches_current_design_constraints(_load_fixture(path))
|
|
||||||
@@ -1,124 +0,0 @@
|
|||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
|
|
||||||
from src.services.memory_service import MemoryHit, MemorySearchResult
|
|
||||||
|
|
||||||
|
|
||||||
def test_knowledge_fetcher_resolves_private_memory_context(monkeypatch):
|
|
||||||
monkeypatch.setattr(
|
|
||||||
knowledge_module,
|
|
||||||
"_chat_manager",
|
|
||||||
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
knowledge_module,
|
|
||||||
"resolve_person_id_for_memory",
|
|
||||||
lambda *, person_name, platform, user_id: f"{person_name}:{platform}:{user_id}",
|
|
||||||
)
|
|
||||||
|
|
||||||
fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1")
|
|
||||||
|
|
||||||
assert fetcher._resolve_private_memory_context() == {
|
|
||||||
"chat_id": "stream-1",
|
|
||||||
"person_id": "Alice:qq:42",
|
|
||||||
"user_id": "42",
|
|
||||||
"group_id": "",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_knowledge_fetcher_memory_get_knowledge_uses_memory_service(monkeypatch):
|
|
||||||
monkeypatch.setattr(
|
|
||||||
knowledge_module,
|
|
||||||
"_chat_manager",
|
|
||||||
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
knowledge_module,
|
|
||||||
"resolve_person_id_for_memory",
|
|
||||||
lambda *, person_name, platform, user_id: f"{person_name}:{platform}:{user_id}",
|
|
||||||
)
|
|
||||||
|
|
||||||
calls = []
|
|
||||||
|
|
||||||
async def fake_search(query: str, **kwargs):
|
|
||||||
calls.append((query, kwargs))
|
|
||||||
return MemorySearchResult(summary="", hits=[MemoryHit(content="她喜欢猫", source="person_fact:qq:42")], filtered=False)
|
|
||||||
|
|
||||||
monkeypatch.setattr(knowledge_module.memory_service, "search", fake_search)
|
|
||||||
|
|
||||||
fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1")
|
|
||||||
result = await fetcher._memory_get_knowledge("她喜欢什么")
|
|
||||||
|
|
||||||
assert "1. 她喜欢猫" in result
|
|
||||||
assert calls == [
|
|
||||||
(
|
|
||||||
"她喜欢什么",
|
|
||||||
{
|
|
||||||
"limit": 5,
|
|
||||||
"mode": "search",
|
|
||||||
"chat_id": "stream-1",
|
|
||||||
"person_id": "Alice:qq:42",
|
|
||||||
"user_id": "42",
|
|
||||||
"group_id": "",
|
|
||||||
"respect_filter": True,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_knowledge_fetcher_falls_back_to_chat_scope_when_person_scope_misses(monkeypatch):
|
|
||||||
monkeypatch.setattr(
|
|
||||||
knowledge_module,
|
|
||||||
"_chat_manager",
|
|
||||||
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
knowledge_module,
|
|
||||||
"resolve_person_id_for_memory",
|
|
||||||
lambda *, person_name, platform, user_id: "person-1",
|
|
||||||
)
|
|
||||||
|
|
||||||
calls = []
|
|
||||||
|
|
||||||
async def fake_search(query: str, **kwargs):
|
|
||||||
calls.append((query, kwargs))
|
|
||||||
if kwargs.get("person_id"):
|
|
||||||
return MemorySearchResult(summary="", hits=[], filtered=False)
|
|
||||||
return MemorySearchResult(summary="", hits=[MemoryHit(content="她计划去杭州音乐节", source="chat_summary:stream-1")], filtered=False)
|
|
||||||
|
|
||||||
monkeypatch.setattr(knowledge_module.memory_service, "search", fake_search)
|
|
||||||
|
|
||||||
fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1")
|
|
||||||
result = await fetcher._memory_get_knowledge("Alice 最近在忙什么")
|
|
||||||
|
|
||||||
assert "杭州音乐节" in result
|
|
||||||
assert calls == [
|
|
||||||
(
|
|
||||||
"Alice 最近在忙什么",
|
|
||||||
{
|
|
||||||
"limit": 5,
|
|
||||||
"mode": "search",
|
|
||||||
"chat_id": "stream-1",
|
|
||||||
"person_id": "person-1",
|
|
||||||
"user_id": "42",
|
|
||||||
"group_id": "",
|
|
||||||
"respect_filter": True,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"Alice 最近在忙什么",
|
|
||||||
{
|
|
||||||
"limit": 5,
|
|
||||||
"mode": "search",
|
|
||||||
"chat_id": "stream-1",
|
|
||||||
"person_id": "",
|
|
||||||
"user_id": "42",
|
|
||||||
"group_id": "",
|
|
||||||
"respect_filter": True,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
]
|
|
||||||
@@ -1,355 +0,0 @@
|
|||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.services import memory_flow_service as memory_flow_module
|
|
||||||
|
|
||||||
|
|
||||||
def _fake_global_config(**integration_values):
|
|
||||||
return SimpleNamespace(
|
|
||||||
a_memorix=SimpleNamespace(
|
|
||||||
integration=SimpleNamespace(**integration_values),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_person_fact_parse_fact_list_deduplicates_and_filters_short_items():
|
|
||||||
raw = '["他喜欢猫", "他喜欢猫", "好", "", "他会弹吉他"]'
|
|
||||||
|
|
||||||
result = memory_flow_module.PersonFactWritebackService._parse_fact_list(raw)
|
|
||||||
|
|
||||||
assert result == ["他喜欢猫", "他会弹吉他"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_person_fact_looks_ephemeral_detects_short_chitchat():
|
|
||||||
assert memory_flow_module.PersonFactWritebackService._looks_ephemeral("哈哈")
|
|
||||||
assert memory_flow_module.PersonFactWritebackService._looks_ephemeral("好的?")
|
|
||||||
assert not memory_flow_module.PersonFactWritebackService._looks_ephemeral("她最近在学法语和钢琴")
|
|
||||||
|
|
||||||
|
|
||||||
def test_person_fact_resolve_target_person_for_private_chat(monkeypatch):
|
|
||||||
class FakePerson:
|
|
||||||
def __init__(self, person_id: str):
|
|
||||||
self.person_id = person_id
|
|
||||||
self.is_known = True
|
|
||||||
|
|
||||||
service = memory_flow_module.PersonFactWritebackService.__new__(memory_flow_module.PersonFactWritebackService)
|
|
||||||
monkeypatch.setattr(memory_flow_module, "is_bot_self", lambda platform, user_id: False)
|
|
||||||
monkeypatch.setattr(memory_flow_module, "get_person_id", lambda platform, user_id: f"{platform}:{user_id}")
|
|
||||||
monkeypatch.setattr(memory_flow_module, "Person", FakePerson)
|
|
||||||
|
|
||||||
message = SimpleNamespace(session=SimpleNamespace(platform="qq", user_id="123", group_id=""))
|
|
||||||
|
|
||||||
person = service._resolve_target_person(message)
|
|
||||||
|
|
||||||
assert person is not None
|
|
||||||
assert person.person_id == "qq:123"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_person_fact_writeback_skips_bot_only_fact_without_user_evidence(monkeypatch):
|
|
||||||
stored_facts: list[tuple[str, str, str]] = []
|
|
||||||
|
|
||||||
class FakePerson:
|
|
||||||
person_id = "person-1"
|
|
||||||
person_name = "测试用户"
|
|
||||||
nickname = "测试用户"
|
|
||||||
is_known = True
|
|
||||||
|
|
||||||
service = memory_flow_module.PersonFactWritebackService.__new__(memory_flow_module.PersonFactWritebackService)
|
|
||||||
service._resolve_target_person = lambda message: FakePerson()
|
|
||||||
|
|
||||||
async def fake_extract_facts(person, reply_text, user_evidence_text):
|
|
||||||
del person, reply_text, user_evidence_text
|
|
||||||
return ["测试用户喜欢辣椒"]
|
|
||||||
|
|
||||||
async def fake_store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str, **kwargs):
|
|
||||||
del kwargs
|
|
||||||
stored_facts.append((person_name, memory_content, chat_id))
|
|
||||||
|
|
||||||
service._extract_facts = fake_extract_facts
|
|
||||||
monkeypatch.setattr(memory_flow_module, "store_person_memory_from_answer", fake_store_person_memory_from_answer)
|
|
||||||
monkeypatch.setattr(memory_flow_module, "find_messages", lambda **kwargs: [])
|
|
||||||
|
|
||||||
message = SimpleNamespace(
|
|
||||||
processed_plain_text="我记得你喜欢辣椒。",
|
|
||||||
session_id="session-1",
|
|
||||||
reply_to="",
|
|
||||||
session=SimpleNamespace(platform="qq", user_id="bot-1", group_id=""),
|
|
||||||
)
|
|
||||||
|
|
||||||
await service._handle_message(message)
|
|
||||||
|
|
||||||
assert stored_facts == []
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_chat_summary_writeback_service_triggers_when_threshold_reached(monkeypatch):
|
|
||||||
events: list[tuple[str, object]] = []
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
memory_flow_module,
|
|
||||||
"global_config",
|
|
||||||
_fake_global_config(
|
|
||||||
chat_summary_writeback_enabled=True,
|
|
||||||
chat_summary_writeback_message_threshold=3,
|
|
||||||
chat_summary_writeback_context_length=7,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
|
|
||||||
|
|
||||||
async def fake_ingest_summary(**kwargs):
|
|
||||||
events.append(("ingest_summary", kwargs))
|
|
||||||
return SimpleNamespace(success=True, detail="ok")
|
|
||||||
|
|
||||||
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
|
|
||||||
del self, session_id, total_message_count
|
|
||||||
return 0
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
memory_flow_module.ChatSummaryWritebackService,
|
|
||||||
"_load_last_trigger_message_count",
|
|
||||||
fake_load_last_trigger_message_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
service = memory_flow_module.ChatSummaryWritebackService()
|
|
||||||
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
|
|
||||||
|
|
||||||
await service._handle_message(message)
|
|
||||||
|
|
||||||
assert len(events) == 1
|
|
||||||
_, payload = events[0]
|
|
||||||
assert payload["external_id"] == "chat_auto_summary:session-1:5"
|
|
||||||
assert payload["chat_id"] == "session-1"
|
|
||||||
assert payload["text"] == ""
|
|
||||||
assert payload["metadata"]["generate_from_chat"] is True
|
|
||||||
assert payload["metadata"]["context_length"] == 7
|
|
||||||
assert payload["metadata"]["trigger"] == "message_threshold"
|
|
||||||
assert payload["user_id"] == "user-1"
|
|
||||||
assert payload["group_id"] == "group-1"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_chat_summary_writeback_service_skips_when_threshold_not_reached(monkeypatch):
|
|
||||||
called = False
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
memory_flow_module,
|
|
||||||
"global_config",
|
|
||||||
_fake_global_config(
|
|
||||||
chat_summary_writeback_enabled=True,
|
|
||||||
chat_summary_writeback_message_threshold=6,
|
|
||||||
chat_summary_writeback_context_length=9,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
|
|
||||||
|
|
||||||
async def fake_ingest_summary(**kwargs):
|
|
||||||
nonlocal called
|
|
||||||
called = True
|
|
||||||
return SimpleNamespace(success=True, detail="ok")
|
|
||||||
|
|
||||||
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
|
|
||||||
del self, session_id, total_message_count
|
|
||||||
return 0
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
memory_flow_module.ChatSummaryWritebackService,
|
|
||||||
"_load_last_trigger_message_count",
|
|
||||||
fake_load_last_trigger_message_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
service = memory_flow_module.ChatSummaryWritebackService()
|
|
||||||
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
|
|
||||||
|
|
||||||
await service._handle_message(message)
|
|
||||||
|
|
||||||
assert called is False
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_chat_summary_writeback_service_restores_previous_trigger_count(monkeypatch):
|
|
||||||
events: list[tuple[str, object]] = []
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
memory_flow_module,
|
|
||||||
"global_config",
|
|
||||||
_fake_global_config(
|
|
||||||
chat_summary_writeback_enabled=True,
|
|
||||||
chat_summary_writeback_message_threshold=3,
|
|
||||||
chat_summary_writeback_context_length=7,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 8)
|
|
||||||
|
|
||||||
async def fake_ingest_summary(**kwargs):
|
|
||||||
events.append(("ingest_summary", kwargs))
|
|
||||||
return SimpleNamespace(success=True, detail="ok")
|
|
||||||
|
|
||||||
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
|
|
||||||
del self, session_id, total_message_count
|
|
||||||
return 5
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
memory_flow_module.ChatSummaryWritebackService,
|
|
||||||
"_load_last_trigger_message_count",
|
|
||||||
fake_load_last_trigger_message_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
service = memory_flow_module.ChatSummaryWritebackService()
|
|
||||||
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
|
|
||||||
|
|
||||||
await service._handle_message(message)
|
|
||||||
|
|
||||||
assert len(events) == 1
|
|
||||||
_, payload = events[0]
|
|
||||||
assert payload["external_id"] == "chat_auto_summary:session-1:8"
|
|
||||||
assert service._states["session-1"].last_trigger_message_count == 8
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_chat_summary_writeback_service_falls_back_to_current_count_for_legacy_summary(monkeypatch):
|
|
||||||
called = False
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
memory_flow_module,
|
|
||||||
"global_config",
|
|
||||||
_fake_global_config(
|
|
||||||
chat_summary_writeback_enabled=True,
|
|
||||||
chat_summary_writeback_message_threshold=3,
|
|
||||||
chat_summary_writeback_context_length=7,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
|
|
||||||
|
|
||||||
async def fake_ingest_summary(**kwargs):
|
|
||||||
nonlocal called
|
|
||||||
called = True
|
|
||||||
return SimpleNamespace(success=True, detail="ok")
|
|
||||||
|
|
||||||
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
|
|
||||||
del self, session_id, total_message_count
|
|
||||||
return 5
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
memory_flow_module.ChatSummaryWritebackService,
|
|
||||||
"_load_last_trigger_message_count",
|
|
||||||
fake_load_last_trigger_message_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
service = memory_flow_module.ChatSummaryWritebackService()
|
|
||||||
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
|
|
||||||
|
|
||||||
await service._handle_message(message)
|
|
||||||
|
|
||||||
assert called is False
|
|
||||||
assert service._states["session-1"].last_trigger_message_count == 5
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_chat_summary_writeback_service_loads_trigger_count_from_summary_metadata(monkeypatch):
|
|
||||||
class FakeMetadataStore:
|
|
||||||
@staticmethod
|
|
||||||
def get_paragraphs_by_source(source: str):
|
|
||||||
assert source == "chat_summary:session-1"
|
|
||||||
return [
|
|
||||||
{"created_at": 1.0, "metadata": {"trigger_message_count": 3}},
|
|
||||||
{"created_at": 2.0, "metadata": {"trigger_message_count": 6}},
|
|
||||||
]
|
|
||||||
|
|
||||||
class FakeRuntimeManager:
|
|
||||||
@staticmethod
|
|
||||||
async def _ensure_kernel():
|
|
||||||
return SimpleNamespace(metadata_store=FakeMetadataStore())
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_flow_module.memory_service_module, "a_memorix_host_service", FakeRuntimeManager())
|
|
||||||
|
|
||||||
service = memory_flow_module.ChatSummaryWritebackService()
|
|
||||||
|
|
||||||
restored = await service._load_last_trigger_message_count(session_id="session-1", total_message_count=8)
|
|
||||||
|
|
||||||
assert restored == 6
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_memory_automation_service_auto_starts_and_delegates():
|
|
||||||
events: list[tuple[str, str]] = []
|
|
||||||
|
|
||||||
class FakeFactWriteback:
|
|
||||||
async def start(self):
|
|
||||||
events.append(("start", "fact"))
|
|
||||||
|
|
||||||
async def enqueue(self, message):
|
|
||||||
events.append(("sent", message.session_id))
|
|
||||||
|
|
||||||
async def shutdown(self):
|
|
||||||
events.append(("shutdown", "fact"))
|
|
||||||
|
|
||||||
class FakeChatSummaryWriteback:
|
|
||||||
async def start(self):
|
|
||||||
events.append(("start", "summary"))
|
|
||||||
|
|
||||||
async def enqueue(self, message):
|
|
||||||
events.append(("summary", message.session_id))
|
|
||||||
|
|
||||||
async def shutdown(self):
|
|
||||||
events.append(("shutdown", "summary"))
|
|
||||||
|
|
||||||
service = memory_flow_module.MemoryAutomationService()
|
|
||||||
service.fact_writeback = FakeFactWriteback()
|
|
||||||
service.chat_summary_writeback = FakeChatSummaryWriteback()
|
|
||||||
|
|
||||||
await service.on_message_sent(SimpleNamespace(session_id="session-1"))
|
|
||||||
await service.shutdown()
|
|
||||||
|
|
||||||
assert events == [
|
|
||||||
("start", "fact"),
|
|
||||||
("start", "summary"),
|
|
||||||
("sent", "session-1"),
|
|
||||||
("summary", "session-1"),
|
|
||||||
("shutdown", "summary"),
|
|
||||||
("shutdown", "fact"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_memory_automation_service_on_incoming_message_auto_starts_only():
|
|
||||||
events: list[tuple[str, str]] = []
|
|
||||||
|
|
||||||
class FakeFactWriteback:
|
|
||||||
async def start(self):
|
|
||||||
events.append(("start", "fact"))
|
|
||||||
|
|
||||||
async def enqueue(self, message):
|
|
||||||
events.append(("sent", message.session_id))
|
|
||||||
|
|
||||||
async def shutdown(self):
|
|
||||||
events.append(("shutdown", "fact"))
|
|
||||||
|
|
||||||
class FakeChatSummaryWriteback:
|
|
||||||
async def start(self):
|
|
||||||
events.append(("start", "summary"))
|
|
||||||
|
|
||||||
async def enqueue(self, message):
|
|
||||||
events.append(("summary", message.session_id))
|
|
||||||
|
|
||||||
async def shutdown(self):
|
|
||||||
events.append(("shutdown", "summary"))
|
|
||||||
|
|
||||||
service = memory_flow_module.MemoryAutomationService()
|
|
||||||
service.fact_writeback = FakeFactWriteback()
|
|
||||||
service.chat_summary_writeback = FakeChatSummaryWriteback()
|
|
||||||
|
|
||||||
await service.on_incoming_message(SimpleNamespace(session_id="session-1"))
|
|
||||||
await service.shutdown()
|
|
||||||
|
|
||||||
assert events == [
|
|
||||||
("start", "fact"),
|
|
||||||
("start", "summary"),
|
|
||||||
("shutdown", "summary"),
|
|
||||||
("shutdown", "fact"),
|
|
||||||
]
|
|
||||||
@@ -1,113 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
|
|
||||||
|
|
||||||
|
|
||||||
class _DummyMetadataStore:
|
|
||||||
def __init__(self, *, entities: list[dict[str, Any]], relations: list[dict[str, Any]]) -> None:
|
|
||||||
self._entities = entities
|
|
||||||
self._relations = relations
|
|
||||||
|
|
||||||
def query(self, sql: str, params: tuple[Any, ...] = ()) -> list[dict[str, Any]]:
|
|
||||||
sql_token = " ".join(str(sql or "").lower().split())
|
|
||||||
keyword = str(params[0] or "").strip("%").lower() if params else ""
|
|
||||||
if "from entities" in sql_token:
|
|
||||||
rows = [dict(item) for item in self._entities if not bool(item.get("is_deleted", 0))]
|
|
||||||
if not keyword:
|
|
||||||
return rows
|
|
||||||
return [
|
|
||||||
row
|
|
||||||
for row in rows
|
|
||||||
if keyword in str(row.get("name", "") or "").lower()
|
|
||||||
or keyword in str(row.get("hash", "") or "").lower()
|
|
||||||
]
|
|
||||||
if "from relations" in sql_token:
|
|
||||||
rows = [dict(item) for item in self._relations if not bool(item.get("is_inactive", 0))]
|
|
||||||
if not keyword:
|
|
||||||
return rows
|
|
||||||
return [
|
|
||||||
row
|
|
||||||
for row in rows
|
|
||||||
if keyword in str(row.get("subject", "") or "").lower()
|
|
||||||
or keyword in str(row.get("object", "") or "").lower()
|
|
||||||
or keyword in str(row.get("predicate", "") or "").lower()
|
|
||||||
or keyword in str(row.get("hash", "") or "").lower()
|
|
||||||
]
|
|
||||||
raise AssertionError(f"unexpected query: {sql_token}")
|
|
||||||
|
|
||||||
|
|
||||||
def _build_kernel(*, entities: list[dict[str, Any]], relations: list[dict[str, Any]]) -> SDKMemoryKernel:
|
|
||||||
kernel = SDKMemoryKernel(plugin_root=Path.cwd(), config={})
|
|
||||||
|
|
||||||
async def _fake_initialize() -> None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
kernel.initialize = _fake_initialize # type: ignore[method-assign]
|
|
||||||
kernel.metadata_store = _DummyMetadataStore(entities=entities, relations=relations)
|
|
||||||
kernel.graph_store = object() # type: ignore[assignment]
|
|
||||||
return kernel
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_memory_graph_admin_search_orders_and_dedupes_results() -> None:
|
|
||||||
kernel = _build_kernel(
|
|
||||||
entities=[
|
|
||||||
{"hash": "e1", "name": "Alice", "appearance_count": 5, "is_deleted": 0},
|
|
||||||
{"hash": "e1", "name": "Alice Duplicate", "appearance_count": 99, "is_deleted": 0},
|
|
||||||
{"hash": "e2", "name": "Alice Cooper", "appearance_count": 7, "is_deleted": 0},
|
|
||||||
{"hash": "e3", "name": "my alice note", "appearance_count": 11, "is_deleted": 0},
|
|
||||||
{"hash": "e4", "name": "alice deleted", "appearance_count": 100, "is_deleted": 1},
|
|
||||||
],
|
|
||||||
relations=[
|
|
||||||
{"hash": "r1", "subject": "Alice", "predicate": "knows", "object": "Bob", "confidence": 0.6, "created_at": 100, "is_inactive": 0},
|
|
||||||
{"hash": "r3", "subject": "Alice", "predicate": "supports", "object": "Carol", "confidence": 0.9, "created_at": 90, "is_inactive": 0},
|
|
||||||
{"hash": "r1", "subject": "Alice", "predicate": "knows duplicate", "object": "Bob", "confidence": 0.99, "created_at": 200, "is_inactive": 0},
|
|
||||||
{"hash": "r2", "subject": "Alice Cooper", "predicate": "likes", "object": "Tea", "confidence": 0.2, "created_at": 50, "is_inactive": 0},
|
|
||||||
{"hash": "", "subject": "Carol", "predicate": "mentions alice", "object": "Topic", "confidence": 0.8, "created_at": 70, "is_inactive": 0},
|
|
||||||
{"hash": "", "subject": "Carol", "predicate": "mentions alice", "object": "Topic", "confidence": 0.3, "created_at": 10, "is_inactive": 0},
|
|
||||||
{"hash": "r4", "subject": "alice inactive", "predicate": "old", "object": "Data", "confidence": 1.0, "created_at": 300, "is_inactive": 1},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = await kernel.memory_graph_admin(action="search", query="alice", limit=20)
|
|
||||||
|
|
||||||
assert payload["success"] is True
|
|
||||||
assert payload["count"] == len(payload["items"])
|
|
||||||
entity_items = [item for item in payload["items"] if item["type"] == "entity"]
|
|
||||||
relation_items = [item for item in payload["items"] if item["type"] == "relation"]
|
|
||||||
|
|
||||||
assert [item["entity_hash"] for item in entity_items] == ["e1", "e2", "e3"]
|
|
||||||
assert [item["relation_hash"] for item in relation_items] == ["r3", "r1", "r2", ""]
|
|
||||||
assert relation_items[0]["confidence"] == pytest.approx(0.9)
|
|
||||||
assert relation_items[1]["confidence"] == pytest.approx(0.6)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_memory_graph_admin_search_filters_deleted_and_inactive_records() -> None:
|
|
||||||
kernel = _build_kernel(
|
|
||||||
entities=[
|
|
||||||
{"hash": "e-deleted", "name": "Ghost Alice", "appearance_count": 10, "is_deleted": 1},
|
|
||||||
],
|
|
||||||
relations=[
|
|
||||||
{
|
|
||||||
"hash": "r-inactive",
|
|
||||||
"subject": "Ghost Alice",
|
|
||||||
"predicate": "linked",
|
|
||||||
"object": "Ghost Bob",
|
|
||||||
"confidence": 0.9,
|
|
||||||
"created_at": 10,
|
|
||||||
"is_inactive": 1,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = await kernel.memory_graph_admin(action="search", query="ghost", limit=50)
|
|
||||||
|
|
||||||
assert payload["success"] is True
|
|
||||||
assert payload["items"] == []
|
|
||||||
assert payload["count"] == 0
|
|
||||||
@@ -1,281 +0,0 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
from src.services.memory_service import MemorySearchResult, MemoryService
|
|
||||||
|
|
||||||
|
|
||||||
def test_coerce_write_result_treats_skipped_payload_as_success():
|
|
||||||
result = MemoryService._coerce_write_result({"skipped_ids": ["p1"], "detail": "chat_filtered"})
|
|
||||||
|
|
||||||
assert result.success is True
|
|
||||||
assert result.stored_ids == []
|
|
||||||
assert result.skipped_ids == ["p1"]
|
|
||||||
assert result.detail == "chat_filtered"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_graph_admin_invokes_plugin(monkeypatch):
|
|
||||||
service = MemoryService()
|
|
||||||
calls = []
|
|
||||||
|
|
||||||
async def fake_invoke(component_name, args=None, **kwargs):
|
|
||||||
calls.append((component_name, args, kwargs))
|
|
||||||
return {"success": True, "nodes": [], "edges": []}
|
|
||||||
|
|
||||||
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
|
||||||
|
|
||||||
result = await service.graph_admin(action="get_graph", limit=12)
|
|
||||||
|
|
||||||
assert result["success"] is True
|
|
||||||
assert calls == [("memory_graph_admin", {"action": "get_graph", "limit": 12}, {"timeout_ms": 30000})]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_recycle_bin_uses_maintain_memory_tool(monkeypatch):
|
|
||||||
service = MemoryService()
|
|
||||||
calls = []
|
|
||||||
|
|
||||||
async def fake_invoke(component_name, args=None, **kwargs):
|
|
||||||
calls.append((component_name, args))
|
|
||||||
return {"success": True, "items": [{"hash": "abc"}], "count": 1}
|
|
||||||
|
|
||||||
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
|
||||||
|
|
||||||
result = await service.get_recycle_bin(limit=5)
|
|
||||||
|
|
||||||
assert result == {"success": True, "items": [{"hash": "abc"}], "count": 1}
|
|
||||||
assert calls == [("maintain_memory", {"action": "recycle_bin", "limit": 5})]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_search_respects_filter_by_default(monkeypatch):
|
|
||||||
service = MemoryService()
|
|
||||||
calls = []
|
|
||||||
|
|
||||||
async def fake_invoke(component_name, args=None, **kwargs):
|
|
||||||
calls.append((component_name, args))
|
|
||||||
return {"summary": "ok", "hits": [], "filtered": True}
|
|
||||||
|
|
||||||
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
|
||||||
|
|
||||||
result = await service.search(
|
|
||||||
"mai",
|
|
||||||
chat_id="stream-1",
|
|
||||||
person_id="person-1",
|
|
||||||
user_id="user-1",
|
|
||||||
group_id="",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(result, MemorySearchResult)
|
|
||||||
assert result.filtered is True
|
|
||||||
assert calls == [
|
|
||||||
(
|
|
||||||
"search_memory",
|
|
||||||
{
|
|
||||||
"query": "mai",
|
|
||||||
"limit": 5,
|
|
||||||
"mode": "search",
|
|
||||||
"chat_id": "stream-1",
|
|
||||||
"person_id": "person-1",
|
|
||||||
"time_start": None,
|
|
||||||
"time_end": None,
|
|
||||||
"respect_filter": True,
|
|
||||||
"user_id": "user-1",
|
|
||||||
"group_id": "",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_ingest_summary_can_bypass_filter(monkeypatch):
|
|
||||||
service = MemoryService()
|
|
||||||
calls = []
|
|
||||||
|
|
||||||
async def fake_invoke(component_name, args=None, **kwargs):
|
|
||||||
calls.append((component_name, args))
|
|
||||||
return {"success": True, "stored_ids": ["p1"], "detail": ""}
|
|
||||||
|
|
||||||
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
|
||||||
|
|
||||||
result = await service.ingest_summary(
|
|
||||||
external_id="chat_history:1",
|
|
||||||
chat_id="stream-1",
|
|
||||||
text="summary",
|
|
||||||
respect_filter=False,
|
|
||||||
user_id="user-1",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.success is True
|
|
||||||
assert calls == [
|
|
||||||
(
|
|
||||||
"ingest_summary",
|
|
||||||
{
|
|
||||||
"external_id": "chat_history:1",
|
|
||||||
"chat_id": "stream-1",
|
|
||||||
"text": "summary",
|
|
||||||
"participants": [],
|
|
||||||
"time_start": None,
|
|
||||||
"time_end": None,
|
|
||||||
"tags": [],
|
|
||||||
"metadata": {},
|
|
||||||
"respect_filter": False,
|
|
||||||
"user_id": "user-1",
|
|
||||||
"group_id": "",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_v5_admin_invokes_plugin(monkeypatch):
|
|
||||||
service = MemoryService()
|
|
||||||
calls = []
|
|
||||||
|
|
||||||
async def fake_invoke(component_name, args=None, **kwargs):
|
|
||||||
calls.append((component_name, args, kwargs))
|
|
||||||
return {"success": True, "count": 1}
|
|
||||||
|
|
||||||
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
|
||||||
|
|
||||||
result = await service.v5_admin(action="status", target="mai", limit=5)
|
|
||||||
|
|
||||||
assert result["success"] is True
|
|
||||||
assert calls == [("memory_v5_admin", {"action": "status", "target": "mai", "limit": 5}, {"timeout_ms": 30000})]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_admin_uses_long_timeout(monkeypatch):
|
|
||||||
service = MemoryService()
|
|
||||||
calls = []
|
|
||||||
|
|
||||||
async def fake_invoke(component_name, args=None, **kwargs):
|
|
||||||
calls.append((component_name, args, kwargs))
|
|
||||||
return {"success": True, "operation_id": "del-1"}
|
|
||||||
|
|
||||||
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
|
||||||
|
|
||||||
result = await service.delete_admin(action="execute", mode="relation", selector={"query": "mai"})
|
|
||||||
|
|
||||||
assert result["success"] is True
|
|
||||||
assert calls == [
|
|
||||||
(
|
|
||||||
"memory_delete_admin",
|
|
||||||
{"action": "execute", "mode": "relation", "selector": {"query": "mai"}},
|
|
||||||
{"timeout_ms": 120000},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_search_returns_empty_when_query_and_time_missing_async():
|
|
||||||
service = MemoryService()
|
|
||||||
|
|
||||||
result = await service.search("", time_start=None, time_end=None)
|
|
||||||
|
|
||||||
assert isinstance(result, MemorySearchResult)
|
|
||||||
assert result.summary == ""
|
|
||||||
assert result.hits == []
|
|
||||||
assert result.filtered is False
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_search_accepts_string_time_bounds(monkeypatch):
|
|
||||||
service = MemoryService()
|
|
||||||
calls = []
|
|
||||||
|
|
||||||
async def fake_invoke(component_name, args=None, **kwargs):
|
|
||||||
calls.append((component_name, args))
|
|
||||||
return {"summary": "ok", "hits": [], "filtered": False}
|
|
||||||
|
|
||||||
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
|
||||||
|
|
||||||
result = await service.search(
|
|
||||||
"广播站",
|
|
||||||
mode="time",
|
|
||||||
time_start="2026/03/18",
|
|
||||||
time_end="2026/03/18 09:30",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(result, MemorySearchResult)
|
|
||||||
assert calls == [
|
|
||||||
(
|
|
||||||
"search_memory",
|
|
||||||
{
|
|
||||||
"query": "广播站",
|
|
||||||
"limit": 5,
|
|
||||||
"mode": "time",
|
|
||||||
"chat_id": "",
|
|
||||||
"person_id": "",
|
|
||||||
"time_start": "2026/03/18",
|
|
||||||
"time_end": "2026/03/18 09:30",
|
|
||||||
"respect_filter": True,
|
|
||||||
"user_id": "",
|
|
||||||
"group_id": "",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_coerce_search_result_preserves_aggregate_source_branches():
|
|
||||||
result = MemoryService._coerce_search_result(
|
|
||||||
{
|
|
||||||
"hits": [
|
|
||||||
{
|
|
||||||
"content": "广播站值夜班",
|
|
||||||
"type": "paragraph",
|
|
||||||
"metadata": {"event_time_start": 1.0},
|
|
||||||
"source_branches": ["search", "time"],
|
|
||||||
"rank": 1,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.hits[0].metadata["source_branches"] == ["search", "time"]
|
|
||||||
assert result.hits[0].metadata["rank"] == 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_import_admin_uses_long_timeout(monkeypatch):
|
|
||||||
service = MemoryService()
|
|
||||||
calls = []
|
|
||||||
|
|
||||||
async def fake_invoke(component_name, args=None, **kwargs):
|
|
||||||
calls.append((component_name, args, kwargs))
|
|
||||||
return {"success": True, "task_id": "import-1"}
|
|
||||||
|
|
||||||
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
|
||||||
|
|
||||||
result = await service.import_admin(action="create_lpmm_openie", alias="lpmm")
|
|
||||||
|
|
||||||
assert result["success"] is True
|
|
||||||
assert calls == [
|
|
||||||
(
|
|
||||||
"memory_import_admin",
|
|
||||||
{"action": "create_lpmm_openie", "alias": "lpmm"},
|
|
||||||
{"timeout_ms": 120000},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tuning_admin_uses_long_timeout(monkeypatch):
|
|
||||||
service = MemoryService()
|
|
||||||
calls = []
|
|
||||||
|
|
||||||
async def fake_invoke(component_name, args=None, **kwargs):
|
|
||||||
calls.append((component_name, args, kwargs))
|
|
||||||
return {"success": True, "task_id": "tuning-1"}
|
|
||||||
|
|
||||||
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
|
||||||
|
|
||||||
result = await service.tuning_admin(action="create_task", payload={"query": "mai"})
|
|
||||||
|
|
||||||
assert result["success"] is True
|
|
||||||
assert calls == [
|
|
||||||
(
|
|
||||||
"memory_tuning_admin",
|
|
||||||
{"action": "create_task", "payload": {"query": "mai"}},
|
|
||||||
{"timeout_ms": 120000},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from src.A_memorix.core.storage.metadata_store import MetadataStore
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_all_sources_ignores_soft_deleted_paragraphs(tmp_path: Path) -> None:
|
|
||||||
store = MetadataStore(data_dir=tmp_path)
|
|
||||||
store.connect()
|
|
||||||
try:
|
|
||||||
live_hash = store.add_paragraph("Alice 喜欢地图", source="live-source")
|
|
||||||
deleted_hash = store.add_paragraph("Bob 喜欢咖啡", source="deleted-source")
|
|
||||||
|
|
||||||
assert live_hash
|
|
||||||
store.mark_as_deleted([deleted_hash], "paragraph")
|
|
||||||
|
|
||||||
sources = store.get_all_sources()
|
|
||||||
finally:
|
|
||||||
store.close()
|
|
||||||
|
|
||||||
assert [item["source"] for item in sources] == ["live-source"]
|
|
||||||
assert sources[0]["count"] == 1
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.person_info import person_info as person_info_module
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_store_person_memory_from_answer_writes_person_fact(monkeypatch):
|
|
||||||
calls = []
|
|
||||||
|
|
||||||
class FakePerson:
|
|
||||||
def __init__(self, person_id: str):
|
|
||||||
self.person_id = person_id
|
|
||||||
self.person_name = "Alice"
|
|
||||||
self.is_known = True
|
|
||||||
|
|
||||||
async def fake_ingest_text(**kwargs):
|
|
||||||
calls.append(kwargs)
|
|
||||||
return SimpleNamespace(success=True, detail="", stored_ids=["p1"])
|
|
||||||
|
|
||||||
session = SimpleNamespace(platform="qq", user_id="10001", group_id="", session_id="session-1")
|
|
||||||
monkeypatch.setattr(person_info_module, "_chat_manager", SimpleNamespace(get_session_by_session_id=lambda chat_id: session))
|
|
||||||
monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: "person-1")
|
|
||||||
monkeypatch.setattr(person_info_module, "Person", FakePerson)
|
|
||||||
monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text)
|
|
||||||
|
|
||||||
await person_info_module.store_person_memory_from_answer("Alice", "她喜欢猫和爵士乐", "session-1")
|
|
||||||
|
|
||||||
assert len(calls) == 1
|
|
||||||
payload = calls[0]
|
|
||||||
assert payload["external_id"].startswith("person_fact:person-1:")
|
|
||||||
assert payload["source_type"] == "person_fact"
|
|
||||||
assert payload["chat_id"] == "session-1"
|
|
||||||
assert payload["person_ids"] == ["person-1"]
|
|
||||||
assert payload["participants"] == ["Alice"]
|
|
||||||
assert payload["respect_filter"] is True
|
|
||||||
assert payload["user_id"] == "10001"
|
|
||||||
assert payload["group_id"] == ""
|
|
||||||
assert payload["metadata"]["person_id"] == "person-1"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_store_person_memory_from_answer_skips_unknown_person(monkeypatch):
|
|
||||||
calls = []
|
|
||||||
|
|
||||||
class FakePerson:
|
|
||||||
def __init__(self, person_id: str):
|
|
||||||
self.person_id = person_id
|
|
||||||
self.person_name = "Unknown"
|
|
||||||
self.is_known = False
|
|
||||||
|
|
||||||
async def fake_ingest_text(**kwargs):
|
|
||||||
calls.append(kwargs)
|
|
||||||
return SimpleNamespace(success=True, detail="", stored_ids=["p1"])
|
|
||||||
|
|
||||||
session = SimpleNamespace(platform="qq", user_id="10001", group_id="", session_id="session-1")
|
|
||||||
monkeypatch.setattr(person_info_module, "_chat_manager", SimpleNamespace(get_session_by_session_id=lambda chat_id: session))
|
|
||||||
monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: "person-1")
|
|
||||||
monkeypatch.setattr(person_info_module, "Person", FakePerson)
|
|
||||||
monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text)
|
|
||||||
|
|
||||||
await person_info_module.store_person_memory_from_answer("Alice", "她喜欢猫和爵士乐", "session-1")
|
|
||||||
|
|
||||||
assert calls == []
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_store_person_memory_from_answer_skips_empty_content(monkeypatch):
|
|
||||||
calls = []
|
|
||||||
|
|
||||||
async def fake_ingest_text(**kwargs):
|
|
||||||
calls.append(kwargs)
|
|
||||||
return SimpleNamespace(success=True, detail="", stored_ids=["p1"])
|
|
||||||
|
|
||||||
monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text)
|
|
||||||
|
|
||||||
await person_info_module.store_person_memory_from_answer("Alice", " ", "session-1")
|
|
||||||
|
|
||||||
assert calls == []
|
|
||||||
|
|
||||||
@@ -1,115 +0,0 @@
|
|||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.A_memorix.core.utils.person_profile_service import PersonProfileService
|
|
||||||
|
|
||||||
|
|
||||||
class FakeMetadataStore:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.snapshots: list[dict] = []
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_latest_person_profile_snapshot(person_id: str):
|
|
||||||
del person_id
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_relations(**kwargs):
|
|
||||||
del kwargs
|
|
||||||
return []
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_paragraphs_by_source(source: str):
|
|
||||||
if source == "person_fact:person-1":
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"hash": "person-fact-1",
|
|
||||||
"content": "测试用户喜欢猫。",
|
|
||||||
"source": source,
|
|
||||||
"metadata": {"source_type": "person_fact"},
|
|
||||||
"created_at": 2.0,
|
|
||||||
"updated_at": 2.0,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
return []
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_paragraph(hash_value: str):
|
|
||||||
if hash_value == "chat-summary-1":
|
|
||||||
return {
|
|
||||||
"hash": hash_value,
|
|
||||||
"content": "机器人建议测试用户以后叫星灯。",
|
|
||||||
"source": "chat_summary:session-1",
|
|
||||||
"metadata": {"source_type": "chat_summary"},
|
|
||||||
"word_count": 1,
|
|
||||||
}
|
|
||||||
if hash_value == "person-fact-1":
|
|
||||||
return {
|
|
||||||
"hash": hash_value,
|
|
||||||
"content": "测试用户喜欢猫。",
|
|
||||||
"source": "person_fact:person-1",
|
|
||||||
"metadata": {"source_type": "person_fact"},
|
|
||||||
"word_count": 1,
|
|
||||||
}
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_paragraph_stale_relation_marks_batch(paragraph_hashes):
|
|
||||||
del paragraph_hashes
|
|
||||||
return {}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_relation_status_batch(relation_hashes):
|
|
||||||
del relation_hashes
|
|
||||||
return {}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_person_profile_override(person_id: str):
|
|
||||||
del person_id
|
|
||||||
return None
|
|
||||||
|
|
||||||
def upsert_person_profile_snapshot(self, **kwargs):
|
|
||||||
self.snapshots.append(kwargs)
|
|
||||||
return {
|
|
||||||
"person_id": kwargs["person_id"],
|
|
||||||
"profile_text": kwargs["profile_text"],
|
|
||||||
"aliases": kwargs["aliases"],
|
|
||||||
"relation_edges": kwargs["relation_edges"],
|
|
||||||
"vector_evidence": kwargs["vector_evidence"],
|
|
||||||
"evidence_ids": kwargs["evidence_ids"],
|
|
||||||
"updated_at": 1.0,
|
|
||||||
"expires_at": kwargs["expires_at"],
|
|
||||||
"source_note": kwargs["source_note"],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class FakeRetriever:
|
|
||||||
async def retrieve(self, query: str, top_k: int):
|
|
||||||
del query, top_k
|
|
||||||
return [
|
|
||||||
SimpleNamespace(
|
|
||||||
hash_value="chat-summary-1",
|
|
||||||
result_type="paragraph",
|
|
||||||
score=0.95,
|
|
||||||
content="机器人建议测试用户以后叫星灯。",
|
|
||||||
metadata={"source_type": "chat_summary"},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_person_profile_keeps_chat_summary_as_recent_interaction_not_stable_profile():
|
|
||||||
metadata_store = FakeMetadataStore()
|
|
||||||
service = PersonProfileService(metadata_store=metadata_store, retriever=FakeRetriever())
|
|
||||||
service.get_person_aliases = lambda person_id: (["测试用户"], "测试用户", [])
|
|
||||||
|
|
||||||
payload = await service.query_person_profile(person_id="person-1", top_k=6, force_refresh=True)
|
|
||||||
|
|
||||||
assert payload["success"] is True
|
|
||||||
profile_text = payload["profile_text"]
|
|
||||||
stable_section = profile_text.split("近期相关互动:", 1)[0]
|
|
||||||
assert "测试用户喜欢猫" in stable_section
|
|
||||||
assert "星灯" not in stable_section
|
|
||||||
assert "近期相关互动:" in profile_text
|
|
||||||
assert "星灯" in profile_text
|
|
||||||
@@ -1,184 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.memory_system.retrieval_tools import query_long_term_memory as tool_module
|
|
||||||
from src.memory_system.retrieval_tools import init_all_tools
|
|
||||||
from src.memory_system.retrieval_tools.query_long_term_memory import (
|
|
||||||
_resolve_time_expression,
|
|
||||||
query_long_term_memory,
|
|
||||||
register_tool,
|
|
||||||
)
|
|
||||||
from src.memory_system.retrieval_tools.tool_registry import get_tool_registry
|
|
||||||
from src.services.memory_service import MemoryHit, MemorySearchResult
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_time_expression_supports_relative_and_absolute_inputs():
|
|
||||||
now = datetime(2026, 3, 18, 15, 30)
|
|
||||||
|
|
||||||
start_ts, end_ts, start_text, end_text = _resolve_time_expression("今天", now=now)
|
|
||||||
assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 0, 0)
|
|
||||||
assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59)
|
|
||||||
assert start_text == "2026/03/18 00:00"
|
|
||||||
assert end_text == "2026/03/18 23:59"
|
|
||||||
|
|
||||||
start_ts, end_ts, start_text, end_text = _resolve_time_expression("最近7天", now=now)
|
|
||||||
assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 12, 0, 0)
|
|
||||||
assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59)
|
|
||||||
assert start_text == "2026/03/12 00:00"
|
|
||||||
assert end_text == "2026/03/18 23:59"
|
|
||||||
|
|
||||||
start_ts, end_ts, start_text, end_text = _resolve_time_expression("2026/03/18", now=now)
|
|
||||||
assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 0, 0)
|
|
||||||
assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59)
|
|
||||||
assert start_text == "2026/03/18 00:00"
|
|
||||||
assert end_text == "2026/03/18 23:59"
|
|
||||||
|
|
||||||
start_ts, end_ts, start_text, end_text = _resolve_time_expression("2026/03/18 09:30", now=now)
|
|
||||||
assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 9, 30)
|
|
||||||
assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 9, 30)
|
|
||||||
assert start_text == "2026/03/18 09:30"
|
|
||||||
assert end_text == "2026/03/18 09:30"
|
|
||||||
|
|
||||||
|
|
||||||
def test_register_tool_exposes_mode_and_time_expression():
|
|
||||||
register_tool()
|
|
||||||
tool = get_tool_registry().get_tool("search_long_term_memory")
|
|
||||||
|
|
||||||
assert tool is not None
|
|
||||||
params = {item["name"]: item for item in tool.parameters}
|
|
||||||
assert "mode" in params
|
|
||||||
assert params["mode"]["enum"] == ["search", "time", "episode", "aggregate"]
|
|
||||||
assert "time_expression" in params
|
|
||||||
assert params["query"]["required"] is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_all_tools_registers_long_term_memory_tool():
|
|
||||||
init_all_tools()
|
|
||||||
|
|
||||||
tool = get_tool_registry().get_tool("search_long_term_memory")
|
|
||||||
assert tool is not None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_long_term_memory_search_mode_keeps_search(monkeypatch):
|
|
||||||
captured = {}
|
|
||||||
|
|
||||||
async def fake_search(query, **kwargs):
|
|
||||||
captured["query"] = query
|
|
||||||
captured["kwargs"] = kwargs
|
|
||||||
return MemorySearchResult(
|
|
||||||
hits=[MemoryHit(content="Alice 喜欢猫", score=0.9, hit_type="paragraph")],
|
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search))
|
|
||||||
|
|
||||||
text = await query_long_term_memory("Alice 喜欢什么", chat_id="stream-1", person_id="person-1")
|
|
||||||
|
|
||||||
assert "Alice 喜欢猫" in text
|
|
||||||
assert captured == {
|
|
||||||
"query": "Alice 喜欢什么",
|
|
||||||
"kwargs": {
|
|
||||||
"limit": 5,
|
|
||||||
"mode": "search",
|
|
||||||
"chat_id": "stream-1",
|
|
||||||
"person_id": "person-1",
|
|
||||||
"time_start": None,
|
|
||||||
"time_end": None,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_long_term_memory_time_mode_parses_expression(monkeypatch):
|
|
||||||
captured = {}
|
|
||||||
|
|
||||||
async def fake_search(query, **kwargs):
|
|
||||||
captured["query"] = query
|
|
||||||
captured["kwargs"] = kwargs
|
|
||||||
return MemorySearchResult(
|
|
||||||
hits=[
|
|
||||||
MemoryHit(
|
|
||||||
content="昨天晚上广播站停播了十分钟。",
|
|
||||||
score=0.8,
|
|
||||||
hit_type="paragraph",
|
|
||||||
metadata={"event_time_start": 1773797400.0},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search))
|
|
||||||
monkeypatch.setattr(
|
|
||||||
tool_module,
|
|
||||||
"_resolve_time_expression",
|
|
||||||
lambda expression, now=None: (1773795600.0, 1773881940.0, "2026/03/17 00:00", "2026/03/17 23:59"),
|
|
||||||
)
|
|
||||||
|
|
||||||
text = await query_long_term_memory(
|
|
||||||
query="广播站",
|
|
||||||
mode="time",
|
|
||||||
time_expression="昨天",
|
|
||||||
chat_id="stream-1",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "指定时间范围" in text
|
|
||||||
assert "广播站停播" in text
|
|
||||||
assert captured == {
|
|
||||||
"query": "广播站",
|
|
||||||
"kwargs": {
|
|
||||||
"limit": 5,
|
|
||||||
"mode": "time",
|
|
||||||
"chat_id": "stream-1",
|
|
||||||
"person_id": "",
|
|
||||||
"time_start": 1773795600.0,
|
|
||||||
"time_end": 1773881940.0,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_long_term_memory_episode_and_aggregate_format_output(monkeypatch):
|
|
||||||
responses = {
|
|
||||||
"episode": MemorySearchResult(
|
|
||||||
hits=[
|
|
||||||
MemoryHit(
|
|
||||||
content="苏弦在灯塔拆开了那封冬信。",
|
|
||||||
title="冬信重见天日",
|
|
||||||
hit_type="episode",
|
|
||||||
metadata={"participants": ["苏弦"], "keywords": ["冬信", "灯塔"]},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
),
|
|
||||||
"aggregate": MemorySearchResult(
|
|
||||||
hits=[
|
|
||||||
MemoryHit(
|
|
||||||
content="唐未在广播站值夜班时带着黑狗墨点。",
|
|
||||||
hit_type="paragraph",
|
|
||||||
metadata={"source_branches": ["search", "time"]},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
async def fake_search(query, **kwargs):
|
|
||||||
return responses[kwargs["mode"]]
|
|
||||||
|
|
||||||
monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search))
|
|
||||||
|
|
||||||
episode_text = await query_long_term_memory("那封冬信后来怎么样了", mode="episode")
|
|
||||||
aggregate_text = await query_long_term_memory("唐未最近有什么线索", mode="aggregate")
|
|
||||||
|
|
||||||
assert "事件《冬信重见天日》" in episode_text
|
|
||||||
assert "参与者:苏弦" in episode_text
|
|
||||||
assert "[search,time][paragraph]" in aggregate_text
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_long_term_memory_invalid_time_expression_returns_retryable_message():
|
|
||||||
text = await query_long_term_memory(query="广播站", mode="time", time_expression="明年春分后第三周")
|
|
||||||
|
|
||||||
assert "无法解析" in text
|
|
||||||
assert "最近7天" in text
|
|
||||||
@@ -1,140 +0,0 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
from src.A_memorix.core.utils.summary_importer import (
|
|
||||||
SummaryImporter,
|
|
||||||
_message_timestamp,
|
|
||||||
_normalize_entity_items,
|
|
||||||
_normalize_relation_items,
|
|
||||||
)
|
|
||||||
from src.config.model_configs import TaskConfig
|
|
||||||
from src.services import llm_service as llm_api
|
|
||||||
|
|
||||||
|
|
||||||
def _fake_available_models() -> dict[str, TaskConfig]:
|
|
||||||
return {
|
|
||||||
"memory": TaskConfig(
|
|
||||||
model_list=["memory-model"],
|
|
||||||
max_tokens=512,
|
|
||||||
temperature=0.4,
|
|
||||||
selection_strategy="random",
|
|
||||||
),
|
|
||||||
"utils": TaskConfig(
|
|
||||||
model_list=["utils-model"],
|
|
||||||
max_tokens=256,
|
|
||||||
temperature=0.5,
|
|
||||||
selection_strategy="random",
|
|
||||||
),
|
|
||||||
"replyer": TaskConfig(
|
|
||||||
model_list=["replyer-model"],
|
|
||||||
max_tokens=128,
|
|
||||||
temperature=0.7,
|
|
||||||
selection_strategy="random",
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_summary_model_config_uses_auto_list_when_summarization_missing(monkeypatch):
|
|
||||||
monkeypatch.setattr(llm_api, "get_available_models", _fake_available_models)
|
|
||||||
|
|
||||||
importer = SummaryImporter(
|
|
||||||
vector_store=None,
|
|
||||||
graph_store=None,
|
|
||||||
metadata_store=None,
|
|
||||||
embedding_manager=None,
|
|
||||||
plugin_config={},
|
|
||||||
)
|
|
||||||
|
|
||||||
resolved = importer._resolve_summary_model_config()
|
|
||||||
|
|
||||||
assert resolved is not None
|
|
||||||
assert resolved.model_list == ["memory-model"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_summary_model_config_auto_falls_back_to_utils_then_planner(monkeypatch):
|
|
||||||
importer = SummaryImporter(
|
|
||||||
vector_store=None,
|
|
||||||
graph_store=None,
|
|
||||||
metadata_store=None,
|
|
||||||
embedding_manager=None,
|
|
||||||
plugin_config={},
|
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
llm_api,
|
|
||||||
"get_available_models",
|
|
||||||
lambda: {
|
|
||||||
"utils": TaskConfig(model_list=["utils-model"]),
|
|
||||||
"planner": TaskConfig(model_list=["planner-model"]),
|
|
||||||
"replyer": TaskConfig(model_list=["replyer-model"]),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
resolved = importer._resolve_summary_model_config()
|
|
||||||
assert resolved is not None
|
|
||||||
assert resolved.model_list == ["utils-model"]
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
llm_api,
|
|
||||||
"get_available_models",
|
|
||||||
lambda: {
|
|
||||||
"planner": TaskConfig(model_list=["planner-model"]),
|
|
||||||
"replyer": TaskConfig(model_list=["replyer-model"]),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
resolved = importer._resolve_summary_model_config()
|
|
||||||
assert resolved is not None
|
|
||||||
assert resolved.model_list == ["planner-model"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_summary_model_config_auto_does_not_fallback_to_replyer(monkeypatch):
|
|
||||||
monkeypatch.setattr(
|
|
||||||
llm_api,
|
|
||||||
"get_available_models",
|
|
||||||
lambda: {
|
|
||||||
"replyer": TaskConfig(model_list=["replyer-model"]),
|
|
||||||
"embedding": TaskConfig(model_list=["embedding-model"]),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
importer = SummaryImporter(
|
|
||||||
vector_store=None,
|
|
||||||
graph_store=None,
|
|
||||||
metadata_store=None,
|
|
||||||
embedding_manager=None,
|
|
||||||
plugin_config={},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert importer._resolve_summary_model_config() is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_summary_model_config_rejects_legacy_string_selector(monkeypatch):
|
|
||||||
monkeypatch.setattr(llm_api, "get_available_models", _fake_available_models)
|
|
||||||
|
|
||||||
importer = SummaryImporter(
|
|
||||||
vector_store=None,
|
|
||||||
graph_store=None,
|
|
||||||
metadata_store=None,
|
|
||||||
embedding_manager=None,
|
|
||||||
plugin_config={"summarization": {"model_name": "auto"}},
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="List\\[str\\]"):
|
|
||||||
importer._resolve_summary_model_config()
|
|
||||||
|
|
||||||
|
|
||||||
def test_summary_importer_normalizes_llm_entities_and_relations():
|
|
||||||
assert _normalize_entity_items(["Alice", {"name": "地图"}, ["bad"], "Alice"]) == ["Alice", "地图"]
|
|
||||||
assert _normalize_entity_items("Alice") == []
|
|
||||||
assert _normalize_relation_items(
|
|
||||||
[
|
|
||||||
{"subject": "Alice", "predicate": "持有", "object": "地图"},
|
|
||||||
{"subject": "Alice", "predicate": "", "object": "地图"},
|
|
||||||
["bad"],
|
|
||||||
]
|
|
||||||
) == [{"subject": "Alice", "predicate": "持有", "object": "地图"}]
|
|
||||||
|
|
||||||
|
|
||||||
def test_summary_importer_message_timestamp_accepts_time_fallback():
|
|
||||||
class Message:
|
|
||||||
time = 123.5
|
|
||||||
|
|
||||||
assert _message_timestamp(Message()) == 123.5
|
|
||||||
@@ -1,182 +0,0 @@
|
|||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.A_memorix.core.strategies.base import ChunkContext, KnowledgeType, ProcessedChunk, SourceInfo
|
|
||||||
from src.A_memorix.core.utils.web_import_manager import (
|
|
||||||
ImportChunkRecord,
|
|
||||||
ImportFileRecord,
|
|
||||||
ImportTaskManager,
|
|
||||||
ImportTaskRecord,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _DummyMetadataStore:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.paragraphs: list[dict[str, object]] = []
|
|
||||||
self.entities: list[str] = []
|
|
||||||
self.relations: list[tuple[str, str, str]] = []
|
|
||||||
|
|
||||||
def add_paragraph(self, **kwargs):
|
|
||||||
self.paragraphs.append(dict(kwargs))
|
|
||||||
return f"paragraph-{len(self.paragraphs)}"
|
|
||||||
|
|
||||||
def add_entity(self, *, name: str, source_paragraph: str = "") -> str:
|
|
||||||
del source_paragraph
|
|
||||||
self.entities.append(name)
|
|
||||||
return f"entity-{name}"
|
|
||||||
|
|
||||||
def add_relation(self, *, subject: str, predicate: str, obj: str, **kwargs) -> str:
|
|
||||||
del kwargs
|
|
||||||
self.relations.append((subject, predicate, obj))
|
|
||||||
return f"relation-{len(self.relations)}"
|
|
||||||
|
|
||||||
def set_relation_vector_state(self, rel_hash: str, state: str) -> None:
|
|
||||||
del rel_hash, state
|
|
||||||
|
|
||||||
|
|
||||||
class _DummyGraphStore:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.nodes: list[list[str]] = []
|
|
||||||
self.edges: list[list[tuple[str, str]]] = []
|
|
||||||
|
|
||||||
def add_nodes(self, nodes):
|
|
||||||
self.nodes.append(list(nodes))
|
|
||||||
|
|
||||||
def add_edges(self, edges, relation_hashes=None):
|
|
||||||
del relation_hashes
|
|
||||||
self.edges.append(list(edges))
|
|
||||||
|
|
||||||
|
|
||||||
class _DummyVectorStore:
|
|
||||||
def __contains__(self, item: str) -> bool:
|
|
||||||
del item
|
|
||||||
return False
|
|
||||||
|
|
||||||
def add(self, vectors, ids):
|
|
||||||
del vectors, ids
|
|
||||||
|
|
||||||
|
|
||||||
class _DummyEmbeddingManager:
|
|
||||||
async def encode(self, text: str) -> np.ndarray:
|
|
||||||
del text
|
|
||||||
return np.ones(4, dtype=np.float32)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_manager() -> tuple[ImportTaskManager, _DummyMetadataStore]:
|
|
||||||
metadata_store = _DummyMetadataStore()
|
|
||||||
plugin = SimpleNamespace(
|
|
||||||
metadata_store=metadata_store,
|
|
||||||
graph_store=_DummyGraphStore(),
|
|
||||||
vector_store=_DummyVectorStore(),
|
|
||||||
embedding_manager=_DummyEmbeddingManager(),
|
|
||||||
relation_write_service=None,
|
|
||||||
get_config=lambda key, default=None: default,
|
|
||||||
_is_embedding_degraded=lambda: False,
|
|
||||||
_allow_metadata_only_write=lambda: True,
|
|
||||||
write_paragraph_vector_or_enqueue=None,
|
|
||||||
)
|
|
||||||
manager = ImportTaskManager(plugin)
|
|
||||||
return manager, metadata_store
|
|
||||||
|
|
||||||
|
|
||||||
def _build_progress_task(task_id: str, total_chunks: int = 2) -> ImportTaskRecord:
|
|
||||||
file_record = ImportFileRecord(
|
|
||||||
file_id="file-1",
|
|
||||||
name="demo.txt",
|
|
||||||
source_kind="paste",
|
|
||||||
input_mode="text",
|
|
||||||
total_chunks=total_chunks,
|
|
||||||
chunks=[
|
|
||||||
ImportChunkRecord(chunk_id=f"chunk-{index}", index=index, chunk_type="text")
|
|
||||||
for index in range(total_chunks)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
return ImportTaskRecord(task_id=task_id, source="paste", params={}, files=[file_record])
|
|
||||||
|
|
||||||
|
|
||||||
def _build_chunk(data) -> ProcessedChunk:
|
|
||||||
return ProcessedChunk(
|
|
||||||
type=KnowledgeType.FACTUAL,
|
|
||||||
source=SourceInfo(file="demo.txt", offset_start=0, offset_end=4),
|
|
||||||
chunk=ChunkContext(chunk_id="chunk-1", index=0, text="Alice 持有地图"),
|
|
||||||
data=data,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_persist_processed_chunk_rejects_non_object_before_paragraph_write() -> None:
|
|
||||||
manager, metadata_store = _build_manager()
|
|
||||||
file_record = SimpleNamespace(source_path="", source_kind="paste", name="demo.txt")
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="分块抽取结果 必须返回 JSON 对象"):
|
|
||||||
await manager._persist_processed_chunk(file_record, _build_chunk(["bad"]))
|
|
||||||
|
|
||||||
assert metadata_store.paragraphs == []
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_chunk_terminal_progress_uses_successful_chunks_only() -> None:
|
|
||||||
manager, _ = _build_manager()
|
|
||||||
|
|
||||||
task = _build_progress_task("task-fail-then-complete")
|
|
||||||
manager._tasks[task.task_id] = task
|
|
||||||
|
|
||||||
await manager._set_chunk_failed(task.task_id, "file-1", "chunk-0", "boom")
|
|
||||||
await manager._set_chunk_completed(task.task_id, "file-1", "chunk-1")
|
|
||||||
|
|
||||||
file_record = task.files[0]
|
|
||||||
assert file_record.done_chunks == 1
|
|
||||||
assert file_record.failed_chunks == 1
|
|
||||||
assert file_record.progress == pytest.approx(0.5)
|
|
||||||
assert task.progress == pytest.approx(0.5)
|
|
||||||
|
|
||||||
reverse_task = _build_progress_task("task-complete-then-fail")
|
|
||||||
manager._tasks[reverse_task.task_id] = reverse_task
|
|
||||||
|
|
||||||
await manager._set_chunk_completed(reverse_task.task_id, "file-1", "chunk-0")
|
|
||||||
await manager._set_chunk_failed(reverse_task.task_id, "file-1", "chunk-1", "boom")
|
|
||||||
|
|
||||||
reverse_file = reverse_task.files[0]
|
|
||||||
assert reverse_file.done_chunks == 1
|
|
||||||
assert reverse_file.failed_chunks == 1
|
|
||||||
assert reverse_file.progress == pytest.approx(0.5)
|
|
||||||
assert reverse_task.progress == pytest.approx(0.5)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_cancelled_chunks_do_not_increase_file_progress() -> None:
|
|
||||||
manager, _ = _build_manager()
|
|
||||||
task = _build_progress_task("task-cancelled-progress", total_chunks=3)
|
|
||||||
manager._tasks[task.task_id] = task
|
|
||||||
|
|
||||||
await manager._set_chunk_completed(task.task_id, "file-1", "chunk-0")
|
|
||||||
await manager._set_chunk_cancelled(task.task_id, "file-1", "chunk-1", "任务已取消")
|
|
||||||
|
|
||||||
file_record = task.files[0]
|
|
||||||
assert file_record.done_chunks == 1
|
|
||||||
assert file_record.cancelled_chunks == 1
|
|
||||||
assert file_record.progress == pytest.approx(1 / 3)
|
|
||||||
assert task.progress == pytest.approx(1 / 3)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_persist_processed_chunk_skips_invalid_nested_items() -> None:
|
|
||||||
manager, metadata_store = _build_manager()
|
|
||||||
file_record = SimpleNamespace(source_path="", source_kind="paste", name="demo.txt")
|
|
||||||
|
|
||||||
await manager._persist_processed_chunk(
|
|
||||||
file_record,
|
|
||||||
_build_chunk(
|
|
||||||
{
|
|
||||||
"triples": [{"subject": "Alice", "predicate": "持有", "object": "地图"}, ["bad"]],
|
|
||||||
"relations": [{"subject": "Alice", "predicate": "", "object": "地图"}],
|
|
||||||
"entities": ["Alice", {"name": "地图"}, ["bad"]],
|
|
||||||
}
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(metadata_store.paragraphs) == 1
|
|
||||||
assert set(metadata_store.entities) >= {"Alice", "地图"}
|
|
||||||
assert metadata_store.relations == [("Alice", "持有", "地图")]
|
|
||||||
@@ -1,89 +0,0 @@
|
|||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
from src.chat.message_receive.chat_manager import chat_manager
|
|
||||||
from src.common.utils.utils_config import ChatConfigUtils, ExpressionConfigUtils
|
|
||||||
from src.common.utils.utils_session import SessionUtils
|
|
||||||
from src.config.config import global_config
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_chat_prompt_for_chat_merges_multiple_matching_prompts(monkeypatch):
|
|
||||||
session_id = SessionUtils.calculate_session_id("qq", group_id="1036092828")
|
|
||||||
monkeypatch.setattr(
|
|
||||||
global_config.chat,
|
|
||||||
"chat_prompts",
|
|
||||||
[
|
|
||||||
{"platform": "qq", "item_id": "1036092828", "rule_type": "group", "prompt": "你也是群管理员,可以适当进行管理"},
|
|
||||||
{"platform": "qq", "item_id": "1036092828", "rule_type": "group", "prompt": "这个群是技术实验群,请你专心讨论技术"},
|
|
||||||
{"platform": "qq", "item_id": "other", "rule_type": "group", "prompt": "不应该生效"},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(chat_manager, "get_session_by_session_id", lambda _session_id: None)
|
|
||||||
|
|
||||||
result = ChatConfigUtils.get_chat_prompt_for_chat(session_id, True)
|
|
||||||
|
|
||||||
assert result == "你也是群管理员,可以适当进行管理\n这个群是技术实验群,请你专心讨论技术"
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_chat_prompt_for_chat_matches_routed_session_by_chat_stream(monkeypatch):
|
|
||||||
session_id = SessionUtils.calculate_session_id("qq", group_id="1036092828", account_id="bot-a")
|
|
||||||
monkeypatch.setattr(
|
|
||||||
global_config.chat,
|
|
||||||
"chat_prompts",
|
|
||||||
[
|
|
||||||
{"platform": "qq", "item_id": "1036092828", "rule_type": "group", "prompt": "路由会话也应该生效"},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
chat_manager,
|
|
||||||
"get_session_by_session_id",
|
|
||||||
lambda _session_id: SimpleNamespace(platform="qq", group_id="1036092828", user_id=None),
|
|
||||||
)
|
|
||||||
|
|
||||||
result = ChatConfigUtils.get_chat_prompt_for_chat(session_id, True)
|
|
||||||
|
|
||||||
assert result == "路由会话也应该生效"
|
|
||||||
|
|
||||||
|
|
||||||
def test_expression_learning_list_matches_routed_session_by_chat_stream(monkeypatch):
|
|
||||||
session_id = SessionUtils.calculate_session_id("qq", group_id="1036092828", account_id="bot-a")
|
|
||||||
monkeypatch.setattr(
|
|
||||||
global_config.expression,
|
|
||||||
"learning_list",
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"platform": "qq",
|
|
||||||
"item_id": "1036092828",
|
|
||||||
"rule_type": "group",
|
|
||||||
"use_expression": False,
|
|
||||||
"enable_learning": False,
|
|
||||||
"enable_jargon_learning": True,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
chat_manager,
|
|
||||||
"get_session_by_session_id",
|
|
||||||
lambda _session_id: SimpleNamespace(platform="qq", group_id="1036092828", user_id=None),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert ExpressionConfigUtils.get_expression_config_for_chat(session_id) == (False, False, True)
|
|
||||||
|
|
||||||
|
|
||||||
def test_talk_value_rules_match_routed_session_by_chat_stream(monkeypatch):
|
|
||||||
session_id = SessionUtils.calculate_session_id("qq", group_id="1036092828", account_id="bot-a")
|
|
||||||
monkeypatch.setattr(global_config.chat, "talk_value", 0.1)
|
|
||||||
monkeypatch.setattr(global_config.chat, "enable_talk_value_rules", True)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
global_config.chat,
|
|
||||||
"talk_value_rules",
|
|
||||||
[
|
|
||||||
{"platform": "qq", "item_id": "1036092828", "rule_type": "group", "time": "00:00-23:59", "value": 0.7}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
chat_manager,
|
|
||||||
"get_session_by_session_id",
|
|
||||||
lambda _session_id: SimpleNamespace(platform="qq", group_id="1036092828", user_id=None),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert ChatConfigUtils.get_talk_value(session_id, True) == 0.7
|
|
||||||
@@ -1,908 +0,0 @@
|
|||||||
"""数据库迁移基础设施测试。"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
from sqlalchemy import text
|
|
||||||
from sqlalchemy.engine import Connection, Engine
|
|
||||||
from sqlmodel import SQLModel, create_engine
|
|
||||||
|
|
||||||
import json
|
|
||||||
import msgpack
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.common.database import database as database_module
|
|
||||||
from src.common.database.migrations import (
|
|
||||||
BaseSchemaVersionDetector,
|
|
||||||
BaseMigrationProgressReporter,
|
|
||||||
DatabaseSchemaSnapshot,
|
|
||||||
DatabaseMigrationBootstrapper,
|
|
||||||
DatabaseMigrationState,
|
|
||||||
DatabaseMigrationManager,
|
|
||||||
EMPTY_SCHEMA_VERSION,
|
|
||||||
LATEST_SCHEMA_VERSION,
|
|
||||||
LEGACY_V1_SCHEMA_VERSION,
|
|
||||||
MigrationExecutionContext,
|
|
||||||
MigrationPlan,
|
|
||||||
MigrationRegistry,
|
|
||||||
MigrationStep,
|
|
||||||
ResolvedSchemaVersion,
|
|
||||||
SchemaVersionResolver,
|
|
||||||
SchemaVersionSource,
|
|
||||||
SQLiteSchemaInspector,
|
|
||||||
SQLiteUserVersionStore,
|
|
||||||
build_default_migration_registry,
|
|
||||||
build_default_schema_version_resolver,
|
|
||||||
create_database_migration_bootstrapper,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FixedVersionDetector(BaseSchemaVersionDetector):
|
|
||||||
"""测试用固定版本探测器。"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
"""返回测试探测器名称。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 探测器名称。
|
|
||||||
"""
|
|
||||||
return "fixed_version_detector"
|
|
||||||
|
|
||||||
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
|
|
||||||
"""根据测试表是否存在返回固定版本。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
snapshot: 当前数据库结构快照。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[int]: 若存在测试表则返回固定版本,否则返回 ``None``。
|
|
||||||
"""
|
|
||||||
if snapshot.has_table("legacy_records"):
|
|
||||||
return 2
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class FakeMigrationProgressReporter(BaseMigrationProgressReporter):
|
|
||||||
"""测试用迁移进度上报器。"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
"""初始化测试用进度上报器。"""
|
|
||||||
self.events: List[Tuple[str, Optional[int], Optional[int], Optional[str]]] = []
|
|
||||||
|
|
||||||
def open(self) -> None:
|
|
||||||
"""记录打开事件。"""
|
|
||||||
self.events.append(("open", None, None, None))
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
"""记录关闭事件。"""
|
|
||||||
self.events.append(("close", None, None, None))
|
|
||||||
|
|
||||||
def start(
|
|
||||||
self,
|
|
||||||
total_records: int,
|
|
||||||
total_tables: int,
|
|
||||||
description: str = "总迁移进度",
|
|
||||||
table_unit_name: str = "表",
|
|
||||||
record_unit_name: str = "记录",
|
|
||||||
) -> None:
|
|
||||||
"""记录启动事件。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
total_records: 任务记录总数。
|
|
||||||
total_tables: 任务表总数。
|
|
||||||
description: 任务描述。
|
|
||||||
table_unit_name: 表级进度单位名称。
|
|
||||||
record_unit_name: 记录级进度单位名称。
|
|
||||||
"""
|
|
||||||
del table_unit_name, record_unit_name
|
|
||||||
self.events.append(("start", total_records, total_tables, description))
|
|
||||||
|
|
||||||
def advance(
|
|
||||||
self,
|
|
||||||
records: int = 0,
|
|
||||||
completed_tables: int = 0,
|
|
||||||
item_name: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
"""记录推进事件。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
records: 推进的记录数。
|
|
||||||
completed_tables: 已完成的表数。
|
|
||||||
item_name: 当前完成的项目名称。
|
|
||||||
"""
|
|
||||||
self.events.append(("advance", records, completed_tables, item_name))
|
|
||||||
|
|
||||||
|
|
||||||
def _create_sqlite_engine(database_file: Path) -> Engine:
|
|
||||||
"""创建测试用 SQLite 引擎。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
database_file: 测试数据库文件路径。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Engine: SQLite 引擎实例。
|
|
||||||
"""
|
|
||||||
return create_engine(
|
|
||||||
f"sqlite:///{database_file}",
|
|
||||||
echo=False,
|
|
||||||
connect_args={"check_same_thread": False},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _create_current_schema(connection: Connection) -> None:
|
|
||||||
"""创建当前最新版本的数据库结构。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
connection: 当前数据库连接。
|
|
||||||
"""
|
|
||||||
import src.common.database.database_model # noqa: F401
|
|
||||||
|
|
||||||
SQLModel.metadata.create_all(connection)
|
|
||||||
|
|
||||||
|
|
||||||
def _create_legacy_v1_schema_with_sample_data(connection: Connection) -> None:
|
|
||||||
"""创建带示例数据的旧版 ``0.x`` 数据库结构。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
connection: 当前数据库连接。
|
|
||||||
"""
|
|
||||||
connection.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
CREATE TABLE chat_streams (
|
|
||||||
id INTEGER PRIMARY KEY,
|
|
||||||
stream_id TEXT NOT NULL,
|
|
||||||
create_time REAL NOT NULL,
|
|
||||||
last_active_time REAL NOT NULL,
|
|
||||||
platform TEXT NOT NULL,
|
|
||||||
user_id TEXT,
|
|
||||||
group_id TEXT,
|
|
||||||
group_name TEXT
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
)
|
|
||||||
connection.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
CREATE TABLE messages (
|
|
||||||
id INTEGER PRIMARY KEY,
|
|
||||||
message_id TEXT NOT NULL,
|
|
||||||
time REAL NOT NULL,
|
|
||||||
chat_id TEXT NOT NULL,
|
|
||||||
chat_info_platform TEXT,
|
|
||||||
user_id TEXT,
|
|
||||||
user_nickname TEXT,
|
|
||||||
chat_info_group_id TEXT,
|
|
||||||
chat_info_group_name TEXT,
|
|
||||||
is_mentioned INTEGER,
|
|
||||||
is_at INTEGER,
|
|
||||||
processed_plain_text TEXT,
|
|
||||||
display_message TEXT,
|
|
||||||
is_emoji INTEGER,
|
|
||||||
is_picid INTEGER,
|
|
||||||
is_command INTEGER,
|
|
||||||
is_notify INTEGER,
|
|
||||||
additional_config TEXT,
|
|
||||||
priority_mode TEXT
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
)
|
|
||||||
connection.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
CREATE TABLE action_records (
|
|
||||||
id INTEGER PRIMARY KEY,
|
|
||||||
action_id TEXT NOT NULL,
|
|
||||||
time REAL NOT NULL,
|
|
||||||
action_reasoning TEXT,
|
|
||||||
action_name TEXT NOT NULL,
|
|
||||||
action_data TEXT,
|
|
||||||
action_prompt_display TEXT,
|
|
||||||
chat_id TEXT
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
)
|
|
||||||
connection.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
CREATE TABLE expression (
|
|
||||||
id INTEGER PRIMARY KEY,
|
|
||||||
situation TEXT NOT NULL,
|
|
||||||
style TEXT NOT NULL,
|
|
||||||
content_list TEXT,
|
|
||||||
count INTEGER,
|
|
||||||
last_active_time REAL NOT NULL,
|
|
||||||
chat_id TEXT,
|
|
||||||
create_date REAL,
|
|
||||||
checked INTEGER,
|
|
||||||
rejected INTEGER,
|
|
||||||
modified_by TEXT
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
)
|
|
||||||
connection.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
CREATE TABLE jargon (
|
|
||||||
id INTEGER PRIMARY KEY,
|
|
||||||
content TEXT NOT NULL,
|
|
||||||
raw_content TEXT,
|
|
||||||
meaning TEXT,
|
|
||||||
chat_id TEXT,
|
|
||||||
is_global INTEGER,
|
|
||||||
count INTEGER,
|
|
||||||
is_jargon INTEGER,
|
|
||||||
last_inference_count INTEGER,
|
|
||||||
is_complete INTEGER,
|
|
||||||
inference_with_context TEXT,
|
|
||||||
inference_content_only TEXT
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
connection.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
INSERT INTO chat_streams (
|
|
||||||
id,
|
|
||||||
stream_id,
|
|
||||||
create_time,
|
|
||||||
last_active_time,
|
|
||||||
platform,
|
|
||||||
user_id,
|
|
||||||
group_id,
|
|
||||||
group_name
|
|
||||||
) VALUES (
|
|
||||||
1,
|
|
||||||
'session-1',
|
|
||||||
1710000000.0,
|
|
||||||
1710000300.0,
|
|
||||||
'qq',
|
|
||||||
'user-1',
|
|
||||||
'group-1',
|
|
||||||
'测试群'
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
)
|
|
||||||
connection.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
INSERT INTO messages (
|
|
||||||
id,
|
|
||||||
message_id,
|
|
||||||
time,
|
|
||||||
chat_id,
|
|
||||||
chat_info_platform,
|
|
||||||
user_id,
|
|
||||||
user_nickname,
|
|
||||||
chat_info_group_id,
|
|
||||||
chat_info_group_name,
|
|
||||||
is_mentioned,
|
|
||||||
is_at,
|
|
||||||
processed_plain_text,
|
|
||||||
display_message,
|
|
||||||
is_emoji,
|
|
||||||
is_picid,
|
|
||||||
is_command,
|
|
||||||
is_notify,
|
|
||||||
additional_config,
|
|
||||||
priority_mode
|
|
||||||
) VALUES (
|
|
||||||
1,
|
|
||||||
'msg-1',
|
|
||||||
1710000010.0,
|
|
||||||
'session-1',
|
|
||||||
'qq',
|
|
||||||
'user-1',
|
|
||||||
'测试用户',
|
|
||||||
'group-1',
|
|
||||||
'测试群',
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
'你好',
|
|
||||||
'你好呀',
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
'{"source":"legacy"}',
|
|
||||||
'high'
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
)
|
|
||||||
connection.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
INSERT INTO action_records (
|
|
||||||
id,
|
|
||||||
action_id,
|
|
||||||
time,
|
|
||||||
action_reasoning,
|
|
||||||
action_name,
|
|
||||||
action_data,
|
|
||||||
action_prompt_display,
|
|
||||||
chat_id
|
|
||||||
) VALUES (
|
|
||||||
1,
|
|
||||||
'action-1',
|
|
||||||
1710000020.0,
|
|
||||||
'需要调用工具',
|
|
||||||
'search',
|
|
||||||
'{"query":"MaiBot"}',
|
|
||||||
'执行搜索',
|
|
||||||
'session-1'
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
)
|
|
||||||
connection.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
INSERT INTO expression (
|
|
||||||
id,
|
|
||||||
situation,
|
|
||||||
style,
|
|
||||||
content_list,
|
|
||||||
count,
|
|
||||||
last_active_time,
|
|
||||||
chat_id,
|
|
||||||
create_date,
|
|
||||||
checked,
|
|
||||||
rejected,
|
|
||||||
modified_by
|
|
||||||
) VALUES (
|
|
||||||
1,
|
|
||||||
'打招呼',
|
|
||||||
'可爱',
|
|
||||||
'["你好呀","早上好"]',
|
|
||||||
3,
|
|
||||||
1710000030.0,
|
|
||||||
'session-1',
|
|
||||||
1710000040.0,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
'ai'
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
)
|
|
||||||
connection.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
INSERT INTO jargon (
|
|
||||||
id,
|
|
||||||
content,
|
|
||||||
raw_content,
|
|
||||||
meaning,
|
|
||||||
chat_id,
|
|
||||||
is_global,
|
|
||||||
count,
|
|
||||||
is_jargon,
|
|
||||||
last_inference_count,
|
|
||||||
is_complete,
|
|
||||||
inference_with_context,
|
|
||||||
inference_content_only
|
|
||||||
) VALUES (
|
|
||||||
1,
|
|
||||||
'上分',
|
|
||||||
'["上分"]',
|
|
||||||
'提高排名',
|
|
||||||
'session-1',
|
|
||||||
0,
|
|
||||||
5,
|
|
||||||
1,
|
|
||||||
2,
|
|
||||||
1,
|
|
||||||
'{"guess":"context"}',
|
|
||||||
'{"guess":"content"}'
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_user_version_store_can_read_and_write_versions(tmp_path: Path) -> None:
|
|
||||||
"""应支持读取与写入 SQLite ``user_version``。"""
|
|
||||||
engine = _create_sqlite_engine(tmp_path / "version_store.db")
|
|
||||||
version_store = SQLiteUserVersionStore()
|
|
||||||
|
|
||||||
with engine.begin() as connection:
|
|
||||||
assert version_store.read_version(connection) == 0
|
|
||||||
version_store.write_version(connection, 7)
|
|
||||||
|
|
||||||
with engine.connect() as connection:
|
|
||||||
assert version_store.read_version(connection) == 7
|
|
||||||
|
|
||||||
|
|
||||||
def test_schema_inspector_can_extract_tables_and_columns(tmp_path: Path) -> None:
|
|
||||||
"""应能提取 SQLite 数据库的表与列结构。"""
|
|
||||||
engine = _create_sqlite_engine(tmp_path / "schema_inspector.db")
|
|
||||||
inspector = SQLiteSchemaInspector()
|
|
||||||
|
|
||||||
with engine.begin() as connection:
|
|
||||||
connection.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
CREATE TABLE legacy_records (
|
|
||||||
id INTEGER PRIMARY KEY,
|
|
||||||
payload TEXT NOT NULL,
|
|
||||||
created_at TEXT
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
with engine.connect() as connection:
|
|
||||||
snapshot = inspector.inspect(connection)
|
|
||||||
|
|
||||||
assert snapshot.has_table("legacy_records")
|
|
||||||
assert snapshot.has_column("legacy_records", "payload")
|
|
||||||
assert not snapshot.has_column("legacy_records", "missing_column")
|
|
||||||
table_schema = snapshot.get_table("legacy_records")
|
|
||||||
|
|
||||||
assert table_schema is not None
|
|
||||||
assert table_schema.column_names() == ["created_at", "id", "payload"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolver_can_identify_empty_database(tmp_path: Path) -> None:
|
|
||||||
"""空数据库应被解析为版本 ``0``。"""
|
|
||||||
engine = _create_sqlite_engine(tmp_path / "empty_resolver.db")
|
|
||||||
resolver = SchemaVersionResolver()
|
|
||||||
|
|
||||||
with engine.connect() as connection:
|
|
||||||
resolved_version = resolver.resolve(connection)
|
|
||||||
|
|
||||||
assert resolved_version.version == 0
|
|
||||||
assert resolved_version.source == SchemaVersionSource.EMPTY_DATABASE
|
|
||||||
assert resolved_version.snapshot is not None
|
|
||||||
assert resolved_version.snapshot.is_empty()
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolver_can_use_detector_for_unversioned_legacy_database(tmp_path: Path) -> None:
|
|
||||||
"""未写入 ``user_version`` 的历史库应支持通过探测器识别版本。"""
|
|
||||||
engine = _create_sqlite_engine(tmp_path / "legacy_resolver.db")
|
|
||||||
resolver = SchemaVersionResolver(detectors=[FixedVersionDetector()])
|
|
||||||
|
|
||||||
with engine.begin() as connection:
|
|
||||||
connection.execute(text("CREATE TABLE legacy_records (id INTEGER PRIMARY KEY, payload TEXT NOT NULL)"))
|
|
||||||
|
|
||||||
with engine.connect() as connection:
|
|
||||||
resolved_version = resolver.resolve(connection)
|
|
||||||
|
|
||||||
assert resolved_version.version == 2
|
|
||||||
assert resolved_version.source == SchemaVersionSource.DETECTOR
|
|
||||||
assert resolved_version.detector_name == "fixed_version_detector"
|
|
||||||
|
|
||||||
|
|
||||||
def test_registry_and_manager_can_execute_registered_steps(tmp_path: Path) -> None:
|
|
||||||
"""迁移编排器应能按顺序执行已注册步骤并更新版本号。"""
|
|
||||||
engine = _create_sqlite_engine(tmp_path / "manager.db")
|
|
||||||
executed_steps: List[str] = []
|
|
||||||
|
|
||||||
def migrate_0_to_1(context: MigrationExecutionContext) -> None:
|
|
||||||
"""测试迁移步骤 0 -> 1。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context: 当前迁移步骤执行上下文。
|
|
||||||
"""
|
|
||||||
executed_steps.append(f"{context.current_version}->{context.target_version}:step_0_to_1")
|
|
||||||
context.connection.execute(text("CREATE TABLE sample_records (id INTEGER PRIMARY KEY, name TEXT NOT NULL)"))
|
|
||||||
|
|
||||||
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
|
|
||||||
"""测试迁移步骤 1 -> 2。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context: 当前迁移步骤执行上下文。
|
|
||||||
"""
|
|
||||||
executed_steps.append(f"{context.current_version}->{context.target_version}:step_1_to_2")
|
|
||||||
context.connection.execute(text("ALTER TABLE sample_records ADD COLUMN email TEXT"))
|
|
||||||
|
|
||||||
registry = MigrationRegistry(
|
|
||||||
steps=[
|
|
||||||
MigrationStep(
|
|
||||||
version_from=0,
|
|
||||||
version_to=1,
|
|
||||||
name="create_sample_records",
|
|
||||||
description="创建示例表。",
|
|
||||||
handler=migrate_0_to_1,
|
|
||||||
),
|
|
||||||
MigrationStep(
|
|
||||||
version_from=1,
|
|
||||||
version_to=2,
|
|
||||||
name="add_sample_email",
|
|
||||||
description="为示例表增加邮箱字段。",
|
|
||||||
handler=migrate_1_to_2,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
manager = DatabaseMigrationManager(engine=engine, registry=registry)
|
|
||||||
|
|
||||||
migration_plan = manager.migrate()
|
|
||||||
|
|
||||||
assert migration_plan.step_count() == 2
|
|
||||||
assert executed_steps == ["0->2:step_0_to_1", "1->2:step_1_to_2"]
|
|
||||||
|
|
||||||
with engine.connect() as connection:
|
|
||||||
version_store = SQLiteUserVersionStore()
|
|
||||||
snapshot = SQLiteSchemaInspector().inspect(connection)
|
|
||||||
recorded_version = version_store.read_version(connection)
|
|
||||||
|
|
||||||
assert recorded_version == 2
|
|
||||||
assert snapshot.has_table("sample_records")
|
|
||||||
assert snapshot.has_column("sample_records", "email")
|
|
||||||
|
|
||||||
|
|
||||||
def test_manager_can_report_step_progress(tmp_path: Path) -> None:
|
|
||||||
"""迁移编排器应支持通过上下文上报步骤进度。"""
|
|
||||||
engine = _create_sqlite_engine(tmp_path / "manager_progress.db")
|
|
||||||
reporter_instances: List[FakeMigrationProgressReporter] = []
|
|
||||||
|
|
||||||
def _build_reporter() -> BaseMigrationProgressReporter:
|
|
||||||
"""构建测试用进度上报器。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BaseMigrationProgressReporter: 测试用进度上报器实例。
|
|
||||||
"""
|
|
||||||
reporter = FakeMigrationProgressReporter()
|
|
||||||
reporter_instances.append(reporter)
|
|
||||||
return reporter
|
|
||||||
|
|
||||||
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
|
|
||||||
"""测试迁移步骤 ``1 -> 2`` 的进度上报。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context: 当前迁移步骤执行上下文。
|
|
||||||
"""
|
|
||||||
context.start_progress(total_tables=3, total_records=30, description="总迁移进度")
|
|
||||||
context.advance_progress(records=10, completed_tables=1, item_name="chat_sessions")
|
|
||||||
context.advance_progress(records=10, completed_tables=1, item_name="mai_messages")
|
|
||||||
context.advance_progress(records=10, completed_tables=1, item_name="tool_records")
|
|
||||||
context.connection.execute(text("CREATE TABLE progress_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)"))
|
|
||||||
|
|
||||||
with engine.begin() as connection:
|
|
||||||
SQLiteUserVersionStore().write_version(connection, 1)
|
|
||||||
|
|
||||||
registry = MigrationRegistry(
|
|
||||||
steps=[
|
|
||||||
MigrationStep(
|
|
||||||
version_from=1,
|
|
||||||
version_to=2,
|
|
||||||
name="progress_step",
|
|
||||||
description="测试进度上报。",
|
|
||||||
handler=migrate_1_to_2,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
manager = DatabaseMigrationManager(
|
|
||||||
engine=engine,
|
|
||||||
registry=registry,
|
|
||||||
progress_reporter_factory=_build_reporter,
|
|
||||||
)
|
|
||||||
|
|
||||||
migration_plan = manager.migrate()
|
|
||||||
|
|
||||||
assert migration_plan.step_count() == 1
|
|
||||||
assert len(reporter_instances) == 1
|
|
||||||
assert reporter_instances[0].events == [
|
|
||||||
("open", None, None, None),
|
|
||||||
("start", 30, 3, "总迁移进度"),
|
|
||||||
("advance", 10, 1, "chat_sessions"),
|
|
||||||
("advance", 10, 1, "mai_messages"),
|
|
||||||
("advance", 10, 1, "tool_records"),
|
|
||||||
("close", None, None, None),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_default_resolver_can_identify_unversioned_latest_database(tmp_path: Path) -> None:
|
|
||||||
"""默认解析器应能识别未写入版本号的最新结构数据库。"""
|
|
||||||
engine = _create_sqlite_engine(tmp_path / "latest_resolver.db")
|
|
||||||
resolver = build_default_schema_version_resolver()
|
|
||||||
|
|
||||||
with engine.begin() as connection:
|
|
||||||
_create_current_schema(connection)
|
|
||||||
|
|
||||||
with engine.connect() as connection:
|
|
||||||
resolved_version = resolver.resolve(connection)
|
|
||||||
|
|
||||||
assert resolved_version.version == LATEST_SCHEMA_VERSION
|
|
||||||
assert resolved_version.source == SchemaVersionSource.DETECTOR
|
|
||||||
assert resolved_version.detector_name == "latest_schema_detector"
|
|
||||||
|
|
||||||
|
|
||||||
def test_default_resolver_can_identify_legacy_v1_database(tmp_path: Path) -> None:
|
|
||||||
"""默认解析器应能识别未写版本号的旧版 ``0.x`` 数据库。"""
|
|
||||||
engine = _create_sqlite_engine(tmp_path / "legacy_v1_resolver.db")
|
|
||||||
resolver = build_default_schema_version_resolver()
|
|
||||||
|
|
||||||
with engine.begin() as connection:
|
|
||||||
_create_legacy_v1_schema_with_sample_data(connection)
|
|
||||||
|
|
||||||
with engine.connect() as connection:
|
|
||||||
resolved_version = resolver.resolve(connection)
|
|
||||||
|
|
||||||
assert resolved_version.version == LEGACY_V1_SCHEMA_VERSION
|
|
||||||
assert resolved_version.source == SchemaVersionSource.DETECTOR
|
|
||||||
assert resolved_version.detector_name == "legacy_v1_schema_detector"
|
|
||||||
|
|
||||||
|
|
||||||
def test_bootstrapper_can_finalize_unversioned_latest_database(tmp_path: Path) -> None:
|
|
||||||
"""已是最新结构但未写版本号的数据库应直接补写 ``user_version``。"""
|
|
||||||
engine = _create_sqlite_engine(tmp_path / "latest_finalize.db")
|
|
||||||
bootstrapper = create_database_migration_bootstrapper(engine)
|
|
||||||
|
|
||||||
with engine.begin() as connection:
|
|
||||||
_create_current_schema(connection)
|
|
||||||
|
|
||||||
migration_state = bootstrapper.prepare_database()
|
|
||||||
bootstrapper.finalize_database(migration_state)
|
|
||||||
|
|
||||||
assert not migration_state.requires_migration()
|
|
||||||
assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION
|
|
||||||
assert migration_state.resolved_version.source == SchemaVersionSource.DETECTOR
|
|
||||||
|
|
||||||
with engine.connect() as connection:
|
|
||||||
recorded_version = SQLiteUserVersionStore().read_version(connection)
|
|
||||||
|
|
||||||
assert recorded_version == LATEST_SCHEMA_VERSION
|
|
||||||
|
|
||||||
|
|
||||||
def test_bootstrapper_can_finalize_empty_database_to_latest_version(tmp_path: Path) -> None:
|
|
||||||
"""空库在建表完成后应回写最新 ``user_version``。"""
|
|
||||||
engine = _create_sqlite_engine(tmp_path / "bootstrap_empty.db")
|
|
||||||
bootstrapper = create_database_migration_bootstrapper(engine)
|
|
||||||
|
|
||||||
migration_state = bootstrapper.prepare_database()
|
|
||||||
|
|
||||||
assert not migration_state.requires_migration()
|
|
||||||
assert migration_state.resolved_version.version == EMPTY_SCHEMA_VERSION
|
|
||||||
assert migration_state.target_version == LATEST_SCHEMA_VERSION
|
|
||||||
|
|
||||||
with engine.begin() as connection:
|
|
||||||
_create_current_schema(connection)
|
|
||||||
|
|
||||||
bootstrapper.finalize_database(migration_state)
|
|
||||||
|
|
||||||
with engine.connect() as connection:
|
|
||||||
recorded_version = SQLiteUserVersionStore().read_version(connection)
|
|
||||||
|
|
||||||
assert recorded_version == LATEST_SCHEMA_VERSION
|
|
||||||
|
|
||||||
|
|
||||||
def test_bootstrapper_runs_registered_steps_for_versioned_database(tmp_path: Path) -> None:
|
|
||||||
"""启动桥接器应在已登记旧版本数据库上执行注册迁移步骤。"""
|
|
||||||
engine = _create_sqlite_engine(tmp_path / "bootstrap_registered.db")
|
|
||||||
execution_marks: List[str] = []
|
|
||||||
|
|
||||||
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
|
|
||||||
"""测试桥接器迁移步骤 ``1 -> 2``。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context: 当前迁移步骤执行上下文。
|
|
||||||
"""
|
|
||||||
execution_marks.append(f"step={context.step_name},index={context.step_index}")
|
|
||||||
context.connection.execute(text("ALTER TABLE bootstrap_records ADD COLUMN email TEXT"))
|
|
||||||
|
|
||||||
with engine.begin() as connection:
|
|
||||||
connection.execute(
|
|
||||||
text("CREATE TABLE bootstrap_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)")
|
|
||||||
)
|
|
||||||
SQLiteUserVersionStore().write_version(connection, 1)
|
|
||||||
|
|
||||||
registry = MigrationRegistry(
|
|
||||||
steps=[
|
|
||||||
MigrationStep(
|
|
||||||
version_from=1,
|
|
||||||
version_to=2,
|
|
||||||
name="bootstrap_add_email",
|
|
||||||
description="为桥接器测试表增加邮箱字段。",
|
|
||||||
handler=migrate_1_to_2,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
bootstrapper = DatabaseMigrationBootstrapper(
|
|
||||||
manager=DatabaseMigrationManager(engine=engine, registry=registry),
|
|
||||||
latest_schema_version=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
migration_state = bootstrapper.prepare_database()
|
|
||||||
|
|
||||||
assert migration_state.resolved_version.version == 2
|
|
||||||
assert migration_state.target_version == 2
|
|
||||||
assert execution_marks == ["step=bootstrap_add_email,index=1"]
|
|
||||||
|
|
||||||
with engine.connect() as connection:
|
|
||||||
snapshot = SQLiteSchemaInspector().inspect(connection)
|
|
||||||
recorded_version = SQLiteUserVersionStore().read_version(connection)
|
|
||||||
|
|
||||||
assert recorded_version == 2
|
|
||||||
assert snapshot.has_table("bootstrap_records")
|
|
||||||
assert snapshot.has_column("bootstrap_records", "email")
|
|
||||||
|
|
||||||
|
|
||||||
def test_default_bootstrapper_can_migrate_legacy_v1_database(tmp_path: Path) -> None:
|
|
||||||
"""默认桥接器应能把旧版 ``0.x`` 数据库整体迁移到最新结构。"""
|
|
||||||
engine = _create_sqlite_engine(tmp_path / "legacy_v1_to_v2.db")
|
|
||||||
bootstrapper = create_database_migration_bootstrapper(engine)
|
|
||||||
|
|
||||||
with engine.begin() as connection:
|
|
||||||
_create_legacy_v1_schema_with_sample_data(connection)
|
|
||||||
|
|
||||||
migration_state = bootstrapper.prepare_database()
|
|
||||||
bootstrapper.finalize_database(migration_state)
|
|
||||||
|
|
||||||
assert not migration_state.requires_migration()
|
|
||||||
assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION
|
|
||||||
assert migration_state.resolved_version.source == SchemaVersionSource.PRAGMA
|
|
||||||
|
|
||||||
with engine.connect() as connection:
|
|
||||||
recorded_version = SQLiteUserVersionStore().read_version(connection)
|
|
||||||
snapshot = SQLiteSchemaInspector().inspect(connection)
|
|
||||||
message_row = connection.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
SELECT session_id, processed_plain_text, additional_config, raw_content
|
|
||||||
FROM mai_messages
|
|
||||||
WHERE message_id = 'msg-1'
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
).mappings().one()
|
|
||||||
tool_row = connection.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
SELECT session_id, tool_name, tool_display_prompt
|
|
||||||
FROM tool_records
|
|
||||||
WHERE tool_id = 'action-1'
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
).mappings().one()
|
|
||||||
expression_row = connection.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
SELECT session_id, content_list, modified_by
|
|
||||||
FROM expressions
|
|
||||||
WHERE id = 1
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
).mappings().one()
|
|
||||||
jargon_row = connection.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
SELECT session_id_dict, raw_content, inference_with_content_only
|
|
||||||
FROM jargons
|
|
||||||
WHERE id = 1
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
).mappings().one()
|
|
||||||
|
|
||||||
assert recorded_version == LATEST_SCHEMA_VERSION
|
|
||||||
assert snapshot.has_table("__legacy_v1_messages")
|
|
||||||
assert snapshot.has_table("chat_sessions")
|
|
||||||
assert snapshot.has_table("mai_messages")
|
|
||||||
assert snapshot.has_table("tool_records")
|
|
||||||
assert not snapshot.has_table("action_records")
|
|
||||||
assert not snapshot.has_column("mai_messages", "display_message")
|
|
||||||
|
|
||||||
unpacked_raw_content = msgpack.unpackb(message_row["raw_content"], raw=False)
|
|
||||||
additional_config = json.loads(message_row["additional_config"])
|
|
||||||
expression_content_list = json.loads(expression_row["content_list"])
|
|
||||||
jargon_session_id_dict = json.loads(jargon_row["session_id_dict"])
|
|
||||||
jargon_raw_content = json.loads(jargon_row["raw_content"])
|
|
||||||
|
|
||||||
assert message_row["session_id"] == "session-1"
|
|
||||||
assert message_row["processed_plain_text"] == "你好"
|
|
||||||
assert unpacked_raw_content == [{"type": "text", "data": "你好呀"}]
|
|
||||||
assert additional_config == {"priority_mode": "high", "source": "legacy"}
|
|
||||||
assert tool_row["session_id"] == "session-1"
|
|
||||||
assert tool_row["tool_name"] == "search"
|
|
||||||
assert tool_row["tool_display_prompt"] == "执行搜索"
|
|
||||||
assert expression_row["session_id"] == "session-1"
|
|
||||||
assert expression_row["modified_by"] == "AI"
|
|
||||||
assert expression_content_list == ["你好呀", "早上好"]
|
|
||||||
assert jargon_session_id_dict == {"session-1": 5}
|
|
||||||
assert jargon_raw_content == ["上分"]
|
|
||||||
assert jargon_row["inference_with_content_only"] == '{"guess":"content"}'
|
|
||||||
|
|
||||||
|
|
||||||
def test_legacy_v1_migration_reports_table_progress(tmp_path: Path) -> None:
|
|
||||||
"""旧版迁移步骤应按目标表数量推进总进度。"""
|
|
||||||
engine = _create_sqlite_engine(tmp_path / "legacy_progress.db")
|
|
||||||
reporter_instances: List[FakeMigrationProgressReporter] = []
|
|
||||||
|
|
||||||
def _build_reporter() -> BaseMigrationProgressReporter:
|
|
||||||
"""构建测试用进度上报器。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BaseMigrationProgressReporter: 测试用进度上报器实例。
|
|
||||||
"""
|
|
||||||
reporter = FakeMigrationProgressReporter()
|
|
||||||
reporter_instances.append(reporter)
|
|
||||||
return reporter
|
|
||||||
|
|
||||||
with engine.begin() as connection:
|
|
||||||
_create_legacy_v1_schema_with_sample_data(connection)
|
|
||||||
|
|
||||||
manager = DatabaseMigrationManager(
|
|
||||||
engine=engine,
|
|
||||||
registry=build_default_migration_registry(),
|
|
||||||
resolver=build_default_schema_version_resolver(),
|
|
||||||
progress_reporter_factory=_build_reporter,
|
|
||||||
)
|
|
||||||
|
|
||||||
migration_plan = manager.migrate(target_version=LATEST_SCHEMA_VERSION)
|
|
||||||
|
|
||||||
assert migration_plan.step_count() == 3
|
|
||||||
assert len(reporter_instances) == 3
|
|
||||||
reporter_events = reporter_instances[0].events
|
|
||||||
|
|
||||||
assert reporter_events[0] == ("open", None, None, None)
|
|
||||||
assert reporter_events[1] == ("start", 6, 12, "总迁移进度")
|
|
||||||
assert reporter_events[-1] == ("close", None, None, None)
|
|
||||||
assert reporter_events.count(("advance", 1, 0, None)) == 6
|
|
||||||
assert reporter_events.count(("advance", 0, 1, "chat_sessions")) == 1
|
|
||||||
assert reporter_events.count(("advance", 0, 1, "thinking_questions")) == 1
|
|
||||||
assert len([event for event in reporter_events if event[0] == "advance"]) == 18
|
|
||||||
|
|
||||||
|
|
||||||
def test_initialize_database_calls_bootstrapper_before_create_all(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
tmp_path: Path,
|
|
||||||
) -> None:
|
|
||||||
"""数据库初始化入口应先准备迁移,再建表、补迁移并收尾。"""
|
|
||||||
call_order: List[str] = []
|
|
||||||
|
|
||||||
def _fake_prepare_database() -> DatabaseMigrationState:
|
|
||||||
"""返回测试用迁移状态。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
DatabaseMigrationState: 不包含迁移步骤的测试状态。
|
|
||||||
"""
|
|
||||||
call_order.append("prepare_database")
|
|
||||||
return DatabaseMigrationState(
|
|
||||||
resolved_version=ResolvedSchemaVersion(version=0, source=SchemaVersionSource.EMPTY_DATABASE),
|
|
||||||
target_version=LATEST_SCHEMA_VERSION,
|
|
||||||
plan=MigrationPlan(
|
|
||||||
current_version=EMPTY_SCHEMA_VERSION,
|
|
||||||
target_version=LATEST_SCHEMA_VERSION,
|
|
||||||
steps=[],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _fake_create_all(bind) -> None:
|
|
||||||
"""记录建表调用。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
bind: 传入的数据库绑定对象。
|
|
||||||
"""
|
|
||||||
del bind
|
|
||||||
call_order.append("create_all")
|
|
||||||
|
|
||||||
def _fake_finalize_database(migration_state: DatabaseMigrationState) -> None:
|
|
||||||
"""记录迁移收尾调用。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
migration_state: 当前数据库迁移状态。
|
|
||||||
"""
|
|
||||||
del migration_state
|
|
||||||
call_order.append("finalize_database")
|
|
||||||
|
|
||||||
monkeypatch.setattr(database_module, "_db_initialized", False)
|
|
||||||
monkeypatch.setattr(database_module, "_DB_DIR", tmp_path / "data")
|
|
||||||
monkeypatch.setattr(database_module._migration_bootstrapper, "prepare_database", _fake_prepare_database)
|
|
||||||
monkeypatch.setattr(database_module._migration_bootstrapper, "finalize_database", _fake_finalize_database)
|
|
||||||
monkeypatch.setattr(database_module.SQLModel.metadata, "create_all", _fake_create_all)
|
|
||||||
|
|
||||||
database_module.initialize_database()
|
|
||||||
|
|
||||||
assert call_order == [
|
|
||||||
"prepare_database",
|
|
||||||
"create_all",
|
|
||||||
"finalize_database",
|
|
||||||
]
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
"""测试表达方式学习器的数据库读取行为。"""
|
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import Generator
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from sqlalchemy.pool import StaticPool
|
|
||||||
from sqlmodel import Session, SQLModel, create_engine
|
|
||||||
|
|
||||||
from src.bw_learner.expression_learner import ExpressionLearner
|
|
||||||
from src.common.database.database_model import Expression
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="expression_learner_engine")
|
|
||||||
def expression_learner_engine_fixture() -> Generator:
|
|
||||||
"""创建用于表达方式学习器测试的内存数据库引擎。
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
Generator: 供测试使用的 SQLite 内存引擎。
|
|
||||||
"""
|
|
||||||
engine = create_engine(
|
|
||||||
"sqlite://",
|
|
||||||
connect_args={"check_same_thread": False},
|
|
||||||
poolclass=StaticPool,
|
|
||||||
)
|
|
||||||
SQLModel.metadata.create_all(engine)
|
|
||||||
yield engine
|
|
||||||
|
|
||||||
|
|
||||||
def test_find_similar_expression_uses_read_only_session_and_history_content(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
expression_learner_engine,
|
|
||||||
) -> None:
|
|
||||||
"""查找相似表达方式时,应能在离开会话后安全使用结果,并比较历史情景内容。"""
|
|
||||||
import src.bw_learner.expression_learner as expression_learner_module
|
|
||||||
|
|
||||||
with Session(expression_learner_engine) as session:
|
|
||||||
session.add(
|
|
||||||
Expression(
|
|
||||||
situation="发送汗滴表情",
|
|
||||||
style="发送💦表情符号",
|
|
||||||
content_list='["表达情绪高涨或生理反应"]',
|
|
||||||
count=1,
|
|
||||||
session_id="session-a",
|
|
||||||
checked=False,
|
|
||||||
rejected=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
|
|
||||||
"""构造带自动提交语义的测试会话工厂。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
auto_commit: 退出上下文时是否自动提交。
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
Generator[Session, None, None]: SQLModel 会话对象。
|
|
||||||
"""
|
|
||||||
session = Session(expression_learner_engine)
|
|
||||||
try:
|
|
||||||
yield session
|
|
||||||
if auto_commit:
|
|
||||||
session.commit()
|
|
||||||
except Exception:
|
|
||||||
session.rollback()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
session.close()
|
|
||||||
|
|
||||||
monkeypatch.setattr(expression_learner_module, "get_db_session", fake_get_db_session)
|
|
||||||
|
|
||||||
learner = ExpressionLearner(session_id="session-a")
|
|
||||||
result = learner._find_similar_expression("表达情绪高涨或生理反应")
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
expression, similarity = result
|
|
||||||
assert expression.item_id is not None
|
|
||||||
assert expression.style == "发送💦表情符号"
|
|
||||||
assert similarity == pytest.approx(1.0)
|
|
||||||
@@ -1,78 +0,0 @@
|
|||||||
"""测试表达方式表结构和基础插入行为。"""
|
|
||||||
|
|
||||||
from typing import Generator
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from sqlalchemy.pool import StaticPool
|
|
||||||
from sqlmodel import Session, SQLModel, create_engine
|
|
||||||
|
|
||||||
from src.common.database.database_model import Expression
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="expression_engine")
|
|
||||||
def expression_engine_fixture() -> Generator:
|
|
||||||
"""创建仅用于表达方式表测试的内存数据库引擎。
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
Generator: 供测试使用的 SQLite 内存引擎。
|
|
||||||
"""
|
|
||||||
engine = create_engine(
|
|
||||||
"sqlite://",
|
|
||||||
connect_args={"check_same_thread": False},
|
|
||||||
poolclass=StaticPool,
|
|
||||||
)
|
|
||||||
SQLModel.metadata.create_all(engine)
|
|
||||||
yield engine
|
|
||||||
|
|
||||||
|
|
||||||
def test_expression_insert_assigns_auto_increment_id(expression_engine) -> None:
|
|
||||||
"""表达方式表在新库中应能自动分配自增主键。"""
|
|
||||||
with Session(expression_engine) as session:
|
|
||||||
expression = Expression(
|
|
||||||
situation="表达情绪高涨或生理反应",
|
|
||||||
style="发送💦表情符号",
|
|
||||||
content_list='["表达情绪高涨或生理反应"]',
|
|
||||||
count=1,
|
|
||||||
session_id="session-a",
|
|
||||||
checked=False,
|
|
||||||
rejected=False,
|
|
||||||
)
|
|
||||||
session.add(expression)
|
|
||||||
session.commit()
|
|
||||||
session.refresh(expression)
|
|
||||||
|
|
||||||
assert expression.id is not None
|
|
||||||
assert expression.id > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_expression_insert_allows_same_situation_style(expression_engine) -> None:
|
|
||||||
"""相同情景和风格的表达方式记录不应再被错误绑定到复合主键。"""
|
|
||||||
with Session(expression_engine) as session:
|
|
||||||
first_expression = Expression(
|
|
||||||
situation="对重复行为的默契响应",
|
|
||||||
style="持续性跟发相同内容",
|
|
||||||
content_list='["对重复行为的默契响应"]',
|
|
||||||
count=1,
|
|
||||||
session_id="session-a",
|
|
||||||
checked=False,
|
|
||||||
rejected=False,
|
|
||||||
)
|
|
||||||
second_expression = Expression(
|
|
||||||
situation="对重复行为的默契响应",
|
|
||||||
style="持续性跟发相同内容",
|
|
||||||
content_list='["对重复行为的默契响应-变体"]',
|
|
||||||
count=2,
|
|
||||||
session_id="session-b",
|
|
||||||
checked=False,
|
|
||||||
rejected=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
session.add(first_expression)
|
|
||||||
session.add(second_expression)
|
|
||||||
session.commit()
|
|
||||||
session.refresh(first_expression)
|
|
||||||
session.refresh(second_expression)
|
|
||||||
|
|
||||||
assert first_expression.id is not None
|
|
||||||
assert second_expression.id is not None
|
|
||||||
assert first_expression.id != second_expression.id
|
|
||||||
@@ -1,90 +0,0 @@
|
|||||||
"""测试黑话学习器的数据库读取行为。"""
|
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import Generator
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from sqlalchemy.pool import StaticPool
|
|
||||||
from sqlmodel import Session, SQLModel, create_engine, select
|
|
||||||
|
|
||||||
from src.bw_learner.jargon_miner import JargonMiner
|
|
||||||
from src.common.database.database_model import Jargon
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="jargon_miner_engine")
|
|
||||||
def jargon_miner_engine_fixture() -> Generator:
|
|
||||||
"""创建用于黑话学习器测试的内存数据库引擎。
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
Generator: 供测试使用的 SQLite 内存引擎。
|
|
||||||
"""
|
|
||||||
engine = create_engine(
|
|
||||||
"sqlite://",
|
|
||||||
connect_args={"check_same_thread": False},
|
|
||||||
poolclass=StaticPool,
|
|
||||||
)
|
|
||||||
SQLModel.metadata.create_all(engine)
|
|
||||||
yield engine
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_process_extracted_entries_updates_existing_jargon_without_detached_session(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
jargon_miner_engine,
|
|
||||||
) -> None:
|
|
||||||
"""更新已有黑话时,不应因会话关闭导致 ORM 实例失效。"""
|
|
||||||
import src.bw_learner.jargon_miner as jargon_miner_module
|
|
||||||
|
|
||||||
with Session(jargon_miner_engine) as session:
|
|
||||||
session.add(
|
|
||||||
Jargon(
|
|
||||||
content="VF8V4L",
|
|
||||||
raw_content='["[1] first"]',
|
|
||||||
meaning="",
|
|
||||||
session_id_dict='{"session-a": 1}',
|
|
||||||
count=0,
|
|
||||||
is_jargon=True,
|
|
||||||
is_complete=False,
|
|
||||||
is_global=False,
|
|
||||||
last_inference_count=0,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
|
|
||||||
"""构造带自动提交语义的测试会话工厂。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
auto_commit: 退出上下文时是否自动提交。
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
Generator[Session, None, None]: SQLModel 会话对象。
|
|
||||||
"""
|
|
||||||
session = Session(jargon_miner_engine)
|
|
||||||
try:
|
|
||||||
yield session
|
|
||||||
if auto_commit:
|
|
||||||
session.commit()
|
|
||||||
except Exception:
|
|
||||||
session.rollback()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
session.close()
|
|
||||||
|
|
||||||
monkeypatch.setattr(jargon_miner_module, "get_db_session", fake_get_db_session)
|
|
||||||
|
|
||||||
jargon_miner = JargonMiner(session_id="session-a", session_name="测试群")
|
|
||||||
await jargon_miner.process_extracted_entries(
|
|
||||||
[{"content": "VF8V4L", "raw_content": {"[2] second"}}],
|
|
||||||
)
|
|
||||||
|
|
||||||
with Session(jargon_miner_engine) as session:
|
|
||||||
db_jargon = session.exec(select(Jargon).where(Jargon.content == "VF8V4L")).one()
|
|
||||||
|
|
||||||
assert db_jargon.count == 1
|
|
||||||
assert db_jargon.session_id_dict == '{"session-a": 2}'
|
|
||||||
assert sorted(db_jargon.raw_content and __import__("json").loads(db_jargon.raw_content)) == [
|
|
||||||
"[1] first",
|
|
||||||
"[2] second",
|
|
||||||
]
|
|
||||||
@@ -1,84 +0,0 @@
|
|||||||
"""测试黑话表结构和基础插入行为。"""
|
|
||||||
|
|
||||||
from typing import Generator
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from sqlalchemy.pool import StaticPool
|
|
||||||
from sqlmodel import Session, SQLModel, create_engine
|
|
||||||
|
|
||||||
from src.common.database.database_model import Jargon
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="jargon_engine")
|
|
||||||
def jargon_engine_fixture() -> Generator:
|
|
||||||
"""创建仅用于黑话表测试的内存数据库引擎。
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
Generator: 供测试使用的 SQLite 内存引擎。
|
|
||||||
"""
|
|
||||||
engine = create_engine(
|
|
||||||
"sqlite://",
|
|
||||||
connect_args={"check_same_thread": False},
|
|
||||||
poolclass=StaticPool,
|
|
||||||
)
|
|
||||||
SQLModel.metadata.create_all(engine)
|
|
||||||
yield engine
|
|
||||||
|
|
||||||
|
|
||||||
def test_jargon_insert_assigns_auto_increment_id(jargon_engine) -> None:
|
|
||||||
"""黑话表在新库中应能自动分配自增主键。"""
|
|
||||||
with Session(jargon_engine) as session:
|
|
||||||
jargon = Jargon(
|
|
||||||
content="VF8V4L",
|
|
||||||
raw_content='["[1] test"]',
|
|
||||||
meaning="",
|
|
||||||
session_id_dict='{"session-a": 1}',
|
|
||||||
count=1,
|
|
||||||
is_jargon=True,
|
|
||||||
is_complete=False,
|
|
||||||
is_global=True,
|
|
||||||
last_inference_count=0,
|
|
||||||
)
|
|
||||||
session.add(jargon)
|
|
||||||
session.commit()
|
|
||||||
session.refresh(jargon)
|
|
||||||
|
|
||||||
assert jargon.id is not None
|
|
||||||
assert jargon.id > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_jargon_insert_allows_same_content_with_different_rows(jargon_engine) -> None:
|
|
||||||
"""黑话内容不应再被错误地绑成复合主键的一部分。"""
|
|
||||||
with Session(jargon_engine) as session:
|
|
||||||
first_jargon = Jargon(
|
|
||||||
content="表情1",
|
|
||||||
raw_content='["[1] first"]',
|
|
||||||
meaning="",
|
|
||||||
session_id_dict='{"session-a": 1}',
|
|
||||||
count=1,
|
|
||||||
is_jargon=True,
|
|
||||||
is_complete=False,
|
|
||||||
is_global=False,
|
|
||||||
last_inference_count=0,
|
|
||||||
)
|
|
||||||
second_jargon = Jargon(
|
|
||||||
content="表情1",
|
|
||||||
raw_content='["[1] second"]',
|
|
||||||
meaning="",
|
|
||||||
session_id_dict='{"session-b": 1}',
|
|
||||||
count=1,
|
|
||||||
is_jargon=True,
|
|
||||||
is_complete=False,
|
|
||||||
is_global=False,
|
|
||||||
last_inference_count=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
session.add(first_jargon)
|
|
||||||
session.add(second_jargon)
|
|
||||||
session.commit()
|
|
||||||
session.refresh(first_jargon)
|
|
||||||
session.refresh(second_jargon)
|
|
||||||
|
|
||||||
assert first_jargon.id is not None
|
|
||||||
assert second_jargon.id is not None
|
|
||||||
assert first_jargon.id != second_jargon.id
|
|
||||||
@@ -1,135 +0,0 @@
|
|||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import src.chat.replyer.maisaka_expression_selector as selector_module
|
|
||||||
from src.chat.replyer.maisaka_expression_selector import MaisakaExpressionSelector
|
|
||||||
from src.common.utils.utils_session import SessionUtils
|
|
||||||
|
|
||||||
|
|
||||||
def _build_target(platform: str, item_id: str, rule_type: str = "group") -> SimpleNamespace:
|
|
||||||
return SimpleNamespace(platform=platform, item_id=item_id, rule_type=rule_type)
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_expression_group_scope_returns_related_sessions(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001")
|
|
||||||
related_session_id = SessionUtils.calculate_session_id("qq", group_id="10002")
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
selector_module,
|
|
||||||
"global_config",
|
|
||||||
SimpleNamespace(
|
|
||||||
expression=SimpleNamespace(
|
|
||||||
expression_groups=[
|
|
||||||
SimpleNamespace(
|
|
||||||
expression_groups=[
|
|
||||||
_build_target("qq", "10001"),
|
|
||||||
_build_target("qq", "10002"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
selector = MaisakaExpressionSelector()
|
|
||||||
related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id)
|
|
||||||
|
|
||||||
assert related_session_ids == {current_session_id, related_session_id}
|
|
||||||
assert has_global_share is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_expression_group_scope_matches_routed_sessions(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001", account_id="bot-a")
|
|
||||||
related_session_id = SessionUtils.calculate_session_id("qq", group_id="10002", account_id="bot-a")
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
selector_module,
|
|
||||||
"global_config",
|
|
||||||
SimpleNamespace(
|
|
||||||
expression=SimpleNamespace(
|
|
||||||
expression_groups=[
|
|
||||||
SimpleNamespace(
|
|
||||||
expression_groups=[
|
|
||||||
_build_target("qq", "10001"),
|
|
||||||
_build_target("qq", "10002"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
selector_module.ChatConfigUtils,
|
|
||||||
"_get_chat_stream",
|
|
||||||
lambda session_id: SimpleNamespace(platform="qq", group_id="10001", user_id=None)
|
|
||||||
if session_id == current_session_id
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
target_session_ids = {
|
|
||||||
"10001": current_session_id,
|
|
||||||
"10002": related_session_id,
|
|
||||||
}
|
|
||||||
monkeypatch.setattr(
|
|
||||||
selector_module.ChatConfigUtils,
|
|
||||||
"get_target_session_ids",
|
|
||||||
lambda target_item: {target_session_ids[target_item.item_id]},
|
|
||||||
)
|
|
||||||
|
|
||||||
selector = MaisakaExpressionSelector()
|
|
||||||
related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id)
|
|
||||||
|
|
||||||
assert related_session_ids == {current_session_id, related_session_id}
|
|
||||||
assert has_global_share is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_expression_group_scope_uses_star_as_global_share(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001")
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
selector_module,
|
|
||||||
"global_config",
|
|
||||||
SimpleNamespace(
|
|
||||||
expression=SimpleNamespace(
|
|
||||||
expression_groups=[
|
|
||||||
SimpleNamespace(
|
|
||||||
expression_groups=[
|
|
||||||
_build_target("*", "*"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
selector = MaisakaExpressionSelector()
|
|
||||||
related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id)
|
|
||||||
|
|
||||||
assert related_session_ids == {current_session_id}
|
|
||||||
assert has_global_share is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_expression_group_scope_does_not_treat_empty_target_as_global(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001")
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
selector_module,
|
|
||||||
"global_config",
|
|
||||||
SimpleNamespace(
|
|
||||||
expression=SimpleNamespace(
|
|
||||||
expression_groups=[
|
|
||||||
SimpleNamespace(
|
|
||||||
expression_groups=[
|
|
||||||
_build_target("", ""),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
selector = MaisakaExpressionSelector()
|
|
||||||
related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id)
|
|
||||||
|
|
||||||
assert related_session_ids == {current_session_id}
|
|
||||||
assert has_global_share is False
|
|
||||||
@@ -1,355 +0,0 @@
|
|||||||
"""人物信息群名片字段兼容测试。"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from importlib.util import module_from_spec, spec_from_file_location
|
|
||||||
from pathlib import Path
|
|
||||||
from types import ModuleType, SimpleNamespace
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json
|
|
||||||
|
|
||||||
|
|
||||||
class _DummyLogger:
|
|
||||||
"""模拟日志记录器。"""
|
|
||||||
|
|
||||||
def debug(self, message: str) -> None:
|
|
||||||
"""记录调试日志。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: 日志内容。
|
|
||||||
"""
|
|
||||||
del message
|
|
||||||
|
|
||||||
def info(self, message: str) -> None:
|
|
||||||
"""记录信息日志。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: 日志内容。
|
|
||||||
"""
|
|
||||||
del message
|
|
||||||
|
|
||||||
def warning(self, message: str) -> None:
|
|
||||||
"""记录警告日志。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: 日志内容。
|
|
||||||
"""
|
|
||||||
del message
|
|
||||||
|
|
||||||
def error(self, message: str) -> None:
|
|
||||||
"""记录错误日志。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: 日志内容。
|
|
||||||
"""
|
|
||||||
del message
|
|
||||||
|
|
||||||
|
|
||||||
class _DummyStatement:
|
|
||||||
"""模拟 SQL 查询语句对象。"""
|
|
||||||
|
|
||||||
def where(self, condition: Any) -> "_DummyStatement":
|
|
||||||
"""附加过滤条件。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
condition: 过滤条件。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
_DummyStatement: 当前语句对象。
|
|
||||||
"""
|
|
||||||
del condition
|
|
||||||
return self
|
|
||||||
|
|
||||||
def limit(self, value: int) -> "_DummyStatement":
|
|
||||||
"""限制返回条数。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
value: 条数限制。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
_DummyStatement: 当前语句对象。
|
|
||||||
"""
|
|
||||||
del value
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class _DummyColumn:
|
|
||||||
"""模拟 SQLModel 列对象。"""
|
|
||||||
|
|
||||||
def is_not(self, value: Any) -> "_DummyColumn":
|
|
||||||
"""模拟 `IS NOT` 条件构造。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
value: 比较值。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
_DummyColumn: 当前列对象。
|
|
||||||
"""
|
|
||||||
del value
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> "_DummyColumn":
|
|
||||||
"""模拟等值条件构造。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
other: 比较值。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
_DummyColumn: 当前列对象。
|
|
||||||
"""
|
|
||||||
del other
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class _DummyResult:
|
|
||||||
"""模拟数据库查询结果。"""
|
|
||||||
|
|
||||||
def __init__(self, record: Any) -> None:
|
|
||||||
"""初始化查询结果。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
record: 待返回的首条记录。
|
|
||||||
"""
|
|
||||||
self._record = record
|
|
||||||
|
|
||||||
def first(self) -> Any:
|
|
||||||
"""返回第一条记录。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: 首条记录。
|
|
||||||
"""
|
|
||||||
return self._record
|
|
||||||
|
|
||||||
def all(self) -> list[Any]:
|
|
||||||
"""返回全部结果。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[Any]: 结果列表。
|
|
||||||
"""
|
|
||||||
if self._record is None:
|
|
||||||
return []
|
|
||||||
return self._record if isinstance(self._record, list) else [self._record]
|
|
||||||
|
|
||||||
|
|
||||||
class _DummySession:
|
|
||||||
"""模拟数据库 Session。"""
|
|
||||||
|
|
||||||
def __init__(self, record: Any) -> None:
|
|
||||||
"""初始化 Session。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
record: `first()` 应返回的记录。
|
|
||||||
"""
|
|
||||||
self.record = record
|
|
||||||
self.added_records: list[Any] = []
|
|
||||||
|
|
||||||
def __enter__(self) -> "_DummySession":
|
|
||||||
"""进入上下文管理器。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
_DummySession: 当前 Session。
|
|
||||||
"""
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
||||||
"""退出上下文管理器。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
exc_type: 异常类型。
|
|
||||||
exc_val: 异常值。
|
|
||||||
exc_tb: 异常回溯。
|
|
||||||
"""
|
|
||||||
del exc_type
|
|
||||||
del exc_val
|
|
||||||
del exc_tb
|
|
||||||
|
|
||||||
def exec(self, statement: Any) -> _DummyResult:
|
|
||||||
"""执行查询。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
statement: 查询语句。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
_DummyResult: 模拟结果对象。
|
|
||||||
"""
|
|
||||||
del statement
|
|
||||||
return _DummyResult(self.record)
|
|
||||||
|
|
||||||
def add(self, record: Any) -> None:
|
|
||||||
"""记录被添加的对象。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
record: 被写入 Session 的对象。
|
|
||||||
"""
|
|
||||||
self.added_records.append(record)
|
|
||||||
|
|
||||||
|
|
||||||
class _DummyPersonInfoRecord:
|
|
||||||
"""模拟 `PersonInfo` ORM 模型。"""
|
|
||||||
|
|
||||||
person_id = "person_id"
|
|
||||||
person_name = "person_name"
|
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
|
||||||
"""使用关键字参数初始化记录对象。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
**kwargs: 字段值。
|
|
||||||
"""
|
|
||||||
for key, value in kwargs.items():
|
|
||||||
setattr(self, key, value)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_person_module(monkeypatch: pytest.MonkeyPatch, session: _DummySession) -> ModuleType:
|
|
||||||
"""加载带依赖桩的 `person_info` 模块。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
monkeypatch: Pytest monkeypatch 工具。
|
|
||||||
session: 提供给模块使用的假数据库 Session。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ModuleType: 加载后的模块对象。
|
|
||||||
"""
|
|
||||||
logger_module = ModuleType("src.common.logger")
|
|
||||||
logger_module.get_logger = lambda name: _DummyLogger()
|
|
||||||
monkeypatch.setitem(sys.modules, "src.common.logger", logger_module)
|
|
||||||
|
|
||||||
database_module = ModuleType("src.common.database.database")
|
|
||||||
database_module.get_db_session = lambda: session
|
|
||||||
monkeypatch.setitem(sys.modules, "src.common.database.database", database_module)
|
|
||||||
|
|
||||||
database_model_module = ModuleType("src.common.database.database_model")
|
|
||||||
database_model_module.PersonInfo = _DummyPersonInfoRecord
|
|
||||||
monkeypatch.setitem(sys.modules, "src.common.database.database_model", database_model_module)
|
|
||||||
|
|
||||||
llm_module = ModuleType("src.llm_models.utils_model")
|
|
||||||
|
|
||||||
class _DummyLLMRequest:
|
|
||||||
"""模拟 LLMRequest。"""
|
|
||||||
|
|
||||||
def __init__(self, model_set: Any, request_type: str) -> None:
|
|
||||||
"""初始化假请求对象。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_set: 模型配置。
|
|
||||||
request_type: 请求类型。
|
|
||||||
"""
|
|
||||||
del model_set
|
|
||||||
del request_type
|
|
||||||
|
|
||||||
llm_module.LLMRequest = _DummyLLMRequest
|
|
||||||
monkeypatch.setitem(sys.modules, "src.llm_models.utils_model", llm_module)
|
|
||||||
|
|
||||||
config_module = ModuleType("src.config.config")
|
|
||||||
config_module.global_config = SimpleNamespace(bot=SimpleNamespace(nickname="MaiBot"))
|
|
||||||
config_module.model_config = SimpleNamespace(model_task_config=SimpleNamespace(tool_use="tool_use", utils="utils"))
|
|
||||||
monkeypatch.setitem(sys.modules, "src.config.config", config_module)
|
|
||||||
|
|
||||||
chat_manager_module = ModuleType("src.chat.message_receive.chat_manager")
|
|
||||||
chat_manager_module.chat_manager = SimpleNamespace()
|
|
||||||
monkeypatch.setitem(sys.modules, "src.chat.message_receive.chat_manager", chat_manager_module)
|
|
||||||
|
|
||||||
module_path = Path(__file__).resolve().parents[2] / "src" / "person_info" / "person_info.py"
|
|
||||||
spec = spec_from_file_location("person_info_group_cardname_test_module", module_path)
|
|
||||||
assert spec is not None and spec.loader is not None
|
|
||||||
|
|
||||||
module = module_from_spec(spec)
|
|
||||||
monkeypatch.setitem(sys.modules, spec.name, module)
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
|
|
||||||
monkeypatch.setattr(module, "select", lambda *args: _DummyStatement())
|
|
||||||
monkeypatch.setattr(module, "col", lambda field: _DummyColumn())
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_group_cardname_json_uses_canonical_key() -> None:
|
|
||||||
"""群名片 JSON 解析应只使用 `group_cardname` 键名。"""
|
|
||||||
parsed = parse_group_cardname_json(
|
|
||||||
json.dumps(
|
|
||||||
[
|
|
||||||
{"group_id": "1001", "group_cardname": "现行字段"},
|
|
||||||
],
|
|
||||||
ensure_ascii=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
assert parsed is not None
|
|
||||||
assert [(item.group_id, item.group_cardname) for item in parsed] == [
|
|
||||||
("1001", "现行字段"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_dump_group_cardname_records_uses_canonical_key() -> None:
|
|
||||||
"""群名片序列化应输出 `group_cardname` 键名。"""
|
|
||||||
dumped = dump_group_cardname_records(
|
|
||||||
[
|
|
||||||
{"group_id": "1001", "group_cardname": "群昵称"},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
assert json.loads(dumped) == [{"group_id": "1001", "group_cardname": "群昵称"}]
|
|
||||||
|
|
||||||
|
|
||||||
def test_person_sync_to_database_uses_group_cardname_field(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
"""同步人物信息时应写入数据库模型的 `group_cardname` 字段。"""
|
|
||||||
record = _DummyPersonInfoRecord()
|
|
||||||
session = _DummySession(record)
|
|
||||||
module = _load_person_module(monkeypatch, session)
|
|
||||||
|
|
||||||
person = module.Person.__new__(module.Person)
|
|
||||||
person.is_known = True
|
|
||||||
person.person_id = "person-1"
|
|
||||||
person.platform = "qq"
|
|
||||||
person.user_id = "10001"
|
|
||||||
person.nickname = "看番的龙"
|
|
||||||
person.person_name = "看番的龙"
|
|
||||||
person.name_reason = "测试"
|
|
||||||
person.know_times = 1
|
|
||||||
person.know_since = 1700000000.0
|
|
||||||
person.last_know = 1700000100.0
|
|
||||||
person.memory_points = ["喜好:番剧:0.8"]
|
|
||||||
person.group_cardname_list = [{"group_id": "20001", "group_cardname": "白泽大人"}]
|
|
||||||
|
|
||||||
person.sync_to_database()
|
|
||||||
|
|
||||||
assert record.group_cardname == '[{"group_id": "20001", "group_cardname": "白泽大人"}]'
|
|
||||||
assert not hasattr(record, "group_nickname")
|
|
||||||
|
|
||||||
|
|
||||||
def test_person_load_from_database_normalizes_group_cardname_payload(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
"""从数据库加载人物信息时应读取标准 `group_cardname` 结构。"""
|
|
||||||
record = _DummyPersonInfoRecord(
|
|
||||||
user_id="10001",
|
|
||||||
platform="qq",
|
|
||||||
is_known=True,
|
|
||||||
user_nickname="看番的龙",
|
|
||||||
person_name="看番的龙",
|
|
||||||
name_reason=None,
|
|
||||||
know_counts=2,
|
|
||||||
memory_points='["喜好:番剧:0.8"]',
|
|
||||||
group_cardname=json.dumps(
|
|
||||||
[
|
|
||||||
{"group_id": "20001", "group_cardname": "白泽大人"},
|
|
||||||
],
|
|
||||||
ensure_ascii=False,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
session = _DummySession(record)
|
|
||||||
module = _load_person_module(monkeypatch, session)
|
|
||||||
|
|
||||||
person = module.Person.__new__(module.Person)
|
|
||||||
person.person_id = "person-1"
|
|
||||||
person.memory_points = []
|
|
||||||
person.group_cardname_list = []
|
|
||||||
|
|
||||||
person.load_from_database()
|
|
||||||
|
|
||||||
assert person.group_cardname_list == [
|
|
||||||
{"group_id": "20001", "group_cardname": "白泽大人"},
|
|
||||||
]
|
|
||||||
@@ -1,533 +0,0 @@
|
|||||||
import logging
|
|
||||||
import sys
|
|
||||||
from importlib import util
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
# -------------------------------------------------------------
|
|
||||||
# 测试环境准备:补全 logger 和 AttrDocBase 依赖
|
|
||||||
# -------------------------------------------------------------
|
|
||||||
|
|
||||||
TEST_ROOT = Path(__file__).parent.parent.absolute().resolve()
|
|
||||||
logger_file = TEST_ROOT / "logger.py"
|
|
||||||
spec = util.spec_from_file_location("src.common.logger", logger_file)
|
|
||||||
module = util.module_from_spec(spec) # type: ignore
|
|
||||||
assert spec is not None and spec.loader is not None
|
|
||||||
spec.loader.exec_module(module) # type: ignore
|
|
||||||
sys.modules["src.common.logger"] = module
|
|
||||||
|
|
||||||
PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve()
|
|
||||||
sys.path.insert(0, str(PROJECT_ROOT))
|
|
||||||
sys.path.insert(0, str(PROJECT_ROOT / "src" / "config"))
|
|
||||||
|
|
||||||
from src.config.config_base import ConfigBase # noqa: E402
|
|
||||||
import src.config.config_base as config_base_module # noqa: E402
|
|
||||||
|
|
||||||
|
|
||||||
class AttrDocBase:
|
|
||||||
"""用于测试的轻量级 AttrDocBase 替身"""
|
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
|
||||||
# 被 ConfigBase.model_post_init 调用
|
|
||||||
self.__post_init_called__ = True
|
|
||||||
|
|
||||||
|
|
||||||
# 打补丁,让 ConfigBase 使用测试替身
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def patch_attrdoc_post_init():
|
|
||||||
orig = config_base_module.AttrDocBase.__post_init__
|
|
||||||
config_base_module.AttrDocBase.__post_init__ = AttrDocBase.__post_init__ # type: ignore
|
|
||||||
yield
|
|
||||||
config_base_module.AttrDocBase.__post_init__ = orig
|
|
||||||
|
|
||||||
|
|
||||||
config_base_module.logger = logging.getLogger("config_base_test_logger")
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleClass(ConfigBase):
|
|
||||||
a: int = 1
|
|
||||||
b: str = "test"
|
|
||||||
|
|
||||||
|
|
||||||
class TestConfigBase:
|
|
||||||
# ---------------------------------------------------------
|
|
||||||
# happy path:整体 model_post_init 测试
|
|
||||||
# ---------------------------------------------------------
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"model_cls, init_kwargs, expected_fields",
|
|
||||||
[
|
|
||||||
pytest.param(
|
|
||||||
# 简单原子类型字段
|
|
||||||
type(
|
|
||||||
"SimpleAtomic",
|
|
||||||
(ConfigBase,),
|
|
||||||
{
|
|
||||||
"__annotations__": {
|
|
||||||
"a": int,
|
|
||||||
"b": str,
|
|
||||||
"c": bool,
|
|
||||||
"d": float,
|
|
||||||
},
|
|
||||||
"a": Field(default=1),
|
|
||||||
"b": Field(default="x"),
|
|
||||||
"c": Field(default=True),
|
|
||||||
"d": Field(default=1.5),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
{},
|
|
||||||
{"a", "b", "c", "d"},
|
|
||||||
id="happy-simple-atomic-fields",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
# list/set/dict 泛型 + 原子内部类型
|
|
||||||
type(
|
|
||||||
"AtomicContainers",
|
|
||||||
(ConfigBase,),
|
|
||||||
{
|
|
||||||
"__annotations__": {
|
|
||||||
"ints": List[int],
|
|
||||||
"names": Set[str],
|
|
||||||
"mapping": Dict[str, int],
|
|
||||||
},
|
|
||||||
"ints": Field(default_factory=lambda: [1, 2]),
|
|
||||||
"names": Field(default_factory=lambda: {"a", "b"}),
|
|
||||||
"mapping": Field(default_factory=lambda: {"x": 1}),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
{},
|
|
||||||
{"ints", "names", "mapping"},
|
|
||||||
id="happy-atomic-containers",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
# Optional 原子和 Optional 容器
|
|
||||||
type(
|
|
||||||
"OptionalFields",
|
|
||||||
(ConfigBase,),
|
|
||||||
{
|
|
||||||
"__annotations__": {
|
|
||||||
"maybe_int": Optional[int],
|
|
||||||
"maybe_str_list": Optional[List[str]],
|
|
||||||
},
|
|
||||||
"maybe_int": Field(default=None),
|
|
||||||
"maybe_str_list": Field(default=None),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
{},
|
|
||||||
{"maybe_int", "maybe_str_list"},
|
|
||||||
id="happy-optional-fields",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_model_post_init_happy_paths(self, model_cls, init_kwargs, expected_fields):
|
|
||||||
# Act
|
|
||||||
instance = model_cls(**init_kwargs)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
for field_name in expected_fields:
|
|
||||||
assert field_name in type(instance).model_fields
|
|
||||||
_ = getattr(instance, field_name)
|
|
||||||
assert getattr(instance, "__post_init_called__", False) is True
|
|
||||||
|
|
||||||
# ---------------------------------------------------------
|
|
||||||
# _get_real_type
|
|
||||||
# ---------------------------------------------------------
|
|
||||||
def test_get_real_type_non_generic_and_generic(self):
|
|
||||||
class Sample(ConfigBase):
|
|
||||||
x: int = 1
|
|
||||||
y: List[int] = Field(default_factory=list)
|
|
||||||
|
|
||||||
instance = Sample()
|
|
||||||
|
|
||||||
# Act
|
|
||||||
origin_x, args_x = instance._get_real_type(int)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert origin_x is int
|
|
||||||
assert args_x == ()
|
|
||||||
|
|
||||||
# Act
|
|
||||||
origin_y, args_y = instance._get_real_type(List[int])
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert origin_y in (list, List)
|
|
||||||
assert args_y == (int,)
|
|
||||||
|
|
||||||
# ---------------------------------------------------------
|
|
||||||
# _validate_union_type
|
|
||||||
# ---------------------------------------------------------
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"annotation, expect_error, error_fragment, expected_origin_type",
|
|
||||||
[
|
|
||||||
pytest.param(
|
|
||||||
int,
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
int,
|
|
||||||
id="union-validation-atomic-non-union",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
Optional[int],
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
int,
|
|
||||||
id="union-validation-optional-atomic",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
Optional[List[int]],
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
list,
|
|
||||||
id="union-validation-optional-container",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
Union[int, str],
|
|
||||||
True,
|
|
||||||
"不允许使用 Union 类型注解",
|
|
||||||
None,
|
|
||||||
id="union-validation-disallow-non-optional-union",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
int | str,
|
|
||||||
True,
|
|
||||||
"不允许使用 Union 类型注解",
|
|
||||||
None,
|
|
||||||
id="union-validation-pep604-disallow-non-optional-union",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
Union[int, None, str],
|
|
||||||
True,
|
|
||||||
"不允许使用 Union 类型注解",
|
|
||||||
None,
|
|
||||||
id="union-validation-disallow-union-more-than-two",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
Optional[Union[int, str]],
|
|
||||||
True,
|
|
||||||
"不允许使用 Union 类型注解",
|
|
||||||
None,
|
|
||||||
id="union-validation-disallow-nested-optional-union",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_validate_union_type(self, annotation, expect_error, error_fragment, expected_origin_type):
|
|
||||||
# 这里我们不实例化 Sample,以避免在 __init__/model_post_init 阶段触发验证。
|
|
||||||
# 直接通过一个“哑实例”调用受测方法,仅测试类型注解逻辑。
|
|
||||||
|
|
||||||
class Dummy(ConfigBase):
|
|
||||||
pass
|
|
||||||
|
|
||||||
dummy = Dummy() # 最小初始化,避免字段校验
|
|
||||||
|
|
||||||
field_name = "v"
|
|
||||||
|
|
||||||
if expect_error:
|
|
||||||
# Act / Assert
|
|
||||||
with pytest.raises(TypeError) as exc_info:
|
|
||||||
dummy._validate_union_type(annotation, field_name)
|
|
||||||
assert error_fragment in str(exc_info.value)
|
|
||||||
else:
|
|
||||||
# Act
|
|
||||||
origin, args, other = dummy._validate_union_type(annotation, field_name)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert origin is expected_origin_type
|
|
||||||
assert other is not None
|
|
||||||
|
|
||||||
# ---------------------------------------------------------
|
|
||||||
# _validate_list_set_type
|
|
||||||
# ---------------------------------------------------------
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"annotation, expect_error, error_fragment",
|
|
||||||
[
|
|
||||||
pytest.param(
|
|
||||||
List[int],
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
id="listset-validation-list-happy",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
Set[str],
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
id="listset-validation-set-happy",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
list,
|
|
||||||
True,
|
|
||||||
"必须指定且仅指定一个类型参数",
|
|
||||||
id="listset-validation-missing-type-arg",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
List[int | None],
|
|
||||||
True,
|
|
||||||
"不允许嵌套泛型类型",
|
|
||||||
id="listset-validation-nested-generic-inner-union",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
List[List[int]],
|
|
||||||
True,
|
|
||||||
"不允许嵌套泛型类型",
|
|
||||||
id="listset-validation-nested-generic-inner-list",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
List[SimpleClass],
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
id="listset-validation-list-configbase-element_allow",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
Set[SimpleClass],
|
|
||||||
True,
|
|
||||||
"ConfigBase is not Hashable",
|
|
||||||
id="listset-validation-set-configbase-element_reject",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_validate_list_set_type(self, annotation, expect_error, error_fragment):
|
|
||||||
# 不实例化带有这些字段的模型,避免在 __init__/model_post_init 阶段就失败,
|
|
||||||
# 只测试 _validate_list_set_type 本身的逻辑。
|
|
||||||
|
|
||||||
class Dummy(ConfigBase):
|
|
||||||
pass
|
|
||||||
|
|
||||||
dummy = Dummy()
|
|
||||||
|
|
||||||
field_name = "items"
|
|
||||||
|
|
||||||
if expect_error:
|
|
||||||
# Act / Assert
|
|
||||||
with pytest.raises(TypeError) as exc_info:
|
|
||||||
dummy._validate_list_set_type(annotation, field_name)
|
|
||||||
assert error_fragment in str(exc_info.value)
|
|
||||||
else:
|
|
||||||
# Act
|
|
||||||
dummy._validate_list_set_type(annotation, field_name)
|
|
||||||
|
|
||||||
# ---------------------------------------------------------
|
|
||||||
# _validate_dict_type
|
|
||||||
# ---------------------------------------------------------
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"annotation, expect_error, error_fragment",
|
|
||||||
[
|
|
||||||
pytest.param(
|
|
||||||
Dict[str, int],
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
id="dict-validation-happy-atomic",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
Dict[str, Any],
|
|
||||||
True,
|
|
||||||
"不允许使用 Any 类型注解",
|
|
||||||
id="dict-validation-any-value-disallowed",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
Dict[str, Dict[str, int]],
|
|
||||||
True,
|
|
||||||
"不允许嵌套泛型类型",
|
|
||||||
id="dict-validation-optional-nested-list",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
Dict,
|
|
||||||
True,
|
|
||||||
"必须指定键和值的类型参数",
|
|
||||||
id="dict-validation-missing-args",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
Dict[str, SimpleClass],
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
id="dict-validation-happy-configbase-value",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_validate_dict_type(self, annotation, expect_error, error_fragment):
|
|
||||||
# 同样不通过字段定义来触发 model_post_init,只测试 _validate_dict_type 本身。
|
|
||||||
|
|
||||||
class Dummy(ConfigBase):
|
|
||||||
_validate_any: bool = True
|
|
||||||
|
|
||||||
dummy = Dummy()
|
|
||||||
field_name = "mapping"
|
|
||||||
|
|
||||||
if expect_error:
|
|
||||||
# Act / Assert
|
|
||||||
with pytest.raises(TypeError) as exc_info:
|
|
||||||
dummy._validate_dict_type(annotation, field_name)
|
|
||||||
assert error_fragment in str(exc_info.value)
|
|
||||||
else:
|
|
||||||
# Act
|
|
||||||
dummy._validate_dict_type(annotation, field_name)
|
|
||||||
|
|
||||||
# ---------------------------------------------------------
|
|
||||||
# _discourage_any_usage
|
|
||||||
# ---------------------------------------------------------
|
|
||||||
def test_discourage_any_usage_raises_when_validate_any_true(self, caplog):
|
|
||||||
class Sample(ConfigBase):
|
|
||||||
_validate_any: bool = True
|
|
||||||
|
|
||||||
instance = Sample()
|
|
||||||
|
|
||||||
# Act / Assert
|
|
||||||
with pytest.raises(TypeError) as exc_info:
|
|
||||||
instance._discourage_any_usage("field_x")
|
|
||||||
assert "不允许使用 Any 类型注解" in str(exc_info.value)
|
|
||||||
assert "建议避免使用" not in caplog.text
|
|
||||||
|
|
||||||
def test_discourage_any_usage_logs_when_validate_any_false(self, caplog):
|
|
||||||
class Sample(ConfigBase):
|
|
||||||
_validate_any: bool = False
|
|
||||||
|
|
||||||
instance = Sample()
|
|
||||||
|
|
||||||
# Arrange
|
|
||||||
caplog.set_level(logging.WARNING, logger="config_base_test_logger")
|
|
||||||
|
|
||||||
# Act
|
|
||||||
instance._discourage_any_usage("field_y")
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert "字段'field_y'中使用了 Any 类型注解" in caplog.text
|
|
||||||
|
|
||||||
def test_discourage_any_usage_suppressed_warning(self, caplog):
|
|
||||||
class Sample(ConfigBase):
|
|
||||||
_validate_any: bool = False
|
|
||||||
suppress_any_warning: bool = True
|
|
||||||
|
|
||||||
instance = Sample()
|
|
||||||
|
|
||||||
# Arrange
|
|
||||||
caplog.set_level(logging.WARNING, logger="config_base_test_logger")
|
|
||||||
|
|
||||||
# Act
|
|
||||||
instance._discourage_any_usage("field_z")
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert "字段'field_z'中使用了 Any 类型注解" not in caplog.text
|
|
||||||
|
|
||||||
# ---------------------------------------------------------
|
|
||||||
# model_post_init 规则覆盖(错误与边界情况)
|
|
||||||
# ---------------------------------------------------------
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"field_annotation, expect_error, error_fragment, test_id",
|
|
||||||
[
|
|
||||||
(
|
|
||||||
Tuple[int, int],
|
|
||||||
True,
|
|
||||||
"不允许使用 Tuple 类型注解",
|
|
||||||
"model-post-init-disallow-tuple-typing-tuple",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
tuple[int, int],
|
|
||||||
True,
|
|
||||||
"不允许使用 Tuple 类型注解",
|
|
||||||
"model-post-init-disallow-pep604-tuple",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
Union[int, str],
|
|
||||||
True,
|
|
||||||
"不允许使用 Union 类型注解",
|
|
||||||
"model-post-init-disallow-union-field",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
list,
|
|
||||||
True,
|
|
||||||
"必须指定且仅指定一个类型参数",
|
|
||||||
"model-post-init-list-missing-type-arg",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
List[List[int]],
|
|
||||||
True,
|
|
||||||
"不允许嵌套泛型类型",
|
|
||||||
"model-post-init-list-nested-generic",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
Dict[str, Any],
|
|
||||||
True,
|
|
||||||
"不允许使用 Any 类型注解",
|
|
||||||
"model-post-init-dict-value-any",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
Any,
|
|
||||||
True,
|
|
||||||
"不允许使用 Any 类型注解",
|
|
||||||
"model-post-init-field-any-disallowed",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
Set[int],
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
"model-post-init-allow-set-int",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
Dict[str, Optional[int]],
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
"model-post-init-allow-dict-optional-int",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
ids=lambda v: v[3] if isinstance(v, tuple) else v,
|
|
||||||
)
|
|
||||||
def test_model_post_init_type_rules(self, field_annotation, expect_error, error_fragment, test_id):
|
|
||||||
# Arrange
|
|
||||||
attrs = {
|
|
||||||
"__annotations__": {"f": field_annotation},
|
|
||||||
"f": Field(default=None),
|
|
||||||
}
|
|
||||||
model_cls = type("DynamicModel" + test_id.replace("-", "_"), (ConfigBase,), attrs)
|
|
||||||
|
|
||||||
if expect_error:
|
|
||||||
# Act / Assert
|
|
||||||
with pytest.raises(TypeError) as exc_info:
|
|
||||||
model_cls()
|
|
||||||
assert error_fragment in str(exc_info.value)
|
|
||||||
else:
|
|
||||||
# Act
|
|
||||||
instance = model_cls()
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert hasattr(instance, "f")
|
|
||||||
|
|
||||||
# ---------------------------------------------------------
|
|
||||||
# 嵌套 ConfigBase & 非支持泛型 origin
|
|
||||||
# ---------------------------------------------------------
|
|
||||||
def test_model_post_init_allows_configbase_nested_class(self):
|
|
||||||
class Child(ConfigBase):
|
|
||||||
value: int = 1
|
|
||||||
|
|
||||||
class Parent(ConfigBase):
|
|
||||||
child: Child = Field(default_factory=Child)
|
|
||||||
|
|
||||||
# Act
|
|
||||||
parent = Parent()
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert isinstance(parent.child, Child)
|
|
||||||
|
|
||||||
def test_model_post_init_disallow_non_supported_generic_origin(self):
|
|
||||||
class CustomGeneric(BaseModel):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class Sample(ConfigBase):
|
|
||||||
f: CustomGeneric = Field(default_factory=CustomGeneric)
|
|
||||||
|
|
||||||
# Arrange / Act / Assert
|
|
||||||
with pytest.raises(TypeError) as exc_info:
|
|
||||||
Sample()
|
|
||||||
assert "仅允许使用list, set, dict三种泛型类型注解" in str(exc_info.value)
|
|
||||||
|
|
||||||
# ---------------------------------------------------------
|
|
||||||
# super().model_post_init 和 AttrDocBase.__post_init__ 调用
|
|
||||||
# ---------------------------------------------------------
|
|
||||||
def test_super_model_post_init_and_attrdoc_post_init_called(self):
|
|
||||||
class Sample(ConfigBase):
|
|
||||||
value: int = 1
|
|
||||||
|
|
||||||
# Act
|
|
||||||
instance = Sample()
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert getattr(instance, "__post_init_called__", False) is True
|
|
||||||
@@ -1,104 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from watchfiles import Change
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.config.config import ConfigManager
|
|
||||||
from src.config.file_watcher import FileChange, FileWatcherStats
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_file_changes_throttles_reload():
|
|
||||||
manager = ConfigManager()
|
|
||||||
manager._hot_reload_min_interval_s = 100.0
|
|
||||||
|
|
||||||
called = 0
|
|
||||||
|
|
||||||
async def reload_stub(changed_scopes=None) -> bool:
|
|
||||||
nonlocal called
|
|
||||||
called += 1
|
|
||||||
return True
|
|
||||||
|
|
||||||
manager.reload_config = reload_stub # type: ignore[method-assign]
|
|
||||||
changes = [FileChange(change_type=Change.modified, path=Path("/tmp/bot_config.toml"))]
|
|
||||||
|
|
||||||
await manager._handle_file_changes(changes)
|
|
||||||
await manager._handle_file_changes(changes)
|
|
||||||
|
|
||||||
assert called == 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_file_changes_timeout_logged(caplog):
|
|
||||||
manager = ConfigManager()
|
|
||||||
manager._hot_reload_min_interval_s = 0.0
|
|
||||||
manager._hot_reload_timeout_s = 0.01
|
|
||||||
|
|
||||||
async def reload_stub(changed_scopes=None) -> bool:
|
|
||||||
await asyncio.sleep(0.05)
|
|
||||||
return True
|
|
||||||
|
|
||||||
manager.reload_config = reload_stub # type: ignore[method-assign]
|
|
||||||
changes = [FileChange(change_type=Change.modified, path=Path("/tmp/model_config.toml"))]
|
|
||||||
|
|
||||||
with caplog.at_level("ERROR"):
|
|
||||||
await manager._handle_file_changes(changes)
|
|
||||||
|
|
||||||
assert "配置热重载超时" in caplog.text
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_file_changes_empty_skips_reload():
|
|
||||||
manager = ConfigManager()
|
|
||||||
|
|
||||||
called = 0
|
|
||||||
|
|
||||||
async def reload_stub(changed_scopes=None) -> bool:
|
|
||||||
nonlocal called
|
|
||||||
called += 1
|
|
||||||
return True
|
|
||||||
|
|
||||||
manager.reload_config = reload_stub # type: ignore[method-assign]
|
|
||||||
|
|
||||||
await manager._handle_file_changes([])
|
|
||||||
|
|
||||||
assert called == 0
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeWatcher:
|
|
||||||
def __init__(self):
|
|
||||||
self.unsubscribe_called_with: str | None = None
|
|
||||||
self.stop_called = False
|
|
||||||
self.stats = FileWatcherStats(
|
|
||||||
batches_seen=1,
|
|
||||||
changes_seen=2,
|
|
||||||
callbacks_succeeded=3,
|
|
||||||
callbacks_failed=4,
|
|
||||||
callbacks_timed_out=5,
|
|
||||||
callbacks_skipped_cooldown=6,
|
|
||||||
restart_count=7,
|
|
||||||
)
|
|
||||||
|
|
||||||
def unsubscribe(self, subscription_id: str) -> bool:
|
|
||||||
self.unsubscribe_called_with = subscription_id
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
self.stop_called = True
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_stop_file_watcher_cleans_state():
|
|
||||||
manager = ConfigManager()
|
|
||||||
fake_watcher = _FakeWatcher()
|
|
||||||
manager._file_watcher = fake_watcher # type: ignore[assignment]
|
|
||||||
manager._file_watcher_subscription_id = "sub-1"
|
|
||||||
|
|
||||||
await manager.stop_file_watcher()
|
|
||||||
|
|
||||||
assert fake_watcher.unsubscribe_called_with == "sub-1"
|
|
||||||
assert fake_watcher.stop_called is True
|
|
||||||
assert manager._file_watcher is None
|
|
||||||
assert manager._file_watcher_subscription_id is None
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
from typing import Any
|
|
||||||
|
|
||||||
from src.config import config as config_module
|
|
||||||
from src.config.config import Config, ConfigManager, ModelConfig
|
|
||||||
|
|
||||||
|
|
||||||
def test_initialize_upgrades_bot_and_model_config_without_exit(monkeypatch):
|
|
||||||
manager = ConfigManager()
|
|
||||||
loaded_config_classes: list[type[Any]] = []
|
|
||||||
warnings: list[Any] = []
|
|
||||||
|
|
||||||
def fake_load_config_from_file(config_class, config_path, new_ver, override_repr=False):
|
|
||||||
loaded_config_classes.append(config_class)
|
|
||||||
return object(), True
|
|
||||||
|
|
||||||
monkeypatch.setattr(config_module, "load_config_from_file", fake_load_config_from_file)
|
|
||||||
monkeypatch.setattr(ConfigManager, "_warn_if_vlm_not_configured", lambda self, model_config: warnings.append(model_config))
|
|
||||||
|
|
||||||
manager.initialize()
|
|
||||||
|
|
||||||
assert loaded_config_classes == [Config, ModelConfig]
|
|
||||||
assert warnings == [manager.model_config]
|
|
||||||
@@ -1,138 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from watchfiles import Change
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.config.file_watcher import FileChange, FileWatcher
|
|
||||||
|
|
||||||
from typing import Sequence
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_dispatch_changes_with_path_and_change_type_filters(tmp_path: Path):
|
|
||||||
watcher = FileWatcher(paths=[tmp_path])
|
|
||||||
target_file = (tmp_path / "bot_config.toml").resolve()
|
|
||||||
|
|
||||||
received: list[list[FileChange]] = []
|
|
||||||
|
|
||||||
async def callback(changes):
|
|
||||||
received.append(list(changes))
|
|
||||||
|
|
||||||
watcher.subscribe(callback, paths=[target_file], change_types=[Change.modified])
|
|
||||||
|
|
||||||
await watcher._dispatch_changes(
|
|
||||||
[
|
|
||||||
FileChange(change_type=Change.added, path=target_file),
|
|
||||||
FileChange(change_type=Change.modified, path=target_file),
|
|
||||||
FileChange(change_type=Change.modified, path=(tmp_path / "other.toml").resolve()),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(received) == 1
|
|
||||||
assert len(received[0]) == 1
|
|
||||||
assert received[0][0].change_type == Change.modified
|
|
||||||
assert received[0][0].path == target_file
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_sync_callback_supported(tmp_path: Path):
|
|
||||||
watcher = FileWatcher(paths=[tmp_path])
|
|
||||||
target_file = (tmp_path / "model_config.toml").resolve()
|
|
||||||
|
|
||||||
received_paths: list[Path] = []
|
|
||||||
|
|
||||||
def sync_callback(changes):
|
|
||||||
received_paths.extend(change.path for change in changes)
|
|
||||||
|
|
||||||
watcher.subscribe(sync_callback, paths=[target_file])
|
|
||||||
|
|
||||||
await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)])
|
|
||||||
|
|
||||||
assert received_paths == [target_file]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_callback_timeout_and_cooldown(tmp_path: Path):
|
|
||||||
watcher = FileWatcher(
|
|
||||||
paths=[tmp_path],
|
|
||||||
callback_timeout_s=0.05,
|
|
||||||
callback_failure_threshold=2,
|
|
||||||
callback_cooldown_s=0.2,
|
|
||||||
)
|
|
||||||
target_file = (tmp_path / "bot_config.toml").resolve()
|
|
||||||
|
|
||||||
async def slow_callback(changes):
|
|
||||||
await asyncio.sleep(0.2)
|
|
||||||
|
|
||||||
watcher.subscribe(slow_callback, paths=[target_file])
|
|
||||||
|
|
||||||
await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)])
|
|
||||||
await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)])
|
|
||||||
|
|
||||||
stats_after_failures = watcher.stats
|
|
||||||
assert stats_after_failures.callbacks_timed_out == 2
|
|
||||||
assert stats_after_failures.callbacks_failed == 2
|
|
||||||
|
|
||||||
await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)])
|
|
||||||
stats_after_cooldown_skip = watcher.stats
|
|
||||||
assert stats_after_cooldown_skip.callbacks_skipped_cooldown >= 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_start_requires_subscription(tmp_path: Path):
|
|
||||||
watcher = FileWatcher(paths=[tmp_path])
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError):
|
|
||||||
await watcher.start()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_unsubscribe_stops_dispatch(tmp_path: Path):
|
|
||||||
watcher = FileWatcher(paths=[tmp_path])
|
|
||||||
target_file = (tmp_path / "bot_config.toml").resolve()
|
|
||||||
|
|
||||||
calls = 0
|
|
||||||
|
|
||||||
async def callback(changes):
|
|
||||||
nonlocal calls
|
|
||||||
calls += 1
|
|
||||||
|
|
||||||
subscription_id = watcher.subscribe(callback, paths=[target_file])
|
|
||||||
assert watcher.unsubscribe(subscription_id) is True
|
|
||||||
|
|
||||||
await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)])
|
|
||||||
|
|
||||||
assert calls == 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_add_callback_while_watcher_running(tmp_path: Path):
|
|
||||||
dirs = (tmp_path / "a_dir").resolve()
|
|
||||||
dirs.mkdir(exist_ok=True)
|
|
||||||
file = (dirs / "a.toml").resolve()
|
|
||||||
file.touch()
|
|
||||||
watcher = FileWatcher(paths=[dirs], debounce_ms=200)
|
|
||||||
|
|
||||||
calls = 0
|
|
||||||
|
|
||||||
async def callback(changes: Sequence[FileChange]):
|
|
||||||
nonlocal calls
|
|
||||||
print(f"Callback called with changes: {[f'{change.change_type} {change.path}' for change in changes]}")
|
|
||||||
calls += 1
|
|
||||||
|
|
||||||
uuid = watcher.subscribe(callback, paths=[file])
|
|
||||||
await watcher.start()
|
|
||||||
try:
|
|
||||||
with file.open("w") as f:
|
|
||||||
f.write("change")
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
assert calls == 1
|
|
||||||
watcher.unsubscribe(uuid)
|
|
||||||
with file.open("w") as f:
|
|
||||||
f.write("change2")
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
assert calls == 1
|
|
||||||
finally:
|
|
||||||
await watcher.stop()
|
|
||||||
@@ -1,76 +0,0 @@
|
|||||||
from types import SimpleNamespace
|
|
||||||
from importlib import util
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from src.config.config import config_manager
|
|
||||||
from src.config.model_configs import TaskConfig
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
|
|
||||||
|
|
||||||
def _load_llm_api_module():
|
|
||||||
file_path = Path(__file__).parent.parent.parent / "src" / "plugin_system" / "apis" / "llm_api.py"
|
|
||||||
spec = util.spec_from_file_location("test_llm_api_module", file_path)
|
|
||||||
assert spec is not None and spec.loader is not None
|
|
||||||
module = util.module_from_spec(spec)
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
def _make_model_config(task_config: TaskConfig, attr_name: str = "utils"):
|
|
||||||
model_task_config = SimpleNamespace(**{attr_name: task_config})
|
|
||||||
return SimpleNamespace(model_task_config=model_task_config, models=[], api_providers=[])
|
|
||||||
|
|
||||||
|
|
||||||
def test_llm_request_resolve_task_config_by_signature(monkeypatch):
|
|
||||||
old_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0)
|
|
||||||
current_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0)
|
|
||||||
|
|
||||||
monkeypatch.setattr(config_manager, "get_model_config", lambda: _make_model_config(current_task, "utils"))
|
|
||||||
|
|
||||||
req = LLMRequest(model_set=old_task, request_type="test")
|
|
||||||
|
|
||||||
assert req._task_config_name == "utils"
|
|
||||||
|
|
||||||
|
|
||||||
def test_llm_request_refresh_task_config_updates_runtime_state(monkeypatch):
|
|
||||||
old_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0)
|
|
||||||
initial_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0)
|
|
||||||
updated_task = TaskConfig(model_list=["gpt-b", "gpt-c"], max_tokens=1024, temperature=0.5, slow_threshold=20.0)
|
|
||||||
|
|
||||||
current = {"task": initial_task}
|
|
||||||
|
|
||||||
def get_model_config_stub():
|
|
||||||
return _make_model_config(current["task"], "replyer")
|
|
||||||
|
|
||||||
monkeypatch.setattr(config_manager, "get_model_config", get_model_config_stub)
|
|
||||||
|
|
||||||
req = LLMRequest(model_set=old_task, request_type="test")
|
|
||||||
assert req._task_config_name == "replyer"
|
|
||||||
|
|
||||||
current["task"] = updated_task
|
|
||||||
req._refresh_task_config()
|
|
||||||
|
|
||||||
assert req.model_for_task.model_list == ["gpt-b", "gpt-c"]
|
|
||||||
assert list(req.model_usage.keys()) == ["gpt-b", "gpt-c"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_llm_api_get_available_models_reads_latest_config(monkeypatch):
|
|
||||||
llm_api = _load_llm_api_module()
|
|
||||||
|
|
||||||
first_utils = TaskConfig(model_list=["gpt-a"])
|
|
||||||
second_utils = TaskConfig(model_list=["gpt-z"])
|
|
||||||
|
|
||||||
state = {"task": first_utils}
|
|
||||||
|
|
||||||
def get_model_config_stub():
|
|
||||||
model_task_config = SimpleNamespace(utils=state["task"], planner=TaskConfig(model_list=["gpt-p"]))
|
|
||||||
return SimpleNamespace(model_task_config=model_task_config)
|
|
||||||
|
|
||||||
monkeypatch.setattr(config_manager, "get_model_config", get_model_config_stub)
|
|
||||||
|
|
||||||
first = llm_api.get_available_models()
|
|
||||||
assert first["utils"].model_list == ["gpt-a"]
|
|
||||||
|
|
||||||
state["task"] = second_utils
|
|
||||||
second = llm_api.get_available_models()
|
|
||||||
assert second["utils"].model_list == ["gpt-z"]
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
from src.config.model_configs import ModelInfo
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_identifier_strips_surrounding_whitespace() -> None:
|
|
||||||
model_info = ModelInfo(
|
|
||||||
api_provider="test-provider",
|
|
||||||
model_identifier=" glm-5.1 ",
|
|
||||||
name="test-model",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert model_info.model_identifier == "glm-5.1"
|
|
||||||
@@ -1,104 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
from types import SimpleNamespace
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from src.config.legacy_migration import migrate_legacy_bind_env_to_bot_config_dict
|
|
||||||
from src.config.startup_bindings import (
|
|
||||||
BindAddress,
|
|
||||||
get_startup_main_bind_address,
|
|
||||||
get_startup_webui_bind_address,
|
|
||||||
resolve_main_bind_address,
|
|
||||||
resolve_webui_bind_address,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_startup_bindings_use_defaults_when_config_file_missing(tmp_path: Path):
|
|
||||||
missing_path = tmp_path / "missing_bot_config.toml"
|
|
||||||
|
|
||||||
assert get_startup_main_bind_address(missing_path) == BindAddress(host="127.0.0.1", port=8080)
|
|
||||||
assert get_startup_webui_bind_address(missing_path) == BindAddress(host="127.0.0.1", port=8001)
|
|
||||||
|
|
||||||
|
|
||||||
def test_startup_bindings_can_read_addresses_from_bot_config(tmp_path: Path):
|
|
||||||
config_path = tmp_path / "bot_config.toml"
|
|
||||||
config_path.write_text(
|
|
||||||
"""
|
|
||||||
[inner]
|
|
||||||
version = "8.3.1"
|
|
||||||
|
|
||||||
[maim_message]
|
|
||||||
ws_server_host = "0.0.0.0"
|
|
||||||
ws_server_port = 22345
|
|
||||||
|
|
||||||
[webui]
|
|
||||||
host = "192.168.1.9"
|
|
||||||
port = 18001
|
|
||||||
""".strip(),
|
|
||||||
encoding="utf-8",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert get_startup_main_bind_address(config_path) == BindAddress(host="0.0.0.0", port=22345)
|
|
||||||
assert get_startup_webui_bind_address(config_path) == BindAddress(host="192.168.1.9", port=18001)
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_bindings_prefer_initialized_global_config(monkeypatch):
|
|
||||||
fake_config_module = SimpleNamespace(
|
|
||||||
global_config=SimpleNamespace(
|
|
||||||
maim_message=SimpleNamespace(ws_server_host="10.0.0.2", ws_server_port=32000),
|
|
||||||
webui=SimpleNamespace(host="10.0.0.3", port=32001),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setitem(sys.modules, "src.config.config", fake_config_module)
|
|
||||||
|
|
||||||
assert resolve_main_bind_address() == BindAddress(host="10.0.0.2", port=32000)
|
|
||||||
assert resolve_webui_bind_address() == BindAddress(host="10.0.0.3", port=32001)
|
|
||||||
|
|
||||||
|
|
||||||
def test_legacy_env_bindings_are_migrated_when_fields_missing_or_default(monkeypatch):
|
|
||||||
monkeypatch.setenv("HOST", "0.0.0.0")
|
|
||||||
monkeypatch.setenv("PORT", "22345")
|
|
||||||
monkeypatch.setenv("WEBUI_HOST", "192.168.1.8")
|
|
||||||
monkeypatch.setenv("WEBUI_PORT", "19001")
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"maim_message": {
|
|
||||||
"ws_server_host": "127.0.0.1",
|
|
||||||
"ws_server_port": 8080,
|
|
||||||
},
|
|
||||||
"webui": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
result = migrate_legacy_bind_env_to_bot_config_dict(payload)
|
|
||||||
|
|
||||||
assert result.migrated is True
|
|
||||||
assert payload["maim_message"]["ws_server_host"] == "0.0.0.0"
|
|
||||||
assert payload["maim_message"]["ws_server_port"] == 22345
|
|
||||||
assert payload["webui"]["host"] == "192.168.1.8"
|
|
||||||
assert payload["webui"]["port"] == 19001
|
|
||||||
|
|
||||||
|
|
||||||
def test_legacy_env_bindings_do_not_override_explicit_config(monkeypatch):
|
|
||||||
monkeypatch.setenv("HOST", "0.0.0.0")
|
|
||||||
monkeypatch.setenv("PORT", "22345")
|
|
||||||
monkeypatch.setenv("WEBUI_HOST", "192.168.1.8")
|
|
||||||
monkeypatch.setenv("WEBUI_PORT", "19001")
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"maim_message": {
|
|
||||||
"ws_server_host": "10.1.1.1",
|
|
||||||
"ws_server_port": 30000,
|
|
||||||
},
|
|
||||||
"webui": {
|
|
||||||
"host": "10.1.1.2",
|
|
||||||
"port": 30001,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
result = migrate_legacy_bind_env_to_bot_config_dict(payload)
|
|
||||||
|
|
||||||
assert result.migrated is False
|
|
||||||
assert payload["maim_message"]["ws_server_host"] == "10.1.1.1"
|
|
||||||
assert payload["maim_message"]["ws_server_port"] == 30000
|
|
||||||
assert payload["webui"]["host"] == "10.1.1.2"
|
|
||||||
assert payload["webui"]["port"] == 30001
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add project root to Python path so src imports work
|
|
||||||
project_root = Path(__file__).parent.parent.absolute()
|
|
||||||
src_root = project_root / "src"
|
|
||||||
if str(src_root) not in sys.path:
|
|
||||||
sys.path.insert(0, str(src_root))
|
|
||||||
if str(project_root) not in sys.path:
|
|
||||||
sys.path.insert(1, str(project_root))
|
|
||||||
@@ -1,66 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.common.i18n.manager import I18nManager
|
|
||||||
from src.common.i18n.loaders import DuplicateTranslationKeyError, load_locale_catalog
|
|
||||||
|
|
||||||
|
|
||||||
def write_locale_file(locales_root: Path, locale: str, file_name: str, payload: dict[str, object]) -> None:
|
|
||||||
locale_dir = locales_root / locale
|
|
||||||
locale_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
file_path = locale_dir / file_name
|
|
||||||
file_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def test_t_falls_back_to_default_locale(tmp_path: Path) -> None:
|
|
||||||
locales_root = tmp_path / "locales"
|
|
||||||
write_locale_file(locales_root, "zh-CN", "core.json", {"greeting": "你好,{name}"})
|
|
||||||
write_locale_file(locales_root, "en-US", "core.json", {})
|
|
||||||
|
|
||||||
manager = I18nManager(locales_root=locales_root)
|
|
||||||
|
|
||||||
assert manager.t("greeting", locale="en-US", name="Mai") == "你好,Mai"
|
|
||||||
|
|
||||||
|
|
||||||
def test_t_returns_key_when_missing_everywhere(tmp_path: Path) -> None:
|
|
||||||
locales_root = tmp_path / "locales"
|
|
||||||
write_locale_file(locales_root, "zh-CN", "core.json", {})
|
|
||||||
write_locale_file(locales_root, "en-US", "core.json", {})
|
|
||||||
|
|
||||||
manager = I18nManager(locales_root=locales_root)
|
|
||||||
|
|
||||||
assert manager.t("missing.key", locale="en-US") == "missing.key"
|
|
||||||
|
|
||||||
|
|
||||||
def test_tn_uses_plural_rules(tmp_path: Path) -> None:
|
|
||||||
locales_root = tmp_path / "locales"
|
|
||||||
write_locale_file(
|
|
||||||
locales_root,
|
|
||||||
"en-US",
|
|
||||||
"core.json",
|
|
||||||
{
|
|
||||||
"tasks.cancelled": {
|
|
||||||
"one": "Cancelled {count} task",
|
|
||||||
"other": "Cancelled {count} tasks",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
manager = I18nManager(default_locale="en-US", locales_root=locales_root)
|
|
||||||
|
|
||||||
assert manager.tn("tasks.cancelled", 1) == "Cancelled 1 task"
|
|
||||||
assert manager.tn("tasks.cancelled", 2) == "Cancelled 2 tasks"
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_locale_catalog_rejects_duplicate_keys(tmp_path: Path) -> None:
|
|
||||||
locales_root = tmp_path / "locales"
|
|
||||||
write_locale_file(locales_root, "zh-CN", "a.json", {"duplicate.key": "A"})
|
|
||||||
write_locale_file(locales_root, "zh-CN", "b.json", {"duplicate.key": "B"})
|
|
||||||
|
|
||||||
with pytest.raises(DuplicateTranslationKeyError):
|
|
||||||
load_locale_catalog("zh-CN", locales_root)
|
|
||||||
@@ -1,110 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from importlib.util import module_from_spec, spec_from_file_location
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
SCRIPT_PATH = Path(__file__).resolve().parents[2] / "scripts" / "i18n_validate.py"
|
|
||||||
MODULE_SPEC = spec_from_file_location("i18n_validate_script", SCRIPT_PATH)
|
|
||||||
assert MODULE_SPEC is not None
|
|
||||||
assert MODULE_SPEC.loader is not None
|
|
||||||
I18N_VALIDATE = module_from_spec(MODULE_SPEC)
|
|
||||||
MODULE_SPEC.loader.exec_module(I18N_VALIDATE)
|
|
||||||
|
|
||||||
|
|
||||||
def write_locale_file(locales_root: Path, locale: str, file_name: str, payload: dict[str, object]) -> None:
|
|
||||||
locale_dir = locales_root / locale
|
|
||||||
locale_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
(locale_dir / file_name).write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def write_dashboard_locale_file(locales_root: Path, locale: str, payload: dict[str, object]) -> None:
|
|
||||||
locales_root.mkdir(parents=True, exist_ok=True)
|
|
||||||
(locales_root / f"{locale}.json").write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_json_locales_rejects_han_characters_in_english_locale(tmp_path: Path) -> None:
|
|
||||||
locales_root = tmp_path / "locales"
|
|
||||||
write_locale_file(locales_root, "zh-CN", "core.json", {"consent.prompt": '输入"同意"继续'})
|
|
||||||
write_locale_file(locales_root, "en-US", "core.json", {"consent.prompt": 'Type "confirmed" or "同意" to continue'})
|
|
||||||
|
|
||||||
errors = I18N_VALIDATE.validate_json_locales(locales_root)
|
|
||||||
|
|
||||||
assert any("consent.prompt" in error and "仍包含中文字符" in error for error in errors)
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_json_locales_rejects_untranslated_han_source_in_other_target_locales(tmp_path: Path) -> None:
|
|
||||||
locales_root = tmp_path / "locales"
|
|
||||||
write_locale_file(locales_root, "zh-CN", "core.json", {"greeting": "你好,世界"})
|
|
||||||
write_locale_file(locales_root, "ja", "core.json", {"greeting": "你好,世界"})
|
|
||||||
|
|
||||||
errors = I18N_VALIDATE.validate_json_locales(locales_root)
|
|
||||||
|
|
||||||
assert any("greeting" in error and "直接保留了包含中文字符的 source 文案" in error for error in errors)
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_json_locales_avoids_false_positive_when_plural_categories_do_not_align(tmp_path: Path) -> None:
|
|
||||||
locales_root = tmp_path / "locales"
|
|
||||||
write_locale_file(
|
|
||||||
locales_root,
|
|
||||||
"zh-CN",
|
|
||||||
"core.json",
|
|
||||||
{
|
|
||||||
"tasks.cancelled": {
|
|
||||||
"one": "中文单数",
|
|
||||||
"other": "中文复数",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
write_locale_file(
|
|
||||||
locales_root,
|
|
||||||
"ja",
|
|
||||||
"core.json",
|
|
||||||
{
|
|
||||||
"tasks.cancelled": {
|
|
||||||
"many": "中文单数",
|
|
||||||
"other": "已翻译",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
errors = I18N_VALIDATE.validate_json_locales(locales_root)
|
|
||||||
|
|
||||||
assert any("tasks.cancelled" in error and "plural category 不一致" in error for error in errors)
|
|
||||||
assert not any("tasks.cancelled" in error and "直接保留了包含中文字符的 source 文案" in error for error in errors)
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_dashboard_json_locales_rejects_han_characters_in_english_locale(tmp_path: Path) -> None:
|
|
||||||
locales_root = tmp_path / "dashboard-locales"
|
|
||||||
write_dashboard_locale_file(locales_root, "zh", {"common": {"greeting": "你好,世界"}})
|
|
||||||
write_dashboard_locale_file(locales_root, "en", {"common": {"greeting": "Hello 同意"}})
|
|
||||||
|
|
||||||
errors = I18N_VALIDATE.validate_dashboard_json_locales(locales_root)
|
|
||||||
|
|
||||||
assert any("dashboard:en" in error and "common.greeting" in error and "仍包含中文字符" in error for error in errors)
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_dashboard_json_locales_rejects_untranslated_han_source_in_other_target_locales(
|
|
||||||
tmp_path: Path,
|
|
||||||
) -> None:
|
|
||||||
locales_root = tmp_path / "dashboard-locales"
|
|
||||||
write_dashboard_locale_file(locales_root, "zh", {"common": {"greeting": "你好,世界"}})
|
|
||||||
write_dashboard_locale_file(locales_root, "ja", {"common": {"greeting": "你好,世界"}})
|
|
||||||
|
|
||||||
errors = I18N_VALIDATE.validate_dashboard_json_locales(locales_root)
|
|
||||||
|
|
||||||
assert any(
|
|
||||||
"dashboard:ja" in error and "common.greeting" in error and "直接保留了包含中文字符的 source 文案" in error
|
|
||||||
for error in errors
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_dashboard_json_locales_rejects_i18next_placeholder_drift(tmp_path: Path) -> None:
|
|
||||||
locales_root = tmp_path / "dashboard-locales"
|
|
||||||
write_dashboard_locale_file(locales_root, "zh", {"status": {"checkingDesc": "等待服务恢复... ({{current}}/{{max}})"}})
|
|
||||||
write_dashboard_locale_file(locales_root, "ko", {"status": {"checkingDesc": "서비스 복구 대기 중... ({{current}}/{{limit}})"}})
|
|
||||||
|
|
||||||
errors = I18N_VALIDATE.validate_dashboard_json_locales(locales_root)
|
|
||||||
|
|
||||||
assert any("dashboard:ko" in error and "status.checkingDesc" in error and "占位符集合与 source 不一致" in error for error in errors)
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,295 +0,0 @@
|
|||||||
import sys
|
|
||||||
import types
|
|
||||||
import importlib
|
|
||||||
import pytest
|
|
||||||
from pathlib import Path
|
|
||||||
import importlib.util
|
|
||||||
|
|
||||||
|
|
||||||
class DummyLogger:
|
|
||||||
def info(self, *a, **k):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def warning(self, *a, **k):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def error(self, *a, **k):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DummySession:
|
|
||||||
def __init__(self):
|
|
||||||
self.record = None
|
|
||||||
|
|
||||||
def exec(self, *a, **k):
|
|
||||||
record = self.record
|
|
||||||
|
|
||||||
class R:
|
|
||||||
def first(self):
|
|
||||||
return record
|
|
||||||
|
|
||||||
def yield_per(self, n):
|
|
||||||
if record is None:
|
|
||||||
return iter(())
|
|
||||||
return iter((record,))
|
|
||||||
|
|
||||||
return R()
|
|
||||||
|
|
||||||
def add(self, record, *a, **k):
|
|
||||||
self.record = record
|
|
||||||
|
|
||||||
def flush(self, *a, **k):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def delete(self, *a, **k):
|
|
||||||
self.record = None
|
|
||||||
|
|
||||||
def expunge(self, *a, **k):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc, tb):
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class DummyMaiImage:
|
|
||||||
def __init__(self, full_path=None, image_bytes=None):
|
|
||||||
self.full_path = full_path
|
|
||||||
self.image_bytes = image_bytes
|
|
||||||
self.file_hash = "dummy-hash"
|
|
||||||
self.image_format = "png"
|
|
||||||
self.description = ""
|
|
||||||
self.vlm_processed = False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_db_instance(cls, record):
|
|
||||||
image = cls(full_path=getattr(record, "full_path", None))
|
|
||||||
image.file_hash = getattr(record, "image_hash", "dummy-hash")
|
|
||||||
image.description = getattr(record, "description", "")
|
|
||||||
image.vlm_processed = getattr(record, "vlm_processed", False)
|
|
||||||
return image
|
|
||||||
|
|
||||||
def to_db_instance(self):
|
|
||||||
return types.SimpleNamespace(
|
|
||||||
description=self.description,
|
|
||||||
full_path=str(self.full_path) if self.full_path is not None else "",
|
|
||||||
id=1,
|
|
||||||
image_hash=self.file_hash,
|
|
||||||
image_type="image",
|
|
||||||
last_used_time=None,
|
|
||||||
no_file_flag=False,
|
|
||||||
query_count=0,
|
|
||||||
register_time=None,
|
|
||||||
vlm_processed=self.vlm_processed,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def calculate_hash_format(self):
|
|
||||||
self.file_hash = "dummy-hash"
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class DummyLLMRequest:
|
|
||||||
def __init__(self, *a, **k):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def generate_response_for_image(self, prompt, image_base64, image_format, temp):
|
|
||||||
return ("dummy description", {})
|
|
||||||
|
|
||||||
|
|
||||||
class DummyLLMServiceClient:
|
|
||||||
def __init__(self, *a, **k):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def generate_response_for_image(self, prompt, image_base64, image_format, options=None):
|
|
||||||
return types.SimpleNamespace(response="dummy description")
|
|
||||||
|
|
||||||
|
|
||||||
class DummySelect:
|
|
||||||
def __init__(self, *a, **k):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def filter_by(self, *a, **k):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def limit(self, n):
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class DetachedRecord:
|
|
||||||
def __init__(self, description="cached description", vlm_processed=True):
|
|
||||||
self._detached = False
|
|
||||||
self._description = description
|
|
||||||
self._vlm_processed = vlm_processed
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self):
|
|
||||||
if not self._detached:
|
|
||||||
raise RuntimeError("attribute refresh operation cannot proceed")
|
|
||||||
return self._description
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vlm_processed(self):
|
|
||||||
if not self._detached:
|
|
||||||
raise RuntimeError("attribute refresh operation cannot proceed")
|
|
||||||
return self._vlm_processed
|
|
||||||
|
|
||||||
|
|
||||||
class DetachedRecordSession(DummySession):
|
|
||||||
def __init__(self, record):
|
|
||||||
self.record = record
|
|
||||||
|
|
||||||
def exec(self, *a, **k):
|
|
||||||
record = self.record
|
|
||||||
|
|
||||||
class R:
|
|
||||||
def first(self):
|
|
||||||
return record
|
|
||||||
|
|
||||||
return R()
|
|
||||||
|
|
||||||
def expunge(self, record):
|
|
||||||
record._detached = True
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def patch_external_dependencies(monkeypatch):
|
|
||||||
# Provide dummy implementations as modules so that importing image_manager is safe
|
|
||||||
# Patch LLMRequest
|
|
||||||
llm_mod = types.SimpleNamespace(LLMRequest=DummyLLMRequest)
|
|
||||||
monkeypatch.setitem(sys.modules, "src.llm_models.utils_model", llm_mod)
|
|
||||||
llm_service_mod = types.SimpleNamespace(LLMServiceClient=DummyLLMServiceClient)
|
|
||||||
monkeypatch.setitem(sys.modules, "src.services.llm_service", llm_service_mod)
|
|
||||||
|
|
||||||
# Patch logger
|
|
||||||
logger_mod = types.SimpleNamespace(get_logger=lambda name: DummyLogger())
|
|
||||||
monkeypatch.setitem(sys.modules, "src.common.logger", logger_mod)
|
|
||||||
|
|
||||||
# Patch DB session provider
|
|
||||||
shared_session = DummySession()
|
|
||||||
db_mod = types.SimpleNamespace(get_db_session=lambda: shared_session)
|
|
||||||
monkeypatch.setitem(sys.modules, "src.common.database.database", db_mod)
|
|
||||||
|
|
||||||
# Patch database model types
|
|
||||||
db_model_mod = types.SimpleNamespace(Images=types.SimpleNamespace, ImageType=types.SimpleNamespace(IMAGE="image"))
|
|
||||||
monkeypatch.setitem(sys.modules, "src.common.database.database_model", db_model_mod)
|
|
||||||
|
|
||||||
# Patch MaiImage data model
|
|
||||||
data_model_mod = types.SimpleNamespace(MaiImage=DummyMaiImage)
|
|
||||||
monkeypatch.setitem(sys.modules, "src.common.data_models.image_data_model", data_model_mod)
|
|
||||||
|
|
||||||
# Patch SQLModel select function
|
|
||||||
sql_mod = types.SimpleNamespace(select=lambda *a, **k: DummySelect())
|
|
||||||
monkeypatch.setitem(sys.modules, "sqlmodel", sql_mod)
|
|
||||||
|
|
||||||
# Patch prompt manager used to build image description prompt.
|
|
||||||
class _PromptManager:
|
|
||||||
def get_prompt(self, _name):
|
|
||||||
return types.SimpleNamespace()
|
|
||||||
|
|
||||||
async def render_prompt(self, _prompt):
|
|
||||||
return "test-style"
|
|
||||||
|
|
||||||
prompt_manager_mod = types.SimpleNamespace(prompt_manager=_PromptManager())
|
|
||||||
monkeypatch.setitem(sys.modules, "src.prompt.prompt_manager", prompt_manager_mod)
|
|
||||||
|
|
||||||
llm_options_mod = types.SimpleNamespace(LLMImageOptions=lambda **kwargs: types.SimpleNamespace(**kwargs))
|
|
||||||
monkeypatch.setitem(sys.modules, "src.common.data_models.llm_service_data_models", llm_options_mod)
|
|
||||||
|
|
||||||
# If module already imported, reload it to apply patches
|
|
||||||
mod_name = "src.chat.image_system.image_manager"
|
|
||||||
if mod_name in sys.modules:
|
|
||||||
importlib.reload(sys.modules[mod_name])
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
def _load_image_manager_module(tmp_path=None):
|
|
||||||
repo_root = Path(__file__).parent.parent.parent
|
|
||||||
file_path = repo_root / "src" / "chat" / "image_system" / "image_manager.py"
|
|
||||||
spec = importlib.util.spec_from_file_location("image_manager_test_loaded", str(file_path))
|
|
||||||
mod = importlib.util.module_from_spec(spec)
|
|
||||||
sys.modules[spec.name] = mod
|
|
||||||
spec.loader.exec_module(mod)
|
|
||||||
# Redirect IMAGE_DIR to pytest's tmp_path when provided
|
|
||||||
try:
|
|
||||||
if tmp_path is not None:
|
|
||||||
tmpdir = Path(tmp_path)
|
|
||||||
tmpdir.mkdir(parents=True, exist_ok=True)
|
|
||||||
mod.IMAGE_DIR = tmpdir
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return mod
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_image_description_generates(tmp_path):
|
|
||||||
image_manager = _load_image_manager_module(tmp_path)
|
|
||||||
|
|
||||||
mgr = image_manager.ImageManager()
|
|
||||||
desc = await mgr.get_image_description(image_bytes=b"abc")
|
|
||||||
assert desc == "dummy description"
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_image_from_db_none(tmp_path):
|
|
||||||
image_manager = _load_image_manager_module(tmp_path)
|
|
||||||
|
|
||||||
mgr = image_manager.ImageManager()
|
|
||||||
assert mgr.get_image_from_db("nohash") is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_register_image_to_db(tmp_path):
|
|
||||||
image_manager = _load_image_manager_module(tmp_path)
|
|
||||||
|
|
||||||
mgr = image_manager.ImageManager()
|
|
||||||
p = tmp_path / "img.png"
|
|
||||||
p.write_bytes(b"data")
|
|
||||||
img = DummyMaiImage(full_path=p, image_bytes=b"data")
|
|
||||||
assert mgr.register_image_to_db(img) is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_image_description_not_found(tmp_path):
|
|
||||||
image_manager = _load_image_manager_module(tmp_path)
|
|
||||||
|
|
||||||
mgr = image_manager.ImageManager()
|
|
||||||
img = DummyMaiImage()
|
|
||||||
img.file_hash = "nohash"
|
|
||||||
img.description = "desc"
|
|
||||||
assert mgr.update_image_description(img) is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_image_not_found(tmp_path):
|
|
||||||
image_manager = _load_image_manager_module(tmp_path)
|
|
||||||
|
|
||||||
mgr = image_manager.ImageManager()
|
|
||||||
img = DummyMaiImage()
|
|
||||||
img.file_hash = "nohash"
|
|
||||||
img.full_path = tmp_path = None
|
|
||||||
assert mgr.delete_image(img) is False
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_save_image_and_process_and_cleanup(tmp_path):
|
|
||||||
image_manager = _load_image_manager_module(tmp_path)
|
|
||||||
|
|
||||||
mgr = image_manager.ImageManager()
|
|
||||||
# call save_image_and_process
|
|
||||||
image = await mgr.save_image_and_process(b"binarydata")
|
|
||||||
assert getattr(image, "description", None) == "dummy description"
|
|
||||||
|
|
||||||
# cleanup should run without error
|
|
||||||
mgr.cleanup_invalid_descriptions_in_db()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_image_description_returns_cached_description_after_session_closed(monkeypatch, tmp_path):
|
|
||||||
image_manager = _load_image_manager_module(tmp_path)
|
|
||||||
|
|
||||||
cached_record = DetachedRecord()
|
|
||||||
monkeypatch.setattr(image_manager, "get_db_session", lambda: DetachedRecordSession(cached_record))
|
|
||||||
|
|
||||||
mgr = image_manager.ImageManager()
|
|
||||||
desc = await mgr.get_image_description(image_hash="cached-hash", wait_for_build=False)
|
|
||||||
|
|
||||||
assert desc == "cached description"
|
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
import io
|
|
||||||
|
|
||||||
from PIL import Image as PILImage
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.common.data_models.image_data_model import MaiEmoji, MaiImage
|
|
||||||
|
|
||||||
|
|
||||||
def _build_test_image_bytes(image_format: str) -> bytes:
|
|
||||||
image = PILImage.new("RGB", (8, 8), color="white")
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
image.save(buffer, format=image_format)
|
|
||||||
return buffer.getvalue()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_calculate_hash_format_updates_runtime_path_metadata(tmp_path: Path) -> None:
|
|
||||||
image_bytes = _build_test_image_bytes("JPEG")
|
|
||||||
tmp_file_path = tmp_path / "emoji.tmp"
|
|
||||||
tmp_file_path.write_bytes(image_bytes)
|
|
||||||
|
|
||||||
emoji = MaiEmoji(full_path=tmp_file_path, image_bytes=image_bytes)
|
|
||||||
|
|
||||||
assert await emoji.calculate_hash_format() is True
|
|
||||||
assert emoji.image_format == "jpeg"
|
|
||||||
assert emoji.full_path.suffix == ".jpeg"
|
|
||||||
assert emoji.file_name == emoji.full_path.name
|
|
||||||
assert emoji.dir_path == tmp_path.resolve()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_calculate_hash_format_reuses_existing_target_file(tmp_path: Path) -> None:
|
|
||||||
image_bytes = _build_test_image_bytes("JPEG")
|
|
||||||
tmp_file_path = tmp_path / "emoji.tmp"
|
|
||||||
target_file_path = tmp_path / "emoji.jpeg"
|
|
||||||
tmp_file_path.write_bytes(image_bytes)
|
|
||||||
target_file_path.write_bytes(image_bytes)
|
|
||||||
|
|
||||||
emoji = MaiEmoji(full_path=tmp_file_path, image_bytes=image_bytes)
|
|
||||||
|
|
||||||
assert await emoji.calculate_hash_format() is True
|
|
||||||
assert emoji.full_path == target_file_path.resolve()
|
|
||||||
assert emoji.file_name == target_file_path.name
|
|
||||||
assert not tmp_file_path.exists()
|
|
||||||
assert target_file_path.exists()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("model_cls", "extra_fields"),
|
|
||||||
[
|
|
||||||
(
|
|
||||||
MaiEmoji,
|
|
||||||
{
|
|
||||||
"description": "",
|
|
||||||
"last_used_time": None,
|
|
||||||
"query_count": 0,
|
|
||||||
"register_time": None,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
(
|
|
||||||
MaiImage,
|
|
||||||
{
|
|
||||||
"description": "",
|
|
||||||
"vlm_processed": False,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_from_db_instance_restores_image_format_from_path(
|
|
||||||
tmp_path: Path,
|
|
||||||
model_cls: type[MaiEmoji] | type[MaiImage],
|
|
||||||
extra_fields: dict[str, object],
|
|
||||||
) -> None:
|
|
||||||
image_path = tmp_path / "cached.png"
|
|
||||||
image_path.write_bytes(_build_test_image_bytes("PNG"))
|
|
||||||
|
|
||||||
record = SimpleNamespace(
|
|
||||||
no_file_flag=False,
|
|
||||||
image_hash="hash",
|
|
||||||
full_path=str(image_path),
|
|
||||||
**extra_fields,
|
|
||||||
)
|
|
||||||
|
|
||||||
image = model_cls.from_db_instance(record)
|
|
||||||
|
|
||||||
assert image.full_path == image_path.resolve()
|
|
||||||
assert image.file_name == image_path.name
|
|
||||||
assert image.image_format == "png"
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
class MyLogger:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def info(self, msg):
|
|
||||||
print(f"INFO: {msg}")
|
|
||||||
|
|
||||||
def error(self, msg):
|
|
||||||
print(f"ERROR: {msg}")
|
|
||||||
|
|
||||||
def debug(self, msg):
|
|
||||||
print(f"DEBUG: {msg}")
|
|
||||||
|
|
||||||
def warning(self, msg):
|
|
||||||
print(f"WARNING: {msg}")
|
|
||||||
|
|
||||||
def trace(self, msg):
|
|
||||||
print(f"TRACE: {msg}")
|
|
||||||
|
|
||||||
|
|
||||||
def get_logger(*args, **kwargs):
|
|
||||||
return MyLogger()
|
|
||||||
@@ -1,422 +0,0 @@
|
|||||||
import sys
|
|
||||||
import asyncio
|
|
||||||
import pytest
|
|
||||||
import importlib
|
|
||||||
import importlib.util
|
|
||||||
from types import ModuleType
|
|
||||||
from pathlib import Path
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from src.common.data_models.message_component_data_model import MessageSequence, ForwardComponent
|
|
||||||
from src.chat.message_receive.message import (
|
|
||||||
SessionMessage,
|
|
||||||
TextComponent,
|
|
||||||
ImageComponent,
|
|
||||||
EmojiComponent,
|
|
||||||
VoiceComponent,
|
|
||||||
AtComponent,
|
|
||||||
ReplyComponent,
|
|
||||||
ForwardNodeComponent,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DummyLogger:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.logging_record = []
|
|
||||||
|
|
||||||
def debug(self, msg):
|
|
||||||
print(f"DEBUG: {msg}")
|
|
||||||
self.logging_record.append(f"DEBUG: {msg}")
|
|
||||||
|
|
||||||
def info(self, msg):
|
|
||||||
print(f"INFO: {msg}")
|
|
||||||
self.logging_record.append(f"INFO: {msg}")
|
|
||||||
|
|
||||||
def warning(self, msg):
|
|
||||||
print(f"WARNING: {msg}")
|
|
||||||
self.logging_record.append(f"WARNING: {msg}")
|
|
||||||
|
|
||||||
def error(self, msg):
|
|
||||||
print(f"ERROR: {msg}")
|
|
||||||
self.logging_record.append(f"ERROR: {msg}")
|
|
||||||
|
|
||||||
def critical(self, msg):
|
|
||||||
print(f"CRITICAL: {msg}")
|
|
||||||
self.logging_record.append(f"CRITICAL: {msg}")
|
|
||||||
|
|
||||||
|
|
||||||
def get_logger(name):
|
|
||||||
return DummyLogger()
|
|
||||||
|
|
||||||
|
|
||||||
class DummyDBSession:
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def exec(self, statement):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def first(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def commit(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def all(self):
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def get_db_session():
|
|
||||||
return DummyDBSession()
|
|
||||||
|
|
||||||
|
|
||||||
def get_manual_db_session():
|
|
||||||
return DummyDBSession()
|
|
||||||
|
|
||||||
|
|
||||||
class DummySelect:
|
|
||||||
def __init__(self, model):
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
def filter_by(self, **kwargs):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def where(self, condition):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def limit(self, n):
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
def select(model):
|
|
||||||
return DummySelect(model)
|
|
||||||
|
|
||||||
|
|
||||||
async def dummy_get_voice_text(binary_data):
|
|
||||||
return None # 可以根据需要返回模拟的文本结果
|
|
||||||
|
|
||||||
|
|
||||||
class DummyPersonUtils:
|
|
||||||
@staticmethod
|
|
||||||
def get_person_info_by_user_id_and_platform(user_id, platform):
|
|
||||||
return None # 可以根据需要返回模拟的用户信息
|
|
||||||
|
|
||||||
|
|
||||||
def setup_mocks(monkeypatch):
|
|
||||||
def _stub_module(name: str) -> ModuleType:
|
|
||||||
module = ModuleType(name)
|
|
||||||
monkeypatch.setitem(sys.modules, name, module)
|
|
||||||
return module
|
|
||||||
|
|
||||||
# src.common.logger
|
|
||||||
logger_mod = _stub_module("src.common.logger")
|
|
||||||
# Mock the logger
|
|
||||||
logger_mod.get_logger = get_logger
|
|
||||||
|
|
||||||
db_mod = _stub_module("src.common.database.database")
|
|
||||||
db_mod.get_db_session = get_db_session
|
|
||||||
db_mod.get_manual_db_session = get_manual_db_session
|
|
||||||
|
|
||||||
db_model_mod = _stub_module("src.common.database.database_model")
|
|
||||||
db_model_mod.Messages = None # 可以根据需要添加更多的属性或方法
|
|
||||||
|
|
||||||
emoji_manager_mod = _stub_module("src.emoji_system.emoji_manager")
|
|
||||||
emoji_manager_mod.emoji_manager = None # 可以根据需要添加更多的属性或方法
|
|
||||||
|
|
||||||
image_manager_mod = _stub_module("src.chat.image_system.image_manager")
|
|
||||||
image_manager_mod.image_manager = None # 可以根据需要添加更多的属性或方法
|
|
||||||
|
|
||||||
msg_utils_mod = _stub_module("src.common.utils.utils_message")
|
|
||||||
msg_utils_mod.MessageUtils = None # 可以根据需要添加更多的属性或方法
|
|
||||||
|
|
||||||
voice_utils_mod = _stub_module("src.common.utils.utils_voice")
|
|
||||||
voice_utils_mod.get_voice_text = dummy_get_voice_text
|
|
||||||
|
|
||||||
person_utils_mod = _stub_module("src.common.utils.utils_person")
|
|
||||||
person_utils_mod.PersonUtils = DummyPersonUtils
|
|
||||||
|
|
||||||
|
|
||||||
def load_message_via_file(monkeypatch):
|
|
||||||
setup_mocks(monkeypatch)
|
|
||||||
file_path = Path(__file__).parent.parent.parent / "src" / "chat" / "message_receive" / "message.py"
|
|
||||||
spec = importlib.util.spec_from_file_location("message", file_path)
|
|
||||||
message_module = importlib.util.module_from_spec(spec)
|
|
||||||
monkeypatch.setitem(sys.modules, "message_module", message_module)
|
|
||||||
spec.loader.exec_module(message_module)
|
|
||||||
message_module.select = select
|
|
||||||
SessionMessageClass = message_module.SessionMessage
|
|
||||||
TextComponentClass = message_module.TextComponent
|
|
||||||
ImageComponentClass = message_module.ImageComponent
|
|
||||||
EmojiComponentClass = message_module.EmojiComponent
|
|
||||||
VoiceComponentClass = message_module.VoiceComponent
|
|
||||||
AtComponentClass = message_module.AtComponent
|
|
||||||
ReplyComponentClass = message_module.ReplyComponent
|
|
||||||
ForwardNodeComponentClass = message_module.ForwardNodeComponent
|
|
||||||
MessageSequenceClass = sys.modules["src.common.data_models.message_component_data_model"].MessageSequence
|
|
||||||
ForwardComponentClass = sys.modules["src.common.data_models.message_component_data_model"].ForwardComponent
|
|
||||||
globals()["SessionMessage"] = SessionMessageClass
|
|
||||||
globals()["TextComponent"] = TextComponentClass
|
|
||||||
globals()["ImageComponent"] = ImageComponentClass
|
|
||||||
globals()["EmojiComponent"] = EmojiComponentClass
|
|
||||||
globals()["VoiceComponent"] = VoiceComponentClass
|
|
||||||
globals()["AtComponent"] = AtComponentClass
|
|
||||||
globals()["ReplyComponent"] = ReplyComponentClass
|
|
||||||
globals()["ForwardNodeComponent"] = ForwardNodeComponentClass
|
|
||||||
globals()["MessageSequence"] = MessageSequenceClass
|
|
||||||
globals()["ForwardComponent"] = ForwardComponentClass
|
|
||||||
return message_module
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_process(monkeypatch):
|
|
||||||
load_message_via_file(monkeypatch)
|
|
||||||
msg = SessionMessage("msg123", datetime.now(), platform="test_platform")
|
|
||||||
msg.session_id = "session123"
|
|
||||||
msg.raw_message = MessageSequence(components=[])
|
|
||||||
msg.raw_message.components = [TextComponent("Hello, world!")]
|
|
||||||
await msg.process()
|
|
||||||
assert msg.processed_plain_text == "Hello, world!"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_multiple_text(monkeypatch):
|
|
||||||
load_message_via_file(monkeypatch)
|
|
||||||
msg = SessionMessage("msg123", datetime.now(), platform="test_platform")
|
|
||||||
msg.session_id = "session123"
|
|
||||||
msg.raw_message = MessageSequence(components=[])
|
|
||||||
msg.raw_message.components = [TextComponent("Hello,"), TextComponent("world!")]
|
|
||||||
await msg.process()
|
|
||||||
assert msg.processed_plain_text == "Hello, world!"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_image(monkeypatch):
|
|
||||||
load_message_via_file(monkeypatch)
|
|
||||||
msg = SessionMessage("msg123", datetime.now(), platform="test_platform")
|
|
||||||
msg.session_id = "session123"
|
|
||||||
msg.raw_message = MessageSequence(components=[])
|
|
||||||
msg.raw_message.components = [ImageComponent(binary_hash="image_hash"), TextComponent("Hello, world!")]
|
|
||||||
await msg.process()
|
|
||||||
assert msg.processed_plain_text == "[一张图片,网卡了加载不出来] Hello, world!"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_emoji(monkeypatch):
|
|
||||||
load_message_via_file(monkeypatch)
|
|
||||||
msg = SessionMessage("msg123", datetime.now(), platform="test_platform")
|
|
||||||
msg.session_id = "session123"
|
|
||||||
msg.raw_message = MessageSequence(components=[])
|
|
||||||
msg.raw_message.components = [EmojiComponent(binary_hash="emoji_hash"), TextComponent("Hello, world!")]
|
|
||||||
await msg.process()
|
|
||||||
assert msg.processed_plain_text == "[一个表情,网卡了加载不出来] Hello, world!"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_voice(monkeypatch):
|
|
||||||
load_message_via_file(monkeypatch)
|
|
||||||
msg = SessionMessage("msg123", datetime.now(), platform="test_platform")
|
|
||||||
msg.session_id = "session123"
|
|
||||||
msg.raw_message = MessageSequence(components=[])
|
|
||||||
msg.raw_message.components = [VoiceComponent(binary_hash="voice_hash"), TextComponent("Hello, world!")]
|
|
||||||
await msg.process()
|
|
||||||
assert msg.processed_plain_text == "[语音消息,转录失败] Hello, world!"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_at_component(monkeypatch):
|
|
||||||
load_message_via_file(monkeypatch)
|
|
||||||
msg = SessionMessage("msg123", datetime.now(), platform="test_platform")
|
|
||||||
msg.session_id = "session123"
|
|
||||||
msg.platform = "test_platform"
|
|
||||||
msg.raw_message = MessageSequence(components=[])
|
|
||||||
msg.raw_message.components = [AtComponent(target_user_id="114514"), TextComponent("Hello, world!")]
|
|
||||||
await msg.process()
|
|
||||||
assert msg.processed_plain_text == "@114514 Hello, world!"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_reply_component_fail_to_fetch(monkeypatch):
|
|
||||||
load_message_via_file(monkeypatch)
|
|
||||||
msg = SessionMessage("msg123", datetime.now(), platform="test_platform")
|
|
||||||
msg.session_id = "session123"
|
|
||||||
msg.platform = "test_platform"
|
|
||||||
msg.raw_message = MessageSequence(components=[])
|
|
||||||
msg.raw_message.components = [ReplyComponent(target_message_id="1919810"), TextComponent("Hello, world!")]
|
|
||||||
await msg.process()
|
|
||||||
assert msg.processed_plain_text == "[回复了一条消息,但原消息已无法访问] Hello, world!"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_reply_component_success(monkeypatch):
|
|
||||||
module_msg = load_message_via_file(monkeypatch)
|
|
||||||
|
|
||||||
class DummyDBSessionWithReply(DummyDBSession):
|
|
||||||
def exec(self, s):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def first(inner_self):
|
|
||||||
class DummyRecord:
|
|
||||||
processed_plain_text = "原消息内容"
|
|
||||||
user_cardname = "cardname123"
|
|
||||||
user_nickname = "nickname123"
|
|
||||||
user_id = "userid123"
|
|
||||||
|
|
||||||
return DummyRecord()
|
|
||||||
|
|
||||||
module_msg.get_db_session = lambda: DummyDBSessionWithReply()
|
|
||||||
msg = SessionMessage("msg123", datetime.now(), platform="test_platform")
|
|
||||||
msg.session_id = "session123"
|
|
||||||
msg.platform = "test_platform"
|
|
||||||
msg.raw_message = MessageSequence(components=[])
|
|
||||||
msg.raw_message.components = [ReplyComponent(target_message_id="1919810"), TextComponent("Hello, world!")]
|
|
||||||
await msg.process()
|
|
||||||
assert msg.processed_plain_text == "[回复了cardname123的消息: 原消息内容] Hello, world!"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_reply_component_with_db_fail(monkeypatch):
|
|
||||||
module_msg = load_message_via_file(monkeypatch)
|
|
||||||
|
|
||||||
class DummyDBSessionWithError(DummyDBSession):
|
|
||||||
def exec(self, s):
|
|
||||||
raise Exception("数据库查询失败")
|
|
||||||
|
|
||||||
module_msg.get_db_session = lambda: DummyDBSessionWithError()
|
|
||||||
msg = SessionMessage("msg123", datetime.now(), platform="test_platform")
|
|
||||||
msg.session_id = "session123"
|
|
||||||
msg.platform = "test_platform"
|
|
||||||
msg.raw_message = MessageSequence(components=[])
|
|
||||||
msg.raw_message.components = [ReplyComponent(target_message_id="1919810"), TextComponent("Hello, world!")]
|
|
||||||
await msg.process()
|
|
||||||
assert msg.processed_plain_text == "[回复了一条消息,但原消息已无法访问] Hello, world!"
|
|
||||||
assert any("数据库查询失败" in log for log in module_msg.logger.logging_record)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_forward_component(monkeypatch):
|
|
||||||
load_message_via_file(monkeypatch)
|
|
||||||
msg = SessionMessage("msg123", datetime.now(), platform="test_platform")
|
|
||||||
msg.session_id = "session123"
|
|
||||||
msg.platform = "test_platform"
|
|
||||||
msg.raw_message = MessageSequence(components=[])
|
|
||||||
msg.raw_message.components = [
|
|
||||||
ForwardNodeComponent(
|
|
||||||
forward_components=[
|
|
||||||
ForwardComponent(
|
|
||||||
message_id="msg1",
|
|
||||||
user_id="user1",
|
|
||||||
user_nickname="nickname1",
|
|
||||||
user_cardname="cardname1",
|
|
||||||
content=[TextComponent("转发消息1")],
|
|
||||||
),
|
|
||||||
ForwardComponent(
|
|
||||||
message_id="msg2",
|
|
||||||
user_id="user2",
|
|
||||||
user_nickname="nickname2",
|
|
||||||
user_cardname="cardname2",
|
|
||||||
content=[TextComponent("转发消息2")],
|
|
||||||
),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
TextComponent("Hello, world!"),
|
|
||||||
]
|
|
||||||
await msg.process()
|
|
||||||
print("Processed plain text:", msg.processed_plain_text)
|
|
||||||
expected_forward_text = """【合并转发消息:
|
|
||||||
-- 【cardname1】: 转发消息1
|
|
||||||
-- 【cardname2】: 转发消息2
|
|
||||||
】 Hello, world!"""
|
|
||||||
assert msg.processed_plain_text == expected_forward_text
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_forward_with_reply(monkeypatch):
|
|
||||||
load_message_via_file(monkeypatch)
|
|
||||||
msg = SessionMessage("msg123", datetime.now(), platform="test_platform")
|
|
||||||
msg.session_id = "session123"
|
|
||||||
msg.platform = "test_platform"
|
|
||||||
msg.raw_message = MessageSequence(components=[])
|
|
||||||
msg.raw_message.components = [
|
|
||||||
ForwardNodeComponent(
|
|
||||||
forward_components=[
|
|
||||||
ForwardComponent(
|
|
||||||
message_id="msg1",
|
|
||||||
user_id="user1",
|
|
||||||
user_nickname="nickname1",
|
|
||||||
user_cardname="cardname1",
|
|
||||||
content=[TextComponent("转发消息1")],
|
|
||||||
),
|
|
||||||
ForwardComponent(
|
|
||||||
message_id="msg2",
|
|
||||||
user_id="user2",
|
|
||||||
user_nickname="nickname2",
|
|
||||||
user_cardname="cardname2",
|
|
||||||
content=[ReplyComponent(target_message_id="msg1"), TextComponent("转发消息2")],
|
|
||||||
),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
TextComponent("Hello, world!"),
|
|
||||||
]
|
|
||||||
await msg.process()
|
|
||||||
assert (
|
|
||||||
msg.processed_plain_text
|
|
||||||
== """【合并转发消息:
|
|
||||||
-- 【cardname1】: 转发消息1
|
|
||||||
-- 【cardname2】: [回复了cardname1的消息: 转发消息1] 转发消息2
|
|
||||||
】 Hello, world!"""
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_multiple_reply_with_delay_in_forward(monkeypatch):
|
|
||||||
load_message_via_file(monkeypatch)
|
|
||||||
msg = SessionMessage("msg123", datetime.now(), platform="test_platform")
|
|
||||||
msg.session_id = "session123"
|
|
||||||
msg.platform = "test_platform"
|
|
||||||
msg.raw_message = MessageSequence(components=[])
|
|
||||||
|
|
||||||
async def delayed_get_voice_text(binary_data):
|
|
||||||
await asyncio.sleep(0.5) # 模拟延迟
|
|
||||||
return "这是语音转文本的结果"
|
|
||||||
|
|
||||||
sys.modules["src.common.utils.utils_voice"].get_voice_text = delayed_get_voice_text
|
|
||||||
|
|
||||||
msg.raw_message.components = [
|
|
||||||
ForwardNodeComponent(
|
|
||||||
forward_components=[
|
|
||||||
ForwardComponent(
|
|
||||||
message_id="msg1",
|
|
||||||
user_id="user1",
|
|
||||||
user_nickname="nickname1",
|
|
||||||
user_cardname="cardname1",
|
|
||||||
content=[VoiceComponent(binary_hash="voice_hash1"), TextComponent("转发消息1")],
|
|
||||||
),
|
|
||||||
ForwardComponent(
|
|
||||||
message_id="msg2",
|
|
||||||
user_id="user2",
|
|
||||||
user_nickname="nickname2",
|
|
||||||
user_cardname="cardname2",
|
|
||||||
content=[ReplyComponent(target_message_id="msg1"), TextComponent("转发消息2")],
|
|
||||||
),
|
|
||||||
ForwardComponent(
|
|
||||||
message_id="msg3",
|
|
||||||
user_id="user3",
|
|
||||||
user_nickname="nickname3",
|
|
||||||
user_cardname="cardname3",
|
|
||||||
content=[ReplyComponent(target_message_id="msg1"), TextComponent("转发消息3")],
|
|
||||||
),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
]
|
|
||||||
await msg.process()
|
|
||||||
expected_text = """【合并转发消息:
|
|
||||||
-- 【cardname1】: [语音: 这是语音转文本的结果] 转发消息1
|
|
||||||
-- 【cardname2】: [回复了cardname1的消息: [语音: 这是语音转文本的结果] 转发消息1] 转发消息2
|
|
||||||
-- 【cardname3】: [回复了cardname1的消息: [语音: 这是语音转文本的结果] 转发消息1] 转发消息3
|
|
||||||
】"""
|
|
||||||
assert msg.processed_plain_text == expected_text
|
|
||||||
@@ -1,220 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.common.i18n import set_locale
|
|
||||||
from src.common.prompt_i18n import clear_prompt_cache, load_prompt, list_prompt_templates
|
|
||||||
from src.prompt.prompt_manager import PromptManager
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def clear_prompt_i18n_cache() -> None:
|
|
||||||
set_locale("zh-CN")
|
|
||||||
clear_prompt_cache()
|
|
||||||
yield
|
|
||||||
clear_prompt_cache()
|
|
||||||
set_locale("zh-CN")
|
|
||||||
|
|
||||||
|
|
||||||
def write_prompt(prompt_dir: Path, locale: str | None, name: str, content: str) -> None:
|
|
||||||
base_dir = prompt_dir if locale is None else prompt_dir / locale
|
|
||||||
base_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
(base_dir / f"{name}.prompt").write_text(content, encoding="utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_prompt_prefers_requested_locale(tmp_path: Path) -> None:
|
|
||||||
prompts_root = tmp_path / "prompts"
|
|
||||||
write_prompt(prompts_root, "zh-CN", "replyer", "你好,{user_name}")
|
|
||||||
write_prompt(prompts_root, "en-US", "replyer", "Hello, {user_name}")
|
|
||||||
|
|
||||||
rendered = load_prompt("replyer", locale="en-US", prompts_root=prompts_root, user_name="Mai")
|
|
||||||
|
|
||||||
assert rendered == "Hello, Mai"
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_prompt_falls_back_to_default_locale(tmp_path: Path) -> None:
|
|
||||||
prompts_root = tmp_path / "prompts"
|
|
||||||
write_prompt(prompts_root, "zh-CN", "replyer", "你好,{user_name}")
|
|
||||||
|
|
||||||
rendered = load_prompt("replyer", locale="en-US", prompts_root=prompts_root, user_name="Mai")
|
|
||||||
|
|
||||||
assert rendered == "你好,Mai"
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_prompt_does_not_fall_back_to_legacy_root(tmp_path: Path) -> None:
|
|
||||||
prompts_root = tmp_path / "prompts"
|
|
||||||
write_prompt(prompts_root, None, "replyer", "Legacy {user_name}")
|
|
||||||
|
|
||||||
with pytest.raises(FileNotFoundError):
|
|
||||||
load_prompt("replyer", locale="en-US", prompts_root=prompts_root, user_name="Mai")
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_prompt_with_category_falls_back_to_default_locale_root(tmp_path: Path) -> None:
|
|
||||||
prompts_root = tmp_path / "prompts"
|
|
||||||
write_prompt(prompts_root, "zh-CN", "replyer", "你好,{user_name}")
|
|
||||||
|
|
||||||
rendered = load_prompt("replyer", locale="en-US", category="chat", prompts_root=prompts_root, user_name="Mai")
|
|
||||||
|
|
||||||
assert rendered == "你好,Mai"
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_prompt_prefers_custom_prompt_override(tmp_path: Path) -> None:
|
|
||||||
prompts_root = tmp_path / "prompts"
|
|
||||||
custom_prompts_root = tmp_path / "data" / "custom_prompts"
|
|
||||||
write_prompt(prompts_root, "zh-CN", "replyer", "Base {user_name}")
|
|
||||||
write_prompt(custom_prompts_root, "zh-CN", "replyer", "Custom {user_name}")
|
|
||||||
|
|
||||||
rendered = load_prompt(
|
|
||||||
"replyer",
|
|
||||||
locale="zh-CN",
|
|
||||||
prompts_root=prompts_root,
|
|
||||||
custom_prompts_root=custom_prompts_root,
|
|
||||||
user_name="Mai",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert rendered == "Custom Mai"
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_prompt_prefers_custom_prompt_requested_locale(tmp_path: Path) -> None:
|
|
||||||
prompts_root = tmp_path / "prompts"
|
|
||||||
custom_prompts_root = tmp_path / "data" / "custom_prompts"
|
|
||||||
write_prompt(prompts_root, "zh-CN", "replyer", "Base zh {user_name}")
|
|
||||||
write_prompt(prompts_root, "en-US", "replyer", "Base en {user_name}")
|
|
||||||
write_prompt(custom_prompts_root, "zh-CN", "replyer", "Custom zh {user_name}")
|
|
||||||
write_prompt(custom_prompts_root, "en-US", "replyer", "Custom en {user_name}")
|
|
||||||
|
|
||||||
rendered = load_prompt(
|
|
||||||
"replyer",
|
|
||||||
locale="en-US",
|
|
||||||
prompts_root=prompts_root,
|
|
||||||
custom_prompts_root=custom_prompts_root,
|
|
||||||
user_name="Mai",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert rendered == "Custom en Mai"
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_prompt_uses_requested_locale_source_before_default_custom(tmp_path: Path) -> None:
|
|
||||||
prompts_root = tmp_path / "prompts"
|
|
||||||
custom_prompts_root = tmp_path / "data" / "custom_prompts"
|
|
||||||
write_prompt(prompts_root, "zh-CN", "replyer", "Base zh {user_name}")
|
|
||||||
write_prompt(prompts_root, "en-US", "replyer", "Base en {user_name}")
|
|
||||||
write_prompt(custom_prompts_root, "zh-CN", "replyer", "Custom zh {user_name}")
|
|
||||||
|
|
||||||
rendered = load_prompt(
|
|
||||||
"replyer",
|
|
||||||
locale="en-US",
|
|
||||||
prompts_root=prompts_root,
|
|
||||||
custom_prompts_root=custom_prompts_root,
|
|
||||||
user_name="Mai",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert rendered == "Base en Mai"
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_prompt_strict_mode_raises_on_missing_placeholder(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
prompts_root = tmp_path / "prompts"
|
|
||||||
write_prompt(prompts_root, "zh-CN", "replyer", "你好,{user_name},现在是 {current_time}")
|
|
||||||
monkeypatch.setenv("MAIBOT_PROMPT_I18N_STRICT", "1")
|
|
||||||
|
|
||||||
with pytest.raises(KeyError) as exc_info:
|
|
||||||
load_prompt("replyer", locale="zh-CN", prompts_root=prompts_root, user_name="Mai")
|
|
||||||
|
|
||||||
assert "current_time" in str(exc_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_prompt_rejects_path_traversal(tmp_path: Path) -> None:
|
|
||||||
prompts_root = tmp_path / "prompts"
|
|
||||||
write_prompt(prompts_root, "zh-CN", "replyer", "你好")
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
load_prompt("../replyer", locale="zh-CN", prompts_root=prompts_root)
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_prompt_templates_prefers_locale_specific_files(tmp_path: Path) -> None:
|
|
||||||
prompts_root = tmp_path / "prompts"
|
|
||||||
write_prompt(prompts_root, "zh-CN", "replyer", "中文")
|
|
||||||
write_prompt(prompts_root, "en-US", "replyer", "English")
|
|
||||||
set_locale("en-US")
|
|
||||||
|
|
||||||
prompt_templates = list_prompt_templates(prompts_root=prompts_root)
|
|
||||||
|
|
||||||
assert prompt_templates["replyer"].path.read_text(encoding="utf-8") == "English"
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_prompt_templates_loads_directory_metadata(tmp_path: Path) -> None:
|
|
||||||
prompts_root = tmp_path / "prompts"
|
|
||||||
write_prompt(prompts_root, "zh-CN", "replyer", "中文")
|
|
||||||
metadata_path = prompts_root / "zh-CN" / ".meta.toml"
|
|
||||||
metadata_path.write_text(
|
|
||||||
"""
|
|
||||||
[replyer]
|
|
||||||
display_name = "回复器"
|
|
||||||
advanced = true
|
|
||||||
description = "用于生成回复的主模板"
|
|
||||||
""".strip(),
|
|
||||||
encoding="utf-8",
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_templates = list_prompt_templates(prompts_root=prompts_root)
|
|
||||||
metadata = prompt_templates["replyer"].metadata
|
|
||||||
|
|
||||||
assert metadata.display_name == "回复器"
|
|
||||||
assert metadata.advanced is True
|
|
||||||
assert metadata.description == "用于生成回复的主模板"
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_prompt_templates_loads_prompt_specific_metadata(tmp_path: Path) -> None:
|
|
||||||
prompts_root = tmp_path / "prompts"
|
|
||||||
write_prompt(prompts_root, "zh-CN", "replyer", "中文")
|
|
||||||
metadata_path = prompts_root / "zh-CN" / "replyer.meta.json"
|
|
||||||
metadata_path.write_text(
|
|
||||||
'{"display_name": "Replyer", "advanced": false, "description": "Prompt specific metadata"}',
|
|
||||||
encoding="utf-8",
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_templates = list_prompt_templates(prompts_root=prompts_root)
|
|
||||||
metadata = prompt_templates["replyer"].metadata
|
|
||||||
|
|
||||||
assert metadata.display_name == "Replyer"
|
|
||||||
assert metadata.advanced is False
|
|
||||||
assert metadata.description == "Prompt specific metadata"
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_prompt_templates_reports_duplicate_name_with_custom_root(tmp_path: Path) -> None:
|
|
||||||
prompts_root = tmp_path / "prompts"
|
|
||||||
first_dir = prompts_root / "zh-CN" / "chat"
|
|
||||||
second_dir = prompts_root / "zh-CN" / "system"
|
|
||||||
first_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
second_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
(first_dir / "replyer.prompt").write_text("chat", encoding="utf-8")
|
|
||||||
(second_dir / "replyer.prompt").write_text("system", encoding="utf-8")
|
|
||||||
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
|
||||||
list_prompt_templates(prompts_root=prompts_root)
|
|
||||||
|
|
||||||
assert "zh-CN/chat/replyer.prompt" in str(exc_info.value)
|
|
||||||
assert "zh-CN/system/replyer.prompt" in str(exc_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_load_prompts_prefers_locale_dir(
|
|
||||||
tmp_path: Path,
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
prompts_root = tmp_path / "prompts"
|
|
||||||
custom_prompts_root = tmp_path / "data" / "custom_prompts"
|
|
||||||
custom_prompts_root.mkdir(parents=True, exist_ok=True)
|
|
||||||
write_prompt(prompts_root, "zh-CN", "replyer", "中文模板")
|
|
||||||
write_prompt(prompts_root, "en-US", "replyer", "English template")
|
|
||||||
set_locale("en-US")
|
|
||||||
|
|
||||||
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_root, raising=False)
|
|
||||||
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_prompts_root, raising=False)
|
|
||||||
monkeypatch.setattr("src.prompt.prompt_manager.SUFFIX_PROMPT", ".prompt", raising=False)
|
|
||||||
|
|
||||||
manager = PromptManager()
|
|
||||||
manager.load_prompts()
|
|
||||||
|
|
||||||
assert manager.get_prompt("replyer").template == "English template"
|
|
||||||
@@ -1,893 +0,0 @@
|
|||||||
# File: pytests/prompt_test/test_prompt_manager.py
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import inspect
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve()
|
|
||||||
sys.path.insert(0, str(PROJECT_ROOT))
|
|
||||||
sys.path.insert(0, str(PROJECT_ROOT / "src" / "config"))
|
|
||||||
|
|
||||||
from src.common.i18n.loaders import DEFAULT_LOCALE # noqa
|
|
||||||
from src.prompt.prompt_manager import ( # noqa
|
|
||||||
SUFFIX_PROMPT,
|
|
||||||
Prompt,
|
|
||||||
PromptManager,
|
|
||||||
prompt_manager,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def write_source_prompt(prompts_dir: Path, name: str, content: str) -> Path:
|
|
||||||
from src.common.i18n.loaders import DEFAULT_LOCALE
|
|
||||||
|
|
||||||
source_dir = prompts_dir / DEFAULT_LOCALE
|
|
||||||
source_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
prompt_file = source_dir / f"{name}{SUFFIX_PROMPT}"
|
|
||||||
prompt_file.write_text(content, encoding="utf-8")
|
|
||||||
return prompt_file
|
|
||||||
|
|
||||||
|
|
||||||
# ========= Prompt 基础行为 =========
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"prompt_name, template",
|
|
||||||
[
|
|
||||||
pytest.param("simple", "Hello {name}", id="simple-template-with-field"),
|
|
||||||
pytest.param("no-fields", "Just a static template", id="template-without-fields"),
|
|
||||||
pytest.param(
|
|
||||||
"brace-escaping",
|
|
||||||
"Use {{ and }} around {field}",
|
|
||||||
id="template-with-escaped-braces",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_prompt_init_happy_paths(prompt_name: str, template: str):
|
|
||||||
# Act
|
|
||||||
prompt = Prompt(prompt_name=prompt_name, template=template)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert prompt.prompt_name == prompt_name
|
|
||||||
assert prompt.template == template
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"prompt_name, template, expected_exception, expected_msg_substring",
|
|
||||||
[
|
|
||||||
pytest.param("", "Hello {name}", ValueError, "prompt_name 不能为空", id="empty-prompt-name"),
|
|
||||||
pytest.param("valid-name", "", ValueError, "template 不能为空", id="empty-template"),
|
|
||||||
pytest.param(
|
|
||||||
"unnamed-placeholder",
|
|
||||||
"Hello {}",
|
|
||||||
ValueError,
|
|
||||||
"模板中不允许使用未命名的占位符",
|
|
||||||
id="unnamed-placeholder-not-allowed",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
"unnamed-placeholder-with-escaped-brace",
|
|
||||||
"Value {{}} and {}",
|
|
||||||
ValueError,
|
|
||||||
"模板中不允许使用未命名的占位符",
|
|
||||||
id="unnamed-placeholder-mixed-with-escaped",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_prompt_init_error_cases(
|
|
||||||
prompt_name,
|
|
||||||
template,
|
|
||||||
expected_exception,
|
|
||||||
expected_msg_substring,
|
|
||||||
):
|
|
||||||
# Act / Assert
|
|
||||||
with pytest.raises(expected_exception) as exc_info:
|
|
||||||
Prompt(prompt_name=prompt_name, template=template)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert expected_msg_substring in str(exc_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"initial_context, name, func, expected_value, expected_exception, expected_msg_substring, case_id",
|
|
||||||
[
|
|
||||||
(
|
|
||||||
{},
|
|
||||||
"const_str",
|
|
||||||
"constant",
|
|
||||||
"constant",
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
"add-context-from-string-creates-wrapper",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
{},
|
|
||||||
"callable_str",
|
|
||||||
lambda prompt_name: f"hello-{prompt_name}",
|
|
||||||
"hello-my_prompt",
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
"add-context-from-callable",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
{"dup": lambda _: "x"},
|
|
||||||
"dup",
|
|
||||||
"y",
|
|
||||||
None,
|
|
||||||
KeyError,
|
|
||||||
"Context function name 'dup' 已存在于 Prompt 'my_prompt' 中",
|
|
||||||
"add-context-duplicate-key-error",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_prompt_add_context(
|
|
||||||
initial_context,
|
|
||||||
name,
|
|
||||||
func,
|
|
||||||
expected_value,
|
|
||||||
expected_exception,
|
|
||||||
expected_msg_substring,
|
|
||||||
case_id,
|
|
||||||
):
|
|
||||||
# Arrange
|
|
||||||
prompt = Prompt(prompt_name="my_prompt", template="template")
|
|
||||||
prompt.prompt_render_context = dict(initial_context)
|
|
||||||
|
|
||||||
# Act
|
|
||||||
if expected_exception:
|
|
||||||
with pytest.raises(expected_exception) as exc_info:
|
|
||||||
prompt.add_context(name, func)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert expected_msg_substring in str(exc_info.value)
|
|
||||||
else:
|
|
||||||
prompt.add_context(name, func)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert name in prompt.prompt_render_context
|
|
||||||
result = prompt.prompt_render_context[name]("my_prompt")
|
|
||||||
assert result == expected_value
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_clone_independent_instance():
|
|
||||||
# Arrange
|
|
||||||
prompt = Prompt(prompt_name="p", template="T {x}")
|
|
||||||
prompt.add_context("x", "X")
|
|
||||||
|
|
||||||
# Act
|
|
||||||
cloned = prompt.clone()
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert cloned is not prompt
|
|
||||||
assert cloned.prompt_name == prompt.prompt_name
|
|
||||||
assert cloned.template == prompt.template
|
|
||||||
# 当前实现 clone 不复制 context
|
|
||||||
assert cloned.prompt_render_context == {}
|
|
||||||
|
|
||||||
|
|
||||||
# ========= PromptManager:添加/获取/删除/替换 =========
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_add_prompt_happy_and_error():
|
|
||||||
# Arrange
|
|
||||||
manager = PromptManager()
|
|
||||||
prompt1 = Prompt(prompt_name="p1", template="T1")
|
|
||||||
manager.add_prompt(prompt1, need_save=True)
|
|
||||||
|
|
||||||
# Act
|
|
||||||
prompt2 = Prompt(prompt_name="p2", template="T2")
|
|
||||||
manager.add_prompt(prompt2, need_save=False)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert "p1" in manager._prompt_to_save
|
|
||||||
assert "p2" not in manager._prompt_to_save
|
|
||||||
|
|
||||||
# Arrange
|
|
||||||
prompt_dup = Prompt(prompt_name="p1", template="T-dup")
|
|
||||||
|
|
||||||
# Act / Assert
|
|
||||||
with pytest.raises(KeyError) as exc_info:
|
|
||||||
manager.add_prompt(prompt_dup)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert "Prompt name 'p1' 已存在" in str(exc_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_remove_prompt_happy_and_error():
|
|
||||||
# Arrange
|
|
||||||
manager = PromptManager()
|
|
||||||
p1 = Prompt(prompt_name="p1", template="T")
|
|
||||||
manager.add_prompt(p1, need_save=True)
|
|
||||||
|
|
||||||
# Act
|
|
||||||
manager.remove_prompt("p1")
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert "p1" not in manager.prompts
|
|
||||||
assert "p1" not in manager._prompt_to_save
|
|
||||||
|
|
||||||
# Act / Assert
|
|
||||||
with pytest.raises(KeyError) as exc_info:
|
|
||||||
manager.remove_prompt("no_such")
|
|
||||||
|
|
||||||
assert "Prompt name 'no_such' 不存在" in str(exc_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_replace_prompt_happy_and_error():
|
|
||||||
# sourcery skip: extract-duplicate-method
|
|
||||||
# Arrange
|
|
||||||
manager = PromptManager()
|
|
||||||
p1 = Prompt(prompt_name="p", template="Old")
|
|
||||||
manager.add_prompt(p1, need_save=True)
|
|
||||||
|
|
||||||
p_new = Prompt(prompt_name="p", template="New")
|
|
||||||
|
|
||||||
# Act: 替换且保持 need_save
|
|
||||||
manager.replace_prompt(p_new, need_save=True)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert manager.prompts["p"].template == "New"
|
|
||||||
assert "p" in manager._prompt_to_save
|
|
||||||
|
|
||||||
# Act: 再次替换,且不需要保存
|
|
||||||
p_new2 = Prompt(prompt_name="p", template="New2")
|
|
||||||
manager.replace_prompt(p_new2, need_save=False)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert manager.prompts["p"].template == "New2"
|
|
||||||
assert "p" not in manager._prompt_to_save
|
|
||||||
|
|
||||||
# Error: 不存在的 prompt
|
|
||||||
p_unknown = Prompt(prompt_name="unknown", template="T")
|
|
||||||
with pytest.raises(KeyError) as exc_info:
|
|
||||||
manager.replace_prompt(p_unknown)
|
|
||||||
|
|
||||||
assert "Prompt name 'unknown' 不存在,无法替换" in str(exc_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_get_prompt_is_copy():
|
|
||||||
# Arrange
|
|
||||||
manager = PromptManager()
|
|
||||||
prompt = Prompt(prompt_name="original", template="T")
|
|
||||||
manager.add_prompt(prompt)
|
|
||||||
|
|
||||||
# Act
|
|
||||||
retrieved_prompt = manager.get_prompt("original")
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert retrieved_prompt is not prompt
|
|
||||||
assert retrieved_prompt.prompt_name == prompt.prompt_name
|
|
||||||
assert retrieved_prompt.template == prompt.template
|
|
||||||
assert retrieved_prompt.prompt_render_context == prompt.prompt_render_context
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_add_prompt_conflict_with_context_name():
|
|
||||||
# Arrange
|
|
||||||
manager = PromptManager()
|
|
||||||
manager.add_context_construct_function("ctx_name", lambda _: "value")
|
|
||||||
prompt_conflict = Prompt(prompt_name="ctx_name", template="T")
|
|
||||||
|
|
||||||
# Act / Assert
|
|
||||||
with pytest.raises(KeyError) as exc_info:
|
|
||||||
manager.add_prompt(prompt_conflict)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert "Prompt name 'ctx_name' 已存在" in str(exc_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_add_context_construct_function_happy():
|
|
||||||
# Arrange
|
|
||||||
manager = PromptManager()
|
|
||||||
|
|
||||||
def ctx_func(prompt_name: str) -> str:
|
|
||||||
return f"ctx-{prompt_name}"
|
|
||||||
|
|
||||||
# Act
|
|
||||||
manager.add_context_construct_function("ctx", ctx_func)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert "ctx" in manager._context_construct_functions
|
|
||||||
stored_func, module = manager._context_construct_functions["ctx"]
|
|
||||||
assert stored_func is ctx_func
|
|
||||||
assert module == __name__
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_add_context_construct_function_duplicate():
|
|
||||||
# Arrange
|
|
||||||
manager = PromptManager()
|
|
||||||
|
|
||||||
def f(_):
|
|
||||||
return "x"
|
|
||||||
|
|
||||||
manager.add_context_construct_function("dup", f)
|
|
||||||
manager.add_prompt(Prompt(prompt_name="dup_prompt", template="T"))
|
|
||||||
|
|
||||||
# Act / Assert
|
|
||||||
with pytest.raises(KeyError) as exc_info1:
|
|
||||||
manager.add_context_construct_function("dup", f)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert "Construct function name 'dup' 已存在" in str(exc_info1.value)
|
|
||||||
|
|
||||||
# Act / Assert
|
|
||||||
with pytest.raises(KeyError) as exc_info2:
|
|
||||||
manager.add_context_construct_function("dup_prompt", f)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert "Construct function name 'dup_prompt' 已存在" in str(exc_info2.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_get_prompt_not_exist():
|
|
||||||
# Arrange
|
|
||||||
manager = PromptManager()
|
|
||||||
|
|
||||||
# Act / Assert
|
|
||||||
with pytest.raises(KeyError) as exc_info:
|
|
||||||
manager.get_prompt("no_such_prompt")
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert "Prompt name 'no_such_prompt' 不存在" in str(exc_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
# ========= 渲染逻辑 =========
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"template, inner_context, global_context, expected, case_id",
|
|
||||||
[
|
|
||||||
pytest.param(
|
|
||||||
"Hello {name}",
|
|
||||||
{"name": lambda p: f"name-for-{p}"},
|
|
||||||
{},
|
|
||||||
"Hello name-for-main",
|
|
||||||
"render-with-inner-context",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
"Global {block}",
|
|
||||||
{},
|
|
||||||
{"block": lambda p: f"block-{p}"},
|
|
||||||
"Global block-main",
|
|
||||||
"render-with-global-context",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
"Mix {inner} and {global}",
|
|
||||||
{"inner": lambda p: f"inner-{p}"},
|
|
||||||
{"global": lambda p: f"global-{p}"},
|
|
||||||
"Mix inner-main and global-main",
|
|
||||||
"render-with-inner-and-global-context",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
"Escaped {{ and }} and {field}",
|
|
||||||
{"field": lambda _: "X"},
|
|
||||||
{},
|
|
||||||
"Escaped { and } and X",
|
|
||||||
"render-with-escaped-braces",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_prompt_manager_render_contexts(
|
|
||||||
template,
|
|
||||||
inner_context,
|
|
||||||
global_context,
|
|
||||||
expected,
|
|
||||||
case_id,
|
|
||||||
):
|
|
||||||
# Arrange
|
|
||||||
manager = PromptManager()
|
|
||||||
tmp_prompt = Prompt(prompt_name="main", template=template)
|
|
||||||
manager.add_prompt(tmp_prompt)
|
|
||||||
prompt = manager.get_prompt("main")
|
|
||||||
for name, fn in inner_context.items():
|
|
||||||
prompt.add_context(name, fn)
|
|
||||||
for name, fn in global_context.items():
|
|
||||||
manager.add_context_construct_function(name, fn)
|
|
||||||
|
|
||||||
# Act
|
|
||||||
rendered = await manager.render_prompt(prompt)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert rendered == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_prompt_manager_render_nested_prompts():
|
|
||||||
# Arrange
|
|
||||||
manager = PromptManager()
|
|
||||||
p1 = Prompt(prompt_name="p1", template="P1-{x}")
|
|
||||||
p2 = Prompt(prompt_name="p2", template="P2-{p1}")
|
|
||||||
p3_tmp = Prompt(prompt_name="p3", template="{p2}-end")
|
|
||||||
manager.add_prompt(p1)
|
|
||||||
manager.add_prompt(p2)
|
|
||||||
manager.add_prompt(p3_tmp)
|
|
||||||
p3 = manager.get_prompt("p3")
|
|
||||||
p3.add_context("x", lambda _: "X")
|
|
||||||
|
|
||||||
# Act
|
|
||||||
rendered = await manager.render_prompt(p3)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert rendered == "P2-P1-X-end"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_prompt_manager_render_recursive_limit():
|
|
||||||
# Arrange
|
|
||||||
manager = PromptManager()
|
|
||||||
p1_tmp = Prompt(prompt_name="p1", template="{p2}")
|
|
||||||
p2_tmp = Prompt(prompt_name="p2", template="{p1}")
|
|
||||||
manager.add_prompt(p1_tmp)
|
|
||||||
manager.add_prompt(p2_tmp)
|
|
||||||
p1 = manager.get_prompt("p1")
|
|
||||||
|
|
||||||
# Act / Assert
|
|
||||||
with pytest.raises(RecursionError) as exc_info:
|
|
||||||
await manager.render_prompt(p1)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert "递归层级过深" in str(exc_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_prompt_manager_render_missing_field_error():
|
|
||||||
# Arrange
|
|
||||||
manager = PromptManager()
|
|
||||||
tmp_prompt = Prompt(prompt_name="main", template="Hello {missing}")
|
|
||||||
manager.add_prompt(tmp_prompt)
|
|
||||||
prompt = manager.get_prompt("main")
|
|
||||||
|
|
||||||
# Act / Assert
|
|
||||||
with pytest.raises(KeyError) as exc_info:
|
|
||||||
await manager.render_prompt(prompt)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert "Prompt 'main' 中缺少必要的内容块或构建函数: 'missing'" in str(exc_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_prompt_manager_render_prefers_inner_context_over_global():
|
|
||||||
# Arrange
|
|
||||||
manager = PromptManager()
|
|
||||||
tmp_prompt = Prompt(prompt_name="main", template="{field}")
|
|
||||||
manager.add_context_construct_function("field", lambda _: "global")
|
|
||||||
manager.add_prompt(tmp_prompt)
|
|
||||||
prompt = manager.get_prompt("main")
|
|
||||||
prompt.add_context("field", lambda _: "inner")
|
|
||||||
|
|
||||||
# Act
|
|
||||||
rendered = await manager.render_prompt(prompt)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert rendered == "inner"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_prompt_manager_render_with_coroutine_context_function():
|
|
||||||
# Arrange
|
|
||||||
manager = PromptManager()
|
|
||||||
|
|
||||||
async def async_inner(prompt_name: str) -> str:
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
return f"async-{prompt_name}"
|
|
||||||
|
|
||||||
tmp_prompt = Prompt(prompt_name="main", template="{inner}")
|
|
||||||
manager.add_prompt(tmp_prompt)
|
|
||||||
prompt = manager.get_prompt("main")
|
|
||||||
prompt.add_context("inner", async_inner)
|
|
||||||
|
|
||||||
# Act
|
|
||||||
rendered = await manager.render_prompt(prompt)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert rendered == "async-main"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_prompt_manager_render_with_coroutine_global_context_function():
|
|
||||||
# Arrange
|
|
||||||
manager = PromptManager()
|
|
||||||
|
|
||||||
async def async_global(prompt_name: str) -> str:
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
return f"g-{prompt_name}"
|
|
||||||
|
|
||||||
tmp_prompt = Prompt(prompt_name="main", template="{g}")
|
|
||||||
manager.add_context_construct_function("g", async_global)
|
|
||||||
manager.add_prompt(tmp_prompt)
|
|
||||||
prompt = manager.get_prompt("main")
|
|
||||||
|
|
||||||
# Act
|
|
||||||
rendered = await manager.render_prompt(prompt)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert rendered == "g-main"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_prompt_manager_render_only_cloned_instance():
|
|
||||||
# Arrange
|
|
||||||
manager = PromptManager()
|
|
||||||
p = Prompt(prompt_name="p", template="T")
|
|
||||||
manager.add_prompt(p)
|
|
||||||
|
|
||||||
# Act / Assert: 直接用原始 p 渲染会报错
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
|
||||||
await manager.render_prompt(p)
|
|
||||||
|
|
||||||
assert "只能渲染通过 PromptManager.get_prompt 方法获取的 Prompt 实例" in str(exc_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"is_prompt_context, use_coroutine, case_id",
|
|
||||||
[
|
|
||||||
pytest.param(True, False, "prompt-context-sync-error"),
|
|
||||||
pytest.param(False, False, "global-context-sync-error"),
|
|
||||||
pytest.param(True, True, "prompt-context-async-error"),
|
|
||||||
pytest.param(False, True, "global-context-async-error"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_prompt_manager_get_function_result_error_logging(
|
|
||||||
monkeypatch,
|
|
||||||
is_prompt_context,
|
|
||||||
use_coroutine,
|
|
||||||
case_id,
|
|
||||||
):
|
|
||||||
# Arrange
|
|
||||||
manager = PromptManager()
|
|
||||||
|
|
||||||
class DummyError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def sync_func(_name: str) -> str:
|
|
||||||
raise DummyError("sync-error")
|
|
||||||
|
|
||||||
async def async_func(_name: str) -> str:
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
raise DummyError("async-error")
|
|
||||||
|
|
||||||
func = async_func if use_coroutine else sync_func
|
|
||||||
logged_messages: list[str] = []
|
|
||||||
|
|
||||||
def fake_error(msg: Any) -> None:
|
|
||||||
logged_messages.append(str(msg))
|
|
||||||
|
|
||||||
fake_logger = type("FakeLogger", (), {"error": staticmethod(fake_error)})
|
|
||||||
|
|
||||||
monkeypatch.setattr("src.prompt.prompt_manager.logger", fake_logger)
|
|
||||||
|
|
||||||
# Act / Assert
|
|
||||||
with pytest.raises(DummyError):
|
|
||||||
await manager._get_function_result(
|
|
||||||
func=func,
|
|
||||||
prompt_name="P",
|
|
||||||
field_name="field",
|
|
||||||
is_prompt_context=is_prompt_context,
|
|
||||||
module="mod",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert logged_messages
|
|
||||||
log = logged_messages[0]
|
|
||||||
if is_prompt_context:
|
|
||||||
assert "调用 Prompt 'P' 内部上下文构造函数 'field' 时出错" in log
|
|
||||||
else:
|
|
||||||
assert "调用上下文构造函数 'field' 时出错,所属模块: 'mod'" in log
|
|
||||||
|
|
||||||
|
|
||||||
# ========= add_context_construct_function 边界 =========
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_add_context_construct_function_unknown_frame(monkeypatch):
|
|
||||||
# Arrange
|
|
||||||
manager = PromptManager()
|
|
||||||
|
|
||||||
def fake_currentframe() -> None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
monkeypatch.setattr("inspect.currentframe", fake_currentframe)
|
|
||||||
|
|
||||||
def f(_):
|
|
||||||
return "x"
|
|
||||||
|
|
||||||
# Act / Assert
|
|
||||||
with pytest.raises(RuntimeError) as exc_info:
|
|
||||||
manager.add_context_construct_function("x", f)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert "无法获取调用栈" in str(exc_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_add_context_construct_function_unknown_caller_frame(monkeypatch):
|
|
||||||
# Arrange
|
|
||||||
manager = PromptManager()
|
|
||||||
real_currentframe = inspect.currentframe
|
|
||||||
|
|
||||||
class FakeFrame:
|
|
||||||
f_back = None
|
|
||||||
|
|
||||||
def fake_currentframe():
|
|
||||||
return FakeFrame()
|
|
||||||
|
|
||||||
monkeypatch.setattr("inspect.currentframe", fake_currentframe)
|
|
||||||
|
|
||||||
def f(_):
|
|
||||||
return "x"
|
|
||||||
|
|
||||||
# Act / Assert
|
|
||||||
with pytest.raises(RuntimeError) as exc_info:
|
|
||||||
manager.add_context_construct_function("x", f)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert "无法获取调用栈的上一级" in str(exc_info.value)
|
|
||||||
|
|
||||||
# Cleanup
|
|
||||||
monkeypatch.setattr("inspect.currentframe", real_currentframe)
|
|
||||||
|
|
||||||
|
|
||||||
# ========= save/load & 目录逻辑 =========
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_save_prompts_io_error_on_unlink(tmp_path, monkeypatch):
|
|
||||||
"""
|
|
||||||
save_prompts 现在的逻辑:
|
|
||||||
1. 先删除 CUSTOM_PROMPTS_DIR 下的所有 *.prompt 文件;
|
|
||||||
2. 再将 _prompt_to_save 中的 prompt 写入 CUSTOM_PROMPTS_DIR。
|
|
||||||
|
|
||||||
这里模拟删除已有自定义 prompt 文件时发生 IO 错误。
|
|
||||||
"""
|
|
||||||
# Arrange
|
|
||||||
prompts_dir = tmp_path / "prompts"
|
|
||||||
custom_dir = tmp_path / "data" / "custom_prompts"
|
|
||||||
prompts_dir.mkdir(parents=True)
|
|
||||||
custom_dir.mkdir(parents=True)
|
|
||||||
|
|
||||||
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
|
|
||||||
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
|
|
||||||
|
|
||||||
# 先在自定义目录写入一个 prompt 文件,触发 unlink 路径
|
|
||||||
old_file = custom_dir / f"old{SUFFIX_PROMPT}"
|
|
||||||
old_file.write_text("old", encoding="utf-8")
|
|
||||||
|
|
||||||
manager = PromptManager()
|
|
||||||
p1 = Prompt(prompt_name="save_error", template="T")
|
|
||||||
manager.add_prompt(p1, need_save=True)
|
|
||||||
|
|
||||||
# 打桩 Path.unlink,使删除文件时报错
|
|
||||||
def fake_unlink(self):
|
|
||||||
raise OSError("disk unlink error")
|
|
||||||
|
|
||||||
monkeypatch.setattr("pathlib.Path.unlink", fake_unlink)
|
|
||||||
|
|
||||||
# Act / Assert
|
|
||||||
with pytest.raises(OSError) as exc_info:
|
|
||||||
manager.save_prompts()
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert "disk unlink error" in str(exc_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_save_prompts_io_error_on_write(tmp_path, monkeypatch):
|
|
||||||
"""
|
|
||||||
模拟 save_prompts 在写入新 prompt 文件时发生 IO 错误。
|
|
||||||
"""
|
|
||||||
# Arrange
|
|
||||||
prompts_dir = tmp_path / "prompts"
|
|
||||||
custom_dir = tmp_path / "data" / "custom_prompts"
|
|
||||||
prompts_dir.mkdir(parents=True)
|
|
||||||
custom_dir.mkdir(parents=True)
|
|
||||||
|
|
||||||
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
|
|
||||||
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
|
|
||||||
|
|
||||||
manager = PromptManager()
|
|
||||||
p1 = Prompt(prompt_name="save_error", template="T")
|
|
||||||
manager.add_prompt(p1, need_save=True)
|
|
||||||
|
|
||||||
original_write_text = Path.write_text
|
|
||||||
|
|
||||||
def fake_write_text(self, *args, **kwargs):
|
|
||||||
if self == custom_dir / DEFAULT_LOCALE / f"save_error{SUFFIX_PROMPT}":
|
|
||||||
raise OSError("disk write error")
|
|
||||||
return original_write_text(self, *args, **kwargs)
|
|
||||||
|
|
||||||
monkeypatch.setattr(Path, "write_text", fake_write_text)
|
|
||||||
|
|
||||||
# Act / Assert
|
|
||||||
with pytest.raises(OSError) as exc_info:
|
|
||||||
manager.save_prompts()
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert "disk write error" in str(exc_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_load_prompts_io_error_from_default_dir(tmp_path, monkeypatch):
|
|
||||||
"""
|
|
||||||
模拟从默认 locale 目录读取 prompt 时发生 IO 错误。
|
|
||||||
"""
|
|
||||||
# Arrange
|
|
||||||
prompts_dir = tmp_path / "prompts"
|
|
||||||
custom_dir = tmp_path / "data" / "custom_prompts"
|
|
||||||
prompts_dir.mkdir(parents=True)
|
|
||||||
custom_dir.mkdir(parents=True)
|
|
||||||
|
|
||||||
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
|
|
||||||
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
|
|
||||||
|
|
||||||
prompt_file = write_source_prompt(prompts_dir, "bad", "content")
|
|
||||||
|
|
||||||
original_read_text = Path.read_text
|
|
||||||
|
|
||||||
def fake_read_text(self, *args, **kwargs):
|
|
||||||
if self == prompt_file:
|
|
||||||
raise OSError("read error")
|
|
||||||
return original_read_text(self, *args, **kwargs)
|
|
||||||
|
|
||||||
monkeypatch.setattr(Path, "read_text", fake_read_text)
|
|
||||||
manager = PromptManager()
|
|
||||||
|
|
||||||
# Act / Assert
|
|
||||||
with pytest.raises(OSError) as exc_info:
|
|
||||||
manager.load_prompts()
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert "read error" in str(exc_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_load_prompts_io_error_from_custom_dir(tmp_path, monkeypatch):
|
|
||||||
"""
|
|
||||||
模拟从 CUSTOM_PROMPTS_DIR 读取 prompt 时发生 IO 错误。
|
|
||||||
包含两种路径:
|
|
||||||
1. default 与 custom 同名,load_prompts 会优先读取 custom;
|
|
||||||
2. 仅 custom 有文件,且 default 无同名文件。
|
|
||||||
"""
|
|
||||||
# Arrange
|
|
||||||
prompts_dir = tmp_path / "prompts"
|
|
||||||
custom_dir = tmp_path / "data" / "custom_prompts"
|
|
||||||
prompts_dir.mkdir(parents=True)
|
|
||||||
custom_dir.mkdir(parents=True)
|
|
||||||
|
|
||||||
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
|
|
||||||
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
|
|
||||||
|
|
||||||
# default 与 custom 同名的文件
|
|
||||||
base_file = write_source_prompt(prompts_dir, "same", "base")
|
|
||||||
same_name = base_file.name
|
|
||||||
custom_file_same = custom_dir / same_name
|
|
||||||
custom_file_same.write_text("custom", encoding="utf-8")
|
|
||||||
|
|
||||||
# 仅 custom 下存在的文件
|
|
||||||
only_custom_file = custom_dir / f"only_custom{SUFFIX_PROMPT}"
|
|
||||||
only_custom_file.write_text("only", encoding="utf-8")
|
|
||||||
|
|
||||||
original_read_text = Path.read_text
|
|
||||||
|
|
||||||
def fake_read_text(self, *args, **kwargs):
|
|
||||||
if self.parent == custom_dir:
|
|
||||||
raise OSError("custom read error")
|
|
||||||
return original_read_text(self, *args, **kwargs)
|
|
||||||
|
|
||||||
monkeypatch.setattr(Path, "read_text", fake_read_text)
|
|
||||||
manager = PromptManager()
|
|
||||||
|
|
||||||
# Act / Assert
|
|
||||||
with pytest.raises(OSError) as exc_info:
|
|
||||||
manager.load_prompts()
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert "custom read error" in str(exc_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_load_prompts_custom_overrides_default(tmp_path, monkeypatch):
|
|
||||||
"""
|
|
||||||
load_prompts 逻辑:
|
|
||||||
- 遍历 locale 目录中的 source prompt
|
|
||||||
- 如果 CUSTOM_PROMPTS_DIR 下存在同名文件,则优先使用自定义目录
|
|
||||||
"""
|
|
||||||
# Arrange
|
|
||||||
prompts_dir = tmp_path / "prompts"
|
|
||||||
custom_dir = tmp_path / "data" / "custom_prompts"
|
|
||||||
prompts_dir.mkdir(parents=True)
|
|
||||||
custom_dir.mkdir(parents=True)
|
|
||||||
|
|
||||||
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
|
|
||||||
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
|
|
||||||
|
|
||||||
# source locale 目录 prompt
|
|
||||||
base_file = write_source_prompt(prompts_dir, "testp", "BaseTemplate {x}")
|
|
||||||
|
|
||||||
# 自定义目录同名 prompt,应当覆盖默认
|
|
||||||
custom_file = custom_dir / base_file.name
|
|
||||||
custom_file.write_text("CustomTemplate {x}", encoding="utf-8")
|
|
||||||
|
|
||||||
manager = PromptManager()
|
|
||||||
|
|
||||||
# Act
|
|
||||||
manager.load_prompts()
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
p = manager.get_prompt("testp")
|
|
||||||
assert p.template == "CustomTemplate {x}"
|
|
||||||
# 从自定义目录加载的 prompt 应标记为 need_save(加入 _prompt_to_save)
|
|
||||||
assert "testp" in manager._prompt_to_save
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_load_prompts_default_dir_not_mark_need_save(tmp_path, monkeypatch):
|
|
||||||
"""
|
|
||||||
从 source locale 目录加载、且没有同名自定义 prompt 时,need_save 应为 False(不进入 _prompt_to_save)。
|
|
||||||
"""
|
|
||||||
# Arrange
|
|
||||||
prompts_dir = tmp_path / "prompts"
|
|
||||||
custom_dir = tmp_path / "data" / "custom_prompts"
|
|
||||||
prompts_dir.mkdir(parents=True)
|
|
||||||
custom_dir.mkdir(parents=True)
|
|
||||||
|
|
||||||
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
|
|
||||||
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
|
|
||||||
|
|
||||||
# 仅 source locale 目录有 prompt,自定义目录中无同名文件
|
|
||||||
base_file = write_source_prompt(prompts_dir, "only_default", "DefaultTemplate {x}")
|
|
||||||
|
|
||||||
manager = PromptManager()
|
|
||||||
|
|
||||||
# Act
|
|
||||||
manager.load_prompts()
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
p = manager.get_prompt("only_default")
|
|
||||||
assert p.template == base_file.read_text(encoding="utf-8")
|
|
||||||
# 从默认目录加载的 prompt 不应标记为 need_save
|
|
||||||
assert "only_default" not in manager._prompt_to_save
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_save_prompts_use_custom_dir(tmp_path, monkeypatch):
|
|
||||||
"""
|
|
||||||
save_prompts 使用 CUSTOM_PROMPTS_DIR 进行保存。
|
|
||||||
"""
|
|
||||||
prompts_dir = tmp_path / "prompts"
|
|
||||||
custom_dir = tmp_path / "data" / "custom_prompts"
|
|
||||||
prompts_dir.mkdir(parents=True)
|
|
||||||
custom_dir.mkdir(parents=True)
|
|
||||||
|
|
||||||
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
|
|
||||||
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
|
|
||||||
|
|
||||||
manager = PromptManager()
|
|
||||||
p1 = Prompt(prompt_name="save_me", template="Template {x}")
|
|
||||||
p1.add_context("x", "X")
|
|
||||||
manager.add_prompt(p1, need_save=True)
|
|
||||||
|
|
||||||
# Act
|
|
||||||
manager.save_prompts()
|
|
||||||
|
|
||||||
# Assert: 文件应保存在 custom_dir 中
|
|
||||||
saved_file = custom_dir / DEFAULT_LOCALE / f"save_me{SUFFIX_PROMPT}"
|
|
||||||
assert saved_file.exists()
|
|
||||||
assert saved_file.read_text(encoding="utf-8") == "Template {x}"
|
|
||||||
|
|
||||||
|
|
||||||
# ========= 其它 =========
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_global_instance_access():
|
|
||||||
# Act
|
|
||||||
pm = prompt_manager
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert isinstance(pm, PromptManager)
|
|
||||||
|
|
||||||
|
|
||||||
def test_formatter_parsing_named_fields_only():
|
|
||||||
# Arrange
|
|
||||||
manager = PromptManager()
|
|
||||||
prompt = Prompt(prompt_name="main", template="A {x} B {y} C")
|
|
||||||
manager.add_prompt(prompt)
|
|
||||||
|
|
||||||
# Act
|
|
||||||
fields = {field_name for _, field_name, _, _ in manager._formatter.parse(prompt.template) if field_name}
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert fields == {"x", "y"}
|
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
from src.common.data_models.message_component_data_model import (
|
|
||||||
ImageComponent,
|
|
||||||
MessageSequence,
|
|
||||||
ReplyComponent,
|
|
||||||
TextComponent,
|
|
||||||
)
|
|
||||||
from src.llm_models.payload_content.message import RoleType
|
|
||||||
from src.maisaka.context_messages import _build_message_from_sequence
|
|
||||||
from src.maisaka.message_adapter import build_visible_text_from_sequence
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_only_message_keeps_placeholder_in_text_fallback() -> None:
|
|
||||||
message_sequence = MessageSequence(
|
|
||||||
[
|
|
||||||
TextComponent("[时间]19:21:20\n[用户名]William730\n[用户群昵称]\n[msg_id]1385025976\n[发言内容]"),
|
|
||||||
ImageComponent(binary_hash="hash", content=None, binary_data=None),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
message = _build_message_from_sequence(
|
|
||||||
RoleType.User,
|
|
||||||
message_sequence,
|
|
||||||
"[时间]19:21:20\n[用户名]William730\n[用户群昵称]\n[msg_id]1385025976\n[发言内容][图片]",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert message is not None
|
|
||||||
assert "[发言内容]" in message.get_text_content()
|
|
||||||
assert "[图片]" in message.get_text_content()
|
|
||||||
|
|
||||||
|
|
||||||
def test_whitespace_image_content_uses_placeholder_in_text_fallback() -> None:
|
|
||||||
message_sequence = MessageSequence(
|
|
||||||
[
|
|
||||||
TextComponent("[发言内容]"),
|
|
||||||
ImageComponent(binary_hash="hash", content=" ", binary_data=None),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
message = _build_message_from_sequence(
|
|
||||||
RoleType.User,
|
|
||||||
message_sequence,
|
|
||||||
"[发言内容][图片]",
|
|
||||||
enable_visual_message=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert message is not None
|
|
||||||
assert message.get_text_content() == "[发言内容][图片]"
|
|
||||||
|
|
||||||
|
|
||||||
def test_visible_text_uses_image_placeholder_for_whitespace_content() -> None:
|
|
||||||
visible_text = build_visible_text_from_sequence(
|
|
||||||
MessageSequence(
|
|
||||||
[
|
|
||||||
TextComponent("看这个"),
|
|
||||||
ImageComponent(binary_hash="hash", content=" ", binary_data=None),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
assert visible_text == "看这个[图片]"
|
|
||||||
|
|
||||||
|
|
||||||
def test_visible_text_adds_body_marker_after_reply_component() -> None:
|
|
||||||
visible_text = build_visible_text_from_sequence(
|
|
||||||
MessageSequence(
|
|
||||||
[
|
|
||||||
ReplyComponent(target_message_id="75625487"),
|
|
||||||
TextComponent("你说是那就是"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
assert visible_text == "[引用]quote_id=75625487\n[发言内容]你说是那就是"
|
|
||||||
@@ -1,72 +0,0 @@
|
|||||||
import base64
|
|
||||||
import sys
|
|
||||||
from types import ModuleType, SimpleNamespace
|
|
||||||
|
|
||||||
|
|
||||||
config_module = ModuleType("src.config.config")
|
|
||||||
|
|
||||||
|
|
||||||
class _ConfigManagerStub:
|
|
||||||
def get_model_config(self) -> SimpleNamespace:
|
|
||||||
return SimpleNamespace(api_providers=[])
|
|
||||||
|
|
||||||
def register_reload_callback(self, _: object) -> None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
config_module.config_manager = _ConfigManagerStub()
|
|
||||||
sys.modules.setdefault("src.config.config", config_module)
|
|
||||||
|
|
||||||
from src.llm_models.model_client import gemini_client
|
|
||||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
|
||||||
from src.llm_models.payload_content.tool_option import ToolCall
|
|
||||||
|
|
||||||
|
|
||||||
def _encode_signature(value: bytes) -> str:
|
|
||||||
return base64.b64encode(value).decode("ascii")
|
|
||||||
|
|
||||||
|
|
||||||
def test_convert_messages_preserves_gemini_function_call_signature_and_tool_result_id() -> None:
|
|
||||||
thought_signature = b"gemini-signature"
|
|
||||||
tool_call = ToolCall(
|
|
||||||
call_id="call-1",
|
|
||||||
func_name="reply",
|
|
||||||
args={"msg_id": "42"},
|
|
||||||
extra_content={"google": {"thought_signature": _encode_signature(thought_signature)}},
|
|
||||||
)
|
|
||||||
assistant_message = MessageBuilder().set_role(RoleType.Assistant).set_tool_calls([tool_call]).build()
|
|
||||||
tool_message = (
|
|
||||||
MessageBuilder()
|
|
||||||
.set_role(RoleType.Tool)
|
|
||||||
.set_tool_call_id("call-1")
|
|
||||||
.set_tool_name("reply")
|
|
||||||
.add_text_content('{"ok": true}')
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
|
|
||||||
contents, _ = gemini_client._convert_messages([assistant_message, tool_message])
|
|
||||||
|
|
||||||
assistant_part = contents[0].parts[0]
|
|
||||||
assert assistant_part.function_call is not None
|
|
||||||
assert assistant_part.function_call.id == "call-1"
|
|
||||||
assert assistant_part.function_call.name == "reply"
|
|
||||||
assert assistant_part.thought_signature == thought_signature
|
|
||||||
|
|
||||||
tool_part = contents[1].parts[0]
|
|
||||||
assert tool_part.function_response is not None
|
|
||||||
assert tool_part.function_response.id == "call-1"
|
|
||||||
assert tool_part.function_response.name == "reply"
|
|
||||||
assert tool_part.function_response.response == {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
def test_convert_messages_injects_dummy_signature_for_first_historical_tool_call() -> None:
|
|
||||||
tool_calls = [
|
|
||||||
ToolCall(call_id="call-1", func_name="reply", args={"msg_id": "1"}),
|
|
||||||
ToolCall(call_id="call-2", func_name="reply", args={"msg_id": "2"}),
|
|
||||||
]
|
|
||||||
assistant_message = MessageBuilder().set_role(RoleType.Assistant).set_tool_calls(tool_calls).build()
|
|
||||||
|
|
||||||
contents, _ = gemini_client._convert_messages([assistant_message])
|
|
||||||
|
|
||||||
assert contents[0].parts[0].thought_signature == gemini_client.GEMINI_FALLBACK_THOUGHT_SIGNATURE
|
|
||||||
assert contents[0].parts[1].thought_signature is None
|
|
||||||
@@ -1,194 +0,0 @@
|
|||||||
"""HTML 浏览器渲染服务测试。"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.config.official_configs import PluginRuntimeRenderConfig
|
|
||||||
from src.services import html_render_service as html_render_service_module
|
|
||||||
from src.services.html_render_service import HTMLRenderService, ManagedBrowserRecord
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeChromium:
|
|
||||||
"""用于模拟 Playwright Chromium 启动器的测试桩。"""
|
|
||||||
|
|
||||||
def __init__(self, effects: List[Any]) -> None:
|
|
||||||
"""初始化 Chromium 启动测试桩。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
effects: 每次调用 ``launch`` 时依次返回或抛出的结果。
|
|
||||||
"""
|
|
||||||
|
|
||||||
self._effects: List[Any] = list(effects)
|
|
||||||
self.calls: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
async def launch(self, **kwargs: Any) -> Any:
|
|
||||||
"""模拟 Playwright Chromium 的启动过程。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
**kwargs: 浏览器启动参数。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: 预设的浏览器对象。
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: 当预设结果为异常对象时抛出。
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.calls.append(dict(kwargs))
|
|
||||||
effect = self._effects.pop(0)
|
|
||||||
if isinstance(effect, Exception):
|
|
||||||
raise effect
|
|
||||||
return effect
|
|
||||||
|
|
||||||
|
|
||||||
class _FakePlaywright:
|
|
||||||
"""用于模拟 Playwright 根对象的测试桩。"""
|
|
||||||
|
|
||||||
def __init__(self, chromium: _FakeChromium) -> None:
|
|
||||||
"""初始化 Playwright 测试桩。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chromium: Chromium 启动器测试桩。
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.chromium = chromium
|
|
||||||
|
|
||||||
|
|
||||||
def _build_render_config(**kwargs: Any) -> PluginRuntimeRenderConfig:
|
|
||||||
"""构造用于测试的浏览器渲染配置。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
**kwargs: 需要覆盖的配置字段。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PluginRuntimeRenderConfig: 测试使用的配置对象。
|
|
||||||
"""
|
|
||||||
|
|
||||||
payload: Dict[str, Any] = {
|
|
||||||
"auto_download_chromium": True,
|
|
||||||
"browser_install_root": "data/test-playwright-browsers",
|
|
||||||
}
|
|
||||||
payload.update(kwargs)
|
|
||||||
return PluginRuntimeRenderConfig(**payload)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_launch_browser_auto_downloads_chromium_when_missing(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
|
||||||
"""未检测到可用浏览器时,应自动下载 Chromium 并记录状态。"""
|
|
||||||
|
|
||||||
monkeypatch.setattr(html_render_service_module, "PROJECT_ROOT", tmp_path)
|
|
||||||
service = HTMLRenderService()
|
|
||||||
config = _build_render_config()
|
|
||||||
fake_browser = object()
|
|
||||||
fake_chromium = _FakeChromium(
|
|
||||||
[
|
|
||||||
RuntimeError("browserType.launch: Executable doesn't exist at /tmp/chromium"),
|
|
||||||
fake_browser,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
install_calls: List[str] = []
|
|
||||||
|
|
||||||
monkeypatch.setattr(service, "_resolve_executable_path", lambda _config: "")
|
|
||||||
|
|
||||||
async def fake_install(_config: PluginRuntimeRenderConfig) -> None:
|
|
||||||
"""模拟 Chromium 自动下载。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
_config: 当前浏览器渲染配置。
|
|
||||||
"""
|
|
||||||
|
|
||||||
install_calls.append(_config.browser_install_root)
|
|
||||||
browsers_path = service._get_managed_browsers_path(_config)
|
|
||||||
(browsers_path / "chromium-1234").mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
monkeypatch.setattr(service, "_install_chromium_browser", fake_install)
|
|
||||||
|
|
||||||
browser = await service._launch_browser(_FakePlaywright(fake_chromium), config)
|
|
||||||
|
|
||||||
assert browser is fake_browser
|
|
||||||
assert install_calls == ["data/test-playwright-browsers"]
|
|
||||||
assert len(fake_chromium.calls) == 2
|
|
||||||
|
|
||||||
browser_record = service._load_managed_browser_record()
|
|
||||||
assert browser_record is not None
|
|
||||||
assert browser_record.install_source == "auto_download"
|
|
||||||
assert browser_record.browser_name == "chromium"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_launch_browser_reuses_existing_managed_browser(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
|
||||||
"""已存在 Playwright 托管浏览器时,不应重复下载。"""
|
|
||||||
|
|
||||||
monkeypatch.setattr(html_render_service_module, "PROJECT_ROOT", tmp_path)
|
|
||||||
service = HTMLRenderService()
|
|
||||||
config = _build_render_config()
|
|
||||||
browsers_path = service._get_managed_browsers_path(config)
|
|
||||||
(browsers_path / "chrome-headless-shell-1234").mkdir(parents=True, exist_ok=True)
|
|
||||||
fake_browser = object()
|
|
||||||
fake_chromium = _FakeChromium([fake_browser])
|
|
||||||
|
|
||||||
monkeypatch.setattr(service, "_resolve_executable_path", lambda _config: "")
|
|
||||||
|
|
||||||
async def fail_install(_config: PluginRuntimeRenderConfig) -> None:
|
|
||||||
"""若被错误调用则立即失败。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
_config: 当前浏览器渲染配置。
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AssertionError: 表示本测试不期望进入下载逻辑。
|
|
||||||
"""
|
|
||||||
|
|
||||||
raise AssertionError("不应触发自动下载")
|
|
||||||
|
|
||||||
monkeypatch.setattr(service, "_install_chromium_browser", fail_install)
|
|
||||||
|
|
||||||
browser = await service._launch_browser(_FakePlaywright(fake_chromium), config)
|
|
||||||
|
|
||||||
assert browser is fake_browser
|
|
||||||
assert len(fake_chromium.calls) == 1
|
|
||||||
|
|
||||||
browser_record = service._load_managed_browser_record()
|
|
||||||
assert browser_record is not None
|
|
||||||
assert browser_record.install_source == "existing_cache"
|
|
||||||
assert browser_record.browsers_path == str(browsers_path)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_launch_browser_prefers_local_executable(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
|
||||||
"""探测到本机浏览器时,应优先使用可执行文件路径启动。"""
|
|
||||||
|
|
||||||
monkeypatch.setattr(html_render_service_module, "PROJECT_ROOT", tmp_path)
|
|
||||||
service = HTMLRenderService()
|
|
||||||
config = _build_render_config()
|
|
||||||
fake_browser = object()
|
|
||||||
fake_chromium = _FakeChromium([fake_browser])
|
|
||||||
executable_path = "/usr/bin/google-chrome"
|
|
||||||
|
|
||||||
monkeypatch.setattr(service, "_resolve_executable_path", lambda _config: executable_path)
|
|
||||||
|
|
||||||
browser = await service._launch_browser(_FakePlaywright(fake_chromium), config)
|
|
||||||
|
|
||||||
assert browser is fake_browser
|
|
||||||
assert len(fake_chromium.calls) == 1
|
|
||||||
assert fake_chromium.calls[0]["executable_path"] == executable_path
|
|
||||||
assert service._load_managed_browser_record() is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_managed_browser_record_roundtrip() -> None:
|
|
||||||
"""托管浏览器记录应支持序列化与反序列化。"""
|
|
||||||
|
|
||||||
record = ManagedBrowserRecord(
|
|
||||||
browser_name="chromium",
|
|
||||||
browsers_path="/tmp/playwright-browsers",
|
|
||||||
install_source="auto_download",
|
|
||||||
playwright_version="1.58.0",
|
|
||||||
recorded_at="2026-04-03T10:00:00+00:00",
|
|
||||||
last_verified_at="2026-04-03T10:00:01+00:00",
|
|
||||||
)
|
|
||||||
|
|
||||||
restored_record = ManagedBrowserRecord.from_dict(record.to_dict())
|
|
||||||
|
|
||||||
assert restored_record == record
|
|
||||||
@@ -1,101 +0,0 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
from src.llm_models.model_client.base_client import (
|
|
||||||
APIResponse,
|
|
||||||
AudioTranscriptionRequest,
|
|
||||||
BaseClient,
|
|
||||||
ClientProviderRegistration,
|
|
||||||
ClientRegistry,
|
|
||||||
EmbeddingRequest,
|
|
||||||
ResponseRequest,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DummyClient(BaseClient):
|
|
||||||
"""测试用 LLM 客户端。"""
|
|
||||||
|
|
||||||
async def get_response(self, request: ResponseRequest) -> APIResponse:
|
|
||||||
"""获取测试响应。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: 统一响应请求。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
APIResponse: 测试响应。
|
|
||||||
"""
|
|
||||||
del request
|
|
||||||
return APIResponse(content="ok")
|
|
||||||
|
|
||||||
async def get_embedding(self, request: EmbeddingRequest) -> APIResponse:
|
|
||||||
"""获取测试嵌入。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: 统一嵌入请求。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
APIResponse: 测试嵌入响应。
|
|
||||||
"""
|
|
||||||
del request
|
|
||||||
return APIResponse(embedding=[1.0])
|
|
||||||
|
|
||||||
async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse:
|
|
||||||
"""获取测试音频转写。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: 统一音频转写请求。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
APIResponse: 测试音频转写响应。
|
|
||||||
"""
|
|
||||||
del request
|
|
||||||
return APIResponse(content="audio")
|
|
||||||
|
|
||||||
def get_support_image_formats(self) -> List[str]:
|
|
||||||
"""获取测试支持的图片格式。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: 支持的图片格式列表。
|
|
||||||
"""
|
|
||||||
return ["png"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_client_registry_rejects_provider_conflict():
|
|
||||||
"""同一 client_type 被不同插件注册时应拒绝。"""
|
|
||||||
registry = ClientRegistry()
|
|
||||||
registry.replace_plugin_providers(
|
|
||||||
"plugin.alpha",
|
|
||||||
[
|
|
||||||
ClientProviderRegistration(
|
|
||||||
client_type="example",
|
|
||||||
factory=DummyClient,
|
|
||||||
owner_plugin_id="plugin.alpha",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
registry.validate_plugin_provider_replacement("plugin.beta", ["example"])
|
|
||||||
except ValueError as exc:
|
|
||||||
assert "冲突" in str(exc)
|
|
||||||
else:
|
|
||||||
raise AssertionError("不同插件注册相同 client_type 应失败")
|
|
||||||
|
|
||||||
|
|
||||||
def test_client_registry_unregisters_plugin_providers():
|
|
||||||
"""插件注销时应移除它拥有的 Provider 注册。"""
|
|
||||||
registry = ClientRegistry()
|
|
||||||
registry.replace_plugin_providers(
|
|
||||||
"plugin.alpha",
|
|
||||||
[
|
|
||||||
ClientProviderRegistration(
|
|
||||||
client_type="example",
|
|
||||||
factory=DummyClient,
|
|
||||||
owner_plugin_id="plugin.alpha",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
removed_count = registry.unregister_plugin_providers("plugin.alpha")
|
|
||||||
|
|
||||||
assert removed_count == 1
|
|
||||||
assert "example" not in registry.client_registry
|
|
||||||
@@ -1,113 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
from src.chat.message_receive.message import SessionMessage
|
|
||||||
from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo
|
|
||||||
from src.common.data_models.message_component_data_model import AtComponent, MessageSequence, ReplyComponent, TextComponent
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.maisaka.builtin_tool.context import BuiltinToolRuntimeContext
|
|
||||||
from src.maisaka.runtime import MaisakaHeartFlowChatting
|
|
||||||
|
|
||||||
|
|
||||||
def _build_sent_message() -> SessionMessage:
|
|
||||||
message = SessionMessage(
|
|
||||||
message_id="real-message-id",
|
|
||||||
timestamp=datetime(2026, 4, 5, 12, 0, 0),
|
|
||||||
platform="qq",
|
|
||||||
)
|
|
||||||
message.message_info = MessageInfo(
|
|
||||||
user_info=UserInfo(
|
|
||||||
user_id="bot-qq",
|
|
||||||
user_nickname="MaiSaka",
|
|
||||||
user_cardname=None,
|
|
||||||
),
|
|
||||||
group_info=None,
|
|
||||||
additional_config={},
|
|
||||||
)
|
|
||||||
message.raw_message = MessageSequence(
|
|
||||||
[
|
|
||||||
ReplyComponent(target_message_id="m123"),
|
|
||||||
TextComponent(text="你好"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
message.session_id = "test-session"
|
|
||||||
message.initialized = True
|
|
||||||
return message
|
|
||||||
|
|
||||||
|
|
||||||
def test_append_sent_message_to_chat_history_keeps_message_id() -> None:
|
|
||||||
runtime = SimpleNamespace(_chat_history=[])
|
|
||||||
engine = SimpleNamespace(_get_runtime_manager=lambda: None)
|
|
||||||
tool_ctx = BuiltinToolRuntimeContext(engine=engine, runtime=runtime)
|
|
||||||
|
|
||||||
tool_ctx.append_sent_message_to_chat_history(_build_sent_message())
|
|
||||||
|
|
||||||
assert len(runtime._chat_history) == 1
|
|
||||||
history_message = runtime._chat_history[0]
|
|
||||||
assert history_message.message_id == "real-message-id"
|
|
||||||
assert "[msg_id]real-message-id\n" in history_message.raw_message.components[0].text
|
|
||||||
assert "[msg_id:real-message-id]" in history_message.visible_text
|
|
||||||
|
|
||||||
|
|
||||||
def test_post_process_reply_message_sequences_converts_at_marker_before_bracket_cleanup(monkeypatch) -> None:
|
|
||||||
monkeypatch.setattr(global_config.chat, "enable_at", True)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"src.maisaka.builtin_tool.context.process_llm_response",
|
|
||||||
lambda text: [text.strip()] if text.strip() else [],
|
|
||||||
)
|
|
||||||
target_message = SimpleNamespace(
|
|
||||||
message_info=SimpleNamespace(
|
|
||||||
user_info=SimpleNamespace(
|
|
||||||
user_id="target-user",
|
|
||||||
user_nickname="目标昵称",
|
|
||||||
user_cardname="群名片",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
runtime = SimpleNamespace(
|
|
||||||
find_source_message_by_id=lambda message_id: target_message if message_id == "12160142" else None
|
|
||||||
)
|
|
||||||
engine = SimpleNamespace(_get_runtime_manager=lambda: None)
|
|
||||||
tool_ctx = BuiltinToolRuntimeContext(engine=engine, runtime=runtime)
|
|
||||||
|
|
||||||
sequences = tool_ctx.post_process_reply_message_sequences("at[12160142] 就这个群")
|
|
||||||
|
|
||||||
assert len(sequences) == 1
|
|
||||||
components = sequences[0].components
|
|
||||||
assert isinstance(components[0], AtComponent)
|
|
||||||
assert components[0].target_user_id == "target-user"
|
|
||||||
assert components[0].target_user_nickname == "目标昵称"
|
|
||||||
assert components[0].target_user_cardname == "群名片"
|
|
||||||
assert isinstance(components[1], TextComponent)
|
|
||||||
assert components[1].text == " 就这个群"
|
|
||||||
|
|
||||||
|
|
||||||
def test_post_process_reply_message_sequences_ignores_at_marker_when_disabled(monkeypatch) -> None:
|
|
||||||
monkeypatch.setattr(global_config.chat, "enable_at", False)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"src.maisaka.builtin_tool.context.process_llm_response",
|
|
||||||
lambda text: [text.strip()] if text.strip() else [],
|
|
||||||
)
|
|
||||||
runtime = SimpleNamespace(find_source_message_by_id=lambda message_id: None)
|
|
||||||
engine = SimpleNamespace(_get_runtime_manager=lambda: None)
|
|
||||||
tool_ctx = BuiltinToolRuntimeContext(engine=engine, runtime=runtime)
|
|
||||||
|
|
||||||
sequences = tool_ctx.post_process_reply_message_sequences("at[12160142] 就这个群")
|
|
||||||
|
|
||||||
assert len(sequences) == 1
|
|
||||||
components = sequences[0].components
|
|
||||||
assert len(components) == 1
|
|
||||||
assert isinstance(components[0], TextComponent)
|
|
||||||
assert components[0].text == "at[12160142] 就这个群"
|
|
||||||
|
|
||||||
|
|
||||||
def test_runtime_finds_source_message_from_history() -> None:
|
|
||||||
target_message = _build_sent_message()
|
|
||||||
runtime = object.__new__(MaisakaHeartFlowChatting)
|
|
||||||
runtime._chat_history = [
|
|
||||||
SimpleNamespace(message_id="other-message-id", original_message=SimpleNamespace()),
|
|
||||||
SimpleNamespace(message_id="real-message-id", original_message=target_message),
|
|
||||||
]
|
|
||||||
|
|
||||||
assert runtime.find_source_message_by_id("real-message-id") is target_message
|
|
||||||
assert runtime.find_source_message_by_id("missing-message-id") is None
|
|
||||||
@@ -1,241 +0,0 @@
|
|||||||
from types import SimpleNamespace
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.core.tooling import ToolInvocation
|
|
||||||
from src.maisaka.builtin_tool import query_memory as query_memory_tool
|
|
||||||
from src.maisaka.builtin_tool.context import BuiltinToolRuntimeContext
|
|
||||||
from src.services.memory_service import MemoryHit, MemorySearchResult
|
|
||||||
|
|
||||||
|
|
||||||
def _build_tool_ctx(
|
|
||||||
*,
|
|
||||||
session_id: str = "session-1",
|
|
||||||
platform: str = "qq",
|
|
||||||
user_id: str = "user-1",
|
|
||||||
group_id: str = "",
|
|
||||||
) -> BuiltinToolRuntimeContext:
|
|
||||||
runtime = SimpleNamespace(
|
|
||||||
session_id=session_id,
|
|
||||||
chat_stream=SimpleNamespace(
|
|
||||||
platform=platform,
|
|
||||||
user_id=user_id,
|
|
||||||
group_id=group_id,
|
|
||||||
),
|
|
||||||
log_prefix=f"[{session_id}]",
|
|
||||||
)
|
|
||||||
return BuiltinToolRuntimeContext(engine=SimpleNamespace(), runtime=runtime)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_invocation(arguments: Dict[str, Any]) -> ToolInvocation:
|
|
||||||
return ToolInvocation(
|
|
||||||
tool_name="query_memory",
|
|
||||||
arguments=dict(arguments),
|
|
||||||
call_id="call-query-memory",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def _patch_maisaka_config(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
monkeypatch.setattr(
|
|
||||||
query_memory_tool,
|
|
||||||
"global_config",
|
|
||||||
SimpleNamespace(memory=SimpleNamespace(memory_query_default_limit=5)),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_memory_rejects_empty_query_and_time(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult:
|
|
||||||
_ = query
|
|
||||||
_ = kwargs
|
|
||||||
raise AssertionError("参数校验失败时不应调用 memory_service.search")
|
|
||||||
|
|
||||||
monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search)
|
|
||||||
|
|
||||||
result = await query_memory_tool.handle_tool(
|
|
||||||
_build_tool_ctx(),
|
|
||||||
_build_invocation({"query": "", "time_start": "", "time_end": ""}),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.success is False
|
|
||||||
assert "query_memory 需要提供 query" in result.error_message
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_memory_private_chat_auto_sets_person_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
captured: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
def fake_resolve_person_id_for_memory(
|
|
||||||
*,
|
|
||||||
person_name: str = "",
|
|
||||||
platform: str = "",
|
|
||||||
user_id: Any = None,
|
|
||||||
strict_known: bool = False,
|
|
||||||
) -> str:
|
|
||||||
_ = strict_known
|
|
||||||
captured["resolve_args"] = {
|
|
||||||
"person_name": person_name,
|
|
||||||
"platform": platform,
|
|
||||||
"user_id": user_id,
|
|
||||||
}
|
|
||||||
return "pid-private-auto"
|
|
||||||
|
|
||||||
async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult:
|
|
||||||
captured["query"] = query
|
|
||||||
captured["search_kwargs"] = dict(kwargs)
|
|
||||||
return MemorySearchResult(
|
|
||||||
summary="检索摘要",
|
|
||||||
hits=[MemoryHit(content="Alice 喜欢咖啡", score=0.91)],
|
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
|
|
||||||
monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search)
|
|
||||||
|
|
||||||
result = await query_memory_tool.handle_tool(
|
|
||||||
_build_tool_ctx(session_id="private-session", platform="qq", user_id="alice", group_id=""),
|
|
||||||
_build_invocation({"query": "Alice 的喜好"}),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.success is True
|
|
||||||
assert captured["query"] == "Alice 的喜好"
|
|
||||||
assert captured["resolve_args"] == {
|
|
||||||
"person_name": "",
|
|
||||||
"platform": "qq",
|
|
||||||
"user_id": "alice",
|
|
||||||
}
|
|
||||||
assert captured["search_kwargs"]["chat_id"] == "private-session"
|
|
||||||
assert captured["search_kwargs"]["user_id"] == "alice"
|
|
||||||
assert captured["search_kwargs"]["group_id"] == ""
|
|
||||||
assert captured["search_kwargs"]["person_id"] == "pid-private-auto"
|
|
||||||
assert isinstance(result.structured_content, dict)
|
|
||||||
assert result.structured_content["person_id"] == "pid-private-auto"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_memory_group_chat_does_not_attach_default_person_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
call_counter = {"resolve": 0}
|
|
||||||
captured_kwargs: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
def fake_resolve_person_id_for_memory(
|
|
||||||
*,
|
|
||||||
person_name: str = "",
|
|
||||||
platform: str = "",
|
|
||||||
user_id: Any = None,
|
|
||||||
strict_known: bool = False,
|
|
||||||
) -> str:
|
|
||||||
_ = person_name
|
|
||||||
_ = platform
|
|
||||||
_ = user_id
|
|
||||||
_ = strict_known
|
|
||||||
call_counter["resolve"] += 1
|
|
||||||
return "unexpected-person-id"
|
|
||||||
|
|
||||||
async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult:
|
|
||||||
_ = query
|
|
||||||
captured_kwargs.update(kwargs)
|
|
||||||
return MemorySearchResult(summary="", hits=[])
|
|
||||||
|
|
||||||
monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
|
|
||||||
monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search)
|
|
||||||
|
|
||||||
result = await query_memory_tool.handle_tool(
|
|
||||||
_build_tool_ctx(session_id="group-session", platform="qq", user_id="alice", group_id="group-1"),
|
|
||||||
_build_invocation({"query": "群聊上下文"}),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.success is True
|
|
||||||
assert call_counter["resolve"] == 0
|
|
||||||
assert captured_kwargs["chat_id"] == "group-session"
|
|
||||||
assert captured_kwargs["group_id"] == "group-1"
|
|
||||||
assert captured_kwargs["person_id"] == ""
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_memory_search_failure_is_returned(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult:
|
|
||||||
_ = query
|
|
||||||
_ = kwargs
|
|
||||||
return MemorySearchResult(success=False, error="boom")
|
|
||||||
|
|
||||||
monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search)
|
|
||||||
monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", lambda **kwargs: "")
|
|
||||||
|
|
||||||
result = await query_memory_tool.handle_tool(
|
|
||||||
_build_tool_ctx(),
|
|
||||||
_build_invocation({"query": "测试失败透传"}),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.success is False
|
|
||||||
assert result.error_message == "boom"
|
|
||||||
assert isinstance(result.structured_content, dict)
|
|
||||||
assert result.structured_content["success"] is False
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_memory_prefers_person_name_resolution(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
captured: Dict[str, Any] = {"resolve_calls": []}
|
|
||||||
|
|
||||||
def fake_resolve_person_id_for_memory(
|
|
||||||
*,
|
|
||||||
person_name: str = "",
|
|
||||||
platform: str = "",
|
|
||||||
user_id: Any = None,
|
|
||||||
strict_known: bool = False,
|
|
||||||
) -> str:
|
|
||||||
_ = strict_known
|
|
||||||
captured["resolve_calls"].append(
|
|
||||||
{
|
|
||||||
"person_name": person_name,
|
|
||||||
"platform": platform,
|
|
||||||
"user_id": user_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if person_name:
|
|
||||||
return "pid-by-name"
|
|
||||||
return "pid-private-auto"
|
|
||||||
|
|
||||||
async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult:
|
|
||||||
_ = query
|
|
||||||
captured["search_kwargs"] = dict(kwargs)
|
|
||||||
return MemorySearchResult(summary="", hits=[MemoryHit(content="命中1")])
|
|
||||||
|
|
||||||
monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
|
|
||||||
monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search)
|
|
||||||
|
|
||||||
result = await query_memory_tool.handle_tool(
|
|
||||||
_build_tool_ctx(session_id="private-session", platform="qq", user_id="alice", group_id=""),
|
|
||||||
_build_invocation({"query": "小明资料", "person_name": "小明"}),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.success is True
|
|
||||||
assert captured["resolve_calls"][0] == {
|
|
||||||
"person_name": "小明",
|
|
||||||
"platform": "qq",
|
|
||||||
"user_id": "alice",
|
|
||||||
}
|
|
||||||
assert captured["search_kwargs"]["person_id"] == "pid-by-name"
|
|
||||||
assert result.structured_content["person_name"] == "小明"
|
|
||||||
assert result.structured_content["person_id"] == "pid-by-name"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_memory_no_hit_returns_readable_message(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult:
|
|
||||||
_ = query
|
|
||||||
_ = kwargs
|
|
||||||
return MemorySearchResult(summary="", hits=[])
|
|
||||||
|
|
||||||
monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search)
|
|
||||||
monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", lambda **kwargs: "")
|
|
||||||
|
|
||||||
result = await query_memory_tool.handle_tool(
|
|
||||||
_build_tool_ctx(),
|
|
||||||
_build_invocation({"query": "不存在的记忆"}),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.success is True
|
|
||||||
assert "未找到匹配的长期记忆" in result.content
|
|
||||||
assert isinstance(result.structured_content, dict)
|
|
||||||
assert result.structured_content["query"] == "不存在的记忆"
|
|
||||||
@@ -1,105 +0,0 @@
|
|||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import time
|
|
||||||
|
|
||||||
from src.chat.heart_flow import heartflow_manager as heartflow_manager_module
|
|
||||||
from src.chat.heart_flow.heartflow_manager import HEARTFLOW_ACTIVE_RETENTION_SECONDS, HeartflowManager
|
|
||||||
from src.learners.expression_learner import ExpressionLearner
|
|
||||||
from src.maisaka.runtime import MAX_RETAINED_MESSAGE_CACHE_SIZE, MaisakaHeartFlowChatting
|
|
||||||
|
|
||||||
|
|
||||||
def _build_runtime_with_messages(message_count: int) -> MaisakaHeartFlowChatting:
|
|
||||||
runtime = object.__new__(MaisakaHeartFlowChatting)
|
|
||||||
runtime.log_prefix = "[test]"
|
|
||||||
runtime.message_cache = [SimpleNamespace(message_id=f"msg-{index}") for index in range(message_count)]
|
|
||||||
runtime._last_processed_index = message_count
|
|
||||||
runtime._expression_learner = ExpressionLearner("session-1")
|
|
||||||
runtime._expression_learner.mark_all_processed(runtime.message_cache)
|
|
||||||
return runtime
|
|
||||||
|
|
||||||
|
|
||||||
def test_prune_processed_message_cache_keeps_bounded_recent_window() -> None:
|
|
||||||
runtime = _build_runtime_with_messages(MAX_RETAINED_MESSAGE_CACHE_SIZE + 25)
|
|
||||||
|
|
||||||
runtime._prune_processed_message_cache()
|
|
||||||
|
|
||||||
assert len(runtime.message_cache) == MAX_RETAINED_MESSAGE_CACHE_SIZE
|
|
||||||
assert runtime.message_cache[0].message_id == "msg-25"
|
|
||||||
assert runtime._last_processed_index == MAX_RETAINED_MESSAGE_CACHE_SIZE
|
|
||||||
assert runtime._expression_learner.last_processed_index == MAX_RETAINED_MESSAGE_CACHE_SIZE
|
|
||||||
|
|
||||||
|
|
||||||
def test_prune_processed_message_cache_keeps_unlearned_messages() -> None:
|
|
||||||
runtime = _build_runtime_with_messages(MAX_RETAINED_MESSAGE_CACHE_SIZE + 25)
|
|
||||||
runtime._expression_learner.discard_processed_prefix(MAX_RETAINED_MESSAGE_CACHE_SIZE + 5)
|
|
||||||
|
|
||||||
runtime._prune_processed_message_cache()
|
|
||||||
|
|
||||||
assert len(runtime.message_cache) == MAX_RETAINED_MESSAGE_CACHE_SIZE + 5
|
|
||||||
assert runtime.message_cache[0].message_id == "msg-20"
|
|
||||||
assert runtime._expression_learner.last_processed_index == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_collect_pending_messages_uses_single_pending_received_time() -> None:
|
|
||||||
runtime = _build_runtime_with_messages(2)
|
|
||||||
runtime._last_processed_index = 0
|
|
||||||
runtime._oldest_pending_message_received_at = 123.0
|
|
||||||
runtime._last_message_received_at = 456.0
|
|
||||||
runtime._reply_latency_measurement_started_at = None
|
|
||||||
|
|
||||||
pending_messages = runtime._collect_pending_messages()
|
|
||||||
|
|
||||||
assert [message.message_id for message in pending_messages] == ["msg-0", "msg-1"]
|
|
||||||
assert runtime._reply_latency_measurement_started_at == 123.0
|
|
||||||
assert runtime._oldest_pending_message_received_at is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_heartflow_manager_evicts_lru_chat_over_limit(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
manager = HeartflowManager()
|
|
||||||
stopped_session_ids: list[str] = []
|
|
||||||
old_active_at = time.time() - HEARTFLOW_ACTIVE_RETENTION_SECONDS - 1
|
|
||||||
|
|
||||||
class FakeChat:
|
|
||||||
def __init__(self, session_id: str) -> None:
|
|
||||||
self.session_id = session_id
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
stopped_session_ids.append(self.session_id)
|
|
||||||
|
|
||||||
monkeypatch.setattr(heartflow_manager_module, "HEARTFLOW_MAX_ACTIVE_CHATS", 2)
|
|
||||||
manager.heartflow_chat_list["session-1"] = FakeChat("session-1")
|
|
||||||
manager.heartflow_chat_list["session-2"] = FakeChat("session-2")
|
|
||||||
manager.heartflow_chat_list["session-3"] = FakeChat("session-3")
|
|
||||||
manager._chat_last_active_at["session-1"] = old_active_at
|
|
||||||
manager._chat_last_active_at["session-2"] = old_active_at
|
|
||||||
manager._chat_last_active_at["session-3"] = time.time()
|
|
||||||
|
|
||||||
await manager._evict_over_limit_chats(protected_session_id="session-3")
|
|
||||||
|
|
||||||
assert stopped_session_ids == ["session-1"]
|
|
||||||
assert list(manager.heartflow_chat_list) == ["session-2", "session-3"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_heartflow_manager_keeps_recent_chats_even_over_limit(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
manager = HeartflowManager()
|
|
||||||
stopped_session_ids: list[str] = []
|
|
||||||
|
|
||||||
class FakeChat:
|
|
||||||
def __init__(self, session_id: str) -> None:
|
|
||||||
self.session_id = session_id
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
stopped_session_ids.append(self.session_id)
|
|
||||||
|
|
||||||
monkeypatch.setattr(heartflow_manager_module, "HEARTFLOW_MAX_ACTIVE_CHATS", 2)
|
|
||||||
for session_id in ("session-1", "session-2", "session-3"):
|
|
||||||
manager.heartflow_chat_list[session_id] = FakeChat(session_id)
|
|
||||||
manager._chat_last_active_at[session_id] = time.time()
|
|
||||||
|
|
||||||
await manager._evict_over_limit_chats(protected_session_id="session-3")
|
|
||||||
|
|
||||||
assert stopped_session_ids == []
|
|
||||||
assert list(manager.heartflow_chat_list) == ["session-1", "session-2", "session-3"]
|
|
||||||
@@ -1,54 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from src.chat.message_receive.message import SessionMessage
|
|
||||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
|
||||||
from src.llm_models.payload_content.tool_option import ToolCall
|
|
||||||
from src.maisaka.message_adapter import build_message, get_message_kind, get_message_role, get_tool_call_id, get_tool_calls
|
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
|
||||||
if str(PROJECT_ROOT) not in sys.path:
|
|
||||||
sys.path.insert(0, str(PROJECT_ROOT))
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_message_returns_session_message_with_maisaka_metadata() -> None:
|
|
||||||
timestamp = datetime.now()
|
|
||||||
tool_call = ToolCall(
|
|
||||||
call_id="call-1",
|
|
||||||
func_name="reply",
|
|
||||||
args={"message_id": "msg-1"},
|
|
||||||
)
|
|
||||||
raw_message = MessageSequence(components=[TextComponent(text="内部消息内容")])
|
|
||||||
|
|
||||||
message = build_message(
|
|
||||||
role="assistant",
|
|
||||||
content="展示消息内容",
|
|
||||||
message_kind="perception",
|
|
||||||
source="assistant",
|
|
||||||
tool_call_id="call-1",
|
|
||||||
tool_calls=[tool_call],
|
|
||||||
timestamp=timestamp,
|
|
||||||
message_id="maisaka-msg-1",
|
|
||||||
raw_message=raw_message,
|
|
||||||
display_text="展示消息内容",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(message, SessionMessage)
|
|
||||||
assert message.initialized is True
|
|
||||||
assert message.message_id == "maisaka-msg-1"
|
|
||||||
assert message.timestamp == timestamp
|
|
||||||
assert message.processed_plain_text == "展示消息内容"
|
|
||||||
assert message.raw_message is raw_message
|
|
||||||
|
|
||||||
assert get_message_role(message) == "assistant"
|
|
||||||
assert get_message_kind(message) == "perception"
|
|
||||||
assert get_tool_call_id(message) == "call-1"
|
|
||||||
|
|
||||||
restored_tool_calls = get_tool_calls(message)
|
|
||||||
assert len(restored_tool_calls) == 1
|
|
||||||
assert restored_tool_calls[0].call_id == "call-1"
|
|
||||||
assert restored_tool_calls[0].func_name == "reply"
|
|
||||||
assert restored_tool_calls[0].args == {"message_id": "msg-1"}
|
|
||||||
@@ -1,619 +0,0 @@
|
|||||||
from types import SimpleNamespace
|
|
||||||
from typing import Any, Callable
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.text import Text
|
|
||||||
|
|
||||||
from src.chat.replyer import maisaka_generator as replyer_module
|
|
||||||
from src.common.data_models.reply_generation_data_models import (
|
|
||||||
GenerationMetrics,
|
|
||||||
LLMCompletionResult,
|
|
||||||
ReplyGenerationResult,
|
|
||||||
)
|
|
||||||
from src.core.tooling import ToolExecutionResult, ToolInvocation
|
|
||||||
from src.maisaka.builtin_tool.context import BuiltinToolRuntimeContext
|
|
||||||
from src.maisaka.builtin_tool import reply as reply_tool_module
|
|
||||||
from src.maisaka.builtin_tool import send_emoji as send_emoji_tool_module
|
|
||||||
from src.maisaka.monitor_events import emit_planner_finalized
|
|
||||||
from src.maisaka.reasoning_engine import MaisakaReasoningEngine
|
|
||||||
from src.maisaka import runtime as runtime_module
|
|
||||||
from src.maisaka.runtime import MaisakaHeartFlowChatting
|
|
||||||
|
|
||||||
|
|
||||||
def test_runtime_maps_expression_config_flags_to_correct_fields(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
fake_chat_stream = SimpleNamespace(
|
|
||||||
is_group_session=True,
|
|
||||||
group_id="group-1",
|
|
||||||
user_id="user-1",
|
|
||||||
platform="test",
|
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
runtime_module.chat_manager,
|
|
||||||
"get_session_by_session_id",
|
|
||||||
lambda session_id: fake_chat_stream,
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(runtime_module.chat_manager, "get_session_name", lambda session_id: "测试会话")
|
|
||||||
monkeypatch.setattr(
|
|
||||||
runtime_module.ExpressionConfigUtils,
|
|
||||||
"get_expression_config_for_chat",
|
|
||||||
staticmethod(lambda session_id: (True, False, True)),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(runtime_module, "ExpressionLearner", lambda session_id: SimpleNamespace())
|
|
||||||
monkeypatch.setattr(runtime_module, "JargonMiner", lambda session_id, session_name: SimpleNamespace())
|
|
||||||
monkeypatch.setattr(runtime_module, "MaisakaReasoningEngine", lambda runtime: SimpleNamespace())
|
|
||||||
monkeypatch.setattr(runtime_module, "ToolRegistry", lambda: SimpleNamespace())
|
|
||||||
monkeypatch.setattr(runtime_module, "ReplyEffectTracker", lambda **kwargs: SimpleNamespace())
|
|
||||||
monkeypatch.setattr(MaisakaHeartFlowChatting, "_register_tool_providers", lambda self: None)
|
|
||||||
monkeypatch.setattr(MaisakaHeartFlowChatting, "_emit_monitor_session_start", lambda self: None)
|
|
||||||
|
|
||||||
runtime = MaisakaHeartFlowChatting("session-1")
|
|
||||||
|
|
||||||
assert runtime._enable_expression_use is True
|
|
||||||
assert runtime._enable_expression_learning is False
|
|
||||||
assert runtime._enable_jargon_learning is True
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeLLMResult:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.response = "测试回复"
|
|
||||||
self.reasoning = "先理解上下文,再给出自然回复。"
|
|
||||||
self.model_name = "fake-model"
|
|
||||||
self.tool_calls = []
|
|
||||||
self.prompt_tokens = 12
|
|
||||||
self.completion_tokens = 7
|
|
||||||
self.total_tokens = 19
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeLegacyLLMServiceClient:
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
||||||
del args
|
|
||||||
del kwargs
|
|
||||||
|
|
||||||
async def generate_response_with_messages(self, *, message_factory: Callable[[object], list[Any]]) -> _FakeLLMResult:
|
|
||||||
assert message_factory(object())
|
|
||||||
return _FakeLLMResult()
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeMultimodalLLMServiceClient:
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
||||||
del args
|
|
||||||
del kwargs
|
|
||||||
|
|
||||||
async def generate_response_with_messages(self, *, message_factory: Callable[[object], list[Any]]) -> _FakeLLMResult:
|
|
||||||
assert message_factory(object())
|
|
||||||
return _FakeLLMResult()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_legacy_and_multimodal_replyer_monitor_detail_have_same_shape(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
monkeypatch.setattr(replyer_module, "LLMServiceClient", _FakeLegacyLLMServiceClient)
|
|
||||||
monkeypatch.setattr(replyer_module, "load_prompt", lambda *args, **kwargs: "legacy prompt")
|
|
||||||
|
|
||||||
legacy_generator = replyer_module.MaisakaReplyGenerator(
|
|
||||||
chat_stream=None,
|
|
||||||
request_type="test_legacy",
|
|
||||||
enable_visual_message=False,
|
|
||||||
)
|
|
||||||
multimodal_generator = replyer_module.MaisakaReplyGenerator(
|
|
||||||
chat_stream=None,
|
|
||||||
request_type="test_multi",
|
|
||||||
llm_client_cls=_FakeMultimodalLLMServiceClient,
|
|
||||||
load_prompt_func=lambda *args, **kwargs: "multi prompt",
|
|
||||||
enable_visual_message=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
legacy_success, legacy_result = await legacy_generator.generate_reply_with_context(
|
|
||||||
stream_id="session-legacy",
|
|
||||||
chat_history=[],
|
|
||||||
reply_reason="测试原因",
|
|
||||||
)
|
|
||||||
multimodal_success, multimodal_result = await multimodal_generator.generate_reply_with_context(
|
|
||||||
stream_id="session-multi",
|
|
||||||
chat_history=[],
|
|
||||||
reply_reason="测试原因",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert legacy_success is True
|
|
||||||
assert multimodal_success is True
|
|
||||||
assert legacy_result.monitor_detail is not None
|
|
||||||
assert multimodal_result.monitor_detail is not None
|
|
||||||
assert set(legacy_result.monitor_detail.keys()) == set(multimodal_result.monitor_detail.keys())
|
|
||||||
assert set(legacy_result.monitor_detail["metrics"].keys()) == set(multimodal_result.monitor_detail["metrics"].keys())
|
|
||||||
assert legacy_result.monitor_detail["metrics"]["prompt_tokens"] == 12
|
|
||||||
assert legacy_result.monitor_detail["metrics"]["completion_tokens"] == 7
|
|
||||||
assert legacy_result.monitor_detail["metrics"]["total_tokens"] == 19
|
|
||||||
|
|
||||||
|
|
||||||
def test_legacy_replyer_builds_message_sequence_like_multimodal() -> None:
|
|
||||||
legacy_generator = replyer_module.MaisakaReplyGenerator(
|
|
||||||
chat_stream=None,
|
|
||||||
request_type="test_legacy",
|
|
||||||
enable_visual_message=False,
|
|
||||||
)
|
|
||||||
legacy_prompt_loader = replyer_module.load_prompt
|
|
||||||
replyer_module.load_prompt = lambda *args, **kwargs: "legacy prompt"
|
|
||||||
|
|
||||||
try:
|
|
||||||
session_message = replyer_module.SessionBackedMessage(
|
|
||||||
raw_message=SimpleNamespace(),
|
|
||||||
visible_text="[Alice]你好\n[Bob]在吗",
|
|
||||||
timestamp=replyer_module.datetime.now(),
|
|
||||||
source_kind="user",
|
|
||||||
)
|
|
||||||
request_messages = legacy_generator._build_request_messages(
|
|
||||||
chat_history=[session_message],
|
|
||||||
reply_message=None,
|
|
||||||
reply_reason="测试原因",
|
|
||||||
stream_id="session-legacy",
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
replyer_module.load_prompt = legacy_prompt_loader
|
|
||||||
|
|
||||||
assert len(request_messages) == 4
|
|
||||||
assert request_messages[0].role.value == "system"
|
|
||||||
assert request_messages[0].get_text_content() == "legacy prompt"
|
|
||||||
assert request_messages[1].role.value == "user"
|
|
||||||
assert request_messages[1].get_text_content() == "[Alice]你好"
|
|
||||||
assert request_messages[2].role.value == "user"
|
|
||||||
assert request_messages[2].get_text_content() == "[Bob]在吗"
|
|
||||||
assert request_messages[3].role.value == "user"
|
|
||||||
assert "当前时间:" in request_messages[3].get_text_content()
|
|
||||||
assert "【回复信息参考】" in request_messages[3].get_text_content()
|
|
||||||
assert "【最新推理】\n测试原因" in request_messages[3].get_text_content()
|
|
||||||
assert "请自然地回复。" in request_messages[3].get_text_content()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_reply_tool_puts_monitor_detail_into_metadata(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
fake_monitor_detail = {
|
|
||||||
"prompt_text": "reply prompt",
|
|
||||||
"reasoning_text": "reply reasoning",
|
|
||||||
"output_text": "reply output",
|
|
||||||
"metrics": {"model_name": "fake-model", "total_tokens": 10},
|
|
||||||
}
|
|
||||||
fake_reply_result = ReplyGenerationResult(
|
|
||||||
success=True,
|
|
||||||
completion=LLMCompletionResult(response_text="测试回复"),
|
|
||||||
metrics=GenerationMetrics(overall_ms=11.5),
|
|
||||||
monitor_detail=fake_monitor_detail,
|
|
||||||
)
|
|
||||||
|
|
||||||
class _FakeReplyer:
|
|
||||||
async def generate_reply_with_context(self, **kwargs: Any) -> tuple[bool, ReplyGenerationResult]:
|
|
||||||
del kwargs
|
|
||||||
return True, fake_reply_result
|
|
||||||
|
|
||||||
monkeypatch.setattr(reply_tool_module.replyer_manager, "get_replyer", lambda **kwargs: _FakeReplyer())
|
|
||||||
monkeypatch.setattr(reply_tool_module, "render_cli_message", lambda text: text)
|
|
||||||
|
|
||||||
target_message = SimpleNamespace(
|
|
||||||
message_id="msg-1",
|
|
||||||
message_info=SimpleNamespace(
|
|
||||||
user_info=SimpleNamespace(
|
|
||||||
user_cardname="测试用户",
|
|
||||||
user_nickname="测试用户",
|
|
||||||
user_id="user-1",
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
runtime = SimpleNamespace(
|
|
||||||
find_source_message_by_id=lambda message_id: target_message if message_id == "msg-1" else None,
|
|
||||||
log_prefix="[test]",
|
|
||||||
chat_stream=SimpleNamespace(platform=reply_tool_module.CLI_PLATFORM_NAME),
|
|
||||||
session_id="session-1",
|
|
||||||
_chat_history=[],
|
|
||||||
_clear_force_continue_until_reply=lambda: None,
|
|
||||||
_record_reply_sent=lambda: None,
|
|
||||||
run_sub_agent=None,
|
|
||||||
)
|
|
||||||
engine = SimpleNamespace(_get_runtime_manager=lambda: None)
|
|
||||||
tool_ctx = BuiltinToolRuntimeContext(engine=engine, runtime=runtime)
|
|
||||||
invocation = ToolInvocation(tool_name="reply", arguments={"msg_id": "msg-1", "set_quote": True})
|
|
||||||
|
|
||||||
result = await reply_tool_module.handle_tool(tool_ctx, invocation)
|
|
||||||
|
|
||||||
assert result.success is True
|
|
||||||
assert result.metadata["monitor_detail"] == fake_monitor_detail
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_send_emoji_tool_puts_monitor_detail_into_metadata(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
async def _fake_build_emoji_candidate_message(emojis: list[Any]) -> object:
|
|
||||||
assert emojis
|
|
||||||
return SimpleNamespace()
|
|
||||||
|
|
||||||
async def _fake_send_emoji_for_maisaka(**kwargs: Any) -> Any:
|
|
||||||
selected_emoji, matched_emotion = await kwargs["emoji_selector"](
|
|
||||||
kwargs["requested_emotion"],
|
|
||||||
kwargs["reasoning"],
|
|
||||||
kwargs["context_texts"],
|
|
||||||
2,
|
|
||||||
)
|
|
||||||
assert selected_emoji is not None
|
|
||||||
return SimpleNamespace(
|
|
||||||
success=True,
|
|
||||||
message="已发送表情包:开心",
|
|
||||||
emoji_base64="ZW1vamk=",
|
|
||||||
description="开心",
|
|
||||||
emotions=["开心", "可爱"],
|
|
||||||
matched_emotion=matched_emotion or "开心",
|
|
||||||
sent_message=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setattr(send_emoji_tool_module, "_build_emoji_candidate_message", _fake_build_emoji_candidate_message)
|
|
||||||
monkeypatch.setattr(send_emoji_tool_module, "send_emoji_for_maisaka", _fake_send_emoji_for_maisaka)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
send_emoji_tool_module.emoji_manager,
|
|
||||||
"emojis",
|
|
||||||
[
|
|
||||||
SimpleNamespace(description="开心,可爱", emotion=["开心", "可爱"]),
|
|
||||||
SimpleNamespace(description="难过", emotion=["难过"]),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _fake_run_sub_agent(**kwargs: Any) -> Any:
|
|
||||||
del kwargs
|
|
||||||
return SimpleNamespace(
|
|
||||||
content='{"emoji_index": 1, "reason": "更贴合当前语气"}',
|
|
||||||
prompt_tokens=9,
|
|
||||||
completion_tokens=6,
|
|
||||||
total_tokens=15,
|
|
||||||
)
|
|
||||||
|
|
||||||
runtime = SimpleNamespace(
|
|
||||||
_chat_history=[],
|
|
||||||
log_prefix="[test]",
|
|
||||||
session_id="session-emoji",
|
|
||||||
run_sub_agent=_fake_run_sub_agent,
|
|
||||||
)
|
|
||||||
engine = SimpleNamespace(last_reasoning_content="用户刚刚表达了开心情绪")
|
|
||||||
tool_ctx = BuiltinToolRuntimeContext(engine=engine, runtime=runtime)
|
|
||||||
invocation = ToolInvocation(tool_name="send_emoji", arguments={"emotion": "开心"})
|
|
||||||
|
|
||||||
result = await send_emoji_tool_module.handle_tool(tool_ctx, invocation)
|
|
||||||
|
|
||||||
assert result.success is True
|
|
||||||
assert result.metadata["monitor_detail"]["prompt_text"]
|
|
||||||
assert result.metadata["monitor_detail"]["reasoning_text"] == "更贴合当前语气"
|
|
||||||
assert result.metadata["monitor_detail"]["metrics"]["total_tokens"] == 15
|
|
||||||
assert any(
|
|
||||||
section["title"] == "表情发送结果"
|
|
||||||
for section in result.metadata["monitor_detail"]["extra_sections"]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_emit_planner_finalized_broadcasts_new_protocol(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
captured: dict[str, Any] = {}
|
|
||||||
|
|
||||||
async def _fake_broadcast(event: str, data: dict[str, Any]) -> None:
|
|
||||||
captured["event"] = event
|
|
||||||
captured["data"] = data
|
|
||||||
|
|
||||||
monkeypatch.setattr("src.maisaka.monitor_events._broadcast", _fake_broadcast)
|
|
||||||
|
|
||||||
await emit_planner_finalized(
|
|
||||||
session_id="session-1",
|
|
||||||
cycle_id=3,
|
|
||||||
timing_request_messages=[{"role": "user", "content": "先看看要不要继续"}],
|
|
||||||
timing_selected_history_count=3,
|
|
||||||
timing_tool_count=1,
|
|
||||||
timing_action="continue",
|
|
||||||
timing_content="继续",
|
|
||||||
timing_tool_calls=[SimpleNamespace(call_id="timing-call-1", func_name="continue", args={})],
|
|
||||||
timing_tool_results=["- continue [成功]: 继续执行"],
|
|
||||||
timing_prompt_tokens=40,
|
|
||||||
timing_completion_tokens=5,
|
|
||||||
timing_total_tokens=45,
|
|
||||||
timing_duration_ms=11.2,
|
|
||||||
planner_request_messages=[{"role": "user", "content": "你好"}],
|
|
||||||
planner_selected_history_count=5,
|
|
||||||
planner_tool_count=2,
|
|
||||||
planner_content="先查询再回复",
|
|
||||||
planner_tool_calls=[SimpleNamespace(call_id="call-1", func_name="reply", args={"msg_id": "m1"})],
|
|
||||||
planner_prompt_tokens=100,
|
|
||||||
planner_completion_tokens=30,
|
|
||||||
planner_total_tokens=130,
|
|
||||||
planner_duration_ms=88.5,
|
|
||||||
tools=[
|
|
||||||
{
|
|
||||||
"tool_call_id": "call-1",
|
|
||||||
"tool_name": "reply",
|
|
||||||
"tool_args": {"msg_id": "m1"},
|
|
||||||
"success": True,
|
|
||||||
"duration_ms": 22.0,
|
|
||||||
"summary": "- reply [成功]: 已回复",
|
|
||||||
"detail": {"output_text": "测试回复"},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
time_records={"planner": 0.1, "tool_calls": 0.2},
|
|
||||||
agent_state="stop",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert captured["event"] == "planner.finalized"
|
|
||||||
payload = captured["data"]
|
|
||||||
assert payload["timing_gate"]["result"]["action"] == "continue"
|
|
||||||
assert payload["timing_gate"]["result"]["tool_results"] == ["- continue [成功]: 继续执行"]
|
|
||||||
assert payload["request"]["messages"][0]["content"] == "你好"
|
|
||||||
assert payload["request"]["tool_count"] == 2
|
|
||||||
assert payload["planner"]["tool_calls"][0]["id"] == "call-1"
|
|
||||||
assert payload["tools"][0]["detail"]["output_text"] == "测试回复"
|
|
||||||
assert payload["final_state"]["agent_state"] == "stop"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_emit_planner_finalized_supports_timing_only_cycle(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
captured: dict[str, Any] = {}
|
|
||||||
|
|
||||||
async def _fake_broadcast(event: str, data: dict[str, Any]) -> None:
|
|
||||||
captured["event"] = event
|
|
||||||
captured["data"] = data
|
|
||||||
|
|
||||||
monkeypatch.setattr("src.maisaka.monitor_events._broadcast", _fake_broadcast)
|
|
||||||
|
|
||||||
await emit_planner_finalized(
|
|
||||||
session_id="session-2",
|
|
||||||
cycle_id=7,
|
|
||||||
timing_request_messages=[{"role": "user", "content": "先别回"}],
|
|
||||||
timing_selected_history_count=2,
|
|
||||||
timing_tool_count=1,
|
|
||||||
timing_action="no_reply",
|
|
||||||
timing_content="当前不适合继续",
|
|
||||||
timing_tool_calls=[SimpleNamespace(call_id="timing-call-2", func_name="no_reply", args={})],
|
|
||||||
timing_tool_results=["- no_reply [成功]: 暂停当前对话"],
|
|
||||||
timing_prompt_tokens=18,
|
|
||||||
timing_completion_tokens=4,
|
|
||||||
timing_total_tokens=22,
|
|
||||||
timing_duration_ms=6.5,
|
|
||||||
planner_request_messages=None,
|
|
||||||
planner_selected_history_count=None,
|
|
||||||
planner_tool_count=None,
|
|
||||||
planner_content=None,
|
|
||||||
planner_tool_calls=None,
|
|
||||||
planner_prompt_tokens=None,
|
|
||||||
planner_completion_tokens=None,
|
|
||||||
planner_total_tokens=None,
|
|
||||||
planner_duration_ms=None,
|
|
||||||
tools=[],
|
|
||||||
time_records={"timing_gate": 0.02},
|
|
||||||
agent_state="stop",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert captured["event"] == "planner.finalized"
|
|
||||||
payload = captured["data"]
|
|
||||||
assert payload["timing_gate"]["result"]["action"] == "no_reply"
|
|
||||||
assert payload["planner"] is None
|
|
||||||
assert payload["request"] is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_reasoning_engine_build_tool_monitor_result_keeps_non_reply_tool_without_detail() -> None:
|
|
||||||
engine = object.__new__(MaisakaReasoningEngine)
|
|
||||||
tool_call = SimpleNamespace(call_id="call-2", func_name="query_memory")
|
|
||||||
invocation = ToolInvocation(tool_name="query_memory", arguments={"query": "Alice"})
|
|
||||||
result = ToolExecutionResult(tool_name="query_memory", success=True, content="查询成功")
|
|
||||||
|
|
||||||
tool_result = engine._build_tool_monitor_result(tool_call, invocation, result, duration_ms=18.6)
|
|
||||||
|
|
||||||
assert tool_result["tool_call_id"] == "call-2"
|
|
||||||
assert tool_result["tool_name"] == "query_memory"
|
|
||||||
assert tool_result["tool_args"] == {"query": "Alice"}
|
|
||||||
assert tool_result["detail"] is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_runtime_build_tool_detail_panels_renders_reply_monitor_detail() -> None:
|
|
||||||
runtime = object.__new__(MaisakaHeartFlowChatting)
|
|
||||||
runtime.session_id = "session-1"
|
|
||||||
panels = runtime._build_tool_detail_cards(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"tool_call_id": "call-reply-1",
|
|
||||||
"tool_name": "reply",
|
|
||||||
"tool_args": {"msg_id": "m1"},
|
|
||||||
"success": True,
|
|
||||||
"duration_ms": 20.5,
|
|
||||||
"summary": "- reply [成功]: 已回复",
|
|
||||||
"detail": {
|
|
||||||
"prompt_text": "reply prompt",
|
|
||||||
"reasoning_text": "reply reasoning",
|
|
||||||
"output_text": "reply output",
|
|
||||||
"metrics": {
|
|
||||||
"model_name": "fake-model",
|
|
||||||
"prompt_tokens": 10,
|
|
||||||
"completion_tokens": 5,
|
|
||||||
"total_tokens": 15,
|
|
||||||
"prompt_ms": 2.1,
|
|
||||||
"llm_ms": 18.4,
|
|
||||||
"overall_ms": 20.5,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
stage_title="工具调用",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(panels) == 1
|
|
||||||
assert isinstance(panels[0], Panel)
|
|
||||||
|
|
||||||
|
|
||||||
def test_runtime_filter_redundant_tool_results_keeps_only_non_detailed_summary() -> None:
|
|
||||||
filtered_results = MaisakaHeartFlowChatting._filter_redundant_tool_results(
|
|
||||||
tool_results=[
|
|
||||||
"- reply [成功]: 已回复",
|
|
||||||
"- query_memory [成功]: 查询到 2 条记录",
|
|
||||||
],
|
|
||||||
tool_detail_results=[
|
|
||||||
{
|
|
||||||
"summary": "- reply [成功]: 已回复",
|
|
||||||
"detail": {"output_text": "测试回复"},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert filtered_results == ["- query_memory [成功]: 查询到 2 条记录"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_runtime_build_tool_detail_panels_uses_prompt_access_panel(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
runtime = object.__new__(MaisakaHeartFlowChatting)
|
|
||||||
runtime.session_id = "session-link"
|
|
||||||
captured: dict[str, Any] = {}
|
|
||||||
|
|
||||||
def _fake_build_text_access_panel(content: str, **kwargs: Any) -> str:
|
|
||||||
captured["content"] = content
|
|
||||||
captured["kwargs"] = kwargs
|
|
||||||
return "PROMPT_LINK"
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"src.maisaka.runtime.PromptCLIVisualizer.build_text_access_panel",
|
|
||||||
_fake_build_text_access_panel,
|
|
||||||
)
|
|
||||||
|
|
||||||
panels = runtime._build_tool_detail_cards(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"tool_call_id": "call-reply-2",
|
|
||||||
"tool_name": "reply",
|
|
||||||
"tool_args": {"msg_id": "m2"},
|
|
||||||
"success": True,
|
|
||||||
"duration_ms": 12.0,
|
|
||||||
"summary": "- reply [成功]: 已回复",
|
|
||||||
"detail": {
|
|
||||||
"prompt_text": "reply prompt link",
|
|
||||||
"output_text": "reply output",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
stage_title="工具调用",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(panels) == 1
|
|
||||||
assert captured["content"] == "reply prompt link"
|
|
||||||
assert captured["kwargs"]["chat_id"] == "session-link"
|
|
||||||
assert captured["kwargs"]["request_kind"] == "replyer"
|
|
||||||
|
|
||||||
|
|
||||||
def test_runtime_build_tool_detail_panels_uses_emotion_prompt_access_panel(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
runtime = object.__new__(MaisakaHeartFlowChatting)
|
|
||||||
runtime.session_id = "session-emotion"
|
|
||||||
captured: dict[str, Any] = {}
|
|
||||||
|
|
||||||
def _fake_build_text_access_panel(content: str, **kwargs: Any) -> str:
|
|
||||||
captured["content"] = content
|
|
||||||
captured["kwargs"] = kwargs
|
|
||||||
return "EMOTION_PROMPT_LINK"
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"src.maisaka.runtime.PromptCLIVisualizer.build_text_access_panel",
|
|
||||||
_fake_build_text_access_panel,
|
|
||||||
)
|
|
||||||
|
|
||||||
panels = runtime._build_tool_detail_cards(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"tool_call_id": "call-emoji-1",
|
|
||||||
"tool_name": "send_emoji",
|
|
||||||
"tool_args": {"emotion": "开心"},
|
|
||||||
"success": True,
|
|
||||||
"duration_ms": 15.0,
|
|
||||||
"summary": "- send_emoji [成功]: 已发送表情包",
|
|
||||||
"detail": {
|
|
||||||
"prompt_text": "emotion prompt link",
|
|
||||||
"output_text": '{"emoji_index": 1}',
|
|
||||||
},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
stage_title="工具调用",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(panels) == 1
|
|
||||||
assert captured["content"] == "emotion prompt link"
|
|
||||||
assert captured["kwargs"]["chat_id"] == "session-emotion"
|
|
||||||
assert captured["kwargs"]["request_kind"] == "emotion"
|
|
||||||
|
|
||||||
|
|
||||||
def test_runtime_build_tool_detail_cards_uses_structured_prompt_messages_with_images(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
runtime = object.__new__(MaisakaHeartFlowChatting)
|
|
||||||
runtime.session_id = "session-image"
|
|
||||||
captured: dict[str, Any] = {}
|
|
||||||
|
|
||||||
def _fake_build_prompt_access_panel(messages: list[Any], **kwargs: Any) -> str:
|
|
||||||
captured["messages"] = messages
|
|
||||||
captured["kwargs"] = kwargs
|
|
||||||
return "IMAGE_PROMPT_LINK"
|
|
||||||
|
|
||||||
def _fake_build_text_access_panel(content: str, **kwargs: Any) -> str:
|
|
||||||
captured["text_content"] = content
|
|
||||||
captured["text_kwargs"] = kwargs
|
|
||||||
return "TEXT_PROMPT_LINK"
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"src.maisaka.runtime.PromptCLIVisualizer.build_prompt_access_panel",
|
|
||||||
_fake_build_prompt_access_panel,
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"src.maisaka.runtime.PromptCLIVisualizer.build_text_access_panel",
|
|
||||||
_fake_build_text_access_panel,
|
|
||||||
)
|
|
||||||
|
|
||||||
panels = runtime._build_tool_detail_cards(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"tool_call_id": "call-reply-image-1",
|
|
||||||
"tool_name": "reply",
|
|
||||||
"tool_args": {"msg_id": "m3"},
|
|
||||||
"success": True,
|
|
||||||
"duration_ms": 22.0,
|
|
||||||
"summary": "- reply [成功]: 已回复",
|
|
||||||
"detail": {
|
|
||||||
"prompt_text": "reply prompt image",
|
|
||||||
"request_messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": ["前缀文本", ["png", "ZmFrZQ=="]],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"output_text": "reply output",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
stage_title="工具调用",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(panels) == 1
|
|
||||||
assert "messages" in captured
|
|
||||||
assert "text_content" not in captured
|
|
||||||
assert captured["kwargs"]["chat_id"] == "session-image"
|
|
||||||
assert captured["kwargs"]["request_kind"] == "replyer"
|
|
||||||
|
|
||||||
|
|
||||||
def test_runtime_render_context_usage_panel_merges_timing_and_planner(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
runtime = object.__new__(MaisakaHeartFlowChatting)
|
|
||||||
runtime.session_id = "session-merged"
|
|
||||||
runtime.session_name = "测试聊天流"
|
|
||||||
runtime._max_context_size = 20
|
|
||||||
|
|
||||||
printed: list[Any] = []
|
|
||||||
monkeypatch.setattr("src.maisaka.runtime.console.print", lambda renderable: printed.append(renderable))
|
|
||||||
|
|
||||||
runtime._render_context_usage_panel(
|
|
||||||
cycle_id=12,
|
|
||||||
timing_selected_history_count=3,
|
|
||||||
timing_prompt_tokens=15,
|
|
||||||
timing_action="continue",
|
|
||||||
timing_response="继续执行",
|
|
||||||
planner_selected_history_count=5,
|
|
||||||
planner_prompt_tokens=42,
|
|
||||||
planner_response="先查询再回复",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(printed) == 1
|
|
||||||
outer_panel = printed[0]
|
|
||||||
assert isinstance(outer_panel, Panel)
|
|
||||||
renderables = list(outer_panel.renderable.renderables)
|
|
||||||
assert isinstance(renderables[0], Text)
|
|
||||||
assert "聊天流名称:测试聊天流" in renderables[0].plain
|
|
||||||
assert "聊天流ID:session-merged" in renderables[0].plain
|
|
||||||
assert len(renderables) == 3
|
|
||||||
@@ -1,339 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.core.tooling import ToolAvailabilityContext, ToolExecutionResult, ToolInvocation
|
|
||||||
from src.llm_models.payload_content.tool_option import ToolCall
|
|
||||||
from src.maisaka.builtin_tool import get_timing_tools
|
|
||||||
from src.maisaka.chat_loop_service import ChatResponse, MaisakaChatLoopService
|
|
||||||
from src.maisaka.context_messages import AssistantMessage, TIMING_GATE_INVALID_TOOL_HINT_SOURCE
|
|
||||||
from src.maisaka.reasoning_engine import MaisakaReasoningEngine
|
|
||||||
from src.maisaka.runtime import MaisakaHeartFlowChatting
|
|
||||||
|
|
||||||
|
|
||||||
def _build_chat_response(tool_calls: list[ToolCall]) -> ChatResponse:
|
|
||||||
return ChatResponse(
|
|
||||||
content="The model returned an invalid timing tool.",
|
|
||||||
tool_calls=tool_calls,
|
|
||||||
request_messages=[],
|
|
||||||
raw_message=AssistantMessage(
|
|
||||||
content="",
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
source_kind="perception",
|
|
||||||
),
|
|
||||||
selected_history_count=1,
|
|
||||||
tool_count=len(tool_calls),
|
|
||||||
prompt_tokens=10,
|
|
||||||
built_message_count=1,
|
|
||||||
completion_tokens=3,
|
|
||||||
total_tokens=13,
|
|
||||||
prompt_section=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_runtime_stub(*, is_group_chat: bool) -> SimpleNamespace:
|
|
||||||
return SimpleNamespace(
|
|
||||||
_force_next_timing_continue=False,
|
|
||||||
_chat_history=[],
|
|
||||||
session_id="test-session",
|
|
||||||
chat_stream=SimpleNamespace(
|
|
||||||
session_id="test-session",
|
|
||||||
stream_id="test-stream",
|
|
||||||
is_group_session=is_group_chat,
|
|
||||||
group_id="group-1" if is_group_chat else "",
|
|
||||||
user_id="user-1",
|
|
||||||
platform="qq",
|
|
||||||
),
|
|
||||||
_chat_loop_service=SimpleNamespace(build_prompt_template_context=lambda: {}),
|
|
||||||
log_prefix="[test]",
|
|
||||||
stopped=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_timing_gate_tools_expose_wait_only_in_private_chat() -> None:
|
|
||||||
private_tool_names = {
|
|
||||||
tool_definition["name"]
|
|
||||||
for tool_definition in get_timing_tools(ToolAvailabilityContext(is_group_chat=False))
|
|
||||||
}
|
|
||||||
group_tool_names = {
|
|
||||||
tool_definition["name"]
|
|
||||||
for tool_definition in get_timing_tools(ToolAvailabilityContext(is_group_chat=True))
|
|
||||||
}
|
|
||||||
|
|
||||||
assert private_tool_names == {"continue", "no_reply", "wait"}
|
|
||||||
assert group_tool_names == {"continue", "no_reply"}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_timing_gate_invalid_tool_defaults_to_no_reply(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
runtime = _build_runtime_stub(is_group_chat=True)
|
|
||||||
|
|
||||||
def _enter_stop_state() -> None:
|
|
||||||
runtime.stopped = True
|
|
||||||
|
|
||||||
runtime._enter_stop_state = _enter_stop_state
|
|
||||||
engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
call_count = 0
|
|
||||||
|
|
||||||
async def _fake_timing_gate_sub_agent(**kwargs: object) -> ChatResponse:
|
|
||||||
nonlocal call_count
|
|
||||||
del kwargs
|
|
||||||
call_count += 1
|
|
||||||
return _build_chat_response([
|
|
||||||
ToolCall(call_id="invalid-timing-tool", func_name="finish", args={}),
|
|
||||||
])
|
|
||||||
|
|
||||||
async def _fail_invoke_tool_call(*args: object, **kwargs: object) -> None:
|
|
||||||
del args, kwargs
|
|
||||||
raise AssertionError("invalid timing tools must not be executed")
|
|
||||||
|
|
||||||
monkeypatch.setattr(engine, "_run_timing_gate_sub_agent", _fake_timing_gate_sub_agent)
|
|
||||||
monkeypatch.setattr(engine, "_invoke_tool_call", _fail_invoke_tool_call)
|
|
||||||
|
|
||||||
action, response, tool_results, tool_monitor_results = await engine._run_timing_gate(object()) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
assert action == "no_reply"
|
|
||||||
assert call_count == 3
|
|
||||||
assert response.tool_calls[0].func_name == "finish"
|
|
||||||
assert runtime.stopped is True
|
|
||||||
assert tool_monitor_results == []
|
|
||||||
assert len(runtime._chat_history) == 1
|
|
||||||
assert runtime._chat_history[0].source == TIMING_GATE_INVALID_TOOL_HINT_SOURCE
|
|
||||||
assert "finish" in runtime._chat_history[0].processed_plain_text
|
|
||||||
assert tool_results == [
|
|
||||||
"- retry [非法 Timing 工具]: 返回了 finish,将重试 (1/3)",
|
|
||||||
"- retry [非法 Timing 工具]: 返回了 finish,将重试 (2/3)",
|
|
||||||
"- no_reply [非法 Timing 工具]: 返回了 finish,已停止本轮并等待新消息",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_timing_gate_invalid_tool_retries_until_valid(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
runtime = _build_runtime_stub(is_group_chat=True)
|
|
||||||
|
|
||||||
def _enter_stop_state() -> None:
|
|
||||||
runtime.stopped = True
|
|
||||||
|
|
||||||
runtime._enter_stop_state = _enter_stop_state
|
|
||||||
engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type]
|
|
||||||
responses = [
|
|
||||||
_build_chat_response([ToolCall(call_id="invalid-timing-tool", func_name="finish", args={})]),
|
|
||||||
_build_chat_response([ToolCall(call_id="valid-timing-tool", func_name="continue", args={})]),
|
|
||||||
]
|
|
||||||
|
|
||||||
async def _fake_timing_gate_sub_agent(**kwargs: object) -> ChatResponse:
|
|
||||||
del kwargs
|
|
||||||
return responses.pop(0)
|
|
||||||
|
|
||||||
async def _fake_invoke_tool_call(
|
|
||||||
tool_call: ToolCall,
|
|
||||||
latest_thought: str,
|
|
||||||
anchor_message: object,
|
|
||||||
*,
|
|
||||||
append_history: bool = True,
|
|
||||||
store_record: bool = True,
|
|
||||||
) -> tuple[ToolInvocation, ToolExecutionResult, None]:
|
|
||||||
del latest_thought, anchor_message, append_history, store_record
|
|
||||||
return (
|
|
||||||
ToolInvocation(tool_name=tool_call.func_name, call_id=tool_call.call_id),
|
|
||||||
ToolExecutionResult(
|
|
||||||
tool_name=tool_call.func_name,
|
|
||||||
success=True,
|
|
||||||
content="继续执行主流程",
|
|
||||||
metadata={"timing_action": "continue"},
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setattr(engine, "_run_timing_gate_sub_agent", _fake_timing_gate_sub_agent)
|
|
||||||
monkeypatch.setattr(engine, "_invoke_tool_call", _fake_invoke_tool_call)
|
|
||||||
|
|
||||||
action, response, tool_results, tool_monitor_results = await engine._run_timing_gate(object()) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
assert action == "continue"
|
|
||||||
assert response.tool_calls[0].func_name == "continue"
|
|
||||||
assert runtime.stopped is False
|
|
||||||
assert len(runtime._chat_history) == 2
|
|
||||||
assert all(message.source != TIMING_GATE_INVALID_TOOL_HINT_SOURCE for message in runtime._chat_history)
|
|
||||||
assert tool_results == [
|
|
||||||
"- retry [非法 Timing 工具]: 返回了 finish,将重试 (1/3)",
|
|
||||||
"- continue [成功]: 继续执行主流程",
|
|
||||||
]
|
|
||||||
assert tool_monitor_results[0]["tool_name"] == "continue"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_timing_gate_group_chat_treats_wait_as_invalid(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
runtime = _build_runtime_stub(is_group_chat=True)
|
|
||||||
|
|
||||||
def _enter_stop_state() -> None:
|
|
||||||
runtime.stopped = True
|
|
||||||
|
|
||||||
runtime._enter_stop_state = _enter_stop_state
|
|
||||||
engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
async def _fake_timing_gate_sub_agent(**kwargs: object) -> ChatResponse:
|
|
||||||
tool_definitions = kwargs["tool_definitions"]
|
|
||||||
assert {tool_definition["name"] for tool_definition in tool_definitions} == {"continue", "no_reply"}
|
|
||||||
return _build_chat_response([
|
|
||||||
ToolCall(call_id="disabled-wait", func_name="wait", args={"seconds": 3}),
|
|
||||||
])
|
|
||||||
|
|
||||||
async def _fail_invoke_tool_call(*args: object, **kwargs: object) -> None:
|
|
||||||
del args, kwargs
|
|
||||||
raise AssertionError("群聊中禁用的 wait 不应被执行")
|
|
||||||
|
|
||||||
monkeypatch.setattr(engine, "_run_timing_gate_sub_agent", _fake_timing_gate_sub_agent)
|
|
||||||
monkeypatch.setattr(engine, "_invoke_tool_call", _fail_invoke_tool_call)
|
|
||||||
|
|
||||||
action, _, tool_results, _ = await engine._run_timing_gate(object()) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
assert action == "no_reply"
|
|
||||||
assert runtime.stopped is True
|
|
||||||
assert tool_results[-1] == "- no_reply [非法 Timing 工具]: 返回了 wait,已停止本轮并等待新消息"
|
|
||||||
|
|
||||||
|
|
||||||
def test_timing_gate_invalid_tool_hint_keeps_only_latest() -> None:
|
|
||||||
old_hint = SimpleNamespace(source=TIMING_GATE_INVALID_TOOL_HINT_SOURCE)
|
|
||||||
runtime = SimpleNamespace(_chat_history=[old_hint])
|
|
||||||
engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
engine._append_timing_gate_invalid_tool_hint("finish")
|
|
||||||
engine._append_timing_gate_invalid_tool_hint("reply")
|
|
||||||
|
|
||||||
assert len(runtime._chat_history) == 1
|
|
||||||
hint_message = runtime._chat_history[0]
|
|
||||||
assert hint_message.source == TIMING_GATE_INVALID_TOOL_HINT_SOURCE
|
|
||||||
assert "reply" in hint_message.processed_plain_text
|
|
||||||
assert "finish" not in hint_message.processed_plain_text
|
|
||||||
|
|
||||||
|
|
||||||
def test_timing_gate_invalid_tool_hint_only_visible_to_timing_gate() -> None:
|
|
||||||
runtime = SimpleNamespace(_chat_history=[])
|
|
||||||
engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type]
|
|
||||||
engine._append_timing_gate_invalid_tool_hint("finish")
|
|
||||||
hint_message = runtime._chat_history[0]
|
|
||||||
|
|
||||||
timing_history = MaisakaChatLoopService._filter_history_for_request_kind(
|
|
||||||
[hint_message],
|
|
||||||
request_kind="timing_gate",
|
|
||||||
)
|
|
||||||
planner_history = MaisakaChatLoopService._filter_history_for_request_kind(
|
|
||||||
[hint_message],
|
|
||||||
request_kind="planner",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert timing_history == [hint_message]
|
|
||||||
assert planner_history == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_forced_timing_trigger_bypasses_message_frequency_threshold() -> None:
|
|
||||||
runtime = SimpleNamespace(
|
|
||||||
_STATE_WAIT="wait",
|
|
||||||
_agent_state="stop",
|
|
||||||
_message_turn_scheduled=False,
|
|
||||||
_internal_turn_queue=asyncio.Queue(),
|
|
||||||
_has_pending_messages=lambda: True,
|
|
||||||
_get_pending_message_count=lambda: 1,
|
|
||||||
_has_forced_timing_trigger=lambda: True,
|
|
||||||
_cancel_deferred_message_turn_task=lambda: None,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _fail_get_message_trigger_threshold() -> int:
|
|
||||||
raise AssertionError("@/提及必回不应被普通聊天频率阈值拦住")
|
|
||||||
|
|
||||||
runtime._get_message_trigger_threshold = _fail_get_message_trigger_threshold
|
|
||||||
|
|
||||||
MaisakaHeartFlowChatting._schedule_message_turn(runtime) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
assert runtime._message_turn_scheduled is True
|
|
||||||
assert runtime._internal_turn_queue.get_nowait() == "message"
|
|
||||||
|
|
||||||
|
|
||||||
def test_finish_tool_is_not_written_back_to_history() -> None:
|
|
||||||
finish_call = ToolCall(call_id="finish-call", func_name="finish", args={})
|
|
||||||
reply_call = ToolCall(call_id="reply-call", func_name="reply", args={})
|
|
||||||
assistant_message = AssistantMessage(
|
|
||||||
content="当前不需要继续回复。",
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
tool_calls=[finish_call, reply_call],
|
|
||||||
)
|
|
||||||
runtime = SimpleNamespace(_chat_history=[assistant_message])
|
|
||||||
engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
engine._append_tool_execution_result(
|
|
||||||
finish_call,
|
|
||||||
ToolExecutionResult(
|
|
||||||
tool_name="finish",
|
|
||||||
success=True,
|
|
||||||
content="当前对话循环已结束本轮思考,等待新的消息到来。",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert runtime._chat_history == [assistant_message]
|
|
||||||
assert [tool_call.func_name for tool_call in assistant_message.tool_calls] == ["reply"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_finish_tool_removes_empty_assistant_history_message() -> None:
|
|
||||||
finish_call = ToolCall(call_id="finish-call", func_name="finish", args={})
|
|
||||||
assistant_message = AssistantMessage(
|
|
||||||
content="",
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
tool_calls=[finish_call],
|
|
||||||
)
|
|
||||||
runtime = SimpleNamespace(_chat_history=[assistant_message])
|
|
||||||
engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
engine._append_tool_execution_result(
|
|
||||||
finish_call,
|
|
||||||
ToolExecutionResult(tool_name="finish", success=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert runtime._chat_history == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_timing_gate_head_trim_keeps_short_history() -> None:
|
|
||||||
messages = [
|
|
||||||
AssistantMessage(content="第一条消息", timestamp=datetime.now()),
|
|
||||||
AssistantMessage(content="第二条消息", timestamp=datetime.now()),
|
|
||||||
]
|
|
||||||
|
|
||||||
trimmed_messages = MaisakaHeartFlowChatting._drop_head_context_messages(
|
|
||||||
messages,
|
|
||||||
drop_context_count=3,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert trimmed_messages == messages
|
|
||||||
|
|
||||||
|
|
||||||
def test_timing_gate_head_trim_keeps_history_within_config_limit() -> None:
|
|
||||||
messages = [
|
|
||||||
AssistantMessage(content=f"消息 {index}", timestamp=datetime.now())
|
|
||||||
for index in range(10)
|
|
||||||
]
|
|
||||||
|
|
||||||
trimmed_messages = MaisakaHeartFlowChatting._drop_head_context_messages(
|
|
||||||
messages,
|
|
||||||
drop_context_count=7,
|
|
||||||
trim_threshold_context_count=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert trimmed_messages == messages
|
|
||||||
|
|
||||||
|
|
||||||
def test_timing_gate_head_trim_applies_after_config_limit_exceeded() -> None:
|
|
||||||
messages = [
|
|
||||||
AssistantMessage(content=f"消息 {index}", timestamp=datetime.now())
|
|
||||||
for index in range(11)
|
|
||||||
]
|
|
||||||
|
|
||||||
trimmed_messages = MaisakaHeartFlowChatting._drop_head_context_messages(
|
|
||||||
messages,
|
|
||||||
drop_context_count=7,
|
|
||||||
trim_threshold_context_count=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert trimmed_messages == messages[7:]
|
|
||||||
@@ -1,170 +0,0 @@
|
|||||||
"""消息网关运行时状态同步测试。"""
|
|
||||||
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.platform_io.manager import PlatformIOManager
|
|
||||||
from src.platform_io.types import RouteKey
|
|
||||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
|
||||||
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
|
|
||||||
|
|
||||||
|
|
||||||
def _make_request(method: str, plugin_id: str, payload: Dict[str, Any]) -> Envelope:
|
|
||||||
"""构造一个 RPC 请求信封。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
method: RPC 方法名。
|
|
||||||
plugin_id: 目标插件 ID。
|
|
||||||
payload: 请求载荷。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Envelope: 标准 RPC 请求信封。
|
|
||||||
"""
|
|
||||||
|
|
||||||
return Envelope(
|
|
||||||
request_id=1,
|
|
||||||
message_type=MessageType.REQUEST,
|
|
||||||
method=method,
|
|
||||||
plugin_id=plugin_id,
|
|
||||||
payload=payload,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_message_gateway_runtime_state_binds_send_and_receive_routes(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
"""消息网关就绪后应同时绑定发送表和接收表。"""
|
|
||||||
|
|
||||||
import src.plugin_runtime.host.supervisor as supervisor_module
|
|
||||||
|
|
||||||
platform_io_manager = PlatformIOManager()
|
|
||||||
monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager)
|
|
||||||
|
|
||||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
|
||||||
register_response = await supervisor._handle_register_plugin(
|
|
||||||
_make_request(
|
|
||||||
"plugin.register_components",
|
|
||||||
"napcat_plugin",
|
|
||||||
{
|
|
||||||
"plugin_id": "napcat_plugin",
|
|
||||||
"plugin_version": "1.0.0",
|
|
||||||
"components": [
|
|
||||||
{
|
|
||||||
"name": "napcat_gateway",
|
|
||||||
"component_type": "MESSAGE_GATEWAY",
|
|
||||||
"plugin_id": "napcat_plugin",
|
|
||||||
"metadata": {
|
|
||||||
"route_type": "duplex",
|
|
||||||
"platform": "qq",
|
|
||||||
"protocol": "napcat",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"capabilities_required": [],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
assert register_response.error is None
|
|
||||||
response = await supervisor._handle_update_message_gateway_state(
|
|
||||||
_make_request(
|
|
||||||
"host.update_message_gateway_state",
|
|
||||||
"napcat_plugin",
|
|
||||||
{
|
|
||||||
"gateway_name": "napcat_gateway",
|
|
||||||
"ready": True,
|
|
||||||
"platform": "qq",
|
|
||||||
"account_id": "10001",
|
|
||||||
"scope": "primary",
|
|
||||||
"metadata": {},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.error is None
|
|
||||||
assert response.payload["accepted"] is True
|
|
||||||
|
|
||||||
send_bindings = platform_io_manager.send_route_table.resolve_bindings(
|
|
||||||
RouteKey(platform="qq", account_id="10001", scope="primary")
|
|
||||||
)
|
|
||||||
receive_bindings = platform_io_manager.receive_route_table.resolve_bindings(
|
|
||||||
RouteKey(platform="qq", account_id="10001", scope="primary")
|
|
||||||
)
|
|
||||||
|
|
||||||
assert [binding.driver_id for binding in send_bindings] == ["gateway:napcat_plugin:napcat_gateway"]
|
|
||||||
assert [binding.driver_id for binding in receive_bindings] == ["gateway:napcat_plugin:napcat_gateway"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_message_gateway_runtime_state_unbinds_routes_when_not_ready(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
"""消息网关断开后应撤销发送表和接收表中的绑定。"""
|
|
||||||
|
|
||||||
import src.plugin_runtime.host.supervisor as supervisor_module
|
|
||||||
|
|
||||||
platform_io_manager = PlatformIOManager()
|
|
||||||
monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager)
|
|
||||||
|
|
||||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
|
||||||
await supervisor._handle_register_plugin(
|
|
||||||
_make_request(
|
|
||||||
"plugin.register_components",
|
|
||||||
"napcat_plugin",
|
|
||||||
{
|
|
||||||
"plugin_id": "napcat_plugin",
|
|
||||||
"plugin_version": "1.0.0",
|
|
||||||
"components": [
|
|
||||||
{
|
|
||||||
"name": "napcat_gateway",
|
|
||||||
"component_type": "MESSAGE_GATEWAY",
|
|
||||||
"plugin_id": "napcat_plugin",
|
|
||||||
"metadata": {
|
|
||||||
"route_type": "duplex",
|
|
||||||
"platform": "qq",
|
|
||||||
"protocol": "napcat",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"capabilities_required": [],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
await supervisor._handle_update_message_gateway_state(
|
|
||||||
_make_request(
|
|
||||||
"host.update_message_gateway_state",
|
|
||||||
"napcat_plugin",
|
|
||||||
{
|
|
||||||
"gateway_name": "napcat_gateway",
|
|
||||||
"ready": True,
|
|
||||||
"platform": "qq",
|
|
||||||
"account_id": "10001",
|
|
||||||
"scope": "primary",
|
|
||||||
"metadata": {},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
response = await supervisor._handle_update_message_gateway_state(
|
|
||||||
_make_request(
|
|
||||||
"host.update_message_gateway_state",
|
|
||||||
"napcat_plugin",
|
|
||||||
{
|
|
||||||
"gateway_name": "napcat_gateway",
|
|
||||||
"ready": False,
|
|
||||||
"platform": "qq",
|
|
||||||
"account_id": "",
|
|
||||||
"scope": "",
|
|
||||||
"metadata": {},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.error is None
|
|
||||||
assert response.payload["accepted"] is True
|
|
||||||
assert platform_io_manager.send_route_table.resolve_bindings(RouteKey(platform="qq", account_id="10001")) == []
|
|
||||||
assert (
|
|
||||||
platform_io_manager.receive_route_table.resolve_bindings(RouteKey(platform="qq", account_id="10001")) == []
|
|
||||||
)
|
|
||||||
@@ -1,883 +0,0 @@
|
|||||||
"""NapCat 插件与新 SDK 对接测试。"""
|
|
||||||
|
|
||||||
from importlib import import_module, util
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Tuple
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
|
||||||
PLUGINS_ROOT = PROJECT_ROOT / "plugins"
|
|
||||||
PLUGIN_TEMPLATE_ROOT = PROJECT_ROOT / "plugin-templates"
|
|
||||||
SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk"
|
|
||||||
NAPCAT_PLUGIN_DIR = PLUGINS_ROOT / "MaiBot-Napcat-Adapter"
|
|
||||||
NAPCAT_TEMPLATE_DIR = PLUGIN_TEMPLATE_ROOT / "MaiBot-Napcat-Adapter"
|
|
||||||
NAPCAT_TEST_MODULE = "_test_napcat_adapter"
|
|
||||||
|
|
||||||
for import_path in (str(SDK_ROOT),):
|
|
||||||
if import_path not in sys.path:
|
|
||||||
sys.path.insert(0, import_path)
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeGatewayCapability:
|
|
||||||
"""用于捕获消息网关状态上报的测试替身。"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
"""初始化测试替身。"""
|
|
||||||
|
|
||||||
self.calls: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
async def update_state(
|
|
||||||
self,
|
|
||||||
gateway_name: str,
|
|
||||||
*,
|
|
||||||
ready: bool,
|
|
||||||
platform: str = "",
|
|
||||||
account_id: str = "",
|
|
||||||
scope: str = "",
|
|
||||||
metadata: Dict[str, Any] | None = None,
|
|
||||||
) -> bool:
|
|
||||||
"""记录一次状态上报请求。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
gateway_name: 网关组件名称。
|
|
||||||
ready: 当前是否就绪。
|
|
||||||
platform: 平台名称。
|
|
||||||
account_id: 账号 ID。
|
|
||||||
scope: 路由作用域。
|
|
||||||
metadata: 附加元数据。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 始终返回 ``True``,模拟 Host 接受状态更新。
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.calls.append(
|
|
||||||
{
|
|
||||||
"gateway_name": gateway_name,
|
|
||||||
"ready": ready,
|
|
||||||
"platform": platform,
|
|
||||||
"account_id": account_id,
|
|
||||||
"scope": scope,
|
|
||||||
"metadata": metadata or {},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeNapCatQueryService:
|
|
||||||
"""用于驱动 NapCat 入站编解码测试的查询服务替身。"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
forward_payloads: Dict[str, Any] | None = None,
|
|
||||||
group_member_payloads: Dict[tuple[str, str], Dict[str, Any] | None] | None = None,
|
|
||||||
stranger_payloads: Dict[str, Dict[str, Any] | None] | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""初始化查询服务替身。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
forward_payloads: 预置的合并转发响应映射。
|
|
||||||
group_member_payloads: 预置的群成员资料映射。
|
|
||||||
stranger_payloads: 预置的陌生人资料映射。
|
|
||||||
"""
|
|
||||||
self._forward_payloads = forward_payloads or {}
|
|
||||||
self._group_member_payloads = group_member_payloads or {}
|
|
||||||
self._stranger_payloads = stranger_payloads or {}
|
|
||||||
|
|
||||||
async def download_binary(self, url: str) -> bytes | None:
|
|
||||||
"""模拟下载远程二进制资源。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url: 资源地址。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bytes | None: 测试中默认不返回二进制内容。
|
|
||||||
"""
|
|
||||||
del url
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_message_detail(self, message_id: str) -> Dict[str, Any] | None:
|
|
||||||
"""模拟获取消息详情。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message_id: 消息 ID。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, Any] | None: 测试中默认不返回详情。
|
|
||||||
"""
|
|
||||||
del message_id
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_forward_message(
|
|
||||||
self,
|
|
||||||
message_id: str | None = None,
|
|
||||||
forward_id: str | None = None,
|
|
||||||
) -> Any:
|
|
||||||
"""模拟获取合并转发消息详情。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message_id: 转发消息 ID。
|
|
||||||
forward_id: 兼容字段 ``id``。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: 预置的合并转发消息详情。
|
|
||||||
"""
|
|
||||||
return self._forward_payloads.get(forward_id or message_id or "")
|
|
||||||
|
|
||||||
async def get_group_member_info(
|
|
||||||
self,
|
|
||||||
group_id: str,
|
|
||||||
user_id: str,
|
|
||||||
no_cache: bool = True,
|
|
||||||
) -> Dict[str, Any] | None:
|
|
||||||
"""模拟获取群成员资料。"""
|
|
||||||
del no_cache
|
|
||||||
return self._group_member_payloads.get((group_id, user_id))
|
|
||||||
|
|
||||||
async def get_stranger_info(self, user_id: str, no_cache: bool = False) -> Dict[str, Any] | None:
|
|
||||||
"""模拟获取 QQ 昵称资料。"""
|
|
||||||
del no_cache
|
|
||||||
return self._stranger_payloads.get(user_id)
|
|
||||||
|
|
||||||
async def get_record_detail(
|
|
||||||
self,
|
|
||||||
file_name: str | None = None,
|
|
||||||
file_id: str | None = None,
|
|
||||||
out_format: str = "wav",
|
|
||||||
) -> Dict[str, Any] | None:
|
|
||||||
"""模拟获取语音详情。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_name: 文件名。
|
|
||||||
file_id: 文件 ID。
|
|
||||||
out_format: 输出格式。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, Any] | None: 测试中默认不返回语音详情。
|
|
||||||
"""
|
|
||||||
del file_name
|
|
||||||
del file_id
|
|
||||||
del out_format
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeNapCatActionService:
|
|
||||||
"""用于驱动 NapCat 查询服务测试的动作服务替身。"""
|
|
||||||
|
|
||||||
def __init__(self, response_data: Any) -> None:
|
|
||||||
"""初始化动作服务替身。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response_data: 预置的 ``safe_call_action_data`` 返回值。
|
|
||||||
"""
|
|
||||||
self._response_data = response_data
|
|
||||||
self.action_calls: List[Dict[str, Any]] = []
|
|
||||||
self.action_data_calls: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
async def safe_call_action_data(self, action_name: str, params: Dict[str, Any]) -> Any:
|
|
||||||
"""模拟安全调用 OneBot 动作。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action_name: 动作名称。
|
|
||||||
params: 动作参数。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: 预置返回值。
|
|
||||||
"""
|
|
||||||
self.action_data_calls.append({"action_name": action_name, "params": dict(params)})
|
|
||||||
return self._response_data
|
|
||||||
|
|
||||||
async def call_action(self, action_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""模拟调用 OneBot 动作并记录参数。"""
|
|
||||||
|
|
||||||
self.action_calls.append({"action_name": action_name, "params": dict(params)})
|
|
||||||
return {"status": "ok", "retcode": 0, "data": {}}
|
|
||||||
|
|
||||||
|
|
||||||
def _load_napcat_sdk_modules() -> Tuple[Any, Any, Any, Any]:
|
|
||||||
"""动态加载 NapCat 插件测试所需的模块。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[Any, Any, Any, Any]:
|
|
||||||
依次返回常量模块、配置模块、插件模块和运行时状态模块。
|
|
||||||
"""
|
|
||||||
|
|
||||||
plugin_dir = NAPCAT_PLUGIN_DIR if NAPCAT_PLUGIN_DIR.is_dir() else NAPCAT_TEMPLATE_DIR
|
|
||||||
|
|
||||||
if NAPCAT_TEST_MODULE not in sys.modules:
|
|
||||||
plugin_path = plugin_dir / "plugin.py"
|
|
||||||
spec = util.spec_from_file_location(
|
|
||||||
NAPCAT_TEST_MODULE,
|
|
||||||
plugin_path,
|
|
||||||
submodule_search_locations=[str(plugin_dir)],
|
|
||||||
)
|
|
||||||
if spec is None or spec.loader is None:
|
|
||||||
raise ImportError(f"无法为 NapCat 插件创建模块规格: {plugin_path}")
|
|
||||||
|
|
||||||
module = util.module_from_spec(spec)
|
|
||||||
sys.modules[NAPCAT_TEST_MODULE] = module
|
|
||||||
try:
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
except Exception:
|
|
||||||
sys.modules.pop(NAPCAT_TEST_MODULE, None)
|
|
||||||
raise
|
|
||||||
|
|
||||||
return (
|
|
||||||
import_module(f"{NAPCAT_TEST_MODULE}.constants"),
|
|
||||||
import_module(f"{NAPCAT_TEST_MODULE}.config"),
|
|
||||||
import_module(f"{NAPCAT_TEST_MODULE}.plugin"),
|
|
||||||
import_module(f"{NAPCAT_TEST_MODULE}.runtime_state"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_napcat_sdk_symbols() -> Tuple[Any, Any, Any, Any]:
|
|
||||||
"""动态加载 NapCat 插件测试所需的符号。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[Any, Any, Any, Any]:
|
|
||||||
依次返回网关名常量、配置类、插件类和运行时状态管理器类。
|
|
||||||
"""
|
|
||||||
|
|
||||||
constants_module, config_module, plugin_module, runtime_state_module = _load_napcat_sdk_modules()
|
|
||||||
return (
|
|
||||||
constants_module.NAPCAT_GATEWAY_NAME,
|
|
||||||
config_module.NapCatServerConfig,
|
|
||||||
plugin_module.NapCatAdapterPlugin,
|
|
||||||
runtime_state_module.NapCatRuntimeStateManager,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_napcat_inbound_codec_cls() -> Any:
|
|
||||||
"""动态加载 NapCat 入站编解码器类。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: ``NapCatInboundCodec`` 类对象。
|
|
||||||
"""
|
|
||||||
_load_napcat_sdk_modules()
|
|
||||||
codec_module = import_module(f"{NAPCAT_TEST_MODULE}.codecs.inbound.message_codec")
|
|
||||||
return codec_module.NapCatInboundCodec
|
|
||||||
|
|
||||||
|
|
||||||
def _load_napcat_query_service_cls() -> Any:
|
|
||||||
"""动态加载 NapCat 查询服务类。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: ``NapCatQueryService`` 类对象。
|
|
||||||
"""
|
|
||||||
_load_napcat_sdk_modules()
|
|
||||||
query_service_module = import_module(f"{NAPCAT_TEST_MODULE}.services.query_service")
|
|
||||||
return query_service_module.NapCatQueryService
|
|
||||||
|
|
||||||
|
|
||||||
def test_napcat_plugin_collects_duplex_message_gateway() -> None:
|
|
||||||
"""NapCat 插件应声明新的双工消息网关组件。"""
|
|
||||||
|
|
||||||
napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols()
|
|
||||||
plugin = napcat_plugin_cls()
|
|
||||||
components = plugin.get_components()
|
|
||||||
gateway_components = [
|
|
||||||
component
|
|
||||||
for component in components
|
|
||||||
if component.get("type") == "MESSAGE_GATEWAY"
|
|
||||||
]
|
|
||||||
|
|
||||||
assert len(gateway_components) == 1
|
|
||||||
gateway_component = gateway_components[0]
|
|
||||||
assert gateway_component["name"] == napcat_gateway_name
|
|
||||||
assert gateway_component["metadata"]["route_type"] == "duplex"
|
|
||||||
assert gateway_component["metadata"]["platform"] == "qq"
|
|
||||||
assert gateway_component["metadata"]["protocol"] == "napcat"
|
|
||||||
|
|
||||||
|
|
||||||
def test_napcat_plugin_uses_sdk_config_model() -> None:
|
|
||||||
"""NapCat 插件应声明 SDK 配置模型并暴露默认配置与 Schema。"""
|
|
||||||
|
|
||||||
constants_module, _config_module, plugin_module, _runtime_state_module = _load_napcat_sdk_modules()
|
|
||||||
plugin = plugin_module.NapCatAdapterPlugin()
|
|
||||||
|
|
||||||
default_config = plugin.get_default_config()
|
|
||||||
schema = plugin.get_webui_config_schema(plugin_id="maibot-team.napcat-adapter")
|
|
||||||
|
|
||||||
assert default_config["plugin"]["config_version"] == constants_module.SUPPORTED_CONFIG_VERSION
|
|
||||||
assert default_config["chat"]["ban_qq_bot"] is False
|
|
||||||
assert default_config["filters"]["ignore_self_message"] is True
|
|
||||||
assert schema["plugin_id"] == "maibot-team.napcat-adapter"
|
|
||||||
assert schema["sections"]["chat"]["fields"]["group_list"]["type"] == "array"
|
|
||||||
assert schema["sections"]["chat"]["fields"]["group_list_type"]["choices"] == ["whitelist", "blacklist"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_napcat_plugin_normalizes_legacy_config_values() -> None:
|
|
||||||
"""NapCat 插件应兼容旧配置字段并输出规范化结果。"""
|
|
||||||
|
|
||||||
constants_module, _config_module, plugin_module, _runtime_state_module = _load_napcat_sdk_modules()
|
|
||||||
plugin = plugin_module.NapCatAdapterPlugin()
|
|
||||||
|
|
||||||
plugin.set_plugin_config(
|
|
||||||
{
|
|
||||||
"plugin": {"enabled": True, "config_version": constants_module.SUPPORTED_CONFIG_VERSION},
|
|
||||||
"connection": {
|
|
||||||
"access_token": "secret-token",
|
|
||||||
"heartbeat_sec": "45",
|
|
||||||
"ws_url": "ws://10.0.0.8:3012/onebot/v11/ws",
|
|
||||||
},
|
|
||||||
"chat": {
|
|
||||||
"ban_qq_bot": True,
|
|
||||||
"ban_user_id": ["42", 42, ""],
|
|
||||||
"group_list": [123, " 456 ", None, "123"],
|
|
||||||
"group_list_type": "whitelist",
|
|
||||||
"private_list": "invalid",
|
|
||||||
"private_list_type": "unexpected",
|
|
||||||
},
|
|
||||||
"filters": {"ignore_self_message": True},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
config_data = plugin.get_plugin_config_data()
|
|
||||||
|
|
||||||
assert "connection" not in config_data
|
|
||||||
assert config_data["plugin"]["config_version"] == constants_module.SUPPORTED_CONFIG_VERSION
|
|
||||||
assert config_data["napcat_server"]["host"] == "10.0.0.8"
|
|
||||||
assert config_data["napcat_server"]["port"] == 3012
|
|
||||||
assert config_data["napcat_server"]["token"] == "secret-token"
|
|
||||||
assert config_data["napcat_server"]["heartbeat_interval"] == 45.0
|
|
||||||
assert config_data["chat"]["group_list"] == ["123", "456"]
|
|
||||||
assert config_data["chat"]["private_list"] == []
|
|
||||||
assert config_data["chat"]["private_list_type"] == constants_module.DEFAULT_CHAT_LIST_TYPE
|
|
||||||
assert plugin.config.napcat_server.build_ws_url() == "ws://10.0.0.8:3012"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_runtime_state_reports_via_gateway_capability() -> None:
|
|
||||||
"""NapCat 运行时状态应通过新的消息网关能力上报。"""
|
|
||||||
|
|
||||||
napcat_gateway_name, napcat_server_config_cls, _napcat_plugin_cls, runtime_state_cls = _load_napcat_sdk_symbols()
|
|
||||||
gateway_capability = _FakeGatewayCapability()
|
|
||||||
runtime_state_manager = runtime_state_cls(
|
|
||||||
gateway_capability=gateway_capability,
|
|
||||||
logger=logging.getLogger("test.napcat_adapter"),
|
|
||||||
gateway_name=napcat_gateway_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
connected = await runtime_state_manager.report_connected(
|
|
||||||
"10001",
|
|
||||||
napcat_server_config_cls(connection_id="primary"),
|
|
||||||
)
|
|
||||||
await runtime_state_manager.report_disconnected()
|
|
||||||
|
|
||||||
assert connected is True
|
|
||||||
assert gateway_capability.calls[0]["gateway_name"] == napcat_gateway_name
|
|
||||||
assert gateway_capability.calls[0]["ready"] is True
|
|
||||||
assert gateway_capability.calls[0]["platform"] == "qq"
|
|
||||||
assert gateway_capability.calls[0]["account_id"] == "10001"
|
|
||||||
assert gateway_capability.calls[0]["scope"] == "primary"
|
|
||||||
assert gateway_capability.calls[1]["gateway_name"] == napcat_gateway_name
|
|
||||||
assert gateway_capability.calls[1]["ready"] is False
|
|
||||||
assert gateway_capability.calls[1]["platform"] == "qq"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_napcat_plugin_send_result_contains_message_id_echo_callback() -> None:
|
|
||||||
"""NapCat 插件发送成功后应显式返回消息 ID 回调数据。"""
|
|
||||||
|
|
||||||
_napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols()
|
|
||||||
plugin = napcat_plugin_cls()
|
|
||||||
|
|
||||||
class _FakeOutboundCodec:
|
|
||||||
"""用于测试的出站编码器替身。"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def build_outbound_action(message: Dict[str, Any], route: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
|
||||||
"""返回固定动作与参数。"""
|
|
||||||
|
|
||||||
del message
|
|
||||||
del route
|
|
||||||
return "send_msg", {"message": "hello"}
|
|
||||||
|
|
||||||
class _FakeTransport:
|
|
||||||
"""用于测试的传输层替身。"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def call_action(action_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""返回带平台消息 ID 的成功响应。"""
|
|
||||||
|
|
||||||
del action_name
|
|
||||||
del params
|
|
||||||
return {
|
|
||||||
"status": "ok",
|
|
||||||
"data": {
|
|
||||||
"message_id": "platform-message-id",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
plugin._require_runtime_bundle = lambda: SimpleNamespace( # type: ignore[method-assign]
|
|
||||||
outbound_codec=_FakeOutboundCodec(),
|
|
||||||
transport=_FakeTransport(),
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await plugin.handle_napcat_gateway(
|
|
||||||
message={"message_id": "internal-message-id"},
|
|
||||||
route={},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result["success"] is True
|
|
||||||
assert result["external_message_id"] == "platform-message-id"
|
|
||||||
assert result["metadata"]["adapter_callbacks"] == [
|
|
||||||
{
|
|
||||||
"name": "message_id_echo",
|
|
||||||
"payload": {
|
|
||||||
"content": {
|
|
||||||
"type": "echo",
|
|
||||||
"echo": "internal-message-id",
|
|
||||||
"actual_id": "platform-message-id",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_inbound_codec_parses_forward_nodes_from_legacy_message_field() -> None:
|
|
||||||
"""入站编解码器应兼容旧版 ``sender + message`` 转发节点结构。"""
|
|
||||||
|
|
||||||
inbound_codec_cls = _load_napcat_inbound_codec_cls()
|
|
||||||
codec = inbound_codec_cls(
|
|
||||||
logger=logging.getLogger("test.napcat_adapter.forward_legacy"),
|
|
||||||
query_service=_FakeNapCatQueryService(
|
|
||||||
forward_payloads={
|
|
||||||
"forward-1": {
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"sender": {"user_id": "10001", "nickname": "张三", "card": "群名片"},
|
|
||||||
"message_id": "node-1",
|
|
||||||
"message": [{"type": "text", "data": {"text": "第一条转发"}}],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
segments, is_at = await codec.convert_segments(
|
|
||||||
{"message": [{"type": "forward", "data": {"id": "forward-1"}}]},
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert is_at is False
|
|
||||||
assert len(segments) == 1
|
|
||||||
assert segments[0]["type"] == "forward"
|
|
||||||
assert segments[0]["data"][0]["user_id"] == "10001"
|
|
||||||
assert segments[0]["data"][0]["user_nickname"] == "张三"
|
|
||||||
assert segments[0]["data"][0]["user_cardname"] == "群名片"
|
|
||||||
assert segments[0]["data"][0]["content"] == [{"type": "text", "data": "第一条转发"}]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_inbound_codec_parses_nested_inline_forward_content() -> None:
|
|
||||||
"""入站编解码器应支持内联 ``content`` 形式的嵌套合并转发。"""
|
|
||||||
|
|
||||||
inbound_codec_cls = _load_napcat_inbound_codec_cls()
|
|
||||||
codec = inbound_codec_cls(
|
|
||||||
logger=logging.getLogger("test.napcat_adapter.forward_nested"),
|
|
||||||
query_service=_FakeNapCatQueryService(
|
|
||||||
forward_payloads={
|
|
||||||
"forward-outer": {
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"sender": {"user_id": "10001", "nickname": "张三"},
|
|
||||||
"message_id": "node-outer",
|
|
||||||
"message": [
|
|
||||||
{
|
|
||||||
"type": "forward",
|
|
||||||
"data": {
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"sender": {"user_id": "10002", "nickname": "李四"},
|
|
||||||
"message_id": "node-inner",
|
|
||||||
"message": [{"type": "text", "data": {"text": "内层消息"}}],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
segments, _ = await codec.convert_segments(
|
|
||||||
{"message": [{"type": "forward", "data": {"id": "forward-outer"}}]},
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(segments) == 1
|
|
||||||
assert segments[0]["type"] == "forward"
|
|
||||||
outer_content = segments[0]["data"][0]["content"]
|
|
||||||
assert len(outer_content) == 1
|
|
||||||
assert outer_content[0]["type"] == "forward"
|
|
||||||
nested_nodes = outer_content[0]["data"]
|
|
||||||
assert nested_nodes[0]["user_id"] == "10002"
|
|
||||||
assert nested_nodes[0]["user_nickname"] == "李四"
|
|
||||||
assert nested_nodes[0]["content"] == [{"type": "text", "data": "内层消息"}]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_inbound_codec_resolves_at_to_group_cardname() -> None:
|
|
||||||
"""入站编解码器应优先将 ``at`` 解析为群昵称。"""
|
|
||||||
|
|
||||||
inbound_codec_cls = _load_napcat_inbound_codec_cls()
|
|
||||||
codec = inbound_codec_cls(
|
|
||||||
logger=logging.getLogger("test.napcat_adapter.at_cardname"),
|
|
||||||
query_service=_FakeNapCatQueryService(
|
|
||||||
group_member_payloads={
|
|
||||||
("12345", "1206069534"): {
|
|
||||||
"nickname": "QQ昵称",
|
|
||||||
"card": "群昵称",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
message_dict = await codec.build_message_dict(
|
|
||||||
payload={
|
|
||||||
"message_type": "group",
|
|
||||||
"group_id": "12345",
|
|
||||||
"message_id": "msg-1",
|
|
||||||
"message": [{"type": "at", "data": {"qq": "1206069534"}}],
|
|
||||||
"sender": {"user_id": "10001", "nickname": "发送者"},
|
|
||||||
"time": 1710000000,
|
|
||||||
},
|
|
||||||
self_id="20001",
|
|
||||||
sender_user_id="10001",
|
|
||||||
sender={"user_id": "10001", "nickname": "发送者"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert message_dict["processed_plain_text"] == "@群昵称"
|
|
||||||
assert message_dict["raw_message"] == [
|
|
||||||
{
|
|
||||||
"type": "at",
|
|
||||||
"data": {
|
|
||||||
"target_user_id": "1206069534",
|
|
||||||
"target_user_nickname": "QQ昵称",
|
|
||||||
"target_user_cardname": "群昵称",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_inbound_codec_falls_back_to_qq_nickname_when_group_cardname_is_empty() -> None:
|
|
||||||
"""入站编解码器在群昵称为空时应回退到 QQ 昵称。"""
|
|
||||||
|
|
||||||
inbound_codec_cls = _load_napcat_inbound_codec_cls()
|
|
||||||
codec = inbound_codec_cls(
|
|
||||||
logger=logging.getLogger("test.napcat_adapter.at_nickname"),
|
|
||||||
query_service=_FakeNapCatQueryService(
|
|
||||||
group_member_payloads={
|
|
||||||
("12345", "1206069534"): {
|
|
||||||
"nickname": "QQ昵称",
|
|
||||||
"card": "",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
message_dict = await codec.build_message_dict(
|
|
||||||
payload={
|
|
||||||
"message_type": "group",
|
|
||||||
"group_id": "12345",
|
|
||||||
"message_id": "msg-2",
|
|
||||||
"message": [{"type": "at", "data": {"qq": "1206069534"}}],
|
|
||||||
"sender": {"user_id": "10001", "nickname": "发送者"},
|
|
||||||
"time": 1710000000,
|
|
||||||
},
|
|
||||||
self_id="20001",
|
|
||||||
sender_user_id="10001",
|
|
||||||
sender={"user_id": "10001", "nickname": "发送者"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert message_dict["processed_plain_text"] == "@QQ昵称"
|
|
||||||
assert message_dict["raw_message"] == [
|
|
||||||
{
|
|
||||||
"type": "at",
|
|
||||||
"data": {
|
|
||||||
"target_user_id": "1206069534",
|
|
||||||
"target_user_nickname": "QQ昵称",
|
|
||||||
"target_user_cardname": None,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_inbound_codec_falls_back_to_stranger_nickname_when_group_profile_is_missing() -> None:
|
|
||||||
"""入站编解码器在群资料缺失时应继续回退到 QQ 昵称。"""
|
|
||||||
|
|
||||||
inbound_codec_cls = _load_napcat_inbound_codec_cls()
|
|
||||||
codec = inbound_codec_cls(
|
|
||||||
logger=logging.getLogger("test.napcat_adapter.at_stranger_nickname"),
|
|
||||||
query_service=_FakeNapCatQueryService(
|
|
||||||
group_member_payloads={("12345", "1206069534"): None},
|
|
||||||
stranger_payloads={"1206069534": {"nickname": "QQ昵称"}},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
message_dict = await codec.build_message_dict(
|
|
||||||
payload={
|
|
||||||
"message_type": "group",
|
|
||||||
"group_id": "12345",
|
|
||||||
"message_id": "msg-3",
|
|
||||||
"message": [{"type": "at", "data": {"qq": "1206069534"}}],
|
|
||||||
"sender": {"user_id": "10001", "nickname": "发送者"},
|
|
||||||
"time": 1710000000,
|
|
||||||
},
|
|
||||||
self_id="20001",
|
|
||||||
sender_user_id="10001",
|
|
||||||
sender={"user_id": "10001", "nickname": "发送者"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert message_dict["processed_plain_text"] == "@QQ昵称"
|
|
||||||
assert message_dict["raw_message"] == [
|
|
||||||
{
|
|
||||||
"type": "at",
|
|
||||||
"data": {
|
|
||||||
"target_user_id": "1206069534",
|
|
||||||
"target_user_nickname": "QQ昵称",
|
|
||||||
"target_user_cardname": None,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_service_normalizes_forward_payload_list() -> None:
|
|
||||||
"""查询服务应兼容 ``get_forward_msg`` 直接返回节点列表。"""
|
|
||||||
|
|
||||||
query_service_cls = _load_napcat_query_service_cls()
|
|
||||||
query_service = query_service_cls(
|
|
||||||
action_service=_FakeNapCatActionService(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"sender": {"user_id": "10001", "nickname": "张三"},
|
|
||||||
"message_id": "node-1",
|
|
||||||
"message": [{"type": "text", "data": {"text": "列表返回"}}],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
),
|
|
||||||
logger=logging.getLogger("test.napcat_adapter.query_service"),
|
|
||||||
)
|
|
||||||
|
|
||||||
forward_payload = await query_service.get_forward_message("forward-1")
|
|
||||||
|
|
||||||
assert forward_payload == {
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"sender": {"user_id": "10001", "nickname": "张三"},
|
|
||||||
"message_id": "node-1",
|
|
||||||
"message": [{"type": "text", "data": {"text": "列表返回"}}],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_service_supports_official_no_cache_for_get_stranger_info() -> None:
|
|
||||||
"""查询服务应按官方字段下发 ``no_cache``。"""
|
|
||||||
|
|
||||||
action_service = _FakeNapCatActionService({"nickname": "测试用户"})
|
|
||||||
query_service_cls = _load_napcat_query_service_cls()
|
|
||||||
query_service = query_service_cls(
|
|
||||||
action_service=action_service,
|
|
||||||
logger=logging.getLogger("test.napcat_adapter.query_service.stranger"),
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = await query_service.get_stranger_info("10001", no_cache=True)
|
|
||||||
|
|
||||||
assert payload == {"nickname": "测试用户"}
|
|
||||||
assert action_service.action_data_calls[-1] == {
|
|
||||||
"action_name": "get_stranger_info",
|
|
||||||
"params": {"user_id": "10001", "no_cache": True},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_service_supports_official_forward_id_alias() -> None:
|
|
||||||
"""查询服务应兼容官方 ``id`` 字段调用 ``get_forward_msg``。"""
|
|
||||||
|
|
||||||
action_service = _FakeNapCatActionService({"messages": []})
|
|
||||||
query_service_cls = _load_napcat_query_service_cls()
|
|
||||||
query_service = query_service_cls(
|
|
||||||
action_service=action_service,
|
|
||||||
logger=logging.getLogger("test.napcat_adapter.query_service.forward_alias"),
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = await query_service.get_forward_message(forward_id="forward-alias")
|
|
||||||
|
|
||||||
assert payload == {"messages": []}
|
|
||||||
assert action_service.action_data_calls[-1] == {
|
|
||||||
"action_name": "get_forward_msg",
|
|
||||||
"params": {"id": "forward-alias"},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_service_supports_custom_out_format_for_get_record() -> None:
|
|
||||||
"""查询服务应按官方字段下发自定义 ``out_format``。"""
|
|
||||||
|
|
||||||
action_service = _FakeNapCatActionService({"file": "voice.mp3"})
|
|
||||||
query_service_cls = _load_napcat_query_service_cls()
|
|
||||||
query_service = query_service_cls(
|
|
||||||
action_service=action_service,
|
|
||||||
logger=logging.getLogger("test.napcat_adapter.query_service.record"),
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = await query_service.get_record_detail(file_id="record-1", out_format="mp3")
|
|
||||||
|
|
||||||
assert payload == {"file": "voice.mp3"}
|
|
||||||
assert action_service.action_data_calls[-1] == {
|
|
||||||
"action_name": "get_record",
|
|
||||||
"params": {"file_id": "record-1", "out_format": "mp3"},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_service_supports_target_id_for_send_poke() -> None:
|
|
||||||
"""查询服务应按官方字段下发 ``target_id``。"""
|
|
||||||
|
|
||||||
action_service = _FakeNapCatActionService(None)
|
|
||||||
query_service_cls = _load_napcat_query_service_cls()
|
|
||||||
query_service = query_service_cls(
|
|
||||||
action_service=action_service,
|
|
||||||
logger=logging.getLogger("test.napcat_adapter.query_service.poke"),
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await query_service.send_poke(user_id=10001, group_id=20002, target_id=30003)
|
|
||||||
|
|
||||||
assert response["status"] == "ok"
|
|
||||||
assert action_service.action_calls[-1] == {
|
|
||||||
"action_name": "send_poke",
|
|
||||||
"params": {"user_id": 10001, "group_id": 20002, "target_id": 30003},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_public_api_send_poke_supports_official_fields_and_legacy_alias() -> None:
|
|
||||||
"""公开 API 应同时兼容官方字段和旧版 ``qq_id`` 别名。"""
|
|
||||||
|
|
||||||
_napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols()
|
|
||||||
plugin = napcat_plugin_cls()
|
|
||||||
captured: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
class _SpyQueryService:
|
|
||||||
async def send_poke(
|
|
||||||
self,
|
|
||||||
user_id: int,
|
|
||||||
group_id: int | None = None,
|
|
||||||
target_id: int | None = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
captured.append(
|
|
||||||
{
|
|
||||||
"user_id": user_id,
|
|
||||||
"group_id": group_id,
|
|
||||||
"target_id": target_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return {"status": "ok", "data": {}}
|
|
||||||
|
|
||||||
plugin._query_service = _SpyQueryService()
|
|
||||||
plugin._ensure_runtime_components = lambda: None # type: ignore[method-assign]
|
|
||||||
|
|
||||||
await plugin.api_send_poke(user_id="10001", group_id="20002", target_id="30003")
|
|
||||||
await plugin.api_send_poke(qq_id="40004")
|
|
||||||
|
|
||||||
assert captured == [
|
|
||||||
{"user_id": 10001, "group_id": 20002, "target_id": 30003},
|
|
||||||
{"user_id": 40004, "group_id": None, "target_id": None},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_public_api_get_forward_msg_and_get_record_support_official_fields() -> None:
|
|
||||||
"""公开 API 应接受官方 ``id`` 和 ``out_format`` 等字段。"""
|
|
||||||
|
|
||||||
_napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols()
|
|
||||||
plugin = napcat_plugin_cls()
|
|
||||||
captured: Dict[str, Dict[str, Any]] = {}
|
|
||||||
|
|
||||||
class _SpyQueryService:
|
|
||||||
async def get_forward_message(
|
|
||||||
self,
|
|
||||||
message_id: str | None = None,
|
|
||||||
forward_id: str | None = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
captured["forward"] = {"message_id": message_id, "forward_id": forward_id}
|
|
||||||
return {"messages": []}
|
|
||||||
|
|
||||||
async def get_record_detail(
|
|
||||||
self,
|
|
||||||
file_name: str | None = None,
|
|
||||||
file_id: str | None = None,
|
|
||||||
out_format: str = "wav",
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
captured["record"] = {
|
|
||||||
"file_name": file_name,
|
|
||||||
"file_id": file_id,
|
|
||||||
"out_format": out_format,
|
|
||||||
}
|
|
||||||
return {"file_id": file_id or "record-1"}
|
|
||||||
|
|
||||||
plugin._query_service = _SpyQueryService()
|
|
||||||
plugin._ensure_runtime_components = lambda: None # type: ignore[method-assign]
|
|
||||||
|
|
||||||
forward_payload = await plugin.api_get_forward_msg(id="forward-1")
|
|
||||||
record_payload = await plugin.api_get_record(file_id="record-1", out_format="mp3")
|
|
||||||
|
|
||||||
assert forward_payload == {"messages": []}
|
|
||||||
assert record_payload == {"file_id": "record-1"}
|
|
||||||
assert captured["forward"] == {"message_id": None, "forward_id": "forward-1"}
|
|
||||||
assert captured["record"] == {
|
|
||||||
"file_name": None,
|
|
||||||
"file_id": "record-1",
|
|
||||||
"out_format": "mp3",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_public_api_send_poke_rejects_conflicting_alias_values() -> None:
|
|
||||||
"""公开 ``send_poke`` API 应拒绝互相冲突的别名值。"""
|
|
||||||
|
|
||||||
_napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols()
|
|
||||||
plugin = napcat_plugin_cls()
|
|
||||||
plugin._ensure_runtime_components = lambda: None # type: ignore[method-assign]
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="user_id 与 qq_id 不能同时传递不同的值"):
|
|
||||||
await plugin.api_send_poke(user_id="10001", qq_id="20002")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_public_api_get_forward_msg_rejects_conflicting_fields() -> None:
|
|
||||||
"""公开 ``get_forward_msg`` API 应拒绝冲突的双字段调用。"""
|
|
||||||
|
|
||||||
_napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols()
|
|
||||||
plugin = napcat_plugin_cls()
|
|
||||||
plugin._ensure_runtime_components = lambda: None # type: ignore[method-assign]
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="message_id 与 id 不能同时传递不同的值"):
|
|
||||||
await plugin.api_get_forward_msg(message_id="forward-a", id="forward-b")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_public_api_get_record_requires_file_or_file_id() -> None:
|
|
||||||
"""公开 ``get_record`` API 至少需要一个官方定位字段。"""
|
|
||||||
|
|
||||||
_napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols()
|
|
||||||
plugin = napcat_plugin_cls()
|
|
||||||
plugin._ensure_runtime_components = lambda: None # type: ignore[method-assign]
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="file 或 file_id 至少提供一个"):
|
|
||||||
await plugin.api_get_record()
|
|
||||||
@@ -1,376 +0,0 @@
|
|||||||
"""NapCat 历史补拉与恢复状态测试。"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from importlib import import_module, util
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
|
||||||
PLUGINS_ROOT = PROJECT_ROOT / "plugins"
|
|
||||||
PLUGIN_TEMPLATE_ROOT = PROJECT_ROOT / "plugin-templates"
|
|
||||||
SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk"
|
|
||||||
NAPCAT_PLUGIN_DIR = PLUGINS_ROOT / "MaiBot-Napcat-Adapter"
|
|
||||||
NAPCAT_TEMPLATE_DIR = PLUGIN_TEMPLATE_ROOT / "MaiBot-Napcat-Adapter"
|
|
||||||
NAPCAT_TEST_MODULE = "_test_napcat_adapter_history_recovery"
|
|
||||||
|
|
||||||
for import_path in (str(SDK_ROOT),):
|
|
||||||
if import_path not in sys.path:
|
|
||||||
sys.path.insert(0, import_path)
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeGatewayCapability:
|
|
||||||
"""用于测试入站注入的网关替身。"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
"""初始化测试替身。"""
|
|
||||||
|
|
||||||
self.calls: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
async def route_message(
|
|
||||||
self,
|
|
||||||
gateway_name: str,
|
|
||||||
message: Dict[str, Any],
|
|
||||||
*,
|
|
||||||
route_metadata: Dict[str, Any] | None = None,
|
|
||||||
external_message_id: str = "",
|
|
||||||
dedupe_key: str = "",
|
|
||||||
) -> bool:
|
|
||||||
"""记录入站注入请求并始终模拟成功。"""
|
|
||||||
|
|
||||||
self.calls.append(
|
|
||||||
{
|
|
||||||
"gateway_name": gateway_name,
|
|
||||||
"message": dict(message),
|
|
||||||
"route_metadata": dict(route_metadata or {}),
|
|
||||||
"external_message_id": external_message_id,
|
|
||||||
"dedupe_key": dedupe_key,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_napcat_plugin_dir() -> Path:
|
|
||||||
"""返回当前测试可用的 NapCat 插件目录。"""
|
|
||||||
|
|
||||||
if NAPCAT_PLUGIN_DIR.is_dir():
|
|
||||||
return NAPCAT_PLUGIN_DIR
|
|
||||||
return NAPCAT_TEMPLATE_DIR
|
|
||||||
|
|
||||||
|
|
||||||
def _load_napcat_module(module_suffix: str) -> Any:
|
|
||||||
"""动态加载 NapCat 测试模块。"""
|
|
||||||
|
|
||||||
plugin_dir = _resolve_napcat_plugin_dir()
|
|
||||||
if NAPCAT_TEST_MODULE not in sys.modules:
|
|
||||||
plugin_path = plugin_dir / "plugin.py"
|
|
||||||
spec = util.spec_from_file_location(
|
|
||||||
NAPCAT_TEST_MODULE,
|
|
||||||
plugin_path,
|
|
||||||
submodule_search_locations=[str(plugin_dir)],
|
|
||||||
)
|
|
||||||
if spec is None or spec.loader is None:
|
|
||||||
raise ImportError(f"无法为 NapCat 插件创建模块规格: {plugin_path}")
|
|
||||||
|
|
||||||
module = util.module_from_spec(spec)
|
|
||||||
sys.modules[NAPCAT_TEST_MODULE] = module
|
|
||||||
try:
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
except Exception:
|
|
||||||
sys.modules.pop(NAPCAT_TEST_MODULE, None)
|
|
||||||
raise
|
|
||||||
|
|
||||||
return import_module(f"{NAPCAT_TEST_MODULE}.{module_suffix}")
|
|
||||||
|
|
||||||
|
|
||||||
def _load_history_recovery_store_cls() -> Any:
|
|
||||||
"""动态加载历史恢复状态仓库类。"""
|
|
||||||
|
|
||||||
return _load_napcat_module("services.history_recovery_store").NapCatHistoryRecoveryStore
|
|
||||||
|
|
||||||
|
|
||||||
def _load_query_service_cls() -> Any:
|
|
||||||
"""动态加载查询服务类。"""
|
|
||||||
|
|
||||||
return _load_napcat_module("services.query_service").NapCatQueryService
|
|
||||||
|
|
||||||
|
|
||||||
def _load_router_cls() -> Any:
|
|
||||||
"""动态加载事件路由器类。"""
|
|
||||||
|
|
||||||
return _load_napcat_module("runtime.router").NapCatEventRouter
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeActionService:
|
|
||||||
"""用于查询服务的动作服务替身。"""
|
|
||||||
|
|
||||||
def __init__(self, response_data: Any) -> None:
|
|
||||||
"""初始化动作服务替身。"""
|
|
||||||
|
|
||||||
self._response_data = response_data
|
|
||||||
self.action_data_calls: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
async def safe_call_action_data(self, action_name: str, params: Dict[str, Any]) -> Any:
|
|
||||||
"""记录安全查询动作。"""
|
|
||||||
|
|
||||||
self.action_data_calls.append({"action_name": action_name, "params": dict(params)})
|
|
||||||
return self._response_data
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_history_recovery_store_persists_checkpoint_and_seen_state(tmp_path: Path) -> None:
|
|
||||||
"""历史恢复状态仓库应持久化 checkpoint 与已补拉标记。"""
|
|
||||||
|
|
||||||
store_cls = _load_history_recovery_store_cls()
|
|
||||||
store = store_cls(
|
|
||||||
logger=logging.getLogger("test.napcat.history_store"),
|
|
||||||
storage_path=tmp_path / "history.sqlite3",
|
|
||||||
)
|
|
||||||
|
|
||||||
await store.load()
|
|
||||||
await store.record_checkpoint(
|
|
||||||
account_id="10001",
|
|
||||||
scope="primary",
|
|
||||||
chat_type="group",
|
|
||||||
chat_id="20001",
|
|
||||||
message_id="msg-2",
|
|
||||||
message_time=200.0,
|
|
||||||
message_seq=2,
|
|
||||||
)
|
|
||||||
await store.record_checkpoint(
|
|
||||||
account_id="10001",
|
|
||||||
scope="primary",
|
|
||||||
chat_type="group",
|
|
||||||
chat_id="20001",
|
|
||||||
message_id="msg-1",
|
|
||||||
message_time=100.0,
|
|
||||||
message_seq=1,
|
|
||||||
)
|
|
||||||
await store.mark_recovered_message_seen(
|
|
||||||
account_id="10001",
|
|
||||||
scope="primary",
|
|
||||||
chat_type="group",
|
|
||||||
chat_id="20001",
|
|
||||||
external_message_id="history-1",
|
|
||||||
)
|
|
||||||
|
|
||||||
checkpoints = await store.list_checkpoints("10001", scope="primary")
|
|
||||||
|
|
||||||
assert len(checkpoints) == 1
|
|
||||||
assert checkpoints[0].last_message_id == "msg-2"
|
|
||||||
assert checkpoints[0].last_message_seq == 2
|
|
||||||
assert (
|
|
||||||
await store.has_recovered_message_seen(
|
|
||||||
account_id="10001",
|
|
||||||
scope="primary",
|
|
||||||
chat_type="group",
|
|
||||||
chat_id="20001",
|
|
||||||
external_message_id="history-1",
|
|
||||||
)
|
|
||||||
is True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_service_wraps_group_and_friend_history_actions() -> None:
|
|
||||||
"""查询服务应按官方动作名封装历史消息接口。"""
|
|
||||||
|
|
||||||
query_service_cls = _load_query_service_cls()
|
|
||||||
action_service = _FakeActionService([{"message_id": "msg-1"}])
|
|
||||||
query_service = query_service_cls(
|
|
||||||
action_service=action_service,
|
|
||||||
logger=logging.getLogger("test.napcat.history_query"),
|
|
||||||
)
|
|
||||||
|
|
||||||
group_payload = await query_service.get_group_message_history("20001", message_seq=123, count=10)
|
|
||||||
friend_payload = await query_service.get_friend_message_history("30001", count=5, reverse_order=True)
|
|
||||||
|
|
||||||
assert group_payload == [{"message_id": "msg-1"}]
|
|
||||||
assert friend_payload == [{"message_id": "msg-1"}]
|
|
||||||
assert action_service.action_data_calls == [
|
|
||||||
{
|
|
||||||
"action_name": "get_group_msg_history",
|
|
||||||
"params": {"group_id": "20001", "count": 10, "reverse_order": False, "message_seq": 123},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"action_name": "get_friend_msg_history",
|
|
||||||
"params": {"user_id": "30001", "count": 5, "reverse_order": True},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_router_recover_recent_history_reinjects_messages_in_order(tmp_path: Path) -> None:
|
|
||||||
"""重连补拉应按时间顺序将历史消息重新注入原入站路径。"""
|
|
||||||
|
|
||||||
history_store_cls = _load_history_recovery_store_cls()
|
|
||||||
router_cls = _load_router_cls()
|
|
||||||
gateway_capability = _FakeGatewayCapability()
|
|
||||||
router = router_cls(
|
|
||||||
gateway_capability=gateway_capability,
|
|
||||||
logger=logging.getLogger("test.napcat.history_router"),
|
|
||||||
gateway_name="napcat_gateway",
|
|
||||||
load_settings=lambda: SimpleNamespace(
|
|
||||||
napcat_server=SimpleNamespace(connection_id="primary", heartbeat_interval=30.0),
|
|
||||||
filters=SimpleNamespace(ignore_self_message=True),
|
|
||||||
chat=SimpleNamespace(ban_qq_bot=False),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
history_store = history_store_cls(
|
|
||||||
logger=logging.getLogger("test.napcat.history_router.store"),
|
|
||||||
storage_path=tmp_path / "router.sqlite3",
|
|
||||||
)
|
|
||||||
await history_store.load()
|
|
||||||
await history_store.record_checkpoint(
|
|
||||||
account_id="10001",
|
|
||||||
scope="primary",
|
|
||||||
chat_type="group",
|
|
||||||
chat_id="20001",
|
|
||||||
message_id="msg-1",
|
|
||||||
message_time=100.0,
|
|
||||||
message_seq=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
history_calls: List[Dict[str, Any]] = []
|
|
||||||
history_payloads = [
|
|
||||||
{
|
|
||||||
"post_type": "message",
|
|
||||||
"message_type": "group",
|
|
||||||
"self_id": "10001",
|
|
||||||
"group_id": "20001",
|
|
||||||
"user_id": "30002",
|
|
||||||
"message_id": "msg-3",
|
|
||||||
"message_seq": 12,
|
|
||||||
"time": 102,
|
|
||||||
"message": [{"type": "text", "data": {"text": "第三条"}}],
|
|
||||||
"sender": {"user_id": "30002", "nickname": "用户二"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"post_type": "message",
|
|
||||||
"message_type": "group",
|
|
||||||
"self_id": "10001",
|
|
||||||
"group_id": "20001",
|
|
||||||
"user_id": "30001",
|
|
||||||
"message_id": "msg-2",
|
|
||||||
"message_seq": 11,
|
|
||||||
"time": 101,
|
|
||||||
"message": [{"type": "text", "data": {"text": "第二条"}}],
|
|
||||||
"sender": {"user_id": "30001", "nickname": "用户一"},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
class _FakeQueryService:
|
|
||||||
async def get_group_message_history(
|
|
||||||
self,
|
|
||||||
group_id: str,
|
|
||||||
*,
|
|
||||||
message_seq: int | None = None,
|
|
||||||
count: int = 20,
|
|
||||||
reverse_order: bool = False,
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
history_calls.append(
|
|
||||||
{
|
|
||||||
"group_id": group_id,
|
|
||||||
"message_seq": message_seq,
|
|
||||||
"count": count,
|
|
||||||
"reverse_order": reverse_order,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return list(history_payloads)
|
|
||||||
|
|
||||||
async def get_friend_message_history(self, user_id: str, **kwargs: Any) -> List[Dict[str, Any]]:
|
|
||||||
del user_id
|
|
||||||
del kwargs
|
|
||||||
return []
|
|
||||||
|
|
||||||
class _FakeInboundCodec:
|
|
||||||
@staticmethod
|
|
||||||
async def build_message_dict(
|
|
||||||
payload: Dict[str, Any],
|
|
||||||
self_id: str,
|
|
||||||
sender_user_id: str,
|
|
||||||
sender: Dict[str, Any],
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
del self_id
|
|
||||||
del sender_user_id
|
|
||||||
del sender
|
|
||||||
return {
|
|
||||||
"message_id": str(payload["message_id"]),
|
|
||||||
"platform": "qq",
|
|
||||||
"timestamp": str(float(payload["time"])),
|
|
||||||
"message_info": {
|
|
||||||
"user_info": {"user_id": str(payload["user_id"]), "user_nickname": "测试用户"},
|
|
||||||
"group_info": {"group_id": str(payload["group_id"]), "group_name": "测试群"},
|
|
||||||
"additional_config": {},
|
|
||||||
},
|
|
||||||
"raw_message": [{"type": "text", "data": str(payload["message"][0]["data"]["text"])}],
|
|
||||||
"processed_plain_text": str(payload["message"][0]["data"]["text"]),
|
|
||||||
"display_message": str(payload["message"][0]["data"]["text"]),
|
|
||||||
"is_mentioned": False,
|
|
||||||
"is_at": False,
|
|
||||||
"is_emoji": False,
|
|
||||||
"is_picture": False,
|
|
||||||
"is_command": False,
|
|
||||||
"is_notify": False,
|
|
||||||
"session_id": "",
|
|
||||||
}
|
|
||||||
|
|
||||||
router.bind_runtime(
|
|
||||||
SimpleNamespace(
|
|
||||||
runtime_state=SimpleNamespace(report_connected=lambda *args, **kwargs: _noop_async(), report_disconnected=_noop_async),
|
|
||||||
chat_filter=SimpleNamespace(is_inbound_chat_allowed=lambda *args, **kwargs: True),
|
|
||||||
official_bot_guard=SimpleNamespace(
|
|
||||||
should_reject=lambda *args, **kwargs: _return_false_async(),
|
|
||||||
clear_cache=lambda: None,
|
|
||||||
),
|
|
||||||
inbound_codec=_FakeInboundCodec(),
|
|
||||||
history_recovery_store=history_store,
|
|
||||||
query_service=_FakeQueryService(),
|
|
||||||
heartbeat_monitor=SimpleNamespace(start=_noop_async, stop=_noop_async),
|
|
||||||
ban_tracker=SimpleNamespace(start=_noop_async, stop=_noop_async, record_notice=_noop_async),
|
|
||||||
notice_codec=SimpleNamespace(handle_meta_event=_noop_async, build_notice_message_dict=_return_none_async),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
await router._recover_recent_history(self_id="10001", scope="primary")
|
|
||||||
|
|
||||||
assert history_calls == [
|
|
||||||
{
|
|
||||||
"group_id": "20001",
|
|
||||||
"message_seq": 10,
|
|
||||||
"count": 20,
|
|
||||||
"reverse_order": False,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
assert [call["external_message_id"] for call in gateway_capability.calls] == ["msg-2", "msg-3"]
|
|
||||||
assert [call["message"]["message_id"] for call in gateway_capability.calls] == ["msg-2", "msg-3"]
|
|
||||||
|
|
||||||
|
|
||||||
async def _noop_async(*args: Any, **kwargs: Any) -> None:
|
|
||||||
"""无操作异步函数。"""
|
|
||||||
|
|
||||||
del args
|
|
||||||
del kwargs
|
|
||||||
|
|
||||||
|
|
||||||
async def _return_false_async(*args: Any, **kwargs: Any) -> bool:
|
|
||||||
"""返回 ``False`` 的异步测试替身。"""
|
|
||||||
|
|
||||||
del args
|
|
||||||
del kwargs
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def _return_none_async(*args: Any, **kwargs: Any) -> None:
|
|
||||||
"""返回 ``None`` 的异步测试替身。"""
|
|
||||||
|
|
||||||
del args
|
|
||||||
del kwargs
|
|
||||||
return None
|
|
||||||
@@ -1,164 +0,0 @@
|
|||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.config.model_configs import APIProvider, ReasoningParseMode, ToolArgumentParseMode
|
|
||||||
from src.llm_models.model_client.openai_client import (
|
|
||||||
_OpenAIStreamAccumulator,
|
|
||||||
_build_reasoning_key,
|
|
||||||
_default_normal_response_parser,
|
|
||||||
_parse_tool_arguments,
|
|
||||||
_sanitize_messages_for_toolless_request,
|
|
||||||
)
|
|
||||||
from src.llm_models.payload_content.message import Message, RoleType, TextMessagePart
|
|
||||||
from src.llm_models.payload_content.tool_option import ToolCall
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("parse_mode", list(ToolArgumentParseMode))
|
|
||||||
def test_parse_tool_arguments_treats_blank_arguments_as_empty_dict(parse_mode: ToolArgumentParseMode) -> None:
|
|
||||||
assert _parse_tool_arguments("", parse_mode, None) == {}
|
|
||||||
assert _parse_tool_arguments(" ", parse_mode, None) == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_normal_response_parser_accepts_empty_string_arguments_for_parameterless_tool() -> None:
|
|
||||||
response = SimpleNamespace(
|
|
||||||
choices=[
|
|
||||||
SimpleNamespace(
|
|
||||||
finish_reason="tool_calls",
|
|
||||||
message=SimpleNamespace(
|
|
||||||
content=None,
|
|
||||||
tool_calls=[
|
|
||||||
SimpleNamespace(
|
|
||||||
id="finish-call",
|
|
||||||
type="function",
|
|
||||||
function=SimpleNamespace(name="finish", arguments=""),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
usage=None,
|
|
||||||
model="glm-5.1",
|
|
||||||
)
|
|
||||||
|
|
||||||
api_response, usage_record = _default_normal_response_parser(
|
|
||||||
response,
|
|
||||||
reasoning_parse_mode=ReasoningParseMode.AUTO,
|
|
||||||
tool_argument_parse_mode=ToolArgumentParseMode.AUTO,
|
|
||||||
reasoning_key=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(api_response.tool_calls) == 1
|
|
||||||
assert api_response.tool_calls[0].func_name == "finish"
|
|
||||||
assert api_response.tool_calls[0].args == {}
|
|
||||||
assert usage_record is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_sanitize_messages_for_toolless_request_drops_assistant_tool_call_without_parts() -> None:
|
|
||||||
messages = [
|
|
||||||
Message(
|
|
||||||
role=RoleType.Assistant,
|
|
||||||
tool_calls=[
|
|
||||||
ToolCall(
|
|
||||||
call_id="call_1",
|
|
||||||
func_name="mute_user",
|
|
||||||
args={"target": "alice"},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
Message(
|
|
||||||
role=RoleType.User,
|
|
||||||
parts=[TextMessagePart(text="继续")],
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
sanitized_messages = _sanitize_messages_for_toolless_request(messages)
|
|
||||||
|
|
||||||
assert len(sanitized_messages) == 1
|
|
||||||
assert sanitized_messages[0].role == RoleType.User
|
|
||||||
|
|
||||||
|
|
||||||
def test_normal_response_parser_ignores_reasoning_field_for_non_openrouter_provider() -> None:
|
|
||||||
response = SimpleNamespace(
|
|
||||||
choices=[
|
|
||||||
SimpleNamespace(
|
|
||||||
finish_reason="stop",
|
|
||||||
message=SimpleNamespace(
|
|
||||||
content="正式回复",
|
|
||||||
reasoning="推理内容",
|
|
||||||
tool_calls=None,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
usage=None,
|
|
||||||
model="openrouter/test-model",
|
|
||||||
)
|
|
||||||
|
|
||||||
api_response, usage_record = _default_normal_response_parser(
|
|
||||||
response,
|
|
||||||
reasoning_parse_mode=ReasoningParseMode.AUTO,
|
|
||||||
tool_argument_parse_mode=ToolArgumentParseMode.AUTO,
|
|
||||||
reasoning_key=_build_reasoning_key(
|
|
||||||
APIProvider(name="test", base_url="https://openrouter.ai.example.com/api/v1", api_key="test")
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert api_response.content == "正式回复"
|
|
||||||
assert api_response.reasoning_content is None
|
|
||||||
assert usage_record is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_normal_response_parser_reads_provider_reasoning_field_for_reasoning_domains() -> None:
|
|
||||||
provider_urls = [
|
|
||||||
"https://openrouter.ai/compatible-api",
|
|
||||||
"https://api.groq.com/openai/v1",
|
|
||||||
]
|
|
||||||
|
|
||||||
for provider_url in provider_urls:
|
|
||||||
response = SimpleNamespace(
|
|
||||||
choices=[
|
|
||||||
SimpleNamespace(
|
|
||||||
finish_reason="stop",
|
|
||||||
message=SimpleNamespace(
|
|
||||||
content="正式回复",
|
|
||||||
reasoning="推理内容",
|
|
||||||
tool_calls=None,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
usage=None,
|
|
||||||
model="test-model",
|
|
||||||
)
|
|
||||||
|
|
||||||
api_response, usage_record = _default_normal_response_parser(
|
|
||||||
response,
|
|
||||||
reasoning_parse_mode=ReasoningParseMode.AUTO,
|
|
||||||
tool_argument_parse_mode=ToolArgumentParseMode.AUTO,
|
|
||||||
reasoning_key=_build_reasoning_key(
|
|
||||||
APIProvider(name="reasoning-provider", base_url=provider_url, api_key="test")
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert api_response.content == "正式回复"
|
|
||||||
assert api_response.reasoning_content == "推理内容"
|
|
||||||
assert usage_record is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_accumulator_reads_openrouter_reasoning_delta_field() -> None:
|
|
||||||
accumulator = _OpenAIStreamAccumulator(
|
|
||||||
reasoning_parse_mode=ReasoningParseMode.AUTO,
|
|
||||||
tool_argument_parse_mode=ToolArgumentParseMode.AUTO,
|
|
||||||
reasoning_key=_build_reasoning_key(
|
|
||||||
APIProvider(name="openrouter", base_url="https://openrouter.ai/compatible-api", api_key="test")
|
|
||||||
),
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
accumulator.process_delta(SimpleNamespace(reasoning="流式推理", content=None, tool_calls=None))
|
|
||||||
accumulator.process_delta(SimpleNamespace(content="正式回复", tool_calls=None))
|
|
||||||
|
|
||||||
api_response = accumulator.build_response()
|
|
||||||
finally:
|
|
||||||
accumulator.close()
|
|
||||||
|
|
||||||
assert api_response.content == "正式回复"
|
|
||||||
assert api_response.reasoning_content == "流式推理"
|
|
||||||
@@ -1,209 +0,0 @@
|
|||||||
"""Platform IO 入站去重策略测试。"""
|
|
||||||
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.platform_io.drivers.base import PlatformIODriver
|
|
||||||
from src.platform_io.manager import PlatformIOManager
|
|
||||||
from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey
|
|
||||||
|
|
||||||
|
|
||||||
def _build_envelope(
|
|
||||||
*,
|
|
||||||
dedupe_key: str | None = None,
|
|
||||||
external_message_id: str | None = None,
|
|
||||||
session_message_id: str | None = None,
|
|
||||||
payload: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> InboundMessageEnvelope:
|
|
||||||
"""构造测试用入站信封。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dedupe_key: 显式去重键。
|
|
||||||
external_message_id: 平台侧消息 ID。
|
|
||||||
session_message_id: 规范化消息对象上的消息 ID。
|
|
||||||
payload: 原始载荷。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
InboundMessageEnvelope: 测试用入站消息信封。
|
|
||||||
"""
|
|
||||||
session_message = None
|
|
||||||
if session_message_id is not None:
|
|
||||||
session_message = SimpleNamespace(message_id=session_message_id)
|
|
||||||
|
|
||||||
return InboundMessageEnvelope(
|
|
||||||
route_key=RouteKey(platform="qq", account_id="10001", scope="main"),
|
|
||||||
driver_id="plugin.napcat",
|
|
||||||
driver_kind=DriverKind.PLUGIN,
|
|
||||||
dedupe_key=dedupe_key,
|
|
||||||
external_message_id=external_message_id,
|
|
||||||
session_message=session_message,
|
|
||||||
payload=payload,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _StubPlatformIODriver(PlatformIODriver):
|
|
||||||
"""测试用 Platform IO 驱动。"""
|
|
||||||
|
|
||||||
async def send_message(
|
|
||||||
self,
|
|
||||||
message: Any,
|
|
||||||
route_key: RouteKey,
|
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> DeliveryReceipt:
|
|
||||||
"""返回一个固定的成功回执。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: 待发送的消息对象。
|
|
||||||
route_key: 本次发送使用的路由键。
|
|
||||||
metadata: 额外发送元数据。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
DeliveryReceipt: 固定的成功回执。
|
|
||||||
"""
|
|
||||||
return DeliveryReceipt(
|
|
||||||
internal_message_id=str(getattr(message, "message_id", "stub-message-id")),
|
|
||||||
route_key=route_key,
|
|
||||||
status=DeliveryStatus.SENT,
|
|
||||||
driver_id=self.driver_id,
|
|
||||||
driver_kind=self.descriptor.kind,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_manager() -> PlatformIOManager:
|
|
||||||
"""构造带有最小接收路由的 Broker 管理器。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PlatformIOManager: 已注册测试驱动并绑定接收路由的 Broker。
|
|
||||||
"""
|
|
||||||
manager = PlatformIOManager()
|
|
||||||
driver = _StubPlatformIODriver(
|
|
||||||
DriverDescriptor(
|
|
||||||
driver_id="plugin.napcat",
|
|
||||||
kind=DriverKind.PLUGIN,
|
|
||||||
platform="qq",
|
|
||||||
account_id="10001",
|
|
||||||
scope="main",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
manager.register_driver(driver)
|
|
||||||
manager.bind_receive_route(
|
|
||||||
RouteBinding(
|
|
||||||
route_key=RouteKey(platform="qq", account_id="10001", scope="main"),
|
|
||||||
driver_id=driver.driver_id,
|
|
||||||
driver_kind=driver.descriptor.kind,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return manager
|
|
||||||
|
|
||||||
|
|
||||||
class TestPlatformIODedupe:
|
|
||||||
"""Platform IO 去重测试。"""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_accept_inbound_dedupes_by_external_message_id(self) -> None:
|
|
||||||
"""相同平台消息 ID 的重复入站应被抑制。"""
|
|
||||||
manager = _build_manager()
|
|
||||||
accepted_envelopes: List[InboundMessageEnvelope] = []
|
|
||||||
|
|
||||||
async def dispatcher(envelope: InboundMessageEnvelope) -> None:
|
|
||||||
"""记录被成功接收的入站消息。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
envelope: 被 Broker 接受的入站消息。
|
|
||||||
"""
|
|
||||||
accepted_envelopes.append(envelope)
|
|
||||||
|
|
||||||
manager.set_inbound_dispatcher(dispatcher)
|
|
||||||
|
|
||||||
first_envelope = _build_envelope(
|
|
||||||
external_message_id="msg-1",
|
|
||||||
payload={"message": "hello"},
|
|
||||||
)
|
|
||||||
second_envelope = _build_envelope(
|
|
||||||
external_message_id="msg-1",
|
|
||||||
payload={"message": "hello"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert await manager.accept_inbound(first_envelope) is True
|
|
||||||
assert await manager.accept_inbound(second_envelope) is False
|
|
||||||
assert len(accepted_envelopes) == 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_accept_inbound_without_stable_identity_does_not_guess_duplicate(self) -> None:
|
|
||||||
"""缺少稳定身份时,不应仅凭 payload 内容猜测重复消息。"""
|
|
||||||
manager = _build_manager()
|
|
||||||
accepted_envelopes: List[InboundMessageEnvelope] = []
|
|
||||||
|
|
||||||
async def dispatcher(envelope: InboundMessageEnvelope) -> None:
|
|
||||||
"""记录被成功接收的入站消息。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
envelope: 被 Broker 接受的入站消息。
|
|
||||||
"""
|
|
||||||
accepted_envelopes.append(envelope)
|
|
||||||
|
|
||||||
manager.set_inbound_dispatcher(dispatcher)
|
|
||||||
|
|
||||||
first_envelope = _build_envelope(payload={"message": "same-payload"})
|
|
||||||
second_envelope = _build_envelope(payload={"message": "same-payload"})
|
|
||||||
|
|
||||||
assert await manager.accept_inbound(first_envelope) is True
|
|
||||||
assert await manager.accept_inbound(second_envelope) is True
|
|
||||||
assert len(accepted_envelopes) == 2
|
|
||||||
|
|
||||||
def test_build_inbound_dedupe_key_prefers_explicit_identity(self) -> None:
|
|
||||||
"""去重键应只来自显式或稳定的技术身份。"""
|
|
||||||
explicit_envelope = _build_envelope(dedupe_key="dedupe-1", external_message_id="msg-1")
|
|
||||||
session_message_envelope = _build_envelope(session_message_id="session-1")
|
|
||||||
payload_only_envelope = _build_envelope(payload={"message": "hello"})
|
|
||||||
|
|
||||||
assert PlatformIOManager._build_inbound_dedupe_key(explicit_envelope) == "plugin.napcat:dedupe-1"
|
|
||||||
assert PlatformIOManager._build_inbound_dedupe_key(session_message_envelope) == "plugin.napcat:session-1"
|
|
||||||
assert PlatformIOManager._build_inbound_dedupe_key(payload_only_envelope) is None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_send_message_fans_out_to_all_matching_routes(self) -> None:
|
|
||||||
"""同一路由命中多条发送链路时应全部发送。"""
|
|
||||||
|
|
||||||
manager = PlatformIOManager()
|
|
||||||
first_driver = _StubPlatformIODriver(
|
|
||||||
DriverDescriptor(
|
|
||||||
driver_id="plugin.gateway_a",
|
|
||||||
kind=DriverKind.PLUGIN,
|
|
||||||
platform="qq",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
second_driver = _StubPlatformIODriver(
|
|
||||||
DriverDescriptor(
|
|
||||||
driver_id="plugin.gateway_b",
|
|
||||||
kind=DriverKind.PLUGIN,
|
|
||||||
platform="qq",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
manager.register_driver(first_driver)
|
|
||||||
manager.register_driver(second_driver)
|
|
||||||
manager.bind_send_route(
|
|
||||||
RouteBinding(
|
|
||||||
route_key=RouteKey(platform="qq"),
|
|
||||||
driver_id=first_driver.driver_id,
|
|
||||||
driver_kind=first_driver.descriptor.kind,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
manager.bind_send_route(
|
|
||||||
RouteBinding(
|
|
||||||
route_key=RouteKey(platform="qq"),
|
|
||||||
driver_id=second_driver.driver_id,
|
|
||||||
driver_kind=second_driver.descriptor.kind,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
message = SimpleNamespace(message_id="internal-msg-1")
|
|
||||||
result = await manager.send_message(message, RouteKey(platform="qq"))
|
|
||||||
|
|
||||||
assert result.has_success is True
|
|
||||||
assert [receipt.driver_id for receipt in result.sent_receipts] == [
|
|
||||||
"plugin.gateway_a",
|
|
||||||
"plugin.gateway_b",
|
|
||||||
]
|
|
||||||
@@ -1,178 +0,0 @@
|
|||||||
"""Platform IO legacy driver 回归测试。"""
|
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.chat.utils import utils as chat_utils
|
|
||||||
from src.chat.message_receive import uni_message_sender
|
|
||||||
from src.platform_io.drivers.base import PlatformIODriver
|
|
||||||
from src.platform_io.drivers.legacy_driver import LegacyPlatformDriver
|
|
||||||
from src.platform_io.manager import PlatformIOManager
|
|
||||||
from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteBinding, RouteKey
|
|
||||||
|
|
||||||
|
|
||||||
class _PluginDriver(PlatformIODriver):
|
|
||||||
"""测试用插件发送驱动。"""
|
|
||||||
|
|
||||||
def __init__(self, driver_id: str, platform: str) -> None:
|
|
||||||
"""初始化测试驱动。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
driver_id: 驱动 ID。
|
|
||||||
platform: 负责的平台名称。
|
|
||||||
"""
|
|
||||||
super().__init__(
|
|
||||||
DriverDescriptor(
|
|
||||||
driver_id=driver_id,
|
|
||||||
kind=DriverKind.PLUGIN,
|
|
||||||
platform=platform,
|
|
||||||
plugin_id="test.plugin",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_message(
|
|
||||||
self,
|
|
||||||
message: Any,
|
|
||||||
route_key: RouteKey,
|
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> DeliveryReceipt:
|
|
||||||
"""返回一个固定成功回执。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: 待发送消息。
|
|
||||||
route_key: 当前路由键。
|
|
||||||
metadata: 发送元数据。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
DeliveryReceipt: 固定成功回执。
|
|
||||||
"""
|
|
||||||
del metadata
|
|
||||||
return DeliveryReceipt(
|
|
||||||
internal_message_id=str(message.message_id),
|
|
||||||
route_key=route_key,
|
|
||||||
status=DeliveryStatus.SENT,
|
|
||||||
driver_id=self.driver_id,
|
|
||||||
driver_kind=self.descriptor.kind,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_platform_io_uses_legacy_driver_when_no_explicit_send_route(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
"""没有显式发送路由时,应由 Platform IO 回退到 legacy driver。"""
|
|
||||||
manager = PlatformIOManager()
|
|
||||||
monkeypatch.setattr(chat_utils, "get_all_bot_accounts", lambda: {"qq": "bot-qq"})
|
|
||||||
|
|
||||||
try:
|
|
||||||
await manager.ensure_send_pipeline_ready()
|
|
||||||
|
|
||||||
fallback_drivers = manager.resolve_drivers(RouteKey(platform="qq"))
|
|
||||||
assert [driver.driver_id for driver in fallback_drivers] == ["legacy.send.qq"]
|
|
||||||
|
|
||||||
plugin_driver = _PluginDriver(driver_id="plugin.qq.sender", platform="qq")
|
|
||||||
await manager.add_driver(plugin_driver)
|
|
||||||
manager.bind_send_route(
|
|
||||||
RouteBinding(
|
|
||||||
route_key=RouteKey(platform="qq"),
|
|
||||||
driver_id=plugin_driver.driver_id,
|
|
||||||
driver_kind=plugin_driver.descriptor.kind,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
explicit_drivers = manager.resolve_drivers(RouteKey(platform="qq"))
|
|
||||||
assert [driver.driver_id for driver in explicit_drivers] == ["plugin.qq.sender", "legacy.send.qq"]
|
|
||||||
finally:
|
|
||||||
await manager.stop()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_platform_io_broadcasts_to_plugin_and_legacy_driver(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
"""同一路由命中插件驱动与 legacy driver 时,应同时广播发送。"""
|
|
||||||
|
|
||||||
manager = PlatformIOManager()
|
|
||||||
legacy_calls: list[dict[str, Any]] = []
|
|
||||||
monkeypatch.setattr(chat_utils, "get_all_bot_accounts", lambda: {"qq": "bot-qq"})
|
|
||||||
|
|
||||||
async def _fake_send_prepared_message_to_platform(message: Any, show_log: bool = True) -> bool:
|
|
||||||
"""记录 legacy driver 调用。"""
|
|
||||||
|
|
||||||
legacy_calls.append({"message": message, "show_log": show_log})
|
|
||||||
return True
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
uni_message_sender,
|
|
||||||
"send_prepared_message_to_platform",
|
|
||||||
_fake_send_prepared_message_to_platform,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await manager.ensure_send_pipeline_ready()
|
|
||||||
|
|
||||||
plugin_driver = _PluginDriver(driver_id="plugin.qq.sender", platform="qq")
|
|
||||||
await manager.add_driver(plugin_driver)
|
|
||||||
manager.bind_send_route(
|
|
||||||
RouteBinding(
|
|
||||||
route_key=RouteKey(platform="qq"),
|
|
||||||
driver_id=plugin_driver.driver_id,
|
|
||||||
driver_kind=plugin_driver.descriptor.kind,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
message = type("FakeMessage", (), {"message_id": "message-1"})()
|
|
||||||
batch = await manager.send_message(
|
|
||||||
message=message,
|
|
||||||
route_key=RouteKey(platform="qq"),
|
|
||||||
metadata={"show_log": False},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert sorted(receipt.driver_id for receipt in batch.sent_receipts) == [
|
|
||||||
"legacy.send.qq",
|
|
||||||
"plugin.qq.sender",
|
|
||||||
]
|
|
||||||
assert batch.failed_receipts == []
|
|
||||||
assert len(legacy_calls) == 1
|
|
||||||
assert legacy_calls[0]["message"] is message
|
|
||||||
assert legacy_calls[0]["show_log"] is False
|
|
||||||
finally:
|
|
||||||
await manager.stop()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_legacy_platform_driver_uses_prepared_universal_sender(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
"""legacy driver 应复用已预处理消息的旧链发送函数。"""
|
|
||||||
calls: list[dict[str, Any]] = []
|
|
||||||
|
|
||||||
async def _fake_send_prepared_message_to_platform(message: Any, show_log: bool = True) -> bool:
|
|
||||||
"""记录 legacy driver 调用。"""
|
|
||||||
calls.append({"message": message, "show_log": show_log})
|
|
||||||
return True
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
uni_message_sender,
|
|
||||||
"send_prepared_message_to_platform",
|
|
||||||
_fake_send_prepared_message_to_platform,
|
|
||||||
)
|
|
||||||
|
|
||||||
driver = LegacyPlatformDriver(
|
|
||||||
driver_id="legacy.send.qq",
|
|
||||||
platform="qq",
|
|
||||||
account_id="bot-qq",
|
|
||||||
)
|
|
||||||
message = type("FakeMessage", (), {"message_id": "message-1"})()
|
|
||||||
receipt = await driver.send_message(
|
|
||||||
message=message,
|
|
||||||
route_key=RouteKey(platform="qq"),
|
|
||||||
metadata={"show_log": False},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(calls) == 1
|
|
||||||
assert calls[0]["message"] is message
|
|
||||||
assert calls[0]["show_log"] is False
|
|
||||||
assert receipt.status == DeliveryStatus.SENT
|
|
||||||
assert receipt.driver_id == "legacy.send.qq"
|
|
||||||
@@ -1,553 +0,0 @@
|
|||||||
"""插件配置运行时测试。"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from typing import Any, Dict, Mapping, Optional, Tuple, cast
|
|
||||||
|
|
||||||
import tomllib
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.plugin_runtime.component_query import component_query_service
|
|
||||||
from src.plugin_runtime.protocol.envelope import (
|
|
||||||
Envelope,
|
|
||||||
InspectPluginConfigPayload,
|
|
||||||
MessageType,
|
|
||||||
RegisterPluginPayload,
|
|
||||||
ValidatePluginConfigPayload,
|
|
||||||
)
|
|
||||||
from src.plugin_runtime.runner.runner_main import PluginRunner
|
|
||||||
from src.webui.routers.plugin.config_routes import get_plugin_config, get_plugin_config_schema, update_plugin_config
|
|
||||||
from src.webui.routers.plugin.schemas import UpdatePluginConfigRequest
|
|
||||||
|
|
||||||
|
|
||||||
class _DemoConfigPlugin:
|
|
||||||
"""用于测试 Runner 配置归一化流程的伪插件。"""
|
|
||||||
|
|
||||||
_config_version: str = "2.0.0"
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
"""初始化测试插件状态。"""
|
|
||||||
|
|
||||||
self.received_config: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
def normalize_plugin_config(self, config_data: Optional[Mapping[str, Any]]) -> Tuple[Dict[str, Any], bool]:
|
|
||||||
"""补齐测试插件的默认配置。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_data: 原始配置数据。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[Dict[str, Any], bool]: 补齐后的配置,以及是否发生变更。
|
|
||||||
"""
|
|
||||||
|
|
||||||
current_config = dict(config_data or {})
|
|
||||||
plugin_section = dict(current_config.get("plugin", {}))
|
|
||||||
changed = "retry_count" not in plugin_section or "config_version" not in plugin_section
|
|
||||||
plugin_section.setdefault("config_version", self._config_version)
|
|
||||||
plugin_section.setdefault("enabled", True)
|
|
||||||
plugin_section.setdefault("retry_count", 3)
|
|
||||||
return {"plugin": plugin_section}, changed
|
|
||||||
|
|
||||||
def set_plugin_config(self, config: Dict[str, Any]) -> None:
|
|
||||||
"""记录 Runner 注入的配置内容。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: 当前最新配置。
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.received_config = config
|
|
||||||
|
|
||||||
def get_default_config(self) -> Dict[str, Any]:
|
|
||||||
"""返回测试插件的默认配置。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, Any]: 默认配置字典。
|
|
||||||
"""
|
|
||||||
|
|
||||||
return {"plugin": {"config_version": self._config_version, "enabled": True, "retry_count": 3}}
|
|
||||||
|
|
||||||
def get_webui_config_schema(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
plugin_id: str = "",
|
|
||||||
plugin_name: str = "",
|
|
||||||
plugin_version: str = "",
|
|
||||||
plugin_description: str = "",
|
|
||||||
plugin_author: str = "",
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""返回测试插件的 WebUI 配置 Schema。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_id: 插件 ID。
|
|
||||||
plugin_name: 插件名称。
|
|
||||||
plugin_version: 插件版本。
|
|
||||||
plugin_description: 插件描述。
|
|
||||||
plugin_author: 插件作者。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, Any]: 测试配置 Schema。
|
|
||||||
"""
|
|
||||||
|
|
||||||
del plugin_name, plugin_description, plugin_author
|
|
||||||
return {
|
|
||||||
"plugin_id": plugin_id,
|
|
||||||
"plugin_info": {
|
|
||||||
"name": "Demo",
|
|
||||||
"version": plugin_version,
|
|
||||||
"description": "",
|
|
||||||
"author": "",
|
|
||||||
},
|
|
||||||
"sections": {
|
|
||||||
"plugin": {
|
|
||||||
"fields": {
|
|
||||||
"enabled": {
|
|
||||||
"type": "boolean",
|
|
||||||
"label": "启用",
|
|
||||||
"default": True,
|
|
||||||
"ui_type": "switch",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"layout": {"type": "auto", "tabs": []},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class _StrictConfigPlugin:
|
|
||||||
"""用于测试配置校验错误的伪插件。"""
|
|
||||||
|
|
||||||
def normalize_plugin_config(self, config_data: Optional[Mapping[str, Any]]) -> Tuple[Dict[str, Any], bool]:
|
|
||||||
"""校验重试次数不能为负数。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_data: 原始配置数据。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[Dict[str, Any], bool]: 规范化配置结果。
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 当重试次数为负数时抛出。
|
|
||||||
"""
|
|
||||||
|
|
||||||
current_config = dict(config_data or {})
|
|
||||||
plugin_section = dict(current_config.get("plugin", {}))
|
|
||||||
plugin_section.setdefault("config_version", "2.0.0")
|
|
||||||
retry_count = int(plugin_section.get("retry_count", 0))
|
|
||||||
if retry_count < 0:
|
|
||||||
raise ValueError("重试次数不能小于 0")
|
|
||||||
plugin_section.setdefault("enabled", True)
|
|
||||||
return {"plugin": plugin_section}, False
|
|
||||||
|
|
||||||
def set_plugin_config(self, config: Dict[str, Any]) -> None:
|
|
||||||
"""兼容 Runner 配置注入接口。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: 当前配置字典。
|
|
||||||
"""
|
|
||||||
|
|
||||||
del config
|
|
||||||
|
|
||||||
def get_default_config(self) -> Dict[str, Any]:
|
|
||||||
"""返回测试插件的默认配置。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, Any]: 默认配置字典。
|
|
||||||
"""
|
|
||||||
|
|
||||||
return {"plugin": {"config_version": "2.0.0", "enabled": True, "retry_count": 0}}
|
|
||||||
|
|
||||||
|
|
||||||
def test_runner_apply_plugin_config_generates_config_file(tmp_path: Path) -> None:
|
|
||||||
"""Runner 注入配置时应自动补齐并落盘 config.toml。"""
|
|
||||||
|
|
||||||
plugin = _DemoConfigPlugin()
|
|
||||||
runner = PluginRunner(
|
|
||||||
host_address="ipc://unused",
|
|
||||||
session_token="session-token",
|
|
||||||
plugin_dirs=[],
|
|
||||||
)
|
|
||||||
meta = SimpleNamespace(plugin_id="demo.plugin", plugin_dir=str(tmp_path), instance=plugin)
|
|
||||||
|
|
||||||
runner._apply_plugin_config(
|
|
||||||
cast(Any, meta),
|
|
||||||
config_data={"plugin": {"config_version": "2.0.0", "enabled": False}},
|
|
||||||
)
|
|
||||||
|
|
||||||
config_path = tmp_path / "config.toml"
|
|
||||||
assert config_path.exists()
|
|
||||||
assert plugin.received_config == {"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}}
|
|
||||||
|
|
||||||
with config_path.open("rb") as handle:
|
|
||||||
saved_config = tomllib.load(handle)
|
|
||||||
assert saved_config == {"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}}
|
|
||||||
|
|
||||||
|
|
||||||
def test_runner_apply_plugin_config_preserves_existing_comments(tmp_path: Path) -> None:
|
|
||||||
"""Runner 在版本升级时应尽量保留现有 config.toml 注释。"""
|
|
||||||
|
|
||||||
plugin = _DemoConfigPlugin()
|
|
||||||
runner = PluginRunner(
|
|
||||||
host_address="ipc://unused",
|
|
||||||
session_token="session-token",
|
|
||||||
plugin_dirs=[],
|
|
||||||
)
|
|
||||||
meta = SimpleNamespace(plugin_id="demo.plugin", plugin_dir=str(tmp_path), instance=plugin)
|
|
||||||
config_path = tmp_path / "config.toml"
|
|
||||||
config_path.write_text(
|
|
||||||
'# 插件配置头注释\n[plugin]\nconfig_version = "1.0.0"\nenabled = false # 启用开关注释\n',
|
|
||||||
encoding="utf-8",
|
|
||||||
)
|
|
||||||
|
|
||||||
runner._apply_plugin_config(cast(Any, meta))
|
|
||||||
|
|
||||||
config_text = config_path.read_text(encoding="utf-8")
|
|
||||||
assert "# 插件配置头注释" in config_text
|
|
||||||
assert "# 启用开关注释" in config_text
|
|
||||||
|
|
||||||
with config_path.open("rb") as handle:
|
|
||||||
saved_config = tomllib.load(handle)
|
|
||||||
assert saved_config == {"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}}
|
|
||||||
|
|
||||||
|
|
||||||
def test_runner_apply_plugin_config_same_version_does_not_rewrite_file(tmp_path: Path) -> None:
|
|
||||||
"""Runner 在配置版本未变化时不应仅因补齐默认值而重写文件。"""
|
|
||||||
|
|
||||||
plugin = _DemoConfigPlugin()
|
|
||||||
runner = PluginRunner(
|
|
||||||
host_address="ipc://unused",
|
|
||||||
session_token="session-token",
|
|
||||||
plugin_dirs=[],
|
|
||||||
)
|
|
||||||
meta = SimpleNamespace(plugin_id="demo.plugin", plugin_dir=str(tmp_path), instance=plugin)
|
|
||||||
config_path = tmp_path / "config.toml"
|
|
||||||
original_config_text = '# 原始注释\n[plugin]\nconfig_version = "2.0.0"\nenabled = false\n'
|
|
||||||
config_path.write_text(original_config_text, encoding="utf-8")
|
|
||||||
|
|
||||||
runner._apply_plugin_config(cast(Any, meta))
|
|
||||||
|
|
||||||
assert plugin.received_config == {"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}}
|
|
||||||
assert config_path.read_text(encoding="utf-8") == original_config_text
|
|
||||||
|
|
||||||
|
|
||||||
def test_runner_apply_plugin_config_requires_config_version(tmp_path: Path) -> None:
|
|
||||||
"""Runner 应拒绝缺少配置版本号的插件配置文件。"""
|
|
||||||
|
|
||||||
plugin = _DemoConfigPlugin()
|
|
||||||
runner = PluginRunner(
|
|
||||||
host_address="ipc://unused",
|
|
||||||
session_token="session-token",
|
|
||||||
plugin_dirs=[],
|
|
||||||
)
|
|
||||||
meta = SimpleNamespace(plugin_id="demo.plugin", plugin_dir=str(tmp_path), instance=plugin)
|
|
||||||
config_path = tmp_path / "config.toml"
|
|
||||||
config_path.write_text("[plugin]\nenabled = true\n", encoding="utf-8")
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="config_version"):
|
|
||||||
runner._apply_plugin_config(cast(Any, meta))
|
|
||||||
|
|
||||||
|
|
||||||
def test_component_query_service_returns_plugin_config_schema(monkeypatch: Any) -> None:
|
|
||||||
"""组件查询服务应支持按插件 ID 返回配置 Schema。"""
|
|
||||||
|
|
||||||
payload = RegisterPluginPayload(
|
|
||||||
plugin_id="demo.plugin",
|
|
||||||
plugin_version="1.0.0",
|
|
||||||
default_config={"plugin": {"enabled": True}},
|
|
||||||
config_schema={
|
|
||||||
"plugin_id": "demo.plugin",
|
|
||||||
"plugin_info": {
|
|
||||||
"name": "Demo",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": "",
|
|
||||||
"author": "",
|
|
||||||
},
|
|
||||||
"sections": {"plugin": {"fields": {}}},
|
|
||||||
"layout": {"type": "auto", "tabs": []},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
fake_supervisor = SimpleNamespace(_registered_plugins={"demo.plugin": payload})
|
|
||||||
fake_manager = SimpleNamespace(_get_supervisor_for_plugin=lambda plugin_id: fake_supervisor)
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
type(component_query_service),
|
|
||||||
"_get_runtime_manager",
|
|
||||||
staticmethod(lambda: fake_manager),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert component_query_service.get_plugin_config_schema("demo.plugin") == payload.config_schema
|
|
||||||
assert component_query_service.get_plugin_default_config("demo.plugin") == payload.default_config
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_runner_validate_plugin_config_handler_returns_normalized_config(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
"""Runner 应返回插件模型归一化后的配置。"""
|
|
||||||
|
|
||||||
plugin = _DemoConfigPlugin()
|
|
||||||
runner = PluginRunner(
|
|
||||||
host_address="ipc://unused",
|
|
||||||
session_token="session-token",
|
|
||||||
plugin_dirs=[],
|
|
||||||
)
|
|
||||||
meta = SimpleNamespace(plugin_id="demo.plugin", plugin_dir="", instance=plugin)
|
|
||||||
monkeypatch.setattr(runner._loader, "get_plugin", lambda plugin_id: meta if plugin_id == "demo.plugin" else None)
|
|
||||||
|
|
||||||
envelope = Envelope(
|
|
||||||
request_id=1,
|
|
||||||
message_type=MessageType.REQUEST,
|
|
||||||
method="plugin.validate_config",
|
|
||||||
plugin_id="demo.plugin",
|
|
||||||
payload=ValidatePluginConfigPayload(
|
|
||||||
config_data={"plugin": {"config_version": "2.0.0", "enabled": False}}
|
|
||||||
).model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await runner._handle_validate_plugin_config(envelope)
|
|
||||||
|
|
||||||
assert response.error is None
|
|
||||||
assert response.payload["success"] is True
|
|
||||||
assert response.payload["normalized_config"] == {
|
|
||||||
"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_runner_inspect_plugin_config_handler_supports_unloaded_plugin(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
"""Runner 应支持对未加载插件执行冷检查。"""
|
|
||||||
|
|
||||||
plugin = _DemoConfigPlugin()
|
|
||||||
runner = PluginRunner(
|
|
||||||
host_address="ipc://unused",
|
|
||||||
session_token="session-token",
|
|
||||||
plugin_dirs=[],
|
|
||||||
)
|
|
||||||
meta = SimpleNamespace(
|
|
||||||
plugin_id="demo.plugin",
|
|
||||||
plugin_dir="/tmp/demo-plugin",
|
|
||||||
instance=plugin,
|
|
||||||
manifest=SimpleNamespace(
|
|
||||||
name="Demo",
|
|
||||||
description="",
|
|
||||||
author=SimpleNamespace(name="tester"),
|
|
||||||
),
|
|
||||||
version="1.0.0",
|
|
||||||
)
|
|
||||||
purged_plugins: list[tuple[str, str]] = []
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
runner,
|
|
||||||
"_resolve_plugin_meta_for_config_request",
|
|
||||||
lambda plugin_id: (meta, True, None) if plugin_id == "demo.plugin" else (None, False, "not-found"),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
runner._loader,
|
|
||||||
"purge_plugin_modules",
|
|
||||||
lambda plugin_id, plugin_dir: purged_plugins.append((plugin_id, plugin_dir)),
|
|
||||||
)
|
|
||||||
|
|
||||||
envelope = Envelope(
|
|
||||||
request_id=1,
|
|
||||||
message_type=MessageType.REQUEST,
|
|
||||||
method="plugin.inspect_config",
|
|
||||||
plugin_id="demo.plugin",
|
|
||||||
payload=InspectPluginConfigPayload(
|
|
||||||
config_data={"plugin": {"enabled": False}},
|
|
||||||
use_provided_config=True,
|
|
||||||
).model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await runner._handle_inspect_plugin_config(envelope)
|
|
||||||
|
|
||||||
assert response.error is None
|
|
||||||
assert response.payload["success"] is True
|
|
||||||
assert response.payload["enabled"] is False
|
|
||||||
assert response.payload["normalized_config"] == {
|
|
||||||
"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}
|
|
||||||
}
|
|
||||||
assert response.payload["default_config"] == {
|
|
||||||
"plugin": {"config_version": "2.0.0", "enabled": True, "retry_count": 3}
|
|
||||||
}
|
|
||||||
assert purged_plugins == [("demo.plugin", "/tmp/demo-plugin")]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_runner_validate_plugin_config_handler_returns_error_on_invalid_config(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
"""Runner 应在插件拒绝配置时返回错误响应。"""
|
|
||||||
|
|
||||||
plugin = _StrictConfigPlugin()
|
|
||||||
runner = PluginRunner(
|
|
||||||
host_address="ipc://unused",
|
|
||||||
session_token="session-token",
|
|
||||||
plugin_dirs=[],
|
|
||||||
)
|
|
||||||
meta = SimpleNamespace(plugin_id="demo.plugin", plugin_dir="", instance=plugin)
|
|
||||||
monkeypatch.setattr(runner._loader, "get_plugin", lambda plugin_id: meta if plugin_id == "demo.plugin" else None)
|
|
||||||
|
|
||||||
envelope = Envelope(
|
|
||||||
request_id=1,
|
|
||||||
message_type=MessageType.REQUEST,
|
|
||||||
method="plugin.validate_config",
|
|
||||||
plugin_id="demo.plugin",
|
|
||||||
payload=ValidatePluginConfigPayload(
|
|
||||||
config_data={"plugin": {"config_version": "2.0.0", "retry_count": -1}}
|
|
||||||
).model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await runner._handle_validate_plugin_config(envelope)
|
|
||||||
|
|
||||||
assert response.error is not None
|
|
||||||
assert response.error["message"] == "重试次数不能小于 0"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_plugin_config_prefers_runtime_validation(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
tmp_path: Path,
|
|
||||||
) -> None:
|
|
||||||
"""WebUI 保存插件配置时应优先使用运行时校验结果。"""
|
|
||||||
|
|
||||||
config_path = tmp_path / "config.toml"
|
|
||||||
|
|
||||||
async def _mock_validate_plugin_config(plugin_id: str, config_data: Dict[str, Any]) -> Dict[str, Any] | None:
|
|
||||||
"""返回运行时归一化后的配置。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_id: 插件 ID。
|
|
||||||
config_data: 原始配置。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, Any] | None: 归一化后的配置。
|
|
||||||
"""
|
|
||||||
|
|
||||||
assert plugin_id == "demo.plugin"
|
|
||||||
assert config_data == {"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}}
|
|
||||||
return {"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}}
|
|
||||||
|
|
||||||
async def _mock_inspect_plugin_config(
|
|
||||||
plugin_id: str,
|
|
||||||
config_data: Optional[Dict[str, Any]] = None,
|
|
||||||
*,
|
|
||||||
use_provided_config: bool = False,
|
|
||||||
) -> SimpleNamespace | None:
|
|
||||||
"""返回运行时配置快照。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_id: 插件 ID。
|
|
||||||
config_data: 可选配置。
|
|
||||||
use_provided_config: 是否使用传入配置。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SimpleNamespace | None: 运行时配置快照。
|
|
||||||
"""
|
|
||||||
|
|
||||||
del config_data, use_provided_config
|
|
||||||
if plugin_id != "demo.plugin":
|
|
||||||
return None
|
|
||||||
return SimpleNamespace(
|
|
||||||
normalized_config={"plugin": {"config_version": "2.0.0", "enabled": True, "retry_count": 3}}
|
|
||||||
)
|
|
||||||
|
|
||||||
fake_runtime_manager = SimpleNamespace(
|
|
||||||
inspect_plugin_config=_mock_inspect_plugin_config,
|
|
||||||
validate_plugin_config=_mock_validate_plugin_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"src.webui.routers.plugin.config_routes.require_plugin_token",
|
|
||||||
lambda session: session or "session-token",
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"src.webui.routers.plugin.config_routes.find_plugin_path_by_id",
|
|
||||||
lambda plugin_id: tmp_path if plugin_id == "demo.plugin" else None,
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"src.plugin_runtime.integration.get_plugin_runtime_manager",
|
|
||||||
lambda: fake_runtime_manager,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await update_plugin_config(
|
|
||||||
"demo.plugin",
|
|
||||||
UpdatePluginConfigRequest(config={"plugin.enabled": False}),
|
|
||||||
maibot_session="session-token",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response["success"] is True
|
|
||||||
with config_path.open("rb") as handle:
|
|
||||||
saved_config = tomllib.load(handle)
|
|
||||||
assert saved_config == {"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_webui_config_endpoints_use_runtime_inspection_for_unloaded_plugin(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
tmp_path: Path,
|
|
||||||
) -> None:
|
|
||||||
"""WebUI 在插件未加载时也应从代码定义返回配置与 Schema。"""
|
|
||||||
|
|
||||||
async def _mock_inspect_plugin_config(
|
|
||||||
plugin_id: str,
|
|
||||||
config_data: Optional[Dict[str, Any]] = None,
|
|
||||||
*,
|
|
||||||
use_provided_config: bool = False,
|
|
||||||
) -> SimpleNamespace | None:
|
|
||||||
"""返回运行时冷检查结果。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_id: 插件 ID。
|
|
||||||
config_data: 可选配置。
|
|
||||||
use_provided_config: 是否使用传入配置。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SimpleNamespace | None: 冷检查结果。
|
|
||||||
"""
|
|
||||||
|
|
||||||
del config_data, use_provided_config
|
|
||||||
if plugin_id != "demo.plugin":
|
|
||||||
return None
|
|
||||||
return SimpleNamespace(
|
|
||||||
config_schema={
|
|
||||||
"plugin_id": "demo.plugin",
|
|
||||||
"plugin_info": {
|
|
||||||
"name": "Demo",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": "",
|
|
||||||
"author": "",
|
|
||||||
},
|
|
||||||
"sections": {"plugin": {"fields": {}}},
|
|
||||||
"layout": {"type": "auto", "tabs": []},
|
|
||||||
},
|
|
||||||
normalized_config={"plugin": {"config_version": "2.0.0", "enabled": True, "retry_count": 3}},
|
|
||||||
enabled=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
fake_runtime_manager = SimpleNamespace(inspect_plugin_config=_mock_inspect_plugin_config)
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"src.webui.routers.plugin.config_routes.require_plugin_token",
|
|
||||||
lambda session: session or "session-token",
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"src.webui.routers.plugin.config_routes.find_plugin_path_by_id",
|
|
||||||
lambda plugin_id: tmp_path if plugin_id == "demo.plugin" else None,
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"src.plugin_runtime.integration.get_plugin_runtime_manager",
|
|
||||||
lambda: fake_runtime_manager,
|
|
||||||
)
|
|
||||||
|
|
||||||
schema_response = await get_plugin_config_schema("demo.plugin", maibot_session="session-token")
|
|
||||||
config_response = await get_plugin_config("demo.plugin", maibot_session="session-token")
|
|
||||||
|
|
||||||
assert schema_response["success"] is True
|
|
||||||
assert schema_response["schema"]["plugin_id"] == "demo.plugin"
|
|
||||||
assert config_response == {
|
|
||||||
"success": True,
|
|
||||||
"config": {"plugin": {"config_version": "2.0.0", "enabled": True, "retry_count": 3}},
|
|
||||||
"message": "配置文件不存在,已返回默认配置",
|
|
||||||
}
|
|
||||||
@@ -1,225 +0,0 @@
|
|||||||
"""插件依赖流水线测试。"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.plugin_runtime.dependency_pipeline import PluginDependencyPipeline
|
|
||||||
|
|
||||||
|
|
||||||
def _build_manifest(
|
|
||||||
plugin_id: str,
|
|
||||||
*,
|
|
||||||
dependencies: list[dict[str, str]] | None = None,
|
|
||||||
) -> dict[str, object]:
|
|
||||||
"""构造测试用的 Manifest v2 数据。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_id: 插件 ID。
|
|
||||||
dependencies: 依赖声明列表。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict[str, object]: 可直接写入 ``_manifest.json`` 的字典。
|
|
||||||
"""
|
|
||||||
|
|
||||||
return {
|
|
||||||
"manifest_version": 2,
|
|
||||||
"version": "1.0.0",
|
|
||||||
"name": plugin_id,
|
|
||||||
"description": "测试插件",
|
|
||||||
"author": {
|
|
||||||
"name": "tester",
|
|
||||||
"url": "https://example.com/tester",
|
|
||||||
},
|
|
||||||
"license": "MIT",
|
|
||||||
"urls": {
|
|
||||||
"repository": f"https://example.com/{plugin_id}",
|
|
||||||
},
|
|
||||||
"host_application": {
|
|
||||||
"min_version": "1.0.0",
|
|
||||||
"max_version": "1.0.0",
|
|
||||||
},
|
|
||||||
"sdk": {
|
|
||||||
"min_version": "2.0.0",
|
|
||||||
"max_version": "2.99.99",
|
|
||||||
},
|
|
||||||
"dependencies": dependencies or [],
|
|
||||||
"capabilities": [],
|
|
||||||
"i18n": {
|
|
||||||
"default_locale": "zh-CN",
|
|
||||||
"supported_locales": ["zh-CN"],
|
|
||||||
},
|
|
||||||
"id": plugin_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _write_plugin(
|
|
||||||
plugin_root: Path,
|
|
||||||
plugin_name: str,
|
|
||||||
plugin_id: str,
|
|
||||||
*,
|
|
||||||
dependencies: list[dict[str, str]] | None = None,
|
|
||||||
) -> Path:
|
|
||||||
"""在临时目录中写入一个测试插件。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_root: 插件根目录。
|
|
||||||
plugin_name: 插件目录名。
|
|
||||||
plugin_id: 插件 ID。
|
|
||||||
dependencies: Python 依赖声明列表。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path: 插件目录路径。
|
|
||||||
"""
|
|
||||||
|
|
||||||
plugin_dir = plugin_root / plugin_name
|
|
||||||
plugin_dir.mkdir(parents=True)
|
|
||||||
(plugin_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
|
|
||||||
(plugin_dir / "_manifest.json").write_text(
|
|
||||||
json.dumps(_build_manifest(plugin_id, dependencies=dependencies)),
|
|
||||||
encoding="utf-8",
|
|
||||||
)
|
|
||||||
return plugin_dir
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_plan_blocks_plugin_conflicting_with_host_requirement(tmp_path: Path) -> None:
|
|
||||||
"""与主程序依赖冲突的插件应被阻止加载。"""
|
|
||||||
|
|
||||||
plugin_root = tmp_path / "plugins"
|
|
||||||
_write_plugin(
|
|
||||||
plugin_root,
|
|
||||||
"conflict_plugin",
|
|
||||||
"test.conflict-plugin",
|
|
||||||
dependencies=[
|
|
||||||
{
|
|
||||||
"type": "python_package",
|
|
||||||
"name": "numpy",
|
|
||||||
"version_spec": "<1.0.0",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
pipeline = PluginDependencyPipeline(project_root=Path.cwd())
|
|
||||||
plan = pipeline.build_plan([plugin_root])
|
|
||||||
|
|
||||||
assert "test.conflict-plugin" in plan.blocked_plugin_reasons
|
|
||||||
assert "主程序" in plan.blocked_plugin_reasons["test.conflict-plugin"]
|
|
||||||
assert plan.install_requirements == ()
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_plan_blocks_plugins_with_conflicting_python_dependencies(tmp_path: Path) -> None:
|
|
||||||
"""插件之间出现 Python 包版本冲突时应同时阻止双方加载。"""
|
|
||||||
|
|
||||||
plugin_root = tmp_path / "plugins"
|
|
||||||
_write_plugin(
|
|
||||||
plugin_root,
|
|
||||||
"plugin_a",
|
|
||||||
"test.plugin-a",
|
|
||||||
dependencies=[
|
|
||||||
{
|
|
||||||
"type": "python_package",
|
|
||||||
"name": "demo-package",
|
|
||||||
"version_spec": "<2.0.0",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
_write_plugin(
|
|
||||||
plugin_root,
|
|
||||||
"plugin_b",
|
|
||||||
"test.plugin-b",
|
|
||||||
dependencies=[
|
|
||||||
{
|
|
||||||
"type": "python_package",
|
|
||||||
"name": "demo-package",
|
|
||||||
"version_spec": ">=3.0.0",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
pipeline = PluginDependencyPipeline(project_root=Path.cwd())
|
|
||||||
plan = pipeline.build_plan([plugin_root])
|
|
||||||
|
|
||||||
assert "test.plugin-a" in plan.blocked_plugin_reasons
|
|
||||||
assert "test.plugin-b" in plan.blocked_plugin_reasons
|
|
||||||
assert "test.plugin-b" in plan.blocked_plugin_reasons["test.plugin-a"]
|
|
||||||
assert "test.plugin-a" in plan.blocked_plugin_reasons["test.plugin-b"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_plan_collects_install_requirements_for_missing_packages(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
tmp_path: Path,
|
|
||||||
) -> None:
|
|
||||||
"""未安装但无冲突的依赖应进入自动安装计划。"""
|
|
||||||
|
|
||||||
plugin_root = tmp_path / "plugins"
|
|
||||||
_write_plugin(
|
|
||||||
plugin_root,
|
|
||||||
"plugin_a",
|
|
||||||
"test.plugin-a",
|
|
||||||
dependencies=[
|
|
||||||
{
|
|
||||||
"type": "python_package",
|
|
||||||
"name": "demo-package",
|
|
||||||
"version_spec": ">=1.0.0,<2.0.0",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
pipeline = PluginDependencyPipeline(project_root=Path.cwd())
|
|
||||||
monkeypatch.setattr(
|
|
||||||
pipeline._manifest_validator,
|
|
||||||
"get_installed_package_version",
|
|
||||||
lambda package_name: None if package_name == "demo-package" else "1.0.0",
|
|
||||||
)
|
|
||||||
|
|
||||||
plan = pipeline.build_plan([plugin_root])
|
|
||||||
|
|
||||||
assert plan.blocked_plugin_reasons == {}
|
|
||||||
assert len(plan.install_requirements) == 1
|
|
||||||
assert plan.install_requirements[0].package_name == "demo-package"
|
|
||||||
assert plan.install_requirements[0].plugin_ids == ("test.plugin-a",)
|
|
||||||
assert plan.install_requirements[0].requirement_text == "demo-package>=1.0.0,<2.0.0"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_execute_blocks_plugins_when_auto_install_fails(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
tmp_path: Path,
|
|
||||||
) -> None:
|
|
||||||
"""自动安装失败时,相关插件应被阻止加载。"""
|
|
||||||
|
|
||||||
plugin_root = tmp_path / "plugins"
|
|
||||||
_write_plugin(
|
|
||||||
plugin_root,
|
|
||||||
"plugin_a",
|
|
||||||
"test.plugin-a",
|
|
||||||
dependencies=[
|
|
||||||
{
|
|
||||||
"type": "python_package",
|
|
||||||
"name": "demo-package",
|
|
||||||
"version_spec": ">=1.0.0,<2.0.0",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
pipeline = PluginDependencyPipeline(project_root=Path.cwd())
|
|
||||||
monkeypatch.setattr(
|
|
||||||
pipeline._manifest_validator,
|
|
||||||
"get_installed_package_version",
|
|
||||||
lambda package_name: None if package_name == "demo-package" else "1.0.0",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def fake_install(_requirements) -> tuple[bool, str]:
|
|
||||||
"""模拟依赖安装失败。"""
|
|
||||||
|
|
||||||
return False, "network error"
|
|
||||||
|
|
||||||
monkeypatch.setattr(pipeline, "_install_requirements", fake_install)
|
|
||||||
|
|
||||||
result = await pipeline.execute([plugin_root])
|
|
||||||
|
|
||||||
assert result.environment_changed is False
|
|
||||||
assert "test.plugin-a" in result.blocked_plugin_reasons
|
|
||||||
assert "自动安装 Python 依赖失败" in result.blocked_plugin_reasons["test.plugin-a"]
|
|
||||||
@@ -1,86 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from src.chat.message_receive.message import SessionMessage
|
|
||||||
from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo, UserInfo
|
|
||||||
from src.common.data_models.message_component_data_model import (
|
|
||||||
ForwardComponent,
|
|
||||||
ForwardNodeComponent,
|
|
||||||
ImageComponent,
|
|
||||||
MessageSequence,
|
|
||||||
ReplyComponent,
|
|
||||||
TextComponent,
|
|
||||||
VoiceComponent,
|
|
||||||
)
|
|
||||||
from src.plugin_runtime.host.message_utils import PluginMessageUtils
|
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
|
||||||
if str(PROJECT_ROOT) not in sys.path:
|
|
||||||
sys.path.insert(0, str(PROJECT_ROOT))
|
|
||||||
|
|
||||||
|
|
||||||
def test_plugin_message_utils_preserves_binary_components_and_reply_metadata() -> None:
|
|
||||||
message = SessionMessage(message_id="msg-1", timestamp=datetime.now(), platform="qq")
|
|
||||||
message.message_info = MessageInfo(
|
|
||||||
user_info=UserInfo(user_id="10001", user_nickname="tester"),
|
|
||||||
group_info=GroupInfo(group_id="20001", group_name="group"),
|
|
||||||
additional_config={"self_id": "999"},
|
|
||||||
)
|
|
||||||
message.session_id = "qq:20001:10001"
|
|
||||||
message.processed_plain_text = "binary payload"
|
|
||||||
message.raw_message = MessageSequence(
|
|
||||||
components=[
|
|
||||||
TextComponent("hello"),
|
|
||||||
ImageComponent(binary_hash="", binary_data=b"image-bytes", content=""),
|
|
||||||
VoiceComponent(binary_hash="", binary_data=b"voice-bytes", content=""),
|
|
||||||
ReplyComponent(
|
|
||||||
target_message_id="origin-1",
|
|
||||||
target_message_content="origin text",
|
|
||||||
target_message_sender_id="42",
|
|
||||||
target_message_sender_nickname="alice",
|
|
||||||
target_message_sender_cardname="Alice",
|
|
||||||
),
|
|
||||||
ForwardNodeComponent(
|
|
||||||
forward_components=[
|
|
||||||
ForwardComponent(
|
|
||||||
user_nickname="bob",
|
|
||||||
user_id="43",
|
|
||||||
user_cardname="Bob",
|
|
||||||
message_id="forward-1",
|
|
||||||
content=[
|
|
||||||
TextComponent("node-text"),
|
|
||||||
ImageComponent(binary_hash="", binary_data=b"node-image", content=""),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
]
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
message_dict = PluginMessageUtils._session_message_to_dict(message)
|
|
||||||
rebuilt_message = PluginMessageUtils._build_session_message_from_dict(dict(message_dict))
|
|
||||||
|
|
||||||
image_component = rebuilt_message.raw_message.components[1]
|
|
||||||
voice_component = rebuilt_message.raw_message.components[2]
|
|
||||||
reply_component = rebuilt_message.raw_message.components[3]
|
|
||||||
forward_component = rebuilt_message.raw_message.components[4]
|
|
||||||
|
|
||||||
assert isinstance(image_component, ImageComponent)
|
|
||||||
assert image_component.binary_data == b"image-bytes"
|
|
||||||
|
|
||||||
assert isinstance(voice_component, VoiceComponent)
|
|
||||||
assert voice_component.binary_data == b"voice-bytes"
|
|
||||||
|
|
||||||
assert isinstance(reply_component, ReplyComponent)
|
|
||||||
assert reply_component.target_message_id == "origin-1"
|
|
||||||
assert reply_component.target_message_content == "origin text"
|
|
||||||
assert reply_component.target_message_sender_id == "42"
|
|
||||||
assert reply_component.target_message_sender_nickname == "alice"
|
|
||||||
assert reply_component.target_message_sender_cardname == "Alice"
|
|
||||||
|
|
||||||
assert isinstance(forward_component, ForwardNodeComponent)
|
|
||||||
assert isinstance(forward_component.forward_components[0].content[1], ImageComponent)
|
|
||||||
assert forward_component.forward_components[0].content[1].binary_data == b"node-image"
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,284 +0,0 @@
|
|||||||
"""核心组件查询层与插件运行时聚合测试。"""
|
|
||||||
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import src.plugin_runtime.integration as integration_module
|
|
||||||
|
|
||||||
from src.core.types import ActionInfo, ToolInfo
|
|
||||||
from src.plugin_runtime.component_query import component_query_service
|
|
||||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeRuntimeManager:
|
|
||||||
"""测试用插件运行时管理器。"""
|
|
||||||
|
|
||||||
def __init__(self, supervisor: PluginSupervisor, plugin_id: str, plugin_config: dict[str, Any]) -> None:
|
|
||||||
"""初始化测试用运行时管理器。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
supervisor: 持有测试组件的监督器。
|
|
||||||
plugin_id: 目标插件 ID。
|
|
||||||
plugin_config: 需要返回的插件配置。
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.supervisors = [supervisor]
|
|
||||||
self._plugin_id = plugin_id
|
|
||||||
self._plugin_config = plugin_config
|
|
||||||
|
|
||||||
def _get_supervisor_for_plugin(self, plugin_id: str) -> PluginSupervisor | None:
|
|
||||||
"""按插件 ID 返回对应监督器。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_id: 目标插件 ID。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PluginSupervisor | None: 命中时返回监督器。
|
|
||||||
"""
|
|
||||||
|
|
||||||
return self.supervisors[0] if plugin_id == self._plugin_id else None
|
|
||||||
|
|
||||||
def _load_plugin_config_for_supervisor(self, supervisor: Any, plugin_id: str) -> dict[str, Any]:
|
|
||||||
"""返回测试配置。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
supervisor: 监督器实例。
|
|
||||||
plugin_id: 目标插件 ID。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict[str, Any]: 测试配置内容。
|
|
||||||
"""
|
|
||||||
|
|
||||||
del supervisor
|
|
||||||
if plugin_id != self._plugin_id:
|
|
||||||
return {}
|
|
||||||
return dict(self._plugin_config)
|
|
||||||
|
|
||||||
|
|
||||||
def _install_runtime_manager(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
supervisor: PluginSupervisor,
|
|
||||||
plugin_id: str,
|
|
||||||
plugin_config: dict[str, Any] | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""为测试安装假的运行时管理器。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
monkeypatch: pytest monkeypatch 对象。
|
|
||||||
supervisor: 持有测试组件的监督器。
|
|
||||||
plugin_id: 测试插件 ID。
|
|
||||||
plugin_config: 可选的测试配置内容。
|
|
||||||
"""
|
|
||||||
|
|
||||||
fake_manager = _FakeRuntimeManager(supervisor, plugin_id, plugin_config or {"enabled": True})
|
|
||||||
monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: fake_manager)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_core_component_registry_reads_runtime_action_and_executor(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
"""核心查询层应直接读取运行时 Action,并返回 RPC 执行闭包。"""
|
|
||||||
|
|
||||||
plugin_id = "runtime_action_bridge_plugin"
|
|
||||||
action_name = "runtime_action_bridge_test"
|
|
||||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
|
||||||
captured: dict[str, Any] = {}
|
|
||||||
|
|
||||||
supervisor.component_registry.register_component(
|
|
||||||
name=action_name,
|
|
||||||
component_type="ACTION",
|
|
||||||
plugin_id=plugin_id,
|
|
||||||
metadata={
|
|
||||||
"description": "发送一个测试回复",
|
|
||||||
"enabled": True,
|
|
||||||
"activation_type": "keyword",
|
|
||||||
"activation_probability": 0.25,
|
|
||||||
"activation_keywords": ["测试", "hello"],
|
|
||||||
"action_parameters": {"target": "目标对象"},
|
|
||||||
"action_require": ["需要发送回复时使用"],
|
|
||||||
"associated_types": ["text"],
|
|
||||||
"parallel_action": True,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
_install_runtime_manager(monkeypatch, supervisor, plugin_id, {"enabled": True, "mode": "test"})
|
|
||||||
|
|
||||||
async def fake_invoke_plugin(
|
|
||||||
method: str,
|
|
||||||
plugin_id: str,
|
|
||||||
component_name: str,
|
|
||||||
args: dict[str, Any] | None = None,
|
|
||||||
timeout_ms: int = 30000,
|
|
||||||
) -> Any:
|
|
||||||
"""模拟动作 RPC 调用。"""
|
|
||||||
|
|
||||||
captured["method"] = method
|
|
||||||
captured["plugin_id"] = plugin_id
|
|
||||||
captured["component_name"] = component_name
|
|
||||||
captured["args"] = args or {}
|
|
||||||
captured["timeout_ms"] = timeout_ms
|
|
||||||
return SimpleNamespace(payload={"success": True, "result": (True, "runtime action executed")})
|
|
||||||
|
|
||||||
monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin)
|
|
||||||
|
|
||||||
action_info = component_query_service.get_action_info(action_name)
|
|
||||||
assert isinstance(action_info, ActionInfo)
|
|
||||||
assert action_info.plugin_name == plugin_id
|
|
||||||
assert action_info.description == "发送一个测试回复"
|
|
||||||
assert action_info.activation_keywords == ["测试", "hello"]
|
|
||||||
assert action_info.random_activation_probability == 0.25
|
|
||||||
assert action_info.parallel_action is True
|
|
||||||
assert action_name in component_query_service.get_default_actions()
|
|
||||||
assert component_query_service.get_plugin_config(plugin_id) == {"enabled": True, "mode": "test"}
|
|
||||||
|
|
||||||
executor = component_query_service.get_action_executor(action_name)
|
|
||||||
assert executor is not None
|
|
||||||
|
|
||||||
success, reason = await executor(
|
|
||||||
action_data={"target": "MaiBot"},
|
|
||||||
action_reasoning="当前适合使用这个动作",
|
|
||||||
cycle_timers={"planner": 0.1},
|
|
||||||
thinking_id="tid-1",
|
|
||||||
chat_stream=SimpleNamespace(session_id="stream-1"),
|
|
||||||
log_prefix="[test]",
|
|
||||||
shutting_down=False,
|
|
||||||
plugin_config={"enabled": True},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert success is True
|
|
||||||
assert reason == "runtime action executed"
|
|
||||||
assert captured["method"] == "plugin.invoke_action"
|
|
||||||
assert captured["plugin_id"] == plugin_id
|
|
||||||
assert captured["component_name"] == action_name
|
|
||||||
assert captured["args"]["stream_id"] == "stream-1"
|
|
||||||
assert captured["args"]["chat_id"] == "stream-1"
|
|
||||||
assert captured["args"]["reasoning"] == "当前适合使用这个动作"
|
|
||||||
assert captured["args"]["target"] == "MaiBot"
|
|
||||||
assert captured["args"]["action_data"] == {"target": "MaiBot"}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_core_component_registry_reads_runtime_command_and_executor(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
"""核心查询层应直接使用运行时命令匹配与执行闭包。"""
|
|
||||||
|
|
||||||
plugin_id = "runtime_command_bridge_plugin"
|
|
||||||
command_name = "runtime_command_bridge_test"
|
|
||||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
|
||||||
captured: dict[str, Any] = {}
|
|
||||||
|
|
||||||
supervisor.component_registry.register_component(
|
|
||||||
name=command_name,
|
|
||||||
component_type="COMMAND",
|
|
||||||
plugin_id=plugin_id,
|
|
||||||
metadata={
|
|
||||||
"description": "测试命令",
|
|
||||||
"enabled": True,
|
|
||||||
"command_pattern": r"^/test(?:\s+.+)?$",
|
|
||||||
"aliases": ["/hello"],
|
|
||||||
"intercept_message_level": 1,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
_install_runtime_manager(monkeypatch, supervisor, plugin_id, {"mode": "command"})
|
|
||||||
|
|
||||||
async def fake_invoke_plugin(
|
|
||||||
method: str,
|
|
||||||
plugin_id: str,
|
|
||||||
component_name: str,
|
|
||||||
args: dict[str, Any] | None = None,
|
|
||||||
timeout_ms: int = 30000,
|
|
||||||
) -> Any:
|
|
||||||
"""模拟命令 RPC 调用。"""
|
|
||||||
|
|
||||||
captured["method"] = method
|
|
||||||
captured["plugin_id"] = plugin_id
|
|
||||||
captured["component_name"] = component_name
|
|
||||||
captured["args"] = args or {}
|
|
||||||
captured["timeout_ms"] = timeout_ms
|
|
||||||
return SimpleNamespace(payload={"success": True, "result": (True, "command ok", True)})
|
|
||||||
|
|
||||||
monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin)
|
|
||||||
|
|
||||||
matched = component_query_service.find_command_by_text("/test hello")
|
|
||||||
assert matched is not None
|
|
||||||
command_executor, matched_groups, command_info = matched
|
|
||||||
|
|
||||||
assert matched_groups == {}
|
|
||||||
assert command_info.plugin_name == plugin_id
|
|
||||||
assert command_info.command_pattern == r"^/test(?:\s+.+)?$"
|
|
||||||
|
|
||||||
success, response_text, intercept = await command_executor(
|
|
||||||
message=SimpleNamespace(processed_plain_text="/test hello", session_id="stream-2"),
|
|
||||||
plugin_config={"mode": "command"},
|
|
||||||
matched_groups=matched_groups,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert success is True
|
|
||||||
assert response_text == "command ok"
|
|
||||||
assert intercept is True
|
|
||||||
assert captured["method"] == "plugin.invoke_command"
|
|
||||||
assert captured["plugin_id"] == plugin_id
|
|
||||||
assert captured["component_name"] == command_name
|
|
||||||
assert captured["args"]["text"] == "/test hello"
|
|
||||||
assert captured["args"]["stream_id"] == "stream-2"
|
|
||||||
assert captured["args"]["plugin_config"] == {"mode": "command"}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_core_component_registry_reads_runtime_tools_and_executor(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
"""核心查询层应直接读取运行时 Tool,并返回 RPC 执行闭包。"""
|
|
||||||
|
|
||||||
plugin_id = "runtime_tool_bridge_plugin"
|
|
||||||
tool_name = "runtime_tool_bridge_test"
|
|
||||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
|
||||||
|
|
||||||
supervisor.component_registry.register_component(
|
|
||||||
name=tool_name,
|
|
||||||
component_type="TOOL",
|
|
||||||
plugin_id=plugin_id,
|
|
||||||
metadata={
|
|
||||||
"description": "测试工具",
|
|
||||||
"enabled": True,
|
|
||||||
"parameters": [
|
|
||||||
{
|
|
||||||
"name": "query",
|
|
||||||
"param_type": "string",
|
|
||||||
"description": "查询词",
|
|
||||||
"required": True,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
_install_runtime_manager(monkeypatch, supervisor, plugin_id)
|
|
||||||
|
|
||||||
async def fake_invoke_plugin(
|
|
||||||
method: str,
|
|
||||||
plugin_id: str,
|
|
||||||
component_name: str,
|
|
||||||
args: dict[str, Any] | None = None,
|
|
||||||
timeout_ms: int = 30000,
|
|
||||||
) -> Any:
|
|
||||||
"""模拟工具 RPC 调用。"""
|
|
||||||
|
|
||||||
del timeout_ms
|
|
||||||
assert method == "plugin.invoke_tool"
|
|
||||||
assert plugin_id == "runtime_tool_bridge_plugin"
|
|
||||||
assert component_name == "runtime_tool_bridge_test"
|
|
||||||
assert args == {"query": "MaiBot"}
|
|
||||||
return SimpleNamespace(payload={"success": True, "result": {"content": "tool ok"}})
|
|
||||||
|
|
||||||
monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin)
|
|
||||||
|
|
||||||
tool_info = component_query_service.get_tool_info(tool_name)
|
|
||||||
assert isinstance(tool_info, ToolInfo)
|
|
||||||
assert tool_info.tool_description == "测试工具"
|
|
||||||
assert tool_name in component_query_service.get_llm_available_tools()
|
|
||||||
|
|
||||||
executor = component_query_service.get_tool_executor(tool_name)
|
|
||||||
assert executor is not None
|
|
||||||
assert await executor({"query": "MaiBot"}) == {"content": "tool ok"}
|
|
||||||
@@ -1,524 +0,0 @@
|
|||||||
"""插件 API 注册与调用测试。"""
|
|
||||||
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.plugin_runtime.integration import PluginRuntimeManager
|
|
||||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
|
||||||
from src.plugin_runtime.protocol.envelope import (
|
|
||||||
ComponentDeclaration,
|
|
||||||
Envelope,
|
|
||||||
MessageType,
|
|
||||||
RegisterPluginPayload,
|
|
||||||
UnregisterPluginPayload,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_manager(*supervisors: PluginSupervisor) -> PluginRuntimeManager:
|
|
||||||
"""构造一个最小可用的插件运行时管理器。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
*supervisors: 需要挂载的监督器列表。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PluginRuntimeManager: 已注入监督器的运行时管理器。
|
|
||||||
"""
|
|
||||||
|
|
||||||
manager = PluginRuntimeManager()
|
|
||||||
if supervisors:
|
|
||||||
manager._builtin_supervisor = supervisors[0]
|
|
||||||
if len(supervisors) > 1:
|
|
||||||
manager._third_party_supervisor = supervisors[1]
|
|
||||||
return manager
|
|
||||||
|
|
||||||
|
|
||||||
async def _register_plugin(
|
|
||||||
supervisor: PluginSupervisor,
|
|
||||||
plugin_id: str,
|
|
||||||
components: List[Dict[str, Any]],
|
|
||||||
) -> Envelope:
|
|
||||||
"""通过 Supervisor 注册测试插件。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
supervisor: 目标监督器。
|
|
||||||
plugin_id: 测试插件 ID。
|
|
||||||
components: 组件声明列表。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Envelope: 注册响应信封。
|
|
||||||
"""
|
|
||||||
|
|
||||||
payload = RegisterPluginPayload(
|
|
||||||
plugin_id=plugin_id,
|
|
||||||
plugin_version="1.0.0",
|
|
||||||
components=[
|
|
||||||
ComponentDeclaration(
|
|
||||||
name=str(component.get("name", "") or ""),
|
|
||||||
component_type=str(component.get("component_type", "") or ""),
|
|
||||||
plugin_id=plugin_id,
|
|
||||||
metadata=component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {},
|
|
||||||
)
|
|
||||||
for component in components
|
|
||||||
],
|
|
||||||
)
|
|
||||||
return await supervisor._handle_register_plugin(
|
|
||||||
Envelope(
|
|
||||||
request_id=1,
|
|
||||||
message_type=MessageType.REQUEST,
|
|
||||||
method="plugin.register_components",
|
|
||||||
plugin_id=plugin_id,
|
|
||||||
payload=payload.model_dump(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _unregister_plugin(supervisor: PluginSupervisor, plugin_id: str) -> Envelope:
|
|
||||||
"""通过 Supervisor 注销测试插件。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
supervisor: 目标监督器。
|
|
||||||
plugin_id: 测试插件 ID。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Envelope: 注销响应信封。
|
|
||||||
"""
|
|
||||||
|
|
||||||
payload = UnregisterPluginPayload(plugin_id=plugin_id, reason="test")
|
|
||||||
return await supervisor._handle_unregister_plugin(
|
|
||||||
Envelope(
|
|
||||||
request_id=2,
|
|
||||||
message_type=MessageType.REQUEST,
|
|
||||||
method="plugin.unregister",
|
|
||||||
plugin_id=plugin_id,
|
|
||||||
payload=payload.model_dump(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_register_plugin_syncs_dedicated_api_registry() -> None:
|
|
||||||
"""插件注册时应将 API 同步到独立注册表,而不是通用组件表。"""
|
|
||||||
|
|
||||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
|
||||||
response = await _register_plugin(
|
|
||||||
supervisor,
|
|
||||||
"provider",
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"name": "render_html",
|
|
||||||
"component_type": "API",
|
|
||||||
"metadata": {
|
|
||||||
"description": "渲染 HTML",
|
|
||||||
"version": "1",
|
|
||||||
"public": True,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.payload["accepted"] is True
|
|
||||||
assert response.payload["registered_components"] == 0
|
|
||||||
assert response.payload["registered_apis"] == 1
|
|
||||||
assert supervisor.api_registry.get_api("provider", "render_html") is not None
|
|
||||||
assert supervisor.component_registry.get_component("provider.render_html") is None
|
|
||||||
|
|
||||||
unregister_response = await _unregister_plugin(supervisor, "provider")
|
|
||||||
assert unregister_response.payload["removed_apis"] == 1
|
|
||||||
assert supervisor.api_registry.get_api("provider", "render_html") is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_api_call_allows_public_api_between_plugins(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
"""公开 API 应允许其他插件通过 Host 转发调用。"""
|
|
||||||
|
|
||||||
provider_supervisor = PluginSupervisor(plugin_dirs=[])
|
|
||||||
consumer_supervisor = PluginSupervisor(plugin_dirs=[])
|
|
||||||
await _register_plugin(
|
|
||||||
provider_supervisor,
|
|
||||||
"provider",
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"name": "render_html",
|
|
||||||
"component_type": "API",
|
|
||||||
"metadata": {
|
|
||||||
"description": "渲染 HTML",
|
|
||||||
"version": "1",
|
|
||||||
"public": True,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
await _register_plugin(consumer_supervisor, "consumer", [])
|
|
||||||
|
|
||||||
captured: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
async def fake_invoke_api(
|
|
||||||
plugin_id: str,
|
|
||||||
component_name: str,
|
|
||||||
args: Dict[str, Any] | None = None,
|
|
||||||
timeout_ms: int = 30000,
|
|
||||||
) -> Any:
|
|
||||||
"""模拟 API RPC 调用。"""
|
|
||||||
|
|
||||||
captured["plugin_id"] = plugin_id
|
|
||||||
captured["component_name"] = component_name
|
|
||||||
captured["args"] = args or {}
|
|
||||||
captured["timeout_ms"] = timeout_ms
|
|
||||||
return SimpleNamespace(error=None, payload={"success": True, "result": {"image": "ok"}})
|
|
||||||
|
|
||||||
monkeypatch.setattr(provider_supervisor, "invoke_api", fake_invoke_api)
|
|
||||||
|
|
||||||
manager = _build_manager(provider_supervisor, consumer_supervisor)
|
|
||||||
result = await manager._cap_api_call(
|
|
||||||
"consumer",
|
|
||||||
"api.call",
|
|
||||||
{
|
|
||||||
"api_name": "provider.render_html",
|
|
||||||
"version": "1",
|
|
||||||
"args": {"html": "<div>Hello</div>"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == {"success": True, "result": {"image": "ok"}}
|
|
||||||
assert captured["plugin_id"] == "provider"
|
|
||||||
assert captured["component_name"] == "render_html"
|
|
||||||
assert captured["args"] == {"html": "<div>Hello</div>"}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_api_call_rejects_private_api_between_plugins() -> None:
|
|
||||||
"""未公开的 API 默认不允许跨插件调用。"""
|
|
||||||
|
|
||||||
provider_supervisor = PluginSupervisor(plugin_dirs=[])
|
|
||||||
consumer_supervisor = PluginSupervisor(plugin_dirs=[])
|
|
||||||
await _register_plugin(
|
|
||||||
provider_supervisor,
|
|
||||||
"provider",
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"name": "secret_api",
|
|
||||||
"component_type": "API",
|
|
||||||
"metadata": {
|
|
||||||
"description": "私有 API",
|
|
||||||
"version": "1",
|
|
||||||
"public": False,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
await _register_plugin(consumer_supervisor, "consumer", [])
|
|
||||||
|
|
||||||
manager = _build_manager(provider_supervisor, consumer_supervisor)
|
|
||||||
result = await manager._cap_api_call(
|
|
||||||
"consumer",
|
|
||||||
"api.call",
|
|
||||||
{
|
|
||||||
"api_name": "provider.secret_api",
|
|
||||||
"args": {},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result["success"] is False
|
|
||||||
assert "未公开" in str(result["error"])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_api_list_and_component_toggle_use_dedicated_registry() -> None:
|
|
||||||
"""API 列表与组件启停应直接作用于独立 API 注册表。"""
|
|
||||||
|
|
||||||
provider_supervisor = PluginSupervisor(plugin_dirs=[])
|
|
||||||
consumer_supervisor = PluginSupervisor(plugin_dirs=[])
|
|
||||||
await _register_plugin(
|
|
||||||
provider_supervisor,
|
|
||||||
"provider",
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"name": "public_api",
|
|
||||||
"component_type": "API",
|
|
||||||
"metadata": {"version": "1", "public": True},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "private_api",
|
|
||||||
"component_type": "API",
|
|
||||||
"metadata": {"version": "1", "public": False},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
await _register_plugin(
|
|
||||||
consumer_supervisor,
|
|
||||||
"consumer",
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"name": "self_private_api",
|
|
||||||
"component_type": "API",
|
|
||||||
"metadata": {"version": "1", "public": False},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
manager = _build_manager(provider_supervisor, consumer_supervisor)
|
|
||||||
list_result = await manager._cap_api_list("consumer", "api.list", {})
|
|
||||||
|
|
||||||
assert list_result["success"] is True
|
|
||||||
api_names = {(item["plugin_id"], item["name"]) for item in list_result["apis"]}
|
|
||||||
assert ("provider", "public_api") in api_names
|
|
||||||
assert ("provider", "private_api") not in api_names
|
|
||||||
assert ("consumer", "self_private_api") in api_names
|
|
||||||
|
|
||||||
disable_result = await manager._cap_component_disable(
|
|
||||||
"consumer",
|
|
||||||
"component.disable",
|
|
||||||
{
|
|
||||||
"name": "provider.public_api",
|
|
||||||
"component_type": "API",
|
|
||||||
"scope": "global",
|
|
||||||
"stream_id": "",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert disable_result["success"] is True
|
|
||||||
assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is None
|
|
||||||
|
|
||||||
enable_result = await manager._cap_component_enable(
|
|
||||||
"consumer",
|
|
||||||
"component.enable",
|
|
||||||
{
|
|
||||||
"name": "provider.public_api",
|
|
||||||
"component_type": "API",
|
|
||||||
"scope": "global",
|
|
||||||
"stream_id": "",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert enable_result["success"] is True
|
|
||||||
assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is not None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_api_registry_supports_multiple_versions_with_distinct_handlers(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
"""同名 API 不同版本应可并存,并按版本路由到不同处理器。"""
|
|
||||||
|
|
||||||
provider_supervisor = PluginSupervisor(plugin_dirs=[])
|
|
||||||
consumer_supervisor = PluginSupervisor(plugin_dirs=[])
|
|
||||||
await _register_plugin(
|
|
||||||
provider_supervisor,
|
|
||||||
"provider",
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"name": "render_html",
|
|
||||||
"component_type": "API",
|
|
||||||
"metadata": {
|
|
||||||
"description": "渲染 HTML v1",
|
|
||||||
"version": "1",
|
|
||||||
"public": True,
|
|
||||||
"handler_name": "handle_render_html_v1",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "render_html",
|
|
||||||
"component_type": "API",
|
|
||||||
"metadata": {
|
|
||||||
"description": "渲染 HTML v2",
|
|
||||||
"version": "2",
|
|
||||||
"public": True,
|
|
||||||
"handler_name": "handle_render_html_v2",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
await _register_plugin(consumer_supervisor, "consumer", [])
|
|
||||||
|
|
||||||
captured: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
async def fake_invoke_api(
|
|
||||||
plugin_id: str,
|
|
||||||
component_name: str,
|
|
||||||
args: Dict[str, Any] | None = None,
|
|
||||||
timeout_ms: int = 30000,
|
|
||||||
) -> Any:
|
|
||||||
"""模拟多版本 API 调用。"""
|
|
||||||
|
|
||||||
captured["plugin_id"] = plugin_id
|
|
||||||
captured["component_name"] = component_name
|
|
||||||
captured["args"] = args or {}
|
|
||||||
captured["timeout_ms"] = timeout_ms
|
|
||||||
return SimpleNamespace(error=None, payload={"success": True, "result": {"image": "ok"}})
|
|
||||||
|
|
||||||
monkeypatch.setattr(provider_supervisor, "invoke_api", fake_invoke_api)
|
|
||||||
manager = _build_manager(provider_supervisor, consumer_supervisor)
|
|
||||||
|
|
||||||
ambiguous_result = await manager._cap_api_call(
|
|
||||||
"consumer",
|
|
||||||
"api.call",
|
|
||||||
{
|
|
||||||
"api_name": "provider.render_html",
|
|
||||||
"args": {"html": "<div>Hello</div>"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert ambiguous_result["success"] is False
|
|
||||||
assert "多个版本" in str(ambiguous_result["error"])
|
|
||||||
|
|
||||||
disable_ambiguous_result = await manager._cap_component_disable(
|
|
||||||
"consumer",
|
|
||||||
"component.disable",
|
|
||||||
{
|
|
||||||
"name": "provider.render_html",
|
|
||||||
"component_type": "API",
|
|
||||||
"scope": "global",
|
|
||||||
"stream_id": "",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert disable_ambiguous_result["success"] is False
|
|
||||||
assert "多个版本" in str(disable_ambiguous_result["error"])
|
|
||||||
|
|
||||||
disable_v1_result = await manager._cap_component_disable(
|
|
||||||
"consumer",
|
|
||||||
"component.disable",
|
|
||||||
{
|
|
||||||
"name": "provider.render_html",
|
|
||||||
"component_type": "API",
|
|
||||||
"scope": "global",
|
|
||||||
"stream_id": "",
|
|
||||||
"version": "1",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert disable_v1_result["success"] is True
|
|
||||||
assert provider_supervisor.api_registry.get_api("provider", "render_html", version="1", enabled_only=True) is None
|
|
||||||
assert provider_supervisor.api_registry.get_api("provider", "render_html", version="2", enabled_only=True) is not None
|
|
||||||
|
|
||||||
result = await manager._cap_api_call(
|
|
||||||
"consumer",
|
|
||||||
"api.call",
|
|
||||||
{
|
|
||||||
"api_name": "provider.render_html",
|
|
||||||
"version": "2",
|
|
||||||
"args": {"html": "<div>Hello</div>"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == {"success": True, "result": {"image": "ok"}}
|
|
||||||
assert captured["plugin_id"] == "provider"
|
|
||||||
assert captured["component_name"] == "handle_render_html_v2"
|
|
||||||
assert captured["args"] == {"html": "<div>Hello</div>"}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_api_replace_dynamic_can_offline_removed_entries(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
"""动态 API 替换后,被移除的 API 应返回明确下线错误。"""
|
|
||||||
|
|
||||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
|
||||||
await _register_plugin(supervisor, "provider", [])
|
|
||||||
manager = _build_manager(supervisor)
|
|
||||||
|
|
||||||
captured: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
async def fake_invoke_api(
|
|
||||||
plugin_id: str,
|
|
||||||
component_name: str,
|
|
||||||
args: Dict[str, Any] | None = None,
|
|
||||||
timeout_ms: int = 30000,
|
|
||||||
) -> Any:
|
|
||||||
"""模拟动态 API 调用。"""
|
|
||||||
|
|
||||||
captured["plugin_id"] = plugin_id
|
|
||||||
captured["component_name"] = component_name
|
|
||||||
captured["args"] = args or {}
|
|
||||||
captured["timeout_ms"] = timeout_ms
|
|
||||||
return SimpleNamespace(error=None, payload={"success": True, "result": {"ok": True}})
|
|
||||||
|
|
||||||
monkeypatch.setattr(supervisor, "invoke_api", fake_invoke_api)
|
|
||||||
|
|
||||||
replace_result = await manager._cap_api_replace_dynamic(
|
|
||||||
"provider",
|
|
||||||
"api.replace_dynamic",
|
|
||||||
{
|
|
||||||
"apis": [
|
|
||||||
{
|
|
||||||
"name": "mcp.search",
|
|
||||||
"type": "API",
|
|
||||||
"metadata": {
|
|
||||||
"version": "1",
|
|
||||||
"public": True,
|
|
||||||
"handler_name": "dynamic_search",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "mcp.read",
|
|
||||||
"type": "API",
|
|
||||||
"metadata": {
|
|
||||||
"version": "1",
|
|
||||||
"public": True,
|
|
||||||
"handler_name": "dynamic_read",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"offline_reason": "MCP 服务器已关闭",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert replace_result["success"] is True
|
|
||||||
assert replace_result["count"] == 2
|
|
||||||
list_result = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"})
|
|
||||||
assert {(item["name"], item["version"]) for item in list_result["apis"]} == {
|
|
||||||
("mcp.read", "1"),
|
|
||||||
("mcp.search", "1"),
|
|
||||||
}
|
|
||||||
|
|
||||||
call_result = await manager._cap_api_call(
|
|
||||||
"provider",
|
|
||||||
"api.call",
|
|
||||||
{
|
|
||||||
"api_name": "provider.mcp.search",
|
|
||||||
"version": "1",
|
|
||||||
"args": {"query": "hello"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert call_result == {"success": True, "result": {"ok": True}}
|
|
||||||
assert captured["component_name"] == "dynamic_search"
|
|
||||||
assert captured["args"]["query"] == "hello"
|
|
||||||
assert captured["args"]["__maibot_api_name__"] == "mcp.search"
|
|
||||||
assert captured["args"]["__maibot_api_version__"] == "1"
|
|
||||||
|
|
||||||
second_replace_result = await manager._cap_api_replace_dynamic(
|
|
||||||
"provider",
|
|
||||||
"api.replace_dynamic",
|
|
||||||
{
|
|
||||||
"apis": [
|
|
||||||
{
|
|
||||||
"name": "mcp.read",
|
|
||||||
"type": "API",
|
|
||||||
"metadata": {
|
|
||||||
"version": "1",
|
|
||||||
"public": True,
|
|
||||||
"handler_name": "dynamic_read",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"offline_reason": "MCP 服务器已关闭",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert second_replace_result["success"] is True
|
|
||||||
assert second_replace_result["count"] == 1
|
|
||||||
assert second_replace_result["offlined"] == 1
|
|
||||||
|
|
||||||
offlined_call_result = await manager._cap_api_call(
|
|
||||||
"provider",
|
|
||||||
"api.call",
|
|
||||||
{
|
|
||||||
"api_name": "provider.mcp.search",
|
|
||||||
"version": "1",
|
|
||||||
"args": {},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert offlined_call_result["success"] is False
|
|
||||||
assert "MCP 服务器已关闭" in str(offlined_call_result["error"])
|
|
||||||
|
|
||||||
list_after_replace = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"})
|
|
||||||
assert {(item["name"], item["version"]) for item in list_after_replace["apis"]} == {
|
|
||||||
("mcp.read", "1"),
|
|
||||||
}
|
|
||||||
@@ -1,96 +0,0 @@
|
|||||||
"""插件运行时浏览器渲染能力测试。"""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.plugin_runtime.integration import PluginRuntimeManager
|
|
||||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
|
||||||
from src.services.html_render_service import HtmlRenderRequest, HtmlRenderResult
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeRenderService:
|
|
||||||
"""用于替代真实浏览器渲染服务的测试桩。"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
"""初始化测试桩。"""
|
|
||||||
|
|
||||||
self.last_request: Optional[HtmlRenderRequest] = None
|
|
||||||
|
|
||||||
async def render_html_to_png(self, request: HtmlRenderRequest) -> HtmlRenderResult:
|
|
||||||
"""记录请求并返回固定的渲染结果。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: 当前渲染请求。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
HtmlRenderResult: 固定的测试渲染结果。
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.last_request = request
|
|
||||||
return HtmlRenderResult(
|
|
||||||
image_base64="ZmFrZS1pbWFnZQ==",
|
|
||||||
mime_type="image/png",
|
|
||||||
width=640,
|
|
||||||
height=480,
|
|
||||||
render_ms=12,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_render_capability_is_registered() -> None:
|
|
||||||
"""Host 注册能力时应包含 render.html2png。"""
|
|
||||||
|
|
||||||
manager = PluginRuntimeManager()
|
|
||||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
|
||||||
|
|
||||||
manager._register_capability_impls(supervisor)
|
|
||||||
|
|
||||||
assert "render.html2png" in supervisor.capability_service.list_capabilities()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_render_capability_forwards_request(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
"""render.html2png 应将请求透传给浏览器渲染服务。"""
|
|
||||||
|
|
||||||
from src.plugin_runtime.capabilities import render as render_capability_module
|
|
||||||
|
|
||||||
fake_service = _FakeRenderService()
|
|
||||||
monkeypatch.setattr(render_capability_module, "get_html_render_service", lambda: fake_service)
|
|
||||||
|
|
||||||
manager = PluginRuntimeManager()
|
|
||||||
result = await manager._cap_render_html2png(
|
|
||||||
"demo.plugin",
|
|
||||||
"render.html2png",
|
|
||||||
{
|
|
||||||
"html": "<body><div id='card'>hello</div></body>",
|
|
||||||
"selector": "#card",
|
|
||||||
"viewport": {"width": 1024, "height": 768},
|
|
||||||
"device_scale_factor": 1.5,
|
|
||||||
"full_page": False,
|
|
||||||
"omit_background": True,
|
|
||||||
"wait_until": "networkidle",
|
|
||||||
"wait_for_selector": "#card",
|
|
||||||
"wait_for_timeout_ms": 150,
|
|
||||||
"timeout_ms": 3000,
|
|
||||||
"allow_network": True,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == {
|
|
||||||
"success": True,
|
|
||||||
"result": {
|
|
||||||
"image_base64": "ZmFrZS1pbWFnZQ==",
|
|
||||||
"mime_type": "image/png",
|
|
||||||
"width": 640,
|
|
||||||
"height": 480,
|
|
||||||
"render_ms": 12,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
assert fake_service.last_request is not None
|
|
||||||
assert fake_service.last_request.selector == "#card"
|
|
||||||
assert fake_service.last_request.viewport_width == 1024
|
|
||||||
assert fake_service.last_request.viewport_height == 768
|
|
||||||
assert fake_service.last_request.device_scale_factor == 1.5
|
|
||||||
assert fake_service.last_request.omit_background is True
|
|
||||||
assert fake_service.last_request.wait_until == "networkidle"
|
|
||||||
assert fake_service.last_request.allow_network is True
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
|
||||||
from src.plugin_runtime.hook_payloads import deserialize_prompt_messages, serialize_prompt_messages
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_messages_roundtrip_preserves_image_parts() -> None:
|
|
||||||
messages = [
|
|
||||||
MessageBuilder().set_role(RoleType.User).add_text_content("你好").add_image_content("png", "ZmFrZQ==").build(),
|
|
||||||
]
|
|
||||||
|
|
||||||
serialized_messages = serialize_prompt_messages(messages)
|
|
||||||
restored_messages = deserialize_prompt_messages(serialized_messages)
|
|
||||||
|
|
||||||
assert len(restored_messages) == 1
|
|
||||||
assert restored_messages[0].role == RoleType.User
|
|
||||||
assert restored_messages[0].get_text_content() == "你好"
|
|
||||||
assert len(restored_messages[0].parts) == 2
|
|
||||||
assert restored_messages[0].parts[1].image_format == "png"
|
|
||||||
assert restored_messages[0].parts[1].image_base64 == "ZmFrZQ=="
|
|
||||||
@@ -1,136 +0,0 @@
|
|||||||
"""业务命名 Hook 集成测试。"""
|
|
||||||
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
# 确保项目根目录在 sys.path 中
|
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
|
||||||
# SDK 包路径
|
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "packages", "maibot-plugin-sdk"))
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeHookManager:
|
|
||||||
"""用于业务 Hook 测试的最小运行时管理器。"""
|
|
||||||
|
|
||||||
def __init__(self, responses: dict[str, SimpleNamespace]) -> None:
|
|
||||||
"""初始化测试管理器。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
responses: 按 Hook 名称预设的返回结果映射。
|
|
||||||
"""
|
|
||||||
|
|
||||||
self._responses = responses
|
|
||||||
self.calls: list[tuple[str, dict[str, Any]]] = []
|
|
||||||
|
|
||||||
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> SimpleNamespace:
|
|
||||||
"""模拟调用运行时命名 Hook。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hook_name: 目标 Hook 名称。
|
|
||||||
**kwargs: 传入 Hook 的参数。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SimpleNamespace: 预设的 Hook 返回结果。
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.calls.append((hook_name, dict(kwargs)))
|
|
||||||
return self._responses.get(hook_name, SimpleNamespace(kwargs=dict(kwargs), aborted=False))
|
|
||||||
|
|
||||||
|
|
||||||
def test_builtin_hook_catalog_includes_new_business_hooks(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
"""内置 Hook 目录应包含三个业务系统新增的 Hook。"""
|
|
||||||
|
|
||||||
monkeypatch.setattr(sys, "exit", lambda code=0: None)
|
|
||||||
from src.plugin_runtime.hook_catalog import register_builtin_hook_specs
|
|
||||||
from src.plugin_runtime.host.hook_spec_registry import HookSpecRegistry
|
|
||||||
|
|
||||||
registry = HookSpecRegistry()
|
|
||||||
hook_names = {spec.name for spec in register_builtin_hook_specs(registry)}
|
|
||||||
|
|
||||||
assert "emoji.maisaka.before_select" in hook_names
|
|
||||||
assert "emoji.register.after_build_emotion" in hook_names
|
|
||||||
assert "jargon.extract.before_persist" in hook_names
|
|
||||||
assert "jargon.query.after_search" in hook_names
|
|
||||||
assert "expression.select.before_select" in hook_names
|
|
||||||
assert "expression.learn.before_upsert" in hook_names
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_send_emoji_for_maisaka_can_be_aborted_by_hook(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
"""表情包系统应允许在选择前被 Hook 中止。"""
|
|
||||||
|
|
||||||
from src.emoji_system import maisaka_tool
|
|
||||||
|
|
||||||
fake_manager = _FakeHookManager(
|
|
||||||
{
|
|
||||||
"emoji.maisaka.before_select": SimpleNamespace(
|
|
||||||
kwargs={"abort_message": "插件阻止了表情发送。"},
|
|
||||||
aborted=True,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(maisaka_tool, "_get_runtime_manager", lambda: fake_manager)
|
|
||||||
|
|
||||||
result = await maisaka_tool.send_emoji_for_maisaka(stream_id="stream-1", requested_emotion="开心")
|
|
||||||
|
|
||||||
assert result.success is False
|
|
||||||
assert result.message == "插件阻止了表情发送。"
|
|
||||||
assert fake_manager.calls[0][0] == "emoji.maisaka.before_select"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_jargon_extract_can_be_aborted_before_persist(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
"""黑话提取结果应允许在写库前被 Hook 中止。"""
|
|
||||||
|
|
||||||
from src.learners.jargon_miner import JargonMiner
|
|
||||||
|
|
||||||
fake_manager = _FakeHookManager(
|
|
||||||
{
|
|
||||||
"jargon.extract.before_persist": SimpleNamespace(
|
|
||||||
kwargs={"entries": []},
|
|
||||||
aborted=True,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(JargonMiner, "_get_runtime_manager", staticmethod(lambda: fake_manager))
|
|
||||||
|
|
||||||
miner = JargonMiner(session_id="session-1", session_name="测试会话")
|
|
||||||
await miner.process_extracted_entries(
|
|
||||||
[{"content": "yyds", "raw_content": {"[1] yyds 太强了"}}],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert fake_manager.calls[0][0] == "jargon.extract.before_persist"
|
|
||||||
assert fake_manager.calls[0][1]["session_id"] == "session-1"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_expression_selection_can_be_aborted_by_hook(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
"""表达方式选择流程应允许在开始前被 Hook 中止。"""
|
|
||||||
|
|
||||||
from src.learners.expression_selector import ExpressionSelector
|
|
||||||
|
|
||||||
fake_manager = _FakeHookManager(
|
|
||||||
{
|
|
||||||
"expression.select.before_select": SimpleNamespace(
|
|
||||||
kwargs={},
|
|
||||||
aborted=True,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(ExpressionSelector, "_get_runtime_manager", staticmethod(lambda: fake_manager))
|
|
||||||
monkeypatch.setattr(ExpressionSelector, "can_use_expression_for_chat", lambda self, chat_id: True)
|
|
||||||
|
|
||||||
selector = ExpressionSelector()
|
|
||||||
selected_expressions, selected_ids = await selector.select_suitable_expressions(
|
|
||||||
chat_id="session-1",
|
|
||||||
chat_info="用户刚刚发来一条消息。",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert selected_expressions == []
|
|
||||||
assert selected_ids == []
|
|
||||||
assert fake_manager.calls[0][0] == "expression.select.before_select"
|
|
||||||
@@ -1,360 +0,0 @@
|
|||||||
"""发送服务回归测试。"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.chat.message_receive.chat_manager import BotChatSession
|
|
||||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
|
||||||
from src.services import send_service
|
|
||||||
|
|
||||||
|
|
||||||
class _FakePlatformIOManager:
|
|
||||||
"""用于测试的 Platform IO 管理器假对象。"""
|
|
||||||
|
|
||||||
def __init__(self, delivery_batch: Any) -> None:
|
|
||||||
self._delivery_batch = delivery_batch
|
|
||||||
self.ensure_calls = 0
|
|
||||||
self.sent_messages: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
async def ensure_send_pipeline_ready(self) -> None:
|
|
||||||
self.ensure_calls += 1
|
|
||||||
|
|
||||||
def build_route_key_from_message(self, message: Any) -> Any:
|
|
||||||
del message
|
|
||||||
return SimpleNamespace(platform="qq")
|
|
||||||
|
|
||||||
async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any:
|
|
||||||
self.sent_messages.append(
|
|
||||||
{
|
|
||||||
"message": message,
|
|
||||||
"message_id_before_send": str(getattr(message, "message_id", "") or ""),
|
|
||||||
"route_key": route_key,
|
|
||||||
"metadata": metadata,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return self._delivery_batch
|
|
||||||
|
|
||||||
|
|
||||||
def _build_private_stream() -> BotChatSession:
|
|
||||||
return BotChatSession(
|
|
||||||
session_id="test-session",
|
|
||||||
platform="qq",
|
|
||||||
user_id="target-user",
|
|
||||||
group_id=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_group_stream() -> BotChatSession:
|
|
||||||
return BotChatSession(
|
|
||||||
session_id="group-session",
|
|
||||||
platform="qq",
|
|
||||||
user_id="target-user",
|
|
||||||
group_id="target-group",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_inherit_platform_io_route_metadata_falls_back_to_bot_account(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq" if platform == "qq" else "")
|
|
||||||
|
|
||||||
metadata = send_service._inherit_platform_io_route_metadata(_build_private_stream())
|
|
||||||
|
|
||||||
assert metadata["platform_io_account_id"] == "bot-qq"
|
|
||||||
assert metadata["platform_io_target_user_id"] == "target-user"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
import src.common.message_server.api as message_server_api
|
|
||||||
|
|
||||||
fake_manager = _FakePlatformIOManager(
|
|
||||||
delivery_batch=SimpleNamespace(
|
|
||||||
has_success=True,
|
|
||||||
sent_receipts=[
|
|
||||||
SimpleNamespace(
|
|
||||||
driver_id="plugin.qq.sender",
|
|
||||||
external_message_id="real-message-id",
|
|
||||||
metadata={
|
|
||||||
"adapter_callbacks": [
|
|
||||||
{
|
|
||||||
"name": "message_id_echo",
|
|
||||||
"payload": {
|
|
||||||
"content": {
|
|
||||||
"type": "echo",
|
|
||||||
"echo": "send_api_test",
|
|
||||||
"actual_id": "real-message-id",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
failed_receipts=[],
|
|
||||||
route_key=SimpleNamespace(platform="qq"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
callback_payloads: List[Dict[str, Any]] = []
|
|
||||||
stored_messages: List[Any] = []
|
|
||||||
|
|
||||||
async def fake_echo_handler(payload: Dict[str, Any]) -> None:
|
|
||||||
"""记录发送成功后的消息 ID 回调。"""
|
|
||||||
|
|
||||||
callback_payloads.append(payload)
|
|
||||||
|
|
||||||
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager)
|
|
||||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
|
|
||||||
monkeypatch.setattr(
|
|
||||||
send_service._chat_manager,
|
|
||||||
"get_session_by_session_id",
|
|
||||||
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
send_service.MessageUtils,
|
|
||||||
"store_message_to_db",
|
|
||||||
lambda message: stored_messages.append(message),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
message_server_api,
|
|
||||||
"global_api",
|
|
||||||
SimpleNamespace(_custom_message_handlers={"message_id_echo": fake_echo_handler}),
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await send_service.text_to_stream(text="你好", stream_id="test-session")
|
|
||||||
|
|
||||||
assert result is True
|
|
||||||
assert fake_manager.ensure_calls == 1
|
|
||||||
assert len(fake_manager.sent_messages) == 1
|
|
||||||
assert fake_manager.sent_messages[0]["metadata"] == {"show_log": False}
|
|
||||||
assert len(stored_messages) == 1
|
|
||||||
assert stored_messages[0].message_id == "real-message-id"
|
|
||||||
assert callback_payloads == [
|
|
||||||
{
|
|
||||||
"content": {
|
|
||||||
"type": "echo",
|
|
||||||
"echo": "send_api_test",
|
|
||||||
"actual_id": "real-message-id",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_text_to_stream_with_message_returns_sent_message(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
fake_manager = _FakePlatformIOManager(
|
|
||||||
delivery_batch=SimpleNamespace(
|
|
||||||
has_success=True,
|
|
||||||
sent_receipts=[
|
|
||||||
SimpleNamespace(
|
|
||||||
driver_id="plugin.qq.sender",
|
|
||||||
external_message_id="real-message-id",
|
|
||||||
metadata={},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
failed_receipts=[],
|
|
||||||
route_key=SimpleNamespace(platform="qq"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
stored_messages: List[Any] = []
|
|
||||||
|
|
||||||
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager)
|
|
||||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
|
|
||||||
monkeypatch.setattr(
|
|
||||||
send_service._chat_manager,
|
|
||||||
"get_session_by_session_id",
|
|
||||||
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
send_service.MessageUtils,
|
|
||||||
"store_message_to_db",
|
|
||||||
lambda message: stored_messages.append(message),
|
|
||||||
)
|
|
||||||
|
|
||||||
sent_message = await send_service.text_to_stream_with_message(text="你好", stream_id="test-session")
|
|
||||||
|
|
||||||
assert sent_message is not None
|
|
||||||
assert sent_message.message_id == "real-message-id"
|
|
||||||
assert fake_manager.ensure_calls == 1
|
|
||||||
assert len(stored_messages) == 1
|
|
||||||
assert stored_messages[0].message_id == "real-message-id"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_text_to_stream_with_message_triggers_memory_and_syncs_maisaka_history(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
fake_manager = _FakePlatformIOManager(
|
|
||||||
delivery_batch=SimpleNamespace(
|
|
||||||
has_success=True,
|
|
||||||
sent_receipts=[
|
|
||||||
SimpleNamespace(
|
|
||||||
driver_id="plugin.qq.sender",
|
|
||||||
external_message_id="real-message-id",
|
|
||||||
metadata={},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
failed_receipts=[],
|
|
||||||
route_key=SimpleNamespace(platform="qq"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
stored_messages: List[Any] = []
|
|
||||||
memory_events: List[str] = []
|
|
||||||
history_events: List[tuple[str, str]] = []
|
|
||||||
|
|
||||||
class FakeMemoryAutomationService:
|
|
||||||
async def on_message_sent(self, message: Any) -> None:
|
|
||||||
memory_events.append(str(message.message_id))
|
|
||||||
|
|
||||||
class FakeRuntime:
|
|
||||||
def append_sent_message_to_chat_history(self, message: Any, *, source_kind: str = "guided_reply") -> None:
|
|
||||||
history_events.append((str(message.message_id), source_kind))
|
|
||||||
|
|
||||||
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager)
|
|
||||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
|
|
||||||
monkeypatch.setattr(
|
|
||||||
send_service._chat_manager,
|
|
||||||
"get_session_by_session_id",
|
|
||||||
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
send_service.MessageUtils,
|
|
||||||
"store_message_to_db",
|
|
||||||
lambda message: stored_messages.append(message),
|
|
||||||
)
|
|
||||||
monkeypatch.setitem(
|
|
||||||
sys.modules,
|
|
||||||
"src.services.memory_flow_service",
|
|
||||||
SimpleNamespace(memory_automation_service=FakeMemoryAutomationService()),
|
|
||||||
)
|
|
||||||
monkeypatch.setitem(
|
|
||||||
sys.modules,
|
|
||||||
"src.chat.heart_flow.heartflow_manager",
|
|
||||||
SimpleNamespace(heartflow_manager=SimpleNamespace(heartflow_chat_list={"test-session": FakeRuntime()})),
|
|
||||||
)
|
|
||||||
|
|
||||||
sent_message = await send_service.text_to_stream_with_message(
|
|
||||||
text="你好",
|
|
||||||
stream_id="test-session",
|
|
||||||
sync_to_maisaka_history=True,
|
|
||||||
maisaka_source_kind="guided_reply",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert sent_message is not None
|
|
||||||
assert sent_message.message_id == "real-message-id"
|
|
||||||
assert fake_manager.ensure_calls == 1
|
|
||||||
assert len(stored_messages) == 1
|
|
||||||
assert stored_messages[0].message_id == "real-message-id"
|
|
||||||
assert memory_events == ["real-message-id"]
|
|
||||||
assert history_events == [("real-message-id", "guided_reply")]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
fake_manager = _FakePlatformIOManager(
|
|
||||||
delivery_batch=SimpleNamespace(
|
|
||||||
has_success=False,
|
|
||||||
sent_receipts=[],
|
|
||||||
failed_receipts=[
|
|
||||||
SimpleNamespace(
|
|
||||||
driver_id="plugin.qq.sender",
|
|
||||||
status="failed",
|
|
||||||
error="network error",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
route_key=SimpleNamespace(platform="qq"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager)
|
|
||||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
|
|
||||||
monkeypatch.setattr(
|
|
||||||
send_service._chat_manager,
|
|
||||||
"get_session_by_session_id",
|
|
||||||
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await send_service.text_to_stream(text="发送失败", stream_id="test-session")
|
|
||||||
|
|
||||||
assert result is False
|
|
||||||
assert fake_manager.ensure_calls == 1
|
|
||||||
assert len(fake_manager.sent_messages) == 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_custom_to_stream_detailed_raises_stream_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
monkeypatch.setattr(
|
|
||||||
send_service._chat_manager,
|
|
||||||
"get_session_by_session_id",
|
|
||||||
lambda stream_id: None,
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(send_service.SendServiceError, match="未找到聊天流: group_chat"):
|
|
||||||
await send_service.custom_to_stream_detailed(
|
|
||||||
message_type="poke",
|
|
||||||
content={"qq_id": "2810873701"},
|
|
||||||
stream_id="group_chat",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_private_outbound_message_preserves_bot_sender_and_receiver_user(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
|
|
||||||
monkeypatch.setattr(
|
|
||||||
send_service._chat_manager,
|
|
||||||
"get_session_by_session_id",
|
|
||||||
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
outbound_message = send_service._build_outbound_session_message(
|
|
||||||
message_sequence=MessageSequence(components=[TextComponent(text="你好")]),
|
|
||||||
stream_id="test-session",
|
|
||||||
processed_plain_text="你好",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert outbound_message is not None
|
|
||||||
maim_message = await outbound_message.to_maim_message()
|
|
||||||
|
|
||||||
assert maim_message.message_info.user_info is not None
|
|
||||||
assert maim_message.message_info.user_info.user_id == "bot-qq"
|
|
||||||
assert maim_message.message_info.group_info is None
|
|
||||||
assert maim_message.message_info.sender_info is not None
|
|
||||||
assert maim_message.message_info.sender_info.user_info is not None
|
|
||||||
assert maim_message.message_info.sender_info.user_info.user_id == "bot-qq"
|
|
||||||
assert maim_message.message_info.receiver_info is not None
|
|
||||||
assert maim_message.message_info.receiver_info.user_info is not None
|
|
||||||
assert maim_message.message_info.receiver_info.user_info.user_id == "target-user"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_group_outbound_message_preserves_bot_sender_and_target_group(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
|
|
||||||
monkeypatch.setattr(
|
|
||||||
send_service._chat_manager,
|
|
||||||
"get_session_by_session_id",
|
|
||||||
lambda stream_id: _build_group_stream() if stream_id == "group-session" else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
outbound_message = send_service._build_outbound_session_message(
|
|
||||||
message_sequence=MessageSequence(components=[TextComponent(text="大家好")]),
|
|
||||||
stream_id="group-session",
|
|
||||||
processed_plain_text="大家好",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert outbound_message is not None
|
|
||||||
maim_message = await outbound_message.to_maim_message()
|
|
||||||
|
|
||||||
assert maim_message.message_info.user_info is not None
|
|
||||||
assert maim_message.message_info.user_info.user_id == "bot-qq"
|
|
||||||
assert maim_message.message_info.group_info is not None
|
|
||||||
assert maim_message.message_info.group_info.group_id == "target-group"
|
|
||||||
assert maim_message.message_info.receiver_info is not None
|
|
||||||
assert maim_message.message_info.receiver_info.group_info is not None
|
|
||||||
assert maim_message.message_info.receiver_info.group_info.group_id == "target-group"
|
|
||||||
@@ -1,297 +0,0 @@
|
|||||||
from types import SimpleNamespace
|
|
||||||
from typing import Any
|
|
||||||
import importlib.util
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.core.tooling import ToolAvailabilityContext, ToolRegistry
|
|
||||||
from src.maisaka.tool_provider import MaisakaBuiltinToolProvider
|
|
||||||
from src.plugin_runtime.component_query import ComponentQueryService
|
|
||||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_builtin_at_tool_is_not_exposed() -> None:
|
|
||||||
registry = ToolRegistry()
|
|
||||||
registry.register_provider(MaisakaBuiltinToolProvider())
|
|
||||||
|
|
||||||
group_specs = await registry.list_tools(ToolAvailabilityContext(session_id="group-1", is_group_chat=True))
|
|
||||||
private_specs = await registry.list_tools(ToolAvailabilityContext(session_id="private-1", is_group_chat=False))
|
|
||||||
default_specs = await registry.list_tools()
|
|
||||||
|
|
||||||
assert "at" not in {tool_spec.name for tool_spec in group_specs}
|
|
||||||
assert "at" not in {tool_spec.name for tool_spec in private_specs}
|
|
||||||
assert "at" not in {tool_spec.name for tool_spec in default_specs}
|
|
||||||
|
|
||||||
|
|
||||||
def test_plugin_tool_chat_scope_uses_component_field(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
service = ComponentQueryService()
|
|
||||||
registry = ComponentRegistry()
|
|
||||||
supervisor = SimpleNamespace(component_registry=registry)
|
|
||||||
monkeypatch.setattr(service, "_iter_supervisors", lambda: [supervisor])
|
|
||||||
|
|
||||||
registry.register_plugin_components(
|
|
||||||
"scope_plugin",
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"name": "group_tool",
|
|
||||||
"component_type": "TOOL",
|
|
||||||
"chat_scope": "group",
|
|
||||||
"metadata": {"description": "group only"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "private_tool",
|
|
||||||
"component_type": "TOOL",
|
|
||||||
"chat_scope": "private",
|
|
||||||
"metadata": {"description": "private only"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "all_tool",
|
|
||||||
"component_type": "TOOL",
|
|
||||||
"metadata": {"description": "all chats"},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
group_specs = service.get_llm_available_tool_specs(
|
|
||||||
context=ToolAvailabilityContext(session_id="group-1", is_group_chat=True)
|
|
||||||
)
|
|
||||||
private_specs = service.get_llm_available_tool_specs(
|
|
||||||
context=ToolAvailabilityContext(session_id="private-1", is_group_chat=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
group_entry = registry.get_component("scope_plugin.group_tool")
|
|
||||||
assert group_entry is not None
|
|
||||||
assert group_entry.chat_scope == "group"
|
|
||||||
assert "chat_scope" not in group_entry.metadata
|
|
||||||
assert set(group_specs) == {"group_tool", "all_tool"}
|
|
||||||
assert set(private_specs) == {"private_tool", "all_tool"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_plugin_tool_session_disable_still_filters_specific_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
service = ComponentQueryService()
|
|
||||||
registry = ComponentRegistry()
|
|
||||||
supervisor = SimpleNamespace(component_registry=registry)
|
|
||||||
monkeypatch.setattr(service, "_iter_supervisors", lambda: [supervisor])
|
|
||||||
|
|
||||||
registry.register_plugin_components(
|
|
||||||
"mute_plugin",
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"name": "mute",
|
|
||||||
"component_type": "TOOL",
|
|
||||||
"chat_scope": "group",
|
|
||||||
"metadata": {"description": "mute group member"},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
registry.set_component_enabled("mute_plugin.mute", False, session_id="group-disabled")
|
|
||||||
|
|
||||||
disabled_specs = service.get_llm_available_tool_specs(
|
|
||||||
context=ToolAvailabilityContext(session_id="group-disabled", is_group_chat=True)
|
|
||||||
)
|
|
||||||
enabled_specs = service.get_llm_available_tool_specs(
|
|
||||||
context=ToolAvailabilityContext(session_id="group-enabled", is_group_chat=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "mute" not in disabled_specs
|
|
||||||
assert "mute" in enabled_specs
|
|
||||||
|
|
||||||
|
|
||||||
def test_plugin_tool_allowed_session_filters_tool_exposure(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
service = ComponentQueryService()
|
|
||||||
registry = ComponentRegistry()
|
|
||||||
supervisor = SimpleNamespace(component_registry=registry)
|
|
||||||
monkeypatch.setattr(service, "_iter_supervisors", lambda: [supervisor])
|
|
||||||
|
|
||||||
registry.register_plugin_components(
|
|
||||||
"mute_plugin",
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"name": "mute",
|
|
||||||
"component_type": "TOOL",
|
|
||||||
"chat_scope": "group",
|
|
||||||
"allowed_session": ["qq:10001", "raw-group-id", "exact-session-id"],
|
|
||||||
"metadata": {"description": "mute group member"},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
platform_group_specs = service.get_llm_available_tool_specs(
|
|
||||||
context=ToolAvailabilityContext(
|
|
||||||
session_id="hashed-session-1",
|
|
||||||
is_group_chat=True,
|
|
||||||
group_id="10001",
|
|
||||||
platform="qq",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
raw_group_specs = service.get_llm_available_tool_specs(
|
|
||||||
context=ToolAvailabilityContext(
|
|
||||||
session_id="hashed-session-2",
|
|
||||||
is_group_chat=True,
|
|
||||||
group_id="raw-group-id",
|
|
||||||
platform="qq",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
exact_session_specs = service.get_llm_available_tool_specs(
|
|
||||||
context=ToolAvailabilityContext(session_id="exact-session-id", is_group_chat=True)
|
|
||||||
)
|
|
||||||
blocked_specs = service.get_llm_available_tool_specs(
|
|
||||||
context=ToolAvailabilityContext(
|
|
||||||
session_id="blocked-session",
|
|
||||||
is_group_chat=True,
|
|
||||||
group_id="20002",
|
|
||||||
platform="qq",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
entry = registry.get_component("mute_plugin.mute")
|
|
||||||
assert entry is not None
|
|
||||||
assert entry.allowed_session == {"qq:10001", "raw-group-id", "exact-session-id"}
|
|
||||||
assert "allowed_session" not in entry.metadata
|
|
||||||
assert "mute" in platform_group_specs
|
|
||||||
assert "mute" in raw_group_specs
|
|
||||||
assert "mute" in exact_session_specs
|
|
||||||
assert "mute" not in blocked_specs
|
|
||||||
|
|
||||||
|
|
||||||
def test_plugin_tool_disabled_session_take_precedence_over_allowed_session(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
service = ComponentQueryService()
|
|
||||||
registry = ComponentRegistry()
|
|
||||||
supervisor = SimpleNamespace(component_registry=registry)
|
|
||||||
monkeypatch.setattr(service, "_iter_supervisors", lambda: [supervisor])
|
|
||||||
|
|
||||||
registry.register_plugin_components(
|
|
||||||
"mute_plugin",
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"name": "mute",
|
|
||||||
"component_type": "TOOL",
|
|
||||||
"chat_scope": "group",
|
|
||||||
"allowed_session": ["qq:10001"],
|
|
||||||
"metadata": {"description": "mute group member"},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
registry.set_component_enabled("mute_plugin.mute", False, session_id="allowed-session")
|
|
||||||
|
|
||||||
visible_specs = service.get_llm_available_tool_specs(
|
|
||||||
context=ToolAvailabilityContext(
|
|
||||||
session_id="visible-session",
|
|
||||||
is_group_chat=True,
|
|
||||||
group_id="10001",
|
|
||||||
platform="qq",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
disabled_specs = service.get_llm_available_tool_specs(
|
|
||||||
context=ToolAvailabilityContext(
|
|
||||||
session_id="allowed-session",
|
|
||||||
is_group_chat=True,
|
|
||||||
group_id="10001",
|
|
||||||
platform="qq",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
entry = registry.get_component("mute_plugin.mute")
|
|
||||||
assert entry is not None
|
|
||||||
assert entry.disabled_session == {"allowed-session"}
|
|
||||||
assert "mute" in visible_specs
|
|
||||||
assert "mute" not in disabled_specs
|
|
||||||
|
|
||||||
|
|
||||||
def test_mute_plugin_exports_allowed_groups_as_component_allowed_session() -> None:
|
|
||||||
module_path = "plugins/MutePlugin/plugin.py"
|
|
||||||
spec = importlib.util.spec_from_file_location("mute_plugin_under_test", module_path)
|
|
||||||
assert spec is not None
|
|
||||||
assert spec.loader is not None
|
|
||||||
module = importlib.util.module_from_spec(spec)
|
|
||||||
sys.modules[spec.name] = module
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
module.MutePluginConfig.model_rebuild()
|
|
||||||
|
|
||||||
plugin = module.MutePlugin()
|
|
||||||
plugin.set_plugin_config({"permissions": {"allowed_groups": ["qq:10001", "raw-group-id"]}})
|
|
||||||
|
|
||||||
mute_components = [component for component in plugin.get_components() if component.get("name") == "mute"]
|
|
||||||
|
|
||||||
assert len(mute_components) == 1
|
|
||||||
assert mute_components[0]["chat_scope"] == "group"
|
|
||||||
assert mute_components[0]["allowed_session"] == ["qq:10001", "raw-group-id"]
|
|
||||||
assert "allowed_session" not in mute_components[0]["metadata"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_mute_tool_queries_target_message_with_current_chat_id() -> None:
|
|
||||||
module_path = "plugins/MutePlugin/plugin.py"
|
|
||||||
spec = importlib.util.spec_from_file_location("mute_plugin_under_test_msg_id", module_path)
|
|
||||||
assert spec is not None
|
|
||||||
assert spec.loader is not None
|
|
||||||
module = importlib.util.module_from_spec(spec)
|
|
||||||
sys.modules[spec.name] = module
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
module.MutePluginConfig.model_rebuild()
|
|
||||||
|
|
||||||
capability_calls: list[dict[str, Any]] = []
|
|
||||||
api_calls: list[dict[str, Any]] = []
|
|
||||||
|
|
||||||
async def fake_call_capability(name: str, **kwargs: Any) -> dict[str, Any]:
|
|
||||||
capability_calls.append({"name": name, **kwargs})
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"result": {
|
|
||||||
"success": True,
|
|
||||||
"message": {
|
|
||||||
"message_info": {
|
|
||||||
"user_info": {
|
|
||||||
"user_id": "35529667",
|
|
||||||
"user_cardname": "目标用户",
|
|
||||||
"user_nickname": "目标昵称",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
async def fake_api_call(api_name: str, **kwargs: Any) -> dict[str, Any]:
|
|
||||||
api_calls.append({"name": api_name, **kwargs})
|
|
||||||
if api_name == "adapter.napcat.group.get_group_member_info":
|
|
||||||
return {"success": True, "result": {"data": {"role": "member"}}}
|
|
||||||
return {"status": "ok", "retcode": 0}
|
|
||||||
|
|
||||||
plugin = module.MutePlugin()
|
|
||||||
plugin.set_plugin_config({"components": {"enable_smart_mute": True}})
|
|
||||||
plugin._set_context(
|
|
||||||
SimpleNamespace(
|
|
||||||
call_capability=fake_call_capability,
|
|
||||||
api=SimpleNamespace(call=fake_api_call),
|
|
||||||
logger=SimpleNamespace(info=lambda *args, **kwargs: None, warning=lambda *args, **kwargs: None),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
success, message = await plugin.handle_mute_tool(
|
|
||||||
stream_id="current-session-id",
|
|
||||||
group_id="766798517",
|
|
||||||
msg_id="2046083292",
|
|
||||||
duration=3600,
|
|
||||||
reason="测试",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert success is True
|
|
||||||
assert message == "成功禁言 目标用户"
|
|
||||||
assert capability_calls == [
|
|
||||||
{
|
|
||||||
"name": "message.get_by_id",
|
|
||||||
"message_id": "2046083292",
|
|
||||||
"chat_id": "current-session-id",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
assert api_calls[-1] == {
|
|
||||||
"name": "adapter.napcat.group.set_group_ban",
|
|
||||||
"version": "1",
|
|
||||||
"group_id": "766798517",
|
|
||||||
"user_id": "35529667",
|
|
||||||
"duration": 3600,
|
|
||||||
}
|
|
||||||
@@ -1,367 +0,0 @@
|
|||||||
import sys
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
import pytest
|
|
||||||
import importlib
|
|
||||||
import importlib.util
|
|
||||||
from types import ModuleType
|
|
||||||
from pathlib import Path
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import TYPE_CHECKING, Optional
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from src.common.data_models.message_component_data_model import MessageSequence
|
|
||||||
from src.chat.message_receive.message import (
|
|
||||||
SessionMessage,
|
|
||||||
TextComponent,
|
|
||||||
ImageComponent,
|
|
||||||
AtComponent,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DummyLogger:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.logging_record = []
|
|
||||||
|
|
||||||
def debug(self, msg):
|
|
||||||
print(f"DEBUG: {msg}")
|
|
||||||
self.logging_record.append(f"DEBUG: {msg}")
|
|
||||||
|
|
||||||
def info(self, msg):
|
|
||||||
print(f"INFO: {msg}")
|
|
||||||
self.logging_record.append(f"INFO: {msg}")
|
|
||||||
|
|
||||||
def warning(self, msg):
|
|
||||||
print(f"WARNING: {msg}")
|
|
||||||
self.logging_record.append(f"WARNING: {msg}")
|
|
||||||
|
|
||||||
def error(self, msg):
|
|
||||||
print(f"ERROR: {msg}")
|
|
||||||
self.logging_record.append(f"ERROR: {msg}")
|
|
||||||
|
|
||||||
def critical(self, msg):
|
|
||||||
print(f"CRITICAL: {msg}")
|
|
||||||
self.logging_record.append(f"CRITICAL: {msg}")
|
|
||||||
|
|
||||||
|
|
||||||
def get_logger(name):
|
|
||||||
return DummyLogger()
|
|
||||||
|
|
||||||
|
|
||||||
class DummyDBSession:
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def exec(self, statement):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def first(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def commit(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def all(self):
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def get_db_session():
|
|
||||||
return DummyDBSession()
|
|
||||||
|
|
||||||
|
|
||||||
def get_manual_db_session():
|
|
||||||
return DummyDBSession()
|
|
||||||
|
|
||||||
|
|
||||||
class DummySelect:
|
|
||||||
def __init__(self, model):
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
def filter_by(self, **kwargs):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def where(self, condition):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def limit(self, n):
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
def select(model):
|
|
||||||
return DummySelect(model)
|
|
||||||
|
|
||||||
|
|
||||||
async def dummy_get_voice_text(binary_data):
|
|
||||||
return None # 可以根据需要返回模拟的文本结果
|
|
||||||
|
|
||||||
|
|
||||||
class DummyPersonUtils:
|
|
||||||
@staticmethod
|
|
||||||
def get_person_info_by_user_id_and_platform(user_id, platform):
|
|
||||||
return None # 可以根据需要返回模拟的用户信息
|
|
||||||
|
|
||||||
|
|
||||||
class DummyConfig:
|
|
||||||
class MessageReceiveConfig:
|
|
||||||
ban_words = set()
|
|
||||||
ban_msgs_regex = set()
|
|
||||||
|
|
||||||
message_receive = MessageReceiveConfig()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class UserInfo:
|
|
||||||
user_id: str
|
|
||||||
user_nickname: str
|
|
||||||
user_cardname: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GroupInfo:
|
|
||||||
group_id: str
|
|
||||||
group_name: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MessageInfo:
|
|
||||||
user_info: UserInfo
|
|
||||||
group_info: Optional[GroupInfo] = None
|
|
||||||
additional_config: dict = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_mocks(monkeypatch):
|
|
||||||
def _stub_module(name: str) -> ModuleType:
|
|
||||||
module = ModuleType(name)
|
|
||||||
monkeypatch.setitem(sys.modules, name, module)
|
|
||||||
return module
|
|
||||||
|
|
||||||
# src.common.logger
|
|
||||||
logger_mod = _stub_module("src.common.logger")
|
|
||||||
# Mock the logger
|
|
||||||
logger_mod.get_logger = get_logger
|
|
||||||
|
|
||||||
db_mod = _stub_module("src.common.database.database")
|
|
||||||
db_mod.get_db_session = get_db_session
|
|
||||||
db_mod.get_manual_db_session = get_manual_db_session
|
|
||||||
|
|
||||||
db_model_mod = _stub_module("src.common.database.database_model")
|
|
||||||
db_model_mod.Messages = None # 可以根据需要添加更多的属性或方法
|
|
||||||
|
|
||||||
emoji_manager_mod = _stub_module("src.emoji_system.emoji_manager")
|
|
||||||
emoji_manager_mod.emoji_manager = None # 可以根据需要添加更多的属性或方法
|
|
||||||
|
|
||||||
image_manager_mod = _stub_module("src.chat.image_system.image_manager")
|
|
||||||
image_manager_mod.image_manager = None # 可以根据需要添加更多的属性或方法
|
|
||||||
|
|
||||||
voice_utils_mod = _stub_module("src.common.utils.utils_voice")
|
|
||||||
voice_utils_mod.get_voice_text = dummy_get_voice_text
|
|
||||||
|
|
||||||
person_utils_mod = _stub_module("src.common.utils.utils_person")
|
|
||||||
person_utils_mod.PersonUtils = DummyPersonUtils
|
|
||||||
|
|
||||||
config_mod = _stub_module("src.config.config")
|
|
||||||
config_mod.global_config = DummyConfig()
|
|
||||||
|
|
||||||
|
|
||||||
def load_message_via_file(monkeypatch):
|
|
||||||
setup_mocks(monkeypatch)
|
|
||||||
file_path = Path(__file__).parent.parent.parent / "src" / "chat" / "message_receive" / "message.py"
|
|
||||||
spec = importlib.util.spec_from_file_location("message", file_path)
|
|
||||||
message_module = importlib.util.module_from_spec(spec)
|
|
||||||
monkeypatch.setitem(sys.modules, "message_module", message_module)
|
|
||||||
spec.loader.exec_module(message_module)
|
|
||||||
message_module.select = select
|
|
||||||
SessionMessageClass = message_module.SessionMessage
|
|
||||||
TextComponentClass = message_module.TextComponent
|
|
||||||
ImageComponentClass = message_module.ImageComponent
|
|
||||||
EmojiComponentClass = message_module.EmojiComponent
|
|
||||||
VoiceComponentClass = message_module.VoiceComponent
|
|
||||||
AtComponentClass = message_module.AtComponent
|
|
||||||
ReplyComponentClass = message_module.ReplyComponent
|
|
||||||
ForwardNodeComponentClass = message_module.ForwardNodeComponent
|
|
||||||
MessageSequenceClass = sys.modules["src.common.data_models.message_component_data_model"].MessageSequence
|
|
||||||
ForwardComponentClass = sys.modules["src.common.data_models.message_component_data_model"].ForwardComponent
|
|
||||||
globals()["SessionMessage"] = SessionMessageClass
|
|
||||||
globals()["TextComponent"] = TextComponentClass
|
|
||||||
globals()["ImageComponent"] = ImageComponentClass
|
|
||||||
globals()["EmojiComponent"] = EmojiComponentClass
|
|
||||||
globals()["VoiceComponent"] = VoiceComponentClass
|
|
||||||
globals()["AtComponent"] = AtComponentClass
|
|
||||||
globals()["ReplyComponent"] = ReplyComponentClass
|
|
||||||
globals()["ForwardNodeComponent"] = ForwardNodeComponentClass
|
|
||||||
globals()["MessageSequence"] = MessageSequenceClass
|
|
||||||
globals()["ForwardComponent"] = ForwardComponentClass
|
|
||||||
return message_module
|
|
||||||
|
|
||||||
|
|
||||||
def dummy_number_to_short_id(original_id: int, salt: str, length: int = 6) -> str:
|
|
||||||
return "X" * length # 返回固定的字符串,长度由参数决定,模拟生成短ID的行为
|
|
||||||
|
|
||||||
|
|
||||||
def dummy_is_bot_self(platform, user_id: str) -> bool:
|
|
||||||
return user_id == "bot_self"
|
|
||||||
|
|
||||||
|
|
||||||
def load_utils_via_file(monkeypatch):
|
|
||||||
setup_mocks(monkeypatch)
|
|
||||||
|
|
||||||
# Mock math_utils 模块,供 from .math_utils import number_to_short_id 使用
|
|
||||||
math_utils_mod = ModuleType("src.common.utils.math_utils")
|
|
||||||
math_utils_mod.number_to_short_id = dummy_number_to_short_id
|
|
||||||
math_utils_mod.TimestampMode = type(
|
|
||||||
"TimestampMode", (), {"NORMAL": "%Y-%m-%d %H:%M:%S", "NORMAL_NO_YMD": "%H:%M:%S", "RELATIVE": "relative"}
|
|
||||||
)
|
|
||||||
math_utils_mod.translate_timestamp_to_human_readable = lambda timestamp, mode: (
|
|
||||||
"2024-01-01 12:00:00"
|
|
||||||
) # 返回固定的时间字符串
|
|
||||||
monkeypatch.setitem(sys.modules, "src.common.utils.math_utils", math_utils_mod)
|
|
||||||
|
|
||||||
# 确保包层级模块存在于 sys.modules 中,使相对导入能正确解析
|
|
||||||
for pkg_name in ["src", "src.common", "src.common.utils"]:
|
|
||||||
if pkg_name not in sys.modules:
|
|
||||||
pkg_mod = ModuleType(pkg_name)
|
|
||||||
pkg_mod.__path__ = []
|
|
||||||
monkeypatch.setitem(sys.modules, pkg_name, pkg_mod)
|
|
||||||
|
|
||||||
file_path = Path(__file__).parent.parent.parent / "src" / "common" / "utils" / "utils_message.py"
|
|
||||||
spec = importlib.util.spec_from_file_location("src.common.utils.utils_message", file_path)
|
|
||||||
utils_module = importlib.util.module_from_spec(spec)
|
|
||||||
utils_module.__package__ = "src.common.utils" # 设置包,使相对导入生效
|
|
||||||
monkeypatch.setitem(sys.modules, "src.common.utils.utils_message", utils_module)
|
|
||||||
monkeypatch.setitem(sys.modules, "message_utils_module", utils_module)
|
|
||||||
spec.loader.exec_module(utils_module)
|
|
||||||
utils_module.is_bot_self = dummy_is_bot_self
|
|
||||||
return utils_module
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_message_utils(monkeypatch):
|
|
||||||
load_message_via_file(monkeypatch)
|
|
||||||
load_utils_via_file(monkeypatch)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_build_readable_message_basic(monkeypatch):
|
|
||||||
"""基础用例:单条消息,显示行号"""
|
|
||||||
load_message_via_file(monkeypatch)
|
|
||||||
utils_module = load_utils_via_file(monkeypatch)
|
|
||||||
MessageUtils = utils_module.MessageUtils
|
|
||||||
|
|
||||||
msg = SessionMessage("m1", datetime.now(), platform="test")
|
|
||||||
msg.platform = "test"
|
|
||||||
msg.session_id = "s_test"
|
|
||||||
user_info = UserInfo(user_id="u1", user_nickname="Alice")
|
|
||||||
msg.message_info = MessageInfo(user_info=user_info)
|
|
||||||
msg.raw_message = MessageSequence([TextComponent("Hello world")])
|
|
||||||
text, mapping, _ = await MessageUtils.build_readable_message([msg], anonymize=False, show_lineno=True)
|
|
||||||
assert "[1] Alice说:Hello world" in text
|
|
||||||
assert mapping == {}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_build_readable_message_anonymize(monkeypatch):
|
|
||||||
"""匿名化用例:验证 mapping 和返回文本"""
|
|
||||||
load_message_via_file(monkeypatch)
|
|
||||||
utils_module = load_utils_via_file(monkeypatch)
|
|
||||||
MessageUtils = utils_module.MessageUtils
|
|
||||||
|
|
||||||
msg = SessionMessage("m2", datetime.now(), platform="test")
|
|
||||||
msg.session_id = "s_test"
|
|
||||||
user_info = UserInfo(user_id="u42", user_nickname="Bob")
|
|
||||||
msg.message_info = MessageInfo(user_info=user_info)
|
|
||||||
msg.raw_message = MessageSequence([TextComponent("Secret text")])
|
|
||||||
text, mapping, _ = await MessageUtils.build_readable_message([msg], anonymize=True, show_lineno=False)
|
|
||||||
# 根据实现,original_name 为 user_nickname,因此文本中应包含原始名称
|
|
||||||
assert "XXXXXX说:" in text
|
|
||||||
assert "u42" in mapping
|
|
||||||
assert mapping["u42"][0] == "XXXXXX"
|
|
||||||
assert mapping["u42"][1] == "Bob"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_build_readable_message_replace_bot(monkeypatch):
|
|
||||||
"""替换机器人名用例:当 user_id 为 bot_self 时应被替换为 target_bot_name"""
|
|
||||||
load_message_via_file(monkeypatch)
|
|
||||||
utils_module = load_utils_via_file(monkeypatch)
|
|
||||||
MessageUtils = utils_module.MessageUtils
|
|
||||||
|
|
||||||
msg = SessionMessage("m3", datetime.now(), platform="test")
|
|
||||||
msg.session_id = "s_test"
|
|
||||||
user_info = UserInfo(user_id="bot_self", user_nickname="SomeBot")
|
|
||||||
msg.message_info = MessageInfo(user_info=user_info)
|
|
||||||
msg.raw_message = MessageSequence([TextComponent("ping")])
|
|
||||||
text, mapping, _ = await MessageUtils.build_readable_message([msg], replace_bot_name=True, target_bot_name="MAIBot")
|
|
||||||
assert "MAIBot说:ping" in text
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_build_readable_message_image_extraction(monkeypatch):
|
|
||||||
"""图片提取:验证 extract_pictures 为 True 时,文本中包含图片占位及 img_map 内容被返回"""
|
|
||||||
load_message_via_file(monkeypatch)
|
|
||||||
utils_module = load_utils_via_file(monkeypatch)
|
|
||||||
MessageUtils = utils_module.MessageUtils
|
|
||||||
|
|
||||||
# 构建包含图片组件的消息
|
|
||||||
img = ImageComponent(binary_hash="h", binary_data=b"\x01\x02", content="Img")
|
|
||||||
msg = SessionMessage("mi1", datetime.now(), platform="test")
|
|
||||||
msg.session_id = "s_img"
|
|
||||||
msg.raw_message = MessageSequence([img])
|
|
||||||
msg.message_info = MessageInfo(UserInfo(user_id="ui_img", user_nickname="ImgUser"))
|
|
||||||
text, mapping, _ = await MessageUtils.build_readable_message([msg], extract_pictures=True)
|
|
||||||
# 应包含图片描述占位
|
|
||||||
assert "图片1" in text
|
|
||||||
# mapping 不为空(匿名化未开启则为空)
|
|
||||||
assert isinstance(mapping, dict)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_build_readable_message_anonymize_and_replace_bot_name_and_lineno(monkeypatch):
|
|
||||||
"""组合用例:多个消息同时包含匿名化、机器人名称替换"""
|
|
||||||
load_message_via_file(monkeypatch)
|
|
||||||
utils_module = load_utils_via_file(monkeypatch)
|
|
||||||
MessageUtils = utils_module.MessageUtils
|
|
||||||
# 构建多个消息
|
|
||||||
msg1 = SessionMessage("m4", datetime.now(), platform="test")
|
|
||||||
msg1.session_id = "s_comb"
|
|
||||||
msg2 = SessionMessage("m5", datetime.now(), platform="test")
|
|
||||||
msg2.session_id = "s_comb"
|
|
||||||
msg1.message_info = MessageInfo(UserInfo(user_id="u_comb", user_nickname="Charlie"))
|
|
||||||
msg2.message_info = MessageInfo(UserInfo(user_id="bot_self", user_nickname="SomeBot"))
|
|
||||||
msg1.raw_message = MessageSequence([TextComponent("Hi")])
|
|
||||||
msg2.raw_message = MessageSequence([TextComponent("Hello")])
|
|
||||||
text, mapping, _ = await MessageUtils.build_readable_message(
|
|
||||||
[msg1, msg2],
|
|
||||||
anonymize=True,
|
|
||||||
replace_bot_name=True,
|
|
||||||
target_bot_name="MAIBot",
|
|
||||||
show_lineno=True,
|
|
||||||
)
|
|
||||||
# 验证文本内容
|
|
||||||
assert "[1] XXXXXX说:Hi" in text
|
|
||||||
assert "[2] MAIBot说:Hello" in text
|
|
||||||
# 验证 mapping 内容
|
|
||||||
assert "u_comb" in mapping
|
|
||||||
assert mapping["u_comb"][0] == "XXXXXX"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_build_readable_message_with_at(monkeypatch):
|
|
||||||
"""包含@组件的消息:验证@组件中的用户信息也被匿名化和替换"""
|
|
||||||
load_message_via_file(monkeypatch)
|
|
||||||
utils_module = load_utils_via_file(monkeypatch)
|
|
||||||
MessageUtils = utils_module.MessageUtils
|
|
||||||
|
|
||||||
# 构建包含回复组件的消息
|
|
||||||
at_comp = AtComponent(target_user_id="u_at", target_user_nickname="AtUser")
|
|
||||||
msg = SessionMessage("m_at", datetime.now(), platform="test")
|
|
||||||
msg.session_id = "s_at"
|
|
||||||
msg.raw_message = MessageSequence([at_comp])
|
|
||||||
msg.message_info = MessageInfo(UserInfo(user_id="u_main", user_nickname="MainUser"))
|
|
||||||
text, mapping, _ = await MessageUtils.build_readable_message(
|
|
||||||
[msg], anonymize=True, replace_bot_name=True, target_bot_name="MAIBot"
|
|
||||||
)
|
|
||||||
# 验证主消息和@组件中的用户信息都被处理
|
|
||||||
assert "XXXXXX说:" in text # 主消息用户被匿名化
|
|
||||||
assert "XXXXXX说:@XXXXXX" in text # @组件用户被匿名化
|
|
||||||
@@ -1,117 +0,0 @@
|
|||||||
"""统计模块数据库会话行为测试。"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from types import ModuleType
|
|
||||||
from typing import Any, Callable, Iterator
|
|
||||||
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.chat.utils import statistic
|
|
||||||
|
|
||||||
|
|
||||||
class _DummyResult:
|
|
||||||
"""模拟 SQLModel 查询结果对象。"""
|
|
||||||
|
|
||||||
def all(self) -> list[Any]:
|
|
||||||
"""返回空结果集。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[Any]: 空列表。
|
|
||||||
"""
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
class _DummySession:
|
|
||||||
"""模拟数据库 Session。"""
|
|
||||||
|
|
||||||
def exec(self, statement: Any) -> _DummyResult:
|
|
||||||
"""执行查询语句并返回空结果。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
statement: 待执行的查询语句。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
_DummyResult: 空结果对象。
|
|
||||||
"""
|
|
||||||
del statement
|
|
||||||
return _DummyResult()
|
|
||||||
|
|
||||||
|
|
||||||
def _build_fake_get_db_session(calls: list[bool]) -> Callable[[bool], Iterator[_DummySession]]:
|
|
||||||
"""构造一个记录 auto_commit 参数的假会话工厂。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
calls: 用于记录每次调用 auto_commit 参数的列表。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Callable[[bool], Iterator[_DummySession]]: 可替换 `get_db_session` 的上下文管理器工厂。
|
|
||||||
"""
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_DummySession]:
|
|
||||||
"""记录会话参数并返回假 Session。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
auto_commit: 是否启用自动提交。
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
Iterator[_DummySession]: 假 Session 对象。
|
|
||||||
"""
|
|
||||||
calls.append(auto_commit)
|
|
||||||
yield _DummySession()
|
|
||||||
|
|
||||||
return _fake_get_db_session
|
|
||||||
|
|
||||||
|
|
||||||
def _build_statistic_task() -> statistic.StatisticOutputTask:
|
|
||||||
"""构造一个最小可用的统计任务实例。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
statistic.StatisticOutputTask: 跳过 `__init__` 的测试实例。
|
|
||||||
"""
|
|
||||||
task = statistic.StatisticOutputTask.__new__(statistic.StatisticOutputTask)
|
|
||||||
task.name_mapping = {}
|
|
||||||
return task
|
|
||||||
|
|
||||||
|
|
||||||
def _is_bot_self(platform: str, user_id: str) -> bool:
|
|
||||||
"""返回固定的非机器人身份判断结果。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
platform: 平台名称。
|
|
||||||
user_id: 用户 ID。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 始终返回 ``False``。
|
|
||||||
"""
|
|
||||||
del platform
|
|
||||||
del user_id
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def test_statistic_read_queries_disable_auto_commit(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
"""统计模块的纯读查询应关闭自动提交,避免 Session 退出后对象被 expire。"""
|
|
||||||
calls: list[bool] = []
|
|
||||||
now = datetime.now()
|
|
||||||
task = _build_statistic_task()
|
|
||||||
|
|
||||||
monkeypatch.setattr(statistic, "get_db_session", _build_fake_get_db_session(calls))
|
|
||||||
|
|
||||||
utils_module = ModuleType("src.chat.utils.utils")
|
|
||||||
utils_module.is_bot_self = _is_bot_self
|
|
||||||
monkeypatch.setitem(sys.modules, "src.chat.utils.utils", utils_module)
|
|
||||||
monkeypatch.setattr(statistic, "fetch_online_time_since", lambda query_start_time: [])
|
|
||||||
monkeypatch.setattr(statistic, "fetch_model_usage_since", lambda query_start_time: [])
|
|
||||||
monkeypatch.setattr(statistic, "fetch_messages_since", lambda query_start_time: [])
|
|
||||||
monkeypatch.setattr(statistic, "fetch_tool_records_since", lambda query_start_time: [])
|
|
||||||
|
|
||||||
task._collect_message_count_for_period([("last_hour", now - timedelta(hours=1))])
|
|
||||||
task._collect_interval_data(now, hours=1, interval_minutes=60)
|
|
||||||
task._collect_metrics_interval_data(now, hours=1, interval_hours=1)
|
|
||||||
|
|
||||||
assert calls == []
|
|
||||||
@@ -1,131 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
from src.config.model_configs import APIProvider, ModelInfo
|
|
||||||
from src.llm_models.model_client.base_client import ResponseRequest
|
|
||||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
|
||||||
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
|
|
||||||
from src.llm_models.payload_content.tool_option import ToolCall, ToolOption
|
|
||||||
from src.llm_models.request_snapshot import (
|
|
||||||
attach_request_snapshot,
|
|
||||||
deserialize_messages_snapshot,
|
|
||||||
format_request_snapshot_log_info,
|
|
||||||
save_failed_request_snapshot,
|
|
||||||
serialize_messages_snapshot,
|
|
||||||
serialize_response_request_snapshot,
|
|
||||||
)
|
|
||||||
from src.llm_models import request_snapshot
|
|
||||||
|
|
||||||
|
|
||||||
def _build_api_provider() -> APIProvider:
|
|
||||||
return APIProvider(
|
|
||||||
api_key="secret-token",
|
|
||||||
base_url="https://example.com/v1",
|
|
||||||
name="test-provider",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_model_info() -> ModelInfo:
|
|
||||||
return ModelInfo(
|
|
||||||
api_provider="test-provider",
|
|
||||||
model_identifier="demo-model",
|
|
||||||
name="demo-model",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_response_request() -> ResponseRequest:
|
|
||||||
tool_call = ToolCall(
|
|
||||||
args={"query": "MaiBot"},
|
|
||||||
call_id="call_1",
|
|
||||||
func_name="search_web",
|
|
||||||
extra_content={"google": {"thought_signature": "c2lnbmF0dXJl"}},
|
|
||||||
)
|
|
||||||
message_list = [
|
|
||||||
MessageBuilder().set_role(RoleType.User).add_text_content("你好").add_image_content("png", "ZmFrZQ==").build(),
|
|
||||||
MessageBuilder().set_role(RoleType.Assistant).set_tool_calls([tool_call]).build(),
|
|
||||||
MessageBuilder()
|
|
||||||
.set_role(RoleType.Tool)
|
|
||||||
.set_tool_call_id("call_1")
|
|
||||||
.set_tool_name("search_web")
|
|
||||||
.add_text_content('{"ok": true}')
|
|
||||||
.build(),
|
|
||||||
]
|
|
||||||
return ResponseRequest(
|
|
||||||
extra_params={"trace_id": "trace-123"},
|
|
||||||
max_tokens=256,
|
|
||||||
message_list=message_list,
|
|
||||||
model_info=_build_model_info(),
|
|
||||||
response_format=RespFormat(RespFormatType.JSON_OBJ),
|
|
||||||
temperature=0.2,
|
|
||||||
tool_options=[ToolOption(name="search_web", description="搜索网页")],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_message_snapshot_roundtrip_preserves_tool_messages() -> None:
|
|
||||||
request = _build_response_request()
|
|
||||||
|
|
||||||
snapshot_messages = serialize_messages_snapshot(request.message_list)
|
|
||||||
restored_messages = deserialize_messages_snapshot(snapshot_messages)
|
|
||||||
|
|
||||||
assert len(restored_messages) == 3
|
|
||||||
assert restored_messages[0].role == RoleType.User
|
|
||||||
assert restored_messages[0].get_text_content() == "你好"
|
|
||||||
assert restored_messages[0].parts[1].image_format == "png"
|
|
||||||
assert restored_messages[1].role == RoleType.Assistant
|
|
||||||
assert restored_messages[1].tool_calls is not None
|
|
||||||
assert restored_messages[1].tool_calls[0].func_name == "search_web"
|
|
||||||
assert restored_messages[1].tool_calls[0].args == {"query": "MaiBot"}
|
|
||||||
assert restored_messages[1].tool_calls[0].extra_content == {"google": {"thought_signature": "c2lnbmF0dXJl"}}
|
|
||||||
assert restored_messages[2].role == RoleType.Tool
|
|
||||||
assert restored_messages[2].tool_call_id == "call_1"
|
|
||||||
assert restored_messages[2].tool_name == "search_web"
|
|
||||||
|
|
||||||
|
|
||||||
def test_failed_request_snapshot_contains_replay_entry(tmp_path: Path, monkeypatch) -> None:
|
|
||||||
monkeypatch.setattr(request_snapshot, "LLM_REQUEST_LOG_DIR", tmp_path)
|
|
||||||
|
|
||||||
request = _build_response_request()
|
|
||||||
provider = _build_api_provider()
|
|
||||||
snapshot_path = save_failed_request_snapshot(
|
|
||||||
api_provider=provider,
|
|
||||||
client_type="openai",
|
|
||||||
error=RuntimeError("boom"),
|
|
||||||
internal_request=serialize_response_request_snapshot(request),
|
|
||||||
model_info=request.model_info,
|
|
||||||
operation="chat.completions.create",
|
|
||||||
provider_request={"request_kwargs": {"model": request.model_info.model_identifier}},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert snapshot_path is not None
|
|
||||||
payload = json.loads(snapshot_path.read_text(encoding="utf-8"))
|
|
||||||
|
|
||||||
assert payload["internal_request"]["request_kind"] == "response"
|
|
||||||
assert payload["api_provider"]["name"] == "test-provider"
|
|
||||||
assert payload["replay"]["file_uri"] == snapshot_path.as_uri()
|
|
||||||
assert str(snapshot_path) in payload["replay"]["command"]
|
|
||||||
assert "secret-token" not in snapshot_path.read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def test_format_request_snapshot_log_info_includes_path_uri_and_command(tmp_path: Path, monkeypatch) -> None:
|
|
||||||
monkeypatch.setattr(request_snapshot, "LLM_REQUEST_LOG_DIR", tmp_path)
|
|
||||||
|
|
||||||
request = _build_response_request()
|
|
||||||
snapshot_path = save_failed_request_snapshot(
|
|
||||||
api_provider=_build_api_provider(),
|
|
||||||
client_type="openai",
|
|
||||||
error=ValueError("invalid"),
|
|
||||||
internal_request=serialize_response_request_snapshot(request),
|
|
||||||
model_info=request.model_info,
|
|
||||||
operation="chat.completions.create",
|
|
||||||
provider_request={"request_kwargs": {"messages": []}},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert snapshot_path is not None
|
|
||||||
exc = RuntimeError("wrapped")
|
|
||||||
attach_request_snapshot(exc, snapshot_path)
|
|
||||||
|
|
||||||
log_info = format_request_snapshot_log_info(exc)
|
|
||||||
assert str(snapshot_path) in log_info
|
|
||||||
assert snapshot_path.as_uri() in log_info
|
|
||||||
assert "uv run python scripts/replay_llm_request.py" in log_info
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
from src.chat.message_receive.chat_manager import ChatManager
|
|
||||||
from src.common.utils.utils_session import SessionUtils
|
|
||||||
|
|
||||||
|
|
||||||
def test_calculate_session_id_distinguishes_account_and_scope() -> None:
|
|
||||||
base_session_id = SessionUtils.calculate_session_id("qq", user_id="42")
|
|
||||||
same_base_session_id = SessionUtils.calculate_session_id("qq", user_id="42")
|
|
||||||
account_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123")
|
|
||||||
route_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123", scope="main")
|
|
||||||
|
|
||||||
assert base_session_id == same_base_session_id
|
|
||||||
assert account_scoped_session_id != base_session_id
|
|
||||||
assert route_scoped_session_id != account_scoped_session_id
|
|
||||||
|
|
||||||
|
|
||||||
def test_chat_manager_register_message_uses_route_metadata() -> None:
|
|
||||||
chat_manager = ChatManager()
|
|
||||||
message = SimpleNamespace(
|
|
||||||
platform="qq",
|
|
||||||
session_id="",
|
|
||||||
message_info=SimpleNamespace(
|
|
||||||
user_info=SimpleNamespace(user_id="42"),
|
|
||||||
group_info=SimpleNamespace(group_id="1000"),
|
|
||||||
additional_config={
|
|
||||||
"platform_io_account_id": "123",
|
|
||||||
"platform_io_scope": "main",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
chat_manager.register_message(message)
|
|
||||||
|
|
||||||
assert message.session_id == SessionUtils.calculate_session_id(
|
|
||||||
"qq",
|
|
||||||
user_id="42",
|
|
||||||
group_id="1000",
|
|
||||||
account_id="123",
|
|
||||||
scope="main",
|
|
||||||
)
|
|
||||||
assert chat_manager.last_messages[message.session_id] is message
|
|
||||||
@@ -1,161 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.webui import app as webui_app
|
|
||||||
|
|
||||||
|
|
||||||
def test_ensure_static_path_ready_uses_existing_static_path(tmp_path) -> None:
|
|
||||||
static_path = tmp_path / "dist"
|
|
||||||
static_path.mkdir()
|
|
||||||
(static_path / "index.html").write_text("<html></html>", encoding="utf-8")
|
|
||||||
|
|
||||||
with patch.object(webui_app, "_resolve_static_path", return_value=static_path):
|
|
||||||
result = webui_app._ensure_static_path_ready()
|
|
||||||
|
|
||||||
assert result == static_path
|
|
||||||
|
|
||||||
|
|
||||||
def test_ensure_static_path_ready_logs_install_hint_when_static_assets_are_missing() -> None:
|
|
||||||
with (
|
|
||||||
patch.object(webui_app, "_resolve_static_path", return_value=None),
|
|
||||||
patch.object(webui_app.logger, "warning") as warning_mock,
|
|
||||||
):
|
|
||||||
result = webui_app._ensure_static_path_ready()
|
|
||||||
|
|
||||||
assert result is None
|
|
||||||
warning_mock.assert_any_call(webui_app.t("startup.webui_static_assets_unavailable"))
|
|
||||||
warning_mock.assert_any_call(
|
|
||||||
webui_app.t("startup.webui_dashboard_package_hint", command=webui_app._MANUAL_INSTALL_COMMAND)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_ensure_static_path_ready_logs_index_error_when_static_path_is_invalid(tmp_path) -> None:
|
|
||||||
static_path = tmp_path / "dist"
|
|
||||||
static_path.mkdir()
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(webui_app, "_resolve_static_path", return_value=static_path),
|
|
||||||
patch.object(webui_app.logger, "warning") as warning_mock,
|
|
||||||
):
|
|
||||||
result = webui_app._ensure_static_path_ready()
|
|
||||||
|
|
||||||
assert result is None
|
|
||||||
warning_mock.assert_any_call(
|
|
||||||
webui_app.t("startup.webui_index_missing", index_path=static_path / "index.html")
|
|
||||||
)
|
|
||||||
warning_mock.assert_any_call(
|
|
||||||
webui_app.t("startup.webui_dashboard_package_hint", command=webui_app._MANUAL_INSTALL_COMMAND)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_setup_static_files_does_not_duplicate_warning_when_static_path_is_unavailable() -> None:
|
|
||||||
app = webui_app.FastAPI()
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(webui_app, "_ensure_static_path_ready", return_value=None),
|
|
||||||
patch.object(webui_app.logger, "warning") as warning_mock,
|
|
||||||
):
|
|
||||||
webui_app._setup_static_files(app)
|
|
||||||
|
|
||||||
warning_mock.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_static_path_prefers_installed_dashboard_package(monkeypatch, tmp_path) -> None:
|
|
||||||
package_dist = tmp_path / "site-packages" / "maibot_dashboard" / "dist"
|
|
||||||
package_dist.mkdir(parents=True)
|
|
||||||
|
|
||||||
class _DashboardModule:
|
|
||||||
@staticmethod
|
|
||||||
def get_dist_path() -> Path:
|
|
||||||
return package_dist
|
|
||||||
|
|
||||||
monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path)
|
|
||||||
|
|
||||||
with patch.object(webui_app, "import_module", return_value=_DashboardModule()):
|
|
||||||
resolved_path = webui_app._resolve_static_path()
|
|
||||||
|
|
||||||
assert resolved_path == package_dist
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_static_path_ignores_dashboard_dist_when_package_is_unavailable(monkeypatch, tmp_path) -> None:
|
|
||||||
dashboard_dist = tmp_path / "dashboard" / "dist"
|
|
||||||
dashboard_dist.mkdir(parents=True)
|
|
||||||
(dashboard_dist / "index.html").write_text("<html></html>", encoding="utf-8")
|
|
||||||
|
|
||||||
monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path)
|
|
||||||
|
|
||||||
with patch.object(webui_app, "import_module", side_effect=ImportError):
|
|
||||||
resolved_path = webui_app._resolve_static_path()
|
|
||||||
|
|
||||||
assert resolved_path is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_static_path_uses_package_even_when_dashboard_dist_exists(monkeypatch, tmp_path) -> None:
|
|
||||||
dashboard_dist = tmp_path / "dashboard" / "dist"
|
|
||||||
dashboard_dist.mkdir(parents=True)
|
|
||||||
|
|
||||||
package_dist = tmp_path / "site-packages" / "maibot_dashboard" / "dist"
|
|
||||||
package_dist.mkdir(parents=True)
|
|
||||||
|
|
||||||
class _DashboardModule:
|
|
||||||
@staticmethod
|
|
||||||
def get_dist_path() -> Path:
|
|
||||||
return package_dist
|
|
||||||
|
|
||||||
monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path)
|
|
||||||
|
|
||||||
with patch.object(webui_app, "import_module", return_value=_DashboardModule()):
|
|
||||||
resolved_path = webui_app._resolve_static_path()
|
|
||||||
|
|
||||||
assert resolved_path == package_dist
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_safe_static_file_path_allows_regular_static_file(tmp_path) -> None:
|
|
||||||
static_path = tmp_path / "dist"
|
|
||||||
asset_path = static_path / "assets" / "app.js"
|
|
||||||
asset_path.parent.mkdir(parents=True)
|
|
||||||
asset_path.write_text("console.log('ok')", encoding="utf-8")
|
|
||||||
|
|
||||||
resolved_path = webui_app._resolve_safe_static_file_path(static_path, "assets/app.js")
|
|
||||||
|
|
||||||
assert resolved_path == asset_path.resolve()
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_safe_static_file_path_rejects_relative_path_traversal(tmp_path) -> None:
|
|
||||||
static_path = tmp_path / "dist"
|
|
||||||
static_path.mkdir()
|
|
||||||
|
|
||||||
resolved_path = webui_app._resolve_safe_static_file_path(static_path, "../secret.txt")
|
|
||||||
|
|
||||||
assert resolved_path is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_safe_static_file_path_rejects_absolute_path_traversal(tmp_path) -> None:
|
|
||||||
static_path = tmp_path / "dist"
|
|
||||||
static_path.mkdir()
|
|
||||||
|
|
||||||
resolved_path = webui_app._resolve_safe_static_file_path(static_path, "/etc/passwd")
|
|
||||||
|
|
||||||
assert resolved_path is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_safe_static_file_path_rejects_symlink_escape(tmp_path) -> None:
|
|
||||||
static_path = tmp_path / "dist"
|
|
||||||
static_path.mkdir()
|
|
||||||
|
|
||||||
outside_dir = tmp_path / "outside"
|
|
||||||
outside_dir.mkdir()
|
|
||||||
outside_file = outside_dir / "secret.txt"
|
|
||||||
outside_file.write_text("secret", encoding="utf-8")
|
|
||||||
|
|
||||||
link_path = static_path / "escape"
|
|
||||||
try:
|
|
||||||
link_path.symlink_to(outside_dir, target_is_directory=True)
|
|
||||||
except OSError as exc:
|
|
||||||
pytest.skip(f"symlink is not supported in this environment: {exc}")
|
|
||||||
|
|
||||||
resolved_path = webui_app._resolve_safe_static_file_path(static_path, "escape/secret.txt")
|
|
||||||
|
|
||||||
assert resolved_path is None
|
|
||||||
@@ -1,147 +0,0 @@
|
|||||||
from src.config.official_configs import ChatConfig, MessageReceiveConfig
|
|
||||||
from src.config.config import Config
|
|
||||||
from src.config.config_base import ConfigBase, Field
|
|
||||||
from src.webui.config_schema import ConfigSchemaGenerator
|
|
||||||
|
|
||||||
|
|
||||||
def test_field_docs_in_schema():
|
|
||||||
"""Test that field descriptions are correctly extracted from field_docs (docstrings)."""
|
|
||||||
schema = ConfigSchemaGenerator.generate_schema(ChatConfig)
|
|
||||||
talk_value = next(f for f in schema["fields"] if f["name"] == "talk_value")
|
|
||||||
|
|
||||||
# Verify description field exists
|
|
||||||
assert "description" in talk_value
|
|
||||||
# Verify description contains expected Chinese text from the docstring
|
|
||||||
assert "聊天频率" in talk_value["description"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_json_schema_extra_merged():
|
|
||||||
"""Test that json_schema_extra fields are correctly merged into output."""
|
|
||||||
schema = ConfigSchemaGenerator.generate_schema(ChatConfig)
|
|
||||||
talk_value = next(f for f in schema["fields"] if f["name"] == "talk_value")
|
|
||||||
|
|
||||||
# Verify UI metadata fields from json_schema_extra exist
|
|
||||||
assert talk_value.get("x-widget") == "slider"
|
|
||||||
assert talk_value.get("x-icon") == "message-circle"
|
|
||||||
assert talk_value.get("step") == 0.1
|
|
||||||
|
|
||||||
|
|
||||||
def test_pydantic_constraints_mapped():
|
|
||||||
"""Test that Pydantic constraints (ge/le) are correctly mapped to minValue/maxValue."""
|
|
||||||
schema = ConfigSchemaGenerator.generate_schema(ChatConfig)
|
|
||||||
talk_value = next(f for f in schema["fields"] if f["name"] == "talk_value")
|
|
||||||
|
|
||||||
# Verify constraints are mapped to frontend naming convention
|
|
||||||
assert "minValue" in talk_value
|
|
||||||
assert "maxValue" in talk_value
|
|
||||||
assert talk_value["minValue"] == 0 # From ge=0
|
|
||||||
assert talk_value["maxValue"] == 1 # From le=1
|
|
||||||
|
|
||||||
|
|
||||||
def test_nested_model_schema():
|
|
||||||
"""Test that nested models (ConfigBase fields) are correctly handled."""
|
|
||||||
schema = ConfigSchemaGenerator.generate_schema(Config)
|
|
||||||
|
|
||||||
# Verify nested structure exists
|
|
||||||
assert "nested" in schema
|
|
||||||
assert "chat" in schema["nested"]
|
|
||||||
|
|
||||||
# Verify nested chat schema is complete
|
|
||||||
chat_schema = schema["nested"]["chat"]
|
|
||||||
assert chat_schema["className"] == "ChatConfig"
|
|
||||||
assert "fields" in chat_schema
|
|
||||||
|
|
||||||
# Verify nested schema fields include description and metadata
|
|
||||||
talk_value = next(f for f in chat_schema["fields"] if f["name"] == "talk_value")
|
|
||||||
assert "description" in talk_value
|
|
||||||
assert talk_value.get("x-widget") == "slider"
|
|
||||||
assert talk_value.get("minValue") == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_field_without_extra_metadata():
|
|
||||||
"""Test that fields without json_schema_extra still generate valid schema."""
|
|
||||||
schema = ConfigSchemaGenerator.generate_schema(ChatConfig)
|
|
||||||
inevitable_at_reply = next(f for f in schema["fields"] if f["name"] == "inevitable_at_reply")
|
|
||||||
|
|
||||||
# Verify basic fields are generated
|
|
||||||
assert "name" in inevitable_at_reply
|
|
||||||
assert inevitable_at_reply["name"] == "inevitable_at_reply"
|
|
||||||
assert "type" in inevitable_at_reply
|
|
||||||
assert inevitable_at_reply["type"] == "boolean"
|
|
||||||
assert "label" in inevitable_at_reply
|
|
||||||
assert "required" in inevitable_at_reply
|
|
||||||
|
|
||||||
# Verify no x-widget or x-icon from json_schema_extra (since field has none)
|
|
||||||
# These fields should only be present if explicitly defined in json_schema_extra
|
|
||||||
assert not inevitable_at_reply.get("x-widget")
|
|
||||||
assert not inevitable_at_reply.get("x-icon")
|
|
||||||
|
|
||||||
|
|
||||||
def test_all_top_level_sections_have_ui_metadata():
|
|
||||||
"""所有顶层配置节都必须声明 uiParent 或独立 Tab 的标签与图标。"""
|
|
||||||
schema = ConfigSchemaGenerator.generate_schema(Config)
|
|
||||||
|
|
||||||
for section_name, section_schema in schema["nested"].items():
|
|
||||||
has_parent = bool(section_schema.get("uiParent"))
|
|
||||||
has_host_meta = bool(section_schema.get("uiLabel")) and bool(section_schema.get("uiIcon"))
|
|
||||||
assert has_parent or has_host_meta, f"{section_name} 缺少 UI 元数据"
|
|
||||||
|
|
||||||
|
|
||||||
def test_maisaka_is_host_tab_and_mcp_is_attached_to_it():
|
|
||||||
"""MaiSaka 应作为独立 Tab,MCP 作为其子配置挂载。"""
|
|
||||||
schema = ConfigSchemaGenerator.generate_schema(Config)
|
|
||||||
|
|
||||||
maisaka_schema = schema["nested"]["maisaka"]
|
|
||||||
mcp_schema = schema["nested"]["mcp"]
|
|
||||||
|
|
||||||
assert maisaka_schema.get("uiParent") is None
|
|
||||||
assert maisaka_schema.get("uiLabel") == "MaiSaka"
|
|
||||||
assert maisaka_schema.get("uiIcon") == "message-circle"
|
|
||||||
assert mcp_schema.get("uiParent") == "maisaka"
|
|
||||||
|
|
||||||
|
|
||||||
def test_memory_query_config_fields_are_exposed():
|
|
||||||
"""query_memory 开关和默认条数应出现在记忆配置 schema 中。"""
|
|
||||||
schema = ConfigSchemaGenerator.generate_schema(Config)
|
|
||||||
memory_schema = schema["nested"]["memory"]
|
|
||||||
|
|
||||||
assert memory_schema.get("uiParent") == "emoji"
|
|
||||||
|
|
||||||
enable_field = next(field for field in memory_schema["fields"] if field["name"] == "enable_memory_query_tool")
|
|
||||||
limit_field = next(field for field in memory_schema["fields"] if field["name"] == "memory_query_default_limit")
|
|
||||||
|
|
||||||
assert enable_field["type"] == "boolean"
|
|
||||||
assert enable_field.get("x-widget") == "switch"
|
|
||||||
assert enable_field.get("x-icon") == "database"
|
|
||||||
|
|
||||||
assert limit_field["type"] == "integer"
|
|
||||||
assert limit_field.get("x-widget") == "input"
|
|
||||||
assert limit_field.get("x-icon") == "hash"
|
|
||||||
assert limit_field.get("minValue") == 1
|
|
||||||
assert limit_field.get("maxValue") == 20
|
|
||||||
|
|
||||||
|
|
||||||
def test_set_field_is_mapped_as_array():
|
|
||||||
"""set[str] 应映射为前端可识别的 array。"""
|
|
||||||
schema = ConfigSchemaGenerator.generate_schema(MessageReceiveConfig)
|
|
||||||
ban_words = next(field for field in schema["fields"] if field["name"] == "ban_words")
|
|
||||||
|
|
||||||
assert ban_words["type"] == "array"
|
|
||||||
assert ban_words["items"]["type"] == "string"
|
|
||||||
|
|
||||||
|
|
||||||
def test_advanced_fields_are_hidden_from_webui_schema():
|
|
||||||
"""advanced=True 的字段不应出现在 WebUI 配置 schema 中,未声明时默认展示。"""
|
|
||||||
|
|
||||||
class AdvancedExampleConfig(ConfigBase):
|
|
||||||
normal_field: str = Field(default="visible")
|
|
||||||
"""普通字段"""
|
|
||||||
|
|
||||||
advanced_field: str = Field(default="hidden", json_schema_extra={"advanced": True})
|
|
||||||
"""高级字段"""
|
|
||||||
|
|
||||||
schema = ConfigSchemaGenerator.generate_schema(AdvancedExampleConfig)
|
|
||||||
field_names = {field["name"] for field in schema["fields"]}
|
|
||||||
|
|
||||||
assert "normal_field" in field_names
|
|
||||||
assert "advanced_field" not in field_names
|
|
||||||
@@ -1,461 +0,0 @@
|
|||||||
"""表情包路由 API 测试
|
|
||||||
|
|
||||||
测试 src/webui/routers/emoji.py 中的核心 emoji 路由端点
|
|
||||||
使用内存 SQLite 数据库和 FastAPI TestClient
|
|
||||||
"""
|
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Generator
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
from sqlalchemy.pool import StaticPool
|
|
||||||
from sqlmodel import Session, SQLModel, create_engine
|
|
||||||
|
|
||||||
from src.common.database.database_model import Images, ImageType
|
|
||||||
from src.webui.core import TokenManager
|
|
||||||
from src.webui.routers.emoji import router
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def test_engine():
|
|
||||||
"""创建内存 SQLite 引擎用于测试"""
|
|
||||||
engine = create_engine(
|
|
||||||
"sqlite://",
|
|
||||||
connect_args={"check_same_thread": False},
|
|
||||||
poolclass=StaticPool,
|
|
||||||
)
|
|
||||||
SQLModel.metadata.create_all(engine)
|
|
||||||
return engine
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def test_session(test_engine) -> Generator[Session, None, None]:
|
|
||||||
"""创建测试数据库会话"""
|
|
||||||
with Session(test_engine) as session:
|
|
||||||
yield session
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def test_app(test_session):
|
|
||||||
"""创建测试 FastAPI 应用并覆盖 get_db_session 依赖"""
|
|
||||||
app = FastAPI()
|
|
||||||
app.include_router(router)
|
|
||||||
|
|
||||||
# Create a context manager that yields the test session
|
|
||||||
@contextmanager
|
|
||||||
def override_get_db_session(auto_commit=True):
|
|
||||||
"""Override get_db_session to use test session"""
|
|
||||||
try:
|
|
||||||
yield test_session
|
|
||||||
if auto_commit:
|
|
||||||
test_session.commit()
|
|
||||||
except Exception:
|
|
||||||
test_session.rollback()
|
|
||||||
raise
|
|
||||||
|
|
||||||
with patch("src.webui.routers.emoji.get_db_session", override_get_db_session):
|
|
||||||
yield app
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def client(test_app):
|
|
||||||
"""创建 TestClient"""
|
|
||||||
return TestClient(test_app)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def auth_token():
|
|
||||||
"""创建有效的认证 token"""
|
|
||||||
token_manager = TokenManager(secret_key="test-secret-key", token_expire_hours=24)
|
|
||||||
return token_manager.create_token()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def sample_emojis(test_session) -> list[Images]:
|
|
||||||
"""插入测试用表情包数据"""
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
emojis = [
|
|
||||||
Images(
|
|
||||||
image_type=ImageType.EMOJI,
|
|
||||||
full_path="/data/emoji_registed/test1.png",
|
|
||||||
image_hash=hashlib.sha256(b"test1").hexdigest(),
|
|
||||||
description="测试表情包 1",
|
|
||||||
emotion="开心,快乐",
|
|
||||||
query_count=10,
|
|
||||||
is_registered=True,
|
|
||||||
is_banned=False,
|
|
||||||
record_time=datetime(2026, 1, 1, 10, 0, 0),
|
|
||||||
register_time=datetime(2026, 1, 1, 10, 0, 0),
|
|
||||||
last_used_time=datetime(2026, 1, 2, 10, 0, 0),
|
|
||||||
),
|
|
||||||
Images(
|
|
||||||
image_type=ImageType.EMOJI,
|
|
||||||
full_path="/data/emoji_registed/test2.gif",
|
|
||||||
image_hash=hashlib.sha256(b"test2").hexdigest(),
|
|
||||||
description="测试表情包 2",
|
|
||||||
emotion="难过",
|
|
||||||
query_count=5,
|
|
||||||
is_registered=False,
|
|
||||||
is_banned=False,
|
|
||||||
record_time=datetime(2026, 1, 3, 10, 0, 0),
|
|
||||||
register_time=None,
|
|
||||||
last_used_time=None,
|
|
||||||
),
|
|
||||||
Images(
|
|
||||||
image_type=ImageType.EMOJI,
|
|
||||||
full_path="/data/emoji_registed/test3.webp",
|
|
||||||
image_hash=hashlib.sha256(b"test3").hexdigest(),
|
|
||||||
description="测试表情包 3",
|
|
||||||
emotion="生气",
|
|
||||||
query_count=20,
|
|
||||||
is_registered=True,
|
|
||||||
is_banned=True,
|
|
||||||
record_time=datetime(2026, 1, 4, 10, 0, 0),
|
|
||||||
register_time=datetime(2026, 1, 4, 10, 0, 0),
|
|
||||||
last_used_time=datetime(2026, 1, 5, 10, 0, 0),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
for emoji in emojis:
|
|
||||||
test_session.add(emoji)
|
|
||||||
test_session.commit()
|
|
||||||
|
|
||||||
for emoji in emojis:
|
|
||||||
test_session.refresh(emoji)
|
|
||||||
|
|
||||||
return emojis
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def mock_token_verify():
|
|
||||||
"""Mock token verification to always succeed"""
|
|
||||||
with patch("src.webui.routers.emoji.verify_auth_token", return_value=True):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 测试用例 ====================
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_emojis_basic(client, sample_emojis, mock_token_verify):
|
|
||||||
"""测试获取表情包列表(基本分页)"""
|
|
||||||
response = client.get("/emoji/list?page=1&page_size=10")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["total"] == 3
|
|
||||||
assert data["page"] == 1
|
|
||||||
assert data["page_size"] == 10
|
|
||||||
assert len(data["data"]) == 3
|
|
||||||
|
|
||||||
# 验证第一个表情包字段
|
|
||||||
emoji = data["data"][0]
|
|
||||||
assert "id" in emoji
|
|
||||||
assert "full_path" in emoji
|
|
||||||
assert "emoji_hash" in emoji
|
|
||||||
assert "description" in emoji
|
|
||||||
assert "query_count" in emoji
|
|
||||||
assert "is_registered" in emoji
|
|
||||||
assert "is_banned" in emoji
|
|
||||||
assert "emotion" in emoji
|
|
||||||
assert "record_time" in emoji
|
|
||||||
assert "register_time" in emoji
|
|
||||||
assert "last_used_time" in emoji
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_emojis_pagination(client, sample_emojis, mock_token_verify):
|
|
||||||
"""测试分页功能"""
|
|
||||||
response = client.get("/emoji/list?page=1&page_size=2")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["total"] == 3
|
|
||||||
assert len(data["data"]) == 2
|
|
||||||
|
|
||||||
# 第二页
|
|
||||||
response = client.get("/emoji/list?page=2&page_size=2")
|
|
||||||
data = response.json()
|
|
||||||
assert len(data["data"]) == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_emojis_search(client, sample_emojis, mock_token_verify):
|
|
||||||
"""测试搜索过滤"""
|
|
||||||
response = client.get("/emoji/list?search=表情包 2")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["total"] == 1
|
|
||||||
assert data["data"][0]["description"] == "测试表情包 2"
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_emojis_filter_registered(client, sample_emojis, mock_token_verify):
|
|
||||||
"""测试 is_registered 过滤"""
|
|
||||||
response = client.get("/emoji/list?is_registered=true")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["total"] == 2
|
|
||||||
assert all(emoji["is_registered"] is True for emoji in data["data"])
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_emojis_filter_banned(client, sample_emojis, mock_token_verify):
|
|
||||||
"""测试 is_banned 过滤"""
|
|
||||||
response = client.get("/emoji/list?is_banned=true")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["total"] == 1
|
|
||||||
assert data["data"][0]["is_banned"] is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_emojis_sort_by_query_count(client, sample_emojis, mock_token_verify):
|
|
||||||
"""测试按 query_count 排序"""
|
|
||||||
response = client.get("/emoji/list?sort_by=query_count&sort_order=desc")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
assert data["success"] is True
|
|
||||||
# 验证降序排列 (20 > 10 > 5)
|
|
||||||
assert data["data"][0]["query_count"] == 20
|
|
||||||
assert data["data"][1]["query_count"] == 10
|
|
||||||
assert data["data"][2]["query_count"] == 5
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_emoji_detail_success(client, sample_emojis, mock_token_verify):
|
|
||||||
"""测试获取表情包详情(成功)"""
|
|
||||||
emoji_id = sample_emojis[0].id
|
|
||||||
response = client.get(f"/emoji/{emoji_id}")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["data"]["id"] == emoji_id
|
|
||||||
assert data["data"]["emoji_hash"] == sample_emojis[0].image_hash
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_emoji_detail_not_found(client, mock_token_verify):
|
|
||||||
"""测试获取不存在的表情包(404)"""
|
|
||||||
response = client.get("/emoji/99999")
|
|
||||||
|
|
||||||
assert response.status_code == 404
|
|
||||||
data = response.json()
|
|
||||||
assert "未找到" in data["detail"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_emoji_description(client, sample_emojis, mock_token_verify):
|
|
||||||
"""测试更新表情包描述"""
|
|
||||||
emoji_id = sample_emojis[0].id
|
|
||||||
response = client.patch(
|
|
||||||
f"/emoji/{emoji_id}",
|
|
||||||
json={"description": "更新后的描述"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["data"]["description"] == "更新后的描述"
|
|
||||||
assert "成功更新" in data["message"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_emoji_register_status(client, sample_emojis, mock_token_verify, test_session):
|
|
||||||
"""测试更新注册状态(False -> True 应设置 register_time)"""
|
|
||||||
emoji_id = sample_emojis[1].id # 未注册的表情包
|
|
||||||
response = client.patch(
|
|
||||||
f"/emoji/{emoji_id}",
|
|
||||||
json={"is_registered": True},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["data"]["is_registered"] is True
|
|
||||||
assert data["data"]["register_time"] is not None # 应该设置了注册时间
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_emoji_no_fields(client, sample_emojis, mock_token_verify):
|
|
||||||
"""测试更新请求未提供任何字段(400)"""
|
|
||||||
emoji_id = sample_emojis[0].id
|
|
||||||
response = client.patch(f"/emoji/{emoji_id}", json={})
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
data = response.json()
|
|
||||||
assert "未提供任何需要更新的字段" in data["detail"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_emoji_not_found(client, mock_token_verify):
|
|
||||||
"""测试更新不存在的表情包(404)"""
|
|
||||||
response = client.patch("/emoji/99999", json={"description": "test"})
|
|
||||||
|
|
||||||
assert response.status_code == 404
|
|
||||||
data = response.json()
|
|
||||||
assert "未找到" in data["detail"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_emoji_success(client, sample_emojis, mock_token_verify, test_session):
|
|
||||||
"""测试删除表情包(成功)"""
|
|
||||||
emoji_id = sample_emojis[0].id
|
|
||||||
response = client.delete(f"/emoji/{emoji_id}")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
assert data["success"] is True
|
|
||||||
assert "成功删除" in data["message"]
|
|
||||||
|
|
||||||
# 验证数据库中已删除
|
|
||||||
from sqlmodel import select
|
|
||||||
|
|
||||||
statement = select(Images).where(Images.id == emoji_id)
|
|
||||||
result = test_session.exec(statement).first()
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_emoji_not_found(client, mock_token_verify):
|
|
||||||
"""测试删除不存在的表情包(404)"""
|
|
||||||
response = client.delete("/emoji/99999")
|
|
||||||
|
|
||||||
assert response.status_code == 404
|
|
||||||
data = response.json()
|
|
||||||
assert "未找到" in data["detail"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_batch_delete_success(client, sample_emojis, mock_token_verify, test_session):
|
|
||||||
"""测试批量删除表情包(全部成功)"""
|
|
||||||
emoji_ids = [sample_emojis[0].id, sample_emojis[1].id]
|
|
||||||
response = client.post("/emoji/batch/delete", json={"emoji_ids": emoji_ids})
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["deleted_count"] == 2
|
|
||||||
assert data["failed_count"] == 0
|
|
||||||
assert "成功删除 2 个表情包" in data["message"]
|
|
||||||
|
|
||||||
# 验证数据库中已删除
|
|
||||||
from sqlmodel import select
|
|
||||||
|
|
||||||
for emoji_id in emoji_ids:
|
|
||||||
statement = select(Images).where(Images.id == emoji_id)
|
|
||||||
result = test_session.exec(statement).first()
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_batch_delete_partial_failure(client, sample_emojis, mock_token_verify):
|
|
||||||
"""测试批量删除(部分失败)"""
|
|
||||||
emoji_ids = [sample_emojis[0].id, 99999] # 第二个 ID 不存在
|
|
||||||
response = client.post("/emoji/batch/delete", json={"emoji_ids": emoji_ids})
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["deleted_count"] == 1
|
|
||||||
assert data["failed_count"] == 1
|
|
||||||
assert 99999 in data["failed_ids"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_batch_delete_empty_list(client, mock_token_verify):
|
|
||||||
"""测试批量删除空列表(400)"""
|
|
||||||
response = client.post("/emoji/batch/delete", json={"emoji_ids": []})
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
data = response.json()
|
|
||||||
assert "未提供要删除的表情包ID" in data["detail"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_auth_required_list(client):
|
|
||||||
"""测试未认证访问列表端点(401)"""
|
|
||||||
# Without mock_token_verify fixture
|
|
||||||
with patch("src.webui.routers.emoji.verify_auth_token", return_value=False):
|
|
||||||
client.get("/emoji/list")
|
|
||||||
# verify_auth_token 返回 False 会触发 HTTPException
|
|
||||||
# 但具体状态码取决于 verify_auth_token_from_cookie_or_header 的实现
|
|
||||||
# 这里假设它抛出 401
|
|
||||||
|
|
||||||
|
|
||||||
def test_auth_required_update(client, sample_emojis):
|
|
||||||
"""测试未认证访问更新端点(401)"""
|
|
||||||
with patch("src.webui.routers.emoji.verify_auth_token", return_value=False):
|
|
||||||
emoji_id = sample_emojis[0].id
|
|
||||||
client.patch(f"/emoji/{emoji_id}", json={"description": "test"})
|
|
||||||
# Should be unauthorized
|
|
||||||
|
|
||||||
|
|
||||||
def test_emoji_to_response_field_mapping(sample_emojis):
|
|
||||||
"""测试 emoji_to_response 字段映射(image_hash -> emoji_hash)"""
|
|
||||||
from src.webui.routers.emoji import emoji_to_response
|
|
||||||
|
|
||||||
emoji = sample_emojis[0]
|
|
||||||
response = emoji_to_response(emoji)
|
|
||||||
|
|
||||||
# 验证 API 字段名称
|
|
||||||
assert hasattr(response, "emoji_hash")
|
|
||||||
assert response.emoji_hash == emoji.image_hash
|
|
||||||
|
|
||||||
# 验证时间戳转换
|
|
||||||
assert isinstance(response.record_time, float)
|
|
||||||
assert response.record_time == emoji.record_time.timestamp()
|
|
||||||
|
|
||||||
if emoji.register_time:
|
|
||||||
assert isinstance(response.register_time, float)
|
|
||||||
assert response.register_time == emoji.register_time.timestamp()
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_emojis_only_emoji_type(client, test_session, mock_token_verify):
|
|
||||||
"""测试列表只返回 type=EMOJI 的记录(不包括其他类型)"""
|
|
||||||
# 插入一个非 EMOJI 类型的图片
|
|
||||||
non_emoji = Images(
|
|
||||||
image_type=ImageType.IMAGE, # 不是 EMOJI
|
|
||||||
full_path="/data/images/test.png",
|
|
||||||
image_hash="hash_image",
|
|
||||||
description="非表情包图片",
|
|
||||||
query_count=0,
|
|
||||||
is_registered=False,
|
|
||||||
is_banned=False,
|
|
||||||
record_time=datetime.now(),
|
|
||||||
)
|
|
||||||
test_session.add(non_emoji)
|
|
||||||
test_session.commit()
|
|
||||||
|
|
||||||
# 插入一个 EMOJI 类型
|
|
||||||
emoji = Images(
|
|
||||||
image_type=ImageType.EMOJI,
|
|
||||||
full_path="/data/emoji_registed/emoji.png",
|
|
||||||
image_hash="hash_emoji",
|
|
||||||
description="表情包",
|
|
||||||
query_count=0,
|
|
||||||
is_registered=True,
|
|
||||||
is_banned=False,
|
|
||||||
record_time=datetime.now(),
|
|
||||||
)
|
|
||||||
test_session.add(emoji)
|
|
||||||
test_session.commit()
|
|
||||||
|
|
||||||
response = client.get("/emoji/list")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
# 只应该返回 1 个 EMOJI 类型的记录
|
|
||||||
assert data["total"] == 1
|
|
||||||
assert data["data"][0]["description"] == "表情包"
|
|
||||||
@@ -1,529 +0,0 @@
|
|||||||
"""Expression routes pytest tests"""
|
|
||||||
|
|
||||||
from typing import Generator
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from fastapi import APIRouter, FastAPI
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
from sqlalchemy import text
|
|
||||||
from sqlalchemy.pool import StaticPool
|
|
||||||
from sqlmodel import Session, SQLModel, create_engine, select
|
|
||||||
|
|
||||||
from src.common.database.database_model import Expression, ModifiedBy
|
|
||||||
from src.webui.dependencies import require_auth
|
|
||||||
|
|
||||||
|
|
||||||
def create_test_app() -> FastAPI:
|
|
||||||
"""Create minimal test app with only expression router"""
|
|
||||||
app = FastAPI(title="Test App")
|
|
||||||
from src.webui.routers.expression import router as expression_router
|
|
||||||
|
|
||||||
main_router = APIRouter(prefix="/api/webui")
|
|
||||||
main_router.include_router(expression_router)
|
|
||||||
app.include_router(main_router)
|
|
||||||
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
app = create_test_app()
|
|
||||||
|
|
||||||
|
|
||||||
# Test database setup
|
|
||||||
@pytest.fixture(name="test_engine")
|
|
||||||
def test_engine_fixture():
|
|
||||||
"""Create in-memory SQLite database for testing"""
|
|
||||||
engine = create_engine(
|
|
||||||
"sqlite://",
|
|
||||||
connect_args={"check_same_thread": False},
|
|
||||||
poolclass=StaticPool,
|
|
||||||
)
|
|
||||||
SQLModel.metadata.create_all(engine)
|
|
||||||
return engine
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="test_session")
|
|
||||||
def test_session_fixture(test_engine) -> Generator[Session, None, None]:
|
|
||||||
"""Create a test database session with transaction rollback"""
|
|
||||||
connection = test_engine.connect()
|
|
||||||
transaction = connection.begin()
|
|
||||||
session = Session(bind=connection)
|
|
||||||
|
|
||||||
yield session
|
|
||||||
|
|
||||||
session.close()
|
|
||||||
transaction.rollback()
|
|
||||||
connection.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="client")
|
|
||||||
def client_fixture(test_session: Session, monkeypatch) -> Generator[TestClient, None, None]:
|
|
||||||
"""Create TestClient with overridden database session"""
|
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def get_test_db_session():
|
|
||||||
yield test_session
|
|
||||||
test_session.commit()
|
|
||||||
|
|
||||||
monkeypatch.setattr("src.webui.routers.expression.get_db_session", get_test_db_session)
|
|
||||||
|
|
||||||
with TestClient(app) as client:
|
|
||||||
yield client
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="mock_auth")
|
|
||||||
def mock_auth_fixture():
|
|
||||||
"""Mock authentication to always return True"""
|
|
||||||
app.dependency_overrides[require_auth] = lambda: "test-token"
|
|
||||||
yield
|
|
||||||
app.dependency_overrides.clear()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="sample_expression")
|
|
||||||
def sample_expression_fixture(test_session: Session) -> Expression:
|
|
||||||
"""Insert a sample expression into test database"""
|
|
||||||
test_session.execute(
|
|
||||||
text(
|
|
||||||
"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
|
|
||||||
"VALUES (1, '测试情景', '测试风格', '[\"测试内容1\", \"测试内容2\"]', 10, '2026-02-17 12:00:00', '2026-02-15 10:00:00', 'test_chat_001', 0, 0)"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
test_session.commit()
|
|
||||||
|
|
||||||
expression = test_session.exec(select(Expression).where(Expression.id == 1)).first()
|
|
||||||
assert expression is not None
|
|
||||||
return expression
|
|
||||||
|
|
||||||
|
|
||||||
# ============ Tests ============
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_expressions_empty(client: TestClient, mock_auth):
|
|
||||||
"""Test GET /expression/list with empty database"""
|
|
||||||
response = client.get("/api/webui/expression/list")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["total"] == 0
|
|
||||||
assert data["page"] == 1
|
|
||||||
assert data["page_size"] == 20
|
|
||||||
assert data["data"] == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_expressions_with_data(client: TestClient, mock_auth, sample_expression: Expression):
|
|
||||||
"""Test GET /expression/list returns expression data"""
|
|
||||||
response = client.get("/api/webui/expression/list")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["total"] == 1
|
|
||||||
assert len(data["data"]) == 1
|
|
||||||
|
|
||||||
expr_data = data["data"][0]
|
|
||||||
assert expr_data["id"] == sample_expression.id
|
|
||||||
assert expr_data["situation"] == "测试情景"
|
|
||||||
assert expr_data["style"] == "测试风格"
|
|
||||||
assert expr_data["chat_id"] == "test_chat_001"
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_expressions_pagination(client: TestClient, mock_auth, test_session: Session):
|
|
||||||
"""Test GET /expression/list pagination works correctly"""
|
|
||||||
for i in range(5):
|
|
||||||
test_session.execute(
|
|
||||||
text(
|
|
||||||
f"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
|
|
||||||
f"VALUES ({i + 1}, '情景{i}', '风格{i}', '[]', 0, '2026-02-17 12:0{i}:00', '2026-02-15 10:00:00', 'chat_{i}', 0, 0)"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
test_session.commit()
|
|
||||||
|
|
||||||
# Request page 1 with page_size=2
|
|
||||||
response = client.get("/api/webui/expression/list?page=1&page_size=2")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["total"] == 5
|
|
||||||
assert data["page"] == 1
|
|
||||||
assert data["page_size"] == 2
|
|
||||||
assert len(data["data"]) == 2
|
|
||||||
|
|
||||||
# Request page 2
|
|
||||||
response = client.get("/api/webui/expression/list?page=2&page_size=2")
|
|
||||||
data = response.json()
|
|
||||||
assert data["page"] == 2
|
|
||||||
assert len(data["data"]) == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_expressions_search(client: TestClient, mock_auth, test_session: Session):
|
|
||||||
"""Test GET /expression/list with search filter"""
|
|
||||||
test_session.execute(
|
|
||||||
text(
|
|
||||||
"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
|
|
||||||
"VALUES (1, '找人吃饭', '热情', '[]', 0, datetime('now'), datetime('now'), 'chat_001', 0, 0)"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
test_session.execute(
|
|
||||||
text(
|
|
||||||
"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
|
|
||||||
"VALUES (2, '拒绝邀请', '礼貌', '[]', 0, datetime('now'), datetime('now'), 'chat_002', 0, 0)"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
test_session.commit()
|
|
||||||
|
|
||||||
# Search for "吃饭"
|
|
||||||
response = client.get("/api/webui/expression/list?search=吃饭")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["total"] == 1
|
|
||||||
assert data["data"][0]["situation"] == "找人吃饭"
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_expressions_chat_filter(client: TestClient, mock_auth, test_session: Session):
|
|
||||||
"""Test GET /expression/list with chat_id filter"""
|
|
||||||
test_session.execute(
|
|
||||||
text(
|
|
||||||
"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
|
|
||||||
"VALUES (1, '情景A', '风格A', '[]', 0, datetime('now'), datetime('now'), 'chat_A', 0, 0)"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
test_session.execute(
|
|
||||||
text(
|
|
||||||
"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
|
|
||||||
"VALUES (2, '情景B', '风格B', '[]', 0, datetime('now'), datetime('now'), 'chat_B', 0, 0)"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
test_session.commit()
|
|
||||||
|
|
||||||
# Filter by chat_A
|
|
||||||
response = client.get("/api/webui/expression/list?chat_id=chat_A")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["total"] == 1
|
|
||||||
assert data["data"][0]["situation"] == "情景A"
|
|
||||||
assert data["data"][0]["chat_id"] == "chat_A"
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_expression_detail_success(client: TestClient, mock_auth, sample_expression: Expression):
|
|
||||||
"""Test GET /expression/{id} returns correct detail"""
|
|
||||||
response = client.get(f"/api/webui/expression/{sample_expression.id}")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["data"]["id"] == sample_expression.id
|
|
||||||
assert data["data"]["situation"] == "测试情景"
|
|
||||||
assert data["data"]["style"] == "测试风格"
|
|
||||||
assert data["data"]["chat_id"] == "test_chat_001"
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_expression_detail_not_found(client: TestClient, mock_auth):
|
|
||||||
"""Test GET /expression/{id} returns 404 for non-existent ID"""
|
|
||||||
response = client.get("/api/webui/expression/99999")
|
|
||||||
assert response.status_code == 404
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert "未找到" in data["detail"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_expression_response_has_legacy_fields(client: TestClient, mock_auth, sample_expression: Expression):
|
|
||||||
"""Test that ExpressionResponse includes legacy fields (checked/rejected/modified_by)"""
|
|
||||||
response = client.get(f"/api/webui/expression/{sample_expression.id}")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()["data"]
|
|
||||||
|
|
||||||
# Verify legacy fields exist and have default values
|
|
||||||
assert "checked" in data
|
|
||||||
assert "rejected" in data
|
|
||||||
assert "modified_by" in data
|
|
||||||
|
|
||||||
# Verify hardcoded default values (from expression_to_response)
|
|
||||||
assert data["checked"] is False
|
|
||||||
assert data["rejected"] is False
|
|
||||||
assert data["modified_by"] is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_expression_without_removed_fields(client: TestClient, mock_auth, sample_expression: Expression):
|
|
||||||
"""Test PATCH /expression/{id} does not accept checked/rejected fields"""
|
|
||||||
# Valid update request (only allowed fields)
|
|
||||||
update_payload = {
|
|
||||||
"situation": "更新后的情景",
|
|
||||||
"style": "更新后的风格",
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload)
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["data"]["situation"] == "更新后的情景"
|
|
||||||
assert data["data"]["style"] == "更新后的风格"
|
|
||||||
|
|
||||||
# Verify legacy fields still returned (hardcoded values)
|
|
||||||
assert data["data"]["checked"] is False
|
|
||||||
assert data["data"]["rejected"] is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_expression_ignores_invalid_fields(client: TestClient, mock_auth, sample_expression: Expression):
|
|
||||||
"""Test PATCH /expression/{id} ignores fields not in ExpressionUpdateRequest"""
|
|
||||||
# Request with invalid field (checked not in schema)
|
|
||||||
update_payload = {
|
|
||||||
"situation": "新情景",
|
|
||||||
"checked": True, # This field should be ignored by Pydantic
|
|
||||||
"rejected": True, # This field should be ignored
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload)
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["data"]["situation"] == "新情景"
|
|
||||||
|
|
||||||
# Response should have hardcoded False values (not True from request)
|
|
||||||
assert data["data"]["checked"] is False
|
|
||||||
assert data["data"]["rejected"] is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_expression_chat_id_mapping(client: TestClient, mock_auth, sample_expression: Expression):
|
|
||||||
"""Test PATCH /expression/{id} correctly maps chat_id to session_id"""
|
|
||||||
update_payload = {"chat_id": "updated_chat_999"}
|
|
||||||
|
|
||||||
response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload)
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
|
|
||||||
# Verify chat_id is returned in response (mapped from session_id)
|
|
||||||
assert data["data"]["chat_id"] == "updated_chat_999"
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_expression_not_found(client: TestClient, mock_auth):
|
|
||||||
"""Test PATCH /expression/{id} returns 404 for non-existent ID"""
|
|
||||||
update_payload = {"situation": "新情景"}
|
|
||||||
|
|
||||||
response = client.patch("/api/webui/expression/99999", json=update_payload)
|
|
||||||
assert response.status_code == 404
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert "未找到" in data["detail"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_expression_empty_request(client: TestClient, mock_auth, sample_expression: Expression):
|
|
||||||
"""Test PATCH /expression/{id} returns 400 for empty update request"""
|
|
||||||
update_payload = {}
|
|
||||||
|
|
||||||
response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload)
|
|
||||||
assert response.status_code == 400
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert "未提供任何需要更新的字段" in data["detail"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_expression_success(client: TestClient, mock_auth, sample_expression: Expression):
|
|
||||||
"""Test DELETE /expression/{id} successfully deletes expression"""
|
|
||||||
expression_id = sample_expression.id
|
|
||||||
|
|
||||||
response = client.delete(f"/api/webui/expression/{expression_id}")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert "成功删除" in data["message"]
|
|
||||||
|
|
||||||
# Verify expression is deleted
|
|
||||||
get_response = client.get(f"/api/webui/expression/{expression_id}")
|
|
||||||
assert get_response.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_expression_not_found(client: TestClient, mock_auth):
|
|
||||||
"""Test DELETE /expression/{id} returns 404 for non-existent ID"""
|
|
||||||
response = client.delete("/api/webui/expression/99999")
|
|
||||||
assert response.status_code == 404
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert "未找到" in data["detail"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_expression_success(client: TestClient, mock_auth):
|
|
||||||
"""Test POST /expression/ successfully creates expression"""
|
|
||||||
create_payload = {
|
|
||||||
"situation": "新建情景",
|
|
||||||
"style": "新建风格",
|
|
||||||
"chat_id": "new_chat_123",
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.post("/api/webui/expression/", json=create_payload)
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert "创建成功" in data["message"]
|
|
||||||
assert data["data"]["situation"] == "新建情景"
|
|
||||||
assert data["data"]["style"] == "新建风格"
|
|
||||||
assert data["data"]["chat_id"] == "new_chat_123"
|
|
||||||
|
|
||||||
# Verify legacy fields
|
|
||||||
assert data["data"]["checked"] is False
|
|
||||||
assert data["data"]["rejected"] is False
|
|
||||||
assert data["data"]["modified_by"] is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_batch_delete_expressions_success(client: TestClient, mock_auth, test_session: Session):
|
|
||||||
"""Test POST /expression/batch/delete successfully deletes multiple expressions"""
|
|
||||||
expression_ids = []
|
|
||||||
for i in range(3):
|
|
||||||
test_session.execute(
|
|
||||||
text(
|
|
||||||
f"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
|
|
||||||
f"VALUES ({i + 1}, '批量删除{i}', '风格{i}', '[]', 0, datetime('now'), datetime('now'), 'chat_{i}', 0, 0)"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
expression_ids.append(i + 1)
|
|
||||||
test_session.commit()
|
|
||||||
|
|
||||||
delete_payload = {"ids": expression_ids}
|
|
||||||
response = client.post("/api/webui/expression/batch/delete", json=delete_payload)
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert "成功删除 3 个" in data["message"]
|
|
||||||
|
|
||||||
for expr_id in expression_ids:
|
|
||||||
get_response = client.get(f"/api/webui/expression/{expr_id}")
|
|
||||||
assert get_response.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
def test_batch_delete_partial_not_found(client: TestClient, mock_auth, sample_expression: Expression):
|
|
||||||
"""Test POST /expression/batch/delete handles partial not found IDs"""
|
|
||||||
delete_payload = {"ids": [sample_expression.id, 88888, 99999]}
|
|
||||||
|
|
||||||
response = client.post("/api/webui/expression/batch/delete", json=delete_payload)
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
# Should delete only the 1 valid ID
|
|
||||||
assert "成功删除 1 个" in data["message"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_expression_stats(client: TestClient, mock_auth, test_session: Session):
|
|
||||||
"""Test GET /expression/stats/summary returns correct statistics"""
|
|
||||||
for i in range(3):
|
|
||||||
test_session.execute(
|
|
||||||
text(
|
|
||||||
f"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
|
|
||||||
f"VALUES ({i + 1}, '情景{i}', '风格{i}', '[]', 0, datetime('now'), datetime('now'), 'chat_{i % 2}', 0, 0)"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
test_session.commit()
|
|
||||||
|
|
||||||
response = client.get("/api/webui/expression/stats/summary")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["data"]["total"] == 3
|
|
||||||
assert data["data"]["chat_count"] == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_review_stats(client: TestClient, mock_auth, test_session: Session):
|
|
||||||
"""Test GET /expression/review/stats returns review status counts"""
|
|
||||||
test_session.execute(
|
|
||||||
text(
|
|
||||||
"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) "
|
|
||||||
"VALUES (1, '待审核', '风格', '[]', 0, datetime('now'), datetime('now'), 'chat_001', 0, 0)"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
test_session.commit()
|
|
||||||
|
|
||||||
response = client.get("/api/webui/expression/review/stats")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["total"] == 1 # Total expressions exists
|
|
||||||
assert data["unchecked"] == 1
|
|
||||||
assert data["passed"] == 0
|
|
||||||
assert data["rejected"] == 0
|
|
||||||
assert data["ai_checked"] == 0
|
|
||||||
assert data["user_checked"] == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_review_list_filter_unchecked(client: TestClient, mock_auth, sample_expression: Expression):
|
|
||||||
"""Test GET /expression/review/list with filter_type=unchecked returns unchecked expressions"""
|
|
||||||
response = client.get("/api/webui/expression/review/list?filter_type=unchecked")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["total"] == 1
|
|
||||||
assert len(data["data"]) == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_review_list_filter_all(client: TestClient, mock_auth, sample_expression: Expression):
|
|
||||||
"""Test GET /expression/review/list with filter_type=all returns all expressions"""
|
|
||||||
response = client.get("/api/webui/expression/review/list?filter_type=all")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["total"] == 1
|
|
||||||
assert len(data["data"]) == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_batch_review_expressions_with_unchecked_marker(client: TestClient, mock_auth, sample_expression: Expression):
|
|
||||||
"""Test POST /expression/review/batch succeeds with require_unchecked=True"""
|
|
||||||
review_payload = {"items": [{"id": sample_expression.id, "rejected": False, "require_unchecked": True}]}
|
|
||||||
|
|
||||||
response = client.post("/api/webui/expression/review/batch", json=review_payload)
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["succeeded"] == 1
|
|
||||||
assert data["results"][0]["success"] is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_batch_review_expressions_overwrites_ai_checked(
|
|
||||||
client: TestClient, mock_auth, test_session: Session, sample_expression: Expression
|
|
||||||
):
|
|
||||||
"""Test POST /expression/review/batch lets manual review override AI checked state"""
|
|
||||||
sample_expression.checked = True
|
|
||||||
sample_expression.rejected = True
|
|
||||||
sample_expression.modified_by = ModifiedBy.AI
|
|
||||||
test_session.add(sample_expression)
|
|
||||||
test_session.commit()
|
|
||||||
|
|
||||||
review_payload = {"items": [{"id": sample_expression.id, "rejected": False, "require_unchecked": True}]}
|
|
||||||
|
|
||||||
response = client.post("/api/webui/expression/review/batch", json=review_payload)
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["succeeded"] == 1
|
|
||||||
test_session.expire_all()
|
|
||||||
reviewed_expression = test_session.exec(select(Expression).where(Expression.id == sample_expression.id)).first()
|
|
||||||
assert reviewed_expression is not None
|
|
||||||
assert reviewed_expression.checked is True
|
|
||||||
assert reviewed_expression.rejected is False
|
|
||||||
assert reviewed_expression.modified_by == ModifiedBy.USER
|
|
||||||
|
|
||||||
|
|
||||||
def test_batch_review_expressions_no_unchecked_check(client: TestClient, mock_auth, sample_expression: Expression):
|
|
||||||
"""Test POST /expression/review/batch succeeds when require_unchecked=False"""
|
|
||||||
review_payload = {"items": [{"id": sample_expression.id, "rejected": False, "require_unchecked": False}]}
|
|
||||||
|
|
||||||
response = client.post("/api/webui/expression/review/batch", json=review_payload)
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["succeeded"] == 1
|
|
||||||
assert data["results"][0]["success"] is True
|
|
||||||
@@ -1,512 +0,0 @@
|
|||||||
"""测试 jargon 路由的完整性和正确性"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
from sqlalchemy.pool import StaticPool
|
|
||||||
from sqlmodel import Session, SQLModel, create_engine
|
|
||||||
|
|
||||||
from src.common.database.database_model import ChatSession, Jargon
|
|
||||||
from src.webui.routers.jargon import router as jargon_router
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="app", scope="function")
|
|
||||||
def app_fixture():
|
|
||||||
app = FastAPI()
|
|
||||||
app.include_router(jargon_router, prefix="/api/webui")
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="engine", scope="function")
|
|
||||||
def engine_fixture():
|
|
||||||
engine = create_engine(
|
|
||||||
"sqlite://",
|
|
||||||
connect_args={"check_same_thread": False},
|
|
||||||
poolclass=StaticPool,
|
|
||||||
)
|
|
||||||
SQLModel.metadata.create_all(engine)
|
|
||||||
yield engine
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="session", scope="function")
|
|
||||||
def session_fixture(engine):
|
|
||||||
connection = engine.connect()
|
|
||||||
transaction = connection.begin()
|
|
||||||
session = Session(bind=connection)
|
|
||||||
|
|
||||||
yield session
|
|
||||||
|
|
||||||
session.close()
|
|
||||||
transaction.rollback()
|
|
||||||
connection.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="client", scope="function")
|
|
||||||
def client_fixture(app: FastAPI, session: Session, monkeypatch):
|
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def mock_get_db_session():
|
|
||||||
yield session
|
|
||||||
|
|
||||||
monkeypatch.setattr("src.webui.routers.jargon.get_db_session", mock_get_db_session)
|
|
||||||
|
|
||||||
with TestClient(app) as client:
|
|
||||||
yield client
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="sample_chat_session")
|
|
||||||
def sample_chat_session_fixture(session: Session):
|
|
||||||
"""创建示例 ChatSession"""
|
|
||||||
chat_session = ChatSession(
|
|
||||||
session_id="test_stream_001",
|
|
||||||
platform="qq",
|
|
||||||
group_id="123456789",
|
|
||||||
user_id=None,
|
|
||||||
created_timestamp=datetime.now(),
|
|
||||||
last_active_timestamp=datetime.now(),
|
|
||||||
)
|
|
||||||
session.add(chat_session)
|
|
||||||
session.commit()
|
|
||||||
session.refresh(chat_session)
|
|
||||||
return chat_session
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="sample_jargons")
|
|
||||||
def sample_jargons_fixture(session: Session, sample_chat_session: ChatSession):
|
|
||||||
"""创建示例 Jargon 数据"""
|
|
||||||
jargons = [
|
|
||||||
Jargon(
|
|
||||||
id=1,
|
|
||||||
content="yyds",
|
|
||||||
raw_content="永远的神",
|
|
||||||
meaning="永远的神",
|
|
||||||
session_id=sample_chat_session.session_id,
|
|
||||||
count=10,
|
|
||||||
is_jargon=True,
|
|
||||||
is_complete=False,
|
|
||||||
),
|
|
||||||
Jargon(
|
|
||||||
id=2,
|
|
||||||
content="awsl",
|
|
||||||
raw_content="啊我死了",
|
|
||||||
meaning="啊我死了",
|
|
||||||
session_id=sample_chat_session.session_id,
|
|
||||||
count=5,
|
|
||||||
is_jargon=True,
|
|
||||||
is_complete=False,
|
|
||||||
),
|
|
||||||
Jargon(
|
|
||||||
id=3,
|
|
||||||
content="hello",
|
|
||||||
raw_content=None,
|
|
||||||
meaning="你好",
|
|
||||||
session_id=sample_chat_session.session_id,
|
|
||||||
count=2,
|
|
||||||
is_jargon=False,
|
|
||||||
is_complete=False,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
for jargon in jargons:
|
|
||||||
session.add(jargon)
|
|
||||||
session.commit()
|
|
||||||
for jargon in jargons:
|
|
||||||
session.refresh(jargon)
|
|
||||||
return jargons
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== Test Cases ====================
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_jargons(client: TestClient, sample_jargons):
|
|
||||||
"""测试 GET /jargon/list 基础列表功能"""
|
|
||||||
response = client.get("/api/webui/jargon/list")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["total"] == 3
|
|
||||||
assert data["page"] == 1
|
|
||||||
assert data["page_size"] == 20
|
|
||||||
assert len(data["data"]) == 3
|
|
||||||
|
|
||||||
assert data["data"][0]["content"] == "yyds"
|
|
||||||
assert data["data"][0]["count"] == 10
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_jargons_with_pagination(client: TestClient, sample_jargons):
|
|
||||||
"""测试分页功能"""
|
|
||||||
response = client.get("/api/webui/jargon/list?page=1&page_size=2")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["total"] == 3
|
|
||||||
assert len(data["data"]) == 2
|
|
||||||
|
|
||||||
response = client.get("/api/webui/jargon/list?page=2&page_size=2")
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert len(data["data"]) == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_jargons_with_search(client: TestClient, sample_jargons):
|
|
||||||
"""测试 GET /jargon/list?search=xxx 搜索功能"""
|
|
||||||
response = client.get("/api/webui/jargon/list?search=yyds")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["total"] == 1
|
|
||||||
assert data["data"][0]["content"] == "yyds"
|
|
||||||
|
|
||||||
# 测试搜索 meaning
|
|
||||||
response = client.get("/api/webui/jargon/list?search=你好")
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["total"] == 1
|
|
||||||
assert data["data"][0]["content"] == "hello"
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_jargons_with_chat_id_filter(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
|
|
||||||
"""测试按 chat_id 筛选"""
|
|
||||||
response = client.get(f"/api/webui/jargon/list?chat_id={sample_chat_session.session_id}")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["total"] == 3
|
|
||||||
|
|
||||||
# 测试不存在的 chat_id
|
|
||||||
response = client.get("/api/webui/jargon/list?chat_id=nonexistent")
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["total"] == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_jargons_with_is_jargon_filter(client: TestClient, sample_jargons):
|
|
||||||
"""测试按 is_jargon 筛选"""
|
|
||||||
response = client.get("/api/webui/jargon/list?is_jargon=true")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["total"] == 2
|
|
||||||
assert all(item["is_jargon"] is True for item in data["data"])
|
|
||||||
|
|
||||||
response = client.get("/api/webui/jargon/list?is_jargon=false")
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["total"] == 1
|
|
||||||
assert data["data"][0]["content"] == "hello"
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_jargon_detail(client: TestClient, sample_jargons):
|
|
||||||
"""测试 GET /jargon/{id} 获取详情"""
|
|
||||||
jargon_id = sample_jargons[0].id
|
|
||||||
response = client.get(f"/api/webui/jargon/{jargon_id}")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["data"]["id"] == jargon_id
|
|
||||||
assert data["data"]["content"] == "yyds"
|
|
||||||
assert data["data"]["meaning"] == "永远的神"
|
|
||||||
assert data["data"]["count"] == 10
|
|
||||||
assert data["data"]["is_jargon"] is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_jargon_detail_not_found(client: TestClient):
|
|
||||||
"""测试获取不存在的黑话详情"""
|
|
||||||
response = client.get("/api/webui/jargon/99999")
|
|
||||||
assert response.status_code == 404
|
|
||||||
assert "黑话不存在" in response.json()["detail"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Composite PK (id+content) prevents autoincrement - database model issue")
|
|
||||||
def test_create_jargon(client: TestClient, sample_chat_session: ChatSession):
|
|
||||||
"""测试 POST /jargon/ 创建黑话"""
|
|
||||||
request_data = {
|
|
||||||
"content": "新黑话",
|
|
||||||
"raw_content": "原始内容",
|
|
||||||
"meaning": "含义",
|
|
||||||
"chat_id": sample_chat_session.session_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.post("/api/webui/jargon/", json=request_data)
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["message"] == "创建成功"
|
|
||||||
assert data["data"]["content"] == "新黑话"
|
|
||||||
assert data["data"]["meaning"] == "含义"
|
|
||||||
assert data["data"]["count"] == 0
|
|
||||||
assert data["data"]["is_jargon"] is None
|
|
||||||
assert data["data"]["is_complete"] is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_duplicate_jargon(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
|
|
||||||
"""测试创建重复黑话返回 400"""
|
|
||||||
request_data = {
|
|
||||||
"content": "yyds",
|
|
||||||
"meaning": "重复的",
|
|
||||||
"chat_id": sample_chat_session.session_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.post("/api/webui/jargon/", json=request_data)
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert "已存在相同内容的黑话" in response.json()["detail"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_jargon(client: TestClient, sample_jargons):
|
|
||||||
"""测试 PATCH /jargon/{id} 更新黑话"""
|
|
||||||
jargon_id = sample_jargons[0].id
|
|
||||||
update_data = {
|
|
||||||
"meaning": "更新后的含义",
|
|
||||||
"is_jargon": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.patch(f"/api/webui/jargon/{jargon_id}", json=update_data)
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["message"] == "更新成功"
|
|
||||||
assert data["data"]["meaning"] == "更新后的含义"
|
|
||||||
assert data["data"]["is_jargon"] is True
|
|
||||||
assert data["data"]["content"] == "yyds" # 未改变的字段保持不变
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_jargon_with_chat_id_mapping(client: TestClient, sample_jargons):
|
|
||||||
"""测试更新时 chat_id → session_id 的映射"""
|
|
||||||
jargon_id = sample_jargons[0].id
|
|
||||||
update_data = {
|
|
||||||
"chat_id": "new_session_id",
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.patch(f"/api/webui/jargon/{jargon_id}", json=update_data)
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["data"]["chat_id"] == "new_session_id"
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_jargon_not_found(client: TestClient):
|
|
||||||
"""测试更新不存在的黑话"""
|
|
||||||
response = client.patch("/api/webui/jargon/99999", json={"meaning": "test"})
|
|
||||||
assert response.status_code == 404
|
|
||||||
assert "黑话不存在" in response.json()["detail"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_jargon(client: TestClient, sample_jargons, session: Session):
|
|
||||||
"""测试 DELETE /jargon/{id} 删除黑话"""
|
|
||||||
jargon_id = sample_jargons[0].id
|
|
||||||
response = client.delete(f"/api/webui/jargon/{jargon_id}")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["message"] == "删除成功"
|
|
||||||
assert data["deleted_count"] == 1
|
|
||||||
|
|
||||||
# 验证数据库中已删除
|
|
||||||
response = client.get(f"/api/webui/jargon/{jargon_id}")
|
|
||||||
assert response.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_jargon_not_found(client: TestClient):
|
|
||||||
"""测试删除不存在的黑话"""
|
|
||||||
response = client.delete("/api/webui/jargon/99999")
|
|
||||||
assert response.status_code == 404
|
|
||||||
assert "黑话不存在" in response.json()["detail"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_batch_delete(client: TestClient, sample_jargons):
|
|
||||||
"""测试 POST /jargon/batch/delete 批量删除"""
|
|
||||||
ids_to_delete = [sample_jargons[0].id, sample_jargons[1].id]
|
|
||||||
request_data = {"ids": ids_to_delete}
|
|
||||||
|
|
||||||
response = client.post("/api/webui/jargon/batch/delete", json=request_data)
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["deleted_count"] == 2
|
|
||||||
assert "成功删除 2 条黑话" in data["message"]
|
|
||||||
|
|
||||||
# 验证已删除
|
|
||||||
response = client.get(f"/api/webui/jargon/{ids_to_delete[0]}")
|
|
||||||
assert response.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
def test_batch_delete_empty_list(client: TestClient):
|
|
||||||
"""测试批量删除空列表返回 400"""
|
|
||||||
response = client.post("/api/webui/jargon/batch/delete", json={"ids": []})
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert "ID列表不能为空" in response.json()["detail"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_batch_set_jargon_status(client: TestClient, sample_jargons):
|
|
||||||
"""测试批量设置黑话状态"""
|
|
||||||
ids = [sample_jargons[0].id, sample_jargons[1].id]
|
|
||||||
response = client.post(
|
|
||||||
"/api/webui/jargon/batch/set-jargon",
|
|
||||||
params={"ids": ids, "is_jargon": False},
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert "成功更新 2 条黑话状态" in data["message"]
|
|
||||||
|
|
||||||
# 验证状态已更新
|
|
||||||
detail_response = client.get(f"/api/webui/jargon/{ids[0]}")
|
|
||||||
assert detail_response.json()["data"]["is_jargon"] is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_stats(client: TestClient, sample_jargons):
|
|
||||||
"""测试 GET /jargon/stats/summary 统计数据"""
|
|
||||||
response = client.get("/api/webui/jargon/stats/summary")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
stats = data["data"]
|
|
||||||
|
|
||||||
assert stats["total"] == 3
|
|
||||||
assert stats["confirmed_jargon"] == 2
|
|
||||||
assert stats["confirmed_not_jargon"] == 1
|
|
||||||
assert stats["pending"] == 0
|
|
||||||
assert stats["complete_count"] == 0
|
|
||||||
assert stats["chat_count"] == 1
|
|
||||||
assert isinstance(stats["top_chats"], dict)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_chat_list(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
|
|
||||||
"""测试 GET /jargon/chats 获取聊天列表"""
|
|
||||||
response = client.get("/api/webui/jargon/chats")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert len(data["data"]) == 1
|
|
||||||
|
|
||||||
chat_info = data["data"][0]
|
|
||||||
assert chat_info["chat_id"] == sample_chat_session.session_id
|
|
||||||
assert chat_info["platform"] == "qq"
|
|
||||||
assert chat_info["is_group"] is True
|
|
||||||
assert chat_info["chat_name"] == sample_chat_session.group_id
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_chat_list_with_json_chat_id(client: TestClient, session: Session, sample_chat_session: ChatSession):
|
|
||||||
"""测试解析 JSON 格式的 chat_id"""
|
|
||||||
json_chat_id = json.dumps([[sample_chat_session.session_id, "user123"]])
|
|
||||||
jargon = Jargon(
|
|
||||||
id=100,
|
|
||||||
content="测试黑话",
|
|
||||||
meaning="测试",
|
|
||||||
session_id=json_chat_id,
|
|
||||||
count=1,
|
|
||||||
)
|
|
||||||
session.add(jargon)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
response = client.get("/api/webui/jargon/chats")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert len(data["data"]) == 1
|
|
||||||
assert data["data"][0]["chat_id"] == sample_chat_session.session_id
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_chat_list_without_chat_session(client: TestClient, session: Session):
|
|
||||||
"""测试聊天列表中没有对应 ChatSession 的情况"""
|
|
||||||
jargon = Jargon(
|
|
||||||
id=101,
|
|
||||||
content="孤立黑话",
|
|
||||||
meaning="无对应会话",
|
|
||||||
session_id="nonexistent_stream_id",
|
|
||||||
count=1,
|
|
||||||
)
|
|
||||||
session.add(jargon)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
response = client.get("/api/webui/jargon/chats")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert len(data["data"]) == 1
|
|
||||||
assert data["data"][0]["chat_id"] == "nonexistent_stream_id"
|
|
||||||
assert data["data"][0]["chat_name"] == "nonexistent_stream_id"[:20]
|
|
||||||
assert data["data"][0]["platform"] is None
|
|
||||||
assert data["data"][0]["is_group"] is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_jargon_response_fields(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
|
|
||||||
"""测试 JargonResponse 字段完整性"""
|
|
||||||
response = client.get(f"/api/webui/jargon/{sample_jargons[0].id}")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()["data"]
|
|
||||||
|
|
||||||
# 验证所有必需字段存在
|
|
||||||
required_fields = [
|
|
||||||
"id",
|
|
||||||
"content",
|
|
||||||
"raw_content",
|
|
||||||
"meaning",
|
|
||||||
"chat_id",
|
|
||||||
"stream_id",
|
|
||||||
"chat_name",
|
|
||||||
"count",
|
|
||||||
"is_jargon",
|
|
||||||
"is_complete",
|
|
||||||
"inference_with_context",
|
|
||||||
"inference_content_only",
|
|
||||||
]
|
|
||||||
for field in required_fields:
|
|
||||||
assert field in data
|
|
||||||
|
|
||||||
# 验证 chat_name 显示逻辑
|
|
||||||
assert data["chat_name"] == sample_chat_session.group_id
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Composite PK (id+content) prevents autoincrement - database model issue")
|
|
||||||
def test_create_jargon_without_optional_fields(client: TestClient, sample_chat_session: ChatSession):
|
|
||||||
"""测试创建黑话时可选字段为空"""
|
|
||||||
request_data = {
|
|
||||||
"content": "简单黑话",
|
|
||||||
"chat_id": sample_chat_session.session_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.post("/api/webui/jargon/", json=request_data)
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()["data"]
|
|
||||||
assert data["raw_content"] is None
|
|
||||||
assert data["meaning"] == ""
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_jargon_partial_fields(client: TestClient, sample_jargons):
|
|
||||||
"""测试增量更新(只更新部分字段)"""
|
|
||||||
jargon_id = sample_jargons[0].id
|
|
||||||
original_content = sample_jargons[0].content
|
|
||||||
|
|
||||||
# 只更新 meaning
|
|
||||||
response = client.patch(f"/api/webui/jargon/{jargon_id}", json={"meaning": "新含义"})
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()["data"]
|
|
||||||
assert data["meaning"] == "新含义"
|
|
||||||
assert data["content"] == original_content # 其他字段不变
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_jargons_multiple_filters(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
|
|
||||||
"""测试组合多个过滤条件"""
|
|
||||||
response = client.get(f"/api/webui/jargon/list?search=永远&chat_id={sample_chat_session.session_id}&is_jargon=true")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
assert data["total"] == 1
|
|
||||||
assert data["data"][0]["content"] == "yyds"
|
|
||||||
@@ -1,870 +0,0 @@
|
|||||||
from fastapi import FastAPI
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.services.memory_service import MemorySearchResult
|
|
||||||
from src.webui.dependencies import require_auth
|
|
||||||
from src.webui.routers import memory as memory_router_module
|
|
||||||
from src.webui.routers.memory import compat_router
|
|
||||||
from src.webui.routes import router as main_router
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def client() -> TestClient:
|
|
||||||
app = FastAPI()
|
|
||||||
app.dependency_overrides[require_auth] = lambda: "ok"
|
|
||||||
app.include_router(main_router)
|
|
||||||
app.include_router(compat_router)
|
|
||||||
return TestClient(app)
|
|
||||||
|
|
||||||
|
|
||||||
def test_webui_memory_graph_route(client: TestClient, monkeypatch):
|
|
||||||
async def fake_graph_admin(*, action: str, **kwargs):
|
|
||||||
assert action == "get_graph"
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"nodes": [],
|
|
||||||
"edges": [
|
|
||||||
{
|
|
||||||
"source": "alice",
|
|
||||||
"target": "map",
|
|
||||||
"weight": 1.5,
|
|
||||||
"relation_hashes": ["rel-1"],
|
|
||||||
"predicates": ["持有"],
|
|
||||||
"relation_count": 1,
|
|
||||||
"evidence_count": 2,
|
|
||||||
"label": "持有",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"total_nodes": 0,
|
|
||||||
"limit": kwargs.get("limit"),
|
|
||||||
}
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
|
|
||||||
|
|
||||||
response = client.get("/api/webui/memory/graph", params={"limit": 77})
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["success"] is True
|
|
||||||
assert response.json()["limit"] == 77
|
|
||||||
assert response.json()["edges"][0]["predicates"] == ["持有"]
|
|
||||||
assert response.json()["edges"][0]["relation_count"] == 1
|
|
||||||
assert response.json()["edges"][0]["evidence_count"] == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_webui_memory_graph_search_route(client: TestClient, monkeypatch):
|
|
||||||
async def fake_graph_admin(*, action: str, **kwargs):
|
|
||||||
assert action == "search"
|
|
||||||
assert kwargs["query"] == "Alice"
|
|
||||||
assert kwargs["limit"] == 33
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"query": kwargs["query"],
|
|
||||||
"limit": kwargs["limit"],
|
|
||||||
"count": 1,
|
|
||||||
"items": [
|
|
||||||
{
|
|
||||||
"type": "entity",
|
|
||||||
"title": "Alice",
|
|
||||||
"matched_field": "name",
|
|
||||||
"matched_value": "Alice",
|
|
||||||
"entity_name": "Alice",
|
|
||||||
"entity_hash": "entity-1",
|
|
||||||
"appearance_count": 3,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
|
|
||||||
|
|
||||||
response = client.get("/api/webui/memory/graph/search", params={"query": "Alice", "limit": 33})
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["success"] is True
|
|
||||||
assert response.json()["query"] == "Alice"
|
|
||||||
assert response.json()["limit"] == 33
|
|
||||||
assert response.json()["items"][0]["type"] == "entity"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"params",
|
|
||||||
[
|
|
||||||
{"query": "", "limit": 50},
|
|
||||||
{"query": "Alice", "limit": 0},
|
|
||||||
{"query": "Alice", "limit": 201},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_webui_memory_graph_search_route_validation(client: TestClient, params):
|
|
||||||
response = client.get("/api/webui/memory/graph/search", params=params)
|
|
||||||
|
|
||||||
assert response.status_code == 422
|
|
||||||
|
|
||||||
|
|
||||||
def test_webui_memory_graph_node_detail_route(client: TestClient, monkeypatch):
|
|
||||||
async def fake_graph_admin(*, action: str, **kwargs):
|
|
||||||
assert action == "node_detail"
|
|
||||||
assert kwargs["node_id"] == "Alice"
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"node": {"id": "Alice", "type": "entity", "content": "Alice", "appearance_count": 3},
|
|
||||||
"relations": [{"hash": "rel-1", "subject": "Alice", "predicate": "持有", "object": "Map", "text": "Alice 持有 Map", "confidence": 0.9, "paragraph_count": 1, "paragraph_hashes": ["p-1"], "source_paragraph": "p-1"}],
|
|
||||||
"paragraphs": [{"hash": "p-1", "content": "Alice 拿着地图。", "preview": "Alice 拿着地图。", "source": "demo", "entity_count": 2, "relation_count": 1, "entities": ["Alice", "Map"], "relations": ["Alice 持有 Map"]}],
|
|
||||||
"evidence_graph": {
|
|
||||||
"nodes": [{"id": "entity:Alice", "type": "entity", "content": "Alice"}],
|
|
||||||
"edges": [],
|
|
||||||
"focus_entities": ["Alice"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
|
|
||||||
|
|
||||||
response = client.get("/api/webui/memory/graph/node-detail", params={"node_id": "Alice"})
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["node"]["id"] == "Alice"
|
|
||||||
assert response.json()["relations"][0]["predicate"] == "持有"
|
|
||||||
assert response.json()["evidence_graph"]["focus_entities"] == ["Alice"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_webui_memory_graph_node_detail_route_returns_404(client: TestClient, monkeypatch):
|
|
||||||
async def fake_graph_admin(*, action: str, **kwargs):
|
|
||||||
assert action == "node_detail"
|
|
||||||
return {"success": False, "error": "未找到节点: Missing"}
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
|
|
||||||
|
|
||||||
response = client.get("/api/webui/memory/graph/node-detail", params={"node_id": "Missing"})
|
|
||||||
|
|
||||||
assert response.status_code == 404
|
|
||||||
assert response.json()["detail"] == "未找到节点: Missing"
|
|
||||||
|
|
||||||
|
|
||||||
def test_webui_memory_graph_edge_detail_route(client: TestClient, monkeypatch):
|
|
||||||
async def fake_graph_admin(*, action: str, **kwargs):
|
|
||||||
assert action == "edge_detail"
|
|
||||||
assert kwargs["source"] == "Alice"
|
|
||||||
assert kwargs["target"] == "Map"
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"edge": {
|
|
||||||
"source": "Alice",
|
|
||||||
"target": "Map",
|
|
||||||
"weight": 1.5,
|
|
||||||
"relation_hashes": ["rel-1"],
|
|
||||||
"predicates": ["持有"],
|
|
||||||
"relation_count": 1,
|
|
||||||
"evidence_count": 1,
|
|
||||||
"label": "持有",
|
|
||||||
},
|
|
||||||
"relations": [{"hash": "rel-1", "subject": "Alice", "predicate": "持有", "object": "Map", "text": "Alice 持有 Map", "confidence": 0.9, "paragraph_count": 1, "paragraph_hashes": ["p-1"], "source_paragraph": "p-1"}],
|
|
||||||
"paragraphs": [{"hash": "p-1", "content": "Alice 拿着地图。", "preview": "Alice 拿着地图。", "source": "demo", "entity_count": 2, "relation_count": 1, "entities": ["Alice", "Map"], "relations": ["Alice 持有 Map"]}],
|
|
||||||
"evidence_graph": {
|
|
||||||
"nodes": [{"id": "relation:rel-1", "type": "relation", "content": "Alice 持有 Map"}],
|
|
||||||
"edges": [{"source": "paragraph:p-1", "target": "relation:rel-1", "kind": "supports", "label": "支撑", "weight": 1.0}],
|
|
||||||
"focus_entities": ["Alice", "Map"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
|
|
||||||
|
|
||||||
response = client.get("/api/webui/memory/graph/edge-detail", params={"source": "Alice", "target": "Map"})
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["edge"]["predicates"] == ["持有"]
|
|
||||||
assert response.json()["paragraphs"][0]["source"] == "demo"
|
|
||||||
assert response.json()["evidence_graph"]["edges"][0]["kind"] == "supports"
|
|
||||||
|
|
||||||
|
|
||||||
def test_webui_memory_graph_edge_detail_route_returns_404(client: TestClient, monkeypatch):
|
|
||||||
async def fake_graph_admin(*, action: str, **kwargs):
|
|
||||||
assert action == "edge_detail"
|
|
||||||
return {"success": False, "error": "未找到边: Alice -> Missing"}
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
|
|
||||||
|
|
||||||
response = client.get("/api/webui/memory/graph/edge-detail", params={"source": "Alice", "target": "Missing"})
|
|
||||||
|
|
||||||
assert response.status_code == 404
|
|
||||||
assert response.json()["detail"] == "未找到边: Alice -> Missing"
|
|
||||||
|
|
||||||
|
|
||||||
def test_webui_memory_profile_query_resolves_platform_user_id(client: TestClient, monkeypatch):
|
|
||||||
def fake_resolve_person_id_for_memory(**kwargs):
|
|
||||||
assert kwargs == {"platform": "qq", "user_id": "12345", "strict_known": False}
|
|
||||||
return "resolved-person-id"
|
|
||||||
|
|
||||||
async def fake_profile_admin(*, action: str, **kwargs):
|
|
||||||
assert action == "query"
|
|
||||||
assert kwargs["person_id"] == "resolved-person-id"
|
|
||||||
assert kwargs["person_keyword"] == "Alice"
|
|
||||||
assert kwargs["limit"] == 9
|
|
||||||
assert kwargs["force_refresh"] is True
|
|
||||||
return {"success": True, "person_id": kwargs["person_id"], "profile_text": "profile"}
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "profile_admin", fake_profile_admin)
|
|
||||||
|
|
||||||
response = client.get(
|
|
||||||
"/api/webui/memory/profiles/query",
|
|
||||||
params={
|
|
||||||
"platform": "qq",
|
|
||||||
"user_id": "12345",
|
|
||||||
"person_keyword": "Alice",
|
|
||||||
"limit": 9,
|
|
||||||
"force_refresh": True,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["success"] is True
|
|
||||||
assert response.json()["person_id"] == "resolved-person-id"
|
|
||||||
|
|
||||||
|
|
||||||
def test_webui_memory_profile_query_prefers_explicit_person_id(client: TestClient, monkeypatch):
|
|
||||||
def fake_resolve_person_id_for_memory(**kwargs):
|
|
||||||
raise AssertionError(f"不应解析平台账号: {kwargs}")
|
|
||||||
|
|
||||||
async def fake_profile_admin(*, action: str, **kwargs):
|
|
||||||
assert action == "query"
|
|
||||||
assert kwargs["person_id"] == "explicit-person-id"
|
|
||||||
return {"success": True, "person_id": kwargs["person_id"]}
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "profile_admin", fake_profile_admin)
|
|
||||||
|
|
||||||
response = client.get(
|
|
||||||
"/api/webui/memory/profiles/query",
|
|
||||||
params={"person_id": "explicit-person-id", "platform": "qq", "user_id": "12345"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["person_id"] == "explicit-person-id"
|
|
||||||
|
|
||||||
|
|
||||||
def test_webui_memory_profile_list_enriches_person_name(client: TestClient, monkeypatch):
|
|
||||||
async def fake_profile_admin(*, action: str, **kwargs):
|
|
||||||
assert action == "list"
|
|
||||||
assert kwargs["limit"] == 7
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"items": [
|
|
||||||
{"person_id": "person-1", "profile_text": "profile-1"},
|
|
||||||
{"person_id": "person-2", "profile_text": "profile-2"},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "profile_admin", fake_profile_admin)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
memory_router_module,
|
|
||||||
"_get_person_name_for_person_id",
|
|
||||||
lambda person_id: {"person-1": "Alice"}.get(person_id, ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.get("/api/webui/memory/profiles", params={"limit": 7})
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["items"][0]["person_name"] == "Alice"
|
|
||||||
assert response.json()["items"][1]["person_name"] == ""
|
|
||||||
|
|
||||||
|
|
||||||
def test_webui_memory_profile_search_resolves_platform_user_id(client: TestClient, monkeypatch):
|
|
||||||
def fake_resolve_person_id_for_memory(**kwargs):
|
|
||||||
assert kwargs == {"platform": "qq", "user_id": "12345", "strict_known": False}
|
|
||||||
return "resolved-person-id"
|
|
||||||
|
|
||||||
async def fake_profile_list(limit: int):
|
|
||||||
assert limit == 200
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"items": [
|
|
||||||
{"person_id": "resolved-person-id", "person_name": "Alice", "profile_text": "喜欢咖啡"},
|
|
||||||
{"person_id": "other-person-id", "person_name": "Bob", "profile_text": "喜欢茶"},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
|
|
||||||
monkeypatch.setattr(memory_router_module, "_profile_list", fake_profile_list)
|
|
||||||
|
|
||||||
response = client.get(
|
|
||||||
"/api/webui/memory/profiles/search",
|
|
||||||
params={"platform": "qq", "user_id": "12345", "limit": 50},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["items"] == [
|
|
||||||
{"person_id": "resolved-person-id", "person_name": "Alice", "profile_text": "喜欢咖啡"}
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_webui_memory_profile_search_filters_keyword(client: TestClient, monkeypatch):
|
|
||||||
async def fake_profile_list(limit: int):
|
|
||||||
assert limit == 200
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"items": [
|
|
||||||
{"person_id": "person-1", "person_name": "Alice", "profile_text": "喜欢咖啡"},
|
|
||||||
{"person_id": "person-2", "person_name": "Bob", "profile_text": "喜欢茶"},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module, "_profile_list", fake_profile_list)
|
|
||||||
|
|
||||||
response = client.get("/api/webui/memory/profiles/search", params={"person_keyword": "咖啡", "limit": 50})
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["items"] == [
|
|
||||||
{"person_id": "person-1", "person_name": "Alice", "profile_text": "喜欢咖啡"}
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_webui_memory_episode_list_resolves_platform_user_id(client: TestClient, monkeypatch):
|
|
||||||
def fake_resolve_person_id_for_memory(**kwargs):
|
|
||||||
assert kwargs == {"platform": "qq", "user_id": "12345", "strict_known": False}
|
|
||||||
return "resolved-person-id"
|
|
||||||
|
|
||||||
async def fake_episode_admin(*, action: str, **kwargs):
|
|
||||||
assert action == "list"
|
|
||||||
assert kwargs == {
|
|
||||||
"query": "咖啡",
|
|
||||||
"limit": 9,
|
|
||||||
"source": "chat_summary:demo",
|
|
||||||
"person_id": "resolved-person-id",
|
|
||||||
"time_start": 100.0,
|
|
||||||
"time_end": 200.0,
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"items": [{"episode_id": "ep-1", "person_id": "resolved-person-id", "summary": "喝咖啡"}],
|
|
||||||
"count": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "episode_admin", fake_episode_admin)
|
|
||||||
monkeypatch.setattr(memory_router_module, "_get_person_name_for_person_id", lambda person_id: "测试人物")
|
|
||||||
|
|
||||||
response = client.get(
|
|
||||||
"/api/webui/memory/episodes",
|
|
||||||
params={
|
|
||||||
"query": "咖啡",
|
|
||||||
"limit": 9,
|
|
||||||
"source": "chat_summary:demo",
|
|
||||||
"platform": "qq",
|
|
||||||
"user_id": "12345",
|
|
||||||
"time_start": 100,
|
|
||||||
"time_end": 200,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["items"][0]["person_name"] == "测试人物"
|
|
||||||
|
|
||||||
|
|
||||||
def test_webui_memory_episode_list_prefers_explicit_person_id(client: TestClient, monkeypatch):
|
|
||||||
def fake_resolve_person_id_for_memory(**kwargs):
|
|
||||||
raise AssertionError(f"不应解析平台账号: {kwargs}")
|
|
||||||
|
|
||||||
async def fake_episode_admin(*, action: str, **kwargs):
|
|
||||||
assert action == "list"
|
|
||||||
assert kwargs["person_id"] == "explicit-person-id"
|
|
||||||
return {"success": True, "items": []}
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "episode_admin", fake_episode_admin)
|
|
||||||
|
|
||||||
response = client.get(
|
|
||||||
"/api/webui/memory/episodes",
|
|
||||||
params={"person_id": "explicit-person-id", "platform": "qq", "user_id": "12345"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["items"] == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_compat_aggregate_route(client: TestClient, monkeypatch):
|
|
||||||
async def fake_search(query: str, **kwargs):
|
|
||||||
assert kwargs["mode"] == "aggregate"
|
|
||||||
assert kwargs["respect_filter"] is False
|
|
||||||
return MemorySearchResult(summary=f"summary:{query}", hits=[])
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "search", fake_search)
|
|
||||||
|
|
||||||
response = client.get("/api/query/aggregate", params={"query": "mai"})
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json() == {
|
|
||||||
"success": True,
|
|
||||||
"summary": "summary:mai",
|
|
||||||
"hits": [],
|
|
||||||
"filtered": False,
|
|
||||||
"error": "",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_auto_save_routes(client: TestClient, monkeypatch):
|
|
||||||
async def fake_runtime_admin(*, action: str, **kwargs):
|
|
||||||
if action == "get_config":
|
|
||||||
return {"success": True, "auto_save": True}
|
|
||||||
if action == "set_auto_save":
|
|
||||||
return {"success": True, "auto_save": kwargs["enabled"]}
|
|
||||||
raise AssertionError(action)
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "runtime_admin", fake_runtime_admin)
|
|
||||||
|
|
||||||
get_response = client.get("/api/config/auto_save")
|
|
||||||
post_response = client.post("/api/config/auto_save", json={"enabled": False})
|
|
||||||
|
|
||||||
assert get_response.status_code == 200
|
|
||||||
assert get_response.json() == {"success": True, "auto_save": True}
|
|
||||||
assert post_response.status_code == 200
|
|
||||||
assert post_response.json() == {"success": True, "auto_save": False}
|
|
||||||
|
|
||||||
|
|
||||||
def test_memory_config_routes(client: TestClient, monkeypatch):
|
|
||||||
monkeypatch.setattr(
|
|
||||||
memory_router_module.a_memorix_host_service,
|
|
||||||
"get_config_schema",
|
|
||||||
lambda: {"layout": {"type": "tabs"}, "sections": {"plugin": {"fields": {}}}},
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
memory_router_module.a_memorix_host_service,
|
|
||||||
"get_config_path",
|
|
||||||
lambda: memory_router_module.Path("/tmp/config/bot_config.toml"),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
memory_router_module.a_memorix_host_service,
|
|
||||||
"get_config",
|
|
||||||
lambda: {"plugin": {"enabled": True}},
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
memory_router_module.a_memorix_host_service,
|
|
||||||
"get_raw_config",
|
|
||||||
lambda: "[plugin]\nenabled = true\n",
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
memory_router_module.a_memorix_host_service,
|
|
||||||
"get_raw_config_with_meta",
|
|
||||||
lambda: {
|
|
||||||
"config": "[plugin]\nenabled = true\n",
|
|
||||||
"exists": True,
|
|
||||||
"using_default": False,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
schema_response = client.get("/api/webui/memory/config/schema")
|
|
||||||
config_response = client.get("/api/webui/memory/config")
|
|
||||||
raw_response = client.get("/api/webui/memory/config/raw")
|
|
||||||
expected_path = memory_router_module.Path("/tmp/config/bot_config.toml").as_posix()
|
|
||||||
|
|
||||||
assert schema_response.status_code == 200
|
|
||||||
assert memory_router_module.Path(schema_response.json()["path"]).as_posix() == expected_path
|
|
||||||
assert schema_response.json()["schema"]["layout"]["type"] == "tabs"
|
|
||||||
|
|
||||||
assert config_response.status_code == 200
|
|
||||||
assert config_response.json()["success"] is True
|
|
||||||
assert config_response.json()["config"] == {"plugin": {"enabled": True}}
|
|
||||||
assert memory_router_module.Path(config_response.json()["path"]).as_posix() == expected_path
|
|
||||||
|
|
||||||
assert raw_response.status_code == 200
|
|
||||||
assert raw_response.json()["success"] is True
|
|
||||||
assert raw_response.json()["config"] == "[plugin]\nenabled = true\n"
|
|
||||||
assert memory_router_module.Path(raw_response.json()["path"]).as_posix() == expected_path
|
|
||||||
|
|
||||||
|
|
||||||
def test_memory_config_raw_returns_default_template_when_file_missing(client: TestClient, monkeypatch):
|
|
||||||
monkeypatch.setattr(
|
|
||||||
memory_router_module.a_memorix_host_service,
|
|
||||||
"get_config_path",
|
|
||||||
lambda: memory_router_module.Path("/tmp/config/bot_config.toml"),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
memory_router_module.a_memorix_host_service,
|
|
||||||
"get_raw_config_with_meta",
|
|
||||||
lambda: {
|
|
||||||
"config": "[plugin]\nenabled = true\n",
|
|
||||||
"exists": False,
|
|
||||||
"using_default": True,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.get("/api/webui/memory/config/raw")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["success"] is True
|
|
||||||
assert response.json()["config"] == "[plugin]\nenabled = true\n"
|
|
||||||
assert response.json()["exists"] is False
|
|
||||||
assert response.json()["using_default"] is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_memory_config_update_routes(client: TestClient, monkeypatch):
|
|
||||||
async def fake_update_config(config):
|
|
||||||
assert config == {"plugin": {"enabled": False}}
|
|
||||||
return {"success": True, "config_path": "config/bot_config.toml"}
|
|
||||||
|
|
||||||
async def fake_update_raw(raw_config):
|
|
||||||
assert raw_config == "[plugin]\nenabled = false\n"
|
|
||||||
return {"success": True, "config_path": "config/bot_config.toml"}
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.a_memorix_host_service, "update_config", fake_update_config)
|
|
||||||
monkeypatch.setattr(memory_router_module.a_memorix_host_service, "update_raw_config", fake_update_raw)
|
|
||||||
|
|
||||||
config_response = client.put("/api/webui/memory/config", json={"config": {"plugin": {"enabled": False}}})
|
|
||||||
raw_response = client.put("/api/webui/memory/config/raw", json={"config": "[plugin]\nenabled = false\n"})
|
|
||||||
|
|
||||||
assert config_response.status_code == 200
|
|
||||||
assert config_response.json() == {"success": True, "config_path": "config/bot_config.toml"}
|
|
||||||
|
|
||||||
assert raw_response.status_code == 200
|
|
||||||
assert raw_response.json() == {"success": True, "config_path": "config/bot_config.toml"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_memory_config_raw_rejects_invalid_toml(client: TestClient):
|
|
||||||
response = client.put("/api/webui/memory/config/raw", json={"config": "[plugin\nenabled = true"})
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert "TOML 格式错误" in response.json()["detail"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_recycle_bin_route(client: TestClient, monkeypatch):
|
|
||||||
async def fake_get_recycle_bin(*, limit: int):
|
|
||||||
return {"success": True, "items": [{"hash": "deadbeef"}], "count": 1, "limit": limit}
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "get_recycle_bin", fake_get_recycle_bin)
|
|
||||||
|
|
||||||
response = client.get("/api/memory/recycle_bin", params={"limit": 10})
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["success"] is True
|
|
||||||
assert response.json()["count"] == 1
|
|
||||||
assert response.json()["limit"] == 10
|
|
||||||
|
|
||||||
|
|
||||||
def test_import_guide_route(client: TestClient, monkeypatch):
|
|
||||||
async def fake_import_admin(*, action: str, **kwargs):
|
|
||||||
assert kwargs == {}
|
|
||||||
if action == "get_guide":
|
|
||||||
return {"success": True}
|
|
||||||
if action == "get_settings":
|
|
||||||
return {"success": True, "settings": {"path_aliases": {"raw": "/tmp/raw"}}}
|
|
||||||
raise AssertionError(action)
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin)
|
|
||||||
|
|
||||||
response = client.get("/api/webui/memory/import/guide")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["success"] is True
|
|
||||||
assert response.json()["source"] == "local"
|
|
||||||
assert "长期记忆导入说明" in response.json()["content"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_import_upload_route(client: TestClient, monkeypatch, tmp_path):
|
|
||||||
monkeypatch.setattr(memory_router_module, "STAGING_ROOT", tmp_path)
|
|
||||||
|
|
||||||
async def fake_import_admin(*, action: str, **kwargs):
|
|
||||||
assert action == "create_upload"
|
|
||||||
staged_files = kwargs["staged_files"]
|
|
||||||
assert len(staged_files) == 1
|
|
||||||
assert staged_files[0]["filename"] == "demo.txt"
|
|
||||||
assert memory_router_module.Path(staged_files[0]["staged_path"]).exists()
|
|
||||||
return {"success": True, "task_id": "task-1"}
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/api/import/upload",
|
|
||||||
data={"payload_json": "{\"source\": \"upload\"}"},
|
|
||||||
files=[("files", ("demo.txt", b"hello world", "text/plain"))],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json() == {"success": True, "task_id": "task-1"}
|
|
||||||
assert list(tmp_path.iterdir()) == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_v5_status_route(client: TestClient, monkeypatch):
|
|
||||||
async def fake_v5_admin(*, action: str, **kwargs):
|
|
||||||
assert action == "status"
|
|
||||||
assert kwargs["target"] == "mai"
|
|
||||||
return {"success": True, "active_count": 1, "inactive_count": 2, "deleted_count": 3}
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "v5_admin", fake_v5_admin)
|
|
||||||
|
|
||||||
response = client.get("/api/webui/memory/v5/status", params={"target": "mai"})
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["success"] is True
|
|
||||||
assert response.json()["deleted_count"] == 3
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_preview_route(client: TestClient, monkeypatch):
|
|
||||||
async def fake_delete_admin(*, action: str, **kwargs):
|
|
||||||
assert action == "preview"
|
|
||||||
assert kwargs["mode"] == "paragraph"
|
|
||||||
assert kwargs["selector"] == {"query": "demo"}
|
|
||||||
return {"success": True, "counts": {"paragraphs": 1}, "dry_run": True}
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/api/webui/memory/delete/preview",
|
|
||||||
json={"mode": "paragraph", "selector": {"query": "demo"}},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json() == {"success": True, "counts": {"paragraphs": 1}, "dry_run": True}
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_preview_route_supports_mixed_mode(client: TestClient, monkeypatch):
|
|
||||||
async def fake_delete_admin(*, action: str, **kwargs):
|
|
||||||
assert action == "preview"
|
|
||||||
assert kwargs["mode"] == "mixed"
|
|
||||||
assert kwargs["selector"] == {
|
|
||||||
"entity_hashes": ["entity-1"],
|
|
||||||
"paragraph_hashes": ["p-1"],
|
|
||||||
"relation_hashes": ["rel-1"],
|
|
||||||
"sources": ["demo"],
|
|
||||||
}
|
|
||||||
return {"success": True, "mode": "mixed", "counts": {"entities": 1, "paragraphs": 1, "relations": 1, "sources": 1}}
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/api/webui/memory/delete/preview",
|
|
||||||
json={
|
|
||||||
"mode": "mixed",
|
|
||||||
"selector": {
|
|
||||||
"entity_hashes": ["entity-1"],
|
|
||||||
"paragraph_hashes": ["p-1"],
|
|
||||||
"relation_hashes": ["rel-1"],
|
|
||||||
"sources": ["demo"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["mode"] == "mixed"
|
|
||||||
assert response.json()["counts"]["entities"] == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_execute_route_supports_mixed_mode(client: TestClient, monkeypatch):
|
|
||||||
async def fake_delete_admin(*, action: str, **kwargs):
|
|
||||||
assert action == "execute"
|
|
||||||
assert kwargs["mode"] == "mixed"
|
|
||||||
assert kwargs["selector"] == {
|
|
||||||
"entity_hashes": ["entity-1"],
|
|
||||||
"paragraph_hashes": ["p-1"],
|
|
||||||
"relation_hashes": ["rel-1"],
|
|
||||||
"sources": ["demo"],
|
|
||||||
}
|
|
||||||
assert kwargs["reason"] == "knowledge_graph_delete_entity"
|
|
||||||
assert kwargs["requested_by"] == "knowledge_graph"
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"mode": "mixed",
|
|
||||||
"operation_id": "op-mixed-1",
|
|
||||||
"deleted_count": 4,
|
|
||||||
"deleted_entity_count": 1,
|
|
||||||
"deleted_relation_count": 1,
|
|
||||||
"deleted_paragraph_count": 1,
|
|
||||||
"deleted_source_count": 1,
|
|
||||||
"counts": {"entities": 1, "paragraphs": 1, "relations": 1, "sources": 1},
|
|
||||||
}
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/api/webui/memory/delete/execute",
|
|
||||||
json={
|
|
||||||
"mode": "mixed",
|
|
||||||
"selector": {
|
|
||||||
"entity_hashes": ["entity-1"],
|
|
||||||
"paragraph_hashes": ["p-1"],
|
|
||||||
"relation_hashes": ["rel-1"],
|
|
||||||
"sources": ["demo"],
|
|
||||||
},
|
|
||||||
"reason": "knowledge_graph_delete_entity",
|
|
||||||
"requested_by": "knowledge_graph",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["success"] is True
|
|
||||||
assert response.json()["mode"] == "mixed"
|
|
||||||
assert response.json()["operation_id"] == "op-mixed-1"
|
|
||||||
|
|
||||||
|
|
||||||
def test_episode_process_pending_route(client: TestClient, monkeypatch):
|
|
||||||
async def fake_episode_admin(*, action: str, **kwargs):
|
|
||||||
assert action == "process_pending"
|
|
||||||
assert kwargs == {"limit": 7, "max_retry": 4}
|
|
||||||
return {"success": True, "processed": 3}
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "episode_admin", fake_episode_admin)
|
|
||||||
|
|
||||||
response = client.post("/api/webui/memory/episodes/process-pending", json={"limit": 7, "max_retry": 4})
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json() == {"success": True, "processed": 3}
|
|
||||||
|
|
||||||
|
|
||||||
def test_import_list_route_includes_settings(client: TestClient, monkeypatch):
|
|
||||||
calls = []
|
|
||||||
|
|
||||||
async def fake_import_admin(*, action: str, **kwargs):
|
|
||||||
calls.append((action, kwargs))
|
|
||||||
if action == "list":
|
|
||||||
return {"success": True, "items": [{"task_id": "task-1"}]}
|
|
||||||
if action == "get_settings":
|
|
||||||
return {"success": True, "settings": {"path_aliases": {"lpmm": "/tmp/lpmm"}}}
|
|
||||||
raise AssertionError(action)
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin)
|
|
||||||
|
|
||||||
response = client.get("/api/webui/memory/import/tasks", params={"limit": 9})
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["items"] == [{"task_id": "task-1"}]
|
|
||||||
assert response.json()["settings"] == {"path_aliases": {"lpmm": "/tmp/lpmm"}}
|
|
||||||
assert calls == [("list", {"limit": 9}), ("get_settings", {})]
|
|
||||||
|
|
||||||
|
|
||||||
def test_tuning_profile_route_backfills_settings(client: TestClient, monkeypatch):
|
|
||||||
calls = []
|
|
||||||
|
|
||||||
async def fake_tuning_admin(*, action: str, **kwargs):
|
|
||||||
calls.append((action, kwargs))
|
|
||||||
if action == "get_profile":
|
|
||||||
return {"success": True, "profile": {"retrieval": {"top_k": 8}}}
|
|
||||||
if action == "get_settings":
|
|
||||||
return {"success": True, "settings": {"profiles": ["default"]}}
|
|
||||||
raise AssertionError(action)
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "tuning_admin", fake_tuning_admin)
|
|
||||||
|
|
||||||
response = client.get("/api/webui/memory/retrieval_tuning/profile")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["profile"] == {"retrieval": {"top_k": 8}}
|
|
||||||
assert response.json()["settings"] == {"profiles": ["default"]}
|
|
||||||
assert calls == [("get_profile", {}), ("get_settings", {})]
|
|
||||||
|
|
||||||
|
|
||||||
def test_tuning_report_route_flattens_report_payload(client: TestClient, monkeypatch):
|
|
||||||
async def fake_tuning_admin(*, action: str, **kwargs):
|
|
||||||
assert action == "get_report"
|
|
||||||
assert kwargs == {"task_id": "task-1", "format": "json"}
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"report": {"format": "json", "content": "{\"ok\": true}", "path": "/tmp/report.json"},
|
|
||||||
}
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "tuning_admin", fake_tuning_admin)
|
|
||||||
|
|
||||||
response = client.get("/api/webui/memory/retrieval_tuning/tasks/task-1/report", params={"format": "json"})
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json() == {
|
|
||||||
"success": True,
|
|
||||||
"format": "json",
|
|
||||||
"content": "{\"ok\": true}",
|
|
||||||
"path": "/tmp/report.json",
|
|
||||||
"error": "",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_execute_route(client: TestClient, monkeypatch):
|
|
||||||
async def fake_delete_admin(*, action: str, **kwargs):
|
|
||||||
assert action == "execute"
|
|
||||||
assert kwargs["mode"] == "source"
|
|
||||||
assert kwargs["selector"] == {"source": "chat_summary:stream-1"}
|
|
||||||
assert kwargs["reason"] == "cleanup"
|
|
||||||
assert kwargs["requested_by"] == "tester"
|
|
||||||
return {"success": True, "operation_id": "del-1"}
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/api/webui/memory/delete/execute",
|
|
||||||
json={
|
|
||||||
"mode": "source",
|
|
||||||
"selector": {"source": "chat_summary:stream-1"},
|
|
||||||
"reason": "cleanup",
|
|
||||||
"requested_by": "tester",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json() == {"success": True, "operation_id": "del-1"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_sources_route(client: TestClient, monkeypatch):
|
|
||||||
async def fake_source_admin(*, action: str, **kwargs):
|
|
||||||
assert action == "list"
|
|
||||||
assert kwargs == {}
|
|
||||||
return {"success": True, "items": [{"source": "demo", "paragraph_count": 2}], "count": 1}
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "source_admin", fake_source_admin)
|
|
||||||
|
|
||||||
response = client.get("/api/webui/memory/sources")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["items"] == [{"source": "demo", "paragraph_count": 2}]
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_operation_routes(client: TestClient, monkeypatch):
|
|
||||||
async def fake_delete_admin(*, action: str, **kwargs):
|
|
||||||
if action == "list_operations":
|
|
||||||
assert kwargs == {"limit": 5, "mode": "paragraph"}
|
|
||||||
return {"success": True, "items": [{"operation_id": "del-1"}], "count": 1}
|
|
||||||
if action == "get_operation":
|
|
||||||
assert kwargs == {"operation_id": "del-1"}
|
|
||||||
return {"success": True, "operation": {"operation_id": "del-1", "mode": "paragraph"}}
|
|
||||||
raise AssertionError(action)
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
|
|
||||||
|
|
||||||
list_response = client.get("/api/webui/memory/delete/operations", params={"limit": 5, "mode": "paragraph"})
|
|
||||||
get_response = client.get("/api/webui/memory/delete/operations/del-1")
|
|
||||||
|
|
||||||
assert list_response.status_code == 200
|
|
||||||
assert list_response.json()["count"] == 1
|
|
||||||
assert get_response.status_code == 200
|
|
||||||
assert get_response.json()["operation"]["operation_id"] == "del-1"
|
|
||||||
|
|
||||||
|
|
||||||
def test_feedback_correction_routes(client: TestClient, monkeypatch):
|
|
||||||
async def fake_feedback_admin(*, action: str, **kwargs):
|
|
||||||
if action == "list":
|
|
||||||
assert kwargs == {
|
|
||||||
"limit": 7,
|
|
||||||
"statuses": ["applied"],
|
|
||||||
"rollback_statuses": ["none"],
|
|
||||||
"query": "green",
|
|
||||||
}
|
|
||||||
return {"success": True, "items": [{"task_id": 11, "query_text": "what color"}], "count": 1}
|
|
||||||
if action == "get":
|
|
||||||
assert kwargs == {"task_id": 11}
|
|
||||||
return {"success": True, "task": {"task_id": 11, "query_text": "what color", "action_logs": []}}
|
|
||||||
if action == "rollback":
|
|
||||||
assert kwargs == {"task_id": 11, "requested_by": "tester", "reason": "manual revert"}
|
|
||||||
return {"success": True, "result": {"restored_relation_hashes": ["rel-1"]}}
|
|
||||||
raise AssertionError(action)
|
|
||||||
|
|
||||||
monkeypatch.setattr(memory_router_module.memory_service, "feedback_admin", fake_feedback_admin)
|
|
||||||
|
|
||||||
list_response = client.get(
|
|
||||||
"/api/webui/memory/feedback-corrections",
|
|
||||||
params={"limit": 7, "status": "applied", "rollback_status": "none", "query": "green"},
|
|
||||||
)
|
|
||||||
get_response = client.get("/api/webui/memory/feedback-corrections/11")
|
|
||||||
rollback_response = client.post(
|
|
||||||
"/api/webui/memory/feedback-corrections/11/rollback",
|
|
||||||
json={"requested_by": "tester", "reason": "manual revert"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert list_response.status_code == 200
|
|
||||||
assert list_response.json()["items"][0]["task_id"] == 11
|
|
||||||
assert get_response.status_code == 200
|
|
||||||
assert get_response.json()["task"]["task_id"] == 11
|
|
||||||
assert rollback_response.status_code == 200
|
|
||||||
assert rollback_response.json()["result"]["restored_relation_hashes"] == ["rel-1"]
|
|
||||||
@@ -1,533 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from time import monotonic, sleep
|
|
||||||
from typing import Any, Dict, Generator
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
import pytest
|
|
||||||
import tomlkit
|
|
||||||
|
|
||||||
from src.A_memorix import host_service as host_service_module
|
|
||||||
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
|
|
||||||
from src.A_memorix.core.utils import retrieval_tuning_manager as tuning_manager_module
|
|
||||||
from src.webui.dependencies import require_auth
|
|
||||||
from src.webui.routers import memory as memory_router_module
|
|
||||||
|
|
||||||
|
|
||||||
REQUEST_TIMEOUT_SECONDS = 30
|
|
||||||
IMPORT_TIMEOUT_SECONDS = 120
|
|
||||||
TUNING_TIMEOUT_SECONDS = 420
|
|
||||||
|
|
||||||
IMPORT_TERMINAL_STATUSES = {"completed", "completed_with_errors", "failed", "cancelled"}
|
|
||||||
TUNING_TERMINAL_STATUSES = {"completed", "failed", "cancelled"}
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeEmbeddingManager:
|
|
||||||
def __init__(self, dimension: int = 64) -> None:
|
|
||||||
self.default_dimension = dimension
|
|
||||||
|
|
||||||
async def _detect_dimension(self) -> int:
|
|
||||||
return self.default_dimension
|
|
||||||
|
|
||||||
async def encode(self, text: Any, **kwargs: Any) -> Any:
|
|
||||||
del kwargs
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
def _encode_one(raw: Any) -> Any:
|
|
||||||
content = str(raw or "")
|
|
||||||
vector = np.zeros(self.default_dimension, dtype=np.float32)
|
|
||||||
for index, byte in enumerate(content.encode("utf-8")):
|
|
||||||
vector[index % self.default_dimension] += float((byte % 17) + 1)
|
|
||||||
norm = float(np.linalg.norm(vector))
|
|
||||||
if norm > 0:
|
|
||||||
vector /= norm
|
|
||||||
return vector
|
|
||||||
|
|
||||||
if isinstance(text, (list, tuple)):
|
|
||||||
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
|
|
||||||
return _encode_one(text).astype(np.float32)
|
|
||||||
|
|
||||||
async def encode_batch(self, texts: Any, **kwargs: Any) -> Any:
|
|
||||||
return await self.encode(texts, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_test_config(data_dir: Path) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"storage": {
|
|
||||||
"data_dir": str(data_dir),
|
|
||||||
},
|
|
||||||
"advanced": {
|
|
||||||
"enable_auto_save": False,
|
|
||||||
},
|
|
||||||
"embedding": {
|
|
||||||
"dimension": 64,
|
|
||||||
"batch_size": 4,
|
|
||||||
"max_concurrent": 1,
|
|
||||||
"retry": {
|
|
||||||
"max_attempts": 1,
|
|
||||||
"min_wait_seconds": 0.1,
|
|
||||||
"max_wait_seconds": 0.2,
|
|
||||||
"backoff_multiplier": 1.0,
|
|
||||||
},
|
|
||||||
"fallback": {
|
|
||||||
"enabled": True,
|
|
||||||
"allow_metadata_only_write": True,
|
|
||||||
"probe_interval_seconds": 30,
|
|
||||||
},
|
|
||||||
"paragraph_vector_backfill": {
|
|
||||||
"enabled": False,
|
|
||||||
"interval_seconds": 60,
|
|
||||||
"batch_size": 32,
|
|
||||||
"max_retry": 2,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"retrieval": {
|
|
||||||
"enable_parallel": False,
|
|
||||||
"enable_ppr": False,
|
|
||||||
"top_k_paragraphs": 20,
|
|
||||||
"top_k_relations": 10,
|
|
||||||
"top_k_final": 10,
|
|
||||||
"alpha": 0.5,
|
|
||||||
"search": {
|
|
||||||
"smart_fallback": {
|
|
||||||
"enabled": True,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"sparse": {
|
|
||||||
"enabled": True,
|
|
||||||
"mode": "auto",
|
|
||||||
"candidate_k": 80,
|
|
||||||
"relation_candidate_k": 60,
|
|
||||||
},
|
|
||||||
"fusion": {
|
|
||||||
"method": "weighted_rrf",
|
|
||||||
"rrf_k": 60,
|
|
||||||
"vector_weight": 0.7,
|
|
||||||
"bm25_weight": 0.3,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"threshold": {
|
|
||||||
"percentile": 70.0,
|
|
||||||
"min_results": 1,
|
|
||||||
},
|
|
||||||
"web": {
|
|
||||||
"tuning": {
|
|
||||||
"enabled": True,
|
|
||||||
"poll_interval_ms": 300,
|
|
||||||
"max_queue_size": 4,
|
|
||||||
"default_objective": "balanced",
|
|
||||||
"default_intensity": "quick",
|
|
||||||
"default_sample_size": 4,
|
|
||||||
"default_top_k_eval": 5,
|
|
||||||
"eval_query_timeout_seconds": 1.0,
|
|
||||||
"llm_retry": {
|
|
||||||
"max_attempts": 1,
|
|
||||||
"min_wait_seconds": 0.1,
|
|
||||||
"max_wait_seconds": 0.2,
|
|
||||||
"backoff_multiplier": 1.0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _assert_response_ok(response: Any) -> Dict[str, Any]:
|
|
||||||
assert response.status_code == 200, response.text
|
|
||||||
payload = response.json()
|
|
||||||
assert payload.get("success", True) is True, payload
|
|
||||||
return payload
|
|
||||||
|
|
||||||
|
|
||||||
def _wait_for_import_task_terminal(client: TestClient, task_id: str, *, timeout_seconds: float = IMPORT_TIMEOUT_SECONDS) -> Dict[str, Any]:
|
|
||||||
deadline = monotonic() + timeout_seconds
|
|
||||||
last_payload: Dict[str, Any] = {}
|
|
||||||
while monotonic() < deadline:
|
|
||||||
response = client.get(
|
|
||||||
f"/api/webui/memory/import/tasks/{task_id}",
|
|
||||||
params={"include_chunks": True},
|
|
||||||
)
|
|
||||||
payload = _assert_response_ok(response)
|
|
||||||
last_payload = payload
|
|
||||||
task = payload.get("task") or {}
|
|
||||||
status = str(task.get("status", "") or "")
|
|
||||||
if status in IMPORT_TERMINAL_STATUSES:
|
|
||||||
return task
|
|
||||||
sleep(0.2)
|
|
||||||
raise AssertionError(f"导入任务超时: task_id={task_id}, last_payload={last_payload}")
|
|
||||||
|
|
||||||
|
|
||||||
def _wait_for_tuning_task_terminal(client: TestClient, task_id: str, *, timeout_seconds: float = TUNING_TIMEOUT_SECONDS) -> Dict[str, Any]:
|
|
||||||
deadline = monotonic() + timeout_seconds
|
|
||||||
last_payload: Dict[str, Any] = {}
|
|
||||||
while monotonic() < deadline:
|
|
||||||
response = client.get(
|
|
||||||
f"/api/webui/memory/retrieval_tuning/tasks/{task_id}",
|
|
||||||
params={"include_rounds": False},
|
|
||||||
)
|
|
||||||
payload = _assert_response_ok(response)
|
|
||||||
last_payload = payload
|
|
||||||
task = payload.get("task") or {}
|
|
||||||
status = str(task.get("status", "") or "")
|
|
||||||
if status in TUNING_TERMINAL_STATUSES:
|
|
||||||
return task
|
|
||||||
sleep(0.3)
|
|
||||||
raise AssertionError(f"调优任务超时: task_id={task_id}, last_payload={last_payload}")
|
|
||||||
|
|
||||||
|
|
||||||
def _wait_for_query_hit(client: TestClient, query: str, *, timeout_seconds: float = 30.0) -> Dict[str, Any]:
|
|
||||||
deadline = monotonic() + timeout_seconds
|
|
||||||
last_payload: Dict[str, Any] = {}
|
|
||||||
while monotonic() < deadline:
|
|
||||||
payload = _assert_response_ok(
|
|
||||||
client.get(
|
|
||||||
"/api/webui/memory/query/aggregate",
|
|
||||||
params={"query": query, "limit": 20},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
last_payload = payload
|
|
||||||
hits = payload.get("hits") or []
|
|
||||||
if isinstance(hits, list) and len(hits) > 0:
|
|
||||||
return payload
|
|
||||||
sleep(0.2)
|
|
||||||
raise AssertionError(f"检索命中超时: query={query}, last_payload={last_payload}")
|
|
||||||
|
|
||||||
|
|
||||||
def _get_source_item(client: TestClient, source_name: str) -> Dict[str, Any] | None:
|
|
||||||
payload = _assert_response_ok(client.get("/api/webui/memory/sources"))
|
|
||||||
items = payload.get("items") or []
|
|
||||||
for item in items:
|
|
||||||
if not isinstance(item, dict):
|
|
||||||
continue
|
|
||||||
if str(item.get("source", "") or "") == source_name:
|
|
||||||
return item
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _source_paragraph_count(item: Dict[str, Any] | None) -> int:
|
|
||||||
payload = item or {}
|
|
||||||
if "paragraph_count" in payload:
|
|
||||||
return int(payload.get("paragraph_count", 0) or 0)
|
|
||||||
return int(payload.get("count", 0) or 0)
|
|
||||||
|
|
||||||
|
|
||||||
def _wait_for_source_paragraph_count(
|
|
||||||
client: TestClient,
|
|
||||||
source_name: str,
|
|
||||||
*,
|
|
||||||
min_count: int,
|
|
||||||
timeout_seconds: float = 30.0,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
deadline = monotonic() + timeout_seconds
|
|
||||||
last_item: Dict[str, Any] = {}
|
|
||||||
while monotonic() < deadline:
|
|
||||||
item = _get_source_item(client, source_name)
|
|
||||||
count = _source_paragraph_count(item)
|
|
||||||
if count >= int(min_count):
|
|
||||||
return item or {}
|
|
||||||
if item:
|
|
||||||
last_item = dict(item)
|
|
||||||
sleep(0.2)
|
|
||||||
raise AssertionError(
|
|
||||||
f"等待来源段落计数超时: source={source_name}, min_count={min_count}, last_item={last_item}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _create_multitype_upload_task(client: TestClient) -> str:
|
|
||||||
structured_json = {
|
|
||||||
"paragraphs": [
|
|
||||||
{
|
|
||||||
"content": "Alice 携带地图前往火星港。",
|
|
||||||
"source": "integration-upload-json",
|
|
||||||
"entities": ["Alice", "地图", "火星港"],
|
|
||||||
"relations": [
|
|
||||||
{"subject": "Alice", "predicate": "携带", "object": "地图"},
|
|
||||||
{"subject": "Alice", "predicate": "前往", "object": "火星港"},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
extra_json = {
|
|
||||||
"paragraphs": [
|
|
||||||
{
|
|
||||||
"content": "Carol 记录了一条补充说明。",
|
|
||||||
"source": "integration-upload-json-extra",
|
|
||||||
"entities": ["Carol"],
|
|
||||||
"relations": [],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
payload_json = json.dumps(
|
|
||||||
{
|
|
||||||
"input_mode": "text",
|
|
||||||
"llm_enabled": False,
|
|
||||||
"file_concurrency": 2,
|
|
||||||
"chunk_concurrency": 2,
|
|
||||||
"dedupe_policy": "none",
|
|
||||||
},
|
|
||||||
ensure_ascii=False,
|
|
||||||
)
|
|
||||||
files = [
|
|
||||||
("files", ("integration-notes.txt", "Alice 在测试环境记录了一条长期记忆。".encode("utf-8"), "text/plain")),
|
|
||||||
("files", ("integration-diary.md", "# 日志\nBob 与 Alice 讨论了导图。".encode("utf-8"), "text/markdown")),
|
|
||||||
("files", ("integration-structured.json", json.dumps(structured_json, ensure_ascii=False).encode("utf-8"), "application/json")),
|
|
||||||
("files", ("integration-extra.json", json.dumps(extra_json, ensure_ascii=False).encode("utf-8"), "application/json")),
|
|
||||||
]
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/api/webui/memory/import/upload",
|
|
||||||
data={"payload_json": payload_json},
|
|
||||||
files=files,
|
|
||||||
)
|
|
||||||
payload = _assert_response_ok(response)
|
|
||||||
task_id = str((payload.get("task") or {}).get("task_id") or "").strip()
|
|
||||||
assert task_id, payload
|
|
||||||
return task_id
|
|
||||||
|
|
||||||
|
|
||||||
def _create_seed_paste_task(client: TestClient, *, source: str, unique_token: str) -> str:
|
|
||||||
seed_payload = {
|
|
||||||
"paragraphs": [
|
|
||||||
{
|
|
||||||
"content": f"Alice 在火星港携带地图并记录了口令 {unique_token}。",
|
|
||||||
"source": source,
|
|
||||||
"entities": ["Alice", "火星港", "地图"],
|
|
||||||
"relations": [
|
|
||||||
{"subject": "Alice", "predicate": "前往", "object": "火星港"},
|
|
||||||
{"subject": "Alice", "predicate": "携带", "object": "地图"},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"content": f"Bob 在火星港遇见 Alice,并重复口令 {unique_token}。",
|
|
||||||
"source": source,
|
|
||||||
"entities": ["Bob", "Alice", "火星港"],
|
|
||||||
"relations": [
|
|
||||||
{"subject": "Bob", "predicate": "遇见", "object": "Alice"},
|
|
||||||
{"subject": "Bob", "predicate": "位于", "object": "火星港"},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
response = client.post(
|
|
||||||
"/api/webui/memory/import/paste",
|
|
||||||
json={
|
|
||||||
"name": "integration-seed.json",
|
|
||||||
"input_mode": "json",
|
|
||||||
"llm_enabled": False,
|
|
||||||
"content": json.dumps(seed_payload, ensure_ascii=False),
|
|
||||||
"dedupe_policy": "none",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
payload = _assert_response_ok(response)
|
|
||||||
task_id = str((payload.get("task") or {}).get("task_id") or "").strip()
|
|
||||||
assert task_id, payload
|
|
||||||
return task_id
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def integration_state(tmp_path_factory: pytest.TempPathFactory) -> Generator[Dict[str, Any], None, None]:
|
|
||||||
tmp_root = tmp_path_factory.mktemp("memory_routes_integration")
|
|
||||||
data_dir = (tmp_root / "data").resolve()
|
|
||||||
staging_dir = (tmp_root / "upload_staging").resolve()
|
|
||||||
artifacts_dir = (tmp_root / "artifacts").resolve()
|
|
||||||
config_file = (tmp_root / "config" / "bot_config.toml").resolve()
|
|
||||||
runtime_config = _build_test_config(data_dir)
|
|
||||||
|
|
||||||
patches = pytest.MonkeyPatch()
|
|
||||||
patches.setattr(host_service_module.a_memorix_host_service, "_read_config", lambda: dict(runtime_config))
|
|
||||||
patches.setattr(host_service_module.a_memorix_host_service, "get_config_path", lambda: config_file)
|
|
||||||
patches.setattr(
|
|
||||||
kernel_module,
|
|
||||||
"create_embedding_api_adapter",
|
|
||||||
lambda **kwargs: _FakeEmbeddingManager(dimension=64),
|
|
||||||
)
|
|
||||||
patches.setattr(memory_router_module, "STAGING_ROOT", staging_dir)
|
|
||||||
patches.setattr(tuning_manager_module, "artifacts_root", lambda: artifacts_dir)
|
|
||||||
|
|
||||||
asyncio.run(host_service_module.a_memorix_host_service.stop())
|
|
||||||
host_service_module.a_memorix_host_service._config_cache = None # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
app.dependency_overrides[require_auth] = lambda: "ok"
|
|
||||||
app.include_router(memory_router_module.router, prefix="/api/webui")
|
|
||||||
app.include_router(memory_router_module.compat_router)
|
|
||||||
|
|
||||||
unique_token = f"INTEG_TOKEN_{uuid4().hex[:12]}"
|
|
||||||
source_name = f"integration-source-{uuid4().hex[:8]}"
|
|
||||||
|
|
||||||
with TestClient(app) as client:
|
|
||||||
upload_task_id = _create_multitype_upload_task(client)
|
|
||||||
upload_task = _wait_for_import_task_terminal(client, upload_task_id)
|
|
||||||
|
|
||||||
seed_task_id = _create_seed_paste_task(client, source=source_name, unique_token=unique_token)
|
|
||||||
seed_task = _wait_for_import_task_terminal(client, seed_task_id)
|
|
||||||
assert str(seed_task.get("status", "") or "") in {"completed", "completed_with_errors"}, seed_task
|
|
||||||
|
|
||||||
_wait_for_query_hit(client, unique_token, timeout_seconds=45.0)
|
|
||||||
|
|
||||||
yield {
|
|
||||||
"client": client,
|
|
||||||
"upload_task": upload_task,
|
|
||||||
"seed_task": seed_task,
|
|
||||||
"source_name": source_name,
|
|
||||||
"unique_token": unique_token,
|
|
||||||
}
|
|
||||||
|
|
||||||
asyncio.run(host_service_module.a_memorix_host_service.stop())
|
|
||||||
host_service_module.a_memorix_host_service._config_cache = None # type: ignore[attr-defined]
|
|
||||||
patches.undo()
|
|
||||||
|
|
||||||
|
|
||||||
def test_import_module_end_to_end_supports_multitype_upload(integration_state: Dict[str, Any]) -> None:
|
|
||||||
upload_task = integration_state["upload_task"]
|
|
||||||
|
|
||||||
assert str(upload_task.get("status", "") or "") in {"completed", "completed_with_errors"}, upload_task
|
|
||||||
files = upload_task.get("files") or []
|
|
||||||
assert isinstance(files, list)
|
|
||||||
assert len(files) >= 4
|
|
||||||
|
|
||||||
file_names = {str(item.get("name", "") or "") for item in files if isinstance(item, dict)}
|
|
||||||
assert "integration-notes.txt" in file_names
|
|
||||||
assert "integration-diary.md" in file_names
|
|
||||||
assert "integration-structured.json" in file_names
|
|
||||||
assert "integration-extra.json" in file_names
|
|
||||||
|
|
||||||
|
|
||||||
def test_retrieval_module_end_to_end_queries_seeded_data(integration_state: Dict[str, Any]) -> None:
|
|
||||||
client = integration_state["client"]
|
|
||||||
unique_token = integration_state["unique_token"]
|
|
||||||
|
|
||||||
aggregate_payload = _wait_for_query_hit(client, unique_token, timeout_seconds=45.0)
|
|
||||||
hits = aggregate_payload.get("hits") or []
|
|
||||||
joined_content = "\n".join(str(item.get("content", "") or "") for item in hits if isinstance(item, dict))
|
|
||||||
assert unique_token in joined_content
|
|
||||||
|
|
||||||
graph_payload = _assert_response_ok(
|
|
||||||
client.get(
|
|
||||||
"/api/webui/memory/graph/search",
|
|
||||||
params={"query": "Alice", "limit": 20},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
graph_items = graph_payload.get("items") or []
|
|
||||||
assert isinstance(graph_items, list)
|
|
||||||
assert any(str(item.get("type", "") or "") == "entity" for item in graph_items if isinstance(item, dict)), graph_items
|
|
||||||
|
|
||||||
|
|
||||||
def test_tuning_module_end_to_end_create_and_apply_best(integration_state: Dict[str, Any]) -> None:
|
|
||||||
client = integration_state["client"]
|
|
||||||
|
|
||||||
create_payload = _assert_response_ok(
|
|
||||||
client.post(
|
|
||||||
"/api/webui/memory/retrieval_tuning/tasks",
|
|
||||||
json={
|
|
||||||
"objective": "balanced",
|
|
||||||
"intensity": "quick",
|
|
||||||
"rounds": 2,
|
|
||||||
"sample_size": 4,
|
|
||||||
"top_k_eval": 5,
|
|
||||||
"llm_enabled": False,
|
|
||||||
"eval_query_timeout_seconds": 1.0,
|
|
||||||
"seed": 20260403,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
task_id = str((create_payload.get("task") or {}).get("task_id") or "").strip()
|
|
||||||
assert task_id, create_payload
|
|
||||||
|
|
||||||
task = _wait_for_tuning_task_terminal(client, task_id)
|
|
||||||
assert str(task.get("status", "") or "") == "completed", task
|
|
||||||
|
|
||||||
apply_payload = _assert_response_ok(
|
|
||||||
client.post(
|
|
||||||
f"/api/webui/memory/retrieval_tuning/tasks/{task_id}/apply-best",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
assert "applied" in apply_payload
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_module_end_to_end_preview_execute_restore(integration_state: Dict[str, Any]) -> None:
|
|
||||||
client = integration_state["client"]
|
|
||||||
unique_token = integration_state["unique_token"]
|
|
||||||
source_name = integration_state["source_name"]
|
|
||||||
|
|
||||||
before_source_item = _wait_for_source_paragraph_count(client, source_name, min_count=1, timeout_seconds=45.0)
|
|
||||||
assert _source_paragraph_count(before_source_item) >= 1
|
|
||||||
|
|
||||||
preview_payload = _assert_response_ok(
|
|
||||||
client.post(
|
|
||||||
"/api/webui/memory/delete/preview",
|
|
||||||
json={
|
|
||||||
"mode": "source",
|
|
||||||
"selector": {"sources": [source_name]},
|
|
||||||
"reason": "integration_delete_preview",
|
|
||||||
"requested_by": "pytest_integration",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
preview_counts = preview_payload.get("counts") or {}
|
|
||||||
assert int(preview_counts.get("paragraphs", 0) or 0) >= 1, preview_payload
|
|
||||||
|
|
||||||
execute_payload = _assert_response_ok(
|
|
||||||
client.post(
|
|
||||||
"/api/webui/memory/delete/execute",
|
|
||||||
json={
|
|
||||||
"mode": "source",
|
|
||||||
"selector": {"sources": [source_name]},
|
|
||||||
"reason": "integration_delete_execute",
|
|
||||||
"requested_by": "pytest_integration",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
operation_id = str(execute_payload.get("operation_id", "") or "").strip()
|
|
||||||
assert operation_id, execute_payload
|
|
||||||
|
|
||||||
after_delete_payload = _assert_response_ok(
|
|
||||||
client.get(
|
|
||||||
"/api/webui/memory/query/aggregate",
|
|
||||||
params={"query": unique_token, "limit": 20},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
after_delete_hits = after_delete_payload.get("hits") or []
|
|
||||||
after_delete_text = "\n".join(
|
|
||||||
str(item.get("content", "") or "")
|
|
||||||
for item in after_delete_hits
|
|
||||||
if isinstance(item, dict)
|
|
||||||
)
|
|
||||||
assert unique_token not in after_delete_text
|
|
||||||
assert int(execute_payload.get("deleted_paragraph_count", 0) or 0) >= 1, execute_payload
|
|
||||||
|
|
||||||
_assert_response_ok(
|
|
||||||
client.post(
|
|
||||||
"/api/webui/memory/delete/restore",
|
|
||||||
json={
|
|
||||||
"operation_id": operation_id,
|
|
||||||
"requested_by": "pytest_integration",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
restored_source_item = _wait_for_source_paragraph_count(client, source_name, min_count=1, timeout_seconds=45.0)
|
|
||||||
assert _source_paragraph_count(restored_source_item) >= 1
|
|
||||||
|
|
||||||
operations_payload = _assert_response_ok(
|
|
||||||
client.get(
|
|
||||||
"/api/webui/memory/delete/operations",
|
|
||||||
params={"limit": 20, "mode": "source"},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
operation_items = operations_payload.get("items") or []
|
|
||||||
operation_ids = {
|
|
||||||
str(item.get("operation_id", "") or "")
|
|
||||||
for item in operation_items
|
|
||||||
if isinstance(item, dict)
|
|
||||||
}
|
|
||||||
assert operation_id in operation_ids
|
|
||||||
|
|
||||||
operation_detail_payload = _assert_response_ok(client.get(f"/api/webui/memory/delete/operations/{operation_id}"))
|
|
||||||
detail_operation = operation_detail_payload.get("operation") or {}
|
|
||||||
assert str(detail_operation.get("status", "") or "") == "restored"
|
|
||||||
@@ -1,187 +0,0 @@
|
|||||||
"""模型路由测试
|
|
||||||
|
|
||||||
验证 Gemini 提供商连接测试会使用查询参数传递 API Key,
|
|
||||||
并且不会回退到 OpenAI 兼容接口使用的 Bearer 认证方式。
|
|
||||||
"""
|
|
||||||
|
|
||||||
import importlib
|
|
||||||
import sys
|
|
||||||
from types import ModuleType
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_routes(monkeypatch: pytest.MonkeyPatch):
|
|
||||||
"""在导入路由前 stub 配置与认证依赖模块,避免测试时触发真实初始化。"""
|
|
||||||
config_module = ModuleType("src.config.config")
|
|
||||||
config_module.__dict__["CONFIG_DIR"] = "."
|
|
||||||
monkeypatch.setitem(sys.modules, "src.config.config", config_module)
|
|
||||||
|
|
||||||
dependencies_module = ModuleType("src.webui.dependencies")
|
|
||||||
|
|
||||||
async def require_auth():
|
|
||||||
return "test-token"
|
|
||||||
|
|
||||||
dependencies_module.__dict__["require_auth"] = require_auth
|
|
||||||
monkeypatch.setitem(sys.modules, "src.webui.dependencies", dependencies_module)
|
|
||||||
|
|
||||||
sys.modules.pop("src.webui.routers.model", None)
|
|
||||||
return importlib.import_module("src.webui.routers.model")
|
|
||||||
|
|
||||||
|
|
||||||
class FakeResponse:
|
|
||||||
"""简化版 HTTP 响应对象。"""
|
|
||||||
|
|
||||||
def __init__(self, status_code: int):
|
|
||||||
self.status_code = status_code
|
|
||||||
|
|
||||||
|
|
||||||
def build_async_client_factory(
|
|
||||||
responses: list[FakeResponse],
|
|
||||||
calls: list[dict[str, Any]],
|
|
||||||
):
|
|
||||||
"""构造一个可记录请求参数的 AsyncClient 替身。"""
|
|
||||||
|
|
||||||
response_iter = iter(responses)
|
|
||||||
|
|
||||||
class FakeAsyncClient:
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any):
|
|
||||||
self.args = args
|
|
||||||
self.kwargs = kwargs
|
|
||||||
|
|
||||||
async def __aenter__(self) -> "FakeAsyncClient":
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def get(
|
|
||||||
self,
|
|
||||||
url: str,
|
|
||||||
headers: dict[str, Any] | None = None,
|
|
||||||
params: dict[str, Any] | None = None,
|
|
||||||
) -> FakeResponse:
|
|
||||||
calls.append(
|
|
||||||
{
|
|
||||||
"url": url,
|
|
||||||
"headers": headers or {},
|
|
||||||
"params": params or {},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return next(response_iter)
|
|
||||||
|
|
||||||
return FakeAsyncClient
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_test_provider_connection_uses_query_api_key_for_gemini(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
"""Gemini 连接测试应通过查询参数传递 API Key。"""
|
|
||||||
model_routes = load_model_routes(monkeypatch)
|
|
||||||
calls: list[dict[str, Any]] = []
|
|
||||||
fake_client_class = build_async_client_factory(
|
|
||||||
responses=[FakeResponse(200), FakeResponse(200)],
|
|
||||||
calls=calls,
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class)
|
|
||||||
|
|
||||||
result = await model_routes.test_provider_connection(
|
|
||||||
base_url="https://generativelanguage.googleapis.com/v1beta",
|
|
||||||
api_key="valid-gemini-key",
|
|
||||||
client_type="gemini",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result["network_ok"] is True
|
|
||||||
assert result["api_key_valid"] is True
|
|
||||||
assert len(calls) == 2
|
|
||||||
|
|
||||||
network_call = calls[0]
|
|
||||||
validation_call = calls[1]
|
|
||||||
|
|
||||||
assert network_call["url"] == "https://generativelanguage.googleapis.com/v1beta"
|
|
||||||
assert network_call["headers"] == {}
|
|
||||||
assert network_call["params"] == {}
|
|
||||||
|
|
||||||
assert validation_call["url"] == "https://generativelanguage.googleapis.com/v1beta/models"
|
|
||||||
assert validation_call["params"] == {"key": "valid-gemini-key"}
|
|
||||||
assert validation_call["headers"] == {"Content-Type": "application/json"}
|
|
||||||
assert "Authorization" not in validation_call["headers"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_test_provider_connection_uses_bearer_auth_for_openai_compatible(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
"""非 Gemini 提供商连接测试应继续使用 Bearer 认证。"""
|
|
||||||
model_routes = load_model_routes(monkeypatch)
|
|
||||||
calls: list[dict[str, Any]] = []
|
|
||||||
fake_client_class = build_async_client_factory(
|
|
||||||
responses=[FakeResponse(200), FakeResponse(200)],
|
|
||||||
calls=calls,
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class)
|
|
||||||
|
|
||||||
result = await model_routes.test_provider_connection(
|
|
||||||
base_url="https://example.com/v1",
|
|
||||||
api_key="valid-openai-key",
|
|
||||||
client_type="openai",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result["network_ok"] is True
|
|
||||||
assert result["api_key_valid"] is True
|
|
||||||
assert len(calls) == 2
|
|
||||||
|
|
||||||
validation_call = calls[1]
|
|
||||||
|
|
||||||
assert validation_call["url"] == "https://example.com/v1/models"
|
|
||||||
assert validation_call["params"] == {}
|
|
||||||
assert validation_call["headers"]["Content-Type"] == "application/json"
|
|
||||||
assert validation_call["headers"]["Authorization"] == "Bearer valid-openai-key"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_test_provider_connection_by_name_forwards_provider_client_type(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
tmp_path,
|
|
||||||
) -> None:
|
|
||||||
"""按提供商名称测试连接时,应透传配置中的 client_type。"""
|
|
||||||
model_routes = load_model_routes(monkeypatch)
|
|
||||||
config_path = tmp_path / "model_config.toml"
|
|
||||||
config_path.write_text(
|
|
||||||
"""
|
|
||||||
[[api_providers]]
|
|
||||||
name = "Gemini"
|
|
||||||
base_url = "https://generativelanguage.googleapis.com/v1beta"
|
|
||||||
api_key = "valid-gemini-key"
|
|
||||||
client_type = "gemini"
|
|
||||||
""".strip(),
|
|
||||||
encoding="utf-8",
|
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setattr(model_routes, "CONFIG_DIR", str(tmp_path))
|
|
||||||
|
|
||||||
captured_kwargs: dict[str, Any] = {}
|
|
||||||
|
|
||||||
async def fake_test_provider_connection(**kwargs: Any) -> dict[str, Any]:
|
|
||||||
captured_kwargs.update(kwargs)
|
|
||||||
return {
|
|
||||||
"network_ok": True,
|
|
||||||
"api_key_valid": True,
|
|
||||||
"latency_ms": 12.34,
|
|
||||||
"error": None,
|
|
||||||
"http_status": 200,
|
|
||||||
}
|
|
||||||
|
|
||||||
monkeypatch.setattr(model_routes, "test_provider_connection", fake_test_provider_connection)
|
|
||||||
|
|
||||||
result = await model_routes.test_provider_connection_by_name(provider_name="Gemini")
|
|
||||||
|
|
||||||
assert result["network_ok"] is True
|
|
||||||
assert result["api_key_valid"] is True
|
|
||||||
assert captured_kwargs == {
|
|
||||||
"base_url": "https://generativelanguage.googleapis.com/v1beta",
|
|
||||||
"api_key": "valid-gemini-key",
|
|
||||||
"client_type": "gemini",
|
|
||||||
}
|
|
||||||
@@ -1,136 +0,0 @@
|
|||||||
import json
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
|
|
||||||
from src.webui.routers.plugin import management as management_module
|
|
||||||
from src.webui.routers.plugin import support as support_module
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def client(tmp_path, monkeypatch) -> TestClient:
|
|
||||||
plugins_dir = tmp_path / "plugins"
|
|
||||||
plugins_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
demo_dir = plugins_dir / "demo_plugin"
|
|
||||||
demo_dir.mkdir()
|
|
||||||
(demo_dir / "_manifest.json").write_text(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"manifest_version": 2,
|
|
||||||
"id": "test.demo",
|
|
||||||
"name": "Demo Plugin",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": "demo plugin",
|
|
||||||
}
|
|
||||||
),
|
|
||||||
encoding="utf-8",
|
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setattr(management_module, "require_plugin_token", lambda _: "ok")
|
|
||||||
monkeypatch.setattr(support_module, "get_plugins_dir", lambda: plugins_dir)
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
app.include_router(management_module.router, prefix="/api/webui/plugins")
|
|
||||||
return TestClient(app)
|
|
||||||
|
|
||||||
|
|
||||||
def test_installed_plugins_only_scan_plugins_dir_and_exclude_a_memorix(client: TestClient):
|
|
||||||
response = client.get("/api/webui/plugins/installed")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
payload = response.json()
|
|
||||||
assert payload["success"] is True
|
|
||||||
|
|
||||||
ids = [plugin["id"] for plugin in payload["plugins"]]
|
|
||||||
assert ids == ["test.demo"]
|
|
||||||
assert "a-dawn.a-memorix" not in ids
|
|
||||||
assert all("/src/plugins/built_in/" not in plugin["path"] for plugin in payload["plugins"])
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_installed_plugin_path_falls_back_to_manifest_id(client: TestClient):
|
|
||||||
plugin_path = support_module.resolve_installed_plugin_path("test.demo")
|
|
||||||
|
|
||||||
assert plugin_path is not None
|
|
||||||
assert plugin_path.name == "demo_plugin"
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_installed_plugin_path_accepts_manifest_id_case_mismatch(client: TestClient):
|
|
||||||
plugin_path = support_module.resolve_installed_plugin_path("Test.Demo")
|
|
||||||
|
|
||||||
assert plugin_path is not None
|
|
||||||
assert plugin_path.name == "demo_plugin"
|
|
||||||
|
|
||||||
|
|
||||||
def test_install_plugin_preserves_manifest_declared_id(client: TestClient, monkeypatch):
|
|
||||||
class FakeGitMirrorService:
|
|
||||||
async def clone_repository(self, **kwargs):
|
|
||||||
target_path = kwargs["target_path"]
|
|
||||||
target_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
(target_path / "_manifest.json").write_text(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"manifest_version": 2,
|
|
||||||
"id": "author.declared",
|
|
||||||
"name": "Declared Plugin",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"author": {"name": "author"},
|
|
||||||
}
|
|
||||||
),
|
|
||||||
encoding="utf-8",
|
|
||||||
)
|
|
||||||
return {"success": True}
|
|
||||||
|
|
||||||
monkeypatch.setattr(management_module, "get_git_mirror_service", lambda: FakeGitMirrorService())
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/api/webui/plugins/install",
|
|
||||||
json={
|
|
||||||
"plugin_id": "market.plugin",
|
|
||||||
"repository_url": "https://github.com/author/declared",
|
|
||||||
"branch": "main",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
plugin_path = support_module.resolve_installed_plugin_path("author.declared")
|
|
||||||
assert plugin_path is not None
|
|
||||||
manifest = json.loads((plugin_path / "_manifest.json").read_text(encoding="utf-8"))
|
|
||||||
assert manifest["id"] == "author.declared"
|
|
||||||
|
|
||||||
|
|
||||||
def test_install_plugin_backfills_missing_manifest_id(client: TestClient, monkeypatch):
|
|
||||||
class FakeGitMirrorService:
|
|
||||||
async def clone_repository(self, **kwargs):
|
|
||||||
target_path = kwargs["target_path"]
|
|
||||||
target_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
(target_path / "_manifest.json").write_text(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"manifest_version": 2,
|
|
||||||
"name": "Legacy Plugin",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"author": {"name": "author"},
|
|
||||||
}
|
|
||||||
),
|
|
||||||
encoding="utf-8",
|
|
||||||
)
|
|
||||||
return {"success": True}
|
|
||||||
|
|
||||||
monkeypatch.setattr(management_module, "get_git_mirror_service", lambda: FakeGitMirrorService())
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/api/webui/plugins/install",
|
|
||||||
json={
|
|
||||||
"plugin_id": "market.legacy",
|
|
||||||
"repository_url": "https://github.com/author/legacy",
|
|
||||||
"branch": "main",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
plugin_path = support_module.resolve_installed_plugin_path("market.legacy")
|
|
||||||
assert plugin_path is not None
|
|
||||||
manifest = json.loads((plugin_path / "_manifest.json").read_text(encoding="utf-8"))
|
|
||||||
assert manifest["id"] == "market.legacy"
|
|
||||||
@@ -1,332 +0,0 @@
|
|||||||
from contextlib import contextmanager
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from typing import Any, Iterator
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.services import statistics_service
|
|
||||||
from src.webui.schemas.statistics import DashboardData, StatisticsSummary, TimeSeriesData
|
|
||||||
|
|
||||||
|
|
||||||
class _Result:
|
|
||||||
def __init__(self, *, first_value: Any = None, all_values: list[Any] | None = None) -> None:
|
|
||||||
self._first_value = first_value
|
|
||||||
self._all_values = all_values or []
|
|
||||||
|
|
||||||
def first(self) -> Any:
|
|
||||||
return self._first_value
|
|
||||||
|
|
||||||
def all(self) -> list[Any]:
|
|
||||||
return self._all_values
|
|
||||||
|
|
||||||
|
|
||||||
class _Session:
|
|
||||||
def __init__(self, results: list[_Result]) -> None:
|
|
||||||
self._results = results
|
|
||||||
|
|
||||||
def exec(self, statement: Any) -> _Result:
|
|
||||||
del statement
|
|
||||||
return self._results.pop(0)
|
|
||||||
|
|
||||||
|
|
||||||
class _MemoryStore:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.store: dict[str, Any] = {}
|
|
||||||
|
|
||||||
def __getitem__(self, item: str) -> Any:
|
|
||||||
return self.store.get(item)
|
|
||||||
|
|
||||||
def __setitem__(self, key: str, value: Any) -> None:
|
|
||||||
self.store[key] = value
|
|
||||||
|
|
||||||
|
|
||||||
def _patch_session_results(monkeypatch: pytest.MonkeyPatch, results: list[_Result]) -> list[bool]:
|
|
||||||
auto_commit_calls: list[bool] = []
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_Session]:
|
|
||||||
auto_commit_calls.append(auto_commit)
|
|
||||||
yield _Session([results.pop(0)])
|
|
||||||
|
|
||||||
monkeypatch.setattr(statistics_service, "get_db_session", _fake_get_db_session)
|
|
||||||
return auto_commit_calls
|
|
||||||
|
|
||||||
|
|
||||||
def _patch_session_result_group(monkeypatch: pytest.MonkeyPatch, results: list[_Result]) -> list[bool]:
|
|
||||||
auto_commit_calls: list[bool] = []
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_Session]:
|
|
||||||
auto_commit_calls.append(auto_commit)
|
|
||||||
yield _Session(results)
|
|
||||||
|
|
||||||
monkeypatch.setattr(statistics_service, "get_db_session", _fake_get_db_session)
|
|
||||||
return auto_commit_calls
|
|
||||||
|
|
||||||
|
|
||||||
def _build_dashboard_data(total_requests: int = 1) -> DashboardData:
|
|
||||||
return DashboardData(
|
|
||||||
summary=StatisticsSummary(total_requests=total_requests),
|
|
||||||
model_stats=[],
|
|
||||||
hourly_data=[],
|
|
||||||
daily_data=[],
|
|
||||||
recent_activity=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_dashboard_data_with_time_series() -> DashboardData:
|
|
||||||
return DashboardData(
|
|
||||||
summary=StatisticsSummary(total_requests=1),
|
|
||||||
model_stats=[],
|
|
||||||
hourly_data=[
|
|
||||||
TimeSeriesData(timestamp="2026-05-06T10:00:00", requests=0, cost=0.0, tokens=0),
|
|
||||||
TimeSeriesData(timestamp="2026-05-06T11:00:00", requests=2, cost=0.5, tokens=50),
|
|
||||||
TimeSeriesData(timestamp="2026-05-06T12:00:00", requests=0, cost=0.0, tokens=0),
|
|
||||||
],
|
|
||||||
daily_data=[
|
|
||||||
TimeSeriesData(timestamp="2026-05-05T00:00:00", requests=0, cost=0.0, tokens=0),
|
|
||||||
TimeSeriesData(timestamp="2026-05-06T00:00:00", requests=3, cost=0.7, tokens=70),
|
|
||||||
],
|
|
||||||
recent_activity=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_shared_fetch_queries_disable_auto_commit(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
now = datetime(2026, 5, 6, 12, 0, 0)
|
|
||||||
online_record = SimpleNamespace(start_timestamp=now - timedelta(minutes=5), end_timestamp=now)
|
|
||||||
usage_record = SimpleNamespace(
|
|
||||||
timestamp=now,
|
|
||||||
request_type="chat.reply",
|
|
||||||
model_api_provider_name="provider",
|
|
||||||
model_assign_name="chat-main",
|
|
||||||
model_name="gpt-a",
|
|
||||||
prompt_tokens=10,
|
|
||||||
completion_tokens=5,
|
|
||||||
cost=0.01,
|
|
||||||
time_cost=1.2,
|
|
||||||
)
|
|
||||||
message_record = SimpleNamespace(timestamp=now, message_id="msg-1")
|
|
||||||
tool_record = SimpleNamespace(timestamp=now, tool_name="reply")
|
|
||||||
auto_commit_calls = _patch_session_results(
|
|
||||||
monkeypatch,
|
|
||||||
[
|
|
||||||
_Result(all_values=[online_record]),
|
|
||||||
_Result(all_values=[usage_record]),
|
|
||||||
_Result(all_values=[message_record]),
|
|
||||||
_Result(all_values=[tool_record]),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
online_ranges = statistics_service.fetch_online_time_since(now - timedelta(hours=1))
|
|
||||||
usage_records = statistics_service.fetch_model_usage_since(now - timedelta(hours=1))
|
|
||||||
messages = statistics_service.fetch_messages_since(now - timedelta(hours=1))
|
|
||||||
tool_records = statistics_service.fetch_tool_records_since(now - timedelta(hours=1))
|
|
||||||
|
|
||||||
assert online_ranges == [(online_record.start_timestamp, online_record.end_timestamp)]
|
|
||||||
assert usage_records == [
|
|
||||||
{
|
|
||||||
"timestamp": now,
|
|
||||||
"request_type": "chat.reply",
|
|
||||||
"model_api_provider_name": "provider",
|
|
||||||
"model_assign_name": "chat-main",
|
|
||||||
"model_name": "gpt-a",
|
|
||||||
"prompt_tokens": 10,
|
|
||||||
"completion_tokens": 5,
|
|
||||||
"cost": 0.01,
|
|
||||||
"time_cost": 1.2,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
assert messages == [message_record]
|
|
||||||
assert tool_records == [tool_record]
|
|
||||||
assert auto_commit_calls == [False, False, False, False]
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_earliest_statistics_time_uses_min_valid_timestamp(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
fallback_time = datetime(2026, 5, 6, 12, 0, 0)
|
|
||||||
earliest_time = datetime(2026, 5, 1, 8, 30, 0)
|
|
||||||
auto_commit_calls = _patch_session_result_group(
|
|
||||||
monkeypatch,
|
|
||||||
[
|
|
||||||
_Result(first_value=datetime(2026, 5, 3, 9, 0, 0)),
|
|
||||||
_Result(first_value=earliest_time),
|
|
||||||
_Result(first_value=None),
|
|
||||||
_Result(first_value=datetime(2026, 5, 2, 9, 0, 0)),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
result = statistics_service.get_earliest_statistics_time(fallback_time)
|
|
||||||
|
|
||||||
assert result == earliest_time
|
|
||||||
assert auto_commit_calls == [False]
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_earliest_statistics_time_falls_back_when_query_fails(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
fallback_time = datetime(2026, 5, 6, 12, 0, 0)
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_Session]:
|
|
||||||
del auto_commit
|
|
||||||
raise RuntimeError("database unavailable")
|
|
||||||
yield _Session([])
|
|
||||||
|
|
||||||
monkeypatch.setattr(statistics_service, "get_db_session", _fake_get_db_session)
|
|
||||||
|
|
||||||
assert statistics_service.get_earliest_statistics_time(fallback_time) == fallback_time
|
|
||||||
|
|
||||||
|
|
||||||
def test_dashboard_statistics_cache_roundtrip(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
memory_store = _MemoryStore()
|
|
||||||
now = datetime.now()
|
|
||||||
dashboard_data = _build_dashboard_data(total_requests=7)
|
|
||||||
monkeypatch.setattr(statistics_service, "local_storage", memory_store)
|
|
||||||
|
|
||||||
statistics_service.store_dashboard_statistics_cache({24: dashboard_data}, generated_at=now)
|
|
||||||
cached_data = statistics_service.get_cached_dashboard_statistics(24)
|
|
||||||
|
|
||||||
assert cached_data is not None
|
|
||||||
assert cached_data.summary.total_requests == 7
|
|
||||||
|
|
||||||
|
|
||||||
def test_dashboard_statistics_cache_stores_sparse_time_series(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
memory_store = _MemoryStore()
|
|
||||||
generated_at = datetime(2026, 5, 6, 12, 0, 0)
|
|
||||||
dashboard_data = _build_dashboard_data_with_time_series()
|
|
||||||
monkeypatch.setattr(statistics_service, "local_storage", memory_store)
|
|
||||||
|
|
||||||
statistics_service.store_dashboard_statistics_cache({2: dashboard_data}, generated_at=generated_at)
|
|
||||||
|
|
||||||
raw_cache = memory_store[statistics_service.DASHBOARD_STATISTICS_CACHE_KEY]
|
|
||||||
raw_entry = raw_cache["entries"]["2"]
|
|
||||||
assert raw_entry["sparse"] is True
|
|
||||||
assert raw_entry["hourly_data"] == [
|
|
||||||
{"timestamp": "2026-05-06T11:00:00", "requests": 2, "cost": 0.5, "tokens": 50}
|
|
||||||
]
|
|
||||||
assert raw_entry["daily_data"] == [
|
|
||||||
{"timestamp": "2026-05-06T00:00:00", "requests": 3, "cost": 0.7, "tokens": 70}
|
|
||||||
]
|
|
||||||
|
|
||||||
cached_data = statistics_service.get_cached_dashboard_statistics(2, max_age_seconds=10**9)
|
|
||||||
assert cached_data is not None
|
|
||||||
assert [item.timestamp for item in cached_data.hourly_data] == [
|
|
||||||
"2026-05-06T10:00:00",
|
|
||||||
"2026-05-06T11:00:00",
|
|
||||||
"2026-05-06T12:00:00",
|
|
||||||
]
|
|
||||||
assert cached_data.hourly_data[0].requests == 0
|
|
||||||
assert cached_data.hourly_data[1].requests == 2
|
|
||||||
assert cached_data.hourly_data[2].requests == 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_dashboard_statistics_prefers_cache(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
memory_store = _MemoryStore()
|
|
||||||
dashboard_data = _build_dashboard_data(total_requests=9)
|
|
||||||
monkeypatch.setattr(statistics_service, "local_storage", memory_store)
|
|
||||||
statistics_service.store_dashboard_statistics_cache({24: dashboard_data}, generated_at=datetime.now())
|
|
||||||
|
|
||||||
async def _fail_compute_dashboard_statistics(hours: int = 24) -> DashboardData:
|
|
||||||
del hours
|
|
||||||
raise AssertionError("cache should be used")
|
|
||||||
|
|
||||||
monkeypatch.setattr(statistics_service, "compute_dashboard_statistics", _fail_compute_dashboard_statistics)
|
|
||||||
|
|
||||||
result = await statistics_service.get_dashboard_statistics(24)
|
|
||||||
|
|
||||||
assert result.summary.total_requests == 9
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_dashboard_statistics_returns_empty_when_cache_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
memory_store = _MemoryStore()
|
|
||||||
monkeypatch.setattr(statistics_service, "local_storage", memory_store)
|
|
||||||
|
|
||||||
async def _fail_compute_dashboard_statistics(hours: int = 24) -> DashboardData:
|
|
||||||
del hours
|
|
||||||
raise AssertionError("dashboard API should not compute fallback data")
|
|
||||||
|
|
||||||
monkeypatch.setattr(statistics_service, "compute_dashboard_statistics", _fail_compute_dashboard_statistics)
|
|
||||||
|
|
||||||
result = await statistics_service.get_dashboard_statistics(24)
|
|
||||||
|
|
||||||
assert result.summary.total_requests == 0
|
|
||||||
assert result.model_stats == []
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_summary_statistics_aggregates_database_and_message_counts(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
start_time = datetime(2026, 5, 6, 10, 0, 0)
|
|
||||||
end_time = datetime(2026, 5, 6, 12, 0, 0)
|
|
||||||
online_records = [
|
|
||||||
SimpleNamespace(
|
|
||||||
start_timestamp=start_time - timedelta(minutes=30),
|
|
||||||
end_timestamp=start_time + timedelta(minutes=30),
|
|
||||||
),
|
|
||||||
SimpleNamespace(
|
|
||||||
start_timestamp=start_time + timedelta(hours=1),
|
|
||||||
end_timestamp=end_time + timedelta(minutes=30),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
auto_commit_calls = _patch_session_results(
|
|
||||||
monkeypatch,
|
|
||||||
[
|
|
||||||
_Result(first_value=(3, 1.5, 900, 2.5)),
|
|
||||||
_Result(all_values=online_records),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
def _fake_count_messages(**kwargs: Any) -> int:
|
|
||||||
return 5 if kwargs.get("has_reply_to") is None else 2
|
|
||||||
|
|
||||||
monkeypatch.setattr(statistics_service, "count_messages", _fake_count_messages)
|
|
||||||
|
|
||||||
summary = await statistics_service.get_summary_statistics(start_time, end_time)
|
|
||||||
|
|
||||||
assert summary.total_requests == 3
|
|
||||||
assert summary.total_cost == 1.5
|
|
||||||
assert summary.total_tokens == 900
|
|
||||||
assert summary.avg_response_time == 2.5
|
|
||||||
assert summary.online_time == 5400
|
|
||||||
assert summary.total_messages == 5
|
|
||||||
assert summary.total_replies == 2
|
|
||||||
assert summary.cost_per_hour == pytest.approx(1.0)
|
|
||||||
assert summary.tokens_per_hour == pytest.approx(600.0)
|
|
||||||
assert auto_commit_calls == [False, False]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_model_statistics_groups_by_display_model_name(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
now = datetime(2026, 5, 6, 12, 0, 0)
|
|
||||||
records = [
|
|
||||||
SimpleNamespace(
|
|
||||||
model_assign_name="chat-main",
|
|
||||||
model_name="gpt-a",
|
|
||||||
cost=0.4,
|
|
||||||
total_tokens=100,
|
|
||||||
time_cost=2.0,
|
|
||||||
),
|
|
||||||
SimpleNamespace(
|
|
||||||
model_assign_name="chat-main",
|
|
||||||
model_name="gpt-a",
|
|
||||||
cost=0.6,
|
|
||||||
total_tokens=200,
|
|
||||||
time_cost=4.0,
|
|
||||||
),
|
|
||||||
SimpleNamespace(
|
|
||||||
model_assign_name=None,
|
|
||||||
model_name="gpt-b",
|
|
||||||
cost=0.2,
|
|
||||||
total_tokens=50,
|
|
||||||
time_cost=0.0,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
_patch_session_results(monkeypatch, [_Result(all_values=records)])
|
|
||||||
|
|
||||||
stats = await statistics_service.get_model_statistics(now - timedelta(hours=24))
|
|
||||||
|
|
||||||
assert [item.model_name for item in stats] == ["chat-main", "gpt-b"]
|
|
||||||
assert stats[0].request_count == 2
|
|
||||||
assert stats[0].total_cost == pytest.approx(1.0)
|
|
||||||
assert stats[0].total_tokens == 300
|
|
||||||
assert stats[0].avg_response_time == pytest.approx(3.0)
|
|
||||||
assert stats[1].avg_response_time == 0.0
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
from src.webui.routers import system
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_newer_version_detects_patch_update() -> None:
|
|
||||||
assert system._is_newer_version("1.0.7", "1.0.6") is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_newer_version_ignores_same_version_with_shorter_parts() -> None:
|
|
||||||
assert system._is_newer_version("1.0.0", "1.0") is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_newer_version_handles_unknown_current_version() -> None:
|
|
||||||
assert system._is_newer_version("1.0.7", "unknown") is False
|
|
||||||
@@ -1,459 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import asyncio
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
import importlib
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional, Dict, Any
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
# 强制使用 utf-8,避免控制台编码报错
|
|
||||||
try:
|
|
||||||
if hasattr(sys.stdout, "reconfigure"):
|
|
||||||
sys.stdout.reconfigure(encoding="utf-8")
|
|
||||||
if hasattr(sys.stderr, "reconfigure"):
|
|
||||||
sys.stderr.reconfigure(encoding="utf-8")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 确保能导入 src.*
|
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
|
||||||
|
|
||||||
from src.common.logger import initialize_logging, get_logger
|
|
||||||
from src.common.database.database import db
|
|
||||||
from src.common.database.database_model import LLMUsage
|
|
||||||
from src.common.data_models.mai_message_data_model import UserInfo, GroupInfo
|
|
||||||
|
|
||||||
try:
|
|
||||||
from maim_message import ChatStream, UserInfo, GroupInfo
|
|
||||||
except Exception:
|
|
||||||
@dataclass
|
|
||||||
class ChatStream:
|
|
||||||
stream_id: str
|
|
||||||
platform: str
|
|
||||||
user_info: UserInfo
|
|
||||||
group_info: GroupInfo
|
|
||||||
|
|
||||||
logger = get_logger("test_memory_retrieval")
|
|
||||||
|
|
||||||
|
|
||||||
# 使用 importlib 动态导入,避免循环导入问题
|
|
||||||
def _import_memory_retrieval():
|
|
||||||
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
|
|
||||||
try:
|
|
||||||
# 先导入 prompt_builder,检查 prompt 是否已经初始化
|
|
||||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
|
||||||
|
|
||||||
# 检查 memory_retrieval 相关的 prompt 是否已经注册
|
|
||||||
# 如果已经注册,说明模块可能已经通过其他路径初始化过了
|
|
||||||
prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts
|
|
||||||
|
|
||||||
module_name = "src.memory_system.memory_retrieval"
|
|
||||||
|
|
||||||
# 如果 prompt 已经初始化,尝试直接使用已加载的模块
|
|
||||||
if prompt_already_init and module_name in sys.modules:
|
|
||||||
existing_module = sys.modules[module_name]
|
|
||||||
if hasattr(existing_module, "init_memory_retrieval_prompt"):
|
|
||||||
return (
|
|
||||||
existing_module.init_memory_retrieval_prompt,
|
|
||||||
existing_module._react_agent_solve_question,
|
|
||||||
existing_module._process_single_question,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 如果模块已经在 sys.modules 中但部分初始化,先移除它
|
|
||||||
if module_name in sys.modules:
|
|
||||||
existing_module = sys.modules[module_name]
|
|
||||||
if not hasattr(existing_module, "init_memory_retrieval_prompt"):
|
|
||||||
# 模块部分初始化,移除它
|
|
||||||
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
|
|
||||||
del sys.modules[module_name]
|
|
||||||
# 清理可能相关的部分初始化模块
|
|
||||||
keys_to_remove = []
|
|
||||||
for key in sys.modules.keys():
|
|
||||||
if key.startswith("src.memory_system.") and key != "src.memory_system":
|
|
||||||
keys_to_remove.append(key)
|
|
||||||
for key in keys_to_remove:
|
|
||||||
try:
|
|
||||||
del sys.modules[key]
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载
|
|
||||||
# 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们
|
|
||||||
try:
|
|
||||||
# 先导入可能触发循环导入的模块,让它们完成初始化
|
|
||||||
import src.config.config
|
|
||||||
import src.chat.utils.prompt_builder
|
|
||||||
|
|
||||||
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval)
|
|
||||||
# 如果它们已经导入,就确保它们完全初始化
|
|
||||||
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval)
|
|
||||||
# 如果它们已经导入,就确保它们完全初始化
|
|
||||||
try:
|
|
||||||
import src.chat.replyer.group_generator # noqa: F401
|
|
||||||
except (ImportError, AttributeError):
|
|
||||||
pass # 如果导入失败,继续
|
|
||||||
try:
|
|
||||||
import src.chat.replyer.private_generator # noqa: F401
|
|
||||||
except (ImportError, AttributeError):
|
|
||||||
pass # 如果导入失败,继续
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"预加载依赖模块时出现警告: {e}")
|
|
||||||
|
|
||||||
# 现在尝试导入 memory_retrieval
|
|
||||||
# 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval
|
|
||||||
memory_retrieval_module = importlib.import_module(module_name)
|
|
||||||
|
|
||||||
return (
|
|
||||||
memory_retrieval_module.init_memory_retrieval_prompt,
|
|
||||||
memory_retrieval_module._react_agent_solve_question,
|
|
||||||
memory_retrieval_module._process_single_question,
|
|
||||||
)
|
|
||||||
except (ImportError, AttributeError) as e:
|
|
||||||
logger.error(f"导入 memory_retrieval 模块失败: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def create_test_chat_stream(chat_id: str = "test_memory_retrieval") -> ChatStream:
|
|
||||||
"""创建一个测试用的 ChatStream 对象"""
|
|
||||||
user_info = UserInfo(
|
|
||||||
platform="test",
|
|
||||||
user_id="test_user",
|
|
||||||
user_nickname="测试用户",
|
|
||||||
)
|
|
||||||
group_info = GroupInfo(
|
|
||||||
platform="test",
|
|
||||||
group_id="test_group",
|
|
||||||
group_name="测试群组",
|
|
||||||
)
|
|
||||||
return ChatStream(
|
|
||||||
stream_id=chat_id,
|
|
||||||
platform="test",
|
|
||||||
user_info=user_info,
|
|
||||||
group_info=group_info,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_token_usage_since(start_time: float) -> Dict[str, Any]:
|
|
||||||
"""获取从指定时间开始的token使用情况
|
|
||||||
|
|
||||||
Args:
|
|
||||||
start_time: 开始时间戳
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
包含token使用统计的字典
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
start_datetime = datetime.fromtimestamp(start_time)
|
|
||||||
|
|
||||||
# 查询从开始时间到现在的所有memory相关的token使用记录
|
|
||||||
records = (
|
|
||||||
LLMUsage.select()
|
|
||||||
.where(
|
|
||||||
(LLMUsage.timestamp >= start_datetime)
|
|
||||||
& (
|
|
||||||
(LLMUsage.request_type.like("%memory%"))
|
|
||||||
| (LLMUsage.request_type == "memory.question")
|
|
||||||
| (LLMUsage.request_type == "memory.react")
|
|
||||||
| (LLMUsage.request_type == "memory.react.final")
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.order_by(LLMUsage.timestamp.asc())
|
|
||||||
)
|
|
||||||
|
|
||||||
total_prompt_tokens = 0
|
|
||||||
total_completion_tokens = 0
|
|
||||||
total_tokens = 0
|
|
||||||
total_cost = 0.0
|
|
||||||
request_count = 0
|
|
||||||
model_usage = {} # 按模型统计
|
|
||||||
|
|
||||||
for record in records:
|
|
||||||
total_prompt_tokens += record.prompt_tokens or 0
|
|
||||||
total_completion_tokens += record.completion_tokens or 0
|
|
||||||
total_tokens += record.total_tokens or 0
|
|
||||||
total_cost += record.cost or 0.0
|
|
||||||
request_count += 1
|
|
||||||
|
|
||||||
# 按模型统计
|
|
||||||
model_name = record.model_name or "unknown"
|
|
||||||
if model_name not in model_usage:
|
|
||||||
model_usage[model_name] = {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"total_tokens": 0,
|
|
||||||
"cost": 0.0,
|
|
||||||
"request_count": 0,
|
|
||||||
}
|
|
||||||
model_usage[model_name]["prompt_tokens"] += record.prompt_tokens or 0
|
|
||||||
model_usage[model_name]["completion_tokens"] += record.completion_tokens or 0
|
|
||||||
model_usage[model_name]["total_tokens"] += record.total_tokens or 0
|
|
||||||
model_usage[model_name]["cost"] += record.cost or 0.0
|
|
||||||
model_usage[model_name]["request_count"] += 1
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total_prompt_tokens": total_prompt_tokens,
|
|
||||||
"total_completion_tokens": total_completion_tokens,
|
|
||||||
"total_tokens": total_tokens,
|
|
||||||
"total_cost": total_cost,
|
|
||||||
"request_count": request_count,
|
|
||||||
"model_usage": model_usage,
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取token使用情况失败: {e}")
|
|
||||||
return {
|
|
||||||
"total_prompt_tokens": 0,
|
|
||||||
"total_completion_tokens": 0,
|
|
||||||
"total_tokens": 0,
|
|
||||||
"total_cost": 0.0,
|
|
||||||
"request_count": 0,
|
|
||||||
"model_usage": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def format_thinking_steps(thinking_steps: list) -> str:
|
|
||||||
"""格式化思考步骤为可读字符串"""
|
|
||||||
if not thinking_steps:
|
|
||||||
return "无思考步骤"
|
|
||||||
|
|
||||||
lines = []
|
|
||||||
for step in thinking_steps:
|
|
||||||
iteration = step.get("iteration", "?")
|
|
||||||
thought = step.get("thought", "")
|
|
||||||
actions = step.get("actions", [])
|
|
||||||
observations = step.get("observations", [])
|
|
||||||
|
|
||||||
lines.append(f"\n--- 迭代 {iteration} ---")
|
|
||||||
if thought:
|
|
||||||
lines.append(f"思考: {thought[:200]}...")
|
|
||||||
|
|
||||||
if actions:
|
|
||||||
lines.append("行动:")
|
|
||||||
for action in actions:
|
|
||||||
action_type = action.get("action_type", "unknown")
|
|
||||||
action_params = action.get("action_params", {})
|
|
||||||
lines.append(f" - {action_type}: {json.dumps(action_params, ensure_ascii=False)}")
|
|
||||||
|
|
||||||
if observations:
|
|
||||||
lines.append("观察:")
|
|
||||||
for obs in observations:
|
|
||||||
obs_str = str(obs)[:200]
|
|
||||||
if len(str(obs)) > 200:
|
|
||||||
obs_str += "..."
|
|
||||||
lines.append(f" - {obs_str}")
|
|
||||||
|
|
||||||
return "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_memory_retrieval(
|
|
||||||
question: str,
|
|
||||||
chat_id: str = "test_memory_retrieval",
|
|
||||||
context: str = "",
|
|
||||||
max_iterations: Optional[int] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""测试记忆检索功能
|
|
||||||
|
|
||||||
Args:
|
|
||||||
question: 要查询的问题
|
|
||||||
chat_id: 聊天ID
|
|
||||||
context: 上下文信息
|
|
||||||
max_iterations: 最大迭代次数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
包含测试结果的字典
|
|
||||||
"""
|
|
||||||
print("\n" + "=" * 80)
|
|
||||||
print("[测试] 记忆检索测试")
|
|
||||||
print(f"[问题] {question}")
|
|
||||||
print("=" * 80)
|
|
||||||
|
|
||||||
# 记录开始时间
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
# 延迟导入并初始化记忆检索prompt(这会自动加载 global_config)
|
|
||||||
# 注意:必须在函数内部调用,避免在模块级别触发循环导入
|
|
||||||
try:
|
|
||||||
init_memory_retrieval_prompt, _react_agent_solve_question, _ = _import_memory_retrieval()
|
|
||||||
|
|
||||||
# 检查 prompt 是否已经初始化,避免重复初始化
|
|
||||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
|
||||||
|
|
||||||
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
|
|
||||||
init_memory_retrieval_prompt()
|
|
||||||
else:
|
|
||||||
logger.debug("记忆检索 prompt 已经初始化,跳过重复初始化")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"初始化记忆检索模块失败: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
# 获取 global_config(此时应该已经加载)
|
|
||||||
from src.config.config import global_config
|
|
||||||
|
|
||||||
# 直接调用 _react_agent_solve_question 来获取详细的迭代信息
|
|
||||||
if max_iterations is None:
|
|
||||||
max_iterations = global_config.memory.max_agent_iterations
|
|
||||||
|
|
||||||
timeout = global_config.memory.agent_timeout_seconds
|
|
||||||
|
|
||||||
print("\n[配置]")
|
|
||||||
print(f" 最大迭代次数: {max_iterations}")
|
|
||||||
print(f" 超时时间: {timeout}秒")
|
|
||||||
print(f" 聊天ID: {chat_id}")
|
|
||||||
|
|
||||||
# 执行检索
|
|
||||||
print(f"\n[开始检索] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
|
|
||||||
|
|
||||||
found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question(
|
|
||||||
question=question,
|
|
||||||
chat_id=chat_id,
|
|
||||||
max_iterations=max_iterations,
|
|
||||||
timeout=timeout,
|
|
||||||
initial_info="",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 记录结束时间
|
|
||||||
end_time = time.time()
|
|
||||||
elapsed_time = end_time - start_time
|
|
||||||
|
|
||||||
# 获取token使用情况
|
|
||||||
token_usage = get_token_usage_since(start_time)
|
|
||||||
|
|
||||||
# 构建结果
|
|
||||||
result = {
|
|
||||||
"question": question,
|
|
||||||
"found_answer": found_answer,
|
|
||||||
"answer": answer,
|
|
||||||
"is_timeout": is_timeout,
|
|
||||||
"elapsed_time": elapsed_time,
|
|
||||||
"thinking_steps": thinking_steps,
|
|
||||||
"iteration_count": len(thinking_steps),
|
|
||||||
"token_usage": token_usage,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 输出结果
|
|
||||||
print(f"\n[检索完成] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
|
|
||||||
print("\n[结果]")
|
|
||||||
print(f" 是否找到答案: {'是' if found_answer else '否'}")
|
|
||||||
if found_answer and answer:
|
|
||||||
print(f" 答案: {answer}")
|
|
||||||
else:
|
|
||||||
print(" 答案: (未找到答案)")
|
|
||||||
print(f" 是否超时: {'是' if is_timeout else '否'}")
|
|
||||||
print(f" 迭代次数: {len(thinking_steps)}")
|
|
||||||
print(f" 总耗时: {elapsed_time:.2f}秒")
|
|
||||||
|
|
||||||
print("\n[Token使用情况]")
|
|
||||||
print(f" 总请求数: {token_usage['request_count']}")
|
|
||||||
print(f" 总Prompt Tokens: {token_usage['total_prompt_tokens']:,}")
|
|
||||||
print(f" 总Completion Tokens: {token_usage['total_completion_tokens']:,}")
|
|
||||||
print(f" 总Tokens: {token_usage['total_tokens']:,}")
|
|
||||||
print(f" 总成本: ${token_usage['total_cost']:.6f}")
|
|
||||||
|
|
||||||
if token_usage["model_usage"]:
|
|
||||||
print("\n[按模型统计]")
|
|
||||||
for model_name, usage in token_usage["model_usage"].items():
|
|
||||||
print(f" {model_name}:")
|
|
||||||
print(f" 请求数: {usage['request_count']}")
|
|
||||||
print(f" Prompt Tokens: {usage['prompt_tokens']:,}")
|
|
||||||
print(f" Completion Tokens: {usage['completion_tokens']:,}")
|
|
||||||
print(f" 总Tokens: {usage['total_tokens']:,}")
|
|
||||||
print(f" 成本: ${usage['cost']:.6f}")
|
|
||||||
|
|
||||||
print("\n[迭代详情]")
|
|
||||||
print(format_thinking_steps(thinking_steps))
|
|
||||||
|
|
||||||
print("\n" + "=" * 80)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="测试记忆检索功能。可以输入一个问题,脚本会使用记忆检索的逻辑进行检索,并记录迭代信息、时间和token总消耗。"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--chat-id",
|
|
||||||
default="test_memory_retrieval",
|
|
||||||
help="测试用的聊天ID(默认: test_memory_retrieval)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--context",
|
|
||||||
default="",
|
|
||||||
help="上下文信息(可选)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output",
|
|
||||||
"-o",
|
|
||||||
help="将结果保存到JSON文件(可选)",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# 初始化日志(使用较低的详细程度,避免输出过多日志)
|
|
||||||
initialize_logging(verbose=False)
|
|
||||||
|
|
||||||
# 交互式输入问题
|
|
||||||
print("\n" + "=" * 80)
|
|
||||||
print("记忆检索测试工具")
|
|
||||||
print("=" * 80)
|
|
||||||
question = input("\n请输入要查询的问题: ").strip()
|
|
||||||
if not question:
|
|
||||||
print("错误: 问题不能为空")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 交互式输入最大迭代次数
|
|
||||||
max_iterations_input = input("\n请输入最大迭代次数(直接回车使用配置默认值): ").strip()
|
|
||||||
max_iterations = None
|
|
||||||
if max_iterations_input:
|
|
||||||
try:
|
|
||||||
max_iterations = int(max_iterations_input)
|
|
||||||
if max_iterations <= 0:
|
|
||||||
print("警告: 迭代次数必须大于0,将使用配置默认值")
|
|
||||||
max_iterations = None
|
|
||||||
except ValueError:
|
|
||||||
print("警告: 无效的迭代次数,将使用配置默认值")
|
|
||||||
max_iterations = None
|
|
||||||
|
|
||||||
# 连接数据库
|
|
||||||
try:
|
|
||||||
db.connect(reuse_if_open=True)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"数据库连接失败: {e}")
|
|
||||||
print(f"错误: 数据库连接失败: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 运行测试
|
|
||||||
try:
|
|
||||||
result = asyncio.run(
|
|
||||||
test_memory_retrieval(
|
|
||||||
question=question,
|
|
||||||
chat_id=args.chat_id,
|
|
||||||
context=args.context,
|
|
||||||
max_iterations=max_iterations,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 如果指定了输出文件,保存结果
|
|
||||||
if args.output:
|
|
||||||
# 将thinking_steps转换为可序列化的格式
|
|
||||||
output_result = result.copy()
|
|
||||||
with open(args.output, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(output_result, f, ensure_ascii=False, indent=2)
|
|
||||||
print(f"\n[结果已保存] {args.output}")
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\n\n[中断] 用户中断测试")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"测试失败: {e}", exc_info=True)
|
|
||||||
print(f"\n[错误] 测试失败: {e}")
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
db.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,845 +0,0 @@
|
|||||||
from argparse import ArgumentParser, Namespace
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from dataclasses import asdict, dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, Iterator, List, Sequence
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
|
||||||
if str(PROJECT_ROOT) not in sys.path:
|
|
||||||
sys.path.insert(0, str(PROJECT_ROOT))
|
|
||||||
|
|
||||||
from src.common.data_models.llm_service_data_models import LLMServiceRequest, LLMServiceResult # noqa: E402
|
|
||||||
from src.config.config import config_manager # noqa: E402
|
|
||||||
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig # noqa: E402
|
|
||||||
from src.llm_models.payload_content.tool_option import ToolCall # noqa: E402
|
|
||||||
from src.services.llm_service import generate # noqa: E402
|
|
||||||
from src.services.service_task_resolver import get_available_models # noqa: E402
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_SKIP_TASKS = {"embedding", "voice"}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
|
||||||
class ToolCallCase:
|
|
||||||
"""Tool call 参数测试用例。"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
tool_definition: Dict[str, Any]
|
|
||||||
expected_arguments: Dict[str, Any]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def tool_name(self) -> str:
|
|
||||||
"""返回工具名称。"""
|
|
||||||
if self.tool_definition.get("type") == "function":
|
|
||||||
function_definition = self.tool_definition.get("function", {})
|
|
||||||
return str(function_definition.get("name", "") or "")
|
|
||||||
return str(self.tool_definition.get("name", "") or "")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters_schema(self) -> Dict[str, Any]:
|
|
||||||
"""返回参数 Schema。"""
|
|
||||||
if self.tool_definition.get("type") == "function":
|
|
||||||
function_definition = self.tool_definition.get("function", {})
|
|
||||||
parameters = function_definition.get("parameters", {})
|
|
||||||
return parameters if isinstance(parameters, dict) else {}
|
|
||||||
parameters = self.tool_definition.get("parameters", {})
|
|
||||||
return parameters if isinstance(parameters, dict) else {}
|
|
||||||
|
|
||||||
def build_messages(self) -> List[Dict[str, Any]]:
|
|
||||||
"""构造测试消息。"""
|
|
||||||
expected_json = json.dumps(self.expected_arguments, ensure_ascii=False, indent=2)
|
|
||||||
system_prompt = (
|
|
||||||
"你正在执行严格的工具调用参数兼容性测试。"
|
|
||||||
"你必须通过工具调用响应,不能输出自然语言,不能解释,不能补充额外字段。"
|
|
||||||
)
|
|
||||||
user_prompt = (
|
|
||||||
f"请立刻调用工具 `{self.tool_name}`。\n"
|
|
||||||
"参数必须与下面 JSON 完全一致,键名、值、布尔类型、整数类型、浮点数、数组顺序和对象结构都不能改变。\n"
|
|
||||||
"不要输出任何解释文本,只返回工具调用。\n"
|
|
||||||
f"{expected_json}"
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
{"role": "system", "content": system_prompt},
|
|
||||||
{"role": "user", "content": user_prompt},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
|
||||||
class ProbeTarget:
|
|
||||||
"""单个待测试模型目标。"""
|
|
||||||
|
|
||||||
task_name: str
|
|
||||||
model_name: str
|
|
||||||
provider_name: str
|
|
||||||
client_type: str
|
|
||||||
tool_argument_parse_mode: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
|
||||||
class ProbeResult:
|
|
||||||
"""单次测试结果。"""
|
|
||||||
|
|
||||||
task_name: str
|
|
||||||
target_model_name: str
|
|
||||||
actual_model_name: str
|
|
||||||
provider_name: str
|
|
||||||
client_type: str
|
|
||||||
tool_argument_parse_mode: str
|
|
||||||
case_name: str
|
|
||||||
attempt: int
|
|
||||||
success: bool
|
|
||||||
elapsed_seconds: float
|
|
||||||
errors: List[str]
|
|
||||||
warnings: List[str]
|
|
||||||
response_text: str
|
|
||||||
reasoning_text: str
|
|
||||||
tool_calls: List[Dict[str, Any]]
|
|
||||||
|
|
||||||
|
|
||||||
def _ensure_utf8_console() -> None:
|
|
||||||
"""尽量将控制台编码切到 UTF-8。"""
|
|
||||||
try:
|
|
||||||
if hasattr(sys.stdout, "reconfigure"):
|
|
||||||
sys.stdout.reconfigure(encoding="utf-8")
|
|
||||||
if hasattr(sys.stderr, "reconfigure"):
|
|
||||||
sys.stderr.reconfigure(encoding="utf-8")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def _build_function_tool(name: str, description: str, parameters: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""构造 OpenAI 风格 function tool 定义。"""
|
|
||||||
return {
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": name,
|
|
||||||
"description": description,
|
|
||||||
"parameters": parameters,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _build_default_cases() -> List[ToolCallCase]:
|
|
||||||
"""构造默认测试用例。"""
|
|
||||||
simple_expected_arguments = {
|
|
||||||
"request_id": "probe-simple-001",
|
|
||||||
"count": 7,
|
|
||||||
"enabled": True,
|
|
||||||
"mode": "strict",
|
|
||||||
"ratio": 2.5,
|
|
||||||
}
|
|
||||||
simple_parameters = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"request_id": {"type": "string", "description": "请求 ID"},
|
|
||||||
"count": {"type": "integer", "description": "数量"},
|
|
||||||
"enabled": {"type": "boolean", "description": "是否启用"},
|
|
||||||
"mode": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "模式",
|
|
||||||
"enum": ["strict", "loose"],
|
|
||||||
},
|
|
||||||
"ratio": {"type": "number", "description": "比例"},
|
|
||||||
},
|
|
||||||
"required": ["request_id", "count", "enabled", "mode", "ratio"],
|
|
||||||
"additionalProperties": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
nested_expected_arguments = {
|
|
||||||
"request_id": "probe-nested-001",
|
|
||||||
"notify": False,
|
|
||||||
"profile": {
|
|
||||||
"channel": "stable",
|
|
||||||
"priority": 2,
|
|
||||||
},
|
|
||||||
"tags": ["alpha", "beta", "gamma"],
|
|
||||||
"items": [
|
|
||||||
{"count": 2, "name": "apple"},
|
|
||||||
{"count": 5, "name": "banana"},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
nested_parameters = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"request_id": {"type": "string", "description": "请求 ID"},
|
|
||||||
"notify": {"type": "boolean", "description": "是否通知"},
|
|
||||||
"profile": {
|
|
||||||
"type": "object",
|
|
||||||
"description": "配置对象",
|
|
||||||
"properties": {
|
|
||||||
"channel": {"type": "string", "description": "渠道"},
|
|
||||||
"priority": {"type": "integer", "description": "优先级"},
|
|
||||||
},
|
|
||||||
"required": ["channel", "priority"],
|
|
||||||
"additionalProperties": False,
|
|
||||||
},
|
|
||||||
"tags": {
|
|
||||||
"type": "array",
|
|
||||||
"description": "标签列表",
|
|
||||||
"items": {"type": "string"},
|
|
||||||
},
|
|
||||||
"items": {
|
|
||||||
"type": "array",
|
|
||||||
"description": "条目列表",
|
|
||||||
"items": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"count": {"type": "integer", "description": "数量"},
|
|
||||||
"name": {"type": "string", "description": "名称"},
|
|
||||||
},
|
|
||||||
"required": ["count", "name"],
|
|
||||||
"additionalProperties": False,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["request_id", "notify", "profile", "tags", "items"],
|
|
||||||
"additionalProperties": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
return [
|
|
||||||
ToolCallCase(
|
|
||||||
name="simple",
|
|
||||||
description="标量参数类型校验",
|
|
||||||
tool_definition=_build_function_tool(
|
|
||||||
name="record_simple_probe",
|
|
||||||
description="记录简单参数探测结果",
|
|
||||||
parameters=simple_parameters,
|
|
||||||
),
|
|
||||||
expected_arguments=simple_expected_arguments,
|
|
||||||
),
|
|
||||||
ToolCallCase(
|
|
||||||
name="nested",
|
|
||||||
description="嵌套对象与数组参数校验",
|
|
||||||
tool_definition=_build_function_tool(
|
|
||||||
name="record_nested_probe",
|
|
||||||
description="记录嵌套参数探测结果",
|
|
||||||
parameters=nested_parameters,
|
|
||||||
),
|
|
||||||
expected_arguments=nested_expected_arguments,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_multi_value_args(raw_values: Sequence[str] | None) -> List[str]:
|
|
||||||
"""解析命令行中的多值参数。"""
|
|
||||||
parsed_values: List[str] = []
|
|
||||||
for raw_value in raw_values or []:
|
|
||||||
for item in str(raw_value).split(","):
|
|
||||||
normalized_item = item.strip()
|
|
||||||
if normalized_item:
|
|
||||||
parsed_values.append(normalized_item)
|
|
||||||
return parsed_values
|
|
||||||
|
|
||||||
|
|
||||||
def _build_model_map() -> Dict[str, ModelInfo]:
|
|
||||||
"""构造模型名称到模型配置的映射。"""
|
|
||||||
return {model.name: model for model in config_manager.get_model_config().models}
|
|
||||||
|
|
||||||
|
|
||||||
def _build_provider_map() -> Dict[str, APIProvider]:
|
|
||||||
"""构造 Provider 名称到配置的映射。"""
|
|
||||||
return {provider.name: provider for provider in config_manager.get_model_config().api_providers}
|
|
||||||
|
|
||||||
|
|
||||||
def _pick_default_task_name(task_names: Sequence[str]) -> str:
|
|
||||||
"""选择默认任务名。"""
|
|
||||||
if "utils" in task_names:
|
|
||||||
return "utils"
|
|
||||||
if not task_names:
|
|
||||||
raise ValueError("当前没有可用的任务配置")
|
|
||||||
return str(task_names[0])
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_targets(task_filters: Sequence[str], model_filters: Sequence[str], fallback_task: str) -> List[ProbeTarget]:
|
|
||||||
"""根据命令行参数解析待测试目标。"""
|
|
||||||
available_tasks = get_available_models()
|
|
||||||
model_map = _build_model_map()
|
|
||||||
provider_map = _build_provider_map()
|
|
||||||
|
|
||||||
if not available_tasks:
|
|
||||||
raise ValueError("未找到任何可用的模型任务配置")
|
|
||||||
|
|
||||||
if task_filters:
|
|
||||||
selected_task_names = []
|
|
||||||
for task_name in task_filters:
|
|
||||||
if task_name not in available_tasks:
|
|
||||||
raise ValueError(f"未找到任务 `{task_name}`")
|
|
||||||
selected_task_names.append(task_name)
|
|
||||||
else:
|
|
||||||
selected_task_names = [
|
|
||||||
task_name
|
|
||||||
for task_name in available_tasks
|
|
||||||
if task_name not in DEFAULT_SKIP_TASKS
|
|
||||||
]
|
|
||||||
|
|
||||||
if not selected_task_names:
|
|
||||||
raise ValueError("没有可用于 tool call 测试的任务,请显式通过 --task 指定")
|
|
||||||
|
|
||||||
default_task_name = fallback_task if fallback_task in available_tasks else _pick_default_task_name(selected_task_names)
|
|
||||||
resolved_targets: List[ProbeTarget] = []
|
|
||||||
seen_models: set[str] = set()
|
|
||||||
|
|
||||||
if model_filters:
|
|
||||||
model_names = list(model_filters)
|
|
||||||
else:
|
|
||||||
model_names = []
|
|
||||||
for task_name in selected_task_names:
|
|
||||||
task_config = available_tasks[task_name]
|
|
||||||
for model_name in task_config.model_list:
|
|
||||||
if model_name not in model_names:
|
|
||||||
model_names.append(model_name)
|
|
||||||
|
|
||||||
for model_name in model_names:
|
|
||||||
if model_name in seen_models:
|
|
||||||
continue
|
|
||||||
if model_name not in model_map:
|
|
||||||
raise ValueError(f"未找到模型 `{model_name}`")
|
|
||||||
|
|
||||||
target_task_name = ""
|
|
||||||
for task_name in selected_task_names:
|
|
||||||
if model_name in available_tasks[task_name].model_list:
|
|
||||||
target_task_name = task_name
|
|
||||||
break
|
|
||||||
if not target_task_name:
|
|
||||||
target_task_name = default_task_name
|
|
||||||
|
|
||||||
model_info = model_map[model_name]
|
|
||||||
provider_info = provider_map[model_info.api_provider]
|
|
||||||
resolved_targets.append(
|
|
||||||
ProbeTarget(
|
|
||||||
task_name=target_task_name,
|
|
||||||
model_name=model_name,
|
|
||||||
provider_name=provider_info.name,
|
|
||||||
client_type=provider_info.client_type,
|
|
||||||
tool_argument_parse_mode=provider_info.tool_argument_parse_mode,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
seen_models.add(model_name)
|
|
||||||
|
|
||||||
return resolved_targets
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def _pin_task_to_model(task_name: str, model_name: str) -> Iterator[None]:
|
|
||||||
"""临时将某个任务锁定到单模型。"""
|
|
||||||
model_task_config = config_manager.get_model_config().model_task_config
|
|
||||||
task_config = getattr(model_task_config, task_name, None)
|
|
||||||
if not isinstance(task_config, TaskConfig):
|
|
||||||
raise ValueError(f"未找到任务 `{task_name}` 对应的配置")
|
|
||||||
|
|
||||||
original_model_list = list(task_config.model_list)
|
|
||||||
original_selection_strategy = task_config.selection_strategy
|
|
||||||
task_config.model_list = [model_name]
|
|
||||||
task_config.selection_strategy = "balance"
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
task_config.model_list = original_model_list
|
|
||||||
task_config.selection_strategy = original_selection_strategy
|
|
||||||
|
|
||||||
|
|
||||||
def _serialize_tool_calls(tool_calls: List[ToolCall] | None) -> List[Dict[str, Any]]:
|
|
||||||
"""序列化工具调用结果。"""
|
|
||||||
if not tool_calls:
|
|
||||||
return []
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"id": tool_call.call_id,
|
|
||||||
"function": {
|
|
||||||
"name": tool_call.func_name,
|
|
||||||
"arguments": dict(tool_call.args or {}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for tool_call in tool_calls
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _is_integer_value(value: Any) -> bool:
|
|
||||||
"""判断是否为整数类型且排除布尔值。"""
|
|
||||||
return isinstance(value, int) and not isinstance(value, bool)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_number_value(value: Any) -> bool:
|
|
||||||
"""判断是否为数值类型且排除布尔值。"""
|
|
||||||
return (isinstance(value, int) or isinstance(value, float)) and not isinstance(value, bool)
|
|
||||||
|
|
||||||
|
|
||||||
def _schema_type(schema: Dict[str, Any]) -> str:
|
|
||||||
"""解析 Schema 的类型。"""
|
|
||||||
schema_type = str(schema.get("type", "") or "").strip()
|
|
||||||
if schema_type:
|
|
||||||
return schema_type
|
|
||||||
if "properties" in schema or "required" in schema:
|
|
||||||
return "object"
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_schema(schema: Dict[str, Any], actual_value: Any, path: str = "args") -> List[str]:
|
|
||||||
"""按简化 JSON Schema 校验工具参数。"""
|
|
||||||
errors: List[str] = []
|
|
||||||
schema_type = _schema_type(schema)
|
|
||||||
|
|
||||||
if "enum" in schema and actual_value not in schema["enum"]:
|
|
||||||
errors.append(f"{path} 枚举值不合法,期望属于 {schema['enum']},实际为 {actual_value!r}")
|
|
||||||
|
|
||||||
if schema_type == "string":
|
|
||||||
if not isinstance(actual_value, str):
|
|
||||||
errors.append(f"{path} 类型错误,期望 string,实际为 {type(actual_value).__name__}")
|
|
||||||
return errors
|
|
||||||
|
|
||||||
if schema_type == "integer":
|
|
||||||
if not _is_integer_value(actual_value):
|
|
||||||
errors.append(f"{path} 类型错误,期望 integer,实际为 {type(actual_value).__name__}")
|
|
||||||
return errors
|
|
||||||
|
|
||||||
if schema_type == "number":
|
|
||||||
if not _is_number_value(actual_value):
|
|
||||||
errors.append(f"{path} 类型错误,期望 number,实际为 {type(actual_value).__name__}")
|
|
||||||
return errors
|
|
||||||
|
|
||||||
if schema_type == "boolean":
|
|
||||||
if not isinstance(actual_value, bool):
|
|
||||||
errors.append(f"{path} 类型错误,期望 boolean,实际为 {type(actual_value).__name__}")
|
|
||||||
return errors
|
|
||||||
|
|
||||||
if schema_type == "array":
|
|
||||||
if not isinstance(actual_value, list):
|
|
||||||
errors.append(f"{path} 类型错误,期望 array,实际为 {type(actual_value).__name__}")
|
|
||||||
return errors
|
|
||||||
item_schema = schema.get("items")
|
|
||||||
if isinstance(item_schema, dict):
|
|
||||||
for index, item in enumerate(actual_value):
|
|
||||||
errors.extend(_validate_schema(item_schema, item, f"{path}[{index}]"))
|
|
||||||
return errors
|
|
||||||
|
|
||||||
if schema_type == "object":
|
|
||||||
if not isinstance(actual_value, dict):
|
|
||||||
errors.append(f"{path} 类型错误,期望 object,实际为 {type(actual_value).__name__}")
|
|
||||||
return errors
|
|
||||||
|
|
||||||
properties = schema.get("properties", {})
|
|
||||||
required_fields = [str(item) for item in schema.get("required", [])]
|
|
||||||
for required_field in required_fields:
|
|
||||||
if required_field not in actual_value:
|
|
||||||
errors.append(f"{path}.{required_field} 缺少必填字段")
|
|
||||||
|
|
||||||
for field_name, field_value in actual_value.items():
|
|
||||||
field_path = f"{path}.{field_name}"
|
|
||||||
field_schema = properties.get(field_name)
|
|
||||||
if isinstance(field_schema, dict):
|
|
||||||
errors.extend(_validate_schema(field_schema, field_value, field_path))
|
|
||||||
continue
|
|
||||||
|
|
||||||
additional_properties = schema.get("additionalProperties", True)
|
|
||||||
if additional_properties is False:
|
|
||||||
errors.append(f"{field_path} 是未定义字段")
|
|
||||||
elif isinstance(additional_properties, dict):
|
|
||||||
errors.extend(_validate_schema(additional_properties, field_value, field_path))
|
|
||||||
return errors
|
|
||||||
|
|
||||||
return errors
|
|
||||||
|
|
||||||
|
|
||||||
def _compare_expected_values(expected_value: Any, actual_value: Any, path: str = "args") -> List[str]:
|
|
||||||
"""递归比较实际值与期望值是否完全一致。"""
|
|
||||||
errors: List[str] = []
|
|
||||||
|
|
||||||
if isinstance(expected_value, dict):
|
|
||||||
if not isinstance(actual_value, dict):
|
|
||||||
return [f"{path} 值不一致,期望 object,实际为 {type(actual_value).__name__}"]
|
|
||||||
|
|
||||||
expected_keys = set(expected_value.keys())
|
|
||||||
actual_keys = set(actual_value.keys())
|
|
||||||
for missing_key in sorted(expected_keys - actual_keys):
|
|
||||||
errors.append(f"{path}.{missing_key} 缺少期望字段")
|
|
||||||
for extra_key in sorted(actual_keys - expected_keys):
|
|
||||||
errors.append(f"{path}.{extra_key} 出现了额外字段")
|
|
||||||
for shared_key in sorted(expected_keys & actual_keys):
|
|
||||||
errors.extend(
|
|
||||||
_compare_expected_values(
|
|
||||||
expected_value[shared_key],
|
|
||||||
actual_value[shared_key],
|
|
||||||
f"{path}.{shared_key}",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return errors
|
|
||||||
|
|
||||||
if isinstance(expected_value, list):
|
|
||||||
if not isinstance(actual_value, list):
|
|
||||||
return [f"{path} 值不一致,期望 array,实际为 {type(actual_value).__name__}"]
|
|
||||||
|
|
||||||
if len(expected_value) != len(actual_value):
|
|
||||||
errors.append(f"{path} 列表长度不一致,期望 {len(expected_value)},实际 {len(actual_value)}")
|
|
||||||
for index, (expected_item, actual_item) in enumerate(
|
|
||||||
zip(expected_value, actual_value, strict=False)
|
|
||||||
):
|
|
||||||
errors.extend(_compare_expected_values(expected_item, actual_item, f"{path}[{index}]"))
|
|
||||||
return errors
|
|
||||||
|
|
||||||
if isinstance(expected_value, bool):
|
|
||||||
if not isinstance(actual_value, bool) or actual_value is not expected_value:
|
|
||||||
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
|
|
||||||
return errors
|
|
||||||
|
|
||||||
if _is_integer_value(expected_value):
|
|
||||||
if not _is_integer_value(actual_value) or actual_value != expected_value:
|
|
||||||
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
|
|
||||||
return errors
|
|
||||||
|
|
||||||
if isinstance(expected_value, float):
|
|
||||||
if not _is_number_value(actual_value) or float(actual_value) != expected_value:
|
|
||||||
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
|
|
||||||
return errors
|
|
||||||
|
|
||||||
if expected_value != actual_value:
|
|
||||||
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
|
|
||||||
return errors
|
|
||||||
|
|
||||||
|
|
||||||
def _pick_tool_call(tool_calls: List[ToolCall], expected_tool_name: str) -> ToolCall:
|
|
||||||
"""优先选择同名工具调用,否则回退到第一条。"""
|
|
||||||
for tool_call in tool_calls:
|
|
||||||
if tool_call.func_name == expected_tool_name:
|
|
||||||
return tool_call
|
|
||||||
return tool_calls[0]
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_service_result(
|
|
||||||
service_result: LLMServiceResult,
|
|
||||||
target: ProbeTarget,
|
|
||||||
case: ToolCallCase,
|
|
||||||
) -> tuple[List[str], List[str], List[Dict[str, Any]]]:
|
|
||||||
"""校验服务层返回结果。"""
|
|
||||||
errors: List[str] = []
|
|
||||||
warnings: List[str] = []
|
|
||||||
completion = service_result.completion
|
|
||||||
serialized_tool_calls = _serialize_tool_calls(completion.tool_calls)
|
|
||||||
|
|
||||||
if not service_result.success:
|
|
||||||
errors.append(service_result.error or completion.response or "请求失败但未返回错误信息")
|
|
||||||
return errors, warnings, serialized_tool_calls
|
|
||||||
|
|
||||||
if completion.model_name and completion.model_name != target.model_name:
|
|
||||||
errors.append(
|
|
||||||
f"实际命中的模型为 `{completion.model_name}`,与目标模型 `{target.model_name}` 不一致"
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_calls = completion.tool_calls or []
|
|
||||||
if not tool_calls:
|
|
||||||
errors.append("模型未返回 tool_calls")
|
|
||||||
if completion.response.strip():
|
|
||||||
warnings.append("模型返回了自然语言文本而不是工具调用")
|
|
||||||
return errors, warnings, serialized_tool_calls
|
|
||||||
|
|
||||||
if len(tool_calls) != 1:
|
|
||||||
errors.append(f"返回了 {len(tool_calls)} 个 tool_calls,预期为 1 个")
|
|
||||||
|
|
||||||
selected_tool_call = _pick_tool_call(tool_calls, case.tool_name)
|
|
||||||
if selected_tool_call.func_name != case.tool_name:
|
|
||||||
errors.append(
|
|
||||||
f"工具名不一致,期望 `{case.tool_name}`,实际 `{selected_tool_call.func_name}`"
|
|
||||||
)
|
|
||||||
|
|
||||||
actual_arguments = selected_tool_call.args
|
|
||||||
if not isinstance(actual_arguments, dict):
|
|
||||||
errors.append("工具参数未被解析为对象")
|
|
||||||
return errors, warnings, serialized_tool_calls
|
|
||||||
|
|
||||||
errors.extend(_validate_schema(case.parameters_schema, actual_arguments))
|
|
||||||
errors.extend(_compare_expected_values(case.expected_arguments, actual_arguments))
|
|
||||||
|
|
||||||
if completion.response.strip():
|
|
||||||
warnings.append("模型同时返回了自然语言文本")
|
|
||||||
return errors, warnings, serialized_tool_calls
|
|
||||||
|
|
||||||
|
|
||||||
async def _run_single_probe(
|
|
||||||
target: ProbeTarget,
|
|
||||||
case: ToolCallCase,
|
|
||||||
attempt: int,
|
|
||||||
max_tokens: int,
|
|
||||||
temperature: float,
|
|
||||||
) -> ProbeResult:
|
|
||||||
"""执行单次工具调用参数探测。"""
|
|
||||||
request = LLMServiceRequest(
|
|
||||||
task_name=target.task_name,
|
|
||||||
request_type=f"tool_call_param_probe.{case.name}.attempt_{attempt}",
|
|
||||||
prompt=case.build_messages(),
|
|
||||||
tool_options=[case.tool_definition],
|
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
started_at = time.perf_counter()
|
|
||||||
with _pin_task_to_model(target.task_name, target.model_name):
|
|
||||||
service_result = await generate(request)
|
|
||||||
elapsed_seconds = time.perf_counter() - started_at
|
|
||||||
|
|
||||||
errors, warnings, serialized_tool_calls = _validate_service_result(service_result, target, case)
|
|
||||||
completion = service_result.completion
|
|
||||||
return ProbeResult(
|
|
||||||
task_name=target.task_name,
|
|
||||||
target_model_name=target.model_name,
|
|
||||||
actual_model_name=completion.model_name,
|
|
||||||
provider_name=target.provider_name,
|
|
||||||
client_type=target.client_type,
|
|
||||||
tool_argument_parse_mode=target.tool_argument_parse_mode,
|
|
||||||
case_name=case.name,
|
|
||||||
attempt=attempt,
|
|
||||||
success=not errors,
|
|
||||||
elapsed_seconds=elapsed_seconds,
|
|
||||||
errors=errors,
|
|
||||||
warnings=warnings,
|
|
||||||
response_text=completion.response,
|
|
||||||
reasoning_text=completion.reasoning,
|
|
||||||
tool_calls=serialized_tool_calls,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _print_targets(targets: Sequence[ProbeTarget]) -> None:
|
|
||||||
"""打印待测试目标。"""
|
|
||||||
print("待测试目标:")
|
|
||||||
for index, target in enumerate(targets, start=1):
|
|
||||||
print(
|
|
||||||
f"{index}. model={target.model_name} | task={target.task_name} | "
|
|
||||||
f"provider={target.provider_name} | client={target.client_type} | "
|
|
||||||
f"tool_argument_parse_mode={target.tool_argument_parse_mode}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _print_available_targets() -> None:
|
|
||||||
"""打印当前可用任务与模型。"""
|
|
||||||
available_tasks = get_available_models()
|
|
||||||
model_map = _build_model_map()
|
|
||||||
task_names = list(available_tasks.keys())
|
|
||||||
|
|
||||||
print("当前可用任务:")
|
|
||||||
for task_name in task_names:
|
|
||||||
task_config = available_tasks[task_name]
|
|
||||||
print(f"- {task_name}: {list(task_config.model_list)}")
|
|
||||||
|
|
||||||
referenced_models = {
|
|
||||||
model_name
|
|
||||||
for task_config in available_tasks.values()
|
|
||||||
for model_name in task_config.model_list
|
|
||||||
}
|
|
||||||
|
|
||||||
print("\n当前配置中的模型:")
|
|
||||||
for model_name, model_info in model_map.items():
|
|
||||||
referenced_mark = "已被任务引用" if model_name in referenced_models else "未被任务引用"
|
|
||||||
print(
|
|
||||||
f"- {model_name}: provider={model_info.api_provider}, "
|
|
||||||
f"identifier={model_info.model_identifier}, {referenced_mark}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _select_cases(case_filters: Sequence[str]) -> List[ToolCallCase]:
|
|
||||||
"""根据参数筛选测试用例。"""
|
|
||||||
all_cases = {case.name: case for case in _build_default_cases()}
|
|
||||||
if not case_filters:
|
|
||||||
return list(all_cases.values())
|
|
||||||
|
|
||||||
selected_cases: List[ToolCallCase] = []
|
|
||||||
for case_name in case_filters:
|
|
||||||
if case_name not in all_cases:
|
|
||||||
raise ValueError(f"未知测试用例 `{case_name}`,可选值: {', '.join(sorted(all_cases))}")
|
|
||||||
selected_cases.append(all_cases[case_name])
|
|
||||||
return selected_cases
|
|
||||||
|
|
||||||
|
|
||||||
def _print_single_result(result: ProbeResult, show_response: bool) -> None:
|
|
||||||
"""打印单次结果。"""
|
|
||||||
status_text = "PASS" if result.success else "FAIL"
|
|
||||||
print(
|
|
||||||
f"[{status_text}] model={result.target_model_name} | task={result.task_name} | "
|
|
||||||
f"case={result.case_name} | attempt={result.attempt} | elapsed={result.elapsed_seconds:.2f}s"
|
|
||||||
)
|
|
||||||
if result.errors:
|
|
||||||
for error in result.errors:
|
|
||||||
print(f" ERROR: {error}")
|
|
||||||
if result.warnings:
|
|
||||||
for warning in result.warnings:
|
|
||||||
print(f" WARN: {warning}")
|
|
||||||
if result.tool_calls:
|
|
||||||
print(f" tool_calls: {json.dumps(result.tool_calls, ensure_ascii=False)}")
|
|
||||||
if show_response and result.response_text.strip():
|
|
||||||
print(f" response: {result.response_text}")
|
|
||||||
|
|
||||||
|
|
||||||
def _build_summary(results: Sequence[ProbeResult]) -> Dict[str, Any]:
|
|
||||||
"""构造结果摘要。"""
|
|
||||||
total_count = len(results)
|
|
||||||
passed_count = sum(1 for result in results if result.success)
|
|
||||||
failed_count = total_count - passed_count
|
|
||||||
failed_items = [
|
|
||||||
{
|
|
||||||
"model_name": result.target_model_name,
|
|
||||||
"case_name": result.case_name,
|
|
||||||
"attempt": result.attempt,
|
|
||||||
"errors": list(result.errors),
|
|
||||||
}
|
|
||||||
for result in results
|
|
||||||
if not result.success
|
|
||||||
]
|
|
||||||
return {
|
|
||||||
"total": total_count,
|
|
||||||
"passed": passed_count,
|
|
||||||
"failed": failed_count,
|
|
||||||
"failed_items": failed_items,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _write_json_report(json_out: str, results: Sequence[ProbeResult]) -> None:
|
|
||||||
"""将测试结果写入 JSON 文件。"""
|
|
||||||
output_path = Path(json_out).expanduser().resolve()
|
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
payload = {
|
|
||||||
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
||||||
"summary": _build_summary(results),
|
|
||||||
"results": [asdict(result) for result in results],
|
|
||||||
}
|
|
||||||
output_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
||||||
print(f"\n结果已写入: {output_path}")
|
|
||||||
|
|
||||||
|
|
||||||
async def _run_probes(args: Namespace) -> List[ProbeResult]:
|
|
||||||
"""执行所有探测请求。"""
|
|
||||||
task_filters = _parse_multi_value_args(args.task)
|
|
||||||
model_filters = _parse_multi_value_args(args.model)
|
|
||||||
case_filters = _parse_multi_value_args(args.case)
|
|
||||||
|
|
||||||
selected_cases = _select_cases(case_filters)
|
|
||||||
targets = _resolve_targets(task_filters, model_filters, args.fallback_task)
|
|
||||||
|
|
||||||
_print_targets(targets)
|
|
||||||
print("")
|
|
||||||
|
|
||||||
results: List[ProbeResult] = []
|
|
||||||
for target in targets:
|
|
||||||
for attempt in range(1, args.repeat + 1):
|
|
||||||
for case in selected_cases:
|
|
||||||
print(
|
|
||||||
f"开始测试: model={target.model_name}, task={target.task_name}, "
|
|
||||||
f"case={case.name}, attempt={attempt}"
|
|
||||||
)
|
|
||||||
result = await _run_single_probe(
|
|
||||||
target=target,
|
|
||||||
case=case,
|
|
||||||
attempt=attempt,
|
|
||||||
max_tokens=args.max_tokens,
|
|
||||||
temperature=args.temperature,
|
|
||||||
)
|
|
||||||
_print_single_result(result, args.show_response)
|
|
||||||
print("")
|
|
||||||
results.append(result)
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def _build_parser() -> ArgumentParser:
|
|
||||||
"""构造命令行参数解析器。"""
|
|
||||||
parser = ArgumentParser(
|
|
||||||
description=(
|
|
||||||
"测试 config/model_config.toml 中不同模型的 tool call 参数兼容性。\n"
|
|
||||||
"默认会测试所有非 voice / embedding 任务中引用到的模型。"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--task",
|
|
||||||
action="append",
|
|
||||||
help="指定任务名,可重复传入,或使用逗号分隔多个值,例如 --task utils --task planner",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--model",
|
|
||||||
action="append",
|
|
||||||
help="指定模型名,可重复传入,或使用逗号分隔多个值,例如 --model qwen3.6-plus",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--case",
|
|
||||||
action="append",
|
|
||||||
help="指定测试用例名,可选 simple、nested;不传则运行全部默认用例",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--repeat",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="每个模型每个用例重复测试次数,默认 1",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-tokens",
|
|
||||||
type=int,
|
|
||||||
default=512,
|
|
||||||
help="单次测试的最大输出 token 数,默认 512",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--temperature",
|
|
||||||
type=float,
|
|
||||||
default=0.0,
|
|
||||||
help="单次测试温度,默认 0.0 以尽量提高稳定性",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--fallback-task",
|
|
||||||
default="utils",
|
|
||||||
help="当指定模型未被任何已选任务引用时,用于挂载该模型的任务名,默认 utils",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--json-out",
|
|
||||||
help="可选,将结果写入指定 JSON 文件",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--list-targets",
|
|
||||||
action="store_true",
|
|
||||||
help="仅打印当前任务与模型映射,不发起网络请求",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--show-response",
|
|
||||||
action="store_true",
|
|
||||||
help="打印模型返回的自然语言文本内容",
|
|
||||||
)
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> int:
|
|
||||||
"""脚本入口。"""
|
|
||||||
_ensure_utf8_console()
|
|
||||||
parser = _build_parser()
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.repeat < 1:
|
|
||||||
parser.error("--repeat 必须大于等于 1")
|
|
||||||
if args.max_tokens < 1:
|
|
||||||
parser.error("--max-tokens 必须大于等于 1")
|
|
||||||
|
|
||||||
if args.list_targets:
|
|
||||||
_print_available_targets()
|
|
||||||
return 0
|
|
||||||
|
|
||||||
results = asyncio.run(_run_probes(args))
|
|
||||||
summary = _build_summary(results)
|
|
||||||
|
|
||||||
print("测试摘要:")
|
|
||||||
print(
|
|
||||||
f"total={summary['total']} | passed={summary['passed']} | failed={summary['failed']}"
|
|
||||||
)
|
|
||||||
if summary["failed_items"]:
|
|
||||||
print("失败明细:")
|
|
||||||
for failed_item in summary["failed_items"]:
|
|
||||||
print(
|
|
||||||
f"- model={failed_item['model_name']} | case={failed_item['case_name']} | "
|
|
||||||
f"attempt={failed_item['attempt']} | errors={failed_item['errors']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.json_out:
|
|
||||||
_write_json_report(args.json_out, results)
|
|
||||||
|
|
||||||
return 0 if summary["failed"] == 0 else 1
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
raise SystemExit(main())
|
|
||||||
@@ -1,777 +0,0 @@
|
|||||||
from argparse import ArgumentParser, Namespace
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from dataclasses import asdict, dataclass, field
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, Iterator, List, Sequence
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
|
||||||
if str(PROJECT_ROOT) not in sys.path:
|
|
||||||
sys.path.insert(0, str(PROJECT_ROOT))
|
|
||||||
|
|
||||||
from src.common.data_models.llm_service_data_models import LLMServiceRequest, LLMServiceResult # noqa: E402
|
|
||||||
from src.config.config import config_manager # noqa: E402
|
|
||||||
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig # noqa: E402
|
|
||||||
from src.services.llm_service import generate # noqa: E402
|
|
||||||
from src.services.service_task_resolver import get_available_models # noqa: E402
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_SKIP_TASKS = {"embedding", "voice"}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
|
||||||
class ProbeTarget:
|
|
||||||
"""单个待测试模型目标。"""
|
|
||||||
|
|
||||||
task_name: str
|
|
||||||
model_name: str
|
|
||||||
provider_name: str
|
|
||||||
client_type: str
|
|
||||||
tool_argument_parse_mode: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
|
||||||
class ToolCallScenario:
|
|
||||||
"""工具调用 API 场景定义。"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
prompt: List[Dict[str, Any]]
|
|
||||||
tool_options: List[Dict[str, Any]] | None = None
|
|
||||||
expect_tool_calls: bool | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
|
||||||
class ProbeResult:
|
|
||||||
"""单次 API 探测结果。"""
|
|
||||||
|
|
||||||
task_name: str
|
|
||||||
target_model_name: str
|
|
||||||
actual_model_name: str
|
|
||||||
provider_name: str
|
|
||||||
client_type: str
|
|
||||||
tool_argument_parse_mode: str
|
|
||||||
case_name: str
|
|
||||||
attempt: int
|
|
||||||
success: bool
|
|
||||||
elapsed_seconds: float
|
|
||||||
errors: List[str] = field(default_factory=list)
|
|
||||||
warnings: List[str] = field(default_factory=list)
|
|
||||||
response_text: str = ""
|
|
||||||
reasoning_text: str = ""
|
|
||||||
tool_calls: List[Dict[str, Any]] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
def _ensure_utf8_console() -> None:
|
|
||||||
"""尽量将控制台编码切换为 UTF-8。"""
|
|
||||||
try:
|
|
||||||
if hasattr(sys.stdout, "reconfigure"):
|
|
||||||
sys.stdout.reconfigure(encoding="utf-8")
|
|
||||||
if hasattr(sys.stderr, "reconfigure"):
|
|
||||||
sys.stderr.reconfigure(encoding="utf-8")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def _build_function_tool(name: str, description: str, parameters: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""构造 OpenAI 风格 function tool。"""
|
|
||||||
return {
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": name,
|
|
||||||
"description": description,
|
|
||||||
"parameters": parameters,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _build_probe_tools() -> List[Dict[str, Any]]:
|
|
||||||
"""构造通用测试工具。"""
|
|
||||||
weather_tool = _build_function_tool(
|
|
||||||
name="lookup_weather",
|
|
||||||
description="查询指定城市天气。",
|
|
||||||
parameters={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"city": {"type": "string", "description": "城市名"},
|
|
||||||
"unit": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "温度单位",
|
|
||||||
"enum": ["celsius", "fahrenheit"],
|
|
||||||
},
|
|
||||||
"include_forecast": {"type": "boolean", "description": "是否包含未来天气"},
|
|
||||||
},
|
|
||||||
"required": ["city", "unit", "include_forecast"],
|
|
||||||
"additionalProperties": False,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
search_tool = _build_function_tool(
|
|
||||||
name="search_docs",
|
|
||||||
description="搜索内部知识库。",
|
|
||||||
parameters={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"query": {"type": "string", "description": "搜索关键词"},
|
|
||||||
"top_k": {"type": "integer", "description": "返回条数"},
|
|
||||||
"filters": {
|
|
||||||
"type": "object",
|
|
||||||
"description": "过滤条件",
|
|
||||||
"properties": {
|
|
||||||
"scope": {"type": "string", "description": "搜索范围"},
|
|
||||||
"tag": {"type": "string", "description": "标签"},
|
|
||||||
},
|
|
||||||
"required": ["scope", "tag"],
|
|
||||||
"additionalProperties": False,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["query", "top_k", "filters"],
|
|
||||||
"additionalProperties": False,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return [weather_tool, search_tool]
|
|
||||||
|
|
||||||
|
|
||||||
def _build_default_scenarios() -> List[ToolCallScenario]:
|
|
||||||
"""构造默认测试场景。"""
|
|
||||||
tools = _build_probe_tools()
|
|
||||||
weather_tool = tools[0]
|
|
||||||
search_tool = tools[1]
|
|
||||||
|
|
||||||
history_tool_call = {
|
|
||||||
"id": "call_hist_weather_001",
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "lookup_weather",
|
|
||||||
"arguments": {
|
|
||||||
"city": "上海",
|
|
||||||
"unit": "celsius",
|
|
||||||
"include_forecast": True,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
nested_history_tool_call = {
|
|
||||||
"id": "call_hist_search_001",
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "search_docs",
|
|
||||||
"arguments": {
|
|
||||||
"query": "工具调用兼容性",
|
|
||||||
"top_k": 3,
|
|
||||||
"filters": {
|
|
||||||
"scope": "internal",
|
|
||||||
"tag": "tool-call",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
return [
|
|
||||||
ToolCallScenario(
|
|
||||||
name="fresh_tool_call",
|
|
||||||
description="首轮普通工具调用请求。",
|
|
||||||
prompt=[
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": (
|
|
||||||
"你正在执行工具调用连通性测试。"
|
|
||||||
"如果能调用工具,就优先调用最合适的工具。"
|
|
||||||
),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "请查询上海天气,并使用工具给出参数。",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
tool_options=[weather_tool],
|
|
||||||
expect_tool_calls=True,
|
|
||||||
),
|
|
||||||
ToolCallScenario(
|
|
||||||
name="history_assistant_tool_calls_with_content",
|
|
||||||
description="历史 assistant 同时包含文本和 tool_calls,当前轮不再提供 tools。",
|
|
||||||
prompt=[
|
|
||||||
{"role": "system", "content": "你正在执行多轮上下文兼容性测试。"},
|
|
||||||
{"role": "user", "content": "先帮我查一下上海天气。"},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "我先查询天气,再继续回答。",
|
|
||||||
"tool_calls": [history_tool_call],
|
|
||||||
},
|
|
||||||
{"role": "user", "content": "继续说,别丢掉上下文。"},
|
|
||||||
],
|
|
||||||
tool_options=None,
|
|
||||||
expect_tool_calls=None,
|
|
||||||
),
|
|
||||||
ToolCallScenario(
|
|
||||||
name="history_assistant_tool_calls_without_content",
|
|
||||||
description="历史 assistant 只有 tool_calls,没有文本内容。",
|
|
||||||
prompt=[
|
|
||||||
{"role": "system", "content": "你正在执行多轮上下文兼容性测试。"},
|
|
||||||
{"role": "user", "content": "先帮我查一下上海天气。"},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"tool_calls": [history_tool_call],
|
|
||||||
},
|
|
||||||
{"role": "user", "content": "继续。"},
|
|
||||||
],
|
|
||||||
tool_options=None,
|
|
||||||
expect_tool_calls=None,
|
|
||||||
),
|
|
||||||
ToolCallScenario(
|
|
||||||
name="history_tool_result_followup",
|
|
||||||
description="历史中包含 assistant.tool_calls 与对应 tool 结果消息。",
|
|
||||||
prompt=[
|
|
||||||
{"role": "system", "content": "你正在执行工具调用闭环兼容性测试。"},
|
|
||||||
{"role": "user", "content": "先查上海天气。"},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "我先查询天气。",
|
|
||||||
"tool_calls": [history_tool_call],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": "call_hist_weather_001",
|
|
||||||
"content": json.dumps(
|
|
||||||
{
|
|
||||||
"city": "上海",
|
|
||||||
"condition": "多云",
|
|
||||||
"temperature_c": 24,
|
|
||||||
"forecast": ["晴", "小雨"],
|
|
||||||
},
|
|
||||||
ensure_ascii=False,
|
|
||||||
),
|
|
||||||
},
|
|
||||||
{"role": "user", "content": "结合上面的查询结果继续总结。"},
|
|
||||||
],
|
|
||||||
tool_options=None,
|
|
||||||
expect_tool_calls=None,
|
|
||||||
),
|
|
||||||
ToolCallScenario(
|
|
||||||
name="history_multiple_tool_calls_and_results",
|
|
||||||
description="历史中包含多个 tool_calls 与多条 tool 结果。",
|
|
||||||
prompt=[
|
|
||||||
{"role": "system", "content": "你正在执行多工具上下文兼容性测试。"},
|
|
||||||
{"role": "user", "content": "先查天气,再搜一下工具调用兼容性文档。"},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "我分两步查询。",
|
|
||||||
"tool_calls": [history_tool_call, nested_history_tool_call],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": "call_hist_weather_001",
|
|
||||||
"content": json.dumps(
|
|
||||||
{
|
|
||||||
"city": "上海",
|
|
||||||
"condition": "阴",
|
|
||||||
"temperature_c": 22,
|
|
||||||
},
|
|
||||||
ensure_ascii=False,
|
|
||||||
),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": "call_hist_search_001",
|
|
||||||
"content": json.dumps(
|
|
||||||
{
|
|
||||||
"items": [
|
|
||||||
"OpenAI 兼容接口的 arguments 常见为 JSON 字符串",
|
|
||||||
"部分 provider 在历史消息回放时兼容性较弱",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
ensure_ascii=False,
|
|
||||||
),
|
|
||||||
},
|
|
||||||
{"role": "user", "content": "继续整合上面的两个结果。"},
|
|
||||||
],
|
|
||||||
tool_options=None,
|
|
||||||
expect_tool_calls=None,
|
|
||||||
),
|
|
||||||
ToolCallScenario(
|
|
||||||
name="history_tool_calls_with_current_tools",
|
|
||||||
description="保留历史 tool_calls,同时当前轮仍然提供 tools。",
|
|
||||||
prompt=[
|
|
||||||
{"role": "system", "content": "你正在执行历史 tool_calls 与当前 tools 共存测试。"},
|
|
||||||
{"role": "user", "content": "先查上海天气。"},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "我先查天气。",
|
|
||||||
"tool_calls": [history_tool_call],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": "call_hist_weather_001",
|
|
||||||
"content": json.dumps(
|
|
||||||
{
|
|
||||||
"city": "上海",
|
|
||||||
"condition": "晴",
|
|
||||||
"temperature_c": 26,
|
|
||||||
},
|
|
||||||
ensure_ascii=False,
|
|
||||||
),
|
|
||||||
},
|
|
||||||
{"role": "user", "content": "现在再搜一下工具调用兼容性文档。"},
|
|
||||||
],
|
|
||||||
tool_options=[search_tool],
|
|
||||||
expect_tool_calls=True,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_multi_value_args(raw_values: Sequence[str] | None) -> List[str]:
|
|
||||||
"""解析命令行中的多值参数。"""
|
|
||||||
parsed_values: List[str] = []
|
|
||||||
for raw_value in raw_values or []:
|
|
||||||
for item in str(raw_value).split(","):
|
|
||||||
normalized_item = item.strip()
|
|
||||||
if normalized_item:
|
|
||||||
parsed_values.append(normalized_item)
|
|
||||||
return parsed_values
|
|
||||||
|
|
||||||
|
|
||||||
def _build_model_map() -> Dict[str, ModelInfo]:
|
|
||||||
"""构造模型名到模型配置的映射。"""
|
|
||||||
return {model.name: model for model in config_manager.get_model_config().models}
|
|
||||||
|
|
||||||
|
|
||||||
def _build_provider_map() -> Dict[str, APIProvider]:
|
|
||||||
"""构造 Provider 名称到配置的映射。"""
|
|
||||||
return {provider.name: provider for provider in config_manager.get_model_config().api_providers}
|
|
||||||
|
|
||||||
|
|
||||||
def _pick_default_task_name(task_names: Sequence[str]) -> str:
|
|
||||||
"""选择默认任务名。"""
|
|
||||||
if "utils" in task_names:
|
|
||||||
return "utils"
|
|
||||||
if not task_names:
|
|
||||||
raise ValueError("当前没有可用的任务配置")
|
|
||||||
return str(task_names[0])
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_targets(task_filters: Sequence[str], model_filters: Sequence[str], fallback_task: str) -> List[ProbeTarget]:
|
|
||||||
"""根据命令行参数解析待测试目标。"""
|
|
||||||
available_tasks = get_available_models()
|
|
||||||
model_map = _build_model_map()
|
|
||||||
provider_map = _build_provider_map()
|
|
||||||
|
|
||||||
if not available_tasks:
|
|
||||||
raise ValueError("未找到任何可用的模型任务配置")
|
|
||||||
|
|
||||||
if task_filters:
|
|
||||||
selected_task_names = []
|
|
||||||
for task_name in task_filters:
|
|
||||||
if task_name not in available_tasks:
|
|
||||||
raise ValueError(f"未找到任务 `{task_name}`")
|
|
||||||
selected_task_names.append(task_name)
|
|
||||||
else:
|
|
||||||
selected_task_names = [
|
|
||||||
task_name
|
|
||||||
for task_name in available_tasks
|
|
||||||
if task_name not in DEFAULT_SKIP_TASKS
|
|
||||||
]
|
|
||||||
|
|
||||||
if not selected_task_names:
|
|
||||||
raise ValueError("没有可用于工具调用 API 测试的任务,请显式通过 --task 指定")
|
|
||||||
|
|
||||||
default_task_name = fallback_task if fallback_task in available_tasks else _pick_default_task_name(selected_task_names)
|
|
||||||
resolved_targets: List[ProbeTarget] = []
|
|
||||||
seen_models: set[str] = set()
|
|
||||||
|
|
||||||
if model_filters:
|
|
||||||
model_names = list(model_filters)
|
|
||||||
else:
|
|
||||||
model_names = []
|
|
||||||
for task_name in selected_task_names:
|
|
||||||
task_config = available_tasks[task_name]
|
|
||||||
for model_name in task_config.model_list:
|
|
||||||
if model_name not in model_names:
|
|
||||||
model_names.append(model_name)
|
|
||||||
|
|
||||||
for model_name in model_names:
|
|
||||||
if model_name in seen_models:
|
|
||||||
continue
|
|
||||||
if model_name not in model_map:
|
|
||||||
raise ValueError(f"未找到模型 `{model_name}`")
|
|
||||||
|
|
||||||
target_task_name = ""
|
|
||||||
for task_name in selected_task_names:
|
|
||||||
if model_name in available_tasks[task_name].model_list:
|
|
||||||
target_task_name = task_name
|
|
||||||
break
|
|
||||||
if not target_task_name:
|
|
||||||
target_task_name = default_task_name
|
|
||||||
|
|
||||||
model_info = model_map[model_name]
|
|
||||||
provider_info = provider_map[model_info.api_provider]
|
|
||||||
resolved_targets.append(
|
|
||||||
ProbeTarget(
|
|
||||||
task_name=target_task_name,
|
|
||||||
model_name=model_name,
|
|
||||||
provider_name=provider_info.name,
|
|
||||||
client_type=provider_info.client_type,
|
|
||||||
tool_argument_parse_mode=provider_info.tool_argument_parse_mode,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
seen_models.add(model_name)
|
|
||||||
|
|
||||||
return resolved_targets
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def _pin_task_to_model(task_name: str, model_name: str) -> Iterator[None]:
|
|
||||||
"""临时将某个任务锁定到单模型。"""
|
|
||||||
model_task_config = config_manager.get_model_config().model_task_config
|
|
||||||
task_config = getattr(model_task_config, task_name, None)
|
|
||||||
if not isinstance(task_config, TaskConfig):
|
|
||||||
raise ValueError(f"未找到任务 `{task_name}` 对应的配置")
|
|
||||||
|
|
||||||
original_model_list = list(task_config.model_list)
|
|
||||||
original_selection_strategy = task_config.selection_strategy
|
|
||||||
task_config.model_list = [model_name]
|
|
||||||
task_config.selection_strategy = "balance"
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
task_config.model_list = original_model_list
|
|
||||||
task_config.selection_strategy = original_selection_strategy
|
|
||||||
|
|
||||||
|
|
||||||
def _serialize_tool_calls(tool_calls: Any) -> List[Dict[str, Any]]:
|
|
||||||
"""序列化返回中的工具调用。"""
|
|
||||||
if not tool_calls:
|
|
||||||
return []
|
|
||||||
|
|
||||||
serialized_items: List[Dict[str, Any]] = []
|
|
||||||
for tool_call in tool_calls:
|
|
||||||
serialized_items.append(
|
|
||||||
{
|
|
||||||
"id": getattr(tool_call, "call_id", ""),
|
|
||||||
"function": {
|
|
||||||
"name": getattr(tool_call, "func_name", ""),
|
|
||||||
"arguments": dict(getattr(tool_call, "args", {}) or {}),
|
|
||||||
},
|
|
||||||
**(
|
|
||||||
{"extra_content": dict(getattr(tool_call, "extra_content", {}) or {})}
|
|
||||||
if getattr(tool_call, "extra_content", None)
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return serialized_items
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_service_result(service_result: LLMServiceResult, scenario: ToolCallScenario) -> tuple[List[str], List[str], List[Dict[str, Any]]]:
|
|
||||||
"""校验服务结果。"""
|
|
||||||
errors: List[str] = []
|
|
||||||
warnings: List[str] = []
|
|
||||||
completion = service_result.completion
|
|
||||||
serialized_tool_calls = _serialize_tool_calls(completion.tool_calls)
|
|
||||||
|
|
||||||
if not service_result.success:
|
|
||||||
errors.append(service_result.error or completion.response or "请求失败,但没有返回明确错误")
|
|
||||||
return errors, warnings, serialized_tool_calls
|
|
||||||
|
|
||||||
if scenario.expect_tool_calls is True and not serialized_tool_calls:
|
|
||||||
warnings.append("本场景期望模型倾向于调用工具,但未返回 tool_calls")
|
|
||||||
if scenario.expect_tool_calls is False and serialized_tool_calls:
|
|
||||||
warnings.append("本场景未期望继续调用工具,但模型返回了 tool_calls")
|
|
||||||
if completion.response.strip():
|
|
||||||
warnings.append("模型返回了可见文本")
|
|
||||||
return errors, warnings, serialized_tool_calls
|
|
||||||
|
|
||||||
|
|
||||||
async def _run_single_probe(
|
|
||||||
target: ProbeTarget,
|
|
||||||
scenario: ToolCallScenario,
|
|
||||||
attempt: int,
|
|
||||||
max_tokens: int,
|
|
||||||
temperature: float,
|
|
||||||
) -> ProbeResult:
|
|
||||||
"""执行单次 API 探测。"""
|
|
||||||
request = LLMServiceRequest(
|
|
||||||
task_name=target.task_name,
|
|
||||||
request_type=f"tool_call_api_matrix.{scenario.name}.attempt_{attempt}",
|
|
||||||
prompt=scenario.prompt,
|
|
||||||
tool_options=scenario.tool_options,
|
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
started_at = time.perf_counter()
|
|
||||||
with _pin_task_to_model(target.task_name, target.model_name):
|
|
||||||
service_result = await generate(request)
|
|
||||||
elapsed_seconds = time.perf_counter() - started_at
|
|
||||||
|
|
||||||
errors, warnings, serialized_tool_calls = _validate_service_result(service_result, scenario)
|
|
||||||
completion = service_result.completion
|
|
||||||
return ProbeResult(
|
|
||||||
task_name=target.task_name,
|
|
||||||
target_model_name=target.model_name,
|
|
||||||
actual_model_name=completion.model_name,
|
|
||||||
provider_name=target.provider_name,
|
|
||||||
client_type=target.client_type,
|
|
||||||
tool_argument_parse_mode=target.tool_argument_parse_mode,
|
|
||||||
case_name=scenario.name,
|
|
||||||
attempt=attempt,
|
|
||||||
success=not errors,
|
|
||||||
elapsed_seconds=elapsed_seconds,
|
|
||||||
errors=errors,
|
|
||||||
warnings=warnings,
|
|
||||||
response_text=completion.response,
|
|
||||||
reasoning_text=completion.reasoning,
|
|
||||||
tool_calls=serialized_tool_calls,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _print_targets(targets: Sequence[ProbeTarget]) -> None:
|
|
||||||
"""打印待测试目标。"""
|
|
||||||
print("待测试目标:")
|
|
||||||
for index, target in enumerate(targets, start=1):
|
|
||||||
print(
|
|
||||||
f"{index}. model={target.model_name} | task={target.task_name} | "
|
|
||||||
f"provider={target.provider_name} | client={target.client_type} | "
|
|
||||||
f"tool_argument_parse_mode={target.tool_argument_parse_mode}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _print_available_targets() -> None:
|
|
||||||
"""打印当前可用任务与模型。"""
|
|
||||||
available_tasks = get_available_models()
|
|
||||||
model_map = _build_model_map()
|
|
||||||
task_names = list(available_tasks.keys())
|
|
||||||
|
|
||||||
print("当前可用任务:")
|
|
||||||
for task_name in task_names:
|
|
||||||
task_config = available_tasks[task_name]
|
|
||||||
print(f"- {task_name}: {list(task_config.model_list)}")
|
|
||||||
|
|
||||||
referenced_models = {
|
|
||||||
model_name
|
|
||||||
for task_config in available_tasks.values()
|
|
||||||
for model_name in task_config.model_list
|
|
||||||
}
|
|
||||||
|
|
||||||
print("\n当前配置中的模型:")
|
|
||||||
for model_name, model_info in model_map.items():
|
|
||||||
referenced_mark = "已被任务引用" if model_name in referenced_models else "未被任务引用"
|
|
||||||
print(
|
|
||||||
f"- {model_name}: provider={model_info.api_provider}, "
|
|
||||||
f"identifier={model_info.model_identifier}, {referenced_mark}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _select_scenarios(case_filters: Sequence[str]) -> List[ToolCallScenario]:
|
|
||||||
"""按名称筛选测试场景。"""
|
|
||||||
all_scenarios = {scenario.name: scenario for scenario in _build_default_scenarios()}
|
|
||||||
if not case_filters:
|
|
||||||
return list(all_scenarios.values())
|
|
||||||
|
|
||||||
selected_scenarios: List[ToolCallScenario] = []
|
|
||||||
for case_name in case_filters:
|
|
||||||
if case_name not in all_scenarios:
|
|
||||||
raise ValueError(
|
|
||||||
f"未知测试场景 `{case_name}`,可选值: {', '.join(sorted(all_scenarios))}"
|
|
||||||
)
|
|
||||||
selected_scenarios.append(all_scenarios[case_name])
|
|
||||||
return selected_scenarios
|
|
||||||
|
|
||||||
|
|
||||||
def _print_single_result(result: ProbeResult, show_response: bool) -> None:
|
|
||||||
"""打印单次结果。"""
|
|
||||||
status_text = "PASS" if result.success else "FAIL"
|
|
||||||
print(
|
|
||||||
f"[{status_text}] model={result.target_model_name} | task={result.task_name} | "
|
|
||||||
f"case={result.case_name} | attempt={result.attempt} | elapsed={result.elapsed_seconds:.2f}s"
|
|
||||||
)
|
|
||||||
if result.errors:
|
|
||||||
for error in result.errors:
|
|
||||||
print(f" ERROR: {error}")
|
|
||||||
if result.warnings:
|
|
||||||
for warning in result.warnings:
|
|
||||||
print(f" WARN: {warning}")
|
|
||||||
if result.tool_calls:
|
|
||||||
print(f" tool_calls: {json.dumps(result.tool_calls, ensure_ascii=False)}")
|
|
||||||
if show_response and result.response_text.strip():
|
|
||||||
print(f" response: {result.response_text}")
|
|
||||||
|
|
||||||
|
|
||||||
def _build_summary(results: Sequence[ProbeResult]) -> Dict[str, Any]:
|
|
||||||
"""构造结果摘要。"""
|
|
||||||
total_count = len(results)
|
|
||||||
passed_count = sum(1 for result in results if result.success)
|
|
||||||
failed_count = total_count - passed_count
|
|
||||||
failed_items = [
|
|
||||||
{
|
|
||||||
"model_name": result.target_model_name,
|
|
||||||
"case_name": result.case_name,
|
|
||||||
"attempt": result.attempt,
|
|
||||||
"errors": list(result.errors),
|
|
||||||
}
|
|
||||||
for result in results
|
|
||||||
if not result.success
|
|
||||||
]
|
|
||||||
return {
|
|
||||||
"total": total_count,
|
|
||||||
"passed": passed_count,
|
|
||||||
"failed": failed_count,
|
|
||||||
"failed_items": failed_items,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _write_json_report(json_out: str, results: Sequence[ProbeResult]) -> None:
|
|
||||||
"""将测试结果写入 JSON 文件。"""
|
|
||||||
output_path = Path(json_out).expanduser().resolve()
|
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
payload = {
|
|
||||||
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
||||||
"summary": _build_summary(results),
|
|
||||||
"results": [asdict(result) for result in results],
|
|
||||||
}
|
|
||||||
output_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
||||||
print(f"\n结果已写入: {output_path}")
|
|
||||||
|
|
||||||
|
|
||||||
async def _run_probes(args: Namespace) -> List[ProbeResult]:
|
|
||||||
"""执行所有探测请求。"""
|
|
||||||
task_filters = _parse_multi_value_args(args.task)
|
|
||||||
model_filters = _parse_multi_value_args(args.model)
|
|
||||||
case_filters = _parse_multi_value_args(args.case)
|
|
||||||
|
|
||||||
selected_scenarios = _select_scenarios(case_filters)
|
|
||||||
targets = _resolve_targets(task_filters, model_filters, args.fallback_task)
|
|
||||||
|
|
||||||
_print_targets(targets)
|
|
||||||
print("")
|
|
||||||
|
|
||||||
results: List[ProbeResult] = []
|
|
||||||
for target in targets:
|
|
||||||
for attempt in range(1, args.repeat + 1):
|
|
||||||
for scenario in selected_scenarios:
|
|
||||||
print(
|
|
||||||
f"开始测试: model={target.model_name}, task={target.task_name}, "
|
|
||||||
f"case={scenario.name}, attempt={attempt}"
|
|
||||||
)
|
|
||||||
result = await _run_single_probe(
|
|
||||||
target=target,
|
|
||||||
scenario=scenario,
|
|
||||||
attempt=attempt,
|
|
||||||
max_tokens=args.max_tokens,
|
|
||||||
temperature=args.temperature,
|
|
||||||
)
|
|
||||||
_print_single_result(result, args.show_response)
|
|
||||||
print("")
|
|
||||||
results.append(result)
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def _build_parser() -> ArgumentParser:
|
|
||||||
"""构造命令行参数解析器。"""
|
|
||||||
parser = ArgumentParser(
|
|
||||||
description=(
|
|
||||||
"测试不同模型在多种工具调用消息形态下的 API 兼容性。\n"
|
|
||||||
"重点覆盖历史 assistant.tool_calls、tool 结果消息、多工具调用等场景。"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--task",
|
|
||||||
action="append",
|
|
||||||
help="指定任务名,可重复传入,或使用逗号分隔多个值,例如 --task utils --task planner",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--model",
|
|
||||||
action="append",
|
|
||||||
help="指定模型名,可重复传入,或使用逗号分隔多个值,例如 --model qwen3.5-35b-a3b",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--case",
|
|
||||||
action="append",
|
|
||||||
help=(
|
|
||||||
"指定测试场景名,可选值包括 "
|
|
||||||
"fresh_tool_call、history_assistant_tool_calls_with_content、"
|
|
||||||
"history_assistant_tool_calls_without_content、history_tool_result_followup、"
|
|
||||||
"history_multiple_tool_calls_and_results、history_tool_calls_with_current_tools"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--repeat",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="每个模型每个场景重复测试次数,默认 1",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-tokens",
|
|
||||||
type=int,
|
|
||||||
default=512,
|
|
||||||
help="单次测试的最大输出 token 数,默认 512",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--temperature",
|
|
||||||
type=float,
|
|
||||||
default=0.0,
|
|
||||||
help="单次测试温度,默认 0.0,以尽量提高稳定性",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--fallback-task",
|
|
||||||
default="utils",
|
|
||||||
help="当指定模型未被已选任务引用时,用于挂载该模型的任务名,默认 utils",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--json-out",
|
|
||||||
help="可选,将结果写入指定 JSON 文件",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--list-targets",
|
|
||||||
action="store_true",
|
|
||||||
help="仅打印当前任务与模型映射,不发起网络请求",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--show-response",
|
|
||||||
action="store_true",
|
|
||||||
help="打印模型返回的可见文本内容",
|
|
||||||
)
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> int:
|
|
||||||
"""脚本入口。"""
|
|
||||||
_ensure_utf8_console()
|
|
||||||
config_manager.initialize()
|
|
||||||
parser = _build_parser()
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.repeat < 1:
|
|
||||||
parser.error("--repeat 必须大于等于 1")
|
|
||||||
if args.max_tokens < 1:
|
|
||||||
parser.error("--max-tokens 必须大于等于 1")
|
|
||||||
|
|
||||||
if args.list_targets:
|
|
||||||
_print_available_targets()
|
|
||||||
return 0
|
|
||||||
|
|
||||||
results = asyncio.run(_run_probes(args))
|
|
||||||
summary = _build_summary(results)
|
|
||||||
|
|
||||||
print("测试摘要:")
|
|
||||||
print(
|
|
||||||
f"total={summary['total']} | passed={summary['passed']} | failed={summary['failed']}"
|
|
||||||
)
|
|
||||||
if summary["failed_items"]:
|
|
||||||
print("失败明细:")
|
|
||||||
for failed_item in summary["failed_items"]:
|
|
||||||
print(
|
|
||||||
f"- model={failed_item['model_name']} | case={failed_item['case_name']} | "
|
|
||||||
f"attempt={failed_item['attempt']} | errors={failed_item['errors']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.json_out:
|
|
||||||
_write_json_report(args.json_out, results)
|
|
||||||
|
|
||||||
return 0 if summary["failed"] == 0 else 1
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
raise SystemExit(main())
|
|
||||||
@@ -748,6 +748,9 @@ class ComponentQueryService:
|
|||||||
return payload
|
return payload
|
||||||
|
|
||||||
for key, value in context_payload.items():
|
for key, value in context_payload.items():
|
||||||
|
if key in {"stream_id", "chat_id"}:
|
||||||
|
payload[key] = value
|
||||||
|
continue
|
||||||
if key not in payload or not payload.get(key):
|
if key not in payload or not payload.get(key):
|
||||||
payload[key] = value
|
payload[key] = value
|
||||||
return payload
|
return payload
|
||||||
|
|||||||
@@ -1,76 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import sys
|
|
||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
|
||||||
|
|
||||||
from src.config.config_upgrade_hooks import (
|
|
||||||
BOT_CONFIG_UPGRADE_HOOKS,
|
|
||||||
ConfigUpgradeHook,
|
|
||||||
apply_config_upgrade_hooks,
|
|
||||||
set_nested_config_value,
|
|
||||||
)
|
|
||||||
from src.config.official_configs import ChatConfig
|
|
||||||
|
|
||||||
import src.config.config_upgrade_hooks as hooks
|
|
||||||
|
|
||||||
|
|
||||||
def test_apply_config_upgrade_hooks_runs_when_target_version_is_crossed(monkeypatch):
|
|
||||||
def migrate(data):
|
|
||||||
changed = set_nested_config_value(data, ("chat", "enable"), False)
|
|
||||||
return ["chat.enable"] if changed else []
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
hooks,
|
|
||||||
"BOT_CONFIG_UPGRADE_HOOKS",
|
|
||||||
(ConfigUpgradeHook(target_version="8.10.11", config_names=("bot_config.toml",), migrate=migrate),),
|
|
||||||
)
|
|
||||||
|
|
||||||
data = {"chat": {"enable": True}}
|
|
||||||
result = apply_config_upgrade_hooks(data, "bot_config.toml", "8.10.10", "8.10.11")
|
|
||||||
|
|
||||||
assert result.migrated is True
|
|
||||||
assert result.reason == "8.10.11:chat.enable"
|
|
||||||
assert result.data["chat"]["enable"] is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_apply_config_upgrade_hooks_skips_versions_outside_upgrade_range(monkeypatch):
|
|
||||||
def migrate(data):
|
|
||||||
set_nested_config_value(data, ("chat", "enable"), False)
|
|
||||||
return ["chat.enable"]
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
hooks,
|
|
||||||
"BOT_CONFIG_UPGRADE_HOOKS",
|
|
||||||
(ConfigUpgradeHook(target_version="8.10.11", config_names=("bot_config.toml",), migrate=migrate),),
|
|
||||||
)
|
|
||||||
|
|
||||||
data = {"chat": {"enable": True}}
|
|
||||||
result = apply_config_upgrade_hooks(data, "bot_config.toml", "8.10.11", "8.10.12")
|
|
||||||
|
|
||||||
assert result.migrated is False
|
|
||||||
assert result.data["chat"]["enable"] is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_set_nested_config_value_can_keep_existing_value():
|
|
||||||
data = {"webui": {"port": 8001}}
|
|
||||||
|
|
||||||
changed = set_nested_config_value(data, ("webui", "port"), 8080, force=False)
|
|
||||||
|
|
||||||
assert changed is False
|
|
||||||
assert data["webui"]["port"] == 8001
|
|
||||||
|
|
||||||
|
|
||||||
def test_builtin_hook_resets_group_chat_prompt_when_upgrading_from_8_10_10():
|
|
||||||
data = {"chat": {"group_chat_prompt": "自定义旧提示词"}}
|
|
||||||
|
|
||||||
result = apply_config_upgrade_hooks(data, "bot_config.toml", "8.10.10", "8.10.11")
|
|
||||||
|
|
||||||
assert result.migrated is True
|
|
||||||
assert result.reason == "8.10.11:chat.group_chat_prompt"
|
|
||||||
assert result.data["chat"]["group_chat_prompt"] == ChatConfig().group_chat_prompt
|
|
||||||
|
|
||||||
|
|
||||||
def test_bot_config_upgrade_hooks_register_group_chat_prompt_reset():
|
|
||||||
assert len(BOT_CONFIG_UPGRADE_HOOKS) == 1
|
|
||||||
assert BOT_CONFIG_UPGRADE_HOOKS[0].target_version == "8.10.11"
|
|
||||||
Reference in New Issue
Block a user