重构绝大部分模块以适配新版本的数据库和数据模型,修复缺少依赖问题,更新 pyproject

This commit is contained in:
DrSmoothl
2026-02-13 20:39:11 +08:00
parent c14736ffca
commit 16b16d2ca6
29 changed files with 2459 additions and 1737 deletions

View File

@@ -3,11 +3,16 @@
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
from pydantic import BaseModel
from typing import Optional, List, Dict
from sqlalchemy import case
from datetime import datetime, timedelta
from sqlalchemy import case, func
from sqlmodel import col, select, delete
from src.common.logger import get_logger
from src.common.database.database_model import Expression, ChatStreams
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.chat.message_receive.chat_stream import get_chat_manager
from src.webui.core import verify_auth_token_from_cookie_or_header
import time
logger = get_logger("webui.expression")
@@ -98,30 +103,32 @@ def verify_auth_token(
def expression_to_response(expression: Expression) -> ExpressionResponse:
"""将 Expression 模型转换为响应对象"""
last_active_time = expression.last_active_time.timestamp() if expression.last_active_time else 0.0
create_date = expression.create_time.timestamp() if expression.create_time else None
return ExpressionResponse(
id=expression.id,
id=expression.id if expression.id is not None else 0,
situation=expression.situation,
style=expression.style,
last_active_time=expression.last_active_time,
chat_id=expression.chat_id,
create_date=expression.create_date,
checked=expression.checked,
rejected=expression.rejected,
modified_by=expression.modified_by,
last_active_time=last_active_time,
chat_id=expression.session_id or "",
create_date=create_date,
checked=False,
rejected=False,
modified_by=None,
)
def get_chat_name(chat_id: str) -> str:
"""根据 chat_id 获取聊天名称"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream:
# 优先使用群聊名称,否则使用用户昵称
if chat_stream.group_name:
return chat_stream.group_name
elif chat_stream.user_nickname:
return chat_stream.user_nickname
return chat_id # 找不到时返回原始ID
chat_stream = get_chat_manager().get_stream(chat_id)
if not chat_stream:
return chat_id
if chat_stream.group_info and chat_stream.group_info.group_name:
return chat_stream.group_info.group_name
if chat_stream.user_info and chat_stream.user_info.user_nickname:
return chat_stream.user_info.user_nickname
return chat_id
except Exception:
return chat_id
@@ -130,12 +137,15 @@ def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
"""批量获取聊天名称"""
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
try:
chat_streams = ChatStreams.select().where(ChatStreams.stream_id.in_(chat_ids))
for cs in chat_streams:
if cs.group_name:
result[cs.stream_id] = cs.group_name
elif cs.user_nickname:
result[cs.stream_id] = cs.user_nickname
chat_manager = get_chat_manager()
for chat_id in chat_ids:
chat_stream = chat_manager.get_stream(chat_id)
if not chat_stream:
continue
if chat_stream.group_info and chat_stream.group_info.group_name:
result[chat_id] = chat_stream.group_info.group_name
elif chat_stream.user_info and chat_stream.user_info.user_nickname:
result[chat_id] = chat_stream.user_info.user_nickname
except Exception as e:
logger.warning(f"批量获取聊天名称失败: {e}")
return result
@@ -172,14 +182,17 @@ async def get_chat_list(maibot_session: Optional[str] = Cookie(None), authorizat
verify_auth_token(maibot_session, authorization)
chat_list = []
for cs in ChatStreams.select():
chat_name = cs.group_name if cs.group_name else (cs.user_nickname if cs.user_nickname else cs.stream_id)
for stream_id, stream in get_chat_manager().streams.items():
chat_name = stream.group_info.group_name if stream.group_info and stream.group_info.group_name else None
if not chat_name and stream.user_info and stream.user_info.user_nickname:
chat_name = stream.user_info.user_nickname
chat_name = chat_name or stream_id
chat_list.append(
ChatInfo(
chat_id=cs.stream_id,
chat_id=stream_id,
chat_name=chat_name,
platform=cs.platform,
is_group=bool(cs.group_id),
platform=stream.platform,
is_group=bool(stream.group_info and stream.group_info.group_id),
)
)
@@ -221,29 +234,39 @@ async def get_expression_list(
verify_auth_token(maibot_session, authorization)
# 构建查询
query = Expression.select()
statement = select(Expression)
# 搜索过滤
if search:
query = query.where((Expression.situation.contains(search)) | (Expression.style.contains(search)))
statement = statement.where(
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
)
# 聊天ID过滤
if chat_id:
query = query.where(Expression.chat_id == chat_id)
statement = statement.where(col(Expression.session_id) == chat_id)
# 排序最后活跃时间倒序NULL 值放在最后)
query = query.order_by(
case((Expression.last_active_time.is_null(), 1), else_=0), Expression.last_active_time.desc()
statement = statement.order_by(
case((col(Expression.last_active_time).is_(None), 1), else_=0),
col(Expression.last_active_time).desc(),
)
# 获取总数
total = query.count()
# 分页
offset = (page - 1) * page_size
expressions = query.offset(offset).limit(page_size)
statement = statement.offset(offset).limit(page_size)
with get_db_session() as session:
expressions = session.exec(statement).all()
count_statement = select(Expression.id)
if search:
count_statement = count_statement.where(
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
)
if chat_id:
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
total = len(session.exec(count_statement).all())
# 转换为响应对象
data = [expression_to_response(expr) for expr in expressions]
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
@@ -272,7 +295,9 @@ async def get_expression_detail(
try:
verify_auth_token(maibot_session, authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
with get_db_session() as session:
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
expression = session.exec(statement).first()
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
@@ -305,16 +330,22 @@ async def create_expression(
try:
verify_auth_token(maibot_session, authorization)
current_time = time.time()
current_time = datetime.now()
# 创建表达方式
expression = Expression.create(
situation=request.situation,
style=request.style,
chat_id=request.chat_id,
last_active_time=current_time,
create_date=current_time,
)
with get_db_session() as session:
expression = Expression(
situation=request.situation,
style=request.style,
context="",
up_content="",
content_list="[]",
count=0,
last_active_time=current_time,
create_time=current_time,
session_id=request.chat_id,
)
session.add(expression)
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
@@ -350,16 +381,18 @@ async def update_expression(
try:
verify_auth_token(maibot_session, authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
with get_db_session() as session:
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
expression = session.exec(statement).first()
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
# 冲突检测:如果要求未检查状态,但已经被检查了
if request.require_unchecked and expression.checked:
if request.require_unchecked and getattr(expression, "checked", False):
raise HTTPException(
status_code=409,
detail=f"此表达方式已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查,请刷新列表",
detail=f"此表达方式已被{'AI自动' if getattr(expression, 'modified_by', None) == 'ai' else '人工'}检查,请刷新列表",
)
# 只更新提供的字段
@@ -376,13 +409,18 @@ async def update_expression(
update_data["modified_by"] = "user"
# 更新最后活跃时间
update_data["last_active_time"] = time.time()
update_data["last_active_time"] = datetime.now()
# 执行更新
for field, value in update_data.items():
setattr(expression, field, value)
expression.save()
with get_db_session() as session:
db_expression = session.exec(select(Expression).where(col(Expression.id) == expression_id).limit(1)).first()
if not db_expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
for field, value in update_data.items():
if hasattr(db_expression, field):
setattr(db_expression, field, value)
session.add(db_expression)
expression = db_expression
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
@@ -414,7 +452,9 @@ async def delete_expression(
try:
verify_auth_token(maibot_session, authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
with get_db_session() as session:
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
expression = session.exec(statement).first()
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
@@ -423,7 +463,8 @@ async def delete_expression(
situation = expression.situation
# 执行删除
expression.delete_instance()
with get_db_session() as session:
session.exec(delete(Expression).where(col(Expression.id) == expression_id))
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
@@ -465,8 +506,9 @@ async def batch_delete_expressions(
raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID")
# 查找所有要删除的表达方式
expressions = Expression.select().where(Expression.id.in_(request.ids))
found_ids = [expr.id for expr in expressions]
with get_db_session() as session:
statements = select(Expression.id).where(col(Expression.id).in_(request.ids))
found_ids = [expr_id for expr_id in session.exec(statements).all()]
# 检查是否有未找到的ID
not_found_ids = set(request.ids) - set(found_ids)
@@ -474,7 +516,9 @@ async def batch_delete_expressions(
logger.warning(f"部分表达方式未找到: {not_found_ids}")
# 执行批量删除
deleted_count = Expression.delete().where(Expression.id.in_(found_ids)).execute()
with get_db_session() as session:
result = session.exec(delete(Expression).where(col(Expression.id).in_(found_ids)))
deleted_count = result.rowcount or 0
logger.info(f"批量删除了 {deleted_count} 个表达方式")
@@ -503,21 +547,21 @@ async def get_expression_stats(
try:
verify_auth_token(maibot_session, authorization)
total = Expression.select().count()
with get_db_session() as session:
total = len(session.exec(select(Expression.id)).all())
# 按 chat_id 统计
chat_stats = {}
for expr in Expression.select(Expression.chat_id):
chat_id = expr.chat_id
chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1
chat_stats = {}
for chat_id in session.exec(select(Expression.session_id)).all():
if chat_id:
chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1
# 获取最近创建的记录数7天内
seven_days_ago = time.time() - (7 * 24 * 60 * 60)
recent = (
Expression.select()
.where((Expression.create_date.is_null(False)) & (Expression.create_date >= seven_days_ago))
.count()
)
seven_days_ago = datetime.now() - timedelta(days=7)
recent_statement = (
select(func.count())
.select_from(Expression)
.where(col(Expression.create_time).is_not(None), col(Expression.create_time) >= seven_days_ago)
)
recent = session.exec(recent_statement).one()
return {
"success": True,
@@ -561,12 +605,13 @@ async def get_review_stats(maibot_session: Optional[str] = Cookie(None), authori
try:
verify_auth_token(maibot_session, authorization)
total = Expression.select().count()
unchecked = Expression.select().where(Expression.checked == False).count()
passed = Expression.select().where((Expression.checked == True) & (Expression.rejected == False)).count()
rejected = Expression.select().where((Expression.checked == True) & (Expression.rejected == True)).count()
ai_checked = Expression.select().where(Expression.modified_by == "ai").count()
user_checked = Expression.select().where(Expression.modified_by == "user").count()
with get_db_session() as session:
total = len(session.exec(select(Expression.id)).all())
unchecked = 0
passed = 0
rejected = 0
ai_checked = 0
user_checked = 0
return ReviewStatsResponse(
total=total,
@@ -620,31 +665,44 @@ async def get_review_list(
try:
verify_auth_token(maibot_session, authorization)
query = Expression.select()
statement = select(Expression)
# 根据筛选类型过滤
if filter_type == "unchecked":
query = query.where(Expression.checked == False)
elif filter_type == "passed":
query = query.where((Expression.checked == True) & (Expression.rejected == False))
elif filter_type == "rejected":
query = query.where((Expression.checked == True) & (Expression.rejected == True))
if filter_type in {"unchecked", "passed", "rejected"}:
statement = statement.where(col(Expression.id) == -1)
# all 不需要额外过滤
# 搜索过滤
if search:
query = query.where((Expression.situation.contains(search)) | (Expression.style.contains(search)))
statement = statement.where(
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
)
# 聊天ID过滤
if chat_id:
query = query.where(Expression.chat_id == chat_id)
statement = statement.where(col(Expression.session_id) == chat_id)
# 排序:创建时间倒序
query = query.order_by(case((Expression.create_date.is_null(), 1), else_=0), Expression.create_date.desc())
statement = statement.order_by(
case((col(Expression.create_time).is_(None), 1), else_=0),
col(Expression.create_time).desc(),
)
total = query.count()
offset = (page - 1) * page_size
expressions = query.offset(offset).limit(page_size)
statement = statement.offset(offset).limit(page_size)
with get_db_session() as session:
expressions = session.exec(statement).all()
count_statement = select(Expression.id)
if filter_type in {"unchecked", "passed", "rejected"}:
count_statement = count_statement.where(col(Expression.id) == -1)
if search:
count_statement = count_statement.where(
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
)
if chat_id:
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
total = len(session.exec(count_statement).all())
return ReviewListResponse(
success=True,
@@ -720,7 +778,8 @@ async def batch_review_expressions(
for item in request.items:
try:
expression = Expression.get_or_none(Expression.id == item.id)
with get_db_session() as session:
expression = session.exec(select(Expression).where(col(Expression.id) == item.id).limit(1)).first()
if not expression:
results.append(
@@ -730,23 +789,28 @@ async def batch_review_expressions(
continue
# 冲突检测
if item.require_unchecked and expression.checked:
if item.require_unchecked:
results.append(
BatchReviewResultItem(
id=item.id,
success=False,
message=f"已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查",
)
BatchReviewResultItem(id=item.id, success=False, message="当前模型不支持审核状态过滤")
)
failed += 1
continue
# 更新状态
expression.checked = True
expression.rejected = item.rejected
expression.modified_by = "user"
expression.last_active_time = time.time()
expression.save()
with get_db_session() as session:
db_expression = session.exec(
select(Expression).where(col(Expression.id) == item.id).limit(1)
).first()
if not db_expression:
results.append(
BatchReviewResultItem(
id=item.id, success=False, message=f"未找到 ID 为 {item.id} 的表达方式"
)
)
failed += 1
continue
db_expression.last_active_time = datetime.now()
session.add(db_expression)
results.append(
BatchReviewResultItem(id=item.id, success=True, message="通过" if not item.rejected else "拒绝")