refactor: 优化数据库操作和模型定义,增强表达方式和黑话表的插入逻辑

This commit is contained in:
DrSmoothl
2026-03-22 00:22:24 +08:00
parent baabe4463e
commit 56a6d2fd8c
5 changed files with 195 additions and 17 deletions

View File

@@ -0,0 +1,78 @@
"""测试表达方式表结构和基础插入行为。"""
from typing import Generator
import pytest
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine
from src.common.database.database_model import Expression
@pytest.fixture(name="expression_engine")
def expression_engine_fixture() -> Generator:
"""创建仅用于表达方式表测试的内存数据库引擎。
Yields:
Generator: 供测试使用的 SQLite 内存引擎。
"""
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
yield engine
def test_expression_insert_assigns_auto_increment_id(expression_engine) -> None:
"""表达方式表在新库中应能自动分配自增主键。"""
with Session(expression_engine) as session:
expression = Expression(
situation="表达情绪高涨或生理反应",
style="发送💦表情符号",
content_list='["表达情绪高涨或生理反应"]',
count=1,
session_id="session-a",
checked=False,
rejected=False,
)
session.add(expression)
session.commit()
session.refresh(expression)
assert expression.id is not None
assert expression.id > 0
def test_expression_insert_allows_same_situation_style(expression_engine) -> None:
"""相同情景和风格的表达方式记录不应再被错误绑定到复合主键。"""
with Session(expression_engine) as session:
first_expression = Expression(
situation="对重复行为的默契响应",
style="持续性跟发相同内容",
content_list='["对重复行为的默契响应"]',
count=1,
session_id="session-a",
checked=False,
rejected=False,
)
second_expression = Expression(
situation="对重复行为的默契响应",
style="持续性跟发相同内容",
content_list='["对重复行为的默契响应-变体"]',
count=2,
session_id="session-b",
checked=False,
rejected=False,
)
session.add(first_expression)
session.add(second_expression)
session.commit()
session.refresh(first_expression)
session.refresh(second_expression)
assert first_expression.id is not None
assert second_expression.id is not None
assert first_expression.id != second_expression.id

View File

@@ -0,0 +1,84 @@
"""测试黑话表结构和基础插入行为。"""
from typing import Generator
import pytest
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine
from src.common.database.database_model import Jargon
@pytest.fixture(name="jargon_engine")
def jargon_engine_fixture() -> Generator:
"""创建仅用于黑话表测试的内存数据库引擎。
Yields:
Generator: 供测试使用的 SQLite 内存引擎。
"""
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
yield engine
def test_jargon_insert_assigns_auto_increment_id(jargon_engine) -> None:
"""黑话表在新库中应能自动分配自增主键。"""
with Session(jargon_engine) as session:
jargon = Jargon(
content="VF8V4L",
raw_content='["[1] test"]',
meaning="",
session_id_dict='{"session-a": 1}',
count=1,
is_jargon=True,
is_complete=False,
is_global=True,
last_inference_count=0,
)
session.add(jargon)
session.commit()
session.refresh(jargon)
assert jargon.id is not None
assert jargon.id > 0
def test_jargon_insert_allows_same_content_with_different_rows(jargon_engine) -> None:
"""黑话内容不应再被错误地绑成复合主键的一部分。"""
with Session(jargon_engine) as session:
first_jargon = Jargon(
content="表情1",
raw_content='["[1] first"]',
meaning="",
session_id_dict='{"session-a": 1}',
count=1,
is_jargon=True,
is_complete=False,
is_global=False,
last_inference_count=0,
)
second_jargon = Jargon(
content="表情1",
raw_content='["[1] second"]',
meaning="",
session_id_dict='{"session-b": 1}',
count=1,
is_jargon=True,
is_complete=False,
is_global=False,
last_inference_count=0,
)
session.add(first_jargon)
session.add(second_jargon)
session.commit()
session.refresh(first_jargon)
session.refresh(second_jargon)
assert first_jargon.id is not None
assert second_jargon.id is not None
assert first_jargon.id != second_jargon.id

View File

@@ -1,8 +1,9 @@
from typing import Optional
from sqlalchemy import Column, Float, Enum as SQLEnum, DateTime
from sqlmodel import SQLModel, Field, LargeBinary
from enum import Enum
from datetime import datetime
from enum import Enum
from typing import Optional
from sqlalchemy import Column, DateTime, Enum as SQLEnum, Float
from sqlmodel import Field, LargeBinary, SQLModel
class ModelUser(str, Enum):
@@ -172,8 +173,8 @@ class Expression(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
situation: str = Field(index=True, max_length=255, primary_key=True) # 情景
style: str = Field(index=True, max_length=255, primary_key=True) # 风格
situation: str = Field(index=True, max_length=255) # 情景
style: str = Field(index=True, max_length=255) # 风格
# context: str # 上下文
# up_content: str
@@ -200,7 +201,7 @@ class Jargon(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
content: str = Field(index=True, max_length=255, primary_key=True) # 黑话内容
content: str = Field(index=True, max_length=255) # 黑话内容
raw_content: Optional[str] = Field(default=None, nullable=True) # 原始内容未处理的黑话内容为List[str]
meaning: str # 黑话含义

View File

@@ -329,7 +329,13 @@ class ExpressionLearner:
return filtered_expressions
# ====== DB 操作相关 ======
async def _upsert_expression_to_db(self, situation: str, style: str):
async def _upsert_expression_to_db(self, situation: str, style: str) -> None:
"""将表达方式写入数据库,存在时更新,不存在时新增。
Args:
situation: 表达方式对应的使用情景。
style: 表达方式风格。
"""
expr, similarity = self._find_similar_expression(situation) or (None, 0)
if expr:
# 根据相似度决定是否使用 LLM 总结
@@ -340,7 +346,13 @@ class ExpressionLearner:
# 没有找到匹配的记录,创建新记录
self._create_expression(situation, style)
def _create_expression(self, situation: str, style: str):
def _create_expression(self, situation: str, style: str) -> None:
"""创建新的表达方式记录。
Args:
situation: 表达方式对应的使用情景。
style: 表达方式风格。
"""
content_list = [situation]
try:
with get_db_session() as db:
@@ -353,6 +365,7 @@ class ExpressionLearner:
last_active_time=datetime.now(),
)
db.add(new_expr)
db.flush()
except Exception as e:
logger.error(f"创建表达方式失败: {e}")

View File

@@ -1,17 +1,18 @@
from collections import OrderedDict
from json_repair import repair_json
from sqlmodel import select
from typing import List, Optional, Dict, Callable, TypedDict, Set
from typing import Callable, Dict, List, Optional, Set, TypedDict
import asyncio
import json
import random
from src.common.logger import get_logger
from json_repair import repair_json
from sqlmodel import select
from src.common.data_models.jargon_data_model import MaiJargon
from src.common.database.database import get_db_session
from src.common.database.database_model import Jargon
from src.common.data_models.jargon_data_model import MaiJargon
from src.config.config import model_config, global_config
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.prompt.prompt_manager import prompt_manager
@@ -273,11 +274,12 @@ class JargonMiner:
try:
with get_db_session() as session:
session.add(new_jargon)
session.flush()
saved += 1
self._add_to_cache(content)
except Exception as e:
logger.error(f"保存新黑话 '{content}' 失败: {e}")
continue
finally:
self._add_to_cache(content)
# 固定输出提取的jargon结果格式化为可读形式只要有提取结果就输出
if uniq_entries:
# 收集所有提取的jargon内容