merge: 同步上游 dev 最新内容

This commit is contained in:
DawnARC
2026-05-06 00:53:11 +08:00
125 changed files with 3069 additions and 1271 deletions

View File

@@ -112,7 +112,7 @@ class ChatHistoryManager:
return {
"id": msg.message_id,
"type": "bot" if is_bot else "user",
"content": msg.processed_plain_text or msg.display_message or "",
"content": msg.processed_plain_text or "",
"timestamp": msg.timestamp.timestamp(),
"sender_name": user_info.user_nickname or (global_config.bot.nickname if is_bot else "未知用户"),
"sender_id": "bot" if is_bot else user_id,
@@ -175,11 +175,7 @@ class ChatHistoryManager:
user_info = target_msg.message_info.user_info
if not has_content:
content_text = (
target_msg.processed_plain_text
or target_msg.display_message
or ""
)
content_text = target_msg.processed_plain_text or ""
data["target_message_content"] = content_text
if not has_sender:
data["target_message_sender_id"] = user_info.user_id or ""

View File

@@ -2,19 +2,21 @@
配置管理API路由
"""
from pathlib import Path
from typing import Annotated, Any, Dict, List, Tuple, Union, get_args, get_origin
import copy
import os
from pathlib import Path
from typing import Annotated, Any, Dict, List, Tuple
import types
import tomlkit
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from fastapi.responses import FileResponse
from pydantic import BaseModel, Field
import tomlkit
from src.common.logger import get_logger
from src.common.prompt_i18n import list_prompt_templates
from src.config.config import CONFIG_DIR, PROJECT_ROOT, Config, ModelConfig
from src.config.config_base import AttributeData
from src.config.config_base import AttributeData, ConfigBase
from src.config.model_configs import (
APIProvider,
ModelInfo,
@@ -63,6 +65,9 @@ class PromptFileInfo(BaseModel):
name: str = Field(..., description="Prompt 文件名")
size: int = Field(..., description="文件大小")
modified_at: float = Field(..., description="最后修改时间戳")
display_name: str = Field(default="", description="Prompt 展示名称")
advanced: bool = Field(default=False, description="是否为高级 Prompt")
description: str = Field(default="", description="Prompt 描述")
class PromptCatalogResponse(BaseModel):
@@ -129,6 +134,71 @@ def _toml_to_plain_dict(obj: Any) -> Any:
return obj
def _coerce_numeric_value(value: Any, target_type: Any) -> Any:
"""根据配置字段类型,把旧 WebUI 可能写入的数字字符串还原为数字。"""
if target_type is str:
if isinstance(value, (int, float)):
return str(value)
return value
if target_type is int:
if isinstance(value, str):
try:
parsed_value = float(value.strip())
except ValueError:
return value
if parsed_value.is_integer():
return int(parsed_value)
return value
if target_type is float:
if isinstance(value, str):
try:
return float(value.strip())
except ValueError:
return value
return value
return value
def _coerce_value_by_annotation(value: Any, annotation: Any) -> Any:
"""递归按 ConfigBase 字段注解修正数据类型,避免保存时把数字写成字符串。"""
value = _coerce_numeric_value(value, annotation)
origin = get_origin(annotation)
args = get_args(annotation)
if origin in {Union, types.UnionType}:
for candidate_type in args:
if candidate_type is type(None):
continue
coerced_value = _coerce_value_by_annotation(value, candidate_type)
if coerced_value != value or type(coerced_value) is not type(value):
return coerced_value
return value
if origin in {list, List} and isinstance(value, list) and args:
item_type = args[0]
return [_coerce_value_by_annotation(item, item_type) for item in value]
if origin in {dict, Dict} and isinstance(value, dict) and len(args) >= 2:
value_type = args[1]
return {key: _coerce_value_by_annotation(item, value_type) for key, item in value.items()}
if isinstance(value, dict) and isinstance(annotation, type) and issubclass(annotation, ConfigBase):
return _coerce_config_numeric_values(value, annotation)
return value
def _coerce_config_numeric_values(data: Dict[str, Any], config_type: type[ConfigBase]) -> Dict[str, Any]:
"""按配置类 schema 统一修正所有数字字段类型。"""
for field_name, field_info in config_type.model_fields.items():
if field_name in data:
data[field_name] = _coerce_value_by_annotation(data[field_name], field_info.annotation)
return data
# ===== 架构获取接口 =====
@@ -147,14 +217,20 @@ async def list_prompt_files():
continue
language = language_dir.name
prompt_template_infos = list_prompt_templates(locale=language, prompts_root=PROMPTS_DIR)
prompt_files: List[PromptFileInfo] = []
for prompt_file in sorted(language_dir.glob("*.prompt"), key=lambda item: item.name):
stat = prompt_file.stat()
template_info = prompt_template_infos.get(prompt_file.stem)
metadata = template_info.metadata if template_info and template_info.path == prompt_file else None
prompt_files.append(
PromptFileInfo(
name=prompt_file.name,
size=stat.st_size,
modified_at=stat.st_mtime,
display_name=metadata.display_name if metadata else "",
advanced=metadata.advanced if metadata else False,
description=metadata.description if metadata else "",
)
)
@@ -347,6 +423,8 @@ async def get_model_config():
async def update_bot_config(config_data: ConfigBody):
"""更新麦麦主程序配置"""
try:
config_data = _coerce_config_numeric_values(config_data, Config)
# 验证配置数据
try:
Config.from_dict(AttributeData(), copy.deepcopy(config_data))
@@ -370,6 +448,8 @@ async def update_bot_config(config_data: ConfigBody):
async def update_model_config(config_data: ConfigBody):
"""更新模型配置"""
try:
config_data = _coerce_config_numeric_values(config_data, ModelConfig)
# 验证配置数据
try:
ModelConfig.from_dict(AttributeData(), copy.deepcopy(config_data))
@@ -422,10 +502,13 @@ async def update_bot_config_section(section_name: str, section_data: SectionBody
# 验证完整配置
try:
Config.from_dict(AttributeData(), _toml_to_plain_dict(config_data))
plain_config_data = _coerce_config_numeric_values(_toml_to_plain_dict(config_data), Config)
Config.from_dict(AttributeData(), copy.deepcopy(plain_config_data))
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
config_data = plain_config_data
# 保存配置(格式化数组为多行,保留注释)
save_toml_with_format(config_data, config_path)
@@ -520,13 +603,14 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
# 验证完整配置
try:
ModelConfig.from_dict(AttributeData(), _toml_to_plain_dict(config_data))
plain_config_data = _coerce_config_numeric_values(_toml_to_plain_dict(config_data), ModelConfig)
ModelConfig.from_dict(AttributeData(), copy.deepcopy(plain_config_data))
except Exception as e:
logger.error(f"配置数据验证失败,详细错误: {str(e)}")
# 特殊处理:如果是更新 api_providers检查是否有模型引用了已删除的provider
if section_name == "api_providers" and "api_provider" in str(e):
provider_names = {p.get("name") for p in section_data if isinstance(p, dict)}
models = config_data.get("models", [])
models = plain_config_data.get("models", [])
orphaned_models: List[str] = [
str(model_name)
for m in models
@@ -539,6 +623,8 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
raise HTTPException(status_code=400, detail=error_msg) from e
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
config_data = plain_config_data
# 保存配置(格式化数组为多行,保留注释)
save_toml_with_format(config_data, config_path)

View File

@@ -10,11 +10,12 @@ from sqlmodel import col, delete, select
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.common.database.database_model import ChatSession, Expression, Messages, ModifiedBy
from src.common.logger import get_logger
from src.webui.dependencies import require_auth
logger = get_logger("webui.expression")
EXCLUDE_IDS_QUERY = Query(None, description="需要排除的表达方式 ID")
# 创建路由器
router = APIRouter(prefix="/expression", tags=["Expression"], dependencies=[Depends(require_auth)])
@@ -28,6 +29,7 @@ class ExpressionResponse(BaseModel):
style: str
last_active_time: float
chat_id: str
chat_name: Optional[str] = None
create_date: Optional[float]
checked: bool
rejected: bool
@@ -90,7 +92,61 @@ class ExpressionCreateResponse(BaseModel):
data: ExpressionResponse
def expression_to_response(expression: Expression) -> ExpressionResponse:
def get_chat_name_from_latest_message(chat_id: str, db_session: Any) -> Optional[str]:
"""从最近消息中解析聊天显示名称。"""
statement = (
select(Messages)
.where(col(Messages.session_id) == chat_id)
.order_by(col(Messages.timestamp).desc())
.limit(1)
)
message = db_session.exec(statement).first()
if not message:
return None
if message.group_id:
return message.group_name or f"群聊{message.group_id}"
return message.user_cardname or message.user_nickname or (f"用户{message.user_id}" if message.user_id else None)
def get_chat_name_from_session_record(chat_session: ChatSession) -> str:
"""从会话记录推断兜底显示名称。"""
if chat_session.group_id:
return f"群聊{chat_session.group_id}"
if chat_session.user_id:
return f"用户{chat_session.user_id}"
return chat_session.session_id
def get_chat_name(chat_id: str, db_session: Optional[Any] = None) -> str:
"""根据聊天 ID 获取聊天名称。
Args:
chat_id: 聊天会话 ID。
db_session: 可选数据库会话,用于从历史消息中解析群名或私聊用户名。
Returns:
str: 聊天显示名称,获取失败时返回原始聊天 ID。
"""
try:
if name := _chat_manager.get_session_name(chat_id):
return name
if db_session and (name := get_chat_name_from_latest_message(chat_id, db_session)):
return name
session = _chat_manager.get_session_by_session_id(chat_id)
if session:
if session.group_id:
return f"群聊{session.group_id}"
if session.user_id:
return f"用户{session.user_id}"
return chat_id
except Exception:
return chat_id
def expression_to_response(expression: Expression, db_session: Optional[Any] = None) -> ExpressionResponse:
"""将表达方式模型转换为响应对象。
Args:
@@ -101,38 +157,21 @@ def expression_to_response(expression: Expression) -> ExpressionResponse:
"""
last_active_time = expression.last_active_time.timestamp() if expression.last_active_time else 0.0
create_date = expression.create_time.timestamp() if expression.create_time else None
chat_id = expression.session_id or ""
return ExpressionResponse(
id=expression.id if expression.id is not None else 0,
situation=expression.situation,
style=expression.style,
last_active_time=last_active_time,
chat_id=expression.session_id or "",
chat_id=chat_id,
chat_name=get_chat_name(chat_id, db_session) if chat_id else None,
create_date=create_date,
checked=False,
rejected=False,
modified_by=None,
checked=expression.checked,
rejected=expression.rejected,
modified_by=expression.modified_by.value if expression.modified_by else None,
)
def get_chat_name(chat_id: str) -> str:
"""根据聊天 ID 获取聊天名称。
Args:
chat_id: 聊天会话 ID。
Returns:
str: 聊天显示名称,获取失败时返回原始聊天 ID。
"""
try:
session = _chat_manager.get_session_by_session_id(chat_id)
if not session:
return chat_id
name = _chat_manager.get_session_name(chat_id)
return name or chat_id
except Exception:
return chat_id
def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
"""批量获取聊天名称。
@@ -145,8 +184,7 @@ def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
try:
for chat_id in chat_ids:
if name := _chat_manager.get_session_name(chat_id):
result[chat_id] = name
result[chat_id] = get_chat_name(chat_id)
except Exception as e:
logger.warning(f"批量获取聊天名称失败: {e}")
return result
@@ -176,19 +214,43 @@ async def get_chat_list() -> ChatListResponse:
ChatListResponse: 可用于下拉选择的聊天列表。
"""
try:
chat_list = []
chat_by_id: Dict[str, ChatInfo] = {}
for session_id, session in _chat_manager.sessions.items():
chat_name = _chat_manager.get_session_name(session_id) or session_id
chat_list.append(
ChatInfo(
chat_id=session_id,
chat_name=chat_name,
platform=session.platform,
is_group=session.is_group_session,
)
chat_by_id[session_id] = ChatInfo(
chat_id=session_id,
chat_name=chat_name,
platform=session.platform,
is_group=session.is_group_session,
)
with get_db_session() as session:
for chat_session in session.exec(select(ChatSession)).all():
if chat_session.session_id in chat_by_id:
continue
chat_name = get_chat_name_from_latest_message(chat_session.session_id, session)
chat_by_id[chat_session.session_id] = ChatInfo(
chat_id=chat_session.session_id,
chat_name=chat_name or get_chat_name_from_session_record(chat_session),
platform=chat_session.platform,
is_group=bool(chat_session.group_id),
)
expression_chat_ids = {
chat_id for chat_id in session.exec(select(Expression.session_id)).all() if chat_id
}
for session_id in expression_chat_ids:
if session_id in chat_by_id:
continue
chat_by_id[session_id] = ChatInfo(
chat_id=session_id,
chat_name=get_chat_name(session_id, session),
platform=None,
is_group=False,
)
# 按名称排序
chat_list = list(chat_by_id.values())
chat_list.sort(key=lambda x: x.chat_name)
return ChatListResponse(success=True, data=chat_list)
@@ -252,7 +314,7 @@ async def get_expression_list(
if chat_id:
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
total = len(session.exec(count_statement).all())
data = [expression_to_response(expr) for expr in expressions]
data = [expression_to_response(expr, session) for expr in expressions]
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
@@ -281,7 +343,7 @@ async def get_expression_detail(expression_id: int) -> ExpressionDetailResponse:
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
data = expression_to_response(expression)
data = expression_to_response(expression, session)
return ExpressionDetailResponse(success=True, data=data)
@@ -321,7 +383,7 @@ async def create_expression(
session.add(expression)
session.flush()
expression_id = expression.id
data = expression_to_response(expression)
data = expression_to_response(expression, session)
logger.info(f"表达方式已创建: ID={expression_id}, situation={request.situation}")
@@ -375,7 +437,7 @@ async def update_expression(
db_expression.session_id = update_data["session_id"]
db_expression.last_active_time = update_data["last_active_time"]
session.add(db_expression)
data = expression_to_response(db_expression)
data = expression_to_response(db_expression, session)
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
@@ -524,6 +586,22 @@ class ReviewStatsResponse(BaseModel):
user_checked: int
def apply_review_filter(statement: Any, filter_type: str) -> Any:
"""按审核状态过滤表达方式查询。"""
if filter_type == "unchecked":
return statement.where(col(Expression.checked).is_(False))
if filter_type == "passed":
return statement.where(col(Expression.checked).is_(True), col(Expression.rejected).is_(False))
if filter_type == "rejected":
return statement.where(col(Expression.checked).is_(True), col(Expression.rejected).is_(True))
return statement
def count_expressions(session: Any, statement: Any) -> int:
"""统计表达方式查询结果数量。"""
return len(session.exec(statement).all())
@router.get("/review/stats", response_model=ReviewStatsResponse)
async def get_review_stats() -> ReviewStatsResponse:
"""获取审核统计数据。
@@ -533,12 +611,24 @@ async def get_review_stats() -> ReviewStatsResponse:
"""
try:
with get_db_session() as session:
total = len(session.exec(select(Expression.id)).all())
unchecked = 0
passed = 0
rejected = 0
ai_checked = 0
user_checked = 0
total = count_expressions(session, select(Expression.id))
unchecked = count_expressions(session, apply_review_filter(select(Expression.id), "unchecked"))
passed = count_expressions(session, apply_review_filter(select(Expression.id), "passed"))
rejected = count_expressions(session, apply_review_filter(select(Expression.id), "rejected"))
ai_checked = count_expressions(
session,
select(Expression.id).where(
col(Expression.checked).is_(True),
col(Expression.modified_by) == ModifiedBy.AI,
),
)
user_checked = count_expressions(
session,
select(Expression.id).where(
col(Expression.checked).is_(True),
col(Expression.modified_by) == ModifiedBy.USER,
),
)
return ReviewStatsResponse(
total=total,
@@ -571,8 +661,10 @@ async def get_review_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
filter_type: str = Query("unchecked", description="筛选类型: unchecked/passed/rejected/all"),
order: str = Query("latest", description="排序方式: latest/random"),
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
exclude_ids: Optional[List[int]] = EXCLUDE_IDS_QUERY,
) -> ReviewListResponse:
"""获取待审核或已审核的表达方式列表。
@@ -580,17 +672,16 @@ async def get_review_list(
page: 页码。
page_size: 每页数量。
filter_type: 筛选类型,可选 unchecked、passed、rejected 或 all。
order: 排序方式,可选 latest 或 random。
search: 搜索关键词。
chat_id: 聊天 ID 筛选条件。
exclude_ids: 需要排除的表达方式 ID。
Returns:
ReviewListResponse: 审核列表响应。
"""
try:
statement = select(Expression)
if filter_type in {"unchecked", "passed", "rejected"}:
statement = statement.where(col(Expression.id) == -1)
statement = apply_review_filter(select(Expression), filter_type)
# all 不需要额外过滤
# 搜索过滤
@@ -603,11 +694,17 @@ async def get_review_list(
if chat_id:
statement = statement.where(col(Expression.session_id) == chat_id)
# 排序:创建时间倒序
statement = statement.order_by(
case((col(Expression.create_time).is_(None), 1), else_=0),
col(Expression.create_time).desc(),
)
if exclude_ids:
statement = statement.where(~col(Expression.id).in_(exclude_ids))
if order == "random":
statement = statement.order_by(func.random())
else:
# 排序:创建时间倒序
statement = statement.order_by(
case((col(Expression.create_time).is_(None), 1), else_=0),
col(Expression.create_time).desc(),
)
offset = (page - 1) * page_size
statement = statement.offset(offset).limit(page_size)
@@ -615,9 +712,7 @@ async def get_review_list(
with get_db_session() as session:
expressions = session.exec(statement).all()
count_statement = select(Expression.id)
if filter_type in {"unchecked", "passed", "rejected"}:
count_statement = count_statement.where(col(Expression.id) == -1)
count_statement = apply_review_filter(select(Expression.id), filter_type)
if search:
count_statement = count_statement.where(
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
@@ -625,7 +720,7 @@ async def get_review_list(
if chat_id:
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
total = len(session.exec(count_statement).all())
data = [expression_to_response(expr) for expr in expressions]
data = [expression_to_response(expr, session) for expr in expressions]
return ReviewListResponse(
success=True,
@@ -647,7 +742,7 @@ class BatchReviewItem(BaseModel):
id: int
rejected: bool
require_unchecked: bool = True # 默认要求未检查状态
require_unchecked: bool = True # 前端保留的来源标记,人工审核提交时不再阻断覆盖
class BatchReviewRequest(BaseModel):
@@ -706,14 +801,6 @@ async def batch_review_expressions(
failed += 1
continue
# 冲突检测
if item.require_unchecked:
results.append(
BatchReviewResultItem(id=item.id, success=False, message="当前模型不支持审核状态过滤")
)
failed += 1
continue
# 更新状态
with get_db_session() as session:
db_expression = session.exec(
@@ -727,6 +814,9 @@ async def batch_review_expressions(
)
failed += 1
continue
db_expression.checked = True
db_expression.rejected = item.rejected
db_expression.modified_by = ModifiedBy.USER
db_expression.last_active_time = datetime.now()
session.add(db_expression)