Merge branch 'dev' of https://github.com/Mai-with-u/MaiBot into dev
This commit is contained in:
9
Plan.md
9
Plan.md
@@ -1,9 +0,0 @@
|
||||
Context 在消息接收的时候就进行解析,不再放到 MaiMessage 里面,由消息注册的时候直接进去注册
|
||||
- [ ] 实现`update_chat_context`方法,主要关注`format_info`
|
||||
|
||||
|
||||
1. **预计不对发送的时候进行`accept_format`的格式判断**,希望所有消息适配器接收的时候做一下不兼容内容主动丢弃
|
||||
2. 在发送消息的时候进行`accept_format`的判断,判断不兼容内容是否存在,如果存在则丢弃掉
|
||||
|
||||
- [ ] 实现 status_api
|
||||
|
||||
@@ -1,12 +1,7 @@
|
||||
"""表情包管理 API 路由"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Cookie, HTTPException, Query
|
||||
@@ -15,6 +10,12 @@ from PIL import Image
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import col, select
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
from src.webui.core import get_token_manager
|
||||
@@ -57,7 +58,15 @@ router = APIRouter(prefix="/emoji", tags=["Emoji"])
|
||||
|
||||
|
||||
def _normalize_emoji_description(description: str = "", emotion: str = "") -> str:
|
||||
"""将上传参数中的描述/情绪标签归一化为可存储 description。"""
|
||||
"""将上传参数中的描述或情绪标签归一化为可存储描述。
|
||||
|
||||
Args:
|
||||
description: 用户输入的表情包描述。
|
||||
emotion: 用户输入的情绪标签。
|
||||
|
||||
Returns:
|
||||
str: 归一化后的描述字符串。
|
||||
"""
|
||||
normalized_description = str(description or "").strip()
|
||||
normalized_emotion = str(emotion or "").strip()
|
||||
if normalized_description:
|
||||
@@ -80,7 +89,21 @@ async def get_emoji_list(
|
||||
sort_order: Optional[str] = Query("desc", description="排序方向"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
) -> EmojiListResponse:
|
||||
"""获取表情包列表。"""
|
||||
"""获取表情包列表。
|
||||
|
||||
Args:
|
||||
page: 页码,从 1 开始。
|
||||
page_size: 每页数量,范围为 1-100。
|
||||
search: 搜索关键词,用于匹配描述或哈希。
|
||||
is_registered: 是否已注册筛选条件。
|
||||
is_banned: 是否被禁用筛选条件。
|
||||
sort_by: 排序字段。
|
||||
sort_order: 排序方向。
|
||||
maibot_session: WebUI 登录会话 Cookie。
|
||||
|
||||
Returns:
|
||||
EmojiListResponse: 分页后的表情包列表。
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session)
|
||||
|
||||
@@ -123,13 +146,14 @@ async def get_emoji_list(
|
||||
if is_banned is not None:
|
||||
count_statement = count_statement.where(col(Images.is_banned) == is_banned)
|
||||
total = session.exec(count_statement).one()
|
||||
data = [emoji_to_response(emoji) for emoji in emojis]
|
||||
|
||||
return EmojiListResponse(
|
||||
success=True,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
data=[emoji_to_response(emoji) for emoji in emojis],
|
||||
data=data,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -140,7 +164,15 @@ async def get_emoji_list(
|
||||
|
||||
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
|
||||
async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie(None)) -> EmojiDetailResponse:
|
||||
"""获取表情包详细信息。"""
|
||||
"""获取表情包详细信息。
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包 ID。
|
||||
maibot_session: WebUI 登录会话 Cookie。
|
||||
|
||||
Returns:
|
||||
EmojiDetailResponse: 表情包详细信息。
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session)
|
||||
|
||||
@@ -166,7 +198,16 @@ async def update_emoji(
|
||||
request: EmojiUpdateRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
) -> EmojiUpdateResponse:
|
||||
"""增量更新表情包。"""
|
||||
"""增量更新表情包。
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包 ID。
|
||||
request: 只包含需要更新字段的请求数据。
|
||||
maibot_session: WebUI 登录会话 Cookie。
|
||||
|
||||
Returns:
|
||||
EmojiUpdateResponse: 更新结果和更新后的表情包数据。
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session)
|
||||
|
||||
@@ -195,8 +236,14 @@ async def update_emoji(
|
||||
update_data["description"] = normalized_description
|
||||
update_data.pop("emotion", None)
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(emoji, field, value)
|
||||
if "description" in update_data:
|
||||
emoji.description = update_data["description"]
|
||||
if "is_registered" in update_data:
|
||||
emoji.is_registered = update_data["is_registered"]
|
||||
if "is_banned" in update_data:
|
||||
emoji.is_banned = update_data["is_banned"]
|
||||
if "register_time" in update_data:
|
||||
emoji.register_time = update_data["register_time"]
|
||||
|
||||
session.add(emoji)
|
||||
logger.info(f"表情包已更新: ID={emoji_id}, 字段: {list(update_data.keys())}")
|
||||
@@ -215,7 +262,15 @@ async def update_emoji(
|
||||
|
||||
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
|
||||
async def delete_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None)) -> EmojiDeleteResponse:
|
||||
"""删除表情包。"""
|
||||
"""删除表情包。
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包 ID。
|
||||
maibot_session: WebUI 登录会话 Cookie。
|
||||
|
||||
Returns:
|
||||
EmojiDeleteResponse: 删除结果。
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session)
|
||||
|
||||
@@ -242,7 +297,14 @@ async def delete_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(Non
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||
"""获取表情包统计数据。"""
|
||||
"""获取表情包统计数据。
|
||||
|
||||
Args:
|
||||
maibot_session: WebUI 登录会话 Cookie。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 表情包总数、格式分布和高频使用统计。
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session)
|
||||
|
||||
@@ -312,7 +374,15 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None)) -> Dict[
|
||||
|
||||
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
|
||||
async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None)) -> EmojiUpdateResponse:
|
||||
"""注册表情包。"""
|
||||
"""注册表情包。
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包 ID。
|
||||
maibot_session: WebUI 登录会话 Cookie。
|
||||
|
||||
Returns:
|
||||
EmojiUpdateResponse: 注册结果和更新后的表情包数据。
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session)
|
||||
|
||||
@@ -344,7 +414,15 @@ async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(N
|
||||
|
||||
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
|
||||
async def ban_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None)) -> EmojiUpdateResponse:
|
||||
"""禁用表情包。"""
|
||||
"""禁用表情包。
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包 ID。
|
||||
maibot_session: WebUI 登录会话 Cookie。
|
||||
|
||||
Returns:
|
||||
EmojiUpdateResponse: 禁用结果和更新后的表情包数据。
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session)
|
||||
|
||||
@@ -378,7 +456,17 @@ async def get_emoji_thumbnail(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
original: bool = Query(False, description="是否返回原图"),
|
||||
) -> FileResponse | JSONResponse:
|
||||
"""获取表情包缩略图。"""
|
||||
"""获取表情包缩略图。
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包 ID。
|
||||
token: URL 中携带的访问令牌。
|
||||
maibot_session: WebUI 登录会话 Cookie。
|
||||
original: 是否返回原图。
|
||||
|
||||
Returns:
|
||||
FileResponse | JSONResponse: 缩略图文件、原图文件或生成中的状态响应。
|
||||
"""
|
||||
try:
|
||||
token_manager = get_token_manager()
|
||||
is_valid = False
|
||||
@@ -456,7 +544,15 @@ async def batch_delete_emojis(
|
||||
request: BatchDeleteRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
) -> BatchDeleteResponse:
|
||||
"""批量删除表情包。"""
|
||||
"""批量删除表情包。
|
||||
|
||||
Args:
|
||||
request: 包含要删除表情包 ID 列表的请求。
|
||||
maibot_session: WebUI 登录会话 Cookie。
|
||||
|
||||
Returns:
|
||||
BatchDeleteResponse: 批量删除结果。
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session)
|
||||
|
||||
@@ -512,7 +608,18 @@ async def upload_emoji(
|
||||
is_registered: IsRegisteredForm = True,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
) -> EmojiUploadResponse:
|
||||
"""上传并注册表情包。"""
|
||||
"""上传并注册表情包。
|
||||
|
||||
Args:
|
||||
file: 上传的表情包文件。
|
||||
description: 表情包描述。
|
||||
emotion: 情绪标签。
|
||||
is_registered: 是否上传后直接注册。
|
||||
maibot_session: WebUI 登录会话 Cookie。
|
||||
|
||||
Returns:
|
||||
EmojiUploadResponse: 上传结果和新表情包数据。
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session)
|
||||
|
||||
@@ -574,7 +681,6 @@ async def upload_emoji(
|
||||
full_path=full_path,
|
||||
image_hash=emoji_hash,
|
||||
description=final_description,
|
||||
emotion=None,
|
||||
query_count=0,
|
||||
is_registered=is_registered,
|
||||
is_banned=False,
|
||||
@@ -605,7 +711,17 @@ async def batch_upload_emoji(
|
||||
is_registered: IsRegisteredForm = True,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
) -> Dict[str, Any]:
|
||||
"""批量上传表情包。"""
|
||||
"""批量上传表情包。
|
||||
|
||||
Args:
|
||||
files: 上传的表情包文件列表。
|
||||
emotion: 批量应用的情绪标签。
|
||||
is_registered: 是否上传后直接注册。
|
||||
maibot_session: WebUI 登录会话 Cookie。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 每个文件的上传结果和汇总统计。
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session)
|
||||
|
||||
@@ -685,7 +801,6 @@ async def batch_upload_emoji(
|
||||
full_path=full_path,
|
||||
image_hash=emoji_hash,
|
||||
description=final_description,
|
||||
emotion=None,
|
||||
query_count=0,
|
||||
is_registered=is_registered,
|
||||
is_banned=False,
|
||||
@@ -713,7 +828,14 @@ async def batch_upload_emoji(
|
||||
|
||||
@router.get("/thumbnail-cache/stats", response_model=ThumbnailCacheStatsResponse)
|
||||
async def get_thumbnail_cache_stats(maibot_session: Optional[str] = Cookie(None)) -> ThumbnailCacheStatsResponse:
|
||||
"""获取缩略图缓存统计信息。"""
|
||||
"""获取缩略图缓存统计信息。
|
||||
|
||||
Args:
|
||||
maibot_session: WebUI 登录会话 Cookie。
|
||||
|
||||
Returns:
|
||||
ThumbnailCacheStatsResponse: 缩略图缓存数量、大小和覆盖率统计。
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session)
|
||||
|
||||
@@ -744,7 +866,14 @@ async def get_thumbnail_cache_stats(maibot_session: Optional[str] = Cookie(None)
|
||||
|
||||
@router.post("/thumbnail-cache/cleanup", response_model=ThumbnailCleanupResponse)
|
||||
async def cleanup_thumbnail_cache(maibot_session: Optional[str] = Cookie(None)) -> ThumbnailCleanupResponse:
|
||||
"""清理孤立的缩略图缓存。"""
|
||||
"""清理孤立的缩略图缓存。
|
||||
|
||||
Args:
|
||||
maibot_session: WebUI 登录会话 Cookie。
|
||||
|
||||
Returns:
|
||||
ThumbnailCleanupResponse: 清理结果和删除数量。
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session)
|
||||
|
||||
@@ -767,7 +896,15 @@ async def preheat_thumbnail_cache(
|
||||
limit: int = Query(100, ge=1, le=1000, description="最多预热数量"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
) -> ThumbnailPreheatResponse:
|
||||
"""预热缩略图缓存。"""
|
||||
"""预热缩略图缓存。
|
||||
|
||||
Args:
|
||||
limit: 最多预热的缩略图数量。
|
||||
maibot_session: WebUI 登录会话 Cookie。
|
||||
|
||||
Returns:
|
||||
ThumbnailPreheatResponse: 预热生成、跳过和失败数量统计。
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session)
|
||||
|
||||
@@ -783,7 +920,13 @@ async def preheat_thumbnail_cache(
|
||||
.order_by(col(Images.query_count).desc())
|
||||
.limit(limit * 2)
|
||||
)
|
||||
emojis = session.exec(statement).all()
|
||||
emojis = [
|
||||
{
|
||||
"image_hash": emoji.image_hash,
|
||||
"full_path": emoji.full_path,
|
||||
}
|
||||
for emoji in session.exec(statement).all()
|
||||
]
|
||||
|
||||
generated = 0
|
||||
skipped = 0
|
||||
@@ -793,22 +936,25 @@ async def preheat_thumbnail_cache(
|
||||
if generated >= limit:
|
||||
break
|
||||
|
||||
cache_path = get_thumbnail_cache_path(emoji.image_hash)
|
||||
image_hash = emoji["image_hash"]
|
||||
full_path = emoji["full_path"]
|
||||
|
||||
cache_path = get_thumbnail_cache_path(image_hash)
|
||||
if cache_path.exists():
|
||||
skipped += 1
|
||||
continue
|
||||
if not os.path.exists(emoji.full_path):
|
||||
if not os.path.exists(full_path):
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
get_thumbnail_executor(), generate_thumbnail, emoji.full_path, emoji.image_hash
|
||||
get_thumbnail_executor(), generate_thumbnail, full_path, image_hash
|
||||
)
|
||||
generated += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"预热缩略图失败 {emoji.image_hash}: {e}")
|
||||
logger.warning(f"预热缩略图失败 {image_hash}: {e}")
|
||||
failed += 1
|
||||
|
||||
return ThumbnailPreheatResponse(
|
||||
@@ -827,7 +973,14 @@ async def preheat_thumbnail_cache(
|
||||
|
||||
@router.delete("/thumbnail-cache/clear", response_model=ThumbnailCleanupResponse)
|
||||
async def clear_all_thumbnail_cache(maibot_session: Optional[str] = Cookie(None)) -> ThumbnailCleanupResponse:
|
||||
"""清空所有缩略图缓存。"""
|
||||
"""清空所有缩略图缓存。
|
||||
|
||||
Args:
|
||||
maibot_session: WebUI 登录会话 Cookie。
|
||||
|
||||
Returns:
|
||||
ThumbnailCleanupResponse: 清空结果和删除数量。
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session)
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Annotated, List, Optional
|
||||
|
||||
from fastapi import File, Form, UploadFile
|
||||
from pydantic import BaseModel
|
||||
|
||||
import re
|
||||
|
||||
from src.common.database.database_model import Images
|
||||
|
||||
EmojiFile = Annotated[UploadFile, File(description="表情包上传文件")]
|
||||
@@ -18,9 +20,11 @@ class EmojiResponse(BaseModel):
|
||||
|
||||
id: int
|
||||
full_path: str
|
||||
format: str
|
||||
emoji_hash: str
|
||||
description: str
|
||||
query_count: int
|
||||
usage_count: int
|
||||
is_registered: bool
|
||||
is_banned: bool
|
||||
emotion: Optional[str]
|
||||
@@ -125,26 +129,35 @@ class ThumbnailPreheatResponse(BaseModel):
|
||||
|
||||
|
||||
def emoji_to_response(image: Images) -> EmojiResponse:
|
||||
"""将表情包模型转换为响应对象。
|
||||
|
||||
Args:
|
||||
image: 数据库中的表情包记录。
|
||||
|
||||
Returns:
|
||||
EmojiResponse: WebUI 可直接序列化的表情包数据。
|
||||
"""
|
||||
emotions: list[str] = []
|
||||
if image.description:
|
||||
emotions.extend(
|
||||
item.strip() for item in re.split(r"[,,、;;\s]+", image.description) if item and item.strip()
|
||||
)
|
||||
if not emotions and image.emotion:
|
||||
emotions.extend(item.strip() for item in re.split(r"[,,、;;\s]+", image.emotion) if item and item.strip())
|
||||
|
||||
deduped_emotions: list[str] = []
|
||||
for item in emotions:
|
||||
if item not in deduped_emotions:
|
||||
deduped_emotions.append(item)
|
||||
emotion = ",".join(deduped_emotions) if deduped_emotions else None
|
||||
image_format = Path(image.full_path).suffix.lower().lstrip(".") or "unknown"
|
||||
|
||||
return EmojiResponse(
|
||||
id=image.id if image.id is not None else 0,
|
||||
full_path=image.full_path,
|
||||
format=image_format,
|
||||
emoji_hash=image.image_hash,
|
||||
description=image.description,
|
||||
query_count=image.query_count,
|
||||
usage_count=image.query_count,
|
||||
is_registered=image.is_registered,
|
||||
is_banned=image.is_banned,
|
||||
emotion=emotion,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""表达方式管理 API 路由"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
@@ -91,7 +91,14 @@ class ExpressionCreateResponse(BaseModel):
|
||||
|
||||
|
||||
def expression_to_response(expression: Expression) -> ExpressionResponse:
|
||||
"""将 Expression 模型转换为响应对象"""
|
||||
"""将表达方式模型转换为响应对象。
|
||||
|
||||
Args:
|
||||
expression: 数据库中的表达方式记录。
|
||||
|
||||
Returns:
|
||||
ExpressionResponse: WebUI 可直接序列化的响应对象。
|
||||
"""
|
||||
last_active_time = expression.last_active_time.timestamp() if expression.last_active_time else 0.0
|
||||
create_date = expression.create_time.timestamp() if expression.create_time else None
|
||||
return ExpressionResponse(
|
||||
@@ -108,7 +115,14 @@ def expression_to_response(expression: Expression) -> ExpressionResponse:
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""根据 chat_id 获取聊天名称"""
|
||||
"""根据聊天 ID 获取聊天名称。
|
||||
|
||||
Args:
|
||||
chat_id: 聊天会话 ID。
|
||||
|
||||
Returns:
|
||||
str: 聊天显示名称,获取失败时返回原始聊天 ID。
|
||||
"""
|
||||
try:
|
||||
session = _chat_manager.get_session_by_session_id(chat_id)
|
||||
if not session:
|
||||
@@ -120,7 +134,14 @@ def get_chat_name(chat_id: str) -> str:
|
||||
|
||||
|
||||
def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
|
||||
"""批量获取聊天名称"""
|
||||
"""批量获取聊天名称。
|
||||
|
||||
Args:
|
||||
chat_ids: 需要查询的聊天会话 ID 列表。
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: 以聊天 ID 为键、显示名称为值的映射。
|
||||
"""
|
||||
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
|
||||
try:
|
||||
for chat_id in chat_ids:
|
||||
@@ -148,12 +169,11 @@ class ChatListResponse(BaseModel):
|
||||
|
||||
|
||||
@router.get("/chats", response_model=ChatListResponse)
|
||||
async def get_chat_list():
|
||||
"""
|
||||
获取所有聊天列表(用于下拉选择)
|
||||
async def get_chat_list() -> ChatListResponse:
|
||||
"""获取所有聊天列表。
|
||||
|
||||
Returns:
|
||||
聊天列表
|
||||
ChatListResponse: 可用于下拉选择的聊天列表。
|
||||
"""
|
||||
try:
|
||||
chat_list = []
|
||||
@@ -186,18 +206,17 @@ async def get_expression_list(
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
|
||||
):
|
||||
"""
|
||||
获取表达方式列表
|
||||
) -> ExpressionListResponse:
|
||||
"""获取表达方式列表。
|
||||
|
||||
Args:
|
||||
page: 页码 (从 1 开始)
|
||||
page_size: 每页数量 (1-100)
|
||||
search: 搜索关键词 (匹配 situation, style)
|
||||
chat_id: 聊天ID筛选
|
||||
page: 页码,从 1 开始。
|
||||
page_size: 每页数量,范围为 1-100。
|
||||
search: 搜索关键词,用于匹配情景和风格。
|
||||
chat_id: 聊天 ID 筛选条件。
|
||||
|
||||
Returns:
|
||||
表达方式列表
|
||||
ExpressionListResponse: 分页后的表达方式列表。
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
@@ -233,8 +252,7 @@ async def get_expression_list(
|
||||
if chat_id:
|
||||
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
|
||||
total = len(session.exec(count_statement).all())
|
||||
|
||||
data = [expression_to_response(expr) for expr in expressions]
|
||||
data = [expression_to_response(expr) for expr in expressions]
|
||||
|
||||
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||
|
||||
@@ -246,25 +264,26 @@ async def get_expression_list(
|
||||
|
||||
|
||||
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
|
||||
async def get_expression_detail(expression_id: int):
|
||||
"""
|
||||
获取表达方式详细信息
|
||||
async def get_expression_detail(expression_id: int) -> ExpressionDetailResponse:
|
||||
"""获取表达方式详细信息。
|
||||
|
||||
Args:
|
||||
expression_id: 表达方式ID
|
||||
expression_id: 表达方式 ID。
|
||||
|
||||
Returns:
|
||||
表达方式详细信息
|
||||
ExpressionDetailResponse: 指定表达方式的详细信息。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
|
||||
expression = session.exec(statement).first()
|
||||
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
return ExpressionDetailResponse(success=True, data=expression_to_response(expression))
|
||||
data = expression_to_response(expression)
|
||||
|
||||
return ExpressionDetailResponse(success=True, data=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -276,15 +295,14 @@ async def get_expression_detail(expression_id: int):
|
||||
@router.post("/", response_model=ExpressionCreateResponse)
|
||||
async def create_expression(
|
||||
request: ExpressionCreateRequest,
|
||||
):
|
||||
"""
|
||||
创建新的表达方式
|
||||
) -> ExpressionCreateResponse:
|
||||
"""创建新的表达方式。
|
||||
|
||||
Args:
|
||||
request: 创建请求
|
||||
request: 创建表达方式所需的请求数据。
|
||||
|
||||
Returns:
|
||||
创建结果
|
||||
ExpressionCreateResponse: 创建结果和新表达方式数据。
|
||||
"""
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
@@ -294,8 +312,6 @@ async def create_expression(
|
||||
expression = Expression(
|
||||
situation=request.situation,
|
||||
style=request.style,
|
||||
context="",
|
||||
up_content="",
|
||||
content_list="[]",
|
||||
count=0,
|
||||
last_active_time=current_time,
|
||||
@@ -303,12 +319,13 @@ async def create_expression(
|
||||
session_id=request.chat_id,
|
||||
)
|
||||
session.add(expression)
|
||||
session.flush()
|
||||
expression_id = expression.id
|
||||
data = expression_to_response(expression)
|
||||
|
||||
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
|
||||
logger.info(f"表达方式已创建: ID={expression_id}, situation={request.situation}")
|
||||
|
||||
return ExpressionCreateResponse(
|
||||
success=True, message="表达方式创建成功", data=expression_to_response(expression)
|
||||
)
|
||||
return ExpressionCreateResponse(success=True, message="表达方式创建成功", data=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -321,25 +338,17 @@ async def create_expression(
|
||||
async def update_expression(
|
||||
expression_id: int,
|
||||
request: ExpressionUpdateRequest,
|
||||
):
|
||||
"""
|
||||
增量更新表达方式(只更新提供的字段)
|
||||
) -> ExpressionUpdateResponse:
|
||||
"""增量更新表达方式。
|
||||
|
||||
Args:
|
||||
expression_id: 表达方式ID
|
||||
request: 更新请求(只包含需要更新的字段)
|
||||
expression_id: 表达方式 ID。
|
||||
request: 只包含需要更新字段的请求数据。
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
ExpressionUpdateResponse: 更新结果和更新后的表达方式数据。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
|
||||
expression = session.exec(statement).first()
|
||||
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
# 只更新提供的字段
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
|
||||
@@ -358,17 +367,19 @@ async def update_expression(
|
||||
db_expression = session.exec(select(Expression).where(col(Expression.id) == expression_id).limit(1)).first()
|
||||
if not db_expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
for field, value in update_data.items():
|
||||
if hasattr(db_expression, field):
|
||||
setattr(db_expression, field, value)
|
||||
if "situation" in update_data:
|
||||
db_expression.situation = update_data["situation"]
|
||||
if "style" in update_data:
|
||||
db_expression.style = update_data["style"]
|
||||
if "session_id" in update_data:
|
||||
db_expression.session_id = update_data["session_id"]
|
||||
db_expression.last_active_time = update_data["last_active_time"]
|
||||
session.add(db_expression)
|
||||
expression = db_expression
|
||||
data = expression_to_response(db_expression)
|
||||
|
||||
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
|
||||
|
||||
return ExpressionUpdateResponse(
|
||||
success=True, message=f"成功更新 {len(update_data)} 个字段", data=expression_to_response(expression)
|
||||
)
|
||||
return ExpressionUpdateResponse(success=True, message=f"成功更新 {len(update_data)} 个字段", data=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -378,29 +389,26 @@ async def update_expression(
|
||||
|
||||
|
||||
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
|
||||
async def delete_expression(expression_id: int):
|
||||
"""
|
||||
删除表达方式
|
||||
async def delete_expression(expression_id: int) -> ExpressionDeleteResponse:
|
||||
"""删除表达方式。
|
||||
|
||||
Args:
|
||||
expression_id: 表达方式ID
|
||||
expression_id: 表达方式 ID。
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
ExpressionDeleteResponse: 删除结果。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
|
||||
expression = session.exec(statement).first()
|
||||
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
# 记录删除信息
|
||||
situation = expression.situation
|
||||
# 记录删除信息
|
||||
situation = expression.situation
|
||||
|
||||
# 执行删除
|
||||
with get_db_session() as session:
|
||||
session.exec(delete(Expression).where(col(Expression.id) == expression_id))
|
||||
|
||||
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
|
||||
@@ -423,15 +431,14 @@ class BatchDeleteRequest(BaseModel):
|
||||
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
|
||||
async def batch_delete_expressions(
|
||||
request: BatchDeleteRequest,
|
||||
):
|
||||
"""
|
||||
批量删除表达方式
|
||||
) -> ExpressionDeleteResponse:
|
||||
"""批量删除表达方式。
|
||||
|
||||
Args:
|
||||
request: 包含要删除的ID列表的请求
|
||||
request: 包含要删除表达方式 ID 列表的请求。
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
ExpressionDeleteResponse: 批量删除结果。
|
||||
"""
|
||||
try:
|
||||
if not request.ids:
|
||||
@@ -463,12 +470,11 @@ async def batch_delete_expressions(
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_expression_stats():
|
||||
"""
|
||||
获取表达方式统计数据
|
||||
async def get_expression_stats() -> Dict[str, Any]:
|
||||
"""获取表达方式统计数据。
|
||||
|
||||
Returns:
|
||||
统计数据
|
||||
Dict[str, Any]: 表达方式数量、近期新增和聊天分布统计。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
@@ -519,12 +525,11 @@ class ReviewStatsResponse(BaseModel):
|
||||
|
||||
|
||||
@router.get("/review/stats", response_model=ReviewStatsResponse)
|
||||
async def get_review_stats():
|
||||
"""
|
||||
获取审核统计数据
|
||||
async def get_review_stats() -> ReviewStatsResponse:
|
||||
"""获取审核统计数据。
|
||||
|
||||
Returns:
|
||||
审核统计数据
|
||||
ReviewStatsResponse: 审核统计数据。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
@@ -568,19 +573,18 @@ async def get_review_list(
|
||||
filter_type: str = Query("unchecked", description="筛选类型: unchecked/passed/rejected/all"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
|
||||
):
|
||||
"""
|
||||
获取待审核/已审核的表达方式列表
|
||||
) -> ReviewListResponse:
|
||||
"""获取待审核或已审核的表达方式列表。
|
||||
|
||||
Args:
|
||||
page: 页码
|
||||
page_size: 每页数量
|
||||
filter_type: 筛选类型 (unchecked/passed/rejected/all)
|
||||
search: 搜索关键词
|
||||
chat_id: 聊天ID筛选
|
||||
page: 页码。
|
||||
page_size: 每页数量。
|
||||
filter_type: 筛选类型,可选 unchecked、passed、rejected 或 all。
|
||||
search: 搜索关键词。
|
||||
chat_id: 聊天 ID 筛选条件。
|
||||
|
||||
Returns:
|
||||
表达方式列表
|
||||
ReviewListResponse: 审核列表响应。
|
||||
"""
|
||||
try:
|
||||
statement = select(Expression)
|
||||
@@ -621,13 +625,14 @@ async def get_review_list(
|
||||
if chat_id:
|
||||
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
|
||||
total = len(session.exec(count_statement).all())
|
||||
data = [expression_to_response(expr) for expr in expressions]
|
||||
|
||||
return ReviewListResponse(
|
||||
success=True,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
data=[expression_to_response(expr) for expr in expressions],
|
||||
data=data,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
@@ -672,15 +677,14 @@ class BatchReviewResponse(BaseModel):
|
||||
@router.post("/review/batch", response_model=BatchReviewResponse)
|
||||
async def batch_review_expressions(
|
||||
request: BatchReviewRequest,
|
||||
):
|
||||
"""
|
||||
批量审核表达方式
|
||||
) -> BatchReviewResponse:
|
||||
"""批量审核表达方式。
|
||||
|
||||
Args:
|
||||
request: 批量审核请求
|
||||
request: 批量审核请求。
|
||||
|
||||
Returns:
|
||||
批量审核结果
|
||||
BatchReviewResponse: 每条表达方式的审核结果。
|
||||
"""
|
||||
try:
|
||||
if not request.items:
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
"""黑话(俚语)管理路由"""
|
||||
|
||||
import json
|
||||
from typing import Annotated, Any, Dict, List, Optional, Set
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlmodel import Session, col, delete, select
|
||||
|
||||
import json
|
||||
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import ChatSession, Jargon
|
||||
from src.common.logger import get_logger
|
||||
@@ -21,9 +22,13 @@ router = APIRouter(prefix="/jargon", tags=["Jargon"], dependencies=[Depends(requ
|
||||
|
||||
|
||||
def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
|
||||
"""
|
||||
解析 chat_id 字段,提取所有 stream_id
|
||||
chat_id 格式: [["stream_id", user_id], ...] 或直接是 stream_id 字符串
|
||||
"""解析聊天 ID 字段并提取所有 stream_id。
|
||||
|
||||
Args:
|
||||
chat_id_str: JSON 格式或纯字符串格式的聊天 ID。
|
||||
|
||||
Returns:
|
||||
List[str]: 解析出的 stream_id 列表。
|
||||
"""
|
||||
if not chat_id_str:
|
||||
return []
|
||||
@@ -43,9 +48,14 @@ def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
|
||||
|
||||
|
||||
def get_display_name_for_chat_id(chat_id_str: str, session: Session) -> str:
|
||||
"""
|
||||
获取 chat_id 的显示名称
|
||||
尝试解析 JSON 并查询 ChatSession 表获取群聊名称
|
||||
"""获取聊天 ID 的显示名称。
|
||||
|
||||
Args:
|
||||
chat_id_str: JSON 格式或纯字符串格式的聊天 ID。
|
||||
session: 当前数据库会话。
|
||||
|
||||
Returns:
|
||||
str: 聊天显示名称,无法查询时返回截断后的 stream_id。
|
||||
"""
|
||||
stream_ids = parse_chat_id_to_stream_ids(chat_id_str)
|
||||
|
||||
@@ -175,7 +185,14 @@ class ChatListResponse(BaseModel):
|
||||
|
||||
|
||||
def parse_session_id_dict(session_id_dict_str: Optional[str]) -> Dict[str, int]:
|
||||
"""解析会话计数字典。"""
|
||||
"""解析会话计数字典。
|
||||
|
||||
Args:
|
||||
session_id_dict_str: 数据库中保存的会话计数字典 JSON 字符串。
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: 解析后的会话计数字典。
|
||||
"""
|
||||
if not session_id_dict_str:
|
||||
return {}
|
||||
|
||||
@@ -202,12 +219,26 @@ def parse_session_id_dict(session_id_dict_str: Optional[str]) -> Dict[str, int]:
|
||||
|
||||
|
||||
def dump_session_id_dict(session_counts: Dict[str, int]) -> str:
|
||||
"""序列化会话计数字典。"""
|
||||
"""序列化会话计数字典。
|
||||
|
||||
Args:
|
||||
session_counts: 会话 ID 与出现次数的映射。
|
||||
|
||||
Returns:
|
||||
str: 可写入数据库的 JSON 字符串。
|
||||
"""
|
||||
return json.dumps(session_counts, ensure_ascii=False)
|
||||
|
||||
|
||||
def get_primary_chat_id(session_id_dict_str: Optional[str]) -> str:
|
||||
"""从会话计数字典中选出主聊天 ID。"""
|
||||
"""从会话计数字典中选出主聊天 ID。
|
||||
|
||||
Args:
|
||||
session_id_dict_str: 数据库中保存的会话计数字典 JSON 字符串。
|
||||
|
||||
Returns:
|
||||
str: 出现次数最多的聊天 ID,没有记录时返回空字符串。
|
||||
"""
|
||||
if not (session_counts := parse_session_id_dict(session_id_dict_str)):
|
||||
return ""
|
||||
|
||||
@@ -215,17 +246,41 @@ def get_primary_chat_id(session_id_dict_str: Optional[str]) -> str:
|
||||
|
||||
|
||||
def has_chat_id(session_id_dict_str: Optional[str], chat_id: str) -> bool:
|
||||
"""判断记录是否包含指定聊天 ID。"""
|
||||
"""判断记录是否包含指定聊天 ID。
|
||||
|
||||
Args:
|
||||
session_id_dict_str: 数据库中保存的会话计数字典 JSON 字符串。
|
||||
chat_id: 需要检查的聊天 ID。
|
||||
|
||||
Returns:
|
||||
bool: 记录包含该聊天 ID 时返回 True。
|
||||
"""
|
||||
return chat_id in parse_session_id_dict(session_id_dict_str)
|
||||
|
||||
|
||||
def build_session_id_dict_for_chat(chat_id: str, count: int = 1) -> str:
|
||||
"""为单个聊天 ID 构建会话计数字典。"""
|
||||
"""为单个聊天 ID 构建会话计数字典。
|
||||
|
||||
Args:
|
||||
chat_id: 聊天 ID。
|
||||
count: 该聊天 ID 的出现次数。
|
||||
|
||||
Returns:
|
||||
str: 可写入数据库的会话计数字典 JSON 字符串。
|
||||
"""
|
||||
return dump_session_id_dict({chat_id: count})
|
||||
|
||||
|
||||
def jargon_to_dict(jargon: Jargon, session: Session) -> Dict[str, Any]:
|
||||
"""将 Jargon ORM 对象转换为字典"""
|
||||
"""将黑话模型转换为字典。
|
||||
|
||||
Args:
|
||||
jargon: 数据库中的黑话记录。
|
||||
session: 当前数据库会话,用于查询聊天显示名称。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: WebUI 可直接序列化的黑话数据。
|
||||
"""
|
||||
chat_id = get_primary_chat_id(jargon.session_id_dict)
|
||||
chat_name = get_display_name_for_chat_id(chat_id, session) if chat_id else None
|
||||
|
||||
@@ -255,8 +310,19 @@ async def get_jargon_list(
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
chat_id: Optional[str] = Query(None, description="按聊天ID筛选"),
|
||||
is_jargon: Optional[bool] = Query(None, description="按是否是黑话筛选"),
|
||||
):
|
||||
"""获取黑话列表"""
|
||||
) -> JargonListResponse:
|
||||
"""获取黑话列表。
|
||||
|
||||
Args:
|
||||
page: 页码,从 1 开始。
|
||||
page_size: 每页数量,范围为 1-100。
|
||||
search: 搜索关键词。
|
||||
chat_id: 聊天 ID 筛选条件。
|
||||
is_jargon: 是否为黑话的筛选条件。
|
||||
|
||||
Returns:
|
||||
JargonListResponse: 分页后的黑话列表。
|
||||
"""
|
||||
try:
|
||||
statement = select(Jargon)
|
||||
|
||||
@@ -304,18 +370,21 @@ async def get_jargon_list(
|
||||
|
||||
|
||||
@router.get("/chats", response_model=ChatListResponse)
|
||||
async def get_chat_list():
|
||||
"""获取所有有黑话记录的聊天列表"""
|
||||
async def get_chat_list() -> ChatListResponse:
|
||||
"""获取所有有黑话记录的聊天列表。
|
||||
|
||||
Returns:
|
||||
ChatListResponse: 包含黑话记录的聊天列表。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
jargons = session.exec(select(Jargon)).all()
|
||||
|
||||
seen_stream_ids: Set[str] = set()
|
||||
for jargon in jargons:
|
||||
seen_stream_ids.update(parse_session_id_dict(jargon.session_id_dict).keys())
|
||||
seen_stream_ids: Set[str] = set()
|
||||
for jargon in jargons:
|
||||
seen_stream_ids.update(parse_session_id_dict(jargon.session_id_dict).keys())
|
||||
|
||||
result = []
|
||||
with get_db_session() as session:
|
||||
result: List[ChatInfoResponse] = []
|
||||
for stream_id in seen_stream_ids:
|
||||
if chat_session := session.exec(
|
||||
select(ChatSession).where(col(ChatSession.session_id) == stream_id)
|
||||
@@ -347,25 +416,29 @@ async def get_chat_list():
|
||||
|
||||
|
||||
@router.get("/stats/summary", response_model=JargonStatsResponse)
|
||||
async def get_jargon_stats():
|
||||
"""获取黑话统计数据"""
|
||||
async def get_jargon_stats() -> JargonStatsResponse:
|
||||
"""获取黑话统计数据。
|
||||
|
||||
Returns:
|
||||
JargonStatsResponse: 黑话总数、确认状态和聊天分布统计。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
jargons = session.exec(select(Jargon)).all()
|
||||
|
||||
total = len(jargons)
|
||||
confirmed_jargon = sum(jargon.is_jargon is True for jargon in jargons)
|
||||
confirmed_not_jargon = sum(jargon.is_jargon is False for jargon in jargons)
|
||||
pending = sum(jargon.is_jargon is None for jargon in jargons)
|
||||
complete_count = sum(jargon.is_complete for jargon in jargons)
|
||||
total = len(jargons)
|
||||
confirmed_jargon = sum(jargon.is_jargon is True for jargon in jargons)
|
||||
confirmed_not_jargon = sum(jargon.is_jargon is False for jargon in jargons)
|
||||
pending = sum(jargon.is_jargon is None for jargon in jargons)
|
||||
complete_count = sum(jargon.is_complete for jargon in jargons)
|
||||
|
||||
top_chats_counter: Dict[str, int] = {}
|
||||
for jargon in jargons:
|
||||
for session_id in parse_session_id_dict(jargon.session_id_dict):
|
||||
top_chats_counter[session_id] = top_chats_counter.get(session_id, 0) + 1
|
||||
top_chats_counter: Dict[str, int] = {}
|
||||
for jargon in jargons:
|
||||
for session_id in parse_session_id_dict(jargon.session_id_dict):
|
||||
top_chats_counter[session_id] = top_chats_counter.get(session_id, 0) + 1
|
||||
|
||||
top_chats_dict = dict(sorted(top_chats_counter.items(), key=lambda item: item[1], reverse=True)[:5])
|
||||
chat_count = len(top_chats_counter)
|
||||
top_chats_dict = dict(sorted(top_chats_counter.items(), key=lambda item: item[1], reverse=True)[:5])
|
||||
chat_count = len(top_chats_counter)
|
||||
|
||||
return JargonStatsResponse(
|
||||
success=True,
|
||||
@@ -386,8 +459,15 @@ async def get_jargon_stats():
|
||||
|
||||
|
||||
@router.get("/{jargon_id}", response_model=JargonDetailResponse)
|
||||
async def get_jargon_detail(jargon_id: int):
|
||||
"""获取黑话详情"""
|
||||
async def get_jargon_detail(jargon_id: int) -> JargonDetailResponse:
|
||||
"""获取黑话详情。
|
||||
|
||||
Args:
|
||||
jargon_id: 黑话记录 ID。
|
||||
|
||||
Returns:
|
||||
JargonDetailResponse: 指定黑话记录的详细信息。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
if not (jargon := session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()):
|
||||
@@ -404,8 +484,15 @@ async def get_jargon_detail(jargon_id: int):
|
||||
|
||||
|
||||
@router.post("/", response_model=JargonCreateResponse)
|
||||
async def create_jargon(request: JargonCreateRequest):
|
||||
"""创建黑话"""
|
||||
async def create_jargon(request: JargonCreateRequest) -> JargonCreateResponse:
|
||||
"""创建黑话。
|
||||
|
||||
Args:
|
||||
request: 创建黑话所需的请求数据。
|
||||
|
||||
Returns:
|
||||
JargonCreateResponse: 创建结果和新黑话数据。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
same_content_jargons = session.exec(select(Jargon).where(col(Jargon.content) == request.content)).all()
|
||||
@@ -441,8 +528,16 @@ async def create_jargon(request: JargonCreateRequest):
|
||||
|
||||
|
||||
@router.patch("/{jargon_id}", response_model=JargonUpdateResponse)
|
||||
async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
|
||||
"""更新黑话(增量更新)"""
|
||||
async def update_jargon(jargon_id: int, request: JargonUpdateRequest) -> JargonUpdateResponse:
|
||||
"""增量更新黑话。
|
||||
|
||||
Args:
|
||||
jargon_id: 黑话记录 ID。
|
||||
request: 只包含需要更新字段的请求数据。
|
||||
|
||||
Returns:
|
||||
JargonUpdateResponse: 更新结果和更新后的黑话数据。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()
|
||||
@@ -450,14 +545,16 @@ async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
|
||||
raise HTTPException(status_code=404, detail="黑话不存在")
|
||||
|
||||
if update_data := request.model_dump(exclude_unset=True):
|
||||
for field, value in update_data.items():
|
||||
if field == "is_global":
|
||||
continue
|
||||
if field == "chat_id":
|
||||
jargon.session_id_dict = build_session_id_dict_for_chat(value, max(jargon.count, 1))
|
||||
continue
|
||||
if value is not None or field in ["meaning", "raw_content", "is_jargon"]:
|
||||
setattr(jargon, field, value)
|
||||
if "chat_id" in update_data and update_data["chat_id"] is not None:
|
||||
jargon.session_id_dict = build_session_id_dict_for_chat(update_data["chat_id"], max(jargon.count, 1))
|
||||
if "content" in update_data and update_data["content"] is not None:
|
||||
jargon.content = update_data["content"]
|
||||
if "raw_content" in update_data:
|
||||
jargon.raw_content = update_data["raw_content"]
|
||||
if "meaning" in update_data:
|
||||
jargon.meaning = update_data["meaning"] or ""
|
||||
if "is_jargon" in update_data:
|
||||
jargon.is_jargon = update_data["is_jargon"]
|
||||
session.add(jargon)
|
||||
|
||||
logger.info(f"更新黑话成功: id={jargon_id}")
|
||||
@@ -473,8 +570,15 @@ async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
|
||||
|
||||
|
||||
@router.delete("/{jargon_id}", response_model=JargonDeleteResponse)
|
||||
async def delete_jargon(jargon_id: int):
|
||||
"""删除黑话"""
|
||||
async def delete_jargon(jargon_id: int) -> JargonDeleteResponse:
|
||||
"""删除黑话。
|
||||
|
||||
Args:
|
||||
jargon_id: 黑话记录 ID。
|
||||
|
||||
Returns:
|
||||
JargonDeleteResponse: 删除结果。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()
|
||||
@@ -496,8 +600,15 @@ async def delete_jargon(jargon_id: int):
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=JargonDeleteResponse)
|
||||
async def batch_delete_jargons(request: BatchDeleteRequest):
|
||||
"""批量删除黑话"""
|
||||
async def batch_delete_jargons(request: BatchDeleteRequest) -> JargonDeleteResponse:
|
||||
"""批量删除黑话。
|
||||
|
||||
Args:
|
||||
request: 包含要删除黑话 ID 列表的请求。
|
||||
|
||||
Returns:
|
||||
JargonDeleteResponse: 批量删除结果。
|
||||
"""
|
||||
try:
|
||||
if not request.ids:
|
||||
raise HTTPException(status_code=400, detail="ID列表不能为空")
|
||||
@@ -525,8 +636,16 @@ async def batch_delete_jargons(request: BatchDeleteRequest):
|
||||
async def batch_set_jargon_status(
|
||||
ids: Annotated[List[int], Query(description="黑话ID列表")],
|
||||
is_jargon: Annotated[bool, Query(description="是否是黑话")],
|
||||
):
|
||||
"""批量设置黑话状态"""
|
||||
) -> JargonUpdateResponse:
|
||||
"""批量设置黑话状态。
|
||||
|
||||
Args:
|
||||
ids: 需要更新状态的黑话 ID 列表。
|
||||
is_jargon: 目标黑话状态。
|
||||
|
||||
Returns:
|
||||
JargonUpdateResponse: 批量更新结果。
|
||||
"""
|
||||
try:
|
||||
if not ids:
|
||||
raise HTTPException(status_code=400, detail="ID列表不能为空")
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
"""人物信息管理 API 路由"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import case
|
||||
from sqlmodel import col, delete, select
|
||||
|
||||
import json
|
||||
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import PersonInfo
|
||||
from src.common.logger import get_logger
|
||||
@@ -97,7 +98,14 @@ class BatchDeleteResponse(BaseModel):
|
||||
|
||||
|
||||
def parse_group_nick_name(group_nick_name_str: Optional[str]) -> Optional[List[Dict[str, str]]]:
|
||||
"""解析群昵称 JSON 字符串"""
|
||||
"""解析群昵称 JSON 字符串。
|
||||
|
||||
Args:
|
||||
group_nick_name_str: 数据库中保存的群昵称 JSON 字符串。
|
||||
|
||||
Returns:
|
||||
Optional[List[Dict[str, str]]]: 解析后的群昵称列表,解析失败时返回 None。
|
||||
"""
|
||||
if not group_nick_name_str:
|
||||
return None
|
||||
try:
|
||||
@@ -107,7 +115,14 @@ def parse_group_nick_name(group_nick_name_str: Optional[str]) -> Optional[List[D
|
||||
|
||||
|
||||
def person_to_response(person: PersonInfo) -> PersonInfoResponse:
|
||||
"""将 PersonInfo 模型转换为响应对象"""
|
||||
"""将人物信息模型转换为响应对象。
|
||||
|
||||
Args:
|
||||
person: 数据库中的人物信息记录。
|
||||
|
||||
Returns:
|
||||
PersonInfoResponse: WebUI 可直接序列化的人物信息。
|
||||
"""
|
||||
know_since = person.first_known_time.timestamp() if person.first_known_time else None
|
||||
last_know = person.last_known_time.timestamp() if person.last_known_time else None
|
||||
return PersonInfoResponse(
|
||||
@@ -134,19 +149,18 @@ async def get_person_list(
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
is_known: Optional[bool] = Query(None, description="是否已认识筛选"),
|
||||
platform: Optional[str] = Query(None, description="平台筛选"),
|
||||
):
|
||||
"""
|
||||
获取人物信息列表
|
||||
) -> PersonListResponse:
|
||||
"""获取人物信息列表。
|
||||
|
||||
Args:
|
||||
page: 页码 (从 1 开始)
|
||||
page_size: 每页数量 (1-100)
|
||||
search: 搜索关键词 (匹配 person_name, nickname, user_id)
|
||||
is_known: 是否已认识筛选
|
||||
platform: 平台筛选
|
||||
page: 页码,从 1 开始。
|
||||
page_size: 每页数量,范围为 1-100。
|
||||
search: 搜索关键词,用于匹配人物名称、昵称和用户 ID。
|
||||
is_known: 是否已认识筛选条件。
|
||||
platform: 平台筛选条件。
|
||||
|
||||
Returns:
|
||||
人物信息列表
|
||||
PersonListResponse: 分页后的人物信息列表。
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
@@ -193,8 +207,7 @@ async def get_person_list(
|
||||
if platform:
|
||||
count_statement = count_statement.where(col(PersonInfo.platform) == platform)
|
||||
total = len(session.exec(count_statement).all())
|
||||
|
||||
data = [person_to_response(person) for person in persons]
|
||||
data = [person_to_response(person) for person in persons]
|
||||
|
||||
return PersonListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||
|
||||
@@ -206,25 +219,26 @@ async def get_person_list(
|
||||
|
||||
|
||||
@router.get("/{person_id}", response_model=PersonDetailResponse)
|
||||
async def get_person_detail(person_id: str):
|
||||
"""
|
||||
获取人物详细信息
|
||||
async def get_person_detail(person_id: str) -> PersonDetailResponse:
|
||||
"""获取人物详细信息。
|
||||
|
||||
Args:
|
||||
person_id: 人物唯一 ID
|
||||
person_id: 人物唯一 ID。
|
||||
|
||||
Returns:
|
||||
人物详细信息
|
||||
PersonDetailResponse: 指定人物的详细信息。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||
person = session.exec(statement).first()
|
||||
|
||||
if not person:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||
if not person:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||
|
||||
return PersonDetailResponse(success=True, data=person_to_response(person))
|
||||
data = person_to_response(person)
|
||||
|
||||
return PersonDetailResponse(success=True, data=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -237,25 +251,17 @@ async def get_person_detail(person_id: str):
|
||||
async def update_person(
|
||||
person_id: str,
|
||||
request: PersonUpdateRequest,
|
||||
):
|
||||
"""
|
||||
增量更新人物信息(只更新提供的字段)
|
||||
) -> PersonUpdateResponse:
|
||||
"""增量更新人物信息。
|
||||
|
||||
Args:
|
||||
person_id: 人物唯一 ID
|
||||
request: 更新请求(只包含需要更新的字段)
|
||||
person_id: 人物唯一 ID。
|
||||
request: 只包含需要更新字段的请求数据。
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
PersonUpdateResponse: 更新结果和更新后的人物信息。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||
person = session.exec(statement).first()
|
||||
|
||||
if not person:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||
|
||||
# 只更新提供的字段
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
|
||||
@@ -270,17 +276,23 @@ async def update_person(
|
||||
db_person = session.exec(select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)).first()
|
||||
if not db_person:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||
for field, value in update_data.items():
|
||||
if hasattr(db_person, field):
|
||||
setattr(db_person, field, value)
|
||||
if "person_name" in update_data:
|
||||
db_person.person_name = update_data["person_name"]
|
||||
if "name_reason" in update_data:
|
||||
db_person.name_reason = update_data["name_reason"]
|
||||
if "nickname" in update_data:
|
||||
db_person.user_nickname = update_data["nickname"]
|
||||
if "memory_points" in update_data:
|
||||
db_person.memory_points = update_data["memory_points"]
|
||||
if "is_known" in update_data:
|
||||
db_person.is_known = update_data["is_known"]
|
||||
db_person.last_known_time = update_data["last_known_time"]
|
||||
session.add(db_person)
|
||||
person = db_person
|
||||
data = person_to_response(db_person)
|
||||
|
||||
logger.info(f"人物信息已更新: {person_id}, 字段: {list(update_data.keys())}")
|
||||
|
||||
return PersonUpdateResponse(
|
||||
success=True, message=f"成功更新 {len(update_data)} 个字段", data=person_to_response(person)
|
||||
)
|
||||
return PersonUpdateResponse(success=True, message=f"成功更新 {len(update_data)} 个字段", data=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -290,29 +302,26 @@ async def update_person(
|
||||
|
||||
|
||||
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
|
||||
async def delete_person(person_id: str):
|
||||
"""
|
||||
删除人物信息
|
||||
async def delete_person(person_id: str) -> PersonDeleteResponse:
|
||||
"""删除人物信息。
|
||||
|
||||
Args:
|
||||
person_id: 人物唯一 ID
|
||||
person_id: 人物唯一 ID。
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
PersonDeleteResponse: 删除结果。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||
person = session.exec(statement).first()
|
||||
|
||||
if not person:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||
if not person:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||
|
||||
# 记录删除信息
|
||||
person_name = person.person_name or person.user_nickname or person.user_id
|
||||
# 记录删除信息
|
||||
person_name = person.person_name or person.user_nickname or person.user_id
|
||||
|
||||
# 执行删除
|
||||
with get_db_session() as session:
|
||||
session.exec(delete(PersonInfo).where(col(PersonInfo.person_id) == person_id))
|
||||
|
||||
logger.info(f"人物信息已删除: {person_id} ({person_name})")
|
||||
@@ -327,12 +336,11 @@ async def delete_person(person_id: str):
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_person_stats():
|
||||
"""
|
||||
获取人物信息统计数据
|
||||
async def get_person_stats() -> Dict[str, Any]:
|
||||
"""获取人物信息统计数据。
|
||||
|
||||
Returns:
|
||||
统计数据
|
||||
Dict[str, Any]: 人物总数、已认识数量和平台分布统计。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
@@ -359,15 +367,14 @@ async def get_person_stats():
|
||||
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
||||
async def batch_delete_persons(
|
||||
request: BatchDeleteRequest,
|
||||
):
|
||||
"""
|
||||
批量删除人物信息
|
||||
) -> BatchDeleteResponse:
|
||||
"""批量删除人物信息。
|
||||
|
||||
Args:
|
||||
request: 包含person_ids列表的请求
|
||||
request: 包含人物 ID 列表的请求。
|
||||
|
||||
Returns:
|
||||
批量删除结果
|
||||
BatchDeleteResponse: 批量删除结果。
|
||||
"""
|
||||
try:
|
||||
if not request.person_ids:
|
||||
|
||||
@@ -8,9 +8,11 @@ class EmojiResponse(BaseModel):
|
||||
|
||||
id: int
|
||||
full_path: str
|
||||
format: str
|
||||
emoji_hash: str
|
||||
description: str
|
||||
query_count: int
|
||||
usage_count: int
|
||||
is_registered: bool
|
||||
is_banned: bool
|
||||
emotion: Optional[str]
|
||||
|
||||
42
代码备忘.md
42
代码备忘.md
@@ -1,42 +0,0 @@
|
||||
# 代码备忘
|
||||
|
||||
.env中的webui配置仍旧在被读取
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# 代码备忘
|
||||
|
||||
- [ ] 检查EmojiManager的replace_an_emoji_by_llm传入的emoji是否真的是没有注册到db的
|
||||
- [ ] According to a comment, MaiMBot's check_types() accesses format_info.accept_format without None check
|
||||
- [ ] 如果需要更多的消息格式支持,更新列表如下:
|
||||
- [ ] `src/common/utils/utils_message.py`中的`_parse_maim_message_segment_to_component`函数
|
||||
- [ ] `src/common/data_models/message_component_model.py`中:
|
||||
- [ ] 增加新的消息组件
|
||||
- [ ] 看情况修改`StandardMessageComponents`的内容
|
||||
- [ ] `MessageSequence`的`_dict_2_item`和`_item_2_dict`函数
|
||||
- [ ] **取消了从chat_manager获取ChatSession时候的deepcopy,看看会不会有问题**
|
||||
|
||||
# 迁移脚本备忘
|
||||
- [ ] 迁移env到新版的bot_config管理
|
||||
- [ ] 对于旧的消息,需要重新计算其Hash(md5 -> sha256),做好映射防止消息丢失
|
||||
- [ ] PersonInfo的group_nickname名字改为group_cardname,做好映射防止数据丢失,同时存储的方式从`[{"group_id": str, "group_nick_name": str}]` -> `[{"group_id": str, "group_cardname": str}]`
|
||||
- [ ] Expression中的`up_content`被移除了
|
||||
- [ ] Jargon现在chat_id(session_id_list,格式为`[["session_id", session_count]]`) -> session_id_dict(`{"session_id": session_count}`),做好映射防止数据丢失
|
||||
|
||||
# 插件开发备忘
|
||||
- [ ] 求各位插件开发不要在Dict里面塞一堆乱七八糟的东西,免得数据库存储的时候一团糟
|
||||
|
||||
# Hack备忘
|
||||
- [ ] 对于不符合内容审查要求的表情包,无法注册到数据库内,因此面对相同的非法表情包时,会导致反复识别。有成功注册的可能。
|
||||
- [ ] 考虑到数据库记录表情包不合规判定有大模型误判的风险,因此保留现有的无法注册的情况,在再次遇到的时候重新识别。
|
||||
- [ ] 目前在匿名化build message的时候,如果一个被回复的消息包含了一个转发消息组件,那么这个转发消息组件中的用户信息是不会被匿名化的,后续需要修复这个问题。(有时候感觉用正则是对的)
|
||||
- [ ] 可以考虑将消息保存的时候就将消息中的用户信息匿名化,这样在后续的处理过程中就不需要担心匿名化的问题了,同时也可以避免在build message的时候进行复杂的递归处理,同时还要保存匿名映射表。
|
||||
|
||||
# 计算备忘
|
||||
- [ ] emoji的emotion比较是基于编辑距离的,考虑更换为基于语义的比较(比如使用emoji的embedding进行比较),以提高准确性和鲁棒性
|
||||
- [ ] expression的相似度比较是基于LCS的(Ratcliff-Obershelp算法),考虑更换为基于语义的比较(比如使用embedding进行比较),以提高准确性和鲁棒性
|
||||
- [ ] 为了保持代码的简洁性,HFC无论任何情况都将初始化ExpressionReflector,ExpressionLearner,JargonMiner实例,无论配置文件中是否在此聊天流启用了他们。
|
||||
- [ ] 可优化方向:将其置为Optional,在不启用的情况下不进行初始化
|
||||
- [ ] 当配置文件重载时,重新分析所有启用判定,所有HFC进行并行检查,将启用的进行实例化。不启用的实例化移除引用,释放内存。
|
||||
Reference in New Issue
Block a user