From 56a6d2fd8ce13459395334c64d96ff6c991aef9c Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Sun, 22 Mar 2026 00:22:24 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E6=93=8D=E4=BD=9C=E5=92=8C=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E5=AE=9A=E4=B9=89=EF=BC=8C=E5=A2=9E=E5=BC=BA=E8=A1=A8=E8=BE=BE?= =?UTF-8?q?=E6=96=B9=E5=BC=8F=E5=92=8C=E9=BB=91=E8=AF=9D=E8=A1=A8=E7=9A=84?= =?UTF-8?q?=E6=8F=92=E5=85=A5=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytests/common_test/test_expression_schema.py | 78 +++++++++++++++++ pytests/common_test/test_jargon_schema.py | 84 +++++++++++++++++++ src/common/database/database_model.py | 15 ++-- src/learners/expression_learner.py | 17 +++- src/learners/jargon_miner.py | 18 ++-- 5 files changed, 195 insertions(+), 17 deletions(-) create mode 100644 pytests/common_test/test_expression_schema.py create mode 100644 pytests/common_test/test_jargon_schema.py diff --git a/pytests/common_test/test_expression_schema.py b/pytests/common_test/test_expression_schema.py new file mode 100644 index 00000000..31fcd98f --- /dev/null +++ b/pytests/common_test/test_expression_schema.py @@ -0,0 +1,78 @@ +"""测试表达方式表结构和基础插入行为。""" + +from typing import Generator + +import pytest +from sqlalchemy.pool import StaticPool +from sqlmodel import Session, SQLModel, create_engine + +from src.common.database.database_model import Expression + + +@pytest.fixture(name="expression_engine") +def expression_engine_fixture() -> Generator: + """创建仅用于表达方式表测试的内存数据库引擎。 + + Yields: + Generator: 供测试使用的 SQLite 内存引擎。 + """ + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + SQLModel.metadata.create_all(engine) + yield engine + + +def test_expression_insert_assigns_auto_increment_id(expression_engine) -> None: + """表达方式表在新库中应能自动分配自增主键。""" + with Session(expression_engine) as session: + expression = Expression( + situation="表达情绪高涨或生理反应", + style="发送💦表情符号", + content_list='["表达情绪高涨或生理反应"]', + count=1, + session_id="session-a", + checked=False, + rejected=False, + ) + session.add(expression) + session.commit() + session.refresh(expression) + + assert expression.id is not None + assert expression.id > 0 + + +def test_expression_insert_allows_same_situation_style(expression_engine) -> None: + """相同情景和风格的表达方式记录不应再被错误绑定到复合主键。""" + with Session(expression_engine) as session: + first_expression = Expression( + situation="对重复行为的默契响应", + style="持续性跟发相同内容", + content_list='["对重复行为的默契响应"]', + count=1, + session_id="session-a", + checked=False, + rejected=False, + ) + second_expression = Expression( + situation="对重复行为的默契响应", + style="持续性跟发相同内容", + content_list='["对重复行为的默契响应-变体"]', + count=2, + session_id="session-b", + checked=False, + rejected=False, + ) + + session.add(first_expression) + session.add(second_expression) + session.commit() + session.refresh(first_expression) + session.refresh(second_expression) + + assert first_expression.id is not None + assert second_expression.id is not None + assert first_expression.id != second_expression.id diff --git a/pytests/common_test/test_jargon_schema.py b/pytests/common_test/test_jargon_schema.py new file mode 100644 index 00000000..909392ab --- /dev/null +++ b/pytests/common_test/test_jargon_schema.py @@ -0,0 +1,84 @@ +"""测试黑话表结构和基础插入行为。""" + +from typing import Generator + +import pytest +from sqlalchemy.pool import StaticPool +from sqlmodel import Session, SQLModel, create_engine + +from src.common.database.database_model import Jargon + + +@pytest.fixture(name="jargon_engine") +def jargon_engine_fixture() -> Generator: + """创建仅用于黑话表测试的内存数据库引擎。 + + Yields: + Generator: 供测试使用的 SQLite 内存引擎。 + """ + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + SQLModel.metadata.create_all(engine) + yield engine + + +def test_jargon_insert_assigns_auto_increment_id(jargon_engine) -> None: + """黑话表在新库中应能自动分配自增主键。""" + with Session(jargon_engine) as session: + jargon = Jargon( + content="VF8V4L", + raw_content='["[1] test"]', + meaning="", + session_id_dict='{"session-a": 1}', + count=1, + is_jargon=True, + is_complete=False, + is_global=True, + last_inference_count=0, + ) + session.add(jargon) + session.commit() + session.refresh(jargon) + + assert jargon.id is not None + assert jargon.id > 0 + + +def test_jargon_insert_allows_same_content_with_different_rows(jargon_engine) -> None: + """黑话内容不应再被错误地绑成复合主键的一部分。""" + with Session(jargon_engine) as session: + first_jargon = Jargon( + content="表情1", + raw_content='["[1] first"]', + meaning="", + session_id_dict='{"session-a": 1}', + count=1, + is_jargon=True, + is_complete=False, + is_global=False, + last_inference_count=0, + ) + second_jargon = Jargon( + content="表情1", + raw_content='["[1] second"]', + meaning="", + session_id_dict='{"session-b": 1}', + count=1, + is_jargon=True, + is_complete=False, + is_global=False, + last_inference_count=0, + ) + + session.add(first_jargon) + session.add(second_jargon) + session.commit() + session.refresh(first_jargon) + session.refresh(second_jargon) + + assert first_jargon.id is not None + assert second_jargon.id is not None + assert first_jargon.id != second_jargon.id diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index a0993a77..5b274c43 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -1,8 +1,9 @@ -from typing import Optional -from sqlalchemy import Column, Float, Enum as SQLEnum, DateTime -from sqlmodel import SQLModel, Field, LargeBinary -from enum import Enum from datetime import datetime +from enum import Enum +from typing import Optional + +from sqlalchemy import Column, DateTime, Enum as SQLEnum, Float +from sqlmodel import Field, LargeBinary, SQLModel class ModelUser(str, Enum): @@ -172,8 +173,8 @@ class Expression(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) # 自增主键 - situation: str = Field(index=True, max_length=255, primary_key=True) # 情景 - style: str = Field(index=True, max_length=255, primary_key=True) # 风格 + situation: str = Field(index=True, max_length=255) # 情景 + style: str = Field(index=True, max_length=255) # 风格 # context: str # 上下文 # up_content: str @@ -200,7 +201,7 @@ class Jargon(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) # 自增主键 - content: str = Field(index=True, max_length=255, primary_key=True) # 黑话内容 + content: str = Field(index=True, max_length=255) # 黑话内容 raw_content: Optional[str] = Field(default=None, nullable=True) # 原始内容,未处理的黑话内容,为List[str] meaning: str # 黑话含义 diff --git a/src/learners/expression_learner.py b/src/learners/expression_learner.py index 43e4ee7d..156fedc5 100644 --- a/src/learners/expression_learner.py +++ b/src/learners/expression_learner.py @@ -329,7 +329,13 @@ class ExpressionLearner: return filtered_expressions # ====== DB 操作相关 ====== - async def _upsert_expression_to_db(self, situation: str, style: str): + async def _upsert_expression_to_db(self, situation: str, style: str) -> None: + """将表达方式写入数据库,存在时更新,不存在时新增。 + + Args: + situation: 表达方式对应的使用情景。 + style: 表达方式风格。 + """ expr, similarity = self._find_similar_expression(situation) or (None, 0) if expr: # 根据相似度决定是否使用 LLM 总结 @@ -340,7 +346,13 @@ class ExpressionLearner: # 没有找到匹配的记录,创建新记录 self._create_expression(situation, style) - def _create_expression(self, situation: str, style: str): + def _create_expression(self, situation: str, style: str) -> None: + """创建新的表达方式记录。 + + Args: + situation: 表达方式对应的使用情景。 + style: 表达方式风格。 + """ content_list = [situation] try: with get_db_session() as db: @@ -353,6 +365,7 @@ class ExpressionLearner: last_active_time=datetime.now(), ) db.add(new_expr) + db.flush() except Exception as e: logger.error(f"创建表达方式失败: {e}") diff --git a/src/learners/jargon_miner.py b/src/learners/jargon_miner.py index 2fbf8a2e..674e5cc0 100644 --- a/src/learners/jargon_miner.py +++ b/src/learners/jargon_miner.py @@ -1,17 +1,18 @@ from collections import OrderedDict -from json_repair import repair_json -from sqlmodel import select -from typing import List, Optional, Dict, Callable, TypedDict, Set +from typing import Callable, Dict, List, Optional, Set, TypedDict import asyncio import json import random -from src.common.logger import get_logger +from json_repair import repair_json +from sqlmodel import select + +from src.common.data_models.jargon_data_model import MaiJargon from src.common.database.database import get_db_session from src.common.database.database_model import Jargon -from src.common.data_models.jargon_data_model import MaiJargon -from src.config.config import model_config, global_config +from src.common.logger import get_logger +from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest from src.prompt.prompt_manager import prompt_manager @@ -273,11 +274,12 @@ class JargonMiner: try: with get_db_session() as session: session.add(new_jargon) + session.flush() + saved += 1 + self._add_to_cache(content) except Exception as e: logger.error(f"保存新黑话 '{content}' 失败: {e}") continue - finally: - self._add_to_cache(content) # 固定输出提取的jargon结果,格式化为可读形式(只要有提取结果就输出) if uniq_entries: # 收集所有提取的jargon内容