Refactor API response models and improve documentation
- Updated response model functions in emoji, expression, jargon, and person routers to include detailed docstrings. - Enhanced the clarity of function signatures by specifying return types. - Removed redundant comments and improved code readability. - Added error handling and logging improvements in various endpoints. - Deleted outdated code documentation file.
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
"""黑话(俚语)管理路由"""
|
||||
|
||||
import json
|
||||
from typing import Annotated, Any, Dict, List, Optional, Set
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlmodel import Session, col, delete, select
|
||||
|
||||
import json
|
||||
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import ChatSession, Jargon
|
||||
from src.common.logger import get_logger
|
||||
@@ -21,9 +22,13 @@ router = APIRouter(prefix="/jargon", tags=["Jargon"], dependencies=[Depends(requ
|
||||
|
||||
|
||||
def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
|
||||
"""
|
||||
解析 chat_id 字段,提取所有 stream_id
|
||||
chat_id 格式: [["stream_id", user_id], ...] 或直接是 stream_id 字符串
|
||||
"""解析聊天 ID 字段并提取所有 stream_id。
|
||||
|
||||
Args:
|
||||
chat_id_str: JSON 格式或纯字符串格式的聊天 ID。
|
||||
|
||||
Returns:
|
||||
List[str]: 解析出的 stream_id 列表。
|
||||
"""
|
||||
if not chat_id_str:
|
||||
return []
|
||||
@@ -43,9 +48,14 @@ def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
|
||||
|
||||
|
||||
def get_display_name_for_chat_id(chat_id_str: str, session: Session) -> str:
|
||||
"""
|
||||
获取 chat_id 的显示名称
|
||||
尝试解析 JSON 并查询 ChatSession 表获取群聊名称
|
||||
"""获取聊天 ID 的显示名称。
|
||||
|
||||
Args:
|
||||
chat_id_str: JSON 格式或纯字符串格式的聊天 ID。
|
||||
session: 当前数据库会话。
|
||||
|
||||
Returns:
|
||||
str: 聊天显示名称,无法查询时返回截断后的 stream_id。
|
||||
"""
|
||||
stream_ids = parse_chat_id_to_stream_ids(chat_id_str)
|
||||
|
||||
@@ -175,7 +185,14 @@ class ChatListResponse(BaseModel):
|
||||
|
||||
|
||||
def parse_session_id_dict(session_id_dict_str: Optional[str]) -> Dict[str, int]:
|
||||
"""解析会话计数字典。"""
|
||||
"""解析会话计数字典。
|
||||
|
||||
Args:
|
||||
session_id_dict_str: 数据库中保存的会话计数字典 JSON 字符串。
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: 解析后的会话计数字典。
|
||||
"""
|
||||
if not session_id_dict_str:
|
||||
return {}
|
||||
|
||||
@@ -202,12 +219,26 @@ def parse_session_id_dict(session_id_dict_str: Optional[str]) -> Dict[str, int]:
|
||||
|
||||
|
||||
def dump_session_id_dict(session_counts: Dict[str, int]) -> str:
|
||||
"""序列化会话计数字典。"""
|
||||
"""序列化会话计数字典。
|
||||
|
||||
Args:
|
||||
session_counts: 会话 ID 与出现次数的映射。
|
||||
|
||||
Returns:
|
||||
str: 可写入数据库的 JSON 字符串。
|
||||
"""
|
||||
return json.dumps(session_counts, ensure_ascii=False)
|
||||
|
||||
|
||||
def get_primary_chat_id(session_id_dict_str: Optional[str]) -> str:
|
||||
"""从会话计数字典中选出主聊天 ID。"""
|
||||
"""从会话计数字典中选出主聊天 ID。
|
||||
|
||||
Args:
|
||||
session_id_dict_str: 数据库中保存的会话计数字典 JSON 字符串。
|
||||
|
||||
Returns:
|
||||
str: 出现次数最多的聊天 ID,没有记录时返回空字符串。
|
||||
"""
|
||||
if not (session_counts := parse_session_id_dict(session_id_dict_str)):
|
||||
return ""
|
||||
|
||||
@@ -215,17 +246,41 @@ def get_primary_chat_id(session_id_dict_str: Optional[str]) -> str:
|
||||
|
||||
|
||||
def has_chat_id(session_id_dict_str: Optional[str], chat_id: str) -> bool:
|
||||
"""判断记录是否包含指定聊天 ID。"""
|
||||
"""判断记录是否包含指定聊天 ID。
|
||||
|
||||
Args:
|
||||
session_id_dict_str: 数据库中保存的会话计数字典 JSON 字符串。
|
||||
chat_id: 需要检查的聊天 ID。
|
||||
|
||||
Returns:
|
||||
bool: 记录包含该聊天 ID 时返回 True。
|
||||
"""
|
||||
return chat_id in parse_session_id_dict(session_id_dict_str)
|
||||
|
||||
|
||||
def build_session_id_dict_for_chat(chat_id: str, count: int = 1) -> str:
|
||||
"""为单个聊天 ID 构建会话计数字典。"""
|
||||
"""为单个聊天 ID 构建会话计数字典。
|
||||
|
||||
Args:
|
||||
chat_id: 聊天 ID。
|
||||
count: 该聊天 ID 的出现次数。
|
||||
|
||||
Returns:
|
||||
str: 可写入数据库的会话计数字典 JSON 字符串。
|
||||
"""
|
||||
return dump_session_id_dict({chat_id: count})
|
||||
|
||||
|
||||
def jargon_to_dict(jargon: Jargon, session: Session) -> Dict[str, Any]:
|
||||
"""将 Jargon ORM 对象转换为字典"""
|
||||
"""将黑话模型转换为字典。
|
||||
|
||||
Args:
|
||||
jargon: 数据库中的黑话记录。
|
||||
session: 当前数据库会话,用于查询聊天显示名称。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: WebUI 可直接序列化的黑话数据。
|
||||
"""
|
||||
chat_id = get_primary_chat_id(jargon.session_id_dict)
|
||||
chat_name = get_display_name_for_chat_id(chat_id, session) if chat_id else None
|
||||
|
||||
@@ -255,8 +310,19 @@ async def get_jargon_list(
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
chat_id: Optional[str] = Query(None, description="按聊天ID筛选"),
|
||||
is_jargon: Optional[bool] = Query(None, description="按是否是黑话筛选"),
|
||||
):
|
||||
"""获取黑话列表"""
|
||||
) -> JargonListResponse:
|
||||
"""获取黑话列表。
|
||||
|
||||
Args:
|
||||
page: 页码,从 1 开始。
|
||||
page_size: 每页数量,范围为 1-100。
|
||||
search: 搜索关键词。
|
||||
chat_id: 聊天 ID 筛选条件。
|
||||
is_jargon: 是否为黑话的筛选条件。
|
||||
|
||||
Returns:
|
||||
JargonListResponse: 分页后的黑话列表。
|
||||
"""
|
||||
try:
|
||||
statement = select(Jargon)
|
||||
|
||||
@@ -304,18 +370,21 @@ async def get_jargon_list(
|
||||
|
||||
|
||||
@router.get("/chats", response_model=ChatListResponse)
|
||||
async def get_chat_list():
|
||||
"""获取所有有黑话记录的聊天列表"""
|
||||
async def get_chat_list() -> ChatListResponse:
|
||||
"""获取所有有黑话记录的聊天列表。
|
||||
|
||||
Returns:
|
||||
ChatListResponse: 包含黑话记录的聊天列表。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
jargons = session.exec(select(Jargon)).all()
|
||||
|
||||
seen_stream_ids: Set[str] = set()
|
||||
for jargon in jargons:
|
||||
seen_stream_ids.update(parse_session_id_dict(jargon.session_id_dict).keys())
|
||||
seen_stream_ids: Set[str] = set()
|
||||
for jargon in jargons:
|
||||
seen_stream_ids.update(parse_session_id_dict(jargon.session_id_dict).keys())
|
||||
|
||||
result = []
|
||||
with get_db_session() as session:
|
||||
result: List[ChatInfoResponse] = []
|
||||
for stream_id in seen_stream_ids:
|
||||
if chat_session := session.exec(
|
||||
select(ChatSession).where(col(ChatSession.session_id) == stream_id)
|
||||
@@ -347,25 +416,29 @@ async def get_chat_list():
|
||||
|
||||
|
||||
@router.get("/stats/summary", response_model=JargonStatsResponse)
|
||||
async def get_jargon_stats():
|
||||
"""获取黑话统计数据"""
|
||||
async def get_jargon_stats() -> JargonStatsResponse:
|
||||
"""获取黑话统计数据。
|
||||
|
||||
Returns:
|
||||
JargonStatsResponse: 黑话总数、确认状态和聊天分布统计。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
jargons = session.exec(select(Jargon)).all()
|
||||
|
||||
total = len(jargons)
|
||||
confirmed_jargon = sum(jargon.is_jargon is True for jargon in jargons)
|
||||
confirmed_not_jargon = sum(jargon.is_jargon is False for jargon in jargons)
|
||||
pending = sum(jargon.is_jargon is None for jargon in jargons)
|
||||
complete_count = sum(jargon.is_complete for jargon in jargons)
|
||||
total = len(jargons)
|
||||
confirmed_jargon = sum(jargon.is_jargon is True for jargon in jargons)
|
||||
confirmed_not_jargon = sum(jargon.is_jargon is False for jargon in jargons)
|
||||
pending = sum(jargon.is_jargon is None for jargon in jargons)
|
||||
complete_count = sum(jargon.is_complete for jargon in jargons)
|
||||
|
||||
top_chats_counter: Dict[str, int] = {}
|
||||
for jargon in jargons:
|
||||
for session_id in parse_session_id_dict(jargon.session_id_dict):
|
||||
top_chats_counter[session_id] = top_chats_counter.get(session_id, 0) + 1
|
||||
top_chats_counter: Dict[str, int] = {}
|
||||
for jargon in jargons:
|
||||
for session_id in parse_session_id_dict(jargon.session_id_dict):
|
||||
top_chats_counter[session_id] = top_chats_counter.get(session_id, 0) + 1
|
||||
|
||||
top_chats_dict = dict(sorted(top_chats_counter.items(), key=lambda item: item[1], reverse=True)[:5])
|
||||
chat_count = len(top_chats_counter)
|
||||
top_chats_dict = dict(sorted(top_chats_counter.items(), key=lambda item: item[1], reverse=True)[:5])
|
||||
chat_count = len(top_chats_counter)
|
||||
|
||||
return JargonStatsResponse(
|
||||
success=True,
|
||||
@@ -386,8 +459,15 @@ async def get_jargon_stats():
|
||||
|
||||
|
||||
@router.get("/{jargon_id}", response_model=JargonDetailResponse)
|
||||
async def get_jargon_detail(jargon_id: int):
|
||||
"""获取黑话详情"""
|
||||
async def get_jargon_detail(jargon_id: int) -> JargonDetailResponse:
|
||||
"""获取黑话详情。
|
||||
|
||||
Args:
|
||||
jargon_id: 黑话记录 ID。
|
||||
|
||||
Returns:
|
||||
JargonDetailResponse: 指定黑话记录的详细信息。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
if not (jargon := session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()):
|
||||
@@ -404,8 +484,15 @@ async def get_jargon_detail(jargon_id: int):
|
||||
|
||||
|
||||
@router.post("/", response_model=JargonCreateResponse)
|
||||
async def create_jargon(request: JargonCreateRequest):
|
||||
"""创建黑话"""
|
||||
async def create_jargon(request: JargonCreateRequest) -> JargonCreateResponse:
|
||||
"""创建黑话。
|
||||
|
||||
Args:
|
||||
request: 创建黑话所需的请求数据。
|
||||
|
||||
Returns:
|
||||
JargonCreateResponse: 创建结果和新黑话数据。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
same_content_jargons = session.exec(select(Jargon).where(col(Jargon.content) == request.content)).all()
|
||||
@@ -441,8 +528,16 @@ async def create_jargon(request: JargonCreateRequest):
|
||||
|
||||
|
||||
@router.patch("/{jargon_id}", response_model=JargonUpdateResponse)
|
||||
async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
|
||||
"""更新黑话(增量更新)"""
|
||||
async def update_jargon(jargon_id: int, request: JargonUpdateRequest) -> JargonUpdateResponse:
|
||||
"""增量更新黑话。
|
||||
|
||||
Args:
|
||||
jargon_id: 黑话记录 ID。
|
||||
request: 只包含需要更新字段的请求数据。
|
||||
|
||||
Returns:
|
||||
JargonUpdateResponse: 更新结果和更新后的黑话数据。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()
|
||||
@@ -450,14 +545,16 @@ async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
|
||||
raise HTTPException(status_code=404, detail="黑话不存在")
|
||||
|
||||
if update_data := request.model_dump(exclude_unset=True):
|
||||
for field, value in update_data.items():
|
||||
if field == "is_global":
|
||||
continue
|
||||
if field == "chat_id":
|
||||
jargon.session_id_dict = build_session_id_dict_for_chat(value, max(jargon.count, 1))
|
||||
continue
|
||||
if value is not None or field in ["meaning", "raw_content", "is_jargon"]:
|
||||
setattr(jargon, field, value)
|
||||
if "chat_id" in update_data and update_data["chat_id"] is not None:
|
||||
jargon.session_id_dict = build_session_id_dict_for_chat(update_data["chat_id"], max(jargon.count, 1))
|
||||
if "content" in update_data and update_data["content"] is not None:
|
||||
jargon.content = update_data["content"]
|
||||
if "raw_content" in update_data:
|
||||
jargon.raw_content = update_data["raw_content"]
|
||||
if "meaning" in update_data:
|
||||
jargon.meaning = update_data["meaning"] or ""
|
||||
if "is_jargon" in update_data:
|
||||
jargon.is_jargon = update_data["is_jargon"]
|
||||
session.add(jargon)
|
||||
|
||||
logger.info(f"更新黑话成功: id={jargon_id}")
|
||||
@@ -473,8 +570,15 @@ async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
|
||||
|
||||
|
||||
@router.delete("/{jargon_id}", response_model=JargonDeleteResponse)
|
||||
async def delete_jargon(jargon_id: int):
|
||||
"""删除黑话"""
|
||||
async def delete_jargon(jargon_id: int) -> JargonDeleteResponse:
|
||||
"""删除黑话。
|
||||
|
||||
Args:
|
||||
jargon_id: 黑话记录 ID。
|
||||
|
||||
Returns:
|
||||
JargonDeleteResponse: 删除结果。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()
|
||||
@@ -496,8 +600,15 @@ async def delete_jargon(jargon_id: int):
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=JargonDeleteResponse)
|
||||
async def batch_delete_jargons(request: BatchDeleteRequest):
|
||||
"""批量删除黑话"""
|
||||
async def batch_delete_jargons(request: BatchDeleteRequest) -> JargonDeleteResponse:
|
||||
"""批量删除黑话。
|
||||
|
||||
Args:
|
||||
request: 包含要删除黑话 ID 列表的请求。
|
||||
|
||||
Returns:
|
||||
JargonDeleteResponse: 批量删除结果。
|
||||
"""
|
||||
try:
|
||||
if not request.ids:
|
||||
raise HTTPException(status_code=400, detail="ID列表不能为空")
|
||||
@@ -525,8 +636,16 @@ async def batch_delete_jargons(request: BatchDeleteRequest):
|
||||
async def batch_set_jargon_status(
|
||||
ids: Annotated[List[int], Query(description="黑话ID列表")],
|
||||
is_jargon: Annotated[bool, Query(description="是否是黑话")],
|
||||
):
|
||||
"""批量设置黑话状态"""
|
||||
) -> JargonUpdateResponse:
|
||||
"""批量设置黑话状态。
|
||||
|
||||
Args:
|
||||
ids: 需要更新状态的黑话 ID 列表。
|
||||
is_jargon: 目标黑话状态。
|
||||
|
||||
Returns:
|
||||
JargonUpdateResponse: 批量更新结果。
|
||||
"""
|
||||
try:
|
||||
if not ids:
|
||||
raise HTTPException(status_code=400, detail="ID列表不能为空")
|
||||
|
||||
Reference in New Issue
Block a user