feat:webui支持更加优化的模型配置,优化多处UI体验,支持设置视觉和cache价格,修复多重表达不生效的问题,修复表情包路径错误
This commit is contained in:
@@ -10,7 +10,7 @@ from sqlmodel import col, delete, select
|
||||
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.database.database_model import ChatSession, Expression, Messages, ModifiedBy
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
@@ -28,6 +28,7 @@ class ExpressionResponse(BaseModel):
|
||||
style: str
|
||||
last_active_time: float
|
||||
chat_id: str
|
||||
chat_name: Optional[str] = None
|
||||
create_date: Optional[float]
|
||||
checked: bool
|
||||
rejected: bool
|
||||
@@ -90,7 +91,61 @@ class ExpressionCreateResponse(BaseModel):
|
||||
data: ExpressionResponse
|
||||
|
||||
|
||||
def expression_to_response(expression: Expression) -> ExpressionResponse:
|
||||
def get_chat_name_from_latest_message(chat_id: str, db_session: Any) -> Optional[str]:
|
||||
"""从最近消息中解析聊天显示名称。"""
|
||||
|
||||
statement = (
|
||||
select(Messages)
|
||||
.where(col(Messages.session_id) == chat_id)
|
||||
.order_by(col(Messages.timestamp).desc())
|
||||
.limit(1)
|
||||
)
|
||||
message = db_session.exec(statement).first()
|
||||
if not message:
|
||||
return None
|
||||
if message.group_id:
|
||||
return message.group_name or f"群聊{message.group_id}"
|
||||
return message.user_cardname or message.user_nickname or (f"用户{message.user_id}" if message.user_id else None)
|
||||
|
||||
|
||||
def get_chat_name_from_session_record(chat_session: ChatSession) -> str:
|
||||
"""从会话记录推断兜底显示名称。"""
|
||||
|
||||
if chat_session.group_id:
|
||||
return f"群聊{chat_session.group_id}"
|
||||
if chat_session.user_id:
|
||||
return f"用户{chat_session.user_id}"
|
||||
return chat_session.session_id
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str, db_session: Optional[Any] = None) -> str:
|
||||
"""根据聊天 ID 获取聊天名称。
|
||||
|
||||
Args:
|
||||
chat_id: 聊天会话 ID。
|
||||
db_session: 可选数据库会话,用于从历史消息中解析群名或私聊用户名。
|
||||
|
||||
Returns:
|
||||
str: 聊天显示名称,获取失败时返回原始聊天 ID。
|
||||
"""
|
||||
|
||||
try:
|
||||
if name := _chat_manager.get_session_name(chat_id):
|
||||
return name
|
||||
if db_session and (name := get_chat_name_from_latest_message(chat_id, db_session)):
|
||||
return name
|
||||
session = _chat_manager.get_session_by_session_id(chat_id)
|
||||
if session:
|
||||
if session.group_id:
|
||||
return f"群聊{session.group_id}"
|
||||
if session.user_id:
|
||||
return f"用户{session.user_id}"
|
||||
return chat_id
|
||||
except Exception:
|
||||
return chat_id
|
||||
|
||||
|
||||
def expression_to_response(expression: Expression, db_session: Optional[Any] = None) -> ExpressionResponse:
|
||||
"""将表达方式模型转换为响应对象。
|
||||
|
||||
Args:
|
||||
@@ -101,38 +156,21 @@ def expression_to_response(expression: Expression) -> ExpressionResponse:
|
||||
"""
|
||||
last_active_time = expression.last_active_time.timestamp() if expression.last_active_time else 0.0
|
||||
create_date = expression.create_time.timestamp() if expression.create_time else None
|
||||
chat_id = expression.session_id or ""
|
||||
return ExpressionResponse(
|
||||
id=expression.id if expression.id is not None else 0,
|
||||
situation=expression.situation,
|
||||
style=expression.style,
|
||||
last_active_time=last_active_time,
|
||||
chat_id=expression.session_id or "",
|
||||
chat_id=chat_id,
|
||||
chat_name=get_chat_name(chat_id, db_session) if chat_id else None,
|
||||
create_date=create_date,
|
||||
checked=False,
|
||||
rejected=False,
|
||||
modified_by=None,
|
||||
checked=expression.checked,
|
||||
rejected=expression.rejected,
|
||||
modified_by=expression.modified_by.value if expression.modified_by else None,
|
||||
)
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""根据聊天 ID 获取聊天名称。
|
||||
|
||||
Args:
|
||||
chat_id: 聊天会话 ID。
|
||||
|
||||
Returns:
|
||||
str: 聊天显示名称,获取失败时返回原始聊天 ID。
|
||||
"""
|
||||
try:
|
||||
session = _chat_manager.get_session_by_session_id(chat_id)
|
||||
if not session:
|
||||
return chat_id
|
||||
name = _chat_manager.get_session_name(chat_id)
|
||||
return name or chat_id
|
||||
except Exception:
|
||||
return chat_id
|
||||
|
||||
|
||||
def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
|
||||
"""批量获取聊天名称。
|
||||
|
||||
@@ -145,8 +183,7 @@ def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
|
||||
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
|
||||
try:
|
||||
for chat_id in chat_ids:
|
||||
if name := _chat_manager.get_session_name(chat_id):
|
||||
result[chat_id] = name
|
||||
result[chat_id] = get_chat_name(chat_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"批量获取聊天名称失败: {e}")
|
||||
return result
|
||||
@@ -176,19 +213,43 @@ async def get_chat_list() -> ChatListResponse:
|
||||
ChatListResponse: 可用于下拉选择的聊天列表。
|
||||
"""
|
||||
try:
|
||||
chat_list = []
|
||||
chat_by_id: Dict[str, ChatInfo] = {}
|
||||
for session_id, session in _chat_manager.sessions.items():
|
||||
chat_name = _chat_manager.get_session_name(session_id) or session_id
|
||||
chat_list.append(
|
||||
ChatInfo(
|
||||
chat_id=session_id,
|
||||
chat_name=chat_name,
|
||||
platform=session.platform,
|
||||
is_group=session.is_group_session,
|
||||
)
|
||||
chat_by_id[session_id] = ChatInfo(
|
||||
chat_id=session_id,
|
||||
chat_name=chat_name,
|
||||
platform=session.platform,
|
||||
is_group=session.is_group_session,
|
||||
)
|
||||
|
||||
with get_db_session() as session:
|
||||
for chat_session in session.exec(select(ChatSession)).all():
|
||||
if chat_session.session_id in chat_by_id:
|
||||
continue
|
||||
chat_name = get_chat_name_from_latest_message(chat_session.session_id, session)
|
||||
chat_by_id[chat_session.session_id] = ChatInfo(
|
||||
chat_id=chat_session.session_id,
|
||||
chat_name=chat_name or get_chat_name_from_session_record(chat_session),
|
||||
platform=chat_session.platform,
|
||||
is_group=bool(chat_session.group_id),
|
||||
)
|
||||
|
||||
expression_chat_ids = {
|
||||
chat_id for chat_id in session.exec(select(Expression.session_id)).all() if chat_id
|
||||
}
|
||||
for session_id in expression_chat_ids:
|
||||
if session_id in chat_by_id:
|
||||
continue
|
||||
chat_by_id[session_id] = ChatInfo(
|
||||
chat_id=session_id,
|
||||
chat_name=get_chat_name(session_id, session),
|
||||
platform=None,
|
||||
is_group=False,
|
||||
)
|
||||
|
||||
# 按名称排序
|
||||
chat_list = list(chat_by_id.values())
|
||||
chat_list.sort(key=lambda x: x.chat_name)
|
||||
|
||||
return ChatListResponse(success=True, data=chat_list)
|
||||
@@ -252,7 +313,7 @@ async def get_expression_list(
|
||||
if chat_id:
|
||||
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
|
||||
total = len(session.exec(count_statement).all())
|
||||
data = [expression_to_response(expr) for expr in expressions]
|
||||
data = [expression_to_response(expr, session) for expr in expressions]
|
||||
|
||||
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||
|
||||
@@ -281,7 +342,7 @@ async def get_expression_detail(expression_id: int) -> ExpressionDetailResponse:
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
data = expression_to_response(expression)
|
||||
data = expression_to_response(expression, session)
|
||||
|
||||
return ExpressionDetailResponse(success=True, data=data)
|
||||
|
||||
@@ -321,7 +382,7 @@ async def create_expression(
|
||||
session.add(expression)
|
||||
session.flush()
|
||||
expression_id = expression.id
|
||||
data = expression_to_response(expression)
|
||||
data = expression_to_response(expression, session)
|
||||
|
||||
logger.info(f"表达方式已创建: ID={expression_id}, situation={request.situation}")
|
||||
|
||||
@@ -375,7 +436,7 @@ async def update_expression(
|
||||
db_expression.session_id = update_data["session_id"]
|
||||
db_expression.last_active_time = update_data["last_active_time"]
|
||||
session.add(db_expression)
|
||||
data = expression_to_response(db_expression)
|
||||
data = expression_to_response(db_expression, session)
|
||||
|
||||
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
|
||||
|
||||
@@ -524,6 +585,22 @@ class ReviewStatsResponse(BaseModel):
|
||||
user_checked: int
|
||||
|
||||
|
||||
def apply_review_filter(statement: Any, filter_type: str) -> Any:
|
||||
"""按审核状态过滤表达方式查询。"""
|
||||
if filter_type == "unchecked":
|
||||
return statement.where(col(Expression.checked).is_(False))
|
||||
if filter_type == "passed":
|
||||
return statement.where(col(Expression.checked).is_(True), col(Expression.rejected).is_(False))
|
||||
if filter_type == "rejected":
|
||||
return statement.where(col(Expression.checked).is_(True), col(Expression.rejected).is_(True))
|
||||
return statement
|
||||
|
||||
|
||||
def count_expressions(session: Any, statement: Any) -> int:
|
||||
"""统计表达方式查询结果数量。"""
|
||||
return len(session.exec(statement).all())
|
||||
|
||||
|
||||
@router.get("/review/stats", response_model=ReviewStatsResponse)
|
||||
async def get_review_stats() -> ReviewStatsResponse:
|
||||
"""获取审核统计数据。
|
||||
@@ -533,12 +610,24 @@ async def get_review_stats() -> ReviewStatsResponse:
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
total = len(session.exec(select(Expression.id)).all())
|
||||
unchecked = 0
|
||||
passed = 0
|
||||
rejected = 0
|
||||
ai_checked = 0
|
||||
user_checked = 0
|
||||
total = count_expressions(session, select(Expression.id))
|
||||
unchecked = count_expressions(session, apply_review_filter(select(Expression.id), "unchecked"))
|
||||
passed = count_expressions(session, apply_review_filter(select(Expression.id), "passed"))
|
||||
rejected = count_expressions(session, apply_review_filter(select(Expression.id), "rejected"))
|
||||
ai_checked = count_expressions(
|
||||
session,
|
||||
select(Expression.id).where(
|
||||
col(Expression.checked).is_(True),
|
||||
col(Expression.modified_by) == ModifiedBy.AI,
|
||||
),
|
||||
)
|
||||
user_checked = count_expressions(
|
||||
session,
|
||||
select(Expression.id).where(
|
||||
col(Expression.checked).is_(True),
|
||||
col(Expression.modified_by) == ModifiedBy.USER,
|
||||
),
|
||||
)
|
||||
|
||||
return ReviewStatsResponse(
|
||||
total=total,
|
||||
@@ -587,10 +676,7 @@ async def get_review_list(
|
||||
ReviewListResponse: 审核列表响应。
|
||||
"""
|
||||
try:
|
||||
statement = select(Expression)
|
||||
|
||||
if filter_type in {"unchecked", "passed", "rejected"}:
|
||||
statement = statement.where(col(Expression.id) == -1)
|
||||
statement = apply_review_filter(select(Expression), filter_type)
|
||||
# all 不需要额外过滤
|
||||
|
||||
# 搜索过滤
|
||||
@@ -615,9 +701,7 @@ async def get_review_list(
|
||||
with get_db_session() as session:
|
||||
expressions = session.exec(statement).all()
|
||||
|
||||
count_statement = select(Expression.id)
|
||||
if filter_type in {"unchecked", "passed", "rejected"}:
|
||||
count_statement = count_statement.where(col(Expression.id) == -1)
|
||||
count_statement = apply_review_filter(select(Expression.id), filter_type)
|
||||
if search:
|
||||
count_statement = count_statement.where(
|
||||
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
|
||||
@@ -625,7 +709,7 @@ async def get_review_list(
|
||||
if chat_id:
|
||||
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
|
||||
total = len(session.exec(count_statement).all())
|
||||
data = [expression_to_response(expr) for expr in expressions]
|
||||
data = [expression_to_response(expr, session) for expr in expressions]
|
||||
|
||||
return ReviewListResponse(
|
||||
success=True,
|
||||
@@ -706,10 +790,10 @@ async def batch_review_expressions(
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
# 冲突检测
|
||||
if item.require_unchecked:
|
||||
# 冲突检测:未审核列表发起的操作只允许处理仍处于未审核状态的条目。
|
||||
if item.require_unchecked and expression.checked:
|
||||
results.append(
|
||||
BatchReviewResultItem(id=item.id, success=False, message="当前模型不支持审核状态过滤")
|
||||
BatchReviewResultItem(id=item.id, success=False, message="该表达方式已被审核,请刷新列表后重试")
|
||||
)
|
||||
failed += 1
|
||||
continue
|
||||
@@ -727,6 +811,9 @@ async def batch_review_expressions(
|
||||
)
|
||||
failed += 1
|
||||
continue
|
||||
db_expression.checked = True
|
||||
db_expression.rejected = item.rejected
|
||||
db_expression.modified_by = ModifiedBy.USER
|
||||
db_expression.last_active_time = datetime.now()
|
||||
session.add(db_expression)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user