feat:webui支持更加优化的模型配置,优化多处UI体验,支持设置视觉和cache价格,修复多重表达不生效的问题,修复表情包路径错误

This commit is contained in:
SengokuCola
2026-05-04 22:52:41 +08:00
parent 14b7bc78a2
commit eea95c1961
38 changed files with 1188 additions and 454 deletions

View File

@@ -10,7 +10,7 @@ from sqlmodel import col, delete, select
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.common.database.database_model import ChatSession, Expression, Messages, ModifiedBy
from src.common.logger import get_logger
from src.webui.dependencies import require_auth
@@ -28,6 +28,7 @@ class ExpressionResponse(BaseModel):
style: str
last_active_time: float
chat_id: str
chat_name: Optional[str] = None
create_date: Optional[float]
checked: bool
rejected: bool
@@ -90,7 +91,61 @@ class ExpressionCreateResponse(BaseModel):
data: ExpressionResponse
def expression_to_response(expression: Expression) -> ExpressionResponse:
def get_chat_name_from_latest_message(chat_id: str, db_session: Any) -> Optional[str]:
"""从最近消息中解析聊天显示名称。"""
statement = (
select(Messages)
.where(col(Messages.session_id) == chat_id)
.order_by(col(Messages.timestamp).desc())
.limit(1)
)
message = db_session.exec(statement).first()
if not message:
return None
if message.group_id:
return message.group_name or f"群聊{message.group_id}"
return message.user_cardname or message.user_nickname or (f"用户{message.user_id}" if message.user_id else None)
def get_chat_name_from_session_record(chat_session: ChatSession) -> str:
"""从会话记录推断兜底显示名称。"""
if chat_session.group_id:
return f"群聊{chat_session.group_id}"
if chat_session.user_id:
return f"用户{chat_session.user_id}"
return chat_session.session_id
def get_chat_name(chat_id: str, db_session: Optional[Any] = None) -> str:
"""根据聊天 ID 获取聊天名称。
Args:
chat_id: 聊天会话 ID。
db_session: 可选数据库会话,用于从历史消息中解析群名或私聊用户名。
Returns:
str: 聊天显示名称,获取失败时返回原始聊天 ID。
"""
try:
if name := _chat_manager.get_session_name(chat_id):
return name
if db_session and (name := get_chat_name_from_latest_message(chat_id, db_session)):
return name
session = _chat_manager.get_session_by_session_id(chat_id)
if session:
if session.group_id:
return f"群聊{session.group_id}"
if session.user_id:
return f"用户{session.user_id}"
return chat_id
except Exception:
return chat_id
def expression_to_response(expression: Expression, db_session: Optional[Any] = None) -> ExpressionResponse:
"""将表达方式模型转换为响应对象。
Args:
@@ -101,38 +156,21 @@ def expression_to_response(expression: Expression) -> ExpressionResponse:
"""
last_active_time = expression.last_active_time.timestamp() if expression.last_active_time else 0.0
create_date = expression.create_time.timestamp() if expression.create_time else None
chat_id = expression.session_id or ""
return ExpressionResponse(
id=expression.id if expression.id is not None else 0,
situation=expression.situation,
style=expression.style,
last_active_time=last_active_time,
chat_id=expression.session_id or "",
chat_id=chat_id,
chat_name=get_chat_name(chat_id, db_session) if chat_id else None,
create_date=create_date,
checked=False,
rejected=False,
modified_by=None,
checked=expression.checked,
rejected=expression.rejected,
modified_by=expression.modified_by.value if expression.modified_by else None,
)
def get_chat_name(chat_id: str) -> str:
"""根据聊天 ID 获取聊天名称。
Args:
chat_id: 聊天会话 ID。
Returns:
str: 聊天显示名称,获取失败时返回原始聊天 ID。
"""
try:
session = _chat_manager.get_session_by_session_id(chat_id)
if not session:
return chat_id
name = _chat_manager.get_session_name(chat_id)
return name or chat_id
except Exception:
return chat_id
def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
"""批量获取聊天名称。
@@ -145,8 +183,7 @@ def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
try:
for chat_id in chat_ids:
if name := _chat_manager.get_session_name(chat_id):
result[chat_id] = name
result[chat_id] = get_chat_name(chat_id)
except Exception as e:
logger.warning(f"批量获取聊天名称失败: {e}")
return result
@@ -176,19 +213,43 @@ async def get_chat_list() -> ChatListResponse:
ChatListResponse: 可用于下拉选择的聊天列表。
"""
try:
chat_list = []
chat_by_id: Dict[str, ChatInfo] = {}
for session_id, session in _chat_manager.sessions.items():
chat_name = _chat_manager.get_session_name(session_id) or session_id
chat_list.append(
ChatInfo(
chat_id=session_id,
chat_name=chat_name,
platform=session.platform,
is_group=session.is_group_session,
)
chat_by_id[session_id] = ChatInfo(
chat_id=session_id,
chat_name=chat_name,
platform=session.platform,
is_group=session.is_group_session,
)
with get_db_session() as session:
for chat_session in session.exec(select(ChatSession)).all():
if chat_session.session_id in chat_by_id:
continue
chat_name = get_chat_name_from_latest_message(chat_session.session_id, session)
chat_by_id[chat_session.session_id] = ChatInfo(
chat_id=chat_session.session_id,
chat_name=chat_name or get_chat_name_from_session_record(chat_session),
platform=chat_session.platform,
is_group=bool(chat_session.group_id),
)
expression_chat_ids = {
chat_id for chat_id in session.exec(select(Expression.session_id)).all() if chat_id
}
for session_id in expression_chat_ids:
if session_id in chat_by_id:
continue
chat_by_id[session_id] = ChatInfo(
chat_id=session_id,
chat_name=get_chat_name(session_id, session),
platform=None,
is_group=False,
)
# 按名称排序
chat_list = list(chat_by_id.values())
chat_list.sort(key=lambda x: x.chat_name)
return ChatListResponse(success=True, data=chat_list)
@@ -252,7 +313,7 @@ async def get_expression_list(
if chat_id:
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
total = len(session.exec(count_statement).all())
data = [expression_to_response(expr) for expr in expressions]
data = [expression_to_response(expr, session) for expr in expressions]
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
@@ -281,7 +342,7 @@ async def get_expression_detail(expression_id: int) -> ExpressionDetailResponse:
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
data = expression_to_response(expression)
data = expression_to_response(expression, session)
return ExpressionDetailResponse(success=True, data=data)
@@ -321,7 +382,7 @@ async def create_expression(
session.add(expression)
session.flush()
expression_id = expression.id
data = expression_to_response(expression)
data = expression_to_response(expression, session)
logger.info(f"表达方式已创建: ID={expression_id}, situation={request.situation}")
@@ -375,7 +436,7 @@ async def update_expression(
db_expression.session_id = update_data["session_id"]
db_expression.last_active_time = update_data["last_active_time"]
session.add(db_expression)
data = expression_to_response(db_expression)
data = expression_to_response(db_expression, session)
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
@@ -524,6 +585,22 @@ class ReviewStatsResponse(BaseModel):
user_checked: int
def apply_review_filter(statement: Any, filter_type: str) -> Any:
"""按审核状态过滤表达方式查询。"""
if filter_type == "unchecked":
return statement.where(col(Expression.checked).is_(False))
if filter_type == "passed":
return statement.where(col(Expression.checked).is_(True), col(Expression.rejected).is_(False))
if filter_type == "rejected":
return statement.where(col(Expression.checked).is_(True), col(Expression.rejected).is_(True))
return statement
def count_expressions(session: Any, statement: Any) -> int:
"""统计表达方式查询结果数量。"""
return len(session.exec(statement).all())
@router.get("/review/stats", response_model=ReviewStatsResponse)
async def get_review_stats() -> ReviewStatsResponse:
"""获取审核统计数据。
@@ -533,12 +610,24 @@ async def get_review_stats() -> ReviewStatsResponse:
"""
try:
with get_db_session() as session:
total = len(session.exec(select(Expression.id)).all())
unchecked = 0
passed = 0
rejected = 0
ai_checked = 0
user_checked = 0
total = count_expressions(session, select(Expression.id))
unchecked = count_expressions(session, apply_review_filter(select(Expression.id), "unchecked"))
passed = count_expressions(session, apply_review_filter(select(Expression.id), "passed"))
rejected = count_expressions(session, apply_review_filter(select(Expression.id), "rejected"))
ai_checked = count_expressions(
session,
select(Expression.id).where(
col(Expression.checked).is_(True),
col(Expression.modified_by) == ModifiedBy.AI,
),
)
user_checked = count_expressions(
session,
select(Expression.id).where(
col(Expression.checked).is_(True),
col(Expression.modified_by) == ModifiedBy.USER,
),
)
return ReviewStatsResponse(
total=total,
@@ -587,10 +676,7 @@ async def get_review_list(
ReviewListResponse: 审核列表响应。
"""
try:
statement = select(Expression)
if filter_type in {"unchecked", "passed", "rejected"}:
statement = statement.where(col(Expression.id) == -1)
statement = apply_review_filter(select(Expression), filter_type)
# all 不需要额外过滤
# 搜索过滤
@@ -615,9 +701,7 @@ async def get_review_list(
with get_db_session() as session:
expressions = session.exec(statement).all()
count_statement = select(Expression.id)
if filter_type in {"unchecked", "passed", "rejected"}:
count_statement = count_statement.where(col(Expression.id) == -1)
count_statement = apply_review_filter(select(Expression.id), filter_type)
if search:
count_statement = count_statement.where(
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
@@ -625,7 +709,7 @@ async def get_review_list(
if chat_id:
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
total = len(session.exec(count_statement).all())
data = [expression_to_response(expr) for expr in expressions]
data = [expression_to_response(expr, session) for expr in expressions]
return ReviewListResponse(
success=True,
@@ -706,10 +790,10 @@ async def batch_review_expressions(
failed += 1
continue
# 冲突检测
if item.require_unchecked:
# 冲突检测:未审核列表发起的操作只允许处理仍处于未审核状态的条目。
if item.require_unchecked and expression.checked:
results.append(
BatchReviewResultItem(id=item.id, success=False, message="当前模型不支持审核状态过滤")
BatchReviewResultItem(id=item.id, success=False, message="该表达方式已被审核,请刷新列表后重试")
)
failed += 1
continue
@@ -727,6 +811,9 @@ async def batch_review_expressions(
)
failed += 1
continue
db_expression.checked = True
db_expression.rejected = item.rejected
db_expression.modified_by = ModifiedBy.USER
db_expression.last_active_time = datetime.now()
session.add(db_expression)