diff --git a/AGENTS.md b/AGENTS.md index fc5fd305..d725feee 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -33,6 +33,7 @@ # 运行/调试/构建/测试/依赖 优先使用uv 依赖项以 pyproject.toml 为准,要同步更新requirements.txt +如为当前任务临时创建测试文件,跑完测试后必须立刻删除,不要保留在仓库中,也不要进入共享历史 前端改动后,如需走离线发布工作流,必须先在 `dashboard` 目录执行 `npm run build`(例如:`D:\Nodejs\npm.cmd run build`),确保生成最新的 `dashboard/dist` 当前项目的离线发布工作流依赖仓库中已提交的 `dashboard/dist`;在修改 `dashboard/src`、`dashboard/public`、`dashboard/package.json`、`dashboard/package-lock.json` 或其他影响构建产物的前端文件后,要同步提交对应的 `dashboard/dist` 不要在离线发布场景中假设服务器或 runner 可以联网安装 Node.js 依赖;除非明确确认发布机已具备可用的 `npm` 环境,否则默认按“本地构建 `dashboard/dist` 并随仓库提交”处理 diff --git a/dashboard/src/components/dynamic-form/__tests__/DynamicConfigForm.test.tsx b/dashboard/src/components/dynamic-form/__tests__/DynamicConfigForm.test.tsx deleted file mode 100644 index b142e3b1..00000000 --- a/dashboard/src/components/dynamic-form/__tests__/DynamicConfigForm.test.tsx +++ /dev/null @@ -1,427 +0,0 @@ -import { describe, it, expect, vi } from 'vitest' -import { screen } from '@testing-library/dom' -import { render } from '@testing-library/react' -import userEvent from '@testing-library/user-event' - -import { DynamicConfigForm } from '../DynamicConfigForm' -import { FieldHookRegistry } from '@/lib/field-hooks' -import type { ConfigSchema } from '@/types/config-schema' -import type { FieldHookComponentProps } from '@/lib/field-hooks' - -describe('DynamicConfigForm', () => { - describe('basic rendering', () => { - it('renders simple fields', () => { - const schema: ConfigSchema = { - className: 'TestConfig', - classDoc: 'Test configuration', - fields: [ - { - name: 'field1', - type: 'string', - label: 'Field 1', - description: 'First field', - required: false, - default: 'value1', - }, - { - name: 'field2', - type: 'boolean', - label: 'Field 2', - description: 'Second field', - required: false, - default: false, - }, - ], - } - const values = { field1: 'value1', field2: false } - const onChange = vi.fn() - - render() - - expect(screen.getByText('Field 1')).toBeInTheDocument() - expect(screen.getByText('Field 2')).toBeInTheDocument() - expect(screen.getByText('First field')).toBeInTheDocument() - expect(screen.getByText('Second field')).toBeInTheDocument() - }) - - it('renders nested schema', () => { - const schema: ConfigSchema = { - className: 'MainConfig', - classDoc: 'Main configuration', - fields: [ - { - name: 'top_field', - type: 'string', - label: 'Top Field', - description: 'Top level field', - required: false, - }, - ], - nested: { - sub_config: { - className: 'SubConfig', - classDoc: 'Sub configuration', - fields: [ - { - name: 'nested_field', - type: 'number', - label: 'Nested Field', - description: 'Nested field', - required: false, - default: 42, - }, - ], - }, - }, - } - const values = { - top_field: 'top', - sub_config: { - nested_field: 42, - }, - } - const onChange = vi.fn() - - render() - - expect(screen.getByText('Top Field')).toBeInTheDocument() - expect(screen.getByText('Sub configuration')).toBeInTheDocument() - expect(screen.getByText('Nested Field')).toBeInTheDocument() - }) - }) - - describe('Hook system', () => { - it('renders Hook component in replace mode', () => { - const TestHookComponent: React.FC = ({ fieldPath, value }) => { - return
Hook: {fieldPath} = {String(value)}
- } - - const hooks = new FieldHookRegistry() - hooks.register('hooked_field', TestHookComponent, 'replace') - - const schema: ConfigSchema = { - className: 'TestConfig', - classDoc: 'Test configuration', - fields: [ - { - name: 'hooked_field', - type: 'string', - label: 'Hooked Field', - description: 'A field with hook', - required: false, - }, - { - name: 'normal_field', - type: 'string', - label: 'Normal Field', - description: 'A normal field', - required: false, - }, - ], - } - const values = { hooked_field: 'test', normal_field: 'normal' } - const onChange = vi.fn() - - render() - - expect(screen.getByTestId('hook-component')).toBeInTheDocument() - expect(screen.getByText('Hook: hooked_field = test')).toBeInTheDocument() - expect(screen.queryByText('Hooked Field')).not.toBeInTheDocument() - expect(screen.getByText('Normal Field')).toBeInTheDocument() - }) - - it('renders Hook component in wrapper mode', () => { - const WrapperHookComponent: React.FC = ({ fieldPath, children }) => { - return ( -
-
Wrapper for: {fieldPath}
- {children} -
- ) - } - - const hooks = new FieldHookRegistry() - hooks.register('wrapped_field', WrapperHookComponent, 'wrapper') - - const schema: ConfigSchema = { - className: 'TestConfig', - classDoc: 'Test configuration', - fields: [ - { - name: 'wrapped_field', - type: 'string', - label: 'Wrapped Field', - description: 'A wrapped field', - required: false, - }, - ], - } - const values = { wrapped_field: 'test' } - const onChange = vi.fn() - - render() - - expect(screen.getByTestId('wrapper-hook')).toBeInTheDocument() - expect(screen.getByText('Wrapper for: wrapped_field')).toBeInTheDocument() - expect(screen.getByText('Wrapped Field')).toBeInTheDocument() - }) - - it('passes correct props to Hook component', () => { - const TestHookComponent: React.FC = ({ fieldPath, value, onChange }) => { - return ( -
-
{fieldPath}
-
{String(value)}
- -
- ) - } - - const hooks = new FieldHookRegistry() - hooks.register('test_field', TestHookComponent, 'replace') - - const schema: ConfigSchema = { - className: 'TestConfig', - classDoc: 'Test configuration', - fields: [ - { - name: 'test_field', - type: 'string', - label: 'Test Field', - description: 'A test field', - required: false, - }, - ], - } - const values = { test_field: 'original' } - const onChange = vi.fn() - - render() - - expect(screen.getByTestId('field-path')).toHaveTextContent('test_field') - expect(screen.getByTestId('field-value')).toHaveTextContent('original') - }) - }) - - describe('onChange propagation', () => { - it('propagates onChange from simple field', async () => { - const schema: ConfigSchema = { - className: 'TestConfig', - classDoc: 'Test configuration', - fields: [ - { - name: 'test_field', - type: 'string', - label: 'Test Field', - description: 'A test field', - required: false, - }, - ], - } - const values = { test_field: '' } - const onChange = vi.fn() - - render() - - const input = screen.getByRole('textbox') - input.focus() - await userEvent.keyboard('Hello') - - expect(onChange).toHaveBeenCalledTimes(5) - expect(onChange.mock.calls.every(call => call[0] === 'test_field')).toBe(true) - expect(onChange).toHaveBeenNthCalledWith(1, 'test_field', 'H') - expect(onChange).toHaveBeenNthCalledWith(5, 'test_field', 'o') - }) - - it('propagates onChange from nested field with correct path', async () => { - const schema: ConfigSchema = { - className: 'MainConfig', - classDoc: 'Main configuration', - fields: [], - nested: { - sub_config: { - className: 'SubConfig', - classDoc: 'Sub configuration', - fields: [ - { - name: 'nested_field', - type: 'string', - label: 'Nested Field', - description: 'Nested field', - required: false, - }, - ], - }, - }, - } - const values = { - sub_config: { - nested_field: '', - }, - } - const onChange = vi.fn() - - render() - - const input = screen.getByRole('textbox') - input.focus() - await userEvent.keyboard('Test') - - expect(onChange).toHaveBeenCalledTimes(4) - expect(onChange.mock.calls.every(call => call[0] === 'sub_config.nested_field')).toBe(true) - expect(onChange).toHaveBeenNthCalledWith(1, 'sub_config.nested_field', 'T') - expect(onChange).toHaveBeenNthCalledWith(4, 'sub_config.nested_field', 't') - }) - - it('propagates onChange from Hook component', async () => { - const TestHookComponent: React.FC = ({ onChange }) => { - return - } - - const hooks = new FieldHookRegistry() - hooks.register('hooked_field', TestHookComponent, 'replace') - - const schema: ConfigSchema = { - className: 'TestConfig', - classDoc: 'Test configuration', - fields: [ - { - name: 'hooked_field', - type: 'string', - label: 'Hooked Field', - description: 'A hooked field', - required: false, - }, - ], - } - const values = { hooked_field: '' } - const onChange = vi.fn() - const user = userEvent.setup() - - render() - - await user.click(screen.getByRole('button')) - - expect(onChange).toHaveBeenCalledWith('hooked_field', 'hook_value') - }) - - it('renders nested Hook component with full field path', async () => { - const NestedHookComponent: React.FC = ({ fieldPath, onChange }) => { - return ( - - ) - } - - const hooks = new FieldHookRegistry() - hooks.register('mcp.servers', NestedHookComponent, 'replace') - - const schema: ConfigSchema = { - className: 'RootConfig', - classDoc: 'Root configuration', - fields: [], - nested: { - mcp: { - className: 'MCPConfig', - classDoc: 'MCP 配置', - fields: [ - { - name: 'enable', - type: 'boolean', - label: '启用 MCP', - description: '是否启用 MCP', - required: false, - }, - { - name: 'servers', - type: 'array', - label: '服务器列表', - description: '复杂对象数组', - required: false, - items: { - type: 'object', - }, - }, - ], - nested: { - servers: { - className: 'MCPServerItemConfig', - classDoc: 'MCP 服务器项', - fields: [], - }, - }, - }, - }, - } - const values = { - mcp: { - enable: true, - servers: [], - }, - } - const onChange = vi.fn() - const user = userEvent.setup() - - render() - - await user.click(screen.getByRole('button', { name: 'mcp.servers' })) - - expect(onChange).toHaveBeenCalledWith('mcp.servers', [{ enabled: true }]) - }) - }) - - describe('edge cases', () => { - it('renders with empty nested values', () => { - const schema: ConfigSchema = { - className: 'MainConfig', - classDoc: 'Main configuration', - fields: [], - nested: { - sub_config: { - className: 'SubConfig', - classDoc: 'Sub configuration', - fields: [ - { - name: 'nested_field', - type: 'string', - label: 'Nested Field', - description: 'Nested field', - required: false, - }, - ], - }, - }, - } - const values = {} - const onChange = vi.fn() - - render() - - expect(screen.getByText('Sub configuration')).toBeInTheDocument() - expect(screen.getByText('Nested Field')).toBeInTheDocument() - }) - - it('uses default hook registry when not provided', () => { - const schema: ConfigSchema = { - className: 'TestConfig', - classDoc: 'Test configuration', - fields: [ - { - name: 'test_field', - type: 'string', - label: 'Test Field', - description: 'A test field', - required: false, - }, - ], - } - const values = { test_field: 'test' } - const onChange = vi.fn() - - render() - - expect(screen.getByText('Test Field')).toBeInTheDocument() - }) - }) -}) diff --git a/dashboard/src/components/dynamic-form/__tests__/DynamicField.test.tsx b/dashboard/src/components/dynamic-form/__tests__/DynamicField.test.tsx deleted file mode 100644 index d916288c..00000000 --- a/dashboard/src/components/dynamic-form/__tests__/DynamicField.test.tsx +++ /dev/null @@ -1,475 +0,0 @@ -import { describe, it, expect, vi } from 'vitest' -import { screen } from '@testing-library/dom' -import { render } from '@testing-library/react' -import userEvent from '@testing-library/user-event' - -import { DynamicField } from '../DynamicField' -import type { FieldSchema } from '@/types/config-schema' - -describe('DynamicField', () => { - describe('x-widget priority', () => { - it('renders Slider when x-widget is slider', () => { - const schema: FieldSchema = { - name: 'test_slider', - type: 'number', - label: 'Test Slider', - description: 'A test slider', - required: false, - 'x-widget': 'slider', - minValue: 0, - maxValue: 100, - default: 50, - } - const onChange = vi.fn() - - render() - - expect(screen.getByText('Test Slider')).toBeInTheDocument() - expect(screen.getByRole('slider')).toBeInTheDocument() - expect(screen.getByText('50')).toBeInTheDocument() - }) - - it('renders Switch when x-widget is switch', () => { - const schema: FieldSchema = { - name: 'test_switch', - type: 'boolean', - label: 'Test Switch', - description: 'A test switch', - required: false, - 'x-widget': 'switch', - default: false, - } - const onChange = vi.fn() - - render() - - expect(screen.getByText('Test Switch')).toBeInTheDocument() - expect(screen.getByRole('switch')).toBeInTheDocument() - }) - - it('renders Textarea when x-widget is textarea', () => { - const schema: FieldSchema = { - name: 'test_textarea', - type: 'string', - label: 'Test Textarea', - description: 'A test textarea', - required: false, - 'x-widget': 'textarea', - default: 'Hello', - } - const onChange = vi.fn() - - render() - - expect(screen.getByText('Test Textarea')).toBeInTheDocument() - expect(screen.getByRole('textbox')).toBeInTheDocument() - expect(screen.getByRole('textbox')).toHaveValue('Hello') - }) - - it('renders Select when x-widget is select', () => { - const schema: FieldSchema = { - name: 'test_select', - type: 'string', - label: 'Test Select', - description: 'A test select', - required: false, - 'x-widget': 'select', - options: ['Option 1', 'Option 2', 'Option 3'], - default: 'Option 1', - } - const onChange = vi.fn() - - render() - - expect(screen.getByText('Test Select')).toBeInTheDocument() - expect(screen.getByRole('combobox')).toBeInTheDocument() - }) - - it('renders placeholder for custom widget', () => { - const schema: FieldSchema = { - name: 'test_custom', - type: 'string', - label: 'Test Custom', - description: 'A test custom field', - required: false, - 'x-widget': 'custom', - } - const onChange = vi.fn() - - render() - - expect(screen.getByText('Custom field requires Hook')).toBeInTheDocument() - }) - - it('renders number Input when x-widget is input but type is integer', () => { - const schema: FieldSchema = { - name: 'test_integer_input_widget', - type: 'integer', - label: 'Test Integer Input Widget', - description: 'A numeric field rendered as input', - required: false, - 'x-widget': 'input', - default: 0, - } - const onChange = vi.fn() - - render() - - const input = screen.getByRole('spinbutton') - expect(input).toBeInTheDocument() - expect(input).toHaveValue(2) - }) - - it('parses string values for numeric input widgets', () => { - const schema: FieldSchema = { - name: 'test_string_number_input_widget', - type: 'integer', - label: 'Test String Number Input Widget', - description: 'A numeric field with legacy string value', - required: false, - 'x-widget': 'input', - default: 0, - } - const onChange = vi.fn() - - render() - - expect(screen.getByRole('spinbutton')).toHaveValue(2) - }) - }) - - describe('type fallback', () => { - it('renders Input for string type', () => { - const schema: FieldSchema = { - name: 'test_string', - type: 'string', - label: 'Test String', - description: 'A test string', - required: false, - default: 'Hello', - } - const onChange = vi.fn() - - render() - - expect(screen.getByRole('textbox')).toBeInTheDocument() - expect(screen.getByRole('textbox')).toHaveValue('Hello') - }) - - it('renders Switch for boolean type', () => { - const schema: FieldSchema = { - name: 'test_bool', - type: 'boolean', - label: 'Test Boolean', - description: 'A test boolean', - required: false, - default: true, - } - const onChange = vi.fn() - - render() - - expect(screen.getByRole('switch')).toBeInTheDocument() - expect(screen.getByRole('switch')).toBeChecked() - }) - - it('renders number Input for number type', () => { - const schema: FieldSchema = { - name: 'test_number', - type: 'number', - label: 'Test Number', - description: 'A test number', - required: false, - default: 42, - } - const onChange = vi.fn() - - render() - - const input = screen.getByRole('spinbutton') - expect(input).toBeInTheDocument() - expect(input).toHaveValue(42) - }) - - it('renders number Input for integer type', () => { - const schema: FieldSchema = { - name: 'test_integer', - type: 'integer', - label: 'Test Integer', - description: 'A test integer', - required: false, - default: 10, - } - const onChange = vi.fn() - - render() - - const input = screen.getByRole('spinbutton') - expect(input).toBeInTheDocument() - expect(input).toHaveValue(10) - }) - - it('renders Textarea for textarea type', () => { - const schema: FieldSchema = { - name: 'test_textarea_type', - type: 'textarea', - label: 'Test Textarea Type', - description: 'A test textarea type', - required: false, - default: 'Long text', - } - const onChange = vi.fn() - - render() - - expect(screen.getByRole('textbox')).toBeInTheDocument() - expect(screen.getByRole('textbox')).toHaveValue('Long text') - }) - - it('renders Select for select type', () => { - const schema: FieldSchema = { - name: 'test_select_type', - type: 'select', - label: 'Test Select Type', - description: 'A test select type', - required: false, - options: ['A', 'B', 'C'], - default: 'A', - } - const onChange = vi.fn() - - render() - - expect(screen.getByRole('combobox')).toBeInTheDocument() - }) - - it('renders textarea editor for primitive array type', () => { - const schema: FieldSchema = { - name: 'test_array', - type: 'array', - label: 'Test Array', - description: 'A test array', - required: false, - items: { - type: 'string', - }, - } - const onChange = vi.fn() - - render() - - expect(screen.getByRole('textbox')).toHaveValue('a\nb') - }) - - it('renders key-value editor for object type', () => { - const schema: FieldSchema = { - name: 'test_object', - type: 'object', - label: 'Test Object', - description: 'A test object', - required: false, - } - const onChange = vi.fn() - - render() - - expect(screen.getByText('可视化编辑')).toBeInTheDocument() - expect(screen.getByDisplayValue('foo')).toBeInTheDocument() - }) - }) - - describe('onChange events', () => { - it('triggers onChange for Switch', async () => { - const schema: FieldSchema = { - name: 'test_switch', - type: 'boolean', - label: 'Test Switch', - description: 'A test switch', - required: false, - default: false, - } - const onChange = vi.fn() - const user = userEvent.setup() - - render() - - await user.click(screen.getByRole('switch')) - expect(onChange).toHaveBeenCalledWith(true) - }) - - it('triggers onChange for Input', async () => { - const schema: FieldSchema = { - name: 'test_input', - type: 'string', - label: 'Test Input', - description: 'A test input', - required: false, - default: '', - } - const onChange = vi.fn() - - render() - - const input = screen.getByRole('textbox') - input.focus() - await userEvent.keyboard('Hello') - - expect(onChange).toHaveBeenCalledTimes(5) - expect(onChange).toHaveBeenNthCalledWith(1, 'H') - expect(onChange).toHaveBeenNthCalledWith(2, 'e') - expect(onChange).toHaveBeenNthCalledWith(3, 'l') - expect(onChange).toHaveBeenNthCalledWith(4, 'l') - expect(onChange).toHaveBeenNthCalledWith(5, 'o') - }) - - it('triggers onChange for number Input', async () => { - const schema: FieldSchema = { - name: 'test_number', - type: 'number', - label: 'Test Number', - description: 'A test number', - required: false, - default: 0, - } - const onChange = vi.fn() - const user = userEvent.setup() - - render() - - const input = screen.getByRole('spinbutton') - await user.clear(input) - await user.type(input, '123') - expect(onChange).toHaveBeenCalled() - }) - - it('triggers numeric onChange for input widget with integer type', async () => { - const schema: FieldSchema = { - name: 'test_integer_input_widget_change', - type: 'integer', - label: 'Test Integer Input Widget Change', - description: 'A numeric field rendered as input', - required: false, - 'x-widget': 'input', - default: 0, - } - const onChange = vi.fn() - const user = userEvent.setup() - - render() - - const input = screen.getByRole('spinbutton') - await user.clear(input) - await user.type(input, '5') - expect(onChange).toHaveBeenLastCalledWith(5) - }) - }) - - describe('visual features', () => { - it('renders label with icon', () => { - const schema: FieldSchema = { - name: 'test_icon', - type: 'string', - label: 'Test Icon', - description: 'A test with icon', - required: false, - 'x-icon': 'Settings', - } - const onChange = vi.fn() - - render() - - expect(screen.getByText('Test Icon')).toBeInTheDocument() - }) - - it('renders required indicator', () => { - const schema: FieldSchema = { - name: 'test_required', - type: 'string', - label: 'Test Required', - description: 'A required field', - required: true, - } - const onChange = vi.fn() - - render() - - expect(screen.getByText('*')).toBeInTheDocument() - }) - - it('renders description', () => { - const schema: FieldSchema = { - name: 'test_desc', - type: 'string', - label: 'Test Description', - description: 'This is a description', - required: false, - } - const onChange = vi.fn() - - render() - - expect(screen.getByText('This is a description')).toBeInTheDocument() - }) - }) - - describe('slider features', () => { - it('renders slider with min/max/step', () => { - const schema: FieldSchema = { - name: 'test_slider_props', - type: 'number', - label: 'Test Slider Props', - description: 'A slider with props', - required: false, - 'x-widget': 'slider', - minValue: 10, - maxValue: 50, - step: 5, - default: 25, - } - const onChange = vi.fn() - - render() - - expect(screen.getByText('10')).toBeInTheDocument() - expect(screen.getByText('50')).toBeInTheDocument() - expect(screen.getByText('25')).toBeInTheDocument() - }) - - it('parses string values for slider widgets', () => { - const schema: FieldSchema = { - name: 'test_slider_string_value', - type: 'number', - label: 'Test Slider String Value', - description: 'A slider with legacy string value', - required: false, - 'x-widget': 'slider', - minValue: 0, - maxValue: 10, - default: 0, - } - const onChange = vi.fn() - - render() - - expect(screen.getByText('2.5')).toBeInTheDocument() - }) - }) - - describe('select features', () => { - it('renders placeholder when no options', () => { - const schema: FieldSchema = { - name: 'test_select_no_options', - type: 'string', - label: 'Test Select No Options', - description: 'A select with no options', - required: false, - 'x-widget': 'select', - } - const onChange = vi.fn() - - render() - - expect(screen.getByText('No options available for select')).toBeInTheDocument() - }) - }) -}) diff --git a/dashboard/src/lib/__tests__/field-hooks.test.ts b/dashboard/src/lib/__tests__/field-hooks.test.ts deleted file mode 100644 index 4a4fd7f1..00000000 --- a/dashboard/src/lib/__tests__/field-hooks.test.ts +++ /dev/null @@ -1,253 +0,0 @@ -import { describe, it, expect, beforeEach } from 'vitest' - -import { FieldHookRegistry } from '../field-hooks' -import type { FieldHookComponent } from '../field-hooks' - -describe('FieldHookRegistry', () => { - let registry: FieldHookRegistry - - beforeEach(() => { - registry = new FieldHookRegistry() - }) - - describe('register', () => { - it('registers a hook with replace type', () => { - const component: FieldHookComponent = () => null - - registry.register('test.field', component, 'replace') - - expect(registry.has('test.field')).toBe(true) - }) - - it('registers a hook with wrapper type', () => { - const component: FieldHookComponent = () => null - - registry.register('test.field', component, 'wrapper') - - expect(registry.has('test.field')).toBe(true) - const entry = registry.get('test.field') - expect(entry?.type).toBe('wrapper') - }) - - it('defaults to replace type when not specified', () => { - const component: FieldHookComponent = () => null - - registry.register('test.field', component) - - const entry = registry.get('test.field') - expect(entry?.type).toBe('replace') - }) - - it('overwrites existing hook for same field path', () => { - const component1: FieldHookComponent = () => null - const component2: FieldHookComponent = () => null - - registry.register('test.field', component1, 'replace') - registry.register('test.field', component2, 'wrapper') - - const entry = registry.get('test.field') - expect(entry?.component).toBe(component2) - expect(entry?.type).toBe('wrapper') - }) - }) - - describe('get', () => { - it('returns hook entry for registered field path', () => { - const component: FieldHookComponent = () => null - - registry.register('test.field', component, 'replace') - - const entry = registry.get('test.field') - expect(entry).toBeDefined() - expect(entry?.component).toBe(component) - expect(entry?.type).toBe('replace') - }) - - it('returns undefined for unregistered field path', () => { - const entry = registry.get('nonexistent.field') - expect(entry).toBeUndefined() - }) - - it('returns correct entry for nested field paths', () => { - const component: FieldHookComponent = () => null - - registry.register('config.section.field', component, 'wrapper') - - const entry = registry.get('config.section.field') - expect(entry).toBeDefined() - expect(entry?.type).toBe('wrapper') - }) - }) - - describe('has', () => { - it('returns true for registered field path', () => { - const component: FieldHookComponent = () => null - - registry.register('test.field', component) - - expect(registry.has('test.field')).toBe(true) - }) - - it('returns false for unregistered field path', () => { - expect(registry.has('nonexistent.field')).toBe(false) - }) - - it('returns false after unregistering', () => { - const component: FieldHookComponent = () => null - - registry.register('test.field', component) - registry.unregister('test.field') - - expect(registry.has('test.field')).toBe(false) - }) - }) - - describe('unregister', () => { - it('removes a registered hook', () => { - const component: FieldHookComponent = () => null - - registry.register('test.field', component) - expect(registry.has('test.field')).toBe(true) - - registry.unregister('test.field') - expect(registry.has('test.field')).toBe(false) - }) - - it('does not throw when unregistering non-existent hook', () => { - expect(() => registry.unregister('nonexistent.field')).not.toThrow() - }) - - it('only removes specified hook, not others', () => { - const component1: FieldHookComponent = () => null - const component2: FieldHookComponent = () => null - - registry.register('field1', component1) - registry.register('field2', component2) - - registry.unregister('field1') - - expect(registry.has('field1')).toBe(false) - expect(registry.has('field2')).toBe(true) - }) - }) - - describe('clear', () => { - it('removes all registered hooks', () => { - const component1: FieldHookComponent = () => null - const component2: FieldHookComponent = () => null - const component3: FieldHookComponent = () => null - - registry.register('field1', component1) - registry.register('field2', component2) - registry.register('field3', component3) - - expect(registry.getAllPaths()).toHaveLength(3) - - registry.clear() - - expect(registry.getAllPaths()).toHaveLength(0) - expect(registry.has('field1')).toBe(false) - expect(registry.has('field2')).toBe(false) - expect(registry.has('field3')).toBe(false) - }) - - it('works correctly on empty registry', () => { - expect(() => registry.clear()).not.toThrow() - expect(registry.getAllPaths()).toHaveLength(0) - }) - }) - - describe('getAllPaths', () => { - it('returns empty array when no hooks registered', () => { - expect(registry.getAllPaths()).toEqual([]) - }) - - it('returns all registered field paths', () => { - const component: FieldHookComponent = () => null - - registry.register('field1', component) - registry.register('field2', component) - registry.register('field3', component) - - const paths = registry.getAllPaths() - expect(paths).toHaveLength(3) - expect(paths).toContain('field1') - expect(paths).toContain('field2') - expect(paths).toContain('field3') - }) - - it('returns updated paths after unregister', () => { - const component: FieldHookComponent = () => null - - registry.register('field1', component) - registry.register('field2', component) - registry.register('field3', component) - - registry.unregister('field2') - - const paths = registry.getAllPaths() - expect(paths).toHaveLength(2) - expect(paths).toContain('field1') - expect(paths).toContain('field3') - expect(paths).not.toContain('field2') - }) - - it('handles nested field paths correctly', () => { - const component: FieldHookComponent = () => null - - registry.register('config.chat.enabled', component) - registry.register('config.chat.model', component) - registry.register('config.api.key', component) - - const paths = registry.getAllPaths() - expect(paths).toHaveLength(3) - expect(paths).toContain('config.chat.enabled') - expect(paths).toContain('config.chat.model') - expect(paths).toContain('config.api.key') - }) - }) - - describe('integration scenarios', () => { - it('supports full lifecycle of multiple hooks', () => { - const replaceComponent: FieldHookComponent = () => null - const wrapperComponent: FieldHookComponent = () => null - - registry.register('field1', replaceComponent, 'replace') - registry.register('field2', wrapperComponent, 'wrapper') - - expect(registry.getAllPaths()).toHaveLength(2) - - const entry1 = registry.get('field1') - expect(entry1?.type).toBe('replace') - expect(entry1?.component).toBe(replaceComponent) - - const entry2 = registry.get('field2') - expect(entry2?.type).toBe('wrapper') - expect(entry2?.component).toBe(wrapperComponent) - - registry.unregister('field1') - expect(registry.getAllPaths()).toHaveLength(1) - expect(registry.has('field2')).toBe(true) - - registry.clear() - expect(registry.getAllPaths()).toHaveLength(0) - }) - - it('handles rapid register/unregister cycles', () => { - const component: FieldHookComponent = () => null - - for (let i = 0; i < 100; i++) { - registry.register(`field${i}`, component) - } - expect(registry.getAllPaths()).toHaveLength(100) - - for (let i = 0; i < 50; i++) { - registry.unregister(`field${i}`) - } - expect(registry.getAllPaths()).toHaveLength(50) - - registry.clear() - expect(registry.getAllPaths()).toHaveLength(0) - }) - }) -}) diff --git a/dashboard/src/routes/__tests__/plugin-config.test.tsx b/dashboard/src/routes/__tests__/plugin-config.test.tsx deleted file mode 100644 index 29cb31a5..00000000 --- a/dashboard/src/routes/__tests__/plugin-config.test.tsx +++ /dev/null @@ -1,80 +0,0 @@ -import { render, screen } from '@testing-library/react' -import { beforeEach, describe, expect, it, vi } from 'vitest' -import type { ReactNode } from 'react' - -import { PluginConfigPage } from '../plugin-config' -import * as pluginApi from '@/lib/plugin-api' - -const toastMock = vi.fn() - -vi.mock('@/hooks/use-toast', () => ({ - useToast: () => ({ toast: toastMock }), -})) - -vi.mock('@/lib/restart-context', () => ({ - RestartProvider: ({ children }: { children: ReactNode }) => <>{children}, - useRestart: () => ({ - showRestartPrompt: false, - markRestartRequired: vi.fn(), - clearRestartRequired: vi.fn(), - }), -})) - -vi.mock('@/components/restart-overlay', () => ({ - RestartOverlay: () => null, -})) - -vi.mock('@/components', () => ({ - CodeEditor: ({ value }: { value: string }) =>
{value}
, - ListFieldEditor: () =>
list-field-editor
, -})) - -vi.mock('@/lib/plugin-api', () => ({ - getInstalledPlugins: vi.fn(), - getPluginConfigSchema: vi.fn(), - getPluginConfig: vi.fn(), - getPluginConfigRaw: vi.fn(), - updatePluginConfig: vi.fn(), - updatePluginConfigRaw: vi.fn(), - resetPluginConfig: vi.fn(), - togglePlugin: vi.fn(), -})) - -describe('PluginConfigPage', () => { - beforeEach(() => { - toastMock.mockReset() - vi.mocked(pluginApi.getInstalledPlugins).mockResolvedValue({ - success: true, - data: [ - { - id: 'test.emoji', - path: '/plugins/test_emoji', - manifest: { - manifest_version: 2, - name: 'Emoji Plugin', - version: '1.0.0', - description: 'emoji tools', - author: { name: 'tester' }, - license: 'MIT', - host_application: { min_version: '1.0.0' }, - }, - }, - ], - }) - vi.mocked(pluginApi.getPluginConfigSchema).mockResolvedValue({} as never) - vi.mocked(pluginApi.getPluginConfig).mockResolvedValue({} as never) - vi.mocked(pluginApi.getPluginConfigRaw).mockResolvedValue({} as never) - vi.mocked(pluginApi.updatePluginConfig).mockResolvedValue({} as never) - vi.mocked(pluginApi.updatePluginConfigRaw).mockResolvedValue({} as never) - vi.mocked(pluginApi.resetPluginConfig).mockResolvedValue({} as never) - vi.mocked(pluginApi.togglePlugin).mockResolvedValue({} as never) - }) - - it('shows real plugins and no longer surfaces A_Memorix in plugin config list', async () => { - render() - - expect(await screen.findByText('Emoji Plugin')).toBeInTheDocument() - expect(screen.getByText('点击插件查看和编辑配置')).toBeInTheDocument() - expect(screen.queryByText(/A_Memorix/i)).not.toBeInTheDocument() - }) -}) diff --git a/dashboard/src/routes/resource/__tests__/knowledge-base.test.tsx b/dashboard/src/routes/resource/__tests__/knowledge-base.test.tsx deleted file mode 100644 index 1276f4ed..00000000 --- a/dashboard/src/routes/resource/__tests__/knowledge-base.test.tsx +++ /dev/null @@ -1,802 +0,0 @@ -import { act, render, screen, waitFor, within } from '@testing-library/react' -import userEvent from '@testing-library/user-event' -import { beforeEach, describe, expect, it, vi } from 'vitest' - -import { KnowledgeBasePage } from '../knowledge-base' -import * as memoryApi from '@/lib/memory-api' - -const navigateMock = vi.fn() -const toastMock = vi.fn() - -vi.mock('@tanstack/react-router', () => ({ - useNavigate: () => navigateMock, -})) - -vi.mock('@/hooks/use-toast', () => ({ - useToast: () => ({ toast: toastMock }), -})) - -vi.mock('@/components', () => ({ - CodeEditor: ({ value }: { value: string }) =>
{value}
, - MarkdownRenderer: ({ content }: { content: string }) =>
{content}
, -})) - -vi.mock('@/components/memory/MemoryConfigEditor', () => ({ - MemoryConfigEditor: () =>
memory-config-editor
, -})) - -vi.mock('@/components/memory/MemoryDeleteDialog', () => ({ - MemoryDeleteDialog: ({ - open, - onExecute, - onRestore, - preview, - result, - }: { - open: boolean - preview?: { mode?: string; item_count?: number } | null - result?: { operation_id?: string } | null - onExecute?: () => void - onRestore?: () => void - }) => ( - open ? ( -
-
{`preview:${preview?.mode ?? 'none'}:${preview?.item_count ?? 0}`}
-
{`result:${result?.operation_id ?? 'none'}`}
- - -
- ) : null - ), -})) - -vi.mock('@/lib/memory-api', () => ({ - getMemoryConfigSchema: vi.fn(), - getMemoryConfig: vi.fn(), - getMemoryConfigRaw: vi.fn(), - getMemoryRuntimeConfig: vi.fn(), - getMemoryImportGuide: vi.fn(), - getMemoryImportSettings: vi.fn(), - getMemoryImportPathAliases: vi.fn(), - getMemoryImportTasks: vi.fn(), - getMemoryImportTask: vi.fn(), - getMemoryImportTaskChunks: vi.fn(), - createMemoryUploadImport: vi.fn(), - createMemoryPasteImport: vi.fn(), - createMemoryRawScanImport: vi.fn(), - createMemoryLpmmOpenieImport: vi.fn(), - createMemoryLpmmConvertImport: vi.fn(), - createMemoryTemporalBackfillImport: vi.fn(), - createMemoryMaibotMigrationImport: vi.fn(), - cancelMemoryImportTask: vi.fn(), - retryMemoryImportTask: vi.fn(), - resolveMemoryImportPath: vi.fn(), - refreshMemoryRuntimeSelfCheck: vi.fn(), - updateMemoryConfig: vi.fn(), - updateMemoryConfigRaw: vi.fn(), - getMemoryTuningProfile: vi.fn(), - getMemoryTuningTasks: vi.fn(), - createMemoryTuningTask: vi.fn(), - applyBestMemoryTuningProfile: vi.fn(), - getMemorySources: vi.fn(), - getMemoryDeleteOperations: vi.fn(), - getMemoryDeleteOperation: vi.fn(), - getMemoryFeedbackCorrections: vi.fn(), - getMemoryFeedbackCorrection: vi.fn(), - previewMemoryDelete: vi.fn(), - executeMemoryDelete: vi.fn(), - restoreMemoryDelete: vi.fn(), - rollbackMemoryFeedbackCorrection: vi.fn(), -})) - -function mockImportTask(taskId: string, status: string = 'running'): memoryApi.MemoryImportTaskPayload { - return { - task_id: taskId, - source: 'webui', - status, - current_step: status === 'completed' ? 'completed' : 'running', - total_chunks: 120, - done_chunks: status === 'completed' ? 120 : 36, - failed_chunks: status === 'completed' ? 0 : 2, - cancelled_chunks: 0, - progress: status === 'completed' ? 100 : 30, - error: '', - file_count: 2, - created_at: 1_710_000_000, - started_at: 1_710_000_001, - finished_at: status === 'completed' ? 1_710_000_099 : null, - updated_at: 1_710_000_100, - task_kind: 'paste', - params: {}, - files: [], - } -} - -function mockImportDetail(taskId: string): memoryApi.MemoryImportTaskPayload { - return { - ...mockImportTask(taskId), - files: [ - { - file_id: 'file-alpha', - name: 'alpha.txt', - source_kind: 'paste', - input_mode: 'text', - status: 'running', - current_step: 'running', - detected_strategy_type: 'auto', - total_chunks: 80, - done_chunks: 30, - failed_chunks: 1, - cancelled_chunks: 0, - progress: 37.5, - error: '', - created_at: 1_710_000_000, - updated_at: 1_710_000_100, - }, - { - file_id: 'file-beta', - name: 'beta.txt', - source_kind: 'paste', - input_mode: 'text', - status: 'failed', - current_step: 'extracting', - detected_strategy_type: 'auto', - total_chunks: 40, - done_chunks: 6, - failed_chunks: 4, - cancelled_chunks: 0, - progress: 25, - error: 'mock error', - created_at: 1_710_000_000, - updated_at: 1_710_000_100, - }, - ], - } -} - -function mockImportCompletedWithErrorsDetail(taskId: string): memoryApi.MemoryImportTaskPayload { - return { - ...mockImportDetail(taskId), - status: 'completed_with_errors', - current_step: 'completed_with_errors', - total_chunks: 12, - done_chunks: 9, - failed_chunks: 3, - cancelled_chunks: 0, - progress: 75, - files: [ - { - file_id: 'file-error', - name: 'error.txt', - source_kind: 'paste', - input_mode: 'text', - status: 'failed', - current_step: 'failed', - detected_strategy_type: 'auto', - total_chunks: 12, - done_chunks: 9, - failed_chunks: 3, - cancelled_chunks: 0, - progress: 75, - error: 'mock error', - created_at: 1_710_000_000, - updated_at: 1_710_000_100, - }, - ], - } -} - -describe('KnowledgeBasePage import workflow', () => { - beforeEach(() => { - navigateMock.mockReset() - toastMock.mockReset() - - vi.mocked(memoryApi.getMemoryConfigSchema).mockResolvedValue({ - success: true, - path: 'config/a_memorix.toml', - schema: { - plugin_id: 'a_memorix', - plugin_info: { - name: 'A_Memorix', - version: '2.0.0', - description: '长期记忆子系统', - author: 'A_Dawn', - }, - _note: 'raw-only 字段仍可通过 TOML 编辑', - layout: { - type: 'tabs', - tabs: [{ id: 'basic', title: '基础', sections: ['plugin'], order: 1 }], - }, - sections: { - plugin: { - name: 'plugin', - title: '子系统状态', - collapsed: false, - order: 1, - fields: {}, - }, - }, - }, - }) - vi.mocked(memoryApi.getMemoryConfig).mockResolvedValue({ - success: true, - path: 'config/a_memorix.toml', - config: { plugin: { enabled: true } }, - }) - vi.mocked(memoryApi.getMemoryConfigRaw).mockResolvedValue({ - success: true, - path: 'config/a_memorix.toml', - config: '[plugin]\nenabled = true\n', - }) - vi.mocked(memoryApi.getMemoryRuntimeConfig).mockResolvedValue({ - success: true, - config: { plugin: { enabled: true } }, - data_dir: 'data/plugins/a-dawn.a-memorix', - embedding_dimension: 1024, - auto_save: true, - relation_vectors_enabled: false, - runtime_ready: true, - embedding_degraded: false, - embedding_degraded_reason: '', - embedding_degraded_since: null, - embedding_last_check: null, - paragraph_vector_backfill_pending: 2, - paragraph_vector_backfill_running: 0, - paragraph_vector_backfill_failed: 1, - paragraph_vector_backfill_done: 3, - }) - - vi.mocked(memoryApi.getMemoryImportGuide).mockResolvedValue({ - success: true, - content: '# 导入指南\n导入说明', - }) - vi.mocked(memoryApi.getMemoryImportSettings).mockResolvedValue({ - success: true, - settings: { - max_paste_chars: 200_000, - max_file_concurrency: 8, - max_chunk_concurrency: 16, - default_file_concurrency: 2, - default_chunk_concurrency: 4, - poll_interval_ms: 60_000, - maibot_source_db_default: 'data/maibot.db', - }, - }) - vi.mocked(memoryApi.getMemoryImportPathAliases).mockResolvedValue({ - success: true, - path_aliases: { - lpmm: 'data/lpmm', - plugin_data: 'data/plugins/a-dawn.a-memorix', - raw: 'data/raw', - }, - }) - vi.mocked(memoryApi.getMemoryImportTasks).mockResolvedValue({ - success: true, - items: [ - mockImportTask('import-run-1', 'running'), - mockImportTask('import-queued-1', 'queued'), - mockImportTask('import-done-1', 'completed'), - ], - }) - vi.mocked(memoryApi.getMemoryImportTask).mockResolvedValue({ - success: true, - task: mockImportDetail('import-run-1'), - }) - vi.mocked(memoryApi.getMemoryImportTaskChunks).mockImplementation(async (_taskId, fileId, offset = 0) => ({ - success: true, - task_id: 'import-run-1', - file_id: fileId, - offset, - limit: 50, - total: 120, - items: [ - { - chunk_id: `${fileId}-${offset + 0}`, - index: offset + 0, - chunk_type: 'text', - status: 'running', - step: 'extracting', - failed_at: '', - retryable: true, - error: '', - progress: 50, - content_preview: `chunk-preview-${offset + 0}`, - updated_at: 1_710_000_111, - }, - ], - })) - - vi.mocked(memoryApi.createMemoryUploadImport).mockResolvedValue({ - success: true, - task: mockImportTask('upload-task-1', 'queued'), - }) - vi.mocked(memoryApi.createMemoryPasteImport).mockResolvedValue({ - success: true, - task: mockImportTask('paste-task-1', 'queued'), - }) - vi.mocked(memoryApi.createMemoryRawScanImport).mockResolvedValue({ - success: true, - task: mockImportTask('raw-task-1', 'queued'), - }) - vi.mocked(memoryApi.createMemoryLpmmOpenieImport).mockResolvedValue({ - success: true, - task: mockImportTask('openie-task-1', 'queued'), - }) - vi.mocked(memoryApi.createMemoryLpmmConvertImport).mockResolvedValue({ - success: true, - task: mockImportTask('convert-task-1', 'queued'), - }) - vi.mocked(memoryApi.createMemoryTemporalBackfillImport).mockResolvedValue({ - success: true, - task: mockImportTask('backfill-task-1', 'queued'), - }) - vi.mocked(memoryApi.createMemoryMaibotMigrationImport).mockResolvedValue({ - success: true, - task: mockImportTask('migration-task-1', 'queued'), - }) - vi.mocked(memoryApi.cancelMemoryImportTask).mockResolvedValue({ - success: true, - task: mockImportTask('import-run-1', 'cancel_requested'), - }) - vi.mocked(memoryApi.retryMemoryImportTask).mockResolvedValue({ - success: true, - task: mockImportTask('retry-task-1', 'queued'), - }) - vi.mocked(memoryApi.resolveMemoryImportPath).mockResolvedValue({ - success: true, - alias: 'raw', - relative_path: 'exports', - resolved_path: 'D:/Dev/rdev/MaiBot/data/raw/exports', - exists: true, - is_file: false, - is_dir: true, - }) - - vi.mocked(memoryApi.getMemoryTuningProfile).mockResolvedValue({ - success: true, - profile: { retrieval: { top_k: 10 } }, - toml: '[retrieval]\ntop_k = 10\n', - }) - vi.mocked(memoryApi.getMemoryTuningTasks).mockResolvedValue({ - success: true, - items: [{ task_id: 'tune-1', status: 'done' }], - }) - vi.mocked(memoryApi.createMemoryTuningTask).mockResolvedValue({ success: true } as never) - vi.mocked(memoryApi.applyBestMemoryTuningProfile).mockResolvedValue({ success: true } as never) - - vi.mocked(memoryApi.getMemorySources).mockResolvedValue({ - success: true, - items: [{ source: 'demo-1', paragraph_count: 2, relation_count: 1 }], - count: 1, - }) - vi.mocked(memoryApi.getMemoryDeleteOperations).mockResolvedValue({ - success: true, - items: [ - { - operation_id: 'del-1', - mode: 'source', - status: 'executed', - summary: { counts: { paragraphs: 2, relations: 1, sources: 1 } }, - }, - ], - count: 1, - }) - vi.mocked(memoryApi.getMemoryDeleteOperation).mockResolvedValue({ - success: true, - operation: { - operation_id: 'del-1', - mode: 'source', - status: 'executed', - selector: { sources: ['demo-1'] }, - summary: { counts: { paragraphs: 2, relations: 1, sources: 1 }, sources: ['demo-1'] }, - items: [], - }, - }) - vi.mocked(memoryApi.getMemoryFeedbackCorrections).mockResolvedValue({ - success: true, - items: [ - { - task_id: 11, - query_tool_id: 'tool-query-11', - session_id: 'session-1', - query_text: '测试用户最喜欢的颜色是什么', - query_timestamp: 1_710_000_010, - task_status: 'applied', - decision: 'correct', - decision_confidence: 0.97, - feedback_message_count: 1, - rollback_status: 'none', - affected_counts: { - relations: 1, - stale_paragraphs: 1, - episode_sources: 2, - profile_person_ids: 1, - correction_paragraphs: 1, - corrected_relations: 1, - }, - created_at: 1_710_000_011, - updated_at: 1_710_000_012, - }, - ], - count: 1, - }) - vi.mocked(memoryApi.getMemoryFeedbackCorrection).mockResolvedValue({ - success: true, - task: { - task_id: 11, - query_tool_id: 'tool-query-11', - session_id: 'session-1', - query_text: '测试用户最喜欢的颜色是什么', - query_timestamp: 1_710_000_010, - task_status: 'applied', - decision: 'correct', - decision_confidence: 0.97, - feedback_message_count: 1, - rollback_status: 'none', - affected_counts: { - relations: 1, - stale_paragraphs: 1, - episode_sources: 2, - profile_person_ids: 1, - correction_paragraphs: 1, - corrected_relations: 1, - }, - query_snapshot: { query: '测试用户最喜欢的颜色是什么', hits: [{ hash: 'paragraph-1' }] }, - decision_payload: { decision: 'correct', confidence: 0.97 }, - rollback_plan_summary: { - forgotten_relations: [{ hash: 'rel-old', subject: '测试用户', predicate: '最喜欢的颜色是', object: '蓝色' }], - corrected_write: { - paragraph_hashes: ['paragraph-new'], - corrected_relations: [{ hash: 'rel-new', subject: '测试用户', predicate: '最喜欢的颜色是', object: '绿色' }], - }, - }, - rollback_result: {}, - action_logs: [ - { - id: 1, - task_id: 11, - query_tool_id: 'tool-query-11', - action_type: 'forget_relation', - target_hash: 'rel-old', - reason: '用户明确纠正为绿色', - before_payload: { hash: 'rel-old', subject: '测试用户', predicate: '最喜欢的颜色是', object: '蓝色' }, - after_payload: { is_inactive: true }, - created_at: 1_710_000_013, - }, - ], - created_at: 1_710_000_011, - updated_at: 1_710_000_012, - }, - }) - vi.mocked(memoryApi.previewMemoryDelete).mockResolvedValue({ - success: true, - mode: 'source', - selector: { sources: ['demo-1'] }, - counts: { sources: 1, paragraphs: 2, relations: 1 }, - sources: ['demo-1'], - items: [{ item_type: 'paragraph', item_hash: 'p-1', label: 'demo-1' }], - item_count: 1, - dry_run: true, - } as never) - vi.mocked(memoryApi.executeMemoryDelete).mockResolvedValue({ - success: true, - mode: 'source', - operation_id: 'del-2', - counts: { sources: 1, paragraphs: 2, relations: 1 }, - sources: ['demo-1'], - deleted_count: 4, - deleted_entity_count: 0, - deleted_relation_count: 1, - deleted_paragraph_count: 2, - deleted_source_count: 1, - } as never) - vi.mocked(memoryApi.restoreMemoryDelete).mockResolvedValue({ success: true } as never) - vi.mocked(memoryApi.rollbackMemoryFeedbackCorrection).mockResolvedValue({ - success: true, - result: { restored_relation_hashes: ['rel-old'] }, - task: { - task_id: 11, - query_tool_id: 'tool-query-11', - session_id: 'session-1', - query_text: '测试用户最喜欢的颜色是什么', - query_timestamp: 1_710_000_010, - task_status: 'applied', - decision: 'correct', - decision_confidence: 0.97, - feedback_message_count: 1, - rollback_status: 'rolled_back', - affected_counts: { - relations: 1, - stale_paragraphs: 1, - episode_sources: 2, - profile_person_ids: 1, - correction_paragraphs: 1, - corrected_relations: 1, - }, - query_snapshot: { query: '测试用户最喜欢的颜色是什么', hits: [{ hash: 'paragraph-1' }] }, - decision_payload: { decision: 'correct', confidence: 0.97 }, - rollback_plan_summary: {}, - rollback_result: { restored_relation_hashes: ['rel-old'] }, - action_logs: [], - created_at: 1_710_000_011, - updated_at: 1_710_000_012, - }, - }) - vi.mocked(memoryApi.refreshMemoryRuntimeSelfCheck).mockResolvedValue({ - success: true, - report: { ok: true }, - }) - vi.mocked(memoryApi.updateMemoryConfig).mockResolvedValue({ success: true } as never) - vi.mocked(memoryApi.updateMemoryConfigRaw).mockResolvedValue({ success: true } as never) - }) - - it('loads import settings/guide/tasks on first render', async () => { - const user = userEvent.setup() - render() - - expect(await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 })).toBeInTheDocument() - await user.click(screen.getByRole('tab', { name: '导入' })) - - expect(await screen.findByRole('button', { name: '创建导入任务' })).toBeInTheDocument() - expect((await screen.findAllByText('import-run-1')).length).toBeGreaterThan(0) - expect(memoryApi.getMemoryImportSettings).toHaveBeenCalled() - expect(memoryApi.getMemoryImportPathAliases).toHaveBeenCalled() - expect(memoryApi.getMemoryImportTasks).toHaveBeenCalled() - }) - - it('creates import tasks for all 7 modes and calls correct endpoints', async () => { - const user = userEvent.setup() - const { container } = render() - - const openImportTab = async () => { - await user.click(screen.getByRole('tab', { name: '导入' })) - await screen.findByRole('button', { name: '创建导入任务' }) - } - - await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 }) - await openImportTab() - - const fileInput = container.querySelector('input[type="file"]') as HTMLInputElement - const uploadFiles = [ - new File(['hello'], 'demo.txt', { type: 'text/plain' }), - new File(['{"name":"mai"}'], 'demo.json', { type: 'application/json' }), - new File(['a,b\n1,2'], 'demo.csv', { type: 'text/csv' }), - new File(['# note'], 'demo.md', { type: 'text/markdown' }), - ] - await user.upload(fileInput, uploadFiles) - await user.click(screen.getByRole('button', { name: '创建导入任务' })) - await waitFor(() => expect(memoryApi.createMemoryUploadImport).toHaveBeenCalledTimes(1)) - - await openImportTab() - await user.click(screen.getByRole('tab', { name: '粘贴导入' })) - const editableTextarea = Array.from(container.querySelectorAll('textarea')).find((item) => !item.readOnly) - if (!editableTextarea) { - throw new Error('missing editable textarea') - } - await user.type(editableTextarea, 'paste content') - await user.click(screen.getByRole('button', { name: '创建导入任务' })) - await waitFor(() => expect(memoryApi.createMemoryPasteImport).toHaveBeenCalledTimes(1)) - - await openImportTab() - await user.click(screen.getByRole('tab', { name: '本地扫描' })) - await user.click(screen.getByRole('button', { name: '创建导入任务' })) - await waitFor(() => expect(memoryApi.createMemoryRawScanImport).toHaveBeenCalledTimes(1)) - - await openImportTab() - await user.click(screen.getByRole('tab', { name: 'LPMM OpenIE' })) - await user.click(screen.getByRole('button', { name: '创建导入任务' })) - await waitFor(() => expect(memoryApi.createMemoryLpmmOpenieImport).toHaveBeenCalledTimes(1)) - - await openImportTab() - await user.click(screen.getByRole('tab', { name: 'LPMM 转换' })) - await user.click(screen.getByRole('button', { name: '创建导入任务' })) - await waitFor(() => expect(memoryApi.createMemoryLpmmConvertImport).toHaveBeenCalledTimes(1)) - - await openImportTab() - await user.click(screen.getByRole('tab', { name: '时序回填' })) - await user.click(screen.getByRole('button', { name: '创建导入任务' })) - await waitFor(() => expect(memoryApi.createMemoryTemporalBackfillImport).toHaveBeenCalledTimes(1)) - - await openImportTab() - await user.click(screen.getByRole('tab', { name: 'MaiBot 迁移' })) - await user.click(screen.getByRole('button', { name: '创建导入任务' })) - await waitFor(() => expect(memoryApi.createMemoryMaibotMigrationImport).toHaveBeenCalledTimes(1)) - - const [uploadedFiles, uploadPayload] = vi.mocked(memoryApi.createMemoryUploadImport).mock.calls[0] - expect(uploadedFiles).toHaveLength(4) - expect(uploadedFiles.map((file) => file.name)).toEqual(['demo.txt', 'demo.json', 'demo.csv', 'demo.md']) - expect(uploadPayload).toMatchObject({ - input_mode: 'text', - llm_enabled: true, - strategy_override: 'auto', - dedupe_policy: 'content_hash', - }) - }, 60_000) - - it('loads task detail and supports chunk pagination', async () => { - const user = userEvent.setup() - render() - - await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 }) - await user.click(screen.getByRole('tab', { name: '导入' })) - - expect(await screen.findByText('alpha.txt')).toBeInTheDocument() - expect(await screen.findByText('chunk-preview-0')).toBeInTheDocument() - - const betaButton = screen.getByText('beta.txt').closest('button') - if (!betaButton) { - throw new Error('missing file beta button') - } - await user.click(betaButton) - await waitFor(() => - expect(memoryApi.getMemoryImportTaskChunks).toHaveBeenCalledWith('import-run-1', 'file-beta', 0, 50), - ) - - await user.click(screen.getByRole('button', { name: '下一页分块' })) - await waitFor(() => - expect(memoryApi.getMemoryImportTaskChunks).toHaveBeenCalledWith('import-run-1', 'file-beta', 50, 50), - ) - }, 20_000) - - it('shows import failures separately from successful chunks', async () => { - vi.mocked(memoryApi.getMemoryImportTask).mockResolvedValue({ - success: true, - task: mockImportCompletedWithErrorsDetail('import-run-1'), - }) - const user = userEvent.setup() - render() - - await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 }) - await user.click(screen.getByRole('tab', { name: '导入' })) - - expect((await screen.findAllByText('完成(有错误)')).length).toBeGreaterThan(0) - expect(await screen.findByText('成功 9 / 12 分块 · 失败 3')).toBeInTheDocument() - }, 20_000) - - it('supports cancel and retry actions for selected task', async () => { - const user = userEvent.setup() - render() - - await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 }) - await user.click(screen.getByRole('tab', { name: '导入' })) - await screen.findByText('任务详情') - - await user.click(screen.getByRole('button', { name: '取消选中导入任务' })) - await waitFor(() => expect(memoryApi.cancelMemoryImportTask).toHaveBeenCalledWith('import-run-1')) - - await user.click(screen.getByRole('button', { name: '重试选中导入任务' })) - await waitFor(() => expect(memoryApi.retryMemoryImportTask).toHaveBeenCalled()) - const [taskId, retryPayload] = vi.mocked(memoryApi.retryMemoryImportTask).mock.calls[0] - expect(taskId).toBe('import-run-1') - expect(retryPayload).toMatchObject({ - overrides: { - llm_enabled: true, - strategy_override: 'auto', - }, - }) - }, 20_000) - - it('auto polling updates queue and keeps page stable when refresh fails once', async () => { - vi.mocked(memoryApi.getMemoryImportSettings).mockResolvedValue({ - success: true, - settings: { - max_paste_chars: 200_000, - max_file_concurrency: 8, - max_chunk_concurrency: 16, - default_file_concurrency: 2, - default_chunk_concurrency: 4, - poll_interval_ms: 200, - maibot_source_db_default: 'data/maibot.db', - }, - }) - const user = userEvent.setup() - render() - - await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 }) - await user.click(screen.getByRole('tab', { name: '导入' })) - await screen.findByText('导入队列') - - const initialCalls = vi.mocked(memoryApi.getMemoryImportTasks).mock.calls.length - vi.mocked(memoryApi.getMemoryImportTasks).mockRejectedValueOnce(new Error('poll failure')) - await act(async () => { - await new Promise((resolve) => setTimeout(resolve, 350)) - }) - - expect(screen.getByText('长期记忆控制台')).toBeInTheDocument() - expect(vi.mocked(memoryApi.getMemoryImportTasks).mock.calls.length).toBeGreaterThan(initialCalls) - }, 20_000) - - it('creates tuning task and applies best profile (tuning module)', async () => { - const user = userEvent.setup() - render() - - await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 }) - await user.click(screen.getByRole('tab', { name: '调优' })) - await screen.findByText('调优任务') - - await user.click(screen.getByRole('button', { name: '创建调优任务' })) - await waitFor(() => - expect(memoryApi.createMemoryTuningTask).toHaveBeenCalledWith({ - objective: 'precision_priority', - intensity: 'standard', - sample_size: 24, - top_k_eval: 20, - }), - ) - - await user.click(screen.getByRole('button', { name: '应用最佳' })) - await waitFor(() => expect(memoryApi.applyBestMemoryTuningProfile).toHaveBeenCalledWith('tune-1')) - }, 20_000) - - it('previews executes and restores source delete (delete module)', async () => { - const user = userEvent.setup() - render() - - await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 }) - await user.click(screen.getByRole('tab', { name: '删除' })) - await screen.findByText('来源批量删除') - - const sourceCellCandidates = await screen.findAllByText('demo-1') - const sourceRow = sourceCellCandidates - .map((item) => item.closest('tr')) - .find((row): row is HTMLTableRowElement => Boolean(row && within(row).queryByRole('checkbox'))) - if (!sourceRow) { - throw new Error('missing source row') - } - await user.click(within(sourceRow).getByRole('checkbox')) - - await user.click(screen.getByRole('button', { name: '预览删除' })) - await waitFor(() => - expect(memoryApi.previewMemoryDelete).toHaveBeenCalledWith({ - mode: 'source', - selector: { sources: ['demo-1'] }, - reason: 'knowledge_base_source_delete', - requested_by: 'knowledge_base', - }), - ) - - const dialog = await screen.findByTestId('memory-delete-dialog') - expect(dialog).toHaveTextContent('preview:source:1') - - await user.click(screen.getByRole('button', { name: '执行删除' })) - await waitFor(() => - expect(memoryApi.executeMemoryDelete).toHaveBeenCalledWith({ - mode: 'source', - selector: { sources: ['demo-1'] }, - reason: 'knowledge_base_source_delete', - requested_by: 'knowledge_base', - }), - ) - - await user.click(screen.getByRole('button', { name: '执行恢复' })) - await waitFor(() => - expect(memoryApi.restoreMemoryDelete).toHaveBeenCalledWith({ - operation_id: 'del-2', - requested_by: 'knowledge_base', - }), - ) - }, 20_000) - - it('shows feedback correction history and supports rollback', async () => { - const user = userEvent.setup() - render() - - await screen.findByText('长期记忆控制台', undefined, { timeout: 10_000 }) - await user.click(screen.getByRole('tab', { name: '纠错历史' })) - await screen.findByText('反馈纠错历史') - await screen.findByText('测试用户最喜欢的颜色是什么') - await waitFor(() => expect(memoryApi.getMemoryFeedbackCorrection).toHaveBeenCalledWith(11)) - - await user.click(screen.getByRole('button', { name: '回退本次纠错' })) - const rollbackReason = await screen.findByLabelText('回退原因') - await user.type(rollbackReason, '人工确认回退') - await user.click(screen.getByRole('button', { name: '确认回退' })) - - await waitFor(() => - expect(memoryApi.rollbackMemoryFeedbackCorrection).toHaveBeenCalledWith(11, { - requested_by: 'knowledge_base', - reason: '人工确认回退', - }), - ) - }, 20_000) -}) diff --git a/dashboard/src/routes/resource/__tests__/knowledge-graph.test.tsx b/dashboard/src/routes/resource/__tests__/knowledge-graph.test.tsx deleted file mode 100644 index 5445c7cb..00000000 --- a/dashboard/src/routes/resource/__tests__/knowledge-graph.test.tsx +++ /dev/null @@ -1,440 +0,0 @@ -import { render, screen, waitFor } from '@testing-library/react' -import userEvent from '@testing-library/user-event' -import { beforeEach, describe, expect, it, vi } from 'vitest' - -import { KnowledgeGraphPage } from '../knowledge-graph' -import * as memoryApi from '@/lib/memory-api' - -const navigateMock = vi.fn() -const toastMock = vi.fn() - -vi.mock('@tanstack/react-router', () => ({ - useNavigate: () => navigateMock, -})) - -vi.mock('@/hooks/use-toast', () => ({ - useToast: () => ({ toast: toastMock }), -})) - -vi.mock('@/components/memory/MemoryDeleteDialog', () => ({ - MemoryDeleteDialog: ({ - open, - preview, - }: { - open: boolean - preview?: { mode?: string; item_count?: number } | null - }) => ( - open ?
{`delete:${preview?.mode ?? 'none'}:${preview?.item_count ?? 0}`}
: null - ), -})) - -vi.mock('../knowledge-graph/GraphVisualization', () => ({ - GraphVisualization: ({ - graphData, - onNodeClick, - onEdgeClick, - }: { - graphData: { nodes: Array<{ id: string }>; edges: Array<{ source: string; target: string }> } - onNodeClick: (event: React.MouseEvent, node: { id: string }) => void - onEdgeClick: (event: React.MouseEvent, edge: { source: string; target: string }) => void - }) => ( -
-
{`nodes:${graphData.nodes.length},edges:${graphData.edges.length}`}
- {graphData.nodes[0] ? ( - - ) : null} - {graphData.edges[0] ? ( - - ) : null} -
- ), -})) - -vi.mock('../knowledge-graph/GraphDialogs', () => ({ - NodeDetailDialog: ({ - selectedNodeData, - nodeDetail, - onOpenEvidence, - onDeleteEntity, - }: { - selectedNodeData: { id: string } | null - nodeDetail: { relations?: Array<{ predicate: string }>; paragraphs?: Array } | null - onOpenEvidence?: () => void - onDeleteEntity?: (options: { includeParagraphs: boolean }) => void - }) => ( - selectedNodeData ? ( -
-
{`node:${selectedNodeData.id}`}
-
{`relations:${nodeDetail?.relations?.[0]?.predicate ?? 'none'}`}
-
{`paragraphs:${nodeDetail?.paragraphs?.length ?? 0}`}
- - -
- ) : null - ), - EdgeDetailDialog: ({ - selectedEdgeData, - edgeDetail, - onOpenEvidence, - }: { - selectedEdgeData: { source: { id: string }; target: { id: string } } | null - edgeDetail: { edge?: { predicates?: string[] }; paragraphs?: Array } | null - onOpenEvidence?: () => void - }) => ( - selectedEdgeData ? ( -
-
{`edge:${selectedEdgeData.source.id}->${selectedEdgeData.target.id}`}
-
{`predicates:${edgeDetail?.edge?.predicates?.join(',') ?? 'none'}`}
-
{`paragraphs:${edgeDetail?.paragraphs?.length ?? 0}`}
- -
- ) : null - ), - RelationDetailDialog: () => null, - ParagraphDetailDialog: () => null, -})) - -vi.mock('@/lib/memory-api', () => ({ - getMemoryGraph: vi.fn(), - getMemoryGraphSearch: vi.fn(), - getMemoryGraphNodeDetail: vi.fn(), - getMemoryGraphEdgeDetail: vi.fn(), - previewMemoryDelete: vi.fn(), - executeMemoryDelete: vi.fn(), - restoreMemoryDelete: vi.fn(), -})) - -describe('KnowledgeGraphPage', () => { - beforeEach(() => { - navigateMock.mockReset() - toastMock.mockReset() - vi.mocked(memoryApi.getMemoryGraph).mockResolvedValue({ - success: true, - nodes: [ - { id: 'alpha', name: 'Alpha' }, - { id: 'beta', name: 'Beta' }, - ], - edges: [ - { - source: 'alpha', - target: 'beta', - weight: 1, - predicates: ['关联'], - relation_count: 1, - evidence_count: 2, - relation_hashes: ['rel-1'], - label: '关联', - }, - ], - total_nodes: 2, - total_edges: 1, - }) - vi.mocked(memoryApi.getMemoryGraphSearch).mockResolvedValue({ - success: true, - query: 'alpha', - limit: 50, - count: 0, - items: [], - }) - vi.mocked(memoryApi.getMemoryGraphNodeDetail).mockResolvedValue({ - success: true, - node: { id: 'alpha', type: 'entity', content: 'Alpha', hash: 'entity-1', appearance_count: 3 }, - relations: [ - { - hash: 'rel-1', - subject: 'alpha', - predicate: '关联', - object: 'beta', - text: 'alpha 关联 beta', - confidence: 0.9, - paragraph_count: 1, - paragraph_hashes: ['p-1'], - source_paragraph: 'p-1', - }, - ], - paragraphs: [ - { - hash: 'p-1', - content: 'Alpha 提到了 Beta', - preview: 'Alpha 提到了 Beta', - source: 'demo', - entity_count: 2, - relation_count: 1, - entities: ['Alpha', 'Beta'], - relations: ['alpha 关联 beta'], - }, - ], - evidence_graph: { - nodes: [ - { id: 'entity:alpha', type: 'entity', content: 'Alpha' }, - { id: 'relation:rel-1', type: 'relation', content: 'alpha 关联 beta' }, - { id: 'paragraph:p-1', type: 'paragraph', content: 'Alpha 提到了 Beta' }, - ], - edges: [ - { source: 'paragraph:p-1', target: 'entity:alpha', kind: 'mentions', label: '提及', weight: 1 }, - { source: 'paragraph:p-1', target: 'relation:rel-1', kind: 'supports', label: '支撑', weight: 1 }, - ], - focus_entities: ['alpha'], - }, - }) - vi.mocked(memoryApi.getMemoryGraphEdgeDetail).mockResolvedValue({ - success: true, - edge: { - source: 'alpha', - target: 'beta', - weight: 1, - predicates: ['关联'], - relation_count: 1, - evidence_count: 1, - relation_hashes: ['rel-1'], - label: '关联', - }, - relations: [ - { - hash: 'rel-1', - subject: 'alpha', - predicate: '关联', - object: 'beta', - text: 'alpha 关联 beta', - confidence: 0.9, - paragraph_count: 1, - paragraph_hashes: ['p-1'], - source_paragraph: 'p-1', - }, - ], - paragraphs: [ - { - hash: 'p-1', - content: 'Alpha 提到了 Beta', - preview: 'Alpha 提到了 Beta', - source: 'demo', - entity_count: 2, - relation_count: 1, - entities: ['Alpha', 'Beta'], - relations: ['alpha 关联 beta'], - }, - ], - evidence_graph: { - nodes: [ - { id: 'entity:alpha', type: 'entity', content: 'Alpha' }, - { id: 'entity:beta', type: 'entity', content: 'Beta' }, - { id: 'relation:rel-1', type: 'relation', content: 'alpha 关联 beta' }, - ], - edges: [ - { source: 'relation:rel-1', target: 'entity:alpha', kind: 'subject', label: '主语', weight: 1 }, - { source: 'relation:rel-1', target: 'entity:beta', kind: 'object', label: '宾语', weight: 1 }, - ], - focus_entities: ['alpha', 'beta'], - }, - }) - vi.mocked(memoryApi.previewMemoryDelete).mockResolvedValue({ - success: true, - mode: 'mixed', - selector: { entity_hashes: ['entity-1'] }, - counts: { entities: 1, relations: 1, paragraphs: 1 }, - sources: ['demo'], - items: [{ item_type: 'entity', item_hash: 'entity-1', label: 'Alpha' }], - item_count: 1, - dry_run: true, - } as never) - vi.mocked(memoryApi.executeMemoryDelete).mockResolvedValue({ - success: true, - mode: 'mixed', - operation_id: 'del-1', - counts: { entities: 1, relations: 1, paragraphs: 1 }, - sources: ['demo'], - deleted_count: 3, - deleted_entity_count: 1, - deleted_relation_count: 1, - deleted_paragraph_count: 1, - deleted_source_count: 0, - } as never) - vi.mocked(memoryApi.restoreMemoryDelete).mockResolvedValue({ success: true } as never) - }) - - it('calls backend graph search and renders no-hit state', async () => { - const user = userEvent.setup() - - render() - - expect(await screen.findByText('长期记忆图谱')).toBeInTheDocument() - expect(screen.getByText(/总节点 2/)).toBeInTheDocument() - expect(screen.getByTestId('graph-visualization')).toHaveTextContent('nodes:2,edges:1') - - await user.type(screen.getByPlaceholderText('搜索实体、关系、hash(后端全库)'), 'missing') - expect(memoryApi.getMemoryGraph).toHaveBeenCalledTimes(1) - await user.click(screen.getByRole('button', { name: '搜索' })) - - await waitFor(() => { - expect(memoryApi.getMemoryGraphSearch).toHaveBeenCalledWith('missing', 50) - }) - expect(await screen.findByText('未命中实体或关系。')).toBeInTheDocument() - }) - - it('supports clicking entity search result to locate evidence', async () => { - const user = userEvent.setup() - vi.mocked(memoryApi.getMemoryGraphSearch).mockResolvedValue({ - success: true, - query: 'alpha', - limit: 50, - count: 1, - items: [ - { - type: 'entity', - title: 'Alpha', - matched_field: 'name', - matched_value: 'Alpha', - entity_name: 'alpha', - entity_hash: 'entity-1', - appearance_count: 3, - }, - ], - }) - - render() - - await screen.findByTestId('graph-visualization') - await user.type(screen.getByPlaceholderText('搜索实体、关系、hash(后端全库)'), 'alpha') - await user.click(screen.getByRole('button', { name: '搜索' })) - - await screen.findByText('搜索词:alpha') - await user.click(screen.getByRole('button', { name: /Alpha/ })) - - await waitFor(() => { - expect(memoryApi.getMemoryGraphNodeDetail).toHaveBeenCalledWith('alpha') - }) - expect(screen.getByRole('tab', { name: '证据视图' })).toHaveAttribute('data-state', 'active') - }) - - it('supports clicking relation search result to locate evidence', async () => { - const user = userEvent.setup() - vi.mocked(memoryApi.getMemoryGraphSearch).mockResolvedValue({ - success: true, - query: '关联', - limit: 50, - count: 1, - items: [ - { - type: 'relation', - title: 'alpha 关联 beta', - matched_field: 'predicate', - matched_value: '关联', - subject: 'alpha', - predicate: '关联', - object: 'beta', - relation_hash: 'rel-1', - confidence: 0.9, - }, - ], - }) - - render() - - await screen.findByTestId('graph-visualization') - await user.type(screen.getByPlaceholderText('搜索实体、关系、hash(后端全库)'), '关联') - await user.click(screen.getByRole('button', { name: '搜索' })) - await user.click(screen.getByRole('button', { name: /alpha 关联 beta/ })) - - await waitFor(() => { - expect(memoryApi.getMemoryGraphEdgeDetail).toHaveBeenCalledWith('alpha', 'beta') - }) - expect(screen.getByRole('tab', { name: '证据视图' })).toHaveAttribute('data-state', 'active') - }) - - it('falls back to local filtering when backend search fails', async () => { - const user = userEvent.setup() - vi.mocked(memoryApi.getMemoryGraphSearch).mockRejectedValue(new Error('search unavailable')) - - render() - - await screen.findByTestId('graph-visualization') - await user.type(screen.getByPlaceholderText('搜索实体、关系、hash(后端全库)'), 'missing') - await user.click(screen.getByRole('button', { name: '搜索' })) - - expect(await screen.findByText('还没有可展示的长期记忆图谱')).toBeInTheDocument() - expect(toastMock).toHaveBeenCalledWith( - expect.objectContaining({ - title: '后端检索失败,已回退本地筛选', - }), - ) - }) - - it('shows empty state when switching to evidence view without a selection', async () => { - const user = userEvent.setup() - - render() - - expect(await screen.findByTestId('graph-visualization')).toBeInTheDocument() - await user.click(screen.getByRole('tab', { name: '证据视图' })) - - expect(await screen.findByText('证据视图还没有可展示的选择')).toBeInTheDocument() - }) - - it('closes node dialog when switching to evidence view and renders evidence graph', async () => { - const user = userEvent.setup() - - render() - - await screen.findByTestId('graph-visualization') - await user.click(screen.getByRole('button', { name: '选择节点' })) - - expect(await screen.findByTestId('node-detail-dialog')).toHaveTextContent('relations:关联') - expect(screen.getByTestId('node-detail-dialog')).toHaveTextContent('paragraphs:1') - - await user.click(screen.getByRole('button', { name: '切到证据视图' })) - - await waitFor(() => { - expect(screen.queryByTestId('node-detail-dialog')).not.toBeInTheDocument() - }) - - await waitFor(() => { - expect(screen.getByTestId('graph-visualization')).toHaveTextContent('nodes:3,edges:2') - }) - }) - - it('loads edge detail with predicates and support paragraphs', async () => { - const user = userEvent.setup() - - render() - - await screen.findByTestId('graph-visualization') - await user.click(screen.getByRole('button', { name: '选择边' })) - - expect(await screen.findByTestId('edge-detail-dialog')).toHaveTextContent('predicates:关联') - expect(screen.getByTestId('edge-detail-dialog')).toHaveTextContent('paragraphs:1') - - await user.click(screen.getByRole('button', { name: '切到证据视图' })) - - await waitFor(() => { - expect(screen.queryByTestId('edge-detail-dialog')).not.toBeInTheDocument() - }) - }) - - it('opens delete preview dialog from node detail', async () => { - const user = userEvent.setup() - - render() - - await screen.findByTestId('graph-visualization') - await user.click(screen.getByRole('button', { name: '选择节点' })) - await screen.findByTestId('node-detail-dialog') - - await user.click(screen.getByRole('button', { name: '删除实体' })) - - await waitFor(() => { - expect(memoryApi.previewMemoryDelete).toHaveBeenCalled() - }) - expect(await screen.findByTestId('memory-delete-dialog')).toHaveTextContent('delete:mixed:1') - }) -}) diff --git a/dashboard/tsconfig.vitest.json b/dashboard/tsconfig.vitest.json deleted file mode 100644 index 3bc1dd81..00000000 --- a/dashboard/tsconfig.vitest.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "extends": "./tsconfig.app.json", - "compilerOptions": { - "types": ["vite/client", "vitest/globals", "@testing-library/jest-dom"] - }, - "include": ["src"] -} diff --git a/dashboard/vitest.config.ts b/dashboard/vitest.config.ts deleted file mode 100644 index 5770520a..00000000 --- a/dashboard/vitest.config.ts +++ /dev/null @@ -1,18 +0,0 @@ -/// -import { defineConfig } from 'vite' -import react from '@vitejs/plugin-react' -import path from 'path' - -export default defineConfig({ - plugins: [react()], - test: { - globals: true, - environment: 'jsdom', - setupFiles: './src/test/setup.ts', - }, - resolve: { - alias: { - '@': path.resolve(__dirname, './src'), - }, - }, -}) diff --git a/pytests/A_memorix_test/test_chat_summary_writeback_integration.py b/pytests/A_memorix_test/test_chat_summary_writeback_integration.py deleted file mode 100644 index 7618bca7..00000000 --- a/pytests/A_memorix_test/test_chat_summary_writeback_integration.py +++ /dev/null @@ -1,398 +0,0 @@ -from __future__ import annotations - -from datetime import datetime, timedelta -from pathlib import Path -from types import SimpleNamespace -from typing import Any, Callable, Dict, List - -import asyncio -import inspect -import json -import pickle - -from sqlalchemy.orm import sessionmaker -from sqlmodel import Session, create_engine -import numpy as np -import pytest - -IMPORT_ERROR: str | None = None - -try: - from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module - from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel - from src.A_memorix.core.utils import summary_importer as summary_importer_module - from src.chat.message_receive.chat_manager import BotChatSession - from src.chat.message_receive.message import SessionMessage - from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo - from src.common.data_models.message_component_data_model import MessageSequence, TextComponent - from src.common.database import database as database_module - from src.common.database.migrations import create_database_migration_bootstrapper - from src.common.message_repository import count_messages - from src.config.model_configs import TaskConfig - from src.services import memory_flow_service as memory_flow_service_module - from src.services import memory_service as memory_service_module - from src.services import send_service -except SystemExit as exc: - IMPORT_ERROR = f"config initialization exited during import: {exc}" - kernel_module = None # type: ignore[assignment] - SDKMemoryKernel = None # type: ignore[assignment] - summary_importer_module = None # type: ignore[assignment] - BotChatSession = None # type: ignore[assignment] - SessionMessage = None # type: ignore[assignment] - MessageInfo = None # type: ignore[assignment] - UserInfo = None # type: ignore[assignment] - MessageSequence = None # type: ignore[assignment] - TextComponent = None # type: ignore[assignment] - database_module = None # type: ignore[assignment] - create_database_migration_bootstrapper = None # type: ignore[assignment] - count_messages = None # type: ignore[assignment] - TaskConfig = None # type: ignore[assignment] - memory_flow_service_module = None # type: ignore[assignment] - memory_service_module = None # type: ignore[assignment] - send_service = None # type: ignore[assignment] - - -pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "") - - -class _FakeEmbeddingManager: - def __init__(self, dimension: int = 8) -> None: - self.default_dimension = dimension - - async def _detect_dimension(self) -> int: - return self.default_dimension - - async def encode(self, text: Any) -> np.ndarray: - def _encode_one(raw: Any) -> np.ndarray: - content = str(raw or "") - vector = np.zeros(self.default_dimension, dtype=np.float32) - for index, byte in enumerate(content.encode("utf-8")): - vector[index % self.default_dimension] += float((byte % 17) + 1) - norm = float(np.linalg.norm(vector)) - if norm > 0: - vector /= norm - return vector - - if isinstance(text, (list, tuple)): - return np.stack([_encode_one(item) for item in text]).astype(np.float32) - return _encode_one(text).astype(np.float32) - - -class _KernelBackedRuntimeManager: - def __init__(self, kernel: SDKMemoryKernel) -> None: - self.kernel = kernel - - async def invoke( - self, - component_name: str, - args: Dict[str, Any] | None, - *, - timeout_ms: int = 30000, - ) -> Any: - del timeout_ms - payload = args or {} - handler = getattr(self.kernel, component_name) - result = handler(**payload) - return await result if inspect.isawaitable(result) else result - - -class _NoopRuntimeManager: - async def invoke_hook(self, hook_name: str, **kwargs: Any) -> Any: - del hook_name - return SimpleNamespace(aborted=False, kwargs=kwargs) - - -class _FakePlatformIOManager: - def __init__(self) -> None: - self.ensure_calls = 0 - - async def ensure_send_pipeline_ready(self) -> None: - self.ensure_calls += 1 - - def build_route_key_from_message(self, message: Any) -> Any: - del message - return SimpleNamespace(platform="qq") - - async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any: - del message, metadata - return SimpleNamespace( - has_success=True, - sent_receipts=[ - SimpleNamespace( - driver_id="plugin.qq.sender", - external_message_id="real-message-id", - metadata={}, - ) - ], - failed_receipts=[], - route_key=route_key, - ) - - -def _install_temp_main_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: - db_dir = (tmp_path / "main_db").resolve() - db_dir.mkdir(parents=True, exist_ok=True) - db_file = db_dir / "MaiBot.db" - database_url = f"sqlite:///{db_file}" - - try: - database_module.engine.dispose() - except Exception: - pass - - engine = create_engine( - database_url, - echo=False, - connect_args={"check_same_thread": False}, - pool_pre_ping=True, - ) - session_local = sessionmaker( - autocommit=False, - autoflush=False, - bind=engine, - class_=Session, - ) - bootstrapper = create_database_migration_bootstrapper(engine) - - monkeypatch.setattr(database_module, "_DB_DIR", db_dir, raising=False) - monkeypatch.setattr(database_module, "_DB_FILE", db_file, raising=False) - monkeypatch.setattr(database_module, "DATABASE_URL", database_url, raising=False) - monkeypatch.setattr(database_module, "engine", engine, raising=False) - monkeypatch.setattr(database_module, "SessionLocal", session_local, raising=False) - monkeypatch.setattr(database_module, "_migration_bootstrapper", bootstrapper, raising=False) - monkeypatch.setattr(database_module, "_db_initialized", False, raising=False) - - -def _build_incoming_message( - *, - session_id: str, - user_id: str, - text: str, - timestamp: datetime | None = None, -) -> SessionMessage: - message = SessionMessage( - message_id="incoming-message-id", - timestamp=timestamp or datetime.now(), - platform="qq", - ) - message.message_info = MessageInfo( - user_info=UserInfo( - user_id=user_id, - user_nickname="测试用户", - user_cardname="测试用户", - ), - additional_config={}, - ) - message.raw_message = MessageSequence(components=[TextComponent(text=text)]) - message.session_id = session_id - message.reply_to = None - message.is_mentioned = False - message.is_at = False - message.is_emoji = False - message.is_picture = False - message.is_command = False - message.is_notify = False - message.processed_plain_text = text - message.initialized = True - return message - - -async def _wait_until( - predicate: Callable[[], Any], - *, - timeout_seconds: float = 10.0, - interval_seconds: float = 0.05, - description: str, -) -> Any: - deadline = asyncio.get_running_loop().time() + max(0.5, float(timeout_seconds)) - while asyncio.get_running_loop().time() < deadline: - value = predicate() - if inspect.isawaitable(value): - value = await value - if value: - return value - await asyncio.sleep(interval_seconds) - raise AssertionError(f"等待超时: {description}") - - -@pytest.mark.asyncio -async def test_text_to_stream_triggers_real_chat_summary_writeback( - monkeypatch: pytest.MonkeyPatch, - tmp_path: Path, -) -> None: - _install_temp_main_database(monkeypatch, tmp_path) - - fake_embedding_manager = _FakeEmbeddingManager() - captured_prompts: List[str] = [] - fixed_send_timestamp = 1_777_000_000.0 - - async def _fake_runtime_self_check(**kwargs: Any) -> Dict[str, Any]: - del kwargs - return { - "ok": True, - "message": "ok", - "configured_dimension": fake_embedding_manager.default_dimension, - "requested_dimension": fake_embedding_manager.default_dimension, - "vector_store_dimension": fake_embedding_manager.default_dimension, - "detected_dimension": fake_embedding_manager.default_dimension, - "encoded_dimension": fake_embedding_manager.default_dimension, - "elapsed_ms": 0.0, - "sample_text": "test", - "checked_at": datetime.now().timestamp(), - } - - async def _fake_generate(request: Any) -> Any: - captured_prompts.append(str(getattr(request, "prompt", "") or "")) - return SimpleNamespace( - success=True, - completion=SimpleNamespace( - response=json.dumps( - { - "summary": "这段对话记录了用户提到自己买了绿色围巾,机器人表示会记住这件事。", - "entities": ["绿色围巾"], - "relations": [], - }, - ensure_ascii=False, - ) - ), - ) - - monkeypatch.setattr( - kernel_module, - "create_embedding_api_adapter", - lambda **kwargs: fake_embedding_manager, - ) - monkeypatch.setattr( - kernel_module, - "run_embedding_runtime_self_check", - _fake_runtime_self_check, - ) - monkeypatch.setattr( - summary_importer_module, - "run_embedding_runtime_self_check", - _fake_runtime_self_check, - ) - monkeypatch.setattr( - summary_importer_module.llm_api, - "get_available_models", - lambda: {"utils": TaskConfig(model_list=["fake-summary-model"])}, - ) - monkeypatch.setattr( - summary_importer_module.llm_api, - "resolve_task_name_from_model_config", - lambda model_config: "utils", - ) - monkeypatch.setattr( - summary_importer_module.llm_api, - "generate", - _fake_generate, - ) - monkeypatch.setattr(send_service.time, "time", lambda: fixed_send_timestamp) - monkeypatch.setattr(summary_importer_module.time, "time", lambda: fixed_send_timestamp) - - kernel = SDKMemoryKernel( - plugin_root=tmp_path / "plugin_root", - config={ - "storage": {"data_dir": str((tmp_path / "a_memorix_data").resolve())}, - "advanced": {"enable_auto_save": False}, - "embedding": {"dimension": fake_embedding_manager.default_dimension}, - "memory": {"base_decay_interval_hours": 24}, - "person_profile": {"refresh_interval_minutes": 5}, - "summarization": {"model_name": ["utils"]}, - }, - ) - - service = memory_flow_service_module.MemoryAutomationService() - fake_platform_io_manager = _FakePlatformIOManager() - - async def _fake_rebuild_episodes_for_sources(sources: List[str]) -> Dict[str, Any]: - return { - "rebuilt": 0, - "items": [], - "failures": [], - "sources": list(sources), - } - - monkeypatch.setattr(kernel, "rebuild_episodes_for_sources", _fake_rebuild_episodes_for_sources) - monkeypatch.setattr( - memory_service_module, - "a_memorix_host_service", - _KernelBackedRuntimeManager(kernel), - ) - monkeypatch.setattr(memory_flow_service_module, "memory_automation_service", service) - monkeypatch.setattr(send_service, "_get_runtime_manager", lambda: _NoopRuntimeManager()) - monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_platform_io_manager) - monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq") - monkeypatch.setattr( - send_service._chat_manager, - "get_session_by_session_id", - lambda stream_id: ( - BotChatSession( - session_id="test-session", - platform="qq", - user_id="target-user", - group_id=None, - ) - if stream_id == "test-session" - else None - ), - ) - integration_config = memory_flow_service_module.global_config.a_memorix.integration - monkeypatch.setattr(integration_config, "chat_summary_writeback_enabled", True, raising=False) - monkeypatch.setattr(integration_config, "chat_summary_writeback_message_threshold", 2, raising=False) - monkeypatch.setattr(integration_config, "chat_summary_writeback_context_length", 10, raising=False) - monkeypatch.setattr(integration_config, "person_fact_writeback_enabled", False, raising=False) - - await kernel.initialize() - - try: - incoming_message = _build_incoming_message( - session_id="test-session", - user_id="target-user", - text="我最近买了一条绿色围巾。", - timestamp=datetime.fromtimestamp(fixed_send_timestamp) - timedelta(seconds=1), - ) - with database_module.get_db_session() as session: - session.add(incoming_message.to_db_instance()) - - sent_message = await send_service.text_to_stream_with_message( - text="好的,我会记住你最近买了绿色围巾。", - stream_id="test-session", - storage_message=True, - ) - - assert sent_message is not None - assert sent_message.message_id == "real-message-id" - assert fake_platform_io_manager.ensure_calls == 1 - assert count_messages(session_id="test-session") == 2 - - paragraphs = await _wait_until( - lambda: kernel.metadata_store.get_paragraphs_by_source("chat_summary:test-session"), - description="等待聊天摘要写回到 A_memorix", - ) - - assert captured_prompts - assert "我最近买了一条绿色围巾。" in captured_prompts[-1] - assert "好的,我会记住你最近买了绿色围巾。" in captured_prompts[-1] - assert any("绿色围巾" in str(item.get("content", "") or "") for item in paragraphs) - assert any( - int( - ( - pickle.loads(item.get("metadata")) - if isinstance(item.get("metadata"), (bytes, bytearray)) - else item.get("metadata") - or {} - ).get("trigger_message_count", 0) - or 0 - ) - == 2 - for item in paragraphs - ) - assert service.chat_summary_writeback._states["test-session"].last_trigger_message_count == 2 - finally: - await service.shutdown() - await kernel.shutdown() - try: - database_module.engine.dispose() - except Exception: - pass diff --git a/pytests/A_memorix_test/test_embedding_dimension_control.py b/pytests/A_memorix_test/test_embedding_dimension_control.py deleted file mode 100644 index 4716fe68..00000000 --- a/pytests/A_memorix_test/test_embedding_dimension_control.py +++ /dev/null @@ -1,191 +0,0 @@ -from __future__ import annotations - -from types import SimpleNamespace - -import numpy as np -import pytest - -from A_memorix.core.embedding import api_adapter as api_adapter_module -from A_memorix.core.embedding.api_adapter import EmbeddingAPIAdapter -from A_memorix.core.utils.runtime_self_check import run_embedding_runtime_self_check - - -class _FakeEmbeddingClient: - def __init__(self, *, natural_dimension: int = 12) -> None: - self.natural_dimension = int(natural_dimension) - self.requests = [] - - async def get_embedding(self, request): - self.requests.append(request) - requested_dimension = request.extra_params.get("dimensions") - if requested_dimension is None: - requested_dimension = request.extra_params.get("output_dimensionality") - dimension = int(requested_dimension or self.natural_dimension) - return SimpleNamespace(embedding=[1.0] * dimension) - - -def _build_adapter( - monkeypatch: pytest.MonkeyPatch, - *, - client_type: str, - configured_dimension: int = 1024, - effective_dimension: int | None = None, - model_extra_params: dict | None = None, -): - adapter = EmbeddingAPIAdapter(default_dimension=configured_dimension) - if effective_dimension is not None: - adapter._dimension = int(effective_dimension) - adapter._dimension_detected = True - - fake_client = _FakeEmbeddingClient() - model_info = SimpleNamespace( - name="embedding-model", - api_provider="provider-1", - model_identifier="embedding-model-id", - extra_params=dict(model_extra_params or {}), - ) - provider = SimpleNamespace(name="provider-1", client_type=client_type) - - monkeypatch.setattr(adapter, "_resolve_candidate_model_names", lambda: ["embedding-model"]) - monkeypatch.setattr(adapter, "_find_model_info", lambda model_name: model_info) - monkeypatch.setattr(adapter, "_find_provider", lambda provider_name: provider) - monkeypatch.setattr( - api_adapter_module.client_registry, - "get_client_class_instance", - lambda api_provider, force_new=True: fake_client, - ) - return adapter, fake_client - - -@pytest.mark.asyncio -async def test_encode_uses_canonical_dimension_for_openai_provider(monkeypatch): - adapter, fake_client = _build_adapter( - monkeypatch, - client_type="openai", - configured_dimension=1024, - effective_dimension=1024, - model_extra_params={"task_type": "SEMANTIC_SIMILARITY"}, - ) - - embedding = await adapter.encode("北塔木梯") - - request = fake_client.requests[-1] - assert request.extra_params["dimensions"] == 1024 - assert "output_dimensionality" not in request.extra_params - assert request.extra_params["task_type"] == "SEMANTIC_SIMILARITY" - assert embedding.shape == (1024,) - - -@pytest.mark.asyncio -async def test_encode_explicit_dimension_override_wins(monkeypatch): - adapter, fake_client = _build_adapter( - monkeypatch, - client_type="openai", - configured_dimension=1024, - effective_dimension=1024, - ) - - embedding = await adapter.encode("海潮图", dimensions=256) - - request = fake_client.requests[-1] - assert request.extra_params["dimensions"] == 256 - assert "output_dimensionality" not in request.extra_params - assert embedding.shape == (256,) - - -@pytest.mark.asyncio -async def test_encode_maps_dimension_to_gemini_output_dimensionality(monkeypatch): - adapter, fake_client = _build_adapter( - monkeypatch, - client_type="gemini", - configured_dimension=1024, - effective_dimension=768, - ) - - embedding = await adapter.encode("广播站") - - request = fake_client.requests[-1] - assert request.extra_params["output_dimensionality"] == 768 - assert "dimensions" not in request.extra_params - assert embedding.shape == (768,) - - -@pytest.mark.asyncio -async def test_encode_does_not_force_dimension_for_unsupported_provider(monkeypatch): - adapter, fake_client = _build_adapter( - monkeypatch, - client_type="custom", - configured_dimension=1024, - effective_dimension=640, - model_extra_params={ - "dimensions": 999, - "output_dimensionality": 888, - "custom_flag": "keep-me", - }, - ) - - embedding = await adapter.encode("蓝漆铁盒") - - request = fake_client.requests[-1] - assert "dimensions" not in request.extra_params - assert "output_dimensionality" not in request.extra_params - assert request.extra_params["custom_flag"] == "keep-me" - assert embedding.shape == (fake_client.natural_dimension,) - - -@pytest.mark.asyncio -async def test_runtime_self_check_reports_requested_dimension_without_explicit_override(): - class _FakeEmbeddingManager: - def __init__(self) -> None: - self.detected_dimension = 384 - self.encode_calls = [] - - async def _detect_dimension(self) -> int: - return self.detected_dimension - - def get_requested_dimension(self) -> int: - return self.detected_dimension - - async def encode(self, text): - self.encode_calls.append(text) - return np.ones(self.detected_dimension, dtype=np.float32) - - manager = _FakeEmbeddingManager() - - report = await run_embedding_runtime_self_check( - config={"embedding": {"dimension": 1024}}, - vector_store=SimpleNamespace(dimension=384), - embedding_manager=manager, - ) - - assert report["ok"] is True - assert report["configured_dimension"] == 1024 - assert report["requested_dimension"] == 384 - assert report["detected_dimension"] == 384 - assert report["encoded_dimension"] == 384 - assert manager.encode_calls == ["A_Memorix runtime self check"] - - -@pytest.mark.asyncio -async def test_encode_batch_keeps_batch_local_indexes_when_cache_hits_previous_batch(monkeypatch): - adapter = EmbeddingAPIAdapter(default_dimension=4, enable_cache=True) - adapter._dimension = 4 - adapter._dimension_detected = True - - async def fake_detect_dimension() -> int: - return 4 - - async def fake_get_embedding_direct(text: str, dimensions: int | None = None): - del dimensions - base = float(ord(str(text)[0])) - return [base, base + 1.0, base + 2.0, base + 3.0] - - monkeypatch.setattr(adapter, "_detect_dimension", fake_detect_dimension) - monkeypatch.setattr(adapter, "_get_embedding_direct", fake_get_embedding_direct) - - embeddings = await adapter.encode(["A", "B", "A", "C"], batch_size=2) - - assert embeddings.shape == (4, 4) - assert np.array_equal(embeddings[0], embeddings[2]) - assert embeddings[1][0] == float(ord("B")) - assert embeddings[3][0] == float(ord("C")) diff --git a/pytests/A_memorix_test/test_feedback_correction_chat_flow.py b/pytests/A_memorix_test/test_feedback_correction_chat_flow.py deleted file mode 100644 index bf6f8c72..00000000 --- a/pytests/A_memorix_test/test_feedback_correction_chat_flow.py +++ /dev/null @@ -1,780 +0,0 @@ -from __future__ import annotations - -import asyncio -import inspect -import json -import time -import uuid -from datetime import datetime -from pathlib import Path -from types import SimpleNamespace -from typing import Any, Callable, Dict - -import numpy as np -import pytest -import pytest_asyncio -from sqlalchemy.orm import sessionmaker -from sqlmodel import Session, create_engine, select - -IMPORT_ERROR: str | None = None - -try: - from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module - from src.A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel - from src.chat.heart_flow.heartflow_manager import heartflow_manager - from src.chat.message_receive import bot as bot_module - from src.chat.message_receive.chat_manager import chat_manager - from src.chat.message_receive.bot import chat_bot - from src.common.database import database as database_module - from src.common.database.database_model import PersonInfo, ToolRecord - from src.common.database.migrations import create_database_migration_bootstrapper - from src.common.utils.utils_session import SessionUtils - from src.llm_models.payload_content.tool_option import ToolCall - from src.maisaka import reasoning_engine as reasoning_engine_module - from src.maisaka import runtime as runtime_module - from src.maisaka import chat_loop_service as chat_loop_service_module - from src.maisaka.chat_loop_service import ChatResponse - from src.maisaka.context_messages import AssistantMessage - from src.plugin_runtime import component_query as component_query_module - from src.services import memory_flow_service as memory_flow_service_module - from src.services import memory_service as memory_service_module - from src.services.memory_service import memory_service -except SystemExit as exc: - IMPORT_ERROR = f"config initialization exited during import: {exc}" - kernel_module = None # type: ignore[assignment] - KernelSearchRequest = None # type: ignore[assignment] - SDKMemoryKernel = None # type: ignore[assignment] - heartflow_manager = None # type: ignore[assignment] - bot_module = None # type: ignore[assignment] - chat_manager = None # type: ignore[assignment] - chat_bot = None # type: ignore[assignment] - database_module = None # type: ignore[assignment] - ToolRecord = None # type: ignore[assignment] - PersonInfo = None # type: ignore[assignment] - create_database_migration_bootstrapper = None # type: ignore[assignment] - SessionUtils = None # type: ignore[assignment] - ToolCall = None # type: ignore[assignment] - reasoning_engine_module = None # type: ignore[assignment] - runtime_module = None # type: ignore[assignment] - chat_loop_service_module = None # type: ignore[assignment] - ChatResponse = None # type: ignore[assignment] - AssistantMessage = None # type: ignore[assignment] - component_query_module = None # type: ignore[assignment] - memory_flow_service_module = None # type: ignore[assignment] - memory_service_module = None # type: ignore[assignment] - memory_service = None # type: ignore[assignment] - - -pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "") - -RELATION_QUERY = "测试用户 和 最喜欢的颜色 有什么关系" - - -class _FakeEmbeddingManager: - def __init__(self, dimension: int = 8) -> None: - self.default_dimension = dimension - - async def _detect_dimension(self) -> int: - return self.default_dimension - - async def encode(self, text: Any) -> np.ndarray: - def _encode_one(raw: Any) -> np.ndarray: - content = str(raw or "") - vector = np.zeros(self.default_dimension, dtype=np.float32) - for index, byte in enumerate(content.encode("utf-8")): - vector[index % self.default_dimension] += float((byte % 17) + 1) - norm = float(np.linalg.norm(vector)) - if norm > 0: - vector /= norm - return vector - - if isinstance(text, (list, tuple)): - return np.stack([_encode_one(item) for item in text]).astype(np.float32) - return _encode_one(text).astype(np.float32) - - -class _KernelBackedRuntimeManager: - def __init__(self, kernel: SDKMemoryKernel) -> None: - self.kernel = kernel - - async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000): - del timeout_ms - payload = args or {} - if component_name == "search_memory": - return await self.kernel.search_memory( - KernelSearchRequest( - query=str(payload.get("query", "") or ""), - limit=int(payload.get("limit", 5) or 5), - mode=str(payload.get("mode", "hybrid") or "hybrid"), - chat_id=str(payload.get("chat_id", "") or ""), - person_id=str(payload.get("person_id", "") or ""), - time_start=payload.get("time_start"), - time_end=payload.get("time_end"), - respect_filter=bool(payload.get("respect_filter", True)), - user_id=str(payload.get("user_id", "") or ""), - group_id=str(payload.get("group_id", "") or ""), - ) - ) - - handler = getattr(self.kernel, component_name) - result = handler(**payload) - return await result if inspect.isawaitable(result) else result - - -class _NoopRuntimeManager: - async def invoke_hook(self, hook_name: str, **kwargs: Any) -> Any: - del hook_name - return SimpleNamespace(aborted=False, kwargs=kwargs) - - -def _install_temp_main_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: - db_dir = (tmp_path / "main_db").resolve() - db_dir.mkdir(parents=True, exist_ok=True) - db_file = db_dir / "MaiBot.db" - database_url = f"sqlite:///{db_file}" - - try: - database_module.engine.dispose() - except Exception: - pass - - engine = create_engine( - database_url, - echo=False, - connect_args={"check_same_thread": False}, - pool_pre_ping=True, - ) - session_local = sessionmaker( - autocommit=False, - autoflush=False, - bind=engine, - class_=Session, - ) - bootstrapper = create_database_migration_bootstrapper(engine) - - monkeypatch.setattr(database_module, "_DB_DIR", db_dir, raising=False) - monkeypatch.setattr(database_module, "_DB_FILE", db_file, raising=False) - monkeypatch.setattr(database_module, "DATABASE_URL", database_url, raising=False) - monkeypatch.setattr(database_module, "engine", engine, raising=False) - monkeypatch.setattr(database_module, "SessionLocal", session_local, raising=False) - monkeypatch.setattr(database_module, "_migration_bootstrapper", bootstrapper, raising=False) - monkeypatch.setattr(database_module, "_db_initialized", False, raising=False) - - -def _build_chat_response(content: str, tool_calls: list[ToolCall]) -> ChatResponse: - return ChatResponse( - content=content, - tool_calls=tool_calls, - request_messages=[], - raw_message=AssistantMessage( - content=content, - timestamp=datetime.now(), - tool_calls=tool_calls, - ), - selected_history_count=0, - tool_count=len(tool_calls), - prompt_tokens=0, - built_message_count=0, - completion_tokens=0, - total_tokens=0, - prompt_section=None, - ) - - -def _build_message_data( - *, - content: str, - platform: str, - user_id: str, - user_name: str, - group_id: str, - group_name: str, -) -> Dict[str, Any]: - message_id = str(uuid.uuid4()) - return { - "message_info": { - "platform": platform, - "message_id": message_id, - "time": time.time(), - "group_info": { - "group_id": group_id, - "group_name": group_name, - "platform": platform, - }, - "user_info": { - "user_id": user_id, - "user_nickname": user_name, - "user_cardname": user_name, - "platform": platform, - }, - "additional_config": { - "at_bot": True, - }, - }, - "message_segment": { - "type": "seglist", - "data": [ - { - "type": "text", - "data": content, - }, - ], - }, - "raw_message": content, - "processed_plain_text": content, - } - - -async def _wait_until( - predicate: Callable[[], Any], - *, - timeout_seconds: float = 10.0, - interval_seconds: float = 0.05, - description: str, -) -> Any: - deadline = asyncio.get_running_loop().time() + max(0.5, float(timeout_seconds)) - while asyncio.get_running_loop().time() < deadline: - value = predicate() - if inspect.isawaitable(value): - value = await value - if value: - return value - await asyncio.sleep(interval_seconds) - raise AssertionError(f"等待超时: {description}") - - -def _load_feedback_tasks(kernel: SDKMemoryKernel) -> list[Dict[str, Any]]: - assert kernel.metadata_store is not None - cursor = kernel.metadata_store.get_connection().cursor() - rows = cursor.execute( - "SELECT query_tool_id FROM memory_feedback_tasks ORDER BY id" - ).fetchall() - tasks: list[Dict[str, Any]] = [] - for row in rows: - task = kernel.metadata_store.get_feedback_task(str(row["query_tool_id"] or "")) - if task is not None: - tasks.append(task) - return tasks - - -def _load_feedback_action_types(kernel: SDKMemoryKernel) -> list[str]: - assert kernel.metadata_store is not None - cursor = kernel.metadata_store.get_connection().cursor() - rows = cursor.execute( - "SELECT action_type FROM memory_feedback_action_logs ORDER BY id" - ).fetchall() - return [str(row["action_type"] or "") for row in rows] - - -def _load_query_memory_tool_records(session_id: str) -> list[Dict[str, Any]]: - with database_module.get_db_session() as session: - statement = ( - select(ToolRecord) - .where(ToolRecord.session_id == session_id) - .where(ToolRecord.tool_name == "query_memory") - .order_by(ToolRecord.timestamp) - ) - rows = list(session.exec(statement).all()) - return [ - { - "tool_id": str(row.tool_id or ""), - "session_id": str(row.session_id or ""), - "tool_name": str(row.tool_name or ""), - "tool_data": str(row.tool_data or ""), - "timestamp": row.timestamp, - } - for row in rows - ] - - -def _seed_person_info(*, person_id: str, person_name: str, session_info: Dict[str, Any]) -> None: - with database_module.get_db_session() as session: - session.add( - PersonInfo( - is_known=True, - person_id=person_id, - person_name=person_name, - platform=str(session_info["platform"]), - user_id=str(session_info["user_id"]), - user_nickname=str(session_info["user_name"]), - group_cardname=json.dumps( - [{"group_id": str(session_info["group_id"]), "group_cardname": person_name}], - ensure_ascii=False, - ), - know_counts=1, - first_known_time=datetime.now(), - last_known_time=datetime.now(), - ) - ) - session.commit() - - -@pytest_asyncio.fixture -async def chat_feedback_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path): - _install_temp_main_database(monkeypatch, tmp_path) - - chat_manager.sessions.clear() - chat_manager.last_messages.clear() - heartflow_manager.heartflow_chat_list.clear() - - noop_runtime_manager = _NoopRuntimeManager() - monkeypatch.setattr(bot_module.ChatBot, "_get_runtime_manager", staticmethod(lambda: noop_runtime_manager)) - monkeypatch.setattr( - component_query_module.component_query_service, - "find_command_by_text", - lambda text: None, - ) - monkeypatch.setattr( - component_query_module.component_query_service, - "get_llm_available_tool_specs", - lambda **kwargs: {}, - ) - monkeypatch.setattr(runtime_module.global_config.mcp, "enable", False, raising=False) - monkeypatch.setattr( - runtime_module.MaisakaHeartFlowChatting, - "_get_message_trigger_threshold", - lambda self: 1, - ) - - async def _noop_on_incoming_message(message: Any) -> None: - del message - - monkeypatch.setattr( - memory_flow_service_module.memory_automation_service, - "on_incoming_message", - _noop_on_incoming_message, - ) - - fake_embedding_manager = _FakeEmbeddingManager(dimension=8) - - async def _fake_runtime_self_check( - *, - config: Any, - sample_text: str, - vector_store: Any, - embedding_manager: Any, - ) -> Dict[str, Any]: - del config, sample_text, vector_store, embedding_manager - return { - "ok": True, - "message": "ok", - "checked_at": time.time(), - "encoded_dimension": fake_embedding_manager.default_dimension, - } - - monkeypatch.setattr( - kernel_module, - "create_embedding_api_adapter", - lambda **kwargs: fake_embedding_manager, - ) - monkeypatch.setattr( - kernel_module, - "run_embedding_runtime_self_check", - _fake_runtime_self_check, - ) - - kernel = SDKMemoryKernel( - plugin_root=tmp_path / "plugin_root", - config={ - "storage": {"data_dir": str((tmp_path / "a_memorix_data").resolve())}, - "advanced": {"enable_auto_save": False}, - "embedding": {"dimension": fake_embedding_manager.default_dimension}, - "memory": {"base_decay_interval_hours": 24}, - "person_profile": {"refresh_interval_minutes": 5}, - }, - ) - - monkeypatch.setattr(kernel, "_feedback_cfg_enabled", lambda: True) - monkeypatch.setattr(kernel, "_feedback_cfg_window_hours", lambda: 0.0004) - monkeypatch.setattr(kernel, "_feedback_cfg_check_interval_seconds", lambda: 0.2) - monkeypatch.setattr(kernel, "_feedback_cfg_batch_size", lambda: 10) - monkeypatch.setattr(kernel, "_feedback_cfg_max_messages", lambda: 10) - monkeypatch.setattr(kernel, "_feedback_cfg_auto_apply_threshold", lambda: 0.85) - monkeypatch.setattr(kernel, "_feedback_cfg_prefilter_enabled", lambda: True) - monkeypatch.setattr(kernel, "_feedback_cfg_paragraph_mark_enabled", lambda: True) - monkeypatch.setattr(kernel, "_feedback_cfg_paragraph_hard_filter_enabled", lambda: True) - monkeypatch.setattr(kernel, "_feedback_cfg_profile_refresh_enabled", lambda: True) - monkeypatch.setattr(kernel, "_feedback_cfg_profile_force_refresh_on_read", lambda: True) - monkeypatch.setattr(kernel, "_feedback_cfg_episode_rebuild_enabled", lambda: True) - monkeypatch.setattr(kernel, "_feedback_cfg_episode_query_block_enabled", lambda: True) - monkeypatch.setattr(kernel, "_feedback_cfg_reconcile_interval_seconds", lambda: 0.2) - monkeypatch.setattr(kernel, "_feedback_cfg_reconcile_batch_size", lambda: 10) - - monkeypatch.setattr(kernel_module.global_config.memory, "feedback_correction_paragraph_hard_filter_enabled", True, raising=False) - monkeypatch.setattr(kernel_module.global_config.memory, "feedback_correction_episode_query_block_enabled", True, raising=False) - - async def _fake_classify_feedback( - *, - query_tool_id: str, - query_text: str, - hit_briefs: list[Dict[str, Any]], - feedback_messages: list[str], - ) -> Dict[str, Any]: - del query_tool_id, query_text, feedback_messages - target_hash = "" - for item in hit_briefs: - if str(item.get("type", "") or "").strip() == "relation": - target_hash = str(item.get("hash", "") or "").strip() - break - if not target_hash and hit_briefs: - target_hash = str(hit_briefs[0].get("hash", "") or "").strip() - return { - "decision": "correct", - "confidence": 0.97, - "target_hashes": [target_hash] if target_hash else [], - "corrected_relations": [ - { - "subject": "测试用户", - "predicate": "最喜欢的颜色是", - "object": "绿色", - "confidence": 0.99, - } - ], - "reason": "用户明确纠正为绿色", - } - - monkeypatch.setattr(kernel, "_classify_feedback", _fake_classify_feedback) - - await kernel.initialize() - async def _force_episode_fallback(**kwargs: Any) -> Dict[str, Any]: - raise RuntimeError("force_fallback_for_test") - - monkeypatch.setattr( - kernel.episode_service.segmentation_service, - "segment", - _force_episode_fallback, - ) - monkeypatch.setattr( - kernel, - "process_episode_pending_batch", - lambda *, limit=20, max_retry=3: asyncio.sleep(0, result={"processed": 0, "episode_count": 0, "fallback_count": 0, "failed": 0}), - ) - - host_manager = _KernelBackedRuntimeManager(kernel) - monkeypatch.setattr(memory_service_module, "a_memorix_host_service", host_manager) - - planner_calls: list[str] = [] - - async def _fake_timing_gate(self, anchor_message: Any): - del self, anchor_message - return "continue", _build_chat_response("直接进入 planner。", []), [], [] - - async def _fake_planner( - self, - *, - injected_user_messages: list[str] | None = None, - tool_definitions: list[dict[str, Any]] | None = None, - ) -> ChatResponse: - del injected_user_messages, tool_definitions - latest_message = self._runtime.message_cache[-1] - latest_text = str(latest_message.processed_plain_text or "") - planner_calls.append(latest_text) - handled_message_ids = getattr(self._runtime, "_test_query_message_ids", None) - if handled_message_ids is None: - handled_message_ids = set() - self._runtime._test_query_message_ids = handled_message_ids - - if latest_message.message_id not in handled_message_ids and ( - "回忆" in latest_text or "再查" in latest_text - ): - handled_message_ids.add(latest_message.message_id) - tool_call = ToolCall( - call_id=f"query-{uuid.uuid4().hex}", - func_name="query_memory", - args={ - "query": RELATION_QUERY, - "mode": "search", - "limit": 5, - "respect_filter": False, - }, - ) - return _build_chat_response("先查询长期记忆。", [tool_call]) - - stop_call = ToolCall( - call_id=f"stop-{uuid.uuid4().hex}", - func_name="no_reply", - args={}, - ) - return _build_chat_response("当前轮次结束。", [stop_call]) - - monkeypatch.setattr( - reasoning_engine_module.MaisakaReasoningEngine, - "_run_timing_gate", - _fake_timing_gate, - ) - monkeypatch.setattr( - reasoning_engine_module.MaisakaReasoningEngine, - "_run_interruptible_planner", - _fake_planner, - ) - monkeypatch.setattr(reasoning_engine_module, "resolve_enable_visual_planner", lambda: False) - monkeypatch.setattr(chat_loop_service_module, "resolve_enable_visual_planner", lambda: False) - - session_info = { - "platform": "unit_test_chat", - "user_id": "user_feedback_flow", - "user_name": "反馈测试用户", - "group_id": "group_feedback_flow", - "group_name": "反馈纠错测试群", - } - person_id = "person_feedback_flow" - session_id = SessionUtils.calculate_session_id( - session_info["platform"], - user_id=session_info["user_id"], - group_id=session_info["group_id"], - ) - _seed_person_info(person_id=person_id, person_name="测试用户", session_info=session_info) - - try: - yield { - "kernel": kernel, - "session_id": session_id, - "session_info": session_info, - "person_id": person_id, - "planner_calls": planner_calls, - } - finally: - for key, chat in list(heartflow_manager.heartflow_chat_list.items()): - try: - await chat.stop() - except Exception: - pass - heartflow_manager.heartflow_chat_list.pop(key, None) - chat_manager.sessions.clear() - chat_manager.last_messages.clear() - await kernel.shutdown() - try: - database_module.engine.dispose() - except Exception: - pass - - -@pytest.mark.asyncio -async def test_feedback_correction_real_chat_flow( - chat_feedback_env, - monkeypatch: pytest.MonkeyPatch, -) -> None: - kernel = chat_feedback_env["kernel"] - session_id = chat_feedback_env["session_id"] - session_info = chat_feedback_env["session_info"] - person_id = chat_feedback_env["person_id"] - - write_result = await memory_service.ingest_text( - external_id=f"test:feedback-seed:{uuid.uuid4().hex}", - source_type="chat_summary", - text="测试用户 最喜欢的颜色是 蓝色", - chat_id=session_id, - relations=[ - { - "subject": "测试用户", - "predicate": "最喜欢的颜色是", - "object": "蓝色", - "confidence": 1.0, - } - ], - metadata={"test_case": "feedback_correction_chat_flow"}, - respect_filter=False, - ) - assert write_result.success is True - - pre_search = await memory_service.search( - RELATION_QUERY, - mode="search", - chat_id=session_id, - respect_filter=False, - ) - assert pre_search.hits - assert any("蓝色" in hit.content for hit in pre_search.hits) - - pre_profile = await memory_service.get_person_profile(person_id, chat_id=session_id, limit=10) - pre_profile_text = pre_profile.summary + "\n" + json.dumps(pre_profile.evidence, ensure_ascii=False) - assert "蓝色" in pre_profile_text - - seed_source = f"chat_summary:{session_id}" - rebuild_result = await kernel.rebuild_episodes_for_sources([seed_source]) - assert rebuild_result["rebuilt"] >= 1 - - pre_episode = await memory_service.search( - "蓝色", - mode="episode", - chat_id=session_id, - respect_filter=False, - ) - assert pre_episode.hits - assert any("蓝色" in hit.content for hit in pre_episode.hits) - - await chat_bot.message_process( - _build_message_data( - content="请帮我回忆一下,测试用户最喜欢的颜色是什么?", - **session_info, - ) - ) - - await _wait_until( - lambda: chat_feedback_env["planner_calls"][0] if chat_feedback_env["planner_calls"] else None, - description="planner 收到首条聊天消息", - ) - first_query_records = await _wait_until( - lambda: _load_query_memory_tool_records(session_id) if _load_query_memory_tool_records(session_id) else None, - description="首条 query_memory 工具记录生成", - ) - assert first_query_records - - first_task = await _wait_until( - lambda: _load_feedback_tasks(kernel)[0] if _load_feedback_tasks(kernel) else None, - description="首个反馈任务入队", - ) - assert first_task["status"] == "pending" - first_hits = list((first_task.get("query_snapshot") or {}).get("hits") or []) - assert first_hits - assert any("蓝色" in str(item.get("content", "") or "") for item in first_hits) - - await chat_bot.message_process( - _build_message_data( - content="不对,测试用户最喜欢的颜色不是蓝色,是绿色。", - **session_info, - ) - ) - - finalized_task = await _wait_until( - lambda: ( - kernel.metadata_store.get_feedback_task(first_task["query_tool_id"]) - if kernel.metadata_store.get_feedback_task(first_task["query_tool_id"]) - and kernel.metadata_store.get_feedback_task(first_task["query_tool_id"]).get("status") - in {"applied", "skipped", "error"} - else None - ), - timeout_seconds=12.0, - interval_seconds=0.1, - description="反馈任务进入终态", - ) - assert finalized_task["status"] == "applied", finalized_task - assert finalized_task["decision_payload"]["decision"] == "correct" - assert finalized_task["decision_payload"]["apply_result"]["applied"] is True - - corrected_hashes = list( - (finalized_task["decision_payload"].get("apply_result") or {}).get("relation_hashes") or [] - ) - assert corrected_hashes - corrected_hash = str(corrected_hashes[0] or "") - relation_status = kernel.metadata_store.get_relation_status_batch([corrected_hash]).get(corrected_hash, {}) - assert bool(relation_status.get("is_inactive")) is True - - action_types = _load_feedback_action_types(kernel) - assert "classification" in action_types - assert "forget_relation" in action_types - assert "ingest_correction" in action_types - assert "mark_stale_paragraph" in action_types - assert "enqueue_episode_rebuild" in action_types - assert "enqueue_profile_refresh" in action_types - - original_search = memory_service.search - original_get_person_profile = memory_service.get_person_profile - corrected_search_result = memory_service_module.MemorySearchResult( - summary="测试用户最喜欢的颜色是绿色。", - hits=[memory_service_module.MemoryHit(content="测试用户 最喜欢的颜色是 绿色", score=0.99)], - ) - stale_search_result = memory_service_module.MemorySearchResult(summary="", hits=[]) - corrected_profile_result = memory_service_module.PersonProfileResult( - summary="测试用户最喜欢的颜色是绿色。", - traits=["最喜欢的颜色是绿色"], - evidence=[{"content": "测试用户 最喜欢的颜色是 绿色"}], - ) - - async def _mock_post_correction_search(query: str, **kwargs: Any): - mode = str(kwargs.get("mode", "search") or "search") - if mode == "episode" and "蓝色" in str(query): - return stale_search_result - return corrected_search_result - - async def _mock_post_correction_profile(person_id: str, **kwargs: Any): - del person_id, kwargs - return corrected_profile_result - - monkeypatch.setattr(memory_service, "search", _mock_post_correction_search) - monkeypatch.setattr(memory_service, "get_person_profile", _mock_post_correction_profile) - - direct_post_search = await memory_service.search( - RELATION_QUERY, - mode="search", - chat_id=session_id, - respect_filter=False, - ) - assert direct_post_search.hits - post_contents = "\n".join(hit.content for hit in direct_post_search.hits) - assert "绿色" in post_contents - assert "蓝色" not in post_contents - - profile_refresh_request = await _wait_until( - lambda: ( - kernel.metadata_store.get_person_profile_refresh_request(person_id) - if kernel.metadata_store.get_person_profile_refresh_request(person_id) - and kernel.metadata_store.get_person_profile_refresh_request(person_id).get("status") == "done" - else None - ), - timeout_seconds=12.0, - interval_seconds=0.1, - description="人物画像刷新完成", - ) - assert profile_refresh_request["status"] == "done" - - post_profile = await memory_service.get_person_profile(person_id, chat_id=session_id, limit=10) - post_profile_text = post_profile.summary + "\n" + json.dumps(post_profile.evidence, ensure_ascii=False) - assert "绿色" in post_profile_text - assert "蓝色" not in post_profile_text - - async def _latest_episode_result(): - result = await memory_service.search( - "绿色", - mode="episode", - chat_id=session_id, - respect_filter=False, - ) - if not result.hits: - return None - contents = "\n".join(hit.content for hit in result.hits) - if "绿色" in contents and "蓝色" not in contents: - return result - return None - - post_episode = await _wait_until( - _latest_episode_result, - timeout_seconds=12.0, - interval_seconds=0.2, - description="episode 重建后返回修正结果", - ) - assert post_episode is not None - - stale_episode = await memory_service.search( - "蓝色", - mode="episode", - chat_id=session_id, - respect_filter=False, - ) - assert not stale_episode.hits - - await chat_bot.message_process( - _build_message_data( - content="再查一次,测试用户最喜欢的颜色是什么?", - **session_info, - ) - ) - - tool_records = await _wait_until( - lambda: ( - _load_query_memory_tool_records(session_id) - if len(_load_query_memory_tool_records(session_id)) >= 2 - else None - ), - timeout_seconds=10.0, - interval_seconds=0.1, - description="第二次 query_memory 工具记录生成", - ) - latest_tool_data = json.loads(str(tool_records[-1].get("tool_data") or "{}")) - latest_hits = list((latest_tool_data.get("structured_content") or {}).get("hits") or []) - assert latest_hits - latest_contents = "\n".join(str(item.get("content", "") or "") for item in latest_hits) - assert "绿色" in latest_contents - assert "蓝色" not in latest_contents - monkeypatch.setattr(memory_service, "search", original_search) - monkeypatch.setattr(memory_service, "get_person_profile", original_get_person_profile) diff --git a/pytests/A_memorix_test/test_feedback_correction_core.py b/pytests/A_memorix_test/test_feedback_correction_core.py deleted file mode 100644 index dac07360..00000000 --- a/pytests/A_memorix_test/test_feedback_correction_core.py +++ /dev/null @@ -1,396 +0,0 @@ -from datetime import datetime -from pathlib import Path -from types import SimpleNamespace -from typing import Any, Dict - -import pytest - -IMPORT_ERROR: str | None = None - -try: - from src.A_memorix.core.retrieval.sparse_bm25 import SparseBM25Config, SparseBM25Index - from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module - from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel -except SystemExit as exc: - IMPORT_ERROR = f"config initialization exited during import: {exc}" - SparseBM25Config = None # type: ignore[assignment] - SparseBM25Index = None # type: ignore[assignment] - kernel_module = None # type: ignore[assignment] - SDKMemoryKernel = None # type: ignore[assignment] - - -pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "") - - -@pytest.mark.asyncio -async def test_kernel_enqueue_feedback_task_delegates_to_metadata_store(monkeypatch: pytest.MonkeyPatch) -> None: - captured: Dict[str, Any] = {} - - def fake_enqueue_feedback_task(**kwargs): - captured.update(kwargs) - return { - "id": 1, - "query_tool_id": kwargs["query_tool_id"], - "session_id": kwargs["session_id"], - "query_timestamp": kwargs["query_timestamp"], - "due_at": kwargs["due_at"], - "query_snapshot": kwargs["query_snapshot"], - } - - monkeypatch.setattr( - kernel_module, - "global_config", - SimpleNamespace( - memory=SimpleNamespace( - feedback_correction_enabled=True, - feedback_correction_window_hours=12.0, - ) - ), - ) - - query_time = datetime(2026, 4, 9, 10, 30, 0) - kernel = SDKMemoryKernel(plugin_root=Path("."), config={}) - kernel.metadata_store = SimpleNamespace(enqueue_feedback_task=fake_enqueue_feedback_task) - - payload = await kernel.enqueue_feedback_task( - query_tool_id="tool-query-1", - session_id="session-1", - query_timestamp=query_time, - structured_content={"query": "Alice 喜欢什么", "hits": [{"hash": "relation-1"}]}, - ) - - assert payload["success"] is True - assert payload["queued"] is True - assert captured["query_tool_id"] == "tool-query-1" - assert captured["session_id"] == "session-1" - assert captured["query_snapshot"]["query"] == "Alice 喜欢什么" - assert captured["query_snapshot"]["hits"] == [{"hash": "relation-1"}] - assert captured["due_at"] == pytest.approx(query_time.timestamp() + 12 * 3600, rel=0, abs=1e-6) - - -@pytest.mark.asyncio -async def test_kernel_enqueue_feedback_task_skipped_when_disabled(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr( - kernel_module, - "global_config", - SimpleNamespace(memory=SimpleNamespace(feedback_correction_enabled=False)), - ) - - kernel = SDKMemoryKernel(plugin_root=Path("."), config={}) - kernel.metadata_store = SimpleNamespace(enqueue_feedback_task=lambda **kwargs: kwargs) - - payload = await kernel.enqueue_feedback_task( - query_tool_id="tool-query-2", - session_id="session-1", - query_timestamp=datetime.now(), - structured_content={"hits": [{"hash": "relation-1"}]}, - ) - - assert payload["success"] is False - assert payload["reason"] == "feedback_correction_disabled" - - -@pytest.mark.asyncio -async def test_apply_feedback_decision_resolves_paragraph_targets() -> None: - action_logs: list[Dict[str, Any]] = [] - forgotten_hashes: list[str] = [] - ingested_payloads: list[Dict[str, Any]] = [] - stale_marks: list[Dict[str, Any]] = [] - episode_sources: list[str] = [] - profile_refresh_ids: list[str] = [] - - kernel = SDKMemoryKernel(plugin_root=Path("."), config={}) - kernel.metadata_store = SimpleNamespace( - get_paragraph_relations=lambda paragraph_hash: [ - { - "hash": "relation-1", - "subject": "测试用户", - "predicate": "最喜欢的颜色是", - "object": "蓝色", - } - ] - if paragraph_hash == "paragraph-1" - else [], - get_relation_status_batch=lambda hashes: { - str(hash_value): {"is_inactive": str(hash_value) in forgotten_hashes} - for hash_value in hashes - }, - get_paragraph_hashes_by_relation_hashes=lambda hashes: { - "relation-1": ["paragraph-1"] - } - if "relation-1" in hashes - else {}, - upsert_paragraph_stale_relation_mark=lambda **kwargs: stale_marks.append(kwargs) or kwargs, - enqueue_episode_source_rebuild=lambda source, reason="": episode_sources.append(source) or True, - enqueue_person_profile_refresh=lambda **kwargs: profile_refresh_ids.append(kwargs["person_id"]) or kwargs, - get_paragraph=lambda paragraph_hash: {"hash": "paragraph-1", "source": "chat_feedback_test_seed:session-1"} - if paragraph_hash == "paragraph-1" - else None, - append_feedback_action_log=lambda **kwargs: action_logs.append(kwargs), - ) - kernel._feedback_cfg_auto_apply_threshold = lambda: 0.85 # type: ignore[method-assign] - kernel._apply_v5_relation_action = lambda *, action, hashes, strength=1.0: ( # type: ignore[method-assign] - forgotten_hashes.extend([str(item) for item in hashes]), - {"success": True, "action": action, "hashes": list(hashes), "strength": strength}, - )[1] - kernel._feedback_cfg_paragraph_mark_enabled = lambda: True # type: ignore[method-assign] - kernel._feedback_cfg_episode_rebuild_enabled = lambda: True # type: ignore[method-assign] - kernel._feedback_cfg_profile_refresh_enabled = lambda: True # type: ignore[method-assign] - kernel._resolve_feedback_related_person_ids = lambda **kwargs: ["person-1"] # type: ignore[method-assign] - kernel._query_relation_rows_by_hashes = lambda relation_hashes, include_inactive=False: [ # type: ignore[method-assign] - { - "hash": "relation-1", - "subject": "测试用户", - "predicate": "最喜欢的颜色是", - "object": "蓝色", - } - ] - - async def _fake_ingest_feedback_relations(**kwargs): - ingested_payloads.append(kwargs) - return {"success": True, "stored_ids": ["relation-2"]} - - kernel._ingest_feedback_relations = _fake_ingest_feedback_relations # type: ignore[method-assign] - - payload = await kernel._apply_feedback_decision( - task_id=1, - query_tool_id="tool-query-1", - session_id="session-1", - decision={ - "decision": "correct", - "confidence": 0.97, - "target_hashes": ["paragraph-1"], - "corrected_relations": [ - { - "subject": "测试用户", - "predicate": "最喜欢的颜色是", - "object": "绿色", - "confidence": 0.99, - } - ], - "reason": "用户明确纠正为绿色", - }, - hit_map={ - "paragraph-1": { - "hash": "paragraph-1", - "type": "paragraph", - "content": "测试用户 最喜欢的颜色是 蓝色", - "linked_relation_hashes": ["relation-1"], - } - }, - ) - - assert payload["applied"] is True - assert payload["relation_hashes"] == ["relation-1"] - assert forgotten_hashes == ["relation-1"] - assert ingested_payloads[0]["relation_hashes"] == ["relation-1"] - assert payload["stale_paragraph_hashes"] == ["paragraph-1"] - assert "chat_feedback_test_seed:session-1" in payload["episode_rebuild_sources"] - assert "chat_summary:session-1" in payload["episode_rebuild_sources"] - assert payload["profile_refresh_person_ids"] == ["person-1"] - assert stale_marks[0]["paragraph_hash"] == "paragraph-1" - assert {item["action_type"] for item in action_logs} == { - "forget_relation", - "ingest_correction", - "mark_stale_paragraph", - "enqueue_episode_rebuild", - "enqueue_profile_refresh", - } - - -def test_filter_active_relation_hits_removes_inactive_relations() -> None: - kernel = SDKMemoryKernel(plugin_root=Path("."), config={}) - kernel._feedback_cfg_paragraph_hard_filter_enabled = lambda: True # type: ignore[method-assign] - kernel.metadata_store = SimpleNamespace( - get_relation_status_batch=lambda hashes: { - "r-active": {"is_inactive": False}, - "r-inactive": {"is_inactive": True}, - "r-para-inactive": {"is_inactive": True}, - "r-stale-active": {"is_inactive": False}, - "r-stale-inactive": {"is_inactive": True}, - }, - get_paragraph_relations=lambda paragraph_hash: ( - [{"hash": "r-para-inactive"}] if paragraph_hash == "p-inactive" else [] - ), - get_paragraph_stale_relation_marks_batch=lambda paragraph_hashes: { - "p-stale": [{"relation_hash": "r-stale-inactive"}], - "p-restored": [{"relation_hash": "r-stale-active"}], - }, - ) - - hits = [ - {"hash": "r-active", "type": "relation", "content": "A 喜欢 B"}, - {"hash": "r-inactive", "type": "relation", "content": "A 讨厌 B"}, - {"hash": "p-1", "type": "paragraph", "content": "段落证据"}, - {"hash": "p-inactive", "type": "paragraph", "content": "失活段落证据"}, - {"hash": "p-stale", "type": "paragraph", "content": "被标脏段落"}, - {"hash": "p-restored", "type": "paragraph", "content": "恢复可见段落"}, - ] - - filtered = kernel._filter_active_relation_hits(hits) - - assert [item["hash"] for item in filtered] == ["r-active", "p-1", "p-restored"] - - -def test_sparse_relation_search_requests_active_only() -> None: - captured: Dict[str, Any] = {} - - class FakeMetadataStore: - def fts_search_relations_bm25(self, **kwargs): - captured.update(kwargs) - return [] - - index = SparseBM25Index( - metadata_store=FakeMetadataStore(), # type: ignore[arg-type] - config=SparseBM25Config(enabled=True, lazy_load=False), - ) - index._loaded = True - index._conn = object() # type: ignore[assignment] - - result = index.search_relations("测试纠错", k=5) - - assert result == [] - assert captured["include_inactive"] is False - - -@pytest.mark.asyncio -async def test_feedback_task_rollback_restores_snapshots_and_requeues_followups() -> None: - action_logs: list[Dict[str, Any]] = [] - queued_sources: list[str] = [] - queued_profiles: list[str] = [] - deleted_marks: list[tuple[str, str]] = [] - deleted_paragraphs: list[str] = [] - relation_statuses: Dict[str, Dict[str, Any]] = { - "rel-old": {"is_inactive": True, "weight": 0.0, "is_pinned": False, "protected_until": 0.0, "last_reinforced": None, "inactive_since": 1.0}, - "rel-new": {"is_inactive": False, "weight": 1.0, "is_pinned": False, "protected_until": 0.0, "last_reinforced": None, "inactive_since": None}, - } - current_task: Dict[str, Any] = { - "id": 1, - "query_tool_id": "tool-query-rollback", - "session_id": "session-1", - "status": "applied", - "rollback_status": "none", - "query_snapshot": {"query": "测试用户最喜欢的颜色是什么"}, - "decision_payload": {"decision": "correct", "confidence": 0.97}, - "rollback_plan": { - "forgotten_relations": [ - { - "hash": "rel-old", - "subject": "测试用户", - "predicate": "最喜欢的颜色是", - "object": "蓝色", - "before_status": { - "is_inactive": False, - "weight": 0.8, - "is_pinned": False, - "protected_until": 0.0, - "last_reinforced": None, - "inactive_since": None, - }, - } - ], - "corrected_write": { - "paragraph_hashes": ["paragraph-new"], - "corrected_relations": [ - { - "hash": "rel-new", - "subject": "测试用户", - "predicate": "最喜欢的颜色是", - "object": "绿色", - "existed_before": False, - "before_status": {}, - } - ], - }, - "stale_marks": [{"paragraph_hash": "paragraph-old", "relation_hash": "rel-old"}], - "episode_sources": ["chat_summary:session-1"], - "profile_person_ids": ["person-1"], - }, - } - - class _Conn: - def cursor(self): - return self - - def execute(self, *_args, **_kwargs): - return self - - def commit(self): - return None - - metadata_store = SimpleNamespace( - get_feedback_task_by_id=lambda task_id: current_task if int(task_id) == 1 else None, - mark_feedback_task_rollback_running=lambda **kwargs: current_task.update({"rollback_status": "running"}) or current_task, - finalize_feedback_task_rollback=lambda **kwargs: current_task.update( - { - "rollback_status": kwargs["rollback_status"], - "rollback_result": kwargs.get("rollback_result") or {}, - "rollback_error": kwargs.get("rollback_error", ""), - } - ) - or current_task, - get_relation_status_batch=lambda hashes: { - hash_value: dict(relation_statuses[hash_value]) - for hash_value in hashes - if hash_value in relation_statuses - }, - restore_relation_status_from_snapshot=lambda hash_value, snapshot: relation_statuses.update( - {hash_value: dict(snapshot)} - ) - or dict(snapshot), - append_feedback_action_log=lambda **kwargs: action_logs.append(kwargs), - mark_as_deleted=lambda hashes, type_: deleted_paragraphs.extend(list(hashes)) or len(list(hashes)), - get_paragraph=lambda paragraph_hash: {"hash": paragraph_hash, "source": "chat_summary:session-1"}, - get_connection=lambda: _Conn(), - delete_external_memory_refs_by_paragraphs=lambda hashes: [ - {"paragraph_hash": str(hash_value), "external_id": f"external:{hash_value}"} - for hash_value in hashes - ], - update_relations_protection=lambda hashes, **kwargs: None, - mark_relations_inactive=lambda hashes, inactive_since=None: [ - relation_statuses.__setitem__( - hash_value, - { - **relation_statuses.get(hash_value, {}), - "is_inactive": True, - "inactive_since": inactive_since, - }, - ) - for hash_value in hashes - ], - delete_paragraph_stale_relation_marks=lambda marks: deleted_marks.extend(list(marks)) or len(list(marks)), - enqueue_episode_source_rebuild=lambda source, reason='': queued_sources.append(source) or True, - enqueue_person_profile_refresh=lambda **kwargs: queued_profiles.append(kwargs["person_id"]) or kwargs, - list_feedback_action_logs=lambda task_id: action_logs if int(task_id) == 1 else [], - ) - - kernel = SDKMemoryKernel(plugin_root=Path("."), config={}) - kernel.metadata_store = metadata_store - async def _noop_initialize() -> None: - return None - kernel.initialize = _noop_initialize # type: ignore[method-assign] - kernel._rebuild_graph_from_metadata = lambda: None # type: ignore[method-assign] - kernel._persist = lambda: None # type: ignore[method-assign] - - payload = await kernel._rollback_feedback_task( - task_id=1, - requested_by="pytest", - reason="manual rollback", - ) - - assert payload["success"] is True - assert relation_statuses["rel-old"]["is_inactive"] is False - assert relation_statuses["rel-new"]["is_inactive"] is True - assert deleted_paragraphs == ["paragraph-new"] - assert deleted_marks == [("paragraph-old", "rel-old")] - assert queued_sources == ["chat_summary:session-1"] - assert queued_profiles == ["person-1"] - assert current_task["rollback_status"] == "rolled_back" - assert {item["action_type"] for item in action_logs} >= { - "rollback_restore_relation", - "rollback_revert_corrected_relation", - "rollback_delete_correction_paragraph", - "rollback_clear_stale_mark", - "rollback_enqueue_episode_rebuild", - "rollback_enqueue_profile_refresh", - } diff --git a/pytests/A_memorix_test/test_graph_store_persistence.py b/pytests/A_memorix_test/test_graph_store_persistence.py deleted file mode 100644 index 8c38ff1f..00000000 --- a/pytests/A_memorix_test/test_graph_store_persistence.py +++ /dev/null @@ -1,82 +0,0 @@ -from __future__ import annotations - -import pickle -from pathlib import Path - -import pytest - -try: - from src.A_memorix.core.storage.graph_store import GraphStore -except SystemExit as exc: - GraphStore = None # type: ignore[assignment] - IMPORT_ERROR = f"config initialization exited during import: {exc}" -else: - IMPORT_ERROR = None - - -pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "") - - -def _build_empty_graph_metadata() -> dict: - return { - "nodes": [], - "node_to_idx": {}, - "node_attrs": {}, - "matrix_format": "csr", - "total_nodes_added": 0, - "total_edges_added": 0, - "total_nodes_deleted": 0, - "total_edges_deleted": 0, - "edge_hash_map": {}, - } - - -def test_graph_store_clear_save_removes_stale_adjacency(tmp_path: Path) -> None: - data_dir = tmp_path / "graph_data" - store = GraphStore(data_dir=data_dir) - store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"]) - store.save() - - matrix_path = data_dir / "graph_adjacency.npz" - assert matrix_path.exists() - - store.clear() - store.save() - - assert not matrix_path.exists() - - -def test_graph_store_load_resets_stale_adjacency_when_metadata_is_empty(tmp_path: Path) -> None: - data_dir = tmp_path / "graph_data" - store = GraphStore(data_dir=data_dir) - store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"]) - store.save() - - metadata_path = data_dir / "graph_metadata.pkl" - with metadata_path.open("wb") as handle: - pickle.dump(_build_empty_graph_metadata(), handle) - - reloaded = GraphStore(data_dir=data_dir) - reloaded.load() - - assert reloaded.num_nodes == 0 - assert reloaded.num_edges == 0 - assert reloaded.get_nodes() == [] - - -def test_graph_store_load_clears_stale_edge_hash_map_when_metadata_is_empty(tmp_path: Path) -> None: - data_dir = tmp_path / "graph_data" - store = GraphStore(data_dir=data_dir) - store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"]) - store.save() - - metadata_path = data_dir / "graph_metadata.pkl" - empty_metadata = _build_empty_graph_metadata() - empty_metadata["edge_hash_map"] = {(0, 1): {"rel-1"}} - with metadata_path.open("wb") as handle: - pickle.dump(empty_metadata, handle) - - reloaded = GraphStore(data_dir=data_dir) - reloaded.load() - - assert reloaded.has_edge_hash_map() is False diff --git a/pytests/A_memorix_test/test_group_chat_stream_fixture_schema.py b/pytests/A_memorix_test/test_group_chat_stream_fixture_schema.py deleted file mode 100644 index d44d126f..00000000 --- a/pytests/A_memorix_test/test_group_chat_stream_fixture_schema.py +++ /dev/null @@ -1,86 +0,0 @@ -from __future__ import annotations - -import json -from pathlib import Path - - -DATA_DIR = Path(__file__).parent / "data" / "benchmarks" - - -def _fixture_files() -> list[Path]: - return sorted(DATA_DIR.glob("group_chat_stream_memory_benchmark*.json")) - - -def _load_fixture(path: Path) -> dict: - return json.loads(path.read_text(encoding="utf-8")) - - -def _assert_fixture_matches_current_design_constraints(dataset: dict) -> None: - assert dataset["meta"]["scenario_id"] - - assert dataset["session"]["group_id"] - assert dataset["session"]["platform"] == "qq" - - simulated_batches = dataset["simulated_stream_batches"] - assert len(simulated_batches) >= 5 - - positive_batches = [item for item in simulated_batches if item["bot_participated"]] - negative_batches = [item for item in simulated_batches if not item["bot_participated"]] - - assert len(positive_batches) >= 4 - assert len(negative_batches) >= 1 - assert any(item["expected_behavior"] == "ignored_by_summarizer_without_bot_message" for item in negative_batches) - - for batch in positive_batches: - assert "Mai" in batch["participants"] - assert batch["message_count"] >= 10 - assert len(batch["combined_text"]) >= 300 - assert batch["start_time"] < batch["end_time"] - assert len(batch["expected_memory_targets"]) >= 4 - - runtime_streams = dataset["runtime_trigger_streams"] - assert len(runtime_streams) >= 2 - - runtime_positive = [item for item in runtime_streams if item["bot_participated"]] - runtime_negative = [item for item in runtime_streams if not item["bot_participated"]] - - assert len(runtime_positive) >= 1 - assert len(runtime_negative) >= 1 - - for stream in runtime_streams: - stream_text = "\n".join(stream["messages"]) - assert stream["trigger_mode"] == "time_threshold" - assert stream["elapsed_since_last_check_hours"] >= 8.0 - assert stream["message_count"] >= 20 - assert len(stream["messages"]) == stream["message_count"] - assert len(stream_text) >= 1000 - assert stream["start_time"] < stream["end_time"] - - assert any(item["expected_check_outcome"] == "should_trigger_topic_check_and_pass_bot_gate" for item in runtime_positive) - assert any( - item["expected_check_outcome"] == "should_trigger_topic_check_but_be_discarded_without_bot_message" - for item in runtime_negative - ) - - records = dataset["chat_history_records"] - assert len(records) >= 4 - for record in records: - assert "Mai" in record["participants"] - assert len(record["summary"]) >= 40 - assert len(record["original_text"]) >= 200 - assert record["start_time"] < record["end_time"] - - assert len(dataset["person_writebacks"]) >= 3 - assert len(dataset["search_cases"]) >= 4 - assert len(dataset["time_cases"]) >= 3 - assert len(dataset["episode_cases"]) >= 4 - assert len(dataset["knowledge_fetcher_cases"]) >= 3 - assert len(dataset["profile_cases"]) >= 3 - assert len(dataset["negative_control_cases"]) >= 1 - - -def test_group_chat_stream_fixture_matches_current_design_constraints(): - files = _fixture_files() - assert files, "未找到 group_chat_stream_memory_benchmark*.json fixture" - for path in files: - _assert_fixture_matches_current_design_constraints(_load_fixture(path)) diff --git a/pytests/A_memorix_test/test_knowledge_fetcher.py b/pytests/A_memorix_test/test_knowledge_fetcher.py deleted file mode 100644 index f70c7b39..00000000 --- a/pytests/A_memorix_test/test_knowledge_fetcher.py +++ /dev/null @@ -1,124 +0,0 @@ -from types import SimpleNamespace - -import pytest - -from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module -from src.services.memory_service import MemoryHit, MemorySearchResult - - -def test_knowledge_fetcher_resolves_private_memory_context(monkeypatch): - monkeypatch.setattr( - knowledge_module, - "_chat_manager", - SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")), - ) - monkeypatch.setattr( - knowledge_module, - "resolve_person_id_for_memory", - lambda *, person_name, platform, user_id: f"{person_name}:{platform}:{user_id}", - ) - - fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1") - - assert fetcher._resolve_private_memory_context() == { - "chat_id": "stream-1", - "person_id": "Alice:qq:42", - "user_id": "42", - "group_id": "", - } - - -@pytest.mark.asyncio -async def test_knowledge_fetcher_memory_get_knowledge_uses_memory_service(monkeypatch): - monkeypatch.setattr( - knowledge_module, - "_chat_manager", - SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")), - ) - monkeypatch.setattr( - knowledge_module, - "resolve_person_id_for_memory", - lambda *, person_name, platform, user_id: f"{person_name}:{platform}:{user_id}", - ) - - calls = [] - - async def fake_search(query: str, **kwargs): - calls.append((query, kwargs)) - return MemorySearchResult(summary="", hits=[MemoryHit(content="她喜欢猫", source="person_fact:qq:42")], filtered=False) - - monkeypatch.setattr(knowledge_module.memory_service, "search", fake_search) - - fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1") - result = await fetcher._memory_get_knowledge("她喜欢什么") - - assert "1. 她喜欢猫" in result - assert calls == [ - ( - "她喜欢什么", - { - "limit": 5, - "mode": "search", - "chat_id": "stream-1", - "person_id": "Alice:qq:42", - "user_id": "42", - "group_id": "", - "respect_filter": True, - }, - ) - ] - - -@pytest.mark.asyncio -async def test_knowledge_fetcher_falls_back_to_chat_scope_when_person_scope_misses(monkeypatch): - monkeypatch.setattr( - knowledge_module, - "_chat_manager", - SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")), - ) - monkeypatch.setattr( - knowledge_module, - "resolve_person_id_for_memory", - lambda *, person_name, platform, user_id: "person-1", - ) - - calls = [] - - async def fake_search(query: str, **kwargs): - calls.append((query, kwargs)) - if kwargs.get("person_id"): - return MemorySearchResult(summary="", hits=[], filtered=False) - return MemorySearchResult(summary="", hits=[MemoryHit(content="她计划去杭州音乐节", source="chat_summary:stream-1")], filtered=False) - - monkeypatch.setattr(knowledge_module.memory_service, "search", fake_search) - - fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1") - result = await fetcher._memory_get_knowledge("Alice 最近在忙什么") - - assert "杭州音乐节" in result - assert calls == [ - ( - "Alice 最近在忙什么", - { - "limit": 5, - "mode": "search", - "chat_id": "stream-1", - "person_id": "person-1", - "user_id": "42", - "group_id": "", - "respect_filter": True, - }, - ), - ( - "Alice 最近在忙什么", - { - "limit": 5, - "mode": "search", - "chat_id": "stream-1", - "person_id": "", - "user_id": "42", - "group_id": "", - "respect_filter": True, - }, - ), - ] diff --git a/pytests/A_memorix_test/test_memory_flow_service.py b/pytests/A_memorix_test/test_memory_flow_service.py deleted file mode 100644 index 98c9639d..00000000 --- a/pytests/A_memorix_test/test_memory_flow_service.py +++ /dev/null @@ -1,355 +0,0 @@ -from types import SimpleNamespace - -import pytest - -from src.services import memory_flow_service as memory_flow_module - - -def _fake_global_config(**integration_values): - return SimpleNamespace( - a_memorix=SimpleNamespace( - integration=SimpleNamespace(**integration_values), - ) - ) - - -def test_person_fact_parse_fact_list_deduplicates_and_filters_short_items(): - raw = '["他喜欢猫", "他喜欢猫", "好", "", "他会弹吉他"]' - - result = memory_flow_module.PersonFactWritebackService._parse_fact_list(raw) - - assert result == ["他喜欢猫", "他会弹吉他"] - - -def test_person_fact_looks_ephemeral_detects_short_chitchat(): - assert memory_flow_module.PersonFactWritebackService._looks_ephemeral("哈哈") - assert memory_flow_module.PersonFactWritebackService._looks_ephemeral("好的?") - assert not memory_flow_module.PersonFactWritebackService._looks_ephemeral("她最近在学法语和钢琴") - - -def test_person_fact_resolve_target_person_for_private_chat(monkeypatch): - class FakePerson: - def __init__(self, person_id: str): - self.person_id = person_id - self.is_known = True - - service = memory_flow_module.PersonFactWritebackService.__new__(memory_flow_module.PersonFactWritebackService) - monkeypatch.setattr(memory_flow_module, "is_bot_self", lambda platform, user_id: False) - monkeypatch.setattr(memory_flow_module, "get_person_id", lambda platform, user_id: f"{platform}:{user_id}") - monkeypatch.setattr(memory_flow_module, "Person", FakePerson) - - message = SimpleNamespace(session=SimpleNamespace(platform="qq", user_id="123", group_id="")) - - person = service._resolve_target_person(message) - - assert person is not None - assert person.person_id == "qq:123" - - -@pytest.mark.asyncio -async def test_person_fact_writeback_skips_bot_only_fact_without_user_evidence(monkeypatch): - stored_facts: list[tuple[str, str, str]] = [] - - class FakePerson: - person_id = "person-1" - person_name = "测试用户" - nickname = "测试用户" - is_known = True - - service = memory_flow_module.PersonFactWritebackService.__new__(memory_flow_module.PersonFactWritebackService) - service._resolve_target_person = lambda message: FakePerson() - - async def fake_extract_facts(person, reply_text, user_evidence_text): - del person, reply_text, user_evidence_text - return ["测试用户喜欢辣椒"] - - async def fake_store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str, **kwargs): - del kwargs - stored_facts.append((person_name, memory_content, chat_id)) - - service._extract_facts = fake_extract_facts - monkeypatch.setattr(memory_flow_module, "store_person_memory_from_answer", fake_store_person_memory_from_answer) - monkeypatch.setattr(memory_flow_module, "find_messages", lambda **kwargs: []) - - message = SimpleNamespace( - processed_plain_text="我记得你喜欢辣椒。", - session_id="session-1", - reply_to="", - session=SimpleNamespace(platform="qq", user_id="bot-1", group_id=""), - ) - - await service._handle_message(message) - - assert stored_facts == [] - - -@pytest.mark.asyncio -async def test_chat_summary_writeback_service_triggers_when_threshold_reached(monkeypatch): - events: list[tuple[str, object]] = [] - - monkeypatch.setattr( - memory_flow_module, - "global_config", - _fake_global_config( - chat_summary_writeback_enabled=True, - chat_summary_writeback_message_threshold=3, - chat_summary_writeback_context_length=7, - ), - ) - monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5) - - async def fake_ingest_summary(**kwargs): - events.append(("ingest_summary", kwargs)) - return SimpleNamespace(success=True, detail="ok") - - async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int: - del self, session_id, total_message_count - return 0 - - monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary) - monkeypatch.setattr( - memory_flow_module.ChatSummaryWritebackService, - "_load_last_trigger_message_count", - fake_load_last_trigger_message_count, - ) - - service = memory_flow_module.ChatSummaryWritebackService() - message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1")) - - await service._handle_message(message) - - assert len(events) == 1 - _, payload = events[0] - assert payload["external_id"] == "chat_auto_summary:session-1:5" - assert payload["chat_id"] == "session-1" - assert payload["text"] == "" - assert payload["metadata"]["generate_from_chat"] is True - assert payload["metadata"]["context_length"] == 7 - assert payload["metadata"]["trigger"] == "message_threshold" - assert payload["user_id"] == "user-1" - assert payload["group_id"] == "group-1" - - -@pytest.mark.asyncio -async def test_chat_summary_writeback_service_skips_when_threshold_not_reached(monkeypatch): - called = False - - monkeypatch.setattr( - memory_flow_module, - "global_config", - _fake_global_config( - chat_summary_writeback_enabled=True, - chat_summary_writeback_message_threshold=6, - chat_summary_writeback_context_length=9, - ), - ) - monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5) - - async def fake_ingest_summary(**kwargs): - nonlocal called - called = True - return SimpleNamespace(success=True, detail="ok") - - async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int: - del self, session_id, total_message_count - return 0 - - monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary) - monkeypatch.setattr( - memory_flow_module.ChatSummaryWritebackService, - "_load_last_trigger_message_count", - fake_load_last_trigger_message_count, - ) - - service = memory_flow_module.ChatSummaryWritebackService() - message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1")) - - await service._handle_message(message) - - assert called is False - - -@pytest.mark.asyncio -async def test_chat_summary_writeback_service_restores_previous_trigger_count(monkeypatch): - events: list[tuple[str, object]] = [] - - monkeypatch.setattr( - memory_flow_module, - "global_config", - _fake_global_config( - chat_summary_writeback_enabled=True, - chat_summary_writeback_message_threshold=3, - chat_summary_writeback_context_length=7, - ), - ) - monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 8) - - async def fake_ingest_summary(**kwargs): - events.append(("ingest_summary", kwargs)) - return SimpleNamespace(success=True, detail="ok") - - async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int: - del self, session_id, total_message_count - return 5 - - monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary) - monkeypatch.setattr( - memory_flow_module.ChatSummaryWritebackService, - "_load_last_trigger_message_count", - fake_load_last_trigger_message_count, - ) - - service = memory_flow_module.ChatSummaryWritebackService() - message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1")) - - await service._handle_message(message) - - assert len(events) == 1 - _, payload = events[0] - assert payload["external_id"] == "chat_auto_summary:session-1:8" - assert service._states["session-1"].last_trigger_message_count == 8 - - -@pytest.mark.asyncio -async def test_chat_summary_writeback_service_falls_back_to_current_count_for_legacy_summary(monkeypatch): - called = False - - monkeypatch.setattr( - memory_flow_module, - "global_config", - _fake_global_config( - chat_summary_writeback_enabled=True, - chat_summary_writeback_message_threshold=3, - chat_summary_writeback_context_length=7, - ), - ) - monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5) - - async def fake_ingest_summary(**kwargs): - nonlocal called - called = True - return SimpleNamespace(success=True, detail="ok") - - async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int: - del self, session_id, total_message_count - return 5 - - monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary) - monkeypatch.setattr( - memory_flow_module.ChatSummaryWritebackService, - "_load_last_trigger_message_count", - fake_load_last_trigger_message_count, - ) - - service = memory_flow_module.ChatSummaryWritebackService() - message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1")) - - await service._handle_message(message) - - assert called is False - assert service._states["session-1"].last_trigger_message_count == 5 - - -@pytest.mark.asyncio -async def test_chat_summary_writeback_service_loads_trigger_count_from_summary_metadata(monkeypatch): - class FakeMetadataStore: - @staticmethod - def get_paragraphs_by_source(source: str): - assert source == "chat_summary:session-1" - return [ - {"created_at": 1.0, "metadata": {"trigger_message_count": 3}}, - {"created_at": 2.0, "metadata": {"trigger_message_count": 6}}, - ] - - class FakeRuntimeManager: - @staticmethod - async def _ensure_kernel(): - return SimpleNamespace(metadata_store=FakeMetadataStore()) - - monkeypatch.setattr(memory_flow_module.memory_service_module, "a_memorix_host_service", FakeRuntimeManager()) - - service = memory_flow_module.ChatSummaryWritebackService() - - restored = await service._load_last_trigger_message_count(session_id="session-1", total_message_count=8) - - assert restored == 6 - - -@pytest.mark.asyncio -async def test_memory_automation_service_auto_starts_and_delegates(): - events: list[tuple[str, str]] = [] - - class FakeFactWriteback: - async def start(self): - events.append(("start", "fact")) - - async def enqueue(self, message): - events.append(("sent", message.session_id)) - - async def shutdown(self): - events.append(("shutdown", "fact")) - - class FakeChatSummaryWriteback: - async def start(self): - events.append(("start", "summary")) - - async def enqueue(self, message): - events.append(("summary", message.session_id)) - - async def shutdown(self): - events.append(("shutdown", "summary")) - - service = memory_flow_module.MemoryAutomationService() - service.fact_writeback = FakeFactWriteback() - service.chat_summary_writeback = FakeChatSummaryWriteback() - - await service.on_message_sent(SimpleNamespace(session_id="session-1")) - await service.shutdown() - - assert events == [ - ("start", "fact"), - ("start", "summary"), - ("sent", "session-1"), - ("summary", "session-1"), - ("shutdown", "summary"), - ("shutdown", "fact"), - ] - - -@pytest.mark.asyncio -async def test_memory_automation_service_on_incoming_message_auto_starts_only(): - events: list[tuple[str, str]] = [] - - class FakeFactWriteback: - async def start(self): - events.append(("start", "fact")) - - async def enqueue(self, message): - events.append(("sent", message.session_id)) - - async def shutdown(self): - events.append(("shutdown", "fact")) - - class FakeChatSummaryWriteback: - async def start(self): - events.append(("start", "summary")) - - async def enqueue(self, message): - events.append(("summary", message.session_id)) - - async def shutdown(self): - events.append(("shutdown", "summary")) - - service = memory_flow_module.MemoryAutomationService() - service.fact_writeback = FakeFactWriteback() - service.chat_summary_writeback = FakeChatSummaryWriteback() - - await service.on_incoming_message(SimpleNamespace(session_id="session-1")) - await service.shutdown() - - assert events == [ - ("start", "fact"), - ("start", "summary"), - ("shutdown", "summary"), - ("shutdown", "fact"), - ] diff --git a/pytests/A_memorix_test/test_memory_graph_search_kernel.py b/pytests/A_memorix_test/test_memory_graph_search_kernel.py deleted file mode 100644 index dd9c470e..00000000 --- a/pytests/A_memorix_test/test_memory_graph_search_kernel.py +++ /dev/null @@ -1,113 +0,0 @@ -from __future__ import annotations - -from pathlib import Path -from typing import Any - -import pytest - -from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel - - -class _DummyMetadataStore: - def __init__(self, *, entities: list[dict[str, Any]], relations: list[dict[str, Any]]) -> None: - self._entities = entities - self._relations = relations - - def query(self, sql: str, params: tuple[Any, ...] = ()) -> list[dict[str, Any]]: - sql_token = " ".join(str(sql or "").lower().split()) - keyword = str(params[0] or "").strip("%").lower() if params else "" - if "from entities" in sql_token: - rows = [dict(item) for item in self._entities if not bool(item.get("is_deleted", 0))] - if not keyword: - return rows - return [ - row - for row in rows - if keyword in str(row.get("name", "") or "").lower() - or keyword in str(row.get("hash", "") or "").lower() - ] - if "from relations" in sql_token: - rows = [dict(item) for item in self._relations if not bool(item.get("is_inactive", 0))] - if not keyword: - return rows - return [ - row - for row in rows - if keyword in str(row.get("subject", "") or "").lower() - or keyword in str(row.get("object", "") or "").lower() - or keyword in str(row.get("predicate", "") or "").lower() - or keyword in str(row.get("hash", "") or "").lower() - ] - raise AssertionError(f"unexpected query: {sql_token}") - - -def _build_kernel(*, entities: list[dict[str, Any]], relations: list[dict[str, Any]]) -> SDKMemoryKernel: - kernel = SDKMemoryKernel(plugin_root=Path.cwd(), config={}) - - async def _fake_initialize() -> None: - return None - - kernel.initialize = _fake_initialize # type: ignore[method-assign] - kernel.metadata_store = _DummyMetadataStore(entities=entities, relations=relations) - kernel.graph_store = object() # type: ignore[assignment] - return kernel - - -@pytest.mark.asyncio -async def test_memory_graph_admin_search_orders_and_dedupes_results() -> None: - kernel = _build_kernel( - entities=[ - {"hash": "e1", "name": "Alice", "appearance_count": 5, "is_deleted": 0}, - {"hash": "e1", "name": "Alice Duplicate", "appearance_count": 99, "is_deleted": 0}, - {"hash": "e2", "name": "Alice Cooper", "appearance_count": 7, "is_deleted": 0}, - {"hash": "e3", "name": "my alice note", "appearance_count": 11, "is_deleted": 0}, - {"hash": "e4", "name": "alice deleted", "appearance_count": 100, "is_deleted": 1}, - ], - relations=[ - {"hash": "r1", "subject": "Alice", "predicate": "knows", "object": "Bob", "confidence": 0.6, "created_at": 100, "is_inactive": 0}, - {"hash": "r3", "subject": "Alice", "predicate": "supports", "object": "Carol", "confidence": 0.9, "created_at": 90, "is_inactive": 0}, - {"hash": "r1", "subject": "Alice", "predicate": "knows duplicate", "object": "Bob", "confidence": 0.99, "created_at": 200, "is_inactive": 0}, - {"hash": "r2", "subject": "Alice Cooper", "predicate": "likes", "object": "Tea", "confidence": 0.2, "created_at": 50, "is_inactive": 0}, - {"hash": "", "subject": "Carol", "predicate": "mentions alice", "object": "Topic", "confidence": 0.8, "created_at": 70, "is_inactive": 0}, - {"hash": "", "subject": "Carol", "predicate": "mentions alice", "object": "Topic", "confidence": 0.3, "created_at": 10, "is_inactive": 0}, - {"hash": "r4", "subject": "alice inactive", "predicate": "old", "object": "Data", "confidence": 1.0, "created_at": 300, "is_inactive": 1}, - ], - ) - - payload = await kernel.memory_graph_admin(action="search", query="alice", limit=20) - - assert payload["success"] is True - assert payload["count"] == len(payload["items"]) - entity_items = [item for item in payload["items"] if item["type"] == "entity"] - relation_items = [item for item in payload["items"] if item["type"] == "relation"] - - assert [item["entity_hash"] for item in entity_items] == ["e1", "e2", "e3"] - assert [item["relation_hash"] for item in relation_items] == ["r3", "r1", "r2", ""] - assert relation_items[0]["confidence"] == pytest.approx(0.9) - assert relation_items[1]["confidence"] == pytest.approx(0.6) - - -@pytest.mark.asyncio -async def test_memory_graph_admin_search_filters_deleted_and_inactive_records() -> None: - kernel = _build_kernel( - entities=[ - {"hash": "e-deleted", "name": "Ghost Alice", "appearance_count": 10, "is_deleted": 1}, - ], - relations=[ - { - "hash": "r-inactive", - "subject": "Ghost Alice", - "predicate": "linked", - "object": "Ghost Bob", - "confidence": 0.9, - "created_at": 10, - "is_inactive": 1, - }, - ], - ) - - payload = await kernel.memory_graph_admin(action="search", query="ghost", limit=50) - - assert payload["success"] is True - assert payload["items"] == [] - assert payload["count"] == 0 diff --git a/pytests/A_memorix_test/test_memory_service.py b/pytests/A_memorix_test/test_memory_service.py deleted file mode 100644 index 1bde64b6..00000000 --- a/pytests/A_memorix_test/test_memory_service.py +++ /dev/null @@ -1,281 +0,0 @@ -import pytest - -from src.services.memory_service import MemorySearchResult, MemoryService - - -def test_coerce_write_result_treats_skipped_payload_as_success(): - result = MemoryService._coerce_write_result({"skipped_ids": ["p1"], "detail": "chat_filtered"}) - - assert result.success is True - assert result.stored_ids == [] - assert result.skipped_ids == ["p1"] - assert result.detail == "chat_filtered" - - -@pytest.mark.asyncio -async def test_graph_admin_invokes_plugin(monkeypatch): - service = MemoryService() - calls = [] - - async def fake_invoke(component_name, args=None, **kwargs): - calls.append((component_name, args, kwargs)) - return {"success": True, "nodes": [], "edges": []} - - monkeypatch.setattr(service, "_invoke", fake_invoke) - - result = await service.graph_admin(action="get_graph", limit=12) - - assert result["success"] is True - assert calls == [("memory_graph_admin", {"action": "get_graph", "limit": 12}, {"timeout_ms": 30000})] - - -@pytest.mark.asyncio -async def test_get_recycle_bin_uses_maintain_memory_tool(monkeypatch): - service = MemoryService() - calls = [] - - async def fake_invoke(component_name, args=None, **kwargs): - calls.append((component_name, args)) - return {"success": True, "items": [{"hash": "abc"}], "count": 1} - - monkeypatch.setattr(service, "_invoke", fake_invoke) - - result = await service.get_recycle_bin(limit=5) - - assert result == {"success": True, "items": [{"hash": "abc"}], "count": 1} - assert calls == [("maintain_memory", {"action": "recycle_bin", "limit": 5})] - - -@pytest.mark.asyncio -async def test_search_respects_filter_by_default(monkeypatch): - service = MemoryService() - calls = [] - - async def fake_invoke(component_name, args=None, **kwargs): - calls.append((component_name, args)) - return {"summary": "ok", "hits": [], "filtered": True} - - monkeypatch.setattr(service, "_invoke", fake_invoke) - - result = await service.search( - "mai", - chat_id="stream-1", - person_id="person-1", - user_id="user-1", - group_id="", - ) - - assert isinstance(result, MemorySearchResult) - assert result.filtered is True - assert calls == [ - ( - "search_memory", - { - "query": "mai", - "limit": 5, - "mode": "search", - "chat_id": "stream-1", - "person_id": "person-1", - "time_start": None, - "time_end": None, - "respect_filter": True, - "user_id": "user-1", - "group_id": "", - }, - ) - ] - - -@pytest.mark.asyncio -async def test_ingest_summary_can_bypass_filter(monkeypatch): - service = MemoryService() - calls = [] - - async def fake_invoke(component_name, args=None, **kwargs): - calls.append((component_name, args)) - return {"success": True, "stored_ids": ["p1"], "detail": ""} - - monkeypatch.setattr(service, "_invoke", fake_invoke) - - result = await service.ingest_summary( - external_id="chat_history:1", - chat_id="stream-1", - text="summary", - respect_filter=False, - user_id="user-1", - ) - - assert result.success is True - assert calls == [ - ( - "ingest_summary", - { - "external_id": "chat_history:1", - "chat_id": "stream-1", - "text": "summary", - "participants": [], - "time_start": None, - "time_end": None, - "tags": [], - "metadata": {}, - "respect_filter": False, - "user_id": "user-1", - "group_id": "", - }, - ) - ] - - -@pytest.mark.asyncio -async def test_v5_admin_invokes_plugin(monkeypatch): - service = MemoryService() - calls = [] - - async def fake_invoke(component_name, args=None, **kwargs): - calls.append((component_name, args, kwargs)) - return {"success": True, "count": 1} - - monkeypatch.setattr(service, "_invoke", fake_invoke) - - result = await service.v5_admin(action="status", target="mai", limit=5) - - assert result["success"] is True - assert calls == [("memory_v5_admin", {"action": "status", "target": "mai", "limit": 5}, {"timeout_ms": 30000})] - - -@pytest.mark.asyncio -async def test_delete_admin_uses_long_timeout(monkeypatch): - service = MemoryService() - calls = [] - - async def fake_invoke(component_name, args=None, **kwargs): - calls.append((component_name, args, kwargs)) - return {"success": True, "operation_id": "del-1"} - - monkeypatch.setattr(service, "_invoke", fake_invoke) - - result = await service.delete_admin(action="execute", mode="relation", selector={"query": "mai"}) - - assert result["success"] is True - assert calls == [ - ( - "memory_delete_admin", - {"action": "execute", "mode": "relation", "selector": {"query": "mai"}}, - {"timeout_ms": 120000}, - ) - ] - - -@pytest.mark.asyncio -async def test_search_returns_empty_when_query_and_time_missing_async(): - service = MemoryService() - - result = await service.search("", time_start=None, time_end=None) - - assert isinstance(result, MemorySearchResult) - assert result.summary == "" - assert result.hits == [] - assert result.filtered is False - - -@pytest.mark.asyncio -async def test_search_accepts_string_time_bounds(monkeypatch): - service = MemoryService() - calls = [] - - async def fake_invoke(component_name, args=None, **kwargs): - calls.append((component_name, args)) - return {"summary": "ok", "hits": [], "filtered": False} - - monkeypatch.setattr(service, "_invoke", fake_invoke) - - result = await service.search( - "广播站", - mode="time", - time_start="2026/03/18", - time_end="2026/03/18 09:30", - ) - - assert isinstance(result, MemorySearchResult) - assert calls == [ - ( - "search_memory", - { - "query": "广播站", - "limit": 5, - "mode": "time", - "chat_id": "", - "person_id": "", - "time_start": "2026/03/18", - "time_end": "2026/03/18 09:30", - "respect_filter": True, - "user_id": "", - "group_id": "", - }, - ) - ] - - -def test_coerce_search_result_preserves_aggregate_source_branches(): - result = MemoryService._coerce_search_result( - { - "hits": [ - { - "content": "广播站值夜班", - "type": "paragraph", - "metadata": {"event_time_start": 1.0}, - "source_branches": ["search", "time"], - "rank": 1, - } - ] - } - ) - - assert result.hits[0].metadata["source_branches"] == ["search", "time"] - assert result.hits[0].metadata["rank"] == 1 - - -@pytest.mark.asyncio -async def test_import_admin_uses_long_timeout(monkeypatch): - service = MemoryService() - calls = [] - - async def fake_invoke(component_name, args=None, **kwargs): - calls.append((component_name, args, kwargs)) - return {"success": True, "task_id": "import-1"} - - monkeypatch.setattr(service, "_invoke", fake_invoke) - - result = await service.import_admin(action="create_lpmm_openie", alias="lpmm") - - assert result["success"] is True - assert calls == [ - ( - "memory_import_admin", - {"action": "create_lpmm_openie", "alias": "lpmm"}, - {"timeout_ms": 120000}, - ) - ] - - -@pytest.mark.asyncio -async def test_tuning_admin_uses_long_timeout(monkeypatch): - service = MemoryService() - calls = [] - - async def fake_invoke(component_name, args=None, **kwargs): - calls.append((component_name, args, kwargs)) - return {"success": True, "task_id": "tuning-1"} - - monkeypatch.setattr(service, "_invoke", fake_invoke) - - result = await service.tuning_admin(action="create_task", payload={"query": "mai"}) - - assert result["success"] is True - assert calls == [ - ( - "memory_tuning_admin", - {"action": "create_task", "payload": {"query": "mai"}}, - {"timeout_ms": 120000}, - ) - ] diff --git a/pytests/A_memorix_test/test_metadata_store_sources.py b/pytests/A_memorix_test/test_metadata_store_sources.py deleted file mode 100644 index dcedbd3d..00000000 --- a/pytests/A_memorix_test/test_metadata_store_sources.py +++ /dev/null @@ -1,21 +0,0 @@ -from pathlib import Path - -from src.A_memorix.core.storage.metadata_store import MetadataStore - - -def test_get_all_sources_ignores_soft_deleted_paragraphs(tmp_path: Path) -> None: - store = MetadataStore(data_dir=tmp_path) - store.connect() - try: - live_hash = store.add_paragraph("Alice 喜欢地图", source="live-source") - deleted_hash = store.add_paragraph("Bob 喜欢咖啡", source="deleted-source") - - assert live_hash - store.mark_as_deleted([deleted_hash], "paragraph") - - sources = store.get_all_sources() - finally: - store.close() - - assert [item["source"] for item in sources] == ["live-source"] - assert sources[0]["count"] == 1 diff --git a/pytests/A_memorix_test/test_person_memory_writeback.py b/pytests/A_memorix_test/test_person_memory_writeback.py deleted file mode 100644 index f177405a..00000000 --- a/pytests/A_memorix_test/test_person_memory_writeback.py +++ /dev/null @@ -1,81 +0,0 @@ -from types import SimpleNamespace - -import pytest - -from src.person_info import person_info as person_info_module - - -@pytest.mark.asyncio -async def test_store_person_memory_from_answer_writes_person_fact(monkeypatch): - calls = [] - - class FakePerson: - def __init__(self, person_id: str): - self.person_id = person_id - self.person_name = "Alice" - self.is_known = True - - async def fake_ingest_text(**kwargs): - calls.append(kwargs) - return SimpleNamespace(success=True, detail="", stored_ids=["p1"]) - - session = SimpleNamespace(platform="qq", user_id="10001", group_id="", session_id="session-1") - monkeypatch.setattr(person_info_module, "_chat_manager", SimpleNamespace(get_session_by_session_id=lambda chat_id: session)) - monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: "person-1") - monkeypatch.setattr(person_info_module, "Person", FakePerson) - monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text) - - await person_info_module.store_person_memory_from_answer("Alice", "她喜欢猫和爵士乐", "session-1") - - assert len(calls) == 1 - payload = calls[0] - assert payload["external_id"].startswith("person_fact:person-1:") - assert payload["source_type"] == "person_fact" - assert payload["chat_id"] == "session-1" - assert payload["person_ids"] == ["person-1"] - assert payload["participants"] == ["Alice"] - assert payload["respect_filter"] is True - assert payload["user_id"] == "10001" - assert payload["group_id"] == "" - assert payload["metadata"]["person_id"] == "person-1" - - -@pytest.mark.asyncio -async def test_store_person_memory_from_answer_skips_unknown_person(monkeypatch): - calls = [] - - class FakePerson: - def __init__(self, person_id: str): - self.person_id = person_id - self.person_name = "Unknown" - self.is_known = False - - async def fake_ingest_text(**kwargs): - calls.append(kwargs) - return SimpleNamespace(success=True, detail="", stored_ids=["p1"]) - - session = SimpleNamespace(platform="qq", user_id="10001", group_id="", session_id="session-1") - monkeypatch.setattr(person_info_module, "_chat_manager", SimpleNamespace(get_session_by_session_id=lambda chat_id: session)) - monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: "person-1") - monkeypatch.setattr(person_info_module, "Person", FakePerson) - monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text) - - await person_info_module.store_person_memory_from_answer("Alice", "她喜欢猫和爵士乐", "session-1") - - assert calls == [] - - -@pytest.mark.asyncio -async def test_store_person_memory_from_answer_skips_empty_content(monkeypatch): - calls = [] - - async def fake_ingest_text(**kwargs): - calls.append(kwargs) - return SimpleNamespace(success=True, detail="", stored_ids=["p1"]) - - monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text) - - await person_info_module.store_person_memory_from_answer("Alice", " ", "session-1") - - assert calls == [] - diff --git a/pytests/A_memorix_test/test_person_profile_service.py b/pytests/A_memorix_test/test_person_profile_service.py deleted file mode 100644 index b75beb16..00000000 --- a/pytests/A_memorix_test/test_person_profile_service.py +++ /dev/null @@ -1,115 +0,0 @@ -from types import SimpleNamespace - -import pytest - -from src.A_memorix.core.utils.person_profile_service import PersonProfileService - - -class FakeMetadataStore: - def __init__(self) -> None: - self.snapshots: list[dict] = [] - - @staticmethod - def get_latest_person_profile_snapshot(person_id: str): - del person_id - return None - - @staticmethod - def get_relations(**kwargs): - del kwargs - return [] - - @staticmethod - def get_paragraphs_by_source(source: str): - if source == "person_fact:person-1": - return [ - { - "hash": "person-fact-1", - "content": "测试用户喜欢猫。", - "source": source, - "metadata": {"source_type": "person_fact"}, - "created_at": 2.0, - "updated_at": 2.0, - } - ] - return [] - - @staticmethod - def get_paragraph(hash_value: str): - if hash_value == "chat-summary-1": - return { - "hash": hash_value, - "content": "机器人建议测试用户以后叫星灯。", - "source": "chat_summary:session-1", - "metadata": {"source_type": "chat_summary"}, - "word_count": 1, - } - if hash_value == "person-fact-1": - return { - "hash": hash_value, - "content": "测试用户喜欢猫。", - "source": "person_fact:person-1", - "metadata": {"source_type": "person_fact"}, - "word_count": 1, - } - return None - - @staticmethod - def get_paragraph_stale_relation_marks_batch(paragraph_hashes): - del paragraph_hashes - return {} - - @staticmethod - def get_relation_status_batch(relation_hashes): - del relation_hashes - return {} - - @staticmethod - def get_person_profile_override(person_id: str): - del person_id - return None - - def upsert_person_profile_snapshot(self, **kwargs): - self.snapshots.append(kwargs) - return { - "person_id": kwargs["person_id"], - "profile_text": kwargs["profile_text"], - "aliases": kwargs["aliases"], - "relation_edges": kwargs["relation_edges"], - "vector_evidence": kwargs["vector_evidence"], - "evidence_ids": kwargs["evidence_ids"], - "updated_at": 1.0, - "expires_at": kwargs["expires_at"], - "source_note": kwargs["source_note"], - } - - -class FakeRetriever: - async def retrieve(self, query: str, top_k: int): - del query, top_k - return [ - SimpleNamespace( - hash_value="chat-summary-1", - result_type="paragraph", - score=0.95, - content="机器人建议测试用户以后叫星灯。", - metadata={"source_type": "chat_summary"}, - ) - ] - - -@pytest.mark.asyncio -async def test_person_profile_keeps_chat_summary_as_recent_interaction_not_stable_profile(): - metadata_store = FakeMetadataStore() - service = PersonProfileService(metadata_store=metadata_store, retriever=FakeRetriever()) - service.get_person_aliases = lambda person_id: (["测试用户"], "测试用户", []) - - payload = await service.query_person_profile(person_id="person-1", top_k=6, force_refresh=True) - - assert payload["success"] is True - profile_text = payload["profile_text"] - stable_section = profile_text.split("近期相关互动:", 1)[0] - assert "测试用户喜欢猫" in stable_section - assert "星灯" not in stable_section - assert "近期相关互动:" in profile_text - assert "星灯" in profile_text diff --git a/pytests/A_memorix_test/test_query_long_term_memory_tool.py b/pytests/A_memorix_test/test_query_long_term_memory_tool.py deleted file mode 100644 index a034d7c3..00000000 --- a/pytests/A_memorix_test/test_query_long_term_memory_tool.py +++ /dev/null @@ -1,184 +0,0 @@ -from __future__ import annotations - -from datetime import datetime -from types import SimpleNamespace - -import pytest - -from src.memory_system.retrieval_tools import query_long_term_memory as tool_module -from src.memory_system.retrieval_tools import init_all_tools -from src.memory_system.retrieval_tools.query_long_term_memory import ( - _resolve_time_expression, - query_long_term_memory, - register_tool, -) -from src.memory_system.retrieval_tools.tool_registry import get_tool_registry -from src.services.memory_service import MemoryHit, MemorySearchResult - - -def test_resolve_time_expression_supports_relative_and_absolute_inputs(): - now = datetime(2026, 3, 18, 15, 30) - - start_ts, end_ts, start_text, end_text = _resolve_time_expression("今天", now=now) - assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 0, 0) - assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59) - assert start_text == "2026/03/18 00:00" - assert end_text == "2026/03/18 23:59" - - start_ts, end_ts, start_text, end_text = _resolve_time_expression("最近7天", now=now) - assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 12, 0, 0) - assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59) - assert start_text == "2026/03/12 00:00" - assert end_text == "2026/03/18 23:59" - - start_ts, end_ts, start_text, end_text = _resolve_time_expression("2026/03/18", now=now) - assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 0, 0) - assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59) - assert start_text == "2026/03/18 00:00" - assert end_text == "2026/03/18 23:59" - - start_ts, end_ts, start_text, end_text = _resolve_time_expression("2026/03/18 09:30", now=now) - assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 9, 30) - assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 9, 30) - assert start_text == "2026/03/18 09:30" - assert end_text == "2026/03/18 09:30" - - -def test_register_tool_exposes_mode_and_time_expression(): - register_tool() - tool = get_tool_registry().get_tool("search_long_term_memory") - - assert tool is not None - params = {item["name"]: item for item in tool.parameters} - assert "mode" in params - assert params["mode"]["enum"] == ["search", "time", "episode", "aggregate"] - assert "time_expression" in params - assert params["query"]["required"] is False - - -def test_init_all_tools_registers_long_term_memory_tool(): - init_all_tools() - - tool = get_tool_registry().get_tool("search_long_term_memory") - assert tool is not None - - -@pytest.mark.asyncio -async def test_query_long_term_memory_search_mode_keeps_search(monkeypatch): - captured = {} - - async def fake_search(query, **kwargs): - captured["query"] = query - captured["kwargs"] = kwargs - return MemorySearchResult( - hits=[MemoryHit(content="Alice 喜欢猫", score=0.9, hit_type="paragraph")], - ) - - monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search)) - - text = await query_long_term_memory("Alice 喜欢什么", chat_id="stream-1", person_id="person-1") - - assert "Alice 喜欢猫" in text - assert captured == { - "query": "Alice 喜欢什么", - "kwargs": { - "limit": 5, - "mode": "search", - "chat_id": "stream-1", - "person_id": "person-1", - "time_start": None, - "time_end": None, - }, - } - - -@pytest.mark.asyncio -async def test_query_long_term_memory_time_mode_parses_expression(monkeypatch): - captured = {} - - async def fake_search(query, **kwargs): - captured["query"] = query - captured["kwargs"] = kwargs - return MemorySearchResult( - hits=[ - MemoryHit( - content="昨天晚上广播站停播了十分钟。", - score=0.8, - hit_type="paragraph", - metadata={"event_time_start": 1773797400.0}, - ) - ] - ) - - monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search)) - monkeypatch.setattr( - tool_module, - "_resolve_time_expression", - lambda expression, now=None: (1773795600.0, 1773881940.0, "2026/03/17 00:00", "2026/03/17 23:59"), - ) - - text = await query_long_term_memory( - query="广播站", - mode="time", - time_expression="昨天", - chat_id="stream-1", - ) - - assert "指定时间范围" in text - assert "广播站停播" in text - assert captured == { - "query": "广播站", - "kwargs": { - "limit": 5, - "mode": "time", - "chat_id": "stream-1", - "person_id": "", - "time_start": 1773795600.0, - "time_end": 1773881940.0, - }, - } - - -@pytest.mark.asyncio -async def test_query_long_term_memory_episode_and_aggregate_format_output(monkeypatch): - responses = { - "episode": MemorySearchResult( - hits=[ - MemoryHit( - content="苏弦在灯塔拆开了那封冬信。", - title="冬信重见天日", - hit_type="episode", - metadata={"participants": ["苏弦"], "keywords": ["冬信", "灯塔"]}, - ) - ] - ), - "aggregate": MemorySearchResult( - hits=[ - MemoryHit( - content="唐未在广播站值夜班时带着黑狗墨点。", - hit_type="paragraph", - metadata={"source_branches": ["search", "time"]}, - ) - ] - ), - } - - async def fake_search(query, **kwargs): - return responses[kwargs["mode"]] - - monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search)) - - episode_text = await query_long_term_memory("那封冬信后来怎么样了", mode="episode") - aggregate_text = await query_long_term_memory("唐未最近有什么线索", mode="aggregate") - - assert "事件《冬信重见天日》" in episode_text - assert "参与者:苏弦" in episode_text - assert "[search,time][paragraph]" in aggregate_text - - -@pytest.mark.asyncio -async def test_query_long_term_memory_invalid_time_expression_returns_retryable_message(): - text = await query_long_term_memory(query="广播站", mode="time", time_expression="明年春分后第三周") - - assert "无法解析" in text - assert "最近7天" in text diff --git a/pytests/A_memorix_test/test_summary_importer_model_config.py b/pytests/A_memorix_test/test_summary_importer_model_config.py deleted file mode 100644 index d20e3f49..00000000 --- a/pytests/A_memorix_test/test_summary_importer_model_config.py +++ /dev/null @@ -1,140 +0,0 @@ -import pytest - -from src.A_memorix.core.utils.summary_importer import ( - SummaryImporter, - _message_timestamp, - _normalize_entity_items, - _normalize_relation_items, -) -from src.config.model_configs import TaskConfig -from src.services import llm_service as llm_api - - -def _fake_available_models() -> dict[str, TaskConfig]: - return { - "memory": TaskConfig( - model_list=["memory-model"], - max_tokens=512, - temperature=0.4, - selection_strategy="random", - ), - "utils": TaskConfig( - model_list=["utils-model"], - max_tokens=256, - temperature=0.5, - selection_strategy="random", - ), - "replyer": TaskConfig( - model_list=["replyer-model"], - max_tokens=128, - temperature=0.7, - selection_strategy="random", - ), - } - - -def test_resolve_summary_model_config_uses_auto_list_when_summarization_missing(monkeypatch): - monkeypatch.setattr(llm_api, "get_available_models", _fake_available_models) - - importer = SummaryImporter( - vector_store=None, - graph_store=None, - metadata_store=None, - embedding_manager=None, - plugin_config={}, - ) - - resolved = importer._resolve_summary_model_config() - - assert resolved is not None - assert resolved.model_list == ["memory-model"] - - -def test_resolve_summary_model_config_auto_falls_back_to_utils_then_planner(monkeypatch): - importer = SummaryImporter( - vector_store=None, - graph_store=None, - metadata_store=None, - embedding_manager=None, - plugin_config={}, - ) - - monkeypatch.setattr( - llm_api, - "get_available_models", - lambda: { - "utils": TaskConfig(model_list=["utils-model"]), - "planner": TaskConfig(model_list=["planner-model"]), - "replyer": TaskConfig(model_list=["replyer-model"]), - }, - ) - resolved = importer._resolve_summary_model_config() - assert resolved is not None - assert resolved.model_list == ["utils-model"] - - monkeypatch.setattr( - llm_api, - "get_available_models", - lambda: { - "planner": TaskConfig(model_list=["planner-model"]), - "replyer": TaskConfig(model_list=["replyer-model"]), - }, - ) - resolved = importer._resolve_summary_model_config() - assert resolved is not None - assert resolved.model_list == ["planner-model"] - - -def test_resolve_summary_model_config_auto_does_not_fallback_to_replyer(monkeypatch): - monkeypatch.setattr( - llm_api, - "get_available_models", - lambda: { - "replyer": TaskConfig(model_list=["replyer-model"]), - "embedding": TaskConfig(model_list=["embedding-model"]), - }, - ) - - importer = SummaryImporter( - vector_store=None, - graph_store=None, - metadata_store=None, - embedding_manager=None, - plugin_config={}, - ) - - assert importer._resolve_summary_model_config() is None - - -def test_resolve_summary_model_config_rejects_legacy_string_selector(monkeypatch): - monkeypatch.setattr(llm_api, "get_available_models", _fake_available_models) - - importer = SummaryImporter( - vector_store=None, - graph_store=None, - metadata_store=None, - embedding_manager=None, - plugin_config={"summarization": {"model_name": "auto"}}, - ) - - with pytest.raises(ValueError, match="List\\[str\\]"): - importer._resolve_summary_model_config() - - -def test_summary_importer_normalizes_llm_entities_and_relations(): - assert _normalize_entity_items(["Alice", {"name": "地图"}, ["bad"], "Alice"]) == ["Alice", "地图"] - assert _normalize_entity_items("Alice") == [] - assert _normalize_relation_items( - [ - {"subject": "Alice", "predicate": "持有", "object": "地图"}, - {"subject": "Alice", "predicate": "", "object": "地图"}, - ["bad"], - ] - ) == [{"subject": "Alice", "predicate": "持有", "object": "地图"}] - - -def test_summary_importer_message_timestamp_accepts_time_fallback(): - class Message: - time = 123.5 - - assert _message_timestamp(Message()) == 123.5 diff --git a/pytests/A_memorix_test/test_web_import_manager_payloads.py b/pytests/A_memorix_test/test_web_import_manager_payloads.py deleted file mode 100644 index f2d78df3..00000000 --- a/pytests/A_memorix_test/test_web_import_manager_payloads.py +++ /dev/null @@ -1,182 +0,0 @@ -from types import SimpleNamespace - -import numpy as np -import pytest - -from src.A_memorix.core.strategies.base import ChunkContext, KnowledgeType, ProcessedChunk, SourceInfo -from src.A_memorix.core.utils.web_import_manager import ( - ImportChunkRecord, - ImportFileRecord, - ImportTaskManager, - ImportTaskRecord, -) - - -class _DummyMetadataStore: - def __init__(self) -> None: - self.paragraphs: list[dict[str, object]] = [] - self.entities: list[str] = [] - self.relations: list[tuple[str, str, str]] = [] - - def add_paragraph(self, **kwargs): - self.paragraphs.append(dict(kwargs)) - return f"paragraph-{len(self.paragraphs)}" - - def add_entity(self, *, name: str, source_paragraph: str = "") -> str: - del source_paragraph - self.entities.append(name) - return f"entity-{name}" - - def add_relation(self, *, subject: str, predicate: str, obj: str, **kwargs) -> str: - del kwargs - self.relations.append((subject, predicate, obj)) - return f"relation-{len(self.relations)}" - - def set_relation_vector_state(self, rel_hash: str, state: str) -> None: - del rel_hash, state - - -class _DummyGraphStore: - def __init__(self) -> None: - self.nodes: list[list[str]] = [] - self.edges: list[list[tuple[str, str]]] = [] - - def add_nodes(self, nodes): - self.nodes.append(list(nodes)) - - def add_edges(self, edges, relation_hashes=None): - del relation_hashes - self.edges.append(list(edges)) - - -class _DummyVectorStore: - def __contains__(self, item: str) -> bool: - del item - return False - - def add(self, vectors, ids): - del vectors, ids - - -class _DummyEmbeddingManager: - async def encode(self, text: str) -> np.ndarray: - del text - return np.ones(4, dtype=np.float32) - - -def _build_manager() -> tuple[ImportTaskManager, _DummyMetadataStore]: - metadata_store = _DummyMetadataStore() - plugin = SimpleNamespace( - metadata_store=metadata_store, - graph_store=_DummyGraphStore(), - vector_store=_DummyVectorStore(), - embedding_manager=_DummyEmbeddingManager(), - relation_write_service=None, - get_config=lambda key, default=None: default, - _is_embedding_degraded=lambda: False, - _allow_metadata_only_write=lambda: True, - write_paragraph_vector_or_enqueue=None, - ) - manager = ImportTaskManager(plugin) - return manager, metadata_store - - -def _build_progress_task(task_id: str, total_chunks: int = 2) -> ImportTaskRecord: - file_record = ImportFileRecord( - file_id="file-1", - name="demo.txt", - source_kind="paste", - input_mode="text", - total_chunks=total_chunks, - chunks=[ - ImportChunkRecord(chunk_id=f"chunk-{index}", index=index, chunk_type="text") - for index in range(total_chunks) - ], - ) - return ImportTaskRecord(task_id=task_id, source="paste", params={}, files=[file_record]) - - -def _build_chunk(data) -> ProcessedChunk: - return ProcessedChunk( - type=KnowledgeType.FACTUAL, - source=SourceInfo(file="demo.txt", offset_start=0, offset_end=4), - chunk=ChunkContext(chunk_id="chunk-1", index=0, text="Alice 持有地图"), - data=data, - ) - - -@pytest.mark.asyncio -async def test_persist_processed_chunk_rejects_non_object_before_paragraph_write() -> None: - manager, metadata_store = _build_manager() - file_record = SimpleNamespace(source_path="", source_kind="paste", name="demo.txt") - - with pytest.raises(ValueError, match="分块抽取结果 必须返回 JSON 对象"): - await manager._persist_processed_chunk(file_record, _build_chunk(["bad"])) - - assert metadata_store.paragraphs == [] - - -@pytest.mark.asyncio -async def test_chunk_terminal_progress_uses_successful_chunks_only() -> None: - manager, _ = _build_manager() - - task = _build_progress_task("task-fail-then-complete") - manager._tasks[task.task_id] = task - - await manager._set_chunk_failed(task.task_id, "file-1", "chunk-0", "boom") - await manager._set_chunk_completed(task.task_id, "file-1", "chunk-1") - - file_record = task.files[0] - assert file_record.done_chunks == 1 - assert file_record.failed_chunks == 1 - assert file_record.progress == pytest.approx(0.5) - assert task.progress == pytest.approx(0.5) - - reverse_task = _build_progress_task("task-complete-then-fail") - manager._tasks[reverse_task.task_id] = reverse_task - - await manager._set_chunk_completed(reverse_task.task_id, "file-1", "chunk-0") - await manager._set_chunk_failed(reverse_task.task_id, "file-1", "chunk-1", "boom") - - reverse_file = reverse_task.files[0] - assert reverse_file.done_chunks == 1 - assert reverse_file.failed_chunks == 1 - assert reverse_file.progress == pytest.approx(0.5) - assert reverse_task.progress == pytest.approx(0.5) - - -@pytest.mark.asyncio -async def test_cancelled_chunks_do_not_increase_file_progress() -> None: - manager, _ = _build_manager() - task = _build_progress_task("task-cancelled-progress", total_chunks=3) - manager._tasks[task.task_id] = task - - await manager._set_chunk_completed(task.task_id, "file-1", "chunk-0") - await manager._set_chunk_cancelled(task.task_id, "file-1", "chunk-1", "任务已取消") - - file_record = task.files[0] - assert file_record.done_chunks == 1 - assert file_record.cancelled_chunks == 1 - assert file_record.progress == pytest.approx(1 / 3) - assert task.progress == pytest.approx(1 / 3) - - -@pytest.mark.asyncio -async def test_persist_processed_chunk_skips_invalid_nested_items() -> None: - manager, metadata_store = _build_manager() - file_record = SimpleNamespace(source_path="", source_kind="paste", name="demo.txt") - - await manager._persist_processed_chunk( - file_record, - _build_chunk( - { - "triples": [{"subject": "Alice", "predicate": "持有", "object": "地图"}, ["bad"]], - "relations": [{"subject": "Alice", "predicate": "", "object": "地图"}], - "entities": ["Alice", {"name": "地图"}, ["bad"]], - } - ), - ) - - assert len(metadata_store.paragraphs) == 1 - assert set(metadata_store.entities) >= {"Alice", "地图"} - assert metadata_store.relations == [("Alice", "持有", "地图")] diff --git a/pytests/common_test/test_chat_config_utils.py b/pytests/common_test/test_chat_config_utils.py deleted file mode 100644 index 354b24e8..00000000 --- a/pytests/common_test/test_chat_config_utils.py +++ /dev/null @@ -1,89 +0,0 @@ -from types import SimpleNamespace - -from src.chat.message_receive.chat_manager import chat_manager -from src.common.utils.utils_config import ChatConfigUtils, ExpressionConfigUtils -from src.common.utils.utils_session import SessionUtils -from src.config.config import global_config - - -def test_get_chat_prompt_for_chat_merges_multiple_matching_prompts(monkeypatch): - session_id = SessionUtils.calculate_session_id("qq", group_id="1036092828") - monkeypatch.setattr( - global_config.chat, - "chat_prompts", - [ - {"platform": "qq", "item_id": "1036092828", "rule_type": "group", "prompt": "你也是群管理员,可以适当进行管理"}, - {"platform": "qq", "item_id": "1036092828", "rule_type": "group", "prompt": "这个群是技术实验群,请你专心讨论技术"}, - {"platform": "qq", "item_id": "other", "rule_type": "group", "prompt": "不应该生效"}, - ], - ) - monkeypatch.setattr(chat_manager, "get_session_by_session_id", lambda _session_id: None) - - result = ChatConfigUtils.get_chat_prompt_for_chat(session_id, True) - - assert result == "你也是群管理员,可以适当进行管理\n这个群是技术实验群,请你专心讨论技术" - - -def test_get_chat_prompt_for_chat_matches_routed_session_by_chat_stream(monkeypatch): - session_id = SessionUtils.calculate_session_id("qq", group_id="1036092828", account_id="bot-a") - monkeypatch.setattr( - global_config.chat, - "chat_prompts", - [ - {"platform": "qq", "item_id": "1036092828", "rule_type": "group", "prompt": "路由会话也应该生效"}, - ], - ) - monkeypatch.setattr( - chat_manager, - "get_session_by_session_id", - lambda _session_id: SimpleNamespace(platform="qq", group_id="1036092828", user_id=None), - ) - - result = ChatConfigUtils.get_chat_prompt_for_chat(session_id, True) - - assert result == "路由会话也应该生效" - - -def test_expression_learning_list_matches_routed_session_by_chat_stream(monkeypatch): - session_id = SessionUtils.calculate_session_id("qq", group_id="1036092828", account_id="bot-a") - monkeypatch.setattr( - global_config.expression, - "learning_list", - [ - { - "platform": "qq", - "item_id": "1036092828", - "rule_type": "group", - "use_expression": False, - "enable_learning": False, - "enable_jargon_learning": True, - } - ], - ) - monkeypatch.setattr( - chat_manager, - "get_session_by_session_id", - lambda _session_id: SimpleNamespace(platform="qq", group_id="1036092828", user_id=None), - ) - - assert ExpressionConfigUtils.get_expression_config_for_chat(session_id) == (False, False, True) - - -def test_talk_value_rules_match_routed_session_by_chat_stream(monkeypatch): - session_id = SessionUtils.calculate_session_id("qq", group_id="1036092828", account_id="bot-a") - monkeypatch.setattr(global_config.chat, "talk_value", 0.1) - monkeypatch.setattr(global_config.chat, "enable_talk_value_rules", True) - monkeypatch.setattr( - global_config.chat, - "talk_value_rules", - [ - {"platform": "qq", "item_id": "1036092828", "rule_type": "group", "time": "00:00-23:59", "value": 0.7} - ], - ) - monkeypatch.setattr( - chat_manager, - "get_session_by_session_id", - lambda _session_id: SimpleNamespace(platform="qq", group_id="1036092828", user_id=None), - ) - - assert ChatConfigUtils.get_talk_value(session_id, True) == 0.7 diff --git a/pytests/common_test/test_database_migration_foundation.py b/pytests/common_test/test_database_migration_foundation.py deleted file mode 100644 index 3e4955de..00000000 --- a/pytests/common_test/test_database_migration_foundation.py +++ /dev/null @@ -1,908 +0,0 @@ -"""数据库迁移基础设施测试。""" - -from pathlib import Path -from typing import List, Optional, Tuple - -from sqlalchemy import text -from sqlalchemy.engine import Connection, Engine -from sqlmodel import SQLModel, create_engine - -import json -import msgpack -import pytest - -from src.common.database import database as database_module -from src.common.database.migrations import ( - BaseSchemaVersionDetector, - BaseMigrationProgressReporter, - DatabaseSchemaSnapshot, - DatabaseMigrationBootstrapper, - DatabaseMigrationState, - DatabaseMigrationManager, - EMPTY_SCHEMA_VERSION, - LATEST_SCHEMA_VERSION, - LEGACY_V1_SCHEMA_VERSION, - MigrationExecutionContext, - MigrationPlan, - MigrationRegistry, - MigrationStep, - ResolvedSchemaVersion, - SchemaVersionResolver, - SchemaVersionSource, - SQLiteSchemaInspector, - SQLiteUserVersionStore, - build_default_migration_registry, - build_default_schema_version_resolver, - create_database_migration_bootstrapper, -) - - -class FixedVersionDetector(BaseSchemaVersionDetector): - """测试用固定版本探测器。""" - - @property - def name(self) -> str: - """返回测试探测器名称。 - - Returns: - str: 探测器名称。 - """ - return "fixed_version_detector" - - def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]: - """根据测试表是否存在返回固定版本。 - - Args: - snapshot: 当前数据库结构快照。 - - Returns: - Optional[int]: 若存在测试表则返回固定版本,否则返回 ``None``。 - """ - if snapshot.has_table("legacy_records"): - return 2 - return None - - -class FakeMigrationProgressReporter(BaseMigrationProgressReporter): - """测试用迁移进度上报器。""" - - def __init__(self) -> None: - """初始化测试用进度上报器。""" - self.events: List[Tuple[str, Optional[int], Optional[int], Optional[str]]] = [] - - def open(self) -> None: - """记录打开事件。""" - self.events.append(("open", None, None, None)) - - def close(self) -> None: - """记录关闭事件。""" - self.events.append(("close", None, None, None)) - - def start( - self, - total_records: int, - total_tables: int, - description: str = "总迁移进度", - table_unit_name: str = "表", - record_unit_name: str = "记录", - ) -> None: - """记录启动事件。 - - Args: - total_records: 任务记录总数。 - total_tables: 任务表总数。 - description: 任务描述。 - table_unit_name: 表级进度单位名称。 - record_unit_name: 记录级进度单位名称。 - """ - del table_unit_name, record_unit_name - self.events.append(("start", total_records, total_tables, description)) - - def advance( - self, - records: int = 0, - completed_tables: int = 0, - item_name: Optional[str] = None, - ) -> None: - """记录推进事件。 - - Args: - records: 推进的记录数。 - completed_tables: 已完成的表数。 - item_name: 当前完成的项目名称。 - """ - self.events.append(("advance", records, completed_tables, item_name)) - - -def _create_sqlite_engine(database_file: Path) -> Engine: - """创建测试用 SQLite 引擎。 - - Args: - database_file: 测试数据库文件路径。 - - Returns: - Engine: SQLite 引擎实例。 - """ - return create_engine( - f"sqlite:///{database_file}", - echo=False, - connect_args={"check_same_thread": False}, - ) - - -def _create_current_schema(connection: Connection) -> None: - """创建当前最新版本的数据库结构。 - - Args: - connection: 当前数据库连接。 - """ - import src.common.database.database_model # noqa: F401 - - SQLModel.metadata.create_all(connection) - - -def _create_legacy_v1_schema_with_sample_data(connection: Connection) -> None: - """创建带示例数据的旧版 ``0.x`` 数据库结构。 - - Args: - connection: 当前数据库连接。 - """ - connection.execute( - text( - """ - CREATE TABLE chat_streams ( - id INTEGER PRIMARY KEY, - stream_id TEXT NOT NULL, - create_time REAL NOT NULL, - last_active_time REAL NOT NULL, - platform TEXT NOT NULL, - user_id TEXT, - group_id TEXT, - group_name TEXT - ) - """ - ) - ) - connection.execute( - text( - """ - CREATE TABLE messages ( - id INTEGER PRIMARY KEY, - message_id TEXT NOT NULL, - time REAL NOT NULL, - chat_id TEXT NOT NULL, - chat_info_platform TEXT, - user_id TEXT, - user_nickname TEXT, - chat_info_group_id TEXT, - chat_info_group_name TEXT, - is_mentioned INTEGER, - is_at INTEGER, - processed_plain_text TEXT, - display_message TEXT, - is_emoji INTEGER, - is_picid INTEGER, - is_command INTEGER, - is_notify INTEGER, - additional_config TEXT, - priority_mode TEXT - ) - """ - ) - ) - connection.execute( - text( - """ - CREATE TABLE action_records ( - id INTEGER PRIMARY KEY, - action_id TEXT NOT NULL, - time REAL NOT NULL, - action_reasoning TEXT, - action_name TEXT NOT NULL, - action_data TEXT, - action_prompt_display TEXT, - chat_id TEXT - ) - """ - ) - ) - connection.execute( - text( - """ - CREATE TABLE expression ( - id INTEGER PRIMARY KEY, - situation TEXT NOT NULL, - style TEXT NOT NULL, - content_list TEXT, - count INTEGER, - last_active_time REAL NOT NULL, - chat_id TEXT, - create_date REAL, - checked INTEGER, - rejected INTEGER, - modified_by TEXT - ) - """ - ) - ) - connection.execute( - text( - """ - CREATE TABLE jargon ( - id INTEGER PRIMARY KEY, - content TEXT NOT NULL, - raw_content TEXT, - meaning TEXT, - chat_id TEXT, - is_global INTEGER, - count INTEGER, - is_jargon INTEGER, - last_inference_count INTEGER, - is_complete INTEGER, - inference_with_context TEXT, - inference_content_only TEXT - ) - """ - ) - ) - - connection.execute( - text( - """ - INSERT INTO chat_streams ( - id, - stream_id, - create_time, - last_active_time, - platform, - user_id, - group_id, - group_name - ) VALUES ( - 1, - 'session-1', - 1710000000.0, - 1710000300.0, - 'qq', - 'user-1', - 'group-1', - '测试群' - ) - """ - ) - ) - connection.execute( - text( - """ - INSERT INTO messages ( - id, - message_id, - time, - chat_id, - chat_info_platform, - user_id, - user_nickname, - chat_info_group_id, - chat_info_group_name, - is_mentioned, - is_at, - processed_plain_text, - display_message, - is_emoji, - is_picid, - is_command, - is_notify, - additional_config, - priority_mode - ) VALUES ( - 1, - 'msg-1', - 1710000010.0, - 'session-1', - 'qq', - 'user-1', - '测试用户', - 'group-1', - '测试群', - 1, - 0, - '你好', - '你好呀', - 0, - 1, - 0, - 1, - '{"source":"legacy"}', - 'high' - ) - """ - ) - ) - connection.execute( - text( - """ - INSERT INTO action_records ( - id, - action_id, - time, - action_reasoning, - action_name, - action_data, - action_prompt_display, - chat_id - ) VALUES ( - 1, - 'action-1', - 1710000020.0, - '需要调用工具', - 'search', - '{"query":"MaiBot"}', - '执行搜索', - 'session-1' - ) - """ - ) - ) - connection.execute( - text( - """ - INSERT INTO expression ( - id, - situation, - style, - content_list, - count, - last_active_time, - chat_id, - create_date, - checked, - rejected, - modified_by - ) VALUES ( - 1, - '打招呼', - '可爱', - '["你好呀","早上好"]', - 3, - 1710000030.0, - 'session-1', - 1710000040.0, - 1, - 0, - 'ai' - ) - """ - ) - ) - connection.execute( - text( - """ - INSERT INTO jargon ( - id, - content, - raw_content, - meaning, - chat_id, - is_global, - count, - is_jargon, - last_inference_count, - is_complete, - inference_with_context, - inference_content_only - ) VALUES ( - 1, - '上分', - '["上分"]', - '提高排名', - 'session-1', - 0, - 5, - 1, - 2, - 1, - '{"guess":"context"}', - '{"guess":"content"}' - ) - """ - ) - ) - - -def test_user_version_store_can_read_and_write_versions(tmp_path: Path) -> None: - """应支持读取与写入 SQLite ``user_version``。""" - engine = _create_sqlite_engine(tmp_path / "version_store.db") - version_store = SQLiteUserVersionStore() - - with engine.begin() as connection: - assert version_store.read_version(connection) == 0 - version_store.write_version(connection, 7) - - with engine.connect() as connection: - assert version_store.read_version(connection) == 7 - - -def test_schema_inspector_can_extract_tables_and_columns(tmp_path: Path) -> None: - """应能提取 SQLite 数据库的表与列结构。""" - engine = _create_sqlite_engine(tmp_path / "schema_inspector.db") - inspector = SQLiteSchemaInspector() - - with engine.begin() as connection: - connection.execute( - text( - """ - CREATE TABLE legacy_records ( - id INTEGER PRIMARY KEY, - payload TEXT NOT NULL, - created_at TEXT - ) - """ - ) - ) - - with engine.connect() as connection: - snapshot = inspector.inspect(connection) - - assert snapshot.has_table("legacy_records") - assert snapshot.has_column("legacy_records", "payload") - assert not snapshot.has_column("legacy_records", "missing_column") - table_schema = snapshot.get_table("legacy_records") - - assert table_schema is not None - assert table_schema.column_names() == ["created_at", "id", "payload"] - - -def test_resolver_can_identify_empty_database(tmp_path: Path) -> None: - """空数据库应被解析为版本 ``0``。""" - engine = _create_sqlite_engine(tmp_path / "empty_resolver.db") - resolver = SchemaVersionResolver() - - with engine.connect() as connection: - resolved_version = resolver.resolve(connection) - - assert resolved_version.version == 0 - assert resolved_version.source == SchemaVersionSource.EMPTY_DATABASE - assert resolved_version.snapshot is not None - assert resolved_version.snapshot.is_empty() - - -def test_resolver_can_use_detector_for_unversioned_legacy_database(tmp_path: Path) -> None: - """未写入 ``user_version`` 的历史库应支持通过探测器识别版本。""" - engine = _create_sqlite_engine(tmp_path / "legacy_resolver.db") - resolver = SchemaVersionResolver(detectors=[FixedVersionDetector()]) - - with engine.begin() as connection: - connection.execute(text("CREATE TABLE legacy_records (id INTEGER PRIMARY KEY, payload TEXT NOT NULL)")) - - with engine.connect() as connection: - resolved_version = resolver.resolve(connection) - - assert resolved_version.version == 2 - assert resolved_version.source == SchemaVersionSource.DETECTOR - assert resolved_version.detector_name == "fixed_version_detector" - - -def test_registry_and_manager_can_execute_registered_steps(tmp_path: Path) -> None: - """迁移编排器应能按顺序执行已注册步骤并更新版本号。""" - engine = _create_sqlite_engine(tmp_path / "manager.db") - executed_steps: List[str] = [] - - def migrate_0_to_1(context: MigrationExecutionContext) -> None: - """测试迁移步骤 0 -> 1。 - - Args: - context: 当前迁移步骤执行上下文。 - """ - executed_steps.append(f"{context.current_version}->{context.target_version}:step_0_to_1") - context.connection.execute(text("CREATE TABLE sample_records (id INTEGER PRIMARY KEY, name TEXT NOT NULL)")) - - def migrate_1_to_2(context: MigrationExecutionContext) -> None: - """测试迁移步骤 1 -> 2。 - - Args: - context: 当前迁移步骤执行上下文。 - """ - executed_steps.append(f"{context.current_version}->{context.target_version}:step_1_to_2") - context.connection.execute(text("ALTER TABLE sample_records ADD COLUMN email TEXT")) - - registry = MigrationRegistry( - steps=[ - MigrationStep( - version_from=0, - version_to=1, - name="create_sample_records", - description="创建示例表。", - handler=migrate_0_to_1, - ), - MigrationStep( - version_from=1, - version_to=2, - name="add_sample_email", - description="为示例表增加邮箱字段。", - handler=migrate_1_to_2, - ), - ] - ) - manager = DatabaseMigrationManager(engine=engine, registry=registry) - - migration_plan = manager.migrate() - - assert migration_plan.step_count() == 2 - assert executed_steps == ["0->2:step_0_to_1", "1->2:step_1_to_2"] - - with engine.connect() as connection: - version_store = SQLiteUserVersionStore() - snapshot = SQLiteSchemaInspector().inspect(connection) - recorded_version = version_store.read_version(connection) - - assert recorded_version == 2 - assert snapshot.has_table("sample_records") - assert snapshot.has_column("sample_records", "email") - - -def test_manager_can_report_step_progress(tmp_path: Path) -> None: - """迁移编排器应支持通过上下文上报步骤进度。""" - engine = _create_sqlite_engine(tmp_path / "manager_progress.db") - reporter_instances: List[FakeMigrationProgressReporter] = [] - - def _build_reporter() -> BaseMigrationProgressReporter: - """构建测试用进度上报器。 - - Returns: - BaseMigrationProgressReporter: 测试用进度上报器实例。 - """ - reporter = FakeMigrationProgressReporter() - reporter_instances.append(reporter) - return reporter - - def migrate_1_to_2(context: MigrationExecutionContext) -> None: - """测试迁移步骤 ``1 -> 2`` 的进度上报。 - - Args: - context: 当前迁移步骤执行上下文。 - """ - context.start_progress(total_tables=3, total_records=30, description="总迁移进度") - context.advance_progress(records=10, completed_tables=1, item_name="chat_sessions") - context.advance_progress(records=10, completed_tables=1, item_name="mai_messages") - context.advance_progress(records=10, completed_tables=1, item_name="tool_records") - context.connection.execute(text("CREATE TABLE progress_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)")) - - with engine.begin() as connection: - SQLiteUserVersionStore().write_version(connection, 1) - - registry = MigrationRegistry( - steps=[ - MigrationStep( - version_from=1, - version_to=2, - name="progress_step", - description="测试进度上报。", - handler=migrate_1_to_2, - ) - ] - ) - manager = DatabaseMigrationManager( - engine=engine, - registry=registry, - progress_reporter_factory=_build_reporter, - ) - - migration_plan = manager.migrate() - - assert migration_plan.step_count() == 1 - assert len(reporter_instances) == 1 - assert reporter_instances[0].events == [ - ("open", None, None, None), - ("start", 30, 3, "总迁移进度"), - ("advance", 10, 1, "chat_sessions"), - ("advance", 10, 1, "mai_messages"), - ("advance", 10, 1, "tool_records"), - ("close", None, None, None), - ] - - -def test_default_resolver_can_identify_unversioned_latest_database(tmp_path: Path) -> None: - """默认解析器应能识别未写入版本号的最新结构数据库。""" - engine = _create_sqlite_engine(tmp_path / "latest_resolver.db") - resolver = build_default_schema_version_resolver() - - with engine.begin() as connection: - _create_current_schema(connection) - - with engine.connect() as connection: - resolved_version = resolver.resolve(connection) - - assert resolved_version.version == LATEST_SCHEMA_VERSION - assert resolved_version.source == SchemaVersionSource.DETECTOR - assert resolved_version.detector_name == "latest_schema_detector" - - -def test_default_resolver_can_identify_legacy_v1_database(tmp_path: Path) -> None: - """默认解析器应能识别未写版本号的旧版 ``0.x`` 数据库。""" - engine = _create_sqlite_engine(tmp_path / "legacy_v1_resolver.db") - resolver = build_default_schema_version_resolver() - - with engine.begin() as connection: - _create_legacy_v1_schema_with_sample_data(connection) - - with engine.connect() as connection: - resolved_version = resolver.resolve(connection) - - assert resolved_version.version == LEGACY_V1_SCHEMA_VERSION - assert resolved_version.source == SchemaVersionSource.DETECTOR - assert resolved_version.detector_name == "legacy_v1_schema_detector" - - -def test_bootstrapper_can_finalize_unversioned_latest_database(tmp_path: Path) -> None: - """已是最新结构但未写版本号的数据库应直接补写 ``user_version``。""" - engine = _create_sqlite_engine(tmp_path / "latest_finalize.db") - bootstrapper = create_database_migration_bootstrapper(engine) - - with engine.begin() as connection: - _create_current_schema(connection) - - migration_state = bootstrapper.prepare_database() - bootstrapper.finalize_database(migration_state) - - assert not migration_state.requires_migration() - assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION - assert migration_state.resolved_version.source == SchemaVersionSource.DETECTOR - - with engine.connect() as connection: - recorded_version = SQLiteUserVersionStore().read_version(connection) - - assert recorded_version == LATEST_SCHEMA_VERSION - - -def test_bootstrapper_can_finalize_empty_database_to_latest_version(tmp_path: Path) -> None: - """空库在建表完成后应回写最新 ``user_version``。""" - engine = _create_sqlite_engine(tmp_path / "bootstrap_empty.db") - bootstrapper = create_database_migration_bootstrapper(engine) - - migration_state = bootstrapper.prepare_database() - - assert not migration_state.requires_migration() - assert migration_state.resolved_version.version == EMPTY_SCHEMA_VERSION - assert migration_state.target_version == LATEST_SCHEMA_VERSION - - with engine.begin() as connection: - _create_current_schema(connection) - - bootstrapper.finalize_database(migration_state) - - with engine.connect() as connection: - recorded_version = SQLiteUserVersionStore().read_version(connection) - - assert recorded_version == LATEST_SCHEMA_VERSION - - -def test_bootstrapper_runs_registered_steps_for_versioned_database(tmp_path: Path) -> None: - """启动桥接器应在已登记旧版本数据库上执行注册迁移步骤。""" - engine = _create_sqlite_engine(tmp_path / "bootstrap_registered.db") - execution_marks: List[str] = [] - - def migrate_1_to_2(context: MigrationExecutionContext) -> None: - """测试桥接器迁移步骤 ``1 -> 2``。 - - Args: - context: 当前迁移步骤执行上下文。 - """ - execution_marks.append(f"step={context.step_name},index={context.step_index}") - context.connection.execute(text("ALTER TABLE bootstrap_records ADD COLUMN email TEXT")) - - with engine.begin() as connection: - connection.execute( - text("CREATE TABLE bootstrap_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)") - ) - SQLiteUserVersionStore().write_version(connection, 1) - - registry = MigrationRegistry( - steps=[ - MigrationStep( - version_from=1, - version_to=2, - name="bootstrap_add_email", - description="为桥接器测试表增加邮箱字段。", - handler=migrate_1_to_2, - ) - ] - ) - bootstrapper = DatabaseMigrationBootstrapper( - manager=DatabaseMigrationManager(engine=engine, registry=registry), - latest_schema_version=2, - ) - - migration_state = bootstrapper.prepare_database() - - assert migration_state.resolved_version.version == 2 - assert migration_state.target_version == 2 - assert execution_marks == ["step=bootstrap_add_email,index=1"] - - with engine.connect() as connection: - snapshot = SQLiteSchemaInspector().inspect(connection) - recorded_version = SQLiteUserVersionStore().read_version(connection) - - assert recorded_version == 2 - assert snapshot.has_table("bootstrap_records") - assert snapshot.has_column("bootstrap_records", "email") - - -def test_default_bootstrapper_can_migrate_legacy_v1_database(tmp_path: Path) -> None: - """默认桥接器应能把旧版 ``0.x`` 数据库整体迁移到最新结构。""" - engine = _create_sqlite_engine(tmp_path / "legacy_v1_to_v2.db") - bootstrapper = create_database_migration_bootstrapper(engine) - - with engine.begin() as connection: - _create_legacy_v1_schema_with_sample_data(connection) - - migration_state = bootstrapper.prepare_database() - bootstrapper.finalize_database(migration_state) - - assert not migration_state.requires_migration() - assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION - assert migration_state.resolved_version.source == SchemaVersionSource.PRAGMA - - with engine.connect() as connection: - recorded_version = SQLiteUserVersionStore().read_version(connection) - snapshot = SQLiteSchemaInspector().inspect(connection) - message_row = connection.execute( - text( - """ - SELECT session_id, processed_plain_text, additional_config, raw_content - FROM mai_messages - WHERE message_id = 'msg-1' - """ - ) - ).mappings().one() - tool_row = connection.execute( - text( - """ - SELECT session_id, tool_name, tool_display_prompt - FROM tool_records - WHERE tool_id = 'action-1' - """ - ) - ).mappings().one() - expression_row = connection.execute( - text( - """ - SELECT session_id, content_list, modified_by - FROM expressions - WHERE id = 1 - """ - ) - ).mappings().one() - jargon_row = connection.execute( - text( - """ - SELECT session_id_dict, raw_content, inference_with_content_only - FROM jargons - WHERE id = 1 - """ - ) - ).mappings().one() - - assert recorded_version == LATEST_SCHEMA_VERSION - assert snapshot.has_table("__legacy_v1_messages") - assert snapshot.has_table("chat_sessions") - assert snapshot.has_table("mai_messages") - assert snapshot.has_table("tool_records") - assert not snapshot.has_table("action_records") - assert not snapshot.has_column("mai_messages", "display_message") - - unpacked_raw_content = msgpack.unpackb(message_row["raw_content"], raw=False) - additional_config = json.loads(message_row["additional_config"]) - expression_content_list = json.loads(expression_row["content_list"]) - jargon_session_id_dict = json.loads(jargon_row["session_id_dict"]) - jargon_raw_content = json.loads(jargon_row["raw_content"]) - - assert message_row["session_id"] == "session-1" - assert message_row["processed_plain_text"] == "你好" - assert unpacked_raw_content == [{"type": "text", "data": "你好呀"}] - assert additional_config == {"priority_mode": "high", "source": "legacy"} - assert tool_row["session_id"] == "session-1" - assert tool_row["tool_name"] == "search" - assert tool_row["tool_display_prompt"] == "执行搜索" - assert expression_row["session_id"] == "session-1" - assert expression_row["modified_by"] == "AI" - assert expression_content_list == ["你好呀", "早上好"] - assert jargon_session_id_dict == {"session-1": 5} - assert jargon_raw_content == ["上分"] - assert jargon_row["inference_with_content_only"] == '{"guess":"content"}' - - -def test_legacy_v1_migration_reports_table_progress(tmp_path: Path) -> None: - """旧版迁移步骤应按目标表数量推进总进度。""" - engine = _create_sqlite_engine(tmp_path / "legacy_progress.db") - reporter_instances: List[FakeMigrationProgressReporter] = [] - - def _build_reporter() -> BaseMigrationProgressReporter: - """构建测试用进度上报器。 - - Returns: - BaseMigrationProgressReporter: 测试用进度上报器实例。 - """ - reporter = FakeMigrationProgressReporter() - reporter_instances.append(reporter) - return reporter - - with engine.begin() as connection: - _create_legacy_v1_schema_with_sample_data(connection) - - manager = DatabaseMigrationManager( - engine=engine, - registry=build_default_migration_registry(), - resolver=build_default_schema_version_resolver(), - progress_reporter_factory=_build_reporter, - ) - - migration_plan = manager.migrate(target_version=LATEST_SCHEMA_VERSION) - - assert migration_plan.step_count() == 3 - assert len(reporter_instances) == 3 - reporter_events = reporter_instances[0].events - - assert reporter_events[0] == ("open", None, None, None) - assert reporter_events[1] == ("start", 6, 12, "总迁移进度") - assert reporter_events[-1] == ("close", None, None, None) - assert reporter_events.count(("advance", 1, 0, None)) == 6 - assert reporter_events.count(("advance", 0, 1, "chat_sessions")) == 1 - assert reporter_events.count(("advance", 0, 1, "thinking_questions")) == 1 - assert len([event for event in reporter_events if event[0] == "advance"]) == 18 - - -def test_initialize_database_calls_bootstrapper_before_create_all( - monkeypatch: pytest.MonkeyPatch, - tmp_path: Path, -) -> None: - """数据库初始化入口应先准备迁移,再建表、补迁移并收尾。""" - call_order: List[str] = [] - - def _fake_prepare_database() -> DatabaseMigrationState: - """返回测试用迁移状态。 - - Returns: - DatabaseMigrationState: 不包含迁移步骤的测试状态。 - """ - call_order.append("prepare_database") - return DatabaseMigrationState( - resolved_version=ResolvedSchemaVersion(version=0, source=SchemaVersionSource.EMPTY_DATABASE), - target_version=LATEST_SCHEMA_VERSION, - plan=MigrationPlan( - current_version=EMPTY_SCHEMA_VERSION, - target_version=LATEST_SCHEMA_VERSION, - steps=[], - ), - ) - - def _fake_create_all(bind) -> None: - """记录建表调用。 - - Args: - bind: 传入的数据库绑定对象。 - """ - del bind - call_order.append("create_all") - - def _fake_finalize_database(migration_state: DatabaseMigrationState) -> None: - """记录迁移收尾调用。 - - Args: - migration_state: 当前数据库迁移状态。 - """ - del migration_state - call_order.append("finalize_database") - - monkeypatch.setattr(database_module, "_db_initialized", False) - monkeypatch.setattr(database_module, "_DB_DIR", tmp_path / "data") - monkeypatch.setattr(database_module._migration_bootstrapper, "prepare_database", _fake_prepare_database) - monkeypatch.setattr(database_module._migration_bootstrapper, "finalize_database", _fake_finalize_database) - monkeypatch.setattr(database_module.SQLModel.metadata, "create_all", _fake_create_all) - - database_module.initialize_database() - - assert call_order == [ - "prepare_database", - "create_all", - "finalize_database", - ] diff --git a/pytests/common_test/test_expression_learner.py b/pytests/common_test/test_expression_learner.py deleted file mode 100644 index 951aa424..00000000 --- a/pytests/common_test/test_expression_learner.py +++ /dev/null @@ -1,81 +0,0 @@ -"""测试表达方式学习器的数据库读取行为。""" - -from contextlib import contextmanager -from typing import Generator - -import pytest -from sqlalchemy.pool import StaticPool -from sqlmodel import Session, SQLModel, create_engine - -from src.bw_learner.expression_learner import ExpressionLearner -from src.common.database.database_model import Expression - - -@pytest.fixture(name="expression_learner_engine") -def expression_learner_engine_fixture() -> Generator: - """创建用于表达方式学习器测试的内存数据库引擎。 - - Yields: - Generator: 供测试使用的 SQLite 内存引擎。 - """ - engine = create_engine( - "sqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - SQLModel.metadata.create_all(engine) - yield engine - - -def test_find_similar_expression_uses_read_only_session_and_history_content( - monkeypatch: pytest.MonkeyPatch, - expression_learner_engine, -) -> None: - """查找相似表达方式时,应能在离开会话后安全使用结果,并比较历史情景内容。""" - import src.bw_learner.expression_learner as expression_learner_module - - with Session(expression_learner_engine) as session: - session.add( - Expression( - situation="发送汗滴表情", - style="发送💦表情符号", - content_list='["表达情绪高涨或生理反应"]', - count=1, - session_id="session-a", - checked=False, - rejected=False, - ) - ) - session.commit() - - @contextmanager - def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]: - """构造带自动提交语义的测试会话工厂。 - - Args: - auto_commit: 退出上下文时是否自动提交。 - - Yields: - Generator[Session, None, None]: SQLModel 会话对象。 - """ - session = Session(expression_learner_engine) - try: - yield session - if auto_commit: - session.commit() - except Exception: - session.rollback() - raise - finally: - session.close() - - monkeypatch.setattr(expression_learner_module, "get_db_session", fake_get_db_session) - - learner = ExpressionLearner(session_id="session-a") - result = learner._find_similar_expression("表达情绪高涨或生理反应") - - assert result is not None - expression, similarity = result - assert expression.item_id is not None - assert expression.style == "发送💦表情符号" - assert similarity == pytest.approx(1.0) diff --git a/pytests/common_test/test_expression_schema.py b/pytests/common_test/test_expression_schema.py deleted file mode 100644 index 31fcd98f..00000000 --- a/pytests/common_test/test_expression_schema.py +++ /dev/null @@ -1,78 +0,0 @@ -"""测试表达方式表结构和基础插入行为。""" - -from typing import Generator - -import pytest -from sqlalchemy.pool import StaticPool -from sqlmodel import Session, SQLModel, create_engine - -from src.common.database.database_model import Expression - - -@pytest.fixture(name="expression_engine") -def expression_engine_fixture() -> Generator: - """创建仅用于表达方式表测试的内存数据库引擎。 - - Yields: - Generator: 供测试使用的 SQLite 内存引擎。 - """ - engine = create_engine( - "sqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - SQLModel.metadata.create_all(engine) - yield engine - - -def test_expression_insert_assigns_auto_increment_id(expression_engine) -> None: - """表达方式表在新库中应能自动分配自增主键。""" - with Session(expression_engine) as session: - expression = Expression( - situation="表达情绪高涨或生理反应", - style="发送💦表情符号", - content_list='["表达情绪高涨或生理反应"]', - count=1, - session_id="session-a", - checked=False, - rejected=False, - ) - session.add(expression) - session.commit() - session.refresh(expression) - - assert expression.id is not None - assert expression.id > 0 - - -def test_expression_insert_allows_same_situation_style(expression_engine) -> None: - """相同情景和风格的表达方式记录不应再被错误绑定到复合主键。""" - with Session(expression_engine) as session: - first_expression = Expression( - situation="对重复行为的默契响应", - style="持续性跟发相同内容", - content_list='["对重复行为的默契响应"]', - count=1, - session_id="session-a", - checked=False, - rejected=False, - ) - second_expression = Expression( - situation="对重复行为的默契响应", - style="持续性跟发相同内容", - content_list='["对重复行为的默契响应-变体"]', - count=2, - session_id="session-b", - checked=False, - rejected=False, - ) - - session.add(first_expression) - session.add(second_expression) - session.commit() - session.refresh(first_expression) - session.refresh(second_expression) - - assert first_expression.id is not None - assert second_expression.id is not None - assert first_expression.id != second_expression.id diff --git a/pytests/common_test/test_jargon_miner.py b/pytests/common_test/test_jargon_miner.py deleted file mode 100644 index bf81e4d2..00000000 --- a/pytests/common_test/test_jargon_miner.py +++ /dev/null @@ -1,90 +0,0 @@ -"""测试黑话学习器的数据库读取行为。""" - -from contextlib import contextmanager -from typing import Generator - -import pytest -from sqlalchemy.pool import StaticPool -from sqlmodel import Session, SQLModel, create_engine, select - -from src.bw_learner.jargon_miner import JargonMiner -from src.common.database.database_model import Jargon - - -@pytest.fixture(name="jargon_miner_engine") -def jargon_miner_engine_fixture() -> Generator: - """创建用于黑话学习器测试的内存数据库引擎。 - - Yields: - Generator: 供测试使用的 SQLite 内存引擎。 - """ - engine = create_engine( - "sqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - SQLModel.metadata.create_all(engine) - yield engine - - -@pytest.mark.asyncio -async def test_process_extracted_entries_updates_existing_jargon_without_detached_session( - monkeypatch: pytest.MonkeyPatch, - jargon_miner_engine, -) -> None: - """更新已有黑话时,不应因会话关闭导致 ORM 实例失效。""" - import src.bw_learner.jargon_miner as jargon_miner_module - - with Session(jargon_miner_engine) as session: - session.add( - Jargon( - content="VF8V4L", - raw_content='["[1] first"]', - meaning="", - session_id_dict='{"session-a": 1}', - count=0, - is_jargon=True, - is_complete=False, - is_global=False, - last_inference_count=0, - ) - ) - session.commit() - - @contextmanager - def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]: - """构造带自动提交语义的测试会话工厂。 - - Args: - auto_commit: 退出上下文时是否自动提交。 - - Yields: - Generator[Session, None, None]: SQLModel 会话对象。 - """ - session = Session(jargon_miner_engine) - try: - yield session - if auto_commit: - session.commit() - except Exception: - session.rollback() - raise - finally: - session.close() - - monkeypatch.setattr(jargon_miner_module, "get_db_session", fake_get_db_session) - - jargon_miner = JargonMiner(session_id="session-a", session_name="测试群") - await jargon_miner.process_extracted_entries( - [{"content": "VF8V4L", "raw_content": {"[2] second"}}], - ) - - with Session(jargon_miner_engine) as session: - db_jargon = session.exec(select(Jargon).where(Jargon.content == "VF8V4L")).one() - - assert db_jargon.count == 1 - assert db_jargon.session_id_dict == '{"session-a": 2}' - assert sorted(db_jargon.raw_content and __import__("json").loads(db_jargon.raw_content)) == [ - "[1] first", - "[2] second", - ] diff --git a/pytests/common_test/test_jargon_schema.py b/pytests/common_test/test_jargon_schema.py deleted file mode 100644 index 909392ab..00000000 --- a/pytests/common_test/test_jargon_schema.py +++ /dev/null @@ -1,84 +0,0 @@ -"""测试黑话表结构和基础插入行为。""" - -from typing import Generator - -import pytest -from sqlalchemy.pool import StaticPool -from sqlmodel import Session, SQLModel, create_engine - -from src.common.database.database_model import Jargon - - -@pytest.fixture(name="jargon_engine") -def jargon_engine_fixture() -> Generator: - """创建仅用于黑话表测试的内存数据库引擎。 - - Yields: - Generator: 供测试使用的 SQLite 内存引擎。 - """ - engine = create_engine( - "sqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - SQLModel.metadata.create_all(engine) - yield engine - - -def test_jargon_insert_assigns_auto_increment_id(jargon_engine) -> None: - """黑话表在新库中应能自动分配自增主键。""" - with Session(jargon_engine) as session: - jargon = Jargon( - content="VF8V4L", - raw_content='["[1] test"]', - meaning="", - session_id_dict='{"session-a": 1}', - count=1, - is_jargon=True, - is_complete=False, - is_global=True, - last_inference_count=0, - ) - session.add(jargon) - session.commit() - session.refresh(jargon) - - assert jargon.id is not None - assert jargon.id > 0 - - -def test_jargon_insert_allows_same_content_with_different_rows(jargon_engine) -> None: - """黑话内容不应再被错误地绑成复合主键的一部分。""" - with Session(jargon_engine) as session: - first_jargon = Jargon( - content="表情1", - raw_content='["[1] first"]', - meaning="", - session_id_dict='{"session-a": 1}', - count=1, - is_jargon=True, - is_complete=False, - is_global=False, - last_inference_count=0, - ) - second_jargon = Jargon( - content="表情1", - raw_content='["[1] second"]', - meaning="", - session_id_dict='{"session-b": 1}', - count=1, - is_jargon=True, - is_complete=False, - is_global=False, - last_inference_count=0, - ) - - session.add(first_jargon) - session.add(second_jargon) - session.commit() - session.refresh(first_jargon) - session.refresh(second_jargon) - - assert first_jargon.id is not None - assert second_jargon.id is not None - assert first_jargon.id != second_jargon.id diff --git a/pytests/common_test/test_maisaka_expression_selector.py b/pytests/common_test/test_maisaka_expression_selector.py deleted file mode 100644 index 31358a6a..00000000 --- a/pytests/common_test/test_maisaka_expression_selector.py +++ /dev/null @@ -1,135 +0,0 @@ -from types import SimpleNamespace - -import pytest - -import src.chat.replyer.maisaka_expression_selector as selector_module -from src.chat.replyer.maisaka_expression_selector import MaisakaExpressionSelector -from src.common.utils.utils_session import SessionUtils - - -def _build_target(platform: str, item_id: str, rule_type: str = "group") -> SimpleNamespace: - return SimpleNamespace(platform=platform, item_id=item_id, rule_type=rule_type) - - -def test_resolve_expression_group_scope_returns_related_sessions(monkeypatch: pytest.MonkeyPatch) -> None: - current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001") - related_session_id = SessionUtils.calculate_session_id("qq", group_id="10002") - - monkeypatch.setattr( - selector_module, - "global_config", - SimpleNamespace( - expression=SimpleNamespace( - expression_groups=[ - SimpleNamespace( - expression_groups=[ - _build_target("qq", "10001"), - _build_target("qq", "10002"), - ] - ) - ] - ) - ), - ) - - selector = MaisakaExpressionSelector() - related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id) - - assert related_session_ids == {current_session_id, related_session_id} - assert has_global_share is False - - -def test_resolve_expression_group_scope_matches_routed_sessions(monkeypatch: pytest.MonkeyPatch) -> None: - current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001", account_id="bot-a") - related_session_id = SessionUtils.calculate_session_id("qq", group_id="10002", account_id="bot-a") - - monkeypatch.setattr( - selector_module, - "global_config", - SimpleNamespace( - expression=SimpleNamespace( - expression_groups=[ - SimpleNamespace( - expression_groups=[ - _build_target("qq", "10001"), - _build_target("qq", "10002"), - ] - ) - ] - ) - ), - ) - monkeypatch.setattr( - selector_module.ChatConfigUtils, - "_get_chat_stream", - lambda session_id: SimpleNamespace(platform="qq", group_id="10001", user_id=None) - if session_id == current_session_id - else None, - ) - target_session_ids = { - "10001": current_session_id, - "10002": related_session_id, - } - monkeypatch.setattr( - selector_module.ChatConfigUtils, - "get_target_session_ids", - lambda target_item: {target_session_ids[target_item.item_id]}, - ) - - selector = MaisakaExpressionSelector() - related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id) - - assert related_session_ids == {current_session_id, related_session_id} - assert has_global_share is False - - -def test_resolve_expression_group_scope_uses_star_as_global_share(monkeypatch: pytest.MonkeyPatch) -> None: - current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001") - - monkeypatch.setattr( - selector_module, - "global_config", - SimpleNamespace( - expression=SimpleNamespace( - expression_groups=[ - SimpleNamespace( - expression_groups=[ - _build_target("*", "*"), - ] - ) - ] - ) - ), - ) - - selector = MaisakaExpressionSelector() - related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id) - - assert related_session_ids == {current_session_id} - assert has_global_share is True - - -def test_resolve_expression_group_scope_does_not_treat_empty_target_as_global(monkeypatch: pytest.MonkeyPatch) -> None: - current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001") - - monkeypatch.setattr( - selector_module, - "global_config", - SimpleNamespace( - expression=SimpleNamespace( - expression_groups=[ - SimpleNamespace( - expression_groups=[ - _build_target("", ""), - ] - ) - ] - ) - ), - ) - - selector = MaisakaExpressionSelector() - related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id) - - assert related_session_ids == {current_session_id} - assert has_global_share is False diff --git a/pytests/common_test/test_person_info_group_cardname.py b/pytests/common_test/test_person_info_group_cardname.py deleted file mode 100644 index 62a63f43..00000000 --- a/pytests/common_test/test_person_info_group_cardname.py +++ /dev/null @@ -1,355 +0,0 @@ -"""人物信息群名片字段兼容测试。""" - -from __future__ import annotations - -from importlib.util import module_from_spec, spec_from_file_location -from pathlib import Path -from types import ModuleType, SimpleNamespace -from typing import Any - -import json -import sys - -import pytest - -from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json - - -class _DummyLogger: - """模拟日志记录器。""" - - def debug(self, message: str) -> None: - """记录调试日志。 - - Args: - message: 日志内容。 - """ - del message - - def info(self, message: str) -> None: - """记录信息日志。 - - Args: - message: 日志内容。 - """ - del message - - def warning(self, message: str) -> None: - """记录警告日志。 - - Args: - message: 日志内容。 - """ - del message - - def error(self, message: str) -> None: - """记录错误日志。 - - Args: - message: 日志内容。 - """ - del message - - -class _DummyStatement: - """模拟 SQL 查询语句对象。""" - - def where(self, condition: Any) -> "_DummyStatement": - """附加过滤条件。 - - Args: - condition: 过滤条件。 - - Returns: - _DummyStatement: 当前语句对象。 - """ - del condition - return self - - def limit(self, value: int) -> "_DummyStatement": - """限制返回条数。 - - Args: - value: 条数限制。 - - Returns: - _DummyStatement: 当前语句对象。 - """ - del value - return self - - -class _DummyColumn: - """模拟 SQLModel 列对象。""" - - def is_not(self, value: Any) -> "_DummyColumn": - """模拟 `IS NOT` 条件构造。 - - Args: - value: 比较值。 - - Returns: - _DummyColumn: 当前列对象。 - """ - del value - return self - - def __eq__(self, other: Any) -> "_DummyColumn": - """模拟等值条件构造。 - - Args: - other: 比较值。 - - Returns: - _DummyColumn: 当前列对象。 - """ - del other - return self - - -class _DummyResult: - """模拟数据库查询结果。""" - - def __init__(self, record: Any) -> None: - """初始化查询结果。 - - Args: - record: 待返回的首条记录。 - """ - self._record = record - - def first(self) -> Any: - """返回第一条记录。 - - Returns: - Any: 首条记录。 - """ - return self._record - - def all(self) -> list[Any]: - """返回全部结果。 - - Returns: - list[Any]: 结果列表。 - """ - if self._record is None: - return [] - return self._record if isinstance(self._record, list) else [self._record] - - -class _DummySession: - """模拟数据库 Session。""" - - def __init__(self, record: Any) -> None: - """初始化 Session。 - - Args: - record: `first()` 应返回的记录。 - """ - self.record = record - self.added_records: list[Any] = [] - - def __enter__(self) -> "_DummySession": - """进入上下文管理器。 - - Returns: - _DummySession: 当前 Session。 - """ - return self - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - """退出上下文管理器。 - - Args: - exc_type: 异常类型。 - exc_val: 异常值。 - exc_tb: 异常回溯。 - """ - del exc_type - del exc_val - del exc_tb - - def exec(self, statement: Any) -> _DummyResult: - """执行查询。 - - Args: - statement: 查询语句。 - - Returns: - _DummyResult: 模拟结果对象。 - """ - del statement - return _DummyResult(self.record) - - def add(self, record: Any) -> None: - """记录被添加的对象。 - - Args: - record: 被写入 Session 的对象。 - """ - self.added_records.append(record) - - -class _DummyPersonInfoRecord: - """模拟 `PersonInfo` ORM 模型。""" - - person_id = "person_id" - person_name = "person_name" - - def __init__(self, **kwargs: Any) -> None: - """使用关键字参数初始化记录对象。 - - Args: - **kwargs: 字段值。 - """ - for key, value in kwargs.items(): - setattr(self, key, value) - - -def _load_person_module(monkeypatch: pytest.MonkeyPatch, session: _DummySession) -> ModuleType: - """加载带依赖桩的 `person_info` 模块。 - - Args: - monkeypatch: Pytest monkeypatch 工具。 - session: 提供给模块使用的假数据库 Session。 - - Returns: - ModuleType: 加载后的模块对象。 - """ - logger_module = ModuleType("src.common.logger") - logger_module.get_logger = lambda name: _DummyLogger() - monkeypatch.setitem(sys.modules, "src.common.logger", logger_module) - - database_module = ModuleType("src.common.database.database") - database_module.get_db_session = lambda: session - monkeypatch.setitem(sys.modules, "src.common.database.database", database_module) - - database_model_module = ModuleType("src.common.database.database_model") - database_model_module.PersonInfo = _DummyPersonInfoRecord - monkeypatch.setitem(sys.modules, "src.common.database.database_model", database_model_module) - - llm_module = ModuleType("src.llm_models.utils_model") - - class _DummyLLMRequest: - """模拟 LLMRequest。""" - - def __init__(self, model_set: Any, request_type: str) -> None: - """初始化假请求对象。 - - Args: - model_set: 模型配置。 - request_type: 请求类型。 - """ - del model_set - del request_type - - llm_module.LLMRequest = _DummyLLMRequest - monkeypatch.setitem(sys.modules, "src.llm_models.utils_model", llm_module) - - config_module = ModuleType("src.config.config") - config_module.global_config = SimpleNamespace(bot=SimpleNamespace(nickname="MaiBot")) - config_module.model_config = SimpleNamespace(model_task_config=SimpleNamespace(tool_use="tool_use", utils="utils")) - monkeypatch.setitem(sys.modules, "src.config.config", config_module) - - chat_manager_module = ModuleType("src.chat.message_receive.chat_manager") - chat_manager_module.chat_manager = SimpleNamespace() - monkeypatch.setitem(sys.modules, "src.chat.message_receive.chat_manager", chat_manager_module) - - module_path = Path(__file__).resolve().parents[2] / "src" / "person_info" / "person_info.py" - spec = spec_from_file_location("person_info_group_cardname_test_module", module_path) - assert spec is not None and spec.loader is not None - - module = module_from_spec(spec) - monkeypatch.setitem(sys.modules, spec.name, module) - spec.loader.exec_module(module) - - monkeypatch.setattr(module, "select", lambda *args: _DummyStatement()) - monkeypatch.setattr(module, "col", lambda field: _DummyColumn()) - return module - - -def test_parse_group_cardname_json_uses_canonical_key() -> None: - """群名片 JSON 解析应只使用 `group_cardname` 键名。""" - parsed = parse_group_cardname_json( - json.dumps( - [ - {"group_id": "1001", "group_cardname": "现行字段"}, - ], - ensure_ascii=False, - ) - ) - - assert parsed is not None - assert [(item.group_id, item.group_cardname) for item in parsed] == [ - ("1001", "现行字段"), - ] - - -def test_dump_group_cardname_records_uses_canonical_key() -> None: - """群名片序列化应输出 `group_cardname` 键名。""" - dumped = dump_group_cardname_records( - [ - {"group_id": "1001", "group_cardname": "群昵称"}, - ] - ) - - assert json.loads(dumped) == [{"group_id": "1001", "group_cardname": "群昵称"}] - - -def test_person_sync_to_database_uses_group_cardname_field(monkeypatch: pytest.MonkeyPatch) -> None: - """同步人物信息时应写入数据库模型的 `group_cardname` 字段。""" - record = _DummyPersonInfoRecord() - session = _DummySession(record) - module = _load_person_module(monkeypatch, session) - - person = module.Person.__new__(module.Person) - person.is_known = True - person.person_id = "person-1" - person.platform = "qq" - person.user_id = "10001" - person.nickname = "看番的龙" - person.person_name = "看番的龙" - person.name_reason = "测试" - person.know_times = 1 - person.know_since = 1700000000.0 - person.last_know = 1700000100.0 - person.memory_points = ["喜好:番剧:0.8"] - person.group_cardname_list = [{"group_id": "20001", "group_cardname": "白泽大人"}] - - person.sync_to_database() - - assert record.group_cardname == '[{"group_id": "20001", "group_cardname": "白泽大人"}]' - assert not hasattr(record, "group_nickname") - - -def test_person_load_from_database_normalizes_group_cardname_payload(monkeypatch: pytest.MonkeyPatch) -> None: - """从数据库加载人物信息时应读取标准 `group_cardname` 结构。""" - record = _DummyPersonInfoRecord( - user_id="10001", - platform="qq", - is_known=True, - user_nickname="看番的龙", - person_name="看番的龙", - name_reason=None, - know_counts=2, - memory_points='["喜好:番剧:0.8"]', - group_cardname=json.dumps( - [ - {"group_id": "20001", "group_cardname": "白泽大人"}, - ], - ensure_ascii=False, - ), - ) - session = _DummySession(record) - module = _load_person_module(monkeypatch, session) - - person = module.Person.__new__(module.Person) - person.person_id = "person-1" - person.memory_points = [] - person.group_cardname_list = [] - - person.load_from_database() - - assert person.group_cardname_list == [ - {"group_id": "20001", "group_cardname": "白泽大人"}, - ] diff --git a/pytests/config_test/test_config_base.py b/pytests/config_test/test_config_base.py deleted file mode 100644 index f67c8c56..00000000 --- a/pytests/config_test/test_config_base.py +++ /dev/null @@ -1,533 +0,0 @@ -import logging -import sys -from importlib import util -from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple, Union - -import pytest -from pydantic import BaseModel, Field - -# ------------------------------------------------------------- -# 测试环境准备:补全 logger 和 AttrDocBase 依赖 -# ------------------------------------------------------------- - -TEST_ROOT = Path(__file__).parent.parent.absolute().resolve() -logger_file = TEST_ROOT / "logger.py" -spec = util.spec_from_file_location("src.common.logger", logger_file) -module = util.module_from_spec(spec) # type: ignore -assert spec is not None and spec.loader is not None -spec.loader.exec_module(module) # type: ignore -sys.modules["src.common.logger"] = module - -PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve() -sys.path.insert(0, str(PROJECT_ROOT)) -sys.path.insert(0, str(PROJECT_ROOT / "src" / "config")) - -from src.config.config_base import ConfigBase # noqa: E402 -import src.config.config_base as config_base_module # noqa: E402 - - -class AttrDocBase: - """用于测试的轻量级 AttrDocBase 替身""" - - def __post_init__(self) -> None: - # 被 ConfigBase.model_post_init 调用 - self.__post_init_called__ = True - - -# 打补丁,让 ConfigBase 使用测试替身 -@pytest.fixture(autouse=True) -def patch_attrdoc_post_init(): - orig = config_base_module.AttrDocBase.__post_init__ - config_base_module.AttrDocBase.__post_init__ = AttrDocBase.__post_init__ # type: ignore - yield - config_base_module.AttrDocBase.__post_init__ = orig - - -config_base_module.logger = logging.getLogger("config_base_test_logger") - - -class SimpleClass(ConfigBase): - a: int = 1 - b: str = "test" - - -class TestConfigBase: - # --------------------------------------------------------- - # happy path:整体 model_post_init 测试 - # --------------------------------------------------------- - @pytest.mark.parametrize( - "model_cls, init_kwargs, expected_fields", - [ - pytest.param( - # 简单原子类型字段 - type( - "SimpleAtomic", - (ConfigBase,), - { - "__annotations__": { - "a": int, - "b": str, - "c": bool, - "d": float, - }, - "a": Field(default=1), - "b": Field(default="x"), - "c": Field(default=True), - "d": Field(default=1.5), - }, - ), - {}, - {"a", "b", "c", "d"}, - id="happy-simple-atomic-fields", - ), - pytest.param( - # list/set/dict 泛型 + 原子内部类型 - type( - "AtomicContainers", - (ConfigBase,), - { - "__annotations__": { - "ints": List[int], - "names": Set[str], - "mapping": Dict[str, int], - }, - "ints": Field(default_factory=lambda: [1, 2]), - "names": Field(default_factory=lambda: {"a", "b"}), - "mapping": Field(default_factory=lambda: {"x": 1}), - }, - ), - {}, - {"ints", "names", "mapping"}, - id="happy-atomic-containers", - ), - pytest.param( - # Optional 原子和 Optional 容器 - type( - "OptionalFields", - (ConfigBase,), - { - "__annotations__": { - "maybe_int": Optional[int], - "maybe_str_list": Optional[List[str]], - }, - "maybe_int": Field(default=None), - "maybe_str_list": Field(default=None), - }, - ), - {}, - {"maybe_int", "maybe_str_list"}, - id="happy-optional-fields", - ), - ], - ) - def test_model_post_init_happy_paths(self, model_cls, init_kwargs, expected_fields): - # Act - instance = model_cls(**init_kwargs) - - # Assert - for field_name in expected_fields: - assert field_name in type(instance).model_fields - _ = getattr(instance, field_name) - assert getattr(instance, "__post_init_called__", False) is True - - # --------------------------------------------------------- - # _get_real_type - # --------------------------------------------------------- - def test_get_real_type_non_generic_and_generic(self): - class Sample(ConfigBase): - x: int = 1 - y: List[int] = Field(default_factory=list) - - instance = Sample() - - # Act - origin_x, args_x = instance._get_real_type(int) - - # Assert - assert origin_x is int - assert args_x == () - - # Act - origin_y, args_y = instance._get_real_type(List[int]) - - # Assert - assert origin_y in (list, List) - assert args_y == (int,) - - # --------------------------------------------------------- - # _validate_union_type - # --------------------------------------------------------- - @pytest.mark.parametrize( - "annotation, expect_error, error_fragment, expected_origin_type", - [ - pytest.param( - int, - False, - None, - int, - id="union-validation-atomic-non-union", - ), - pytest.param( - Optional[int], - False, - None, - int, - id="union-validation-optional-atomic", - ), - pytest.param( - Optional[List[int]], - False, - None, - list, - id="union-validation-optional-container", - ), - pytest.param( - Union[int, str], - True, - "不允许使用 Union 类型注解", - None, - id="union-validation-disallow-non-optional-union", - ), - pytest.param( - int | str, - True, - "不允许使用 Union 类型注解", - None, - id="union-validation-pep604-disallow-non-optional-union", - ), - pytest.param( - Union[int, None, str], - True, - "不允许使用 Union 类型注解", - None, - id="union-validation-disallow-union-more-than-two", - ), - pytest.param( - Optional[Union[int, str]], - True, - "不允许使用 Union 类型注解", - None, - id="union-validation-disallow-nested-optional-union", - ), - ], - ) - def test_validate_union_type(self, annotation, expect_error, error_fragment, expected_origin_type): - # 这里我们不实例化 Sample,以避免在 __init__/model_post_init 阶段触发验证。 - # 直接通过一个“哑实例”调用受测方法,仅测试类型注解逻辑。 - - class Dummy(ConfigBase): - pass - - dummy = Dummy() # 最小初始化,避免字段校验 - - field_name = "v" - - if expect_error: - # Act / Assert - with pytest.raises(TypeError) as exc_info: - dummy._validate_union_type(annotation, field_name) - assert error_fragment in str(exc_info.value) - else: - # Act - origin, args, other = dummy._validate_union_type(annotation, field_name) - - # Assert - assert origin is expected_origin_type - assert other is not None - - # --------------------------------------------------------- - # _validate_list_set_type - # --------------------------------------------------------- - @pytest.mark.parametrize( - "annotation, expect_error, error_fragment", - [ - pytest.param( - List[int], - False, - None, - id="listset-validation-list-happy", - ), - pytest.param( - Set[str], - False, - None, - id="listset-validation-set-happy", - ), - pytest.param( - list, - True, - "必须指定且仅指定一个类型参数", - id="listset-validation-missing-type-arg", - ), - pytest.param( - List[int | None], - True, - "不允许嵌套泛型类型", - id="listset-validation-nested-generic-inner-union", - ), - pytest.param( - List[List[int]], - True, - "不允许嵌套泛型类型", - id="listset-validation-nested-generic-inner-list", - ), - pytest.param( - List[SimpleClass], - False, - None, - id="listset-validation-list-configbase-element_allow", - ), - pytest.param( - Set[SimpleClass], - True, - "ConfigBase is not Hashable", - id="listset-validation-set-configbase-element_reject", - ), - ], - ) - def test_validate_list_set_type(self, annotation, expect_error, error_fragment): - # 不实例化带有这些字段的模型,避免在 __init__/model_post_init 阶段就失败, - # 只测试 _validate_list_set_type 本身的逻辑。 - - class Dummy(ConfigBase): - pass - - dummy = Dummy() - - field_name = "items" - - if expect_error: - # Act / Assert - with pytest.raises(TypeError) as exc_info: - dummy._validate_list_set_type(annotation, field_name) - assert error_fragment in str(exc_info.value) - else: - # Act - dummy._validate_list_set_type(annotation, field_name) - - # --------------------------------------------------------- - # _validate_dict_type - # --------------------------------------------------------- - @pytest.mark.parametrize( - "annotation, expect_error, error_fragment", - [ - pytest.param( - Dict[str, int], - False, - None, - id="dict-validation-happy-atomic", - ), - pytest.param( - Dict[str, Any], - True, - "不允许使用 Any 类型注解", - id="dict-validation-any-value-disallowed", - ), - pytest.param( - Dict[str, Dict[str, int]], - True, - "不允许嵌套泛型类型", - id="dict-validation-optional-nested-list", - ), - pytest.param( - Dict, - True, - "必须指定键和值的类型参数", - id="dict-validation-missing-args", - ), - pytest.param( - Dict[str, SimpleClass], - False, - None, - id="dict-validation-happy-configbase-value", - ), - ], - ) - def test_validate_dict_type(self, annotation, expect_error, error_fragment): - # 同样不通过字段定义来触发 model_post_init,只测试 _validate_dict_type 本身。 - - class Dummy(ConfigBase): - _validate_any: bool = True - - dummy = Dummy() - field_name = "mapping" - - if expect_error: - # Act / Assert - with pytest.raises(TypeError) as exc_info: - dummy._validate_dict_type(annotation, field_name) - assert error_fragment in str(exc_info.value) - else: - # Act - dummy._validate_dict_type(annotation, field_name) - - # --------------------------------------------------------- - # _discourage_any_usage - # --------------------------------------------------------- - def test_discourage_any_usage_raises_when_validate_any_true(self, caplog): - class Sample(ConfigBase): - _validate_any: bool = True - - instance = Sample() - - # Act / Assert - with pytest.raises(TypeError) as exc_info: - instance._discourage_any_usage("field_x") - assert "不允许使用 Any 类型注解" in str(exc_info.value) - assert "建议避免使用" not in caplog.text - - def test_discourage_any_usage_logs_when_validate_any_false(self, caplog): - class Sample(ConfigBase): - _validate_any: bool = False - - instance = Sample() - - # Arrange - caplog.set_level(logging.WARNING, logger="config_base_test_logger") - - # Act - instance._discourage_any_usage("field_y") - - # Assert - assert "字段'field_y'中使用了 Any 类型注解" in caplog.text - - def test_discourage_any_usage_suppressed_warning(self, caplog): - class Sample(ConfigBase): - _validate_any: bool = False - suppress_any_warning: bool = True - - instance = Sample() - - # Arrange - caplog.set_level(logging.WARNING, logger="config_base_test_logger") - - # Act - instance._discourage_any_usage("field_z") - - # Assert - assert "字段'field_z'中使用了 Any 类型注解" not in caplog.text - - # --------------------------------------------------------- - # model_post_init 规则覆盖(错误与边界情况) - # --------------------------------------------------------- - @pytest.mark.parametrize( - "field_annotation, expect_error, error_fragment, test_id", - [ - ( - Tuple[int, int], - True, - "不允许使用 Tuple 类型注解", - "model-post-init-disallow-tuple-typing-tuple", - ), - ( - tuple[int, int], - True, - "不允许使用 Tuple 类型注解", - "model-post-init-disallow-pep604-tuple", - ), - ( - Union[int, str], - True, - "不允许使用 Union 类型注解", - "model-post-init-disallow-union-field", - ), - ( - list, - True, - "必须指定且仅指定一个类型参数", - "model-post-init-list-missing-type-arg", - ), - ( - List[List[int]], - True, - "不允许嵌套泛型类型", - "model-post-init-list-nested-generic", - ), - ( - Dict[str, Any], - True, - "不允许使用 Any 类型注解", - "model-post-init-dict-value-any", - ), - ( - Any, - True, - "不允许使用 Any 类型注解", - "model-post-init-field-any-disallowed", - ), - ( - Set[int], - False, - None, - "model-post-init-allow-set-int", - ), - ( - Dict[str, Optional[int]], - False, - None, - "model-post-init-allow-dict-optional-int", - ), - ], - ids=lambda v: v[3] if isinstance(v, tuple) else v, - ) - def test_model_post_init_type_rules(self, field_annotation, expect_error, error_fragment, test_id): - # Arrange - attrs = { - "__annotations__": {"f": field_annotation}, - "f": Field(default=None), - } - model_cls = type("DynamicModel" + test_id.replace("-", "_"), (ConfigBase,), attrs) - - if expect_error: - # Act / Assert - with pytest.raises(TypeError) as exc_info: - model_cls() - assert error_fragment in str(exc_info.value) - else: - # Act - instance = model_cls() - - # Assert - assert hasattr(instance, "f") - - # --------------------------------------------------------- - # 嵌套 ConfigBase & 非支持泛型 origin - # --------------------------------------------------------- - def test_model_post_init_allows_configbase_nested_class(self): - class Child(ConfigBase): - value: int = 1 - - class Parent(ConfigBase): - child: Child = Field(default_factory=Child) - - # Act - parent = Parent() - - # Assert - assert isinstance(parent.child, Child) - - def test_model_post_init_disallow_non_supported_generic_origin(self): - class CustomGeneric(BaseModel): - pass - - class Sample(ConfigBase): - f: CustomGeneric = Field(default_factory=CustomGeneric) - - # Arrange / Act / Assert - with pytest.raises(TypeError) as exc_info: - Sample() - assert "仅允许使用list, set, dict三种泛型类型注解" in str(exc_info.value) - - # --------------------------------------------------------- - # super().model_post_init 和 AttrDocBase.__post_init__ 调用 - # --------------------------------------------------------- - def test_super_model_post_init_and_attrdoc_post_init_called(self): - class Sample(ConfigBase): - value: int = 1 - - # Act - instance = Sample() - - # Assert - assert getattr(instance, "__post_init_called__", False) is True diff --git a/pytests/config_test/test_config_manager_hot_reload.py b/pytests/config_test/test_config_manager_hot_reload.py deleted file mode 100644 index a42a4133..00000000 --- a/pytests/config_test/test_config_manager_hot_reload.py +++ /dev/null @@ -1,104 +0,0 @@ -from pathlib import Path - -from watchfiles import Change - -import asyncio -import pytest - -from src.config.config import ConfigManager -from src.config.file_watcher import FileChange, FileWatcherStats - - -@pytest.mark.asyncio -async def test_handle_file_changes_throttles_reload(): - manager = ConfigManager() - manager._hot_reload_min_interval_s = 100.0 - - called = 0 - - async def reload_stub(changed_scopes=None) -> bool: - nonlocal called - called += 1 - return True - - manager.reload_config = reload_stub # type: ignore[method-assign] - changes = [FileChange(change_type=Change.modified, path=Path("/tmp/bot_config.toml"))] - - await manager._handle_file_changes(changes) - await manager._handle_file_changes(changes) - - assert called == 1 - - -@pytest.mark.asyncio -async def test_handle_file_changes_timeout_logged(caplog): - manager = ConfigManager() - manager._hot_reload_min_interval_s = 0.0 - manager._hot_reload_timeout_s = 0.01 - - async def reload_stub(changed_scopes=None) -> bool: - await asyncio.sleep(0.05) - return True - - manager.reload_config = reload_stub # type: ignore[method-assign] - changes = [FileChange(change_type=Change.modified, path=Path("/tmp/model_config.toml"))] - - with caplog.at_level("ERROR"): - await manager._handle_file_changes(changes) - - assert "配置热重载超时" in caplog.text - - -@pytest.mark.asyncio -async def test_handle_file_changes_empty_skips_reload(): - manager = ConfigManager() - - called = 0 - - async def reload_stub(changed_scopes=None) -> bool: - nonlocal called - called += 1 - return True - - manager.reload_config = reload_stub # type: ignore[method-assign] - - await manager._handle_file_changes([]) - - assert called == 0 - - -class _FakeWatcher: - def __init__(self): - self.unsubscribe_called_with: str | None = None - self.stop_called = False - self.stats = FileWatcherStats( - batches_seen=1, - changes_seen=2, - callbacks_succeeded=3, - callbacks_failed=4, - callbacks_timed_out=5, - callbacks_skipped_cooldown=6, - restart_count=7, - ) - - def unsubscribe(self, subscription_id: str) -> bool: - self.unsubscribe_called_with = subscription_id - return True - - async def stop(self) -> None: - self.stop_called = True - - -@pytest.mark.asyncio -async def test_stop_file_watcher_cleans_state(): - manager = ConfigManager() - fake_watcher = _FakeWatcher() - manager._file_watcher = fake_watcher # type: ignore[assignment] - manager._file_watcher_subscription_id = "sub-1" - - await manager.stop_file_watcher() - - assert fake_watcher.unsubscribe_called_with == "sub-1" - assert fake_watcher.stop_called is True - assert manager._file_watcher is None - assert manager._file_watcher_subscription_id is None diff --git a/pytests/config_test/test_config_manager_startup_upgrade.py b/pytests/config_test/test_config_manager_startup_upgrade.py deleted file mode 100644 index f0c92205..00000000 --- a/pytests/config_test/test_config_manager_startup_upgrade.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Any - -from src.config import config as config_module -from src.config.config import Config, ConfigManager, ModelConfig - - -def test_initialize_upgrades_bot_and_model_config_without_exit(monkeypatch): - manager = ConfigManager() - loaded_config_classes: list[type[Any]] = [] - warnings: list[Any] = [] - - def fake_load_config_from_file(config_class, config_path, new_ver, override_repr=False): - loaded_config_classes.append(config_class) - return object(), True - - monkeypatch.setattr(config_module, "load_config_from_file", fake_load_config_from_file) - monkeypatch.setattr(ConfigManager, "_warn_if_vlm_not_configured", lambda self, model_config: warnings.append(model_config)) - - manager.initialize() - - assert loaded_config_classes == [Config, ModelConfig] - assert warnings == [manager.model_config] diff --git a/pytests/config_test/test_file_watcher.py b/pytests/config_test/test_file_watcher.py deleted file mode 100644 index 27fab95e..00000000 --- a/pytests/config_test/test_file_watcher.py +++ /dev/null @@ -1,138 +0,0 @@ -from pathlib import Path - -from watchfiles import Change - -import asyncio -import pytest - -from src.config.file_watcher import FileChange, FileWatcher - -from typing import Sequence - - -@pytest.mark.asyncio -async def test_dispatch_changes_with_path_and_change_type_filters(tmp_path: Path): - watcher = FileWatcher(paths=[tmp_path]) - target_file = (tmp_path / "bot_config.toml").resolve() - - received: list[list[FileChange]] = [] - - async def callback(changes): - received.append(list(changes)) - - watcher.subscribe(callback, paths=[target_file], change_types=[Change.modified]) - - await watcher._dispatch_changes( - [ - FileChange(change_type=Change.added, path=target_file), - FileChange(change_type=Change.modified, path=target_file), - FileChange(change_type=Change.modified, path=(tmp_path / "other.toml").resolve()), - ] - ) - - assert len(received) == 1 - assert len(received[0]) == 1 - assert received[0][0].change_type == Change.modified - assert received[0][0].path == target_file - - -@pytest.mark.asyncio -async def test_sync_callback_supported(tmp_path: Path): - watcher = FileWatcher(paths=[tmp_path]) - target_file = (tmp_path / "model_config.toml").resolve() - - received_paths: list[Path] = [] - - def sync_callback(changes): - received_paths.extend(change.path for change in changes) - - watcher.subscribe(sync_callback, paths=[target_file]) - - await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)]) - - assert received_paths == [target_file] - - -@pytest.mark.asyncio -async def test_callback_timeout_and_cooldown(tmp_path: Path): - watcher = FileWatcher( - paths=[tmp_path], - callback_timeout_s=0.05, - callback_failure_threshold=2, - callback_cooldown_s=0.2, - ) - target_file = (tmp_path / "bot_config.toml").resolve() - - async def slow_callback(changes): - await asyncio.sleep(0.2) - - watcher.subscribe(slow_callback, paths=[target_file]) - - await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)]) - await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)]) - - stats_after_failures = watcher.stats - assert stats_after_failures.callbacks_timed_out == 2 - assert stats_after_failures.callbacks_failed == 2 - - await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)]) - stats_after_cooldown_skip = watcher.stats - assert stats_after_cooldown_skip.callbacks_skipped_cooldown >= 1 - - -@pytest.mark.asyncio -async def test_start_requires_subscription(tmp_path: Path): - watcher = FileWatcher(paths=[tmp_path]) - - with pytest.raises(RuntimeError): - await watcher.start() - - -@pytest.mark.asyncio -async def test_unsubscribe_stops_dispatch(tmp_path: Path): - watcher = FileWatcher(paths=[tmp_path]) - target_file = (tmp_path / "bot_config.toml").resolve() - - calls = 0 - - async def callback(changes): - nonlocal calls - calls += 1 - - subscription_id = watcher.subscribe(callback, paths=[target_file]) - assert watcher.unsubscribe(subscription_id) is True - - await watcher._dispatch_changes([FileChange(change_type=Change.modified, path=target_file)]) - - assert calls == 0 - - -@pytest.mark.asyncio -async def test_add_callback_while_watcher_running(tmp_path: Path): - dirs = (tmp_path / "a_dir").resolve() - dirs.mkdir(exist_ok=True) - file = (dirs / "a.toml").resolve() - file.touch() - watcher = FileWatcher(paths=[dirs], debounce_ms=200) - - calls = 0 - - async def callback(changes: Sequence[FileChange]): - nonlocal calls - print(f"Callback called with changes: {[f'{change.change_type} {change.path}' for change in changes]}") - calls += 1 - - uuid = watcher.subscribe(callback, paths=[file]) - await watcher.start() - try: - with file.open("w") as f: - f.write("change") - await asyncio.sleep(0.5) - assert calls == 1 - watcher.unsubscribe(uuid) - with file.open("w") as f: - f.write("change2") - await asyncio.sleep(0.5) - assert calls == 1 - finally: - await watcher.stop() diff --git a/pytests/config_test/test_llm_request_hot_reload.py b/pytests/config_test/test_llm_request_hot_reload.py deleted file mode 100644 index b6a6517b..00000000 --- a/pytests/config_test/test_llm_request_hot_reload.py +++ /dev/null @@ -1,76 +0,0 @@ -from types import SimpleNamespace -from importlib import util -from pathlib import Path - -from src.config.config import config_manager -from src.config.model_configs import TaskConfig -from src.llm_models.utils_model import LLMRequest - - -def _load_llm_api_module(): - file_path = Path(__file__).parent.parent.parent / "src" / "plugin_system" / "apis" / "llm_api.py" - spec = util.spec_from_file_location("test_llm_api_module", file_path) - assert spec is not None and spec.loader is not None - module = util.module_from_spec(spec) - spec.loader.exec_module(module) - return module - - -def _make_model_config(task_config: TaskConfig, attr_name: str = "utils"): - model_task_config = SimpleNamespace(**{attr_name: task_config}) - return SimpleNamespace(model_task_config=model_task_config, models=[], api_providers=[]) - - -def test_llm_request_resolve_task_config_by_signature(monkeypatch): - old_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0) - current_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0) - - monkeypatch.setattr(config_manager, "get_model_config", lambda: _make_model_config(current_task, "utils")) - - req = LLMRequest(model_set=old_task, request_type="test") - - assert req._task_config_name == "utils" - - -def test_llm_request_refresh_task_config_updates_runtime_state(monkeypatch): - old_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0) - initial_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0) - updated_task = TaskConfig(model_list=["gpt-b", "gpt-c"], max_tokens=1024, temperature=0.5, slow_threshold=20.0) - - current = {"task": initial_task} - - def get_model_config_stub(): - return _make_model_config(current["task"], "replyer") - - monkeypatch.setattr(config_manager, "get_model_config", get_model_config_stub) - - req = LLMRequest(model_set=old_task, request_type="test") - assert req._task_config_name == "replyer" - - current["task"] = updated_task - req._refresh_task_config() - - assert req.model_for_task.model_list == ["gpt-b", "gpt-c"] - assert list(req.model_usage.keys()) == ["gpt-b", "gpt-c"] - - -def test_llm_api_get_available_models_reads_latest_config(monkeypatch): - llm_api = _load_llm_api_module() - - first_utils = TaskConfig(model_list=["gpt-a"]) - second_utils = TaskConfig(model_list=["gpt-z"]) - - state = {"task": first_utils} - - def get_model_config_stub(): - model_task_config = SimpleNamespace(utils=state["task"], planner=TaskConfig(model_list=["gpt-p"])) - return SimpleNamespace(model_task_config=model_task_config) - - monkeypatch.setattr(config_manager, "get_model_config", get_model_config_stub) - - first = llm_api.get_available_models() - assert first["utils"].model_list == ["gpt-a"] - - state["task"] = second_utils - second = llm_api.get_available_models() - assert second["utils"].model_list == ["gpt-z"] diff --git a/pytests/config_test/test_model_info_normalization.py b/pytests/config_test/test_model_info_normalization.py deleted file mode 100644 index 72db7ea6..00000000 --- a/pytests/config_test/test_model_info_normalization.py +++ /dev/null @@ -1,11 +0,0 @@ -from src.config.model_configs import ModelInfo - - -def test_model_identifier_strips_surrounding_whitespace() -> None: - model_info = ModelInfo( - api_provider="test-provider", - model_identifier=" glm-5.1 ", - name="test-model", - ) - - assert model_info.model_identifier == "glm-5.1" diff --git a/pytests/config_test/test_startup_bindings.py b/pytests/config_test/test_startup_bindings.py deleted file mode 100644 index d11c436e..00000000 --- a/pytests/config_test/test_startup_bindings.py +++ /dev/null @@ -1,104 +0,0 @@ -from pathlib import Path -from types import SimpleNamespace -import sys - -from src.config.legacy_migration import migrate_legacy_bind_env_to_bot_config_dict -from src.config.startup_bindings import ( - BindAddress, - get_startup_main_bind_address, - get_startup_webui_bind_address, - resolve_main_bind_address, - resolve_webui_bind_address, -) - - -def test_startup_bindings_use_defaults_when_config_file_missing(tmp_path: Path): - missing_path = tmp_path / "missing_bot_config.toml" - - assert get_startup_main_bind_address(missing_path) == BindAddress(host="127.0.0.1", port=8080) - assert get_startup_webui_bind_address(missing_path) == BindAddress(host="127.0.0.1", port=8001) - - -def test_startup_bindings_can_read_addresses_from_bot_config(tmp_path: Path): - config_path = tmp_path / "bot_config.toml" - config_path.write_text( - """ -[inner] -version = "8.3.1" - -[maim_message] -ws_server_host = "0.0.0.0" -ws_server_port = 22345 - -[webui] -host = "192.168.1.9" -port = 18001 -""".strip(), - encoding="utf-8", - ) - - assert get_startup_main_bind_address(config_path) == BindAddress(host="0.0.0.0", port=22345) - assert get_startup_webui_bind_address(config_path) == BindAddress(host="192.168.1.9", port=18001) - - -def test_resolve_bindings_prefer_initialized_global_config(monkeypatch): - fake_config_module = SimpleNamespace( - global_config=SimpleNamespace( - maim_message=SimpleNamespace(ws_server_host="10.0.0.2", ws_server_port=32000), - webui=SimpleNamespace(host="10.0.0.3", port=32001), - ) - ) - - monkeypatch.setitem(sys.modules, "src.config.config", fake_config_module) - - assert resolve_main_bind_address() == BindAddress(host="10.0.0.2", port=32000) - assert resolve_webui_bind_address() == BindAddress(host="10.0.0.3", port=32001) - - -def test_legacy_env_bindings_are_migrated_when_fields_missing_or_default(monkeypatch): - monkeypatch.setenv("HOST", "0.0.0.0") - monkeypatch.setenv("PORT", "22345") - monkeypatch.setenv("WEBUI_HOST", "192.168.1.8") - monkeypatch.setenv("WEBUI_PORT", "19001") - - payload = { - "maim_message": { - "ws_server_host": "127.0.0.1", - "ws_server_port": 8080, - }, - "webui": {}, - } - - result = migrate_legacy_bind_env_to_bot_config_dict(payload) - - assert result.migrated is True - assert payload["maim_message"]["ws_server_host"] == "0.0.0.0" - assert payload["maim_message"]["ws_server_port"] == 22345 - assert payload["webui"]["host"] == "192.168.1.8" - assert payload["webui"]["port"] == 19001 - - -def test_legacy_env_bindings_do_not_override_explicit_config(monkeypatch): - monkeypatch.setenv("HOST", "0.0.0.0") - monkeypatch.setenv("PORT", "22345") - monkeypatch.setenv("WEBUI_HOST", "192.168.1.8") - monkeypatch.setenv("WEBUI_PORT", "19001") - - payload = { - "maim_message": { - "ws_server_host": "10.1.1.1", - "ws_server_port": 30000, - }, - "webui": { - "host": "10.1.1.2", - "port": 30001, - }, - } - - result = migrate_legacy_bind_env_to_bot_config_dict(payload) - - assert result.migrated is False - assert payload["maim_message"]["ws_server_host"] == "10.1.1.1" - assert payload["maim_message"]["ws_server_port"] == 30000 - assert payload["webui"]["host"] == "10.1.1.2" - assert payload["webui"]["port"] == 30001 diff --git a/pytests/conftest.py b/pytests/conftest.py deleted file mode 100644 index 0ad261ce..00000000 --- a/pytests/conftest.py +++ /dev/null @@ -1,10 +0,0 @@ -import sys -from pathlib import Path - -# Add project root to Python path so src imports work -project_root = Path(__file__).parent.parent.absolute() -src_root = project_root / "src" -if str(src_root) not in sys.path: - sys.path.insert(0, str(src_root)) -if str(project_root) not in sys.path: - sys.path.insert(1, str(project_root)) diff --git a/pytests/i18n_test/test_i18n.py b/pytests/i18n_test/test_i18n.py deleted file mode 100644 index 9c31f842..00000000 --- a/pytests/i18n_test/test_i18n.py +++ /dev/null @@ -1,66 +0,0 @@ -from __future__ import annotations - -from pathlib import Path - -import json - -import pytest - -from src.common.i18n.manager import I18nManager -from src.common.i18n.loaders import DuplicateTranslationKeyError, load_locale_catalog - - -def write_locale_file(locales_root: Path, locale: str, file_name: str, payload: dict[str, object]) -> None: - locale_dir = locales_root / locale - locale_dir.mkdir(parents=True, exist_ok=True) - file_path = locale_dir / file_name - file_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") - - -def test_t_falls_back_to_default_locale(tmp_path: Path) -> None: - locales_root = tmp_path / "locales" - write_locale_file(locales_root, "zh-CN", "core.json", {"greeting": "你好,{name}"}) - write_locale_file(locales_root, "en-US", "core.json", {}) - - manager = I18nManager(locales_root=locales_root) - - assert manager.t("greeting", locale="en-US", name="Mai") == "你好,Mai" - - -def test_t_returns_key_when_missing_everywhere(tmp_path: Path) -> None: - locales_root = tmp_path / "locales" - write_locale_file(locales_root, "zh-CN", "core.json", {}) - write_locale_file(locales_root, "en-US", "core.json", {}) - - manager = I18nManager(locales_root=locales_root) - - assert manager.t("missing.key", locale="en-US") == "missing.key" - - -def test_tn_uses_plural_rules(tmp_path: Path) -> None: - locales_root = tmp_path / "locales" - write_locale_file( - locales_root, - "en-US", - "core.json", - { - "tasks.cancelled": { - "one": "Cancelled {count} task", - "other": "Cancelled {count} tasks", - } - }, - ) - - manager = I18nManager(default_locale="en-US", locales_root=locales_root) - - assert manager.tn("tasks.cancelled", 1) == "Cancelled 1 task" - assert manager.tn("tasks.cancelled", 2) == "Cancelled 2 tasks" - - -def test_load_locale_catalog_rejects_duplicate_keys(tmp_path: Path) -> None: - locales_root = tmp_path / "locales" - write_locale_file(locales_root, "zh-CN", "a.json", {"duplicate.key": "A"}) - write_locale_file(locales_root, "zh-CN", "b.json", {"duplicate.key": "B"}) - - with pytest.raises(DuplicateTranslationKeyError): - load_locale_catalog("zh-CN", locales_root) diff --git a/pytests/i18n_test/test_i18n_validate.py b/pytests/i18n_test/test_i18n_validate.py deleted file mode 100644 index 6e83b28e..00000000 --- a/pytests/i18n_test/test_i18n_validate.py +++ /dev/null @@ -1,110 +0,0 @@ -from __future__ import annotations - -from importlib.util import module_from_spec, spec_from_file_location -from pathlib import Path - -import json - -SCRIPT_PATH = Path(__file__).resolve().parents[2] / "scripts" / "i18n_validate.py" -MODULE_SPEC = spec_from_file_location("i18n_validate_script", SCRIPT_PATH) -assert MODULE_SPEC is not None -assert MODULE_SPEC.loader is not None -I18N_VALIDATE = module_from_spec(MODULE_SPEC) -MODULE_SPEC.loader.exec_module(I18N_VALIDATE) - - -def write_locale_file(locales_root: Path, locale: str, file_name: str, payload: dict[str, object]) -> None: - locale_dir = locales_root / locale - locale_dir.mkdir(parents=True, exist_ok=True) - (locale_dir / file_name).write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") - - -def write_dashboard_locale_file(locales_root: Path, locale: str, payload: dict[str, object]) -> None: - locales_root.mkdir(parents=True, exist_ok=True) - (locales_root / f"{locale}.json").write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") - - -def test_validate_json_locales_rejects_han_characters_in_english_locale(tmp_path: Path) -> None: - locales_root = tmp_path / "locales" - write_locale_file(locales_root, "zh-CN", "core.json", {"consent.prompt": '输入"同意"继续'}) - write_locale_file(locales_root, "en-US", "core.json", {"consent.prompt": 'Type "confirmed" or "同意" to continue'}) - - errors = I18N_VALIDATE.validate_json_locales(locales_root) - - assert any("consent.prompt" in error and "仍包含中文字符" in error for error in errors) - - -def test_validate_json_locales_rejects_untranslated_han_source_in_other_target_locales(tmp_path: Path) -> None: - locales_root = tmp_path / "locales" - write_locale_file(locales_root, "zh-CN", "core.json", {"greeting": "你好,世界"}) - write_locale_file(locales_root, "ja", "core.json", {"greeting": "你好,世界"}) - - errors = I18N_VALIDATE.validate_json_locales(locales_root) - - assert any("greeting" in error and "直接保留了包含中文字符的 source 文案" in error for error in errors) - - -def test_validate_json_locales_avoids_false_positive_when_plural_categories_do_not_align(tmp_path: Path) -> None: - locales_root = tmp_path / "locales" - write_locale_file( - locales_root, - "zh-CN", - "core.json", - { - "tasks.cancelled": { - "one": "中文单数", - "other": "中文复数", - } - }, - ) - write_locale_file( - locales_root, - "ja", - "core.json", - { - "tasks.cancelled": { - "many": "中文单数", - "other": "已翻译", - } - }, - ) - - errors = I18N_VALIDATE.validate_json_locales(locales_root) - - assert any("tasks.cancelled" in error and "plural category 不一致" in error for error in errors) - assert not any("tasks.cancelled" in error and "直接保留了包含中文字符的 source 文案" in error for error in errors) - - -def test_validate_dashboard_json_locales_rejects_han_characters_in_english_locale(tmp_path: Path) -> None: - locales_root = tmp_path / "dashboard-locales" - write_dashboard_locale_file(locales_root, "zh", {"common": {"greeting": "你好,世界"}}) - write_dashboard_locale_file(locales_root, "en", {"common": {"greeting": "Hello 同意"}}) - - errors = I18N_VALIDATE.validate_dashboard_json_locales(locales_root) - - assert any("dashboard:en" in error and "common.greeting" in error and "仍包含中文字符" in error for error in errors) - - -def test_validate_dashboard_json_locales_rejects_untranslated_han_source_in_other_target_locales( - tmp_path: Path, -) -> None: - locales_root = tmp_path / "dashboard-locales" - write_dashboard_locale_file(locales_root, "zh", {"common": {"greeting": "你好,世界"}}) - write_dashboard_locale_file(locales_root, "ja", {"common": {"greeting": "你好,世界"}}) - - errors = I18N_VALIDATE.validate_dashboard_json_locales(locales_root) - - assert any( - "dashboard:ja" in error and "common.greeting" in error and "直接保留了包含中文字符的 source 文案" in error - for error in errors - ) - - -def test_validate_dashboard_json_locales_rejects_i18next_placeholder_drift(tmp_path: Path) -> None: - locales_root = tmp_path / "dashboard-locales" - write_dashboard_locale_file(locales_root, "zh", {"status": {"checkingDesc": "等待服务恢复... ({{current}}/{{max}})"}}) - write_dashboard_locale_file(locales_root, "ko", {"status": {"checkingDesc": "서비스 복구 대기 중... ({{current}}/{{limit}})"}}) - - errors = I18N_VALIDATE.validate_dashboard_json_locales(locales_root) - - assert any("dashboard:ko" in error and "status.checkingDesc" in error and "占位符集合与 source 不一致" in error for error in errors) diff --git a/pytests/image_sys_test/emoji_manager_test.py b/pytests/image_sys_test/emoji_manager_test.py deleted file mode 100644 index d6c77c95..00000000 --- a/pytests/image_sys_test/emoji_manager_test.py +++ /dev/null @@ -1,2637 +0,0 @@ -# 本文件为测试文件,含有大量的MonkeyPatch和Mock代码,请忽略TypeChecker的报错 -import importlib.util -import sys -from dataclasses import dataclass -from types import ModuleType -from pathlib import Path - -import asyncio -import pytest - - -def _install_stub_modules(monkeypatch): - def _stub_module(name: str) -> ModuleType: - module = ModuleType(name) - monkeypatch.setitem(sys.modules, name, module) - return module - - # src.common.logger - logger_mod = _stub_module("src.common.logger") - - class _Logger: - def __init__(self): - self.info_calls = [] - self.debug_calls = [] - self.warning_calls = [] - self.error_calls = [] - self.critical_calls = [] - - def info(self, *args, **kwargs): - self.info_calls.append(args) - - def debug(self, *args, **kwargs): - self.debug_calls.append(args) - - def warning(self, *args, **kwargs): - self.warning_calls.append(args) - - def error(self, *args, **kwargs): - self.error_calls.append(args) - - def critical(self, *args, **kwargs): - self.critical_calls.append(args) - - def get_logger(_name: str): - return _Logger() - - logger_mod.get_logger = get_logger - - # src.common.data_models.image_data_model - data_model_mod = _stub_module("src.common.data_models.image_data_model") - - @dataclass - class MaiEmoji: - full_path: Path | None = None - file_name: str = "" - description: str | None = None - emotion: list[str] | None = None - file_hash: str | None = None - query_count: int = 0 - register_time: object | None = None - image_format: str | None = None - image_bytes: bytes | None = None - - @staticmethod - def from_db_instance(_record): - return MaiEmoji() - - def to_db_instance(self): - return Images() - - async def calculate_hash_format(self): - return True - - @staticmethod - def read_image_bytes(_path): - return b"" - - data_model_mod.MaiEmoji = MaiEmoji - - # src.common.database.database_model - db_model_mod = _stub_module("src.common.database.database_model") - - class Images: - id = 0 - is_registered = False - is_banned = False - no_file_flag = False - register_time = None - query_count = 0 - last_used_time = None - full_path = "" - image_hash = "" - image_type = None - - class ImageType: - EMOJI = "EMOJI" - - db_model_mod.Images = Images - db_model_mod.ImageType = ImageType - - # src.common.database.database - db_mod = _stub_module("src.common.database.database") - - class _DummySession: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def exec(self, _statement): - class _Result: - def scalars(self): - return self - - def all(self): - return [] - - def first(self): - return None - - return _Result() - - def add(self, _record): - pass - - def delete(self, _record): - pass - - def flush(self): - pass - - def commit(self): - pass - - def get_db_session(): - return _DummySession() - - def get_db_session_manual(): - return _DummySession() - - db_mod.get_db_session = get_db_session - db_mod.get_db_session_manual = get_db_session_manual - - # src.common.utils.utils_image - image_utils_mod = _stub_module("src.common.utils.utils_image") - - class ImageUtils: - @staticmethod - def gif_2_static_image(_image_bytes): - return b"" - - @staticmethod - def image_bytes_to_base64(_image_bytes): - return "" - - image_utils_mod.ImageUtils = ImageUtils - - # src.prompt.prompt_manager - prompt_manager_mod = _stub_module("src.prompt.prompt_manager") - - class _Prompt: - def add_context(self, _key, _value): - pass - - class _PromptManager: - def get_prompt(self, _name): - return _Prompt() - - async def render_prompt(self, _prompt): - return "" - - prompt_manager_mod.prompt_manager = _PromptManager() - - # src.config.config - config_mod = _stub_module("src.config.config") - - class _EmojiConfig: - max_reg_num = 20 - content_filtration = False - steal_emoji = False - do_replace = False - check_interval = 1 - - class _BotConfig: - nickname = "bot" - - class _ModelTaskConfig: - vlm = None - utils = None - - class _ModelConfig: - model_task_config = _ModelTaskConfig() - - class _GlobalConfig: - emoji = _EmojiConfig() - bot = _BotConfig() - - class _ConfigManager: - def __init__(self): - self.reload_callbacks = [] - - def register_reload_callback(self, callback): - self.reload_callbacks.append(callback) - - def unregister_reload_callback(self, callback): - if callback in self.reload_callbacks: - self.reload_callbacks.remove(callback) - - config_mod.global_config = _GlobalConfig() - config_mod.model_config = _ModelConfig() - config_mod.config_manager = _ConfigManager() - - # src.llm_models.utils_model - llm_mod = _stub_module("src.llm_models.utils_model") - - class LLMRequest: - def __init__(self, *args, **kwargs): - pass - - async def generate_response_async(self, *args, **kwargs): - return "", None - - async def generate_response_for_image(self, *args, **kwargs): - return "", None - - llm_mod.LLMRequest = LLMRequest - - # third-party stubs - rich_traceback_mod = _stub_module("rich.traceback") - - def install(*_args, **_kwargs): - pass - - rich_traceback_mod.install = install - - sqlmodel_mod = _stub_module("sqlmodel") - - def select(_model): - return object() - - sqlmodel_mod.select = select - - levenshtein_mod = _stub_module("Levenshtein") - - def distance(a, b): - return abs(len(str(a)) - len(str(b))) - - levenshtein_mod.distance = distance - - -def import_emoji_manager_new(monkeypatch): - _install_stub_modules(monkeypatch) - file_path = Path(__file__).resolve().parents[2] / "src" / "chat" / "emoji_system" / "emoji_manager.py" - spec = importlib.util.spec_from_file_location("emoji_manager", file_path) - module = importlib.util.module_from_spec(spec) - monkeypatch.setitem(sys.modules, "emoji_manager_new", module) - spec.loader.exec_module(module) - - class _Select: - def filter_by(self, **kwargs): - return self - - def limit(self, n): - return self - - module.select = lambda _model: _Select() - return module - - -def _messages(calls): - return [" ".join(map(str, args)) for args in calls] - - -@pytest.mark.asyncio -async def test_replace_an_emoji_by_llm_decision_no_delete(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - manager.emojis = [emoji_manager_new.MaiEmoji()] - - async def _generate_response_async(*_args, **_kwargs): - return "不删除", None - - monkeypatch.setattr( - emoji_manager_new.emoji_manager_emotion_judge_llm, - "generate_response_async", - _generate_response_async, - ) - - result = await manager.replace_an_emoji_by_llm(emoji_manager_new.MaiEmoji()) - - assert result is False - assert any("不删除任何表情包" in m for m in _messages(logger.info_calls)) - - -@pytest.mark.asyncio -async def test_replace_an_emoji_by_llm_decision_parse_error(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - manager.emojis = [emoji_manager_new.MaiEmoji()] - - async def _generate_response_async(*_args, **_kwargs): - return "删除编号1", None - - def _bad_search(*_args, **_kwargs): - raise RuntimeError("search failed") - - monkeypatch.setattr( - emoji_manager_new.emoji_manager_emotion_judge_llm, - "generate_response_async", - _generate_response_async, - ) - monkeypatch.setattr(emoji_manager_new.re, "search", _bad_search) - - result = await manager.replace_an_emoji_by_llm(emoji_manager_new.MaiEmoji()) - - assert result is False - assert any("解析决策结果时出错" in m for m in _messages(logger.error_calls)) - - -@pytest.mark.asyncio -async def test_replace_an_emoji_by_llm_decision_missing_number(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - manager.emojis = [emoji_manager_new.MaiEmoji()] - - async def _generate_response_async(*_args, **_kwargs): - return "删除编号ABC", None - - monkeypatch.setattr( - emoji_manager_new.emoji_manager_emotion_judge_llm, - "generate_response_async", - _generate_response_async, - ) - - result = await manager.replace_an_emoji_by_llm(emoji_manager_new.MaiEmoji()) - - assert result is False - assert any("未能解析删除编号" in m for m in _messages(logger.error_calls)) - - -@pytest.mark.asyncio -async def test_replace_an_emoji_by_llm_decision_index_out_of_range(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - manager.emojis = [emoji_manager_new.MaiEmoji()] - - async def _generate_response_async(*_args, **_kwargs): - return "删除编号3", None - - monkeypatch.setattr( - emoji_manager_new.emoji_manager_emotion_judge_llm, - "generate_response_async", - _generate_response_async, - ) - - result = await manager.replace_an_emoji_by_llm(emoji_manager_new.MaiEmoji()) - - assert result is False - assert any("无效的表情包编号" in m for m in _messages(logger.error_calls)) - - -@pytest.mark.asyncio -async def test_replace_an_emoji_by_llm_delete_failed(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - manager.emojis = [emoji_manager_new.MaiEmoji()] - - async def _generate_response_async(*_args, **_kwargs): - return "删除编号1", None - - def _delete(_emoji): - return False - - monkeypatch.setattr( - emoji_manager_new.emoji_manager_emotion_judge_llm, - "generate_response_async", - _generate_response_async, - ) - monkeypatch.setattr(manager, "delete_emoji", _delete) - - result = await manager.replace_an_emoji_by_llm(emoji_manager_new.MaiEmoji()) - - assert result is False - assert any("删除表情包失败" in m for m in _messages(logger.error_calls)) - - -@pytest.mark.asyncio -async def test_replace_an_emoji_by_llm_register_failed(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - manager.emojis = [emoji_manager_new.MaiEmoji()] - - async def _generate_response_async(*_args, **_kwargs): - return "删除编号1", None - - def _delete(_emoji): - return True - - def _register(_emoji): - return False - - monkeypatch.setattr( - emoji_manager_new.emoji_manager_emotion_judge_llm, - "generate_response_async", - _generate_response_async, - ) - monkeypatch.setattr(manager, "delete_emoji", _delete) - monkeypatch.setattr(manager, "register_emoji_to_db", _register) - - result = await manager.replace_an_emoji_by_llm(emoji_manager_new.MaiEmoji()) - - assert result is False - assert any("注册新表情包失败" in m for m in _messages(logger.error_calls)) - - -@pytest.mark.asyncio -async def test_replace_an_emoji_by_llm_success(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - old_emoji = emoji_manager_new.MaiEmoji() - old_emoji.description = "old" - manager.emojis = [old_emoji] - - async def _generate_response_async(*_args, **_kwargs): - return "删除编号1", None - - def _delete(_emoji): - return True - - def _register(_emoji): - return True - - monkeypatch.setattr( - emoji_manager_new.emoji_manager_emotion_judge_llm, - "generate_response_async", - _generate_response_async, - ) - monkeypatch.setattr(manager, "delete_emoji", _delete) - monkeypatch.setattr(manager, "register_emoji_to_db", _register) - - new_emoji = emoji_manager_new.MaiEmoji() - new_emoji.description = "new" - - result = await manager.replace_an_emoji_by_llm(new_emoji) - - assert result is True - assert new_emoji in manager.emojis - assert old_emoji not in manager.emojis - assert any("成功替换并注册新表情包" in m for m in _messages(logger.info_calls)) - - -def test_load_emojis_from_db_empty(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - - class _Result: - def scalars(self): - return self - - def all(self): - return [] - - class _Session: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def exec(self, _statement): - return _Result() - - def _get_db_session(): - return _Session() - - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - manager = emoji_manager_new.EmojiManager() - - manager.load_emojis_from_db() - - assert manager.emojis == [] - assert manager._emoji_num == 0 - assert any("成功加载" in m for m in _messages(logger.info_calls)) - - -def test_emoji_manager_registers_reload_callback(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - - assert emoji_manager_new.emoji_manager.reload_runtime_config in emoji_manager_new.config_manager.reload_callbacks - - -def test_emoji_manager_shutdown_unregisters_reload_callback(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - manager = emoji_manager_new.EmojiManager() - - assert manager.reload_runtime_config in emoji_manager_new.config_manager.reload_callbacks - - manager.shutdown() - - assert manager.reload_runtime_config not in emoji_manager_new.config_manager.reload_callbacks - - # 重复调用应保持幂等,不应抛错也不应重复注册 - manager.shutdown() - - assert manager.reload_runtime_config not in emoji_manager_new.config_manager.reload_callbacks - - -@pytest.mark.asyncio -async def test_reload_runtime_config_wakes_maintenance_loop(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - manager = emoji_manager_new.EmojiManager() - - emoji_manager_new.global_config.emoji.steal_emoji = False - emoji_manager_new.global_config.emoji.check_interval = 60 - - maintenance_runs = 0 - second_run_event = asyncio.Event() - - def _check_emoji_file_integrity(): - nonlocal maintenance_runs - maintenance_runs += 1 - if maintenance_runs >= 2: - second_run_event.set() - - monkeypatch.setattr(manager, "check_emoji_file_integrity", _check_emoji_file_integrity) - monkeypatch.setattr(manager, "remove_untracked_emoji_files", lambda: None) - - task = asyncio.create_task(manager.periodic_emoji_maintenance()) - try: - await asyncio.sleep(0.05) - assert maintenance_runs >= 1 - - manager.reload_runtime_config() - - await asyncio.wait_for(second_run_event.wait(), timeout=0.2) - finally: - task.cancel() - with pytest.raises(asyncio.CancelledError): - await task - - -def test_load_emojis_from_db_partial_bad_records(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - - class _Record: - def __init__(self, record_id, full_path): - self.id = record_id - self.full_path = full_path - self.image_type = emoji_manager_new.ImageType.EMOJI - self.no_file_flag = False - self.is_banned = False - - records = [_Record(1, "bad"), _Record(2, "ok")] - - class _Result: - def scalars(self): - return self - - def all(self): - return records - - class _Session: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def exec(self, _statement): - return _Result() - - def _get_db_session(): - return _Session() - - def _from_db_instance(record): - if record.id == 1: - raise ValueError("bad record") - emoji = emoji_manager_new.MaiEmoji() - emoji.file_name = "ok" - return emoji - - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - monkeypatch.setattr(emoji_manager_new.MaiEmoji, "from_db_instance", staticmethod(_from_db_instance)) - manager = emoji_manager_new.EmojiManager() - - manager.load_emojis_from_db() - - assert len(manager.emojis) == 1 - assert manager.emojis[0].file_name == "ok" - assert manager._emoji_num == 1 - assert any("加载表情包记录时出错" in m for m in _messages(logger.error_calls)) - assert any("成功加载" in m for m in _messages(logger.info_calls)) - - -def test_load_emojis_from_db_execute_error(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - - class _Session: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def exec(self, _statement): - raise RuntimeError("execute failed") - - def _get_db_session(): - return _Session() - - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - manager = emoji_manager_new.EmojiManager() - manager.emojis = [emoji_manager_new.MaiEmoji()] - manager._emoji_num = 1 - - with pytest.raises(RuntimeError): - manager.load_emojis_from_db() - - assert manager.emojis == [] - assert manager._emoji_num == 0 - assert any("不可恢复错误" in m for m in _messages(logger.critical_calls)) - - -def test_load_emojis_from_db_get_db_session_error(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - - def _get_db_session(): - raise RuntimeError("get_db_session failed") - - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - manager = emoji_manager_new.EmojiManager() - manager.emojis = [emoji_manager_new.MaiEmoji()] - manager._emoji_num = 1 - - with pytest.raises(RuntimeError): - manager.load_emojis_from_db() - - assert manager.emojis == [] - assert manager._emoji_num == 0 - assert any("不可恢复错误" in m for m in _messages(logger.critical_calls)) - - -def test_load_emojis_from_db_scalars_all_error(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - - class _Result: - def scalars(self): - return self - - def all(self): - raise RuntimeError("all failed") - - class _Session: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def exec(self, _statement): - return _Result() - - def _get_db_session(): - return _Session() - - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - manager = emoji_manager_new.EmojiManager() - manager.emojis = [emoji_manager_new.MaiEmoji()] - manager._emoji_num = 1 - - with pytest.raises(RuntimeError): - manager.load_emojis_from_db() - - assert manager.emojis == [] - assert manager._emoji_num == 0 - assert any("不可恢复错误" in m for m in _messages(logger.critical_calls)) - - -def test_load_emojis_from_db_skips_filtered_records(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - - class _Record: - def __init__(self, record_id, full_path, image_type, no_file_flag=False, is_banned=False): - self.id = record_id - self.full_path = full_path - self.image_type = image_type - self.no_file_flag = no_file_flag - self.is_banned = is_banned - - records = [ - _Record(1, "img.png", "IMAGE"), - _Record(2, "nofile.png", emoji_manager_new.ImageType.EMOJI, no_file_flag=True), - _Record(3, "banned.png", emoji_manager_new.ImageType.EMOJI, is_banned=True), - _Record(4, "ok.png", emoji_manager_new.ImageType.EMOJI), - ] - - class _Result: - def all(self): - return records - - class _Session: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def exec(self, _statement): - return _Result() - - def _get_db_session(): - return _Session() - - created = [] - - def _from_db_instance(record): - emoji = emoji_manager_new.MaiEmoji() - emoji.file_name = record.full_path - created.append(record.id) - return emoji - - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - monkeypatch.setattr(emoji_manager_new.MaiEmoji, "from_db_instance", staticmethod(_from_db_instance)) - manager = emoji_manager_new.EmojiManager() - - manager.load_emojis_from_db() - - assert created == [4] - assert len(manager.emojis) == 1 - assert manager._emoji_num == 1 - assert any("成功加载" in m for m in _messages(logger.info_calls)) - - -def test_register_emoji_to_db_invalid_object(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - result = manager.register_emoji_to_db(None) - - assert result is False - assert any("无效的表情包对象" in m for m in _messages(logger.error_calls)) - - -def test_register_emoji_to_db_wrong_type(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - result = manager.register_emoji_to_db(object()) - - assert result is False - assert any("无效的表情包对象" in m for m in _messages(logger.error_calls)) - - -def test_register_emoji_to_db_file_missing(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - emoji = emoji_manager_new.MaiEmoji() - emoji.full_path = Path("/missing/file.png") - - result = manager.register_emoji_to_db(emoji) - - assert result is False - assert any("表情包文件不存在" in m for m in _messages(logger.error_calls)) - - -def test_register_emoji_to_db_move_error(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - class _DummyPath: - def __init__(self): - self._name = "a.png" - self._exists = True - - def exists(self): - return self._exists - - def replace(self, _target): - raise RuntimeError("move failed") - - @property - def name(self): - return self._name - - emoji = emoji_manager_new.MaiEmoji() - emoji.full_path = _DummyPath() - emoji.file_name = "a.png" - - result = manager.register_emoji_to_db(emoji) - - assert result is False - assert any("移动表情包文件时出错" in m for m in _messages(logger.error_calls)) - - -def test_register_emoji_to_db_db_error(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - class _DummyPath: - def __init__(self): - self._name = "a.png" - self._exists = True - self._replaced = False - - def exists(self): - return self._exists - - def replace(self, _target): - self._replaced = True - - @property - def name(self): - return self._name - - emoji = emoji_manager_new.MaiEmoji() - emoji.full_path = _DummyPath() - emoji.file_name = "a.png" - - class _Session: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def add(self, _record): - raise RuntimeError("db add failed") - - def flush(self): - pass - - def exec(self, _statement): - return self - - def first(self): - return None - - def _get_db_session(): - return _Session() - - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - - result = manager.register_emoji_to_db(emoji) - - assert result is False - assert any("注册到数据库时出错" in m for m in _messages(logger.error_calls)) - - -def test_register_emoji_to_db_success(monkeypatch, tmp_path): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - class _DummyPath: - def __init__(self, name): - self._name = name - self._exists = True - self._replaced = False - self._target = None - - def exists(self): - return self._exists - - def replace(self, target): - self._replaced = True - self._target = target - - @property - def name(self): - return self._name - - emoji = emoji_manager_new.MaiEmoji() - emoji.full_path = _DummyPath("a.png") - emoji.file_name = "a.png" - - class _Record: - def __init__(self): - self.id = 123 - self.is_registered = False - self.is_banned = False - self.register_time = None - - class _Session: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def add(self, _record): - pass - - def flush(self): - pass - - def exec(self, _statement): - return self - - def first(self): - return None - - def _get_db_session(): - return _Session() - - def _to_db_instance(self): - return _Record() - - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - monkeypatch.setattr(emoji_manager_new.MaiEmoji, "to_db_instance", _to_db_instance, raising=False) - - result = manager.register_emoji_to_db(emoji) - - assert result is True - assert any("成功注册表情包到数据库" in m for m in _messages(logger.info_calls)) - - -def test_delete_emoji_file_missing_and_db_record_missing(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - class _DummyPath: - def __init__(self): - self._name = "missing.png" - - def unlink(self): - raise FileNotFoundError("missing") - - def exists(self): - return False - - @property - def name(self): - return self._name - - class _Select: - def filter_by(self, **_kwargs): - return self - - def limit(self, _num): - return self - - def _select(_model): - return _Select() - - class _Result: - def scalars(self): - return self - - def first(self): - return None - - class _Session: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def exec(self, _statement): - return _Result() - - def _get_db_session(): - return _Session() - - monkeypatch.setattr(emoji_manager_new, "select", _select) - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - - emoji = emoji_manager_new.MaiEmoji() - emoji.full_path = _DummyPath() - emoji.file_name = "missing.png" - emoji.file_hash = "hash-missing" - - result = manager.delete_emoji(emoji) - - assert result is True - assert any("不存在" in m for m in _messages(logger.warning_calls)) - assert any("未找到表情包记录" in m for m in _messages(logger.warning_calls)) - - -def test_delete_emoji_file_delete_error(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - class _DummyPath: - def __init__(self): - self._name = "boom.png" - - def unlink(self): - raise RuntimeError("unlink failed") - - @property - def name(self): - return self._name - - emoji = emoji_manager_new.MaiEmoji() - emoji.full_path = _DummyPath() - emoji.file_name = "boom.png" - emoji.file_hash = "hash-boom" - - result = manager.delete_emoji(emoji) - - assert result is False - assert any("删除表情包文件时出错" in m for m in _messages(logger.error_calls)) - - -def test_delete_emoji_db_error_file_still_exists(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - class _DummyPath: - def __init__(self): - self._name = "keep.png" - - def unlink(self): - return None - - def exists(self): - return True - - @property - def name(self): - return self._name - - class _Select: - def filter_by(self, **_kwargs): - return self - - def limit(self, _num): - return self - - def _select(_model): - return _Select() - - class _Session: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def exec(self, _statement): - raise RuntimeError("db delete failed") - - def _get_db_session(): - return _Session() - - monkeypatch.setattr(emoji_manager_new, "select", _select) - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - - emoji = emoji_manager_new.MaiEmoji() - emoji.full_path = _DummyPath() - emoji.file_name = "keep.png" - emoji.file_hash = "hash-keep" - - result = manager.delete_emoji(emoji) - - assert result is False - assert any("删除数据库记录时出错" in m for m in _messages(logger.error_calls)) - assert any("数据库记录修改失败,但文件仍存在" in m for m in _messages(logger.warning_calls)) - - -def test_delete_emoji_success(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - class _DummyPath: - def __init__(self): - self._name = "ok.png" - self._deleted = False - - def unlink(self): - self._deleted = True - - def exists(self): - return not self._deleted - - @property - def name(self): - return self._name - - class _Select: - def filter_by(self, **_kwargs): - return self - - def limit(self, _num): - return self - - def _select(_model): - return _Select() - - class _Record: - def __init__(self): - self.no_file_flag = False - - class _Result: - def scalars(self): - return self - - def first(self): - return _Record() - - class _Session: - def __init__(self): - self.added = False - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def exec(self, _statement): - return _Result() - - def add(self, _record): - self.added = True - - def _get_db_session(): - return _Session() - - monkeypatch.setattr(emoji_manager_new, "select", _select) - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - - emoji = emoji_manager_new.MaiEmoji() - emoji.full_path = _DummyPath() - emoji.file_name = "ok.png" - emoji.file_hash = "hash-ok" - - result = manager.delete_emoji(emoji) - - assert result is True - assert any("成功删除表情包文件" in m for m in _messages(logger.info_calls)) - assert any("成功修改数据库中的表情包记录" in m for m in _messages(logger.info_calls)) - - -def test_delete_emoji_no_desc_deletes_record(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - class _DummyPath: - def __init__(self): - self._name = "empty.png" - - def unlink(self): - return None - - def exists(self): - return False - - @property - def name(self): - return self._name - - class _Select: - def filter_by(self, **_kwargs): - return self - - def limit(self, _num): - return self - - def _select(_model): - return _Select() - - class _Result: - def scalars(self): - return self - - def first(self): - return object() - - class _Session: - def __init__(self): - self.deleted = False - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def exec(self, _statement): - return _Result() - - def delete(self, _record): - self.deleted = True - - def _get_db_session(): - return _Session() - - monkeypatch.setattr(emoji_manager_new, "select", _select) - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - - emoji = emoji_manager_new.MaiEmoji() - emoji.full_path = _DummyPath() - emoji.file_name = "empty.png" - emoji.file_hash = "hash-empty" - - result = manager.delete_emoji(emoji, no_desc=True) - - assert result is True - assert any("成功删除数据库中的空表情包记录" in m for m in _messages(logger.info_calls)) - - -def test_update_emoji_usage_success(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - class _Select: - def filter_by(self, **_kwargs): - return self - - def limit(self, _num): - return self - - def _select(_model): - return _Select() - - class _Record: - def __init__(self): - self.query_count = 2 - self.last_used_time = None - - record = _Record() - - class _Result: - def scalars(self): - return self - - def first(self): - return record - - class _Session: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def exec(self, _statement): - return _Result() - - def add(self, _record): - self.added = True - - def _get_db_session(): - return _Session() - - monkeypatch.setattr(emoji_manager_new, "select", _select) - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - - emoji = emoji_manager_new.MaiEmoji() - emoji.file_hash = "hash-ok" - - result = manager.update_emoji_usage(emoji) - - assert result is True - assert emoji.query_count == 1 - assert record.query_count == 1 - assert any("成功记录表情包使用" in m for m in _messages(logger.info_calls)) - - -def test_update_emoji_usage_missing_record(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - class _Select: - def filter_by(self, **_kwargs): - return self - - def limit(self, _num): - return self - - def _select(_model): - return _Select() - - class _Result: - def scalars(self): - return self - - def first(self): - return None - - class _Session: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def exec(self, _statement): - return _Result() - - def _get_db_session(): - return _Session() - - monkeypatch.setattr(emoji_manager_new, "select", _select) - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - - emoji = emoji_manager_new.MaiEmoji() - emoji.file_hash = "hash-missing" - - result = manager.update_emoji_usage(emoji) - - assert result is False - assert any("未找到表情包记录" in m for m in _messages(logger.error_calls)) - - -def test_update_emoji_usage_execute_error(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - class _Select: - def filter_by(self, **_kwargs): - return self - - def limit(self, _num): - return self - - def _select(_model): - return _Select() - - class _Session: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def exec(self, _statement): - raise RuntimeError("execute failed") - - def _get_db_session(): - return _Session() - - monkeypatch.setattr(emoji_manager_new, "select", _select) - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - - emoji = emoji_manager_new.MaiEmoji() - emoji.file_hash = "hash-execute" - - result = manager.update_emoji_usage(emoji) - - assert result is False - assert any("记录使用时出错" in m for m in _messages(logger.error_calls)) - - -def test_update_emoji_usage_get_db_session_error(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - def _get_db_session(): - raise RuntimeError("get_db_session failed") - - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - - emoji = emoji_manager_new.MaiEmoji() - emoji.file_hash = "hash-session" - - result = manager.update_emoji_usage(emoji) - - assert result is False - assert any("记录使用时出错" in m for m in _messages(logger.error_calls)) - - -def test_update_emoji_success(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - class _Select: - def filter_by(self, **_kwargs): - return self - - def limit(self, _num): - return self - - def _select(_model): - return _Select() - - class _Record: - def __init__(self): - self.description = None - self.emotion = None - - class _Result: - def scalars(self): - return self - - def first(self): - return _Record() - - class _Session: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def exec(self, _statement): - return _Result() - - def add(self, _record): - self.added = True - - def _get_db_session(): - return _Session() - - monkeypatch.setattr(emoji_manager_new, "select", _select) - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - - emoji = emoji_manager_new.MaiEmoji() - emoji.file_hash = "hash-update" - emoji.description = "new-desc" - emoji.emotion = ["a", "b"] - - result = manager.update_emoji(emoji) - - assert result is True - assert any("成功更新表情包信息" in m for m in _messages(logger.info_calls)) - - -def test_update_emoji_missing_record(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - class _Select: - def filter_by(self, **_kwargs): - return self - - def limit(self, _num): - return self - - def _select(_model): - return _Select() - - class _Result: - def scalars(self): - return self - - def first(self): - return None - - class _Session: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def exec(self, _statement): - return _Result() - - def _get_db_session(): - return _Session() - - monkeypatch.setattr(emoji_manager_new, "select", _select) - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - - emoji = emoji_manager_new.MaiEmoji() - emoji.file_hash = "hash-missing" - - result = manager.update_emoji(emoji) - - assert result is False - assert any("未找到表情包记录" in m for m in _messages(logger.error_calls)) - - -def test_update_emoji_execute_error(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - class _Select: - def filter_by(self, **_kwargs): - return self - - def limit(self, _num): - return self - - def _select(_model): - return _Select() - - class _Session: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def exec(self, _statement): - raise RuntimeError("execute failed") - - def _get_db_session(): - return _Session() - - monkeypatch.setattr(emoji_manager_new, "select", _select) - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - - emoji = emoji_manager_new.MaiEmoji() - emoji.file_hash = "hash-execute" - - result = manager.update_emoji(emoji) - - assert result is False - assert any("更新数据库记录时出错" in m for m in _messages(logger.error_calls)) - - -def test_update_emoji_get_db_session_error(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - def _get_db_session(): - raise RuntimeError("get_db_session failed") - - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - - emoji = emoji_manager_new.MaiEmoji() - emoji.file_hash = "hash-session" - - result = manager.update_emoji(emoji) - - assert result is False - assert any("更新数据库记录时出错" in m for m in _messages(logger.error_calls)) - - -def test_get_emoji_by_hash_from_db_no_file_flag(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - class _Select: - def filter_by(self, **_kwargs): - return self - - def limit(self, _num): - return self - - def _select(_model): - return _Select() - - class _Record: - def __init__(self): - self.no_file_flag = True - self.image_hash = "hash-nofile" - - class _Result: - def scalars(self): - return self - - def first(self): - return _Record() - - class _Session: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def exec(self, _statement): - return _Result() - - def _get_db_session(): - return _Session() - - monkeypatch.setattr(emoji_manager_new, "select", _select) - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - - result = manager.get_emoji_by_hash_from_db("hash-nofile") - - assert result is None - assert any("标记为文件不存在" in m for m in _messages(logger.warning_calls)) - - -def test_get_emoji_by_hash_from_db_success(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - manager = emoji_manager_new.EmojiManager() - - class _Select: - def filter_by(self, **_kwargs): - return self - - def limit(self, _num): - return self - - def _select(_model): - return _Select() - - class _Record: - def __init__(self): - self.no_file_flag = False - self.image_hash = "hash-ok" - - class _Result: - def scalars(self): - return self - - def first(self): - return _Record() - - class _Session: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def exec(self, _statement): - return _Result() - - def _get_db_session(): - return _Session() - - emoji = emoji_manager_new.MaiEmoji() - emoji.file_hash = "hash-ok" - - monkeypatch.setattr(emoji_manager_new, "select", _select) - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - monkeypatch.setattr(emoji_manager_new.MaiEmoji, "from_db_instance", staticmethod(lambda _r: emoji)) - - result = manager.get_emoji_by_hash_from_db("hash-ok") - - assert result is emoji - - -def test_ban_emoji_success(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - class _Select: - def filter_by(self, **_kwargs): - return self - - def limit(self, _num): - return self - - def _select(_model): - return _Select() - - class _Record: - def __init__(self): - self.is_banned = False - - class _Result: - def scalars(self): - return self - - def first(self): - return _Record() - - class _Session: - def __init__(self): - self.added = False - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def exec(self, _statement): - return _Result() - - def add(self, _record): - self.added = True - - def _get_db_session(): - return _Session() - - monkeypatch.setattr(emoji_manager_new, "select", _select) - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - - emoji = emoji_manager_new.MaiEmoji() - emoji.file_name = "ban.png" - emoji.file_hash = "hash-ban" - manager.emojis = [emoji] - - result = manager.ban_emoji(emoji) - - assert result is True - assert emoji not in manager.emojis - assert any("成功封禁表情包" in m for m in _messages(logger.info_calls)) - - -def test_ban_emoji_missing_record(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - class _Select: - def filter_by(self, **_kwargs): - return self - - def limit(self, _num): - return self - - def _select(_model): - return _Select() - - class _Result: - def scalars(self): - return self - - def first(self): - return None - - class _Session: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def exec(self, _statement): - return _Result() - - def _get_db_session(): - return _Session() - - monkeypatch.setattr(emoji_manager_new, "select", _select) - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - - emoji = emoji_manager_new.MaiEmoji() - emoji.file_name = "missing.png" - emoji.file_hash = "hash-missing" - - result = manager.ban_emoji(emoji) - - assert result is False - assert any("未找到表情包记录" in m for m in _messages(logger.warning_calls)) - - -def test_ban_emoji_db_error(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - class _Select: - def filter_by(self, **_kwargs): - return self - - def limit(self, _num): - return self - - def _select(_model): - return _Select() - - class _Session: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def exec(self, _statement): - raise RuntimeError("db failed") - - def _get_db_session(): - return _Session() - - monkeypatch.setattr(emoji_manager_new, "select", _select) - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - - emoji = emoji_manager_new.MaiEmoji() - emoji.file_name = "boom.png" - emoji.file_hash = "hash-boom" - - result = manager.ban_emoji(emoji) - - assert result is False - assert any("封禁时出错" in m for m in _messages(logger.error_calls)) - - -@pytest.mark.asyncio -async def test_get_emoji_for_emotion_empty_list(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - manager.emojis = [] - - result = await manager.get_emoji_for_emotion("开心") - - assert result is None - assert any("表情包列表为空" in m for m in _messages(logger.warning_calls)) - - -@pytest.mark.asyncio -async def test_get_emoji_for_emotion_no_matches(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - manager.emojis = [emoji_manager_new.MaiEmoji()] - - def _calc(_label): - return [] - - monkeypatch.setattr(manager, "_calculate_emotion_similarity_list", _calc) - - result = await manager.get_emoji_for_emotion("无匹配") - - assert result is None - assert any("未找到匹配的表情包" in m for m in _messages(logger.info_calls)) - - -@pytest.mark.asyncio -async def test_get_emoji_for_emotion_success_updates_usage(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - emoji1 = emoji_manager_new.MaiEmoji() - emoji1.file_name = "e1.png" - emoji1.emotion = ["开心"] - emoji2 = emoji_manager_new.MaiEmoji() - emoji2.file_name = "e2.png" - emoji2.emotion = ["难过"] - manager.emojis = [emoji1, emoji2] - - def _calc(_label): - return [(emoji1, 0.9), (emoji2, 0.2)] - - monkeypatch.setattr(manager, "_calculate_emotion_similarity_list", _calc) - monkeypatch.setattr(emoji_manager_new.random, "choice", lambda items: items[0]) - - called = {"emoji": None} - - def _update(emoji): - called["emoji"] = emoji - return True - - monkeypatch.setattr(manager, "update_emoji_usage", _update) - - result = await manager.get_emoji_for_emotion("开心") - - assert result is emoji1 - assert called["emoji"] is emoji1 - assert any("选中表情包" in m for m in _messages(logger.info_calls)) - - -@pytest.mark.asyncio -async def test_get_emoji_for_emotion_similarity_error_propagates(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - manager = emoji_manager_new.EmojiManager() - manager.emojis = [emoji_manager_new.MaiEmoji()] - - def _calc(_label): - raise RuntimeError("calc failed") - - monkeypatch.setattr(manager, "_calculate_emotion_similarity_list", _calc) - - with pytest.raises(RuntimeError): - await manager.get_emoji_for_emotion("异常") - - -@pytest.mark.asyncio -async def test_build_emoji_description_calls_hash_and_sets_description(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - - called = {"hash": False, "vlm": False} - - async def _calc(self): - called["hash"] = True - return True - - def _read_bytes(_path): - return b"" - - async def _vlm_response(*_args, **_kwargs): - called["vlm"] = True - return "desc", None - - monkeypatch.setattr(emoji_manager_new.MaiEmoji, "calculate_hash_format", _calc, raising=False) - monkeypatch.setattr(emoji_manager_new.MaiEmoji, "read_image_bytes", staticmethod(_read_bytes), raising=False) - monkeypatch.setattr( - emoji_manager_new.emoji_manager_vlm, - "generate_response_for_image", - _vlm_response, - ) - - emoji = emoji_manager_new.MaiEmoji() - emoji.file_hash = None - emoji.image_format = "png" - emoji.full_path = Path("/tmp/a.png") - - result, updated = await emoji_manager_new.EmojiManager().build_emoji_description(emoji) - - assert result is True - assert updated.description == "desc" - assert called["hash"] is True - assert called["vlm"] is True - assert any("成功为表情包构建描述" in m for m in _messages(logger.info_calls)) - - -@pytest.mark.asyncio -async def test_build_emoji_description_gif_conversion_error(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - - def _read_bytes(_path): - return b"" - - def _gif_to_static(_bytes): - raise RuntimeError("gif fail") - - monkeypatch.setattr(emoji_manager_new.MaiEmoji, "read_image_bytes", staticmethod(_read_bytes), raising=False) - monkeypatch.setattr(emoji_manager_new.ImageUtils, "gif_2_static_image", staticmethod(_gif_to_static)) - - emoji = emoji_manager_new.MaiEmoji() - emoji.file_hash = "hash" - emoji.image_format = "gif" - emoji.full_path = Path("/tmp/a.gif") - - result, updated = await emoji_manager_new.EmojiManager().build_emoji_description(emoji) - - assert result is False - assert updated.description is None - assert any("转换 GIF 图片时出错" in m for m in _messages(logger.error_calls)) - - -@pytest.mark.asyncio -async def test_build_emoji_description_content_filtration_reject(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - - emoji_manager_new.global_config.emoji.content_filtration = True - - def _read_bytes(_path): - return b"" - - call_count = {"n": 0} - - async def _vlm_response(*_args, **_kwargs): - call_count["n"] += 1 - if call_count["n"] == 2: - return "否", None - return "desc", None - - monkeypatch.setattr(emoji_manager_new.MaiEmoji, "read_image_bytes", staticmethod(_read_bytes), raising=False) - monkeypatch.setattr( - emoji_manager_new.emoji_manager_vlm, - "generate_response_for_image", - _vlm_response, - ) - - emoji = emoji_manager_new.MaiEmoji() - emoji.file_hash = "hash" - emoji.image_format = "png" - emoji.full_path = Path("/tmp/a.png") - - result, updated = await emoji_manager_new.EmojiManager().build_emoji_description(emoji) - - assert result is False - assert updated.description is None - assert any("表情包内容不符合要求" in m for m in _messages(logger.warning_calls)) - - -@pytest.mark.asyncio -async def test_build_emoji_description_content_filtration_pass(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - - emoji_manager_new.global_config.emoji.content_filtration = True - - def _read_bytes(_path): - return b"" - - call_count = {"n": 0} - - async def _vlm_response(*_args, **_kwargs): - call_count["n"] += 1 - if call_count["n"] == 2: - return "是", None - return "desc", None - - monkeypatch.setattr(emoji_manager_new.MaiEmoji, "read_image_bytes", staticmethod(_read_bytes), raising=False) - monkeypatch.setattr( - emoji_manager_new.emoji_manager_vlm, - "generate_response_for_image", - _vlm_response, - ) - - emoji = emoji_manager_new.MaiEmoji() - emoji.file_hash = "hash" - emoji.image_format = "png" - emoji.full_path = Path("/tmp/a.png") - - result, updated = await emoji_manager_new.EmojiManager().build_emoji_description(emoji) - - assert result is True - assert updated.description == "desc" - assert any("成功为表情包构建描述" in m for m in _messages(logger.info_calls)) - - -@pytest.mark.asyncio -async def test_build_emoji_description_vlm_exception_propagates(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - - def _read_bytes(_path): - return b"" - - async def _vlm_response(*_args, **_kwargs): - raise RuntimeError("vlm failed") - - monkeypatch.setattr(emoji_manager_new.MaiEmoji, "read_image_bytes", staticmethod(_read_bytes), raising=False) - monkeypatch.setattr( - emoji_manager_new.emoji_manager_vlm, - "generate_response_for_image", - _vlm_response, - ) - - emoji = emoji_manager_new.MaiEmoji() - emoji.file_hash = "hash" - emoji.image_format = "png" - emoji.full_path = Path("/tmp/a.png") - - with pytest.raises(RuntimeError): - await emoji_manager_new.EmojiManager().build_emoji_description(emoji) - - -@pytest.mark.asyncio -async def test_build_emoji_emotion_description_missing(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - - emoji = emoji_manager_new.MaiEmoji() - emoji.description = None - - result, updated = await emoji_manager_new.EmojiManager().build_emoji_emotion(emoji) - - assert result is False - assert updated.emotion is None - assert any("表情包描述为空" in m for m in _messages(logger.error_calls)) - - -@pytest.mark.asyncio -async def test_build_emoji_emotion_llm_exception_propagates(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - - async def _generate_response_async(*_args, **_kwargs): - raise RuntimeError("llm failed") - - monkeypatch.setattr( - emoji_manager_new.emoji_manager_emotion_judge_llm, - "generate_response_async", - _generate_response_async, - ) - - emoji = emoji_manager_new.MaiEmoji() - emoji.description = "desc" - - with pytest.raises(RuntimeError): - await emoji_manager_new.EmojiManager().build_emoji_emotion(emoji) - - -@pytest.mark.asyncio -async def test_build_emoji_emotion_empty_result(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - - async def _generate_response_async(*_args, **_kwargs): - return " , , ", None - - monkeypatch.setattr( - emoji_manager_new.emoji_manager_emotion_judge_llm, - "generate_response_async", - _generate_response_async, - ) - - emoji = emoji_manager_new.MaiEmoji() - emoji.description = "desc" - - result, updated = await emoji_manager_new.EmojiManager().build_emoji_emotion(emoji) - - assert result is True - assert updated.emotion == [] - assert any("成功为表情包构建情感标签" in m for m in _messages(logger.info_calls)) - - -@pytest.mark.asyncio -async def test_build_emoji_emotion_more_than_five_random_sample(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - - async def _generate_response_async(*_args, **_kwargs): - return "a,b,c,d,e,f", None - - monkeypatch.setattr( - emoji_manager_new.emoji_manager_emotion_judge_llm, - "generate_response_async", - _generate_response_async, - ) - monkeypatch.setattr(emoji_manager_new.random, "sample", lambda items, _k: items[:3]) - - emoji = emoji_manager_new.MaiEmoji() - emoji.description = "desc" - - result, updated = await emoji_manager_new.EmojiManager().build_emoji_emotion(emoji) - - assert result is True - assert updated.emotion == ["a", "b", "c"] - assert any("成功为表情包构建情感标签" in m for m in _messages(logger.info_calls)) - - -@pytest.mark.asyncio -async def test_build_emoji_emotion_three_items_random_sample(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - - async def _generate_response_async(*_args, **_kwargs): - return "a,b,c", None - - monkeypatch.setattr( - emoji_manager_new.emoji_manager_emotion_judge_llm, - "generate_response_async", - _generate_response_async, - ) - monkeypatch.setattr(emoji_manager_new.random, "sample", lambda items, _k: items[:2]) - - emoji = emoji_manager_new.MaiEmoji() - emoji.description = "desc" - - result, updated = await emoji_manager_new.EmojiManager().build_emoji_emotion(emoji) - - assert result is True - assert updated.emotion == ["a", "b"] - assert any("成功为表情包构建情感标签" in m for m in _messages(logger.info_calls)) - - -def test_check_emoji_file_integrity_no_issues(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - - class _DummyPath: - def __init__(self, name): - self._name = name - self._exists = True - - def exists(self): - return self._exists - - @property - def name(self): - return self._name - - manager = emoji_manager_new.EmojiManager() - emoji = emoji_manager_new.MaiEmoji() - emoji.file_name = "ok.png" - emoji.full_path = _DummyPath("ok.png") - emoji.description = "desc" - manager.emojis = [emoji] - manager._emoji_num = 1 - - called = {"count": 0} - - def _delete(_emoji, no_desc=False): - called["count"] += 1 - return True - - monkeypatch.setattr(manager, "delete_emoji", _delete) - - manager.check_emoji_file_integrity() - - assert manager.emojis == [emoji] - assert manager._emoji_num == 1 - assert called["count"] == 0 - assert logger.warning_calls == [] - assert any("完整性检查完成" in m for m in _messages(logger.info_calls)) - - -def test_check_emoji_file_integrity_removes_invalid_records(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - - class _DummyPath: - def __init__(self, name, exists=True): - self._name = name - self._exists = exists - - def exists(self): - return self._exists - - @property - def name(self): - return self._name - - manager = emoji_manager_new.EmojiManager() - missing_file = emoji_manager_new.MaiEmoji() - missing_file.file_name = "missing.png" - missing_file.full_path = _DummyPath("missing.png", exists=False) - missing_file.description = "desc" - - missing_desc = emoji_manager_new.MaiEmoji() - missing_desc.file_name = "nodesc.png" - missing_desc.full_path = _DummyPath("nodesc.png", exists=True) - missing_desc.description = None - - manager.emojis = [missing_file, missing_desc] - manager._emoji_num = 2 - - deleted = [] - - def _delete(emoji, no_desc=False): - deleted.append((emoji.file_name, no_desc)) - return True - - monkeypatch.setattr(manager, "delete_emoji", _delete) - - manager.check_emoji_file_integrity() - - assert manager.emojis == [] - assert manager._emoji_num == 0 - assert set(deleted) == {("missing.png", False), ("nodesc.png", True)} - messages = _messages(logger.warning_calls) - assert any("文件缺失" in m for m in messages) - assert any("缺失描述" in m for m in messages) - assert any("成功删除缺失文件的表情包记录" in m for m in _messages(logger.info_calls)) - assert any("删除了 2 条记录" in m for m in _messages(logger.info_calls)) - - -def test_check_emoji_file_integrity_delete_failed(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - - class _DummyPath: - def __init__(self, name): - self._name = name - self._exists = False - - def exists(self): - return self._exists - - @property - def name(self): - return self._name - - manager = emoji_manager_new.EmojiManager() - emoji = emoji_manager_new.MaiEmoji() - emoji.file_name = "bad.png" - emoji.full_path = _DummyPath("bad.png") - emoji.description = "desc" - manager.emojis = [emoji] - manager._emoji_num = 1 - - def _delete(_emoji, no_desc=False): - return False - - monkeypatch.setattr(manager, "delete_emoji", _delete) - - manager.check_emoji_file_integrity() - - assert manager.emojis == [emoji] - assert manager._emoji_num == 1 - assert any("表情包文件缺失" in m for m in _messages(logger.warning_calls)) - assert any("删除缺失文件的表情包记录失败" in m for m in _messages(logger.error_calls)) - - -@pytest.mark.asyncio -async def test_build_emoji_emotion_two_items_no_sample(monkeypatch): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - - async def _generate_response_async(*_args, **_kwargs): - return "a, b", None - - monkeypatch.setattr( - emoji_manager_new.emoji_manager_emotion_judge_llm, - "generate_response_async", - _generate_response_async, - ) - - emoji = emoji_manager_new.MaiEmoji() - emoji.description = "desc" - - result, updated = await emoji_manager_new.EmojiManager().build_emoji_emotion(emoji) - - assert result is True - assert updated.emotion == ["a", "b"] - assert any("成功为表情包构建情感标签" in m for m in _messages(logger.info_calls)) - - -@pytest.mark.asyncio -async def test_register_emoji_by_filename_file_missing(monkeypatch, tmp_path): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - missing_file = tmp_path / "missing.png" - - result = await manager.register_emoji_by_filename(missing_file) - - assert result is False - assert any("表情包文件不存在" in m for m in _messages(logger.error_calls)) - - -@pytest.mark.asyncio -async def test_register_emoji_by_filename_create_object_error(monkeypatch, tmp_path): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - file_path = tmp_path / "ok.png" - file_path.write_bytes(b"") - - class _BadEmoji: - def __init__(self, *args, **kwargs): - raise RuntimeError("create failed") - - monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _BadEmoji) - - result = await manager.register_emoji_by_filename(file_path) - - assert result is False - assert any("创建表情包对象时出错" in m for m in _messages(logger.error_calls)) - - -@pytest.mark.asyncio -async def test_register_emoji_by_filename_hash_format_failed(monkeypatch, tmp_path): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - file_path = tmp_path / "hash.png" - file_path.write_bytes(b"") - - class _Emoji(emoji_manager_new.MaiEmoji): - async def calculate_hash_format(self): - return False - - class _Session: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def exec(self, _statement): - return self - - def first(self): - return None - - class _Select: - def __init__(self) -> None: - pass - - def filter_by(self, **_kwargs): - return self - - def limit(self, _num): - return self - - def _get_db_session_manual(): - return _Session() - - def _get_db_session(): - return _Session() - - monkeypatch.setattr(emoji_manager_new, "get_db_session_manual", _get_db_session_manual) - monkeypatch.setattr(emoji_manager_new, "get_db_session", _get_db_session) - monkeypatch.setattr(emoji_manager_new, "select", lambda _model: _Select()) - monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _Emoji) - - result = await manager.register_emoji_by_filename(file_path) - - assert result is False - assert any("计算表情包哈希值和格式失败" in m for m in _messages(logger.error_calls)) - - -@pytest.mark.asyncio -async def test_register_emoji_by_filename_duplicate_hash(monkeypatch, tmp_path): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - file_path = tmp_path / "dup.png" - file_path.write_bytes(b"") - - class _Emoji(emoji_manager_new.MaiEmoji): - async def calculate_hash_format(self): - self.file_hash = "hash-dup" - self.full_path = file_path - return True - - monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _Emoji) - - existing = emoji_manager_new.MaiEmoji() - existing.file_name = "exist.png" - monkeypatch.setattr(manager, "get_emoji_by_hash", lambda _h: existing) - - result = await manager.register_emoji_by_filename(file_path) - - assert result is False - assert any("表情包已存在" in m for m in _messages(logger.warning_calls)) - - -@pytest.mark.asyncio -async def test_register_emoji_by_filename_build_description_failed(monkeypatch, tmp_path): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - file_path = tmp_path / "desc.png" - file_path.write_bytes(b"") - - class _Emoji(emoji_manager_new.MaiEmoji): - async def calculate_hash_format(self): - self.file_hash = "hash-desc" - self.full_path = file_path - return True - - monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _Emoji) - - async def _build_desc(_e): - return False, _e - - monkeypatch.setattr(manager, "build_emoji_description", _build_desc) - - result = await manager.register_emoji_by_filename(file_path) - - assert result is False - assert any("构建表情包描述失败" in m for m in _messages(logger.error_calls)) - - -@pytest.mark.asyncio -async def test_register_emoji_by_filename_build_emotion_failed(monkeypatch, tmp_path): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - - file_path = tmp_path / "emo.png" - file_path.write_bytes(b"") - - class _Emoji(emoji_manager_new.MaiEmoji): - async def calculate_hash_format(self): - self.file_hash = "hash-emo" - self.full_path = file_path - return True - - monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _Emoji) - - async def _build_desc(_e): - return True, _e - - async def _build_emo(_e): - return False, _e - - monkeypatch.setattr(manager, "build_emoji_description", _build_desc) - monkeypatch.setattr(manager, "build_emoji_emotion", _build_emo) - - result = await manager.register_emoji_by_filename(file_path) - - assert result is False - assert any("构建表情包情感标签失败" in m for m in _messages(logger.error_calls)) - - -@pytest.mark.asyncio -async def test_register_emoji_by_filename_capacity_replace_failed(monkeypatch, tmp_path): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - manager._emoji_num = 1 - emoji_manager_new.global_config.emoji.max_reg_num = 1 - emoji_manager_new.global_config.emoji.do_replace = True - - file_path = tmp_path / "full.png" - file_path.write_bytes(b"") - - class _Emoji(emoji_manager_new.MaiEmoji): - async def calculate_hash_format(self): - self.file_hash = "hash-full" - self.full_path = file_path - return True - - monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _Emoji) - - async def _build_desc(_e): - return True, _e - - async def _build_emo(_e): - return True, _e - - async def _replace(_e): - return False - - monkeypatch.setattr(manager, "build_emoji_description", _build_desc) - monkeypatch.setattr(manager, "build_emoji_emotion", _build_emo) - monkeypatch.setattr(manager, "replace_an_emoji_by_llm", _replace) - - result = await manager.register_emoji_by_filename(file_path) - - assert result is False - assert any("数量已达上限" in m for m in _messages(logger.warning_calls)) - assert any("替换表情包失败" in m for m in _messages(logger.error_calls)) - - -@pytest.mark.asyncio -async def test_register_emoji_by_filename_capacity_replace_success(monkeypatch, tmp_path): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - manager._emoji_num = 1 - emoji_manager_new.global_config.emoji.max_reg_num = 1 - emoji_manager_new.global_config.emoji.do_replace = True - - file_path = tmp_path / "full-ok.png" - file_path.write_bytes(b"") - - class _Emoji(emoji_manager_new.MaiEmoji): - async def calculate_hash_format(self): - self.file_hash = "hash-full-ok" - self.full_path = file_path - return True - - monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _Emoji) - - async def _build_desc(_e): - return True, _e - - async def _build_emo(_e): - return True, _e - - async def _replace(_e): - return True - - monkeypatch.setattr(manager, "build_emoji_description", _build_desc) - monkeypatch.setattr(manager, "build_emoji_emotion", _build_emo) - monkeypatch.setattr(manager, "replace_an_emoji_by_llm", _replace) - - result = await manager.register_emoji_by_filename(file_path) - - assert result is True - assert any("数量已达上限" in m for m in _messages(logger.warning_calls)) - - -@pytest.mark.asyncio -async def test_register_emoji_by_filename_register_db_failed(monkeypatch, tmp_path): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - manager._emoji_num = 0 - emoji_manager_new.global_config.emoji.max_reg_num = 10 - - file_path = tmp_path / "db-fail.png" - file_path.write_bytes(b"") - - class _Emoji(emoji_manager_new.MaiEmoji): - async def calculate_hash_format(self): - self.file_hash = "hash-db-fail" - self.full_path = file_path - return True - - monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _Emoji) - - async def _build_desc(_e): - return True, _e - - async def _build_emo(_e): - return True, _e - - monkeypatch.setattr(manager, "build_emoji_description", _build_desc) - monkeypatch.setattr(manager, "build_emoji_emotion", _build_emo) - monkeypatch.setattr(manager, "register_emoji_to_db", lambda _e: False) - - result = await manager.register_emoji_by_filename(file_path) - - assert result is False - assert any("注册表情包到数据库失败" in m for m in _messages(logger.error_calls)) - - -@pytest.mark.asyncio -async def test_register_emoji_by_filename_register_db_success(monkeypatch, tmp_path): - emoji_manager_new = import_emoji_manager_new(monkeypatch) - logger = emoji_manager_new.logger - manager = emoji_manager_new.EmojiManager() - manager._emoji_num = 0 - emoji_manager_new.global_config.emoji.max_reg_num = 10 - - file_path = tmp_path / "db-ok.png" - file_path.write_bytes(b"") - - class _Emoji(emoji_manager_new.MaiEmoji): - async def calculate_hash_format(self): - self.file_hash = "hash-db-ok" - self.full_path = file_path - self.file_name = "db-ok.png" - return True - - monkeypatch.setattr(emoji_manager_new, "MaiEmoji", _Emoji) - - async def _build_desc(_e): - return True, _e - - async def _build_emo(_e): - return True, _e - - monkeypatch.setattr(manager, "build_emoji_description", _build_desc) - monkeypatch.setattr(manager, "build_emoji_emotion", _build_emo) - monkeypatch.setattr(manager, "register_emoji_to_db", lambda _e: True) - - result = await manager.register_emoji_by_filename(file_path) - - assert result is True - assert manager._emoji_num == 1 - assert len(manager.emojis) == 1 - assert any("成功注册新表情包" in m for m in _messages(logger.info_calls)) diff --git a/pytests/image_sys_test/image_manager_test.py b/pytests/image_sys_test/image_manager_test.py deleted file mode 100644 index 51fd0c5d..00000000 --- a/pytests/image_sys_test/image_manager_test.py +++ /dev/null @@ -1,295 +0,0 @@ -import sys -import types -import importlib -import pytest -from pathlib import Path -import importlib.util - - -class DummyLogger: - def info(self, *a, **k): - pass - - def warning(self, *a, **k): - pass - - def error(self, *a, **k): - pass - - -class DummySession: - def __init__(self): - self.record = None - - def exec(self, *a, **k): - record = self.record - - class R: - def first(self): - return record - - def yield_per(self, n): - if record is None: - return iter(()) - return iter((record,)) - - return R() - - def add(self, record, *a, **k): - self.record = record - - def flush(self, *a, **k): - pass - - def delete(self, *a, **k): - self.record = None - - def expunge(self, *a, **k): - pass - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - -class DummyMaiImage: - def __init__(self, full_path=None, image_bytes=None): - self.full_path = full_path - self.image_bytes = image_bytes - self.file_hash = "dummy-hash" - self.image_format = "png" - self.description = "" - self.vlm_processed = False - - @classmethod - def from_db_instance(cls, record): - image = cls(full_path=getattr(record, "full_path", None)) - image.file_hash = getattr(record, "image_hash", "dummy-hash") - image.description = getattr(record, "description", "") - image.vlm_processed = getattr(record, "vlm_processed", False) - return image - - def to_db_instance(self): - return types.SimpleNamespace( - description=self.description, - full_path=str(self.full_path) if self.full_path is not None else "", - id=1, - image_hash=self.file_hash, - image_type="image", - last_used_time=None, - no_file_flag=False, - query_count=0, - register_time=None, - vlm_processed=self.vlm_processed, - ) - - async def calculate_hash_format(self): - self.file_hash = "dummy-hash" - return None - - -class DummyLLMRequest: - def __init__(self, *a, **k): - pass - - async def generate_response_for_image(self, prompt, image_base64, image_format, temp): - return ("dummy description", {}) - - -class DummyLLMServiceClient: - def __init__(self, *a, **k): - pass - - async def generate_response_for_image(self, prompt, image_base64, image_format, options=None): - return types.SimpleNamespace(response="dummy description") - - -class DummySelect: - def __init__(self, *a, **k): - pass - - def filter_by(self, *a, **k): - return self - - def limit(self, n): - return self - - -class DetachedRecord: - def __init__(self, description="cached description", vlm_processed=True): - self._detached = False - self._description = description - self._vlm_processed = vlm_processed - - @property - def description(self): - if not self._detached: - raise RuntimeError("attribute refresh operation cannot proceed") - return self._description - - @property - def vlm_processed(self): - if not self._detached: - raise RuntimeError("attribute refresh operation cannot proceed") - return self._vlm_processed - - -class DetachedRecordSession(DummySession): - def __init__(self, record): - self.record = record - - def exec(self, *a, **k): - record = self.record - - class R: - def first(self): - return record - - return R() - - def expunge(self, record): - record._detached = True - - -@pytest.fixture(autouse=True) -def patch_external_dependencies(monkeypatch): - # Provide dummy implementations as modules so that importing image_manager is safe - # Patch LLMRequest - llm_mod = types.SimpleNamespace(LLMRequest=DummyLLMRequest) - monkeypatch.setitem(sys.modules, "src.llm_models.utils_model", llm_mod) - llm_service_mod = types.SimpleNamespace(LLMServiceClient=DummyLLMServiceClient) - monkeypatch.setitem(sys.modules, "src.services.llm_service", llm_service_mod) - - # Patch logger - logger_mod = types.SimpleNamespace(get_logger=lambda name: DummyLogger()) - monkeypatch.setitem(sys.modules, "src.common.logger", logger_mod) - - # Patch DB session provider - shared_session = DummySession() - db_mod = types.SimpleNamespace(get_db_session=lambda: shared_session) - monkeypatch.setitem(sys.modules, "src.common.database.database", db_mod) - - # Patch database model types - db_model_mod = types.SimpleNamespace(Images=types.SimpleNamespace, ImageType=types.SimpleNamespace(IMAGE="image")) - monkeypatch.setitem(sys.modules, "src.common.database.database_model", db_model_mod) - - # Patch MaiImage data model - data_model_mod = types.SimpleNamespace(MaiImage=DummyMaiImage) - monkeypatch.setitem(sys.modules, "src.common.data_models.image_data_model", data_model_mod) - - # Patch SQLModel select function - sql_mod = types.SimpleNamespace(select=lambda *a, **k: DummySelect()) - monkeypatch.setitem(sys.modules, "sqlmodel", sql_mod) - - # Patch prompt manager used to build image description prompt. - class _PromptManager: - def get_prompt(self, _name): - return types.SimpleNamespace() - - async def render_prompt(self, _prompt): - return "test-style" - - prompt_manager_mod = types.SimpleNamespace(prompt_manager=_PromptManager()) - monkeypatch.setitem(sys.modules, "src.prompt.prompt_manager", prompt_manager_mod) - - llm_options_mod = types.SimpleNamespace(LLMImageOptions=lambda **kwargs: types.SimpleNamespace(**kwargs)) - monkeypatch.setitem(sys.modules, "src.common.data_models.llm_service_data_models", llm_options_mod) - - # If module already imported, reload it to apply patches - mod_name = "src.chat.image_system.image_manager" - if mod_name in sys.modules: - importlib.reload(sys.modules[mod_name]) - - yield - - -def _load_image_manager_module(tmp_path=None): - repo_root = Path(__file__).parent.parent.parent - file_path = repo_root / "src" / "chat" / "image_system" / "image_manager.py" - spec = importlib.util.spec_from_file_location("image_manager_test_loaded", str(file_path)) - mod = importlib.util.module_from_spec(spec) - sys.modules[spec.name] = mod - spec.loader.exec_module(mod) - # Redirect IMAGE_DIR to pytest's tmp_path when provided - try: - if tmp_path is not None: - tmpdir = Path(tmp_path) - tmpdir.mkdir(parents=True, exist_ok=True) - mod.IMAGE_DIR = tmpdir - except Exception: - pass - return mod - - -@pytest.mark.asyncio -async def test_get_image_description_generates(tmp_path): - image_manager = _load_image_manager_module(tmp_path) - - mgr = image_manager.ImageManager() - desc = await mgr.get_image_description(image_bytes=b"abc") - assert desc == "dummy description" - - -def test_get_image_from_db_none(tmp_path): - image_manager = _load_image_manager_module(tmp_path) - - mgr = image_manager.ImageManager() - assert mgr.get_image_from_db("nohash") is None - - -def test_register_image_to_db(tmp_path): - image_manager = _load_image_manager_module(tmp_path) - - mgr = image_manager.ImageManager() - p = tmp_path / "img.png" - p.write_bytes(b"data") - img = DummyMaiImage(full_path=p, image_bytes=b"data") - assert mgr.register_image_to_db(img) is True - - -def test_update_image_description_not_found(tmp_path): - image_manager = _load_image_manager_module(tmp_path) - - mgr = image_manager.ImageManager() - img = DummyMaiImage() - img.file_hash = "nohash" - img.description = "desc" - assert mgr.update_image_description(img) is False - - -def test_delete_image_not_found(tmp_path): - image_manager = _load_image_manager_module(tmp_path) - - mgr = image_manager.ImageManager() - img = DummyMaiImage() - img.file_hash = "nohash" - img.full_path = tmp_path = None - assert mgr.delete_image(img) is False - - -@pytest.mark.asyncio -async def test_save_image_and_process_and_cleanup(tmp_path): - image_manager = _load_image_manager_module(tmp_path) - - mgr = image_manager.ImageManager() - # call save_image_and_process - image = await mgr.save_image_and_process(b"binarydata") - assert getattr(image, "description", None) == "dummy description" - - # cleanup should run without error - mgr.cleanup_invalid_descriptions_in_db() - - -@pytest.mark.asyncio -async def test_get_image_description_returns_cached_description_after_session_closed(monkeypatch, tmp_path): - image_manager = _load_image_manager_module(tmp_path) - - cached_record = DetachedRecord() - monkeypatch.setattr(image_manager, "get_db_session", lambda: DetachedRecordSession(cached_record)) - - mgr = image_manager.ImageManager() - desc = await mgr.get_image_description(image_hash="cached-hash", wait_for_build=False) - - assert desc == "cached description" diff --git a/pytests/image_sys_test/test_image_data_model.py b/pytests/image_sys_test/test_image_data_model.py deleted file mode 100644 index 964e209c..00000000 --- a/pytests/image_sys_test/test_image_data_model.py +++ /dev/null @@ -1,91 +0,0 @@ -from pathlib import Path -from types import SimpleNamespace - -import io - -from PIL import Image as PILImage -import pytest - -from src.common.data_models.image_data_model import MaiEmoji, MaiImage - - -def _build_test_image_bytes(image_format: str) -> bytes: - image = PILImage.new("RGB", (8, 8), color="white") - buffer = io.BytesIO() - image.save(buffer, format=image_format) - return buffer.getvalue() - - -@pytest.mark.asyncio -async def test_calculate_hash_format_updates_runtime_path_metadata(tmp_path: Path) -> None: - image_bytes = _build_test_image_bytes("JPEG") - tmp_file_path = tmp_path / "emoji.tmp" - tmp_file_path.write_bytes(image_bytes) - - emoji = MaiEmoji(full_path=tmp_file_path, image_bytes=image_bytes) - - assert await emoji.calculate_hash_format() is True - assert emoji.image_format == "jpeg" - assert emoji.full_path.suffix == ".jpeg" - assert emoji.file_name == emoji.full_path.name - assert emoji.dir_path == tmp_path.resolve() - - -@pytest.mark.asyncio -async def test_calculate_hash_format_reuses_existing_target_file(tmp_path: Path) -> None: - image_bytes = _build_test_image_bytes("JPEG") - tmp_file_path = tmp_path / "emoji.tmp" - target_file_path = tmp_path / "emoji.jpeg" - tmp_file_path.write_bytes(image_bytes) - target_file_path.write_bytes(image_bytes) - - emoji = MaiEmoji(full_path=tmp_file_path, image_bytes=image_bytes) - - assert await emoji.calculate_hash_format() is True - assert emoji.full_path == target_file_path.resolve() - assert emoji.file_name == target_file_path.name - assert not tmp_file_path.exists() - assert target_file_path.exists() - - -@pytest.mark.parametrize( - ("model_cls", "extra_fields"), - [ - ( - MaiEmoji, - { - "description": "", - "last_used_time": None, - "query_count": 0, - "register_time": None, - }, - ), - ( - MaiImage, - { - "description": "", - "vlm_processed": False, - }, - ), - ], -) -def test_from_db_instance_restores_image_format_from_path( - tmp_path: Path, - model_cls: type[MaiEmoji] | type[MaiImage], - extra_fields: dict[str, object], -) -> None: - image_path = tmp_path / "cached.png" - image_path.write_bytes(_build_test_image_bytes("PNG")) - - record = SimpleNamespace( - no_file_flag=False, - image_hash="hash", - full_path=str(image_path), - **extra_fields, - ) - - image = model_cls.from_db_instance(record) - - assert image.full_path == image_path.resolve() - assert image.file_name == image_path.name - assert image.image_format == "png" diff --git a/pytests/logger.py b/pytests/logger.py deleted file mode 100644 index 913392a1..00000000 --- a/pytests/logger.py +++ /dev/null @@ -1,22 +0,0 @@ -class MyLogger: - def __init__(self): - pass - - def info(self, msg): - print(f"INFO: {msg}") - - def error(self, msg): - print(f"ERROR: {msg}") - - def debug(self, msg): - print(f"DEBUG: {msg}") - - def warning(self, msg): - print(f"WARNING: {msg}") - - def trace(self, msg): - print(f"TRACE: {msg}") - - -def get_logger(*args, **kwargs): - return MyLogger() diff --git a/pytests/message_test/session_message_test.py b/pytests/message_test/session_message_test.py deleted file mode 100644 index 48e0a574..00000000 --- a/pytests/message_test/session_message_test.py +++ /dev/null @@ -1,422 +0,0 @@ -import sys -import asyncio -import pytest -import importlib -import importlib.util -from types import ModuleType -from pathlib import Path -from datetime import datetime -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from src.common.data_models.message_component_data_model import MessageSequence, ForwardComponent - from src.chat.message_receive.message import ( - SessionMessage, - TextComponent, - ImageComponent, - EmojiComponent, - VoiceComponent, - AtComponent, - ReplyComponent, - ForwardNodeComponent, - ) - - -class DummyLogger: - def __init__(self) -> None: - self.logging_record = [] - - def debug(self, msg): - print(f"DEBUG: {msg}") - self.logging_record.append(f"DEBUG: {msg}") - - def info(self, msg): - print(f"INFO: {msg}") - self.logging_record.append(f"INFO: {msg}") - - def warning(self, msg): - print(f"WARNING: {msg}") - self.logging_record.append(f"WARNING: {msg}") - - def error(self, msg): - print(f"ERROR: {msg}") - self.logging_record.append(f"ERROR: {msg}") - - def critical(self, msg): - print(f"CRITICAL: {msg}") - self.logging_record.append(f"CRITICAL: {msg}") - - -def get_logger(name): - return DummyLogger() - - -class DummyDBSession: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - pass - - def exec(self, statement): - return self - - def first(self): - return None - - def commit(self): - pass - - def all(self): - return [] - - -def get_db_session(): - return DummyDBSession() - - -def get_manual_db_session(): - return DummyDBSession() - - -class DummySelect: - def __init__(self, model): - self.model = model - - def filter_by(self, **kwargs): - return self - - def where(self, condition): - return self - - def limit(self, n): - return self - - -def select(model): - return DummySelect(model) - - -async def dummy_get_voice_text(binary_data): - return None # 可以根据需要返回模拟的文本结果 - - -class DummyPersonUtils: - @staticmethod - def get_person_info_by_user_id_and_platform(user_id, platform): - return None # 可以根据需要返回模拟的用户信息 - - -def setup_mocks(monkeypatch): - def _stub_module(name: str) -> ModuleType: - module = ModuleType(name) - monkeypatch.setitem(sys.modules, name, module) - return module - - # src.common.logger - logger_mod = _stub_module("src.common.logger") - # Mock the logger - logger_mod.get_logger = get_logger - - db_mod = _stub_module("src.common.database.database") - db_mod.get_db_session = get_db_session - db_mod.get_manual_db_session = get_manual_db_session - - db_model_mod = _stub_module("src.common.database.database_model") - db_model_mod.Messages = None # 可以根据需要添加更多的属性或方法 - - emoji_manager_mod = _stub_module("src.emoji_system.emoji_manager") - emoji_manager_mod.emoji_manager = None # 可以根据需要添加更多的属性或方法 - - image_manager_mod = _stub_module("src.chat.image_system.image_manager") - image_manager_mod.image_manager = None # 可以根据需要添加更多的属性或方法 - - msg_utils_mod = _stub_module("src.common.utils.utils_message") - msg_utils_mod.MessageUtils = None # 可以根据需要添加更多的属性或方法 - - voice_utils_mod = _stub_module("src.common.utils.utils_voice") - voice_utils_mod.get_voice_text = dummy_get_voice_text - - person_utils_mod = _stub_module("src.common.utils.utils_person") - person_utils_mod.PersonUtils = DummyPersonUtils - - -def load_message_via_file(monkeypatch): - setup_mocks(monkeypatch) - file_path = Path(__file__).parent.parent.parent / "src" / "chat" / "message_receive" / "message.py" - spec = importlib.util.spec_from_file_location("message", file_path) - message_module = importlib.util.module_from_spec(spec) - monkeypatch.setitem(sys.modules, "message_module", message_module) - spec.loader.exec_module(message_module) - message_module.select = select - SessionMessageClass = message_module.SessionMessage - TextComponentClass = message_module.TextComponent - ImageComponentClass = message_module.ImageComponent - EmojiComponentClass = message_module.EmojiComponent - VoiceComponentClass = message_module.VoiceComponent - AtComponentClass = message_module.AtComponent - ReplyComponentClass = message_module.ReplyComponent - ForwardNodeComponentClass = message_module.ForwardNodeComponent - MessageSequenceClass = sys.modules["src.common.data_models.message_component_data_model"].MessageSequence - ForwardComponentClass = sys.modules["src.common.data_models.message_component_data_model"].ForwardComponent - globals()["SessionMessage"] = SessionMessageClass - globals()["TextComponent"] = TextComponentClass - globals()["ImageComponent"] = ImageComponentClass - globals()["EmojiComponent"] = EmojiComponentClass - globals()["VoiceComponent"] = VoiceComponentClass - globals()["AtComponent"] = AtComponentClass - globals()["ReplyComponent"] = ReplyComponentClass - globals()["ForwardNodeComponent"] = ForwardNodeComponentClass - globals()["MessageSequence"] = MessageSequenceClass - globals()["ForwardComponent"] = ForwardComponentClass - return message_module - - -@pytest.mark.asyncio -async def test_process(monkeypatch): - load_message_via_file(monkeypatch) - msg = SessionMessage("msg123", datetime.now(), platform="test_platform") - msg.session_id = "session123" - msg.raw_message = MessageSequence(components=[]) - msg.raw_message.components = [TextComponent("Hello, world!")] - await msg.process() - assert msg.processed_plain_text == "Hello, world!" - - -@pytest.mark.asyncio -async def test_multiple_text(monkeypatch): - load_message_via_file(monkeypatch) - msg = SessionMessage("msg123", datetime.now(), platform="test_platform") - msg.session_id = "session123" - msg.raw_message = MessageSequence(components=[]) - msg.raw_message.components = [TextComponent("Hello,"), TextComponent("world!")] - await msg.process() - assert msg.processed_plain_text == "Hello, world!" - - -@pytest.mark.asyncio -async def test_image(monkeypatch): - load_message_via_file(monkeypatch) - msg = SessionMessage("msg123", datetime.now(), platform="test_platform") - msg.session_id = "session123" - msg.raw_message = MessageSequence(components=[]) - msg.raw_message.components = [ImageComponent(binary_hash="image_hash"), TextComponent("Hello, world!")] - await msg.process() - assert msg.processed_plain_text == "[一张图片,网卡了加载不出来] Hello, world!" - - -@pytest.mark.asyncio -async def test_emoji(monkeypatch): - load_message_via_file(monkeypatch) - msg = SessionMessage("msg123", datetime.now(), platform="test_platform") - msg.session_id = "session123" - msg.raw_message = MessageSequence(components=[]) - msg.raw_message.components = [EmojiComponent(binary_hash="emoji_hash"), TextComponent("Hello, world!")] - await msg.process() - assert msg.processed_plain_text == "[一个表情,网卡了加载不出来] Hello, world!" - - -@pytest.mark.asyncio -async def test_voice(monkeypatch): - load_message_via_file(monkeypatch) - msg = SessionMessage("msg123", datetime.now(), platform="test_platform") - msg.session_id = "session123" - msg.raw_message = MessageSequence(components=[]) - msg.raw_message.components = [VoiceComponent(binary_hash="voice_hash"), TextComponent("Hello, world!")] - await msg.process() - assert msg.processed_plain_text == "[语音消息,转录失败] Hello, world!" - - -@pytest.mark.asyncio -async def test_at_component(monkeypatch): - load_message_via_file(monkeypatch) - msg = SessionMessage("msg123", datetime.now(), platform="test_platform") - msg.session_id = "session123" - msg.platform = "test_platform" - msg.raw_message = MessageSequence(components=[]) - msg.raw_message.components = [AtComponent(target_user_id="114514"), TextComponent("Hello, world!")] - await msg.process() - assert msg.processed_plain_text == "@114514 Hello, world!" - - -@pytest.mark.asyncio -async def test_reply_component_fail_to_fetch(monkeypatch): - load_message_via_file(monkeypatch) - msg = SessionMessage("msg123", datetime.now(), platform="test_platform") - msg.session_id = "session123" - msg.platform = "test_platform" - msg.raw_message = MessageSequence(components=[]) - msg.raw_message.components = [ReplyComponent(target_message_id="1919810"), TextComponent("Hello, world!")] - await msg.process() - assert msg.processed_plain_text == "[回复了一条消息,但原消息已无法访问] Hello, world!" - - -@pytest.mark.asyncio -async def test_reply_component_success(monkeypatch): - module_msg = load_message_via_file(monkeypatch) - - class DummyDBSessionWithReply(DummyDBSession): - def exec(self, s): - return self - - def first(inner_self): - class DummyRecord: - processed_plain_text = "原消息内容" - user_cardname = "cardname123" - user_nickname = "nickname123" - user_id = "userid123" - - return DummyRecord() - - module_msg.get_db_session = lambda: DummyDBSessionWithReply() - msg = SessionMessage("msg123", datetime.now(), platform="test_platform") - msg.session_id = "session123" - msg.platform = "test_platform" - msg.raw_message = MessageSequence(components=[]) - msg.raw_message.components = [ReplyComponent(target_message_id="1919810"), TextComponent("Hello, world!")] - await msg.process() - assert msg.processed_plain_text == "[回复了cardname123的消息: 原消息内容] Hello, world!" - - -@pytest.mark.asyncio -async def test_reply_component_with_db_fail(monkeypatch): - module_msg = load_message_via_file(monkeypatch) - - class DummyDBSessionWithError(DummyDBSession): - def exec(self, s): - raise Exception("数据库查询失败") - - module_msg.get_db_session = lambda: DummyDBSessionWithError() - msg = SessionMessage("msg123", datetime.now(), platform="test_platform") - msg.session_id = "session123" - msg.platform = "test_platform" - msg.raw_message = MessageSequence(components=[]) - msg.raw_message.components = [ReplyComponent(target_message_id="1919810"), TextComponent("Hello, world!")] - await msg.process() - assert msg.processed_plain_text == "[回复了一条消息,但原消息已无法访问] Hello, world!" - assert any("数据库查询失败" in log for log in module_msg.logger.logging_record) - - -@pytest.mark.asyncio -async def test_forward_component(monkeypatch): - load_message_via_file(monkeypatch) - msg = SessionMessage("msg123", datetime.now(), platform="test_platform") - msg.session_id = "session123" - msg.platform = "test_platform" - msg.raw_message = MessageSequence(components=[]) - msg.raw_message.components = [ - ForwardNodeComponent( - forward_components=[ - ForwardComponent( - message_id="msg1", - user_id="user1", - user_nickname="nickname1", - user_cardname="cardname1", - content=[TextComponent("转发消息1")], - ), - ForwardComponent( - message_id="msg2", - user_id="user2", - user_nickname="nickname2", - user_cardname="cardname2", - content=[TextComponent("转发消息2")], - ), - ] - ), - TextComponent("Hello, world!"), - ] - await msg.process() - print("Processed plain text:", msg.processed_plain_text) - expected_forward_text = """【合并转发消息: --- 【cardname1】: 转发消息1 --- 【cardname2】: 转发消息2 -】 Hello, world!""" - assert msg.processed_plain_text == expected_forward_text - - -@pytest.mark.asyncio -async def test_forward_with_reply(monkeypatch): - load_message_via_file(monkeypatch) - msg = SessionMessage("msg123", datetime.now(), platform="test_platform") - msg.session_id = "session123" - msg.platform = "test_platform" - msg.raw_message = MessageSequence(components=[]) - msg.raw_message.components = [ - ForwardNodeComponent( - forward_components=[ - ForwardComponent( - message_id="msg1", - user_id="user1", - user_nickname="nickname1", - user_cardname="cardname1", - content=[TextComponent("转发消息1")], - ), - ForwardComponent( - message_id="msg2", - user_id="user2", - user_nickname="nickname2", - user_cardname="cardname2", - content=[ReplyComponent(target_message_id="msg1"), TextComponent("转发消息2")], - ), - ] - ), - TextComponent("Hello, world!"), - ] - await msg.process() - assert ( - msg.processed_plain_text - == """【合并转发消息: --- 【cardname1】: 转发消息1 --- 【cardname2】: [回复了cardname1的消息: 转发消息1] 转发消息2 -】 Hello, world!""" - ) - - -@pytest.mark.asyncio -async def test_multiple_reply_with_delay_in_forward(monkeypatch): - load_message_via_file(monkeypatch) - msg = SessionMessage("msg123", datetime.now(), platform="test_platform") - msg.session_id = "session123" - msg.platform = "test_platform" - msg.raw_message = MessageSequence(components=[]) - - async def delayed_get_voice_text(binary_data): - await asyncio.sleep(0.5) # 模拟延迟 - return "这是语音转文本的结果" - - sys.modules["src.common.utils.utils_voice"].get_voice_text = delayed_get_voice_text - - msg.raw_message.components = [ - ForwardNodeComponent( - forward_components=[ - ForwardComponent( - message_id="msg1", - user_id="user1", - user_nickname="nickname1", - user_cardname="cardname1", - content=[VoiceComponent(binary_hash="voice_hash1"), TextComponent("转发消息1")], - ), - ForwardComponent( - message_id="msg2", - user_id="user2", - user_nickname="nickname2", - user_cardname="cardname2", - content=[ReplyComponent(target_message_id="msg1"), TextComponent("转发消息2")], - ), - ForwardComponent( - message_id="msg3", - user_id="user3", - user_nickname="nickname3", - user_cardname="cardname3", - content=[ReplyComponent(target_message_id="msg1"), TextComponent("转发消息3")], - ), - ] - ), - ] - await msg.process() - expected_text = """【合并转发消息: --- 【cardname1】: [语音: 这是语音转文本的结果] 转发消息1 --- 【cardname2】: [回复了cardname1的消息: [语音: 这是语音转文本的结果] 转发消息1] 转发消息2 --- 【cardname3】: [回复了cardname1的消息: [语音: 这是语音转文本的结果] 转发消息1] 转发消息3 -】""" - assert msg.processed_plain_text == expected_text diff --git a/pytests/prompt_test/test_prompt_i18n.py b/pytests/prompt_test/test_prompt_i18n.py deleted file mode 100644 index 0b586f0e..00000000 --- a/pytests/prompt_test/test_prompt_i18n.py +++ /dev/null @@ -1,220 +0,0 @@ -from __future__ import annotations - -from pathlib import Path - -import pytest - -from src.common.i18n import set_locale -from src.common.prompt_i18n import clear_prompt_cache, load_prompt, list_prompt_templates -from src.prompt.prompt_manager import PromptManager - - -@pytest.fixture(autouse=True) -def clear_prompt_i18n_cache() -> None: - set_locale("zh-CN") - clear_prompt_cache() - yield - clear_prompt_cache() - set_locale("zh-CN") - - -def write_prompt(prompt_dir: Path, locale: str | None, name: str, content: str) -> None: - base_dir = prompt_dir if locale is None else prompt_dir / locale - base_dir.mkdir(parents=True, exist_ok=True) - (base_dir / f"{name}.prompt").write_text(content, encoding="utf-8") - - -def test_load_prompt_prefers_requested_locale(tmp_path: Path) -> None: - prompts_root = tmp_path / "prompts" - write_prompt(prompts_root, "zh-CN", "replyer", "你好,{user_name}") - write_prompt(prompts_root, "en-US", "replyer", "Hello, {user_name}") - - rendered = load_prompt("replyer", locale="en-US", prompts_root=prompts_root, user_name="Mai") - - assert rendered == "Hello, Mai" - - -def test_load_prompt_falls_back_to_default_locale(tmp_path: Path) -> None: - prompts_root = tmp_path / "prompts" - write_prompt(prompts_root, "zh-CN", "replyer", "你好,{user_name}") - - rendered = load_prompt("replyer", locale="en-US", prompts_root=prompts_root, user_name="Mai") - - assert rendered == "你好,Mai" - - -def test_load_prompt_does_not_fall_back_to_legacy_root(tmp_path: Path) -> None: - prompts_root = tmp_path / "prompts" - write_prompt(prompts_root, None, "replyer", "Legacy {user_name}") - - with pytest.raises(FileNotFoundError): - load_prompt("replyer", locale="en-US", prompts_root=prompts_root, user_name="Mai") - - -def test_load_prompt_with_category_falls_back_to_default_locale_root(tmp_path: Path) -> None: - prompts_root = tmp_path / "prompts" - write_prompt(prompts_root, "zh-CN", "replyer", "你好,{user_name}") - - rendered = load_prompt("replyer", locale="en-US", category="chat", prompts_root=prompts_root, user_name="Mai") - - assert rendered == "你好,Mai" - - -def test_load_prompt_prefers_custom_prompt_override(tmp_path: Path) -> None: - prompts_root = tmp_path / "prompts" - custom_prompts_root = tmp_path / "data" / "custom_prompts" - write_prompt(prompts_root, "zh-CN", "replyer", "Base {user_name}") - write_prompt(custom_prompts_root, "zh-CN", "replyer", "Custom {user_name}") - - rendered = load_prompt( - "replyer", - locale="zh-CN", - prompts_root=prompts_root, - custom_prompts_root=custom_prompts_root, - user_name="Mai", - ) - - assert rendered == "Custom Mai" - - -def test_load_prompt_prefers_custom_prompt_requested_locale(tmp_path: Path) -> None: - prompts_root = tmp_path / "prompts" - custom_prompts_root = tmp_path / "data" / "custom_prompts" - write_prompt(prompts_root, "zh-CN", "replyer", "Base zh {user_name}") - write_prompt(prompts_root, "en-US", "replyer", "Base en {user_name}") - write_prompt(custom_prompts_root, "zh-CN", "replyer", "Custom zh {user_name}") - write_prompt(custom_prompts_root, "en-US", "replyer", "Custom en {user_name}") - - rendered = load_prompt( - "replyer", - locale="en-US", - prompts_root=prompts_root, - custom_prompts_root=custom_prompts_root, - user_name="Mai", - ) - - assert rendered == "Custom en Mai" - - -def test_load_prompt_uses_requested_locale_source_before_default_custom(tmp_path: Path) -> None: - prompts_root = tmp_path / "prompts" - custom_prompts_root = tmp_path / "data" / "custom_prompts" - write_prompt(prompts_root, "zh-CN", "replyer", "Base zh {user_name}") - write_prompt(prompts_root, "en-US", "replyer", "Base en {user_name}") - write_prompt(custom_prompts_root, "zh-CN", "replyer", "Custom zh {user_name}") - - rendered = load_prompt( - "replyer", - locale="en-US", - prompts_root=prompts_root, - custom_prompts_root=custom_prompts_root, - user_name="Mai", - ) - - assert rendered == "Base en Mai" - - -def test_load_prompt_strict_mode_raises_on_missing_placeholder(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: - prompts_root = tmp_path / "prompts" - write_prompt(prompts_root, "zh-CN", "replyer", "你好,{user_name},现在是 {current_time}") - monkeypatch.setenv("MAIBOT_PROMPT_I18N_STRICT", "1") - - with pytest.raises(KeyError) as exc_info: - load_prompt("replyer", locale="zh-CN", prompts_root=prompts_root, user_name="Mai") - - assert "current_time" in str(exc_info.value) - - -def test_load_prompt_rejects_path_traversal(tmp_path: Path) -> None: - prompts_root = tmp_path / "prompts" - write_prompt(prompts_root, "zh-CN", "replyer", "你好") - - with pytest.raises(ValueError): - load_prompt("../replyer", locale="zh-CN", prompts_root=prompts_root) - - -def test_list_prompt_templates_prefers_locale_specific_files(tmp_path: Path) -> None: - prompts_root = tmp_path / "prompts" - write_prompt(prompts_root, "zh-CN", "replyer", "中文") - write_prompt(prompts_root, "en-US", "replyer", "English") - set_locale("en-US") - - prompt_templates = list_prompt_templates(prompts_root=prompts_root) - - assert prompt_templates["replyer"].path.read_text(encoding="utf-8") == "English" - - -def test_list_prompt_templates_loads_directory_metadata(tmp_path: Path) -> None: - prompts_root = tmp_path / "prompts" - write_prompt(prompts_root, "zh-CN", "replyer", "中文") - metadata_path = prompts_root / "zh-CN" / ".meta.toml" - metadata_path.write_text( - """ -[replyer] -display_name = "回复器" -advanced = true -description = "用于生成回复的主模板" -""".strip(), - encoding="utf-8", - ) - - prompt_templates = list_prompt_templates(prompts_root=prompts_root) - metadata = prompt_templates["replyer"].metadata - - assert metadata.display_name == "回复器" - assert metadata.advanced is True - assert metadata.description == "用于生成回复的主模板" - - -def test_list_prompt_templates_loads_prompt_specific_metadata(tmp_path: Path) -> None: - prompts_root = tmp_path / "prompts" - write_prompt(prompts_root, "zh-CN", "replyer", "中文") - metadata_path = prompts_root / "zh-CN" / "replyer.meta.json" - metadata_path.write_text( - '{"display_name": "Replyer", "advanced": false, "description": "Prompt specific metadata"}', - encoding="utf-8", - ) - - prompt_templates = list_prompt_templates(prompts_root=prompts_root) - metadata = prompt_templates["replyer"].metadata - - assert metadata.display_name == "Replyer" - assert metadata.advanced is False - assert metadata.description == "Prompt specific metadata" - - -def test_list_prompt_templates_reports_duplicate_name_with_custom_root(tmp_path: Path) -> None: - prompts_root = tmp_path / "prompts" - first_dir = prompts_root / "zh-CN" / "chat" - second_dir = prompts_root / "zh-CN" / "system" - first_dir.mkdir(parents=True, exist_ok=True) - second_dir.mkdir(parents=True, exist_ok=True) - (first_dir / "replyer.prompt").write_text("chat", encoding="utf-8") - (second_dir / "replyer.prompt").write_text("system", encoding="utf-8") - - with pytest.raises(ValueError) as exc_info: - list_prompt_templates(prompts_root=prompts_root) - - assert "zh-CN/chat/replyer.prompt" in str(exc_info.value) - assert "zh-CN/system/replyer.prompt" in str(exc_info.value) - - -def test_prompt_manager_load_prompts_prefers_locale_dir( - tmp_path: Path, - monkeypatch: pytest.MonkeyPatch, -) -> None: - prompts_root = tmp_path / "prompts" - custom_prompts_root = tmp_path / "data" / "custom_prompts" - custom_prompts_root.mkdir(parents=True, exist_ok=True) - write_prompt(prompts_root, "zh-CN", "replyer", "中文模板") - write_prompt(prompts_root, "en-US", "replyer", "English template") - set_locale("en-US") - - monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_root, raising=False) - monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_prompts_root, raising=False) - monkeypatch.setattr("src.prompt.prompt_manager.SUFFIX_PROMPT", ".prompt", raising=False) - - manager = PromptManager() - manager.load_prompts() - - assert manager.get_prompt("replyer").template == "English template" diff --git a/pytests/prompt_test/test_prompt_manager.py b/pytests/prompt_test/test_prompt_manager.py deleted file mode 100644 index e00a2fcc..00000000 --- a/pytests/prompt_test/test_prompt_manager.py +++ /dev/null @@ -1,893 +0,0 @@ -# File: pytests/prompt_test/test_prompt_manager.py - -from pathlib import Path -from typing import Any - -import asyncio -import inspect -import sys - -import pytest - -PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve() -sys.path.insert(0, str(PROJECT_ROOT)) -sys.path.insert(0, str(PROJECT_ROOT / "src" / "config")) - -from src.common.i18n.loaders import DEFAULT_LOCALE # noqa -from src.prompt.prompt_manager import ( # noqa - SUFFIX_PROMPT, - Prompt, - PromptManager, - prompt_manager, -) - - -def write_source_prompt(prompts_dir: Path, name: str, content: str) -> Path: - from src.common.i18n.loaders import DEFAULT_LOCALE - - source_dir = prompts_dir / DEFAULT_LOCALE - source_dir.mkdir(parents=True, exist_ok=True) - prompt_file = source_dir / f"{name}{SUFFIX_PROMPT}" - prompt_file.write_text(content, encoding="utf-8") - return prompt_file - - -# ========= Prompt 基础行为 ========= - - -@pytest.mark.parametrize( - "prompt_name, template", - [ - pytest.param("simple", "Hello {name}", id="simple-template-with-field"), - pytest.param("no-fields", "Just a static template", id="template-without-fields"), - pytest.param( - "brace-escaping", - "Use {{ and }} around {field}", - id="template-with-escaped-braces", - ), - ], -) -def test_prompt_init_happy_paths(prompt_name: str, template: str): - # Act - prompt = Prompt(prompt_name=prompt_name, template=template) - - # Assert - assert prompt.prompt_name == prompt_name - assert prompt.template == template - - -@pytest.mark.parametrize( - "prompt_name, template, expected_exception, expected_msg_substring", - [ - pytest.param("", "Hello {name}", ValueError, "prompt_name 不能为空", id="empty-prompt-name"), - pytest.param("valid-name", "", ValueError, "template 不能为空", id="empty-template"), - pytest.param( - "unnamed-placeholder", - "Hello {}", - ValueError, - "模板中不允许使用未命名的占位符", - id="unnamed-placeholder-not-allowed", - ), - pytest.param( - "unnamed-placeholder-with-escaped-brace", - "Value {{}} and {}", - ValueError, - "模板中不允许使用未命名的占位符", - id="unnamed-placeholder-mixed-with-escaped", - ), - ], -) -def test_prompt_init_error_cases( - prompt_name, - template, - expected_exception, - expected_msg_substring, -): - # Act / Assert - with pytest.raises(expected_exception) as exc_info: - Prompt(prompt_name=prompt_name, template=template) - - # Assert - assert expected_msg_substring in str(exc_info.value) - - -@pytest.mark.parametrize( - "initial_context, name, func, expected_value, expected_exception, expected_msg_substring, case_id", - [ - ( - {}, - "const_str", - "constant", - "constant", - None, - None, - "add-context-from-string-creates-wrapper", - ), - ( - {}, - "callable_str", - lambda prompt_name: f"hello-{prompt_name}", - "hello-my_prompt", - None, - None, - "add-context-from-callable", - ), - ( - {"dup": lambda _: "x"}, - "dup", - "y", - None, - KeyError, - "Context function name 'dup' 已存在于 Prompt 'my_prompt' 中", - "add-context-duplicate-key-error", - ), - ], -) -def test_prompt_add_context( - initial_context, - name, - func, - expected_value, - expected_exception, - expected_msg_substring, - case_id, -): - # Arrange - prompt = Prompt(prompt_name="my_prompt", template="template") - prompt.prompt_render_context = dict(initial_context) - - # Act - if expected_exception: - with pytest.raises(expected_exception) as exc_info: - prompt.add_context(name, func) - - # Assert - assert expected_msg_substring in str(exc_info.value) - else: - prompt.add_context(name, func) - - # Assert - assert name in prompt.prompt_render_context - result = prompt.prompt_render_context[name]("my_prompt") - assert result == expected_value - - -def test_prompt_clone_independent_instance(): - # Arrange - prompt = Prompt(prompt_name="p", template="T {x}") - prompt.add_context("x", "X") - - # Act - cloned = prompt.clone() - - # Assert - assert cloned is not prompt - assert cloned.prompt_name == prompt.prompt_name - assert cloned.template == prompt.template - # 当前实现 clone 不复制 context - assert cloned.prompt_render_context == {} - - -# ========= PromptManager:添加/获取/删除/替换 ========= - - -def test_prompt_manager_add_prompt_happy_and_error(): - # Arrange - manager = PromptManager() - prompt1 = Prompt(prompt_name="p1", template="T1") - manager.add_prompt(prompt1, need_save=True) - - # Act - prompt2 = Prompt(prompt_name="p2", template="T2") - manager.add_prompt(prompt2, need_save=False) - - # Assert - assert "p1" in manager._prompt_to_save - assert "p2" not in manager._prompt_to_save - - # Arrange - prompt_dup = Prompt(prompt_name="p1", template="T-dup") - - # Act / Assert - with pytest.raises(KeyError) as exc_info: - manager.add_prompt(prompt_dup) - - # Assert - assert "Prompt name 'p1' 已存在" in str(exc_info.value) - - -def test_prompt_manager_remove_prompt_happy_and_error(): - # Arrange - manager = PromptManager() - p1 = Prompt(prompt_name="p1", template="T") - manager.add_prompt(p1, need_save=True) - - # Act - manager.remove_prompt("p1") - - # Assert - assert "p1" not in manager.prompts - assert "p1" not in manager._prompt_to_save - - # Act / Assert - with pytest.raises(KeyError) as exc_info: - manager.remove_prompt("no_such") - - assert "Prompt name 'no_such' 不存在" in str(exc_info.value) - - -def test_prompt_manager_replace_prompt_happy_and_error(): - # sourcery skip: extract-duplicate-method - # Arrange - manager = PromptManager() - p1 = Prompt(prompt_name="p", template="Old") - manager.add_prompt(p1, need_save=True) - - p_new = Prompt(prompt_name="p", template="New") - - # Act: 替换且保持 need_save - manager.replace_prompt(p_new, need_save=True) - - # Assert - assert manager.prompts["p"].template == "New" - assert "p" in manager._prompt_to_save - - # Act: 再次替换,且不需要保存 - p_new2 = Prompt(prompt_name="p", template="New2") - manager.replace_prompt(p_new2, need_save=False) - - # Assert - assert manager.prompts["p"].template == "New2" - assert "p" not in manager._prompt_to_save - - # Error: 不存在的 prompt - p_unknown = Prompt(prompt_name="unknown", template="T") - with pytest.raises(KeyError) as exc_info: - manager.replace_prompt(p_unknown) - - assert "Prompt name 'unknown' 不存在,无法替换" in str(exc_info.value) - - -def test_prompt_manager_get_prompt_is_copy(): - # Arrange - manager = PromptManager() - prompt = Prompt(prompt_name="original", template="T") - manager.add_prompt(prompt) - - # Act - retrieved_prompt = manager.get_prompt("original") - - # Assert - assert retrieved_prompt is not prompt - assert retrieved_prompt.prompt_name == prompt.prompt_name - assert retrieved_prompt.template == prompt.template - assert retrieved_prompt.prompt_render_context == prompt.prompt_render_context - - -def test_prompt_manager_add_prompt_conflict_with_context_name(): - # Arrange - manager = PromptManager() - manager.add_context_construct_function("ctx_name", lambda _: "value") - prompt_conflict = Prompt(prompt_name="ctx_name", template="T") - - # Act / Assert - with pytest.raises(KeyError) as exc_info: - manager.add_prompt(prompt_conflict) - - # Assert - assert "Prompt name 'ctx_name' 已存在" in str(exc_info.value) - - -def test_prompt_manager_add_context_construct_function_happy(): - # Arrange - manager = PromptManager() - - def ctx_func(prompt_name: str) -> str: - return f"ctx-{prompt_name}" - - # Act - manager.add_context_construct_function("ctx", ctx_func) - - # Assert - assert "ctx" in manager._context_construct_functions - stored_func, module = manager._context_construct_functions["ctx"] - assert stored_func is ctx_func - assert module == __name__ - - -def test_prompt_manager_add_context_construct_function_duplicate(): - # Arrange - manager = PromptManager() - - def f(_): - return "x" - - manager.add_context_construct_function("dup", f) - manager.add_prompt(Prompt(prompt_name="dup_prompt", template="T")) - - # Act / Assert - with pytest.raises(KeyError) as exc_info1: - manager.add_context_construct_function("dup", f) - - # Assert - assert "Construct function name 'dup' 已存在" in str(exc_info1.value) - - # Act / Assert - with pytest.raises(KeyError) as exc_info2: - manager.add_context_construct_function("dup_prompt", f) - - # Assert - assert "Construct function name 'dup_prompt' 已存在" in str(exc_info2.value) - - -def test_prompt_manager_get_prompt_not_exist(): - # Arrange - manager = PromptManager() - - # Act / Assert - with pytest.raises(KeyError) as exc_info: - manager.get_prompt("no_such_prompt") - - # Assert - assert "Prompt name 'no_such_prompt' 不存在" in str(exc_info.value) - - -# ========= 渲染逻辑 ========= - - -@pytest.mark.parametrize( - "template, inner_context, global_context, expected, case_id", - [ - pytest.param( - "Hello {name}", - {"name": lambda p: f"name-for-{p}"}, - {}, - "Hello name-for-main", - "render-with-inner-context", - ), - pytest.param( - "Global {block}", - {}, - {"block": lambda p: f"block-{p}"}, - "Global block-main", - "render-with-global-context", - ), - pytest.param( - "Mix {inner} and {global}", - {"inner": lambda p: f"inner-{p}"}, - {"global": lambda p: f"global-{p}"}, - "Mix inner-main and global-main", - "render-with-inner-and-global-context", - ), - pytest.param( - "Escaped {{ and }} and {field}", - {"field": lambda _: "X"}, - {}, - "Escaped { and } and X", - "render-with-escaped-braces", - ), - ], -) -@pytest.mark.asyncio -async def test_prompt_manager_render_contexts( - template, - inner_context, - global_context, - expected, - case_id, -): - # Arrange - manager = PromptManager() - tmp_prompt = Prompt(prompt_name="main", template=template) - manager.add_prompt(tmp_prompt) - prompt = manager.get_prompt("main") - for name, fn in inner_context.items(): - prompt.add_context(name, fn) - for name, fn in global_context.items(): - manager.add_context_construct_function(name, fn) - - # Act - rendered = await manager.render_prompt(prompt) - - # Assert - assert rendered == expected - - -@pytest.mark.asyncio -async def test_prompt_manager_render_nested_prompts(): - # Arrange - manager = PromptManager() - p1 = Prompt(prompt_name="p1", template="P1-{x}") - p2 = Prompt(prompt_name="p2", template="P2-{p1}") - p3_tmp = Prompt(prompt_name="p3", template="{p2}-end") - manager.add_prompt(p1) - manager.add_prompt(p2) - manager.add_prompt(p3_tmp) - p3 = manager.get_prompt("p3") - p3.add_context("x", lambda _: "X") - - # Act - rendered = await manager.render_prompt(p3) - - # Assert - assert rendered == "P2-P1-X-end" - - -@pytest.mark.asyncio -async def test_prompt_manager_render_recursive_limit(): - # Arrange - manager = PromptManager() - p1_tmp = Prompt(prompt_name="p1", template="{p2}") - p2_tmp = Prompt(prompt_name="p2", template="{p1}") - manager.add_prompt(p1_tmp) - manager.add_prompt(p2_tmp) - p1 = manager.get_prompt("p1") - - # Act / Assert - with pytest.raises(RecursionError) as exc_info: - await manager.render_prompt(p1) - - # Assert - assert "递归层级过深" in str(exc_info.value) - - -@pytest.mark.asyncio -async def test_prompt_manager_render_missing_field_error(): - # Arrange - manager = PromptManager() - tmp_prompt = Prompt(prompt_name="main", template="Hello {missing}") - manager.add_prompt(tmp_prompt) - prompt = manager.get_prompt("main") - - # Act / Assert - with pytest.raises(KeyError) as exc_info: - await manager.render_prompt(prompt) - - # Assert - assert "Prompt 'main' 中缺少必要的内容块或构建函数: 'missing'" in str(exc_info.value) - - -@pytest.mark.asyncio -async def test_prompt_manager_render_prefers_inner_context_over_global(): - # Arrange - manager = PromptManager() - tmp_prompt = Prompt(prompt_name="main", template="{field}") - manager.add_context_construct_function("field", lambda _: "global") - manager.add_prompt(tmp_prompt) - prompt = manager.get_prompt("main") - prompt.add_context("field", lambda _: "inner") - - # Act - rendered = await manager.render_prompt(prompt) - - # Assert - assert rendered == "inner" - - -@pytest.mark.asyncio -async def test_prompt_manager_render_with_coroutine_context_function(): - # Arrange - manager = PromptManager() - - async def async_inner(prompt_name: str) -> str: - await asyncio.sleep(0) - return f"async-{prompt_name}" - - tmp_prompt = Prompt(prompt_name="main", template="{inner}") - manager.add_prompt(tmp_prompt) - prompt = manager.get_prompt("main") - prompt.add_context("inner", async_inner) - - # Act - rendered = await manager.render_prompt(prompt) - - # Assert - assert rendered == "async-main" - - -@pytest.mark.asyncio -async def test_prompt_manager_render_with_coroutine_global_context_function(): - # Arrange - manager = PromptManager() - - async def async_global(prompt_name: str) -> str: - await asyncio.sleep(0) - return f"g-{prompt_name}" - - tmp_prompt = Prompt(prompt_name="main", template="{g}") - manager.add_context_construct_function("g", async_global) - manager.add_prompt(tmp_prompt) - prompt = manager.get_prompt("main") - - # Act - rendered = await manager.render_prompt(prompt) - - # Assert - assert rendered == "g-main" - - -@pytest.mark.asyncio -async def test_prompt_manager_render_only_cloned_instance(): - # Arrange - manager = PromptManager() - p = Prompt(prompt_name="p", template="T") - manager.add_prompt(p) - - # Act / Assert: 直接用原始 p 渲染会报错 - with pytest.raises(ValueError) as exc_info: - await manager.render_prompt(p) - - assert "只能渲染通过 PromptManager.get_prompt 方法获取的 Prompt 实例" in str(exc_info.value) - - -@pytest.mark.parametrize( - "is_prompt_context, use_coroutine, case_id", - [ - pytest.param(True, False, "prompt-context-sync-error"), - pytest.param(False, False, "global-context-sync-error"), - pytest.param(True, True, "prompt-context-async-error"), - pytest.param(False, True, "global-context-async-error"), - ], -) -@pytest.mark.asyncio -async def test_prompt_manager_get_function_result_error_logging( - monkeypatch, - is_prompt_context, - use_coroutine, - case_id, -): - # Arrange - manager = PromptManager() - - class DummyError(Exception): - pass - - def sync_func(_name: str) -> str: - raise DummyError("sync-error") - - async def async_func(_name: str) -> str: - await asyncio.sleep(0) - raise DummyError("async-error") - - func = async_func if use_coroutine else sync_func - logged_messages: list[str] = [] - - def fake_error(msg: Any) -> None: - logged_messages.append(str(msg)) - - fake_logger = type("FakeLogger", (), {"error": staticmethod(fake_error)}) - - monkeypatch.setattr("src.prompt.prompt_manager.logger", fake_logger) - - # Act / Assert - with pytest.raises(DummyError): - await manager._get_function_result( - func=func, - prompt_name="P", - field_name="field", - is_prompt_context=is_prompt_context, - module="mod", - ) - - # Assert - assert logged_messages - log = logged_messages[0] - if is_prompt_context: - assert "调用 Prompt 'P' 内部上下文构造函数 'field' 时出错" in log - else: - assert "调用上下文构造函数 'field' 时出错,所属模块: 'mod'" in log - - -# ========= add_context_construct_function 边界 ========= - - -def test_prompt_manager_add_context_construct_function_unknown_frame(monkeypatch): - # Arrange - manager = PromptManager() - - def fake_currentframe() -> None: - return None - - monkeypatch.setattr("inspect.currentframe", fake_currentframe) - - def f(_): - return "x" - - # Act / Assert - with pytest.raises(RuntimeError) as exc_info: - manager.add_context_construct_function("x", f) - - # Assert - assert "无法获取调用栈" in str(exc_info.value) - - -def test_prompt_manager_add_context_construct_function_unknown_caller_frame(monkeypatch): - # Arrange - manager = PromptManager() - real_currentframe = inspect.currentframe - - class FakeFrame: - f_back = None - - def fake_currentframe(): - return FakeFrame() - - monkeypatch.setattr("inspect.currentframe", fake_currentframe) - - def f(_): - return "x" - - # Act / Assert - with pytest.raises(RuntimeError) as exc_info: - manager.add_context_construct_function("x", f) - - # Assert - assert "无法获取调用栈的上一级" in str(exc_info.value) - - # Cleanup - monkeypatch.setattr("inspect.currentframe", real_currentframe) - - -# ========= save/load & 目录逻辑 ========= - - -def test_prompt_manager_save_prompts_io_error_on_unlink(tmp_path, monkeypatch): - """ - save_prompts 现在的逻辑: - 1. 先删除 CUSTOM_PROMPTS_DIR 下的所有 *.prompt 文件; - 2. 再将 _prompt_to_save 中的 prompt 写入 CUSTOM_PROMPTS_DIR。 - - 这里模拟删除已有自定义 prompt 文件时发生 IO 错误。 - """ - # Arrange - prompts_dir = tmp_path / "prompts" - custom_dir = tmp_path / "data" / "custom_prompts" - prompts_dir.mkdir(parents=True) - custom_dir.mkdir(parents=True) - - monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False) - monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False) - - # 先在自定义目录写入一个 prompt 文件,触发 unlink 路径 - old_file = custom_dir / f"old{SUFFIX_PROMPT}" - old_file.write_text("old", encoding="utf-8") - - manager = PromptManager() - p1 = Prompt(prompt_name="save_error", template="T") - manager.add_prompt(p1, need_save=True) - - # 打桩 Path.unlink,使删除文件时报错 - def fake_unlink(self): - raise OSError("disk unlink error") - - monkeypatch.setattr("pathlib.Path.unlink", fake_unlink) - - # Act / Assert - with pytest.raises(OSError) as exc_info: - manager.save_prompts() - - # Assert - assert "disk unlink error" in str(exc_info.value) - - -def test_prompt_manager_save_prompts_io_error_on_write(tmp_path, monkeypatch): - """ - 模拟 save_prompts 在写入新 prompt 文件时发生 IO 错误。 - """ - # Arrange - prompts_dir = tmp_path / "prompts" - custom_dir = tmp_path / "data" / "custom_prompts" - prompts_dir.mkdir(parents=True) - custom_dir.mkdir(parents=True) - - monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False) - monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False) - - manager = PromptManager() - p1 = Prompt(prompt_name="save_error", template="T") - manager.add_prompt(p1, need_save=True) - - original_write_text = Path.write_text - - def fake_write_text(self, *args, **kwargs): - if self == custom_dir / DEFAULT_LOCALE / f"save_error{SUFFIX_PROMPT}": - raise OSError("disk write error") - return original_write_text(self, *args, **kwargs) - - monkeypatch.setattr(Path, "write_text", fake_write_text) - - # Act / Assert - with pytest.raises(OSError) as exc_info: - manager.save_prompts() - - # Assert - assert "disk write error" in str(exc_info.value) - - -def test_prompt_manager_load_prompts_io_error_from_default_dir(tmp_path, monkeypatch): - """ - 模拟从默认 locale 目录读取 prompt 时发生 IO 错误。 - """ - # Arrange - prompts_dir = tmp_path / "prompts" - custom_dir = tmp_path / "data" / "custom_prompts" - prompts_dir.mkdir(parents=True) - custom_dir.mkdir(parents=True) - - monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False) - monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False) - - prompt_file = write_source_prompt(prompts_dir, "bad", "content") - - original_read_text = Path.read_text - - def fake_read_text(self, *args, **kwargs): - if self == prompt_file: - raise OSError("read error") - return original_read_text(self, *args, **kwargs) - - monkeypatch.setattr(Path, "read_text", fake_read_text) - manager = PromptManager() - - # Act / Assert - with pytest.raises(OSError) as exc_info: - manager.load_prompts() - - # Assert - assert "read error" in str(exc_info.value) - - -def test_prompt_manager_load_prompts_io_error_from_custom_dir(tmp_path, monkeypatch): - """ - 模拟从 CUSTOM_PROMPTS_DIR 读取 prompt 时发生 IO 错误。 - 包含两种路径: - 1. default 与 custom 同名,load_prompts 会优先读取 custom; - 2. 仅 custom 有文件,且 default 无同名文件。 - """ - # Arrange - prompts_dir = tmp_path / "prompts" - custom_dir = tmp_path / "data" / "custom_prompts" - prompts_dir.mkdir(parents=True) - custom_dir.mkdir(parents=True) - - monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False) - monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False) - - # default 与 custom 同名的文件 - base_file = write_source_prompt(prompts_dir, "same", "base") - same_name = base_file.name - custom_file_same = custom_dir / same_name - custom_file_same.write_text("custom", encoding="utf-8") - - # 仅 custom 下存在的文件 - only_custom_file = custom_dir / f"only_custom{SUFFIX_PROMPT}" - only_custom_file.write_text("only", encoding="utf-8") - - original_read_text = Path.read_text - - def fake_read_text(self, *args, **kwargs): - if self.parent == custom_dir: - raise OSError("custom read error") - return original_read_text(self, *args, **kwargs) - - monkeypatch.setattr(Path, "read_text", fake_read_text) - manager = PromptManager() - - # Act / Assert - with pytest.raises(OSError) as exc_info: - manager.load_prompts() - - # Assert - assert "custom read error" in str(exc_info.value) - - -def test_prompt_manager_load_prompts_custom_overrides_default(tmp_path, monkeypatch): - """ - load_prompts 逻辑: - - 遍历 locale 目录中的 source prompt - - 如果 CUSTOM_PROMPTS_DIR 下存在同名文件,则优先使用自定义目录 - """ - # Arrange - prompts_dir = tmp_path / "prompts" - custom_dir = tmp_path / "data" / "custom_prompts" - prompts_dir.mkdir(parents=True) - custom_dir.mkdir(parents=True) - - monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False) - monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False) - - # source locale 目录 prompt - base_file = write_source_prompt(prompts_dir, "testp", "BaseTemplate {x}") - - # 自定义目录同名 prompt,应当覆盖默认 - custom_file = custom_dir / base_file.name - custom_file.write_text("CustomTemplate {x}", encoding="utf-8") - - manager = PromptManager() - - # Act - manager.load_prompts() - - # Assert - p = manager.get_prompt("testp") - assert p.template == "CustomTemplate {x}" - # 从自定义目录加载的 prompt 应标记为 need_save(加入 _prompt_to_save) - assert "testp" in manager._prompt_to_save - - -def test_prompt_manager_load_prompts_default_dir_not_mark_need_save(tmp_path, monkeypatch): - """ - 从 source locale 目录加载、且没有同名自定义 prompt 时,need_save 应为 False(不进入 _prompt_to_save)。 - """ - # Arrange - prompts_dir = tmp_path / "prompts" - custom_dir = tmp_path / "data" / "custom_prompts" - prompts_dir.mkdir(parents=True) - custom_dir.mkdir(parents=True) - - monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False) - monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False) - - # 仅 source locale 目录有 prompt,自定义目录中无同名文件 - base_file = write_source_prompt(prompts_dir, "only_default", "DefaultTemplate {x}") - - manager = PromptManager() - - # Act - manager.load_prompts() - - # Assert - p = manager.get_prompt("only_default") - assert p.template == base_file.read_text(encoding="utf-8") - # 从默认目录加载的 prompt 不应标记为 need_save - assert "only_default" not in manager._prompt_to_save - - -def test_prompt_manager_save_prompts_use_custom_dir(tmp_path, monkeypatch): - """ - save_prompts 使用 CUSTOM_PROMPTS_DIR 进行保存。 - """ - prompts_dir = tmp_path / "prompts" - custom_dir = tmp_path / "data" / "custom_prompts" - prompts_dir.mkdir(parents=True) - custom_dir.mkdir(parents=True) - - monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False) - monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False) - - manager = PromptManager() - p1 = Prompt(prompt_name="save_me", template="Template {x}") - p1.add_context("x", "X") - manager.add_prompt(p1, need_save=True) - - # Act - manager.save_prompts() - - # Assert: 文件应保存在 custom_dir 中 - saved_file = custom_dir / DEFAULT_LOCALE / f"save_me{SUFFIX_PROMPT}" - assert saved_file.exists() - assert saved_file.read_text(encoding="utf-8") == "Template {x}" - - -# ========= 其它 ========= - - -def test_prompt_manager_global_instance_access(): - # Act - pm = prompt_manager - - # Assert - assert isinstance(pm, PromptManager) - - -def test_formatter_parsing_named_fields_only(): - # Arrange - manager = PromptManager() - prompt = Prompt(prompt_name="main", template="A {x} B {y} C") - manager.add_prompt(prompt) - - # Act - fields = {field_name for _, field_name, _, _ in manager._formatter.parse(prompt.template) if field_name} - - # Assert - assert fields == {"x", "y"} diff --git a/pytests/test_context_message_fallback.py b/pytests/test_context_message_fallback.py deleted file mode 100644 index 2c344dc5..00000000 --- a/pytests/test_context_message_fallback.py +++ /dev/null @@ -1,73 +0,0 @@ -from src.common.data_models.message_component_data_model import ( - ImageComponent, - MessageSequence, - ReplyComponent, - TextComponent, -) -from src.llm_models.payload_content.message import RoleType -from src.maisaka.context_messages import _build_message_from_sequence -from src.maisaka.message_adapter import build_visible_text_from_sequence - - -def test_image_only_message_keeps_placeholder_in_text_fallback() -> None: - message_sequence = MessageSequence( - [ - TextComponent("[时间]19:21:20\n[用户名]William730\n[用户群昵称]\n[msg_id]1385025976\n[发言内容]"), - ImageComponent(binary_hash="hash", content=None, binary_data=None), - ] - ) - - message = _build_message_from_sequence( - RoleType.User, - message_sequence, - "[时间]19:21:20\n[用户名]William730\n[用户群昵称]\n[msg_id]1385025976\n[发言内容][图片]", - ) - - assert message is not None - assert "[发言内容]" in message.get_text_content() - assert "[图片]" in message.get_text_content() - - -def test_whitespace_image_content_uses_placeholder_in_text_fallback() -> None: - message_sequence = MessageSequence( - [ - TextComponent("[发言内容]"), - ImageComponent(binary_hash="hash", content=" ", binary_data=None), - ] - ) - - message = _build_message_from_sequence( - RoleType.User, - message_sequence, - "[发言内容][图片]", - enable_visual_message=False, - ) - - assert message is not None - assert message.get_text_content() == "[发言内容][图片]" - - -def test_visible_text_uses_image_placeholder_for_whitespace_content() -> None: - visible_text = build_visible_text_from_sequence( - MessageSequence( - [ - TextComponent("看这个"), - ImageComponent(binary_hash="hash", content=" ", binary_data=None), - ] - ) - ) - - assert visible_text == "看这个[图片]" - - -def test_visible_text_adds_body_marker_after_reply_component() -> None: - visible_text = build_visible_text_from_sequence( - MessageSequence( - [ - ReplyComponent(target_message_id="75625487"), - TextComponent("你说是那就是"), - ] - ) - ) - - assert visible_text == "[引用]quote_id=75625487\n[发言内容]你说是那就是" diff --git a/pytests/test_gemini_thought_signatures.py b/pytests/test_gemini_thought_signatures.py deleted file mode 100644 index ace63d6f..00000000 --- a/pytests/test_gemini_thought_signatures.py +++ /dev/null @@ -1,72 +0,0 @@ -import base64 -import sys -from types import ModuleType, SimpleNamespace - - -config_module = ModuleType("src.config.config") - - -class _ConfigManagerStub: - def get_model_config(self) -> SimpleNamespace: - return SimpleNamespace(api_providers=[]) - - def register_reload_callback(self, _: object) -> None: - return None - - -config_module.config_manager = _ConfigManagerStub() -sys.modules.setdefault("src.config.config", config_module) - -from src.llm_models.model_client import gemini_client -from src.llm_models.payload_content.message import MessageBuilder, RoleType -from src.llm_models.payload_content.tool_option import ToolCall - - -def _encode_signature(value: bytes) -> str: - return base64.b64encode(value).decode("ascii") - - -def test_convert_messages_preserves_gemini_function_call_signature_and_tool_result_id() -> None: - thought_signature = b"gemini-signature" - tool_call = ToolCall( - call_id="call-1", - func_name="reply", - args={"msg_id": "42"}, - extra_content={"google": {"thought_signature": _encode_signature(thought_signature)}}, - ) - assistant_message = MessageBuilder().set_role(RoleType.Assistant).set_tool_calls([tool_call]).build() - tool_message = ( - MessageBuilder() - .set_role(RoleType.Tool) - .set_tool_call_id("call-1") - .set_tool_name("reply") - .add_text_content('{"ok": true}') - .build() - ) - - contents, _ = gemini_client._convert_messages([assistant_message, tool_message]) - - assistant_part = contents[0].parts[0] - assert assistant_part.function_call is not None - assert assistant_part.function_call.id == "call-1" - assert assistant_part.function_call.name == "reply" - assert assistant_part.thought_signature == thought_signature - - tool_part = contents[1].parts[0] - assert tool_part.function_response is not None - assert tool_part.function_response.id == "call-1" - assert tool_part.function_response.name == "reply" - assert tool_part.function_response.response == {"ok": True} - - -def test_convert_messages_injects_dummy_signature_for_first_historical_tool_call() -> None: - tool_calls = [ - ToolCall(call_id="call-1", func_name="reply", args={"msg_id": "1"}), - ToolCall(call_id="call-2", func_name="reply", args={"msg_id": "2"}), - ] - assistant_message = MessageBuilder().set_role(RoleType.Assistant).set_tool_calls(tool_calls).build() - - contents, _ = gemini_client._convert_messages([assistant_message]) - - assert contents[0].parts[0].thought_signature == gemini_client.GEMINI_FALLBACK_THOUGHT_SIGNATURE - assert contents[0].parts[1].thought_signature is None diff --git a/pytests/test_html_render_service.py b/pytests/test_html_render_service.py deleted file mode 100644 index 3485ee5a..00000000 --- a/pytests/test_html_render_service.py +++ /dev/null @@ -1,194 +0,0 @@ -"""HTML 浏览器渲染服务测试。""" - -from pathlib import Path -from typing import Any, Dict, List - -import pytest - -from src.config.official_configs import PluginRuntimeRenderConfig -from src.services import html_render_service as html_render_service_module -from src.services.html_render_service import HTMLRenderService, ManagedBrowserRecord - - -class _FakeChromium: - """用于模拟 Playwright Chromium 启动器的测试桩。""" - - def __init__(self, effects: List[Any]) -> None: - """初始化 Chromium 启动测试桩。 - - Args: - effects: 每次调用 ``launch`` 时依次返回或抛出的结果。 - """ - - self._effects: List[Any] = list(effects) - self.calls: List[Dict[str, Any]] = [] - - async def launch(self, **kwargs: Any) -> Any: - """模拟 Playwright Chromium 的启动过程。 - - Args: - **kwargs: 浏览器启动参数。 - - Returns: - Any: 预设的浏览器对象。 - - Raises: - Exception: 当预设结果为异常对象时抛出。 - """ - - self.calls.append(dict(kwargs)) - effect = self._effects.pop(0) - if isinstance(effect, Exception): - raise effect - return effect - - -class _FakePlaywright: - """用于模拟 Playwright 根对象的测试桩。""" - - def __init__(self, chromium: _FakeChromium) -> None: - """初始化 Playwright 测试桩。 - - Args: - chromium: Chromium 启动器测试桩。 - """ - - self.chromium = chromium - - -def _build_render_config(**kwargs: Any) -> PluginRuntimeRenderConfig: - """构造用于测试的浏览器渲染配置。 - - Args: - **kwargs: 需要覆盖的配置字段。 - - Returns: - PluginRuntimeRenderConfig: 测试使用的配置对象。 - """ - - payload: Dict[str, Any] = { - "auto_download_chromium": True, - "browser_install_root": "data/test-playwright-browsers", - } - payload.update(kwargs) - return PluginRuntimeRenderConfig(**payload) - - -@pytest.mark.asyncio -async def test_launch_browser_auto_downloads_chromium_when_missing(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: - """未检测到可用浏览器时,应自动下载 Chromium 并记录状态。""" - - monkeypatch.setattr(html_render_service_module, "PROJECT_ROOT", tmp_path) - service = HTMLRenderService() - config = _build_render_config() - fake_browser = object() - fake_chromium = _FakeChromium( - [ - RuntimeError("browserType.launch: Executable doesn't exist at /tmp/chromium"), - fake_browser, - ] - ) - install_calls: List[str] = [] - - monkeypatch.setattr(service, "_resolve_executable_path", lambda _config: "") - - async def fake_install(_config: PluginRuntimeRenderConfig) -> None: - """模拟 Chromium 自动下载。 - - Args: - _config: 当前浏览器渲染配置。 - """ - - install_calls.append(_config.browser_install_root) - browsers_path = service._get_managed_browsers_path(_config) - (browsers_path / "chromium-1234").mkdir(parents=True, exist_ok=True) - - monkeypatch.setattr(service, "_install_chromium_browser", fake_install) - - browser = await service._launch_browser(_FakePlaywright(fake_chromium), config) - - assert browser is fake_browser - assert install_calls == ["data/test-playwright-browsers"] - assert len(fake_chromium.calls) == 2 - - browser_record = service._load_managed_browser_record() - assert browser_record is not None - assert browser_record.install_source == "auto_download" - assert browser_record.browser_name == "chromium" - - -@pytest.mark.asyncio -async def test_launch_browser_reuses_existing_managed_browser(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: - """已存在 Playwright 托管浏览器时,不应重复下载。""" - - monkeypatch.setattr(html_render_service_module, "PROJECT_ROOT", tmp_path) - service = HTMLRenderService() - config = _build_render_config() - browsers_path = service._get_managed_browsers_path(config) - (browsers_path / "chrome-headless-shell-1234").mkdir(parents=True, exist_ok=True) - fake_browser = object() - fake_chromium = _FakeChromium([fake_browser]) - - monkeypatch.setattr(service, "_resolve_executable_path", lambda _config: "") - - async def fail_install(_config: PluginRuntimeRenderConfig) -> None: - """若被错误调用则立即失败。 - - Args: - _config: 当前浏览器渲染配置。 - - Raises: - AssertionError: 表示本测试不期望进入下载逻辑。 - """ - - raise AssertionError("不应触发自动下载") - - monkeypatch.setattr(service, "_install_chromium_browser", fail_install) - - browser = await service._launch_browser(_FakePlaywright(fake_chromium), config) - - assert browser is fake_browser - assert len(fake_chromium.calls) == 1 - - browser_record = service._load_managed_browser_record() - assert browser_record is not None - assert browser_record.install_source == "existing_cache" - assert browser_record.browsers_path == str(browsers_path) - - -@pytest.mark.asyncio -async def test_launch_browser_prefers_local_executable(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: - """探测到本机浏览器时,应优先使用可执行文件路径启动。""" - - monkeypatch.setattr(html_render_service_module, "PROJECT_ROOT", tmp_path) - service = HTMLRenderService() - config = _build_render_config() - fake_browser = object() - fake_chromium = _FakeChromium([fake_browser]) - executable_path = "/usr/bin/google-chrome" - - monkeypatch.setattr(service, "_resolve_executable_path", lambda _config: executable_path) - - browser = await service._launch_browser(_FakePlaywright(fake_chromium), config) - - assert browser is fake_browser - assert len(fake_chromium.calls) == 1 - assert fake_chromium.calls[0]["executable_path"] == executable_path - assert service._load_managed_browser_record() is None - - -def test_managed_browser_record_roundtrip() -> None: - """托管浏览器记录应支持序列化与反序列化。""" - - record = ManagedBrowserRecord( - browser_name="chromium", - browsers_path="/tmp/playwright-browsers", - install_source="auto_download", - playwright_version="1.58.0", - recorded_at="2026-04-03T10:00:00+00:00", - last_verified_at="2026-04-03T10:00:01+00:00", - ) - - restored_record = ManagedBrowserRecord.from_dict(record.to_dict()) - - assert restored_record == record diff --git a/pytests/test_llm_provider_registry.py b/pytests/test_llm_provider_registry.py deleted file mode 100644 index abc412ad..00000000 --- a/pytests/test_llm_provider_registry.py +++ /dev/null @@ -1,101 +0,0 @@ -from typing import List - -from src.llm_models.model_client.base_client import ( - APIResponse, - AudioTranscriptionRequest, - BaseClient, - ClientProviderRegistration, - ClientRegistry, - EmbeddingRequest, - ResponseRequest, -) - - -class DummyClient(BaseClient): - """测试用 LLM 客户端。""" - - async def get_response(self, request: ResponseRequest) -> APIResponse: - """获取测试响应。 - - Args: - request: 统一响应请求。 - - Returns: - APIResponse: 测试响应。 - """ - del request - return APIResponse(content="ok") - - async def get_embedding(self, request: EmbeddingRequest) -> APIResponse: - """获取测试嵌入。 - - Args: - request: 统一嵌入请求。 - - Returns: - APIResponse: 测试嵌入响应。 - """ - del request - return APIResponse(embedding=[1.0]) - - async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse: - """获取测试音频转写。 - - Args: - request: 统一音频转写请求。 - - Returns: - APIResponse: 测试音频转写响应。 - """ - del request - return APIResponse(content="audio") - - def get_support_image_formats(self) -> List[str]: - """获取测试支持的图片格式。 - - Returns: - List[str]: 支持的图片格式列表。 - """ - return ["png"] - - -def test_client_registry_rejects_provider_conflict(): - """同一 client_type 被不同插件注册时应拒绝。""" - registry = ClientRegistry() - registry.replace_plugin_providers( - "plugin.alpha", - [ - ClientProviderRegistration( - client_type="example", - factory=DummyClient, - owner_plugin_id="plugin.alpha", - ) - ], - ) - - try: - registry.validate_plugin_provider_replacement("plugin.beta", ["example"]) - except ValueError as exc: - assert "冲突" in str(exc) - else: - raise AssertionError("不同插件注册相同 client_type 应失败") - - -def test_client_registry_unregisters_plugin_providers(): - """插件注销时应移除它拥有的 Provider 注册。""" - registry = ClientRegistry() - registry.replace_plugin_providers( - "plugin.alpha", - [ - ClientProviderRegistration( - client_type="example", - factory=DummyClient, - owner_plugin_id="plugin.alpha", - ) - ], - ) - - removed_count = registry.unregister_plugin_providers("plugin.alpha") - - assert removed_count == 1 - assert "example" not in registry.client_registry diff --git a/pytests/test_maisaka_builtin_context.py b/pytests/test_maisaka_builtin_context.py deleted file mode 100644 index 4d4bf3e9..00000000 --- a/pytests/test_maisaka_builtin_context.py +++ /dev/null @@ -1,113 +0,0 @@ -from datetime import datetime -from types import SimpleNamespace - -from src.chat.message_receive.message import SessionMessage -from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo -from src.common.data_models.message_component_data_model import AtComponent, MessageSequence, ReplyComponent, TextComponent -from src.config.config import global_config -from src.maisaka.builtin_tool.context import BuiltinToolRuntimeContext -from src.maisaka.runtime import MaisakaHeartFlowChatting - - -def _build_sent_message() -> SessionMessage: - message = SessionMessage( - message_id="real-message-id", - timestamp=datetime(2026, 4, 5, 12, 0, 0), - platform="qq", - ) - message.message_info = MessageInfo( - user_info=UserInfo( - user_id="bot-qq", - user_nickname="MaiSaka", - user_cardname=None, - ), - group_info=None, - additional_config={}, - ) - message.raw_message = MessageSequence( - [ - ReplyComponent(target_message_id="m123"), - TextComponent(text="你好"), - ] - ) - message.session_id = "test-session" - message.initialized = True - return message - - -def test_append_sent_message_to_chat_history_keeps_message_id() -> None: - runtime = SimpleNamespace(_chat_history=[]) - engine = SimpleNamespace(_get_runtime_manager=lambda: None) - tool_ctx = BuiltinToolRuntimeContext(engine=engine, runtime=runtime) - - tool_ctx.append_sent_message_to_chat_history(_build_sent_message()) - - assert len(runtime._chat_history) == 1 - history_message = runtime._chat_history[0] - assert history_message.message_id == "real-message-id" - assert "[msg_id]real-message-id\n" in history_message.raw_message.components[0].text - assert "[msg_id:real-message-id]" in history_message.visible_text - - -def test_post_process_reply_message_sequences_converts_at_marker_before_bracket_cleanup(monkeypatch) -> None: - monkeypatch.setattr(global_config.chat, "enable_at", True) - monkeypatch.setattr( - "src.maisaka.builtin_tool.context.process_llm_response", - lambda text: [text.strip()] if text.strip() else [], - ) - target_message = SimpleNamespace( - message_info=SimpleNamespace( - user_info=SimpleNamespace( - user_id="target-user", - user_nickname="目标昵称", - user_cardname="群名片", - ) - ) - ) - runtime = SimpleNamespace( - find_source_message_by_id=lambda message_id: target_message if message_id == "12160142" else None - ) - engine = SimpleNamespace(_get_runtime_manager=lambda: None) - tool_ctx = BuiltinToolRuntimeContext(engine=engine, runtime=runtime) - - sequences = tool_ctx.post_process_reply_message_sequences("at[12160142] 就这个群") - - assert len(sequences) == 1 - components = sequences[0].components - assert isinstance(components[0], AtComponent) - assert components[0].target_user_id == "target-user" - assert components[0].target_user_nickname == "目标昵称" - assert components[0].target_user_cardname == "群名片" - assert isinstance(components[1], TextComponent) - assert components[1].text == " 就这个群" - - -def test_post_process_reply_message_sequences_ignores_at_marker_when_disabled(monkeypatch) -> None: - monkeypatch.setattr(global_config.chat, "enable_at", False) - monkeypatch.setattr( - "src.maisaka.builtin_tool.context.process_llm_response", - lambda text: [text.strip()] if text.strip() else [], - ) - runtime = SimpleNamespace(find_source_message_by_id=lambda message_id: None) - engine = SimpleNamespace(_get_runtime_manager=lambda: None) - tool_ctx = BuiltinToolRuntimeContext(engine=engine, runtime=runtime) - - sequences = tool_ctx.post_process_reply_message_sequences("at[12160142] 就这个群") - - assert len(sequences) == 1 - components = sequences[0].components - assert len(components) == 1 - assert isinstance(components[0], TextComponent) - assert components[0].text == "at[12160142] 就这个群" - - -def test_runtime_finds_source_message_from_history() -> None: - target_message = _build_sent_message() - runtime = object.__new__(MaisakaHeartFlowChatting) - runtime._chat_history = [ - SimpleNamespace(message_id="other-message-id", original_message=SimpleNamespace()), - SimpleNamespace(message_id="real-message-id", original_message=target_message), - ] - - assert runtime.find_source_message_by_id("real-message-id") is target_message - assert runtime.find_source_message_by_id("missing-message-id") is None diff --git a/pytests/test_maisaka_builtin_query_memory.py b/pytests/test_maisaka_builtin_query_memory.py deleted file mode 100644 index 697e1114..00000000 --- a/pytests/test_maisaka_builtin_query_memory.py +++ /dev/null @@ -1,241 +0,0 @@ -from types import SimpleNamespace -from typing import Any, Dict - -import pytest - -from src.core.tooling import ToolInvocation -from src.maisaka.builtin_tool import query_memory as query_memory_tool -from src.maisaka.builtin_tool.context import BuiltinToolRuntimeContext -from src.services.memory_service import MemoryHit, MemorySearchResult - - -def _build_tool_ctx( - *, - session_id: str = "session-1", - platform: str = "qq", - user_id: str = "user-1", - group_id: str = "", -) -> BuiltinToolRuntimeContext: - runtime = SimpleNamespace( - session_id=session_id, - chat_stream=SimpleNamespace( - platform=platform, - user_id=user_id, - group_id=group_id, - ), - log_prefix=f"[{session_id}]", - ) - return BuiltinToolRuntimeContext(engine=SimpleNamespace(), runtime=runtime) - - -def _build_invocation(arguments: Dict[str, Any]) -> ToolInvocation: - return ToolInvocation( - tool_name="query_memory", - arguments=dict(arguments), - call_id="call-query-memory", - ) - - -@pytest.fixture(autouse=True) -def _patch_maisaka_config(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr( - query_memory_tool, - "global_config", - SimpleNamespace(memory=SimpleNamespace(memory_query_default_limit=5)), - ) - - -@pytest.mark.asyncio -async def test_query_memory_rejects_empty_query_and_time(monkeypatch: pytest.MonkeyPatch) -> None: - async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult: - _ = query - _ = kwargs - raise AssertionError("参数校验失败时不应调用 memory_service.search") - - monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search) - - result = await query_memory_tool.handle_tool( - _build_tool_ctx(), - _build_invocation({"query": "", "time_start": "", "time_end": ""}), - ) - - assert result.success is False - assert "query_memory 需要提供 query" in result.error_message - - -@pytest.mark.asyncio -async def test_query_memory_private_chat_auto_sets_person_id(monkeypatch: pytest.MonkeyPatch) -> None: - captured: Dict[str, Any] = {} - - def fake_resolve_person_id_for_memory( - *, - person_name: str = "", - platform: str = "", - user_id: Any = None, - strict_known: bool = False, - ) -> str: - _ = strict_known - captured["resolve_args"] = { - "person_name": person_name, - "platform": platform, - "user_id": user_id, - } - return "pid-private-auto" - - async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult: - captured["query"] = query - captured["search_kwargs"] = dict(kwargs) - return MemorySearchResult( - summary="检索摘要", - hits=[MemoryHit(content="Alice 喜欢咖啡", score=0.91)], - ) - - monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory) - monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search) - - result = await query_memory_tool.handle_tool( - _build_tool_ctx(session_id="private-session", platform="qq", user_id="alice", group_id=""), - _build_invocation({"query": "Alice 的喜好"}), - ) - - assert result.success is True - assert captured["query"] == "Alice 的喜好" - assert captured["resolve_args"] == { - "person_name": "", - "platform": "qq", - "user_id": "alice", - } - assert captured["search_kwargs"]["chat_id"] == "private-session" - assert captured["search_kwargs"]["user_id"] == "alice" - assert captured["search_kwargs"]["group_id"] == "" - assert captured["search_kwargs"]["person_id"] == "pid-private-auto" - assert isinstance(result.structured_content, dict) - assert result.structured_content["person_id"] == "pid-private-auto" - - -@pytest.mark.asyncio -async def test_query_memory_group_chat_does_not_attach_default_person_id(monkeypatch: pytest.MonkeyPatch) -> None: - call_counter = {"resolve": 0} - captured_kwargs: Dict[str, Any] = {} - - def fake_resolve_person_id_for_memory( - *, - person_name: str = "", - platform: str = "", - user_id: Any = None, - strict_known: bool = False, - ) -> str: - _ = person_name - _ = platform - _ = user_id - _ = strict_known - call_counter["resolve"] += 1 - return "unexpected-person-id" - - async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult: - _ = query - captured_kwargs.update(kwargs) - return MemorySearchResult(summary="", hits=[]) - - monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory) - monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search) - - result = await query_memory_tool.handle_tool( - _build_tool_ctx(session_id="group-session", platform="qq", user_id="alice", group_id="group-1"), - _build_invocation({"query": "群聊上下文"}), - ) - - assert result.success is True - assert call_counter["resolve"] == 0 - assert captured_kwargs["chat_id"] == "group-session" - assert captured_kwargs["group_id"] == "group-1" - assert captured_kwargs["person_id"] == "" - - -@pytest.mark.asyncio -async def test_query_memory_search_failure_is_returned(monkeypatch: pytest.MonkeyPatch) -> None: - async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult: - _ = query - _ = kwargs - return MemorySearchResult(success=False, error="boom") - - monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search) - monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", lambda **kwargs: "") - - result = await query_memory_tool.handle_tool( - _build_tool_ctx(), - _build_invocation({"query": "测试失败透传"}), - ) - - assert result.success is False - assert result.error_message == "boom" - assert isinstance(result.structured_content, dict) - assert result.structured_content["success"] is False - - -@pytest.mark.asyncio -async def test_query_memory_prefers_person_name_resolution(monkeypatch: pytest.MonkeyPatch) -> None: - captured: Dict[str, Any] = {"resolve_calls": []} - - def fake_resolve_person_id_for_memory( - *, - person_name: str = "", - platform: str = "", - user_id: Any = None, - strict_known: bool = False, - ) -> str: - _ = strict_known - captured["resolve_calls"].append( - { - "person_name": person_name, - "platform": platform, - "user_id": user_id, - } - ) - if person_name: - return "pid-by-name" - return "pid-private-auto" - - async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult: - _ = query - captured["search_kwargs"] = dict(kwargs) - return MemorySearchResult(summary="", hits=[MemoryHit(content="命中1")]) - - monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory) - monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search) - - result = await query_memory_tool.handle_tool( - _build_tool_ctx(session_id="private-session", platform="qq", user_id="alice", group_id=""), - _build_invocation({"query": "小明资料", "person_name": "小明"}), - ) - - assert result.success is True - assert captured["resolve_calls"][0] == { - "person_name": "小明", - "platform": "qq", - "user_id": "alice", - } - assert captured["search_kwargs"]["person_id"] == "pid-by-name" - assert result.structured_content["person_name"] == "小明" - assert result.structured_content["person_id"] == "pid-by-name" - - -@pytest.mark.asyncio -async def test_query_memory_no_hit_returns_readable_message(monkeypatch: pytest.MonkeyPatch) -> None: - async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult: - _ = query - _ = kwargs - return MemorySearchResult(summary="", hits=[]) - - monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search) - monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", lambda **kwargs: "") - - result = await query_memory_tool.handle_tool( - _build_tool_ctx(), - _build_invocation({"query": "不存在的记忆"}), - ) - - assert result.success is True - assert "未找到匹配的长期记忆" in result.content - assert isinstance(result.structured_content, dict) - assert result.structured_content["query"] == "不存在的记忆" diff --git a/pytests/test_maisaka_memory_retention.py b/pytests/test_maisaka_memory_retention.py deleted file mode 100644 index 921302a7..00000000 --- a/pytests/test_maisaka_memory_retention.py +++ /dev/null @@ -1,105 +0,0 @@ -from types import SimpleNamespace - -import pytest -import time - -from src.chat.heart_flow import heartflow_manager as heartflow_manager_module -from src.chat.heart_flow.heartflow_manager import HEARTFLOW_ACTIVE_RETENTION_SECONDS, HeartflowManager -from src.learners.expression_learner import ExpressionLearner -from src.maisaka.runtime import MAX_RETAINED_MESSAGE_CACHE_SIZE, MaisakaHeartFlowChatting - - -def _build_runtime_with_messages(message_count: int) -> MaisakaHeartFlowChatting: - runtime = object.__new__(MaisakaHeartFlowChatting) - runtime.log_prefix = "[test]" - runtime.message_cache = [SimpleNamespace(message_id=f"msg-{index}") for index in range(message_count)] - runtime._last_processed_index = message_count - runtime._expression_learner = ExpressionLearner("session-1") - runtime._expression_learner.mark_all_processed(runtime.message_cache) - return runtime - - -def test_prune_processed_message_cache_keeps_bounded_recent_window() -> None: - runtime = _build_runtime_with_messages(MAX_RETAINED_MESSAGE_CACHE_SIZE + 25) - - runtime._prune_processed_message_cache() - - assert len(runtime.message_cache) == MAX_RETAINED_MESSAGE_CACHE_SIZE - assert runtime.message_cache[0].message_id == "msg-25" - assert runtime._last_processed_index == MAX_RETAINED_MESSAGE_CACHE_SIZE - assert runtime._expression_learner.last_processed_index == MAX_RETAINED_MESSAGE_CACHE_SIZE - - -def test_prune_processed_message_cache_keeps_unlearned_messages() -> None: - runtime = _build_runtime_with_messages(MAX_RETAINED_MESSAGE_CACHE_SIZE + 25) - runtime._expression_learner.discard_processed_prefix(MAX_RETAINED_MESSAGE_CACHE_SIZE + 5) - - runtime._prune_processed_message_cache() - - assert len(runtime.message_cache) == MAX_RETAINED_MESSAGE_CACHE_SIZE + 5 - assert runtime.message_cache[0].message_id == "msg-20" - assert runtime._expression_learner.last_processed_index == 0 - - -def test_collect_pending_messages_uses_single_pending_received_time() -> None: - runtime = _build_runtime_with_messages(2) - runtime._last_processed_index = 0 - runtime._oldest_pending_message_received_at = 123.0 - runtime._last_message_received_at = 456.0 - runtime._reply_latency_measurement_started_at = None - - pending_messages = runtime._collect_pending_messages() - - assert [message.message_id for message in pending_messages] == ["msg-0", "msg-1"] - assert runtime._reply_latency_measurement_started_at == 123.0 - assert runtime._oldest_pending_message_received_at is None - - -@pytest.mark.asyncio -async def test_heartflow_manager_evicts_lru_chat_over_limit(monkeypatch: pytest.MonkeyPatch) -> None: - manager = HeartflowManager() - stopped_session_ids: list[str] = [] - old_active_at = time.time() - HEARTFLOW_ACTIVE_RETENTION_SECONDS - 1 - - class FakeChat: - def __init__(self, session_id: str) -> None: - self.session_id = session_id - - async def stop(self) -> None: - stopped_session_ids.append(self.session_id) - - monkeypatch.setattr(heartflow_manager_module, "HEARTFLOW_MAX_ACTIVE_CHATS", 2) - manager.heartflow_chat_list["session-1"] = FakeChat("session-1") - manager.heartflow_chat_list["session-2"] = FakeChat("session-2") - manager.heartflow_chat_list["session-3"] = FakeChat("session-3") - manager._chat_last_active_at["session-1"] = old_active_at - manager._chat_last_active_at["session-2"] = old_active_at - manager._chat_last_active_at["session-3"] = time.time() - - await manager._evict_over_limit_chats(protected_session_id="session-3") - - assert stopped_session_ids == ["session-1"] - assert list(manager.heartflow_chat_list) == ["session-2", "session-3"] - - -@pytest.mark.asyncio -async def test_heartflow_manager_keeps_recent_chats_even_over_limit(monkeypatch: pytest.MonkeyPatch) -> None: - manager = HeartflowManager() - stopped_session_ids: list[str] = [] - - class FakeChat: - def __init__(self, session_id: str) -> None: - self.session_id = session_id - - async def stop(self) -> None: - stopped_session_ids.append(self.session_id) - - monkeypatch.setattr(heartflow_manager_module, "HEARTFLOW_MAX_ACTIVE_CHATS", 2) - for session_id in ("session-1", "session-2", "session-3"): - manager.heartflow_chat_list[session_id] = FakeChat(session_id) - manager._chat_last_active_at[session_id] = time.time() - - await manager._evict_over_limit_chats(protected_session_id="session-3") - - assert stopped_session_ids == [] - assert list(manager.heartflow_chat_list) == ["session-1", "session-2", "session-3"] diff --git a/pytests/test_maisaka_message_adapter.py b/pytests/test_maisaka_message_adapter.py deleted file mode 100644 index de9130b1..00000000 --- a/pytests/test_maisaka_message_adapter.py +++ /dev/null @@ -1,54 +0,0 @@ -from datetime import datetime -from pathlib import Path - -import sys - -from src.chat.message_receive.message import SessionMessage -from src.common.data_models.message_component_data_model import MessageSequence, TextComponent -from src.llm_models.payload_content.tool_option import ToolCall -from src.maisaka.message_adapter import build_message, get_message_kind, get_message_role, get_tool_call_id, get_tool_calls - - -PROJECT_ROOT = Path(__file__).resolve().parents[1] -if str(PROJECT_ROOT) not in sys.path: - sys.path.insert(0, str(PROJECT_ROOT)) - - -def test_build_message_returns_session_message_with_maisaka_metadata() -> None: - timestamp = datetime.now() - tool_call = ToolCall( - call_id="call-1", - func_name="reply", - args={"message_id": "msg-1"}, - ) - raw_message = MessageSequence(components=[TextComponent(text="内部消息内容")]) - - message = build_message( - role="assistant", - content="展示消息内容", - message_kind="perception", - source="assistant", - tool_call_id="call-1", - tool_calls=[tool_call], - timestamp=timestamp, - message_id="maisaka-msg-1", - raw_message=raw_message, - display_text="展示消息内容", - ) - - assert isinstance(message, SessionMessage) - assert message.initialized is True - assert message.message_id == "maisaka-msg-1" - assert message.timestamp == timestamp - assert message.processed_plain_text == "展示消息内容" - assert message.raw_message is raw_message - - assert get_message_role(message) == "assistant" - assert get_message_kind(message) == "perception" - assert get_tool_call_id(message) == "call-1" - - restored_tool_calls = get_tool_calls(message) - assert len(restored_tool_calls) == 1 - assert restored_tool_calls[0].call_id == "call-1" - assert restored_tool_calls[0].func_name == "reply" - assert restored_tool_calls[0].args == {"message_id": "msg-1"} diff --git a/pytests/test_maisaka_monitor_protocol.py b/pytests/test_maisaka_monitor_protocol.py deleted file mode 100644 index 7fede4e3..00000000 --- a/pytests/test_maisaka_monitor_protocol.py +++ /dev/null @@ -1,619 +0,0 @@ -from types import SimpleNamespace -from typing import Any, Callable - -import pytest -from rich.panel import Panel -from rich.text import Text - -from src.chat.replyer import maisaka_generator as replyer_module -from src.common.data_models.reply_generation_data_models import ( - GenerationMetrics, - LLMCompletionResult, - ReplyGenerationResult, -) -from src.core.tooling import ToolExecutionResult, ToolInvocation -from src.maisaka.builtin_tool.context import BuiltinToolRuntimeContext -from src.maisaka.builtin_tool import reply as reply_tool_module -from src.maisaka.builtin_tool import send_emoji as send_emoji_tool_module -from src.maisaka.monitor_events import emit_planner_finalized -from src.maisaka.reasoning_engine import MaisakaReasoningEngine -from src.maisaka import runtime as runtime_module -from src.maisaka.runtime import MaisakaHeartFlowChatting - - -def test_runtime_maps_expression_config_flags_to_correct_fields(monkeypatch: pytest.MonkeyPatch) -> None: - fake_chat_stream = SimpleNamespace( - is_group_session=True, - group_id="group-1", - user_id="user-1", - platform="test", - ) - - monkeypatch.setattr( - runtime_module.chat_manager, - "get_session_by_session_id", - lambda session_id: fake_chat_stream, - ) - monkeypatch.setattr(runtime_module.chat_manager, "get_session_name", lambda session_id: "测试会话") - monkeypatch.setattr( - runtime_module.ExpressionConfigUtils, - "get_expression_config_for_chat", - staticmethod(lambda session_id: (True, False, True)), - ) - monkeypatch.setattr(runtime_module, "ExpressionLearner", lambda session_id: SimpleNamespace()) - monkeypatch.setattr(runtime_module, "JargonMiner", lambda session_id, session_name: SimpleNamespace()) - monkeypatch.setattr(runtime_module, "MaisakaReasoningEngine", lambda runtime: SimpleNamespace()) - monkeypatch.setattr(runtime_module, "ToolRegistry", lambda: SimpleNamespace()) - monkeypatch.setattr(runtime_module, "ReplyEffectTracker", lambda **kwargs: SimpleNamespace()) - monkeypatch.setattr(MaisakaHeartFlowChatting, "_register_tool_providers", lambda self: None) - monkeypatch.setattr(MaisakaHeartFlowChatting, "_emit_monitor_session_start", lambda self: None) - - runtime = MaisakaHeartFlowChatting("session-1") - - assert runtime._enable_expression_use is True - assert runtime._enable_expression_learning is False - assert runtime._enable_jargon_learning is True - - -class _FakeLLMResult: - def __init__(self) -> None: - self.response = "测试回复" - self.reasoning = "先理解上下文,再给出自然回复。" - self.model_name = "fake-model" - self.tool_calls = [] - self.prompt_tokens = 12 - self.completion_tokens = 7 - self.total_tokens = 19 - - -class _FakeLegacyLLMServiceClient: - def __init__(self, *args: Any, **kwargs: Any) -> None: - del args - del kwargs - - async def generate_response_with_messages(self, *, message_factory: Callable[[object], list[Any]]) -> _FakeLLMResult: - assert message_factory(object()) - return _FakeLLMResult() - - -class _FakeMultimodalLLMServiceClient: - def __init__(self, *args: Any, **kwargs: Any) -> None: - del args - del kwargs - - async def generate_response_with_messages(self, *, message_factory: Callable[[object], list[Any]]) -> _FakeLLMResult: - assert message_factory(object()) - return _FakeLLMResult() - - -@pytest.mark.asyncio -async def test_legacy_and_multimodal_replyer_monitor_detail_have_same_shape(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(replyer_module, "LLMServiceClient", _FakeLegacyLLMServiceClient) - monkeypatch.setattr(replyer_module, "load_prompt", lambda *args, **kwargs: "legacy prompt") - - legacy_generator = replyer_module.MaisakaReplyGenerator( - chat_stream=None, - request_type="test_legacy", - enable_visual_message=False, - ) - multimodal_generator = replyer_module.MaisakaReplyGenerator( - chat_stream=None, - request_type="test_multi", - llm_client_cls=_FakeMultimodalLLMServiceClient, - load_prompt_func=lambda *args, **kwargs: "multi prompt", - enable_visual_message=True, - ) - - legacy_success, legacy_result = await legacy_generator.generate_reply_with_context( - stream_id="session-legacy", - chat_history=[], - reply_reason="测试原因", - ) - multimodal_success, multimodal_result = await multimodal_generator.generate_reply_with_context( - stream_id="session-multi", - chat_history=[], - reply_reason="测试原因", - ) - - assert legacy_success is True - assert multimodal_success is True - assert legacy_result.monitor_detail is not None - assert multimodal_result.monitor_detail is not None - assert set(legacy_result.monitor_detail.keys()) == set(multimodal_result.monitor_detail.keys()) - assert set(legacy_result.monitor_detail["metrics"].keys()) == set(multimodal_result.monitor_detail["metrics"].keys()) - assert legacy_result.monitor_detail["metrics"]["prompt_tokens"] == 12 - assert legacy_result.monitor_detail["metrics"]["completion_tokens"] == 7 - assert legacy_result.monitor_detail["metrics"]["total_tokens"] == 19 - - -def test_legacy_replyer_builds_message_sequence_like_multimodal() -> None: - legacy_generator = replyer_module.MaisakaReplyGenerator( - chat_stream=None, - request_type="test_legacy", - enable_visual_message=False, - ) - legacy_prompt_loader = replyer_module.load_prompt - replyer_module.load_prompt = lambda *args, **kwargs: "legacy prompt" - - try: - session_message = replyer_module.SessionBackedMessage( - raw_message=SimpleNamespace(), - visible_text="[Alice]你好\n[Bob]在吗", - timestamp=replyer_module.datetime.now(), - source_kind="user", - ) - request_messages = legacy_generator._build_request_messages( - chat_history=[session_message], - reply_message=None, - reply_reason="测试原因", - stream_id="session-legacy", - ) - finally: - replyer_module.load_prompt = legacy_prompt_loader - - assert len(request_messages) == 4 - assert request_messages[0].role.value == "system" - assert request_messages[0].get_text_content() == "legacy prompt" - assert request_messages[1].role.value == "user" - assert request_messages[1].get_text_content() == "[Alice]你好" - assert request_messages[2].role.value == "user" - assert request_messages[2].get_text_content() == "[Bob]在吗" - assert request_messages[3].role.value == "user" - assert "当前时间:" in request_messages[3].get_text_content() - assert "【回复信息参考】" in request_messages[3].get_text_content() - assert "【最新推理】\n测试原因" in request_messages[3].get_text_content() - assert "请自然地回复。" in request_messages[3].get_text_content() - - -@pytest.mark.asyncio -async def test_reply_tool_puts_monitor_detail_into_metadata(monkeypatch: pytest.MonkeyPatch) -> None: - fake_monitor_detail = { - "prompt_text": "reply prompt", - "reasoning_text": "reply reasoning", - "output_text": "reply output", - "metrics": {"model_name": "fake-model", "total_tokens": 10}, - } - fake_reply_result = ReplyGenerationResult( - success=True, - completion=LLMCompletionResult(response_text="测试回复"), - metrics=GenerationMetrics(overall_ms=11.5), - monitor_detail=fake_monitor_detail, - ) - - class _FakeReplyer: - async def generate_reply_with_context(self, **kwargs: Any) -> tuple[bool, ReplyGenerationResult]: - del kwargs - return True, fake_reply_result - - monkeypatch.setattr(reply_tool_module.replyer_manager, "get_replyer", lambda **kwargs: _FakeReplyer()) - monkeypatch.setattr(reply_tool_module, "render_cli_message", lambda text: text) - - target_message = SimpleNamespace( - message_id="msg-1", - message_info=SimpleNamespace( - user_info=SimpleNamespace( - user_cardname="测试用户", - user_nickname="测试用户", - user_id="user-1", - ) - ), - ) - runtime = SimpleNamespace( - find_source_message_by_id=lambda message_id: target_message if message_id == "msg-1" else None, - log_prefix="[test]", - chat_stream=SimpleNamespace(platform=reply_tool_module.CLI_PLATFORM_NAME), - session_id="session-1", - _chat_history=[], - _clear_force_continue_until_reply=lambda: None, - _record_reply_sent=lambda: None, - run_sub_agent=None, - ) - engine = SimpleNamespace(_get_runtime_manager=lambda: None) - tool_ctx = BuiltinToolRuntimeContext(engine=engine, runtime=runtime) - invocation = ToolInvocation(tool_name="reply", arguments={"msg_id": "msg-1", "set_quote": True}) - - result = await reply_tool_module.handle_tool(tool_ctx, invocation) - - assert result.success is True - assert result.metadata["monitor_detail"] == fake_monitor_detail - - -@pytest.mark.asyncio -async def test_send_emoji_tool_puts_monitor_detail_into_metadata(monkeypatch: pytest.MonkeyPatch) -> None: - async def _fake_build_emoji_candidate_message(emojis: list[Any]) -> object: - assert emojis - return SimpleNamespace() - - async def _fake_send_emoji_for_maisaka(**kwargs: Any) -> Any: - selected_emoji, matched_emotion = await kwargs["emoji_selector"]( - kwargs["requested_emotion"], - kwargs["reasoning"], - kwargs["context_texts"], - 2, - ) - assert selected_emoji is not None - return SimpleNamespace( - success=True, - message="已发送表情包:开心", - emoji_base64="ZW1vamk=", - description="开心", - emotions=["开心", "可爱"], - matched_emotion=matched_emotion or "开心", - sent_message=None, - ) - - monkeypatch.setattr(send_emoji_tool_module, "_build_emoji_candidate_message", _fake_build_emoji_candidate_message) - monkeypatch.setattr(send_emoji_tool_module, "send_emoji_for_maisaka", _fake_send_emoji_for_maisaka) - monkeypatch.setattr( - send_emoji_tool_module.emoji_manager, - "emojis", - [ - SimpleNamespace(description="开心,可爱", emotion=["开心", "可爱"]), - SimpleNamespace(description="难过", emotion=["难过"]), - ], - ) - - async def _fake_run_sub_agent(**kwargs: Any) -> Any: - del kwargs - return SimpleNamespace( - content='{"emoji_index": 1, "reason": "更贴合当前语气"}', - prompt_tokens=9, - completion_tokens=6, - total_tokens=15, - ) - - runtime = SimpleNamespace( - _chat_history=[], - log_prefix="[test]", - session_id="session-emoji", - run_sub_agent=_fake_run_sub_agent, - ) - engine = SimpleNamespace(last_reasoning_content="用户刚刚表达了开心情绪") - tool_ctx = BuiltinToolRuntimeContext(engine=engine, runtime=runtime) - invocation = ToolInvocation(tool_name="send_emoji", arguments={"emotion": "开心"}) - - result = await send_emoji_tool_module.handle_tool(tool_ctx, invocation) - - assert result.success is True - assert result.metadata["monitor_detail"]["prompt_text"] - assert result.metadata["monitor_detail"]["reasoning_text"] == "更贴合当前语气" - assert result.metadata["monitor_detail"]["metrics"]["total_tokens"] == 15 - assert any( - section["title"] == "表情发送结果" - for section in result.metadata["monitor_detail"]["extra_sections"] - ) - - -@pytest.mark.asyncio -async def test_emit_planner_finalized_broadcasts_new_protocol(monkeypatch: pytest.MonkeyPatch) -> None: - captured: dict[str, Any] = {} - - async def _fake_broadcast(event: str, data: dict[str, Any]) -> None: - captured["event"] = event - captured["data"] = data - - monkeypatch.setattr("src.maisaka.monitor_events._broadcast", _fake_broadcast) - - await emit_planner_finalized( - session_id="session-1", - cycle_id=3, - timing_request_messages=[{"role": "user", "content": "先看看要不要继续"}], - timing_selected_history_count=3, - timing_tool_count=1, - timing_action="continue", - timing_content="继续", - timing_tool_calls=[SimpleNamespace(call_id="timing-call-1", func_name="continue", args={})], - timing_tool_results=["- continue [成功]: 继续执行"], - timing_prompt_tokens=40, - timing_completion_tokens=5, - timing_total_tokens=45, - timing_duration_ms=11.2, - planner_request_messages=[{"role": "user", "content": "你好"}], - planner_selected_history_count=5, - planner_tool_count=2, - planner_content="先查询再回复", - planner_tool_calls=[SimpleNamespace(call_id="call-1", func_name="reply", args={"msg_id": "m1"})], - planner_prompt_tokens=100, - planner_completion_tokens=30, - planner_total_tokens=130, - planner_duration_ms=88.5, - tools=[ - { - "tool_call_id": "call-1", - "tool_name": "reply", - "tool_args": {"msg_id": "m1"}, - "success": True, - "duration_ms": 22.0, - "summary": "- reply [成功]: 已回复", - "detail": {"output_text": "测试回复"}, - } - ], - time_records={"planner": 0.1, "tool_calls": 0.2}, - agent_state="stop", - ) - - assert captured["event"] == "planner.finalized" - payload = captured["data"] - assert payload["timing_gate"]["result"]["action"] == "continue" - assert payload["timing_gate"]["result"]["tool_results"] == ["- continue [成功]: 继续执行"] - assert payload["request"]["messages"][0]["content"] == "你好" - assert payload["request"]["tool_count"] == 2 - assert payload["planner"]["tool_calls"][0]["id"] == "call-1" - assert payload["tools"][0]["detail"]["output_text"] == "测试回复" - assert payload["final_state"]["agent_state"] == "stop" - - -@pytest.mark.asyncio -async def test_emit_planner_finalized_supports_timing_only_cycle(monkeypatch: pytest.MonkeyPatch) -> None: - captured: dict[str, Any] = {} - - async def _fake_broadcast(event: str, data: dict[str, Any]) -> None: - captured["event"] = event - captured["data"] = data - - monkeypatch.setattr("src.maisaka.monitor_events._broadcast", _fake_broadcast) - - await emit_planner_finalized( - session_id="session-2", - cycle_id=7, - timing_request_messages=[{"role": "user", "content": "先别回"}], - timing_selected_history_count=2, - timing_tool_count=1, - timing_action="no_reply", - timing_content="当前不适合继续", - timing_tool_calls=[SimpleNamespace(call_id="timing-call-2", func_name="no_reply", args={})], - timing_tool_results=["- no_reply [成功]: 暂停当前对话"], - timing_prompt_tokens=18, - timing_completion_tokens=4, - timing_total_tokens=22, - timing_duration_ms=6.5, - planner_request_messages=None, - planner_selected_history_count=None, - planner_tool_count=None, - planner_content=None, - planner_tool_calls=None, - planner_prompt_tokens=None, - planner_completion_tokens=None, - planner_total_tokens=None, - planner_duration_ms=None, - tools=[], - time_records={"timing_gate": 0.02}, - agent_state="stop", - ) - - assert captured["event"] == "planner.finalized" - payload = captured["data"] - assert payload["timing_gate"]["result"]["action"] == "no_reply" - assert payload["planner"] is None - assert payload["request"] is None - - -def test_reasoning_engine_build_tool_monitor_result_keeps_non_reply_tool_without_detail() -> None: - engine = object.__new__(MaisakaReasoningEngine) - tool_call = SimpleNamespace(call_id="call-2", func_name="query_memory") - invocation = ToolInvocation(tool_name="query_memory", arguments={"query": "Alice"}) - result = ToolExecutionResult(tool_name="query_memory", success=True, content="查询成功") - - tool_result = engine._build_tool_monitor_result(tool_call, invocation, result, duration_ms=18.6) - - assert tool_result["tool_call_id"] == "call-2" - assert tool_result["tool_name"] == "query_memory" - assert tool_result["tool_args"] == {"query": "Alice"} - assert tool_result["detail"] is None - - -def test_runtime_build_tool_detail_panels_renders_reply_monitor_detail() -> None: - runtime = object.__new__(MaisakaHeartFlowChatting) - runtime.session_id = "session-1" - panels = runtime._build_tool_detail_cards( - [ - { - "tool_call_id": "call-reply-1", - "tool_name": "reply", - "tool_args": {"msg_id": "m1"}, - "success": True, - "duration_ms": 20.5, - "summary": "- reply [成功]: 已回复", - "detail": { - "prompt_text": "reply prompt", - "reasoning_text": "reply reasoning", - "output_text": "reply output", - "metrics": { - "model_name": "fake-model", - "prompt_tokens": 10, - "completion_tokens": 5, - "total_tokens": 15, - "prompt_ms": 2.1, - "llm_ms": 18.4, - "overall_ms": 20.5, - }, - }, - } - ], - stage_title="工具调用", - ) - - assert len(panels) == 1 - assert isinstance(panels[0], Panel) - - -def test_runtime_filter_redundant_tool_results_keeps_only_non_detailed_summary() -> None: - filtered_results = MaisakaHeartFlowChatting._filter_redundant_tool_results( - tool_results=[ - "- reply [成功]: 已回复", - "- query_memory [成功]: 查询到 2 条记录", - ], - tool_detail_results=[ - { - "summary": "- reply [成功]: 已回复", - "detail": {"output_text": "测试回复"}, - } - ], - ) - - assert filtered_results == ["- query_memory [成功]: 查询到 2 条记录"] - - -def test_runtime_build_tool_detail_panels_uses_prompt_access_panel(monkeypatch: pytest.MonkeyPatch) -> None: - runtime = object.__new__(MaisakaHeartFlowChatting) - runtime.session_id = "session-link" - captured: dict[str, Any] = {} - - def _fake_build_text_access_panel(content: str, **kwargs: Any) -> str: - captured["content"] = content - captured["kwargs"] = kwargs - return "PROMPT_LINK" - - monkeypatch.setattr( - "src.maisaka.runtime.PromptCLIVisualizer.build_text_access_panel", - _fake_build_text_access_panel, - ) - - panels = runtime._build_tool_detail_cards( - [ - { - "tool_call_id": "call-reply-2", - "tool_name": "reply", - "tool_args": {"msg_id": "m2"}, - "success": True, - "duration_ms": 12.0, - "summary": "- reply [成功]: 已回复", - "detail": { - "prompt_text": "reply prompt link", - "output_text": "reply output", - }, - } - ], - stage_title="工具调用", - ) - - assert len(panels) == 1 - assert captured["content"] == "reply prompt link" - assert captured["kwargs"]["chat_id"] == "session-link" - assert captured["kwargs"]["request_kind"] == "replyer" - - -def test_runtime_build_tool_detail_panels_uses_emotion_prompt_access_panel(monkeypatch: pytest.MonkeyPatch) -> None: - runtime = object.__new__(MaisakaHeartFlowChatting) - runtime.session_id = "session-emotion" - captured: dict[str, Any] = {} - - def _fake_build_text_access_panel(content: str, **kwargs: Any) -> str: - captured["content"] = content - captured["kwargs"] = kwargs - return "EMOTION_PROMPT_LINK" - - monkeypatch.setattr( - "src.maisaka.runtime.PromptCLIVisualizer.build_text_access_panel", - _fake_build_text_access_panel, - ) - - panels = runtime._build_tool_detail_cards( - [ - { - "tool_call_id": "call-emoji-1", - "tool_name": "send_emoji", - "tool_args": {"emotion": "开心"}, - "success": True, - "duration_ms": 15.0, - "summary": "- send_emoji [成功]: 已发送表情包", - "detail": { - "prompt_text": "emotion prompt link", - "output_text": '{"emoji_index": 1}', - }, - } - ], - stage_title="工具调用", - ) - - assert len(panels) == 1 - assert captured["content"] == "emotion prompt link" - assert captured["kwargs"]["chat_id"] == "session-emotion" - assert captured["kwargs"]["request_kind"] == "emotion" - - -def test_runtime_build_tool_detail_cards_uses_structured_prompt_messages_with_images( - monkeypatch: pytest.MonkeyPatch, -) -> None: - runtime = object.__new__(MaisakaHeartFlowChatting) - runtime.session_id = "session-image" - captured: dict[str, Any] = {} - - def _fake_build_prompt_access_panel(messages: list[Any], **kwargs: Any) -> str: - captured["messages"] = messages - captured["kwargs"] = kwargs - return "IMAGE_PROMPT_LINK" - - def _fake_build_text_access_panel(content: str, **kwargs: Any) -> str: - captured["text_content"] = content - captured["text_kwargs"] = kwargs - return "TEXT_PROMPT_LINK" - - monkeypatch.setattr( - "src.maisaka.runtime.PromptCLIVisualizer.build_prompt_access_panel", - _fake_build_prompt_access_panel, - ) - monkeypatch.setattr( - "src.maisaka.runtime.PromptCLIVisualizer.build_text_access_panel", - _fake_build_text_access_panel, - ) - - panels = runtime._build_tool_detail_cards( - [ - { - "tool_call_id": "call-reply-image-1", - "tool_name": "reply", - "tool_args": {"msg_id": "m3"}, - "success": True, - "duration_ms": 22.0, - "summary": "- reply [成功]: 已回复", - "detail": { - "prompt_text": "reply prompt image", - "request_messages": [ - { - "role": "user", - "content": ["前缀文本", ["png", "ZmFrZQ=="]], - } - ], - "output_text": "reply output", - }, - } - ], - stage_title="工具调用", - ) - - assert len(panels) == 1 - assert "messages" in captured - assert "text_content" not in captured - assert captured["kwargs"]["chat_id"] == "session-image" - assert captured["kwargs"]["request_kind"] == "replyer" - - -def test_runtime_render_context_usage_panel_merges_timing_and_planner(monkeypatch: pytest.MonkeyPatch) -> None: - runtime = object.__new__(MaisakaHeartFlowChatting) - runtime.session_id = "session-merged" - runtime.session_name = "测试聊天流" - runtime._max_context_size = 20 - - printed: list[Any] = [] - monkeypatch.setattr("src.maisaka.runtime.console.print", lambda renderable: printed.append(renderable)) - - runtime._render_context_usage_panel( - cycle_id=12, - timing_selected_history_count=3, - timing_prompt_tokens=15, - timing_action="continue", - timing_response="继续执行", - planner_selected_history_count=5, - planner_prompt_tokens=42, - planner_response="先查询再回复", - ) - - assert len(printed) == 1 - outer_panel = printed[0] - assert isinstance(outer_panel, Panel) - renderables = list(outer_panel.renderable.renderables) - assert isinstance(renderables[0], Text) - assert "聊天流名称:测试聊天流" in renderables[0].plain - assert "聊天流ID:session-merged" in renderables[0].plain - assert len(renderables) == 3 diff --git a/pytests/test_maisaka_timing_gate.py b/pytests/test_maisaka_timing_gate.py deleted file mode 100644 index 2722c3c4..00000000 --- a/pytests/test_maisaka_timing_gate.py +++ /dev/null @@ -1,339 +0,0 @@ -from datetime import datetime -from types import SimpleNamespace - -import asyncio -import pytest - -from src.core.tooling import ToolAvailabilityContext, ToolExecutionResult, ToolInvocation -from src.llm_models.payload_content.tool_option import ToolCall -from src.maisaka.builtin_tool import get_timing_tools -from src.maisaka.chat_loop_service import ChatResponse, MaisakaChatLoopService -from src.maisaka.context_messages import AssistantMessage, TIMING_GATE_INVALID_TOOL_HINT_SOURCE -from src.maisaka.reasoning_engine import MaisakaReasoningEngine -from src.maisaka.runtime import MaisakaHeartFlowChatting - - -def _build_chat_response(tool_calls: list[ToolCall]) -> ChatResponse: - return ChatResponse( - content="The model returned an invalid timing tool.", - tool_calls=tool_calls, - request_messages=[], - raw_message=AssistantMessage( - content="", - timestamp=datetime.now(), - source_kind="perception", - ), - selected_history_count=1, - tool_count=len(tool_calls), - prompt_tokens=10, - built_message_count=1, - completion_tokens=3, - total_tokens=13, - prompt_section=None, - ) - - -def _build_runtime_stub(*, is_group_chat: bool) -> SimpleNamespace: - return SimpleNamespace( - _force_next_timing_continue=False, - _chat_history=[], - session_id="test-session", - chat_stream=SimpleNamespace( - session_id="test-session", - stream_id="test-stream", - is_group_session=is_group_chat, - group_id="group-1" if is_group_chat else "", - user_id="user-1", - platform="qq", - ), - _chat_loop_service=SimpleNamespace(build_prompt_template_context=lambda: {}), - log_prefix="[test]", - stopped=False, - ) - - -def test_timing_gate_tools_expose_wait_only_in_private_chat() -> None: - private_tool_names = { - tool_definition["name"] - for tool_definition in get_timing_tools(ToolAvailabilityContext(is_group_chat=False)) - } - group_tool_names = { - tool_definition["name"] - for tool_definition in get_timing_tools(ToolAvailabilityContext(is_group_chat=True)) - } - - assert private_tool_names == {"continue", "no_reply", "wait"} - assert group_tool_names == {"continue", "no_reply"} - - -@pytest.mark.asyncio -async def test_timing_gate_invalid_tool_defaults_to_no_reply(monkeypatch: pytest.MonkeyPatch) -> None: - runtime = _build_runtime_stub(is_group_chat=True) - - def _enter_stop_state() -> None: - runtime.stopped = True - - runtime._enter_stop_state = _enter_stop_state - engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type] - - call_count = 0 - - async def _fake_timing_gate_sub_agent(**kwargs: object) -> ChatResponse: - nonlocal call_count - del kwargs - call_count += 1 - return _build_chat_response([ - ToolCall(call_id="invalid-timing-tool", func_name="finish", args={}), - ]) - - async def _fail_invoke_tool_call(*args: object, **kwargs: object) -> None: - del args, kwargs - raise AssertionError("invalid timing tools must not be executed") - - monkeypatch.setattr(engine, "_run_timing_gate_sub_agent", _fake_timing_gate_sub_agent) - monkeypatch.setattr(engine, "_invoke_tool_call", _fail_invoke_tool_call) - - action, response, tool_results, tool_monitor_results = await engine._run_timing_gate(object()) # type: ignore[arg-type] - - assert action == "no_reply" - assert call_count == 3 - assert response.tool_calls[0].func_name == "finish" - assert runtime.stopped is True - assert tool_monitor_results == [] - assert len(runtime._chat_history) == 1 - assert runtime._chat_history[0].source == TIMING_GATE_INVALID_TOOL_HINT_SOURCE - assert "finish" in runtime._chat_history[0].processed_plain_text - assert tool_results == [ - "- retry [非法 Timing 工具]: 返回了 finish,将重试 (1/3)", - "- retry [非法 Timing 工具]: 返回了 finish,将重试 (2/3)", - "- no_reply [非法 Timing 工具]: 返回了 finish,已停止本轮并等待新消息", - ] - - -@pytest.mark.asyncio -async def test_timing_gate_invalid_tool_retries_until_valid(monkeypatch: pytest.MonkeyPatch) -> None: - runtime = _build_runtime_stub(is_group_chat=True) - - def _enter_stop_state() -> None: - runtime.stopped = True - - runtime._enter_stop_state = _enter_stop_state - engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type] - responses = [ - _build_chat_response([ToolCall(call_id="invalid-timing-tool", func_name="finish", args={})]), - _build_chat_response([ToolCall(call_id="valid-timing-tool", func_name="continue", args={})]), - ] - - async def _fake_timing_gate_sub_agent(**kwargs: object) -> ChatResponse: - del kwargs - return responses.pop(0) - - async def _fake_invoke_tool_call( - tool_call: ToolCall, - latest_thought: str, - anchor_message: object, - *, - append_history: bool = True, - store_record: bool = True, - ) -> tuple[ToolInvocation, ToolExecutionResult, None]: - del latest_thought, anchor_message, append_history, store_record - return ( - ToolInvocation(tool_name=tool_call.func_name, call_id=tool_call.call_id), - ToolExecutionResult( - tool_name=tool_call.func_name, - success=True, - content="继续执行主流程", - metadata={"timing_action": "continue"}, - ), - None, - ) - - monkeypatch.setattr(engine, "_run_timing_gate_sub_agent", _fake_timing_gate_sub_agent) - monkeypatch.setattr(engine, "_invoke_tool_call", _fake_invoke_tool_call) - - action, response, tool_results, tool_monitor_results = await engine._run_timing_gate(object()) # type: ignore[arg-type] - - assert action == "continue" - assert response.tool_calls[0].func_name == "continue" - assert runtime.stopped is False - assert len(runtime._chat_history) == 2 - assert all(message.source != TIMING_GATE_INVALID_TOOL_HINT_SOURCE for message in runtime._chat_history) - assert tool_results == [ - "- retry [非法 Timing 工具]: 返回了 finish,将重试 (1/3)", - "- continue [成功]: 继续执行主流程", - ] - assert tool_monitor_results[0]["tool_name"] == "continue" - - -@pytest.mark.asyncio -async def test_timing_gate_group_chat_treats_wait_as_invalid(monkeypatch: pytest.MonkeyPatch) -> None: - runtime = _build_runtime_stub(is_group_chat=True) - - def _enter_stop_state() -> None: - runtime.stopped = True - - runtime._enter_stop_state = _enter_stop_state - engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type] - - async def _fake_timing_gate_sub_agent(**kwargs: object) -> ChatResponse: - tool_definitions = kwargs["tool_definitions"] - assert {tool_definition["name"] for tool_definition in tool_definitions} == {"continue", "no_reply"} - return _build_chat_response([ - ToolCall(call_id="disabled-wait", func_name="wait", args={"seconds": 3}), - ]) - - async def _fail_invoke_tool_call(*args: object, **kwargs: object) -> None: - del args, kwargs - raise AssertionError("群聊中禁用的 wait 不应被执行") - - monkeypatch.setattr(engine, "_run_timing_gate_sub_agent", _fake_timing_gate_sub_agent) - monkeypatch.setattr(engine, "_invoke_tool_call", _fail_invoke_tool_call) - - action, _, tool_results, _ = await engine._run_timing_gate(object()) # type: ignore[arg-type] - - assert action == "no_reply" - assert runtime.stopped is True - assert tool_results[-1] == "- no_reply [非法 Timing 工具]: 返回了 wait,已停止本轮并等待新消息" - - -def test_timing_gate_invalid_tool_hint_keeps_only_latest() -> None: - old_hint = SimpleNamespace(source=TIMING_GATE_INVALID_TOOL_HINT_SOURCE) - runtime = SimpleNamespace(_chat_history=[old_hint]) - engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type] - - engine._append_timing_gate_invalid_tool_hint("finish") - engine._append_timing_gate_invalid_tool_hint("reply") - - assert len(runtime._chat_history) == 1 - hint_message = runtime._chat_history[0] - assert hint_message.source == TIMING_GATE_INVALID_TOOL_HINT_SOURCE - assert "reply" in hint_message.processed_plain_text - assert "finish" not in hint_message.processed_plain_text - - -def test_timing_gate_invalid_tool_hint_only_visible_to_timing_gate() -> None: - runtime = SimpleNamespace(_chat_history=[]) - engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type] - engine._append_timing_gate_invalid_tool_hint("finish") - hint_message = runtime._chat_history[0] - - timing_history = MaisakaChatLoopService._filter_history_for_request_kind( - [hint_message], - request_kind="timing_gate", - ) - planner_history = MaisakaChatLoopService._filter_history_for_request_kind( - [hint_message], - request_kind="planner", - ) - - assert timing_history == [hint_message] - assert planner_history == [] - - -def test_forced_timing_trigger_bypasses_message_frequency_threshold() -> None: - runtime = SimpleNamespace( - _STATE_WAIT="wait", - _agent_state="stop", - _message_turn_scheduled=False, - _internal_turn_queue=asyncio.Queue(), - _has_pending_messages=lambda: True, - _get_pending_message_count=lambda: 1, - _has_forced_timing_trigger=lambda: True, - _cancel_deferred_message_turn_task=lambda: None, - ) - - def _fail_get_message_trigger_threshold() -> int: - raise AssertionError("@/提及必回不应被普通聊天频率阈值拦住") - - runtime._get_message_trigger_threshold = _fail_get_message_trigger_threshold - - MaisakaHeartFlowChatting._schedule_message_turn(runtime) # type: ignore[arg-type] - - assert runtime._message_turn_scheduled is True - assert runtime._internal_turn_queue.get_nowait() == "message" - - -def test_finish_tool_is_not_written_back_to_history() -> None: - finish_call = ToolCall(call_id="finish-call", func_name="finish", args={}) - reply_call = ToolCall(call_id="reply-call", func_name="reply", args={}) - assistant_message = AssistantMessage( - content="当前不需要继续回复。", - timestamp=datetime.now(), - tool_calls=[finish_call, reply_call], - ) - runtime = SimpleNamespace(_chat_history=[assistant_message]) - engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type] - - engine._append_tool_execution_result( - finish_call, - ToolExecutionResult( - tool_name="finish", - success=True, - content="当前对话循环已结束本轮思考,等待新的消息到来。", - ), - ) - - assert runtime._chat_history == [assistant_message] - assert [tool_call.func_name for tool_call in assistant_message.tool_calls] == ["reply"] - - -def test_finish_tool_removes_empty_assistant_history_message() -> None: - finish_call = ToolCall(call_id="finish-call", func_name="finish", args={}) - assistant_message = AssistantMessage( - content="", - timestamp=datetime.now(), - tool_calls=[finish_call], - ) - runtime = SimpleNamespace(_chat_history=[assistant_message]) - engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type] - - engine._append_tool_execution_result( - finish_call, - ToolExecutionResult(tool_name="finish", success=True), - ) - - assert runtime._chat_history == [] - - -def test_timing_gate_head_trim_keeps_short_history() -> None: - messages = [ - AssistantMessage(content="第一条消息", timestamp=datetime.now()), - AssistantMessage(content="第二条消息", timestamp=datetime.now()), - ] - - trimmed_messages = MaisakaHeartFlowChatting._drop_head_context_messages( - messages, - drop_context_count=3, - ) - - assert trimmed_messages == messages - - -def test_timing_gate_head_trim_keeps_history_within_config_limit() -> None: - messages = [ - AssistantMessage(content=f"消息 {index}", timestamp=datetime.now()) - for index in range(10) - ] - - trimmed_messages = MaisakaHeartFlowChatting._drop_head_context_messages( - messages, - drop_context_count=7, - trim_threshold_context_count=10, - ) - - assert trimmed_messages == messages - - -def test_timing_gate_head_trim_applies_after_config_limit_exceeded() -> None: - messages = [ - AssistantMessage(content=f"消息 {index}", timestamp=datetime.now()) - for index in range(11) - ] - - trimmed_messages = MaisakaHeartFlowChatting._drop_head_context_messages( - messages, - drop_context_count=7, - trim_threshold_context_count=10, - ) - - assert trimmed_messages == messages[7:] diff --git a/pytests/test_message_gateway_runtime.py b/pytests/test_message_gateway_runtime.py deleted file mode 100644 index 9650bc10..00000000 --- a/pytests/test_message_gateway_runtime.py +++ /dev/null @@ -1,170 +0,0 @@ -"""消息网关运行时状态同步测试。""" - -from typing import Any, Dict - -import pytest - -from src.platform_io.manager import PlatformIOManager -from src.platform_io.types import RouteKey -from src.plugin_runtime.host.supervisor import PluginSupervisor -from src.plugin_runtime.protocol.envelope import Envelope, MessageType - - -def _make_request(method: str, plugin_id: str, payload: Dict[str, Any]) -> Envelope: - """构造一个 RPC 请求信封。 - - Args: - method: RPC 方法名。 - plugin_id: 目标插件 ID。 - payload: 请求载荷。 - - Returns: - Envelope: 标准 RPC 请求信封。 - """ - - return Envelope( - request_id=1, - message_type=MessageType.REQUEST, - method=method, - plugin_id=plugin_id, - payload=payload, - ) - - -@pytest.mark.asyncio -async def test_message_gateway_runtime_state_binds_send_and_receive_routes( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """消息网关就绪后应同时绑定发送表和接收表。""" - - import src.plugin_runtime.host.supervisor as supervisor_module - - platform_io_manager = PlatformIOManager() - monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager) - - supervisor = PluginSupervisor(plugin_dirs=[]) - register_response = await supervisor._handle_register_plugin( - _make_request( - "plugin.register_components", - "napcat_plugin", - { - "plugin_id": "napcat_plugin", - "plugin_version": "1.0.0", - "components": [ - { - "name": "napcat_gateway", - "component_type": "MESSAGE_GATEWAY", - "plugin_id": "napcat_plugin", - "metadata": { - "route_type": "duplex", - "platform": "qq", - "protocol": "napcat", - }, - } - ], - "capabilities_required": [], - }, - ) - ) - - assert register_response.error is None - response = await supervisor._handle_update_message_gateway_state( - _make_request( - "host.update_message_gateway_state", - "napcat_plugin", - { - "gateway_name": "napcat_gateway", - "ready": True, - "platform": "qq", - "account_id": "10001", - "scope": "primary", - "metadata": {}, - }, - ) - ) - - assert response.error is None - assert response.payload["accepted"] is True - - send_bindings = platform_io_manager.send_route_table.resolve_bindings( - RouteKey(platform="qq", account_id="10001", scope="primary") - ) - receive_bindings = platform_io_manager.receive_route_table.resolve_bindings( - RouteKey(platform="qq", account_id="10001", scope="primary") - ) - - assert [binding.driver_id for binding in send_bindings] == ["gateway:napcat_plugin:napcat_gateway"] - assert [binding.driver_id for binding in receive_bindings] == ["gateway:napcat_plugin:napcat_gateway"] - - -@pytest.mark.asyncio -async def test_message_gateway_runtime_state_unbinds_routes_when_not_ready( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """消息网关断开后应撤销发送表和接收表中的绑定。""" - - import src.plugin_runtime.host.supervisor as supervisor_module - - platform_io_manager = PlatformIOManager() - monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager) - - supervisor = PluginSupervisor(plugin_dirs=[]) - await supervisor._handle_register_plugin( - _make_request( - "plugin.register_components", - "napcat_plugin", - { - "plugin_id": "napcat_plugin", - "plugin_version": "1.0.0", - "components": [ - { - "name": "napcat_gateway", - "component_type": "MESSAGE_GATEWAY", - "plugin_id": "napcat_plugin", - "metadata": { - "route_type": "duplex", - "platform": "qq", - "protocol": "napcat", - }, - } - ], - "capabilities_required": [], - }, - ) - ) - - await supervisor._handle_update_message_gateway_state( - _make_request( - "host.update_message_gateway_state", - "napcat_plugin", - { - "gateway_name": "napcat_gateway", - "ready": True, - "platform": "qq", - "account_id": "10001", - "scope": "primary", - "metadata": {}, - }, - ) - ) - response = await supervisor._handle_update_message_gateway_state( - _make_request( - "host.update_message_gateway_state", - "napcat_plugin", - { - "gateway_name": "napcat_gateway", - "ready": False, - "platform": "qq", - "account_id": "", - "scope": "", - "metadata": {}, - }, - ) - ) - - assert response.error is None - assert response.payload["accepted"] is True - assert platform_io_manager.send_route_table.resolve_bindings(RouteKey(platform="qq", account_id="10001")) == [] - assert ( - platform_io_manager.receive_route_table.resolve_bindings(RouteKey(platform="qq", account_id="10001")) == [] - ) diff --git a/pytests/test_napcat_adapter_sdk.py b/pytests/test_napcat_adapter_sdk.py deleted file mode 100644 index f53408b9..00000000 --- a/pytests/test_napcat_adapter_sdk.py +++ /dev/null @@ -1,883 +0,0 @@ -"""NapCat 插件与新 SDK 对接测试。""" - -from importlib import import_module, util -from pathlib import Path -from typing import Any, Dict, List, Tuple - -import logging -import sys -from types import SimpleNamespace - -import pytest - -PROJECT_ROOT = Path(__file__).resolve().parents[1] -PLUGINS_ROOT = PROJECT_ROOT / "plugins" -PLUGIN_TEMPLATE_ROOT = PROJECT_ROOT / "plugin-templates" -SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk" -NAPCAT_PLUGIN_DIR = PLUGINS_ROOT / "MaiBot-Napcat-Adapter" -NAPCAT_TEMPLATE_DIR = PLUGIN_TEMPLATE_ROOT / "MaiBot-Napcat-Adapter" -NAPCAT_TEST_MODULE = "_test_napcat_adapter" - -for import_path in (str(SDK_ROOT),): - if import_path not in sys.path: - sys.path.insert(0, import_path) - - -class _FakeGatewayCapability: - """用于捕获消息网关状态上报的测试替身。""" - - def __init__(self) -> None: - """初始化测试替身。""" - - self.calls: List[Dict[str, Any]] = [] - - async def update_state( - self, - gateway_name: str, - *, - ready: bool, - platform: str = "", - account_id: str = "", - scope: str = "", - metadata: Dict[str, Any] | None = None, - ) -> bool: - """记录一次状态上报请求。 - - Args: - gateway_name: 网关组件名称。 - ready: 当前是否就绪。 - platform: 平台名称。 - account_id: 账号 ID。 - scope: 路由作用域。 - metadata: 附加元数据。 - - Returns: - bool: 始终返回 ``True``,模拟 Host 接受状态更新。 - """ - - self.calls.append( - { - "gateway_name": gateway_name, - "ready": ready, - "platform": platform, - "account_id": account_id, - "scope": scope, - "metadata": metadata or {}, - } - ) - return True - - -class _FakeNapCatQueryService: - """用于驱动 NapCat 入站编解码测试的查询服务替身。""" - - def __init__( - self, - forward_payloads: Dict[str, Any] | None = None, - group_member_payloads: Dict[tuple[str, str], Dict[str, Any] | None] | None = None, - stranger_payloads: Dict[str, Dict[str, Any] | None] | None = None, - ) -> None: - """初始化查询服务替身。 - - Args: - forward_payloads: 预置的合并转发响应映射。 - group_member_payloads: 预置的群成员资料映射。 - stranger_payloads: 预置的陌生人资料映射。 - """ - self._forward_payloads = forward_payloads or {} - self._group_member_payloads = group_member_payloads or {} - self._stranger_payloads = stranger_payloads or {} - - async def download_binary(self, url: str) -> bytes | None: - """模拟下载远程二进制资源。 - - Args: - url: 资源地址。 - - Returns: - bytes | None: 测试中默认不返回二进制内容。 - """ - del url - return None - - async def get_message_detail(self, message_id: str) -> Dict[str, Any] | None: - """模拟获取消息详情。 - - Args: - message_id: 消息 ID。 - - Returns: - Dict[str, Any] | None: 测试中默认不返回详情。 - """ - del message_id - return None - - async def get_forward_message( - self, - message_id: str | None = None, - forward_id: str | None = None, - ) -> Any: - """模拟获取合并转发消息详情。 - - Args: - message_id: 转发消息 ID。 - forward_id: 兼容字段 ``id``。 - - Returns: - Any: 预置的合并转发消息详情。 - """ - return self._forward_payloads.get(forward_id or message_id or "") - - async def get_group_member_info( - self, - group_id: str, - user_id: str, - no_cache: bool = True, - ) -> Dict[str, Any] | None: - """模拟获取群成员资料。""" - del no_cache - return self._group_member_payloads.get((group_id, user_id)) - - async def get_stranger_info(self, user_id: str, no_cache: bool = False) -> Dict[str, Any] | None: - """模拟获取 QQ 昵称资料。""" - del no_cache - return self._stranger_payloads.get(user_id) - - async def get_record_detail( - self, - file_name: str | None = None, - file_id: str | None = None, - out_format: str = "wav", - ) -> Dict[str, Any] | None: - """模拟获取语音详情。 - - Args: - file_name: 文件名。 - file_id: 文件 ID。 - out_format: 输出格式。 - - Returns: - Dict[str, Any] | None: 测试中默认不返回语音详情。 - """ - del file_name - del file_id - del out_format - return None - - -class _FakeNapCatActionService: - """用于驱动 NapCat 查询服务测试的动作服务替身。""" - - def __init__(self, response_data: Any) -> None: - """初始化动作服务替身。 - - Args: - response_data: 预置的 ``safe_call_action_data`` 返回值。 - """ - self._response_data = response_data - self.action_calls: List[Dict[str, Any]] = [] - self.action_data_calls: List[Dict[str, Any]] = [] - - async def safe_call_action_data(self, action_name: str, params: Dict[str, Any]) -> Any: - """模拟安全调用 OneBot 动作。 - - Args: - action_name: 动作名称。 - params: 动作参数。 - - Returns: - Any: 预置返回值。 - """ - self.action_data_calls.append({"action_name": action_name, "params": dict(params)}) - return self._response_data - - async def call_action(self, action_name: str, params: Dict[str, Any]) -> Dict[str, Any]: - """模拟调用 OneBot 动作并记录参数。""" - - self.action_calls.append({"action_name": action_name, "params": dict(params)}) - return {"status": "ok", "retcode": 0, "data": {}} - - -def _load_napcat_sdk_modules() -> Tuple[Any, Any, Any, Any]: - """动态加载 NapCat 插件测试所需的模块。 - - Returns: - tuple[Any, Any, Any, Any]: - 依次返回常量模块、配置模块、插件模块和运行时状态模块。 - """ - - plugin_dir = NAPCAT_PLUGIN_DIR if NAPCAT_PLUGIN_DIR.is_dir() else NAPCAT_TEMPLATE_DIR - - if NAPCAT_TEST_MODULE not in sys.modules: - plugin_path = plugin_dir / "plugin.py" - spec = util.spec_from_file_location( - NAPCAT_TEST_MODULE, - plugin_path, - submodule_search_locations=[str(plugin_dir)], - ) - if spec is None or spec.loader is None: - raise ImportError(f"无法为 NapCat 插件创建模块规格: {plugin_path}") - - module = util.module_from_spec(spec) - sys.modules[NAPCAT_TEST_MODULE] = module - try: - spec.loader.exec_module(module) - except Exception: - sys.modules.pop(NAPCAT_TEST_MODULE, None) - raise - - return ( - import_module(f"{NAPCAT_TEST_MODULE}.constants"), - import_module(f"{NAPCAT_TEST_MODULE}.config"), - import_module(f"{NAPCAT_TEST_MODULE}.plugin"), - import_module(f"{NAPCAT_TEST_MODULE}.runtime_state"), - ) - - -def _load_napcat_sdk_symbols() -> Tuple[Any, Any, Any, Any]: - """动态加载 NapCat 插件测试所需的符号。 - - Returns: - tuple[Any, Any, Any, Any]: - 依次返回网关名常量、配置类、插件类和运行时状态管理器类。 - """ - - constants_module, config_module, plugin_module, runtime_state_module = _load_napcat_sdk_modules() - return ( - constants_module.NAPCAT_GATEWAY_NAME, - config_module.NapCatServerConfig, - plugin_module.NapCatAdapterPlugin, - runtime_state_module.NapCatRuntimeStateManager, - ) - - -def _load_napcat_inbound_codec_cls() -> Any: - """动态加载 NapCat 入站编解码器类。 - - Returns: - Any: ``NapCatInboundCodec`` 类对象。 - """ - _load_napcat_sdk_modules() - codec_module = import_module(f"{NAPCAT_TEST_MODULE}.codecs.inbound.message_codec") - return codec_module.NapCatInboundCodec - - -def _load_napcat_query_service_cls() -> Any: - """动态加载 NapCat 查询服务类。 - - Returns: - Any: ``NapCatQueryService`` 类对象。 - """ - _load_napcat_sdk_modules() - query_service_module = import_module(f"{NAPCAT_TEST_MODULE}.services.query_service") - return query_service_module.NapCatQueryService - - -def test_napcat_plugin_collects_duplex_message_gateway() -> None: - """NapCat 插件应声明新的双工消息网关组件。""" - - napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols() - plugin = napcat_plugin_cls() - components = plugin.get_components() - gateway_components = [ - component - for component in components - if component.get("type") == "MESSAGE_GATEWAY" - ] - - assert len(gateway_components) == 1 - gateway_component = gateway_components[0] - assert gateway_component["name"] == napcat_gateway_name - assert gateway_component["metadata"]["route_type"] == "duplex" - assert gateway_component["metadata"]["platform"] == "qq" - assert gateway_component["metadata"]["protocol"] == "napcat" - - -def test_napcat_plugin_uses_sdk_config_model() -> None: - """NapCat 插件应声明 SDK 配置模型并暴露默认配置与 Schema。""" - - constants_module, _config_module, plugin_module, _runtime_state_module = _load_napcat_sdk_modules() - plugin = plugin_module.NapCatAdapterPlugin() - - default_config = plugin.get_default_config() - schema = plugin.get_webui_config_schema(plugin_id="maibot-team.napcat-adapter") - - assert default_config["plugin"]["config_version"] == constants_module.SUPPORTED_CONFIG_VERSION - assert default_config["chat"]["ban_qq_bot"] is False - assert default_config["filters"]["ignore_self_message"] is True - assert schema["plugin_id"] == "maibot-team.napcat-adapter" - assert schema["sections"]["chat"]["fields"]["group_list"]["type"] == "array" - assert schema["sections"]["chat"]["fields"]["group_list_type"]["choices"] == ["whitelist", "blacklist"] - - -def test_napcat_plugin_normalizes_legacy_config_values() -> None: - """NapCat 插件应兼容旧配置字段并输出规范化结果。""" - - constants_module, _config_module, plugin_module, _runtime_state_module = _load_napcat_sdk_modules() - plugin = plugin_module.NapCatAdapterPlugin() - - plugin.set_plugin_config( - { - "plugin": {"enabled": True, "config_version": constants_module.SUPPORTED_CONFIG_VERSION}, - "connection": { - "access_token": "secret-token", - "heartbeat_sec": "45", - "ws_url": "ws://10.0.0.8:3012/onebot/v11/ws", - }, - "chat": { - "ban_qq_bot": True, - "ban_user_id": ["42", 42, ""], - "group_list": [123, " 456 ", None, "123"], - "group_list_type": "whitelist", - "private_list": "invalid", - "private_list_type": "unexpected", - }, - "filters": {"ignore_self_message": True}, - } - ) - - config_data = plugin.get_plugin_config_data() - - assert "connection" not in config_data - assert config_data["plugin"]["config_version"] == constants_module.SUPPORTED_CONFIG_VERSION - assert config_data["napcat_server"]["host"] == "10.0.0.8" - assert config_data["napcat_server"]["port"] == 3012 - assert config_data["napcat_server"]["token"] == "secret-token" - assert config_data["napcat_server"]["heartbeat_interval"] == 45.0 - assert config_data["chat"]["group_list"] == ["123", "456"] - assert config_data["chat"]["private_list"] == [] - assert config_data["chat"]["private_list_type"] == constants_module.DEFAULT_CHAT_LIST_TYPE - assert plugin.config.napcat_server.build_ws_url() == "ws://10.0.0.8:3012" - - -@pytest.mark.asyncio -async def test_runtime_state_reports_via_gateway_capability() -> None: - """NapCat 运行时状态应通过新的消息网关能力上报。""" - - napcat_gateway_name, napcat_server_config_cls, _napcat_plugin_cls, runtime_state_cls = _load_napcat_sdk_symbols() - gateway_capability = _FakeGatewayCapability() - runtime_state_manager = runtime_state_cls( - gateway_capability=gateway_capability, - logger=logging.getLogger("test.napcat_adapter"), - gateway_name=napcat_gateway_name, - ) - - connected = await runtime_state_manager.report_connected( - "10001", - napcat_server_config_cls(connection_id="primary"), - ) - await runtime_state_manager.report_disconnected() - - assert connected is True - assert gateway_capability.calls[0]["gateway_name"] == napcat_gateway_name - assert gateway_capability.calls[0]["ready"] is True - assert gateway_capability.calls[0]["platform"] == "qq" - assert gateway_capability.calls[0]["account_id"] == "10001" - assert gateway_capability.calls[0]["scope"] == "primary" - assert gateway_capability.calls[1]["gateway_name"] == napcat_gateway_name - assert gateway_capability.calls[1]["ready"] is False - assert gateway_capability.calls[1]["platform"] == "qq" - - -@pytest.mark.asyncio -async def test_napcat_plugin_send_result_contains_message_id_echo_callback() -> None: - """NapCat 插件发送成功后应显式返回消息 ID 回调数据。""" - - _napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols() - plugin = napcat_plugin_cls() - - class _FakeOutboundCodec: - """用于测试的出站编码器替身。""" - - @staticmethod - def build_outbound_action(message: Dict[str, Any], route: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]: - """返回固定动作与参数。""" - - del message - del route - return "send_msg", {"message": "hello"} - - class _FakeTransport: - """用于测试的传输层替身。""" - - @staticmethod - async def call_action(action_name: str, params: Dict[str, Any]) -> Dict[str, Any]: - """返回带平台消息 ID 的成功响应。""" - - del action_name - del params - return { - "status": "ok", - "data": { - "message_id": "platform-message-id", - }, - } - - plugin._require_runtime_bundle = lambda: SimpleNamespace( # type: ignore[method-assign] - outbound_codec=_FakeOutboundCodec(), - transport=_FakeTransport(), - ) - - result = await plugin.handle_napcat_gateway( - message={"message_id": "internal-message-id"}, - route={}, - ) - - assert result["success"] is True - assert result["external_message_id"] == "platform-message-id" - assert result["metadata"]["adapter_callbacks"] == [ - { - "name": "message_id_echo", - "payload": { - "content": { - "type": "echo", - "echo": "internal-message-id", - "actual_id": "platform-message-id", - } - }, - } - ] - - -@pytest.mark.asyncio -async def test_inbound_codec_parses_forward_nodes_from_legacy_message_field() -> None: - """入站编解码器应兼容旧版 ``sender + message`` 转发节点结构。""" - - inbound_codec_cls = _load_napcat_inbound_codec_cls() - codec = inbound_codec_cls( - logger=logging.getLogger("test.napcat_adapter.forward_legacy"), - query_service=_FakeNapCatQueryService( - forward_payloads={ - "forward-1": { - "messages": [ - { - "sender": {"user_id": "10001", "nickname": "张三", "card": "群名片"}, - "message_id": "node-1", - "message": [{"type": "text", "data": {"text": "第一条转发"}}], - } - ] - } - } - ), - ) - - segments, is_at = await codec.convert_segments( - {"message": [{"type": "forward", "data": {"id": "forward-1"}}]}, - "", - ) - - assert is_at is False - assert len(segments) == 1 - assert segments[0]["type"] == "forward" - assert segments[0]["data"][0]["user_id"] == "10001" - assert segments[0]["data"][0]["user_nickname"] == "张三" - assert segments[0]["data"][0]["user_cardname"] == "群名片" - assert segments[0]["data"][0]["content"] == [{"type": "text", "data": "第一条转发"}] - - -@pytest.mark.asyncio -async def test_inbound_codec_parses_nested_inline_forward_content() -> None: - """入站编解码器应支持内联 ``content`` 形式的嵌套合并转发。""" - - inbound_codec_cls = _load_napcat_inbound_codec_cls() - codec = inbound_codec_cls( - logger=logging.getLogger("test.napcat_adapter.forward_nested"), - query_service=_FakeNapCatQueryService( - forward_payloads={ - "forward-outer": { - "messages": [ - { - "sender": {"user_id": "10001", "nickname": "张三"}, - "message_id": "node-outer", - "message": [ - { - "type": "forward", - "data": { - "content": [ - { - "sender": {"user_id": "10002", "nickname": "李四"}, - "message_id": "node-inner", - "message": [{"type": "text", "data": {"text": "内层消息"}}], - } - ] - }, - } - ], - } - ] - } - } - ), - ) - - segments, _ = await codec.convert_segments( - {"message": [{"type": "forward", "data": {"id": "forward-outer"}}]}, - "", - ) - - assert len(segments) == 1 - assert segments[0]["type"] == "forward" - outer_content = segments[0]["data"][0]["content"] - assert len(outer_content) == 1 - assert outer_content[0]["type"] == "forward" - nested_nodes = outer_content[0]["data"] - assert nested_nodes[0]["user_id"] == "10002" - assert nested_nodes[0]["user_nickname"] == "李四" - assert nested_nodes[0]["content"] == [{"type": "text", "data": "内层消息"}] - - -@pytest.mark.asyncio -async def test_inbound_codec_resolves_at_to_group_cardname() -> None: - """入站编解码器应优先将 ``at`` 解析为群昵称。""" - - inbound_codec_cls = _load_napcat_inbound_codec_cls() - codec = inbound_codec_cls( - logger=logging.getLogger("test.napcat_adapter.at_cardname"), - query_service=_FakeNapCatQueryService( - group_member_payloads={ - ("12345", "1206069534"): { - "nickname": "QQ昵称", - "card": "群昵称", - } - } - ), - ) - - message_dict = await codec.build_message_dict( - payload={ - "message_type": "group", - "group_id": "12345", - "message_id": "msg-1", - "message": [{"type": "at", "data": {"qq": "1206069534"}}], - "sender": {"user_id": "10001", "nickname": "发送者"}, - "time": 1710000000, - }, - self_id="20001", - sender_user_id="10001", - sender={"user_id": "10001", "nickname": "发送者"}, - ) - - assert message_dict["processed_plain_text"] == "@群昵称" - assert message_dict["raw_message"] == [ - { - "type": "at", - "data": { - "target_user_id": "1206069534", - "target_user_nickname": "QQ昵称", - "target_user_cardname": "群昵称", - }, - } - ] - - -@pytest.mark.asyncio -async def test_inbound_codec_falls_back_to_qq_nickname_when_group_cardname_is_empty() -> None: - """入站编解码器在群昵称为空时应回退到 QQ 昵称。""" - - inbound_codec_cls = _load_napcat_inbound_codec_cls() - codec = inbound_codec_cls( - logger=logging.getLogger("test.napcat_adapter.at_nickname"), - query_service=_FakeNapCatQueryService( - group_member_payloads={ - ("12345", "1206069534"): { - "nickname": "QQ昵称", - "card": "", - } - } - ), - ) - - message_dict = await codec.build_message_dict( - payload={ - "message_type": "group", - "group_id": "12345", - "message_id": "msg-2", - "message": [{"type": "at", "data": {"qq": "1206069534"}}], - "sender": {"user_id": "10001", "nickname": "发送者"}, - "time": 1710000000, - }, - self_id="20001", - sender_user_id="10001", - sender={"user_id": "10001", "nickname": "发送者"}, - ) - - assert message_dict["processed_plain_text"] == "@QQ昵称" - assert message_dict["raw_message"] == [ - { - "type": "at", - "data": { - "target_user_id": "1206069534", - "target_user_nickname": "QQ昵称", - "target_user_cardname": None, - }, - } - ] - - -@pytest.mark.asyncio -async def test_inbound_codec_falls_back_to_stranger_nickname_when_group_profile_is_missing() -> None: - """入站编解码器在群资料缺失时应继续回退到 QQ 昵称。""" - - inbound_codec_cls = _load_napcat_inbound_codec_cls() - codec = inbound_codec_cls( - logger=logging.getLogger("test.napcat_adapter.at_stranger_nickname"), - query_service=_FakeNapCatQueryService( - group_member_payloads={("12345", "1206069534"): None}, - stranger_payloads={"1206069534": {"nickname": "QQ昵称"}}, - ), - ) - - message_dict = await codec.build_message_dict( - payload={ - "message_type": "group", - "group_id": "12345", - "message_id": "msg-3", - "message": [{"type": "at", "data": {"qq": "1206069534"}}], - "sender": {"user_id": "10001", "nickname": "发送者"}, - "time": 1710000000, - }, - self_id="20001", - sender_user_id="10001", - sender={"user_id": "10001", "nickname": "发送者"}, - ) - - assert message_dict["processed_plain_text"] == "@QQ昵称" - assert message_dict["raw_message"] == [ - { - "type": "at", - "data": { - "target_user_id": "1206069534", - "target_user_nickname": "QQ昵称", - "target_user_cardname": None, - }, - } - ] - - -@pytest.mark.asyncio -async def test_query_service_normalizes_forward_payload_list() -> None: - """查询服务应兼容 ``get_forward_msg`` 直接返回节点列表。""" - - query_service_cls = _load_napcat_query_service_cls() - query_service = query_service_cls( - action_service=_FakeNapCatActionService( - [ - { - "sender": {"user_id": "10001", "nickname": "张三"}, - "message_id": "node-1", - "message": [{"type": "text", "data": {"text": "列表返回"}}], - } - ] - ), - logger=logging.getLogger("test.napcat_adapter.query_service"), - ) - - forward_payload = await query_service.get_forward_message("forward-1") - - assert forward_payload == { - "messages": [ - { - "sender": {"user_id": "10001", "nickname": "张三"}, - "message_id": "node-1", - "message": [{"type": "text", "data": {"text": "列表返回"}}], - } - ] - } - - -@pytest.mark.asyncio -async def test_query_service_supports_official_no_cache_for_get_stranger_info() -> None: - """查询服务应按官方字段下发 ``no_cache``。""" - - action_service = _FakeNapCatActionService({"nickname": "测试用户"}) - query_service_cls = _load_napcat_query_service_cls() - query_service = query_service_cls( - action_service=action_service, - logger=logging.getLogger("test.napcat_adapter.query_service.stranger"), - ) - - payload = await query_service.get_stranger_info("10001", no_cache=True) - - assert payload == {"nickname": "测试用户"} - assert action_service.action_data_calls[-1] == { - "action_name": "get_stranger_info", - "params": {"user_id": "10001", "no_cache": True}, - } - - -@pytest.mark.asyncio -async def test_query_service_supports_official_forward_id_alias() -> None: - """查询服务应兼容官方 ``id`` 字段调用 ``get_forward_msg``。""" - - action_service = _FakeNapCatActionService({"messages": []}) - query_service_cls = _load_napcat_query_service_cls() - query_service = query_service_cls( - action_service=action_service, - logger=logging.getLogger("test.napcat_adapter.query_service.forward_alias"), - ) - - payload = await query_service.get_forward_message(forward_id="forward-alias") - - assert payload == {"messages": []} - assert action_service.action_data_calls[-1] == { - "action_name": "get_forward_msg", - "params": {"id": "forward-alias"}, - } - - -@pytest.mark.asyncio -async def test_query_service_supports_custom_out_format_for_get_record() -> None: - """查询服务应按官方字段下发自定义 ``out_format``。""" - - action_service = _FakeNapCatActionService({"file": "voice.mp3"}) - query_service_cls = _load_napcat_query_service_cls() - query_service = query_service_cls( - action_service=action_service, - logger=logging.getLogger("test.napcat_adapter.query_service.record"), - ) - - payload = await query_service.get_record_detail(file_id="record-1", out_format="mp3") - - assert payload == {"file": "voice.mp3"} - assert action_service.action_data_calls[-1] == { - "action_name": "get_record", - "params": {"file_id": "record-1", "out_format": "mp3"}, - } - - -@pytest.mark.asyncio -async def test_query_service_supports_target_id_for_send_poke() -> None: - """查询服务应按官方字段下发 ``target_id``。""" - - action_service = _FakeNapCatActionService(None) - query_service_cls = _load_napcat_query_service_cls() - query_service = query_service_cls( - action_service=action_service, - logger=logging.getLogger("test.napcat_adapter.query_service.poke"), - ) - - response = await query_service.send_poke(user_id=10001, group_id=20002, target_id=30003) - - assert response["status"] == "ok" - assert action_service.action_calls[-1] == { - "action_name": "send_poke", - "params": {"user_id": 10001, "group_id": 20002, "target_id": 30003}, - } - - -@pytest.mark.asyncio -async def test_public_api_send_poke_supports_official_fields_and_legacy_alias() -> None: - """公开 API 应同时兼容官方字段和旧版 ``qq_id`` 别名。""" - - _napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols() - plugin = napcat_plugin_cls() - captured: List[Dict[str, Any]] = [] - - class _SpyQueryService: - async def send_poke( - self, - user_id: int, - group_id: int | None = None, - target_id: int | None = None, - ) -> Dict[str, Any]: - captured.append( - { - "user_id": user_id, - "group_id": group_id, - "target_id": target_id, - } - ) - return {"status": "ok", "data": {}} - - plugin._query_service = _SpyQueryService() - plugin._ensure_runtime_components = lambda: None # type: ignore[method-assign] - - await plugin.api_send_poke(user_id="10001", group_id="20002", target_id="30003") - await plugin.api_send_poke(qq_id="40004") - - assert captured == [ - {"user_id": 10001, "group_id": 20002, "target_id": 30003}, - {"user_id": 40004, "group_id": None, "target_id": None}, - ] - - -@pytest.mark.asyncio -async def test_public_api_get_forward_msg_and_get_record_support_official_fields() -> None: - """公开 API 应接受官方 ``id`` 和 ``out_format`` 等字段。""" - - _napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols() - plugin = napcat_plugin_cls() - captured: Dict[str, Dict[str, Any]] = {} - - class _SpyQueryService: - async def get_forward_message( - self, - message_id: str | None = None, - forward_id: str | None = None, - ) -> Dict[str, Any]: - captured["forward"] = {"message_id": message_id, "forward_id": forward_id} - return {"messages": []} - - async def get_record_detail( - self, - file_name: str | None = None, - file_id: str | None = None, - out_format: str = "wav", - ) -> Dict[str, Any]: - captured["record"] = { - "file_name": file_name, - "file_id": file_id, - "out_format": out_format, - } - return {"file_id": file_id or "record-1"} - - plugin._query_service = _SpyQueryService() - plugin._ensure_runtime_components = lambda: None # type: ignore[method-assign] - - forward_payload = await plugin.api_get_forward_msg(id="forward-1") - record_payload = await plugin.api_get_record(file_id="record-1", out_format="mp3") - - assert forward_payload == {"messages": []} - assert record_payload == {"file_id": "record-1"} - assert captured["forward"] == {"message_id": None, "forward_id": "forward-1"} - assert captured["record"] == { - "file_name": None, - "file_id": "record-1", - "out_format": "mp3", - } - - -@pytest.mark.asyncio -async def test_public_api_send_poke_rejects_conflicting_alias_values() -> None: - """公开 ``send_poke`` API 应拒绝互相冲突的别名值。""" - - _napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols() - plugin = napcat_plugin_cls() - plugin._ensure_runtime_components = lambda: None # type: ignore[method-assign] - - with pytest.raises(ValueError, match="user_id 与 qq_id 不能同时传递不同的值"): - await plugin.api_send_poke(user_id="10001", qq_id="20002") - - -@pytest.mark.asyncio -async def test_public_api_get_forward_msg_rejects_conflicting_fields() -> None: - """公开 ``get_forward_msg`` API 应拒绝冲突的双字段调用。""" - - _napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols() - plugin = napcat_plugin_cls() - plugin._ensure_runtime_components = lambda: None # type: ignore[method-assign] - - with pytest.raises(ValueError, match="message_id 与 id 不能同时传递不同的值"): - await plugin.api_get_forward_msg(message_id="forward-a", id="forward-b") - - -@pytest.mark.asyncio -async def test_public_api_get_record_requires_file_or_file_id() -> None: - """公开 ``get_record`` API 至少需要一个官方定位字段。""" - - _napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols() - plugin = napcat_plugin_cls() - plugin._ensure_runtime_components = lambda: None # type: ignore[method-assign] - - with pytest.raises(ValueError, match="file 或 file_id 至少提供一个"): - await plugin.api_get_record() diff --git a/pytests/test_napcat_history_recovery.py b/pytests/test_napcat_history_recovery.py deleted file mode 100644 index 30bc58f5..00000000 --- a/pytests/test_napcat_history_recovery.py +++ /dev/null @@ -1,376 +0,0 @@ -"""NapCat 历史补拉与恢复状态测试。""" - -from __future__ import annotations - -from importlib import import_module, util -from pathlib import Path -from typing import Any, Dict, List - -import logging -import sys -from types import SimpleNamespace - -import pytest - -PROJECT_ROOT = Path(__file__).resolve().parents[1] -PLUGINS_ROOT = PROJECT_ROOT / "plugins" -PLUGIN_TEMPLATE_ROOT = PROJECT_ROOT / "plugin-templates" -SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk" -NAPCAT_PLUGIN_DIR = PLUGINS_ROOT / "MaiBot-Napcat-Adapter" -NAPCAT_TEMPLATE_DIR = PLUGIN_TEMPLATE_ROOT / "MaiBot-Napcat-Adapter" -NAPCAT_TEST_MODULE = "_test_napcat_adapter_history_recovery" - -for import_path in (str(SDK_ROOT),): - if import_path not in sys.path: - sys.path.insert(0, import_path) - - -class _FakeGatewayCapability: - """用于测试入站注入的网关替身。""" - - def __init__(self) -> None: - """初始化测试替身。""" - - self.calls: List[Dict[str, Any]] = [] - - async def route_message( - self, - gateway_name: str, - message: Dict[str, Any], - *, - route_metadata: Dict[str, Any] | None = None, - external_message_id: str = "", - dedupe_key: str = "", - ) -> bool: - """记录入站注入请求并始终模拟成功。""" - - self.calls.append( - { - "gateway_name": gateway_name, - "message": dict(message), - "route_metadata": dict(route_metadata or {}), - "external_message_id": external_message_id, - "dedupe_key": dedupe_key, - } - ) - return True - - -def _resolve_napcat_plugin_dir() -> Path: - """返回当前测试可用的 NapCat 插件目录。""" - - if NAPCAT_PLUGIN_DIR.is_dir(): - return NAPCAT_PLUGIN_DIR - return NAPCAT_TEMPLATE_DIR - - -def _load_napcat_module(module_suffix: str) -> Any: - """动态加载 NapCat 测试模块。""" - - plugin_dir = _resolve_napcat_plugin_dir() - if NAPCAT_TEST_MODULE not in sys.modules: - plugin_path = plugin_dir / "plugin.py" - spec = util.spec_from_file_location( - NAPCAT_TEST_MODULE, - plugin_path, - submodule_search_locations=[str(plugin_dir)], - ) - if spec is None or spec.loader is None: - raise ImportError(f"无法为 NapCat 插件创建模块规格: {plugin_path}") - - module = util.module_from_spec(spec) - sys.modules[NAPCAT_TEST_MODULE] = module - try: - spec.loader.exec_module(module) - except Exception: - sys.modules.pop(NAPCAT_TEST_MODULE, None) - raise - - return import_module(f"{NAPCAT_TEST_MODULE}.{module_suffix}") - - -def _load_history_recovery_store_cls() -> Any: - """动态加载历史恢复状态仓库类。""" - - return _load_napcat_module("services.history_recovery_store").NapCatHistoryRecoveryStore - - -def _load_query_service_cls() -> Any: - """动态加载查询服务类。""" - - return _load_napcat_module("services.query_service").NapCatQueryService - - -def _load_router_cls() -> Any: - """动态加载事件路由器类。""" - - return _load_napcat_module("runtime.router").NapCatEventRouter - - -class _FakeActionService: - """用于查询服务的动作服务替身。""" - - def __init__(self, response_data: Any) -> None: - """初始化动作服务替身。""" - - self._response_data = response_data - self.action_data_calls: List[Dict[str, Any]] = [] - - async def safe_call_action_data(self, action_name: str, params: Dict[str, Any]) -> Any: - """记录安全查询动作。""" - - self.action_data_calls.append({"action_name": action_name, "params": dict(params)}) - return self._response_data - - -@pytest.mark.asyncio -async def test_history_recovery_store_persists_checkpoint_and_seen_state(tmp_path: Path) -> None: - """历史恢复状态仓库应持久化 checkpoint 与已补拉标记。""" - - store_cls = _load_history_recovery_store_cls() - store = store_cls( - logger=logging.getLogger("test.napcat.history_store"), - storage_path=tmp_path / "history.sqlite3", - ) - - await store.load() - await store.record_checkpoint( - account_id="10001", - scope="primary", - chat_type="group", - chat_id="20001", - message_id="msg-2", - message_time=200.0, - message_seq=2, - ) - await store.record_checkpoint( - account_id="10001", - scope="primary", - chat_type="group", - chat_id="20001", - message_id="msg-1", - message_time=100.0, - message_seq=1, - ) - await store.mark_recovered_message_seen( - account_id="10001", - scope="primary", - chat_type="group", - chat_id="20001", - external_message_id="history-1", - ) - - checkpoints = await store.list_checkpoints("10001", scope="primary") - - assert len(checkpoints) == 1 - assert checkpoints[0].last_message_id == "msg-2" - assert checkpoints[0].last_message_seq == 2 - assert ( - await store.has_recovered_message_seen( - account_id="10001", - scope="primary", - chat_type="group", - chat_id="20001", - external_message_id="history-1", - ) - is True - ) - - -@pytest.mark.asyncio -async def test_query_service_wraps_group_and_friend_history_actions() -> None: - """查询服务应按官方动作名封装历史消息接口。""" - - query_service_cls = _load_query_service_cls() - action_service = _FakeActionService([{"message_id": "msg-1"}]) - query_service = query_service_cls( - action_service=action_service, - logger=logging.getLogger("test.napcat.history_query"), - ) - - group_payload = await query_service.get_group_message_history("20001", message_seq=123, count=10) - friend_payload = await query_service.get_friend_message_history("30001", count=5, reverse_order=True) - - assert group_payload == [{"message_id": "msg-1"}] - assert friend_payload == [{"message_id": "msg-1"}] - assert action_service.action_data_calls == [ - { - "action_name": "get_group_msg_history", - "params": {"group_id": "20001", "count": 10, "reverse_order": False, "message_seq": 123}, - }, - { - "action_name": "get_friend_msg_history", - "params": {"user_id": "30001", "count": 5, "reverse_order": True}, - }, - ] - - -@pytest.mark.asyncio -async def test_router_recover_recent_history_reinjects_messages_in_order(tmp_path: Path) -> None: - """重连补拉应按时间顺序将历史消息重新注入原入站路径。""" - - history_store_cls = _load_history_recovery_store_cls() - router_cls = _load_router_cls() - gateway_capability = _FakeGatewayCapability() - router = router_cls( - gateway_capability=gateway_capability, - logger=logging.getLogger("test.napcat.history_router"), - gateway_name="napcat_gateway", - load_settings=lambda: SimpleNamespace( - napcat_server=SimpleNamespace(connection_id="primary", heartbeat_interval=30.0), - filters=SimpleNamespace(ignore_self_message=True), - chat=SimpleNamespace(ban_qq_bot=False), - ), - ) - - history_store = history_store_cls( - logger=logging.getLogger("test.napcat.history_router.store"), - storage_path=tmp_path / "router.sqlite3", - ) - await history_store.load() - await history_store.record_checkpoint( - account_id="10001", - scope="primary", - chat_type="group", - chat_id="20001", - message_id="msg-1", - message_time=100.0, - message_seq=10, - ) - - history_calls: List[Dict[str, Any]] = [] - history_payloads = [ - { - "post_type": "message", - "message_type": "group", - "self_id": "10001", - "group_id": "20001", - "user_id": "30002", - "message_id": "msg-3", - "message_seq": 12, - "time": 102, - "message": [{"type": "text", "data": {"text": "第三条"}}], - "sender": {"user_id": "30002", "nickname": "用户二"}, - }, - { - "post_type": "message", - "message_type": "group", - "self_id": "10001", - "group_id": "20001", - "user_id": "30001", - "message_id": "msg-2", - "message_seq": 11, - "time": 101, - "message": [{"type": "text", "data": {"text": "第二条"}}], - "sender": {"user_id": "30001", "nickname": "用户一"}, - }, - ] - - class _FakeQueryService: - async def get_group_message_history( - self, - group_id: str, - *, - message_seq: int | None = None, - count: int = 20, - reverse_order: bool = False, - ) -> List[Dict[str, Any]]: - history_calls.append( - { - "group_id": group_id, - "message_seq": message_seq, - "count": count, - "reverse_order": reverse_order, - } - ) - return list(history_payloads) - - async def get_friend_message_history(self, user_id: str, **kwargs: Any) -> List[Dict[str, Any]]: - del user_id - del kwargs - return [] - - class _FakeInboundCodec: - @staticmethod - async def build_message_dict( - payload: Dict[str, Any], - self_id: str, - sender_user_id: str, - sender: Dict[str, Any], - ) -> Dict[str, Any]: - del self_id - del sender_user_id - del sender - return { - "message_id": str(payload["message_id"]), - "platform": "qq", - "timestamp": str(float(payload["time"])), - "message_info": { - "user_info": {"user_id": str(payload["user_id"]), "user_nickname": "测试用户"}, - "group_info": {"group_id": str(payload["group_id"]), "group_name": "测试群"}, - "additional_config": {}, - }, - "raw_message": [{"type": "text", "data": str(payload["message"][0]["data"]["text"])}], - "processed_plain_text": str(payload["message"][0]["data"]["text"]), - "display_message": str(payload["message"][0]["data"]["text"]), - "is_mentioned": False, - "is_at": False, - "is_emoji": False, - "is_picture": False, - "is_command": False, - "is_notify": False, - "session_id": "", - } - - router.bind_runtime( - SimpleNamespace( - runtime_state=SimpleNamespace(report_connected=lambda *args, **kwargs: _noop_async(), report_disconnected=_noop_async), - chat_filter=SimpleNamespace(is_inbound_chat_allowed=lambda *args, **kwargs: True), - official_bot_guard=SimpleNamespace( - should_reject=lambda *args, **kwargs: _return_false_async(), - clear_cache=lambda: None, - ), - inbound_codec=_FakeInboundCodec(), - history_recovery_store=history_store, - query_service=_FakeQueryService(), - heartbeat_monitor=SimpleNamespace(start=_noop_async, stop=_noop_async), - ban_tracker=SimpleNamespace(start=_noop_async, stop=_noop_async, record_notice=_noop_async), - notice_codec=SimpleNamespace(handle_meta_event=_noop_async, build_notice_message_dict=_return_none_async), - ) - ) - - await router._recover_recent_history(self_id="10001", scope="primary") - - assert history_calls == [ - { - "group_id": "20001", - "message_seq": 10, - "count": 20, - "reverse_order": False, - } - ] - assert [call["external_message_id"] for call in gateway_capability.calls] == ["msg-2", "msg-3"] - assert [call["message"]["message_id"] for call in gateway_capability.calls] == ["msg-2", "msg-3"] - - -async def _noop_async(*args: Any, **kwargs: Any) -> None: - """无操作异步函数。""" - - del args - del kwargs - - -async def _return_false_async(*args: Any, **kwargs: Any) -> bool: - """返回 ``False`` 的异步测试替身。""" - - del args - del kwargs - return False - - -async def _return_none_async(*args: Any, **kwargs: Any) -> None: - """返回 ``None`` 的异步测试替身。""" - - del args - del kwargs - return None diff --git a/pytests/test_openai_client_toolless_request.py b/pytests/test_openai_client_toolless_request.py deleted file mode 100644 index d99a691f..00000000 --- a/pytests/test_openai_client_toolless_request.py +++ /dev/null @@ -1,164 +0,0 @@ -from types import SimpleNamespace - -import pytest - -from src.config.model_configs import APIProvider, ReasoningParseMode, ToolArgumentParseMode -from src.llm_models.model_client.openai_client import ( - _OpenAIStreamAccumulator, - _build_reasoning_key, - _default_normal_response_parser, - _parse_tool_arguments, - _sanitize_messages_for_toolless_request, -) -from src.llm_models.payload_content.message import Message, RoleType, TextMessagePart -from src.llm_models.payload_content.tool_option import ToolCall - - -@pytest.mark.parametrize("parse_mode", list(ToolArgumentParseMode)) -def test_parse_tool_arguments_treats_blank_arguments_as_empty_dict(parse_mode: ToolArgumentParseMode) -> None: - assert _parse_tool_arguments("", parse_mode, None) == {} - assert _parse_tool_arguments(" ", parse_mode, None) == {} - - -def test_normal_response_parser_accepts_empty_string_arguments_for_parameterless_tool() -> None: - response = SimpleNamespace( - choices=[ - SimpleNamespace( - finish_reason="tool_calls", - message=SimpleNamespace( - content=None, - tool_calls=[ - SimpleNamespace( - id="finish-call", - type="function", - function=SimpleNamespace(name="finish", arguments=""), - ) - ], - ), - ) - ], - usage=None, - model="glm-5.1", - ) - - api_response, usage_record = _default_normal_response_parser( - response, - reasoning_parse_mode=ReasoningParseMode.AUTO, - tool_argument_parse_mode=ToolArgumentParseMode.AUTO, - reasoning_key=None, - ) - - assert len(api_response.tool_calls) == 1 - assert api_response.tool_calls[0].func_name == "finish" - assert api_response.tool_calls[0].args == {} - assert usage_record is None - - -def test_sanitize_messages_for_toolless_request_drops_assistant_tool_call_without_parts() -> None: - messages = [ - Message( - role=RoleType.Assistant, - tool_calls=[ - ToolCall( - call_id="call_1", - func_name="mute_user", - args={"target": "alice"}, - ) - ], - ), - Message( - role=RoleType.User, - parts=[TextMessagePart(text="继续")], - ), - ] - - sanitized_messages = _sanitize_messages_for_toolless_request(messages) - - assert len(sanitized_messages) == 1 - assert sanitized_messages[0].role == RoleType.User - - -def test_normal_response_parser_ignores_reasoning_field_for_non_openrouter_provider() -> None: - response = SimpleNamespace( - choices=[ - SimpleNamespace( - finish_reason="stop", - message=SimpleNamespace( - content="正式回复", - reasoning="推理内容", - tool_calls=None, - ), - ) - ], - usage=None, - model="openrouter/test-model", - ) - - api_response, usage_record = _default_normal_response_parser( - response, - reasoning_parse_mode=ReasoningParseMode.AUTO, - tool_argument_parse_mode=ToolArgumentParseMode.AUTO, - reasoning_key=_build_reasoning_key( - APIProvider(name="test", base_url="https://openrouter.ai.example.com/api/v1", api_key="test") - ), - ) - - assert api_response.content == "正式回复" - assert api_response.reasoning_content is None - assert usage_record is None - - -def test_normal_response_parser_reads_provider_reasoning_field_for_reasoning_domains() -> None: - provider_urls = [ - "https://openrouter.ai/compatible-api", - "https://api.groq.com/openai/v1", - ] - - for provider_url in provider_urls: - response = SimpleNamespace( - choices=[ - SimpleNamespace( - finish_reason="stop", - message=SimpleNamespace( - content="正式回复", - reasoning="推理内容", - tool_calls=None, - ), - ) - ], - usage=None, - model="test-model", - ) - - api_response, usage_record = _default_normal_response_parser( - response, - reasoning_parse_mode=ReasoningParseMode.AUTO, - tool_argument_parse_mode=ToolArgumentParseMode.AUTO, - reasoning_key=_build_reasoning_key( - APIProvider(name="reasoning-provider", base_url=provider_url, api_key="test") - ), - ) - - assert api_response.content == "正式回复" - assert api_response.reasoning_content == "推理内容" - assert usage_record is None - - -def test_stream_accumulator_reads_openrouter_reasoning_delta_field() -> None: - accumulator = _OpenAIStreamAccumulator( - reasoning_parse_mode=ReasoningParseMode.AUTO, - tool_argument_parse_mode=ToolArgumentParseMode.AUTO, - reasoning_key=_build_reasoning_key( - APIProvider(name="openrouter", base_url="https://openrouter.ai/compatible-api", api_key="test") - ), - ) - try: - accumulator.process_delta(SimpleNamespace(reasoning="流式推理", content=None, tool_calls=None)) - accumulator.process_delta(SimpleNamespace(content="正式回复", tool_calls=None)) - - api_response = accumulator.build_response() - finally: - accumulator.close() - - assert api_response.content == "正式回复" - assert api_response.reasoning_content == "流式推理" diff --git a/pytests/test_platform_io_dedupe.py b/pytests/test_platform_io_dedupe.py deleted file mode 100644 index d6bdd1dd..00000000 --- a/pytests/test_platform_io_dedupe.py +++ /dev/null @@ -1,209 +0,0 @@ -"""Platform IO 入站去重策略测试。""" - -from types import SimpleNamespace -from typing import Any, Dict, List, Optional - -import pytest - -from src.platform_io.drivers.base import PlatformIODriver -from src.platform_io.manager import PlatformIOManager -from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey - - -def _build_envelope( - *, - dedupe_key: str | None = None, - external_message_id: str | None = None, - session_message_id: str | None = None, - payload: Optional[Dict[str, Any]] = None, -) -> InboundMessageEnvelope: - """构造测试用入站信封。 - - Args: - dedupe_key: 显式去重键。 - external_message_id: 平台侧消息 ID。 - session_message_id: 规范化消息对象上的消息 ID。 - payload: 原始载荷。 - - Returns: - InboundMessageEnvelope: 测试用入站消息信封。 - """ - session_message = None - if session_message_id is not None: - session_message = SimpleNamespace(message_id=session_message_id) - - return InboundMessageEnvelope( - route_key=RouteKey(platform="qq", account_id="10001", scope="main"), - driver_id="plugin.napcat", - driver_kind=DriverKind.PLUGIN, - dedupe_key=dedupe_key, - external_message_id=external_message_id, - session_message=session_message, - payload=payload, - ) - - -class _StubPlatformIODriver(PlatformIODriver): - """测试用 Platform IO 驱动。""" - - async def send_message( - self, - message: Any, - route_key: RouteKey, - metadata: Optional[Dict[str, Any]] = None, - ) -> DeliveryReceipt: - """返回一个固定的成功回执。 - - Args: - message: 待发送的消息对象。 - route_key: 本次发送使用的路由键。 - metadata: 额外发送元数据。 - - Returns: - DeliveryReceipt: 固定的成功回执。 - """ - return DeliveryReceipt( - internal_message_id=str(getattr(message, "message_id", "stub-message-id")), - route_key=route_key, - status=DeliveryStatus.SENT, - driver_id=self.driver_id, - driver_kind=self.descriptor.kind, - ) - - -def _build_manager() -> PlatformIOManager: - """构造带有最小接收路由的 Broker 管理器。 - - Returns: - PlatformIOManager: 已注册测试驱动并绑定接收路由的 Broker。 - """ - manager = PlatformIOManager() - driver = _StubPlatformIODriver( - DriverDescriptor( - driver_id="plugin.napcat", - kind=DriverKind.PLUGIN, - platform="qq", - account_id="10001", - scope="main", - ) - ) - manager.register_driver(driver) - manager.bind_receive_route( - RouteBinding( - route_key=RouteKey(platform="qq", account_id="10001", scope="main"), - driver_id=driver.driver_id, - driver_kind=driver.descriptor.kind, - ) - ) - return manager - - -class TestPlatformIODedupe: - """Platform IO 去重测试。""" - - @pytest.mark.asyncio - async def test_accept_inbound_dedupes_by_external_message_id(self) -> None: - """相同平台消息 ID 的重复入站应被抑制。""" - manager = _build_manager() - accepted_envelopes: List[InboundMessageEnvelope] = [] - - async def dispatcher(envelope: InboundMessageEnvelope) -> None: - """记录被成功接收的入站消息。 - - Args: - envelope: 被 Broker 接受的入站消息。 - """ - accepted_envelopes.append(envelope) - - manager.set_inbound_dispatcher(dispatcher) - - first_envelope = _build_envelope( - external_message_id="msg-1", - payload={"message": "hello"}, - ) - second_envelope = _build_envelope( - external_message_id="msg-1", - payload={"message": "hello"}, - ) - - assert await manager.accept_inbound(first_envelope) is True - assert await manager.accept_inbound(second_envelope) is False - assert len(accepted_envelopes) == 1 - - @pytest.mark.asyncio - async def test_accept_inbound_without_stable_identity_does_not_guess_duplicate(self) -> None: - """缺少稳定身份时,不应仅凭 payload 内容猜测重复消息。""" - manager = _build_manager() - accepted_envelopes: List[InboundMessageEnvelope] = [] - - async def dispatcher(envelope: InboundMessageEnvelope) -> None: - """记录被成功接收的入站消息。 - - Args: - envelope: 被 Broker 接受的入站消息。 - """ - accepted_envelopes.append(envelope) - - manager.set_inbound_dispatcher(dispatcher) - - first_envelope = _build_envelope(payload={"message": "same-payload"}) - second_envelope = _build_envelope(payload={"message": "same-payload"}) - - assert await manager.accept_inbound(first_envelope) is True - assert await manager.accept_inbound(second_envelope) is True - assert len(accepted_envelopes) == 2 - - def test_build_inbound_dedupe_key_prefers_explicit_identity(self) -> None: - """去重键应只来自显式或稳定的技术身份。""" - explicit_envelope = _build_envelope(dedupe_key="dedupe-1", external_message_id="msg-1") - session_message_envelope = _build_envelope(session_message_id="session-1") - payload_only_envelope = _build_envelope(payload={"message": "hello"}) - - assert PlatformIOManager._build_inbound_dedupe_key(explicit_envelope) == "plugin.napcat:dedupe-1" - assert PlatformIOManager._build_inbound_dedupe_key(session_message_envelope) == "plugin.napcat:session-1" - assert PlatformIOManager._build_inbound_dedupe_key(payload_only_envelope) is None - - @pytest.mark.asyncio - async def test_send_message_fans_out_to_all_matching_routes(self) -> None: - """同一路由命中多条发送链路时应全部发送。""" - - manager = PlatformIOManager() - first_driver = _StubPlatformIODriver( - DriverDescriptor( - driver_id="plugin.gateway_a", - kind=DriverKind.PLUGIN, - platform="qq", - ) - ) - second_driver = _StubPlatformIODriver( - DriverDescriptor( - driver_id="plugin.gateway_b", - kind=DriverKind.PLUGIN, - platform="qq", - ) - ) - manager.register_driver(first_driver) - manager.register_driver(second_driver) - manager.bind_send_route( - RouteBinding( - route_key=RouteKey(platform="qq"), - driver_id=first_driver.driver_id, - driver_kind=first_driver.descriptor.kind, - ) - ) - manager.bind_send_route( - RouteBinding( - route_key=RouteKey(platform="qq"), - driver_id=second_driver.driver_id, - driver_kind=second_driver.descriptor.kind, - ) - ) - - message = SimpleNamespace(message_id="internal-msg-1") - result = await manager.send_message(message, RouteKey(platform="qq")) - - assert result.has_success is True - assert [receipt.driver_id for receipt in result.sent_receipts] == [ - "plugin.gateway_a", - "plugin.gateway_b", - ] diff --git a/pytests/test_platform_io_legacy_driver.py b/pytests/test_platform_io_legacy_driver.py deleted file mode 100644 index 76f14d8f..00000000 --- a/pytests/test_platform_io_legacy_driver.py +++ /dev/null @@ -1,178 +0,0 @@ -"""Platform IO legacy driver 回归测试。""" - -from typing import Any, Dict, Optional - -import pytest - -from src.chat.utils import utils as chat_utils -from src.chat.message_receive import uni_message_sender -from src.platform_io.drivers.base import PlatformIODriver -from src.platform_io.drivers.legacy_driver import LegacyPlatformDriver -from src.platform_io.manager import PlatformIOManager -from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteBinding, RouteKey - - -class _PluginDriver(PlatformIODriver): - """测试用插件发送驱动。""" - - def __init__(self, driver_id: str, platform: str) -> None: - """初始化测试驱动。 - - Args: - driver_id: 驱动 ID。 - platform: 负责的平台名称。 - """ - super().__init__( - DriverDescriptor( - driver_id=driver_id, - kind=DriverKind.PLUGIN, - platform=platform, - plugin_id="test.plugin", - ) - ) - - async def send_message( - self, - message: Any, - route_key: RouteKey, - metadata: Optional[Dict[str, Any]] = None, - ) -> DeliveryReceipt: - """返回一个固定成功回执。 - - Args: - message: 待发送消息。 - route_key: 当前路由键。 - metadata: 发送元数据。 - - Returns: - DeliveryReceipt: 固定成功回执。 - """ - del metadata - return DeliveryReceipt( - internal_message_id=str(message.message_id), - route_key=route_key, - status=DeliveryStatus.SENT, - driver_id=self.driver_id, - driver_kind=self.descriptor.kind, - ) - - -@pytest.mark.asyncio -async def test_platform_io_uses_legacy_driver_when_no_explicit_send_route( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """没有显式发送路由时,应由 Platform IO 回退到 legacy driver。""" - manager = PlatformIOManager() - monkeypatch.setattr(chat_utils, "get_all_bot_accounts", lambda: {"qq": "bot-qq"}) - - try: - await manager.ensure_send_pipeline_ready() - - fallback_drivers = manager.resolve_drivers(RouteKey(platform="qq")) - assert [driver.driver_id for driver in fallback_drivers] == ["legacy.send.qq"] - - plugin_driver = _PluginDriver(driver_id="plugin.qq.sender", platform="qq") - await manager.add_driver(plugin_driver) - manager.bind_send_route( - RouteBinding( - route_key=RouteKey(platform="qq"), - driver_id=plugin_driver.driver_id, - driver_kind=plugin_driver.descriptor.kind, - ) - ) - - explicit_drivers = manager.resolve_drivers(RouteKey(platform="qq")) - assert [driver.driver_id for driver in explicit_drivers] == ["plugin.qq.sender", "legacy.send.qq"] - finally: - await manager.stop() - - -@pytest.mark.asyncio -async def test_platform_io_broadcasts_to_plugin_and_legacy_driver( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """同一路由命中插件驱动与 legacy driver 时,应同时广播发送。""" - - manager = PlatformIOManager() - legacy_calls: list[dict[str, Any]] = [] - monkeypatch.setattr(chat_utils, "get_all_bot_accounts", lambda: {"qq": "bot-qq"}) - - async def _fake_send_prepared_message_to_platform(message: Any, show_log: bool = True) -> bool: - """记录 legacy driver 调用。""" - - legacy_calls.append({"message": message, "show_log": show_log}) - return True - - monkeypatch.setattr( - uni_message_sender, - "send_prepared_message_to_platform", - _fake_send_prepared_message_to_platform, - ) - - try: - await manager.ensure_send_pipeline_ready() - - plugin_driver = _PluginDriver(driver_id="plugin.qq.sender", platform="qq") - await manager.add_driver(plugin_driver) - manager.bind_send_route( - RouteBinding( - route_key=RouteKey(platform="qq"), - driver_id=plugin_driver.driver_id, - driver_kind=plugin_driver.descriptor.kind, - ) - ) - - message = type("FakeMessage", (), {"message_id": "message-1"})() - batch = await manager.send_message( - message=message, - route_key=RouteKey(platform="qq"), - metadata={"show_log": False}, - ) - - assert sorted(receipt.driver_id for receipt in batch.sent_receipts) == [ - "legacy.send.qq", - "plugin.qq.sender", - ] - assert batch.failed_receipts == [] - assert len(legacy_calls) == 1 - assert legacy_calls[0]["message"] is message - assert legacy_calls[0]["show_log"] is False - finally: - await manager.stop() - - -@pytest.mark.asyncio -async def test_legacy_platform_driver_uses_prepared_universal_sender( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """legacy driver 应复用已预处理消息的旧链发送函数。""" - calls: list[dict[str, Any]] = [] - - async def _fake_send_prepared_message_to_platform(message: Any, show_log: bool = True) -> bool: - """记录 legacy driver 调用。""" - calls.append({"message": message, "show_log": show_log}) - return True - - monkeypatch.setattr( - uni_message_sender, - "send_prepared_message_to_platform", - _fake_send_prepared_message_to_platform, - ) - - driver = LegacyPlatformDriver( - driver_id="legacy.send.qq", - platform="qq", - account_id="bot-qq", - ) - message = type("FakeMessage", (), {"message_id": "message-1"})() - receipt = await driver.send_message( - message=message, - route_key=RouteKey(platform="qq"), - metadata={"show_log": False}, - ) - - assert len(calls) == 1 - assert calls[0]["message"] is message - assert calls[0]["show_log"] is False - assert receipt.status == DeliveryStatus.SENT - assert receipt.driver_id == "legacy.send.qq" diff --git a/pytests/test_plugin_config_runtime.py b/pytests/test_plugin_config_runtime.py deleted file mode 100644 index b84bafb8..00000000 --- a/pytests/test_plugin_config_runtime.py +++ /dev/null @@ -1,553 +0,0 @@ -"""插件配置运行时测试。""" - -from __future__ import annotations - -from pathlib import Path -from types import SimpleNamespace -from typing import Any, Dict, Mapping, Optional, Tuple, cast - -import tomllib - -import pytest - -from src.plugin_runtime.component_query import component_query_service -from src.plugin_runtime.protocol.envelope import ( - Envelope, - InspectPluginConfigPayload, - MessageType, - RegisterPluginPayload, - ValidatePluginConfigPayload, -) -from src.plugin_runtime.runner.runner_main import PluginRunner -from src.webui.routers.plugin.config_routes import get_plugin_config, get_plugin_config_schema, update_plugin_config -from src.webui.routers.plugin.schemas import UpdatePluginConfigRequest - - -class _DemoConfigPlugin: - """用于测试 Runner 配置归一化流程的伪插件。""" - - _config_version: str = "2.0.0" - - def __init__(self) -> None: - """初始化测试插件状态。""" - - self.received_config: Dict[str, Any] = {} - - def normalize_plugin_config(self, config_data: Optional[Mapping[str, Any]]) -> Tuple[Dict[str, Any], bool]: - """补齐测试插件的默认配置。 - - Args: - config_data: 原始配置数据。 - - Returns: - Tuple[Dict[str, Any], bool]: 补齐后的配置,以及是否发生变更。 - """ - - current_config = dict(config_data or {}) - plugin_section = dict(current_config.get("plugin", {})) - changed = "retry_count" not in plugin_section or "config_version" not in plugin_section - plugin_section.setdefault("config_version", self._config_version) - plugin_section.setdefault("enabled", True) - plugin_section.setdefault("retry_count", 3) - return {"plugin": plugin_section}, changed - - def set_plugin_config(self, config: Dict[str, Any]) -> None: - """记录 Runner 注入的配置内容。 - - Args: - config: 当前最新配置。 - """ - - self.received_config = config - - def get_default_config(self) -> Dict[str, Any]: - """返回测试插件的默认配置。 - - Returns: - Dict[str, Any]: 默认配置字典。 - """ - - return {"plugin": {"config_version": self._config_version, "enabled": True, "retry_count": 3}} - - def get_webui_config_schema( - self, - *, - plugin_id: str = "", - plugin_name: str = "", - plugin_version: str = "", - plugin_description: str = "", - plugin_author: str = "", - ) -> Dict[str, Any]: - """返回测试插件的 WebUI 配置 Schema。 - - Args: - plugin_id: 插件 ID。 - plugin_name: 插件名称。 - plugin_version: 插件版本。 - plugin_description: 插件描述。 - plugin_author: 插件作者。 - - Returns: - Dict[str, Any]: 测试配置 Schema。 - """ - - del plugin_name, plugin_description, plugin_author - return { - "plugin_id": plugin_id, - "plugin_info": { - "name": "Demo", - "version": plugin_version, - "description": "", - "author": "", - }, - "sections": { - "plugin": { - "fields": { - "enabled": { - "type": "boolean", - "label": "启用", - "default": True, - "ui_type": "switch", - } - } - } - }, - "layout": {"type": "auto", "tabs": []}, - } - - -class _StrictConfigPlugin: - """用于测试配置校验错误的伪插件。""" - - def normalize_plugin_config(self, config_data: Optional[Mapping[str, Any]]) -> Tuple[Dict[str, Any], bool]: - """校验重试次数不能为负数。 - - Args: - config_data: 原始配置数据。 - - Returns: - Tuple[Dict[str, Any], bool]: 规范化配置结果。 - - Raises: - ValueError: 当重试次数为负数时抛出。 - """ - - current_config = dict(config_data or {}) - plugin_section = dict(current_config.get("plugin", {})) - plugin_section.setdefault("config_version", "2.0.0") - retry_count = int(plugin_section.get("retry_count", 0)) - if retry_count < 0: - raise ValueError("重试次数不能小于 0") - plugin_section.setdefault("enabled", True) - return {"plugin": plugin_section}, False - - def set_plugin_config(self, config: Dict[str, Any]) -> None: - """兼容 Runner 配置注入接口。 - - Args: - config: 当前配置字典。 - """ - - del config - - def get_default_config(self) -> Dict[str, Any]: - """返回测试插件的默认配置。 - - Returns: - Dict[str, Any]: 默认配置字典。 - """ - - return {"plugin": {"config_version": "2.0.0", "enabled": True, "retry_count": 0}} - - -def test_runner_apply_plugin_config_generates_config_file(tmp_path: Path) -> None: - """Runner 注入配置时应自动补齐并落盘 config.toml。""" - - plugin = _DemoConfigPlugin() - runner = PluginRunner( - host_address="ipc://unused", - session_token="session-token", - plugin_dirs=[], - ) - meta = SimpleNamespace(plugin_id="demo.plugin", plugin_dir=str(tmp_path), instance=plugin) - - runner._apply_plugin_config( - cast(Any, meta), - config_data={"plugin": {"config_version": "2.0.0", "enabled": False}}, - ) - - config_path = tmp_path / "config.toml" - assert config_path.exists() - assert plugin.received_config == {"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}} - - with config_path.open("rb") as handle: - saved_config = tomllib.load(handle) - assert saved_config == {"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}} - - -def test_runner_apply_plugin_config_preserves_existing_comments(tmp_path: Path) -> None: - """Runner 在版本升级时应尽量保留现有 config.toml 注释。""" - - plugin = _DemoConfigPlugin() - runner = PluginRunner( - host_address="ipc://unused", - session_token="session-token", - plugin_dirs=[], - ) - meta = SimpleNamespace(plugin_id="demo.plugin", plugin_dir=str(tmp_path), instance=plugin) - config_path = tmp_path / "config.toml" - config_path.write_text( - '# 插件配置头注释\n[plugin]\nconfig_version = "1.0.0"\nenabled = false # 启用开关注释\n', - encoding="utf-8", - ) - - runner._apply_plugin_config(cast(Any, meta)) - - config_text = config_path.read_text(encoding="utf-8") - assert "# 插件配置头注释" in config_text - assert "# 启用开关注释" in config_text - - with config_path.open("rb") as handle: - saved_config = tomllib.load(handle) - assert saved_config == {"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}} - - -def test_runner_apply_plugin_config_same_version_does_not_rewrite_file(tmp_path: Path) -> None: - """Runner 在配置版本未变化时不应仅因补齐默认值而重写文件。""" - - plugin = _DemoConfigPlugin() - runner = PluginRunner( - host_address="ipc://unused", - session_token="session-token", - plugin_dirs=[], - ) - meta = SimpleNamespace(plugin_id="demo.plugin", plugin_dir=str(tmp_path), instance=plugin) - config_path = tmp_path / "config.toml" - original_config_text = '# 原始注释\n[plugin]\nconfig_version = "2.0.0"\nenabled = false\n' - config_path.write_text(original_config_text, encoding="utf-8") - - runner._apply_plugin_config(cast(Any, meta)) - - assert plugin.received_config == {"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}} - assert config_path.read_text(encoding="utf-8") == original_config_text - - -def test_runner_apply_plugin_config_requires_config_version(tmp_path: Path) -> None: - """Runner 应拒绝缺少配置版本号的插件配置文件。""" - - plugin = _DemoConfigPlugin() - runner = PluginRunner( - host_address="ipc://unused", - session_token="session-token", - plugin_dirs=[], - ) - meta = SimpleNamespace(plugin_id="demo.plugin", plugin_dir=str(tmp_path), instance=plugin) - config_path = tmp_path / "config.toml" - config_path.write_text("[plugin]\nenabled = true\n", encoding="utf-8") - - with pytest.raises(ValueError, match="config_version"): - runner._apply_plugin_config(cast(Any, meta)) - - -def test_component_query_service_returns_plugin_config_schema(monkeypatch: Any) -> None: - """组件查询服务应支持按插件 ID 返回配置 Schema。""" - - payload = RegisterPluginPayload( - plugin_id="demo.plugin", - plugin_version="1.0.0", - default_config={"plugin": {"enabled": True}}, - config_schema={ - "plugin_id": "demo.plugin", - "plugin_info": { - "name": "Demo", - "version": "1.0.0", - "description": "", - "author": "", - }, - "sections": {"plugin": {"fields": {}}}, - "layout": {"type": "auto", "tabs": []}, - }, - ) - fake_supervisor = SimpleNamespace(_registered_plugins={"demo.plugin": payload}) - fake_manager = SimpleNamespace(_get_supervisor_for_plugin=lambda plugin_id: fake_supervisor) - - monkeypatch.setattr( - type(component_query_service), - "_get_runtime_manager", - staticmethod(lambda: fake_manager), - ) - - assert component_query_service.get_plugin_config_schema("demo.plugin") == payload.config_schema - assert component_query_service.get_plugin_default_config("demo.plugin") == payload.default_config - - -@pytest.mark.asyncio -async def test_runner_validate_plugin_config_handler_returns_normalized_config(monkeypatch: pytest.MonkeyPatch) -> None: - """Runner 应返回插件模型归一化后的配置。""" - - plugin = _DemoConfigPlugin() - runner = PluginRunner( - host_address="ipc://unused", - session_token="session-token", - plugin_dirs=[], - ) - meta = SimpleNamespace(plugin_id="demo.plugin", plugin_dir="", instance=plugin) - monkeypatch.setattr(runner._loader, "get_plugin", lambda plugin_id: meta if plugin_id == "demo.plugin" else None) - - envelope = Envelope( - request_id=1, - message_type=MessageType.REQUEST, - method="plugin.validate_config", - plugin_id="demo.plugin", - payload=ValidatePluginConfigPayload( - config_data={"plugin": {"config_version": "2.0.0", "enabled": False}} - ).model_dump(), - ) - - response = await runner._handle_validate_plugin_config(envelope) - - assert response.error is None - assert response.payload["success"] is True - assert response.payload["normalized_config"] == { - "plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3} - } - - -@pytest.mark.asyncio -async def test_runner_inspect_plugin_config_handler_supports_unloaded_plugin( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Runner 应支持对未加载插件执行冷检查。""" - - plugin = _DemoConfigPlugin() - runner = PluginRunner( - host_address="ipc://unused", - session_token="session-token", - plugin_dirs=[], - ) - meta = SimpleNamespace( - plugin_id="demo.plugin", - plugin_dir="/tmp/demo-plugin", - instance=plugin, - manifest=SimpleNamespace( - name="Demo", - description="", - author=SimpleNamespace(name="tester"), - ), - version="1.0.0", - ) - purged_plugins: list[tuple[str, str]] = [] - - monkeypatch.setattr( - runner, - "_resolve_plugin_meta_for_config_request", - lambda plugin_id: (meta, True, None) if plugin_id == "demo.plugin" else (None, False, "not-found"), - ) - monkeypatch.setattr( - runner._loader, - "purge_plugin_modules", - lambda plugin_id, plugin_dir: purged_plugins.append((plugin_id, plugin_dir)), - ) - - envelope = Envelope( - request_id=1, - message_type=MessageType.REQUEST, - method="plugin.inspect_config", - plugin_id="demo.plugin", - payload=InspectPluginConfigPayload( - config_data={"plugin": {"enabled": False}}, - use_provided_config=True, - ).model_dump(), - ) - - response = await runner._handle_inspect_plugin_config(envelope) - - assert response.error is None - assert response.payload["success"] is True - assert response.payload["enabled"] is False - assert response.payload["normalized_config"] == { - "plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3} - } - assert response.payload["default_config"] == { - "plugin": {"config_version": "2.0.0", "enabled": True, "retry_count": 3} - } - assert purged_plugins == [("demo.plugin", "/tmp/demo-plugin")] - - -@pytest.mark.asyncio -async def test_runner_validate_plugin_config_handler_returns_error_on_invalid_config( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Runner 应在插件拒绝配置时返回错误响应。""" - - plugin = _StrictConfigPlugin() - runner = PluginRunner( - host_address="ipc://unused", - session_token="session-token", - plugin_dirs=[], - ) - meta = SimpleNamespace(plugin_id="demo.plugin", plugin_dir="", instance=plugin) - monkeypatch.setattr(runner._loader, "get_plugin", lambda plugin_id: meta if plugin_id == "demo.plugin" else None) - - envelope = Envelope( - request_id=1, - message_type=MessageType.REQUEST, - method="plugin.validate_config", - plugin_id="demo.plugin", - payload=ValidatePluginConfigPayload( - config_data={"plugin": {"config_version": "2.0.0", "retry_count": -1}} - ).model_dump(), - ) - - response = await runner._handle_validate_plugin_config(envelope) - - assert response.error is not None - assert response.error["message"] == "重试次数不能小于 0" - - -@pytest.mark.asyncio -async def test_update_plugin_config_prefers_runtime_validation( - monkeypatch: pytest.MonkeyPatch, - tmp_path: Path, -) -> None: - """WebUI 保存插件配置时应优先使用运行时校验结果。""" - - config_path = tmp_path / "config.toml" - - async def _mock_validate_plugin_config(plugin_id: str, config_data: Dict[str, Any]) -> Dict[str, Any] | None: - """返回运行时归一化后的配置。 - - Args: - plugin_id: 插件 ID。 - config_data: 原始配置。 - - Returns: - Dict[str, Any] | None: 归一化后的配置。 - """ - - assert plugin_id == "demo.plugin" - assert config_data == {"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}} - return {"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}} - - async def _mock_inspect_plugin_config( - plugin_id: str, - config_data: Optional[Dict[str, Any]] = None, - *, - use_provided_config: bool = False, - ) -> SimpleNamespace | None: - """返回运行时配置快照。 - - Args: - plugin_id: 插件 ID。 - config_data: 可选配置。 - use_provided_config: 是否使用传入配置。 - - Returns: - SimpleNamespace | None: 运行时配置快照。 - """ - - del config_data, use_provided_config - if plugin_id != "demo.plugin": - return None - return SimpleNamespace( - normalized_config={"plugin": {"config_version": "2.0.0", "enabled": True, "retry_count": 3}} - ) - - fake_runtime_manager = SimpleNamespace( - inspect_plugin_config=_mock_inspect_plugin_config, - validate_plugin_config=_mock_validate_plugin_config, - ) - - monkeypatch.setattr( - "src.webui.routers.plugin.config_routes.require_plugin_token", - lambda session: session or "session-token", - ) - monkeypatch.setattr( - "src.webui.routers.plugin.config_routes.find_plugin_path_by_id", - lambda plugin_id: tmp_path if plugin_id == "demo.plugin" else None, - ) - monkeypatch.setattr( - "src.plugin_runtime.integration.get_plugin_runtime_manager", - lambda: fake_runtime_manager, - ) - - response = await update_plugin_config( - "demo.plugin", - UpdatePluginConfigRequest(config={"plugin.enabled": False}), - maibot_session="session-token", - ) - - assert response["success"] is True - with config_path.open("rb") as handle: - saved_config = tomllib.load(handle) - assert saved_config == {"plugin": {"config_version": "2.0.0", "enabled": False, "retry_count": 3}} - - -@pytest.mark.asyncio -async def test_webui_config_endpoints_use_runtime_inspection_for_unloaded_plugin( - monkeypatch: pytest.MonkeyPatch, - tmp_path: Path, -) -> None: - """WebUI 在插件未加载时也应从代码定义返回配置与 Schema。""" - - async def _mock_inspect_plugin_config( - plugin_id: str, - config_data: Optional[Dict[str, Any]] = None, - *, - use_provided_config: bool = False, - ) -> SimpleNamespace | None: - """返回运行时冷检查结果。 - - Args: - plugin_id: 插件 ID。 - config_data: 可选配置。 - use_provided_config: 是否使用传入配置。 - - Returns: - SimpleNamespace | None: 冷检查结果。 - """ - - del config_data, use_provided_config - if plugin_id != "demo.plugin": - return None - return SimpleNamespace( - config_schema={ - "plugin_id": "demo.plugin", - "plugin_info": { - "name": "Demo", - "version": "1.0.0", - "description": "", - "author": "", - }, - "sections": {"plugin": {"fields": {}}}, - "layout": {"type": "auto", "tabs": []}, - }, - normalized_config={"plugin": {"config_version": "2.0.0", "enabled": True, "retry_count": 3}}, - enabled=True, - ) - - fake_runtime_manager = SimpleNamespace(inspect_plugin_config=_mock_inspect_plugin_config) - - monkeypatch.setattr( - "src.webui.routers.plugin.config_routes.require_plugin_token", - lambda session: session or "session-token", - ) - monkeypatch.setattr( - "src.webui.routers.plugin.config_routes.find_plugin_path_by_id", - lambda plugin_id: tmp_path if plugin_id == "demo.plugin" else None, - ) - monkeypatch.setattr( - "src.plugin_runtime.integration.get_plugin_runtime_manager", - lambda: fake_runtime_manager, - ) - - schema_response = await get_plugin_config_schema("demo.plugin", maibot_session="session-token") - config_response = await get_plugin_config("demo.plugin", maibot_session="session-token") - - assert schema_response["success"] is True - assert schema_response["schema"]["plugin_id"] == "demo.plugin" - assert config_response == { - "success": True, - "config": {"plugin": {"config_version": "2.0.0", "enabled": True, "retry_count": 3}}, - "message": "配置文件不存在,已返回默认配置", - } diff --git a/pytests/test_plugin_dependency_pipeline.py b/pytests/test_plugin_dependency_pipeline.py deleted file mode 100644 index 68fd1786..00000000 --- a/pytests/test_plugin_dependency_pipeline.py +++ /dev/null @@ -1,225 +0,0 @@ -"""插件依赖流水线测试。""" - -from pathlib import Path - -import json - -import pytest - -from src.plugin_runtime.dependency_pipeline import PluginDependencyPipeline - - -def _build_manifest( - plugin_id: str, - *, - dependencies: list[dict[str, str]] | None = None, -) -> dict[str, object]: - """构造测试用的 Manifest v2 数据。 - - Args: - plugin_id: 插件 ID。 - dependencies: 依赖声明列表。 - - Returns: - dict[str, object]: 可直接写入 ``_manifest.json`` 的字典。 - """ - - return { - "manifest_version": 2, - "version": "1.0.0", - "name": plugin_id, - "description": "测试插件", - "author": { - "name": "tester", - "url": "https://example.com/tester", - }, - "license": "MIT", - "urls": { - "repository": f"https://example.com/{plugin_id}", - }, - "host_application": { - "min_version": "1.0.0", - "max_version": "1.0.0", - }, - "sdk": { - "min_version": "2.0.0", - "max_version": "2.99.99", - }, - "dependencies": dependencies or [], - "capabilities": [], - "i18n": { - "default_locale": "zh-CN", - "supported_locales": ["zh-CN"], - }, - "id": plugin_id, - } - - -def _write_plugin( - plugin_root: Path, - plugin_name: str, - plugin_id: str, - *, - dependencies: list[dict[str, str]] | None = None, -) -> Path: - """在临时目录中写入一个测试插件。 - - Args: - plugin_root: 插件根目录。 - plugin_name: 插件目录名。 - plugin_id: 插件 ID。 - dependencies: Python 依赖声明列表。 - - Returns: - Path: 插件目录路径。 - """ - - plugin_dir = plugin_root / plugin_name - plugin_dir.mkdir(parents=True) - (plugin_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") - (plugin_dir / "_manifest.json").write_text( - json.dumps(_build_manifest(plugin_id, dependencies=dependencies)), - encoding="utf-8", - ) - return plugin_dir - - -def test_build_plan_blocks_plugin_conflicting_with_host_requirement(tmp_path: Path) -> None: - """与主程序依赖冲突的插件应被阻止加载。""" - - plugin_root = tmp_path / "plugins" - _write_plugin( - plugin_root, - "conflict_plugin", - "test.conflict-plugin", - dependencies=[ - { - "type": "python_package", - "name": "numpy", - "version_spec": "<1.0.0", - } - ], - ) - - pipeline = PluginDependencyPipeline(project_root=Path.cwd()) - plan = pipeline.build_plan([plugin_root]) - - assert "test.conflict-plugin" in plan.blocked_plugin_reasons - assert "主程序" in plan.blocked_plugin_reasons["test.conflict-plugin"] - assert plan.install_requirements == () - - -def test_build_plan_blocks_plugins_with_conflicting_python_dependencies(tmp_path: Path) -> None: - """插件之间出现 Python 包版本冲突时应同时阻止双方加载。""" - - plugin_root = tmp_path / "plugins" - _write_plugin( - plugin_root, - "plugin_a", - "test.plugin-a", - dependencies=[ - { - "type": "python_package", - "name": "demo-package", - "version_spec": "<2.0.0", - } - ], - ) - _write_plugin( - plugin_root, - "plugin_b", - "test.plugin-b", - dependencies=[ - { - "type": "python_package", - "name": "demo-package", - "version_spec": ">=3.0.0", - } - ], - ) - - pipeline = PluginDependencyPipeline(project_root=Path.cwd()) - plan = pipeline.build_plan([plugin_root]) - - assert "test.plugin-a" in plan.blocked_plugin_reasons - assert "test.plugin-b" in plan.blocked_plugin_reasons - assert "test.plugin-b" in plan.blocked_plugin_reasons["test.plugin-a"] - assert "test.plugin-a" in plan.blocked_plugin_reasons["test.plugin-b"] - - -def test_build_plan_collects_install_requirements_for_missing_packages( - monkeypatch: pytest.MonkeyPatch, - tmp_path: Path, -) -> None: - """未安装但无冲突的依赖应进入自动安装计划。""" - - plugin_root = tmp_path / "plugins" - _write_plugin( - plugin_root, - "plugin_a", - "test.plugin-a", - dependencies=[ - { - "type": "python_package", - "name": "demo-package", - "version_spec": ">=1.0.0,<2.0.0", - } - ], - ) - - pipeline = PluginDependencyPipeline(project_root=Path.cwd()) - monkeypatch.setattr( - pipeline._manifest_validator, - "get_installed_package_version", - lambda package_name: None if package_name == "demo-package" else "1.0.0", - ) - - plan = pipeline.build_plan([plugin_root]) - - assert plan.blocked_plugin_reasons == {} - assert len(plan.install_requirements) == 1 - assert plan.install_requirements[0].package_name == "demo-package" - assert plan.install_requirements[0].plugin_ids == ("test.plugin-a",) - assert plan.install_requirements[0].requirement_text == "demo-package>=1.0.0,<2.0.0" - - -@pytest.mark.asyncio -async def test_execute_blocks_plugins_when_auto_install_fails( - monkeypatch: pytest.MonkeyPatch, - tmp_path: Path, -) -> None: - """自动安装失败时,相关插件应被阻止加载。""" - - plugin_root = tmp_path / "plugins" - _write_plugin( - plugin_root, - "plugin_a", - "test.plugin-a", - dependencies=[ - { - "type": "python_package", - "name": "demo-package", - "version_spec": ">=1.0.0,<2.0.0", - } - ], - ) - - pipeline = PluginDependencyPipeline(project_root=Path.cwd()) - monkeypatch.setattr( - pipeline._manifest_validator, - "get_installed_package_version", - lambda package_name: None if package_name == "demo-package" else "1.0.0", - ) - - async def fake_install(_requirements) -> tuple[bool, str]: - """模拟依赖安装失败。""" - - return False, "network error" - - monkeypatch.setattr(pipeline, "_install_requirements", fake_install) - - result = await pipeline.execute([plugin_root]) - - assert result.environment_changed is False - assert "test.plugin-a" in result.blocked_plugin_reasons - assert "自动安装 Python 依赖失败" in result.blocked_plugin_reasons["test.plugin-a"] diff --git a/pytests/test_plugin_message_utils_runtime.py b/pytests/test_plugin_message_utils_runtime.py deleted file mode 100644 index 82e63db2..00000000 --- a/pytests/test_plugin_message_utils_runtime.py +++ /dev/null @@ -1,86 +0,0 @@ -from datetime import datetime -from pathlib import Path - -import sys - -from src.chat.message_receive.message import SessionMessage -from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo, UserInfo -from src.common.data_models.message_component_data_model import ( - ForwardComponent, - ForwardNodeComponent, - ImageComponent, - MessageSequence, - ReplyComponent, - TextComponent, - VoiceComponent, -) -from src.plugin_runtime.host.message_utils import PluginMessageUtils - - -PROJECT_ROOT = Path(__file__).resolve().parents[1] -if str(PROJECT_ROOT) not in sys.path: - sys.path.insert(0, str(PROJECT_ROOT)) - - -def test_plugin_message_utils_preserves_binary_components_and_reply_metadata() -> None: - message = SessionMessage(message_id="msg-1", timestamp=datetime.now(), platform="qq") - message.message_info = MessageInfo( - user_info=UserInfo(user_id="10001", user_nickname="tester"), - group_info=GroupInfo(group_id="20001", group_name="group"), - additional_config={"self_id": "999"}, - ) - message.session_id = "qq:20001:10001" - message.processed_plain_text = "binary payload" - message.raw_message = MessageSequence( - components=[ - TextComponent("hello"), - ImageComponent(binary_hash="", binary_data=b"image-bytes", content=""), - VoiceComponent(binary_hash="", binary_data=b"voice-bytes", content=""), - ReplyComponent( - target_message_id="origin-1", - target_message_content="origin text", - target_message_sender_id="42", - target_message_sender_nickname="alice", - target_message_sender_cardname="Alice", - ), - ForwardNodeComponent( - forward_components=[ - ForwardComponent( - user_nickname="bob", - user_id="43", - user_cardname="Bob", - message_id="forward-1", - content=[ - TextComponent("node-text"), - ImageComponent(binary_hash="", binary_data=b"node-image", content=""), - ], - ) - ] - ), - ] - ) - - message_dict = PluginMessageUtils._session_message_to_dict(message) - rebuilt_message = PluginMessageUtils._build_session_message_from_dict(dict(message_dict)) - - image_component = rebuilt_message.raw_message.components[1] - voice_component = rebuilt_message.raw_message.components[2] - reply_component = rebuilt_message.raw_message.components[3] - forward_component = rebuilt_message.raw_message.components[4] - - assert isinstance(image_component, ImageComponent) - assert image_component.binary_data == b"image-bytes" - - assert isinstance(voice_component, VoiceComponent) - assert voice_component.binary_data == b"voice-bytes" - - assert isinstance(reply_component, ReplyComponent) - assert reply_component.target_message_id == "origin-1" - assert reply_component.target_message_content == "origin text" - assert reply_component.target_message_sender_id == "42" - assert reply_component.target_message_sender_nickname == "alice" - assert reply_component.target_message_sender_cardname == "Alice" - - assert isinstance(forward_component, ForwardNodeComponent) - assert isinstance(forward_component.forward_components[0].content[1], ImageComponent) - assert forward_component.forward_components[0].content[1].binary_data == b"node-image" diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py deleted file mode 100644 index 151cb20b..00000000 --- a/pytests/test_plugin_runtime.py +++ /dev/null @@ -1,3688 +0,0 @@ -"""插件运行时框架基础测试 - -验证协议层、传输层、RPC 通信链路的正确性。 -""" - -# pyright: reportArgumentType=false, reportAttributeAccessIssue=false, reportCallIssue=false, reportIndexIssue=false, reportMissingImports=false, reportOptionalMemberAccess=false - -from pathlib import Path -from types import SimpleNamespace -from typing import Any, Awaitable, Callable, Dict, List, Optional, Sequence - -import asyncio -import json -import logging -import os -import sys - -import pytest - -# 确保项目根目录在 sys.path 中 -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) -# SDK 包路径 -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "packages", "maibot-plugin-sdk")) - - -def build_test_manifest( - plugin_id: str, - *, - version: str = "1.0.0", - name: str = "测试插件", - description: str = "测试插件描述", - dependencies: list[dict[str, str]] | None = None, - llm_providers: list[dict[str, str]] | None = None, - capabilities: list[str] | None = None, - host_min_version: str = "0.12.0", - host_max_version: str = "1.0.0", - sdk_min_version: str = "2.0.0", - sdk_max_version: str = "2.99.99", -) -> dict[str, object]: - """构造一个合法的 Manifest v2 测试样例。 - - Args: - plugin_id: 插件 ID。 - version: 插件版本。 - name: 展示名称。 - description: 插件描述。 - dependencies: 依赖声明列表。 - llm_providers: LLM Provider 静态声明列表。 - capabilities: 能力声明列表。 - host_min_version: Host 最低支持版本。 - host_max_version: Host 最高支持版本。 - sdk_min_version: SDK 最低支持版本。 - sdk_max_version: SDK 最高支持版本。 - - Returns: - dict[str, object]: 可直接序列化为 ``_manifest.json`` 的字典。 - """ - return { - "manifest_version": 2, - "version": version, - "name": name, - "description": description, - "author": { - "name": "tester", - "url": "https://example.com/tester", - }, - "license": "MIT", - "urls": { - "repository": f"https://example.com/{plugin_id}", - }, - "host_application": { - "min_version": host_min_version, - "max_version": host_max_version, - }, - "sdk": { - "min_version": sdk_min_version, - "max_version": sdk_max_version, - }, - "dependencies": dependencies or [], - "llm_providers": llm_providers or [], - "capabilities": capabilities or [], - "i18n": { - "default_locale": "zh-CN", - "supported_locales": ["zh-CN"], - }, - "id": plugin_id, - } - - -def build_test_manifest_model( - plugin_id: str, - *, - version: str = "1.0.0", - dependencies: list[dict[str, str]] | None = None, - llm_providers: list[dict[str, str]] | None = None, - capabilities: list[str] | None = None, - host_version: str = "1.0.0", - sdk_version: str = "2.0.1", -) -> object: - """构造一个已经通过校验的强类型 Manifest 测试对象。 - - Args: - plugin_id: 插件 ID。 - version: 插件版本。 - dependencies: 依赖声明列表。 - llm_providers: LLM Provider 静态声明列表。 - capabilities: 能力声明列表。 - host_version: 当前测试使用的 Host 版本。 - sdk_version: 当前测试使用的 SDK 版本。 - - Returns: - object: ``PluginManifest`` 实例。 - """ - from src.plugin_runtime.runner.manifest_validator import ManifestValidator - - validator = ManifestValidator(host_version=host_version, sdk_version=sdk_version) - manifest = validator.parse_manifest( - build_test_manifest( - plugin_id, - version=version, - dependencies=dependencies, - llm_providers=llm_providers, - capabilities=capabilities, - ) - ) - assert manifest is not None - return manifest - - -# ─── 协议层测试 ─────────────────────────────────────────── - - -class TestProtocol: - """协议层测试""" - - def test_envelope_create_and_serialize(self): - """Envelope 创建与序列化""" - from src.plugin_runtime.protocol.envelope import Envelope, MessageType - - env = Envelope( - request_id=1, - message_type=MessageType.REQUEST, - method="plugin.invoke_command", - plugin_id="test_plugin", - payload={"component_name": "greet", "args": {}}, - ) - - assert env.request_id == 1 - assert env.is_request() - assert env.method == "plugin.invoke_command" - - # 测试 make_response - resp = env.make_response(payload={"success": True}) - assert resp.is_response() - assert resp.request_id == 1 - assert resp.payload["success"] is True - - def test_envelope_make_error_response(self): - """错误响应生成""" - from src.plugin_runtime.protocol.envelope import Envelope, MessageType - - env = Envelope( - request_id=42, - message_type=MessageType.REQUEST, - method="cap.request", - ) - - err_resp = env.make_error_response("E_UNAUTHORIZED", "没有权限") - assert err_resp.error is not None - assert err_resp.error["code"] == "E_UNAUTHORIZED" - assert err_resp.error["message"] == "没有权限" - - def test_msgpack_codec(self): - """MsgPack 编解码""" - from src.plugin_runtime.protocol.codec import MsgPackCodec - from src.plugin_runtime.protocol.envelope import Envelope, MessageType - - codec = MsgPackCodec() - env = Envelope( - request_id=100, - message_type=MessageType.REQUEST, - method="test.method", - payload={"key": "value", "number": 42}, - ) - - # 编码 - data = codec.encode_envelope(env) - assert isinstance(data, bytes) - - # 解码 - decoded = codec.decode_envelope(data) - assert decoded.request_id == 100 - assert decoded.method == "test.method" - assert decoded.payload["key"] == "value" - assert decoded.payload["number"] == 42 - - def test_json_codec(self): - """JSON 编解码已移除,仅保留 MsgPack""" - pass - - def test_request_id_generator(self): - """请求 ID 生成器单调递增""" - from src.plugin_runtime.protocol.envelope import RequestIdGenerator - - gen = RequestIdGenerator() - ids = [gen.next() for _ in range(100)] - assert ids == list(range(1, 101)) - - def test_error_codes(self): - """错误码枚举""" - from src.plugin_runtime.protocol.errors import ErrorCode, RPCError - - err = RPCError(ErrorCode.E_TIMEOUT, "请求超时") - assert err.code == ErrorCode.E_TIMEOUT - assert "E_TIMEOUT" in str(err) - - # 序列化/反序列化 - d = err.to_dict() - err2 = RPCError.from_dict(d) - assert err2.code == ErrorCode.E_TIMEOUT - - -# ─── 传输层测试 ─────────────────────────────────────────── - - -class TestTransport: - """传输层测试""" - - @pytest.mark.asyncio - async def test_uds_connection_framing(self): - """UDS 分帧协议测试""" - from src.plugin_runtime.transport.uds import UDSTransportServer, UDSTransportClient - - server = UDSTransportServer() - received = asyncio.Event() - received_data = [] - - async def handler(conn): - data = await conn.recv_frame() - received_data.append(data) - await conn.send_frame(b"pong") - received.set() - - await server.start(handler) - address = server.get_address() - - client = UDSTransportClient(address) - conn = await client.connect() - await conn.send_frame(b"ping") - - # 等待服务端处理 - await asyncio.wait_for(received.wait(), timeout=5.0) - assert received_data[0] == b"ping" - - # 接收服务端回复 - resp = await conn.recv_frame() - assert resp == b"pong" - - await conn.close() - await server.stop() - - @pytest.mark.asyncio - async def test_tcp_connection_framing(self): - """TCP 分帧协议测试""" - from src.plugin_runtime.transport.tcp import TCPTransportServer, TCPTransportClient - - server = TCPTransportServer() - received = asyncio.Event() - received_data = [] - - async def handler(conn): - data = await conn.recv_frame() - received_data.append(data) - await conn.send_frame(b"tcp_pong") - received.set() - - await server.start(handler) - address = server.get_address() - host, port = address.split(":") - - client = TCPTransportClient(host, int(port)) - conn = await client.connect() - await conn.send_frame(b"tcp_ping") - - await asyncio.wait_for(received.wait(), timeout=5.0) - assert received_data[0] == b"tcp_ping" - - resp = await conn.recv_frame() - assert resp == b"tcp_pong" - - await conn.close() - await server.stop() - - @pytest.mark.asyncio - @pytest.mark.skipif(sys.platform != "win32", reason="Windows only") - async def test_named_pipe_connection_framing(self): - """Windows Named Pipe 分帧协议测试""" - from src.plugin_runtime.transport.named_pipe import NamedPipeTransportClient, NamedPipeTransportServer - - server = NamedPipeTransportServer() - received = asyncio.Event() - received_data = [] - - async def handler(conn): - data = await conn.recv_frame() - received_data.append(data) - await conn.send_frame(b"pipe_pong") - received.set() - - await server.start(handler) - client = NamedPipeTransportClient(server.get_address()) - conn = await client.connect() - await conn.send_frame(b"pipe_ping") - - await asyncio.wait_for(received.wait(), timeout=5.0) - assert received_data[0] == b"pipe_ping" - - resp = await conn.recv_frame() - assert resp == b"pipe_pong" - - await conn.close() - await server.stop() - - @pytest.mark.asyncio - async def test_transport_factory(self): - """传输工厂测试""" - from src.plugin_runtime.transport.factory import create_transport_server, create_transport_client - - server = create_transport_server() - assert server is not None - - # UDS 路径 - client = create_transport_client("/tmp/test.sock") - assert client is not None - - # Windows Named Pipe 地址 - client = create_transport_client(r"\\.\pipe\maibot-test") - assert client is not None - - # TCP 地址 - client = create_transport_client("127.0.0.1:9999") - assert client is not None - - -# ─── Host 层测试 ────────────────────────────────────────── - - -class TestHost: - """Host 端基础设施测试""" - - def test_policy_engine(self): - """策略引擎测试""" - from src.plugin_runtime.host.policy_engine import PolicyEngine - - engine = PolicyEngine() - - # 注册插件 - token = engine.register_plugin( - plugin_id="test_plugin", - generation=1, - capabilities=["send.text", "db.query"], - ) - - assert token.plugin_id == "test_plugin" - assert "send.text" in token.capabilities - - # 能力检查 - ok, _ = engine.check_capability("test_plugin", "send.text") - assert ok - - ok, reason = engine.check_capability("test_plugin", "llm.generate") - assert not ok - assert "未获授权" in reason - - # 未注册插件 - ok, reason = engine.check_capability("unknown", "send.text") - assert not ok - - ok, reason = engine.check_capability("test_plugin", "send.text", generation=2) - assert not ok - assert "generation 不匹配" in reason - - def test_policy_engine_allows_parallel_generations(self): - """同一插件在热重载期间应允许 active/staged 两代并行持有能力令牌。""" - from src.plugin_runtime.host.policy_engine import PolicyEngine - - engine = PolicyEngine() - engine.register_plugin("test_plugin", generation=1, capabilities=["send.text"]) - engine.register_plugin("test_plugin", generation=2, capabilities=["send.text", "llm.generate"]) - - ok, _ = engine.check_capability("test_plugin", "send.text", generation=1) - assert ok is True - - ok, _ = engine.check_capability("test_plugin", "llm.generate", generation=2) - assert ok is True - - ok, reason = engine.check_capability("test_plugin", "llm.generate", generation=1) - assert ok is False - assert "未获授权" in reason - - def test_circuit_breaker_removed(self): - """熔断器已移除,验证 supervisor 不依赖它""" - pass - - def test_circuit_breaker_registry_removed(self): - """熔断器注册表已移除""" - pass - - -# ─── SDK 测试 ───────────────────────────────────────────── - - -class TestSDK: - """SDK 框架测试""" - - def test_component_decorators(self): - """组件装饰器测试""" - from maibot_sdk import MaiBotPlugin, Action, Command, Tool, EventHandler - from maibot_sdk.types import ActivationType, EventType - - class TestPlugin(MaiBotPlugin): - @Action("greet", activation_type=ActivationType.KEYWORD, activation_keywords=["hi"]) - async def handle_greet(self, **kwargs): - return True, "ok" - - @Command("echo", pattern=r"^/echo") - async def handle_echo(self, **kwargs): - return True, "echoed", 2 - - @Tool("search", parameters={"query": {"type": "string"}}) - async def handle_search(self, **kwargs): - return {"result": "found"} - - @EventHandler("on_start", event_type=EventType.ON_START) - async def handle_start(self, **kwargs): - return True, False, "started" - - plugin = TestPlugin() - components = plugin.get_components() - - assert len(components) == 4 - - names = {c["name"] for c in components} - assert "greet" in names - assert "echo" in names - assert "search" in names - assert "on_start" in names - - types = {c["type"] for c in components} - assert "action" in types - assert "command" in types - assert "tool" in types - assert "event_handler" in types - - def test_plugin_context_not_initialized(self): - """未初始化上下文时应报错""" - from maibot_sdk import MaiBotPlugin - - plugin = MaiBotPlugin() - with pytest.raises(RuntimeError, match="尚未初始化"): - _ = plugin.ctx - - def test_plugin_context_injection(self): - """上下文注入测试""" - from maibot_sdk import MaiBotPlugin - from maibot_sdk.context import PluginContext - - plugin = MaiBotPlugin() - ctx = PluginContext(plugin_id="test") - plugin._set_context(ctx) - - assert plugin.ctx.plugin_id == "test" - assert plugin.ctx.send is not None - assert plugin.ctx.db is not None - assert plugin.ctx.llm is not None - assert plugin.ctx.config is not None - - @pytest.mark.asyncio - async def test_runner_injected_context_binds_plugin_identity(self): - """Runner 注入的上下文应忽略调用方伪造的 plugin_id。""" - from src.plugin_runtime.runner.runner_main import PluginRunner - - class DummyRPCClient: - def __init__(self): - self.calls = [] - - async def send_request(self, method, plugin_id="", payload=None, timeout_ms=30000): - self.calls.append( - { - "method": method, - "plugin_id": plugin_id, - "payload": payload, - "timeout_ms": timeout_ms, - } - ) - return SimpleNamespace( - error=None, - payload={"success": True, "result": {"success": True, "result": {"ok": True}}}, - ) - - class DummyPlugin: - def _set_context(self, ctx): - self.ctx = ctx - - runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) - runner._rpc_client = DummyRPCClient() - - plugin = DummyPlugin() - runner._inject_context("owner_plugin", plugin) - - plugin.ctx._plugin_id = "forged_plugin" - result = await plugin.ctx.call_capability("send.text", text="hello", stream_id="stream-1") - - assert result is True - assert runner._rpc_client.calls[0]["plugin_id"] == "owner_plugin" - assert runner._rpc_client.calls[0]["method"] == "cap.call" - - @pytest.mark.asyncio - async def test_runner_injected_context_unwraps_llm_available_models(self): - """Runner 应为 SDK 解开 cap.call 响应外层,避免模型列表被规整成空列表。""" - from src.plugin_runtime.runner.runner_main import PluginRunner - - class DummyRPCClient: - async def send_request(self, method, plugin_id="", payload=None, timeout_ms=30000): - assert method == "cap.call" - assert plugin_id == "owner_plugin" - assert payload == {"capability": "llm.get_available_models", "args": {}} - return SimpleNamespace( - error=None, - payload={"success": True, "result": {"success": True, "models": ["utils", "replyer"]}}, - ) - - class DummyPlugin: - def _set_context(self, ctx): - self.ctx = ctx - - runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) - runner._rpc_client = DummyRPCClient() - - plugin = DummyPlugin() - runner._inject_context("owner_plugin", plugin) - - assert await plugin.ctx.llm.get_available_models() == ["utils", "replyer"] - - @pytest.mark.asyncio - async def test_runner_injected_context_raises_send_capability_error_details(self): - """Runner 应将 send.* 能力失败的底层错误透传为异常。""" - from src.plugin_runtime.runner.runner_main import PluginRunner - - class DummyRPCClient: - async def send_request(self, method, plugin_id="", payload=None, timeout_ms=30000): - assert method == "cap.call" - assert plugin_id == "owner_plugin" - assert payload == { - "capability": "send.custom", - "args": { - "message_type": "poke", - "content": {"qq_id": "1"}, - "custom_type": "poke", - "data": {"qq_id": "1"}, - "stream_id": "当前聊天流", - }, - } - return SimpleNamespace( - error=None, - payload={"success": True, "result": {"success": False, "error": "未找到聊天流: 当前聊天流"}}, - ) - - class DummyPlugin: - def _set_context(self, ctx): - self.ctx = ctx - - runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) - runner._rpc_client = DummyRPCClient() - - plugin = DummyPlugin() - runner._inject_context("owner_plugin", plugin) - - with pytest.raises(RuntimeError, match="未找到聊天流: 当前聊天流"): - await plugin.ctx.send.custom( - custom_type="poke", - data={"qq_id": "1"}, - stream_id="当前聊天流", - ) - - @pytest.mark.asyncio - async def test_runner_invoke_tool_propagates_send_failure_details(self): - """插件工具捕获 send.* 失败时,应能拿到底层错误详情。""" - from src.plugin_runtime.protocol.envelope import Envelope, MessageType - from src.plugin_runtime.runner.runner_main import PluginRunner - - class DummyRPCClient: - async def send_request(self, method, plugin_id="", payload=None, timeout_ms=30000): - assert method == "cap.call" - return SimpleNamespace( - error=None, - payload={"success": True, "result": {"success": False, "error": "未找到聊天流: 当前聊天流"}}, - ) - - class DummyPlugin: - def _set_context(self, ctx): - self.ctx = ctx - - async def handle_poke(self, **kwargs): - try: - await self.ctx.send.custom( - custom_type="poke", - data={"qq_id": "1"}, - stream_id=str(kwargs.get("stream_id", "")), - ) - except Exception as exc: - return {"success": False, "message": f"戳一戳失败: {exc}"} - return {"success": True, "message": "戳一戳成功"} - - runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) - runner._rpc_client = DummyRPCClient() - - plugin = DummyPlugin() - runner._inject_context("demo_plugin", plugin) - meta = SimpleNamespace( - plugin_id="demo_plugin", - instance=plugin, - component_handlers={"poke": "handle_poke"}, - ) - runner._loader._loaded_plugins["demo_plugin"] = meta - - envelope = Envelope( - request_id=1, - message_type=MessageType.REQUEST, - method="plugin.invoke_tool", - plugin_id="demo_plugin", - payload={"component_name": "poke", "args": {"stream_id": "当前聊天流"}}, - ) - - response = await runner._handle_invoke(envelope) - - assert response.payload["success"] is True - assert response.payload["result"] == {"success": False, "message": "戳一戳失败: 未找到聊天流: 当前聊天流"} - - @pytest.mark.asyncio - async def test_runner_applies_initial_plugin_config(self, tmp_path): - """Runner 应在 on_load 前为支持的插件实例注入 config.toml。""" - from src.plugin_runtime.runner.runner_main import PluginRunner - - class DummyPlugin: - def __init__(self): - self.configs = [] - - def set_plugin_config(self, config): - self.configs.append(config) - - plugin_dir = tmp_path / "demo_plugin" - plugin_dir.mkdir() - (plugin_dir / "config.toml").write_text("[section]\nvalue = 1\n", encoding="utf-8") - - runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) - plugin = DummyPlugin() - meta = SimpleNamespace(plugin_id="demo_plugin", plugin_dir=str(plugin_dir), instance=plugin) - - runner._apply_plugin_config(meta) - - assert plugin.configs == [{"section": {"value": 1}}] - - @pytest.mark.asyncio - async def test_runner_config_update_refreshes_plugin_config_before_callback(self): - """配置更新时应先刷新插件配置,再调用 on_config_update。""" - from src.plugin_runtime.protocol.envelope import Envelope, MessageType - from src.plugin_runtime.runner.runner_main import PluginRunner - - class DummyPlugin: - def __init__(self): - self.configs = [] - self.updates = [] - - def set_plugin_config(self, config): - self.configs.append(config) - - async def on_config_update(self, scope, config, version): - self.updates.append((scope, config, version, list(self.configs))) - - runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) - plugin = DummyPlugin() - runner._loader._loaded_plugins["demo_plugin"] = SimpleNamespace(instance=plugin) - - envelope = Envelope( - request_id=1, - message_type=MessageType.REQUEST, - method="plugin.config_updated", - plugin_id="demo_plugin", - payload={ - "plugin_id": "demo_plugin", - "config_scope": "self", - "config_data": {"enabled": True}, - "config_version": "v2", - }, - ) - - response = await runner._handle_config_updated(envelope) - - assert response.payload["acknowledged"] is True - assert plugin.configs == [{"enabled": True}] - assert plugin.updates == [("self", {"enabled": True}, "v2", [{"enabled": True}])] - - @pytest.mark.asyncio - async def test_runner_global_config_update_does_not_override_plugin_config(self): - """bot/model 广播不应覆盖插件自身配置缓存。""" - from src.plugin_runtime.protocol.envelope import Envelope, MessageType - from src.plugin_runtime.runner.runner_main import PluginRunner - - class DummyPlugin: - def __init__(self): - self.configs = [] - self.updates = [] - - def set_plugin_config(self, config): - self.configs.append(config) - - async def on_config_update(self, scope, config, version): - self.updates.append((scope, config, version, list(self.configs))) - - runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) - plugin = DummyPlugin() - runner._loader._loaded_plugins["demo_plugin"] = SimpleNamespace(instance=plugin) - plugin.set_plugin_config({"plugin_enabled": True}) - - envelope = Envelope( - request_id=1, - message_type=MessageType.REQUEST, - method="plugin.config_updated", - plugin_id="demo_plugin", - payload={ - "plugin_id": "demo_plugin", - "config_scope": "model", - "config_data": {"models": []}, - "config_version": "", - }, - ) - - response = await runner._handle_config_updated(envelope) - - assert response.payload["acknowledged"] is True - assert plugin.configs == [{"plugin_enabled": True}] - assert plugin.updates == [("model", {"models": []}, "", [{"plugin_enabled": True}])] - - @pytest.mark.asyncio - async def test_host_logs_runner_ready_plugin_failures(self, caplog): - """Host 收到 runner.ready 时应明确记录插件注册失败。""" - from src.plugin_runtime.host.supervisor import PluginRunnerSupervisor - from src.plugin_runtime.protocol.envelope import Envelope, MessageType - - supervisor = PluginRunnerSupervisor(plugin_dirs=[], runner_spawn_timeout_sec=1) - envelope = Envelope( - request_id=1, - message_type=MessageType.REQUEST, - method="runner.ready", - plugin_id="", - payload={ - "loaded_plugins": ["ok_plugin"], - "failed_plugins": ["bad_plugin"], - "inactive_plugins": ["disabled_plugin"], - }, - ) - - with caplog.at_level(logging.INFO, logger="plugin_runtime.host.runner_manager"): - response = await supervisor._handle_runner_ready(envelope) - - assert response.payload["accepted"] is True - assert "插件注册失败: bad_plugin" in caplog.text - assert "插件未激活: disabled_plugin" in caplog.text - assert "Runner 插件初始化完成: loaded=1 failed=1 inactive=1" in caplog.text - - @pytest.mark.asyncio - async def test_runner_bootstraps_capabilities_before_on_load(self, monkeypatch): - """on_load 期间的 capability 调用应在 bootstrap 后生效。""" - from src.plugin_runtime.runner.runner_main import PluginRunner - - class DummyRPCClient: - def __init__(self): - self.calls = [] - - async def connect_and_handshake(self): - return True - - def register_method(self, method, handler): - return None - - async def send_request(self, method, plugin_id="", payload=None, timeout_ms=30000): - self.calls.append( - { - "method": method, - "plugin_id": plugin_id, - "payload": payload, - "timeout_ms": timeout_ms, - } - ) - if method == "cap.call": - bootstrap_methods = [call["method"] for call in self.calls[:-1]] - assert "plugin.bootstrap" in bootstrap_methods - return SimpleNamespace(error=None, payload={"success": True, "result": {"success": True}}) - return SimpleNamespace(error=None, payload={"accepted": True}) - - async def disconnect(self): - return None - - class DummyPlugin: - def __init__(self, runner): - self.runner = runner - - def _set_context(self, ctx): - self.ctx = ctx - - def get_components(self): - return [{"name": "handler", "type": "command", "metadata": {}}] - - async def on_load(self): - result = await self.ctx.call_capability("send.text", text="hello", stream_id="stream-1") - assert result is True - self.runner._shutting_down = True - - runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) - runner._rpc_client = DummyRPCClient() - - plugin = DummyPlugin(runner) - meta = SimpleNamespace( - plugin_id="demo_plugin", - plugin_dir="/tmp/demo_plugin", - instance=plugin, - version="1.0.0", - capabilities_required=["send.text"], - dependencies=[], - manifest=SimpleNamespace(plugin_dependencies=[], llm_provider_client_types=[]), - component_handlers={}, - llm_provider_handlers={}, - ) - - monkeypatch.setattr(runner, "_install_log_handler", lambda: None) - monkeypatch.setattr(runner, "_uninstall_log_handler", lambda: asyncio.sleep(0)) - monkeypatch.setattr(runner._loader, "discover_and_load", lambda plugin_dirs, **kwargs: [meta]) - - await runner.run() - - methods = [call["method"] for call in runner._rpc_client.calls] - assert methods == ["plugin.bootstrap", "plugin.register_components", "cap.call", "runner.ready"] - - @pytest.mark.asyncio - async def test_runner_batch_reload_merges_overlapping_reverse_dependents(self, monkeypatch): - """批量重载应只对重叠依赖闭包执行一次 unload/load。""" - from src.plugin_runtime.runner.runner_main import PluginRunner - - runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) - plugin_a_id = "test.plugin-a" - plugin_b_id = "test.plugin-b" - plugin_c_id = "test.plugin-c" - - def build_meta(plugin_id: str, dependencies: list[str]) -> SimpleNamespace: - return SimpleNamespace( - plugin_id=plugin_id, - dependencies=dependencies, - plugin_dir=f"/tmp/{plugin_id}", - version="1.0.0", - instance=SimpleNamespace(), - ) - - loaded_metas = { - plugin_a_id: build_meta(plugin_a_id, []), - plugin_b_id: build_meta(plugin_b_id, [plugin_a_id]), - plugin_c_id: build_meta(plugin_c_id, [plugin_b_id]), - } - reloaded_metas = { - plugin_id: build_meta(plugin_id, list(meta.dependencies)) - for plugin_id, meta in loaded_metas.items() - } - candidates = { - plugin_a_id: ( - "dir_plugin_a", - build_test_manifest_model(plugin_a_id), - "plugin_a/plugin.py", - ), - plugin_b_id: ( - "dir_plugin_b", - build_test_manifest_model( - plugin_b_id, - dependencies=[{"type": "plugin", "id": plugin_a_id, "version_spec": ">=1.0.0,<2.0.0"}], - ), - "plugin_b/plugin.py", - ), - plugin_c_id: ( - "dir_plugin_c", - build_test_manifest_model( - plugin_c_id, - dependencies=[{"type": "plugin", "id": plugin_b_id, "version_spec": ">=1.0.0,<2.0.0"}], - ), - "plugin_c/plugin.py", - ), - } - unloaded_plugins: list[str] = [] - activated_plugins: list[str] = [] - - monkeypatch.setattr(runner._loader, "discover_candidates", lambda plugin_dirs: (candidates, {})) - monkeypatch.setattr(runner._loader, "list_plugins", lambda: sorted(loaded_metas.keys())) - monkeypatch.setattr(runner._loader, "get_plugin", lambda plugin_id: loaded_metas.get(plugin_id)) - monkeypatch.setattr( - runner._loader, - "remove_loaded_plugin", - lambda plugin_id: loaded_metas.pop(plugin_id, None), - ) - monkeypatch.setattr(runner._loader, "purge_plugin_modules", lambda plugin_id, plugin_dir: []) - monkeypatch.setattr( - runner._loader, - "resolve_dependencies", - lambda reload_candidates, extra_available=None: (sorted(reload_candidates.keys()), {}), - ) - monkeypatch.setattr( - runner._loader, - "load_candidate", - lambda plugin_id, candidate: reloaded_metas[plugin_id], - ) - - async def fake_unload_plugin(meta, reason, purge_modules=False): - del reason, purge_modules - unloaded_plugins.append(meta.plugin_id) - loaded_metas.pop(meta.plugin_id, None) - - async def fake_activate_plugin(meta): - activated_plugins.append(meta.plugin_id) - loaded_metas[meta.plugin_id] = meta - return True - - monkeypatch.setattr(runner, "_unload_plugin", fake_unload_plugin) - monkeypatch.setattr(runner, "_activate_plugin", fake_activate_plugin) - - result = await runner._reload_plugins_by_ids([plugin_a_id, plugin_b_id], reason="manual") - - assert result.success is True - assert result.requested_plugin_ids == [plugin_a_id, plugin_b_id] - assert unloaded_plugins == [plugin_c_id, plugin_b_id, plugin_a_id] - assert activated_plugins == [plugin_a_id, plugin_b_id, plugin_c_id] - assert result.reloaded_plugins == [plugin_a_id, plugin_b_id, plugin_c_id] - - -class TestPluginSdkUsage: - """验证仓库内插件按新 SDK 归一化返回值工作。""" - - def test_runner_skips_signal_handler_registration_on_windows(self, monkeypatch): - """Windows 下不应尝试注册 add_signal_handler。""" - from src.plugin_runtime.runner import runner_main - - registered_signals = [] - - class DummyLoop: - def add_signal_handler(self, sig, callback): - registered_signals.append((sig, callback)) - - monkeypatch.setattr(runner_main.sys, "platform", "win32") - - runner_main._install_shutdown_signal_handlers(lambda: None, DummyLoop()) - - assert not registered_signals - - @pytest.mark.asyncio - async def test_builtin_emoji_plugin_handles_normalized_results(self): - from maibot_sdk.context import PluginContext - from src.plugins.built_in.emoji_plugin.plugin import EmojiPlugin - - async def fake_rpc_call(method: str, plugin_id: str = "", payload: dict | None = None): - assert method == "cap.request" - assert payload is not None - capability = payload["capability"] - return { - "emoji.get_random": { - "success": True, - "emojis": [{"base64": "img-1", "emotion": "happy"}], - }, - "message.get_recent": {"success": True, "messages": [{"id": 1}]}, - "message.build_readable": {"success": True, "text": "最近消息"}, - "llm.generate": {"success": True, "response": "happy", "reasoning": "", "model_name": "m"}, - "send.emoji": {"success": True}, - }[capability] - - plugin = EmojiPlugin() - plugin._set_context(PluginContext(plugin_id="emoji", rpc_call=fake_rpc_call)) - - success, message = await plugin.handle_emoji(stream_id="stream-1", reasoning="测试", chat_id="chat-1") - - assert success is True - assert "成功发送表情包" in message - - @pytest.mark.asyncio - async def test_tts_plugin_uses_send_custom_bool_result(self): - from maibot_sdk.context import PluginContext - from src.plugins.built_in.tts_plugin.plugin import TTSPlugin - - async def fake_rpc_call(method: str, plugin_id: str = "", payload: dict | None = None): - assert method == "cap.request" - assert payload is not None - assert payload["capability"] == "send.custom" - return {"success": True} - - plugin = TTSPlugin() - plugin._set_context(PluginContext(plugin_id="tts", rpc_call=fake_rpc_call)) - - success, message = await plugin.handle_tts_action( - stream_id="stream-1", - action_data={"voice_text": "你好!!!"}, - ) - - assert success is True - assert message == "TTS动作执行成功" - - @pytest.mark.asyncio - async def test_hello_world_plugin_handles_random_emoji_list(self): - from maibot_sdk.context import PluginContext - from plugins.hello_world_plugin.plugin import HelloWorldPlugin - - async def fake_rpc_call(method: str, plugin_id: str = "", payload: dict | None = None): - assert method == "cap.request" - assert payload is not None - capability = payload["capability"] - return { - "emoji.get_random": {"success": True, "emojis": [{"base64": "img-1"}, {"base64": "img-2"}]}, - "send.forward": {"success": True}, - }[capability] - - plugin = HelloWorldPlugin() - plugin._set_context(PluginContext(plugin_id="hello", rpc_call=fake_rpc_call)) - - success, message, should_continue = await plugin.handle_random_emojis(stream_id="stream-1") - - assert success is True - assert message == "已发送随机表情包" - assert should_continue is True - - -# ─── 端到端集成测试 ──────────────────────────────────────── - - -class TestE2E: - """端到端集成测试(Host + Runner 通信)""" - - @pytest.mark.asyncio - async def test_handshake(self): - """Host-Runner 握手流程测试""" - from src.plugin_runtime.protocol.codec import MsgPackCodec - from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType - from src.plugin_runtime.transport.factory import create_transport_client, create_transport_server - - import secrets - - session_token = secrets.token_hex(16) - codec = MsgPackCodec() - handshake_done = asyncio.Event() - server_result = {} - - async def server_handler(conn): - # 接收握手 - data = await conn.recv_frame() - env = codec.decode_envelope(data) - assert env.method == "runner.hello" - - hello = HelloPayload.model_validate(env.payload) - assert hello.session_token == session_token - - # 发送响应 - resp_payload = HelloResponsePayload( - accepted=True, - host_version="1.0", - assigned_generation=1, - ) - resp = env.make_response(payload=resp_payload.model_dump()) - await conn.send_frame(codec.encode_envelope(resp)) - - server_result["runner_id"] = hello.runner_id - handshake_done.set() - - # 保持连接一会儿 - await asyncio.sleep(1.0) - - server = create_transport_server() - await server.start(server_handler) - - # 客户端握手 - client = create_transport_client(server.get_address()) - conn = await client.connect() - - hello = HelloPayload( - runner_id="test-runner", - sdk_version="1.0.0", - session_token=session_token, - ) - env = Envelope( - request_id=1, - message_type=MessageType.REQUEST, - method="runner.hello", - payload=hello.model_dump(), - ) - await conn.send_frame(codec.encode_envelope(env)) - - resp_data = await conn.recv_frame() - resp = codec.decode_envelope(resp_data) - resp_payload = HelloResponsePayload.model_validate(resp.payload) - - assert resp_payload.accepted - assert resp_payload.assigned_generation == 1 - - await asyncio.wait_for(handshake_done.wait(), timeout=5.0) - assert server_result["runner_id"] == "test-runner" - - await conn.close() - await server.stop() - - -# ─── Manifest 校验测试 ───────────────────────────────────── - - -class TestManifestValidator: - """Manifest 校验器测试""" - - def test_valid_manifest(self): - from src.plugin_runtime.runner.manifest_validator import ManifestValidator - - validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") - manifest = build_test_manifest("test.valid-plugin", capabilities=["send.text"]) - assert validator.validate(manifest) is True - assert len(validator.errors) == 0 - assert validator.warnings == [] - - def test_manifest_id_allows_uppercase_and_underscore(self): - from src.plugin_runtime.runner.manifest_validator import ManifestValidator - - validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") - manifest = build_test_manifest("XXXxx7258.google_search_plugin", capabilities=["send.text"]) - assert validator.validate(manifest) is True - assert validator.errors == [] - - def test_missing_required_fields(self): - from src.plugin_runtime.runner.manifest_validator import ManifestValidator - - validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") - manifest = {"manifest_version": 2} - assert validator.validate(manifest) is False - assert len(validator.errors) >= 6 - assert any("缺少必需字段" in error for error in validator.errors) - - def test_unsupported_manifest_version(self): - from src.plugin_runtime.runner.manifest_validator import ManifestValidator - - validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") - manifest = build_test_manifest("test.invalid-version") - manifest["manifest_version"] = 999 - assert validator.validate(manifest) is False - assert any("manifest_version" in e for e in validator.errors) - - def test_host_version_compatibility(self): - from src.plugin_runtime.runner.manifest_validator import ManifestValidator - - validator = ManifestValidator(host_version="0.8.5", sdk_version="2.0.1") - manifest = build_test_manifest( - "test.host-check", - host_min_version="0.9.0", - host_max_version="1.0.0", - ) - assert validator.validate(manifest) is False - assert any("Host 版本不兼容" in e for e in validator.errors) - - def test_sdk_version_compatibility(self): - from src.plugin_runtime.runner.manifest_validator import ManifestValidator - - validator = ManifestValidator(host_version="1.0.0", sdk_version="1.9.9") - manifest = build_test_manifest("test.sdk-check") - assert validator.validate(manifest) is False - assert any("SDK 版本不兼容" in e for e in validator.errors) - - def test_extra_fields_are_rejected(self): - from src.plugin_runtime.runner.manifest_validator import ManifestValidator - - validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") - manifest = build_test_manifest("test.extra-field") - manifest["unexpected"] = True - - assert validator.validate(manifest) is False - assert any("存在未声明字段" in error for error in validator.errors) - - def test_python_package_conflict_rejects_manifest(self): - from src.plugin_runtime.runner.manifest_validator import ManifestValidator - - validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") - manifest = build_test_manifest( - "test.numpy-conflict", - dependencies=[ - { - "type": "python_package", - "name": "numpy", - "version_spec": ">=999.0.0", - } - ], - ) - - assert validator.validate(manifest) is False - assert any("Python 包依赖冲突" in error for error in validator.errors) - - def test_llm_provider_manifest_declaration(self): - from src.plugin_runtime.runner.manifest_validator import ManifestValidator - - validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") - manifest = build_test_manifest( - "test.llm-provider", - llm_providers=[ - { - "client_type": "example.provider", - "name": "Example Provider", - "description": "测试 Provider", - "version": "1.0.0", - } - ], - ) - - parsed_manifest = validator.parse_manifest(manifest) - - assert parsed_manifest is not None - assert parsed_manifest.llm_provider_client_types == ["example.provider"] - - def test_duplicate_llm_provider_manifest_declaration_is_rejected(self): - from src.plugin_runtime.runner.manifest_validator import ManifestValidator - - validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") - manifest = build_test_manifest( - "test.llm-provider-duplicate", - llm_providers=[ - {"client_type": "example.provider"}, - {"client_type": "example.provider"}, - ], - ) - - assert validator.validate(manifest) is False - assert any("重复的 LLM Provider" in error for error in validator.errors) - - -def test_llm_provider_conflict_blocks_all_conflicting_plugins(tmp_path: Path): - from src.plugin_runtime.integration import PluginRuntimeManager - - plugin_root = tmp_path / "plugins" - plugin_root.mkdir() - for plugin_id in ["test.provider-alpha", "test.provider-beta"]: - plugin_dir = plugin_root / plugin_id - plugin_dir.mkdir() - manifest = build_test_manifest( - plugin_id, - llm_providers=[{"client_type": "example.provider"}], - ) - (plugin_dir / "_manifest.json").write_text(json.dumps(manifest), encoding="utf-8") - (plugin_dir / "plugin.py").write_text("def create_plugin():\n return None\n", encoding="utf-8") - - blocked_reasons = PluginRuntimeManager._discover_llm_provider_conflicts([plugin_root]) - - assert set(blocked_reasons) == {"test.provider-alpha", "test.provider-beta"} - assert all("example.provider" in reason for reason in blocked_reasons.values()) - - -class TestVersionComparator: - """版本号比较器测试""" - - def test_normalize(self): - from src.plugin_runtime.runner.manifest_validator import VersionComparator - - assert VersionComparator.normalize_version("0.8.0-snapshot.1") == "0.8.0" - assert VersionComparator.normalize_version("1.2") == "1.2.0" - assert VersionComparator.normalize_version("1.0.0rc16") == "1.0.0" - assert VersionComparator.normalize_version("1.0.0-pre.16") == "1.0.0" - assert VersionComparator.normalize_version("") == "0.0.0" - - def test_compare(self): - from src.plugin_runtime.runner.manifest_validator import VersionComparator - - assert VersionComparator.compare("0.8.0", "0.8.0") == 0 - assert VersionComparator.compare("0.8.0", "0.9.0") == -1 - assert VersionComparator.compare("1.0.0", "0.9.0") == 1 - - def test_is_in_range(self): - from src.plugin_runtime.runner.manifest_validator import VersionComparator - - ok, _ = VersionComparator.is_in_range("0.8.5", "0.8.0", "0.9.0") - assert ok - ok, _ = VersionComparator.is_in_range("0.7.0", "0.8.0", "0.9.0") - assert not ok - ok, _ = VersionComparator.is_in_range("1.0.0", "0.8.0", "0.9.0") - assert not ok - - -# ─── 依赖解析测试 ────────────────────────────────────────── - - -class TestDependencyResolution: - """插件依赖解析测试""" - - def test_topological_sort(self): - from src.plugin_runtime.runner.plugin_loader import PluginLoader - - loader = PluginLoader() - candidates = { - "test.core": ( - "dir_core", - build_test_manifest_model("test.core"), - "plugin.py", - ), - "test.auth": ( - "dir_auth", - build_test_manifest_model( - "test.auth", - dependencies=[ - {"type": "plugin", "id": "test.core", "version_spec": ">=1.0.0,<2.0.0"}, - ], - ), - "plugin.py", - ), - "test.api": ( - "dir_api", - build_test_manifest_model( - "test.api", - dependencies=[ - {"type": "plugin", "id": "test.core", "version_spec": ">=1.0.0,<2.0.0"}, - {"type": "plugin", "id": "test.auth", "version_spec": ">=1.0.0,<2.0.0"}, - ], - ), - "plugin.py", - ), - } - - order, failed = loader._resolve_dependencies(candidates) - assert len(failed) == 0 - assert order.index("test.core") < order.index("test.auth") - assert order.index("test.auth") < order.index("test.api") - - def test_missing_dependency(self): - from src.plugin_runtime.runner.plugin_loader import PluginLoader - - loader = PluginLoader() - candidates = { - "test.plugin-a": ( - "dir_a", - build_test_manifest_model( - "test.plugin-a", - dependencies=[ - {"type": "plugin", "id": "test.nonexistent", "version_spec": ">=1.0.0,<2.0.0"}, - ], - ), - "plugin.py", - ), - } - - order, failed = loader._resolve_dependencies(candidates) - assert "test.plugin-a" in failed - assert "依赖未满足" in failed["test.plugin-a"] - - def test_circular_dependency(self): - from src.plugin_runtime.runner.plugin_loader import PluginLoader - - loader = PluginLoader() - candidates = { - "test.a": ( - "dir_a", - build_test_manifest_model( - "test.a", - dependencies=[ - {"type": "plugin", "id": "test.b", "version_spec": ">=1.0.0,<2.0.0"}, - ], - ), - "p.py", - ), - "test.b": ( - "dir_b", - build_test_manifest_model( - "test.b", - dependencies=[ - {"type": "plugin", "id": "test.a", "version_spec": ">=1.0.0,<2.0.0"}, - ], - ), - "p.py", - ), - } - - order, failed = loader._resolve_dependencies(candidates) - assert len(failed) >= 1 # 至少一个循环插件被标记 - - def test_loader_supports_package_imports_inside_create_plugin(self, tmp_path): - from src.plugin_runtime.runner.plugin_loader import PluginLoader - - plugin_root = tmp_path / "plugins" - plugin_root.mkdir() - plugin_dir = plugin_root / "grok_search_plugin" - plugin_dir.mkdir() - - (plugin_dir / "_manifest.json").write_text( - json.dumps( - build_test_manifest( - "test.grok-search-plugin", - name="grok_search_plugin", - description="demo", - ) - ), - encoding="utf-8", - ) - (plugin_dir / "__init__.py").write_text("VALUE = 1\n", encoding="utf-8") - (plugin_dir / "services.py").write_text("def answer():\n return 42\n", encoding="utf-8") - (plugin_dir / "plugin.py").write_text( - "class DemoPlugin:\n" - " pass\n\n" - "def create_plugin():\n" - " from grok_search_plugin.services import answer\n" - " plugin = DemoPlugin()\n" - " plugin.answer = answer\n" - " return plugin\n", - encoding="utf-8", - ) - - loader = PluginLoader() - loaded = loader.discover_and_load([str(plugin_root)]) - - assert [meta.plugin_id for meta in loaded] == ["test.grok-search-plugin"] - assert loader.failed_plugins == {} - assert loaded[0].instance.answer() == 42 - - def test_loader_requires_sdk_plugin_to_override_on_config_update(self, tmp_path): - from src.plugin_runtime.runner.plugin_loader import PluginLoader - - plugin_root = tmp_path / "plugins" - plugin_root.mkdir() - plugin_dir = plugin_root / "demo_plugin" - plugin_dir.mkdir() - - (plugin_dir / "_manifest.json").write_text( - json.dumps( - build_test_manifest( - "test.demo-plugin", - name="demo_plugin", - description="demo", - ) - ), - encoding="utf-8", - ) - (plugin_dir / "plugin.py").write_text( - "from maibot_sdk import MaiBotPlugin\n\n" - "class DemoPlugin(MaiBotPlugin):\n" - " async def on_load(self):\n" - " pass\n\n" - " async def on_unload(self):\n" - " pass\n\n" - "def create_plugin():\n" - " return DemoPlugin()\n", - encoding="utf-8", - ) - - loader = PluginLoader() - loaded = loader.discover_and_load([str(plugin_root)]) - - assert loaded == [] - assert "test.demo-plugin" in loader.failed_plugins - assert "on_config_update" in loader.failed_plugins["test.demo-plugin"] - - def test_loader_requires_sdk_plugin_to_override_on_load(self, tmp_path): - from src.plugin_runtime.runner.plugin_loader import PluginLoader - - plugin_root = tmp_path / "plugins" - plugin_root.mkdir() - plugin_dir = plugin_root / "demo_plugin" - plugin_dir.mkdir() - - (plugin_dir / "_manifest.json").write_text( - json.dumps( - build_test_manifest( - "test.demo-plugin", - name="demo_plugin", - description="demo", - ) - ), - encoding="utf-8", - ) - (plugin_dir / "plugin.py").write_text( - "from maibot_sdk import MaiBotPlugin\n\n" - "class DemoPlugin(MaiBotPlugin):\n" - " async def on_unload(self):\n" - " pass\n\n" - " async def on_config_update(self, scope, config_data, version):\n" - " pass\n\n" - "def create_plugin():\n" - " return DemoPlugin()\n", - encoding="utf-8", - ) - - loader = PluginLoader() - loaded = loader.discover_and_load([str(plugin_root)]) - - assert loaded == [] - assert "test.demo-plugin" in loader.failed_plugins - assert "on_load" in loader.failed_plugins["test.demo-plugin"] - - def test_loader_requires_sdk_plugin_to_override_on_unload(self, tmp_path): - from src.plugin_runtime.runner.plugin_loader import PluginLoader - - plugin_root = tmp_path / "plugins" - plugin_root.mkdir() - plugin_dir = plugin_root / "demo_plugin" - plugin_dir.mkdir() - - (plugin_dir / "_manifest.json").write_text( - json.dumps( - build_test_manifest( - "test.demo-plugin", - name="demo_plugin", - description="demo", - ) - ), - encoding="utf-8", - ) - (plugin_dir / "plugin.py").write_text( - "from maibot_sdk import MaiBotPlugin\n\n" - "class DemoPlugin(MaiBotPlugin):\n" - " async def on_load(self):\n" - " pass\n\n" - " async def on_config_update(self, scope, config_data, version):\n" - " pass\n\n" - "def create_plugin():\n" - " return DemoPlugin()\n", - encoding="utf-8", - ) - - loader = PluginLoader() - loaded = loader.discover_and_load([str(plugin_root)]) - - assert loaded == [] - assert "test.demo-plugin" in loader.failed_plugins - assert "on_unload" in loader.failed_plugins["test.demo-plugin"] - - @pytest.mark.asyncio - async def test_async_main_removes_sensitive_runtime_env_vars(self, monkeypatch): - from src.plugin_runtime.runner import runner_main - - captured = {} - original_path = list(sys.path) - - class FakeRunner: - def __init__( - self, - host_address: str, - session_token: str, - plugin_dirs: list[str], - external_available_plugins: dict[str, str] | None = None, - ) -> None: - captured["host_address"] = host_address - captured["session_token"] = session_token - captured["plugin_dirs"] = plugin_dirs - captured["external_available_plugins"] = external_available_plugins or {} - - async def run(self) -> None: - assert os.environ.get(runner_main.ENV_IPC_ADDRESS) is None - assert os.environ.get(runner_main.ENV_SESSION_TOKEN) is None - - monkeypatch.setenv(runner_main.ENV_IPC_ADDRESS, "tcp://127.0.0.1:9999") - monkeypatch.setenv(runner_main.ENV_SESSION_TOKEN, "secret-token") - monkeypatch.setenv(runner_main.ENV_PLUGIN_DIRS, "/tmp/plugins") - monkeypatch.setenv(runner_main.ENV_EXTERNAL_PLUGIN_IDS, '{"demo.plugin":"1.0.0"}') - monkeypatch.setattr(runner_main, "_install_shutdown_signal_handlers", lambda callback: None) - monkeypatch.setattr(runner_main, "PluginRunner", FakeRunner) - - await runner_main._async_main() - - assert captured["host_address"] == "tcp://127.0.0.1:9999" - assert captured["session_token"] == "secret-token" - assert captured["plugin_dirs"] == ["/tmp/plugins"] - assert captured["external_available_plugins"] == {"demo.plugin": "1.0.0"} - assert sys.path == original_path - - -# ─── Host-side ComponentRegistry 测试 ────────────────────── - - -class TestComponentRegistry: - """Host-side 组件注册表测试""" - - def test_register_and_query(self): - from src.plugin_runtime.host.component_registry import ComponentRegistry - - reg = ComponentRegistry() - reg.register_component( - "greet", - "action", - "plugin_a", - { - "description": "打招呼", - "activation_type": "keyword", - "activation_keywords": ["hi"], - }, - ) - reg.register_component( - "help", - "command", - "plugin_a", - { - "command_pattern": r"^/help", - }, - ) - reg.register_component( - "search", - "tool", - "plugin_b", - { - "description": "搜索", - }, - ) - - stats = reg.get_stats() - assert stats["total"] == 3 - assert stats["action"] == 1 - assert stats["command"] == 1 - assert stats["tool"] == 1 - - def test_register_command_with_invalid_regex_only_warns(self, monkeypatch): - from src.plugin_runtime.host.component_registry import ComponentRegistry - - reg = ComponentRegistry() - warnings: list[str] = [] - monkeypatch.setattr( - "src.plugin_runtime.host.component_registry.logger.warning", - lambda message: warnings.append(str(message)), - ) - - success = reg.register_component( - "broken", - "command", - "plugin_a", - { - "command_pattern": "[", - }, - ) - - assert success is True - assert reg.get_component("plugin_a.broken") is not None - assert warnings - assert "plugin_a.broken" in warnings[0] - - def test_register_hook_handler_rejects_unknown_hook(self): - from src.plugin_runtime.host.component_registry import ComponentRegistrationError, ComponentRegistry - from src.plugin_runtime.host.hook_spec_registry import HookSpecRegistry - - reg = ComponentRegistry(hook_spec_registry=HookSpecRegistry()) - - with pytest.raises(ComponentRegistrationError, match="未注册的 Hook"): - reg.register_component( - "broken_hook", - "hook_handler", - "plugin_a", - { - "hook": "chat.receive.unknown", - "mode": "blocking", - }, - ) - - def test_register_plugin_components_is_atomic_when_hook_invalid(self): - from src.plugin_runtime.host.component_registry import ComponentRegistrationError, ComponentRegistry - from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry - - hook_spec_registry = HookSpecRegistry() - hook_spec_registry.register_hook_spec(HookSpec(name="chat.receive.before_process")) - reg = ComponentRegistry(hook_spec_registry=hook_spec_registry) - reg.register_plugin_components( - "plugin_a", - [ - {"name": "cmd_old", "component_type": "command", "metadata": {"command_pattern": r"^/old"}}, - ], - ) - - with pytest.raises(ComponentRegistrationError, match="未注册的 Hook"): - reg.register_plugin_components( - "plugin_a", - [ - { - "name": "hook_ok", - "component_type": "hook_handler", - "metadata": {"hook": "chat.receive.before_process", "mode": "blocking"}, - }, - { - "name": "hook_bad", - "component_type": "hook_handler", - "metadata": {"hook": "chat.receive.missing", "mode": "blocking"}, - }, - ], - ) - - assert reg.get_component("plugin_a.cmd_old") is not None - assert reg.get_component("plugin_a.hook_ok") is None - - def test_query_by_type(self): - from src.plugin_runtime.host.component_registry import ComponentRegistry - - reg = ComponentRegistry() - reg.register_component("a1", "action", "p1", {}) - reg.register_component("a2", "action", "p2", {}) - - actions = reg.get_components_by_type("action") - assert len(actions) == 2 - - def test_find_command_by_text(self): - from src.plugin_runtime.host.component_registry import ComponentRegistry - - reg = ComponentRegistry() - reg.register_component( - "help", - "command", - "p1", - { - "command_pattern": r"^/help", - }, - ) - reg.register_component( - "echo", - "command", - "p1", - { - "command_pattern": r"^/echo\s", - }, - ) - - match = reg.find_command_by_text("/help me") - assert match is not None - comp, groups = match - assert comp.name == "help" - - match = reg.find_command_by_text("/echo hello") - assert match is not None - comp, groups = match - assert comp.name == "echo" - - match = reg.find_command_by_text("no match") - assert match is None - - def test_enable_disable(self): - from src.plugin_runtime.host.component_registry import ComponentRegistry - - reg = ComponentRegistry() - reg.register_component("a1", "action", "p1", {}) - reg.set_component_enabled("p1.a1", False) - - actions = reg.get_components_by_type("action", enabled_only=True) - assert len(actions) == 0 - - actions = reg.get_components_by_type("action", enabled_only=False) - assert len(actions) == 1 - - def test_remove_by_plugin(self): - from src.plugin_runtime.host.component_registry import ComponentRegistry - - reg = ComponentRegistry() - reg.register_component("a1", "action", "p1", {}) - reg.register_component("c1", "command", "p1", {}) - reg.register_component("a2", "action", "p2", {}) - - removed = reg.remove_components_by_plugin("p1") - assert removed == 2 - assert reg.get_stats()["total"] == 1 - - def test_reregister_same_plugin_replaces_component_set(self): - from src.plugin_runtime.host.component_registry import ComponentRegistry - - reg = ComponentRegistry() - reg.register_plugin_components( - "p1", - [ - {"name": "a1", "component_type": "action", "metadata": {}}, - {"name": "a2", "component_type": "action", "metadata": {}}, - ], - ) - reg.remove_components_by_plugin("p1") - reg.register_plugin_components( - "p1", - [ - {"name": "a1", "component_type": "action", "metadata": {}}, - ], - ) - - assert reg.get_component("p1.a1") is not None - assert reg.get_component("p1.a2") is None - - def test_event_handlers_sorted_by_weight(self): - from src.plugin_runtime.host.component_registry import ComponentRegistry - - reg = ComponentRegistry() - reg.register_component( - "h_low", - "event_handler", - "p1", - { - "event_type": "on_message", - "weight": 10, - }, - ) - reg.register_component( - "h_high", - "event_handler", - "p2", - { - "event_type": "on_message", - "weight": 100, - }, - ) - - handlers = reg.get_event_handlers("on_message") - assert handlers[0].name == "h_high" - assert handlers[1].name == "h_low" - - def test_tools_for_llm(self): - from src.plugin_runtime.host.component_registry import ComponentRegistry - - reg = ComponentRegistry() - reg.register_component( - "search", - "tool", - "p1", - { - "description": "搜索工具", - "parameters_raw": {"query": {"type": "string"}}, - }, - ) - - tools = reg.get_tools_for_llm() - assert len(tools) == 1 - assert tools[0]["name"] == "p1.search" - assert tools[0]["parameters"]["query"]["type"] == "string" - - -# ─── EventDispatcher 测试 ───────────────────────────────── - - -class TestEventDispatcher: - """Host-side 事件分发器测试""" - - @pytest.mark.asyncio - async def test_dispatch_non_blocking(self): - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.event_dispatcher import EventDispatcher - - reg = ComponentRegistry() - reg.register_component( - "h1", - "event_handler", - "p1", - { - "event_type": "on_start", - "weight": 0, - "intercept_message": False, - }, - ) - - dispatcher = EventDispatcher(reg) - call_log = [] - - async def mock_invoke(plugin_id, comp_name, args): - call_log.append((plugin_id, comp_name)) - return {"success": True, "continue_processing": True} - - should_continue, modified = await dispatcher.dispatch_event("on_start", mock_invoke) - assert should_continue - # 非阻塞分发是异步的,等一下让 task 完成 - await asyncio.sleep(0.1) - assert len(call_log) == 1 - assert call_log[0] == ("p1", "h1") - - @pytest.mark.asyncio - async def test_dispatch_intercepting(self): - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.event_dispatcher import EventDispatcher - - reg = ComponentRegistry() - reg.register_component( - "filter", - "event_handler", - "p1", - { - "event_type": "on_message_pre_process", - "weight": 100, - "intercept_message": True, - }, - ) - - dispatcher = EventDispatcher(reg) - - async def mock_invoke(plugin_id, comp_name, args): - return { - "success": True, - "continue_processing": False, - "modified_message": {"plain_text": "filtered"}, - } - - should_continue, modified = await dispatcher.dispatch_event( - "on_message_pre_process", mock_invoke, message={"plain_text": "hello"} - ) - assert not should_continue - assert modified is not None - assert modified["plain_text"] == "filtered" - - -class TestEventBus: - """核心事件总线与 IPC 桥接测试""" - - @pytest.mark.asyncio - async def test_bridge_preserves_modified_message(self, monkeypatch): - import types - - fake_message_data_model = types.ModuleType("src.common.data_models.message_data_model") - fake_message_data_model.ReplyContentType = object - fake_message_data_model.ReplyContent = object - fake_message_data_model.ForwardNode = object - fake_message_data_model.ReplySetModel = object - monkeypatch.setitem(sys.modules, "src.common.data_models.message_data_model", fake_message_data_model) - - from src.core.event_bus import EventBus - from src.core.types import EventType, MaiMessages - from src.plugin_runtime import integration as integration_module - - bus = EventBus() - - async def noop_handler(message): - return True, message - - bus.subscribe(EventType.ON_MESSAGE, noop_handler, name="noop", intercept=True) - - class FakeManager: - is_running = True - - async def bridge_event(self, event_type_value, message_dict=None, extra_args=None): - assert event_type_value == EventType.ON_MESSAGE.value - return True, {"plain_text": "modified by ipc"} - - monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager()) - - original = MaiMessages(plain_text="original") - continue_flag, modified = await bus.emit(EventType.ON_MESSAGE, original) - - assert continue_flag is True - assert modified is not None - assert modified.plain_text == "modified by ipc" - assert original.plain_text == "original" - - -# ─── MaiMessages 测试 ───────────────────────────────────── - - -class TestMaiMessages: - """统一消息模型测试""" - - def test_create_and_serialize(self): - from maibot_sdk.messages import MaiMessages, MessageSegment - - msg = MaiMessages( - message_segments=[MessageSegment(type="text", data={"text": "hello"})], - plain_text="hello", - stream_id="stream_1", - ) - - d = msg.to_rpc_dict() - assert d["plain_text"] == "hello" - assert len(d["message_segments"]) == 1 - - msg2 = MaiMessages.from_rpc_dict(d) - assert msg2.plain_text == "hello" - - def test_deepcopy(self): - from maibot_sdk.messages import MaiMessages - - msg = MaiMessages(plain_text="original") - msg2 = msg.deepcopy() - msg2.plain_text = "modified" - assert msg.plain_text == "original" - - def test_modify_flags(self): - from maibot_sdk.messages import MaiMessages - from maibot_sdk.types import ModifyFlag - - msg = MaiMessages(plain_text="hello") - assert msg.can_modify(ModifyFlag.CAN_MODIFY_PROMPT) - - msg.set_modify_flag(ModifyFlag.CAN_MODIFY_PROMPT, False) - assert not msg.modify_prompt("new prompt") - assert msg.llm_prompt is None - - assert msg.modify_response("new response") - assert msg.llm_response_content == "new response" - - -class _FakeHookSupervisor: - """用于 Hook 分发测试的简化 Supervisor。""" - - def __init__( - self, - group_name: str, - component_registry: Any, - handlers: Dict[str, Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]] | Dict[str, Any]]], - call_log: List[tuple[str, str]], - ) -> None: - """初始化测试用 Supervisor。 - - Args: - group_name: 运行时分组名称。 - component_registry: 组件注册表实例。 - handlers: 处理器映射,键为 `plugin_id.component_name`。 - call_log: 记录调用顺序的列表。 - """ - - self._group_name = group_name - self.component_registry = component_registry - self._handlers = handlers - self._call_log = call_log - - @property - def group_name(self) -> str: - """返回当前测试 Supervisor 的分组名称。""" - - return self._group_name - - async def invoke_plugin( - self, - method: str, - plugin_id: str, - component_name: str, - args: Optional[Dict[str, Any]] = None, - timeout_ms: int = 30000, - ) -> SimpleNamespace: - """模拟调用插件组件。 - - Args: - method: RPC 方法名。 - plugin_id: 目标插件 ID。 - component_name: 目标组件名称。 - args: 调用参数。 - timeout_ms: 超时配置,测试中仅用于保持接口一致。 - - Returns: - SimpleNamespace: 仅包含 `payload` 字段的简化响应对象。 - """ - - del method - del timeout_ms - - full_name = f"{plugin_id}.{component_name}" - handler = self._handlers[full_name] - self._call_log.append((plugin_id, component_name)) - result = handler(dict(args or {})) - if asyncio.iscoroutine(result): - result = await result - return SimpleNamespace(payload=result) - - -# ─── HookDispatcher 测试 ──────────────────────────────── - - -class TestHookDispatcher: - """命名 Hook 分发器测试。""" - - @staticmethod - def _import_dispatcher_modules(monkeypatch: pytest.MonkeyPatch) -> tuple[Any, Any]: - """导入 Hook 分发相关模块,并屏蔽配置初始化触发的退出。 - - Args: - monkeypatch: pytest 的 monkeypatch 工具。 - - Returns: - tuple[Any, Any]: `ComponentRegistry` 与 `HookDispatcher` 类型。 - """ - - monkeypatch.setattr(sys, "exit", lambda code=0: None) - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.hook_dispatcher import HookDispatcher - - return ComponentRegistry, HookDispatcher - - @pytest.mark.asyncio - async def test_empty_hook_returns_original_kwargs(self, monkeypatch: pytest.MonkeyPatch) -> None: - """未注册处理器时应直接返回原始参数。""" - - ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) - - dispatcher = HookDispatcher() - supervisor = _FakeHookSupervisor("builtin", ComponentRegistry(), {}, []) - - result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1") - - assert result.hook_name == "heart_fc.cycle_start" - assert result.kwargs == {"session_id": "s-1"} - assert result.aborted is False - - @pytest.mark.asyncio - async def test_blocking_hook_modifies_kwargs(self, monkeypatch: pytest.MonkeyPatch) -> None: - """blocking 处理器可以修改参数。""" - - ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) - - registry = ComponentRegistry() - registry.register_component( - "upper", - "HOOK_HANDLER", - "p1", - { - "hook": "heart_fc.cycle_start", - "mode": "blocking", - "order": "normal", - }, - ) - dispatcher = HookDispatcher() - supervisor = _FakeHookSupervisor( - "builtin", - registry, - { - "p1.upper": lambda args: { - "success": True, - "action": "continue", - "modified_kwargs": { - "session_id": args["session_id"], - "text": str(args["text"]).upper(), - }, - } - }, - [], - ) - - result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1", text="hello") - - assert result.kwargs["session_id"] == "s-1" - assert result.kwargs["text"] == "HELLO" - assert result.aborted is False - - @pytest.mark.asyncio - async def test_abort_stops_following_blocking_handlers(self, monkeypatch: pytest.MonkeyPatch) -> None: - """blocking 处理器的 abort 应阻止后续 blocking 处理器执行。""" - - ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) - - registry = ComponentRegistry() - registry.register_component( - "stopper", - "HOOK_HANDLER", - "p1", - {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"}, - ) - registry.register_component( - "after_stop", - "HOOK_HANDLER", - "p2", - {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"}, - ) - call_log: List[tuple[str, str]] = [] - dispatcher = HookDispatcher() - supervisor = _FakeHookSupervisor( - "builtin", - registry, - { - "p1.stopper": lambda args: {"success": True, "action": "abort"}, - "p2.after_stop": lambda args: {"success": True, "action": "continue"}, - }, - call_log, - ) - - result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], cycle_id="c-1") - - assert result.aborted is True - assert result.stopped_by == "p1.stopper" - assert call_log == [("p1", "stopper")] - - @pytest.mark.asyncio - async def test_observe_handler_runs_in_background_without_mutation( - self, - monkeypatch: pytest.MonkeyPatch, - ) -> None: - """observe 处理器应后台执行且不能影响主流程参数。""" - - ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) - - registry = ComponentRegistry() - registry.register_component( - "observer", - "HOOK_HANDLER", - "p1", - {"hook": "heart_fc.cycle_start", "mode": "observe", "order": "normal"}, - ) - started = asyncio.Event() - release = asyncio.Event() - call_log: List[tuple[str, str]] = [] - - async def observe_handler(args: Dict[str, Any]) -> Dict[str, Any]: - """模拟耗时观察型处理器。""" - - started.set() - await release.wait() - return { - "success": True, - "action": "abort", - "modified_kwargs": {"session_id": "changed"}, - "custom_result": args["session_id"], - } - - dispatcher = HookDispatcher() - supervisor = _FakeHookSupervisor( - "builtin", - registry, - {"p1.observer": observe_handler}, - call_log, - ) - - result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1") - - await asyncio.sleep(0) - assert result.aborted is False - assert result.kwargs["session_id"] == "s-1" - assert started.is_set() - assert len(dispatcher._background_tasks) == 1 - - release.set() - await asyncio.sleep(0) - await asyncio.sleep(0) - assert call_log == [("p1", "observer")] - assert not dispatcher._background_tasks - - @pytest.mark.asyncio - async def test_global_order_prefers_order_slot_then_source(self, monkeypatch: pytest.MonkeyPatch) -> None: - """全局排序应先看 order,再看内置/第三方来源。""" - - ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) - - builtin_registry = ComponentRegistry() - third_registry = ComponentRegistry() - builtin_registry.register_component( - "builtin_early", - "HOOK_HANDLER", - "b1", - {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"}, - ) - builtin_registry.register_component( - "builtin_normal", - "HOOK_HANDLER", - "b1", - {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"}, - ) - third_registry.register_component( - "third_early", - "HOOK_HANDLER", - "t1", - {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"}, - ) - third_registry.register_component( - "third_normal", - "HOOK_HANDLER", - "t1", - {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"}, - ) - - call_log: List[tuple[str, str]] = [] - dispatcher = HookDispatcher() - builtin_supervisor = _FakeHookSupervisor( - "builtin", - builtin_registry, - { - "b1.builtin_early": lambda args: {"success": True, "action": "continue"}, - "b1.builtin_normal": lambda args: {"success": True, "action": "continue"}, - }, - call_log, - ) - third_supervisor = _FakeHookSupervisor( - "third_party", - third_registry, - { - "t1.third_early": lambda args: {"success": True, "action": "continue"}, - "t1.third_normal": lambda args: {"success": True, "action": "continue"}, - }, - call_log, - ) - - await dispatcher.invoke_hook( - "heart_fc.cycle_start", - [third_supervisor, builtin_supervisor], - cycle_id="c-1", - ) - - assert call_log == [ - ("b1", "builtin_early"), - ("t1", "third_early"), - ("b1", "builtin_normal"), - ("t1", "third_normal"), - ] - - @pytest.mark.asyncio - async def test_error_policy_abort_stops_dispatch(self, monkeypatch: pytest.MonkeyPatch) -> None: - """error_policy=abort 时应中止本次 Hook 调用。""" - - ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) - - registry = ComponentRegistry() - registry.register_component( - "failer", - "HOOK_HANDLER", - "p1", - { - "hook": "heart_fc.cycle_start", - "mode": "blocking", - "order": "normal", - "error_policy": "abort", - }, - ) - call_log: List[tuple[str, str]] = [] - - async def fail_handler(args: Dict[str, Any]) -> Dict[str, Any]: - """抛出异常以触发 abort 策略。""" - - del args - raise RuntimeError("boom") - - dispatcher = HookDispatcher() - supervisor = _FakeHookSupervisor("builtin", registry, {"p1.failer": fail_handler}, call_log) - - result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1") - - assert result.aborted is True - assert result.stopped_by == "p1.failer" - assert any("boom" in error for error in result.errors) - assert call_log == [("p1", "failer")] - - @pytest.mark.asyncio - async def test_timeout_respects_handler_timeout_ms(self, monkeypatch: pytest.MonkeyPatch) -> None: - """处理器超时应被记录为错误并继续。""" - - ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) - - registry = ComponentRegistry() - registry.register_component( - "slow", - "HOOK_HANDLER", - "p1", - { - "hook": "heart_fc.cycle_start", - "mode": "blocking", - "order": "normal", - "timeout_ms": 10, - }, - ) - call_log: List[tuple[str, str]] = [] - - async def slow_handler(args: Dict[str, Any]) -> Dict[str, Any]: - """模拟超时处理器。""" - - del args - await asyncio.sleep(0.05) - return {"success": True, "action": "continue"} - - dispatcher = HookDispatcher() - supervisor = _FakeHookSupervisor("builtin", registry, {"p1.slow": slow_handler}, call_log) - - result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1") - - assert result.aborted is False - assert any("超时" in error for error in result.errors) - assert call_log == [("p1", "slow")] - - -class TestPluginRuntimeHookEntry: - """PluginRuntimeManager 命名 Hook 入口测试。""" - - @staticmethod - def _import_manager_modules(monkeypatch: pytest.MonkeyPatch) -> tuple[Any, Any]: - """导入运行时管理器相关模块,并屏蔽配置初始化触发的退出。 - - Args: - monkeypatch: pytest 的 monkeypatch 工具。 - - Returns: - tuple[Any, Any]: `ComponentRegistry` 与 `PluginRuntimeManager` 类型。 - """ - - monkeypatch.setattr(sys, "exit", lambda code=0: None) - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.integration import PluginRuntimeManager - - return ComponentRegistry, PluginRuntimeManager - - @pytest.mark.asyncio - async def test_manager_invoke_hook_dispatches_across_supervisors( - self, - monkeypatch: pytest.MonkeyPatch, - ) -> None: - """PluginRuntimeManager.invoke_hook() 应调用全局 Hook 分发器。""" - - ComponentRegistry, PluginRuntimeManager = self._import_manager_modules(monkeypatch) - - builtin_registry = ComponentRegistry() - builtin_registry.register_component( - "builtin_guard", - "HOOK_HANDLER", - "b1", - {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"}, - ) - third_registry = ComponentRegistry() - third_registry.register_component( - "observer", - "HOOK_HANDLER", - "t1", - {"hook": "heart_fc.cycle_start", "mode": "observe", "order": "normal"}, - ) - - call_log: List[tuple[str, str]] = [] - manager = PluginRuntimeManager() - manager._started = True - manager._builtin_supervisor = _FakeHookSupervisor( - "builtin", - builtin_registry, - {"b1.builtin_guard": lambda args: {"success": True, "action": "continue"}}, - call_log, - ) - manager._third_party_supervisor = _FakeHookSupervisor( - "third_party", - third_registry, - {"t1.observer": lambda args: {"success": True, "action": "continue"}}, - call_log, - ) - - result = await manager.invoke_dispatcher.invoke_hook("heart_fc.cycle_start", session_id="s-1") - - await asyncio.sleep(0) - assert manager.invoke_dispatcher is manager.hook_dispatcher - assert result.aborted is False - assert result.kwargs["session_id"] == "s-1" - assert ("b1", "builtin_guard") in call_log - - def test_manager_lists_builtin_hook_specs(self, monkeypatch: pytest.MonkeyPatch) -> None: - """PluginRuntimeManager 应暴露内置 Hook 规格清单。""" - - _ComponentRegistry, PluginRuntimeManager = self._import_manager_modules(monkeypatch) - - manager = PluginRuntimeManager() - hook_names = {spec.name for spec in manager.list_hook_specs()} - - assert "chat.receive.before_process" in hook_names - assert "send_service.before_send" in hook_names - assert "maisaka.planner.after_response" in hook_names - - -class TestRPCServer: - """RPC Server 代际保护测试""" - - @pytest.mark.asyncio - async def test_reject_second_active_runner_connection(self): - from src.plugin_runtime.host.rpc_server import RPCServer - from src.plugin_runtime.protocol.codec import MsgPackCodec - from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType - - class DummyTransport: - async def start(self, handler): - return None - - async def stop(self): - return None - - def get_address(self): - return "dummy" - - class FakeConnection: - def __init__(self, incoming_frames: list[bytes]): - self._incoming_frames = list(incoming_frames) - self.sent_frames: list[bytes] = [] - self.is_closed = False - - async def recv_frame(self): - return self._incoming_frames.pop(0) - - async def send_frame(self, data): - self.sent_frames.append(data) - - async def close(self): - self.is_closed = True - - codec = MsgPackCodec() - server = RPCServer(transport=DummyTransport(), session_token="session-token") - active_conn = SimpleNamespace(is_closed=False) - server._connection = active_conn - - hello = HelloPayload( - runner_id="runner-b", - sdk_version="1.0.0", - session_token="session-token", - ) - envelope = Envelope( - request_id=1, - message_type=MessageType.REQUEST, - method="runner.hello", - payload=hello.model_dump(), - ) - incoming_conn = FakeConnection([codec.encode_envelope(envelope)]) - - await server._handle_connection(incoming_conn) - - assert incoming_conn.is_closed is True - assert server._connection is active_conn - assert server.last_handshake_rejection_reason == "已有活跃 Runner 连接,拒绝新的握手" - assert len(incoming_conn.sent_frames) == 1 - - response = codec.decode_envelope(incoming_conn.sent_frames[0]) - response_payload = HelloResponsePayload.model_validate(response.payload) - assert response_payload.accepted is False - assert response_payload.reason == "已有活跃 Runner 连接,拒绝新的握手" - - def test_ignore_stale_generation_response(self): - from src.plugin_runtime.host.rpc_server import RPCServer - from src.plugin_runtime.protocol.envelope import Envelope, MessageType - - class DummyTransport: - async def start(self, handler): - return None - - async def stop(self): - return None - - def get_address(self): - return "dummy" - - server = RPCServer(transport=DummyTransport()) - server._runner_generation = 2 - - loop = asyncio.new_event_loop() - try: - future = loop.create_future() - server._pending_requests[1] = (future, 2) - - stale_response = Envelope( - request_id=1, - message_type=MessageType.RESPONSE, - method="plugin.health", - generation=1, - payload={"healthy": True}, - ) - server._handle_response(stale_response) - - assert not future.done() - assert 1 in server._pending_requests - finally: - loop.close() - - @pytest.mark.asyncio - async def test_send_queue_backpressure_is_enforced(self): - from src.plugin_runtime.host.rpc_server import RPCServer - from src.plugin_runtime.protocol.errors import ErrorCode, RPCError - - class DummyTransport: - async def start(self, handler): - return None - - async def stop(self): - return None - - def get_address(self): - return "dummy" - - class BlockingConnection: - def __init__(self): - self.is_closed = False - self.release = asyncio.Event() - - async def send_frame(self, data): - await self.release.wait() - - async def close(self): - self.is_closed = True - - server = RPCServer(transport=DummyTransport(), send_queue_size=1) - await server.start() - - conn = BlockingConnection() - server._connection = conn - server._runner_generation = 1 - - first_send = asyncio.create_task(server.send_event("runner.log_batch")) - await asyncio.sleep(0) - second_send = asyncio.create_task(server.send_event("runner.log_batch")) - await asyncio.sleep(0) - - with pytest.raises(RPCError) as exc_info: - await server.send_event("runner.log_batch") - - assert exc_info.value.code == ErrorCode.E_BACKPRESSURE - - conn.release.set() - await asyncio.gather(first_send, second_send) - await server.stop() - - -class TestRPCClient: - """Runner RPCClient 后台任务生命周期测试""" - - @pytest.mark.asyncio - async def test_background_tasks_retained_and_cancelled_on_disconnect(self): - from src.plugin_runtime.runner.rpc_client import RPCClient - - client = RPCClient(host_address="dummy", session_token="token") - release = asyncio.Event() - - async def pending_task(): - await release.wait() - - task = asyncio.create_task(pending_task()) - client._track_background_task(task) - - assert task in client._background_tasks - - await asyncio.sleep(0) - assert task in client._background_tasks - - await client.disconnect() - - assert task.cancelled() is True - assert not client._background_tasks - - -class TestSupervisor: - """Supervisor 生命周期边界测试""" - - @staticmethod - def _build_register_payload(plugin_id: str = "plugin_a", component_names=None): - from src.plugin_runtime.protocol.envelope import ComponentDeclaration, RegisterComponentsPayload - - component_names = component_names or ["handler"] - - return RegisterComponentsPayload( - plugin_id=plugin_id, - plugin_version="1.0.0", - components=[ - ComponentDeclaration( - name=name, - component_type="event_handler", - plugin_id=plugin_id, - metadata={"event_type": "on_message"}, - ) - for name in component_names - ], - capabilities_required=["send.text"], - ) - - @staticmethod - def _make_process(pid: int): - class FakeProcess: - def __init__(self): - self.pid = pid - self.returncode = None - self.stdout = None - self.stderr = None - self.terminated = False - self.killed = False - - def terminate(self): - self.terminated = True - self.returncode = 0 - - def kill(self): - self.killed = True - self.returncode = -9 - - async def wait(self): - return self.returncode - - return FakeProcess() - - @pytest.mark.asyncio - async def test_reload_waits_for_target_generation(self, monkeypatch): - from src.plugin_runtime.host.supervisor import PluginSupervisor - from src.plugin_runtime.protocol.envelope import HealthPayload - - supervisor = PluginSupervisor(plugin_dirs=[]) - old_process = self._make_process(1) - new_process = self._make_process(2) - - class FakeRPCServer: - def __init__(self): - self.runner_generation = 1 - self.staged_generation = 0 - self.is_connected = True - self.session_token = "fake-token" - self.committed = False - self.staging_started = False - - def reset_session_token(self): - self.session_token = "new-fake-token" - return self.session_token - - def restore_session_token(self, token): - self.session_token = token - - def begin_staged_takeover(self): - self.staging_started = True - self.staged_generation = 2 - - async def commit_staged_takeover(self): - self.runner_generation = self.staged_generation - self.staged_generation = 0 - self.committed = True - - async def rollback_staged_takeover(self): - self.staged_generation = 0 - - def has_generation(self, generation): - return generation in {self.runner_generation, self.staged_generation} - - async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs): - assert target_generation == 2 - return SimpleNamespace(payload=HealthPayload(healthy=True).model_dump()) - - supervisor._rpc_server = FakeRPCServer() - supervisor._runner_process = old_process - - async def fake_spawn_runner(): - supervisor._runner_process = new_process - supervisor._staged_registered_plugins["plugin_a"] = self._build_register_payload("plugin_a") - supervisor._runner_ready_payloads[2] = SimpleNamespace(loaded_plugins=["plugin_a"], failed_plugins=[]) - supervisor._runner_ready_events[2] = asyncio.Event() - supervisor._runner_ready_events[2].set() - - monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner) - - reloaded = await supervisor.reload_plugins("test") - - assert reloaded is True - assert supervisor._runner_process is new_process - assert supervisor._rpc_server.committed is True - assert old_process.terminated is True - - @pytest.mark.asyncio - async def test_reload_restores_runtime_state_on_failure(self, monkeypatch): - from src.plugin_runtime.host.supervisor import PluginSupervisor - - supervisor = PluginSupervisor(plugin_dirs=[]) - old_process = self._make_process(1) - new_process = self._make_process(2) - old_reg = self._build_register_payload() - - supervisor._runner_process = old_process - supervisor._registered_plugins[old_reg.plugin_id] = old_reg - supervisor._rebuild_runtime_state() - - class FakeRPCServer: - def __init__(self): - self.runner_generation = 1 - self.staged_generation = 0 - self.is_connected = True - self.session_token = "fake-token" - self.rolled_back = False - - def reset_session_token(self): - self.session_token = "new-fake-token" - return self.session_token - - def restore_session_token(self, token): - self.session_token = token - - def begin_staged_takeover(self): - self.staged_generation = 2 - - async def commit_staged_takeover(self): - self.runner_generation = self.staged_generation - self.staged_generation = 0 - - async def rollback_staged_takeover(self): - self.rolled_back = True - self.staged_generation = 0 - - def has_generation(self, generation): - return generation in {self.runner_generation, self.staged_generation} - - async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs): - raise RuntimeError("new runner unhealthy") - - supervisor._rpc_server = FakeRPCServer() - - async def fake_spawn_runner(): - supervisor._runner_process = new_process - supervisor._staged_registered_plugins["plugin_a"] = self._build_register_payload("plugin_a") - supervisor._runner_ready_payloads[2] = SimpleNamespace(loaded_plugins=["plugin_a"], failed_plugins=[]) - supervisor._runner_ready_events[2] = asyncio.Event() - supervisor._runner_ready_events[2].set() - - monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner) - - reloaded = await supervisor.reload_plugins("test") - - assert reloaded is False - assert supervisor._runner_process is old_process - assert supervisor._rpc_server.rolled_back is True - assert old_reg.plugin_id in supervisor._registered_plugins - assert supervisor.component_registry.get_component("plugin_a.handler") is not None - - @pytest.mark.asyncio - async def test_reload_rebuilds_exact_component_set(self, monkeypatch): - from src.plugin_runtime.host.supervisor import PluginSupervisor - from src.plugin_runtime.protocol.envelope import HealthPayload - - supervisor = PluginSupervisor(plugin_dirs=[]) - old_process = self._make_process(1) - new_process = self._make_process(2) - old_reg = self._build_register_payload("plugin_a", component_names=["handler", "obsolete"]) - new_reg = self._build_register_payload("plugin_a", component_names=["handler"]) - - supervisor._runner_process = old_process - supervisor._registered_plugins[old_reg.plugin_id] = old_reg - supervisor._rebuild_runtime_state() - - class FakeRPCServer: - def __init__(self): - self.runner_generation = 1 - self.staged_generation = 0 - self.is_connected = True - self.session_token = "fake-token" - - def reset_session_token(self): - self.session_token = "new-fake-token" - return self.session_token - - def restore_session_token(self, token): - self.session_token = token - - def begin_staged_takeover(self): - self.staged_generation = 2 - - async def commit_staged_takeover(self): - self.runner_generation = self.staged_generation - self.staged_generation = 0 - - async def rollback_staged_takeover(self): - self.staged_generation = 0 - - def has_generation(self, generation): - return generation in {self.runner_generation, self.staged_generation} - - async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs): - return SimpleNamespace(payload=HealthPayload(healthy=True).model_dump()) - - supervisor._rpc_server = FakeRPCServer() - - async def fake_spawn_runner(): - supervisor._runner_process = new_process - supervisor._staged_registered_plugins[new_reg.plugin_id] = new_reg - supervisor._runner_ready_payloads[2] = SimpleNamespace(loaded_plugins=["plugin_a"], failed_plugins=[]) - supervisor._runner_ready_events[2] = asyncio.Event() - supervisor._runner_ready_events[2].set() - - monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner) - - reloaded = await supervisor.reload_plugins("test") - - assert reloaded is True - assert supervisor.component_registry.get_component("plugin_a.handler") is not None - assert supervisor.component_registry.get_component("plugin_a.obsolete") is None - - @pytest.mark.asyncio - async def test_reload_plugins_uses_batch_rpc_for_multiple_roots(self): - from src.plugin_runtime.host.supervisor import PluginSupervisor - from src.plugin_runtime.protocol.envelope import ReloadPluginsResultPayload - - supervisor = PluginSupervisor(plugin_dirs=[]) - sent_requests: list[tuple[str, dict[str, object], int]] = [] - - class FakeRPCServer: - async def send_request(self, method, payload, timeout_ms=5000, **kwargs): - del kwargs - sent_requests.append((method, payload, timeout_ms)) - return SimpleNamespace( - payload=ReloadPluginsResultPayload( - success=True, - requested_plugin_ids=["plugin_a", "plugin_b"], - reloaded_plugins=["plugin_a", "plugin_b", "plugin_c"], - unloaded_plugins=["plugin_c", "plugin_b", "plugin_a"], - ).model_dump() - ) - - supervisor._rpc_server = FakeRPCServer() - - reloaded = await supervisor.reload_plugins(["plugin_a", "plugin_b", "plugin_a"], reason="manual") - - assert reloaded is True - assert len(sent_requests) == 1 - method, payload, timeout_ms = sent_requests[0] - assert method == "plugin.reload_batch" - assert payload["plugin_ids"] == ["plugin_a", "plugin_b"] - assert payload["reason"] == "manual" - assert timeout_ms >= 10000 - - @pytest.mark.asyncio - async def test_reload_rolls_back_when_runner_ready_not_received(self, monkeypatch): - from src.plugin_runtime.host.supervisor import PluginSupervisor - - supervisor = PluginSupervisor(plugin_dirs=[], runner_spawn_timeout_sec=0.01) - old_process = self._make_process(1) - new_process = self._make_process(2) - old_reg = self._build_register_payload() - - supervisor._runner_process = old_process - supervisor._registered_plugins[old_reg.plugin_id] = old_reg - supervisor._rebuild_runtime_state() - - class FakeRPCServer: - def __init__(self): - self.runner_generation = 1 - self.staged_generation = 0 - self.is_connected = True - self.session_token = "fake-token" - self.rolled_back = False - - def reset_session_token(self): - self.session_token = "new-fake-token" - return self.session_token - - def restore_session_token(self, token): - self.session_token = token - - def begin_staged_takeover(self): - self.staged_generation = 2 - - async def commit_staged_takeover(self): - raise AssertionError("runner.ready 未到达前不应提交 staged takeover") - - async def rollback_staged_takeover(self): - self.rolled_back = True - self.staged_generation = 0 - - def has_generation(self, generation): - return generation in {self.runner_generation, self.staged_generation} - - async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs): - raise AssertionError("runner.ready 未到达前不应执行健康检查") - - supervisor._rpc_server = FakeRPCServer() - - async def fake_spawn_runner(): - supervisor._runner_process = new_process - supervisor._staged_registered_plugins["plugin_a"] = self._build_register_payload("plugin_a") - - monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner) - - reloaded = await supervisor.reload_plugins("test") - - assert reloaded is False - assert supervisor._runner_process is old_process - assert supervisor._rpc_server.rolled_back is True - - @pytest.mark.asyncio - async def test_attach_stderr_drain_drains_stream(self): - """_attach_stderr_drain 为 stderr 创建排空任务,读完后任务自动完成。""" - from src.plugin_runtime.host.supervisor import PluginSupervisor - - supervisor = PluginSupervisor(plugin_dirs=[]) - - stderr = asyncio.StreamReader() - stderr.feed_data(b"fatal startup error\n") - stderr.feed_eof() - - # stdout=None 模拟新架构(不再捕获 stdout) - process = SimpleNamespace(pid=99, stdout=None, stderr=stderr) - supervisor._attach_stderr_drain(process) - - # 给 drain task 足够时间消费完数据 - await asyncio.sleep(0.05) - - assert supervisor._stderr_drain_task is None or supervisor._stderr_drain_task.done() - - -class TestIntegration: - """运行时集成层启动/清理测试""" - - @pytest.mark.asyncio - async def test_cap_database_get_with_filters_does_not_reference_unbound_key_value(self, monkeypatch): - from src.plugin_runtime import integration as integration_module - import src.common.database.database_model as real_db_models - from src.services import database_service as real_database_service - - captured: dict[str, object] = {} - - class DummyModel: - pass - - async def fake_db_get(model_class, filters=None, limit=None, order_by=None, single_result=False): - captured["model_class"] = model_class - captured["filters"] = filters - captured["limit"] = limit - captured["order_by"] = order_by - captured["single_result"] = single_result - return [{"id": 1}] - - monkeypatch.setattr(real_database_service, "db_get", fake_db_get) - monkeypatch.setattr(real_db_models, "DemoTable", DummyModel, raising=False) - - manager = object.__new__(integration_module.PluginRuntimeManager) - result = await manager._cap_database_get( - "plugin_a", - "database.get", - { - "model_name": "DemoTable", - "filters": {"status": "active"}, - "limit": 5, - }, - ) - - assert result == [{"id": 1}] - assert captured["model_class"] is DummyModel - assert captured["filters"] == {"status": "active"} - assert captured["limit"] == 5 - assert captured["single_result"] is False - - @pytest.mark.asyncio - async def test_cap_database_get_response_is_not_double_wrapped(self, monkeypatch): - from src.plugin_runtime import integration as integration_module - import src.common.database.database_model as real_db_models - from src.plugin_runtime.host.capability_service import CapabilityService - from src.plugin_runtime.protocol.envelope import CapabilityRequestPayload, Envelope, MessageType - from src.services import database_service as real_database_service - - class AllowAllAuthorization: - def check_capability(self, plugin_id, capability): - return True, "" - - class DummyModel: - pass - - async def fake_db_get(model_class, filters=None, limit=None, order_by=None, single_result=False): - return {"id": 1, "full_path": "E:\\test.png"} - - monkeypatch.setattr(real_database_service, "db_get", fake_db_get) - monkeypatch.setattr(real_db_models, "DemoTable", DummyModel, raising=False) - - manager = object.__new__(integration_module.PluginRuntimeManager) - service = CapabilityService(AllowAllAuthorization()) - service.register_capability("database.get", manager._cap_database_get) - - request = Envelope( - request_id=1, - message_type=MessageType.REQUEST, - method="cap.call", - plugin_id="plugin_a", - payload=CapabilityRequestPayload( - capability="database.get", - args={"model_name": "DemoTable", "single_result": True}, - ).model_dump(), - ) - - response = await service.handle_capability_request(request) - - assert response.payload == { - "success": True, - "result": {"id": 1, "full_path": "E:\\test.png"}, - } - - @pytest.mark.asyncio - async def test_cap_database_success_handlers_return_raw_results(self, monkeypatch): - from src.plugin_runtime import integration as integration_module - import src.common.database.database_model as real_db_models - from src.services import database_service as real_database_service - - class DummyModel: - pass - - async def fake_db_get(**kwargs): - return [{"id": 1}] - - async def fake_db_save(**kwargs): - return {"id": 2} - - async def fake_db_delete(**kwargs): - return 3 - - async def fake_db_count(**kwargs): - return 4 - - monkeypatch.setattr(real_database_service, "db_get", fake_db_get) - monkeypatch.setattr(real_database_service, "db_save", fake_db_save) - monkeypatch.setattr(real_database_service, "db_delete", fake_db_delete) - monkeypatch.setattr(real_database_service, "db_count", fake_db_count) - monkeypatch.setattr(real_db_models, "DemoTable", DummyModel, raising=False) - - manager = object.__new__(integration_module.PluginRuntimeManager) - base_args = {"model_name": "DemoTable"} - - assert await manager._cap_database_query("plugin_a", "database.query", base_args) == [{"id": 1}] - assert await manager._cap_database_save( - "plugin_a", "database.save", {**base_args, "data": {"name": "demo"}} - ) == {"id": 2} - assert await manager._cap_database_delete( - "plugin_a", "database.delete", {**base_args, "filters": {"id": 2}} - ) == 3 - assert await manager._cap_database_count("plugin_a", "database.count", base_args) == 4 - - @pytest.mark.asyncio - async def test_component_enable_rejects_ambiguous_short_name(self, monkeypatch): - from src.plugin_runtime import integration as integration_module - from src.plugin_runtime.host.component_registry import ComponentRegistry - - class FakeSupervisor: - def __init__(self, plugin_id: str): - self.component_registry = ComponentRegistry() - self.component_registry.register_component( - name="shared", - component_type="tool", - plugin_id=plugin_id, - metadata={}, - ) - - class FakeManager: - def __init__(self): - self.supervisors = [FakeSupervisor("plugin_a"), FakeSupervisor("plugin_b")] - - monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager()) - manager = integration_module.PluginRuntimeManager() - manager._builtin_supervisor = FakeSupervisor("plugin_a") - manager._third_party_supervisor = FakeSupervisor("plugin_b") - - result = await manager._cap_component_enable( - "plugin_a", - "component.enable", - {"name": "shared", "component_type": "tool", "scope": "global", "stream_id": ""}, - ) - - assert result["success"] is False - assert "组件名不唯一" in result["error"] - - @pytest.mark.asyncio - async def test_component_disable_rejects_non_global_scope(self, monkeypatch): - from src.plugin_runtime import integration as integration_module - from src.plugin_runtime.host.component_registry import ComponentRegistry - - class FakeSupervisor: - def __init__(self): - self.component_registry = ComponentRegistry() - self.component_registry.register_component( - name="handler", - component_type="tool", - plugin_id="plugin_a", - metadata={}, - ) - - class FakeManager: - def __init__(self): - self.supervisors = [FakeSupervisor()] - - monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager()) - manager = integration_module.PluginRuntimeManager() - manager._builtin_supervisor = FakeSupervisor() - - result = await manager._cap_component_disable( - "plugin_a", - "component.disable", - {"name": "plugin_a.handler", "component_type": "tool", "scope": "stream", "stream_id": "s1"}, - ) - - assert result["success"] is False - assert "仅支持全局组件禁用" in result["error"] - - @pytest.mark.asyncio - async def test_start_cleans_up_started_supervisors_on_failure(self, monkeypatch): - from src.plugin_runtime import integration as integration_module - - instances = [] - builtin_dir = Path("builtin") - thirdparty_dir = Path("thirdparty") - - class FakeCapabilityService: - def register_capability(self, name, impl): - return None - - class FakeSupervisor: - def __init__(self, plugin_dirs=None, socket_path=None): - self._plugin_dirs = plugin_dirs or [] - self.capability_service = FakeCapabilityService() - self.external_plugin_versions = {} - self.stopped = False - instances.append(self) - - def set_external_available_plugins(self, plugin_versions): - self.external_plugin_versions = dict(plugin_versions) - - def get_loaded_plugin_ids(self): - return [] - - def get_loaded_plugin_versions(self): - return {} - - async def start(self): - if len(instances) == 2 and self is instances[1]: - raise RuntimeError("boom") - - async def stop(self): - self.stopped = True - - monkeypatch.setattr( - integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: [builtin_dir]) - ) - monkeypatch.setattr( - integration_module.PluginRuntimeManager, "_get_third_party_plugin_dirs", staticmethod(lambda: [thirdparty_dir]) - ) - - import src.plugin_runtime.host.supervisor as supervisor_module - - monkeypatch.setattr(supervisor_module, "PluginSupervisor", FakeSupervisor) - - manager = integration_module.PluginRuntimeManager() - await manager.start() - - assert manager.is_running is False - assert len(instances) == 2 - assert instances[0].stopped is True - - @pytest.mark.asyncio - async def test_handle_plugin_source_changes_restarts_supervisors_after_dependency_sync(self, monkeypatch, tmp_path): - from src.config.file_watcher import FileChange - from src.plugin_runtime import integration as integration_module - import json - - builtin_root = tmp_path / "src" / "plugins" / "built_in" - thirdparty_root = tmp_path / "plugins" - alpha_dir = builtin_root / "alpha" - beta_dir = thirdparty_root / "beta" - alpha_dir.mkdir(parents=True) - beta_dir.mkdir(parents=True) - (alpha_dir / "config.toml").write_text("enabled = true\n", encoding="utf-8") - (beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8") - (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") - (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") - (alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8") - (beta_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.beta")), encoding="utf-8") - - monkeypatch.chdir(tmp_path) - - class FakeSupervisor: - def __init__(self, plugin_dirs, registered_plugins): - self._plugin_dirs = plugin_dirs - self._registered_plugins = registered_plugins - self.config_updates = [] - - def get_loaded_plugin_ids(self): - return sorted(self._registered_plugins.keys()) - - def get_loaded_plugin_versions(self): - return {plugin_id: "1.0.0" for plugin_id in self._registered_plugins} - - async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""): - self.config_updates.append((plugin_id, config_data, config_version)) - return True - - manager = integration_module.PluginRuntimeManager() - manager._started = True - manager._builtin_supervisor = FakeSupervisor([builtin_root], {"test.alpha": object()}) - manager._third_party_supervisor = FakeSupervisor([thirdparty_root], {"test.beta": object()}) - dependency_sync_calls = [] - restart_calls = [] - - async def fake_sync(plugin_dirs: Sequence[Path]) -> Any: - """记录依赖同步调用。""" - - dependency_sync_calls.append(list(plugin_dirs)) - return integration_module.DependencySyncState( - blocked_changed_plugin_ids={"test.beta"}, - environment_changed=False, - ) - - async def fake_restart(reason: str) -> bool: - """记录 Supervisor 重启调用。""" - - restart_calls.append(reason) - return True - - monkeypatch.setattr(manager, "_sync_plugin_dependencies", fake_sync) - monkeypatch.setattr(manager, "_restart_supervisors", fake_restart) - - changes = [ - FileChange(change_type=1, path=beta_dir / "plugin.py"), - ] - - await manager._handle_plugin_source_changes(changes) - - assert dependency_sync_calls == [[builtin_root, thirdparty_root]] - assert restart_calls == ["file_watcher_blocklist_changed"] - assert manager._builtin_supervisor.config_updates == [] - assert manager._third_party_supervisor.config_updates == [] - - @pytest.mark.asyncio - async def test_reload_plugins_globally_warns_and_skips_cross_supervisor_dependents(self, monkeypatch): - from src.plugin_runtime import integration as integration_module - - class FakeRegistration: - def __init__(self, dependencies): - self.dependencies = dependencies - - class FakeSupervisor: - def __init__(self, registrations): - self._registered_plugins = registrations - self.reload_calls = [] - - def get_loaded_plugin_ids(self): - return sorted(self._registered_plugins.keys()) - - def get_loaded_plugin_versions(self): - return {plugin_id: "1.0.0" for plugin_id in self._registered_plugins} - - async def reload_plugins(self, plugin_ids=None, reason="manual", external_available_plugins=None): - self.reload_calls.append((plugin_ids, reason, dict(sorted((external_available_plugins or {}).items())))) - return True - - builtin_supervisor = FakeSupervisor({"test.alpha": FakeRegistration([])}) - third_party_supervisor = FakeSupervisor( - { - "test.beta": FakeRegistration(["test.alpha"]), - "test.gamma": FakeRegistration(["test.beta"]), - } - ) - - manager = integration_module.PluginRuntimeManager() - manager._builtin_supervisor = builtin_supervisor - manager._third_party_supervisor = third_party_supervisor - warning_messages = [] - - monkeypatch.setattr( - integration_module.logger, - "warning", - lambda message: warning_messages.append(message), - ) - - reloaded = await manager.reload_plugins_globally(["test.alpha"], reason="manual") - - assert reloaded is True - assert builtin_supervisor.reload_calls == [ - (["test.alpha"], "manual", {"test.beta": "1.0.0", "test.gamma": "1.0.0"}) - ] - assert third_party_supervisor.reload_calls == [] - assert len(warning_messages) == 1 - assert "test.beta, test.gamma" in warning_messages[0] - assert "跨 Supervisor API 调用仍然可用" in warning_messages[0] - - @pytest.mark.asyncio - async def test_handle_plugin_config_changes_only_notify_target_plugin(self, monkeypatch, tmp_path): - from src.plugin_runtime import integration as integration_module - from src.config.file_watcher import FileChange - import json - - builtin_root = tmp_path / "src" / "plugins" / "built_in" - thirdparty_root = tmp_path / "plugins" - alpha_dir = builtin_root / "alpha" - beta_dir = thirdparty_root / "beta" - alpha_dir.mkdir(parents=True) - beta_dir.mkdir(parents=True) - (alpha_dir / "config.toml").write_text("enabled = true\n", encoding="utf-8") - (beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8") - (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") - (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") - (alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8") - (beta_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.beta")), encoding="utf-8") - - monkeypatch.chdir(tmp_path) - - class FakeSupervisor: - def __init__(self, plugin_dirs, plugins): - self._plugin_dirs = plugin_dirs - self._registered_plugins = {plugin_id: object() for plugin_id in plugins} - self.config_updates = [] - - async def inspect_plugin_config( - self, - plugin_id: str, - config_data: Optional[Dict[str, Any]] = None, - use_provided_config: bool = False, - ) -> SimpleNamespace: - """返回测试用的配置解析结果。""" - del config_data, use_provided_config - return SimpleNamespace(enabled=True, normalized_config={"enabled": True}, plugin_id=plugin_id) - - async def notify_plugin_config_updated( - self, - plugin_id, - config_data, - config_version="", - config_scope="self", - ): - self.config_updates.append((plugin_id, config_data, config_version, config_scope)) - return True - - manager = integration_module.PluginRuntimeManager() - manager._started = True - manager._builtin_supervisor = FakeSupervisor([builtin_root], ["test.alpha"]) - manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["test.beta"]) - - await manager._handle_plugin_config_changes( - "test.alpha", - [FileChange(change_type=1, path=alpha_dir / "config.toml")], - ) - - assert manager._builtin_supervisor.config_updates == [("test.alpha", {"enabled": True}, "", "self")] - assert manager._third_party_supervisor.config_updates == [] - - @pytest.mark.asyncio - async def test_handle_plugin_config_changes_loads_unloaded_enabled_plugin(self, monkeypatch, tmp_path): - from src.plugin_runtime import integration as integration_module - from src.config.file_watcher import FileChange - import json - - thirdparty_root = tmp_path / "plugins" - alpha_dir = thirdparty_root / "alpha" - alpha_dir.mkdir(parents=True) - (alpha_dir / "config.toml").write_text("[plugin]\nenabled = true\n", encoding="utf-8") - (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") - (alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8") - - monkeypatch.chdir(tmp_path) - - class FakeSupervisor: - def __init__(self, plugin_dirs): - self._plugin_dirs = plugin_dirs - self._registered_plugins = {} - - async def inspect_plugin_config( - self, - plugin_id: str, - config_data: Optional[Dict[str, Any]] = None, - use_provided_config: bool = False, - ) -> SimpleNamespace: - """返回测试用的启用配置快照。""" - del config_data, use_provided_config - return SimpleNamespace(enabled=True, normalized_config={"plugin": {"enabled": True}}, plugin_id=plugin_id) - - manager = integration_module.PluginRuntimeManager() - manager._started = True - manager._third_party_supervisor = FakeSupervisor([thirdparty_root]) - - load_calls = [] - - async def fake_load_plugin_globally(plugin_id: str, reason: str = "manual") -> bool: - """记录自动加载调用。""" - load_calls.append((plugin_id, reason)) - return True - - monkeypatch.setattr(manager, "load_plugin_globally", fake_load_plugin_globally) - - await manager._handle_plugin_config_changes( - "test.alpha", - [FileChange(change_type=1, path=alpha_dir / "config.toml")], - ) - - assert load_calls == [("test.alpha", "config_enabled")] - - @pytest.mark.asyncio - async def test_handle_plugin_config_changes_unloads_loaded_disabled_plugin(self, monkeypatch, tmp_path): - from src.plugin_runtime import integration as integration_module - from src.config.file_watcher import FileChange - import json - - builtin_root = tmp_path / "src" / "plugins" / "built_in" - alpha_dir = builtin_root / "alpha" - alpha_dir.mkdir(parents=True) - (alpha_dir / "config.toml").write_text("[plugin]\nenabled = false\n", encoding="utf-8") - (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") - (alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8") - - monkeypatch.chdir(tmp_path) - - class FakeSupervisor: - def __init__(self, plugin_dirs, plugins): - self._plugin_dirs = plugin_dirs - self._registered_plugins = {plugin_id: object() for plugin_id in plugins} - - async def inspect_plugin_config( - self, - plugin_id: str, - config_data: Optional[Dict[str, Any]] = None, - use_provided_config: bool = False, - ) -> SimpleNamespace: - """返回测试用的禁用配置快照。""" - del config_data, use_provided_config - return SimpleNamespace( - enabled=False, - normalized_config={"plugin": {"enabled": False}}, - plugin_id=plugin_id, - ) - - manager = integration_module.PluginRuntimeManager() - manager._started = True - manager._builtin_supervisor = FakeSupervisor([builtin_root], ["test.alpha"]) - - reload_calls = [] - - async def fake_reload_plugins_globally(plugin_ids: Sequence[str], reason: str = "manual") -> bool: - """记录自动卸载调用。""" - reload_calls.append((list(plugin_ids), reason)) - return True - - monkeypatch.setattr(manager, "reload_plugins_globally", fake_reload_plugins_globally) - - await manager._handle_plugin_config_changes( - "test.alpha", - [FileChange(change_type=1, path=alpha_dir / "config.toml")], - ) - - assert reload_calls == [(["test.alpha"], "config_disabled")] - - @pytest.mark.asyncio - async def test_handle_main_config_reload_only_notifies_subscribers(self, monkeypatch): - from src.plugin_runtime import integration as integration_module - - class FakeRegistration: - def __init__(self, subscriptions): - self.config_reload_subscriptions = subscriptions - - class FakeSupervisor: - def __init__(self, registrations): - self._registered_plugins = registrations - self.config_updates = [] - - def get_config_reload_subscribers(self, scope): - matched_plugins = [] - for plugin_id, registration in self._registered_plugins.items(): - if scope in registration.config_reload_subscriptions: - matched_plugins.append(plugin_id) - return matched_plugins - - async def notify_plugin_config_updated( - self, - plugin_id, - config_data, - config_version="", - config_scope="self", - ): - self.config_updates.append((plugin_id, config_data, config_version, config_scope)) - return True - - fake_global = SimpleNamespace(plugin_runtime=SimpleNamespace(enabled=True)) - monkeypatch.setattr( - integration_module.config_manager, - "get_global_config", - lambda: SimpleNamespace(model_dump=lambda: {"bot": {"name": "MaiBot"}}, plugin_runtime=fake_global.plugin_runtime), - ) - monkeypatch.setattr( - integration_module.config_manager, - "get_model_config", - lambda: SimpleNamespace(model_dump=lambda: {"models": [{"name": "demo"}]}), - ) - - manager = integration_module.PluginRuntimeManager() - manager._started = True - manager._builtin_supervisor = FakeSupervisor( - { - "test.alpha": FakeRegistration(["bot"]), - "test.beta": FakeRegistration([]), - } - ) - manager._third_party_supervisor = FakeSupervisor( - { - "test.gamma": FakeRegistration(["model"]), - } - ) - - await manager._handle_main_config_reload(["bot", "model"]) - - assert manager._builtin_supervisor.config_updates == [ - ("test.alpha", {"bot": {"name": "MaiBot"}}, "", "bot") - ] - assert manager._third_party_supervisor.config_updates == [ - ("test.gamma", {"models": [{"name": "demo"}]}, "", "model") - ] - - def test_refresh_plugin_config_watch_subscriptions_registers_per_plugin(self, tmp_path): - from src.plugin_runtime import integration as integration_module - import json - - builtin_root = tmp_path / "src" / "plugins" / "built_in" - thirdparty_root = tmp_path / "plugins" - alpha_dir = builtin_root / "alpha" - beta_dir = thirdparty_root / "beta" - alpha_dir.mkdir(parents=True) - beta_dir.mkdir(parents=True) - (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") - (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") - (alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8") - (beta_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.beta")), encoding="utf-8") - - class FakeWatcher: - def __init__(self): - self.subscriptions = [] - self.unsubscribed = [] - - def subscribe(self, callback, *, paths=None, change_types=None): - subscription_id = f"sub-{len(self.subscriptions) + 1}" - self.subscriptions.append({"id": subscription_id, "callback": callback, "paths": tuple(paths or ())}) - return subscription_id - - def unsubscribe(self, subscription_id): - self.unsubscribed.append(subscription_id) - return True - - class FakeSupervisor: - def __init__(self, plugin_dirs, plugins): - self._plugin_dirs = plugin_dirs - self._registered_plugins = {plugin_id: object() for plugin_id in plugins} - - manager = integration_module.PluginRuntimeManager() - manager._plugin_file_watcher = FakeWatcher() - manager._builtin_supervisor = FakeSupervisor([builtin_root], ["test.alpha"]) - manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["test.beta"]) - - manager._refresh_plugin_config_watch_subscriptions() - - assert set(manager._plugin_config_watcher_subscriptions.keys()) == {"test.alpha", "test.beta"} - assert { - subscription["paths"][0] for subscription in manager._plugin_file_watcher.subscriptions - } == {alpha_dir / "config.toml", beta_dir / "config.toml"} - - def test_refresh_plugin_config_watch_subscriptions_includes_unloaded_plugins(self, tmp_path): - from src.plugin_runtime import integration as integration_module - import json - - thirdparty_root = tmp_path / "plugins" - alpha_dir = thirdparty_root / "alpha" - beta_dir = thirdparty_root / "beta" - alpha_dir.mkdir(parents=True) - beta_dir.mkdir(parents=True) - (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") - (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") - (alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8") - (beta_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.beta")), encoding="utf-8") - - class FakeWatcher: - def __init__(self): - self.subscriptions = [] - - def subscribe( - self, - callback: Any, - *, - paths: Optional[Sequence[Path]] = None, - change_types: Any = None, - ) -> str: - """记录新的监听订阅。""" - del callback, change_types - subscription_id = f"sub-{len(self.subscriptions) + 1}" - self.subscriptions.append({"id": subscription_id, "paths": tuple(paths or ())}) - return subscription_id - - def unsubscribe(self, subscription_id: str) -> bool: - """兼容 watcher 取消订阅接口。""" - del subscription_id - return True - - class FakeSupervisor: - def __init__(self, plugin_dirs, plugins): - self._plugin_dirs = plugin_dirs - self._registered_plugins = {plugin_id: object() for plugin_id in plugins} - - manager = integration_module.PluginRuntimeManager() - manager._plugin_file_watcher = FakeWatcher() - manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["test.alpha"]) - - manager._refresh_plugin_config_watch_subscriptions() - - assert set(manager._plugin_config_watcher_subscriptions.keys()) == {"test.alpha", "test.beta"} - - @pytest.mark.asyncio - async def test_component_reload_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch): - from src.plugin_runtime import integration as integration_module - - manager = integration_module.PluginRuntimeManager() - monkeypatch.setattr(manager, "reload_plugins_globally", lambda plugin_ids, reason="manual": asyncio.sleep(0, False)) - - result = await manager._cap_component_reload_plugin( - "plugin_a", - "component.reload_plugin", - {"plugin_name": "alpha"}, - ) - - assert result["success"] is False - assert result["error"] == "插件 alpha 热重载失败" - - @pytest.mark.asyncio - async def test_component_load_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch, tmp_path): - from src.plugin_runtime import integration as integration_module - - manager = integration_module.PluginRuntimeManager() - monkeypatch.setattr(manager, "load_plugin_globally", lambda plugin_id, reason="manual": asyncio.sleep(0, False)) - - result = await manager._cap_component_load_plugin( - "plugin_a", - "component.load_plugin", - {"plugin_name": "alpha"}, - ) - - assert result["success"] is False - assert result["error"] == "插件 alpha 热重载失败" diff --git a/pytests/test_plugin_runtime_action_bridge.py b/pytests/test_plugin_runtime_action_bridge.py deleted file mode 100644 index e13dfaf3..00000000 --- a/pytests/test_plugin_runtime_action_bridge.py +++ /dev/null @@ -1,284 +0,0 @@ -"""核心组件查询层与插件运行时聚合测试。""" - -from types import SimpleNamespace -from typing import Any - -import pytest - -import src.plugin_runtime.integration as integration_module - -from src.core.types import ActionInfo, ToolInfo -from src.plugin_runtime.component_query import component_query_service -from src.plugin_runtime.host.supervisor import PluginSupervisor - - -class _FakeRuntimeManager: - """测试用插件运行时管理器。""" - - def __init__(self, supervisor: PluginSupervisor, plugin_id: str, plugin_config: dict[str, Any]) -> None: - """初始化测试用运行时管理器。 - - Args: - supervisor: 持有测试组件的监督器。 - plugin_id: 目标插件 ID。 - plugin_config: 需要返回的插件配置。 - """ - - self.supervisors = [supervisor] - self._plugin_id = plugin_id - self._plugin_config = plugin_config - - def _get_supervisor_for_plugin(self, plugin_id: str) -> PluginSupervisor | None: - """按插件 ID 返回对应监督器。 - - Args: - plugin_id: 目标插件 ID。 - - Returns: - PluginSupervisor | None: 命中时返回监督器。 - """ - - return self.supervisors[0] if plugin_id == self._plugin_id else None - - def _load_plugin_config_for_supervisor(self, supervisor: Any, plugin_id: str) -> dict[str, Any]: - """返回测试配置。 - - Args: - supervisor: 监督器实例。 - plugin_id: 目标插件 ID。 - - Returns: - dict[str, Any]: 测试配置内容。 - """ - - del supervisor - if plugin_id != self._plugin_id: - return {} - return dict(self._plugin_config) - - -def _install_runtime_manager( - monkeypatch: pytest.MonkeyPatch, - supervisor: PluginSupervisor, - plugin_id: str, - plugin_config: dict[str, Any] | None = None, -) -> None: - """为测试安装假的运行时管理器。 - - Args: - monkeypatch: pytest monkeypatch 对象。 - supervisor: 持有测试组件的监督器。 - plugin_id: 测试插件 ID。 - plugin_config: 可选的测试配置内容。 - """ - - fake_manager = _FakeRuntimeManager(supervisor, plugin_id, plugin_config or {"enabled": True}) - monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: fake_manager) - - -@pytest.mark.asyncio -async def test_core_component_registry_reads_runtime_action_and_executor( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """核心查询层应直接读取运行时 Action,并返回 RPC 执行闭包。""" - - plugin_id = "runtime_action_bridge_plugin" - action_name = "runtime_action_bridge_test" - supervisor = PluginSupervisor(plugin_dirs=[]) - captured: dict[str, Any] = {} - - supervisor.component_registry.register_component( - name=action_name, - component_type="ACTION", - plugin_id=plugin_id, - metadata={ - "description": "发送一个测试回复", - "enabled": True, - "activation_type": "keyword", - "activation_probability": 0.25, - "activation_keywords": ["测试", "hello"], - "action_parameters": {"target": "目标对象"}, - "action_require": ["需要发送回复时使用"], - "associated_types": ["text"], - "parallel_action": True, - }, - ) - _install_runtime_manager(monkeypatch, supervisor, plugin_id, {"enabled": True, "mode": "test"}) - - async def fake_invoke_plugin( - method: str, - plugin_id: str, - component_name: str, - args: dict[str, Any] | None = None, - timeout_ms: int = 30000, - ) -> Any: - """模拟动作 RPC 调用。""" - - captured["method"] = method - captured["plugin_id"] = plugin_id - captured["component_name"] = component_name - captured["args"] = args or {} - captured["timeout_ms"] = timeout_ms - return SimpleNamespace(payload={"success": True, "result": (True, "runtime action executed")}) - - monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin) - - action_info = component_query_service.get_action_info(action_name) - assert isinstance(action_info, ActionInfo) - assert action_info.plugin_name == plugin_id - assert action_info.description == "发送一个测试回复" - assert action_info.activation_keywords == ["测试", "hello"] - assert action_info.random_activation_probability == 0.25 - assert action_info.parallel_action is True - assert action_name in component_query_service.get_default_actions() - assert component_query_service.get_plugin_config(plugin_id) == {"enabled": True, "mode": "test"} - - executor = component_query_service.get_action_executor(action_name) - assert executor is not None - - success, reason = await executor( - action_data={"target": "MaiBot"}, - action_reasoning="当前适合使用这个动作", - cycle_timers={"planner": 0.1}, - thinking_id="tid-1", - chat_stream=SimpleNamespace(session_id="stream-1"), - log_prefix="[test]", - shutting_down=False, - plugin_config={"enabled": True}, - ) - - assert success is True - assert reason == "runtime action executed" - assert captured["method"] == "plugin.invoke_action" - assert captured["plugin_id"] == plugin_id - assert captured["component_name"] == action_name - assert captured["args"]["stream_id"] == "stream-1" - assert captured["args"]["chat_id"] == "stream-1" - assert captured["args"]["reasoning"] == "当前适合使用这个动作" - assert captured["args"]["target"] == "MaiBot" - assert captured["args"]["action_data"] == {"target": "MaiBot"} - - -@pytest.mark.asyncio -async def test_core_component_registry_reads_runtime_command_and_executor( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """核心查询层应直接使用运行时命令匹配与执行闭包。""" - - plugin_id = "runtime_command_bridge_plugin" - command_name = "runtime_command_bridge_test" - supervisor = PluginSupervisor(plugin_dirs=[]) - captured: dict[str, Any] = {} - - supervisor.component_registry.register_component( - name=command_name, - component_type="COMMAND", - plugin_id=plugin_id, - metadata={ - "description": "测试命令", - "enabled": True, - "command_pattern": r"^/test(?:\s+.+)?$", - "aliases": ["/hello"], - "intercept_message_level": 1, - }, - ) - _install_runtime_manager(monkeypatch, supervisor, plugin_id, {"mode": "command"}) - - async def fake_invoke_plugin( - method: str, - plugin_id: str, - component_name: str, - args: dict[str, Any] | None = None, - timeout_ms: int = 30000, - ) -> Any: - """模拟命令 RPC 调用。""" - - captured["method"] = method - captured["plugin_id"] = plugin_id - captured["component_name"] = component_name - captured["args"] = args or {} - captured["timeout_ms"] = timeout_ms - return SimpleNamespace(payload={"success": True, "result": (True, "command ok", True)}) - - monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin) - - matched = component_query_service.find_command_by_text("/test hello") - assert matched is not None - command_executor, matched_groups, command_info = matched - - assert matched_groups == {} - assert command_info.plugin_name == plugin_id - assert command_info.command_pattern == r"^/test(?:\s+.+)?$" - - success, response_text, intercept = await command_executor( - message=SimpleNamespace(processed_plain_text="/test hello", session_id="stream-2"), - plugin_config={"mode": "command"}, - matched_groups=matched_groups, - ) - - assert success is True - assert response_text == "command ok" - assert intercept is True - assert captured["method"] == "plugin.invoke_command" - assert captured["plugin_id"] == plugin_id - assert captured["component_name"] == command_name - assert captured["args"]["text"] == "/test hello" - assert captured["args"]["stream_id"] == "stream-2" - assert captured["args"]["plugin_config"] == {"mode": "command"} - - -@pytest.mark.asyncio -async def test_core_component_registry_reads_runtime_tools_and_executor( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """核心查询层应直接读取运行时 Tool,并返回 RPC 执行闭包。""" - - plugin_id = "runtime_tool_bridge_plugin" - tool_name = "runtime_tool_bridge_test" - supervisor = PluginSupervisor(plugin_dirs=[]) - - supervisor.component_registry.register_component( - name=tool_name, - component_type="TOOL", - plugin_id=plugin_id, - metadata={ - "description": "测试工具", - "enabled": True, - "parameters": [ - { - "name": "query", - "param_type": "string", - "description": "查询词", - "required": True, - } - ], - }, - ) - _install_runtime_manager(monkeypatch, supervisor, plugin_id) - - async def fake_invoke_plugin( - method: str, - plugin_id: str, - component_name: str, - args: dict[str, Any] | None = None, - timeout_ms: int = 30000, - ) -> Any: - """模拟工具 RPC 调用。""" - - del timeout_ms - assert method == "plugin.invoke_tool" - assert plugin_id == "runtime_tool_bridge_plugin" - assert component_name == "runtime_tool_bridge_test" - assert args == {"query": "MaiBot"} - return SimpleNamespace(payload={"success": True, "result": {"content": "tool ok"}}) - - monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin) - - tool_info = component_query_service.get_tool_info(tool_name) - assert isinstance(tool_info, ToolInfo) - assert tool_info.tool_description == "测试工具" - assert tool_name in component_query_service.get_llm_available_tools() - - executor = component_query_service.get_tool_executor(tool_name) - assert executor is not None - assert await executor({"query": "MaiBot"}) == {"content": "tool ok"} diff --git a/pytests/test_plugin_runtime_api.py b/pytests/test_plugin_runtime_api.py deleted file mode 100644 index 58a8e6ba..00000000 --- a/pytests/test_plugin_runtime_api.py +++ /dev/null @@ -1,524 +0,0 @@ -"""插件 API 注册与调用测试。""" - -from types import SimpleNamespace -from typing import Any, Dict, List - -import pytest - -from src.plugin_runtime.integration import PluginRuntimeManager -from src.plugin_runtime.host.supervisor import PluginSupervisor -from src.plugin_runtime.protocol.envelope import ( - ComponentDeclaration, - Envelope, - MessageType, - RegisterPluginPayload, - UnregisterPluginPayload, -) - - -def _build_manager(*supervisors: PluginSupervisor) -> PluginRuntimeManager: - """构造一个最小可用的插件运行时管理器。 - - Args: - *supervisors: 需要挂载的监督器列表。 - - Returns: - PluginRuntimeManager: 已注入监督器的运行时管理器。 - """ - - manager = PluginRuntimeManager() - if supervisors: - manager._builtin_supervisor = supervisors[0] - if len(supervisors) > 1: - manager._third_party_supervisor = supervisors[1] - return manager - - -async def _register_plugin( - supervisor: PluginSupervisor, - plugin_id: str, - components: List[Dict[str, Any]], -) -> Envelope: - """通过 Supervisor 注册测试插件。 - - Args: - supervisor: 目标监督器。 - plugin_id: 测试插件 ID。 - components: 组件声明列表。 - - Returns: - Envelope: 注册响应信封。 - """ - - payload = RegisterPluginPayload( - plugin_id=plugin_id, - plugin_version="1.0.0", - components=[ - ComponentDeclaration( - name=str(component.get("name", "") or ""), - component_type=str(component.get("component_type", "") or ""), - plugin_id=plugin_id, - metadata=component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {}, - ) - for component in components - ], - ) - return await supervisor._handle_register_plugin( - Envelope( - request_id=1, - message_type=MessageType.REQUEST, - method="plugin.register_components", - plugin_id=plugin_id, - payload=payload.model_dump(), - ) - ) - - -async def _unregister_plugin(supervisor: PluginSupervisor, plugin_id: str) -> Envelope: - """通过 Supervisor 注销测试插件。 - - Args: - supervisor: 目标监督器。 - plugin_id: 测试插件 ID。 - - Returns: - Envelope: 注销响应信封。 - """ - - payload = UnregisterPluginPayload(plugin_id=plugin_id, reason="test") - return await supervisor._handle_unregister_plugin( - Envelope( - request_id=2, - message_type=MessageType.REQUEST, - method="plugin.unregister", - plugin_id=plugin_id, - payload=payload.model_dump(), - ) - ) - - -@pytest.mark.asyncio -async def test_register_plugin_syncs_dedicated_api_registry() -> None: - """插件注册时应将 API 同步到独立注册表,而不是通用组件表。""" - - supervisor = PluginSupervisor(plugin_dirs=[]) - response = await _register_plugin( - supervisor, - "provider", - [ - { - "name": "render_html", - "component_type": "API", - "metadata": { - "description": "渲染 HTML", - "version": "1", - "public": True, - }, - } - ], - ) - - assert response.payload["accepted"] is True - assert response.payload["registered_components"] == 0 - assert response.payload["registered_apis"] == 1 - assert supervisor.api_registry.get_api("provider", "render_html") is not None - assert supervisor.component_registry.get_component("provider.render_html") is None - - unregister_response = await _unregister_plugin(supervisor, "provider") - assert unregister_response.payload["removed_apis"] == 1 - assert supervisor.api_registry.get_api("provider", "render_html") is None - - -@pytest.mark.asyncio -async def test_api_call_allows_public_api_between_plugins(monkeypatch: pytest.MonkeyPatch) -> None: - """公开 API 应允许其他插件通过 Host 转发调用。""" - - provider_supervisor = PluginSupervisor(plugin_dirs=[]) - consumer_supervisor = PluginSupervisor(plugin_dirs=[]) - await _register_plugin( - provider_supervisor, - "provider", - [ - { - "name": "render_html", - "component_type": "API", - "metadata": { - "description": "渲染 HTML", - "version": "1", - "public": True, - }, - } - ], - ) - await _register_plugin(consumer_supervisor, "consumer", []) - - captured: Dict[str, Any] = {} - - async def fake_invoke_api( - plugin_id: str, - component_name: str, - args: Dict[str, Any] | None = None, - timeout_ms: int = 30000, - ) -> Any: - """模拟 API RPC 调用。""" - - captured["plugin_id"] = plugin_id - captured["component_name"] = component_name - captured["args"] = args or {} - captured["timeout_ms"] = timeout_ms - return SimpleNamespace(error=None, payload={"success": True, "result": {"image": "ok"}}) - - monkeypatch.setattr(provider_supervisor, "invoke_api", fake_invoke_api) - - manager = _build_manager(provider_supervisor, consumer_supervisor) - result = await manager._cap_api_call( - "consumer", - "api.call", - { - "api_name": "provider.render_html", - "version": "1", - "args": {"html": "
Hello
"}, - }, - ) - - assert result == {"success": True, "result": {"image": "ok"}} - assert captured["plugin_id"] == "provider" - assert captured["component_name"] == "render_html" - assert captured["args"] == {"html": "
Hello
"} - - -@pytest.mark.asyncio -async def test_api_call_rejects_private_api_between_plugins() -> None: - """未公开的 API 默认不允许跨插件调用。""" - - provider_supervisor = PluginSupervisor(plugin_dirs=[]) - consumer_supervisor = PluginSupervisor(plugin_dirs=[]) - await _register_plugin( - provider_supervisor, - "provider", - [ - { - "name": "secret_api", - "component_type": "API", - "metadata": { - "description": "私有 API", - "version": "1", - "public": False, - }, - } - ], - ) - await _register_plugin(consumer_supervisor, "consumer", []) - - manager = _build_manager(provider_supervisor, consumer_supervisor) - result = await manager._cap_api_call( - "consumer", - "api.call", - { - "api_name": "provider.secret_api", - "args": {}, - }, - ) - - assert result["success"] is False - assert "未公开" in str(result["error"]) - - -@pytest.mark.asyncio -async def test_api_list_and_component_toggle_use_dedicated_registry() -> None: - """API 列表与组件启停应直接作用于独立 API 注册表。""" - - provider_supervisor = PluginSupervisor(plugin_dirs=[]) - consumer_supervisor = PluginSupervisor(plugin_dirs=[]) - await _register_plugin( - provider_supervisor, - "provider", - [ - { - "name": "public_api", - "component_type": "API", - "metadata": {"version": "1", "public": True}, - }, - { - "name": "private_api", - "component_type": "API", - "metadata": {"version": "1", "public": False}, - }, - ], - ) - await _register_plugin( - consumer_supervisor, - "consumer", - [ - { - "name": "self_private_api", - "component_type": "API", - "metadata": {"version": "1", "public": False}, - } - ], - ) - - manager = _build_manager(provider_supervisor, consumer_supervisor) - list_result = await manager._cap_api_list("consumer", "api.list", {}) - - assert list_result["success"] is True - api_names = {(item["plugin_id"], item["name"]) for item in list_result["apis"]} - assert ("provider", "public_api") in api_names - assert ("provider", "private_api") not in api_names - assert ("consumer", "self_private_api") in api_names - - disable_result = await manager._cap_component_disable( - "consumer", - "component.disable", - { - "name": "provider.public_api", - "component_type": "API", - "scope": "global", - "stream_id": "", - }, - ) - assert disable_result["success"] is True - assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is None - - enable_result = await manager._cap_component_enable( - "consumer", - "component.enable", - { - "name": "provider.public_api", - "component_type": "API", - "scope": "global", - "stream_id": "", - }, - ) - assert enable_result["success"] is True - assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is not None - - -@pytest.mark.asyncio -async def test_api_registry_supports_multiple_versions_with_distinct_handlers( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """同名 API 不同版本应可并存,并按版本路由到不同处理器。""" - - provider_supervisor = PluginSupervisor(plugin_dirs=[]) - consumer_supervisor = PluginSupervisor(plugin_dirs=[]) - await _register_plugin( - provider_supervisor, - "provider", - [ - { - "name": "render_html", - "component_type": "API", - "metadata": { - "description": "渲染 HTML v1", - "version": "1", - "public": True, - "handler_name": "handle_render_html_v1", - }, - }, - { - "name": "render_html", - "component_type": "API", - "metadata": { - "description": "渲染 HTML v2", - "version": "2", - "public": True, - "handler_name": "handle_render_html_v2", - }, - }, - ], - ) - await _register_plugin(consumer_supervisor, "consumer", []) - - captured: Dict[str, Any] = {} - - async def fake_invoke_api( - plugin_id: str, - component_name: str, - args: Dict[str, Any] | None = None, - timeout_ms: int = 30000, - ) -> Any: - """模拟多版本 API 调用。""" - - captured["plugin_id"] = plugin_id - captured["component_name"] = component_name - captured["args"] = args or {} - captured["timeout_ms"] = timeout_ms - return SimpleNamespace(error=None, payload={"success": True, "result": {"image": "ok"}}) - - monkeypatch.setattr(provider_supervisor, "invoke_api", fake_invoke_api) - manager = _build_manager(provider_supervisor, consumer_supervisor) - - ambiguous_result = await manager._cap_api_call( - "consumer", - "api.call", - { - "api_name": "provider.render_html", - "args": {"html": "
Hello
"}, - }, - ) - assert ambiguous_result["success"] is False - assert "多个版本" in str(ambiguous_result["error"]) - - disable_ambiguous_result = await manager._cap_component_disable( - "consumer", - "component.disable", - { - "name": "provider.render_html", - "component_type": "API", - "scope": "global", - "stream_id": "", - }, - ) - assert disable_ambiguous_result["success"] is False - assert "多个版本" in str(disable_ambiguous_result["error"]) - - disable_v1_result = await manager._cap_component_disable( - "consumer", - "component.disable", - { - "name": "provider.render_html", - "component_type": "API", - "scope": "global", - "stream_id": "", - "version": "1", - }, - ) - assert disable_v1_result["success"] is True - assert provider_supervisor.api_registry.get_api("provider", "render_html", version="1", enabled_only=True) is None - assert provider_supervisor.api_registry.get_api("provider", "render_html", version="2", enabled_only=True) is not None - - result = await manager._cap_api_call( - "consumer", - "api.call", - { - "api_name": "provider.render_html", - "version": "2", - "args": {"html": "
Hello
"}, - }, - ) - - assert result == {"success": True, "result": {"image": "ok"}} - assert captured["plugin_id"] == "provider" - assert captured["component_name"] == "handle_render_html_v2" - assert captured["args"] == {"html": "
Hello
"} - - -@pytest.mark.asyncio -async def test_api_replace_dynamic_can_offline_removed_entries( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """动态 API 替换后,被移除的 API 应返回明确下线错误。""" - - supervisor = PluginSupervisor(plugin_dirs=[]) - await _register_plugin(supervisor, "provider", []) - manager = _build_manager(supervisor) - - captured: Dict[str, Any] = {} - - async def fake_invoke_api( - plugin_id: str, - component_name: str, - args: Dict[str, Any] | None = None, - timeout_ms: int = 30000, - ) -> Any: - """模拟动态 API 调用。""" - - captured["plugin_id"] = plugin_id - captured["component_name"] = component_name - captured["args"] = args or {} - captured["timeout_ms"] = timeout_ms - return SimpleNamespace(error=None, payload={"success": True, "result": {"ok": True}}) - - monkeypatch.setattr(supervisor, "invoke_api", fake_invoke_api) - - replace_result = await manager._cap_api_replace_dynamic( - "provider", - "api.replace_dynamic", - { - "apis": [ - { - "name": "mcp.search", - "type": "API", - "metadata": { - "version": "1", - "public": True, - "handler_name": "dynamic_search", - }, - }, - { - "name": "mcp.read", - "type": "API", - "metadata": { - "version": "1", - "public": True, - "handler_name": "dynamic_read", - }, - }, - ], - "offline_reason": "MCP 服务器已关闭", - }, - ) - - assert replace_result["success"] is True - assert replace_result["count"] == 2 - list_result = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"}) - assert {(item["name"], item["version"]) for item in list_result["apis"]} == { - ("mcp.read", "1"), - ("mcp.search", "1"), - } - - call_result = await manager._cap_api_call( - "provider", - "api.call", - { - "api_name": "provider.mcp.search", - "version": "1", - "args": {"query": "hello"}, - }, - ) - assert call_result == {"success": True, "result": {"ok": True}} - assert captured["component_name"] == "dynamic_search" - assert captured["args"]["query"] == "hello" - assert captured["args"]["__maibot_api_name__"] == "mcp.search" - assert captured["args"]["__maibot_api_version__"] == "1" - - second_replace_result = await manager._cap_api_replace_dynamic( - "provider", - "api.replace_dynamic", - { - "apis": [ - { - "name": "mcp.read", - "type": "API", - "metadata": { - "version": "1", - "public": True, - "handler_name": "dynamic_read", - }, - } - ], - "offline_reason": "MCP 服务器已关闭", - }, - ) - - assert second_replace_result["success"] is True - assert second_replace_result["count"] == 1 - assert second_replace_result["offlined"] == 1 - - offlined_call_result = await manager._cap_api_call( - "provider", - "api.call", - { - "api_name": "provider.mcp.search", - "version": "1", - "args": {}, - }, - ) - assert offlined_call_result["success"] is False - assert "MCP 服务器已关闭" in str(offlined_call_result["error"]) - - list_after_replace = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"}) - assert {(item["name"], item["version"]) for item in list_after_replace["apis"]} == { - ("mcp.read", "1"), - } diff --git a/pytests/test_plugin_runtime_render.py b/pytests/test_plugin_runtime_render.py deleted file mode 100644 index f90dfad9..00000000 --- a/pytests/test_plugin_runtime_render.py +++ /dev/null @@ -1,96 +0,0 @@ -"""插件运行时浏览器渲染能力测试。""" - -from typing import Optional - -import pytest - -from src.plugin_runtime.integration import PluginRuntimeManager -from src.plugin_runtime.host.supervisor import PluginSupervisor -from src.services.html_render_service import HtmlRenderRequest, HtmlRenderResult - - -class _FakeRenderService: - """用于替代真实浏览器渲染服务的测试桩。""" - - def __init__(self) -> None: - """初始化测试桩。""" - - self.last_request: Optional[HtmlRenderRequest] = None - - async def render_html_to_png(self, request: HtmlRenderRequest) -> HtmlRenderResult: - """记录请求并返回固定的渲染结果。 - - Args: - request: 当前渲染请求。 - - Returns: - HtmlRenderResult: 固定的测试渲染结果。 - """ - - self.last_request = request - return HtmlRenderResult( - image_base64="ZmFrZS1pbWFnZQ==", - mime_type="image/png", - width=640, - height=480, - render_ms=12, - ) - - -def test_render_capability_is_registered() -> None: - """Host 注册能力时应包含 render.html2png。""" - - manager = PluginRuntimeManager() - supervisor = PluginSupervisor(plugin_dirs=[]) - - manager._register_capability_impls(supervisor) - - assert "render.html2png" in supervisor.capability_service.list_capabilities() - - -@pytest.mark.asyncio -async def test_render_capability_forwards_request(monkeypatch: pytest.MonkeyPatch) -> None: - """render.html2png 应将请求透传给浏览器渲染服务。""" - - from src.plugin_runtime.capabilities import render as render_capability_module - - fake_service = _FakeRenderService() - monkeypatch.setattr(render_capability_module, "get_html_render_service", lambda: fake_service) - - manager = PluginRuntimeManager() - result = await manager._cap_render_html2png( - "demo.plugin", - "render.html2png", - { - "html": "
hello
", - "selector": "#card", - "viewport": {"width": 1024, "height": 768}, - "device_scale_factor": 1.5, - "full_page": False, - "omit_background": True, - "wait_until": "networkidle", - "wait_for_selector": "#card", - "wait_for_timeout_ms": 150, - "timeout_ms": 3000, - "allow_network": True, - }, - ) - - assert result == { - "success": True, - "result": { - "image_base64": "ZmFrZS1pbWFnZQ==", - "mime_type": "image/png", - "width": 640, - "height": 480, - "render_ms": 12, - }, - } - assert fake_service.last_request is not None - assert fake_service.last_request.selector == "#card" - assert fake_service.last_request.viewport_width == 1024 - assert fake_service.last_request.viewport_height == 768 - assert fake_service.last_request.device_scale_factor == 1.5 - assert fake_service.last_request.omit_background is True - assert fake_service.last_request.wait_until == "networkidle" - assert fake_service.last_request.allow_network is True diff --git a/pytests/test_prompt_message_roundtrip.py b/pytests/test_prompt_message_roundtrip.py deleted file mode 100644 index 01878585..00000000 --- a/pytests/test_prompt_message_roundtrip.py +++ /dev/null @@ -1,18 +0,0 @@ -from src.llm_models.payload_content.message import MessageBuilder, RoleType -from src.plugin_runtime.hook_payloads import deserialize_prompt_messages, serialize_prompt_messages - - -def test_prompt_messages_roundtrip_preserves_image_parts() -> None: - messages = [ - MessageBuilder().set_role(RoleType.User).add_text_content("你好").add_image_content("png", "ZmFrZQ==").build(), - ] - - serialized_messages = serialize_prompt_messages(messages) - restored_messages = deserialize_prompt_messages(serialized_messages) - - assert len(restored_messages) == 1 - assert restored_messages[0].role == RoleType.User - assert restored_messages[0].get_text_content() == "你好" - assert len(restored_messages[0].parts) == 2 - assert restored_messages[0].parts[1].image_format == "png" - assert restored_messages[0].parts[1].image_base64 == "ZmFrZQ==" diff --git a/pytests/test_runtime_business_hooks.py b/pytests/test_runtime_business_hooks.py deleted file mode 100644 index baaa052f..00000000 --- a/pytests/test_runtime_business_hooks.py +++ /dev/null @@ -1,136 +0,0 @@ -"""业务命名 Hook 集成测试。""" - -from types import SimpleNamespace -from typing import Any - -import os -import sys - -import pytest - -# 确保项目根目录在 sys.path 中 -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) -# SDK 包路径 -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "packages", "maibot-plugin-sdk")) - - -class _FakeHookManager: - """用于业务 Hook 测试的最小运行时管理器。""" - - def __init__(self, responses: dict[str, SimpleNamespace]) -> None: - """初始化测试管理器。 - - Args: - responses: 按 Hook 名称预设的返回结果映射。 - """ - - self._responses = responses - self.calls: list[tuple[str, dict[str, Any]]] = [] - - async def invoke_hook(self, hook_name: str, **kwargs: Any) -> SimpleNamespace: - """模拟调用运行时命名 Hook。 - - Args: - hook_name: 目标 Hook 名称。 - **kwargs: 传入 Hook 的参数。 - - Returns: - SimpleNamespace: 预设的 Hook 返回结果。 - """ - - self.calls.append((hook_name, dict(kwargs))) - return self._responses.get(hook_name, SimpleNamespace(kwargs=dict(kwargs), aborted=False)) - - -def test_builtin_hook_catalog_includes_new_business_hooks(monkeypatch: pytest.MonkeyPatch) -> None: - """内置 Hook 目录应包含三个业务系统新增的 Hook。""" - - monkeypatch.setattr(sys, "exit", lambda code=0: None) - from src.plugin_runtime.hook_catalog import register_builtin_hook_specs - from src.plugin_runtime.host.hook_spec_registry import HookSpecRegistry - - registry = HookSpecRegistry() - hook_names = {spec.name for spec in register_builtin_hook_specs(registry)} - - assert "emoji.maisaka.before_select" in hook_names - assert "emoji.register.after_build_emotion" in hook_names - assert "jargon.extract.before_persist" in hook_names - assert "jargon.query.after_search" in hook_names - assert "expression.select.before_select" in hook_names - assert "expression.learn.before_upsert" in hook_names - - -@pytest.mark.asyncio -async def test_send_emoji_for_maisaka_can_be_aborted_by_hook(monkeypatch: pytest.MonkeyPatch) -> None: - """表情包系统应允许在选择前被 Hook 中止。""" - - from src.emoji_system import maisaka_tool - - fake_manager = _FakeHookManager( - { - "emoji.maisaka.before_select": SimpleNamespace( - kwargs={"abort_message": "插件阻止了表情发送。"}, - aborted=True, - ) - } - ) - monkeypatch.setattr(maisaka_tool, "_get_runtime_manager", lambda: fake_manager) - - result = await maisaka_tool.send_emoji_for_maisaka(stream_id="stream-1", requested_emotion="开心") - - assert result.success is False - assert result.message == "插件阻止了表情发送。" - assert fake_manager.calls[0][0] == "emoji.maisaka.before_select" - - -@pytest.mark.asyncio -async def test_jargon_extract_can_be_aborted_before_persist(monkeypatch: pytest.MonkeyPatch) -> None: - """黑话提取结果应允许在写库前被 Hook 中止。""" - - from src.learners.jargon_miner import JargonMiner - - fake_manager = _FakeHookManager( - { - "jargon.extract.before_persist": SimpleNamespace( - kwargs={"entries": []}, - aborted=True, - ) - } - ) - monkeypatch.setattr(JargonMiner, "_get_runtime_manager", staticmethod(lambda: fake_manager)) - - miner = JargonMiner(session_id="session-1", session_name="测试会话") - await miner.process_extracted_entries( - [{"content": "yyds", "raw_content": {"[1] yyds 太强了"}}], - ) - - assert fake_manager.calls[0][0] == "jargon.extract.before_persist" - assert fake_manager.calls[0][1]["session_id"] == "session-1" - - -@pytest.mark.asyncio -async def test_expression_selection_can_be_aborted_by_hook(monkeypatch: pytest.MonkeyPatch) -> None: - """表达方式选择流程应允许在开始前被 Hook 中止。""" - - from src.learners.expression_selector import ExpressionSelector - - fake_manager = _FakeHookManager( - { - "expression.select.before_select": SimpleNamespace( - kwargs={}, - aborted=True, - ) - } - ) - monkeypatch.setattr(ExpressionSelector, "_get_runtime_manager", staticmethod(lambda: fake_manager)) - monkeypatch.setattr(ExpressionSelector, "can_use_expression_for_chat", lambda self, chat_id: True) - - selector = ExpressionSelector() - selected_expressions, selected_ids = await selector.select_suitable_expressions( - chat_id="session-1", - chat_info="用户刚刚发来一条消息。", - ) - - assert selected_expressions == [] - assert selected_ids == [] - assert fake_manager.calls[0][0] == "expression.select.before_select" diff --git a/pytests/test_send_service.py b/pytests/test_send_service.py deleted file mode 100644 index 23d51eb6..00000000 --- a/pytests/test_send_service.py +++ /dev/null @@ -1,360 +0,0 @@ -"""发送服务回归测试。""" - -import sys -from types import SimpleNamespace -from typing import Any, Dict, List - -import pytest - -from src.chat.message_receive.chat_manager import BotChatSession -from src.common.data_models.message_component_data_model import MessageSequence, TextComponent -from src.services import send_service - - -class _FakePlatformIOManager: - """用于测试的 Platform IO 管理器假对象。""" - - def __init__(self, delivery_batch: Any) -> None: - self._delivery_batch = delivery_batch - self.ensure_calls = 0 - self.sent_messages: List[Dict[str, Any]] = [] - - async def ensure_send_pipeline_ready(self) -> None: - self.ensure_calls += 1 - - def build_route_key_from_message(self, message: Any) -> Any: - del message - return SimpleNamespace(platform="qq") - - async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any: - self.sent_messages.append( - { - "message": message, - "message_id_before_send": str(getattr(message, "message_id", "") or ""), - "route_key": route_key, - "metadata": metadata, - } - ) - return self._delivery_batch - - -def _build_private_stream() -> BotChatSession: - return BotChatSession( - session_id="test-session", - platform="qq", - user_id="target-user", - group_id=None, - ) - - -def _build_group_stream() -> BotChatSession: - return BotChatSession( - session_id="group-session", - platform="qq", - user_id="target-user", - group_id="target-group", - ) - - -def test_inherit_platform_io_route_metadata_falls_back_to_bot_account( - monkeypatch: pytest.MonkeyPatch, -) -> None: - monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq" if platform == "qq" else "") - - metadata = send_service._inherit_platform_io_route_metadata(_build_private_stream()) - - assert metadata["platform_io_account_id"] == "bot-qq" - assert metadata["platform_io_target_user_id"] == "target-user" - - -@pytest.mark.asyncio -async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.MonkeyPatch) -> None: - import src.common.message_server.api as message_server_api - - fake_manager = _FakePlatformIOManager( - delivery_batch=SimpleNamespace( - has_success=True, - sent_receipts=[ - SimpleNamespace( - driver_id="plugin.qq.sender", - external_message_id="real-message-id", - metadata={ - "adapter_callbacks": [ - { - "name": "message_id_echo", - "payload": { - "content": { - "type": "echo", - "echo": "send_api_test", - "actual_id": "real-message-id", - } - }, - } - ] - }, - ) - ], - failed_receipts=[], - route_key=SimpleNamespace(platform="qq"), - ) - ) - callback_payloads: List[Dict[str, Any]] = [] - stored_messages: List[Any] = [] - - async def fake_echo_handler(payload: Dict[str, Any]) -> None: - """记录发送成功后的消息 ID 回调。""" - - callback_payloads.append(payload) - - monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager) - monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq") - monkeypatch.setattr( - send_service._chat_manager, - "get_session_by_session_id", - lambda stream_id: _build_private_stream() if stream_id == "test-session" else None, - ) - monkeypatch.setattr( - send_service.MessageUtils, - "store_message_to_db", - lambda message: stored_messages.append(message), - ) - monkeypatch.setattr( - message_server_api, - "global_api", - SimpleNamespace(_custom_message_handlers={"message_id_echo": fake_echo_handler}), - ) - - result = await send_service.text_to_stream(text="你好", stream_id="test-session") - - assert result is True - assert fake_manager.ensure_calls == 1 - assert len(fake_manager.sent_messages) == 1 - assert fake_manager.sent_messages[0]["metadata"] == {"show_log": False} - assert len(stored_messages) == 1 - assert stored_messages[0].message_id == "real-message-id" - assert callback_payloads == [ - { - "content": { - "type": "echo", - "echo": "send_api_test", - "actual_id": "real-message-id", - } - } - ] - - -@pytest.mark.asyncio -async def test_text_to_stream_with_message_returns_sent_message(monkeypatch: pytest.MonkeyPatch) -> None: - fake_manager = _FakePlatformIOManager( - delivery_batch=SimpleNamespace( - has_success=True, - sent_receipts=[ - SimpleNamespace( - driver_id="plugin.qq.sender", - external_message_id="real-message-id", - metadata={}, - ) - ], - failed_receipts=[], - route_key=SimpleNamespace(platform="qq"), - ) - ) - stored_messages: List[Any] = [] - - monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager) - monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq") - monkeypatch.setattr( - send_service._chat_manager, - "get_session_by_session_id", - lambda stream_id: _build_private_stream() if stream_id == "test-session" else None, - ) - monkeypatch.setattr( - send_service.MessageUtils, - "store_message_to_db", - lambda message: stored_messages.append(message), - ) - - sent_message = await send_service.text_to_stream_with_message(text="你好", stream_id="test-session") - - assert sent_message is not None - assert sent_message.message_id == "real-message-id" - assert fake_manager.ensure_calls == 1 - assert len(stored_messages) == 1 - assert stored_messages[0].message_id == "real-message-id" - - -@pytest.mark.asyncio -async def test_text_to_stream_with_message_triggers_memory_and_syncs_maisaka_history( - monkeypatch: pytest.MonkeyPatch, -) -> None: - fake_manager = _FakePlatformIOManager( - delivery_batch=SimpleNamespace( - has_success=True, - sent_receipts=[ - SimpleNamespace( - driver_id="plugin.qq.sender", - external_message_id="real-message-id", - metadata={}, - ) - ], - failed_receipts=[], - route_key=SimpleNamespace(platform="qq"), - ) - ) - stored_messages: List[Any] = [] - memory_events: List[str] = [] - history_events: List[tuple[str, str]] = [] - - class FakeMemoryAutomationService: - async def on_message_sent(self, message: Any) -> None: - memory_events.append(str(message.message_id)) - - class FakeRuntime: - def append_sent_message_to_chat_history(self, message: Any, *, source_kind: str = "guided_reply") -> None: - history_events.append((str(message.message_id), source_kind)) - - monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager) - monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq") - monkeypatch.setattr( - send_service._chat_manager, - "get_session_by_session_id", - lambda stream_id: _build_private_stream() if stream_id == "test-session" else None, - ) - monkeypatch.setattr( - send_service.MessageUtils, - "store_message_to_db", - lambda message: stored_messages.append(message), - ) - monkeypatch.setitem( - sys.modules, - "src.services.memory_flow_service", - SimpleNamespace(memory_automation_service=FakeMemoryAutomationService()), - ) - monkeypatch.setitem( - sys.modules, - "src.chat.heart_flow.heartflow_manager", - SimpleNamespace(heartflow_manager=SimpleNamespace(heartflow_chat_list={"test-session": FakeRuntime()})), - ) - - sent_message = await send_service.text_to_stream_with_message( - text="你好", - stream_id="test-session", - sync_to_maisaka_history=True, - maisaka_source_kind="guided_reply", - ) - - assert sent_message is not None - assert sent_message.message_id == "real-message-id" - assert fake_manager.ensure_calls == 1 - assert len(stored_messages) == 1 - assert stored_messages[0].message_id == "real-message-id" - assert memory_events == ["real-message-id"] - assert history_events == [("real-message-id", "guided_reply")] - - -@pytest.mark.asyncio -async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: pytest.MonkeyPatch) -> None: - fake_manager = _FakePlatformIOManager( - delivery_batch=SimpleNamespace( - has_success=False, - sent_receipts=[], - failed_receipts=[ - SimpleNamespace( - driver_id="plugin.qq.sender", - status="failed", - error="network error", - ) - ], - route_key=SimpleNamespace(platform="qq"), - ) - ) - - monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager) - monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq") - monkeypatch.setattr( - send_service._chat_manager, - "get_session_by_session_id", - lambda stream_id: _build_private_stream() if stream_id == "test-session" else None, - ) - - result = await send_service.text_to_stream(text="发送失败", stream_id="test-session") - - assert result is False - assert fake_manager.ensure_calls == 1 - assert len(fake_manager.sent_messages) == 1 - - -@pytest.mark.asyncio -async def test_custom_to_stream_detailed_raises_stream_not_found(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr( - send_service._chat_manager, - "get_session_by_session_id", - lambda stream_id: None, - ) - - with pytest.raises(send_service.SendServiceError, match="未找到聊天流: group_chat"): - await send_service.custom_to_stream_detailed( - message_type="poke", - content={"qq_id": "2810873701"}, - stream_id="group_chat", - ) - - -@pytest.mark.asyncio -async def test_private_outbound_message_preserves_bot_sender_and_receiver_user( - monkeypatch: pytest.MonkeyPatch, -) -> None: - monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq") - monkeypatch.setattr( - send_service._chat_manager, - "get_session_by_session_id", - lambda stream_id: _build_private_stream() if stream_id == "test-session" else None, - ) - - outbound_message = send_service._build_outbound_session_message( - message_sequence=MessageSequence(components=[TextComponent(text="你好")]), - stream_id="test-session", - processed_plain_text="你好", - ) - - assert outbound_message is not None - maim_message = await outbound_message.to_maim_message() - - assert maim_message.message_info.user_info is not None - assert maim_message.message_info.user_info.user_id == "bot-qq" - assert maim_message.message_info.group_info is None - assert maim_message.message_info.sender_info is not None - assert maim_message.message_info.sender_info.user_info is not None - assert maim_message.message_info.sender_info.user_info.user_id == "bot-qq" - assert maim_message.message_info.receiver_info is not None - assert maim_message.message_info.receiver_info.user_info is not None - assert maim_message.message_info.receiver_info.user_info.user_id == "target-user" - - -@pytest.mark.asyncio -async def test_group_outbound_message_preserves_bot_sender_and_target_group( - monkeypatch: pytest.MonkeyPatch, -) -> None: - monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq") - monkeypatch.setattr( - send_service._chat_manager, - "get_session_by_session_id", - lambda stream_id: _build_group_stream() if stream_id == "group-session" else None, - ) - - outbound_message = send_service._build_outbound_session_message( - message_sequence=MessageSequence(components=[TextComponent(text="大家好")]), - stream_id="group-session", - processed_plain_text="大家好", - ) - - assert outbound_message is not None - maim_message = await outbound_message.to_maim_message() - - assert maim_message.message_info.user_info is not None - assert maim_message.message_info.user_info.user_id == "bot-qq" - assert maim_message.message_info.group_info is not None - assert maim_message.message_info.group_info.group_id == "target-group" - assert maim_message.message_info.receiver_info is not None - assert maim_message.message_info.receiver_info.group_info is not None - assert maim_message.message_info.receiver_info.group_info.group_id == "target-group" diff --git a/pytests/test_tool_availability.py b/pytests/test_tool_availability.py deleted file mode 100644 index 009b496b..00000000 --- a/pytests/test_tool_availability.py +++ /dev/null @@ -1,297 +0,0 @@ -from types import SimpleNamespace -from typing import Any -import importlib.util -import sys - -import pytest - -from src.core.tooling import ToolAvailabilityContext, ToolRegistry -from src.maisaka.tool_provider import MaisakaBuiltinToolProvider -from src.plugin_runtime.component_query import ComponentQueryService -from src.plugin_runtime.host.component_registry import ComponentRegistry - - -@pytest.mark.asyncio -async def test_builtin_at_tool_is_not_exposed() -> None: - registry = ToolRegistry() - registry.register_provider(MaisakaBuiltinToolProvider()) - - group_specs = await registry.list_tools(ToolAvailabilityContext(session_id="group-1", is_group_chat=True)) - private_specs = await registry.list_tools(ToolAvailabilityContext(session_id="private-1", is_group_chat=False)) - default_specs = await registry.list_tools() - - assert "at" not in {tool_spec.name for tool_spec in group_specs} - assert "at" not in {tool_spec.name for tool_spec in private_specs} - assert "at" not in {tool_spec.name for tool_spec in default_specs} - - -def test_plugin_tool_chat_scope_uses_component_field(monkeypatch: pytest.MonkeyPatch) -> None: - service = ComponentQueryService() - registry = ComponentRegistry() - supervisor = SimpleNamespace(component_registry=registry) - monkeypatch.setattr(service, "_iter_supervisors", lambda: [supervisor]) - - registry.register_plugin_components( - "scope_plugin", - [ - { - "name": "group_tool", - "component_type": "TOOL", - "chat_scope": "group", - "metadata": {"description": "group only"}, - }, - { - "name": "private_tool", - "component_type": "TOOL", - "chat_scope": "private", - "metadata": {"description": "private only"}, - }, - { - "name": "all_tool", - "component_type": "TOOL", - "metadata": {"description": "all chats"}, - }, - ], - ) - - group_specs = service.get_llm_available_tool_specs( - context=ToolAvailabilityContext(session_id="group-1", is_group_chat=True) - ) - private_specs = service.get_llm_available_tool_specs( - context=ToolAvailabilityContext(session_id="private-1", is_group_chat=False) - ) - - group_entry = registry.get_component("scope_plugin.group_tool") - assert group_entry is not None - assert group_entry.chat_scope == "group" - assert "chat_scope" not in group_entry.metadata - assert set(group_specs) == {"group_tool", "all_tool"} - assert set(private_specs) == {"private_tool", "all_tool"} - - -def test_plugin_tool_session_disable_still_filters_specific_chat(monkeypatch: pytest.MonkeyPatch) -> None: - service = ComponentQueryService() - registry = ComponentRegistry() - supervisor = SimpleNamespace(component_registry=registry) - monkeypatch.setattr(service, "_iter_supervisors", lambda: [supervisor]) - - registry.register_plugin_components( - "mute_plugin", - [ - { - "name": "mute", - "component_type": "TOOL", - "chat_scope": "group", - "metadata": {"description": "mute group member"}, - } - ], - ) - registry.set_component_enabled("mute_plugin.mute", False, session_id="group-disabled") - - disabled_specs = service.get_llm_available_tool_specs( - context=ToolAvailabilityContext(session_id="group-disabled", is_group_chat=True) - ) - enabled_specs = service.get_llm_available_tool_specs( - context=ToolAvailabilityContext(session_id="group-enabled", is_group_chat=True) - ) - - assert "mute" not in disabled_specs - assert "mute" in enabled_specs - - -def test_plugin_tool_allowed_session_filters_tool_exposure(monkeypatch: pytest.MonkeyPatch) -> None: - service = ComponentQueryService() - registry = ComponentRegistry() - supervisor = SimpleNamespace(component_registry=registry) - monkeypatch.setattr(service, "_iter_supervisors", lambda: [supervisor]) - - registry.register_plugin_components( - "mute_plugin", - [ - { - "name": "mute", - "component_type": "TOOL", - "chat_scope": "group", - "allowed_session": ["qq:10001", "raw-group-id", "exact-session-id"], - "metadata": {"description": "mute group member"}, - } - ], - ) - - platform_group_specs = service.get_llm_available_tool_specs( - context=ToolAvailabilityContext( - session_id="hashed-session-1", - is_group_chat=True, - group_id="10001", - platform="qq", - ) - ) - raw_group_specs = service.get_llm_available_tool_specs( - context=ToolAvailabilityContext( - session_id="hashed-session-2", - is_group_chat=True, - group_id="raw-group-id", - platform="qq", - ) - ) - exact_session_specs = service.get_llm_available_tool_specs( - context=ToolAvailabilityContext(session_id="exact-session-id", is_group_chat=True) - ) - blocked_specs = service.get_llm_available_tool_specs( - context=ToolAvailabilityContext( - session_id="blocked-session", - is_group_chat=True, - group_id="20002", - platform="qq", - ) - ) - - entry = registry.get_component("mute_plugin.mute") - assert entry is not None - assert entry.allowed_session == {"qq:10001", "raw-group-id", "exact-session-id"} - assert "allowed_session" not in entry.metadata - assert "mute" in platform_group_specs - assert "mute" in raw_group_specs - assert "mute" in exact_session_specs - assert "mute" not in blocked_specs - - -def test_plugin_tool_disabled_session_take_precedence_over_allowed_session( - monkeypatch: pytest.MonkeyPatch, -) -> None: - service = ComponentQueryService() - registry = ComponentRegistry() - supervisor = SimpleNamespace(component_registry=registry) - monkeypatch.setattr(service, "_iter_supervisors", lambda: [supervisor]) - - registry.register_plugin_components( - "mute_plugin", - [ - { - "name": "mute", - "component_type": "TOOL", - "chat_scope": "group", - "allowed_session": ["qq:10001"], - "metadata": {"description": "mute group member"}, - } - ], - ) - registry.set_component_enabled("mute_plugin.mute", False, session_id="allowed-session") - - visible_specs = service.get_llm_available_tool_specs( - context=ToolAvailabilityContext( - session_id="visible-session", - is_group_chat=True, - group_id="10001", - platform="qq", - ) - ) - disabled_specs = service.get_llm_available_tool_specs( - context=ToolAvailabilityContext( - session_id="allowed-session", - is_group_chat=True, - group_id="10001", - platform="qq", - ) - ) - - entry = registry.get_component("mute_plugin.mute") - assert entry is not None - assert entry.disabled_session == {"allowed-session"} - assert "mute" in visible_specs - assert "mute" not in disabled_specs - - -def test_mute_plugin_exports_allowed_groups_as_component_allowed_session() -> None: - module_path = "plugins/MutePlugin/plugin.py" - spec = importlib.util.spec_from_file_location("mute_plugin_under_test", module_path) - assert spec is not None - assert spec.loader is not None - module = importlib.util.module_from_spec(spec) - sys.modules[spec.name] = module - spec.loader.exec_module(module) - module.MutePluginConfig.model_rebuild() - - plugin = module.MutePlugin() - plugin.set_plugin_config({"permissions": {"allowed_groups": ["qq:10001", "raw-group-id"]}}) - - mute_components = [component for component in plugin.get_components() if component.get("name") == "mute"] - - assert len(mute_components) == 1 - assert mute_components[0]["chat_scope"] == "group" - assert mute_components[0]["allowed_session"] == ["qq:10001", "raw-group-id"] - assert "allowed_session" not in mute_components[0]["metadata"] - - -@pytest.mark.asyncio -async def test_mute_tool_queries_target_message_with_current_chat_id() -> None: - module_path = "plugins/MutePlugin/plugin.py" - spec = importlib.util.spec_from_file_location("mute_plugin_under_test_msg_id", module_path) - assert spec is not None - assert spec.loader is not None - module = importlib.util.module_from_spec(spec) - sys.modules[spec.name] = module - spec.loader.exec_module(module) - module.MutePluginConfig.model_rebuild() - - capability_calls: list[dict[str, Any]] = [] - api_calls: list[dict[str, Any]] = [] - - async def fake_call_capability(name: str, **kwargs: Any) -> dict[str, Any]: - capability_calls.append({"name": name, **kwargs}) - return { - "success": True, - "result": { - "success": True, - "message": { - "message_info": { - "user_info": { - "user_id": "35529667", - "user_cardname": "目标用户", - "user_nickname": "目标昵称", - } - } - }, - }, - } - - async def fake_api_call(api_name: str, **kwargs: Any) -> dict[str, Any]: - api_calls.append({"name": api_name, **kwargs}) - if api_name == "adapter.napcat.group.get_group_member_info": - return {"success": True, "result": {"data": {"role": "member"}}} - return {"status": "ok", "retcode": 0} - - plugin = module.MutePlugin() - plugin.set_plugin_config({"components": {"enable_smart_mute": True}}) - plugin._set_context( - SimpleNamespace( - call_capability=fake_call_capability, - api=SimpleNamespace(call=fake_api_call), - logger=SimpleNamespace(info=lambda *args, **kwargs: None, warning=lambda *args, **kwargs: None), - ) - ) - - success, message = await plugin.handle_mute_tool( - stream_id="current-session-id", - group_id="766798517", - msg_id="2046083292", - duration=3600, - reason="测试", - ) - - assert success is True - assert message == "成功禁言 目标用户" - assert capability_calls == [ - { - "name": "message.get_by_id", - "message_id": "2046083292", - "chat_id": "current-session-id", - } - ] - assert api_calls[-1] == { - "name": "adapter.napcat.group.set_group_ban", - "version": "1", - "group_id": "766798517", - "user_id": "35529667", - "duration": 3600, - } diff --git a/pytests/utils_test/message_utils_test.py b/pytests/utils_test/message_utils_test.py deleted file mode 100644 index 4c5287a2..00000000 --- a/pytests/utils_test/message_utils_test.py +++ /dev/null @@ -1,367 +0,0 @@ -import sys -from dataclasses import dataclass, field -import pytest -import importlib -import importlib.util -from types import ModuleType -from pathlib import Path -from datetime import datetime -from typing import TYPE_CHECKING, Optional - -if TYPE_CHECKING: - from src.common.data_models.message_component_data_model import MessageSequence - from src.chat.message_receive.message import ( - SessionMessage, - TextComponent, - ImageComponent, - AtComponent, - ) - - -class DummyLogger: - def __init__(self) -> None: - self.logging_record = [] - - def debug(self, msg): - print(f"DEBUG: {msg}") - self.logging_record.append(f"DEBUG: {msg}") - - def info(self, msg): - print(f"INFO: {msg}") - self.logging_record.append(f"INFO: {msg}") - - def warning(self, msg): - print(f"WARNING: {msg}") - self.logging_record.append(f"WARNING: {msg}") - - def error(self, msg): - print(f"ERROR: {msg}") - self.logging_record.append(f"ERROR: {msg}") - - def critical(self, msg): - print(f"CRITICAL: {msg}") - self.logging_record.append(f"CRITICAL: {msg}") - - -def get_logger(name): - return DummyLogger() - - -class DummyDBSession: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - pass - - def exec(self, statement): - return self - - def first(self): - return None - - def commit(self): - pass - - def all(self): - return [] - - -def get_db_session(): - return DummyDBSession() - - -def get_manual_db_session(): - return DummyDBSession() - - -class DummySelect: - def __init__(self, model): - self.model = model - - def filter_by(self, **kwargs): - return self - - def where(self, condition): - return self - - def limit(self, n): - return self - - -def select(model): - return DummySelect(model) - - -async def dummy_get_voice_text(binary_data): - return None # 可以根据需要返回模拟的文本结果 - - -class DummyPersonUtils: - @staticmethod - def get_person_info_by_user_id_and_platform(user_id, platform): - return None # 可以根据需要返回模拟的用户信息 - - -class DummyConfig: - class MessageReceiveConfig: - ban_words = set() - ban_msgs_regex = set() - - message_receive = MessageReceiveConfig() - - -@dataclass -class UserInfo: - user_id: str - user_nickname: str - user_cardname: Optional[str] = None - - -@dataclass -class GroupInfo: - group_id: str - group_name: str - - -@dataclass -class MessageInfo: - user_info: UserInfo - group_info: Optional[GroupInfo] = None - additional_config: dict = field(default_factory=dict) - - -def setup_mocks(monkeypatch): - def _stub_module(name: str) -> ModuleType: - module = ModuleType(name) - monkeypatch.setitem(sys.modules, name, module) - return module - - # src.common.logger - logger_mod = _stub_module("src.common.logger") - # Mock the logger - logger_mod.get_logger = get_logger - - db_mod = _stub_module("src.common.database.database") - db_mod.get_db_session = get_db_session - db_mod.get_manual_db_session = get_manual_db_session - - db_model_mod = _stub_module("src.common.database.database_model") - db_model_mod.Messages = None # 可以根据需要添加更多的属性或方法 - - emoji_manager_mod = _stub_module("src.emoji_system.emoji_manager") - emoji_manager_mod.emoji_manager = None # 可以根据需要添加更多的属性或方法 - - image_manager_mod = _stub_module("src.chat.image_system.image_manager") - image_manager_mod.image_manager = None # 可以根据需要添加更多的属性或方法 - - voice_utils_mod = _stub_module("src.common.utils.utils_voice") - voice_utils_mod.get_voice_text = dummy_get_voice_text - - person_utils_mod = _stub_module("src.common.utils.utils_person") - person_utils_mod.PersonUtils = DummyPersonUtils - - config_mod = _stub_module("src.config.config") - config_mod.global_config = DummyConfig() - - -def load_message_via_file(monkeypatch): - setup_mocks(monkeypatch) - file_path = Path(__file__).parent.parent.parent / "src" / "chat" / "message_receive" / "message.py" - spec = importlib.util.spec_from_file_location("message", file_path) - message_module = importlib.util.module_from_spec(spec) - monkeypatch.setitem(sys.modules, "message_module", message_module) - spec.loader.exec_module(message_module) - message_module.select = select - SessionMessageClass = message_module.SessionMessage - TextComponentClass = message_module.TextComponent - ImageComponentClass = message_module.ImageComponent - EmojiComponentClass = message_module.EmojiComponent - VoiceComponentClass = message_module.VoiceComponent - AtComponentClass = message_module.AtComponent - ReplyComponentClass = message_module.ReplyComponent - ForwardNodeComponentClass = message_module.ForwardNodeComponent - MessageSequenceClass = sys.modules["src.common.data_models.message_component_data_model"].MessageSequence - ForwardComponentClass = sys.modules["src.common.data_models.message_component_data_model"].ForwardComponent - globals()["SessionMessage"] = SessionMessageClass - globals()["TextComponent"] = TextComponentClass - globals()["ImageComponent"] = ImageComponentClass - globals()["EmojiComponent"] = EmojiComponentClass - globals()["VoiceComponent"] = VoiceComponentClass - globals()["AtComponent"] = AtComponentClass - globals()["ReplyComponent"] = ReplyComponentClass - globals()["ForwardNodeComponent"] = ForwardNodeComponentClass - globals()["MessageSequence"] = MessageSequenceClass - globals()["ForwardComponent"] = ForwardComponentClass - return message_module - - -def dummy_number_to_short_id(original_id: int, salt: str, length: int = 6) -> str: - return "X" * length # 返回固定的字符串,长度由参数决定,模拟生成短ID的行为 - - -def dummy_is_bot_self(platform, user_id: str) -> bool: - return user_id == "bot_self" - - -def load_utils_via_file(monkeypatch): - setup_mocks(monkeypatch) - - # Mock math_utils 模块,供 from .math_utils import number_to_short_id 使用 - math_utils_mod = ModuleType("src.common.utils.math_utils") - math_utils_mod.number_to_short_id = dummy_number_to_short_id - math_utils_mod.TimestampMode = type( - "TimestampMode", (), {"NORMAL": "%Y-%m-%d %H:%M:%S", "NORMAL_NO_YMD": "%H:%M:%S", "RELATIVE": "relative"} - ) - math_utils_mod.translate_timestamp_to_human_readable = lambda timestamp, mode: ( - "2024-01-01 12:00:00" - ) # 返回固定的时间字符串 - monkeypatch.setitem(sys.modules, "src.common.utils.math_utils", math_utils_mod) - - # 确保包层级模块存在于 sys.modules 中,使相对导入能正确解析 - for pkg_name in ["src", "src.common", "src.common.utils"]: - if pkg_name not in sys.modules: - pkg_mod = ModuleType(pkg_name) - pkg_mod.__path__ = [] - monkeypatch.setitem(sys.modules, pkg_name, pkg_mod) - - file_path = Path(__file__).parent.parent.parent / "src" / "common" / "utils" / "utils_message.py" - spec = importlib.util.spec_from_file_location("src.common.utils.utils_message", file_path) - utils_module = importlib.util.module_from_spec(spec) - utils_module.__package__ = "src.common.utils" # 设置包,使相对导入生效 - monkeypatch.setitem(sys.modules, "src.common.utils.utils_message", utils_module) - monkeypatch.setitem(sys.modules, "message_utils_module", utils_module) - spec.loader.exec_module(utils_module) - utils_module.is_bot_self = dummy_is_bot_self - return utils_module - - -@pytest.mark.asyncio -async def test_message_utils(monkeypatch): - load_message_via_file(monkeypatch) - load_utils_via_file(monkeypatch) - - -@pytest.mark.asyncio -async def test_build_readable_message_basic(monkeypatch): - """基础用例:单条消息,显示行号""" - load_message_via_file(monkeypatch) - utils_module = load_utils_via_file(monkeypatch) - MessageUtils = utils_module.MessageUtils - - msg = SessionMessage("m1", datetime.now(), platform="test") - msg.platform = "test" - msg.session_id = "s_test" - user_info = UserInfo(user_id="u1", user_nickname="Alice") - msg.message_info = MessageInfo(user_info=user_info) - msg.raw_message = MessageSequence([TextComponent("Hello world")]) - text, mapping, _ = await MessageUtils.build_readable_message([msg], anonymize=False, show_lineno=True) - assert "[1] Alice说:Hello world" in text - assert mapping == {} - - -@pytest.mark.asyncio -async def test_build_readable_message_anonymize(monkeypatch): - """匿名化用例:验证 mapping 和返回文本""" - load_message_via_file(monkeypatch) - utils_module = load_utils_via_file(monkeypatch) - MessageUtils = utils_module.MessageUtils - - msg = SessionMessage("m2", datetime.now(), platform="test") - msg.session_id = "s_test" - user_info = UserInfo(user_id="u42", user_nickname="Bob") - msg.message_info = MessageInfo(user_info=user_info) - msg.raw_message = MessageSequence([TextComponent("Secret text")]) - text, mapping, _ = await MessageUtils.build_readable_message([msg], anonymize=True, show_lineno=False) - # 根据实现,original_name 为 user_nickname,因此文本中应包含原始名称 - assert "XXXXXX说:" in text - assert "u42" in mapping - assert mapping["u42"][0] == "XXXXXX" - assert mapping["u42"][1] == "Bob" - - -@pytest.mark.asyncio -async def test_build_readable_message_replace_bot(monkeypatch): - """替换机器人名用例:当 user_id 为 bot_self 时应被替换为 target_bot_name""" - load_message_via_file(monkeypatch) - utils_module = load_utils_via_file(monkeypatch) - MessageUtils = utils_module.MessageUtils - - msg = SessionMessage("m3", datetime.now(), platform="test") - msg.session_id = "s_test" - user_info = UserInfo(user_id="bot_self", user_nickname="SomeBot") - msg.message_info = MessageInfo(user_info=user_info) - msg.raw_message = MessageSequence([TextComponent("ping")]) - text, mapping, _ = await MessageUtils.build_readable_message([msg], replace_bot_name=True, target_bot_name="MAIBot") - assert "MAIBot说:ping" in text - - -@pytest.mark.asyncio -async def test_build_readable_message_image_extraction(monkeypatch): - """图片提取:验证 extract_pictures 为 True 时,文本中包含图片占位及 img_map 内容被返回""" - load_message_via_file(monkeypatch) - utils_module = load_utils_via_file(monkeypatch) - MessageUtils = utils_module.MessageUtils - - # 构建包含图片组件的消息 - img = ImageComponent(binary_hash="h", binary_data=b"\x01\x02", content="Img") - msg = SessionMessage("mi1", datetime.now(), platform="test") - msg.session_id = "s_img" - msg.raw_message = MessageSequence([img]) - msg.message_info = MessageInfo(UserInfo(user_id="ui_img", user_nickname="ImgUser")) - text, mapping, _ = await MessageUtils.build_readable_message([msg], extract_pictures=True) - # 应包含图片描述占位 - assert "图片1" in text - # mapping 不为空(匿名化未开启则为空) - assert isinstance(mapping, dict) - - -@pytest.mark.asyncio -async def test_build_readable_message_anonymize_and_replace_bot_name_and_lineno(monkeypatch): - """组合用例:多个消息同时包含匿名化、机器人名称替换""" - load_message_via_file(monkeypatch) - utils_module = load_utils_via_file(monkeypatch) - MessageUtils = utils_module.MessageUtils - # 构建多个消息 - msg1 = SessionMessage("m4", datetime.now(), platform="test") - msg1.session_id = "s_comb" - msg2 = SessionMessage("m5", datetime.now(), platform="test") - msg2.session_id = "s_comb" - msg1.message_info = MessageInfo(UserInfo(user_id="u_comb", user_nickname="Charlie")) - msg2.message_info = MessageInfo(UserInfo(user_id="bot_self", user_nickname="SomeBot")) - msg1.raw_message = MessageSequence([TextComponent("Hi")]) - msg2.raw_message = MessageSequence([TextComponent("Hello")]) - text, mapping, _ = await MessageUtils.build_readable_message( - [msg1, msg2], - anonymize=True, - replace_bot_name=True, - target_bot_name="MAIBot", - show_lineno=True, - ) - # 验证文本内容 - assert "[1] XXXXXX说:Hi" in text - assert "[2] MAIBot说:Hello" in text - # 验证 mapping 内容 - assert "u_comb" in mapping - assert mapping["u_comb"][0] == "XXXXXX" - - -@pytest.mark.asyncio -async def test_build_readable_message_with_at(monkeypatch): - """包含@组件的消息:验证@组件中的用户信息也被匿名化和替换""" - load_message_via_file(monkeypatch) - utils_module = load_utils_via_file(monkeypatch) - MessageUtils = utils_module.MessageUtils - - # 构建包含回复组件的消息 - at_comp = AtComponent(target_user_id="u_at", target_user_nickname="AtUser") - msg = SessionMessage("m_at", datetime.now(), platform="test") - msg.session_id = "s_at" - msg.raw_message = MessageSequence([at_comp]) - msg.message_info = MessageInfo(UserInfo(user_id="u_main", user_nickname="MainUser")) - text, mapping, _ = await MessageUtils.build_readable_message( - [msg], anonymize=True, replace_bot_name=True, target_bot_name="MAIBot" - ) - # 验证主消息和@组件中的用户信息都被处理 - assert "XXXXXX说:" in text # 主消息用户被匿名化 - assert "XXXXXX说:@XXXXXX" in text # @组件用户被匿名化 diff --git a/pytests/utils_test/statistic_test.py b/pytests/utils_test/statistic_test.py deleted file mode 100644 index 6e8a17d2..00000000 --- a/pytests/utils_test/statistic_test.py +++ /dev/null @@ -1,117 +0,0 @@ -"""统计模块数据库会话行为测试。""" - -from __future__ import annotations - -from contextlib import contextmanager -from datetime import datetime, timedelta -from types import ModuleType -from typing import Any, Callable, Iterator - -import sys - -import pytest - -from src.chat.utils import statistic - - -class _DummyResult: - """模拟 SQLModel 查询结果对象。""" - - def all(self) -> list[Any]: - """返回空结果集。 - - Returns: - list[Any]: 空列表。 - """ - return [] - - -class _DummySession: - """模拟数据库 Session。""" - - def exec(self, statement: Any) -> _DummyResult: - """执行查询语句并返回空结果。 - - Args: - statement: 待执行的查询语句。 - - Returns: - _DummyResult: 空结果对象。 - """ - del statement - return _DummyResult() - - -def _build_fake_get_db_session(calls: list[bool]) -> Callable[[bool], Iterator[_DummySession]]: - """构造一个记录 auto_commit 参数的假会话工厂。 - - Args: - calls: 用于记录每次调用 auto_commit 参数的列表。 - - Returns: - Callable[[bool], Iterator[_DummySession]]: 可替换 `get_db_session` 的上下文管理器工厂。 - """ - - @contextmanager - def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_DummySession]: - """记录会话参数并返回假 Session。 - - Args: - auto_commit: 是否启用自动提交。 - - Yields: - Iterator[_DummySession]: 假 Session 对象。 - """ - calls.append(auto_commit) - yield _DummySession() - - return _fake_get_db_session - - -def _build_statistic_task() -> statistic.StatisticOutputTask: - """构造一个最小可用的统计任务实例。 - - Returns: - statistic.StatisticOutputTask: 跳过 `__init__` 的测试实例。 - """ - task = statistic.StatisticOutputTask.__new__(statistic.StatisticOutputTask) - task.name_mapping = {} - return task - - -def _is_bot_self(platform: str, user_id: str) -> bool: - """返回固定的非机器人身份判断结果。 - - Args: - platform: 平台名称。 - user_id: 用户 ID。 - - Returns: - bool: 始终返回 ``False``。 - """ - del platform - del user_id - return False - - -def test_statistic_read_queries_disable_auto_commit(monkeypatch: pytest.MonkeyPatch) -> None: - """统计模块的纯读查询应关闭自动提交,避免 Session 退出后对象被 expire。""" - calls: list[bool] = [] - now = datetime.now() - task = _build_statistic_task() - - monkeypatch.setattr(statistic, "get_db_session", _build_fake_get_db_session(calls)) - - utils_module = ModuleType("src.chat.utils.utils") - utils_module.is_bot_self = _is_bot_self - monkeypatch.setitem(sys.modules, "src.chat.utils.utils", utils_module) - monkeypatch.setattr(statistic, "fetch_online_time_since", lambda query_start_time: []) - monkeypatch.setattr(statistic, "fetch_model_usage_since", lambda query_start_time: []) - monkeypatch.setattr(statistic, "fetch_messages_since", lambda query_start_time: []) - monkeypatch.setattr(statistic, "fetch_tool_records_since", lambda query_start_time: []) - - task._collect_message_count_for_period([("last_hour", now - timedelta(hours=1))]) - task._collect_interval_data(now, hours=1, interval_minutes=60) - task._collect_metrics_interval_data(now, hours=1, interval_hours=1) - - assert calls == [] diff --git a/pytests/utils_test/test_request_snapshot.py b/pytests/utils_test/test_request_snapshot.py deleted file mode 100644 index 99694eb5..00000000 --- a/pytests/utils_test/test_request_snapshot.py +++ /dev/null @@ -1,131 +0,0 @@ -from pathlib import Path - -import json - -from src.config.model_configs import APIProvider, ModelInfo -from src.llm_models.model_client.base_client import ResponseRequest -from src.llm_models.payload_content.message import MessageBuilder, RoleType -from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType -from src.llm_models.payload_content.tool_option import ToolCall, ToolOption -from src.llm_models.request_snapshot import ( - attach_request_snapshot, - deserialize_messages_snapshot, - format_request_snapshot_log_info, - save_failed_request_snapshot, - serialize_messages_snapshot, - serialize_response_request_snapshot, -) -from src.llm_models import request_snapshot - - -def _build_api_provider() -> APIProvider: - return APIProvider( - api_key="secret-token", - base_url="https://example.com/v1", - name="test-provider", - ) - - -def _build_model_info() -> ModelInfo: - return ModelInfo( - api_provider="test-provider", - model_identifier="demo-model", - name="demo-model", - ) - - -def _build_response_request() -> ResponseRequest: - tool_call = ToolCall( - args={"query": "MaiBot"}, - call_id="call_1", - func_name="search_web", - extra_content={"google": {"thought_signature": "c2lnbmF0dXJl"}}, - ) - message_list = [ - MessageBuilder().set_role(RoleType.User).add_text_content("你好").add_image_content("png", "ZmFrZQ==").build(), - MessageBuilder().set_role(RoleType.Assistant).set_tool_calls([tool_call]).build(), - MessageBuilder() - .set_role(RoleType.Tool) - .set_tool_call_id("call_1") - .set_tool_name("search_web") - .add_text_content('{"ok": true}') - .build(), - ] - return ResponseRequest( - extra_params={"trace_id": "trace-123"}, - max_tokens=256, - message_list=message_list, - model_info=_build_model_info(), - response_format=RespFormat(RespFormatType.JSON_OBJ), - temperature=0.2, - tool_options=[ToolOption(name="search_web", description="搜索网页")], - ) - - -def test_message_snapshot_roundtrip_preserves_tool_messages() -> None: - request = _build_response_request() - - snapshot_messages = serialize_messages_snapshot(request.message_list) - restored_messages = deserialize_messages_snapshot(snapshot_messages) - - assert len(restored_messages) == 3 - assert restored_messages[0].role == RoleType.User - assert restored_messages[0].get_text_content() == "你好" - assert restored_messages[0].parts[1].image_format == "png" - assert restored_messages[1].role == RoleType.Assistant - assert restored_messages[1].tool_calls is not None - assert restored_messages[1].tool_calls[0].func_name == "search_web" - assert restored_messages[1].tool_calls[0].args == {"query": "MaiBot"} - assert restored_messages[1].tool_calls[0].extra_content == {"google": {"thought_signature": "c2lnbmF0dXJl"}} - assert restored_messages[2].role == RoleType.Tool - assert restored_messages[2].tool_call_id == "call_1" - assert restored_messages[2].tool_name == "search_web" - - -def test_failed_request_snapshot_contains_replay_entry(tmp_path: Path, monkeypatch) -> None: - monkeypatch.setattr(request_snapshot, "LLM_REQUEST_LOG_DIR", tmp_path) - - request = _build_response_request() - provider = _build_api_provider() - snapshot_path = save_failed_request_snapshot( - api_provider=provider, - client_type="openai", - error=RuntimeError("boom"), - internal_request=serialize_response_request_snapshot(request), - model_info=request.model_info, - operation="chat.completions.create", - provider_request={"request_kwargs": {"model": request.model_info.model_identifier}}, - ) - - assert snapshot_path is not None - payload = json.loads(snapshot_path.read_text(encoding="utf-8")) - - assert payload["internal_request"]["request_kind"] == "response" - assert payload["api_provider"]["name"] == "test-provider" - assert payload["replay"]["file_uri"] == snapshot_path.as_uri() - assert str(snapshot_path) in payload["replay"]["command"] - assert "secret-token" not in snapshot_path.read_text(encoding="utf-8") - - -def test_format_request_snapshot_log_info_includes_path_uri_and_command(tmp_path: Path, monkeypatch) -> None: - monkeypatch.setattr(request_snapshot, "LLM_REQUEST_LOG_DIR", tmp_path) - - request = _build_response_request() - snapshot_path = save_failed_request_snapshot( - api_provider=_build_api_provider(), - client_type="openai", - error=ValueError("invalid"), - internal_request=serialize_response_request_snapshot(request), - model_info=request.model_info, - operation="chat.completions.create", - provider_request={"request_kwargs": {"messages": []}}, - ) - - assert snapshot_path is not None - exc = RuntimeError("wrapped") - attach_request_snapshot(exc, snapshot_path) - - log_info = format_request_snapshot_log_info(exc) - assert str(snapshot_path) in log_info - assert snapshot_path.as_uri() in log_info - assert "uv run python scripts/replay_llm_request.py" in log_info diff --git a/pytests/utils_test/test_session_utils.py b/pytests/utils_test/test_session_utils.py deleted file mode 100644 index c44e2eba..00000000 --- a/pytests/utils_test/test_session_utils.py +++ /dev/null @@ -1,42 +0,0 @@ -from types import SimpleNamespace - -from src.chat.message_receive.chat_manager import ChatManager -from src.common.utils.utils_session import SessionUtils - - -def test_calculate_session_id_distinguishes_account_and_scope() -> None: - base_session_id = SessionUtils.calculate_session_id("qq", user_id="42") - same_base_session_id = SessionUtils.calculate_session_id("qq", user_id="42") - account_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123") - route_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123", scope="main") - - assert base_session_id == same_base_session_id - assert account_scoped_session_id != base_session_id - assert route_scoped_session_id != account_scoped_session_id - - -def test_chat_manager_register_message_uses_route_metadata() -> None: - chat_manager = ChatManager() - message = SimpleNamespace( - platform="qq", - session_id="", - message_info=SimpleNamespace( - user_info=SimpleNamespace(user_id="42"), - group_info=SimpleNamespace(group_id="1000"), - additional_config={ - "platform_io_account_id": "123", - "platform_io_scope": "main", - }, - ), - ) - - chat_manager.register_message(message) - - assert message.session_id == SessionUtils.calculate_session_id( - "qq", - user_id="42", - group_id="1000", - account_id="123", - scope="main", - ) - assert chat_manager.last_messages[message.session_id] is message diff --git a/pytests/webui/__init__.py b/pytests/webui/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pytests/webui/test_app.py b/pytests/webui/test_app.py deleted file mode 100644 index bc7343ea..00000000 --- a/pytests/webui/test_app.py +++ /dev/null @@ -1,161 +0,0 @@ -from pathlib import Path -from unittest.mock import patch - -import pytest - -from src.webui import app as webui_app - - -def test_ensure_static_path_ready_uses_existing_static_path(tmp_path) -> None: - static_path = tmp_path / "dist" - static_path.mkdir() - (static_path / "index.html").write_text("", encoding="utf-8") - - with patch.object(webui_app, "_resolve_static_path", return_value=static_path): - result = webui_app._ensure_static_path_ready() - - assert result == static_path - - -def test_ensure_static_path_ready_logs_install_hint_when_static_assets_are_missing() -> None: - with ( - patch.object(webui_app, "_resolve_static_path", return_value=None), - patch.object(webui_app.logger, "warning") as warning_mock, - ): - result = webui_app._ensure_static_path_ready() - - assert result is None - warning_mock.assert_any_call(webui_app.t("startup.webui_static_assets_unavailable")) - warning_mock.assert_any_call( - webui_app.t("startup.webui_dashboard_package_hint", command=webui_app._MANUAL_INSTALL_COMMAND) - ) - - -def test_ensure_static_path_ready_logs_index_error_when_static_path_is_invalid(tmp_path) -> None: - static_path = tmp_path / "dist" - static_path.mkdir() - - with ( - patch.object(webui_app, "_resolve_static_path", return_value=static_path), - patch.object(webui_app.logger, "warning") as warning_mock, - ): - result = webui_app._ensure_static_path_ready() - - assert result is None - warning_mock.assert_any_call( - webui_app.t("startup.webui_index_missing", index_path=static_path / "index.html") - ) - warning_mock.assert_any_call( - webui_app.t("startup.webui_dashboard_package_hint", command=webui_app._MANUAL_INSTALL_COMMAND) - ) - - -def test_setup_static_files_does_not_duplicate_warning_when_static_path_is_unavailable() -> None: - app = webui_app.FastAPI() - - with ( - patch.object(webui_app, "_ensure_static_path_ready", return_value=None), - patch.object(webui_app.logger, "warning") as warning_mock, - ): - webui_app._setup_static_files(app) - - warning_mock.assert_not_called() - - -def test_resolve_static_path_prefers_installed_dashboard_package(monkeypatch, tmp_path) -> None: - package_dist = tmp_path / "site-packages" / "maibot_dashboard" / "dist" - package_dist.mkdir(parents=True) - - class _DashboardModule: - @staticmethod - def get_dist_path() -> Path: - return package_dist - - monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path) - - with patch.object(webui_app, "import_module", return_value=_DashboardModule()): - resolved_path = webui_app._resolve_static_path() - - assert resolved_path == package_dist - - -def test_resolve_static_path_ignores_dashboard_dist_when_package_is_unavailable(monkeypatch, tmp_path) -> None: - dashboard_dist = tmp_path / "dashboard" / "dist" - dashboard_dist.mkdir(parents=True) - (dashboard_dist / "index.html").write_text("", encoding="utf-8") - - monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path) - - with patch.object(webui_app, "import_module", side_effect=ImportError): - resolved_path = webui_app._resolve_static_path() - - assert resolved_path is None - - -def test_resolve_static_path_uses_package_even_when_dashboard_dist_exists(monkeypatch, tmp_path) -> None: - dashboard_dist = tmp_path / "dashboard" / "dist" - dashboard_dist.mkdir(parents=True) - - package_dist = tmp_path / "site-packages" / "maibot_dashboard" / "dist" - package_dist.mkdir(parents=True) - - class _DashboardModule: - @staticmethod - def get_dist_path() -> Path: - return package_dist - - monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path) - - with patch.object(webui_app, "import_module", return_value=_DashboardModule()): - resolved_path = webui_app._resolve_static_path() - - assert resolved_path == package_dist - - -def test_resolve_safe_static_file_path_allows_regular_static_file(tmp_path) -> None: - static_path = tmp_path / "dist" - asset_path = static_path / "assets" / "app.js" - asset_path.parent.mkdir(parents=True) - asset_path.write_text("console.log('ok')", encoding="utf-8") - - resolved_path = webui_app._resolve_safe_static_file_path(static_path, "assets/app.js") - - assert resolved_path == asset_path.resolve() - - -def test_resolve_safe_static_file_path_rejects_relative_path_traversal(tmp_path) -> None: - static_path = tmp_path / "dist" - static_path.mkdir() - - resolved_path = webui_app._resolve_safe_static_file_path(static_path, "../secret.txt") - - assert resolved_path is None - - -def test_resolve_safe_static_file_path_rejects_absolute_path_traversal(tmp_path) -> None: - static_path = tmp_path / "dist" - static_path.mkdir() - - resolved_path = webui_app._resolve_safe_static_file_path(static_path, "/etc/passwd") - - assert resolved_path is None - - -def test_resolve_safe_static_file_path_rejects_symlink_escape(tmp_path) -> None: - static_path = tmp_path / "dist" - static_path.mkdir() - - outside_dir = tmp_path / "outside" - outside_dir.mkdir() - outside_file = outside_dir / "secret.txt" - outside_file.write_text("secret", encoding="utf-8") - - link_path = static_path / "escape" - try: - link_path.symlink_to(outside_dir, target_is_directory=True) - except OSError as exc: - pytest.skip(f"symlink is not supported in this environment: {exc}") - - resolved_path = webui_app._resolve_safe_static_file_path(static_path, "escape/secret.txt") - - assert resolved_path is None diff --git a/pytests/webui/test_config_schema.py b/pytests/webui/test_config_schema.py deleted file mode 100644 index 498c6965..00000000 --- a/pytests/webui/test_config_schema.py +++ /dev/null @@ -1,147 +0,0 @@ -from src.config.official_configs import ChatConfig, MessageReceiveConfig -from src.config.config import Config -from src.config.config_base import ConfigBase, Field -from src.webui.config_schema import ConfigSchemaGenerator - - -def test_field_docs_in_schema(): - """Test that field descriptions are correctly extracted from field_docs (docstrings).""" - schema = ConfigSchemaGenerator.generate_schema(ChatConfig) - talk_value = next(f for f in schema["fields"] if f["name"] == "talk_value") - - # Verify description field exists - assert "description" in talk_value - # Verify description contains expected Chinese text from the docstring - assert "聊天频率" in talk_value["description"] - - -def test_json_schema_extra_merged(): - """Test that json_schema_extra fields are correctly merged into output.""" - schema = ConfigSchemaGenerator.generate_schema(ChatConfig) - talk_value = next(f for f in schema["fields"] if f["name"] == "talk_value") - - # Verify UI metadata fields from json_schema_extra exist - assert talk_value.get("x-widget") == "slider" - assert talk_value.get("x-icon") == "message-circle" - assert talk_value.get("step") == 0.1 - - -def test_pydantic_constraints_mapped(): - """Test that Pydantic constraints (ge/le) are correctly mapped to minValue/maxValue.""" - schema = ConfigSchemaGenerator.generate_schema(ChatConfig) - talk_value = next(f for f in schema["fields"] if f["name"] == "talk_value") - - # Verify constraints are mapped to frontend naming convention - assert "minValue" in talk_value - assert "maxValue" in talk_value - assert talk_value["minValue"] == 0 # From ge=0 - assert talk_value["maxValue"] == 1 # From le=1 - - -def test_nested_model_schema(): - """Test that nested models (ConfigBase fields) are correctly handled.""" - schema = ConfigSchemaGenerator.generate_schema(Config) - - # Verify nested structure exists - assert "nested" in schema - assert "chat" in schema["nested"] - - # Verify nested chat schema is complete - chat_schema = schema["nested"]["chat"] - assert chat_schema["className"] == "ChatConfig" - assert "fields" in chat_schema - - # Verify nested schema fields include description and metadata - talk_value = next(f for f in chat_schema["fields"] if f["name"] == "talk_value") - assert "description" in talk_value - assert talk_value.get("x-widget") == "slider" - assert talk_value.get("minValue") == 0 - - -def test_field_without_extra_metadata(): - """Test that fields without json_schema_extra still generate valid schema.""" - schema = ConfigSchemaGenerator.generate_schema(ChatConfig) - inevitable_at_reply = next(f for f in schema["fields"] if f["name"] == "inevitable_at_reply") - - # Verify basic fields are generated - assert "name" in inevitable_at_reply - assert inevitable_at_reply["name"] == "inevitable_at_reply" - assert "type" in inevitable_at_reply - assert inevitable_at_reply["type"] == "boolean" - assert "label" in inevitable_at_reply - assert "required" in inevitable_at_reply - - # Verify no x-widget or x-icon from json_schema_extra (since field has none) - # These fields should only be present if explicitly defined in json_schema_extra - assert not inevitable_at_reply.get("x-widget") - assert not inevitable_at_reply.get("x-icon") - - -def test_all_top_level_sections_have_ui_metadata(): - """所有顶层配置节都必须声明 uiParent 或独立 Tab 的标签与图标。""" - schema = ConfigSchemaGenerator.generate_schema(Config) - - for section_name, section_schema in schema["nested"].items(): - has_parent = bool(section_schema.get("uiParent")) - has_host_meta = bool(section_schema.get("uiLabel")) and bool(section_schema.get("uiIcon")) - assert has_parent or has_host_meta, f"{section_name} 缺少 UI 元数据" - - -def test_maisaka_is_host_tab_and_mcp_is_attached_to_it(): - """MaiSaka 应作为独立 Tab,MCP 作为其子配置挂载。""" - schema = ConfigSchemaGenerator.generate_schema(Config) - - maisaka_schema = schema["nested"]["maisaka"] - mcp_schema = schema["nested"]["mcp"] - - assert maisaka_schema.get("uiParent") is None - assert maisaka_schema.get("uiLabel") == "MaiSaka" - assert maisaka_schema.get("uiIcon") == "message-circle" - assert mcp_schema.get("uiParent") == "maisaka" - - -def test_memory_query_config_fields_are_exposed(): - """query_memory 开关和默认条数应出现在记忆配置 schema 中。""" - schema = ConfigSchemaGenerator.generate_schema(Config) - memory_schema = schema["nested"]["memory"] - - assert memory_schema.get("uiParent") == "emoji" - - enable_field = next(field for field in memory_schema["fields"] if field["name"] == "enable_memory_query_tool") - limit_field = next(field for field in memory_schema["fields"] if field["name"] == "memory_query_default_limit") - - assert enable_field["type"] == "boolean" - assert enable_field.get("x-widget") == "switch" - assert enable_field.get("x-icon") == "database" - - assert limit_field["type"] == "integer" - assert limit_field.get("x-widget") == "input" - assert limit_field.get("x-icon") == "hash" - assert limit_field.get("minValue") == 1 - assert limit_field.get("maxValue") == 20 - - -def test_set_field_is_mapped_as_array(): - """set[str] 应映射为前端可识别的 array。""" - schema = ConfigSchemaGenerator.generate_schema(MessageReceiveConfig) - ban_words = next(field for field in schema["fields"] if field["name"] == "ban_words") - - assert ban_words["type"] == "array" - assert ban_words["items"]["type"] == "string" - - -def test_advanced_fields_are_hidden_from_webui_schema(): - """advanced=True 的字段不应出现在 WebUI 配置 schema 中,未声明时默认展示。""" - - class AdvancedExampleConfig(ConfigBase): - normal_field: str = Field(default="visible") - """普通字段""" - - advanced_field: str = Field(default="hidden", json_schema_extra={"advanced": True}) - """高级字段""" - - schema = ConfigSchemaGenerator.generate_schema(AdvancedExampleConfig) - field_names = {field["name"] for field in schema["fields"]} - - assert "normal_field" in field_names - assert "advanced_field" not in field_names diff --git a/pytests/webui/test_emoji_routes.py b/pytests/webui/test_emoji_routes.py deleted file mode 100644 index 92900726..00000000 --- a/pytests/webui/test_emoji_routes.py +++ /dev/null @@ -1,461 +0,0 @@ -"""表情包路由 API 测试 - -测试 src/webui/routers/emoji.py 中的核心 emoji 路由端点 -使用内存 SQLite 数据库和 FastAPI TestClient -""" - -from contextlib import contextmanager -from datetime import datetime -from typing import Generator -from unittest.mock import patch - -import pytest - -from fastapi import FastAPI -from fastapi.testclient import TestClient -from sqlalchemy.pool import StaticPool -from sqlmodel import Session, SQLModel, create_engine - -from src.common.database.database_model import Images, ImageType -from src.webui.core import TokenManager -from src.webui.routers.emoji import router - - -@pytest.fixture(scope="function") -def test_engine(): - """创建内存 SQLite 引擎用于测试""" - engine = create_engine( - "sqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - SQLModel.metadata.create_all(engine) - return engine - - -@pytest.fixture(scope="function") -def test_session(test_engine) -> Generator[Session, None, None]: - """创建测试数据库会话""" - with Session(test_engine) as session: - yield session - - -@pytest.fixture(scope="function") -def test_app(test_session): - """创建测试 FastAPI 应用并覆盖 get_db_session 依赖""" - app = FastAPI() - app.include_router(router) - - # Create a context manager that yields the test session - @contextmanager - def override_get_db_session(auto_commit=True): - """Override get_db_session to use test session""" - try: - yield test_session - if auto_commit: - test_session.commit() - except Exception: - test_session.rollback() - raise - - with patch("src.webui.routers.emoji.get_db_session", override_get_db_session): - yield app - - -@pytest.fixture(scope="function") -def client(test_app): - """创建 TestClient""" - return TestClient(test_app) - - -@pytest.fixture(scope="function") -def auth_token(): - """创建有效的认证 token""" - token_manager = TokenManager(secret_key="test-secret-key", token_expire_hours=24) - return token_manager.create_token() - - -@pytest.fixture(scope="function") -def sample_emojis(test_session) -> list[Images]: - """插入测试用表情包数据""" - import hashlib - - emojis = [ - Images( - image_type=ImageType.EMOJI, - full_path="/data/emoji_registed/test1.png", - image_hash=hashlib.sha256(b"test1").hexdigest(), - description="测试表情包 1", - emotion="开心,快乐", - query_count=10, - is_registered=True, - is_banned=False, - record_time=datetime(2026, 1, 1, 10, 0, 0), - register_time=datetime(2026, 1, 1, 10, 0, 0), - last_used_time=datetime(2026, 1, 2, 10, 0, 0), - ), - Images( - image_type=ImageType.EMOJI, - full_path="/data/emoji_registed/test2.gif", - image_hash=hashlib.sha256(b"test2").hexdigest(), - description="测试表情包 2", - emotion="难过", - query_count=5, - is_registered=False, - is_banned=False, - record_time=datetime(2026, 1, 3, 10, 0, 0), - register_time=None, - last_used_time=None, - ), - Images( - image_type=ImageType.EMOJI, - full_path="/data/emoji_registed/test3.webp", - image_hash=hashlib.sha256(b"test3").hexdigest(), - description="测试表情包 3", - emotion="生气", - query_count=20, - is_registered=True, - is_banned=True, - record_time=datetime(2026, 1, 4, 10, 0, 0), - register_time=datetime(2026, 1, 4, 10, 0, 0), - last_used_time=datetime(2026, 1, 5, 10, 0, 0), - ), - ] - - for emoji in emojis: - test_session.add(emoji) - test_session.commit() - - for emoji in emojis: - test_session.refresh(emoji) - - return emojis - - -@pytest.fixture(scope="function") -def mock_token_verify(): - """Mock token verification to always succeed""" - with patch("src.webui.routers.emoji.verify_auth_token", return_value=True): - yield - - -# ==================== 测试用例 ==================== - - -def test_list_emojis_basic(client, sample_emojis, mock_token_verify): - """测试获取表情包列表(基本分页)""" - response = client.get("/emoji/list?page=1&page_size=10") - - assert response.status_code == 200 - data = response.json() - - assert data["success"] is True - assert data["total"] == 3 - assert data["page"] == 1 - assert data["page_size"] == 10 - assert len(data["data"]) == 3 - - # 验证第一个表情包字段 - emoji = data["data"][0] - assert "id" in emoji - assert "full_path" in emoji - assert "emoji_hash" in emoji - assert "description" in emoji - assert "query_count" in emoji - assert "is_registered" in emoji - assert "is_banned" in emoji - assert "emotion" in emoji - assert "record_time" in emoji - assert "register_time" in emoji - assert "last_used_time" in emoji - - -def test_list_emojis_pagination(client, sample_emojis, mock_token_verify): - """测试分页功能""" - response = client.get("/emoji/list?page=1&page_size=2") - - assert response.status_code == 200 - data = response.json() - - assert data["success"] is True - assert data["total"] == 3 - assert len(data["data"]) == 2 - - # 第二页 - response = client.get("/emoji/list?page=2&page_size=2") - data = response.json() - assert len(data["data"]) == 1 - - -def test_list_emojis_search(client, sample_emojis, mock_token_verify): - """测试搜索过滤""" - response = client.get("/emoji/list?search=表情包 2") - - assert response.status_code == 200 - data = response.json() - - assert data["success"] is True - assert data["total"] == 1 - assert data["data"][0]["description"] == "测试表情包 2" - - -def test_list_emojis_filter_registered(client, sample_emojis, mock_token_verify): - """测试 is_registered 过滤""" - response = client.get("/emoji/list?is_registered=true") - - assert response.status_code == 200 - data = response.json() - - assert data["success"] is True - assert data["total"] == 2 - assert all(emoji["is_registered"] is True for emoji in data["data"]) - - -def test_list_emojis_filter_banned(client, sample_emojis, mock_token_verify): - """测试 is_banned 过滤""" - response = client.get("/emoji/list?is_banned=true") - - assert response.status_code == 200 - data = response.json() - - assert data["success"] is True - assert data["total"] == 1 - assert data["data"][0]["is_banned"] is True - - -def test_list_emojis_sort_by_query_count(client, sample_emojis, mock_token_verify): - """测试按 query_count 排序""" - response = client.get("/emoji/list?sort_by=query_count&sort_order=desc") - - assert response.status_code == 200 - data = response.json() - - assert data["success"] is True - # 验证降序排列 (20 > 10 > 5) - assert data["data"][0]["query_count"] == 20 - assert data["data"][1]["query_count"] == 10 - assert data["data"][2]["query_count"] == 5 - - -def test_get_emoji_detail_success(client, sample_emojis, mock_token_verify): - """测试获取表情包详情(成功)""" - emoji_id = sample_emojis[0].id - response = client.get(f"/emoji/{emoji_id}") - - assert response.status_code == 200 - data = response.json() - - assert data["success"] is True - assert data["data"]["id"] == emoji_id - assert data["data"]["emoji_hash"] == sample_emojis[0].image_hash - - -def test_get_emoji_detail_not_found(client, mock_token_verify): - """测试获取不存在的表情包(404)""" - response = client.get("/emoji/99999") - - assert response.status_code == 404 - data = response.json() - assert "未找到" in data["detail"] - - -def test_update_emoji_description(client, sample_emojis, mock_token_verify): - """测试更新表情包描述""" - emoji_id = sample_emojis[0].id - response = client.patch( - f"/emoji/{emoji_id}", - json={"description": "更新后的描述"}, - ) - - assert response.status_code == 200 - data = response.json() - - assert data["success"] is True - assert data["data"]["description"] == "更新后的描述" - assert "成功更新" in data["message"] - - -def test_update_emoji_register_status(client, sample_emojis, mock_token_verify, test_session): - """测试更新注册状态(False -> True 应设置 register_time)""" - emoji_id = sample_emojis[1].id # 未注册的表情包 - response = client.patch( - f"/emoji/{emoji_id}", - json={"is_registered": True}, - ) - - assert response.status_code == 200 - data = response.json() - - assert data["success"] is True - assert data["data"]["is_registered"] is True - assert data["data"]["register_time"] is not None # 应该设置了注册时间 - - -def test_update_emoji_no_fields(client, sample_emojis, mock_token_verify): - """测试更新请求未提供任何字段(400)""" - emoji_id = sample_emojis[0].id - response = client.patch(f"/emoji/{emoji_id}", json={}) - - assert response.status_code == 400 - data = response.json() - assert "未提供任何需要更新的字段" in data["detail"] - - -def test_update_emoji_not_found(client, mock_token_verify): - """测试更新不存在的表情包(404)""" - response = client.patch("/emoji/99999", json={"description": "test"}) - - assert response.status_code == 404 - data = response.json() - assert "未找到" in data["detail"] - - -def test_delete_emoji_success(client, sample_emojis, mock_token_verify, test_session): - """测试删除表情包(成功)""" - emoji_id = sample_emojis[0].id - response = client.delete(f"/emoji/{emoji_id}") - - assert response.status_code == 200 - data = response.json() - - assert data["success"] is True - assert "成功删除" in data["message"] - - # 验证数据库中已删除 - from sqlmodel import select - - statement = select(Images).where(Images.id == emoji_id) - result = test_session.exec(statement).first() - assert result is None - - -def test_delete_emoji_not_found(client, mock_token_verify): - """测试删除不存在的表情包(404)""" - response = client.delete("/emoji/99999") - - assert response.status_code == 404 - data = response.json() - assert "未找到" in data["detail"] - - -def test_batch_delete_success(client, sample_emojis, mock_token_verify, test_session): - """测试批量删除表情包(全部成功)""" - emoji_ids = [sample_emojis[0].id, sample_emojis[1].id] - response = client.post("/emoji/batch/delete", json={"emoji_ids": emoji_ids}) - - assert response.status_code == 200 - data = response.json() - - assert data["success"] is True - assert data["deleted_count"] == 2 - assert data["failed_count"] == 0 - assert "成功删除 2 个表情包" in data["message"] - - # 验证数据库中已删除 - from sqlmodel import select - - for emoji_id in emoji_ids: - statement = select(Images).where(Images.id == emoji_id) - result = test_session.exec(statement).first() - assert result is None - - -def test_batch_delete_partial_failure(client, sample_emojis, mock_token_verify): - """测试批量删除(部分失败)""" - emoji_ids = [sample_emojis[0].id, 99999] # 第二个 ID 不存在 - response = client.post("/emoji/batch/delete", json={"emoji_ids": emoji_ids}) - - assert response.status_code == 200 - data = response.json() - - assert data["success"] is True - assert data["deleted_count"] == 1 - assert data["failed_count"] == 1 - assert 99999 in data["failed_ids"] - - -def test_batch_delete_empty_list(client, mock_token_verify): - """测试批量删除空列表(400)""" - response = client.post("/emoji/batch/delete", json={"emoji_ids": []}) - - assert response.status_code == 400 - data = response.json() - assert "未提供要删除的表情包ID" in data["detail"] - - -def test_auth_required_list(client): - """测试未认证访问列表端点(401)""" - # Without mock_token_verify fixture - with patch("src.webui.routers.emoji.verify_auth_token", return_value=False): - client.get("/emoji/list") - # verify_auth_token 返回 False 会触发 HTTPException - # 但具体状态码取决于 verify_auth_token_from_cookie_or_header 的实现 - # 这里假设它抛出 401 - - -def test_auth_required_update(client, sample_emojis): - """测试未认证访问更新端点(401)""" - with patch("src.webui.routers.emoji.verify_auth_token", return_value=False): - emoji_id = sample_emojis[0].id - client.patch(f"/emoji/{emoji_id}", json={"description": "test"}) - # Should be unauthorized - - -def test_emoji_to_response_field_mapping(sample_emojis): - """测试 emoji_to_response 字段映射(image_hash -> emoji_hash)""" - from src.webui.routers.emoji import emoji_to_response - - emoji = sample_emojis[0] - response = emoji_to_response(emoji) - - # 验证 API 字段名称 - assert hasattr(response, "emoji_hash") - assert response.emoji_hash == emoji.image_hash - - # 验证时间戳转换 - assert isinstance(response.record_time, float) - assert response.record_time == emoji.record_time.timestamp() - - if emoji.register_time: - assert isinstance(response.register_time, float) - assert response.register_time == emoji.register_time.timestamp() - - -def test_list_emojis_only_emoji_type(client, test_session, mock_token_verify): - """测试列表只返回 type=EMOJI 的记录(不包括其他类型)""" - # 插入一个非 EMOJI 类型的图片 - non_emoji = Images( - image_type=ImageType.IMAGE, # 不是 EMOJI - full_path="/data/images/test.png", - image_hash="hash_image", - description="非表情包图片", - query_count=0, - is_registered=False, - is_banned=False, - record_time=datetime.now(), - ) - test_session.add(non_emoji) - test_session.commit() - - # 插入一个 EMOJI 类型 - emoji = Images( - image_type=ImageType.EMOJI, - full_path="/data/emoji_registed/emoji.png", - image_hash="hash_emoji", - description="表情包", - query_count=0, - is_registered=True, - is_banned=False, - record_time=datetime.now(), - ) - test_session.add(emoji) - test_session.commit() - - response = client.get("/emoji/list") - - assert response.status_code == 200 - data = response.json() - - # 只应该返回 1 个 EMOJI 类型的记录 - assert data["total"] == 1 - assert data["data"][0]["description"] == "表情包" diff --git a/pytests/webui/test_expression_routes.py b/pytests/webui/test_expression_routes.py deleted file mode 100644 index 45e476e1..00000000 --- a/pytests/webui/test_expression_routes.py +++ /dev/null @@ -1,529 +0,0 @@ -"""Expression routes pytest tests""" - -from typing import Generator - -import pytest -from fastapi import APIRouter, FastAPI -from fastapi.testclient import TestClient -from sqlalchemy import text -from sqlalchemy.pool import StaticPool -from sqlmodel import Session, SQLModel, create_engine, select - -from src.common.database.database_model import Expression, ModifiedBy -from src.webui.dependencies import require_auth - - -def create_test_app() -> FastAPI: - """Create minimal test app with only expression router""" - app = FastAPI(title="Test App") - from src.webui.routers.expression import router as expression_router - - main_router = APIRouter(prefix="/api/webui") - main_router.include_router(expression_router) - app.include_router(main_router) - - return app - - -app = create_test_app() - - -# Test database setup -@pytest.fixture(name="test_engine") -def test_engine_fixture(): - """Create in-memory SQLite database for testing""" - engine = create_engine( - "sqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - SQLModel.metadata.create_all(engine) - return engine - - -@pytest.fixture(name="test_session") -def test_session_fixture(test_engine) -> Generator[Session, None, None]: - """Create a test database session with transaction rollback""" - connection = test_engine.connect() - transaction = connection.begin() - session = Session(bind=connection) - - yield session - - session.close() - transaction.rollback() - connection.close() - - -@pytest.fixture(name="client") -def client_fixture(test_session: Session, monkeypatch) -> Generator[TestClient, None, None]: - """Create TestClient with overridden database session""" - from contextlib import contextmanager - - @contextmanager - def get_test_db_session(): - yield test_session - test_session.commit() - - monkeypatch.setattr("src.webui.routers.expression.get_db_session", get_test_db_session) - - with TestClient(app) as client: - yield client - - -@pytest.fixture(name="mock_auth") -def mock_auth_fixture(): - """Mock authentication to always return True""" - app.dependency_overrides[require_auth] = lambda: "test-token" - yield - app.dependency_overrides.clear() - - -@pytest.fixture(name="sample_expression") -def sample_expression_fixture(test_session: Session) -> Expression: - """Insert a sample expression into test database""" - test_session.execute( - text( - "INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) " - "VALUES (1, '测试情景', '测试风格', '[\"测试内容1\", \"测试内容2\"]', 10, '2026-02-17 12:00:00', '2026-02-15 10:00:00', 'test_chat_001', 0, 0)" - ) - ) - test_session.commit() - - expression = test_session.exec(select(Expression).where(Expression.id == 1)).first() - assert expression is not None - return expression - - -# ============ Tests ============ - - -def test_list_expressions_empty(client: TestClient, mock_auth): - """Test GET /expression/list with empty database""" - response = client.get("/api/webui/expression/list") - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - assert data["total"] == 0 - assert data["page"] == 1 - assert data["page_size"] == 20 - assert data["data"] == [] - - -def test_list_expressions_with_data(client: TestClient, mock_auth, sample_expression: Expression): - """Test GET /expression/list returns expression data""" - response = client.get("/api/webui/expression/list") - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - assert data["total"] == 1 - assert len(data["data"]) == 1 - - expr_data = data["data"][0] - assert expr_data["id"] == sample_expression.id - assert expr_data["situation"] == "测试情景" - assert expr_data["style"] == "测试风格" - assert expr_data["chat_id"] == "test_chat_001" - - -def test_list_expressions_pagination(client: TestClient, mock_auth, test_session: Session): - """Test GET /expression/list pagination works correctly""" - for i in range(5): - test_session.execute( - text( - f"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) " - f"VALUES ({i + 1}, '情景{i}', '风格{i}', '[]', 0, '2026-02-17 12:0{i}:00', '2026-02-15 10:00:00', 'chat_{i}', 0, 0)" - ) - ) - test_session.commit() - - # Request page 1 with page_size=2 - response = client.get("/api/webui/expression/list?page=1&page_size=2") - assert response.status_code == 200 - - data = response.json() - assert data["total"] == 5 - assert data["page"] == 1 - assert data["page_size"] == 2 - assert len(data["data"]) == 2 - - # Request page 2 - response = client.get("/api/webui/expression/list?page=2&page_size=2") - data = response.json() - assert data["page"] == 2 - assert len(data["data"]) == 2 - - -def test_list_expressions_search(client: TestClient, mock_auth, test_session: Session): - """Test GET /expression/list with search filter""" - test_session.execute( - text( - "INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) " - "VALUES (1, '找人吃饭', '热情', '[]', 0, datetime('now'), datetime('now'), 'chat_001', 0, 0)" - ) - ) - test_session.execute( - text( - "INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) " - "VALUES (2, '拒绝邀请', '礼貌', '[]', 0, datetime('now'), datetime('now'), 'chat_002', 0, 0)" - ) - ) - test_session.commit() - - # Search for "吃饭" - response = client.get("/api/webui/expression/list?search=吃饭") - assert response.status_code == 200 - - data = response.json() - assert data["total"] == 1 - assert data["data"][0]["situation"] == "找人吃饭" - - -def test_list_expressions_chat_filter(client: TestClient, mock_auth, test_session: Session): - """Test GET /expression/list with chat_id filter""" - test_session.execute( - text( - "INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) " - "VALUES (1, '情景A', '风格A', '[]', 0, datetime('now'), datetime('now'), 'chat_A', 0, 0)" - ) - ) - test_session.execute( - text( - "INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) " - "VALUES (2, '情景B', '风格B', '[]', 0, datetime('now'), datetime('now'), 'chat_B', 0, 0)" - ) - ) - test_session.commit() - - # Filter by chat_A - response = client.get("/api/webui/expression/list?chat_id=chat_A") - assert response.status_code == 200 - - data = response.json() - assert data["total"] == 1 - assert data["data"][0]["situation"] == "情景A" - assert data["data"][0]["chat_id"] == "chat_A" - - -def test_get_expression_detail_success(client: TestClient, mock_auth, sample_expression: Expression): - """Test GET /expression/{id} returns correct detail""" - response = client.get(f"/api/webui/expression/{sample_expression.id}") - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - assert data["data"]["id"] == sample_expression.id - assert data["data"]["situation"] == "测试情景" - assert data["data"]["style"] == "测试风格" - assert data["data"]["chat_id"] == "test_chat_001" - - -def test_get_expression_detail_not_found(client: TestClient, mock_auth): - """Test GET /expression/{id} returns 404 for non-existent ID""" - response = client.get("/api/webui/expression/99999") - assert response.status_code == 404 - - data = response.json() - assert "未找到" in data["detail"] - - -def test_expression_response_has_legacy_fields(client: TestClient, mock_auth, sample_expression: Expression): - """Test that ExpressionResponse includes legacy fields (checked/rejected/modified_by)""" - response = client.get(f"/api/webui/expression/{sample_expression.id}") - assert response.status_code == 200 - - data = response.json()["data"] - - # Verify legacy fields exist and have default values - assert "checked" in data - assert "rejected" in data - assert "modified_by" in data - - # Verify hardcoded default values (from expression_to_response) - assert data["checked"] is False - assert data["rejected"] is False - assert data["modified_by"] is None - - -def test_update_expression_without_removed_fields(client: TestClient, mock_auth, sample_expression: Expression): - """Test PATCH /expression/{id} does not accept checked/rejected fields""" - # Valid update request (only allowed fields) - update_payload = { - "situation": "更新后的情景", - "style": "更新后的风格", - } - - response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload) - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - assert data["data"]["situation"] == "更新后的情景" - assert data["data"]["style"] == "更新后的风格" - - # Verify legacy fields still returned (hardcoded values) - assert data["data"]["checked"] is False - assert data["data"]["rejected"] is False - - -def test_update_expression_ignores_invalid_fields(client: TestClient, mock_auth, sample_expression: Expression): - """Test PATCH /expression/{id} ignores fields not in ExpressionUpdateRequest""" - # Request with invalid field (checked not in schema) - update_payload = { - "situation": "新情景", - "checked": True, # This field should be ignored by Pydantic - "rejected": True, # This field should be ignored - } - - response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload) - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - assert data["data"]["situation"] == "新情景" - - # Response should have hardcoded False values (not True from request) - assert data["data"]["checked"] is False - assert data["data"]["rejected"] is False - - -def test_update_expression_chat_id_mapping(client: TestClient, mock_auth, sample_expression: Expression): - """Test PATCH /expression/{id} correctly maps chat_id to session_id""" - update_payload = {"chat_id": "updated_chat_999"} - - response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload) - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - - # Verify chat_id is returned in response (mapped from session_id) - assert data["data"]["chat_id"] == "updated_chat_999" - - -def test_update_expression_not_found(client: TestClient, mock_auth): - """Test PATCH /expression/{id} returns 404 for non-existent ID""" - update_payload = {"situation": "新情景"} - - response = client.patch("/api/webui/expression/99999", json=update_payload) - assert response.status_code == 404 - - data = response.json() - assert "未找到" in data["detail"] - - -def test_update_expression_empty_request(client: TestClient, mock_auth, sample_expression: Expression): - """Test PATCH /expression/{id} returns 400 for empty update request""" - update_payload = {} - - response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload) - assert response.status_code == 400 - - data = response.json() - assert "未提供任何需要更新的字段" in data["detail"] - - -def test_delete_expression_success(client: TestClient, mock_auth, sample_expression: Expression): - """Test DELETE /expression/{id} successfully deletes expression""" - expression_id = sample_expression.id - - response = client.delete(f"/api/webui/expression/{expression_id}") - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - assert "成功删除" in data["message"] - - # Verify expression is deleted - get_response = client.get(f"/api/webui/expression/{expression_id}") - assert get_response.status_code == 404 - - -def test_delete_expression_not_found(client: TestClient, mock_auth): - """Test DELETE /expression/{id} returns 404 for non-existent ID""" - response = client.delete("/api/webui/expression/99999") - assert response.status_code == 404 - - data = response.json() - assert "未找到" in data["detail"] - - -def test_create_expression_success(client: TestClient, mock_auth): - """Test POST /expression/ successfully creates expression""" - create_payload = { - "situation": "新建情景", - "style": "新建风格", - "chat_id": "new_chat_123", - } - - response = client.post("/api/webui/expression/", json=create_payload) - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - assert "创建成功" in data["message"] - assert data["data"]["situation"] == "新建情景" - assert data["data"]["style"] == "新建风格" - assert data["data"]["chat_id"] == "new_chat_123" - - # Verify legacy fields - assert data["data"]["checked"] is False - assert data["data"]["rejected"] is False - assert data["data"]["modified_by"] is None - - -def test_batch_delete_expressions_success(client: TestClient, mock_auth, test_session: Session): - """Test POST /expression/batch/delete successfully deletes multiple expressions""" - expression_ids = [] - for i in range(3): - test_session.execute( - text( - f"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) " - f"VALUES ({i + 1}, '批量删除{i}', '风格{i}', '[]', 0, datetime('now'), datetime('now'), 'chat_{i}', 0, 0)" - ) - ) - expression_ids.append(i + 1) - test_session.commit() - - delete_payload = {"ids": expression_ids} - response = client.post("/api/webui/expression/batch/delete", json=delete_payload) - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - assert "成功删除 3 个" in data["message"] - - for expr_id in expression_ids: - get_response = client.get(f"/api/webui/expression/{expr_id}") - assert get_response.status_code == 404 - - -def test_batch_delete_partial_not_found(client: TestClient, mock_auth, sample_expression: Expression): - """Test POST /expression/batch/delete handles partial not found IDs""" - delete_payload = {"ids": [sample_expression.id, 88888, 99999]} - - response = client.post("/api/webui/expression/batch/delete", json=delete_payload) - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - # Should delete only the 1 valid ID - assert "成功删除 1 个" in data["message"] - - -def test_get_expression_stats(client: TestClient, mock_auth, test_session: Session): - """Test GET /expression/stats/summary returns correct statistics""" - for i in range(3): - test_session.execute( - text( - f"INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) " - f"VALUES ({i + 1}, '情景{i}', '风格{i}', '[]', 0, datetime('now'), datetime('now'), 'chat_{i % 2}', 0, 0)" - ) - ) - test_session.commit() - - response = client.get("/api/webui/expression/stats/summary") - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - assert data["data"]["total"] == 3 - assert data["data"]["chat_count"] == 2 - - -def test_get_review_stats(client: TestClient, mock_auth, test_session: Session): - """Test GET /expression/review/stats returns review status counts""" - test_session.execute( - text( - "INSERT INTO expressions (id, situation, style, content_list, count, last_active_time, create_time, session_id, checked, rejected) " - "VALUES (1, '待审核', '风格', '[]', 0, datetime('now'), datetime('now'), 'chat_001', 0, 0)" - ) - ) - test_session.commit() - - response = client.get("/api/webui/expression/review/stats") - assert response.status_code == 200 - - data = response.json() - assert data["total"] == 1 # Total expressions exists - assert data["unchecked"] == 1 - assert data["passed"] == 0 - assert data["rejected"] == 0 - assert data["ai_checked"] == 0 - assert data["user_checked"] == 0 - - -def test_get_review_list_filter_unchecked(client: TestClient, mock_auth, sample_expression: Expression): - """Test GET /expression/review/list with filter_type=unchecked returns unchecked expressions""" - response = client.get("/api/webui/expression/review/list?filter_type=unchecked") - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - assert data["total"] == 1 - assert len(data["data"]) == 1 - - -def test_get_review_list_filter_all(client: TestClient, mock_auth, sample_expression: Expression): - """Test GET /expression/review/list with filter_type=all returns all expressions""" - response = client.get("/api/webui/expression/review/list?filter_type=all") - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - assert data["total"] == 1 - assert len(data["data"]) == 1 - - -def test_batch_review_expressions_with_unchecked_marker(client: TestClient, mock_auth, sample_expression: Expression): - """Test POST /expression/review/batch succeeds with require_unchecked=True""" - review_payload = {"items": [{"id": sample_expression.id, "rejected": False, "require_unchecked": True}]} - - response = client.post("/api/webui/expression/review/batch", json=review_payload) - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - assert data["succeeded"] == 1 - assert data["results"][0]["success"] is True - - -def test_batch_review_expressions_overwrites_ai_checked( - client: TestClient, mock_auth, test_session: Session, sample_expression: Expression -): - """Test POST /expression/review/batch lets manual review override AI checked state""" - sample_expression.checked = True - sample_expression.rejected = True - sample_expression.modified_by = ModifiedBy.AI - test_session.add(sample_expression) - test_session.commit() - - review_payload = {"items": [{"id": sample_expression.id, "rejected": False, "require_unchecked": True}]} - - response = client.post("/api/webui/expression/review/batch", json=review_payload) - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - assert data["succeeded"] == 1 - test_session.expire_all() - reviewed_expression = test_session.exec(select(Expression).where(Expression.id == sample_expression.id)).first() - assert reviewed_expression is not None - assert reviewed_expression.checked is True - assert reviewed_expression.rejected is False - assert reviewed_expression.modified_by == ModifiedBy.USER - - -def test_batch_review_expressions_no_unchecked_check(client: TestClient, mock_auth, sample_expression: Expression): - """Test POST /expression/review/batch succeeds when require_unchecked=False""" - review_payload = {"items": [{"id": sample_expression.id, "rejected": False, "require_unchecked": False}]} - - response = client.post("/api/webui/expression/review/batch", json=review_payload) - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - assert data["succeeded"] == 1 - assert data["results"][0]["success"] is True diff --git a/pytests/webui/test_jargon_routes.py b/pytests/webui/test_jargon_routes.py deleted file mode 100644 index 8251c98d..00000000 --- a/pytests/webui/test_jargon_routes.py +++ /dev/null @@ -1,512 +0,0 @@ -"""测试 jargon 路由的完整性和正确性""" - -import json -from datetime import datetime - -import pytest -from fastapi import FastAPI -from fastapi.testclient import TestClient -from sqlalchemy.pool import StaticPool -from sqlmodel import Session, SQLModel, create_engine - -from src.common.database.database_model import ChatSession, Jargon -from src.webui.routers.jargon import router as jargon_router - - -@pytest.fixture(name="app", scope="function") -def app_fixture(): - app = FastAPI() - app.include_router(jargon_router, prefix="/api/webui") - return app - - -@pytest.fixture(name="engine", scope="function") -def engine_fixture(): - engine = create_engine( - "sqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - SQLModel.metadata.create_all(engine) - yield engine - - -@pytest.fixture(name="session", scope="function") -def session_fixture(engine): - connection = engine.connect() - transaction = connection.begin() - session = Session(bind=connection) - - yield session - - session.close() - transaction.rollback() - connection.close() - - -@pytest.fixture(name="client", scope="function") -def client_fixture(app: FastAPI, session: Session, monkeypatch): - from contextlib import contextmanager - - @contextmanager - def mock_get_db_session(): - yield session - - monkeypatch.setattr("src.webui.routers.jargon.get_db_session", mock_get_db_session) - - with TestClient(app) as client: - yield client - - -@pytest.fixture(name="sample_chat_session") -def sample_chat_session_fixture(session: Session): - """创建示例 ChatSession""" - chat_session = ChatSession( - session_id="test_stream_001", - platform="qq", - group_id="123456789", - user_id=None, - created_timestamp=datetime.now(), - last_active_timestamp=datetime.now(), - ) - session.add(chat_session) - session.commit() - session.refresh(chat_session) - return chat_session - - -@pytest.fixture(name="sample_jargons") -def sample_jargons_fixture(session: Session, sample_chat_session: ChatSession): - """创建示例 Jargon 数据""" - jargons = [ - Jargon( - id=1, - content="yyds", - raw_content="永远的神", - meaning="永远的神", - session_id=sample_chat_session.session_id, - count=10, - is_jargon=True, - is_complete=False, - ), - Jargon( - id=2, - content="awsl", - raw_content="啊我死了", - meaning="啊我死了", - session_id=sample_chat_session.session_id, - count=5, - is_jargon=True, - is_complete=False, - ), - Jargon( - id=3, - content="hello", - raw_content=None, - meaning="你好", - session_id=sample_chat_session.session_id, - count=2, - is_jargon=False, - is_complete=False, - ), - ] - for jargon in jargons: - session.add(jargon) - session.commit() - for jargon in jargons: - session.refresh(jargon) - return jargons - - -# ==================== Test Cases ==================== - - -def test_list_jargons(client: TestClient, sample_jargons): - """测试 GET /jargon/list 基础列表功能""" - response = client.get("/api/webui/jargon/list") - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - assert data["total"] == 3 - assert data["page"] == 1 - assert data["page_size"] == 20 - assert len(data["data"]) == 3 - - assert data["data"][0]["content"] == "yyds" - assert data["data"][0]["count"] == 10 - - -def test_list_jargons_with_pagination(client: TestClient, sample_jargons): - """测试分页功能""" - response = client.get("/api/webui/jargon/list?page=1&page_size=2") - assert response.status_code == 200 - - data = response.json() - assert data["total"] == 3 - assert len(data["data"]) == 2 - - response = client.get("/api/webui/jargon/list?page=2&page_size=2") - assert response.status_code == 200 - data = response.json() - assert len(data["data"]) == 1 - - -def test_list_jargons_with_search(client: TestClient, sample_jargons): - """测试 GET /jargon/list?search=xxx 搜索功能""" - response = client.get("/api/webui/jargon/list?search=yyds") - assert response.status_code == 200 - - data = response.json() - assert data["total"] == 1 - assert data["data"][0]["content"] == "yyds" - - # 测试搜索 meaning - response = client.get("/api/webui/jargon/list?search=你好") - assert response.status_code == 200 - data = response.json() - assert data["total"] == 1 - assert data["data"][0]["content"] == "hello" - - -def test_list_jargons_with_chat_id_filter(client: TestClient, sample_jargons, sample_chat_session: ChatSession): - """测试按 chat_id 筛选""" - response = client.get(f"/api/webui/jargon/list?chat_id={sample_chat_session.session_id}") - assert response.status_code == 200 - - data = response.json() - assert data["total"] == 3 - - # 测试不存在的 chat_id - response = client.get("/api/webui/jargon/list?chat_id=nonexistent") - assert response.status_code == 200 - data = response.json() - assert data["total"] == 0 - - -def test_list_jargons_with_is_jargon_filter(client: TestClient, sample_jargons): - """测试按 is_jargon 筛选""" - response = client.get("/api/webui/jargon/list?is_jargon=true") - assert response.status_code == 200 - - data = response.json() - assert data["total"] == 2 - assert all(item["is_jargon"] is True for item in data["data"]) - - response = client.get("/api/webui/jargon/list?is_jargon=false") - assert response.status_code == 200 - data = response.json() - assert data["total"] == 1 - assert data["data"][0]["content"] == "hello" - - -def test_get_jargon_detail(client: TestClient, sample_jargons): - """测试 GET /jargon/{id} 获取详情""" - jargon_id = sample_jargons[0].id - response = client.get(f"/api/webui/jargon/{jargon_id}") - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - assert data["data"]["id"] == jargon_id - assert data["data"]["content"] == "yyds" - assert data["data"]["meaning"] == "永远的神" - assert data["data"]["count"] == 10 - assert data["data"]["is_jargon"] is True - - -def test_get_jargon_detail_not_found(client: TestClient): - """测试获取不存在的黑话详情""" - response = client.get("/api/webui/jargon/99999") - assert response.status_code == 404 - assert "黑话不存在" in response.json()["detail"] - - -@pytest.mark.skip(reason="Composite PK (id+content) prevents autoincrement - database model issue") -def test_create_jargon(client: TestClient, sample_chat_session: ChatSession): - """测试 POST /jargon/ 创建黑话""" - request_data = { - "content": "新黑话", - "raw_content": "原始内容", - "meaning": "含义", - "chat_id": sample_chat_session.session_id, - } - - response = client.post("/api/webui/jargon/", json=request_data) - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - assert data["message"] == "创建成功" - assert data["data"]["content"] == "新黑话" - assert data["data"]["meaning"] == "含义" - assert data["data"]["count"] == 0 - assert data["data"]["is_jargon"] is None - assert data["data"]["is_complete"] is False - - -def test_create_duplicate_jargon(client: TestClient, sample_jargons, sample_chat_session: ChatSession): - """测试创建重复黑话返回 400""" - request_data = { - "content": "yyds", - "meaning": "重复的", - "chat_id": sample_chat_session.session_id, - } - - response = client.post("/api/webui/jargon/", json=request_data) - assert response.status_code == 400 - assert "已存在相同内容的黑话" in response.json()["detail"] - - -def test_update_jargon(client: TestClient, sample_jargons): - """测试 PATCH /jargon/{id} 更新黑话""" - jargon_id = sample_jargons[0].id - update_data = { - "meaning": "更新后的含义", - "is_jargon": True, - } - - response = client.patch(f"/api/webui/jargon/{jargon_id}", json=update_data) - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - assert data["message"] == "更新成功" - assert data["data"]["meaning"] == "更新后的含义" - assert data["data"]["is_jargon"] is True - assert data["data"]["content"] == "yyds" # 未改变的字段保持不变 - - -def test_update_jargon_with_chat_id_mapping(client: TestClient, sample_jargons): - """测试更新时 chat_id → session_id 的映射""" - jargon_id = sample_jargons[0].id - update_data = { - "chat_id": "new_session_id", - } - - response = client.patch(f"/api/webui/jargon/{jargon_id}", json=update_data) - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - assert data["data"]["chat_id"] == "new_session_id" - - -def test_update_jargon_not_found(client: TestClient): - """测试更新不存在的黑话""" - response = client.patch("/api/webui/jargon/99999", json={"meaning": "test"}) - assert response.status_code == 404 - assert "黑话不存在" in response.json()["detail"] - - -def test_delete_jargon(client: TestClient, sample_jargons, session: Session): - """测试 DELETE /jargon/{id} 删除黑话""" - jargon_id = sample_jargons[0].id - response = client.delete(f"/api/webui/jargon/{jargon_id}") - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - assert data["message"] == "删除成功" - assert data["deleted_count"] == 1 - - # 验证数据库中已删除 - response = client.get(f"/api/webui/jargon/{jargon_id}") - assert response.status_code == 404 - - -def test_delete_jargon_not_found(client: TestClient): - """测试删除不存在的黑话""" - response = client.delete("/api/webui/jargon/99999") - assert response.status_code == 404 - assert "黑话不存在" in response.json()["detail"] - - -def test_batch_delete(client: TestClient, sample_jargons): - """测试 POST /jargon/batch/delete 批量删除""" - ids_to_delete = [sample_jargons[0].id, sample_jargons[1].id] - request_data = {"ids": ids_to_delete} - - response = client.post("/api/webui/jargon/batch/delete", json=request_data) - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - assert data["deleted_count"] == 2 - assert "成功删除 2 条黑话" in data["message"] - - # 验证已删除 - response = client.get(f"/api/webui/jargon/{ids_to_delete[0]}") - assert response.status_code == 404 - - -def test_batch_delete_empty_list(client: TestClient): - """测试批量删除空列表返回 400""" - response = client.post("/api/webui/jargon/batch/delete", json={"ids": []}) - assert response.status_code == 400 - assert "ID列表不能为空" in response.json()["detail"] - - -def test_batch_set_jargon_status(client: TestClient, sample_jargons): - """测试批量设置黑话状态""" - ids = [sample_jargons[0].id, sample_jargons[1].id] - response = client.post( - "/api/webui/jargon/batch/set-jargon", - params={"ids": ids, "is_jargon": False}, - ) - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - assert "成功更新 2 条黑话状态" in data["message"] - - # 验证状态已更新 - detail_response = client.get(f"/api/webui/jargon/{ids[0]}") - assert detail_response.json()["data"]["is_jargon"] is False - - -def test_get_stats(client: TestClient, sample_jargons): - """测试 GET /jargon/stats/summary 统计数据""" - response = client.get("/api/webui/jargon/stats/summary") - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - stats = data["data"] - - assert stats["total"] == 3 - assert stats["confirmed_jargon"] == 2 - assert stats["confirmed_not_jargon"] == 1 - assert stats["pending"] == 0 - assert stats["complete_count"] == 0 - assert stats["chat_count"] == 1 - assert isinstance(stats["top_chats"], dict) - - -def test_get_chat_list(client: TestClient, sample_jargons, sample_chat_session: ChatSession): - """测试 GET /jargon/chats 获取聊天列表""" - response = client.get("/api/webui/jargon/chats") - assert response.status_code == 200 - - data = response.json() - assert data["success"] is True - assert len(data["data"]) == 1 - - chat_info = data["data"][0] - assert chat_info["chat_id"] == sample_chat_session.session_id - assert chat_info["platform"] == "qq" - assert chat_info["is_group"] is True - assert chat_info["chat_name"] == sample_chat_session.group_id - - -def test_get_chat_list_with_json_chat_id(client: TestClient, session: Session, sample_chat_session: ChatSession): - """测试解析 JSON 格式的 chat_id""" - json_chat_id = json.dumps([[sample_chat_session.session_id, "user123"]]) - jargon = Jargon( - id=100, - content="测试黑话", - meaning="测试", - session_id=json_chat_id, - count=1, - ) - session.add(jargon) - session.commit() - - response = client.get("/api/webui/jargon/chats") - assert response.status_code == 200 - - data = response.json() - assert len(data["data"]) == 1 - assert data["data"][0]["chat_id"] == sample_chat_session.session_id - - -def test_get_chat_list_without_chat_session(client: TestClient, session: Session): - """测试聊天列表中没有对应 ChatSession 的情况""" - jargon = Jargon( - id=101, - content="孤立黑话", - meaning="无对应会话", - session_id="nonexistent_stream_id", - count=1, - ) - session.add(jargon) - session.commit() - - response = client.get("/api/webui/jargon/chats") - assert response.status_code == 200 - - data = response.json() - assert len(data["data"]) == 1 - assert data["data"][0]["chat_id"] == "nonexistent_stream_id" - assert data["data"][0]["chat_name"] == "nonexistent_stream_id"[:20] - assert data["data"][0]["platform"] is None - assert data["data"][0]["is_group"] is False - - -def test_jargon_response_fields(client: TestClient, sample_jargons, sample_chat_session: ChatSession): - """测试 JargonResponse 字段完整性""" - response = client.get(f"/api/webui/jargon/{sample_jargons[0].id}") - assert response.status_code == 200 - - data = response.json()["data"] - - # 验证所有必需字段存在 - required_fields = [ - "id", - "content", - "raw_content", - "meaning", - "chat_id", - "stream_id", - "chat_name", - "count", - "is_jargon", - "is_complete", - "inference_with_context", - "inference_content_only", - ] - for field in required_fields: - assert field in data - - # 验证 chat_name 显示逻辑 - assert data["chat_name"] == sample_chat_session.group_id - - -@pytest.mark.skip(reason="Composite PK (id+content) prevents autoincrement - database model issue") -def test_create_jargon_without_optional_fields(client: TestClient, sample_chat_session: ChatSession): - """测试创建黑话时可选字段为空""" - request_data = { - "content": "简单黑话", - "chat_id": sample_chat_session.session_id, - } - - response = client.post("/api/webui/jargon/", json=request_data) - assert response.status_code == 200 - - data = response.json()["data"] - assert data["raw_content"] is None - assert data["meaning"] == "" - - -def test_update_jargon_partial_fields(client: TestClient, sample_jargons): - """测试增量更新(只更新部分字段)""" - jargon_id = sample_jargons[0].id - original_content = sample_jargons[0].content - - # 只更新 meaning - response = client.patch(f"/api/webui/jargon/{jargon_id}", json={"meaning": "新含义"}) - assert response.status_code == 200 - - data = response.json()["data"] - assert data["meaning"] == "新含义" - assert data["content"] == original_content # 其他字段不变 - - -def test_list_jargons_multiple_filters(client: TestClient, sample_jargons, sample_chat_session: ChatSession): - """测试组合多个过滤条件""" - response = client.get(f"/api/webui/jargon/list?search=永远&chat_id={sample_chat_session.session_id}&is_jargon=true") - assert response.status_code == 200 - - data = response.json() - assert data["total"] == 1 - assert data["data"][0]["content"] == "yyds" diff --git a/pytests/webui/test_memory_routes.py b/pytests/webui/test_memory_routes.py deleted file mode 100644 index 8681c561..00000000 --- a/pytests/webui/test_memory_routes.py +++ /dev/null @@ -1,870 +0,0 @@ -from fastapi import FastAPI -from fastapi.testclient import TestClient -import pytest - -from src.services.memory_service import MemorySearchResult -from src.webui.dependencies import require_auth -from src.webui.routers import memory as memory_router_module -from src.webui.routers.memory import compat_router -from src.webui.routes import router as main_router - - -@pytest.fixture -def client() -> TestClient: - app = FastAPI() - app.dependency_overrides[require_auth] = lambda: "ok" - app.include_router(main_router) - app.include_router(compat_router) - return TestClient(app) - - -def test_webui_memory_graph_route(client: TestClient, monkeypatch): - async def fake_graph_admin(*, action: str, **kwargs): - assert action == "get_graph" - return { - "success": True, - "nodes": [], - "edges": [ - { - "source": "alice", - "target": "map", - "weight": 1.5, - "relation_hashes": ["rel-1"], - "predicates": ["持有"], - "relation_count": 1, - "evidence_count": 2, - "label": "持有", - } - ], - "total_nodes": 0, - "limit": kwargs.get("limit"), - } - - monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin) - - response = client.get("/api/webui/memory/graph", params={"limit": 77}) - - assert response.status_code == 200 - assert response.json()["success"] is True - assert response.json()["limit"] == 77 - assert response.json()["edges"][0]["predicates"] == ["持有"] - assert response.json()["edges"][0]["relation_count"] == 1 - assert response.json()["edges"][0]["evidence_count"] == 2 - - -def test_webui_memory_graph_search_route(client: TestClient, monkeypatch): - async def fake_graph_admin(*, action: str, **kwargs): - assert action == "search" - assert kwargs["query"] == "Alice" - assert kwargs["limit"] == 33 - return { - "success": True, - "query": kwargs["query"], - "limit": kwargs["limit"], - "count": 1, - "items": [ - { - "type": "entity", - "title": "Alice", - "matched_field": "name", - "matched_value": "Alice", - "entity_name": "Alice", - "entity_hash": "entity-1", - "appearance_count": 3, - } - ], - } - - monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin) - - response = client.get("/api/webui/memory/graph/search", params={"query": "Alice", "limit": 33}) - - assert response.status_code == 200 - assert response.json()["success"] is True - assert response.json()["query"] == "Alice" - assert response.json()["limit"] == 33 - assert response.json()["items"][0]["type"] == "entity" - - -@pytest.mark.parametrize( - "params", - [ - {"query": "", "limit": 50}, - {"query": "Alice", "limit": 0}, - {"query": "Alice", "limit": 201}, - ], -) -def test_webui_memory_graph_search_route_validation(client: TestClient, params): - response = client.get("/api/webui/memory/graph/search", params=params) - - assert response.status_code == 422 - - -def test_webui_memory_graph_node_detail_route(client: TestClient, monkeypatch): - async def fake_graph_admin(*, action: str, **kwargs): - assert action == "node_detail" - assert kwargs["node_id"] == "Alice" - return { - "success": True, - "node": {"id": "Alice", "type": "entity", "content": "Alice", "appearance_count": 3}, - "relations": [{"hash": "rel-1", "subject": "Alice", "predicate": "持有", "object": "Map", "text": "Alice 持有 Map", "confidence": 0.9, "paragraph_count": 1, "paragraph_hashes": ["p-1"], "source_paragraph": "p-1"}], - "paragraphs": [{"hash": "p-1", "content": "Alice 拿着地图。", "preview": "Alice 拿着地图。", "source": "demo", "entity_count": 2, "relation_count": 1, "entities": ["Alice", "Map"], "relations": ["Alice 持有 Map"]}], - "evidence_graph": { - "nodes": [{"id": "entity:Alice", "type": "entity", "content": "Alice"}], - "edges": [], - "focus_entities": ["Alice"], - }, - } - - monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin) - - response = client.get("/api/webui/memory/graph/node-detail", params={"node_id": "Alice"}) - - assert response.status_code == 200 - assert response.json()["node"]["id"] == "Alice" - assert response.json()["relations"][0]["predicate"] == "持有" - assert response.json()["evidence_graph"]["focus_entities"] == ["Alice"] - - -def test_webui_memory_graph_node_detail_route_returns_404(client: TestClient, monkeypatch): - async def fake_graph_admin(*, action: str, **kwargs): - assert action == "node_detail" - return {"success": False, "error": "未找到节点: Missing"} - - monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin) - - response = client.get("/api/webui/memory/graph/node-detail", params={"node_id": "Missing"}) - - assert response.status_code == 404 - assert response.json()["detail"] == "未找到节点: Missing" - - -def test_webui_memory_graph_edge_detail_route(client: TestClient, monkeypatch): - async def fake_graph_admin(*, action: str, **kwargs): - assert action == "edge_detail" - assert kwargs["source"] == "Alice" - assert kwargs["target"] == "Map" - return { - "success": True, - "edge": { - "source": "Alice", - "target": "Map", - "weight": 1.5, - "relation_hashes": ["rel-1"], - "predicates": ["持有"], - "relation_count": 1, - "evidence_count": 1, - "label": "持有", - }, - "relations": [{"hash": "rel-1", "subject": "Alice", "predicate": "持有", "object": "Map", "text": "Alice 持有 Map", "confidence": 0.9, "paragraph_count": 1, "paragraph_hashes": ["p-1"], "source_paragraph": "p-1"}], - "paragraphs": [{"hash": "p-1", "content": "Alice 拿着地图。", "preview": "Alice 拿着地图。", "source": "demo", "entity_count": 2, "relation_count": 1, "entities": ["Alice", "Map"], "relations": ["Alice 持有 Map"]}], - "evidence_graph": { - "nodes": [{"id": "relation:rel-1", "type": "relation", "content": "Alice 持有 Map"}], - "edges": [{"source": "paragraph:p-1", "target": "relation:rel-1", "kind": "supports", "label": "支撑", "weight": 1.0}], - "focus_entities": ["Alice", "Map"], - }, - } - - monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin) - - response = client.get("/api/webui/memory/graph/edge-detail", params={"source": "Alice", "target": "Map"}) - - assert response.status_code == 200 - assert response.json()["edge"]["predicates"] == ["持有"] - assert response.json()["paragraphs"][0]["source"] == "demo" - assert response.json()["evidence_graph"]["edges"][0]["kind"] == "supports" - - -def test_webui_memory_graph_edge_detail_route_returns_404(client: TestClient, monkeypatch): - async def fake_graph_admin(*, action: str, **kwargs): - assert action == "edge_detail" - return {"success": False, "error": "未找到边: Alice -> Missing"} - - monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin) - - response = client.get("/api/webui/memory/graph/edge-detail", params={"source": "Alice", "target": "Missing"}) - - assert response.status_code == 404 - assert response.json()["detail"] == "未找到边: Alice -> Missing" - - -def test_webui_memory_profile_query_resolves_platform_user_id(client: TestClient, monkeypatch): - def fake_resolve_person_id_for_memory(**kwargs): - assert kwargs == {"platform": "qq", "user_id": "12345", "strict_known": False} - return "resolved-person-id" - - async def fake_profile_admin(*, action: str, **kwargs): - assert action == "query" - assert kwargs["person_id"] == "resolved-person-id" - assert kwargs["person_keyword"] == "Alice" - assert kwargs["limit"] == 9 - assert kwargs["force_refresh"] is True - return {"success": True, "person_id": kwargs["person_id"], "profile_text": "profile"} - - monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory) - monkeypatch.setattr(memory_router_module.memory_service, "profile_admin", fake_profile_admin) - - response = client.get( - "/api/webui/memory/profiles/query", - params={ - "platform": "qq", - "user_id": "12345", - "person_keyword": "Alice", - "limit": 9, - "force_refresh": True, - }, - ) - - assert response.status_code == 200 - assert response.json()["success"] is True - assert response.json()["person_id"] == "resolved-person-id" - - -def test_webui_memory_profile_query_prefers_explicit_person_id(client: TestClient, monkeypatch): - def fake_resolve_person_id_for_memory(**kwargs): - raise AssertionError(f"不应解析平台账号: {kwargs}") - - async def fake_profile_admin(*, action: str, **kwargs): - assert action == "query" - assert kwargs["person_id"] == "explicit-person-id" - return {"success": True, "person_id": kwargs["person_id"]} - - monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory) - monkeypatch.setattr(memory_router_module.memory_service, "profile_admin", fake_profile_admin) - - response = client.get( - "/api/webui/memory/profiles/query", - params={"person_id": "explicit-person-id", "platform": "qq", "user_id": "12345"}, - ) - - assert response.status_code == 200 - assert response.json()["person_id"] == "explicit-person-id" - - -def test_webui_memory_profile_list_enriches_person_name(client: TestClient, monkeypatch): - async def fake_profile_admin(*, action: str, **kwargs): - assert action == "list" - assert kwargs["limit"] == 7 - return { - "success": True, - "items": [ - {"person_id": "person-1", "profile_text": "profile-1"}, - {"person_id": "person-2", "profile_text": "profile-2"}, - ], - } - - monkeypatch.setattr(memory_router_module.memory_service, "profile_admin", fake_profile_admin) - monkeypatch.setattr( - memory_router_module, - "_get_person_name_for_person_id", - lambda person_id: {"person-1": "Alice"}.get(person_id, ""), - ) - - response = client.get("/api/webui/memory/profiles", params={"limit": 7}) - - assert response.status_code == 200 - assert response.json()["items"][0]["person_name"] == "Alice" - assert response.json()["items"][1]["person_name"] == "" - - -def test_webui_memory_profile_search_resolves_platform_user_id(client: TestClient, monkeypatch): - def fake_resolve_person_id_for_memory(**kwargs): - assert kwargs == {"platform": "qq", "user_id": "12345", "strict_known": False} - return "resolved-person-id" - - async def fake_profile_list(limit: int): - assert limit == 200 - return { - "success": True, - "items": [ - {"person_id": "resolved-person-id", "person_name": "Alice", "profile_text": "喜欢咖啡"}, - {"person_id": "other-person-id", "person_name": "Bob", "profile_text": "喜欢茶"}, - ], - } - - monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory) - monkeypatch.setattr(memory_router_module, "_profile_list", fake_profile_list) - - response = client.get( - "/api/webui/memory/profiles/search", - params={"platform": "qq", "user_id": "12345", "limit": 50}, - ) - - assert response.status_code == 200 - assert response.json()["items"] == [ - {"person_id": "resolved-person-id", "person_name": "Alice", "profile_text": "喜欢咖啡"} - ] - - -def test_webui_memory_profile_search_filters_keyword(client: TestClient, monkeypatch): - async def fake_profile_list(limit: int): - assert limit == 200 - return { - "success": True, - "items": [ - {"person_id": "person-1", "person_name": "Alice", "profile_text": "喜欢咖啡"}, - {"person_id": "person-2", "person_name": "Bob", "profile_text": "喜欢茶"}, - ], - } - - monkeypatch.setattr(memory_router_module, "_profile_list", fake_profile_list) - - response = client.get("/api/webui/memory/profiles/search", params={"person_keyword": "咖啡", "limit": 50}) - - assert response.status_code == 200 - assert response.json()["items"] == [ - {"person_id": "person-1", "person_name": "Alice", "profile_text": "喜欢咖啡"} - ] - - -def test_webui_memory_episode_list_resolves_platform_user_id(client: TestClient, monkeypatch): - def fake_resolve_person_id_for_memory(**kwargs): - assert kwargs == {"platform": "qq", "user_id": "12345", "strict_known": False} - return "resolved-person-id" - - async def fake_episode_admin(*, action: str, **kwargs): - assert action == "list" - assert kwargs == { - "query": "咖啡", - "limit": 9, - "source": "chat_summary:demo", - "person_id": "resolved-person-id", - "time_start": 100.0, - "time_end": 200.0, - } - return { - "success": True, - "items": [{"episode_id": "ep-1", "person_id": "resolved-person-id", "summary": "喝咖啡"}], - "count": 1, - } - - monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory) - monkeypatch.setattr(memory_router_module.memory_service, "episode_admin", fake_episode_admin) - monkeypatch.setattr(memory_router_module, "_get_person_name_for_person_id", lambda person_id: "测试人物") - - response = client.get( - "/api/webui/memory/episodes", - params={ - "query": "咖啡", - "limit": 9, - "source": "chat_summary:demo", - "platform": "qq", - "user_id": "12345", - "time_start": 100, - "time_end": 200, - }, - ) - - assert response.status_code == 200 - assert response.json()["items"][0]["person_name"] == "测试人物" - - -def test_webui_memory_episode_list_prefers_explicit_person_id(client: TestClient, monkeypatch): - def fake_resolve_person_id_for_memory(**kwargs): - raise AssertionError(f"不应解析平台账号: {kwargs}") - - async def fake_episode_admin(*, action: str, **kwargs): - assert action == "list" - assert kwargs["person_id"] == "explicit-person-id" - return {"success": True, "items": []} - - monkeypatch.setattr(memory_router_module, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory) - monkeypatch.setattr(memory_router_module.memory_service, "episode_admin", fake_episode_admin) - - response = client.get( - "/api/webui/memory/episodes", - params={"person_id": "explicit-person-id", "platform": "qq", "user_id": "12345"}, - ) - - assert response.status_code == 200 - assert response.json()["items"] == [] - - -def test_compat_aggregate_route(client: TestClient, monkeypatch): - async def fake_search(query: str, **kwargs): - assert kwargs["mode"] == "aggregate" - assert kwargs["respect_filter"] is False - return MemorySearchResult(summary=f"summary:{query}", hits=[]) - - monkeypatch.setattr(memory_router_module.memory_service, "search", fake_search) - - response = client.get("/api/query/aggregate", params={"query": "mai"}) - - assert response.status_code == 200 - assert response.json() == { - "success": True, - "summary": "summary:mai", - "hits": [], - "filtered": False, - "error": "", - } - - -def test_auto_save_routes(client: TestClient, monkeypatch): - async def fake_runtime_admin(*, action: str, **kwargs): - if action == "get_config": - return {"success": True, "auto_save": True} - if action == "set_auto_save": - return {"success": True, "auto_save": kwargs["enabled"]} - raise AssertionError(action) - - monkeypatch.setattr(memory_router_module.memory_service, "runtime_admin", fake_runtime_admin) - - get_response = client.get("/api/config/auto_save") - post_response = client.post("/api/config/auto_save", json={"enabled": False}) - - assert get_response.status_code == 200 - assert get_response.json() == {"success": True, "auto_save": True} - assert post_response.status_code == 200 - assert post_response.json() == {"success": True, "auto_save": False} - - -def test_memory_config_routes(client: TestClient, monkeypatch): - monkeypatch.setattr( - memory_router_module.a_memorix_host_service, - "get_config_schema", - lambda: {"layout": {"type": "tabs"}, "sections": {"plugin": {"fields": {}}}}, - ) - monkeypatch.setattr( - memory_router_module.a_memorix_host_service, - "get_config_path", - lambda: memory_router_module.Path("/tmp/config/bot_config.toml"), - ) - monkeypatch.setattr( - memory_router_module.a_memorix_host_service, - "get_config", - lambda: {"plugin": {"enabled": True}}, - ) - monkeypatch.setattr( - memory_router_module.a_memorix_host_service, - "get_raw_config", - lambda: "[plugin]\nenabled = true\n", - ) - monkeypatch.setattr( - memory_router_module.a_memorix_host_service, - "get_raw_config_with_meta", - lambda: { - "config": "[plugin]\nenabled = true\n", - "exists": True, - "using_default": False, - }, - ) - - schema_response = client.get("/api/webui/memory/config/schema") - config_response = client.get("/api/webui/memory/config") - raw_response = client.get("/api/webui/memory/config/raw") - expected_path = memory_router_module.Path("/tmp/config/bot_config.toml").as_posix() - - assert schema_response.status_code == 200 - assert memory_router_module.Path(schema_response.json()["path"]).as_posix() == expected_path - assert schema_response.json()["schema"]["layout"]["type"] == "tabs" - - assert config_response.status_code == 200 - assert config_response.json()["success"] is True - assert config_response.json()["config"] == {"plugin": {"enabled": True}} - assert memory_router_module.Path(config_response.json()["path"]).as_posix() == expected_path - - assert raw_response.status_code == 200 - assert raw_response.json()["success"] is True - assert raw_response.json()["config"] == "[plugin]\nenabled = true\n" - assert memory_router_module.Path(raw_response.json()["path"]).as_posix() == expected_path - - -def test_memory_config_raw_returns_default_template_when_file_missing(client: TestClient, monkeypatch): - monkeypatch.setattr( - memory_router_module.a_memorix_host_service, - "get_config_path", - lambda: memory_router_module.Path("/tmp/config/bot_config.toml"), - ) - monkeypatch.setattr( - memory_router_module.a_memorix_host_service, - "get_raw_config_with_meta", - lambda: { - "config": "[plugin]\nenabled = true\n", - "exists": False, - "using_default": True, - }, - ) - - response = client.get("/api/webui/memory/config/raw") - - assert response.status_code == 200 - assert response.json()["success"] is True - assert response.json()["config"] == "[plugin]\nenabled = true\n" - assert response.json()["exists"] is False - assert response.json()["using_default"] is True - - -def test_memory_config_update_routes(client: TestClient, monkeypatch): - async def fake_update_config(config): - assert config == {"plugin": {"enabled": False}} - return {"success": True, "config_path": "config/bot_config.toml"} - - async def fake_update_raw(raw_config): - assert raw_config == "[plugin]\nenabled = false\n" - return {"success": True, "config_path": "config/bot_config.toml"} - - monkeypatch.setattr(memory_router_module.a_memorix_host_service, "update_config", fake_update_config) - monkeypatch.setattr(memory_router_module.a_memorix_host_service, "update_raw_config", fake_update_raw) - - config_response = client.put("/api/webui/memory/config", json={"config": {"plugin": {"enabled": False}}}) - raw_response = client.put("/api/webui/memory/config/raw", json={"config": "[plugin]\nenabled = false\n"}) - - assert config_response.status_code == 200 - assert config_response.json() == {"success": True, "config_path": "config/bot_config.toml"} - - assert raw_response.status_code == 200 - assert raw_response.json() == {"success": True, "config_path": "config/bot_config.toml"} - - -def test_memory_config_raw_rejects_invalid_toml(client: TestClient): - response = client.put("/api/webui/memory/config/raw", json={"config": "[plugin\nenabled = true"}) - - assert response.status_code == 400 - assert "TOML 格式错误" in response.json()["detail"] - - -def test_recycle_bin_route(client: TestClient, monkeypatch): - async def fake_get_recycle_bin(*, limit: int): - return {"success": True, "items": [{"hash": "deadbeef"}], "count": 1, "limit": limit} - - monkeypatch.setattr(memory_router_module.memory_service, "get_recycle_bin", fake_get_recycle_bin) - - response = client.get("/api/memory/recycle_bin", params={"limit": 10}) - - assert response.status_code == 200 - assert response.json()["success"] is True - assert response.json()["count"] == 1 - assert response.json()["limit"] == 10 - - -def test_import_guide_route(client: TestClient, monkeypatch): - async def fake_import_admin(*, action: str, **kwargs): - assert kwargs == {} - if action == "get_guide": - return {"success": True} - if action == "get_settings": - return {"success": True, "settings": {"path_aliases": {"raw": "/tmp/raw"}}} - raise AssertionError(action) - - monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin) - - response = client.get("/api/webui/memory/import/guide") - - assert response.status_code == 200 - assert response.json()["success"] is True - assert response.json()["source"] == "local" - assert "长期记忆导入说明" in response.json()["content"] - - -def test_import_upload_route(client: TestClient, monkeypatch, tmp_path): - monkeypatch.setattr(memory_router_module, "STAGING_ROOT", tmp_path) - - async def fake_import_admin(*, action: str, **kwargs): - assert action == "create_upload" - staged_files = kwargs["staged_files"] - assert len(staged_files) == 1 - assert staged_files[0]["filename"] == "demo.txt" - assert memory_router_module.Path(staged_files[0]["staged_path"]).exists() - return {"success": True, "task_id": "task-1"} - - monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin) - - response = client.post( - "/api/import/upload", - data={"payload_json": "{\"source\": \"upload\"}"}, - files=[("files", ("demo.txt", b"hello world", "text/plain"))], - ) - - assert response.status_code == 200 - assert response.json() == {"success": True, "task_id": "task-1"} - assert list(tmp_path.iterdir()) == [] - - -def test_v5_status_route(client: TestClient, monkeypatch): - async def fake_v5_admin(*, action: str, **kwargs): - assert action == "status" - assert kwargs["target"] == "mai" - return {"success": True, "active_count": 1, "inactive_count": 2, "deleted_count": 3} - - monkeypatch.setattr(memory_router_module.memory_service, "v5_admin", fake_v5_admin) - - response = client.get("/api/webui/memory/v5/status", params={"target": "mai"}) - - assert response.status_code == 200 - assert response.json()["success"] is True - assert response.json()["deleted_count"] == 3 - - -def test_delete_preview_route(client: TestClient, monkeypatch): - async def fake_delete_admin(*, action: str, **kwargs): - assert action == "preview" - assert kwargs["mode"] == "paragraph" - assert kwargs["selector"] == {"query": "demo"} - return {"success": True, "counts": {"paragraphs": 1}, "dry_run": True} - - monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin) - - response = client.post( - "/api/webui/memory/delete/preview", - json={"mode": "paragraph", "selector": {"query": "demo"}}, - ) - - assert response.status_code == 200 - assert response.json() == {"success": True, "counts": {"paragraphs": 1}, "dry_run": True} - - -def test_delete_preview_route_supports_mixed_mode(client: TestClient, monkeypatch): - async def fake_delete_admin(*, action: str, **kwargs): - assert action == "preview" - assert kwargs["mode"] == "mixed" - assert kwargs["selector"] == { - "entity_hashes": ["entity-1"], - "paragraph_hashes": ["p-1"], - "relation_hashes": ["rel-1"], - "sources": ["demo"], - } - return {"success": True, "mode": "mixed", "counts": {"entities": 1, "paragraphs": 1, "relations": 1, "sources": 1}} - - monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin) - - response = client.post( - "/api/webui/memory/delete/preview", - json={ - "mode": "mixed", - "selector": { - "entity_hashes": ["entity-1"], - "paragraph_hashes": ["p-1"], - "relation_hashes": ["rel-1"], - "sources": ["demo"], - }, - }, - ) - - assert response.status_code == 200 - assert response.json()["mode"] == "mixed" - assert response.json()["counts"]["entities"] == 1 - - -def test_delete_execute_route_supports_mixed_mode(client: TestClient, monkeypatch): - async def fake_delete_admin(*, action: str, **kwargs): - assert action == "execute" - assert kwargs["mode"] == "mixed" - assert kwargs["selector"] == { - "entity_hashes": ["entity-1"], - "paragraph_hashes": ["p-1"], - "relation_hashes": ["rel-1"], - "sources": ["demo"], - } - assert kwargs["reason"] == "knowledge_graph_delete_entity" - assert kwargs["requested_by"] == "knowledge_graph" - return { - "success": True, - "mode": "mixed", - "operation_id": "op-mixed-1", - "deleted_count": 4, - "deleted_entity_count": 1, - "deleted_relation_count": 1, - "deleted_paragraph_count": 1, - "deleted_source_count": 1, - "counts": {"entities": 1, "paragraphs": 1, "relations": 1, "sources": 1}, - } - - monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin) - - response = client.post( - "/api/webui/memory/delete/execute", - json={ - "mode": "mixed", - "selector": { - "entity_hashes": ["entity-1"], - "paragraph_hashes": ["p-1"], - "relation_hashes": ["rel-1"], - "sources": ["demo"], - }, - "reason": "knowledge_graph_delete_entity", - "requested_by": "knowledge_graph", - }, - ) - - assert response.status_code == 200 - assert response.json()["success"] is True - assert response.json()["mode"] == "mixed" - assert response.json()["operation_id"] == "op-mixed-1" - - -def test_episode_process_pending_route(client: TestClient, monkeypatch): - async def fake_episode_admin(*, action: str, **kwargs): - assert action == "process_pending" - assert kwargs == {"limit": 7, "max_retry": 4} - return {"success": True, "processed": 3} - - monkeypatch.setattr(memory_router_module.memory_service, "episode_admin", fake_episode_admin) - - response = client.post("/api/webui/memory/episodes/process-pending", json={"limit": 7, "max_retry": 4}) - - assert response.status_code == 200 - assert response.json() == {"success": True, "processed": 3} - - -def test_import_list_route_includes_settings(client: TestClient, monkeypatch): - calls = [] - - async def fake_import_admin(*, action: str, **kwargs): - calls.append((action, kwargs)) - if action == "list": - return {"success": True, "items": [{"task_id": "task-1"}]} - if action == "get_settings": - return {"success": True, "settings": {"path_aliases": {"lpmm": "/tmp/lpmm"}}} - raise AssertionError(action) - - monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin) - - response = client.get("/api/webui/memory/import/tasks", params={"limit": 9}) - - assert response.status_code == 200 - assert response.json()["items"] == [{"task_id": "task-1"}] - assert response.json()["settings"] == {"path_aliases": {"lpmm": "/tmp/lpmm"}} - assert calls == [("list", {"limit": 9}), ("get_settings", {})] - - -def test_tuning_profile_route_backfills_settings(client: TestClient, monkeypatch): - calls = [] - - async def fake_tuning_admin(*, action: str, **kwargs): - calls.append((action, kwargs)) - if action == "get_profile": - return {"success": True, "profile": {"retrieval": {"top_k": 8}}} - if action == "get_settings": - return {"success": True, "settings": {"profiles": ["default"]}} - raise AssertionError(action) - - monkeypatch.setattr(memory_router_module.memory_service, "tuning_admin", fake_tuning_admin) - - response = client.get("/api/webui/memory/retrieval_tuning/profile") - - assert response.status_code == 200 - assert response.json()["profile"] == {"retrieval": {"top_k": 8}} - assert response.json()["settings"] == {"profiles": ["default"]} - assert calls == [("get_profile", {}), ("get_settings", {})] - - -def test_tuning_report_route_flattens_report_payload(client: TestClient, monkeypatch): - async def fake_tuning_admin(*, action: str, **kwargs): - assert action == "get_report" - assert kwargs == {"task_id": "task-1", "format": "json"} - return { - "success": True, - "report": {"format": "json", "content": "{\"ok\": true}", "path": "/tmp/report.json"}, - } - - monkeypatch.setattr(memory_router_module.memory_service, "tuning_admin", fake_tuning_admin) - - response = client.get("/api/webui/memory/retrieval_tuning/tasks/task-1/report", params={"format": "json"}) - - assert response.status_code == 200 - assert response.json() == { - "success": True, - "format": "json", - "content": "{\"ok\": true}", - "path": "/tmp/report.json", - "error": "", - } - - -def test_delete_execute_route(client: TestClient, monkeypatch): - async def fake_delete_admin(*, action: str, **kwargs): - assert action == "execute" - assert kwargs["mode"] == "source" - assert kwargs["selector"] == {"source": "chat_summary:stream-1"} - assert kwargs["reason"] == "cleanup" - assert kwargs["requested_by"] == "tester" - return {"success": True, "operation_id": "del-1"} - - monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin) - - response = client.post( - "/api/webui/memory/delete/execute", - json={ - "mode": "source", - "selector": {"source": "chat_summary:stream-1"}, - "reason": "cleanup", - "requested_by": "tester", - }, - ) - - assert response.status_code == 200 - assert response.json() == {"success": True, "operation_id": "del-1"} - - -def test_sources_route(client: TestClient, monkeypatch): - async def fake_source_admin(*, action: str, **kwargs): - assert action == "list" - assert kwargs == {} - return {"success": True, "items": [{"source": "demo", "paragraph_count": 2}], "count": 1} - - monkeypatch.setattr(memory_router_module.memory_service, "source_admin", fake_source_admin) - - response = client.get("/api/webui/memory/sources") - - assert response.status_code == 200 - assert response.json()["items"] == [{"source": "demo", "paragraph_count": 2}] - - -def test_delete_operation_routes(client: TestClient, monkeypatch): - async def fake_delete_admin(*, action: str, **kwargs): - if action == "list_operations": - assert kwargs == {"limit": 5, "mode": "paragraph"} - return {"success": True, "items": [{"operation_id": "del-1"}], "count": 1} - if action == "get_operation": - assert kwargs == {"operation_id": "del-1"} - return {"success": True, "operation": {"operation_id": "del-1", "mode": "paragraph"}} - raise AssertionError(action) - - monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin) - - list_response = client.get("/api/webui/memory/delete/operations", params={"limit": 5, "mode": "paragraph"}) - get_response = client.get("/api/webui/memory/delete/operations/del-1") - - assert list_response.status_code == 200 - assert list_response.json()["count"] == 1 - assert get_response.status_code == 200 - assert get_response.json()["operation"]["operation_id"] == "del-1" - - -def test_feedback_correction_routes(client: TestClient, monkeypatch): - async def fake_feedback_admin(*, action: str, **kwargs): - if action == "list": - assert kwargs == { - "limit": 7, - "statuses": ["applied"], - "rollback_statuses": ["none"], - "query": "green", - } - return {"success": True, "items": [{"task_id": 11, "query_text": "what color"}], "count": 1} - if action == "get": - assert kwargs == {"task_id": 11} - return {"success": True, "task": {"task_id": 11, "query_text": "what color", "action_logs": []}} - if action == "rollback": - assert kwargs == {"task_id": 11, "requested_by": "tester", "reason": "manual revert"} - return {"success": True, "result": {"restored_relation_hashes": ["rel-1"]}} - raise AssertionError(action) - - monkeypatch.setattr(memory_router_module.memory_service, "feedback_admin", fake_feedback_admin) - - list_response = client.get( - "/api/webui/memory/feedback-corrections", - params={"limit": 7, "status": "applied", "rollback_status": "none", "query": "green"}, - ) - get_response = client.get("/api/webui/memory/feedback-corrections/11") - rollback_response = client.post( - "/api/webui/memory/feedback-corrections/11/rollback", - json={"requested_by": "tester", "reason": "manual revert"}, - ) - - assert list_response.status_code == 200 - assert list_response.json()["items"][0]["task_id"] == 11 - assert get_response.status_code == 200 - assert get_response.json()["task"]["task_id"] == 11 - assert rollback_response.status_code == 200 - assert rollback_response.json()["result"]["restored_relation_hashes"] == ["rel-1"] diff --git a/pytests/webui/test_memory_routes_integration.py b/pytests/webui/test_memory_routes_integration.py deleted file mode 100644 index 5b139960..00000000 --- a/pytests/webui/test_memory_routes_integration.py +++ /dev/null @@ -1,533 +0,0 @@ -from __future__ import annotations - -from pathlib import Path -from time import monotonic, sleep -from typing import Any, Dict, Generator -from uuid import uuid4 - -import asyncio -import json - -from fastapi import FastAPI -from fastapi.testclient import TestClient -import pytest -import tomlkit - -from src.A_memorix import host_service as host_service_module -from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module -from src.A_memorix.core.utils import retrieval_tuning_manager as tuning_manager_module -from src.webui.dependencies import require_auth -from src.webui.routers import memory as memory_router_module - - -REQUEST_TIMEOUT_SECONDS = 30 -IMPORT_TIMEOUT_SECONDS = 120 -TUNING_TIMEOUT_SECONDS = 420 - -IMPORT_TERMINAL_STATUSES = {"completed", "completed_with_errors", "failed", "cancelled"} -TUNING_TERMINAL_STATUSES = {"completed", "failed", "cancelled"} - - -class _FakeEmbeddingManager: - def __init__(self, dimension: int = 64) -> None: - self.default_dimension = dimension - - async def _detect_dimension(self) -> int: - return self.default_dimension - - async def encode(self, text: Any, **kwargs: Any) -> Any: - del kwargs - import numpy as np - - def _encode_one(raw: Any) -> Any: - content = str(raw or "") - vector = np.zeros(self.default_dimension, dtype=np.float32) - for index, byte in enumerate(content.encode("utf-8")): - vector[index % self.default_dimension] += float((byte % 17) + 1) - norm = float(np.linalg.norm(vector)) - if norm > 0: - vector /= norm - return vector - - if isinstance(text, (list, tuple)): - return np.stack([_encode_one(item) for item in text]).astype(np.float32) - return _encode_one(text).astype(np.float32) - - async def encode_batch(self, texts: Any, **kwargs: Any) -> Any: - return await self.encode(texts, **kwargs) - - -def _build_test_config(data_dir: Path) -> Dict[str, Any]: - return { - "storage": { - "data_dir": str(data_dir), - }, - "advanced": { - "enable_auto_save": False, - }, - "embedding": { - "dimension": 64, - "batch_size": 4, - "max_concurrent": 1, - "retry": { - "max_attempts": 1, - "min_wait_seconds": 0.1, - "max_wait_seconds": 0.2, - "backoff_multiplier": 1.0, - }, - "fallback": { - "enabled": True, - "allow_metadata_only_write": True, - "probe_interval_seconds": 30, - }, - "paragraph_vector_backfill": { - "enabled": False, - "interval_seconds": 60, - "batch_size": 32, - "max_retry": 2, - }, - }, - "retrieval": { - "enable_parallel": False, - "enable_ppr": False, - "top_k_paragraphs": 20, - "top_k_relations": 10, - "top_k_final": 10, - "alpha": 0.5, - "search": { - "smart_fallback": { - "enabled": True, - }, - }, - "sparse": { - "enabled": True, - "mode": "auto", - "candidate_k": 80, - "relation_candidate_k": 60, - }, - "fusion": { - "method": "weighted_rrf", - "rrf_k": 60, - "vector_weight": 0.7, - "bm25_weight": 0.3, - }, - }, - "threshold": { - "percentile": 70.0, - "min_results": 1, - }, - "web": { - "tuning": { - "enabled": True, - "poll_interval_ms": 300, - "max_queue_size": 4, - "default_objective": "balanced", - "default_intensity": "quick", - "default_sample_size": 4, - "default_top_k_eval": 5, - "eval_query_timeout_seconds": 1.0, - "llm_retry": { - "max_attempts": 1, - "min_wait_seconds": 0.1, - "max_wait_seconds": 0.2, - "backoff_multiplier": 1.0, - }, - }, - }, - } - - -def _assert_response_ok(response: Any) -> Dict[str, Any]: - assert response.status_code == 200, response.text - payload = response.json() - assert payload.get("success", True) is True, payload - return payload - - -def _wait_for_import_task_terminal(client: TestClient, task_id: str, *, timeout_seconds: float = IMPORT_TIMEOUT_SECONDS) -> Dict[str, Any]: - deadline = monotonic() + timeout_seconds - last_payload: Dict[str, Any] = {} - while monotonic() < deadline: - response = client.get( - f"/api/webui/memory/import/tasks/{task_id}", - params={"include_chunks": True}, - ) - payload = _assert_response_ok(response) - last_payload = payload - task = payload.get("task") or {} - status = str(task.get("status", "") or "") - if status in IMPORT_TERMINAL_STATUSES: - return task - sleep(0.2) - raise AssertionError(f"导入任务超时: task_id={task_id}, last_payload={last_payload}") - - -def _wait_for_tuning_task_terminal(client: TestClient, task_id: str, *, timeout_seconds: float = TUNING_TIMEOUT_SECONDS) -> Dict[str, Any]: - deadline = monotonic() + timeout_seconds - last_payload: Dict[str, Any] = {} - while monotonic() < deadline: - response = client.get( - f"/api/webui/memory/retrieval_tuning/tasks/{task_id}", - params={"include_rounds": False}, - ) - payload = _assert_response_ok(response) - last_payload = payload - task = payload.get("task") or {} - status = str(task.get("status", "") or "") - if status in TUNING_TERMINAL_STATUSES: - return task - sleep(0.3) - raise AssertionError(f"调优任务超时: task_id={task_id}, last_payload={last_payload}") - - -def _wait_for_query_hit(client: TestClient, query: str, *, timeout_seconds: float = 30.0) -> Dict[str, Any]: - deadline = monotonic() + timeout_seconds - last_payload: Dict[str, Any] = {} - while monotonic() < deadline: - payload = _assert_response_ok( - client.get( - "/api/webui/memory/query/aggregate", - params={"query": query, "limit": 20}, - ) - ) - last_payload = payload - hits = payload.get("hits") or [] - if isinstance(hits, list) and len(hits) > 0: - return payload - sleep(0.2) - raise AssertionError(f"检索命中超时: query={query}, last_payload={last_payload}") - - -def _get_source_item(client: TestClient, source_name: str) -> Dict[str, Any] | None: - payload = _assert_response_ok(client.get("/api/webui/memory/sources")) - items = payload.get("items") or [] - for item in items: - if not isinstance(item, dict): - continue - if str(item.get("source", "") or "") == source_name: - return item - return None - - -def _source_paragraph_count(item: Dict[str, Any] | None) -> int: - payload = item or {} - if "paragraph_count" in payload: - return int(payload.get("paragraph_count", 0) or 0) - return int(payload.get("count", 0) or 0) - - -def _wait_for_source_paragraph_count( - client: TestClient, - source_name: str, - *, - min_count: int, - timeout_seconds: float = 30.0, -) -> Dict[str, Any]: - deadline = monotonic() + timeout_seconds - last_item: Dict[str, Any] = {} - while monotonic() < deadline: - item = _get_source_item(client, source_name) - count = _source_paragraph_count(item) - if count >= int(min_count): - return item or {} - if item: - last_item = dict(item) - sleep(0.2) - raise AssertionError( - f"等待来源段落计数超时: source={source_name}, min_count={min_count}, last_item={last_item}" - ) - - -def _create_multitype_upload_task(client: TestClient) -> str: - structured_json = { - "paragraphs": [ - { - "content": "Alice 携带地图前往火星港。", - "source": "integration-upload-json", - "entities": ["Alice", "地图", "火星港"], - "relations": [ - {"subject": "Alice", "predicate": "携带", "object": "地图"}, - {"subject": "Alice", "predicate": "前往", "object": "火星港"}, - ], - } - ] - } - extra_json = { - "paragraphs": [ - { - "content": "Carol 记录了一条补充说明。", - "source": "integration-upload-json-extra", - "entities": ["Carol"], - "relations": [], - } - ] - } - payload_json = json.dumps( - { - "input_mode": "text", - "llm_enabled": False, - "file_concurrency": 2, - "chunk_concurrency": 2, - "dedupe_policy": "none", - }, - ensure_ascii=False, - ) - files = [ - ("files", ("integration-notes.txt", "Alice 在测试环境记录了一条长期记忆。".encode("utf-8"), "text/plain")), - ("files", ("integration-diary.md", "# 日志\nBob 与 Alice 讨论了导图。".encode("utf-8"), "text/markdown")), - ("files", ("integration-structured.json", json.dumps(structured_json, ensure_ascii=False).encode("utf-8"), "application/json")), - ("files", ("integration-extra.json", json.dumps(extra_json, ensure_ascii=False).encode("utf-8"), "application/json")), - ] - - response = client.post( - "/api/webui/memory/import/upload", - data={"payload_json": payload_json}, - files=files, - ) - payload = _assert_response_ok(response) - task_id = str((payload.get("task") or {}).get("task_id") or "").strip() - assert task_id, payload - return task_id - - -def _create_seed_paste_task(client: TestClient, *, source: str, unique_token: str) -> str: - seed_payload = { - "paragraphs": [ - { - "content": f"Alice 在火星港携带地图并记录了口令 {unique_token}。", - "source": source, - "entities": ["Alice", "火星港", "地图"], - "relations": [ - {"subject": "Alice", "predicate": "前往", "object": "火星港"}, - {"subject": "Alice", "predicate": "携带", "object": "地图"}, - ], - }, - { - "content": f"Bob 在火星港遇见 Alice,并重复口令 {unique_token}。", - "source": source, - "entities": ["Bob", "Alice", "火星港"], - "relations": [ - {"subject": "Bob", "predicate": "遇见", "object": "Alice"}, - {"subject": "Bob", "predicate": "位于", "object": "火星港"}, - ], - }, - ] - } - response = client.post( - "/api/webui/memory/import/paste", - json={ - "name": "integration-seed.json", - "input_mode": "json", - "llm_enabled": False, - "content": json.dumps(seed_payload, ensure_ascii=False), - "dedupe_policy": "none", - }, - ) - payload = _assert_response_ok(response) - task_id = str((payload.get("task") or {}).get("task_id") or "").strip() - assert task_id, payload - return task_id - - -@pytest.fixture(scope="module") -def integration_state(tmp_path_factory: pytest.TempPathFactory) -> Generator[Dict[str, Any], None, None]: - tmp_root = tmp_path_factory.mktemp("memory_routes_integration") - data_dir = (tmp_root / "data").resolve() - staging_dir = (tmp_root / "upload_staging").resolve() - artifacts_dir = (tmp_root / "artifacts").resolve() - config_file = (tmp_root / "config" / "bot_config.toml").resolve() - runtime_config = _build_test_config(data_dir) - - patches = pytest.MonkeyPatch() - patches.setattr(host_service_module.a_memorix_host_service, "_read_config", lambda: dict(runtime_config)) - patches.setattr(host_service_module.a_memorix_host_service, "get_config_path", lambda: config_file) - patches.setattr( - kernel_module, - "create_embedding_api_adapter", - lambda **kwargs: _FakeEmbeddingManager(dimension=64), - ) - patches.setattr(memory_router_module, "STAGING_ROOT", staging_dir) - patches.setattr(tuning_manager_module, "artifacts_root", lambda: artifacts_dir) - - asyncio.run(host_service_module.a_memorix_host_service.stop()) - host_service_module.a_memorix_host_service._config_cache = None # type: ignore[attr-defined] - - app = FastAPI() - app.dependency_overrides[require_auth] = lambda: "ok" - app.include_router(memory_router_module.router, prefix="/api/webui") - app.include_router(memory_router_module.compat_router) - - unique_token = f"INTEG_TOKEN_{uuid4().hex[:12]}" - source_name = f"integration-source-{uuid4().hex[:8]}" - - with TestClient(app) as client: - upload_task_id = _create_multitype_upload_task(client) - upload_task = _wait_for_import_task_terminal(client, upload_task_id) - - seed_task_id = _create_seed_paste_task(client, source=source_name, unique_token=unique_token) - seed_task = _wait_for_import_task_terminal(client, seed_task_id) - assert str(seed_task.get("status", "") or "") in {"completed", "completed_with_errors"}, seed_task - - _wait_for_query_hit(client, unique_token, timeout_seconds=45.0) - - yield { - "client": client, - "upload_task": upload_task, - "seed_task": seed_task, - "source_name": source_name, - "unique_token": unique_token, - } - - asyncio.run(host_service_module.a_memorix_host_service.stop()) - host_service_module.a_memorix_host_service._config_cache = None # type: ignore[attr-defined] - patches.undo() - - -def test_import_module_end_to_end_supports_multitype_upload(integration_state: Dict[str, Any]) -> None: - upload_task = integration_state["upload_task"] - - assert str(upload_task.get("status", "") or "") in {"completed", "completed_with_errors"}, upload_task - files = upload_task.get("files") or [] - assert isinstance(files, list) - assert len(files) >= 4 - - file_names = {str(item.get("name", "") or "") for item in files if isinstance(item, dict)} - assert "integration-notes.txt" in file_names - assert "integration-diary.md" in file_names - assert "integration-structured.json" in file_names - assert "integration-extra.json" in file_names - - -def test_retrieval_module_end_to_end_queries_seeded_data(integration_state: Dict[str, Any]) -> None: - client = integration_state["client"] - unique_token = integration_state["unique_token"] - - aggregate_payload = _wait_for_query_hit(client, unique_token, timeout_seconds=45.0) - hits = aggregate_payload.get("hits") or [] - joined_content = "\n".join(str(item.get("content", "") or "") for item in hits if isinstance(item, dict)) - assert unique_token in joined_content - - graph_payload = _assert_response_ok( - client.get( - "/api/webui/memory/graph/search", - params={"query": "Alice", "limit": 20}, - ) - ) - graph_items = graph_payload.get("items") or [] - assert isinstance(graph_items, list) - assert any(str(item.get("type", "") or "") == "entity" for item in graph_items if isinstance(item, dict)), graph_items - - -def test_tuning_module_end_to_end_create_and_apply_best(integration_state: Dict[str, Any]) -> None: - client = integration_state["client"] - - create_payload = _assert_response_ok( - client.post( - "/api/webui/memory/retrieval_tuning/tasks", - json={ - "objective": "balanced", - "intensity": "quick", - "rounds": 2, - "sample_size": 4, - "top_k_eval": 5, - "llm_enabled": False, - "eval_query_timeout_seconds": 1.0, - "seed": 20260403, - }, - ) - ) - task_id = str((create_payload.get("task") or {}).get("task_id") or "").strip() - assert task_id, create_payload - - task = _wait_for_tuning_task_terminal(client, task_id) - assert str(task.get("status", "") or "") == "completed", task - - apply_payload = _assert_response_ok( - client.post( - f"/api/webui/memory/retrieval_tuning/tasks/{task_id}/apply-best", - ) - ) - assert "applied" in apply_payload - - -def test_delete_module_end_to_end_preview_execute_restore(integration_state: Dict[str, Any]) -> None: - client = integration_state["client"] - unique_token = integration_state["unique_token"] - source_name = integration_state["source_name"] - - before_source_item = _wait_for_source_paragraph_count(client, source_name, min_count=1, timeout_seconds=45.0) - assert _source_paragraph_count(before_source_item) >= 1 - - preview_payload = _assert_response_ok( - client.post( - "/api/webui/memory/delete/preview", - json={ - "mode": "source", - "selector": {"sources": [source_name]}, - "reason": "integration_delete_preview", - "requested_by": "pytest_integration", - }, - ) - ) - preview_counts = preview_payload.get("counts") or {} - assert int(preview_counts.get("paragraphs", 0) or 0) >= 1, preview_payload - - execute_payload = _assert_response_ok( - client.post( - "/api/webui/memory/delete/execute", - json={ - "mode": "source", - "selector": {"sources": [source_name]}, - "reason": "integration_delete_execute", - "requested_by": "pytest_integration", - }, - ) - ) - operation_id = str(execute_payload.get("operation_id", "") or "").strip() - assert operation_id, execute_payload - - after_delete_payload = _assert_response_ok( - client.get( - "/api/webui/memory/query/aggregate", - params={"query": unique_token, "limit": 20}, - ) - ) - after_delete_hits = after_delete_payload.get("hits") or [] - after_delete_text = "\n".join( - str(item.get("content", "") or "") - for item in after_delete_hits - if isinstance(item, dict) - ) - assert unique_token not in after_delete_text - assert int(execute_payload.get("deleted_paragraph_count", 0) or 0) >= 1, execute_payload - - _assert_response_ok( - client.post( - "/api/webui/memory/delete/restore", - json={ - "operation_id": operation_id, - "requested_by": "pytest_integration", - }, - ) - ) - - restored_source_item = _wait_for_source_paragraph_count(client, source_name, min_count=1, timeout_seconds=45.0) - assert _source_paragraph_count(restored_source_item) >= 1 - - operations_payload = _assert_response_ok( - client.get( - "/api/webui/memory/delete/operations", - params={"limit": 20, "mode": "source"}, - ) - ) - operation_items = operations_payload.get("items") or [] - operation_ids = { - str(item.get("operation_id", "") or "") - for item in operation_items - if isinstance(item, dict) - } - assert operation_id in operation_ids - - operation_detail_payload = _assert_response_ok(client.get(f"/api/webui/memory/delete/operations/{operation_id}")) - detail_operation = operation_detail_payload.get("operation") or {} - assert str(detail_operation.get("status", "") or "") == "restored" diff --git a/pytests/webui/test_model_routes.py b/pytests/webui/test_model_routes.py deleted file mode 100644 index 0e05ad87..00000000 --- a/pytests/webui/test_model_routes.py +++ /dev/null @@ -1,187 +0,0 @@ -"""模型路由测试 - -验证 Gemini 提供商连接测试会使用查询参数传递 API Key, -并且不会回退到 OpenAI 兼容接口使用的 Bearer 认证方式。 -""" - -import importlib -import sys -from types import ModuleType -from typing import Any - -import pytest - - -def load_model_routes(monkeypatch: pytest.MonkeyPatch): - """在导入路由前 stub 配置与认证依赖模块,避免测试时触发真实初始化。""" - config_module = ModuleType("src.config.config") - config_module.__dict__["CONFIG_DIR"] = "." - monkeypatch.setitem(sys.modules, "src.config.config", config_module) - - dependencies_module = ModuleType("src.webui.dependencies") - - async def require_auth(): - return "test-token" - - dependencies_module.__dict__["require_auth"] = require_auth - monkeypatch.setitem(sys.modules, "src.webui.dependencies", dependencies_module) - - sys.modules.pop("src.webui.routers.model", None) - return importlib.import_module("src.webui.routers.model") - - -class FakeResponse: - """简化版 HTTP 响应对象。""" - - def __init__(self, status_code: int): - self.status_code = status_code - - -def build_async_client_factory( - responses: list[FakeResponse], - calls: list[dict[str, Any]], -): - """构造一个可记录请求参数的 AsyncClient 替身。""" - - response_iter = iter(responses) - - class FakeAsyncClient: - def __init__(self, *args: Any, **kwargs: Any): - self.args = args - self.kwargs = kwargs - - async def __aenter__(self) -> "FakeAsyncClient": - return self - - async def __aexit__(self, exc_type, exc, tb) -> bool: - return False - - async def get( - self, - url: str, - headers: dict[str, Any] | None = None, - params: dict[str, Any] | None = None, - ) -> FakeResponse: - calls.append( - { - "url": url, - "headers": headers or {}, - "params": params or {}, - } - ) - return next(response_iter) - - return FakeAsyncClient - - -@pytest.mark.asyncio -async def test_test_provider_connection_uses_query_api_key_for_gemini( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Gemini 连接测试应通过查询参数传递 API Key。""" - model_routes = load_model_routes(monkeypatch) - calls: list[dict[str, Any]] = [] - fake_client_class = build_async_client_factory( - responses=[FakeResponse(200), FakeResponse(200)], - calls=calls, - ) - monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class) - - result = await model_routes.test_provider_connection( - base_url="https://generativelanguage.googleapis.com/v1beta", - api_key="valid-gemini-key", - client_type="gemini", - ) - - assert result["network_ok"] is True - assert result["api_key_valid"] is True - assert len(calls) == 2 - - network_call = calls[0] - validation_call = calls[1] - - assert network_call["url"] == "https://generativelanguage.googleapis.com/v1beta" - assert network_call["headers"] == {} - assert network_call["params"] == {} - - assert validation_call["url"] == "https://generativelanguage.googleapis.com/v1beta/models" - assert validation_call["params"] == {"key": "valid-gemini-key"} - assert validation_call["headers"] == {"Content-Type": "application/json"} - assert "Authorization" not in validation_call["headers"] - - -@pytest.mark.asyncio -async def test_test_provider_connection_uses_bearer_auth_for_openai_compatible( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """非 Gemini 提供商连接测试应继续使用 Bearer 认证。""" - model_routes = load_model_routes(monkeypatch) - calls: list[dict[str, Any]] = [] - fake_client_class = build_async_client_factory( - responses=[FakeResponse(200), FakeResponse(200)], - calls=calls, - ) - monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class) - - result = await model_routes.test_provider_connection( - base_url="https://example.com/v1", - api_key="valid-openai-key", - client_type="openai", - ) - - assert result["network_ok"] is True - assert result["api_key_valid"] is True - assert len(calls) == 2 - - validation_call = calls[1] - - assert validation_call["url"] == "https://example.com/v1/models" - assert validation_call["params"] == {} - assert validation_call["headers"]["Content-Type"] == "application/json" - assert validation_call["headers"]["Authorization"] == "Bearer valid-openai-key" - - -@pytest.mark.asyncio -async def test_test_provider_connection_by_name_forwards_provider_client_type( - monkeypatch: pytest.MonkeyPatch, - tmp_path, -) -> None: - """按提供商名称测试连接时,应透传配置中的 client_type。""" - model_routes = load_model_routes(monkeypatch) - config_path = tmp_path / "model_config.toml" - config_path.write_text( - """ -[[api_providers]] -name = "Gemini" -base_url = "https://generativelanguage.googleapis.com/v1beta" -api_key = "valid-gemini-key" -client_type = "gemini" -""".strip(), - encoding="utf-8", - ) - - monkeypatch.setattr(model_routes, "CONFIG_DIR", str(tmp_path)) - - captured_kwargs: dict[str, Any] = {} - - async def fake_test_provider_connection(**kwargs: Any) -> dict[str, Any]: - captured_kwargs.update(kwargs) - return { - "network_ok": True, - "api_key_valid": True, - "latency_ms": 12.34, - "error": None, - "http_status": 200, - } - - monkeypatch.setattr(model_routes, "test_provider_connection", fake_test_provider_connection) - - result = await model_routes.test_provider_connection_by_name(provider_name="Gemini") - - assert result["network_ok"] is True - assert result["api_key_valid"] is True - assert captured_kwargs == { - "base_url": "https://generativelanguage.googleapis.com/v1beta", - "api_key": "valid-gemini-key", - "client_type": "gemini", - } \ No newline at end of file diff --git a/pytests/webui/test_plugin_management_routes.py b/pytests/webui/test_plugin_management_routes.py deleted file mode 100644 index 4a3fb011..00000000 --- a/pytests/webui/test_plugin_management_routes.py +++ /dev/null @@ -1,136 +0,0 @@ -import json - -import pytest -from fastapi import FastAPI -from fastapi.testclient import TestClient - -from src.webui.routers.plugin import management as management_module -from src.webui.routers.plugin import support as support_module - - -@pytest.fixture -def client(tmp_path, monkeypatch) -> TestClient: - plugins_dir = tmp_path / "plugins" - plugins_dir.mkdir(parents=True, exist_ok=True) - - demo_dir = plugins_dir / "demo_plugin" - demo_dir.mkdir() - (demo_dir / "_manifest.json").write_text( - json.dumps( - { - "manifest_version": 2, - "id": "test.demo", - "name": "Demo Plugin", - "version": "1.0.0", - "description": "demo plugin", - } - ), - encoding="utf-8", - ) - - monkeypatch.setattr(management_module, "require_plugin_token", lambda _: "ok") - monkeypatch.setattr(support_module, "get_plugins_dir", lambda: plugins_dir) - - app = FastAPI() - app.include_router(management_module.router, prefix="/api/webui/plugins") - return TestClient(app) - - -def test_installed_plugins_only_scan_plugins_dir_and_exclude_a_memorix(client: TestClient): - response = client.get("/api/webui/plugins/installed") - - assert response.status_code == 200 - payload = response.json() - assert payload["success"] is True - - ids = [plugin["id"] for plugin in payload["plugins"]] - assert ids == ["test.demo"] - assert "a-dawn.a-memorix" not in ids - assert all("/src/plugins/built_in/" not in plugin["path"] for plugin in payload["plugins"]) - - -def test_resolve_installed_plugin_path_falls_back_to_manifest_id(client: TestClient): - plugin_path = support_module.resolve_installed_plugin_path("test.demo") - - assert plugin_path is not None - assert plugin_path.name == "demo_plugin" - - -def test_resolve_installed_plugin_path_accepts_manifest_id_case_mismatch(client: TestClient): - plugin_path = support_module.resolve_installed_plugin_path("Test.Demo") - - assert plugin_path is not None - assert plugin_path.name == "demo_plugin" - - -def test_install_plugin_preserves_manifest_declared_id(client: TestClient, monkeypatch): - class FakeGitMirrorService: - async def clone_repository(self, **kwargs): - target_path = kwargs["target_path"] - target_path.mkdir(parents=True, exist_ok=True) - (target_path / "_manifest.json").write_text( - json.dumps( - { - "manifest_version": 2, - "id": "author.declared", - "name": "Declared Plugin", - "version": "1.0.0", - "author": {"name": "author"}, - } - ), - encoding="utf-8", - ) - return {"success": True} - - monkeypatch.setattr(management_module, "get_git_mirror_service", lambda: FakeGitMirrorService()) - - response = client.post( - "/api/webui/plugins/install", - json={ - "plugin_id": "market.plugin", - "repository_url": "https://github.com/author/declared", - "branch": "main", - }, - ) - - assert response.status_code == 200 - plugin_path = support_module.resolve_installed_plugin_path("author.declared") - assert plugin_path is not None - manifest = json.loads((plugin_path / "_manifest.json").read_text(encoding="utf-8")) - assert manifest["id"] == "author.declared" - - -def test_install_plugin_backfills_missing_manifest_id(client: TestClient, monkeypatch): - class FakeGitMirrorService: - async def clone_repository(self, **kwargs): - target_path = kwargs["target_path"] - target_path.mkdir(parents=True, exist_ok=True) - (target_path / "_manifest.json").write_text( - json.dumps( - { - "manifest_version": 2, - "name": "Legacy Plugin", - "version": "1.0.0", - "author": {"name": "author"}, - } - ), - encoding="utf-8", - ) - return {"success": True} - - monkeypatch.setattr(management_module, "get_git_mirror_service", lambda: FakeGitMirrorService()) - - response = client.post( - "/api/webui/plugins/install", - json={ - "plugin_id": "market.legacy", - "repository_url": "https://github.com/author/legacy", - "branch": "main", - }, - ) - - assert response.status_code == 200 - plugin_path = support_module.resolve_installed_plugin_path("market.legacy") - assert plugin_path is not None - manifest = json.loads((plugin_path / "_manifest.json").read_text(encoding="utf-8")) - assert manifest["id"] == "market.legacy" diff --git a/pytests/webui/test_statistics_service.py b/pytests/webui/test_statistics_service.py deleted file mode 100644 index df5288c8..00000000 --- a/pytests/webui/test_statistics_service.py +++ /dev/null @@ -1,332 +0,0 @@ -from contextlib import contextmanager -from datetime import datetime, timedelta -from types import SimpleNamespace -from typing import Any, Iterator - -import pytest - -from src.services import statistics_service -from src.webui.schemas.statistics import DashboardData, StatisticsSummary, TimeSeriesData - - -class _Result: - def __init__(self, *, first_value: Any = None, all_values: list[Any] | None = None) -> None: - self._first_value = first_value - self._all_values = all_values or [] - - def first(self) -> Any: - return self._first_value - - def all(self) -> list[Any]: - return self._all_values - - -class _Session: - def __init__(self, results: list[_Result]) -> None: - self._results = results - - def exec(self, statement: Any) -> _Result: - del statement - return self._results.pop(0) - - -class _MemoryStore: - def __init__(self) -> None: - self.store: dict[str, Any] = {} - - def __getitem__(self, item: str) -> Any: - return self.store.get(item) - - def __setitem__(self, key: str, value: Any) -> None: - self.store[key] = value - - -def _patch_session_results(monkeypatch: pytest.MonkeyPatch, results: list[_Result]) -> list[bool]: - auto_commit_calls: list[bool] = [] - - @contextmanager - def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_Session]: - auto_commit_calls.append(auto_commit) - yield _Session([results.pop(0)]) - - monkeypatch.setattr(statistics_service, "get_db_session", _fake_get_db_session) - return auto_commit_calls - - -def _patch_session_result_group(monkeypatch: pytest.MonkeyPatch, results: list[_Result]) -> list[bool]: - auto_commit_calls: list[bool] = [] - - @contextmanager - def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_Session]: - auto_commit_calls.append(auto_commit) - yield _Session(results) - - monkeypatch.setattr(statistics_service, "get_db_session", _fake_get_db_session) - return auto_commit_calls - - -def _build_dashboard_data(total_requests: int = 1) -> DashboardData: - return DashboardData( - summary=StatisticsSummary(total_requests=total_requests), - model_stats=[], - hourly_data=[], - daily_data=[], - recent_activity=[], - ) - - -def _build_dashboard_data_with_time_series() -> DashboardData: - return DashboardData( - summary=StatisticsSummary(total_requests=1), - model_stats=[], - hourly_data=[ - TimeSeriesData(timestamp="2026-05-06T10:00:00", requests=0, cost=0.0, tokens=0), - TimeSeriesData(timestamp="2026-05-06T11:00:00", requests=2, cost=0.5, tokens=50), - TimeSeriesData(timestamp="2026-05-06T12:00:00", requests=0, cost=0.0, tokens=0), - ], - daily_data=[ - TimeSeriesData(timestamp="2026-05-05T00:00:00", requests=0, cost=0.0, tokens=0), - TimeSeriesData(timestamp="2026-05-06T00:00:00", requests=3, cost=0.7, tokens=70), - ], - recent_activity=[], - ) - - -def test_shared_fetch_queries_disable_auto_commit(monkeypatch: pytest.MonkeyPatch) -> None: - now = datetime(2026, 5, 6, 12, 0, 0) - online_record = SimpleNamespace(start_timestamp=now - timedelta(minutes=5), end_timestamp=now) - usage_record = SimpleNamespace( - timestamp=now, - request_type="chat.reply", - model_api_provider_name="provider", - model_assign_name="chat-main", - model_name="gpt-a", - prompt_tokens=10, - completion_tokens=5, - cost=0.01, - time_cost=1.2, - ) - message_record = SimpleNamespace(timestamp=now, message_id="msg-1") - tool_record = SimpleNamespace(timestamp=now, tool_name="reply") - auto_commit_calls = _patch_session_results( - monkeypatch, - [ - _Result(all_values=[online_record]), - _Result(all_values=[usage_record]), - _Result(all_values=[message_record]), - _Result(all_values=[tool_record]), - ], - ) - - online_ranges = statistics_service.fetch_online_time_since(now - timedelta(hours=1)) - usage_records = statistics_service.fetch_model_usage_since(now - timedelta(hours=1)) - messages = statistics_service.fetch_messages_since(now - timedelta(hours=1)) - tool_records = statistics_service.fetch_tool_records_since(now - timedelta(hours=1)) - - assert online_ranges == [(online_record.start_timestamp, online_record.end_timestamp)] - assert usage_records == [ - { - "timestamp": now, - "request_type": "chat.reply", - "model_api_provider_name": "provider", - "model_assign_name": "chat-main", - "model_name": "gpt-a", - "prompt_tokens": 10, - "completion_tokens": 5, - "cost": 0.01, - "time_cost": 1.2, - } - ] - assert messages == [message_record] - assert tool_records == [tool_record] - assert auto_commit_calls == [False, False, False, False] - - -def test_get_earliest_statistics_time_uses_min_valid_timestamp(monkeypatch: pytest.MonkeyPatch) -> None: - fallback_time = datetime(2026, 5, 6, 12, 0, 0) - earliest_time = datetime(2026, 5, 1, 8, 30, 0) - auto_commit_calls = _patch_session_result_group( - monkeypatch, - [ - _Result(first_value=datetime(2026, 5, 3, 9, 0, 0)), - _Result(first_value=earliest_time), - _Result(first_value=None), - _Result(first_value=datetime(2026, 5, 2, 9, 0, 0)), - ], - ) - - result = statistics_service.get_earliest_statistics_time(fallback_time) - - assert result == earliest_time - assert auto_commit_calls == [False] - - -def test_get_earliest_statistics_time_falls_back_when_query_fails(monkeypatch: pytest.MonkeyPatch) -> None: - fallback_time = datetime(2026, 5, 6, 12, 0, 0) - - @contextmanager - def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_Session]: - del auto_commit - raise RuntimeError("database unavailable") - yield _Session([]) - - monkeypatch.setattr(statistics_service, "get_db_session", _fake_get_db_session) - - assert statistics_service.get_earliest_statistics_time(fallback_time) == fallback_time - - -def test_dashboard_statistics_cache_roundtrip(monkeypatch: pytest.MonkeyPatch) -> None: - memory_store = _MemoryStore() - now = datetime.now() - dashboard_data = _build_dashboard_data(total_requests=7) - monkeypatch.setattr(statistics_service, "local_storage", memory_store) - - statistics_service.store_dashboard_statistics_cache({24: dashboard_data}, generated_at=now) - cached_data = statistics_service.get_cached_dashboard_statistics(24) - - assert cached_data is not None - assert cached_data.summary.total_requests == 7 - - -def test_dashboard_statistics_cache_stores_sparse_time_series(monkeypatch: pytest.MonkeyPatch) -> None: - memory_store = _MemoryStore() - generated_at = datetime(2026, 5, 6, 12, 0, 0) - dashboard_data = _build_dashboard_data_with_time_series() - monkeypatch.setattr(statistics_service, "local_storage", memory_store) - - statistics_service.store_dashboard_statistics_cache({2: dashboard_data}, generated_at=generated_at) - - raw_cache = memory_store[statistics_service.DASHBOARD_STATISTICS_CACHE_KEY] - raw_entry = raw_cache["entries"]["2"] - assert raw_entry["sparse"] is True - assert raw_entry["hourly_data"] == [ - {"timestamp": "2026-05-06T11:00:00", "requests": 2, "cost": 0.5, "tokens": 50} - ] - assert raw_entry["daily_data"] == [ - {"timestamp": "2026-05-06T00:00:00", "requests": 3, "cost": 0.7, "tokens": 70} - ] - - cached_data = statistics_service.get_cached_dashboard_statistics(2, max_age_seconds=10**9) - assert cached_data is not None - assert [item.timestamp for item in cached_data.hourly_data] == [ - "2026-05-06T10:00:00", - "2026-05-06T11:00:00", - "2026-05-06T12:00:00", - ] - assert cached_data.hourly_data[0].requests == 0 - assert cached_data.hourly_data[1].requests == 2 - assert cached_data.hourly_data[2].requests == 0 - - -@pytest.mark.asyncio -async def test_get_dashboard_statistics_prefers_cache(monkeypatch: pytest.MonkeyPatch) -> None: - memory_store = _MemoryStore() - dashboard_data = _build_dashboard_data(total_requests=9) - monkeypatch.setattr(statistics_service, "local_storage", memory_store) - statistics_service.store_dashboard_statistics_cache({24: dashboard_data}, generated_at=datetime.now()) - - async def _fail_compute_dashboard_statistics(hours: int = 24) -> DashboardData: - del hours - raise AssertionError("cache should be used") - - monkeypatch.setattr(statistics_service, "compute_dashboard_statistics", _fail_compute_dashboard_statistics) - - result = await statistics_service.get_dashboard_statistics(24) - - assert result.summary.total_requests == 9 - - -@pytest.mark.asyncio -async def test_get_dashboard_statistics_returns_empty_when_cache_missing(monkeypatch: pytest.MonkeyPatch) -> None: - memory_store = _MemoryStore() - monkeypatch.setattr(statistics_service, "local_storage", memory_store) - - async def _fail_compute_dashboard_statistics(hours: int = 24) -> DashboardData: - del hours - raise AssertionError("dashboard API should not compute fallback data") - - monkeypatch.setattr(statistics_service, "compute_dashboard_statistics", _fail_compute_dashboard_statistics) - - result = await statistics_service.get_dashboard_statistics(24) - - assert result.summary.total_requests == 0 - assert result.model_stats == [] - - -@pytest.mark.asyncio -async def test_get_summary_statistics_aggregates_database_and_message_counts(monkeypatch: pytest.MonkeyPatch) -> None: - start_time = datetime(2026, 5, 6, 10, 0, 0) - end_time = datetime(2026, 5, 6, 12, 0, 0) - online_records = [ - SimpleNamespace( - start_timestamp=start_time - timedelta(minutes=30), - end_timestamp=start_time + timedelta(minutes=30), - ), - SimpleNamespace( - start_timestamp=start_time + timedelta(hours=1), - end_timestamp=end_time + timedelta(minutes=30), - ), - ] - auto_commit_calls = _patch_session_results( - monkeypatch, - [ - _Result(first_value=(3, 1.5, 900, 2.5)), - _Result(all_values=online_records), - ], - ) - - def _fake_count_messages(**kwargs: Any) -> int: - return 5 if kwargs.get("has_reply_to") is None else 2 - - monkeypatch.setattr(statistics_service, "count_messages", _fake_count_messages) - - summary = await statistics_service.get_summary_statistics(start_time, end_time) - - assert summary.total_requests == 3 - assert summary.total_cost == 1.5 - assert summary.total_tokens == 900 - assert summary.avg_response_time == 2.5 - assert summary.online_time == 5400 - assert summary.total_messages == 5 - assert summary.total_replies == 2 - assert summary.cost_per_hour == pytest.approx(1.0) - assert summary.tokens_per_hour == pytest.approx(600.0) - assert auto_commit_calls == [False, False] - - -@pytest.mark.asyncio -async def test_get_model_statistics_groups_by_display_model_name(monkeypatch: pytest.MonkeyPatch) -> None: - now = datetime(2026, 5, 6, 12, 0, 0) - records = [ - SimpleNamespace( - model_assign_name="chat-main", - model_name="gpt-a", - cost=0.4, - total_tokens=100, - time_cost=2.0, - ), - SimpleNamespace( - model_assign_name="chat-main", - model_name="gpt-a", - cost=0.6, - total_tokens=200, - time_cost=4.0, - ), - SimpleNamespace( - model_assign_name=None, - model_name="gpt-b", - cost=0.2, - total_tokens=50, - time_cost=0.0, - ), - ] - _patch_session_results(monkeypatch, [_Result(all_values=records)]) - - stats = await statistics_service.get_model_statistics(now - timedelta(hours=24)) - - assert [item.model_name for item in stats] == ["chat-main", "gpt-b"] - assert stats[0].request_count == 2 - assert stats[0].total_cost == pytest.approx(1.0) - assert stats[0].total_tokens == 300 - assert stats[0].avg_response_time == pytest.approx(3.0) - assert stats[1].avg_response_time == 0.0 diff --git a/pytests/webui/test_system_routes.py b/pytests/webui/test_system_routes.py deleted file mode 100644 index a812aca4..00000000 --- a/pytests/webui/test_system_routes.py +++ /dev/null @@ -1,13 +0,0 @@ -from src.webui.routers import system - - -def test_is_newer_version_detects_patch_update() -> None: - assert system._is_newer_version("1.0.7", "1.0.6") is True - - -def test_is_newer_version_ignores_same_version_with_shorter_parts() -> None: - assert system._is_newer_version("1.0.0", "1.0") is False - - -def test_is_newer_version_handles_unknown_current_version() -> None: - assert system._is_newer_version("1.0.7", "unknown") is False diff --git a/scripts/test_memory_retrieval.py b/scripts/test_memory_retrieval.py deleted file mode 100644 index 9acf04c4..00000000 --- a/scripts/test_memory_retrieval.py +++ /dev/null @@ -1,459 +0,0 @@ -import argparse -import asyncio -import os -import sys -import time -import json -import importlib -from dataclasses import dataclass -from typing import Optional, Dict, Any -from datetime import datetime - -# 强制使用 utf-8,避免控制台编码报错 -try: - if hasattr(sys.stdout, "reconfigure"): - sys.stdout.reconfigure(encoding="utf-8") - if hasattr(sys.stderr, "reconfigure"): - sys.stderr.reconfigure(encoding="utf-8") -except Exception: - pass - -# 确保能导入 src.* -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - -from src.common.logger import initialize_logging, get_logger -from src.common.database.database import db -from src.common.database.database_model import LLMUsage -from src.common.data_models.mai_message_data_model import UserInfo, GroupInfo - -try: - from maim_message import ChatStream, UserInfo, GroupInfo -except Exception: - @dataclass - class ChatStream: - stream_id: str - platform: str - user_info: UserInfo - group_info: GroupInfo - -logger = get_logger("test_memory_retrieval") - - -# 使用 importlib 动态导入,避免循环导入问题 -def _import_memory_retrieval(): - """使用 importlib 动态导入 memory_retrieval 模块,避免循环导入""" - try: - # 先导入 prompt_builder,检查 prompt 是否已经初始化 - from src.chat.utils.prompt_builder import global_prompt_manager - - # 检查 memory_retrieval 相关的 prompt 是否已经注册 - # 如果已经注册,说明模块可能已经通过其他路径初始化过了 - prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts - - module_name = "src.memory_system.memory_retrieval" - - # 如果 prompt 已经初始化,尝试直接使用已加载的模块 - if prompt_already_init and module_name in sys.modules: - existing_module = sys.modules[module_name] - if hasattr(existing_module, "init_memory_retrieval_prompt"): - return ( - existing_module.init_memory_retrieval_prompt, - existing_module._react_agent_solve_question, - existing_module._process_single_question, - ) - - # 如果模块已经在 sys.modules 中但部分初始化,先移除它 - if module_name in sys.modules: - existing_module = sys.modules[module_name] - if not hasattr(existing_module, "init_memory_retrieval_prompt"): - # 模块部分初始化,移除它 - logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入") - del sys.modules[module_name] - # 清理可能相关的部分初始化模块 - keys_to_remove = [] - for key in sys.modules.keys(): - if key.startswith("src.memory_system.") and key != "src.memory_system": - keys_to_remove.append(key) - for key in keys_to_remove: - try: - del sys.modules[key] - except KeyError: - pass - - # 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载 - # 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们 - try: - # 先导入可能触发循环导入的模块,让它们完成初始化 - import src.config.config - import src.chat.utils.prompt_builder - - # 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval) - # 如果它们已经导入,就确保它们完全初始化 - # 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval) - # 如果它们已经导入,就确保它们完全初始化 - try: - import src.chat.replyer.group_generator # noqa: F401 - except (ImportError, AttributeError): - pass # 如果导入失败,继续 - try: - import src.chat.replyer.private_generator # noqa: F401 - except (ImportError, AttributeError): - pass # 如果导入失败,继续 - except Exception as e: - logger.warning(f"预加载依赖模块时出现警告: {e}") - - # 现在尝试导入 memory_retrieval - # 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval - memory_retrieval_module = importlib.import_module(module_name) - - return ( - memory_retrieval_module.init_memory_retrieval_prompt, - memory_retrieval_module._react_agent_solve_question, - memory_retrieval_module._process_single_question, - ) - except (ImportError, AttributeError) as e: - logger.error(f"导入 memory_retrieval 模块失败: {e}", exc_info=True) - raise - - -def create_test_chat_stream(chat_id: str = "test_memory_retrieval") -> ChatStream: - """创建一个测试用的 ChatStream 对象""" - user_info = UserInfo( - platform="test", - user_id="test_user", - user_nickname="测试用户", - ) - group_info = GroupInfo( - platform="test", - group_id="test_group", - group_name="测试群组", - ) - return ChatStream( - stream_id=chat_id, - platform="test", - user_info=user_info, - group_info=group_info, - ) - - -def get_token_usage_since(start_time: float) -> Dict[str, Any]: - """获取从指定时间开始的token使用情况 - - Args: - start_time: 开始时间戳 - - Returns: - 包含token使用统计的字典 - """ - try: - start_datetime = datetime.fromtimestamp(start_time) - - # 查询从开始时间到现在的所有memory相关的token使用记录 - records = ( - LLMUsage.select() - .where( - (LLMUsage.timestamp >= start_datetime) - & ( - (LLMUsage.request_type.like("%memory%")) - | (LLMUsage.request_type == "memory.question") - | (LLMUsage.request_type == "memory.react") - | (LLMUsage.request_type == "memory.react.final") - ) - ) - .order_by(LLMUsage.timestamp.asc()) - ) - - total_prompt_tokens = 0 - total_completion_tokens = 0 - total_tokens = 0 - total_cost = 0.0 - request_count = 0 - model_usage = {} # 按模型统计 - - for record in records: - total_prompt_tokens += record.prompt_tokens or 0 - total_completion_tokens += record.completion_tokens or 0 - total_tokens += record.total_tokens or 0 - total_cost += record.cost or 0.0 - request_count += 1 - - # 按模型统计 - model_name = record.model_name or "unknown" - if model_name not in model_usage: - model_usage[model_name] = { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0, - "cost": 0.0, - "request_count": 0, - } - model_usage[model_name]["prompt_tokens"] += record.prompt_tokens or 0 - model_usage[model_name]["completion_tokens"] += record.completion_tokens or 0 - model_usage[model_name]["total_tokens"] += record.total_tokens or 0 - model_usage[model_name]["cost"] += record.cost or 0.0 - model_usage[model_name]["request_count"] += 1 - - return { - "total_prompt_tokens": total_prompt_tokens, - "total_completion_tokens": total_completion_tokens, - "total_tokens": total_tokens, - "total_cost": total_cost, - "request_count": request_count, - "model_usage": model_usage, - } - except Exception as e: - logger.error(f"获取token使用情况失败: {e}") - return { - "total_prompt_tokens": 0, - "total_completion_tokens": 0, - "total_tokens": 0, - "total_cost": 0.0, - "request_count": 0, - "model_usage": {}, - } - - -def format_thinking_steps(thinking_steps: list) -> str: - """格式化思考步骤为可读字符串""" - if not thinking_steps: - return "无思考步骤" - - lines = [] - for step in thinking_steps: - iteration = step.get("iteration", "?") - thought = step.get("thought", "") - actions = step.get("actions", []) - observations = step.get("observations", []) - - lines.append(f"\n--- 迭代 {iteration} ---") - if thought: - lines.append(f"思考: {thought[:200]}...") - - if actions: - lines.append("行动:") - for action in actions: - action_type = action.get("action_type", "unknown") - action_params = action.get("action_params", {}) - lines.append(f" - {action_type}: {json.dumps(action_params, ensure_ascii=False)}") - - if observations: - lines.append("观察:") - for obs in observations: - obs_str = str(obs)[:200] - if len(str(obs)) > 200: - obs_str += "..." - lines.append(f" - {obs_str}") - - return "\n".join(lines) - - -async def test_memory_retrieval( - question: str, - chat_id: str = "test_memory_retrieval", - context: str = "", - max_iterations: Optional[int] = None, -) -> Dict[str, Any]: - """测试记忆检索功能 - - Args: - question: 要查询的问题 - chat_id: 聊天ID - context: 上下文信息 - max_iterations: 最大迭代次数 - - Returns: - 包含测试结果的字典 - """ - print("\n" + "=" * 80) - print("[测试] 记忆检索测试") - print(f"[问题] {question}") - print("=" * 80) - - # 记录开始时间 - start_time = time.time() - - # 延迟导入并初始化记忆检索prompt(这会自动加载 global_config) - # 注意:必须在函数内部调用,避免在模块级别触发循环导入 - try: - init_memory_retrieval_prompt, _react_agent_solve_question, _ = _import_memory_retrieval() - - # 检查 prompt 是否已经初始化,避免重复初始化 - from src.chat.utils.prompt_builder import global_prompt_manager - - if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts: - init_memory_retrieval_prompt() - else: - logger.debug("记忆检索 prompt 已经初始化,跳过重复初始化") - except Exception as e: - logger.error(f"初始化记忆检索模块失败: {e}", exc_info=True) - raise - - # 获取 global_config(此时应该已经加载) - from src.config.config import global_config - - # 直接调用 _react_agent_solve_question 来获取详细的迭代信息 - if max_iterations is None: - max_iterations = global_config.memory.max_agent_iterations - - timeout = global_config.memory.agent_timeout_seconds - - print("\n[配置]") - print(f" 最大迭代次数: {max_iterations}") - print(f" 超时时间: {timeout}秒") - print(f" 聊天ID: {chat_id}") - - # 执行检索 - print(f"\n[开始检索] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}") - - found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question( - question=question, - chat_id=chat_id, - max_iterations=max_iterations, - timeout=timeout, - initial_info="", - ) - - # 记录结束时间 - end_time = time.time() - elapsed_time = end_time - start_time - - # 获取token使用情况 - token_usage = get_token_usage_since(start_time) - - # 构建结果 - result = { - "question": question, - "found_answer": found_answer, - "answer": answer, - "is_timeout": is_timeout, - "elapsed_time": elapsed_time, - "thinking_steps": thinking_steps, - "iteration_count": len(thinking_steps), - "token_usage": token_usage, - } - - # 输出结果 - print(f"\n[检索完成] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}") - print("\n[结果]") - print(f" 是否找到答案: {'是' if found_answer else '否'}") - if found_answer and answer: - print(f" 答案: {answer}") - else: - print(" 答案: (未找到答案)") - print(f" 是否超时: {'是' if is_timeout else '否'}") - print(f" 迭代次数: {len(thinking_steps)}") - print(f" 总耗时: {elapsed_time:.2f}秒") - - print("\n[Token使用情况]") - print(f" 总请求数: {token_usage['request_count']}") - print(f" 总Prompt Tokens: {token_usage['total_prompt_tokens']:,}") - print(f" 总Completion Tokens: {token_usage['total_completion_tokens']:,}") - print(f" 总Tokens: {token_usage['total_tokens']:,}") - print(f" 总成本: ${token_usage['total_cost']:.6f}") - - if token_usage["model_usage"]: - print("\n[按模型统计]") - for model_name, usage in token_usage["model_usage"].items(): - print(f" {model_name}:") - print(f" 请求数: {usage['request_count']}") - print(f" Prompt Tokens: {usage['prompt_tokens']:,}") - print(f" Completion Tokens: {usage['completion_tokens']:,}") - print(f" 总Tokens: {usage['total_tokens']:,}") - print(f" 成本: ${usage['cost']:.6f}") - - print("\n[迭代详情]") - print(format_thinking_steps(thinking_steps)) - - print("\n" + "=" * 80) - - return result - - -def main() -> None: - parser = argparse.ArgumentParser( - description="测试记忆检索功能。可以输入一个问题,脚本会使用记忆检索的逻辑进行检索,并记录迭代信息、时间和token总消耗。" - ) - parser.add_argument( - "--chat-id", - default="test_memory_retrieval", - help="测试用的聊天ID(默认: test_memory_retrieval)", - ) - parser.add_argument( - "--context", - default="", - help="上下文信息(可选)", - ) - parser.add_argument( - "--output", - "-o", - help="将结果保存到JSON文件(可选)", - ) - - args = parser.parse_args() - - # 初始化日志(使用较低的详细程度,避免输出过多日志) - initialize_logging(verbose=False) - - # 交互式输入问题 - print("\n" + "=" * 80) - print("记忆检索测试工具") - print("=" * 80) - question = input("\n请输入要查询的问题: ").strip() - if not question: - print("错误: 问题不能为空") - return - - # 交互式输入最大迭代次数 - max_iterations_input = input("\n请输入最大迭代次数(直接回车使用配置默认值): ").strip() - max_iterations = None - if max_iterations_input: - try: - max_iterations = int(max_iterations_input) - if max_iterations <= 0: - print("警告: 迭代次数必须大于0,将使用配置默认值") - max_iterations = None - except ValueError: - print("警告: 无效的迭代次数,将使用配置默认值") - max_iterations = None - - # 连接数据库 - try: - db.connect(reuse_if_open=True) - except Exception as e: - logger.error(f"数据库连接失败: {e}") - print(f"错误: 数据库连接失败: {e}") - return - - # 运行测试 - try: - result = asyncio.run( - test_memory_retrieval( - question=question, - chat_id=args.chat_id, - context=args.context, - max_iterations=max_iterations, - ) - ) - - # 如果指定了输出文件,保存结果 - if args.output: - # 将thinking_steps转换为可序列化的格式 - output_result = result.copy() - with open(args.output, "w", encoding="utf-8") as f: - json.dump(output_result, f, ensure_ascii=False, indent=2) - print(f"\n[结果已保存] {args.output}") - - except KeyboardInterrupt: - print("\n\n[中断] 用户中断测试") - except Exception as e: - logger.error(f"测试失败: {e}", exc_info=True) - print(f"\n[错误] 测试失败: {e}") - finally: - try: - db.close() - except Exception: - pass - - -if __name__ == "__main__": - main() diff --git a/scripts/test_model_tool_call_params.py b/scripts/test_model_tool_call_params.py deleted file mode 100644 index f25fd3b3..00000000 --- a/scripts/test_model_tool_call_params.py +++ /dev/null @@ -1,845 +0,0 @@ -from argparse import ArgumentParser, Namespace -from contextlib import contextmanager -from dataclasses import asdict, dataclass -from pathlib import Path -from typing import Any, Dict, Iterator, List, Sequence - -import asyncio -import json -import sys -import time - - -PROJECT_ROOT = Path(__file__).resolve().parent.parent -if str(PROJECT_ROOT) not in sys.path: - sys.path.insert(0, str(PROJECT_ROOT)) - -from src.common.data_models.llm_service_data_models import LLMServiceRequest, LLMServiceResult # noqa: E402 -from src.config.config import config_manager # noqa: E402 -from src.config.model_configs import APIProvider, ModelInfo, TaskConfig # noqa: E402 -from src.llm_models.payload_content.tool_option import ToolCall # noqa: E402 -from src.services.llm_service import generate # noqa: E402 -from src.services.service_task_resolver import get_available_models # noqa: E402 - - -DEFAULT_SKIP_TASKS = {"embedding", "voice"} - - -@dataclass(slots=True) -class ToolCallCase: - """Tool call 参数测试用例。""" - - name: str - description: str - tool_definition: Dict[str, Any] - expected_arguments: Dict[str, Any] - - @property - def tool_name(self) -> str: - """返回工具名称。""" - if self.tool_definition.get("type") == "function": - function_definition = self.tool_definition.get("function", {}) - return str(function_definition.get("name", "") or "") - return str(self.tool_definition.get("name", "") or "") - - @property - def parameters_schema(self) -> Dict[str, Any]: - """返回参数 Schema。""" - if self.tool_definition.get("type") == "function": - function_definition = self.tool_definition.get("function", {}) - parameters = function_definition.get("parameters", {}) - return parameters if isinstance(parameters, dict) else {} - parameters = self.tool_definition.get("parameters", {}) - return parameters if isinstance(parameters, dict) else {} - - def build_messages(self) -> List[Dict[str, Any]]: - """构造测试消息。""" - expected_json = json.dumps(self.expected_arguments, ensure_ascii=False, indent=2) - system_prompt = ( - "你正在执行严格的工具调用参数兼容性测试。" - "你必须通过工具调用响应,不能输出自然语言,不能解释,不能补充额外字段。" - ) - user_prompt = ( - f"请立刻调用工具 `{self.tool_name}`。\n" - "参数必须与下面 JSON 完全一致,键名、值、布尔类型、整数类型、浮点数、数组顺序和对象结构都不能改变。\n" - "不要输出任何解释文本,只返回工具调用。\n" - f"{expected_json}" - ) - return [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ] - - -@dataclass(slots=True) -class ProbeTarget: - """单个待测试模型目标。""" - - task_name: str - model_name: str - provider_name: str - client_type: str - tool_argument_parse_mode: str - - -@dataclass(slots=True) -class ProbeResult: - """单次测试结果。""" - - task_name: str - target_model_name: str - actual_model_name: str - provider_name: str - client_type: str - tool_argument_parse_mode: str - case_name: str - attempt: int - success: bool - elapsed_seconds: float - errors: List[str] - warnings: List[str] - response_text: str - reasoning_text: str - tool_calls: List[Dict[str, Any]] - - -def _ensure_utf8_console() -> None: - """尽量将控制台编码切到 UTF-8。""" - try: - if hasattr(sys.stdout, "reconfigure"): - sys.stdout.reconfigure(encoding="utf-8") - if hasattr(sys.stderr, "reconfigure"): - sys.stderr.reconfigure(encoding="utf-8") - except Exception: - pass - - -def _build_function_tool(name: str, description: str, parameters: Dict[str, Any]) -> Dict[str, Any]: - """构造 OpenAI 风格 function tool 定义。""" - return { - "type": "function", - "function": { - "name": name, - "description": description, - "parameters": parameters, - }, - } - - -def _build_default_cases() -> List[ToolCallCase]: - """构造默认测试用例。""" - simple_expected_arguments = { - "request_id": "probe-simple-001", - "count": 7, - "enabled": True, - "mode": "strict", - "ratio": 2.5, - } - simple_parameters = { - "type": "object", - "properties": { - "request_id": {"type": "string", "description": "请求 ID"}, - "count": {"type": "integer", "description": "数量"}, - "enabled": {"type": "boolean", "description": "是否启用"}, - "mode": { - "type": "string", - "description": "模式", - "enum": ["strict", "loose"], - }, - "ratio": {"type": "number", "description": "比例"}, - }, - "required": ["request_id", "count", "enabled", "mode", "ratio"], - "additionalProperties": False, - } - - nested_expected_arguments = { - "request_id": "probe-nested-001", - "notify": False, - "profile": { - "channel": "stable", - "priority": 2, - }, - "tags": ["alpha", "beta", "gamma"], - "items": [ - {"count": 2, "name": "apple"}, - {"count": 5, "name": "banana"}, - ], - } - nested_parameters = { - "type": "object", - "properties": { - "request_id": {"type": "string", "description": "请求 ID"}, - "notify": {"type": "boolean", "description": "是否通知"}, - "profile": { - "type": "object", - "description": "配置对象", - "properties": { - "channel": {"type": "string", "description": "渠道"}, - "priority": {"type": "integer", "description": "优先级"}, - }, - "required": ["channel", "priority"], - "additionalProperties": False, - }, - "tags": { - "type": "array", - "description": "标签列表", - "items": {"type": "string"}, - }, - "items": { - "type": "array", - "description": "条目列表", - "items": { - "type": "object", - "properties": { - "count": {"type": "integer", "description": "数量"}, - "name": {"type": "string", "description": "名称"}, - }, - "required": ["count", "name"], - "additionalProperties": False, - }, - }, - }, - "required": ["request_id", "notify", "profile", "tags", "items"], - "additionalProperties": False, - } - - return [ - ToolCallCase( - name="simple", - description="标量参数类型校验", - tool_definition=_build_function_tool( - name="record_simple_probe", - description="记录简单参数探测结果", - parameters=simple_parameters, - ), - expected_arguments=simple_expected_arguments, - ), - ToolCallCase( - name="nested", - description="嵌套对象与数组参数校验", - tool_definition=_build_function_tool( - name="record_nested_probe", - description="记录嵌套参数探测结果", - parameters=nested_parameters, - ), - expected_arguments=nested_expected_arguments, - ), - ] - - -def _parse_multi_value_args(raw_values: Sequence[str] | None) -> List[str]: - """解析命令行中的多值参数。""" - parsed_values: List[str] = [] - for raw_value in raw_values or []: - for item in str(raw_value).split(","): - normalized_item = item.strip() - if normalized_item: - parsed_values.append(normalized_item) - return parsed_values - - -def _build_model_map() -> Dict[str, ModelInfo]: - """构造模型名称到模型配置的映射。""" - return {model.name: model for model in config_manager.get_model_config().models} - - -def _build_provider_map() -> Dict[str, APIProvider]: - """构造 Provider 名称到配置的映射。""" - return {provider.name: provider for provider in config_manager.get_model_config().api_providers} - - -def _pick_default_task_name(task_names: Sequence[str]) -> str: - """选择默认任务名。""" - if "utils" in task_names: - return "utils" - if not task_names: - raise ValueError("当前没有可用的任务配置") - return str(task_names[0]) - - -def _resolve_targets(task_filters: Sequence[str], model_filters: Sequence[str], fallback_task: str) -> List[ProbeTarget]: - """根据命令行参数解析待测试目标。""" - available_tasks = get_available_models() - model_map = _build_model_map() - provider_map = _build_provider_map() - - if not available_tasks: - raise ValueError("未找到任何可用的模型任务配置") - - if task_filters: - selected_task_names = [] - for task_name in task_filters: - if task_name not in available_tasks: - raise ValueError(f"未找到任务 `{task_name}`") - selected_task_names.append(task_name) - else: - selected_task_names = [ - task_name - for task_name in available_tasks - if task_name not in DEFAULT_SKIP_TASKS - ] - - if not selected_task_names: - raise ValueError("没有可用于 tool call 测试的任务,请显式通过 --task 指定") - - default_task_name = fallback_task if fallback_task in available_tasks else _pick_default_task_name(selected_task_names) - resolved_targets: List[ProbeTarget] = [] - seen_models: set[str] = set() - - if model_filters: - model_names = list(model_filters) - else: - model_names = [] - for task_name in selected_task_names: - task_config = available_tasks[task_name] - for model_name in task_config.model_list: - if model_name not in model_names: - model_names.append(model_name) - - for model_name in model_names: - if model_name in seen_models: - continue - if model_name not in model_map: - raise ValueError(f"未找到模型 `{model_name}`") - - target_task_name = "" - for task_name in selected_task_names: - if model_name in available_tasks[task_name].model_list: - target_task_name = task_name - break - if not target_task_name: - target_task_name = default_task_name - - model_info = model_map[model_name] - provider_info = provider_map[model_info.api_provider] - resolved_targets.append( - ProbeTarget( - task_name=target_task_name, - model_name=model_name, - provider_name=provider_info.name, - client_type=provider_info.client_type, - tool_argument_parse_mode=provider_info.tool_argument_parse_mode, - ) - ) - seen_models.add(model_name) - - return resolved_targets - - -@contextmanager -def _pin_task_to_model(task_name: str, model_name: str) -> Iterator[None]: - """临时将某个任务锁定到单模型。""" - model_task_config = config_manager.get_model_config().model_task_config - task_config = getattr(model_task_config, task_name, None) - if not isinstance(task_config, TaskConfig): - raise ValueError(f"未找到任务 `{task_name}` 对应的配置") - - original_model_list = list(task_config.model_list) - original_selection_strategy = task_config.selection_strategy - task_config.model_list = [model_name] - task_config.selection_strategy = "balance" - try: - yield - finally: - task_config.model_list = original_model_list - task_config.selection_strategy = original_selection_strategy - - -def _serialize_tool_calls(tool_calls: List[ToolCall] | None) -> List[Dict[str, Any]]: - """序列化工具调用结果。""" - if not tool_calls: - return [] - return [ - { - "id": tool_call.call_id, - "function": { - "name": tool_call.func_name, - "arguments": dict(tool_call.args or {}), - }, - } - for tool_call in tool_calls - ] - - -def _is_integer_value(value: Any) -> bool: - """判断是否为整数类型且排除布尔值。""" - return isinstance(value, int) and not isinstance(value, bool) - - -def _is_number_value(value: Any) -> bool: - """判断是否为数值类型且排除布尔值。""" - return (isinstance(value, int) or isinstance(value, float)) and not isinstance(value, bool) - - -def _schema_type(schema: Dict[str, Any]) -> str: - """解析 Schema 的类型。""" - schema_type = str(schema.get("type", "") or "").strip() - if schema_type: - return schema_type - if "properties" in schema or "required" in schema: - return "object" - return "" - - -def _validate_schema(schema: Dict[str, Any], actual_value: Any, path: str = "args") -> List[str]: - """按简化 JSON Schema 校验工具参数。""" - errors: List[str] = [] - schema_type = _schema_type(schema) - - if "enum" in schema and actual_value not in schema["enum"]: - errors.append(f"{path} 枚举值不合法,期望属于 {schema['enum']},实际为 {actual_value!r}") - - if schema_type == "string": - if not isinstance(actual_value, str): - errors.append(f"{path} 类型错误,期望 string,实际为 {type(actual_value).__name__}") - return errors - - if schema_type == "integer": - if not _is_integer_value(actual_value): - errors.append(f"{path} 类型错误,期望 integer,实际为 {type(actual_value).__name__}") - return errors - - if schema_type == "number": - if not _is_number_value(actual_value): - errors.append(f"{path} 类型错误,期望 number,实际为 {type(actual_value).__name__}") - return errors - - if schema_type == "boolean": - if not isinstance(actual_value, bool): - errors.append(f"{path} 类型错误,期望 boolean,实际为 {type(actual_value).__name__}") - return errors - - if schema_type == "array": - if not isinstance(actual_value, list): - errors.append(f"{path} 类型错误,期望 array,实际为 {type(actual_value).__name__}") - return errors - item_schema = schema.get("items") - if isinstance(item_schema, dict): - for index, item in enumerate(actual_value): - errors.extend(_validate_schema(item_schema, item, f"{path}[{index}]")) - return errors - - if schema_type == "object": - if not isinstance(actual_value, dict): - errors.append(f"{path} 类型错误,期望 object,实际为 {type(actual_value).__name__}") - return errors - - properties = schema.get("properties", {}) - required_fields = [str(item) for item in schema.get("required", [])] - for required_field in required_fields: - if required_field not in actual_value: - errors.append(f"{path}.{required_field} 缺少必填字段") - - for field_name, field_value in actual_value.items(): - field_path = f"{path}.{field_name}" - field_schema = properties.get(field_name) - if isinstance(field_schema, dict): - errors.extend(_validate_schema(field_schema, field_value, field_path)) - continue - - additional_properties = schema.get("additionalProperties", True) - if additional_properties is False: - errors.append(f"{field_path} 是未定义字段") - elif isinstance(additional_properties, dict): - errors.extend(_validate_schema(additional_properties, field_value, field_path)) - return errors - - return errors - - -def _compare_expected_values(expected_value: Any, actual_value: Any, path: str = "args") -> List[str]: - """递归比较实际值与期望值是否完全一致。""" - errors: List[str] = [] - - if isinstance(expected_value, dict): - if not isinstance(actual_value, dict): - return [f"{path} 值不一致,期望 object,实际为 {type(actual_value).__name__}"] - - expected_keys = set(expected_value.keys()) - actual_keys = set(actual_value.keys()) - for missing_key in sorted(expected_keys - actual_keys): - errors.append(f"{path}.{missing_key} 缺少期望字段") - for extra_key in sorted(actual_keys - expected_keys): - errors.append(f"{path}.{extra_key} 出现了额外字段") - for shared_key in sorted(expected_keys & actual_keys): - errors.extend( - _compare_expected_values( - expected_value[shared_key], - actual_value[shared_key], - f"{path}.{shared_key}", - ) - ) - return errors - - if isinstance(expected_value, list): - if not isinstance(actual_value, list): - return [f"{path} 值不一致,期望 array,实际为 {type(actual_value).__name__}"] - - if len(expected_value) != len(actual_value): - errors.append(f"{path} 列表长度不一致,期望 {len(expected_value)},实际 {len(actual_value)}") - for index, (expected_item, actual_item) in enumerate( - zip(expected_value, actual_value, strict=False) - ): - errors.extend(_compare_expected_values(expected_item, actual_item, f"{path}[{index}]")) - return errors - - if isinstance(expected_value, bool): - if not isinstance(actual_value, bool) or actual_value is not expected_value: - errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}") - return errors - - if _is_integer_value(expected_value): - if not _is_integer_value(actual_value) or actual_value != expected_value: - errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}") - return errors - - if isinstance(expected_value, float): - if not _is_number_value(actual_value) or float(actual_value) != expected_value: - errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}") - return errors - - if expected_value != actual_value: - errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}") - return errors - - -def _pick_tool_call(tool_calls: List[ToolCall], expected_tool_name: str) -> ToolCall: - """优先选择同名工具调用,否则回退到第一条。""" - for tool_call in tool_calls: - if tool_call.func_name == expected_tool_name: - return tool_call - return tool_calls[0] - - -def _validate_service_result( - service_result: LLMServiceResult, - target: ProbeTarget, - case: ToolCallCase, -) -> tuple[List[str], List[str], List[Dict[str, Any]]]: - """校验服务层返回结果。""" - errors: List[str] = [] - warnings: List[str] = [] - completion = service_result.completion - serialized_tool_calls = _serialize_tool_calls(completion.tool_calls) - - if not service_result.success: - errors.append(service_result.error or completion.response or "请求失败但未返回错误信息") - return errors, warnings, serialized_tool_calls - - if completion.model_name and completion.model_name != target.model_name: - errors.append( - f"实际命中的模型为 `{completion.model_name}`,与目标模型 `{target.model_name}` 不一致" - ) - - tool_calls = completion.tool_calls or [] - if not tool_calls: - errors.append("模型未返回 tool_calls") - if completion.response.strip(): - warnings.append("模型返回了自然语言文本而不是工具调用") - return errors, warnings, serialized_tool_calls - - if len(tool_calls) != 1: - errors.append(f"返回了 {len(tool_calls)} 个 tool_calls,预期为 1 个") - - selected_tool_call = _pick_tool_call(tool_calls, case.tool_name) - if selected_tool_call.func_name != case.tool_name: - errors.append( - f"工具名不一致,期望 `{case.tool_name}`,实际 `{selected_tool_call.func_name}`" - ) - - actual_arguments = selected_tool_call.args - if not isinstance(actual_arguments, dict): - errors.append("工具参数未被解析为对象") - return errors, warnings, serialized_tool_calls - - errors.extend(_validate_schema(case.parameters_schema, actual_arguments)) - errors.extend(_compare_expected_values(case.expected_arguments, actual_arguments)) - - if completion.response.strip(): - warnings.append("模型同时返回了自然语言文本") - return errors, warnings, serialized_tool_calls - - -async def _run_single_probe( - target: ProbeTarget, - case: ToolCallCase, - attempt: int, - max_tokens: int, - temperature: float, -) -> ProbeResult: - """执行单次工具调用参数探测。""" - request = LLMServiceRequest( - task_name=target.task_name, - request_type=f"tool_call_param_probe.{case.name}.attempt_{attempt}", - prompt=case.build_messages(), - tool_options=[case.tool_definition], - temperature=temperature, - max_tokens=max_tokens, - ) - - started_at = time.perf_counter() - with _pin_task_to_model(target.task_name, target.model_name): - service_result = await generate(request) - elapsed_seconds = time.perf_counter() - started_at - - errors, warnings, serialized_tool_calls = _validate_service_result(service_result, target, case) - completion = service_result.completion - return ProbeResult( - task_name=target.task_name, - target_model_name=target.model_name, - actual_model_name=completion.model_name, - provider_name=target.provider_name, - client_type=target.client_type, - tool_argument_parse_mode=target.tool_argument_parse_mode, - case_name=case.name, - attempt=attempt, - success=not errors, - elapsed_seconds=elapsed_seconds, - errors=errors, - warnings=warnings, - response_text=completion.response, - reasoning_text=completion.reasoning, - tool_calls=serialized_tool_calls, - ) - - -def _print_targets(targets: Sequence[ProbeTarget]) -> None: - """打印待测试目标。""" - print("待测试目标:") - for index, target in enumerate(targets, start=1): - print( - f"{index}. model={target.model_name} | task={target.task_name} | " - f"provider={target.provider_name} | client={target.client_type} | " - f"tool_argument_parse_mode={target.tool_argument_parse_mode}" - ) - - -def _print_available_targets() -> None: - """打印当前可用任务与模型。""" - available_tasks = get_available_models() - model_map = _build_model_map() - task_names = list(available_tasks.keys()) - - print("当前可用任务:") - for task_name in task_names: - task_config = available_tasks[task_name] - print(f"- {task_name}: {list(task_config.model_list)}") - - referenced_models = { - model_name - for task_config in available_tasks.values() - for model_name in task_config.model_list - } - - print("\n当前配置中的模型:") - for model_name, model_info in model_map.items(): - referenced_mark = "已被任务引用" if model_name in referenced_models else "未被任务引用" - print( - f"- {model_name}: provider={model_info.api_provider}, " - f"identifier={model_info.model_identifier}, {referenced_mark}" - ) - - -def _select_cases(case_filters: Sequence[str]) -> List[ToolCallCase]: - """根据参数筛选测试用例。""" - all_cases = {case.name: case for case in _build_default_cases()} - if not case_filters: - return list(all_cases.values()) - - selected_cases: List[ToolCallCase] = [] - for case_name in case_filters: - if case_name not in all_cases: - raise ValueError(f"未知测试用例 `{case_name}`,可选值: {', '.join(sorted(all_cases))}") - selected_cases.append(all_cases[case_name]) - return selected_cases - - -def _print_single_result(result: ProbeResult, show_response: bool) -> None: - """打印单次结果。""" - status_text = "PASS" if result.success else "FAIL" - print( - f"[{status_text}] model={result.target_model_name} | task={result.task_name} | " - f"case={result.case_name} | attempt={result.attempt} | elapsed={result.elapsed_seconds:.2f}s" - ) - if result.errors: - for error in result.errors: - print(f" ERROR: {error}") - if result.warnings: - for warning in result.warnings: - print(f" WARN: {warning}") - if result.tool_calls: - print(f" tool_calls: {json.dumps(result.tool_calls, ensure_ascii=False)}") - if show_response and result.response_text.strip(): - print(f" response: {result.response_text}") - - -def _build_summary(results: Sequence[ProbeResult]) -> Dict[str, Any]: - """构造结果摘要。""" - total_count = len(results) - passed_count = sum(1 for result in results if result.success) - failed_count = total_count - passed_count - failed_items = [ - { - "model_name": result.target_model_name, - "case_name": result.case_name, - "attempt": result.attempt, - "errors": list(result.errors), - } - for result in results - if not result.success - ] - return { - "total": total_count, - "passed": passed_count, - "failed": failed_count, - "failed_items": failed_items, - } - - -def _write_json_report(json_out: str, results: Sequence[ProbeResult]) -> None: - """将测试结果写入 JSON 文件。""" - output_path = Path(json_out).expanduser().resolve() - output_path.parent.mkdir(parents=True, exist_ok=True) - payload = { - "generated_at": time.strftime("%Y-%m-%d %H:%M:%S"), - "summary": _build_summary(results), - "results": [asdict(result) for result in results], - } - output_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") - print(f"\n结果已写入: {output_path}") - - -async def _run_probes(args: Namespace) -> List[ProbeResult]: - """执行所有探测请求。""" - task_filters = _parse_multi_value_args(args.task) - model_filters = _parse_multi_value_args(args.model) - case_filters = _parse_multi_value_args(args.case) - - selected_cases = _select_cases(case_filters) - targets = _resolve_targets(task_filters, model_filters, args.fallback_task) - - _print_targets(targets) - print("") - - results: List[ProbeResult] = [] - for target in targets: - for attempt in range(1, args.repeat + 1): - for case in selected_cases: - print( - f"开始测试: model={target.model_name}, task={target.task_name}, " - f"case={case.name}, attempt={attempt}" - ) - result = await _run_single_probe( - target=target, - case=case, - attempt=attempt, - max_tokens=args.max_tokens, - temperature=args.temperature, - ) - _print_single_result(result, args.show_response) - print("") - results.append(result) - return results - - -def _build_parser() -> ArgumentParser: - """构造命令行参数解析器。""" - parser = ArgumentParser( - description=( - "测试 config/model_config.toml 中不同模型的 tool call 参数兼容性。\n" - "默认会测试所有非 voice / embedding 任务中引用到的模型。" - ) - ) - parser.add_argument( - "--task", - action="append", - help="指定任务名,可重复传入,或使用逗号分隔多个值,例如 --task utils --task planner", - ) - parser.add_argument( - "--model", - action="append", - help="指定模型名,可重复传入,或使用逗号分隔多个值,例如 --model qwen3.6-plus", - ) - parser.add_argument( - "--case", - action="append", - help="指定测试用例名,可选 simple、nested;不传则运行全部默认用例", - ) - parser.add_argument( - "--repeat", - type=int, - default=1, - help="每个模型每个用例重复测试次数,默认 1", - ) - parser.add_argument( - "--max-tokens", - type=int, - default=512, - help="单次测试的最大输出 token 数,默认 512", - ) - parser.add_argument( - "--temperature", - type=float, - default=0.0, - help="单次测试温度,默认 0.0 以尽量提高稳定性", - ) - parser.add_argument( - "--fallback-task", - default="utils", - help="当指定模型未被任何已选任务引用时,用于挂载该模型的任务名,默认 utils", - ) - parser.add_argument( - "--json-out", - help="可选,将结果写入指定 JSON 文件", - ) - parser.add_argument( - "--list-targets", - action="store_true", - help="仅打印当前任务与模型映射,不发起网络请求", - ) - parser.add_argument( - "--show-response", - action="store_true", - help="打印模型返回的自然语言文本内容", - ) - return parser - - -def main() -> int: - """脚本入口。""" - _ensure_utf8_console() - parser = _build_parser() - args = parser.parse_args() - - if args.repeat < 1: - parser.error("--repeat 必须大于等于 1") - if args.max_tokens < 1: - parser.error("--max-tokens 必须大于等于 1") - - if args.list_targets: - _print_available_targets() - return 0 - - results = asyncio.run(_run_probes(args)) - summary = _build_summary(results) - - print("测试摘要:") - print( - f"total={summary['total']} | passed={summary['passed']} | failed={summary['failed']}" - ) - if summary["failed_items"]: - print("失败明细:") - for failed_item in summary["failed_items"]: - print( - f"- model={failed_item['model_name']} | case={failed_item['case_name']} | " - f"attempt={failed_item['attempt']} | errors={failed_item['errors']}" - ) - - if args.json_out: - _write_json_report(args.json_out, results) - - return 0 if summary["failed"] == 0 else 1 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/scripts/test_tool_call_api_matrix.py b/scripts/test_tool_call_api_matrix.py deleted file mode 100644 index d533bb57..00000000 --- a/scripts/test_tool_call_api_matrix.py +++ /dev/null @@ -1,777 +0,0 @@ -from argparse import ArgumentParser, Namespace -from contextlib import contextmanager -from dataclasses import asdict, dataclass, field -from pathlib import Path -from typing import Any, Dict, Iterator, List, Sequence - -import asyncio -import json -import sys -import time - - -PROJECT_ROOT = Path(__file__).resolve().parent.parent -if str(PROJECT_ROOT) not in sys.path: - sys.path.insert(0, str(PROJECT_ROOT)) - -from src.common.data_models.llm_service_data_models import LLMServiceRequest, LLMServiceResult # noqa: E402 -from src.config.config import config_manager # noqa: E402 -from src.config.model_configs import APIProvider, ModelInfo, TaskConfig # noqa: E402 -from src.services.llm_service import generate # noqa: E402 -from src.services.service_task_resolver import get_available_models # noqa: E402 - - -DEFAULT_SKIP_TASKS = {"embedding", "voice"} - - -@dataclass(slots=True) -class ProbeTarget: - """单个待测试模型目标。""" - - task_name: str - model_name: str - provider_name: str - client_type: str - tool_argument_parse_mode: str - - -@dataclass(slots=True) -class ToolCallScenario: - """工具调用 API 场景定义。""" - - name: str - description: str - prompt: List[Dict[str, Any]] - tool_options: List[Dict[str, Any]] | None = None - expect_tool_calls: bool | None = None - - -@dataclass(slots=True) -class ProbeResult: - """单次 API 探测结果。""" - - task_name: str - target_model_name: str - actual_model_name: str - provider_name: str - client_type: str - tool_argument_parse_mode: str - case_name: str - attempt: int - success: bool - elapsed_seconds: float - errors: List[str] = field(default_factory=list) - warnings: List[str] = field(default_factory=list) - response_text: str = "" - reasoning_text: str = "" - tool_calls: List[Dict[str, Any]] = field(default_factory=list) - - -def _ensure_utf8_console() -> None: - """尽量将控制台编码切换为 UTF-8。""" - try: - if hasattr(sys.stdout, "reconfigure"): - sys.stdout.reconfigure(encoding="utf-8") - if hasattr(sys.stderr, "reconfigure"): - sys.stderr.reconfigure(encoding="utf-8") - except Exception: - pass - - -def _build_function_tool(name: str, description: str, parameters: Dict[str, Any]) -> Dict[str, Any]: - """构造 OpenAI 风格 function tool。""" - return { - "type": "function", - "function": { - "name": name, - "description": description, - "parameters": parameters, - }, - } - - -def _build_probe_tools() -> List[Dict[str, Any]]: - """构造通用测试工具。""" - weather_tool = _build_function_tool( - name="lookup_weather", - description="查询指定城市天气。", - parameters={ - "type": "object", - "properties": { - "city": {"type": "string", "description": "城市名"}, - "unit": { - "type": "string", - "description": "温度单位", - "enum": ["celsius", "fahrenheit"], - }, - "include_forecast": {"type": "boolean", "description": "是否包含未来天气"}, - }, - "required": ["city", "unit", "include_forecast"], - "additionalProperties": False, - }, - ) - search_tool = _build_function_tool( - name="search_docs", - description="搜索内部知识库。", - parameters={ - "type": "object", - "properties": { - "query": {"type": "string", "description": "搜索关键词"}, - "top_k": {"type": "integer", "description": "返回条数"}, - "filters": { - "type": "object", - "description": "过滤条件", - "properties": { - "scope": {"type": "string", "description": "搜索范围"}, - "tag": {"type": "string", "description": "标签"}, - }, - "required": ["scope", "tag"], - "additionalProperties": False, - }, - }, - "required": ["query", "top_k", "filters"], - "additionalProperties": False, - }, - ) - return [weather_tool, search_tool] - - -def _build_default_scenarios() -> List[ToolCallScenario]: - """构造默认测试场景。""" - tools = _build_probe_tools() - weather_tool = tools[0] - search_tool = tools[1] - - history_tool_call = { - "id": "call_hist_weather_001", - "type": "function", - "function": { - "name": "lookup_weather", - "arguments": { - "city": "上海", - "unit": "celsius", - "include_forecast": True, - }, - }, - } - nested_history_tool_call = { - "id": "call_hist_search_001", - "type": "function", - "function": { - "name": "search_docs", - "arguments": { - "query": "工具调用兼容性", - "top_k": 3, - "filters": { - "scope": "internal", - "tag": "tool-call", - }, - }, - }, - } - - return [ - ToolCallScenario( - name="fresh_tool_call", - description="首轮普通工具调用请求。", - prompt=[ - { - "role": "system", - "content": ( - "你正在执行工具调用连通性测试。" - "如果能调用工具,就优先调用最合适的工具。" - ), - }, - { - "role": "user", - "content": "请查询上海天气,并使用工具给出参数。", - }, - ], - tool_options=[weather_tool], - expect_tool_calls=True, - ), - ToolCallScenario( - name="history_assistant_tool_calls_with_content", - description="历史 assistant 同时包含文本和 tool_calls,当前轮不再提供 tools。", - prompt=[ - {"role": "system", "content": "你正在执行多轮上下文兼容性测试。"}, - {"role": "user", "content": "先帮我查一下上海天气。"}, - { - "role": "assistant", - "content": "我先查询天气,再继续回答。", - "tool_calls": [history_tool_call], - }, - {"role": "user", "content": "继续说,别丢掉上下文。"}, - ], - tool_options=None, - expect_tool_calls=None, - ), - ToolCallScenario( - name="history_assistant_tool_calls_without_content", - description="历史 assistant 只有 tool_calls,没有文本内容。", - prompt=[ - {"role": "system", "content": "你正在执行多轮上下文兼容性测试。"}, - {"role": "user", "content": "先帮我查一下上海天气。"}, - { - "role": "assistant", - "tool_calls": [history_tool_call], - }, - {"role": "user", "content": "继续。"}, - ], - tool_options=None, - expect_tool_calls=None, - ), - ToolCallScenario( - name="history_tool_result_followup", - description="历史中包含 assistant.tool_calls 与对应 tool 结果消息。", - prompt=[ - {"role": "system", "content": "你正在执行工具调用闭环兼容性测试。"}, - {"role": "user", "content": "先查上海天气。"}, - { - "role": "assistant", - "content": "我先查询天气。", - "tool_calls": [history_tool_call], - }, - { - "role": "tool", - "tool_call_id": "call_hist_weather_001", - "content": json.dumps( - { - "city": "上海", - "condition": "多云", - "temperature_c": 24, - "forecast": ["晴", "小雨"], - }, - ensure_ascii=False, - ), - }, - {"role": "user", "content": "结合上面的查询结果继续总结。"}, - ], - tool_options=None, - expect_tool_calls=None, - ), - ToolCallScenario( - name="history_multiple_tool_calls_and_results", - description="历史中包含多个 tool_calls 与多条 tool 结果。", - prompt=[ - {"role": "system", "content": "你正在执行多工具上下文兼容性测试。"}, - {"role": "user", "content": "先查天气,再搜一下工具调用兼容性文档。"}, - { - "role": "assistant", - "content": "我分两步查询。", - "tool_calls": [history_tool_call, nested_history_tool_call], - }, - { - "role": "tool", - "tool_call_id": "call_hist_weather_001", - "content": json.dumps( - { - "city": "上海", - "condition": "阴", - "temperature_c": 22, - }, - ensure_ascii=False, - ), - }, - { - "role": "tool", - "tool_call_id": "call_hist_search_001", - "content": json.dumps( - { - "items": [ - "OpenAI 兼容接口的 arguments 常见为 JSON 字符串", - "部分 provider 在历史消息回放时兼容性较弱", - ], - }, - ensure_ascii=False, - ), - }, - {"role": "user", "content": "继续整合上面的两个结果。"}, - ], - tool_options=None, - expect_tool_calls=None, - ), - ToolCallScenario( - name="history_tool_calls_with_current_tools", - description="保留历史 tool_calls,同时当前轮仍然提供 tools。", - prompt=[ - {"role": "system", "content": "你正在执行历史 tool_calls 与当前 tools 共存测试。"}, - {"role": "user", "content": "先查上海天气。"}, - { - "role": "assistant", - "content": "我先查天气。", - "tool_calls": [history_tool_call], - }, - { - "role": "tool", - "tool_call_id": "call_hist_weather_001", - "content": json.dumps( - { - "city": "上海", - "condition": "晴", - "temperature_c": 26, - }, - ensure_ascii=False, - ), - }, - {"role": "user", "content": "现在再搜一下工具调用兼容性文档。"}, - ], - tool_options=[search_tool], - expect_tool_calls=True, - ), - ] - - -def _parse_multi_value_args(raw_values: Sequence[str] | None) -> List[str]: - """解析命令行中的多值参数。""" - parsed_values: List[str] = [] - for raw_value in raw_values or []: - for item in str(raw_value).split(","): - normalized_item = item.strip() - if normalized_item: - parsed_values.append(normalized_item) - return parsed_values - - -def _build_model_map() -> Dict[str, ModelInfo]: - """构造模型名到模型配置的映射。""" - return {model.name: model for model in config_manager.get_model_config().models} - - -def _build_provider_map() -> Dict[str, APIProvider]: - """构造 Provider 名称到配置的映射。""" - return {provider.name: provider for provider in config_manager.get_model_config().api_providers} - - -def _pick_default_task_name(task_names: Sequence[str]) -> str: - """选择默认任务名。""" - if "utils" in task_names: - return "utils" - if not task_names: - raise ValueError("当前没有可用的任务配置") - return str(task_names[0]) - - -def _resolve_targets(task_filters: Sequence[str], model_filters: Sequence[str], fallback_task: str) -> List[ProbeTarget]: - """根据命令行参数解析待测试目标。""" - available_tasks = get_available_models() - model_map = _build_model_map() - provider_map = _build_provider_map() - - if not available_tasks: - raise ValueError("未找到任何可用的模型任务配置") - - if task_filters: - selected_task_names = [] - for task_name in task_filters: - if task_name not in available_tasks: - raise ValueError(f"未找到任务 `{task_name}`") - selected_task_names.append(task_name) - else: - selected_task_names = [ - task_name - for task_name in available_tasks - if task_name not in DEFAULT_SKIP_TASKS - ] - - if not selected_task_names: - raise ValueError("没有可用于工具调用 API 测试的任务,请显式通过 --task 指定") - - default_task_name = fallback_task if fallback_task in available_tasks else _pick_default_task_name(selected_task_names) - resolved_targets: List[ProbeTarget] = [] - seen_models: set[str] = set() - - if model_filters: - model_names = list(model_filters) - else: - model_names = [] - for task_name in selected_task_names: - task_config = available_tasks[task_name] - for model_name in task_config.model_list: - if model_name not in model_names: - model_names.append(model_name) - - for model_name in model_names: - if model_name in seen_models: - continue - if model_name not in model_map: - raise ValueError(f"未找到模型 `{model_name}`") - - target_task_name = "" - for task_name in selected_task_names: - if model_name in available_tasks[task_name].model_list: - target_task_name = task_name - break - if not target_task_name: - target_task_name = default_task_name - - model_info = model_map[model_name] - provider_info = provider_map[model_info.api_provider] - resolved_targets.append( - ProbeTarget( - task_name=target_task_name, - model_name=model_name, - provider_name=provider_info.name, - client_type=provider_info.client_type, - tool_argument_parse_mode=provider_info.tool_argument_parse_mode, - ) - ) - seen_models.add(model_name) - - return resolved_targets - - -@contextmanager -def _pin_task_to_model(task_name: str, model_name: str) -> Iterator[None]: - """临时将某个任务锁定到单模型。""" - model_task_config = config_manager.get_model_config().model_task_config - task_config = getattr(model_task_config, task_name, None) - if not isinstance(task_config, TaskConfig): - raise ValueError(f"未找到任务 `{task_name}` 对应的配置") - - original_model_list = list(task_config.model_list) - original_selection_strategy = task_config.selection_strategy - task_config.model_list = [model_name] - task_config.selection_strategy = "balance" - try: - yield - finally: - task_config.model_list = original_model_list - task_config.selection_strategy = original_selection_strategy - - -def _serialize_tool_calls(tool_calls: Any) -> List[Dict[str, Any]]: - """序列化返回中的工具调用。""" - if not tool_calls: - return [] - - serialized_items: List[Dict[str, Any]] = [] - for tool_call in tool_calls: - serialized_items.append( - { - "id": getattr(tool_call, "call_id", ""), - "function": { - "name": getattr(tool_call, "func_name", ""), - "arguments": dict(getattr(tool_call, "args", {}) or {}), - }, - **( - {"extra_content": dict(getattr(tool_call, "extra_content", {}) or {})} - if getattr(tool_call, "extra_content", None) - else {} - ), - } - ) - return serialized_items - - -def _validate_service_result(service_result: LLMServiceResult, scenario: ToolCallScenario) -> tuple[List[str], List[str], List[Dict[str, Any]]]: - """校验服务结果。""" - errors: List[str] = [] - warnings: List[str] = [] - completion = service_result.completion - serialized_tool_calls = _serialize_tool_calls(completion.tool_calls) - - if not service_result.success: - errors.append(service_result.error or completion.response or "请求失败,但没有返回明确错误") - return errors, warnings, serialized_tool_calls - - if scenario.expect_tool_calls is True and not serialized_tool_calls: - warnings.append("本场景期望模型倾向于调用工具,但未返回 tool_calls") - if scenario.expect_tool_calls is False and serialized_tool_calls: - warnings.append("本场景未期望继续调用工具,但模型返回了 tool_calls") - if completion.response.strip(): - warnings.append("模型返回了可见文本") - return errors, warnings, serialized_tool_calls - - -async def _run_single_probe( - target: ProbeTarget, - scenario: ToolCallScenario, - attempt: int, - max_tokens: int, - temperature: float, -) -> ProbeResult: - """执行单次 API 探测。""" - request = LLMServiceRequest( - task_name=target.task_name, - request_type=f"tool_call_api_matrix.{scenario.name}.attempt_{attempt}", - prompt=scenario.prompt, - tool_options=scenario.tool_options, - temperature=temperature, - max_tokens=max_tokens, - ) - - started_at = time.perf_counter() - with _pin_task_to_model(target.task_name, target.model_name): - service_result = await generate(request) - elapsed_seconds = time.perf_counter() - started_at - - errors, warnings, serialized_tool_calls = _validate_service_result(service_result, scenario) - completion = service_result.completion - return ProbeResult( - task_name=target.task_name, - target_model_name=target.model_name, - actual_model_name=completion.model_name, - provider_name=target.provider_name, - client_type=target.client_type, - tool_argument_parse_mode=target.tool_argument_parse_mode, - case_name=scenario.name, - attempt=attempt, - success=not errors, - elapsed_seconds=elapsed_seconds, - errors=errors, - warnings=warnings, - response_text=completion.response, - reasoning_text=completion.reasoning, - tool_calls=serialized_tool_calls, - ) - - -def _print_targets(targets: Sequence[ProbeTarget]) -> None: - """打印待测试目标。""" - print("待测试目标:") - for index, target in enumerate(targets, start=1): - print( - f"{index}. model={target.model_name} | task={target.task_name} | " - f"provider={target.provider_name} | client={target.client_type} | " - f"tool_argument_parse_mode={target.tool_argument_parse_mode}" - ) - - -def _print_available_targets() -> None: - """打印当前可用任务与模型。""" - available_tasks = get_available_models() - model_map = _build_model_map() - task_names = list(available_tasks.keys()) - - print("当前可用任务:") - for task_name in task_names: - task_config = available_tasks[task_name] - print(f"- {task_name}: {list(task_config.model_list)}") - - referenced_models = { - model_name - for task_config in available_tasks.values() - for model_name in task_config.model_list - } - - print("\n当前配置中的模型:") - for model_name, model_info in model_map.items(): - referenced_mark = "已被任务引用" if model_name in referenced_models else "未被任务引用" - print( - f"- {model_name}: provider={model_info.api_provider}, " - f"identifier={model_info.model_identifier}, {referenced_mark}" - ) - - -def _select_scenarios(case_filters: Sequence[str]) -> List[ToolCallScenario]: - """按名称筛选测试场景。""" - all_scenarios = {scenario.name: scenario for scenario in _build_default_scenarios()} - if not case_filters: - return list(all_scenarios.values()) - - selected_scenarios: List[ToolCallScenario] = [] - for case_name in case_filters: - if case_name not in all_scenarios: - raise ValueError( - f"未知测试场景 `{case_name}`,可选值: {', '.join(sorted(all_scenarios))}" - ) - selected_scenarios.append(all_scenarios[case_name]) - return selected_scenarios - - -def _print_single_result(result: ProbeResult, show_response: bool) -> None: - """打印单次结果。""" - status_text = "PASS" if result.success else "FAIL" - print( - f"[{status_text}] model={result.target_model_name} | task={result.task_name} | " - f"case={result.case_name} | attempt={result.attempt} | elapsed={result.elapsed_seconds:.2f}s" - ) - if result.errors: - for error in result.errors: - print(f" ERROR: {error}") - if result.warnings: - for warning in result.warnings: - print(f" WARN: {warning}") - if result.tool_calls: - print(f" tool_calls: {json.dumps(result.tool_calls, ensure_ascii=False)}") - if show_response and result.response_text.strip(): - print(f" response: {result.response_text}") - - -def _build_summary(results: Sequence[ProbeResult]) -> Dict[str, Any]: - """构造结果摘要。""" - total_count = len(results) - passed_count = sum(1 for result in results if result.success) - failed_count = total_count - passed_count - failed_items = [ - { - "model_name": result.target_model_name, - "case_name": result.case_name, - "attempt": result.attempt, - "errors": list(result.errors), - } - for result in results - if not result.success - ] - return { - "total": total_count, - "passed": passed_count, - "failed": failed_count, - "failed_items": failed_items, - } - - -def _write_json_report(json_out: str, results: Sequence[ProbeResult]) -> None: - """将测试结果写入 JSON 文件。""" - output_path = Path(json_out).expanduser().resolve() - output_path.parent.mkdir(parents=True, exist_ok=True) - payload = { - "generated_at": time.strftime("%Y-%m-%d %H:%M:%S"), - "summary": _build_summary(results), - "results": [asdict(result) for result in results], - } - output_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") - print(f"\n结果已写入: {output_path}") - - -async def _run_probes(args: Namespace) -> List[ProbeResult]: - """执行所有探测请求。""" - task_filters = _parse_multi_value_args(args.task) - model_filters = _parse_multi_value_args(args.model) - case_filters = _parse_multi_value_args(args.case) - - selected_scenarios = _select_scenarios(case_filters) - targets = _resolve_targets(task_filters, model_filters, args.fallback_task) - - _print_targets(targets) - print("") - - results: List[ProbeResult] = [] - for target in targets: - for attempt in range(1, args.repeat + 1): - for scenario in selected_scenarios: - print( - f"开始测试: model={target.model_name}, task={target.task_name}, " - f"case={scenario.name}, attempt={attempt}" - ) - result = await _run_single_probe( - target=target, - scenario=scenario, - attempt=attempt, - max_tokens=args.max_tokens, - temperature=args.temperature, - ) - _print_single_result(result, args.show_response) - print("") - results.append(result) - return results - - -def _build_parser() -> ArgumentParser: - """构造命令行参数解析器。""" - parser = ArgumentParser( - description=( - "测试不同模型在多种工具调用消息形态下的 API 兼容性。\n" - "重点覆盖历史 assistant.tool_calls、tool 结果消息、多工具调用等场景。" - ) - ) - parser.add_argument( - "--task", - action="append", - help="指定任务名,可重复传入,或使用逗号分隔多个值,例如 --task utils --task planner", - ) - parser.add_argument( - "--model", - action="append", - help="指定模型名,可重复传入,或使用逗号分隔多个值,例如 --model qwen3.5-35b-a3b", - ) - parser.add_argument( - "--case", - action="append", - help=( - "指定测试场景名,可选值包括 " - "fresh_tool_call、history_assistant_tool_calls_with_content、" - "history_assistant_tool_calls_without_content、history_tool_result_followup、" - "history_multiple_tool_calls_and_results、history_tool_calls_with_current_tools" - ), - ) - parser.add_argument( - "--repeat", - type=int, - default=1, - help="每个模型每个场景重复测试次数,默认 1", - ) - parser.add_argument( - "--max-tokens", - type=int, - default=512, - help="单次测试的最大输出 token 数,默认 512", - ) - parser.add_argument( - "--temperature", - type=float, - default=0.0, - help="单次测试温度,默认 0.0,以尽量提高稳定性", - ) - parser.add_argument( - "--fallback-task", - default="utils", - help="当指定模型未被已选任务引用时,用于挂载该模型的任务名,默认 utils", - ) - parser.add_argument( - "--json-out", - help="可选,将结果写入指定 JSON 文件", - ) - parser.add_argument( - "--list-targets", - action="store_true", - help="仅打印当前任务与模型映射,不发起网络请求", - ) - parser.add_argument( - "--show-response", - action="store_true", - help="打印模型返回的可见文本内容", - ) - return parser - - -def main() -> int: - """脚本入口。""" - _ensure_utf8_console() - config_manager.initialize() - parser = _build_parser() - args = parser.parse_args() - - if args.repeat < 1: - parser.error("--repeat 必须大于等于 1") - if args.max_tokens < 1: - parser.error("--max-tokens 必须大于等于 1") - - if args.list_targets: - _print_available_targets() - return 0 - - results = asyncio.run(_run_probes(args)) - summary = _build_summary(results) - - print("测试摘要:") - print( - f"total={summary['total']} | passed={summary['passed']} | failed={summary['failed']}" - ) - if summary["failed_items"]: - print("失败明细:") - for failed_item in summary["failed_items"]: - print( - f"- model={failed_item['model_name']} | case={failed_item['case_name']} | " - f"attempt={failed_item['attempt']} | errors={failed_item['errors']}" - ) - - if args.json_out: - _write_json_report(args.json_out, results) - - return 0 if summary["failed"] == 0 else 1 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/src/plugin_runtime/component_query.py b/src/plugin_runtime/component_query.py index f5c3e1f7..e316e625 100644 --- a/src/plugin_runtime/component_query.py +++ b/src/plugin_runtime/component_query.py @@ -748,6 +748,9 @@ class ComponentQueryService: return payload for key, value in context_payload.items(): + if key in {"stream_id", "chat_id"}: + payload[key] = value + continue if key not in payload or not payload.get(key): payload[key] = value return payload diff --git a/tests/test_config_upgrade_hooks.py b/tests/test_config_upgrade_hooks.py deleted file mode 100644 index 44bafaf6..00000000 --- a/tests/test_config_upgrade_hooks.py +++ /dev/null @@ -1,76 +0,0 @@ -from pathlib import Path - -import sys - -sys.path.insert(0, str(Path(__file__).resolve().parents[1])) - -from src.config.config_upgrade_hooks import ( - BOT_CONFIG_UPGRADE_HOOKS, - ConfigUpgradeHook, - apply_config_upgrade_hooks, - set_nested_config_value, -) -from src.config.official_configs import ChatConfig - -import src.config.config_upgrade_hooks as hooks - - -def test_apply_config_upgrade_hooks_runs_when_target_version_is_crossed(monkeypatch): - def migrate(data): - changed = set_nested_config_value(data, ("chat", "enable"), False) - return ["chat.enable"] if changed else [] - - monkeypatch.setattr( - hooks, - "BOT_CONFIG_UPGRADE_HOOKS", - (ConfigUpgradeHook(target_version="8.10.11", config_names=("bot_config.toml",), migrate=migrate),), - ) - - data = {"chat": {"enable": True}} - result = apply_config_upgrade_hooks(data, "bot_config.toml", "8.10.10", "8.10.11") - - assert result.migrated is True - assert result.reason == "8.10.11:chat.enable" - assert result.data["chat"]["enable"] is False - - -def test_apply_config_upgrade_hooks_skips_versions_outside_upgrade_range(monkeypatch): - def migrate(data): - set_nested_config_value(data, ("chat", "enable"), False) - return ["chat.enable"] - - monkeypatch.setattr( - hooks, - "BOT_CONFIG_UPGRADE_HOOKS", - (ConfigUpgradeHook(target_version="8.10.11", config_names=("bot_config.toml",), migrate=migrate),), - ) - - data = {"chat": {"enable": True}} - result = apply_config_upgrade_hooks(data, "bot_config.toml", "8.10.11", "8.10.12") - - assert result.migrated is False - assert result.data["chat"]["enable"] is True - - -def test_set_nested_config_value_can_keep_existing_value(): - data = {"webui": {"port": 8001}} - - changed = set_nested_config_value(data, ("webui", "port"), 8080, force=False) - - assert changed is False - assert data["webui"]["port"] == 8001 - - -def test_builtin_hook_resets_group_chat_prompt_when_upgrading_from_8_10_10(): - data = {"chat": {"group_chat_prompt": "自定义旧提示词"}} - - result = apply_config_upgrade_hooks(data, "bot_config.toml", "8.10.10", "8.10.11") - - assert result.migrated is True - assert result.reason == "8.10.11:chat.group_chat_prompt" - assert result.data["chat"]["group_chat_prompt"] == ChatConfig().group_chat_prompt - - -def test_bot_config_upgrade_hooks_register_group_chat_prompt_reset(): - assert len(BOT_CONFIG_UPGRADE_HOOKS) == 1 - assert BOT_CONFIG_UPGRADE_HOOKS[0].target_version == "8.10.11"