合并到远程
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
from rich.traceback import install
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from sqlalchemy import create_engine, event, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy import inspect as sqlalchemy_inspect
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlmodel import create_engine, Session
|
||||
from typing import TYPE_CHECKING, Generator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -27,19 +27,12 @@ DATABASE_URL = f"sqlite:///{_DB_FILE}"
|
||||
def set_sqlite_pragma(dbapi_connection: "SQLite3Connection", connection_record):
|
||||
"""
|
||||
为每个新的数据库连接设置 SQLite PRAGMA。
|
||||
|
||||
这些设置优化了并发性能和数据安全性:
|
||||
- journal_mode=WAL: 启用预写式日志,提高并发性能
|
||||
- cache_size: 设置缓存大小为 64MB
|
||||
- foreign_keys: 启用外键约束
|
||||
- synchronous=NORMAL: 平衡性能和数据安全
|
||||
- busy_timeout: 设置1秒超时,避免锁定冲突
|
||||
"""
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute("PRAGMA cache_size=-64000") # 负值表示KB,64000KB = 64MB
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL") # NORMAL 模式在WAL下是安全的
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.execute("PRAGMA busy_timeout=1000") # 1秒超时
|
||||
cursor.close()
|
||||
|
||||
@@ -52,11 +45,12 @@ engine = create_engine(
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
# 创建会话工厂
|
||||
# 创建会话工厂(使用 sqlmodel.Session)
|
||||
SessionLocal = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=engine,
|
||||
class_=Session,
|
||||
)
|
||||
|
||||
|
||||
@@ -96,7 +90,6 @@ def get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
|
||||
session = SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
# 如果启用自动提交且没有异常,则提交事务
|
||||
if auto_commit:
|
||||
session.commit()
|
||||
except Exception:
|
||||
@@ -132,59 +125,3 @@ def get_db() -> Generator[Session, None, None]:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
class _AtomicContext:
|
||||
def __init__(self) -> None:
|
||||
self._session: Session | None = None
|
||||
|
||||
def __enter__(self) -> Session:
|
||||
self._session = SessionLocal()
|
||||
self._session.begin()
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
if self._session is None:
|
||||
return
|
||||
try:
|
||||
if exc_type is None:
|
||||
self._session.commit()
|
||||
else:
|
||||
self._session.rollback()
|
||||
finally:
|
||||
self._session.close()
|
||||
|
||||
|
||||
class DatabaseCompat:
|
||||
"""兼容旧 db 调用接口(Peewee 风格),底层使用 SQLAlchemy。"""
|
||||
|
||||
def connect(self, reuse_if_open: bool = True) -> None:
|
||||
# SQLAlchemy 由 engine 按需管理连接,这里保留兼容入口。
|
||||
_ = reuse_if_open
|
||||
|
||||
def create_tables(self, models: list[type], safe: bool = True) -> None:
|
||||
_ = safe
|
||||
tables = [model.__table__ for model in models if hasattr(model, "__table__")]
|
||||
if not tables:
|
||||
return
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
SQLModel.metadata.create_all(engine, tables=tables)
|
||||
|
||||
def atomic(self) -> _AtomicContext:
|
||||
return _AtomicContext()
|
||||
|
||||
def execute_sql(self, sql: str):
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(text(sql))
|
||||
conn.commit()
|
||||
return result
|
||||
|
||||
def table_exists(self, model: type) -> bool:
|
||||
if not hasattr(model, "__tablename__"):
|
||||
return False
|
||||
inspector = sqlalchemy_inspect(engine)
|
||||
return inspector.has_table(model.__tablename__)
|
||||
|
||||
|
||||
db = DatabaseCompat()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
from sqlalchemy import Column, Float, Enum as SQLEnum
|
||||
from sqlmodel import SQLModel, Field
|
||||
from sqlmodel import SQLModel, Field, LargeBinary
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
|
||||
@@ -45,8 +45,8 @@ class Messages(SQLModel, table=True):
|
||||
is_notify: bool = Field(default=False) # 是否为通知消息
|
||||
|
||||
# 消息内容
|
||||
raw_content: str # base64编码的原始消息内容
|
||||
processed_plain_text: str = Field(index=True) # 平面化处理后的纯文本消息
|
||||
raw_content: bytes = Field(sa_column=Column(LargeBinary)) # base64编码的原始消息内容
|
||||
processed_plain_text: str = Field() # 平面化处理后的纯文本消息
|
||||
display_message: str # 显示的消息内容(被放入Prompt)
|
||||
|
||||
# 其他配置
|
||||
@@ -85,9 +85,9 @@ class Images(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
|
||||
|
||||
# 元信息
|
||||
image_hash: str = Field(default="", max_length=255) # 图片哈希,使用sha256哈希值,亦作为图片唯一ID
|
||||
image_hash: str = Field(index=True, max_length=255) # 图片哈希,使用sha256哈希值,亦作为图片唯一ID
|
||||
description: str # 图片的描述
|
||||
full_path: str = Field(index=True, max_length=1024) # 文件的完整路径 (包括文件名)
|
||||
full_path: str = Field(max_length=1024) # 文件的完整路径 (包括文件名)
|
||||
image_type: ImageType = Field(sa_column=Column(SQLEnum(ImageType)), default=ImageType.EMOJI)
|
||||
"""图片类型,例如 'emoji' 或 'image'"""
|
||||
emotion: Optional[str] = Field(default=None, nullable=True) # 表情包的情感标签,逗号分隔
|
||||
@@ -116,7 +116,7 @@ class ActionRecord(SQLModel, table=True):
|
||||
session_id: str = Field(index=True, max_length=255) # 对应的 ChatSession session_id
|
||||
|
||||
# 调用信息
|
||||
action_name: str = Field(max_length=255) # 动作名称
|
||||
action_name: str = Field(index=True, max_length=255) # 动作名称
|
||||
action_reasoning: Optional[str] = Field(default=None) # 动作推理过程
|
||||
action_data: Optional[str] = Field(default=None) # 动作数据,JSON格式存储
|
||||
|
||||
@@ -153,7 +153,7 @@ class OnlineTime(SQLModel, table=True):
|
||||
timestamp: datetime = Field(default_factory=datetime.now, index=True) # 时间戳
|
||||
duration_minutes: int = Field() # 时长,单位秒
|
||||
start_timestamp: datetime = Field(default_factory=datetime.now) # 上线时间
|
||||
end_timestamp: datetime = Field(index=True) # 下线时间
|
||||
end_timestamp: datetime = Field() # 下线时间
|
||||
|
||||
|
||||
class Expression(SQLModel, table=True):
|
||||
@@ -230,7 +230,68 @@ class ThinkingQuestion(SQLModel, table=True):
|
||||
context: Optional[str] = Field(default=None, nullable=True) # 上下文
|
||||
found_answer: bool = Field(default=False) # 是否找到答案
|
||||
answer: Optional[str] = Field(default=None, nullable=True) # 问题答案
|
||||
|
||||
|
||||
thinking_steps: Optional[str] = Field(default=None, nullable=True) # 思考步骤,JSON格式存储
|
||||
created_timestamp: datetime = Field(default_factory=datetime.now, index=True) # 创建时间
|
||||
updated_timestamp: datetime = Field(default_factory=datetime.now, index=True) # 最后更新时间
|
||||
|
||||
|
||||
class BinaryData(SQLModel, table=True):
|
||||
"""存储二进制数据的模型"""
|
||||
|
||||
__tablename__ = "binary_data" # type: ignore
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
|
||||
|
||||
data_hash: str = Field(index=True, max_length=255) # 数据哈希,使用sha256哈希值,亦作为数据唯一ID
|
||||
full_path: str = Field(max_length=1024) # 文件的完整路径 (包括文件名)
|
||||
|
||||
|
||||
class PersonInfo(SQLModel, table=True):
|
||||
"""存储个人信息的模型"""
|
||||
|
||||
__tablename__ = "person_info" # type: ignore
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
|
||||
|
||||
is_known: bool = Field(default=False) # 是否为已知人
|
||||
person_id: str = Field(unique=True, index=True, max_length=255) # 人员ID
|
||||
person_name: Optional[str] = Field(default=None, max_length=255, nullable=True) # 人员名称
|
||||
name_reason: Optional[str] = Field(default=None, nullable=True) # 名称原因
|
||||
|
||||
# 身份元数据
|
||||
platform: str = Field(index=True, max_length=100) # 平台名称
|
||||
user_id: str = Field(index=True, max_length=255) # 用户ID
|
||||
user_nickname: str = Field(index=True, max_length=255) # 用户昵称
|
||||
group_nickname: Optional[str] = Field(
|
||||
default=None, nullable=True
|
||||
) # 群昵称 (JSON, [{"group_id": str, "group_nick_name": str}])
|
||||
|
||||
# 印象
|
||||
memory_points: Optional[str] = Field(default=None, nullable=True) # 记忆要点,JSON格式存储
|
||||
|
||||
# 认识次数和时间
|
||||
know_counts: int = Field(default=0) # 认识次数
|
||||
first_known_time: Optional[datetime] = Field(default=None, nullable=True) # 首次认识时间
|
||||
last_known_time: Optional[datetime] = Field(default=None, nullable=True) # 最后认识时间
|
||||
|
||||
|
||||
class ChatSession(SQLModel, table=True):
|
||||
"""存储聊天会话的模型"""
|
||||
|
||||
__tablename__ = "chat_sessions" # type: ignore
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
|
||||
|
||||
session_id: str = Field(unique=True, index=True, max_length=255) # 聊天会话ID
|
||||
|
||||
created_timestamp: datetime = Field(default_factory=datetime.now, index=True) # 创建时间
|
||||
last_active_timestamp: datetime = Field(default_factory=datetime.now, index=True) # 最后活跃时间
|
||||
|
||||
# 身份元数据
|
||||
user_id: str = Field(index=True, max_length=255) # 用户ID
|
||||
user_nickname: str = Field(index=True, max_length=255) # 用户昵称
|
||||
user_cardname: Optional[str] = Field(default=None, max_length=255, nullable=True) # 用户备注名
|
||||
group_id: Optional[str] = Field(index=True, default=None, max_length=255, nullable=True) # 群组id
|
||||
group_name: Optional[str] = Field(index=True, default=None, max_length=255, nullable=True) # 群组名称
|
||||
platform: str = Field(index=True, max_length=100) # 用户平台
|
||||
|
||||
Reference in New Issue
Block a user