Ruff fix
This commit is contained in:
4
bot.py
4
bot.py
@@ -1,7 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
import platform
|
import platform
|
||||||
import traceback
|
import traceback
|
||||||
@@ -30,7 +29,7 @@ else:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
|
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
|
||||||
from src.common.logger import initialize_logging, get_logger, shutdown_logging #noqa
|
from src.common.logger import initialize_logging, get_logger, shutdown_logging # noqa
|
||||||
|
|
||||||
initialize_logging()
|
initialize_logging()
|
||||||
|
|
||||||
@@ -215,6 +214,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# 初始化 WebSocket 日志推送
|
# 初始化 WebSocket 日志推送
|
||||||
from src.common.logger import initialize_ws_handler
|
from src.common.logger import initialize_ws_handler
|
||||||
|
|
||||||
initialize_ws_handler(loop)
|
initialize_ws_handler(loop)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -16,8 +16,6 @@ if PROJECT_ROOT not in sys.path:
|
|||||||
sys.path.insert(0, PROJECT_ROOT)
|
sys.path.insert(0, PROJECT_ROOT)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
SECONDS_5_MINUTES = 5 * 60
|
SECONDS_5_MINUTES = 5 * 60
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|||||||
sys.path.insert(0, project_root)
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 设置中文字体
|
# 设置中文字体
|
||||||
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
|
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
|
||||||
plt.rcParams["axes.unicode_minus"] = False
|
plt.rcParams["axes.unicode_minus"] = False
|
||||||
|
|||||||
@@ -57,8 +57,8 @@ from src.common.database.database import db
|
|||||||
from src.common.database.database_model import Emoji
|
from src.common.database.database_model import Emoji
|
||||||
|
|
||||||
# 常量定义
|
# 常量定义
|
||||||
MAGIC = b'MMIP'
|
MAGIC = b"MMIP"
|
||||||
FOOTER_MAGIC = b'MMFF'
|
FOOTER_MAGIC = b"MMFF"
|
||||||
VERSION = 1
|
VERSION = 1
|
||||||
FOOTER_VERSION = 1
|
FOOTER_VERSION = 1
|
||||||
|
|
||||||
@@ -67,7 +67,7 @@ MAX_MANIFEST_SIZE = 200 * 1024 * 1024 # 200 MB
|
|||||||
MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 * 1024 # 10 GB
|
MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 * 1024 # 10 GB
|
||||||
|
|
||||||
# 支持的图片格式
|
# 支持的图片格式
|
||||||
SUPPORTED_FORMATS = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.avif', '.bmp'}
|
SUPPORTED_FORMATS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".avif", ".bmp"}
|
||||||
|
|
||||||
# 创建控制台对象
|
# 创建控制台对象
|
||||||
console = Console()
|
console = Console()
|
||||||
@@ -75,6 +75,7 @@ console = Console()
|
|||||||
|
|
||||||
class MMIPKGError(Exception):
|
class MMIPKGError(Exception):
|
||||||
"""MMIPKG 相关错误"""
|
"""MMIPKG 相关错误"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -97,55 +98,55 @@ def get_image_info(file_path: str) -> Tuple[int, int, str]:
|
|||||||
try:
|
try:
|
||||||
with Image.open(file_path) as img:
|
with Image.open(file_path) as img:
|
||||||
width, height = img.size
|
width, height = img.size
|
||||||
format_lower = img.format.lower() if img.format else 'unknown'
|
format_lower = img.format.lower() if img.format else "unknown"
|
||||||
mime_map = {
|
mime_map = {
|
||||||
'jpeg': 'image/jpeg',
|
"jpeg": "image/jpeg",
|
||||||
'jpg': 'image/jpeg',
|
"jpg": "image/jpeg",
|
||||||
'png': 'image/png',
|
"png": "image/png",
|
||||||
'gif': 'image/gif',
|
"gif": "image/gif",
|
||||||
'webp': 'image/webp',
|
"webp": "image/webp",
|
||||||
'avif': 'image/avif',
|
"avif": "image/avif",
|
||||||
'bmp': 'image/bmp'
|
"bmp": "image/bmp",
|
||||||
}
|
}
|
||||||
mime_type = mime_map.get(format_lower, f'image/{format_lower}')
|
mime_type = mime_map.get(format_lower, f"image/{format_lower}")
|
||||||
return width, height, mime_type
|
return width, height, mime_type
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"警告: 无法读取图片信息 {file_path}: {e}")
|
print(f"警告: 无法读取图片信息 {file_path}: {e}")
|
||||||
return 0, 0, 'image/unknown'
|
return 0, 0, "image/unknown"
|
||||||
|
|
||||||
|
|
||||||
def reencode_image(file_path: str, output_format: str = 'webp', quality: int = 80) -> bytes:
|
def reencode_image(file_path: str, output_format: str = "webp", quality: int = 80) -> bytes:
|
||||||
"""重新编码图片"""
|
"""重新编码图片"""
|
||||||
try:
|
try:
|
||||||
with Image.open(file_path) as img:
|
with Image.open(file_path) as img:
|
||||||
# 转换为 RGB(如果需要)
|
# 转换为 RGB(如果需要)
|
||||||
if img.mode in ('RGBA', 'LA', 'P'):
|
if img.mode in ("RGBA", "LA", "P"):
|
||||||
if output_format.lower() == 'jpeg':
|
if output_format.lower() == "jpeg":
|
||||||
# JPEG 不支持透明度,转为白色背景
|
# JPEG 不支持透明度,转为白色背景
|
||||||
background = Image.new('RGB', img.size, (255, 255, 255))
|
background = Image.new("RGB", img.size, (255, 255, 255))
|
||||||
if img.mode == 'P':
|
if img.mode == "P":
|
||||||
img = img.convert('RGBA')
|
img = img.convert("RGBA")
|
||||||
background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None)
|
background.paste(img, mask=img.split()[-1] if img.mode == "RGBA" else None)
|
||||||
img = background
|
img = background
|
||||||
elif output_format.lower() == 'webp':
|
elif output_format.lower() == "webp":
|
||||||
# WebP 支持透明度
|
# WebP 支持透明度
|
||||||
if img.mode == 'P':
|
if img.mode == "P":
|
||||||
img = img.convert('RGBA')
|
img = img.convert("RGBA")
|
||||||
elif img.mode not in ('RGB', 'RGBA'):
|
elif img.mode not in ("RGB", "RGBA"):
|
||||||
img = img.convert('RGB')
|
img = img.convert("RGB")
|
||||||
|
|
||||||
# 编码图片
|
# 编码图片
|
||||||
output = io.BytesIO()
|
output = io.BytesIO()
|
||||||
save_kwargs = {'format': output_format.upper()}
|
save_kwargs = {"format": output_format.upper()}
|
||||||
|
|
||||||
if output_format.lower() in {'jpeg', 'jpg'}:
|
if output_format.lower() in {"jpeg", "jpg"}:
|
||||||
save_kwargs['quality'] = quality
|
save_kwargs["quality"] = quality
|
||||||
save_kwargs['optimize'] = True
|
save_kwargs["optimize"] = True
|
||||||
elif output_format.lower() == 'webp':
|
elif output_format.lower() == "webp":
|
||||||
save_kwargs['quality'] = quality
|
save_kwargs["quality"] = quality
|
||||||
save_kwargs['method'] = 6 # 更好的压缩
|
save_kwargs["method"] = 6 # 更好的压缩
|
||||||
elif output_format.lower() == 'png':
|
elif output_format.lower() == "png":
|
||||||
save_kwargs['optimize'] = True
|
save_kwargs["optimize"] = True
|
||||||
|
|
||||||
img.save(output, **save_kwargs)
|
img.save(output, **save_kwargs)
|
||||||
return output.getvalue()
|
return output.getvalue()
|
||||||
@@ -156,11 +157,13 @@ def reencode_image(file_path: str, output_format: str = 'webp', quality: int = 8
|
|||||||
class MMIPKGPacker:
|
class MMIPKGPacker:
|
||||||
"""MMIPKG 打包器"""
|
"""MMIPKG 打包器"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
use_compression: bool = True,
|
use_compression: bool = True,
|
||||||
zstd_level: int = 3,
|
zstd_level: int = 3,
|
||||||
reencode: Optional[str] = None,
|
reencode: Optional[str] = None,
|
||||||
reencode_quality: int = 80):
|
reencode_quality: int = 80,
|
||||||
|
):
|
||||||
self.use_compression = use_compression and zstd is not None
|
self.use_compression = use_compression and zstd is not None
|
||||||
self.zstd_level = zstd_level
|
self.zstd_level = zstd_level
|
||||||
self.reencode = reencode
|
self.reencode = reencode
|
||||||
@@ -170,8 +173,9 @@ class MMIPKGPacker:
|
|||||||
print("警告: zstandard 未安装,将不使用压缩")
|
print("警告: zstandard 未安装,将不使用压缩")
|
||||||
self.use_compression = False
|
self.use_compression = False
|
||||||
|
|
||||||
def pack_from_db(self, output_path: str, pack_name: Optional[str] = None,
|
def pack_from_db(
|
||||||
custom_manifest: Optional[Dict] = None) -> bool:
|
self, output_path: str, pack_name: Optional[str] = None, custom_manifest: Optional[Dict] = None
|
||||||
|
) -> bool:
|
||||||
"""从数据库导出已注册的表情包
|
"""从数据库导出已注册的表情包
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -205,12 +209,14 @@ class MMIPKGPacker:
|
|||||||
BarColumn(),
|
BarColumn(),
|
||||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||||
TimeElapsedColumn(),
|
TimeElapsedColumn(),
|
||||||
console=console
|
console=console,
|
||||||
) as progress:
|
) as progress:
|
||||||
task = progress.add_task("[cyan]扫描表情包...", total=emoji_count)
|
task = progress.add_task("[cyan]扫描表情包...", total=emoji_count)
|
||||||
|
|
||||||
for idx, emoji in enumerate(emojis, 1):
|
for idx, emoji in enumerate(emojis, 1):
|
||||||
progress.update(task, description=f"[cyan]处理 {idx}/{emoji_count}: {os.path.basename(emoji.full_path)}")
|
progress.update(
|
||||||
|
task, description=f"[cyan]处理 {idx}/{emoji_count}: {os.path.basename(emoji.full_path)}"
|
||||||
|
)
|
||||||
|
|
||||||
# 检查文件是否存在
|
# 检查文件是否存在
|
||||||
if not os.path.exists(emoji.full_path):
|
if not os.path.exists(emoji.full_path):
|
||||||
@@ -224,10 +230,10 @@ class MMIPKGPacker:
|
|||||||
img_bytes = reencode_image(emoji.full_path, self.reencode, self.reencode_quality)
|
img_bytes = reencode_image(emoji.full_path, self.reencode, self.reencode_quality)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f" [yellow]警告: 重新编码失败,使用原始文件: {e}[/yellow]")
|
console.print(f" [yellow]警告: 重新编码失败,使用原始文件: {e}[/yellow]")
|
||||||
with open(emoji.full_path, 'rb') as f:
|
with open(emoji.full_path, "rb") as f:
|
||||||
img_bytes = f.read()
|
img_bytes = f.read()
|
||||||
else:
|
else:
|
||||||
with open(emoji.full_path, 'rb') as f:
|
with open(emoji.full_path, "rb") as f:
|
||||||
img_bytes = f.read()
|
img_bytes = f.read()
|
||||||
|
|
||||||
# 计算 SHA256
|
# 计算 SHA256
|
||||||
@@ -259,7 +265,7 @@ class MMIPKGPacker:
|
|||||||
"emoji_hash": emoji.emoji_hash or "",
|
"emoji_hash": emoji.emoji_hash or "",
|
||||||
"is_registered": True,
|
"is_registered": True,
|
||||||
"is_banned": emoji.is_banned or False,
|
"is_banned": emoji.is_banned or False,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
items.append(item)
|
items.append(item)
|
||||||
@@ -281,7 +287,7 @@ class MMIPKGPacker:
|
|||||||
"p": pack_id, # pack_id
|
"p": pack_id, # pack_id
|
||||||
"n": pack_name, # pack_name
|
"n": pack_name, # pack_name
|
||||||
"t": datetime.now().isoformat(), # created_at
|
"t": datetime.now().isoformat(), # created_at
|
||||||
"a": items # items array
|
"a": items, # items array
|
||||||
}
|
}
|
||||||
|
|
||||||
# 添加自定义字段
|
# 添加自定义字段
|
||||||
@@ -308,26 +314,28 @@ class MMIPKGPacker:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"打包失败: {e}")
|
print(f"打包失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False
|
return False
|
||||||
finally:
|
finally:
|
||||||
if not db.is_closed():
|
if not db.is_closed():
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
def _write_package(self, output_path: str, manifest_bytes: bytes,
|
def _write_package(
|
||||||
image_data_list: List[bytes], payload_size: int) -> bool:
|
self, output_path: str, manifest_bytes: bytes, image_data_list: List[bytes], payload_size: int
|
||||||
|
) -> bool:
|
||||||
"""写入打包文件"""
|
"""写入打包文件"""
|
||||||
try:
|
try:
|
||||||
with open(output_path, 'wb') as f:
|
with open(output_path, "wb") as f:
|
||||||
# 写入 Header (32 bytes)
|
# 写入 Header (32 bytes)
|
||||||
flags = 0x01 if self.use_compression else 0x00
|
flags = 0x01 if self.use_compression else 0x00
|
||||||
header = MAGIC # 4 bytes
|
header = MAGIC # 4 bytes
|
||||||
header += struct.pack('B', VERSION) # 1 byte
|
header += struct.pack("B", VERSION) # 1 byte
|
||||||
header += struct.pack('B', flags) # 1 byte
|
header += struct.pack("B", flags) # 1 byte
|
||||||
header += b'\x00\x00' # 2 bytes reserved
|
header += b"\x00\x00" # 2 bytes reserved
|
||||||
header += struct.pack('>Q', payload_size) # 8 bytes
|
header += struct.pack(">Q", payload_size) # 8 bytes
|
||||||
header += struct.pack('>Q', len(manifest_bytes)) # 8 bytes
|
header += struct.pack(">Q", len(manifest_bytes)) # 8 bytes
|
||||||
header += b'\x00' * 8 # 8 bytes reserved
|
header += b"\x00" * 8 # 8 bytes reserved
|
||||||
|
|
||||||
assert len(header) == 32, f"Header size mismatch: {len(header)}"
|
assert len(header) == 32, f"Header size mismatch: {len(header)}"
|
||||||
f.write(header)
|
f.write(header)
|
||||||
@@ -342,7 +350,7 @@ class MMIPKGPacker:
|
|||||||
|
|
||||||
with compressor.stream_writer(f, closefd=False) as writer:
|
with compressor.stream_writer(f, closefd=False) as writer:
|
||||||
# 写入 manifest
|
# 写入 manifest
|
||||||
manifest_len_bytes = struct.pack('>I', len(manifest_bytes))
|
manifest_len_bytes = struct.pack(">I", len(manifest_bytes))
|
||||||
writer.write(manifest_len_bytes)
|
writer.write(manifest_len_bytes)
|
||||||
writer.write(manifest_bytes)
|
writer.write(manifest_bytes)
|
||||||
payload_sha.update(manifest_len_bytes)
|
payload_sha.update(manifest_len_bytes)
|
||||||
@@ -355,13 +363,13 @@ class MMIPKGPacker:
|
|||||||
BarColumn(),
|
BarColumn(),
|
||||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||||
TimeRemainingColumn(),
|
TimeRemainingColumn(),
|
||||||
console=console
|
console=console,
|
||||||
) as progress:
|
) as progress:
|
||||||
task = progress.add_task("[green]压缩写入图片...", total=len(image_data_list))
|
task = progress.add_task("[green]压缩写入图片...", total=len(image_data_list))
|
||||||
|
|
||||||
for idx, img_bytes in enumerate(image_data_list, 1):
|
for idx, img_bytes in enumerate(image_data_list, 1):
|
||||||
progress.update(task, description=f"[green]压缩写入 {idx}/{len(image_data_list)}")
|
progress.update(task, description=f"[green]压缩写入 {idx}/{len(image_data_list)}")
|
||||||
img_len_bytes = struct.pack('>I', len(img_bytes))
|
img_len_bytes = struct.pack(">I", len(img_bytes))
|
||||||
writer.write(img_len_bytes)
|
writer.write(img_len_bytes)
|
||||||
writer.write(img_bytes)
|
writer.write(img_bytes)
|
||||||
payload_sha.update(img_len_bytes)
|
payload_sha.update(img_len_bytes)
|
||||||
@@ -370,7 +378,7 @@ class MMIPKGPacker:
|
|||||||
else:
|
else:
|
||||||
# 不压缩,直接写入
|
# 不压缩,直接写入
|
||||||
# 写入 manifest
|
# 写入 manifest
|
||||||
manifest_len_bytes = struct.pack('>I', len(manifest_bytes))
|
manifest_len_bytes = struct.pack(">I", len(manifest_bytes))
|
||||||
f.write(manifest_len_bytes)
|
f.write(manifest_len_bytes)
|
||||||
f.write(manifest_bytes)
|
f.write(manifest_bytes)
|
||||||
payload_sha.update(manifest_len_bytes)
|
payload_sha.update(manifest_len_bytes)
|
||||||
@@ -383,13 +391,13 @@ class MMIPKGPacker:
|
|||||||
BarColumn(),
|
BarColumn(),
|
||||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||||
TimeRemainingColumn(),
|
TimeRemainingColumn(),
|
||||||
console=console
|
console=console,
|
||||||
) as progress:
|
) as progress:
|
||||||
task = progress.add_task("[green]写入图片...", total=len(image_data_list))
|
task = progress.add_task("[green]写入图片...", total=len(image_data_list))
|
||||||
|
|
||||||
for idx, img_bytes in enumerate(image_data_list, 1):
|
for idx, img_bytes in enumerate(image_data_list, 1):
|
||||||
progress.update(task, description=f"[green]写入 {idx}/{len(image_data_list)}")
|
progress.update(task, description=f"[green]写入 {idx}/{len(image_data_list)}")
|
||||||
img_len_bytes = struct.pack('>I', len(img_bytes))
|
img_len_bytes = struct.pack(">I", len(img_bytes))
|
||||||
f.write(img_len_bytes)
|
f.write(img_len_bytes)
|
||||||
f.write(img_bytes)
|
f.write(img_bytes)
|
||||||
payload_sha.update(img_len_bytes)
|
payload_sha.update(img_len_bytes)
|
||||||
@@ -400,8 +408,8 @@ class MMIPKGPacker:
|
|||||||
file_sha256 = payload_sha.digest()
|
file_sha256 = payload_sha.digest()
|
||||||
footer = FOOTER_MAGIC # 4 bytes
|
footer = FOOTER_MAGIC # 4 bytes
|
||||||
footer += file_sha256 # 32 bytes
|
footer += file_sha256 # 32 bytes
|
||||||
footer += struct.pack('B', FOOTER_VERSION) # 1 byte
|
footer += struct.pack("B", FOOTER_VERSION) # 1 byte
|
||||||
footer += b'\x00' * 3 # 3 bytes reserved
|
footer += b"\x00" * 3 # 3 bytes reserved
|
||||||
|
|
||||||
assert len(footer) == 40, f"Footer size mismatch: {len(footer)}"
|
assert len(footer) == 40, f"Footer size mismatch: {len(footer)}"
|
||||||
f.write(footer)
|
f.write(footer)
|
||||||
@@ -419,6 +427,7 @@ class MMIPKGPacker:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"写入文件失败: {e}")
|
print(f"写入文件失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -429,10 +438,9 @@ class MMIPKGUnpacker:
|
|||||||
def __init__(self, verify_sha: bool = True):
|
def __init__(self, verify_sha: bool = True):
|
||||||
self.verify_sha = verify_sha
|
self.verify_sha = verify_sha
|
||||||
|
|
||||||
def import_to_db(self, package_path: str,
|
def import_to_db(
|
||||||
output_dir: Optional[str] = None,
|
self, package_path: str, output_dir: Optional[str] = None, replace_existing: bool = False, batch_size: int = 500
|
||||||
replace_existing: bool = False,
|
) -> bool:
|
||||||
batch_size: int = 500) -> bool:
|
|
||||||
"""导入到数据库"""
|
"""导入到数据库"""
|
||||||
try:
|
try:
|
||||||
if not os.path.exists(package_path):
|
if not os.path.exists(package_path):
|
||||||
@@ -451,7 +459,7 @@ class MMIPKGUnpacker:
|
|||||||
|
|
||||||
print(f"正在读取包: {package_path}")
|
print(f"正在读取包: {package_path}")
|
||||||
|
|
||||||
with open(package_path, 'rb') as f:
|
with open(package_path, "rb") as f:
|
||||||
# 读取 Header
|
# 读取 Header
|
||||||
header = f.read(32)
|
header = f.read(32)
|
||||||
if len(header) != 32:
|
if len(header) != 32:
|
||||||
@@ -461,15 +469,15 @@ class MMIPKGUnpacker:
|
|||||||
if magic != MAGIC:
|
if magic != MAGIC:
|
||||||
raise MMIPKGError(f"无效的 MAGIC: {magic}")
|
raise MMIPKGError(f"无效的 MAGIC: {magic}")
|
||||||
|
|
||||||
version = struct.unpack('B', header[4:5])[0]
|
version = struct.unpack("B", header[4:5])[0]
|
||||||
if version != VERSION:
|
if version != VERSION:
|
||||||
print(f"警告: 包版本 {version} 与当前版本 {VERSION} 不匹配")
|
print(f"警告: 包版本 {version} 与当前版本 {VERSION} 不匹配")
|
||||||
|
|
||||||
flags = struct.unpack('B', header[5:6])[0]
|
flags = struct.unpack("B", header[5:6])[0]
|
||||||
is_compressed = bool(flags & 0x01)
|
is_compressed = bool(flags & 0x01)
|
||||||
|
|
||||||
payload_uncompressed_len = struct.unpack('>Q', header[8:16])[0]
|
payload_uncompressed_len = struct.unpack(">Q", header[8:16])[0]
|
||||||
manifest_uncompressed_len = struct.unpack('>Q', header[16:24])[0]
|
manifest_uncompressed_len = struct.unpack(">Q", header[16:24])[0]
|
||||||
|
|
||||||
# 安全检查
|
# 安全检查
|
||||||
if manifest_uncompressed_len > MAX_MANIFEST_SIZE:
|
if manifest_uncompressed_len > MAX_MANIFEST_SIZE:
|
||||||
@@ -519,7 +527,9 @@ class MMIPKGUnpacker:
|
|||||||
# 方法2:如果流式失败,尝试直接解压(兼容旧格式)
|
# 方法2:如果流式失败,尝试直接解压(兼容旧格式)
|
||||||
print(f" 流式解压失败,尝试直接解压: {e}")
|
print(f" 流式解压失败,尝试直接解压: {e}")
|
||||||
try:
|
try:
|
||||||
payload_data = decompressor.decompress(compressed_data, max_output_size=payload_uncompressed_len)
|
payload_data = decompressor.decompress(
|
||||||
|
compressed_data, max_output_size=payload_uncompressed_len
|
||||||
|
)
|
||||||
except Exception as e2:
|
except Exception as e2:
|
||||||
raise MMIPKGError(f"解压失败: {e2}") from e2
|
raise MMIPKGError(f"解压失败: {e2}") from e2
|
||||||
else:
|
else:
|
||||||
@@ -537,7 +547,7 @@ class MMIPKGUnpacker:
|
|||||||
|
|
||||||
# 读取 manifest
|
# 读取 manifest
|
||||||
manifest_len_bytes = payload_stream.read(4)
|
manifest_len_bytes = payload_stream.read(4)
|
||||||
manifest_len = struct.unpack('>I', manifest_len_bytes)[0]
|
manifest_len = struct.unpack(">I", manifest_len_bytes)[0]
|
||||||
manifest_bytes = payload_stream.read(manifest_len)
|
manifest_bytes = payload_stream.read(manifest_len)
|
||||||
manifest = msgpack.unpackb(manifest_bytes, raw=False)
|
manifest = msgpack.unpackb(manifest_bytes, raw=False)
|
||||||
|
|
||||||
@@ -553,20 +563,21 @@ class MMIPKGUnpacker:
|
|||||||
print(f" 表情包数量: {len(items)}")
|
print(f" 表情包数量: {len(items)}")
|
||||||
|
|
||||||
# 导入表情包
|
# 导入表情包
|
||||||
return self._import_items(payload_stream, items, output_dir,
|
return self._import_items(payload_stream, items, output_dir, replace_existing, batch_size)
|
||||||
replace_existing, batch_size)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"导入失败: {e}")
|
print(f"导入失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False
|
return False
|
||||||
finally:
|
finally:
|
||||||
if not db.is_closed():
|
if not db.is_closed():
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
def _import_items(self, payload_stream: BinaryIO, items: List[Dict],
|
def _import_items(
|
||||||
output_dir: str, replace_existing: bool, batch_size: int) -> bool:
|
self, payload_stream: BinaryIO, items: List[Dict], output_dir: str, replace_existing: bool, batch_size: int
|
||||||
|
) -> bool:
|
||||||
"""导入 items 到数据库"""
|
"""导入 items 到数据库"""
|
||||||
try:
|
try:
|
||||||
imported_count = 0
|
imported_count = 0
|
||||||
@@ -581,7 +592,7 @@ class MMIPKGUnpacker:
|
|||||||
BarColumn(),
|
BarColumn(),
|
||||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||||
TimeRemainingColumn(),
|
TimeRemainingColumn(),
|
||||||
console=console
|
console=console,
|
||||||
) as progress:
|
) as progress:
|
||||||
task = progress.add_task("[cyan]导入表情包...", total=len(items))
|
task = progress.add_task("[cyan]导入表情包...", total=len(items))
|
||||||
|
|
||||||
@@ -597,7 +608,7 @@ class MMIPKGUnpacker:
|
|||||||
progress.advance(task)
|
progress.advance(task)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
img_len = struct.unpack('>I', img_len_bytes)[0]
|
img_len = struct.unpack(">I", img_len_bytes)[0]
|
||||||
img_bytes = payload_stream.read(img_len)
|
img_bytes = payload_stream.read(img_len)
|
||||||
|
|
||||||
if len(img_bytes) != img_len:
|
if len(img_bytes) != img_len:
|
||||||
@@ -641,7 +652,7 @@ class MMIPKGUnpacker:
|
|||||||
file_path = os.path.join(output_dir, filename)
|
file_path = os.path.join(output_dir, filename)
|
||||||
counter += 1
|
counter += 1
|
||||||
|
|
||||||
with open(file_path, 'wb') as img_file:
|
with open(file_path, "wb") as img_file:
|
||||||
img_file.write(img_bytes)
|
img_file.write(img_bytes)
|
||||||
|
|
||||||
# 准备数据库记录
|
# 准备数据库记录
|
||||||
@@ -700,6 +711,7 @@ class MMIPKGUnpacker:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f"[red]导入 items 失败: {e}[/red]")
|
console.print(f"[red]导入 items 失败: {e}[/red]")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -719,8 +731,9 @@ def print_menu():
|
|||||||
console.print(" [2] [bold]导入表情包[/bold] (从 .mmipkg 文件导入到数据库)")
|
console.print(" [2] [bold]导入表情包[/bold] (从 .mmipkg 文件导入到数据库)")
|
||||||
console.print(" [0] [bold]退出[/bold]")
|
console.print(" [0] [bold]退出[/bold]")
|
||||||
console.print()
|
console.print()
|
||||||
def get_input(prompt: str, default: Optional[str] = None,
|
|
||||||
choices: Optional[List[str]] = None) -> str:
|
|
||||||
|
def get_input(prompt: str, default: Optional[str] = None, choices: Optional[List[str]] = None) -> str:
|
||||||
"""获取用户输入"""
|
"""获取用户输入"""
|
||||||
if default:
|
if default:
|
||||||
prompt = f"{prompt} (默认: {default})"
|
prompt = f"{prompt} (默认: {default})"
|
||||||
@@ -760,9 +773,9 @@ def get_yes_no(prompt: str, default: bool = False) -> bool:
|
|||||||
if not value:
|
if not value:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
if value in ('y', 'yes', '是'):
|
if value in ("y", "yes", "是"):
|
||||||
return True
|
return True
|
||||||
elif value in ('n', 'no', '否'):
|
elif value in ("n", "no", "否"):
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
console.print(" [yellow]⚠ 请输入 y/yes/是 或 n/no/否[/yellow]")
|
console.print(" [yellow]⚠ 请输入 y/yes/是 或 n/no/否[/yellow]")
|
||||||
@@ -843,8 +856,8 @@ def interactive_export():
|
|||||||
output_path = get_input(" 输出文件路径", default_filename)
|
output_path = get_input(" 输出文件路径", default_filename)
|
||||||
|
|
||||||
# 确保有 .mmipkg 扩展名
|
# 确保有 .mmipkg 扩展名
|
||||||
if not output_path.endswith('.mmipkg'):
|
if not output_path.endswith(".mmipkg"):
|
||||||
output_path += '.mmipkg'
|
output_path += ".mmipkg"
|
||||||
|
|
||||||
# 获取包名称
|
# 获取包名称
|
||||||
default_pack_name = f"MaiBot表情包_{datetime.now().strftime('%Y%m%d')}"
|
default_pack_name = f"MaiBot表情包_{datetime.now().strftime('%Y%m%d')}"
|
||||||
@@ -853,9 +866,7 @@ def interactive_export():
|
|||||||
# 自定义 manifest
|
# 自定义 manifest
|
||||||
console.print("\n[yellow]2. 包信息设置(可选)[/yellow]")
|
console.print("\n[yellow]2. 包信息设置(可选)[/yellow]")
|
||||||
if get_yes_no(" 是否添加包的作者和介绍信息", False):
|
if get_yes_no(" 是否添加包的作者和介绍信息", False):
|
||||||
custom_manifest = {
|
custom_manifest = {"author": author} if (author := input(" 作者名称(可选): ").strip()) else {}
|
||||||
"author": author
|
|
||||||
} if (author := input(" 作者名称(可选): ").strip()) else {}
|
|
||||||
|
|
||||||
# 介绍信息
|
# 介绍信息
|
||||||
console.print(" 包介绍(限制 100 字以内):")
|
console.print(" 包介绍(限制 100 字以内):")
|
||||||
@@ -888,9 +899,9 @@ def interactive_export():
|
|||||||
console.print(" webp: 推荐,体积小且支持透明度")
|
console.print(" webp: 推荐,体积小且支持透明度")
|
||||||
console.print(" jpeg: 最小体积,但不支持透明度")
|
console.print(" jpeg: 最小体积,但不支持透明度")
|
||||||
console.print(" png: 无损,文件较大")
|
console.print(" png: 无损,文件较大")
|
||||||
reencode = get_input(" 选择格式", "webp", ['webp', 'jpeg', 'png'])
|
reencode = get_input(" 选择格式", "webp", ["webp", "jpeg", "png"])
|
||||||
|
|
||||||
quality = get_int(" 编码质量", 80, 1, 100) if reencode in ('webp', 'jpeg') else 80
|
quality = get_int(" 编码质量", 80, 1, 100) if reencode in ("webp", "jpeg") else 80
|
||||||
else:
|
else:
|
||||||
reencode = None
|
reencode = None
|
||||||
quality = 80
|
quality = 80
|
||||||
@@ -920,10 +931,7 @@ def interactive_export():
|
|||||||
# 开始导出
|
# 开始导出
|
||||||
console.print("\n[cyan]开始导出...[/cyan]")
|
console.print("\n[cyan]开始导出...[/cyan]")
|
||||||
packer = MMIPKGPacker(
|
packer = MMIPKGPacker(
|
||||||
use_compression=use_compression,
|
use_compression=use_compression, zstd_level=zstd_level, reencode=reencode, reencode_quality=quality
|
||||||
zstd_level=zstd_level,
|
|
||||||
reencode=reencode,
|
|
||||||
reencode_quality=quality
|
|
||||||
)
|
)
|
||||||
|
|
||||||
success = packer.pack_from_db(output_path, pack_name, custom_manifest)
|
success = packer.pack_from_db(output_path, pack_name, custom_manifest)
|
||||||
@@ -944,11 +952,11 @@ def interactive_import():
|
|||||||
|
|
||||||
# 选择导入模式
|
# 选择导入模式
|
||||||
print_import_mode_selection()
|
print_import_mode_selection()
|
||||||
import_mode = get_input("请选择", "1", ['1', '2'])
|
import_mode = get_input("请选择", "1", ["1", "2"])
|
||||||
|
|
||||||
input_files = []
|
input_files = []
|
||||||
|
|
||||||
if import_mode == '1':
|
if import_mode == "1":
|
||||||
# 自动扫描模式
|
# 自动扫描模式
|
||||||
import_dir = os.path.join(PROJECT_ROOT, "data", "import_emoji")
|
import_dir = os.path.join(PROJECT_ROOT, "data", "import_emoji")
|
||||||
os.makedirs(import_dir, exist_ok=True)
|
os.makedirs(import_dir, exist_ok=True)
|
||||||
@@ -957,7 +965,7 @@ def interactive_import():
|
|||||||
|
|
||||||
# 查找所有 .mmipkg 文件
|
# 查找所有 .mmipkg 文件
|
||||||
for file in os.listdir(import_dir):
|
for file in os.listdir(import_dir):
|
||||||
if file.endswith('.mmipkg'):
|
if file.endswith(".mmipkg"):
|
||||||
file_path = os.path.join(import_dir, file)
|
file_path = os.path.join(import_dir, file)
|
||||||
if os.path.isfile(file_path):
|
if os.path.isfile(file_path):
|
||||||
input_files.append(file_path)
|
input_files.append(file_path)
|
||||||
@@ -1032,7 +1040,7 @@ def interactive_import():
|
|||||||
BarColumn(),
|
BarColumn(),
|
||||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||||
TimeElapsedColumn(),
|
TimeElapsedColumn(),
|
||||||
console=console
|
console=console,
|
||||||
) as progress:
|
) as progress:
|
||||||
task = progress.add_task("[cyan]导入文件...", total=len(input_files))
|
task = progress.add_task("[cyan]导入文件...", total=len(input_files))
|
||||||
|
|
||||||
@@ -1044,10 +1052,7 @@ def interactive_import():
|
|||||||
console.print(f"[bold]{'=' * 70}[/bold]")
|
console.print(f"[bold]{'=' * 70}[/bold]")
|
||||||
|
|
||||||
success = unpacker.import_to_db(
|
success = unpacker.import_to_db(
|
||||||
input_path,
|
input_path, output_dir=output_dir, replace_existing=replace_existing, batch_size=batch_size
|
||||||
output_dir=output_dir,
|
|
||||||
replace_existing=replace_existing,
|
|
||||||
batch_size=batch_size
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
@@ -1076,16 +1081,16 @@ def main():
|
|||||||
while True:
|
while True:
|
||||||
print_menu()
|
print_menu()
|
||||||
try:
|
try:
|
||||||
choice = get_input("请选择", "1", ['0', '1', '2'])
|
choice = get_input("请选择", "1", ["0", "1", "2"])
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
console.print("\n[green]再见![/green]")
|
console.print("\n[green]再见![/green]")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
if choice == '0':
|
if choice == "0":
|
||||||
console.print("\n[green]再见![/green]")
|
console.print("\n[green]再见![/green]")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
elif choice == '1':
|
elif choice == "1":
|
||||||
try:
|
try:
|
||||||
interactive_export()
|
interactive_export()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
@@ -1093,6 +1098,7 @@ def main():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f"\n[red]✗ 发生错误: {e}[/red]")
|
console.print(f"\n[red]✗ 发生错误: {e}[/red]")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -1100,7 +1106,7 @@ def main():
|
|||||||
except (KeyboardInterrupt, EOFError):
|
except (KeyboardInterrupt, EOFError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
elif choice == '2':
|
elif choice == "2":
|
||||||
try:
|
try:
|
||||||
interactive_import()
|
interactive_import()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
@@ -1108,6 +1114,7 @@ def main():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f"\n[red]✗ 发生错误: {e}[/red]")
|
console.print(f"\n[red]✗ 发生错误: {e}[/red]")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -1121,5 +1128,5 @@ def main():
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
sys.exit(main())
|
sys.exit(main())
|
||||||
|
|||||||
@@ -334,7 +334,6 @@ class HeartFChatting:
|
|||||||
self.consecutive_no_reply_count = 0
|
self.consecutive_no_reply_count = 0
|
||||||
reason = ""
|
reason = ""
|
||||||
|
|
||||||
|
|
||||||
await database_api.store_action_info(
|
await database_api.store_action_info(
|
||||||
chat_stream=self.chat_stream,
|
chat_stream=self.chat_stream,
|
||||||
action_build_into_prompt=False,
|
action_build_into_prompt=False,
|
||||||
|
|||||||
@@ -30,9 +30,11 @@ DATA_PATH = os.path.join(ROOT_PATH, "data")
|
|||||||
qa_manager = None
|
qa_manager = None
|
||||||
inspire_manager = None
|
inspire_manager = None
|
||||||
|
|
||||||
|
|
||||||
def get_qa_manager():
|
def get_qa_manager():
|
||||||
return qa_manager
|
return qa_manager
|
||||||
|
|
||||||
|
|
||||||
def lpmm_start_up(): # sourcery skip: extract-duplicate-method
|
def lpmm_start_up(): # sourcery skip: extract-duplicate-method
|
||||||
# 检查LPMM知识库是否启用
|
# 检查LPMM知识库是否启用
|
||||||
if global_config.lpmm_knowledge.enable:
|
if global_config.lpmm_knowledge.enable:
|
||||||
|
|||||||
@@ -128,8 +128,7 @@ class QAManager:
|
|||||||
selected_knowledge = knowledge[:limit]
|
selected_knowledge = knowledge[:limit]
|
||||||
|
|
||||||
formatted_knowledge = [
|
formatted_knowledge = [
|
||||||
f"第{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}"
|
f"第{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}" for i, k in enumerate(selected_knowledge)
|
||||||
for i, k in enumerate(selected_knowledge)
|
|
||||||
]
|
]
|
||||||
# if max_score is not None:
|
# if max_score is not None:
|
||||||
# formatted_knowledge.insert(0, f"最高相关系数:{max_score}")
|
# formatted_knowledge.insert(0, f"最高相关系数:{max_score}")
|
||||||
|
|||||||
@@ -226,7 +226,9 @@ class DefaultReplyer:
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False, llm_response
|
return False, llm_response
|
||||||
|
|
||||||
async def build_expression_habits(self, chat_history: str, target: str, reply_reason: str = "") -> Tuple[str, List[int]]:
|
async def build_expression_habits(
|
||||||
|
self, chat_history: str, target: str, reply_reason: str = ""
|
||||||
|
) -> Tuple[str, List[int]]:
|
||||||
# sourcery skip: for-append-to-extend
|
# sourcery skip: for-append-to-extend
|
||||||
"""构建表达习惯块
|
"""构建表达习惯块
|
||||||
|
|
||||||
|
|||||||
@@ -241,7 +241,9 @@ class PrivateReplyer:
|
|||||||
|
|
||||||
return f"{sender_relation}"
|
return f"{sender_relation}"
|
||||||
|
|
||||||
async def build_expression_habits(self, chat_history: str, target: str, reply_reason: str = "") -> Tuple[str, List[int]]:
|
async def build_expression_habits(
|
||||||
|
self, chat_history: str, target: str, reply_reason: str = ""
|
||||||
|
) -> Tuple[str, List[int]]:
|
||||||
# sourcery skip: for-append-to-extend
|
# sourcery skip: for-append-to-extend
|
||||||
"""构建表达习惯块
|
"""构建表达习惯块
|
||||||
|
|
||||||
|
|||||||
@@ -204,8 +204,9 @@ class WebSocketLogHandler(logging.Handler):
|
|||||||
message = formatted_msg
|
message = formatted_msg
|
||||||
try:
|
try:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
log_dict = json.loads(formatted_msg)
|
log_dict = json.loads(formatted_msg)
|
||||||
message = log_dict.get('event', formatted_msg)
|
message = log_dict.get("event", formatted_msg)
|
||||||
except (json.JSONDecodeError, ValueError):
|
except (json.JSONDecodeError, ValueError):
|
||||||
# 不是 JSON,直接使用消息
|
# 不是 JSON,直接使用消息
|
||||||
message = formatted_msg
|
message = formatted_msg
|
||||||
@@ -228,10 +229,7 @@ class WebSocketLogHandler(logging.Handler):
|
|||||||
import asyncio
|
import asyncio
|
||||||
from src.webui.logs_ws import broadcast_log
|
from src.webui.logs_ws import broadcast_log
|
||||||
|
|
||||||
asyncio.run_coroutine_threadsafe(
|
asyncio.run_coroutine_threadsafe(broadcast_log(log_data), self.loop)
|
||||||
broadcast_log(log_data),
|
|
||||||
self.loop
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# WebSocket 推送失败不影响日志记录
|
# WebSocket 推送失败不影响日志记录
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -467,11 +467,7 @@ class ExpressionLearner:
|
|||||||
up_content: str,
|
up_content: str,
|
||||||
current_time: float,
|
current_time: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
expr_obj = (
|
expr_obj = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.style == style)).first()
|
||||||
Expression.select()
|
|
||||||
.where((Expression.chat_id == self.chat_id) & (Expression.style == style))
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if expr_obj:
|
if expr_obj:
|
||||||
await self._update_existing_expression(
|
await self._update_existing_expression(
|
||||||
|
|||||||
@@ -42,8 +42,6 @@ def init_prompt():
|
|||||||
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ExpressionSelector:
|
class ExpressionSelector:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.llm_model = LLMRequest(
|
self.llm_model = LLMRequest(
|
||||||
@@ -262,7 +260,6 @@ class ExpressionSelector:
|
|||||||
# 4. 调用LLM
|
# 4. 调用LLM
|
||||||
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||||
|
|
||||||
|
|
||||||
# print(prompt)
|
# print(prompt)
|
||||||
|
|
||||||
if not content:
|
if not content:
|
||||||
|
|||||||
@@ -36,10 +36,7 @@ def _contains_bot_self_name(content: str) -> bool:
|
|||||||
|
|
||||||
target = content.strip().lower()
|
target = content.strip().lower()
|
||||||
nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower()
|
nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower()
|
||||||
alias_names = [
|
alias_names = [str(alias or "").strip().lower() for alias in getattr(bot_config, "alias_names", []) or []]
|
||||||
str(alias or "").strip().lower()
|
|
||||||
for alias in getattr(bot_config, "alias_names", []) or []
|
|
||||||
]
|
|
||||||
|
|
||||||
candidates = [name for name in [nickname, *alias_names] if name]
|
candidates = [name for name in [nickname, *alias_names] if name]
|
||||||
|
|
||||||
@@ -188,9 +185,7 @@ async def _enrich_raw_content_if_needed(
|
|||||||
# 获取该消息的前三条消息
|
# 获取该消息的前三条消息
|
||||||
try:
|
try:
|
||||||
previous_messages = get_raw_msg_before_timestamp_with_chat(
|
previous_messages = get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id, timestamp=target_message.time, limit=3
|
||||||
timestamp=target_message.time,
|
|
||||||
limit=3
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if previous_messages:
|
if previous_messages:
|
||||||
@@ -245,7 +240,7 @@ def _should_infer_meaning(jargon_obj: Jargon) -> bool:
|
|||||||
last_inference = jargon_obj.last_inference_count or 0
|
last_inference = jargon_obj.last_inference_count or 0
|
||||||
|
|
||||||
# 阈值列表:3,6, 10, 20, 40, 60, 100
|
# 阈值列表:3,6, 10, 20, 40, 60, 100
|
||||||
thresholds = [3,6, 10, 20, 40, 60, 100]
|
thresholds = [3, 6, 10, 20, 40, 60, 100]
|
||||||
|
|
||||||
if count < thresholds[0]:
|
if count < thresholds[0]:
|
||||||
return False
|
return False
|
||||||
@@ -311,7 +306,9 @@ class JargonMiner:
|
|||||||
raw_content_list = []
|
raw_content_list = []
|
||||||
if raw_content_str:
|
if raw_content_str:
|
||||||
try:
|
try:
|
||||||
raw_content_list = json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str
|
raw_content_list = (
|
||||||
|
json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str
|
||||||
|
)
|
||||||
if not isinstance(raw_content_list, list):
|
if not isinstance(raw_content_list, list):
|
||||||
raw_content_list = [raw_content_list] if raw_content_list else []
|
raw_content_list = [raw_content_list] if raw_content_list else []
|
||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
@@ -360,7 +357,6 @@ class JargonMiner:
|
|||||||
jargon_obj.save()
|
jargon_obj.save()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
# 步骤2: 仅基于content推断
|
# 步骤2: 仅基于content推断
|
||||||
prompt2 = await global_prompt_manager.format_prompt(
|
prompt2 = await global_prompt_manager.format_prompt(
|
||||||
"jargon_inference_content_only_prompt",
|
"jargon_inference_content_only_prompt",
|
||||||
@@ -388,7 +384,6 @@ class JargonMiner:
|
|||||||
logger.error(f"jargon {content} 推断2解析失败: {e}")
|
logger.error(f"jargon {content} 推断2解析失败: {e}")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
|
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
|
||||||
logger.info(f"jargon {content} 推断2结果: {response2}")
|
logger.info(f"jargon {content} 推断2结果: {response2}")
|
||||||
logger.info(f"jargon {content} 推断1提示词: {prompt1}")
|
logger.info(f"jargon {content} 推断1提示词: {prompt1}")
|
||||||
@@ -457,7 +452,9 @@ class JargonMiner:
|
|||||||
jargon_obj.is_complete = True
|
jargon_obj.is_complete = True
|
||||||
|
|
||||||
jargon_obj.save()
|
jargon_obj.save()
|
||||||
logger.debug(f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}")
|
logger.debug(
|
||||||
|
f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}"
|
||||||
|
)
|
||||||
|
|
||||||
# 固定输出推断结果,格式化为可读形式
|
# 固定输出推断结果,格式化为可读形式
|
||||||
if is_jargon:
|
if is_jargon:
|
||||||
@@ -475,6 +472,7 @@ class JargonMiner:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"jargon推断失败: {e}")
|
logger.error(f"jargon推断失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
def should_trigger(self) -> bool:
|
def should_trigger(self) -> bool:
|
||||||
@@ -571,10 +569,7 @@ class JargonMiner:
|
|||||||
if _contains_bot_self_name(content):
|
if _contains_bot_self_name(content):
|
||||||
logger.debug(f"解析阶段跳过包含机器人昵称/别名的词条: {content}")
|
logger.debug(f"解析阶段跳过包含机器人昵称/别名的词条: {content}")
|
||||||
continue
|
continue
|
||||||
entries.append({
|
entries.append({"content": content, "raw_content": raw_content_list})
|
||||||
"content": content,
|
|
||||||
"raw_content": raw_content_list
|
|
||||||
})
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"解析jargon JSON失败: {e}; 原始: {response}")
|
logger.error(f"解析jargon JSON失败: {e}; 原始: {response}")
|
||||||
return
|
return
|
||||||
@@ -612,19 +607,10 @@ class JargonMiner:
|
|||||||
# 根据all_global配置决定查询逻辑
|
# 根据all_global配置决定查询逻辑
|
||||||
if global_config.jargon.all_global:
|
if global_config.jargon.all_global:
|
||||||
# 开启all_global:无视chat_id,查询所有content匹配的记录(所有记录都是全局的)
|
# 开启all_global:无视chat_id,查询所有content匹配的记录(所有记录都是全局的)
|
||||||
query = (
|
query = Jargon.select().where(Jargon.content == content)
|
||||||
Jargon.select()
|
|
||||||
.where(Jargon.content == content)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# 关闭all_global:只查询chat_id匹配的记录(不考虑is_global)
|
# 关闭all_global:只查询chat_id匹配的记录(不考虑is_global)
|
||||||
query = (
|
query = Jargon.select().where((Jargon.chat_id == self.chat_id) & (Jargon.content == content))
|
||||||
Jargon.select()
|
|
||||||
.where(
|
|
||||||
(Jargon.chat_id == self.chat_id) &
|
|
||||||
(Jargon.content == content)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if query.exists():
|
if query.exists():
|
||||||
obj = query.get()
|
obj = query.get()
|
||||||
@@ -637,7 +623,9 @@ class JargonMiner:
|
|||||||
existing_raw_content = []
|
existing_raw_content = []
|
||||||
if obj.raw_content:
|
if obj.raw_content:
|
||||||
try:
|
try:
|
||||||
existing_raw_content = json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
existing_raw_content = (
|
||||||
|
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||||
|
)
|
||||||
if not isinstance(existing_raw_content, list):
|
if not isinstance(existing_raw_content, list):
|
||||||
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
@@ -676,7 +664,7 @@ class JargonMiner:
|
|||||||
raw_content=json.dumps(raw_content_list, ensure_ascii=False),
|
raw_content=json.dumps(raw_content_list, ensure_ascii=False),
|
||||||
chat_id=self.chat_id,
|
chat_id=self.chat_id,
|
||||||
is_global=is_global_new,
|
is_global=is_global_new,
|
||||||
count=1
|
count=1,
|
||||||
)
|
)
|
||||||
saved += 1
|
saved += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -720,11 +708,7 @@ async def extract_and_store_jargon(chat_id: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def search_jargon(
|
def search_jargon(
|
||||||
keyword: str,
|
keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True
|
||||||
chat_id: Optional[str] = None,
|
|
||||||
limit: int = 10,
|
|
||||||
case_sensitive: bool = False,
|
|
||||||
fuzzy: bool = True
|
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
搜索jargon,支持大小写不敏感和模糊搜索
|
搜索jargon,支持大小写不敏感和模糊搜索
|
||||||
@@ -747,10 +731,7 @@ def search_jargon(
|
|||||||
keyword = keyword.strip()
|
keyword = keyword.strip()
|
||||||
|
|
||||||
# 构建查询
|
# 构建查询
|
||||||
query = Jargon.select(
|
query = Jargon.select(Jargon.content, Jargon.meaning)
|
||||||
Jargon.content,
|
|
||||||
Jargon.meaning
|
|
||||||
)
|
|
||||||
|
|
||||||
# 构建搜索条件
|
# 构建搜索条件
|
||||||
if case_sensitive:
|
if case_sensitive:
|
||||||
@@ -760,7 +741,7 @@ def search_jargon(
|
|||||||
search_condition = Jargon.content.contains(keyword)
|
search_condition = Jargon.content.contains(keyword)
|
||||||
else:
|
else:
|
||||||
# 精确匹配
|
# 精确匹配
|
||||||
search_condition = (Jargon.content == keyword)
|
search_condition = Jargon.content == keyword
|
||||||
else:
|
else:
|
||||||
# 大小写不敏感
|
# 大小写不敏感
|
||||||
if fuzzy:
|
if fuzzy:
|
||||||
@@ -768,7 +749,7 @@ def search_jargon(
|
|||||||
search_condition = fn.LOWER(Jargon.content).contains(keyword.lower())
|
search_condition = fn.LOWER(Jargon.content).contains(keyword.lower())
|
||||||
else:
|
else:
|
||||||
# 精确匹配(使用LOWER函数)
|
# 精确匹配(使用LOWER函数)
|
||||||
search_condition = (fn.LOWER(Jargon.content) == keyword.lower())
|
search_condition = fn.LOWER(Jargon.content) == keyword.lower()
|
||||||
|
|
||||||
query = query.where(search_condition)
|
query = query.where(search_condition)
|
||||||
|
|
||||||
@@ -779,14 +760,10 @@ def search_jargon(
|
|||||||
else:
|
else:
|
||||||
# 关闭all_global:如果提供了chat_id,优先搜索该聊天或global的jargon
|
# 关闭all_global:如果提供了chat_id,优先搜索该聊天或global的jargon
|
||||||
if chat_id:
|
if chat_id:
|
||||||
query = query.where(
|
query = query.where((Jargon.chat_id == chat_id) | Jargon.is_global)
|
||||||
(Jargon.chat_id == chat_id) | Jargon.is_global
|
|
||||||
)
|
|
||||||
|
|
||||||
# 只返回有meaning的记录
|
# 只返回有meaning的记录
|
||||||
query = query.where(
|
query = query.where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
|
||||||
(Jargon.meaning.is_null(False)) & (Jargon.meaning != "")
|
|
||||||
)
|
|
||||||
|
|
||||||
# 按count降序排序,优先返回出现频率高的
|
# 按count降序排序,优先返回出现频率高的
|
||||||
query = query.order_by(Jargon.count.desc())
|
query = query.order_by(Jargon.count.desc())
|
||||||
@@ -797,10 +774,7 @@ def search_jargon(
|
|||||||
# 执行查询并返回结果
|
# 执行查询并返回结果
|
||||||
results = []
|
results = []
|
||||||
for jargon in query:
|
for jargon in query:
|
||||||
results.append({
|
results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""})
|
||||||
"content": jargon.content or "",
|
|
||||||
"meaning": jargon.meaning or ""
|
|
||||||
})
|
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -840,10 +814,7 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st
|
|||||||
if global_config.jargon.all_global:
|
if global_config.jargon.all_global:
|
||||||
query = Jargon.select().where(Jargon.content == jargon_keyword)
|
query = Jargon.select().where(Jargon.content == jargon_keyword)
|
||||||
else:
|
else:
|
||||||
query = Jargon.select().where(
|
query = Jargon.select().where((Jargon.chat_id == chat_id) & (Jargon.content == jargon_keyword))
|
||||||
(Jargon.chat_id == chat_id) &
|
|
||||||
(Jargon.content == jargon_keyword)
|
|
||||||
)
|
|
||||||
|
|
||||||
if query.exists():
|
if query.exists():
|
||||||
# 更新现有记录
|
# 更新现有记录
|
||||||
@@ -854,7 +825,9 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st
|
|||||||
existing_raw_content = []
|
existing_raw_content = []
|
||||||
if obj.raw_content:
|
if obj.raw_content:
|
||||||
try:
|
try:
|
||||||
existing_raw_content = json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
existing_raw_content = (
|
||||||
|
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||||
|
)
|
||||||
if not isinstance(existing_raw_content, list):
|
if not isinstance(existing_raw_content, list):
|
||||||
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
@@ -877,11 +850,9 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st
|
|||||||
raw_content=json.dumps([raw_content], ensure_ascii=False),
|
raw_content=json.dumps([raw_content], ensure_ascii=False),
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
is_global=is_global_new,
|
is_global=is_global_new,
|
||||||
count=1
|
count=1,
|
||||||
)
|
)
|
||||||
logger.info(f"创建新jargon记录: {jargon_keyword}")
|
logger.info(f"创建新jargon记录: {jargon_keyword}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"存储jargon失败: {e}")
|
logger.error(f"存储jargon失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -116,9 +116,7 @@ class MessageBuilder:
|
|||||||
构建消息对象
|
构建消息对象
|
||||||
:return: Message对象
|
:return: Message对象
|
||||||
"""
|
"""
|
||||||
if len(self.__content) == 0 and not (
|
if len(self.__content) == 0 and not (self.__role == RoleType.Assistant and self.__tool_calls):
|
||||||
self.__role == RoleType.Assistant and self.__tool_calls
|
|
||||||
):
|
|
||||||
raise ValueError("内容不能为空")
|
raise ValueError("内容不能为空")
|
||||||
if self.__role == RoleType.Tool and self.__tool_call_id is None:
|
if self.__role == RoleType.Tool and self.__tool_call_id is None:
|
||||||
raise ValueError("Tool角色的工具调用ID不能为空")
|
raise ValueError("Tool角色的工具调用ID不能为空")
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ class MainSystem:
|
|||||||
"""注册 WebUI API 路由"""
|
"""注册 WebUI API 路由"""
|
||||||
try:
|
try:
|
||||||
from src.webui.routes import router as webui_router
|
from src.webui.routes import router as webui_router
|
||||||
|
|
||||||
self.server.register_router(webui_router)
|
self.server.register_router(webui_router)
|
||||||
logger.info("WebUI API 路由已注册")
|
logger.info("WebUI API 路由已注册")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -55,6 +56,7 @@ class MainSystem:
|
|||||||
def _setup_webui(self):
|
def _setup_webui(self):
|
||||||
"""设置 WebUI(根据环境变量决定模式)"""
|
"""设置 WebUI(根据环境变量决定模式)"""
|
||||||
import os
|
import os
|
||||||
|
|
||||||
webui_enabled = os.getenv("WEBUI_ENABLED", "false").lower() == "true"
|
webui_enabled = os.getenv("WEBUI_ENABLED", "false").lower() == "true"
|
||||||
if not webui_enabled:
|
if not webui_enabled:
|
||||||
logger.info("WebUI 已禁用")
|
logger.info("WebUI 已禁用")
|
||||||
@@ -64,6 +66,7 @@ class MainSystem:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from src.webui.manager import setup_webui
|
from src.webui.manager import setup_webui
|
||||||
|
|
||||||
setup_webui(mode=webui_mode)
|
setup_webui(mode=webui_mode)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"设置 WebUI 失败: {e}")
|
logger.error(f"设置 WebUI 失败: {e}")
|
||||||
|
|||||||
@@ -33,10 +33,7 @@ def _cleanup_stale_not_found_thinking_back() -> None:
|
|||||||
try:
|
try:
|
||||||
deleted_rows = (
|
deleted_rows = (
|
||||||
ThinkingBack.delete()
|
ThinkingBack.delete()
|
||||||
.where(
|
.where((ThinkingBack.found_answer == 0) & (ThinkingBack.update_time < threshold_time))
|
||||||
(ThinkingBack.found_answer == 0) &
|
|
||||||
(ThinkingBack.update_time < threshold_time)
|
|
||||||
)
|
|
||||||
.execute()
|
.execute()
|
||||||
)
|
)
|
||||||
if deleted_rows:
|
if deleted_rows:
|
||||||
@@ -45,6 +42,7 @@ def _cleanup_stale_not_found_thinking_back() -> None:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"清理未找到答案的thinking_back记录失败: {e}")
|
logger.error(f"清理未找到答案的thinking_back记录失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
def init_memory_retrieval_prompt():
|
def init_memory_retrieval_prompt():
|
||||||
"""初始化记忆检索相关的 prompt 模板和工具"""
|
"""初始化记忆检索相关的 prompt 模板和工具"""
|
||||||
# 首先注册所有工具
|
# 首先注册所有工具
|
||||||
@@ -221,10 +219,7 @@ def _parse_react_response(response: str) -> Optional[Dict[str, Any]]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def _retrieve_concepts_with_jargon(
|
async def _retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> str:
|
||||||
concepts: List[str],
|
|
||||||
chat_id: str
|
|
||||||
) -> str:
|
|
||||||
"""对概念列表进行jargon检索
|
"""对概念列表进行jargon检索
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -246,25 +241,13 @@ async def _retrieve_concepts_with_jargon(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# 先尝试精确匹配
|
# 先尝试精确匹配
|
||||||
jargon_results = search_jargon(
|
jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False)
|
||||||
keyword=concept,
|
|
||||||
chat_id=chat_id,
|
|
||||||
limit=10,
|
|
||||||
case_sensitive=False,
|
|
||||||
fuzzy=False
|
|
||||||
)
|
|
||||||
|
|
||||||
is_fuzzy_match = False
|
is_fuzzy_match = False
|
||||||
|
|
||||||
# 如果精确匹配未找到,尝试模糊搜索
|
# 如果精确匹配未找到,尝试模糊搜索
|
||||||
if not jargon_results:
|
if not jargon_results:
|
||||||
jargon_results = search_jargon(
|
jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True)
|
||||||
keyword=concept,
|
|
||||||
chat_id=chat_id,
|
|
||||||
limit=10,
|
|
||||||
case_sensitive=False,
|
|
||||||
fuzzy=True
|
|
||||||
)
|
|
||||||
is_fuzzy_match = True
|
is_fuzzy_match = True
|
||||||
|
|
||||||
if jargon_results:
|
if jargon_results:
|
||||||
@@ -298,11 +281,7 @@ async def _retrieve_concepts_with_jargon(
|
|||||||
|
|
||||||
|
|
||||||
async def _react_agent_solve_question(
|
async def _react_agent_solve_question(
|
||||||
question: str,
|
question: str, chat_id: str, max_iterations: int = 5, timeout: float = 30.0, initial_info: str = ""
|
||||||
chat_id: str,
|
|
||||||
max_iterations: int = 5,
|
|
||||||
timeout: float = 30.0,
|
|
||||||
initial_info: str = ""
|
|
||||||
) -> Tuple[bool, str, List[Dict[str, Any]], bool]:
|
) -> Tuple[bool, str, List[Dict[str, Any]], bool]:
|
||||||
"""使用ReAct架构的Agent来解决问题
|
"""使用ReAct架构的Agent来解决问题
|
||||||
|
|
||||||
@@ -343,11 +322,12 @@ async def _react_agent_solve_question(
|
|||||||
remaining_iterations = max_iterations - current_iteration
|
remaining_iterations = max_iterations - current_iteration
|
||||||
is_final_iteration = current_iteration >= max_iterations
|
is_final_iteration = current_iteration >= max_iterations
|
||||||
|
|
||||||
|
|
||||||
if is_final_iteration:
|
if is_final_iteration:
|
||||||
# 最后一次迭代,使用最终prompt
|
# 最后一次迭代,使用最终prompt
|
||||||
tool_definitions = []
|
tool_definitions = []
|
||||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: 0(最后一次迭代,不提供工具调用)")
|
logger.info(
|
||||||
|
f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: 0(最后一次迭代,不提供工具调用)"
|
||||||
|
)
|
||||||
|
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
prompt = await global_prompt_manager.format_prompt(
|
||||||
"memory_retrieval_react_final_prompt",
|
"memory_retrieval_react_final_prompt",
|
||||||
@@ -370,7 +350,9 @@ async def _react_agent_solve_question(
|
|||||||
else:
|
else:
|
||||||
# 非最终迭代,使用head_prompt
|
# 非最终迭代,使用head_prompt
|
||||||
tool_definitions = tool_registry.get_tool_definitions()
|
tool_definitions = tool_registry.get_tool_definitions()
|
||||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: {len(tool_definitions)}")
|
logger.info(
|
||||||
|
f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: {len(tool_definitions)}"
|
||||||
|
)
|
||||||
|
|
||||||
head_prompt = await global_prompt_manager.format_prompt(
|
head_prompt = await global_prompt_manager.format_prompt(
|
||||||
"memory_retrieval_react_prompt_head",
|
"memory_retrieval_react_prompt_head",
|
||||||
@@ -401,7 +383,7 @@ async def _react_agent_solve_question(
|
|||||||
# 优化日志展示 - 合并所有消息到一条日志
|
# 优化日志展示 - 合并所有消息到一条日志
|
||||||
log_lines = []
|
log_lines = []
|
||||||
for idx, msg in enumerate(messages, 1):
|
for idx, msg in enumerate(messages, 1):
|
||||||
role_name = msg.role.value if hasattr(msg.role, 'value') else str(msg.role)
|
role_name = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
|
||||||
|
|
||||||
# 处理内容 - 显示完整内容,不截断
|
# 处理内容 - 显示完整内容,不截断
|
||||||
if isinstance(msg.content, str):
|
if isinstance(msg.content, str):
|
||||||
@@ -437,14 +419,22 @@ async def _react_agent_solve_question(
|
|||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
success, response, reasoning_content, model_name, tool_calls = await llm_api.generate_with_model_with_tools_by_message_factory(
|
(
|
||||||
|
success,
|
||||||
|
response,
|
||||||
|
reasoning_content,
|
||||||
|
model_name,
|
||||||
|
tool_calls,
|
||||||
|
) = await llm_api.generate_with_model_with_tools_by_message_factory(
|
||||||
message_factory,
|
message_factory,
|
||||||
model_config=model_config.model_task_config.tool_use,
|
model_config=model_config.model_task_config.tool_use,
|
||||||
tool_options=tool_definitions,
|
tool_options=tool_definitions,
|
||||||
request_type="memory.react",
|
request_type="memory.react",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}")
|
logger.info(
|
||||||
|
f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
|
||||||
|
)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
logger.error(f"ReAct Agent LLM调用失败: {response}")
|
logger.error(f"ReAct Agent LLM调用失败: {response}")
|
||||||
@@ -465,12 +455,7 @@ async def _react_agent_solve_question(
|
|||||||
assistant_message = assistant_builder.build()
|
assistant_message = assistant_builder.build()
|
||||||
|
|
||||||
# 记录思考步骤
|
# 记录思考步骤
|
||||||
step = {
|
step = {"iteration": iteration + 1, "thought": response, "actions": [], "observations": []}
|
||||||
"iteration": iteration + 1,
|
|
||||||
"thought": response,
|
|
||||||
"actions": [],
|
|
||||||
"observations": []
|
|
||||||
}
|
|
||||||
|
|
||||||
# 优先从思考内容中提取found_answer或not_enough_info
|
# 优先从思考内容中提取found_answer或not_enough_info
|
||||||
def extract_quoted_content(text, func_name, param_name):
|
def extract_quoted_content(text, func_name, param_name):
|
||||||
@@ -495,14 +480,14 @@ async def _react_agent_solve_question(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# 查找参数名和等号
|
# 查找参数名和等号
|
||||||
param_pattern = f'{param_name}='
|
param_pattern = f"{param_name}="
|
||||||
param_pos = text_lower.find(param_pattern, func_pos)
|
param_pos = text_lower.find(param_pattern, func_pos)
|
||||||
if param_pos == -1:
|
if param_pos == -1:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 跳过参数名、等号和空白
|
# 跳过参数名、等号和空白
|
||||||
start_pos = param_pos + len(param_pattern)
|
start_pos = param_pos + len(param_pattern)
|
||||||
while start_pos < len(text) and text[start_pos] in ' \t\n':
|
while start_pos < len(text) and text[start_pos] in " \t\n":
|
||||||
start_pos += 1
|
start_pos += 1
|
||||||
|
|
||||||
if start_pos >= len(text):
|
if start_pos >= len(text):
|
||||||
@@ -518,13 +503,13 @@ async def _react_agent_solve_question(
|
|||||||
while end_pos < len(text):
|
while end_pos < len(text):
|
||||||
if text[end_pos] == quote_char:
|
if text[end_pos] == quote_char:
|
||||||
# 检查是否是转义的引号
|
# 检查是否是转义的引号
|
||||||
if end_pos > start_pos + 1 and text[end_pos - 1] == '\\':
|
if end_pos > start_pos + 1 and text[end_pos - 1] == "\\":
|
||||||
end_pos += 1
|
end_pos += 1
|
||||||
continue
|
continue
|
||||||
# 找到匹配的引号
|
# 找到匹配的引号
|
||||||
content = text[start_pos + 1:end_pos]
|
content = text[start_pos + 1 : end_pos]
|
||||||
# 处理转义字符
|
# 处理转义字符
|
||||||
content = content.replace('\\"', '"').replace("\\'", "'").replace('\\\\', '\\')
|
content = content.replace('\\"', '"').replace("\\'", "'").replace("\\\\", "\\")
|
||||||
return content
|
return content
|
||||||
end_pos += 1
|
end_pos += 1
|
||||||
|
|
||||||
@@ -536,27 +521,35 @@ async def _react_agent_solve_question(
|
|||||||
|
|
||||||
# 只检查response(LLM的直接输出内容),不检查reasoning_content
|
# 只检查response(LLM的直接输出内容),不检查reasoning_content
|
||||||
if response:
|
if response:
|
||||||
found_answer_content = extract_quoted_content(response, 'found_answer', 'answer')
|
found_answer_content = extract_quoted_content(response, "found_answer", "answer")
|
||||||
if not found_answer_content:
|
if not found_answer_content:
|
||||||
not_enough_info_reason = extract_quoted_content(response, 'not_enough_info', 'reason')
|
not_enough_info_reason = extract_quoted_content(response, "not_enough_info", "reason")
|
||||||
|
|
||||||
# 如果从输出内容中找到了答案,直接返回
|
# 如果从输出内容中找到了答案,直接返回
|
||||||
if found_answer_content:
|
if found_answer_content:
|
||||||
step["actions"].append({"action_type": "found_answer", "action_params": {"answer": found_answer_content}})
|
step["actions"].append({"action_type": "found_answer", "action_params": {"answer": found_answer_content}})
|
||||||
step["observations"] = ["从LLM输出内容中检测到found_answer"]
|
step["observations"] = ["从LLM输出内容中检测到found_answer"]
|
||||||
thinking_steps.append(step)
|
thinking_steps.append(step)
|
||||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到found_answer: {found_answer_content[:100]}...")
|
logger.info(
|
||||||
|
f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到found_answer: {found_answer_content[:100]}..."
|
||||||
|
)
|
||||||
return True, found_answer_content, thinking_steps, False
|
return True, found_answer_content, thinking_steps, False
|
||||||
|
|
||||||
if not_enough_info_reason:
|
if not_enough_info_reason:
|
||||||
step["actions"].append({"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}})
|
step["actions"].append(
|
||||||
|
{"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}
|
||||||
|
)
|
||||||
step["observations"] = ["从LLM输出内容中检测到not_enough_info"]
|
step["observations"] = ["从LLM输出内容中检测到not_enough_info"]
|
||||||
thinking_steps.append(step)
|
thinking_steps.append(step)
|
||||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到not_enough_info: {not_enough_info_reason[:100]}...")
|
logger.info(
|
||||||
|
f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到not_enough_info: {not_enough_info_reason[:100]}..."
|
||||||
|
)
|
||||||
return False, not_enough_info_reason, thinking_steps, False
|
return False, not_enough_info_reason, thinking_steps, False
|
||||||
|
|
||||||
if is_final_iteration:
|
if is_final_iteration:
|
||||||
step["actions"].append({"action_type": "not_enough_info", "action_params": {"reason": "已到达最后一次迭代,无法找到答案"}})
|
step["actions"].append(
|
||||||
|
{"action_type": "not_enough_info", "action_params": {"reason": "已到达最后一次迭代,无法找到答案"}}
|
||||||
|
)
|
||||||
step["observations"] = ["已到达最后一次迭代,无法找到答案"]
|
step["observations"] = ["已到达最后一次迭代,无法找到答案"]
|
||||||
thinking_steps.append(step)
|
thinking_steps.append(step)
|
||||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 已到达最后一次迭代,无法找到答案")
|
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 已到达最后一次迭代,无法找到答案")
|
||||||
@@ -596,7 +589,9 @@ async def _react_agent_solve_question(
|
|||||||
tool_name = tool_call.func_name
|
tool_name = tool_call.func_name
|
||||||
tool_args = tool_call.args or {}
|
tool_args = tool_call.args or {}
|
||||||
|
|
||||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 工具调用 {i+1}/{len(tool_calls)}: {tool_name}({tool_args})")
|
logger.info(
|
||||||
|
f"ReAct Agent 第 {iteration + 1} 次迭代 工具调用 {i + 1}/{len(tool_calls)}: {tool_name}({tool_args})"
|
||||||
|
)
|
||||||
|
|
||||||
# 普通工具调用
|
# 普通工具调用
|
||||||
tool = tool_registry.get_tool(tool_name)
|
tool = tool_registry.get_tool(tool_name)
|
||||||
@@ -606,6 +601,7 @@ async def _react_agent_solve_question(
|
|||||||
|
|
||||||
# 如果工具函数签名需要chat_id,添加它
|
# 如果工具函数签名需要chat_id,添加它
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
sig = inspect.signature(tool.execute_func)
|
sig = inspect.signature(tool.execute_func)
|
||||||
if "chat_id" in sig.parameters:
|
if "chat_id" in sig.parameters:
|
||||||
tool_params["chat_id"] = chat_id
|
tool_params["chat_id"] = chat_id
|
||||||
@@ -625,7 +621,7 @@ async def _react_agent_solve_question(
|
|||||||
step["actions"].append({"action_type": tool_name, "action_params": tool_args})
|
step["actions"].append({"action_type": tool_name, "action_params": tool_args})
|
||||||
else:
|
else:
|
||||||
error_msg = f"未知的工具类型: {tool_name}"
|
error_msg = f"未知的工具类型: {tool_name}"
|
||||||
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i+1}/{len(tool_calls)} {error_msg}")
|
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i + 1}/{len(tool_calls)} {error_msg}")
|
||||||
tool_tasks.append(asyncio.create_task(asyncio.sleep(0, result=f"查询{tool_name}失败: {error_msg}")))
|
tool_tasks.append(asyncio.create_task(asyncio.sleep(0, result=f"查询{tool_name}失败: {error_msg}")))
|
||||||
|
|
||||||
# 并行执行所有工具
|
# 并行执行所有工具
|
||||||
@@ -636,7 +632,7 @@ async def _react_agent_solve_question(
|
|||||||
for i, (tool_call_item, observation) in enumerate(zip(tool_calls, observations, strict=False)):
|
for i, (tool_call_item, observation) in enumerate(zip(tool_calls, observations, strict=False)):
|
||||||
if isinstance(observation, Exception):
|
if isinstance(observation, Exception):
|
||||||
observation = f"工具执行异常: {str(observation)}"
|
observation = f"工具执行异常: {str(observation)}"
|
||||||
logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i+1} 执行异常: {observation}")
|
logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i + 1} 执行异常: {observation}")
|
||||||
|
|
||||||
observation_text = observation if isinstance(observation, str) else str(observation)
|
observation_text = observation if isinstance(observation, str) else str(observation)
|
||||||
step["observations"].append(observation_text)
|
step["observations"].append(observation_text)
|
||||||
@@ -655,7 +651,9 @@ async def _react_agent_solve_question(
|
|||||||
# 迭代超时应该直接视为not_enough_info,而不是使用已有信息
|
# 迭代超时应该直接视为not_enough_info,而不是使用已有信息
|
||||||
# 只有Agent明确返回found_answer时,才认为找到了答案
|
# 只有Agent明确返回found_answer时,才认为找到了答案
|
||||||
if collected_info:
|
if collected_info:
|
||||||
logger.warning(f"ReAct Agent达到最大迭代次数或超时,但未明确返回found_answer。已收集信息: {collected_info[:100]}...")
|
logger.warning(
|
||||||
|
f"ReAct Agent达到最大迭代次数或超时,但未明确返回found_answer。已收集信息: {collected_info[:100]}..."
|
||||||
|
)
|
||||||
if is_timeout:
|
if is_timeout:
|
||||||
logger.warning("ReAct Agent超时,直接视为not_enough_info")
|
logger.warning("ReAct Agent超时,直接视为not_enough_info")
|
||||||
else:
|
else:
|
||||||
@@ -680,10 +678,7 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0)
|
|||||||
# 查询最近时间窗口内的记录,按更新时间倒序
|
# 查询最近时间窗口内的记录,按更新时间倒序
|
||||||
records = (
|
records = (
|
||||||
ThinkingBack.select()
|
ThinkingBack.select()
|
||||||
.where(
|
.where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.update_time >= start_time))
|
||||||
(ThinkingBack.chat_id == chat_id) &
|
|
||||||
(ThinkingBack.update_time >= start_time)
|
|
||||||
)
|
|
||||||
.order_by(ThinkingBack.update_time.desc())
|
.order_by(ThinkingBack.update_time.desc())
|
||||||
.limit(5) # 最多返回5条最近的记录
|
.limit(5) # 最多返回5条最近的记录
|
||||||
)
|
)
|
||||||
@@ -735,9 +730,9 @@ def _get_cached_memories(chat_id: str, time_window_seconds: float = 300.0) -> Li
|
|||||||
records = (
|
records = (
|
||||||
ThinkingBack.select()
|
ThinkingBack.select()
|
||||||
.where(
|
.where(
|
||||||
(ThinkingBack.chat_id == chat_id) &
|
(ThinkingBack.chat_id == chat_id)
|
||||||
(ThinkingBack.update_time >= start_time) &
|
& (ThinkingBack.update_time >= start_time)
|
||||||
(ThinkingBack.found_answer == 1)
|
& (ThinkingBack.found_answer == 1)
|
||||||
)
|
)
|
||||||
.order_by(ThinkingBack.update_time.desc())
|
.order_by(ThinkingBack.update_time.desc())
|
||||||
.limit(5) # 最多返回5条最近的记录
|
.limit(5) # 最多返回5条最近的记录
|
||||||
@@ -775,10 +770,7 @@ def _query_thinking_back(chat_id: str, question: str) -> Optional[Tuple[bool, st
|
|||||||
# 按更新时间倒序,获取最新的记录
|
# 按更新时间倒序,获取最新的记录
|
||||||
records = (
|
records = (
|
||||||
ThinkingBack.select()
|
ThinkingBack.select()
|
||||||
.where(
|
.where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.question == question))
|
||||||
(ThinkingBack.chat_id == chat_id) &
|
|
||||||
(ThinkingBack.question == question)
|
|
||||||
)
|
|
||||||
.order_by(ThinkingBack.update_time.desc())
|
.order_by(ThinkingBack.update_time.desc())
|
||||||
.limit(1)
|
.limit(1)
|
||||||
)
|
)
|
||||||
@@ -857,6 +849,7 @@ async def _analyze_question_answer(question: str, answer: str, chat_id: str) ->
|
|||||||
jargon_keyword = analysis_result.get("jargon_keyword", "").strip()
|
jargon_keyword = analysis_result.get("jargon_keyword", "").strip()
|
||||||
if jargon_keyword:
|
if jargon_keyword:
|
||||||
from src.jargon.jargon_miner import store_jargon_from_answer
|
from src.jargon.jargon_miner import store_jargon_from_answer
|
||||||
|
|
||||||
await store_jargon_from_answer(jargon_keyword, answer, chat_id)
|
await store_jargon_from_answer(jargon_keyword, answer, chat_id)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"分析为黑话但未提取到关键词,问题: {question[:50]}...")
|
logger.warning(f"分析为黑话但未提取到关键词,问题: {question[:50]}...")
|
||||||
@@ -882,14 +875,8 @@ async def _analyze_question_answer(question: str, answer: str, chat_id: str) ->
|
|||||||
logger.error(f"分析问题和答案时发生异常: {e}")
|
logger.error(f"分析问题和答案时发生异常: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _store_thinking_back(
|
def _store_thinking_back(
|
||||||
chat_id: str,
|
chat_id: str, question: str, context: str, found_answer: bool, answer: str, thinking_steps: List[Dict[str, Any]]
|
||||||
question: str,
|
|
||||||
context: str,
|
|
||||||
found_answer: bool,
|
|
||||||
answer: str,
|
|
||||||
thinking_steps: List[Dict[str, Any]]
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""存储或更新思考过程到数据库(如果已存在则更新,否则创建)
|
"""存储或更新思考过程到数据库(如果已存在则更新,否则创建)
|
||||||
|
|
||||||
@@ -907,10 +894,7 @@ def _store_thinking_back(
|
|||||||
# 先查询是否已存在相同chat_id和问题的记录
|
# 先查询是否已存在相同chat_id和问题的记录
|
||||||
existing = (
|
existing = (
|
||||||
ThinkingBack.select()
|
ThinkingBack.select()
|
||||||
.where(
|
.where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.question == question))
|
||||||
(ThinkingBack.chat_id == chat_id) &
|
|
||||||
(ThinkingBack.question == question)
|
|
||||||
)
|
|
||||||
.order_by(ThinkingBack.update_time.desc())
|
.order_by(ThinkingBack.update_time.desc())
|
||||||
.limit(1)
|
.limit(1)
|
||||||
)
|
)
|
||||||
@@ -935,19 +919,14 @@ def _store_thinking_back(
|
|||||||
answer=answer,
|
answer=answer,
|
||||||
thinking_steps=json.dumps(thinking_steps, ensure_ascii=False),
|
thinking_steps=json.dumps(thinking_steps, ensure_ascii=False),
|
||||||
create_time=now,
|
create_time=now,
|
||||||
update_time=now
|
update_time=now,
|
||||||
)
|
)
|
||||||
logger.info(f"已创建思考过程到数据库,问题: {question[:50]}...")
|
logger.info(f"已创建思考过程到数据库,问题: {question[:50]}...")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"存储思考过程失败: {e}")
|
logger.error(f"存储思考过程失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def _process_single_question(
|
async def _process_single_question(question: str, chat_id: str, context: str, initial_info: str = "") -> Optional[str]:
|
||||||
question: str,
|
|
||||||
chat_id: str,
|
|
||||||
context: str,
|
|
||||||
initial_info: str = ""
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""处理单个问题的查询(包含缓存检查逻辑)
|
"""处理单个问题的查询(包含缓存检查逻辑)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -1015,7 +994,7 @@ async def _process_single_question(
|
|||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
max_iterations=global_config.memory.max_agent_iterations,
|
max_iterations=global_config.memory.max_agent_iterations,
|
||||||
timeout=120.0,
|
timeout=120.0,
|
||||||
initial_info=question_initial_info
|
initial_info=question_initial_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 存储到数据库(超时时不存储)
|
# 存储到数据库(超时时不存储)
|
||||||
@@ -1026,7 +1005,7 @@ async def _process_single_question(
|
|||||||
context=context,
|
context=context,
|
||||||
found_answer=found_answer,
|
found_answer=found_answer,
|
||||||
answer=answer,
|
answer=answer,
|
||||||
thinking_steps=thinking_steps
|
thinking_steps=thinking_steps,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(f"ReAct Agent超时,不存储到数据库,问题: {question[:50]}...")
|
logger.info(f"ReAct Agent超时,不存储到数据库,问题: {question[:50]}...")
|
||||||
@@ -1112,7 +1091,6 @@ async def build_memory_retrieval_prompt(
|
|||||||
else:
|
else:
|
||||||
logger.info("概念检索未找到任何结果")
|
logger.info("概念检索未找到任何结果")
|
||||||
|
|
||||||
|
|
||||||
# 获取缓存的记忆(与question时使用相同的时间窗口和数量限制)
|
# 获取缓存的记忆(与question时使用相同的时间窗口和数量限制)
|
||||||
cached_memories = _get_cached_memories(chat_id, time_window_seconds=300.0)
|
cached_memories = _get_cached_memories(chat_id, time_window_seconds=300.0)
|
||||||
|
|
||||||
@@ -1141,12 +1119,7 @@ async def build_memory_retrieval_prompt(
|
|||||||
|
|
||||||
# 并行处理所有问题,将概念检索结果作为初始信息传递
|
# 并行处理所有问题,将概念检索结果作为初始信息传递
|
||||||
question_tasks = [
|
question_tasks = [
|
||||||
_process_single_question(
|
_process_single_question(question=question, chat_id=chat_id, context=message, initial_info=initial_info)
|
||||||
question=question,
|
|
||||||
chat_id=chat_id,
|
|
||||||
context=message,
|
|
||||||
initial_info=initial_info
|
|
||||||
)
|
|
||||||
for question in questions
|
for question in questions
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -1179,7 +1152,9 @@ async def build_memory_retrieval_prompt(
|
|||||||
|
|
||||||
if all_results:
|
if all_results:
|
||||||
retrieved_memory = "\n\n".join(all_results)
|
retrieved_memory = "\n\n".join(all_results)
|
||||||
logger.info(f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(all_results)} 条记忆(含缓存)")
|
logger.info(
|
||||||
|
f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(all_results)} 条记忆(含缓存)"
|
||||||
|
)
|
||||||
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
|
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
|
||||||
else:
|
else:
|
||||||
logger.debug("所有问题均未找到答案,且无缓存记忆")
|
logger.debug("所有问题均未找到答案,且无缓存记忆")
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
记忆系统工具函数
|
记忆系统工具函数
|
||||||
包含模糊查找、相似度计算等工具函数
|
包含模糊查找、相似度计算等工具函数
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -14,6 +15,7 @@ from src.common.logger import get_logger
|
|||||||
|
|
||||||
logger = get_logger("memory_utils")
|
logger = get_logger("memory_utils")
|
||||||
|
|
||||||
|
|
||||||
def parse_md_json(json_text: str) -> list[str]:
|
def parse_md_json(json_text: str) -> list[str]:
|
||||||
"""从Markdown格式的内容中提取JSON对象和推理内容"""
|
"""从Markdown格式的内容中提取JSON对象和推理内容"""
|
||||||
json_objects = []
|
json_objects = []
|
||||||
@@ -52,6 +54,7 @@ def parse_md_json(json_text: str) -> list[str]:
|
|||||||
|
|
||||||
return json_objects, reasoning_content
|
return json_objects, reasoning_content
|
||||||
|
|
||||||
|
|
||||||
def calculate_similarity(text1: str, text2: str) -> float:
|
def calculate_similarity(text1: str, text2: str) -> float:
|
||||||
"""
|
"""
|
||||||
计算两个文本的相似度
|
计算两个文本的相似度
|
||||||
@@ -97,10 +100,10 @@ def preprocess_text(text: str) -> str:
|
|||||||
text = text.lower()
|
text = text.lower()
|
||||||
|
|
||||||
# 移除标点符号和特殊字符
|
# 移除标点符号和特殊字符
|
||||||
text = re.sub(r'[^\w\s]', '', text)
|
text = re.sub(r"[^\w\s]", "", text)
|
||||||
|
|
||||||
# 移除多余空格
|
# 移除多余空格
|
||||||
text = re.sub(r'\s+', ' ', text).strip()
|
text = re.sub(r"\s+", " ", text).strip()
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
@@ -109,7 +112,6 @@ def preprocess_text(text: str) -> str:
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def parse_datetime_to_timestamp(value: str) -> float:
|
def parse_datetime_to_timestamp(value: str) -> float:
|
||||||
"""
|
"""
|
||||||
接受多种常见格式并转换为时间戳(秒)
|
接受多种常见格式并转换为时间戳(秒)
|
||||||
@@ -164,4 +166,3 @@ def parse_time_range(time_range: str) -> Tuple[float, float]:
|
|||||||
end_timestamp = parse_datetime_to_timestamp(end_str)
|
end_timestamp = parse_datetime_to_timestamp(end_str)
|
||||||
|
|
||||||
return start_timestamp, end_timestamp
|
return start_timestamp, end_timestamp
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
|
|||||||
from .query_person_info import register_tool as register_query_person_info
|
from .query_person_info import register_tool as register_query_person_info
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
|
||||||
def init_all_tools():
|
def init_all_tools():
|
||||||
"""初始化并注册所有记忆检索工具"""
|
"""初始化并注册所有记忆检索工具"""
|
||||||
register_query_jargon()
|
register_query_jargon()
|
||||||
|
|||||||
@@ -15,10 +15,7 @@ logger = get_logger("memory_retrieval_tools")
|
|||||||
|
|
||||||
|
|
||||||
async def query_chat_history(
|
async def query_chat_history(
|
||||||
chat_id: str,
|
chat_id: str, keyword: Optional[str] = None, time_range: Optional[str] = None, fuzzy: bool = True
|
||||||
keyword: Optional[str] = None,
|
|
||||||
time_range: Optional[str] = None,
|
|
||||||
fuzzy: bool = True
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""根据时间或关键词在chat_history表中查询聊天记录概述
|
"""根据时间或关键词在chat_history表中查询聊天记录概述
|
||||||
|
|
||||||
@@ -50,17 +47,11 @@ async def query_chat_history(
|
|||||||
# 时间范围:查询与时间范围有交集的记录
|
# 时间范围:查询与时间范围有交集的记录
|
||||||
start_timestamp, end_timestamp = parse_time_range(time_range)
|
start_timestamp, end_timestamp = parse_time_range(time_range)
|
||||||
# 交集条件:start_time < end_timestamp AND end_time > start_timestamp
|
# 交集条件:start_time < end_timestamp AND end_time > start_timestamp
|
||||||
time_filter = (
|
time_filter = (ChatHistory.start_time < end_timestamp) & (ChatHistory.end_time > start_timestamp)
|
||||||
(ChatHistory.start_time < end_timestamp) &
|
|
||||||
(ChatHistory.end_time > start_timestamp)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# 时间点:查询包含该时间点的记录(start_time <= time_point <= end_time)
|
# 时间点:查询包含该时间点的记录(start_time <= time_point <= end_time)
|
||||||
target_timestamp = parse_datetime_to_timestamp(time_range)
|
target_timestamp = parse_datetime_to_timestamp(time_range)
|
||||||
time_filter = (
|
time_filter = (ChatHistory.start_time <= target_timestamp) & (ChatHistory.end_time >= target_timestamp)
|
||||||
(ChatHistory.start_time <= target_timestamp) &
|
|
||||||
(ChatHistory.end_time >= target_timestamp)
|
|
||||||
)
|
|
||||||
query = query.where(time_filter)
|
query = query.where(time_filter)
|
||||||
|
|
||||||
# 执行查询
|
# 执行查询
|
||||||
@@ -91,7 +82,9 @@ async def query_chat_history(
|
|||||||
record_keywords_list = []
|
record_keywords_list = []
|
||||||
if record.keywords:
|
if record.keywords:
|
||||||
try:
|
try:
|
||||||
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
keywords_data = (
|
||||||
|
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||||
|
)
|
||||||
if isinstance(keywords_data, list):
|
if isinstance(keywords_data, list):
|
||||||
record_keywords_list = [str(k).lower() for k in keywords_data]
|
record_keywords_list = [str(k).lower() for k in keywords_data]
|
||||||
except (json.JSONDecodeError, TypeError, ValueError):
|
except (json.JSONDecodeError, TypeError, ValueError):
|
||||||
@@ -102,20 +95,24 @@ async def query_chat_history(
|
|||||||
if fuzzy:
|
if fuzzy:
|
||||||
# 模糊匹配:只要包含任意一个关键词即匹配(OR关系)
|
# 模糊匹配:只要包含任意一个关键词即匹配(OR关系)
|
||||||
for kw in keywords_lower:
|
for kw in keywords_lower:
|
||||||
if (kw in theme or
|
if (
|
||||||
kw in summary or
|
kw in theme
|
||||||
kw in original_text or
|
or kw in summary
|
||||||
any(kw in k for k in record_keywords_list)):
|
or kw in original_text
|
||||||
|
or any(kw in k for k in record_keywords_list)
|
||||||
|
):
|
||||||
matched = True
|
matched = True
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# 全匹配:必须包含所有关键词才匹配(AND关系)
|
# 全匹配:必须包含所有关键词才匹配(AND关系)
|
||||||
matched = True
|
matched = True
|
||||||
for kw in keywords_lower:
|
for kw in keywords_lower:
|
||||||
kw_matched = (kw in theme or
|
kw_matched = (
|
||||||
kw in summary or
|
kw in theme
|
||||||
kw in original_text or
|
or kw in summary
|
||||||
any(kw in k for k in record_keywords_list))
|
or kw in original_text
|
||||||
|
or any(kw in k for k in record_keywords_list)
|
||||||
|
)
|
||||||
if not kw_matched:
|
if not kw_matched:
|
||||||
matched = False
|
matched = False
|
||||||
break
|
break
|
||||||
@@ -160,6 +157,7 @@ async def query_chat_history(
|
|||||||
|
|
||||||
# 添加时间范围
|
# 添加时间范围
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S")
|
start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S")
|
end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
result_parts.append(f"时间:{start_str} - {end_str}")
|
result_parts.append(f"时间:{start_str} - {end_str}")
|
||||||
@@ -199,20 +197,20 @@ def register_tool():
|
|||||||
"name": "keyword",
|
"name": "keyword",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘' 或 '麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索)",
|
"description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘' 或 '麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索)",
|
||||||
"required": False
|
"required": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "time_range",
|
"name": "time_range",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)",
|
"description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)",
|
||||||
"required": False
|
"required": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "fuzzy",
|
"name": "fuzzy",
|
||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
"description": "是否使用模糊匹配模式(默认True)。True表示模糊匹配(只要包含任意一个关键词即匹配,OR关系),False表示全匹配(必须包含所有关键词才匹配,AND关系)",
|
"description": "是否使用模糊匹配模式(默认True)。True表示模糊匹配(只要包含任意一个关键词即匹配,OR关系),False表示全匹配(必须包含所有关键词才匹配,AND关系)",
|
||||||
"required": False
|
"required": False,
|
||||||
}
|
},
|
||||||
],
|
],
|
||||||
execute_func=query_chat_history
|
execute_func=query_chat_history,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -73,5 +73,3 @@ def register_tool():
|
|||||||
],
|
],
|
||||||
execute_func=query_lpmm_knowledge,
|
execute_func=query_lpmm_knowledge,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,9 @@ def _format_group_nick_names(group_nick_name_field) -> str:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 解析JSON格式的群昵称列表
|
# 解析JSON格式的群昵称列表
|
||||||
group_nick_names_data = json.loads(group_nick_name_field) if isinstance(group_nick_name_field, str) else group_nick_name_field
|
group_nick_names_data = (
|
||||||
|
json.loads(group_nick_name_field) if isinstance(group_nick_name_field, str) else group_nick_name_field
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(group_nick_names_data, list) or not group_nick_names_data:
|
if not isinstance(group_nick_names_data, list) or not group_nick_names_data:
|
||||||
return ""
|
return ""
|
||||||
@@ -71,9 +73,7 @@ async def query_person_info(person_name: str) -> str:
|
|||||||
return "用户名称为空"
|
return "用户名称为空"
|
||||||
|
|
||||||
# 构建查询条件(使用模糊查询)
|
# 构建查询条件(使用模糊查询)
|
||||||
query = PersonInfo.select().where(
|
query = PersonInfo.select().where(PersonInfo.person_name.contains(person_name))
|
||||||
PersonInfo.person_name.contains(person_name)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 执行查询
|
# 执行查询
|
||||||
records = list(query.limit(20)) # 最多返回20条记录
|
records = list(query.limit(20)) # 最多返回20条记录
|
||||||
@@ -137,7 +137,11 @@ async def query_person_info(person_name: str) -> str:
|
|||||||
# 记忆点(memory_points)
|
# 记忆点(memory_points)
|
||||||
if record.memory_points:
|
if record.memory_points:
|
||||||
try:
|
try:
|
||||||
memory_points_data = json.loads(record.memory_points) if isinstance(record.memory_points, str) else record.memory_points
|
memory_points_data = (
|
||||||
|
json.loads(record.memory_points)
|
||||||
|
if isinstance(record.memory_points, str)
|
||||||
|
else record.memory_points
|
||||||
|
)
|
||||||
if isinstance(memory_points_data, list) and memory_points_data:
|
if isinstance(memory_points_data, list) and memory_points_data:
|
||||||
# 解析记忆点格式:category:content:weight
|
# 解析记忆点格式:category:content:weight
|
||||||
memory_list = []
|
memory_list = []
|
||||||
@@ -206,7 +210,11 @@ async def query_person_info(person_name: str) -> str:
|
|||||||
# 记忆点(memory_points)
|
# 记忆点(memory_points)
|
||||||
if record.memory_points:
|
if record.memory_points:
|
||||||
try:
|
try:
|
||||||
memory_points_data = json.loads(record.memory_points) if isinstance(record.memory_points, str) else record.memory_points
|
memory_points_data = (
|
||||||
|
json.loads(record.memory_points)
|
||||||
|
if isinstance(record.memory_points, str)
|
||||||
|
else record.memory_points
|
||||||
|
)
|
||||||
if isinstance(memory_points_data, list) and memory_points_data:
|
if isinstance(memory_points_data, list) and memory_points_data:
|
||||||
# 解析记忆点格式:category:content:weight
|
# 解析记忆点格式:category:content:weight
|
||||||
memory_list = []
|
memory_list = []
|
||||||
@@ -275,13 +283,7 @@ def register_tool():
|
|||||||
name="query_person_info",
|
name="query_person_info",
|
||||||
description="根据查询某个用户的所有信息。名称、昵称、平台、用户ID、qq号、群昵称等",
|
description="根据查询某个用户的所有信息。名称、昵称、平台、用户ID、qq号、群昵称等",
|
||||||
parameters=[
|
parameters=[
|
||||||
{
|
{"name": "person_name", "type": "string", "description": "用户名称,用于查询用户信息", "required": True}
|
||||||
"name": "person_name",
|
|
||||||
"type": "string",
|
|
||||||
"description": "用户名称,用于查询用户信息",
|
|
||||||
"required": True
|
|
||||||
}
|
|
||||||
],
|
],
|
||||||
execute_func=query_person_info
|
execute_func=query_person_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -82,11 +82,7 @@ class MemoryRetrievalTool:
|
|||||||
param_tuples.append(param_tuple)
|
param_tuples.append(param_tuple)
|
||||||
|
|
||||||
# 构建工具定义,格式与BaseTool.get_tool_definition()一致
|
# 构建工具定义,格式与BaseTool.get_tool_definition()一致
|
||||||
tool_def = {
|
tool_def = {"name": self.name, "description": self.description, "parameters": param_tuples}
|
||||||
"name": self.name,
|
|
||||||
"description": self.description,
|
|
||||||
"parameters": param_tuples
|
|
||||||
}
|
|
||||||
|
|
||||||
return tool_def
|
return tool_def
|
||||||
|
|
||||||
|
|||||||
@@ -162,7 +162,12 @@ def levenshtein_distance(s1: str, s2: str) -> int:
|
|||||||
class Person:
|
class Person:
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_person(
|
def register_person(
|
||||||
cls, platform: str, user_id: str, nickname: str, group_id: Optional[str] = None, group_nick_name: Optional[str] = None
|
cls,
|
||||||
|
platform: str,
|
||||||
|
user_id: str,
|
||||||
|
nickname: str,
|
||||||
|
group_id: Optional[str] = None,
|
||||||
|
group_nick_name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
注册新用户的类方法
|
注册新用户的类方法
|
||||||
@@ -781,7 +786,11 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
|
|||||||
if len(parts) >= 2:
|
if len(parts) >= 2:
|
||||||
existing_content = parts[1].strip()
|
existing_content = parts[1].strip()
|
||||||
# 简单相似度检查(如果内容相同或非常相似,则跳过)
|
# 简单相似度检查(如果内容相同或非常相似,则跳过)
|
||||||
if existing_content == memory_content or memory_content in existing_content or existing_content in memory_content:
|
if (
|
||||||
|
existing_content == memory_content
|
||||||
|
or memory_content in existing_content
|
||||||
|
or existing_content in memory_content
|
||||||
|
):
|
||||||
is_duplicate = True
|
is_duplicate = True
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@@ -125,7 +125,6 @@ class ToolExecutor:
|
|||||||
prompt=prompt, tools=tools, raise_when_empty=False
|
prompt=prompt, tools=tools, raise_when_empty=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# 执行工具调用
|
# 执行工具调用
|
||||||
tool_results, used_tools = await self.execute_tool_calls(tool_calls)
|
tool_results, used_tools = await self.execute_tool_calls(tool_calls)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""表情包管理 API 路由"""
|
"""表情包管理 API 路由"""
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Header, Query
|
from fastapi import APIRouter, HTTPException, Header, Query
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -18,6 +19,7 @@ router = APIRouter(prefix="/emoji", tags=["Emoji"])
|
|||||||
|
|
||||||
class EmojiResponse(BaseModel):
|
class EmojiResponse(BaseModel):
|
||||||
"""表情包响应"""
|
"""表情包响应"""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
full_path: str
|
full_path: str
|
||||||
format: str
|
format: str
|
||||||
@@ -35,6 +37,7 @@ class EmojiResponse(BaseModel):
|
|||||||
|
|
||||||
class EmojiListResponse(BaseModel):
|
class EmojiListResponse(BaseModel):
|
||||||
"""表情包列表响应"""
|
"""表情包列表响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
total: int
|
total: int
|
||||||
page: int
|
page: int
|
||||||
@@ -44,12 +47,14 @@ class EmojiListResponse(BaseModel):
|
|||||||
|
|
||||||
class EmojiDetailResponse(BaseModel):
|
class EmojiDetailResponse(BaseModel):
|
||||||
"""表情包详情响应"""
|
"""表情包详情响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
data: EmojiResponse
|
data: EmojiResponse
|
||||||
|
|
||||||
|
|
||||||
class EmojiUpdateRequest(BaseModel):
|
class EmojiUpdateRequest(BaseModel):
|
||||||
"""表情包更新请求"""
|
"""表情包更新请求"""
|
||||||
|
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
is_registered: Optional[bool] = None
|
is_registered: Optional[bool] = None
|
||||||
is_banned: Optional[bool] = None
|
is_banned: Optional[bool] = None
|
||||||
@@ -58,6 +63,7 @@ class EmojiUpdateRequest(BaseModel):
|
|||||||
|
|
||||||
class EmojiUpdateResponse(BaseModel):
|
class EmojiUpdateResponse(BaseModel):
|
||||||
"""表情包更新响应"""
|
"""表情包更新响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
data: Optional[EmojiResponse] = None
|
data: Optional[EmojiResponse] = None
|
||||||
@@ -65,6 +71,7 @@ class EmojiUpdateResponse(BaseModel):
|
|||||||
|
|
||||||
class EmojiDeleteResponse(BaseModel):
|
class EmojiDeleteResponse(BaseModel):
|
||||||
"""表情包删除响应"""
|
"""表情包删除响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
@@ -120,7 +127,7 @@ async def get_emoji_list(
|
|||||||
is_registered: Optional[bool] = Query(None, description="是否已注册筛选"),
|
is_registered: Optional[bool] = Query(None, description="是否已注册筛选"),
|
||||||
is_banned: Optional[bool] = Query(None, description="是否被禁用筛选"),
|
is_banned: Optional[bool] = Query(None, description="是否被禁用筛选"),
|
||||||
format: Optional[str] = Query(None, description="格式筛选"),
|
format: Optional[str] = Query(None, description="格式筛选"),
|
||||||
authorization: Optional[str] = Header(None)
|
authorization: Optional[str] = Header(None),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取表情包列表
|
获取表情包列表
|
||||||
@@ -145,10 +152,7 @@ async def get_emoji_list(
|
|||||||
|
|
||||||
# 搜索过滤
|
# 搜索过滤
|
||||||
if search:
|
if search:
|
||||||
query = query.where(
|
query = query.where((Emoji.description.contains(search)) | (Emoji.emoji_hash.contains(search)))
|
||||||
(Emoji.description.contains(search)) |
|
|
||||||
(Emoji.emoji_hash.contains(search))
|
|
||||||
)
|
|
||||||
|
|
||||||
# 注册状态过滤
|
# 注册状态过滤
|
||||||
if is_registered is not None:
|
if is_registered is not None:
|
||||||
@@ -164,10 +168,9 @@ async def get_emoji_list(
|
|||||||
|
|
||||||
# 排序:使用次数倒序,然后按记录时间倒序
|
# 排序:使用次数倒序,然后按记录时间倒序
|
||||||
from peewee import Case
|
from peewee import Case
|
||||||
|
|
||||||
query = query.order_by(
|
query = query.order_by(
|
||||||
Emoji.usage_count.desc(),
|
Emoji.usage_count.desc(), Case(None, [(Emoji.record_time.is_null(), 1)], 0), Emoji.record_time.desc()
|
||||||
Case(None, [(Emoji.record_time.is_null(), 1)], 0),
|
|
||||||
Emoji.record_time.desc()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取总数
|
# 获取总数
|
||||||
@@ -180,13 +183,7 @@ async def get_emoji_list(
|
|||||||
# 转换为响应对象
|
# 转换为响应对象
|
||||||
data = [emoji_to_response(emoji) for emoji in emojis]
|
data = [emoji_to_response(emoji) for emoji in emojis]
|
||||||
|
|
||||||
return EmojiListResponse(
|
return EmojiListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||||
success=True,
|
|
||||||
total=total,
|
|
||||||
page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
data=data
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -196,10 +193,7 @@ async def get_emoji_list(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
|
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
|
||||||
async def get_emoji_detail(
|
async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||||
emoji_id: int,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
获取表情包详细信息
|
获取表情包详细信息
|
||||||
|
|
||||||
@@ -218,10 +212,7 @@ async def get_emoji_detail(
|
|||||||
if not emoji:
|
if not emoji:
|
||||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||||
|
|
||||||
return EmojiDetailResponse(
|
return EmojiDetailResponse(success=True, data=emoji_to_response(emoji))
|
||||||
success=True,
|
|
||||||
data=emoji_to_response(emoji)
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -231,11 +222,7 @@ async def get_emoji_detail(
|
|||||||
|
|
||||||
|
|
||||||
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
|
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
|
||||||
async def update_emoji(
|
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization: Optional[str] = Header(None)):
|
||||||
emoji_id: int,
|
|
||||||
request: EmojiUpdateRequest,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
增量更新表情包(只更新提供的字段)
|
增量更新表情包(只更新提供的字段)
|
||||||
|
|
||||||
@@ -262,15 +249,15 @@ async def update_emoji(
|
|||||||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||||
|
|
||||||
# 处理情感标签(转换为 JSON)
|
# 处理情感标签(转换为 JSON)
|
||||||
if 'emotion' in update_data:
|
if "emotion" in update_data:
|
||||||
if update_data['emotion'] is None:
|
if update_data["emotion"] is None:
|
||||||
update_data['emotion'] = None
|
update_data["emotion"] = None
|
||||||
else:
|
else:
|
||||||
update_data['emotion'] = json.dumps(update_data['emotion'], ensure_ascii=False)
|
update_data["emotion"] = json.dumps(update_data["emotion"], ensure_ascii=False)
|
||||||
|
|
||||||
# 如果注册状态从 False 变为 True,记录注册时间
|
# 如果注册状态从 False 变为 True,记录注册时间
|
||||||
if 'is_registered' in update_data and update_data['is_registered'] and not emoji.is_registered:
|
if "is_registered" in update_data and update_data["is_registered"] and not emoji.is_registered:
|
||||||
update_data['register_time'] = time.time()
|
update_data["register_time"] = time.time()
|
||||||
|
|
||||||
# 执行更新
|
# 执行更新
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
@@ -281,9 +268,7 @@ async def update_emoji(
|
|||||||
logger.info(f"表情包已更新: ID={emoji_id}, 字段: {list(update_data.keys())}")
|
logger.info(f"表情包已更新: ID={emoji_id}, 字段: {list(update_data.keys())}")
|
||||||
|
|
||||||
return EmojiUpdateResponse(
|
return EmojiUpdateResponse(
|
||||||
success=True,
|
success=True, message=f"成功更新 {len(update_data)} 个字段", data=emoji_to_response(emoji)
|
||||||
message=f"成功更新 {len(update_data)} 个字段",
|
|
||||||
data=emoji_to_response(emoji)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
@@ -294,10 +279,7 @@ async def update_emoji(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
|
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
|
||||||
async def delete_emoji(
|
async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||||
emoji_id: int,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
删除表情包
|
删除表情包
|
||||||
|
|
||||||
@@ -324,10 +306,7 @@ async def delete_emoji(
|
|||||||
|
|
||||||
logger.info(f"表情包已删除: ID={emoji_id}, hash={emoji_hash}")
|
logger.info(f"表情包已删除: ID={emoji_id}, hash={emoji_hash}")
|
||||||
|
|
||||||
return EmojiDeleteResponse(
|
return EmojiDeleteResponse(success=True, message=f"成功删除表情包: {emoji_hash}")
|
||||||
success=True,
|
|
||||||
message=f"成功删除表情包: {emoji_hash}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -337,9 +316,7 @@ async def delete_emoji(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/stats/summary")
|
@router.get("/stats/summary")
|
||||||
async def get_emoji_stats(
|
async def get_emoji_stats(authorization: Optional[str] = Header(None)):
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
获取表情包统计数据
|
获取表情包统计数据
|
||||||
|
|
||||||
@@ -369,7 +346,7 @@ async def get_emoji_stats(
|
|||||||
"id": emoji.id,
|
"id": emoji.id,
|
||||||
"emoji_hash": emoji.emoji_hash,
|
"emoji_hash": emoji.emoji_hash,
|
||||||
"description": emoji.description,
|
"description": emoji.description,
|
||||||
"usage_count": emoji.usage_count
|
"usage_count": emoji.usage_count,
|
||||||
}
|
}
|
||||||
for emoji in top_used
|
for emoji in top_used
|
||||||
]
|
]
|
||||||
@@ -382,8 +359,8 @@ async def get_emoji_stats(
|
|||||||
"banned": banned,
|
"banned": banned,
|
||||||
"unregistered": total - registered,
|
"unregistered": total - registered,
|
||||||
"formats": formats,
|
"formats": formats,
|
||||||
"top_used": top_used_list
|
"top_used": top_used_list,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
@@ -394,10 +371,7 @@ async def get_emoji_stats(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
|
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
|
||||||
async def register_emoji(
|
async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||||
emoji_id: int,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
注册表情包(快捷操作)
|
注册表情包(快捷操作)
|
||||||
|
|
||||||
@@ -429,11 +403,7 @@ async def register_emoji(
|
|||||||
|
|
||||||
logger.info(f"表情包已注册: ID={emoji_id}")
|
logger.info(f"表情包已注册: ID={emoji_id}")
|
||||||
|
|
||||||
return EmojiUpdateResponse(
|
return EmojiUpdateResponse(success=True, message="表情包注册成功", data=emoji_to_response(emoji))
|
||||||
success=True,
|
|
||||||
message="表情包注册成功",
|
|
||||||
data=emoji_to_response(emoji)
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -443,10 +413,7 @@ async def register_emoji(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
|
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
|
||||||
async def ban_emoji(
|
async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||||
emoji_id: int,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
禁用表情包(快捷操作)
|
禁用表情包(快捷操作)
|
||||||
|
|
||||||
@@ -472,11 +439,7 @@ async def ban_emoji(
|
|||||||
|
|
||||||
logger.info(f"表情包已禁用: ID={emoji_id}")
|
logger.info(f"表情包已禁用: ID={emoji_id}")
|
||||||
|
|
||||||
return EmojiUpdateResponse(
|
return EmojiUpdateResponse(success=True, message="表情包禁用成功", data=emoji_to_response(emoji))
|
||||||
success=True,
|
|
||||||
message="表情包禁用成功",
|
|
||||||
data=emoji_to_response(emoji)
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -489,7 +452,7 @@ async def ban_emoji(
|
|||||||
async def get_emoji_thumbnail(
|
async def get_emoji_thumbnail(
|
||||||
emoji_id: int,
|
emoji_id: int,
|
||||||
token: Optional[str] = Query(None, description="访问令牌"),
|
token: Optional[str] = Query(None, description="访问令牌"),
|
||||||
authorization: Optional[str] = Header(None)
|
authorization: Optional[str] = Header(None),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取表情包缩略图
|
获取表情包缩略图
|
||||||
@@ -523,25 +486,20 @@ async def get_emoji_thumbnail(
|
|||||||
|
|
||||||
# 根据格式设置 MIME 类型
|
# 根据格式设置 MIME 类型
|
||||||
mime_types = {
|
mime_types = {
|
||||||
'png': 'image/png',
|
"png": "image/png",
|
||||||
'jpg': 'image/jpeg',
|
"jpg": "image/jpeg",
|
||||||
'jpeg': 'image/jpeg',
|
"jpeg": "image/jpeg",
|
||||||
'gif': 'image/gif',
|
"gif": "image/gif",
|
||||||
'webp': 'image/webp',
|
"webp": "image/webp",
|
||||||
'bmp': 'image/bmp'
|
"bmp": "image/bmp",
|
||||||
}
|
}
|
||||||
|
|
||||||
media_type = mime_types.get(emoji.format.lower(), 'application/octet-stream')
|
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
|
||||||
|
|
||||||
return FileResponse(
|
return FileResponse(path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}")
|
||||||
path=emoji.full_path,
|
|
||||||
media_type=media_type,
|
|
||||||
filename=f"{emoji.emoji_hash}.{emoji.format}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"获取表情包缩略图失败: {e}")
|
logger.exception(f"获取表情包缩略图失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"获取表情包缩略图失败: {str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"获取表情包缩略图失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""表达方式管理 API 路由"""
|
"""表达方式管理 API 路由"""
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Header, Query
|
from fastapi import APIRouter, HTTPException, Header, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
@@ -15,6 +16,7 @@ router = APIRouter(prefix="/expression", tags=["Expression"])
|
|||||||
|
|
||||||
class ExpressionResponse(BaseModel):
|
class ExpressionResponse(BaseModel):
|
||||||
"""表达方式响应"""
|
"""表达方式响应"""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
situation: str
|
situation: str
|
||||||
style: str
|
style: str
|
||||||
@@ -27,6 +29,7 @@ class ExpressionResponse(BaseModel):
|
|||||||
|
|
||||||
class ExpressionListResponse(BaseModel):
|
class ExpressionListResponse(BaseModel):
|
||||||
"""表达方式列表响应"""
|
"""表达方式列表响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
total: int
|
total: int
|
||||||
page: int
|
page: int
|
||||||
@@ -36,12 +39,14 @@ class ExpressionListResponse(BaseModel):
|
|||||||
|
|
||||||
class ExpressionDetailResponse(BaseModel):
|
class ExpressionDetailResponse(BaseModel):
|
||||||
"""表达方式详情响应"""
|
"""表达方式详情响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
data: ExpressionResponse
|
data: ExpressionResponse
|
||||||
|
|
||||||
|
|
||||||
class ExpressionCreateRequest(BaseModel):
|
class ExpressionCreateRequest(BaseModel):
|
||||||
"""表达方式创建请求"""
|
"""表达方式创建请求"""
|
||||||
|
|
||||||
situation: str
|
situation: str
|
||||||
style: str
|
style: str
|
||||||
context: Optional[str] = None
|
context: Optional[str] = None
|
||||||
@@ -51,6 +56,7 @@ class ExpressionCreateRequest(BaseModel):
|
|||||||
|
|
||||||
class ExpressionUpdateRequest(BaseModel):
|
class ExpressionUpdateRequest(BaseModel):
|
||||||
"""表达方式更新请求"""
|
"""表达方式更新请求"""
|
||||||
|
|
||||||
situation: Optional[str] = None
|
situation: Optional[str] = None
|
||||||
style: Optional[str] = None
|
style: Optional[str] = None
|
||||||
context: Optional[str] = None
|
context: Optional[str] = None
|
||||||
@@ -60,6 +66,7 @@ class ExpressionUpdateRequest(BaseModel):
|
|||||||
|
|
||||||
class ExpressionUpdateResponse(BaseModel):
|
class ExpressionUpdateResponse(BaseModel):
|
||||||
"""表达方式更新响应"""
|
"""表达方式更新响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
data: Optional[ExpressionResponse] = None
|
data: Optional[ExpressionResponse] = None
|
||||||
@@ -67,12 +74,14 @@ class ExpressionUpdateResponse(BaseModel):
|
|||||||
|
|
||||||
class ExpressionDeleteResponse(BaseModel):
|
class ExpressionDeleteResponse(BaseModel):
|
||||||
"""表达方式删除响应"""
|
"""表达方式删除响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
class ExpressionCreateResponse(BaseModel):
|
class ExpressionCreateResponse(BaseModel):
|
||||||
"""表达方式创建响应"""
|
"""表达方式创建响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
data: ExpressionResponse
|
data: ExpressionResponse
|
||||||
@@ -112,7 +121,7 @@ async def get_expression_list(
|
|||||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||||
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
|
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
|
||||||
authorization: Optional[str] = Header(None)
|
authorization: Optional[str] = Header(None),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取表达方式列表
|
获取表达方式列表
|
||||||
@@ -136,9 +145,9 @@ async def get_expression_list(
|
|||||||
# 搜索过滤
|
# 搜索过滤
|
||||||
if search:
|
if search:
|
||||||
query = query.where(
|
query = query.where(
|
||||||
(Expression.situation.contains(search)) |
|
(Expression.situation.contains(search))
|
||||||
(Expression.style.contains(search)) |
|
| (Expression.style.contains(search))
|
||||||
(Expression.context.contains(search))
|
| (Expression.context.contains(search))
|
||||||
)
|
)
|
||||||
|
|
||||||
# 聊天ID过滤
|
# 聊天ID过滤
|
||||||
@@ -147,9 +156,9 @@ async def get_expression_list(
|
|||||||
|
|
||||||
# 排序:最后活跃时间倒序(NULL 值放在最后)
|
# 排序:最后活跃时间倒序(NULL 值放在最后)
|
||||||
from peewee import Case
|
from peewee import Case
|
||||||
|
|
||||||
query = query.order_by(
|
query = query.order_by(
|
||||||
Case(None, [(Expression.last_active_time.is_null(), 1)], 0),
|
Case(None, [(Expression.last_active_time.is_null(), 1)], 0), Expression.last_active_time.desc()
|
||||||
Expression.last_active_time.desc()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取总数
|
# 获取总数
|
||||||
@@ -162,13 +171,7 @@ async def get_expression_list(
|
|||||||
# 转换为响应对象
|
# 转换为响应对象
|
||||||
data = [expression_to_response(expr) for expr in expressions]
|
data = [expression_to_response(expr) for expr in expressions]
|
||||||
|
|
||||||
return ExpressionListResponse(
|
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||||
success=True,
|
|
||||||
total=total,
|
|
||||||
page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
data=data
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -178,10 +181,7 @@ async def get_expression_list(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
|
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
|
||||||
async def get_expression_detail(
|
async def get_expression_detail(expression_id: int, authorization: Optional[str] = Header(None)):
|
||||||
expression_id: int,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
获取表达方式详细信息
|
获取表达方式详细信息
|
||||||
|
|
||||||
@@ -200,10 +200,7 @@ async def get_expression_detail(
|
|||||||
if not expression:
|
if not expression:
|
||||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||||
|
|
||||||
return ExpressionDetailResponse(
|
return ExpressionDetailResponse(success=True, data=expression_to_response(expression))
|
||||||
success=True,
|
|
||||||
data=expression_to_response(expression)
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -213,10 +210,7 @@ async def get_expression_detail(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/", response_model=ExpressionCreateResponse)
|
@router.post("/", response_model=ExpressionCreateResponse)
|
||||||
async def create_expression(
|
async def create_expression(request: ExpressionCreateRequest, authorization: Optional[str] = Header(None)):
|
||||||
request: ExpressionCreateRequest,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
创建新的表达方式
|
创建新的表达方式
|
||||||
|
|
||||||
@@ -246,9 +240,7 @@ async def create_expression(
|
|||||||
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
|
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
|
||||||
|
|
||||||
return ExpressionCreateResponse(
|
return ExpressionCreateResponse(
|
||||||
success=True,
|
success=True, message="表达方式创建成功", data=expression_to_response(expression)
|
||||||
message="表达方式创建成功",
|
|
||||||
data=expression_to_response(expression)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
@@ -260,9 +252,7 @@ async def create_expression(
|
|||||||
|
|
||||||
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
|
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
|
||||||
async def update_expression(
|
async def update_expression(
|
||||||
expression_id: int,
|
expression_id: int, request: ExpressionUpdateRequest, authorization: Optional[str] = Header(None)
|
||||||
request: ExpressionUpdateRequest,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
增量更新表达方式(只更新提供的字段)
|
增量更新表达方式(只更新提供的字段)
|
||||||
@@ -290,7 +280,7 @@ async def update_expression(
|
|||||||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||||
|
|
||||||
# 更新最后活跃时间
|
# 更新最后活跃时间
|
||||||
update_data['last_active_time'] = time.time()
|
update_data["last_active_time"] = time.time()
|
||||||
|
|
||||||
# 执行更新
|
# 执行更新
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
@@ -301,9 +291,7 @@ async def update_expression(
|
|||||||
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
|
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
|
||||||
|
|
||||||
return ExpressionUpdateResponse(
|
return ExpressionUpdateResponse(
|
||||||
success=True,
|
success=True, message=f"成功更新 {len(update_data)} 个字段", data=expression_to_response(expression)
|
||||||
message=f"成功更新 {len(update_data)} 个字段",
|
|
||||||
data=expression_to_response(expression)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
@@ -314,10 +302,7 @@ async def update_expression(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
|
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
|
||||||
async def delete_expression(
|
async def delete_expression(expression_id: int, authorization: Optional[str] = Header(None)):
|
||||||
expression_id: int,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
删除表达方式
|
删除表达方式
|
||||||
|
|
||||||
@@ -344,10 +329,7 @@ async def delete_expression(
|
|||||||
|
|
||||||
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
|
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
|
||||||
|
|
||||||
return ExpressionDeleteResponse(
|
return ExpressionDeleteResponse(success=True, message=f"成功删除表达方式: {situation}")
|
||||||
success=True,
|
|
||||||
message=f"成功删除表达方式: {situation}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -357,9 +339,7 @@ async def delete_expression(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/stats/summary")
|
@router.get("/stats/summary")
|
||||||
async def get_expression_stats(
|
async def get_expression_stats(authorization: Optional[str] = Header(None)):
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
获取表达方式统计数据
|
获取表达方式统计数据
|
||||||
|
|
||||||
@@ -382,10 +362,11 @@ async def get_expression_stats(
|
|||||||
|
|
||||||
# 获取最近创建的记录数(7天内)
|
# 获取最近创建的记录数(7天内)
|
||||||
seven_days_ago = time.time() - (7 * 24 * 60 * 60)
|
seven_days_ago = time.time() - (7 * 24 * 60 * 60)
|
||||||
recent = Expression.select().where(
|
recent = (
|
||||||
(Expression.create_date.is_null(False)) &
|
Expression.select()
|
||||||
(Expression.create_date >= seven_days_ago)
|
.where((Expression.create_date.is_null(False)) & (Expression.create_date >= seven_days_ago))
|
||||||
).count()
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
@@ -393,8 +374,8 @@ async def get_expression_stats(
|
|||||||
"total": total,
|
"total": total,
|
||||||
"recent_7days": recent,
|
"recent_7days": recent,
|
||||||
"chat_count": len(chat_stats),
|
"chat_count": len(chat_stats),
|
||||||
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10])
|
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10]),
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取"""
|
"""Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取"""
|
||||||
|
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import httpx
|
import httpx
|
||||||
@@ -15,6 +16,7 @@ logger = get_logger("webui.git_mirror")
|
|||||||
# 导入进度更新函数(避免循环导入)
|
# 导入进度更新函数(避免循环导入)
|
||||||
_update_progress = None
|
_update_progress = None
|
||||||
|
|
||||||
|
|
||||||
def set_update_progress_callback(callback):
|
def set_update_progress_callback(callback):
|
||||||
"""设置进度更新回调函数"""
|
"""设置进度更新回调函数"""
|
||||||
global _update_progress
|
global _update_progress
|
||||||
@@ -23,6 +25,7 @@ def set_update_progress_callback(callback):
|
|||||||
|
|
||||||
class MirrorType(str, Enum):
|
class MirrorType(str, Enum):
|
||||||
"""镜像源类型"""
|
"""镜像源类型"""
|
||||||
|
|
||||||
GH_PROXY = "gh-proxy" # gh-proxy 主节点
|
GH_PROXY = "gh-proxy" # gh-proxy 主节点
|
||||||
HK_GH_PROXY = "hk-gh-proxy" # gh-proxy 香港节点
|
HK_GH_PROXY = "hk-gh-proxy" # gh-proxy 香港节点
|
||||||
CDN_GH_PROXY = "cdn-gh-proxy" # gh-proxy CDN 节点
|
CDN_GH_PROXY = "cdn-gh-proxy" # gh-proxy CDN 节点
|
||||||
@@ -47,7 +50,7 @@ class GitMirrorConfig:
|
|||||||
"clone_prefix": "https://gh-proxy.org/https://github.com",
|
"clone_prefix": "https://gh-proxy.org/https://github.com",
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
"priority": 1,
|
"priority": 1,
|
||||||
"created_at": None
|
"created_at": None,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "hk-gh-proxy",
|
"id": "hk-gh-proxy",
|
||||||
@@ -56,7 +59,7 @@ class GitMirrorConfig:
|
|||||||
"clone_prefix": "https://hk.gh-proxy.org/https://github.com",
|
"clone_prefix": "https://hk.gh-proxy.org/https://github.com",
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
"priority": 2,
|
"priority": 2,
|
||||||
"created_at": None
|
"created_at": None,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "cdn-gh-proxy",
|
"id": "cdn-gh-proxy",
|
||||||
@@ -65,7 +68,7 @@ class GitMirrorConfig:
|
|||||||
"clone_prefix": "https://cdn.gh-proxy.org/https://github.com",
|
"clone_prefix": "https://cdn.gh-proxy.org/https://github.com",
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
"priority": 3,
|
"priority": 3,
|
||||||
"created_at": None
|
"created_at": None,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "edgeone-gh-proxy",
|
"id": "edgeone-gh-proxy",
|
||||||
@@ -74,7 +77,7 @@ class GitMirrorConfig:
|
|||||||
"clone_prefix": "https://edgeone.gh-proxy.org/https://github.com",
|
"clone_prefix": "https://edgeone.gh-proxy.org/https://github.com",
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
"priority": 4,
|
"priority": 4,
|
||||||
"created_at": None
|
"created_at": None,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "meyzh-github",
|
"id": "meyzh-github",
|
||||||
@@ -83,7 +86,7 @@ class GitMirrorConfig:
|
|||||||
"clone_prefix": "https://meyzh.github.io/https://github.com",
|
"clone_prefix": "https://meyzh.github.io/https://github.com",
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
"priority": 5,
|
"priority": 5,
|
||||||
"created_at": None
|
"created_at": None,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "github",
|
"id": "github",
|
||||||
@@ -92,8 +95,8 @@ class GitMirrorConfig:
|
|||||||
"clone_prefix": "https://github.com",
|
"clone_prefix": "https://github.com",
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
"priority": 999,
|
"priority": 999,
|
||||||
"created_at": None
|
"created_at": None,
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -106,7 +109,7 @@ class GitMirrorConfig:
|
|||||||
"""加载配置文件"""
|
"""加载配置文件"""
|
||||||
try:
|
try:
|
||||||
if self.config_file.exists():
|
if self.config_file.exists():
|
||||||
with open(self.config_file, 'r', encoding='utf-8') as f:
|
with open(self.config_file, "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
# 检查是否有镜像源配置
|
# 检查是否有镜像源配置
|
||||||
@@ -145,14 +148,14 @@ class GitMirrorConfig:
|
|||||||
# 读取现有配置
|
# 读取现有配置
|
||||||
existing_data = {}
|
existing_data = {}
|
||||||
if self.config_file.exists():
|
if self.config_file.exists():
|
||||||
with open(self.config_file, 'r', encoding='utf-8') as f:
|
with open(self.config_file, "r", encoding="utf-8") as f:
|
||||||
existing_data = json.load(f)
|
existing_data = json.load(f)
|
||||||
|
|
||||||
# 更新镜像源配置
|
# 更新镜像源配置
|
||||||
existing_data["git_mirrors"] = self.mirrors
|
existing_data["git_mirrors"] = self.mirrors
|
||||||
|
|
||||||
# 写入文件
|
# 写入文件
|
||||||
with open(self.config_file, 'w', encoding='utf-8') as f:
|
with open(self.config_file, "w", encoding="utf-8") as f:
|
||||||
json.dump(existing_data, f, indent=2, ensure_ascii=False)
|
json.dump(existing_data, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
logger.debug(f"配置已保存到 {self.config_file}")
|
logger.debug(f"配置已保存到 {self.config_file}")
|
||||||
@@ -182,7 +185,7 @@ class GitMirrorConfig:
|
|||||||
raw_prefix: str,
|
raw_prefix: str,
|
||||||
clone_prefix: str,
|
clone_prefix: str,
|
||||||
enabled: bool = True,
|
enabled: bool = True,
|
||||||
priority: Optional[int] = None
|
priority: Optional[int] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
添加新的镜像源
|
添加新的镜像源
|
||||||
@@ -209,7 +212,7 @@ class GitMirrorConfig:
|
|||||||
"clone_prefix": clone_prefix,
|
"clone_prefix": clone_prefix,
|
||||||
"enabled": enabled,
|
"enabled": enabled,
|
||||||
"priority": priority,
|
"priority": priority,
|
||||||
"created_at": datetime.now().isoformat()
|
"created_at": datetime.now().isoformat(),
|
||||||
}
|
}
|
||||||
|
|
||||||
self.mirrors.append(new_mirror)
|
self.mirrors.append(new_mirror)
|
||||||
@@ -225,7 +228,7 @@ class GitMirrorConfig:
|
|||||||
raw_prefix: Optional[str] = None,
|
raw_prefix: Optional[str] = None,
|
||||||
clone_prefix: Optional[str] = None,
|
clone_prefix: Optional[str] = None,
|
||||||
enabled: Optional[bool] = None,
|
enabled: Optional[bool] = None,
|
||||||
priority: Optional[int] = None
|
priority: Optional[int] = None,
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
更新镜像源配置
|
更新镜像源配置
|
||||||
@@ -279,12 +282,7 @@ class GitMirrorConfig:
|
|||||||
class GitMirrorService:
|
class GitMirrorService:
|
||||||
"""Git 镜像源服务"""
|
"""Git 镜像源服务"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, max_retries: int = 3, timeout: int = 30, config: Optional[GitMirrorConfig] = None):
|
||||||
self,
|
|
||||||
max_retries: int = 3,
|
|
||||||
timeout: int = 30,
|
|
||||||
config: Optional[GitMirrorConfig] = None
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
初始化 Git 镜像源服务
|
初始化 Git 镜像源服务
|
||||||
|
|
||||||
@@ -323,46 +321,25 @@ class GitMirrorService:
|
|||||||
|
|
||||||
if not git_path:
|
if not git_path:
|
||||||
logger.warning("未找到 Git 可执行文件")
|
logger.warning("未找到 Git 可执行文件")
|
||||||
return {
|
return {"installed": False, "error": "系统中未找到 Git,请先安装 Git"}
|
||||||
"installed": False,
|
|
||||||
"error": "系统中未找到 Git,请先安装 Git"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 获取 Git 版本
|
# 获取 Git 版本
|
||||||
result = subprocess.run(
|
result = subprocess.run(["git", "--version"], capture_output=True, text=True, timeout=5)
|
||||||
["git", "--version"],
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
timeout=5
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.returncode == 0:
|
if result.returncode == 0:
|
||||||
version = result.stdout.strip()
|
version = result.stdout.strip()
|
||||||
logger.info(f"检测到 Git: {version} at {git_path}")
|
logger.info(f"检测到 Git: {version} at {git_path}")
|
||||||
return {
|
return {"installed": True, "version": version, "path": git_path}
|
||||||
"installed": True,
|
|
||||||
"version": version,
|
|
||||||
"path": git_path
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Git 命令执行失败: {result.stderr}")
|
logger.warning(f"Git 命令执行失败: {result.stderr}")
|
||||||
return {
|
return {"installed": False, "error": f"Git 命令执行失败: {result.stderr}"}
|
||||||
"installed": False,
|
|
||||||
"error": f"Git 命令执行失败: {result.stderr}"
|
|
||||||
}
|
|
||||||
|
|
||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
logger.error("Git 版本检测超时")
|
logger.error("Git 版本检测超时")
|
||||||
return {
|
return {"installed": False, "error": "Git 版本检测超时"}
|
||||||
"installed": False,
|
|
||||||
"error": "Git 版本检测超时"
|
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"检测 Git 时发生错误: {e}")
|
logger.error(f"检测 Git 时发生错误: {e}")
|
||||||
return {
|
return {"installed": False, "error": f"检测 Git 时发生错误: {str(e)}"}
|
||||||
"installed": False,
|
|
||||||
"error": f"检测 Git 时发生错误: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
async def fetch_raw_file(
|
async def fetch_raw_file(
|
||||||
self,
|
self,
|
||||||
@@ -371,7 +348,7 @@ class GitMirrorService:
|
|||||||
branch: str,
|
branch: str,
|
||||||
file_path: str,
|
file_path: str,
|
||||||
mirror_id: Optional[str] = None,
|
mirror_id: Optional[str] = None,
|
||||||
custom_url: Optional[str] = None
|
custom_url: Optional[str] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取 GitHub 仓库的 Raw 文件内容
|
获取 GitHub 仓库的 Raw 文件内容
|
||||||
@@ -403,12 +380,7 @@ class GitMirrorService:
|
|||||||
# 使用指定的镜像源
|
# 使用指定的镜像源
|
||||||
mirror = self.config.get_mirror_by_id(mirror_id)
|
mirror = self.config.get_mirror_by_id(mirror_id)
|
||||||
if not mirror:
|
if not mirror:
|
||||||
return {
|
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
|
||||||
"success": False,
|
|
||||||
"error": f"未找到镜像源: {mirror_id}",
|
|
||||||
"mirror_used": None,
|
|
||||||
"attempts": 0
|
|
||||||
}
|
|
||||||
mirrors_to_try = [mirror]
|
mirrors_to_try = [mirror]
|
||||||
else:
|
else:
|
||||||
# 使用所有启用的镜像源
|
# 使用所有启用的镜像源
|
||||||
@@ -427,14 +399,12 @@ class GitMirrorService:
|
|||||||
progress=progress,
|
progress=progress,
|
||||||
message=f"正在尝试镜像源 {index}/{total_mirrors}: {mirror['name']}",
|
message=f"正在尝试镜像源 {index}/{total_mirrors}: {mirror['name']}",
|
||||||
total_plugins=0,
|
total_plugins=0,
|
||||||
loaded_plugins=0
|
loaded_plugins=0,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"推送进度失败: {e}")
|
logger.warning(f"推送进度失败: {e}")
|
||||||
|
|
||||||
result = await self._fetch_raw_from_mirror(
|
result = await self._fetch_raw_from_mirror(owner, repo, branch, file_path, mirror)
|
||||||
owner, repo, branch, file_path, mirror
|
|
||||||
)
|
|
||||||
|
|
||||||
if result["success"]:
|
if result["success"]:
|
||||||
# 成功,推送进度
|
# 成功,推送进度
|
||||||
@@ -445,7 +415,7 @@ class GitMirrorService:
|
|||||||
progress=70,
|
progress=70,
|
||||||
message=f"成功从 {mirror['name']} 获取数据",
|
message=f"成功从 {mirror['name']} 获取数据",
|
||||||
total_plugins=0,
|
total_plugins=0,
|
||||||
loaded_plugins=0
|
loaded_plugins=0,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"推送进度失败: {e}")
|
logger.warning(f"推送进度失败: {e}")
|
||||||
@@ -461,26 +431,16 @@ class GitMirrorService:
|
|||||||
progress=30 + int(index / total_mirrors * 40),
|
progress=30 + int(index / total_mirrors * 40),
|
||||||
message=f"镜像源 {mirror['name']} 失败,尝试下一个...",
|
message=f"镜像源 {mirror['name']} 失败,尝试下一个...",
|
||||||
total_plugins=0,
|
total_plugins=0,
|
||||||
loaded_plugins=0
|
loaded_plugins=0,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"推送进度失败: {e}")
|
logger.warning(f"推送进度失败: {e}")
|
||||||
|
|
||||||
# 所有镜像源都失败
|
# 所有镜像源都失败
|
||||||
return {
|
return {"success": False, "error": "所有镜像源均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
|
||||||
"success": False,
|
|
||||||
"error": "所有镜像源均失败",
|
|
||||||
"mirror_used": None,
|
|
||||||
"attempts": len(mirrors_to_try)
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _fetch_raw_from_mirror(
|
async def _fetch_raw_from_mirror(
|
||||||
self,
|
self, owner: str, repo: str, branch: str, file_path: str, mirror: Dict[str, Any]
|
||||||
owner: str,
|
|
||||||
repo: str,
|
|
||||||
branch: str,
|
|
||||||
file_path: str,
|
|
||||||
mirror: Dict[str, Any]
|
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""从指定镜像源获取文件"""
|
"""从指定镜像源获取文件"""
|
||||||
# 构建 URL
|
# 构建 URL
|
||||||
@@ -508,7 +468,7 @@ class GitMirrorService:
|
|||||||
"data": response.text,
|
"data": response.text,
|
||||||
"mirror_used": mirror_type,
|
"mirror_used": mirror_type,
|
||||||
"attempts": attempts,
|
"attempts": attempts,
|
||||||
"url": url
|
"url": url,
|
||||||
}
|
}
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
last_error = f"HTTP {e.response.status_code}: {e}"
|
last_error = f"HTTP {e.response.status_code}: {e}"
|
||||||
@@ -520,13 +480,7 @@ class GitMirrorService:
|
|||||||
last_error = f"未知错误: {e}"
|
last_error = f"未知错误: {e}"
|
||||||
logger.error(f"错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
logger.error(f"错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||||
|
|
||||||
return {
|
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
|
||||||
"success": False,
|
|
||||||
"error": last_error,
|
|
||||||
"mirror_used": mirror_type,
|
|
||||||
"attempts": attempts,
|
|
||||||
"url": url
|
|
||||||
}
|
|
||||||
|
|
||||||
async def clone_repository(
|
async def clone_repository(
|
||||||
self,
|
self,
|
||||||
@@ -536,7 +490,7 @@ class GitMirrorService:
|
|||||||
branch: Optional[str] = None,
|
branch: Optional[str] = None,
|
||||||
mirror_id: Optional[str] = None,
|
mirror_id: Optional[str] = None,
|
||||||
custom_url: Optional[str] = None,
|
custom_url: Optional[str] = None,
|
||||||
depth: Optional[int] = None
|
depth: Optional[int] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
克隆 GitHub 仓库
|
克隆 GitHub 仓库
|
||||||
@@ -569,12 +523,7 @@ class GitMirrorService:
|
|||||||
# 使用指定的镜像源
|
# 使用指定的镜像源
|
||||||
mirror = self.config.get_mirror_by_id(mirror_id)
|
mirror = self.config.get_mirror_by_id(mirror_id)
|
||||||
if not mirror:
|
if not mirror:
|
||||||
return {
|
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
|
||||||
"success": False,
|
|
||||||
"error": f"未找到镜像源: {mirror_id}",
|
|
||||||
"mirror_used": None,
|
|
||||||
"attempts": 0
|
|
||||||
}
|
|
||||||
mirrors_to_try = [mirror]
|
mirrors_to_try = [mirror]
|
||||||
else:
|
else:
|
||||||
# 使用所有启用的镜像源
|
# 使用所有启用的镜像源
|
||||||
@@ -582,20 +531,13 @@ class GitMirrorService:
|
|||||||
|
|
||||||
# 依次尝试每个镜像源
|
# 依次尝试每个镜像源
|
||||||
for mirror in mirrors_to_try:
|
for mirror in mirrors_to_try:
|
||||||
result = await self._clone_from_mirror(
|
result = await self._clone_from_mirror(owner, repo, target_path, branch, depth, mirror)
|
||||||
owner, repo, target_path, branch, depth, mirror
|
|
||||||
)
|
|
||||||
if result["success"]:
|
if result["success"]:
|
||||||
return result
|
return result
|
||||||
logger.warning(f"镜像源 {mirror['id']} 克隆失败: {result.get('error')}")
|
logger.warning(f"镜像源 {mirror['id']} 克隆失败: {result.get('error')}")
|
||||||
|
|
||||||
# 所有镜像源都失败
|
# 所有镜像源都失败
|
||||||
return {
|
return {"success": False, "error": "所有镜像源克隆均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
|
||||||
"success": False,
|
|
||||||
"error": "所有镜像源克隆均失败",
|
|
||||||
"mirror_used": None,
|
|
||||||
"attempts": len(mirrors_to_try)
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _clone_from_mirror(
|
async def _clone_from_mirror(
|
||||||
self,
|
self,
|
||||||
@@ -604,7 +546,7 @@ class GitMirrorService:
|
|||||||
target_path: Path,
|
target_path: Path,
|
||||||
branch: Optional[str],
|
branch: Optional[str],
|
||||||
depth: Optional[int],
|
depth: Optional[int],
|
||||||
mirror: Dict[str, Any]
|
mirror: Dict[str, Any],
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""从指定镜像源克隆仓库"""
|
"""从指定镜像源克隆仓库"""
|
||||||
# 构建克隆 URL
|
# 构建克隆 URL
|
||||||
@@ -614,12 +556,7 @@ class GitMirrorService:
|
|||||||
return await self._clone_with_url(url, target_path, branch, depth, mirror["id"])
|
return await self._clone_with_url(url, target_path, branch, depth, mirror["id"])
|
||||||
|
|
||||||
async def _clone_with_url(
|
async def _clone_with_url(
|
||||||
self,
|
self, url: str, target_path: Path, branch: Optional[str], depth: Optional[int], mirror_type: str
|
||||||
url: str,
|
|
||||||
target_path: Path,
|
|
||||||
branch: Optional[str],
|
|
||||||
depth: Optional[int],
|
|
||||||
mirror_type: str
|
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""使用指定 URL 克隆仓库,支持重试"""
|
"""使用指定 URL 克隆仓库,支持重试"""
|
||||||
attempts = 0
|
attempts = 0
|
||||||
@@ -657,7 +594,7 @@ class GitMirrorService:
|
|||||||
stage="loading",
|
stage="loading",
|
||||||
progress=20 + attempt * 10,
|
progress=20 + attempt * 10,
|
||||||
message=f"正在克隆仓库 (尝试 {attempt + 1}/{self.max_retries})...",
|
message=f"正在克隆仓库 (尝试 {attempt + 1}/{self.max_retries})...",
|
||||||
operation="install"
|
operation="install",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"推送进度失败: {e}")
|
logger.warning(f"推送进度失败: {e}")
|
||||||
@@ -670,7 +607,7 @@ class GitMirrorService:
|
|||||||
cmd,
|
cmd,
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
text=True,
|
text=True,
|
||||||
timeout=300 # 5分钟超时
|
timeout=300, # 5分钟超时
|
||||||
)
|
)
|
||||||
|
|
||||||
process = await loop.run_in_executor(None, run_git_clone)
|
process = await loop.run_in_executor(None, run_git_clone)
|
||||||
@@ -683,7 +620,7 @@ class GitMirrorService:
|
|||||||
"mirror_used": mirror_type,
|
"mirror_used": mirror_type,
|
||||||
"attempts": attempts,
|
"attempts": attempts,
|
||||||
"url": url,
|
"url": url,
|
||||||
"branch": branch or "default"
|
"branch": branch or "default",
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
last_error = f"Git 克隆失败: {process.stderr}"
|
last_error = f"Git 克隆失败: {process.stderr}"
|
||||||
@@ -710,13 +647,7 @@ class GitMirrorService:
|
|||||||
if target_path.exists():
|
if target_path.exists():
|
||||||
shutil.rmtree(target_path, ignore_errors=True)
|
shutil.rmtree(target_path, ignore_errors=True)
|
||||||
|
|
||||||
return {
|
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
|
||||||
"success": False,
|
|
||||||
"error": last_error,
|
|
||||||
"mirror_used": mirror_type,
|
|
||||||
"attempts": attempts,
|
|
||||||
"url": url
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# 全局服务实例
|
# 全局服务实例
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""WebSocket 日志推送模块"""
|
"""WebSocket 日志推送模块"""
|
||||||
|
|
||||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||||
from typing import Set
|
from typing import Set
|
||||||
import json
|
import json
|
||||||
@@ -49,7 +50,9 @@ def load_recent_logs(limit: int = 100) -> list[dict]:
|
|||||||
log_entry = json.loads(line.strip())
|
log_entry = json.loads(line.strip())
|
||||||
# 转换为前端期望的格式
|
# 转换为前端期望的格式
|
||||||
# 使用时间戳 + 计数器生成唯一 ID
|
# 使用时间戳 + 计数器生成唯一 ID
|
||||||
timestamp_id = log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
|
timestamp_id = (
|
||||||
|
log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
|
||||||
|
)
|
||||||
formatted_log = {
|
formatted_log = {
|
||||||
"id": f"{timestamp_id}_{log_counter}",
|
"id": f"{timestamp_id}_{log_counter}",
|
||||||
"timestamp": log_entry.get("timestamp", ""),
|
"timestamp": log_entry.get("timestamp", ""),
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""WebUI 管理器 - 处理开发/生产环境的 WebUI 启动"""
|
"""WebUI 管理器 - 处理开发/生产环境的 WebUI 启动"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -55,10 +56,10 @@ def setup_production_mode() -> bool:
|
|||||||
|
|
||||||
# 确保正确的 MIME 类型映射
|
# 确保正确的 MIME 类型映射
|
||||||
mimetypes.init()
|
mimetypes.init()
|
||||||
mimetypes.add_type('application/javascript', '.js')
|
mimetypes.add_type("application/javascript", ".js")
|
||||||
mimetypes.add_type('application/javascript', '.mjs')
|
mimetypes.add_type("application/javascript", ".mjs")
|
||||||
mimetypes.add_type('text/css', '.css')
|
mimetypes.add_type("text/css", ".css")
|
||||||
mimetypes.add_type('application/json', '.json')
|
mimetypes.add_type("application/json", ".json")
|
||||||
|
|
||||||
server = get_global_server()
|
server = get_global_server()
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""人物信息管理 API 路由"""
|
"""人物信息管理 API 路由"""
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Header, Query
|
from fastapi import APIRouter, HTTPException, Header, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional, List, Dict
|
from typing import Optional, List, Dict
|
||||||
@@ -16,6 +17,7 @@ router = APIRouter(prefix="/person", tags=["Person"])
|
|||||||
|
|
||||||
class PersonInfoResponse(BaseModel):
|
class PersonInfoResponse(BaseModel):
|
||||||
"""人物信息响应"""
|
"""人物信息响应"""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
is_known: bool
|
is_known: bool
|
||||||
person_id: str
|
person_id: str
|
||||||
@@ -33,6 +35,7 @@ class PersonInfoResponse(BaseModel):
|
|||||||
|
|
||||||
class PersonListResponse(BaseModel):
|
class PersonListResponse(BaseModel):
|
||||||
"""人物列表响应"""
|
"""人物列表响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
total: int
|
total: int
|
||||||
page: int
|
page: int
|
||||||
@@ -42,12 +45,14 @@ class PersonListResponse(BaseModel):
|
|||||||
|
|
||||||
class PersonDetailResponse(BaseModel):
|
class PersonDetailResponse(BaseModel):
|
||||||
"""人物详情响应"""
|
"""人物详情响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
data: PersonInfoResponse
|
data: PersonInfoResponse
|
||||||
|
|
||||||
|
|
||||||
class PersonUpdateRequest(BaseModel):
|
class PersonUpdateRequest(BaseModel):
|
||||||
"""人物信息更新请求"""
|
"""人物信息更新请求"""
|
||||||
|
|
||||||
person_name: Optional[str] = None
|
person_name: Optional[str] = None
|
||||||
name_reason: Optional[str] = None
|
name_reason: Optional[str] = None
|
||||||
nickname: Optional[str] = None
|
nickname: Optional[str] = None
|
||||||
@@ -57,6 +62,7 @@ class PersonUpdateRequest(BaseModel):
|
|||||||
|
|
||||||
class PersonUpdateResponse(BaseModel):
|
class PersonUpdateResponse(BaseModel):
|
||||||
"""人物信息更新响应"""
|
"""人物信息更新响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
data: Optional[PersonInfoResponse] = None
|
data: Optional[PersonInfoResponse] = None
|
||||||
@@ -64,6 +70,7 @@ class PersonUpdateResponse(BaseModel):
|
|||||||
|
|
||||||
class PersonDeleteResponse(BaseModel):
|
class PersonDeleteResponse(BaseModel):
|
||||||
"""人物删除响应"""
|
"""人物删除响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
@@ -118,7 +125,7 @@ async def get_person_list(
|
|||||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||||
is_known: Optional[bool] = Query(None, description="是否已认识筛选"),
|
is_known: Optional[bool] = Query(None, description="是否已认识筛选"),
|
||||||
platform: Optional[str] = Query(None, description="平台筛选"),
|
platform: Optional[str] = Query(None, description="平台筛选"),
|
||||||
authorization: Optional[str] = Header(None)
|
authorization: Optional[str] = Header(None),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取人物信息列表
|
获取人物信息列表
|
||||||
@@ -143,9 +150,9 @@ async def get_person_list(
|
|||||||
# 搜索过滤
|
# 搜索过滤
|
||||||
if search:
|
if search:
|
||||||
query = query.where(
|
query = query.where(
|
||||||
(PersonInfo.person_name.contains(search)) |
|
(PersonInfo.person_name.contains(search))
|
||||||
(PersonInfo.nickname.contains(search)) |
|
| (PersonInfo.nickname.contains(search))
|
||||||
(PersonInfo.user_id.contains(search))
|
| (PersonInfo.user_id.contains(search))
|
||||||
)
|
)
|
||||||
|
|
||||||
# 已认识状态过滤
|
# 已认识状态过滤
|
||||||
@@ -159,10 +166,8 @@ async def get_person_list(
|
|||||||
# 排序:最后更新时间倒序(NULL 值放在最后)
|
# 排序:最后更新时间倒序(NULL 值放在最后)
|
||||||
# Peewee 不支持 nulls_last,使用 CASE WHEN 来实现
|
# Peewee 不支持 nulls_last,使用 CASE WHEN 来实现
|
||||||
from peewee import Case
|
from peewee import Case
|
||||||
query = query.order_by(
|
|
||||||
Case(None, [(PersonInfo.last_know.is_null(), 1)], 0),
|
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
|
||||||
PersonInfo.last_know.desc()
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取总数
|
# 获取总数
|
||||||
total = query.count()
|
total = query.count()
|
||||||
@@ -174,13 +179,7 @@ async def get_person_list(
|
|||||||
# 转换为响应对象
|
# 转换为响应对象
|
||||||
data = [person_to_response(person) for person in persons]
|
data = [person_to_response(person) for person in persons]
|
||||||
|
|
||||||
return PersonListResponse(
|
return PersonListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||||
success=True,
|
|
||||||
total=total,
|
|
||||||
page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
data=data
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -190,10 +189,7 @@ async def get_person_list(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{person_id}", response_model=PersonDetailResponse)
|
@router.get("/{person_id}", response_model=PersonDetailResponse)
|
||||||
async def get_person_detail(
|
async def get_person_detail(person_id: str, authorization: Optional[str] = Header(None)):
|
||||||
person_id: str,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
获取人物详细信息
|
获取人物详细信息
|
||||||
|
|
||||||
@@ -212,10 +208,7 @@ async def get_person_detail(
|
|||||||
if not person:
|
if not person:
|
||||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||||
|
|
||||||
return PersonDetailResponse(
|
return PersonDetailResponse(success=True, data=person_to_response(person))
|
||||||
success=True,
|
|
||||||
data=person_to_response(person)
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -225,11 +218,7 @@ async def get_person_detail(
|
|||||||
|
|
||||||
|
|
||||||
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
|
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
|
||||||
async def update_person(
|
async def update_person(person_id: str, request: PersonUpdateRequest, authorization: Optional[str] = Header(None)):
|
||||||
person_id: str,
|
|
||||||
request: PersonUpdateRequest,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
增量更新人物信息(只更新提供的字段)
|
增量更新人物信息(只更新提供的字段)
|
||||||
|
|
||||||
@@ -256,7 +245,7 @@ async def update_person(
|
|||||||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||||
|
|
||||||
# 更新最后修改时间
|
# 更新最后修改时间
|
||||||
update_data['last_know'] = time.time()
|
update_data["last_know"] = time.time()
|
||||||
|
|
||||||
# 执行更新
|
# 执行更新
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
@@ -267,9 +256,7 @@ async def update_person(
|
|||||||
logger.info(f"人物信息已更新: {person_id}, 字段: {list(update_data.keys())}")
|
logger.info(f"人物信息已更新: {person_id}, 字段: {list(update_data.keys())}")
|
||||||
|
|
||||||
return PersonUpdateResponse(
|
return PersonUpdateResponse(
|
||||||
success=True,
|
success=True, message=f"成功更新 {len(update_data)} 个字段", data=person_to_response(person)
|
||||||
message=f"成功更新 {len(update_data)} 个字段",
|
|
||||||
data=person_to_response(person)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
@@ -280,10 +267,7 @@ async def update_person(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
|
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
|
||||||
async def delete_person(
|
async def delete_person(person_id: str, authorization: Optional[str] = Header(None)):
|
||||||
person_id: str,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
删除人物信息
|
删除人物信息
|
||||||
|
|
||||||
@@ -310,10 +294,7 @@ async def delete_person(
|
|||||||
|
|
||||||
logger.info(f"人物信息已删除: {person_id} ({person_name})")
|
logger.info(f"人物信息已删除: {person_id} ({person_name})")
|
||||||
|
|
||||||
return PersonDeleteResponse(
|
return PersonDeleteResponse(success=True, message=f"成功删除人物信息: {person_name}")
|
||||||
success=True,
|
|
||||||
message=f"成功删除人物信息: {person_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -323,9 +304,7 @@ async def delete_person(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/stats/summary")
|
@router.get("/stats/summary")
|
||||||
async def get_person_stats(
|
async def get_person_stats(authorization: Optional[str] = Header(None)):
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
获取人物信息统计数据
|
获取人物信息统计数据
|
||||||
|
|
||||||
@@ -348,15 +327,7 @@ async def get_person_stats(
|
|||||||
platform = person.platform
|
platform = person.platform
|
||||||
platforms[platform] = platforms.get(platform, 0) + 1
|
platforms[platform] = platforms.get(platform, 0) + 1
|
||||||
|
|
||||||
return {
|
return {"success": True, "data": {"total": total, "known": known, "unknown": unknown, "platforms": platforms}}
|
||||||
"success": True,
|
|
||||||
"data": {
|
|
||||||
"total": total,
|
|
||||||
"known": known,
|
|
||||||
"unknown": unknown,
|
|
||||||
"platforms": platforms
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""WebSocket 插件加载进度推送模块"""
|
"""WebSocket 插件加载进度推送模块"""
|
||||||
|
|
||||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||||
from typing import Set, Dict, Any
|
from typing import Set, Dict, Any
|
||||||
import json
|
import json
|
||||||
@@ -22,7 +23,7 @@ current_progress: Dict[str, Any] = {
|
|||||||
"error": None,
|
"error": None,
|
||||||
"plugin_id": None, # 当前操作的插件 ID
|
"plugin_id": None, # 当前操作的插件 ID
|
||||||
"total_plugins": 0,
|
"total_plugins": 0,
|
||||||
"loaded_plugins": 0
|
"loaded_plugins": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -57,7 +58,7 @@ async def update_progress(
|
|||||||
error: str = None,
|
error: str = None,
|
||||||
plugin_id: str = None,
|
plugin_id: str = None,
|
||||||
total_plugins: int = 0,
|
total_plugins: int = 0,
|
||||||
loaded_plugins: int = 0
|
loaded_plugins: int = 0,
|
||||||
):
|
):
|
||||||
"""更新并广播进度
|
"""更新并广播进度
|
||||||
|
|
||||||
@@ -80,7 +81,7 @@ async def update_progress(
|
|||||||
"plugin_id": plugin_id,
|
"plugin_id": plugin_id,
|
||||||
"total_plugins": total_plugins,
|
"total_plugins": total_plugins,
|
||||||
"loaded_plugins": loaded_plugins,
|
"loaded_plugins": loaded_plugins,
|
||||||
"timestamp": asyncio.get_event_loop().time()
|
"timestamp": asyncio.get_event_loop().time(),
|
||||||
}
|
}
|
||||||
|
|
||||||
await broadcast_progress(progress_data)
|
await broadcast_progress(progress_data)
|
||||||
|
|||||||
@@ -30,12 +30,12 @@ def parse_version(version_str: str) -> tuple[int, int, int]:
|
|||||||
(major, minor, patch) 三元组
|
(major, minor, patch) 三元组
|
||||||
"""
|
"""
|
||||||
# 移除 snapshot 等后缀
|
# 移除 snapshot 等后缀
|
||||||
base_version = version_str.split('.snapshot')[0].split('.dev')[0].split('.alpha')[0].split('.beta')[0]
|
base_version = version_str.split(".snapshot")[0].split(".dev")[0].split(".alpha")[0].split(".beta")[0]
|
||||||
|
|
||||||
parts = base_version.split('.')
|
parts = base_version.split(".")
|
||||||
if len(parts) < 3:
|
if len(parts) < 3:
|
||||||
# 补齐到 3 位
|
# 补齐到 3 位
|
||||||
parts.extend(['0'] * (3 - len(parts)))
|
parts.extend(["0"] * (3 - len(parts)))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
major = int(parts[0])
|
major = int(parts[0])
|
||||||
@@ -49,8 +49,10 @@ def parse_version(version_str: str) -> tuple[int, int, int]:
|
|||||||
|
|
||||||
# ============ 请求/响应模型 ============
|
# ============ 请求/响应模型 ============
|
||||||
|
|
||||||
|
|
||||||
class FetchRawFileRequest(BaseModel):
|
class FetchRawFileRequest(BaseModel):
|
||||||
"""获取 Raw 文件请求"""
|
"""获取 Raw 文件请求"""
|
||||||
|
|
||||||
owner: str = Field(..., description="仓库所有者", example="MaiM-with-u")
|
owner: str = Field(..., description="仓库所有者", example="MaiM-with-u")
|
||||||
repo: str = Field(..., description="仓库名称", example="plugin-repo")
|
repo: str = Field(..., description="仓库名称", example="plugin-repo")
|
||||||
branch: str = Field(..., description="分支名称", example="main")
|
branch: str = Field(..., description="分支名称", example="main")
|
||||||
@@ -61,6 +63,7 @@ class FetchRawFileRequest(BaseModel):
|
|||||||
|
|
||||||
class FetchRawFileResponse(BaseModel):
|
class FetchRawFileResponse(BaseModel):
|
||||||
"""获取 Raw 文件响应"""
|
"""获取 Raw 文件响应"""
|
||||||
|
|
||||||
success: bool = Field(..., description="是否成功")
|
success: bool = Field(..., description="是否成功")
|
||||||
data: Optional[str] = Field(None, description="文件内容")
|
data: Optional[str] = Field(None, description="文件内容")
|
||||||
error: Optional[str] = Field(None, description="错误信息")
|
error: Optional[str] = Field(None, description="错误信息")
|
||||||
@@ -71,6 +74,7 @@ class FetchRawFileResponse(BaseModel):
|
|||||||
|
|
||||||
class CloneRepositoryRequest(BaseModel):
|
class CloneRepositoryRequest(BaseModel):
|
||||||
"""克隆仓库请求"""
|
"""克隆仓库请求"""
|
||||||
|
|
||||||
owner: str = Field(..., description="仓库所有者", example="MaiM-with-u")
|
owner: str = Field(..., description="仓库所有者", example="MaiM-with-u")
|
||||||
repo: str = Field(..., description="仓库名称", example="plugin-repo")
|
repo: str = Field(..., description="仓库名称", example="plugin-repo")
|
||||||
target_path: str = Field(..., description="目标路径(相对于插件目录)")
|
target_path: str = Field(..., description="目标路径(相对于插件目录)")
|
||||||
@@ -82,6 +86,7 @@ class CloneRepositoryRequest(BaseModel):
|
|||||||
|
|
||||||
class CloneRepositoryResponse(BaseModel):
|
class CloneRepositoryResponse(BaseModel):
|
||||||
"""克隆仓库响应"""
|
"""克隆仓库响应"""
|
||||||
|
|
||||||
success: bool = Field(..., description="是否成功")
|
success: bool = Field(..., description="是否成功")
|
||||||
path: Optional[str] = Field(None, description="克隆路径")
|
path: Optional[str] = Field(None, description="克隆路径")
|
||||||
error: Optional[str] = Field(None, description="错误信息")
|
error: Optional[str] = Field(None, description="错误信息")
|
||||||
@@ -93,6 +98,7 @@ class CloneRepositoryResponse(BaseModel):
|
|||||||
|
|
||||||
class MirrorConfigResponse(BaseModel):
|
class MirrorConfigResponse(BaseModel):
|
||||||
"""镜像源配置响应"""
|
"""镜像源配置响应"""
|
||||||
|
|
||||||
id: str = Field(..., description="镜像源 ID")
|
id: str = Field(..., description="镜像源 ID")
|
||||||
name: str = Field(..., description="镜像源名称")
|
name: str = Field(..., description="镜像源名称")
|
||||||
raw_prefix: str = Field(..., description="Raw 文件前缀")
|
raw_prefix: str = Field(..., description="Raw 文件前缀")
|
||||||
@@ -103,12 +109,14 @@ class MirrorConfigResponse(BaseModel):
|
|||||||
|
|
||||||
class AvailableMirrorsResponse(BaseModel):
|
class AvailableMirrorsResponse(BaseModel):
|
||||||
"""可用镜像源列表响应"""
|
"""可用镜像源列表响应"""
|
||||||
|
|
||||||
mirrors: List[MirrorConfigResponse] = Field(..., description="镜像源列表")
|
mirrors: List[MirrorConfigResponse] = Field(..., description="镜像源列表")
|
||||||
default_priority: List[str] = Field(..., description="默认优先级顺序(ID 列表)")
|
default_priority: List[str] = Field(..., description="默认优先级顺序(ID 列表)")
|
||||||
|
|
||||||
|
|
||||||
class AddMirrorRequest(BaseModel):
|
class AddMirrorRequest(BaseModel):
|
||||||
"""添加镜像源请求"""
|
"""添加镜像源请求"""
|
||||||
|
|
||||||
id: str = Field(..., description="镜像源 ID", example="custom-mirror")
|
id: str = Field(..., description="镜像源 ID", example="custom-mirror")
|
||||||
name: str = Field(..., description="镜像源名称", example="自定义镜像源")
|
name: str = Field(..., description="镜像源名称", example="自定义镜像源")
|
||||||
raw_prefix: str = Field(..., description="Raw 文件前缀", example="https://example.com/raw")
|
raw_prefix: str = Field(..., description="Raw 文件前缀", example="https://example.com/raw")
|
||||||
@@ -119,6 +127,7 @@ class AddMirrorRequest(BaseModel):
|
|||||||
|
|
||||||
class UpdateMirrorRequest(BaseModel):
|
class UpdateMirrorRequest(BaseModel):
|
||||||
"""更新镜像源请求"""
|
"""更新镜像源请求"""
|
||||||
|
|
||||||
name: Optional[str] = Field(None, description="镜像源名称")
|
name: Optional[str] = Field(None, description="镜像源名称")
|
||||||
raw_prefix: Optional[str] = Field(None, description="Raw 文件前缀")
|
raw_prefix: Optional[str] = Field(None, description="Raw 文件前缀")
|
||||||
clone_prefix: Optional[str] = Field(None, description="克隆前缀")
|
clone_prefix: Optional[str] = Field(None, description="克隆前缀")
|
||||||
@@ -128,6 +137,7 @@ class UpdateMirrorRequest(BaseModel):
|
|||||||
|
|
||||||
class GitStatusResponse(BaseModel):
|
class GitStatusResponse(BaseModel):
|
||||||
"""Git 安装状态响应"""
|
"""Git 安装状态响应"""
|
||||||
|
|
||||||
installed: bool = Field(..., description="是否已安装 Git")
|
installed: bool = Field(..., description="是否已安装 Git")
|
||||||
version: Optional[str] = Field(None, description="Git 版本号")
|
version: Optional[str] = Field(None, description="Git 版本号")
|
||||||
path: Optional[str] = Field(None, description="Git 可执行文件路径")
|
path: Optional[str] = Field(None, description="Git 可执行文件路径")
|
||||||
@@ -136,6 +146,7 @@ class GitStatusResponse(BaseModel):
|
|||||||
|
|
||||||
class InstallPluginRequest(BaseModel):
|
class InstallPluginRequest(BaseModel):
|
||||||
"""安装插件请求"""
|
"""安装插件请求"""
|
||||||
|
|
||||||
plugin_id: str = Field(..., description="插件 ID")
|
plugin_id: str = Field(..., description="插件 ID")
|
||||||
repository_url: str = Field(..., description="插件仓库 URL")
|
repository_url: str = Field(..., description="插件仓库 URL")
|
||||||
branch: Optional[str] = Field("main", description="分支名称")
|
branch: Optional[str] = Field("main", description="分支名称")
|
||||||
@@ -144,6 +155,7 @@ class InstallPluginRequest(BaseModel):
|
|||||||
|
|
||||||
class VersionResponse(BaseModel):
|
class VersionResponse(BaseModel):
|
||||||
"""麦麦版本响应"""
|
"""麦麦版本响应"""
|
||||||
|
|
||||||
version: str = Field(..., description="麦麦版本号")
|
version: str = Field(..., description="麦麦版本号")
|
||||||
version_major: int = Field(..., description="主版本号")
|
version_major: int = Field(..., description="主版本号")
|
||||||
version_minor: int = Field(..., description="次版本号")
|
version_minor: int = Field(..., description="次版本号")
|
||||||
@@ -152,11 +164,13 @@ class VersionResponse(BaseModel):
|
|||||||
|
|
||||||
class UninstallPluginRequest(BaseModel):
|
class UninstallPluginRequest(BaseModel):
|
||||||
"""卸载插件请求"""
|
"""卸载插件请求"""
|
||||||
|
|
||||||
plugin_id: str = Field(..., description="插件 ID")
|
plugin_id: str = Field(..., description="插件 ID")
|
||||||
|
|
||||||
|
|
||||||
class UpdatePluginRequest(BaseModel):
|
class UpdatePluginRequest(BaseModel):
|
||||||
"""更新插件请求"""
|
"""更新插件请求"""
|
||||||
|
|
||||||
plugin_id: str = Field(..., description="插件 ID")
|
plugin_id: str = Field(..., description="插件 ID")
|
||||||
repository_url: str = Field(..., description="插件仓库 URL")
|
repository_url: str = Field(..., description="插件仓库 URL")
|
||||||
branch: Optional[str] = Field("main", description="分支名称")
|
branch: Optional[str] = Field("main", description="分支名称")
|
||||||
@@ -165,6 +179,7 @@ class UpdatePluginRequest(BaseModel):
|
|||||||
|
|
||||||
# ============ API 路由 ============
|
# ============ API 路由 ============
|
||||||
|
|
||||||
|
|
||||||
@router.get("/version", response_model=VersionResponse)
|
@router.get("/version", response_model=VersionResponse)
|
||||||
async def get_maimai_version() -> VersionResponse:
|
async def get_maimai_version() -> VersionResponse:
|
||||||
"""
|
"""
|
||||||
@@ -174,12 +189,7 @@ async def get_maimai_version() -> VersionResponse:
|
|||||||
"""
|
"""
|
||||||
major, minor, patch = parse_version(MMC_VERSION)
|
major, minor, patch = parse_version(MMC_VERSION)
|
||||||
|
|
||||||
return VersionResponse(
|
return VersionResponse(version=MMC_VERSION, version_major=major, version_minor=minor, version_patch=patch)
|
||||||
version=MMC_VERSION,
|
|
||||||
version_major=major,
|
|
||||||
version_minor=minor,
|
|
||||||
version_patch=patch
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/git-status", response_model=GitStatusResponse)
|
@router.get("/git-status", response_model=GitStatusResponse)
|
||||||
@@ -196,9 +206,7 @@ async def check_git_status() -> GitStatusResponse:
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/mirrors", response_model=AvailableMirrorsResponse)
|
@router.get("/mirrors", response_model=AvailableMirrorsResponse)
|
||||||
async def get_available_mirrors(
|
async def get_available_mirrors(authorization: Optional[str] = Header(None)) -> AvailableMirrorsResponse:
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
) -> AvailableMirrorsResponse:
|
|
||||||
"""
|
"""
|
||||||
获取所有可用的镜像源配置
|
获取所有可用的镜像源配置
|
||||||
"""
|
"""
|
||||||
@@ -219,22 +227,16 @@ async def get_available_mirrors(
|
|||||||
raw_prefix=m["raw_prefix"],
|
raw_prefix=m["raw_prefix"],
|
||||||
clone_prefix=m["clone_prefix"],
|
clone_prefix=m["clone_prefix"],
|
||||||
enabled=m["enabled"],
|
enabled=m["enabled"],
|
||||||
priority=m["priority"]
|
priority=m["priority"],
|
||||||
)
|
)
|
||||||
for m in all_mirrors
|
for m in all_mirrors
|
||||||
]
|
]
|
||||||
|
|
||||||
return AvailableMirrorsResponse(
|
return AvailableMirrorsResponse(mirrors=mirrors, default_priority=config.get_default_priority_list())
|
||||||
mirrors=mirrors,
|
|
||||||
default_priority=config.get_default_priority_list()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/mirrors", response_model=MirrorConfigResponse)
|
@router.post("/mirrors", response_model=MirrorConfigResponse)
|
||||||
async def add_mirror(
|
async def add_mirror(request: AddMirrorRequest, authorization: Optional[str] = Header(None)) -> MirrorConfigResponse:
|
||||||
request: AddMirrorRequest,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
) -> MirrorConfigResponse:
|
|
||||||
"""
|
"""
|
||||||
添加新的镜像源
|
添加新的镜像源
|
||||||
"""
|
"""
|
||||||
@@ -254,7 +256,7 @@ async def add_mirror(
|
|||||||
raw_prefix=request.raw_prefix,
|
raw_prefix=request.raw_prefix,
|
||||||
clone_prefix=request.clone_prefix,
|
clone_prefix=request.clone_prefix,
|
||||||
enabled=request.enabled,
|
enabled=request.enabled,
|
||||||
priority=request.priority
|
priority=request.priority,
|
||||||
)
|
)
|
||||||
|
|
||||||
return MirrorConfigResponse(
|
return MirrorConfigResponse(
|
||||||
@@ -263,7 +265,7 @@ async def add_mirror(
|
|||||||
raw_prefix=mirror["raw_prefix"],
|
raw_prefix=mirror["raw_prefix"],
|
||||||
clone_prefix=mirror["clone_prefix"],
|
clone_prefix=mirror["clone_prefix"],
|
||||||
enabled=mirror["enabled"],
|
enabled=mirror["enabled"],
|
||||||
priority=mirror["priority"]
|
priority=mirror["priority"],
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||||
@@ -274,9 +276,7 @@ async def add_mirror(
|
|||||||
|
|
||||||
@router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse)
|
@router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse)
|
||||||
async def update_mirror(
|
async def update_mirror(
|
||||||
mirror_id: str,
|
mirror_id: str, request: UpdateMirrorRequest, authorization: Optional[str] = Header(None)
|
||||||
request: UpdateMirrorRequest,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
) -> MirrorConfigResponse:
|
) -> MirrorConfigResponse:
|
||||||
"""
|
"""
|
||||||
更新镜像源配置
|
更新镜像源配置
|
||||||
@@ -297,7 +297,7 @@ async def update_mirror(
|
|||||||
raw_prefix=request.raw_prefix,
|
raw_prefix=request.raw_prefix,
|
||||||
clone_prefix=request.clone_prefix,
|
clone_prefix=request.clone_prefix,
|
||||||
enabled=request.enabled,
|
enabled=request.enabled,
|
||||||
priority=request.priority
|
priority=request.priority,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not mirror:
|
if not mirror:
|
||||||
@@ -309,7 +309,7 @@ async def update_mirror(
|
|||||||
raw_prefix=mirror["raw_prefix"],
|
raw_prefix=mirror["raw_prefix"],
|
||||||
clone_prefix=mirror["clone_prefix"],
|
clone_prefix=mirror["clone_prefix"],
|
||||||
enabled=mirror["enabled"],
|
enabled=mirror["enabled"],
|
||||||
priority=mirror["priority"]
|
priority=mirror["priority"],
|
||||||
)
|
)
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -319,10 +319,7 @@ async def update_mirror(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/mirrors/{mirror_id}")
|
@router.delete("/mirrors/{mirror_id}")
|
||||||
async def delete_mirror(
|
async def delete_mirror(mirror_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||||
mirror_id: str,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
"""
|
||||||
删除镜像源
|
删除镜像源
|
||||||
"""
|
"""
|
||||||
@@ -340,16 +337,12 @@ async def delete_mirror(
|
|||||||
if not success:
|
if not success:
|
||||||
raise HTTPException(status_code=404, detail=f"未找到镜像源: {mirror_id}")
|
raise HTTPException(status_code=404, detail=f"未找到镜像源: {mirror_id}")
|
||||||
|
|
||||||
return {
|
return {"success": True, "message": f"已删除镜像源: {mirror_id}"}
|
||||||
"success": True,
|
|
||||||
"message": f"已删除镜像源: {mirror_id}"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/fetch-raw", response_model=FetchRawFileResponse)
|
@router.post("/fetch-raw", response_model=FetchRawFileResponse)
|
||||||
async def fetch_raw_file(
|
async def fetch_raw_file(
|
||||||
request: FetchRawFileRequest,
|
request: FetchRawFileRequest, authorization: Optional[str] = Header(None)
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
) -> FetchRawFileResponse:
|
) -> FetchRawFileResponse:
|
||||||
"""
|
"""
|
||||||
获取 GitHub 仓库的 Raw 文件内容
|
获取 GitHub 仓库的 Raw 文件内容
|
||||||
@@ -376,7 +369,7 @@ async def fetch_raw_file(
|
|||||||
progress=10,
|
progress=10,
|
||||||
message=f"正在获取插件列表: {request.file_path}",
|
message=f"正在获取插件列表: {request.file_path}",
|
||||||
total_plugins=0,
|
total_plugins=0,
|
||||||
loaded_plugins=0
|
loaded_plugins=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -389,22 +382,19 @@ async def fetch_raw_file(
|
|||||||
branch=request.branch,
|
branch=request.branch,
|
||||||
file_path=request.file_path,
|
file_path=request.file_path,
|
||||||
mirror_id=request.mirror_id,
|
mirror_id=request.mirror_id,
|
||||||
custom_url=request.custom_url
|
custom_url=request.custom_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
if result.get("success"):
|
if result.get("success"):
|
||||||
# 更新进度:成功获取
|
# 更新进度:成功获取
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="loading",
|
stage="loading", progress=70, message="正在解析插件数据...", total_plugins=0, loaded_plugins=0
|
||||||
progress=70,
|
|
||||||
message="正在解析插件数据...",
|
|
||||||
total_plugins=0,
|
|
||||||
loaded_plugins=0
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 尝试解析插件数量
|
# 尝试解析插件数量
|
||||||
try:
|
try:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
data = json.loads(result.get("data", "[]"))
|
data = json.loads(result.get("data", "[]"))
|
||||||
total = len(data) if isinstance(data, list) else 0
|
total = len(data) if isinstance(data, list) else 0
|
||||||
|
|
||||||
@@ -414,16 +404,12 @@ async def fetch_raw_file(
|
|||||||
progress=100,
|
progress=100,
|
||||||
message=f"成功加载 {total} 个插件",
|
message=f"成功加载 {total} 个插件",
|
||||||
total_plugins=total,
|
total_plugins=total,
|
||||||
loaded_plugins=total
|
loaded_plugins=total,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
# 如果解析失败,仍然发送成功状态
|
# 如果解析失败,仍然发送成功状态
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="success",
|
stage="success", progress=100, message="加载完成", total_plugins=0, loaded_plugins=0
|
||||||
progress=100,
|
|
||||||
message="加载完成",
|
|
||||||
total_plugins=0,
|
|
||||||
loaded_plugins=0
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return FetchRawFileResponse(**result)
|
return FetchRawFileResponse(**result)
|
||||||
@@ -433,12 +419,7 @@ async def fetch_raw_file(
|
|||||||
|
|
||||||
# 发送错误进度
|
# 发送错误进度
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="error",
|
stage="error", progress=0, message="加载失败", error=str(e), total_plugins=0, loaded_plugins=0
|
||||||
progress=0,
|
|
||||||
message="加载失败",
|
|
||||||
error=str(e),
|
|
||||||
total_plugins=0,
|
|
||||||
loaded_plugins=0
|
|
||||||
)
|
)
|
||||||
|
|
||||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||||
@@ -446,8 +427,7 @@ async def fetch_raw_file(
|
|||||||
|
|
||||||
@router.post("/clone", response_model=CloneRepositoryResponse)
|
@router.post("/clone", response_model=CloneRepositoryResponse)
|
||||||
async def clone_repository(
|
async def clone_repository(
|
||||||
request: CloneRepositoryRequest,
|
request: CloneRepositoryRequest, authorization: Optional[str] = Header(None)
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
) -> CloneRepositoryResponse:
|
) -> CloneRepositoryResponse:
|
||||||
"""
|
"""
|
||||||
克隆 GitHub 仓库到本地
|
克隆 GitHub 仓库到本地
|
||||||
@@ -460,9 +440,7 @@ async def clone_repository(
|
|||||||
if not token or not token_manager.verify_token(token):
|
if not token or not token_manager.verify_token(token):
|
||||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}")
|
||||||
f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# TODO: 验证 target_path 的安全性,防止路径遍历攻击
|
# TODO: 验证 target_path 的安全性,防止路径遍历攻击
|
||||||
@@ -478,7 +456,7 @@ async def clone_repository(
|
|||||||
branch=request.branch,
|
branch=request.branch,
|
||||||
mirror_id=request.mirror_id,
|
mirror_id=request.mirror_id,
|
||||||
custom_url=request.custom_url,
|
custom_url=request.custom_url,
|
||||||
depth=request.depth
|
depth=request.depth,
|
||||||
)
|
)
|
||||||
|
|
||||||
return CloneRepositoryResponse(**result)
|
return CloneRepositoryResponse(**result)
|
||||||
@@ -489,10 +467,7 @@ async def clone_repository(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/install")
|
@router.post("/install")
|
||||||
async def install_plugin(
|
async def install_plugin(request: InstallPluginRequest, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||||
request: InstallPluginRequest,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
"""
|
||||||
安装插件
|
安装插件
|
||||||
|
|
||||||
@@ -513,16 +488,16 @@ async def install_plugin(
|
|||||||
progress=5,
|
progress=5,
|
||||||
message=f"开始安装插件: {request.plugin_id}",
|
message=f"开始安装插件: {request.plugin_id}",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id
|
plugin_id=request.plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. 解析仓库 URL
|
# 1. 解析仓库 URL
|
||||||
# repository_url 格式: https://github.com/owner/repo
|
# repository_url 格式: https://github.com/owner/repo
|
||||||
repo_url = request.repository_url.rstrip('/')
|
repo_url = request.repository_url.rstrip("/")
|
||||||
if repo_url.endswith('.git'):
|
if repo_url.endswith(".git"):
|
||||||
repo_url = repo_url[:-4]
|
repo_url = repo_url[:-4]
|
||||||
|
|
||||||
parts = repo_url.split('/')
|
parts = repo_url.split("/")
|
||||||
if len(parts) < 2:
|
if len(parts) < 2:
|
||||||
raise HTTPException(status_code=400, detail="无效的仓库 URL")
|
raise HTTPException(status_code=400, detail="无效的仓库 URL")
|
||||||
|
|
||||||
@@ -534,7 +509,7 @@ async def install_plugin(
|
|||||||
progress=10,
|
progress=10,
|
||||||
message=f"解析仓库信息: {owner}/{repo}",
|
message=f"解析仓库信息: {owner}/{repo}",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id
|
plugin_id=request.plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. 确定插件安装路径
|
# 2. 确定插件安装路径
|
||||||
@@ -548,10 +523,10 @@ async def install_plugin(
|
|||||||
await update_progress(
|
await update_progress(
|
||||||
stage="error",
|
stage="error",
|
||||||
progress=0,
|
progress=0,
|
||||||
message=f"插件已存在",
|
message="插件已存在",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=request.plugin_id,
|
||||||
error="插件已安装,请先卸载"
|
error="插件已安装,请先卸载",
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=400, detail="插件已安装")
|
raise HTTPException(status_code=400, detail="插件已安装")
|
||||||
|
|
||||||
@@ -560,31 +535,26 @@ async def install_plugin(
|
|||||||
progress=15,
|
progress=15,
|
||||||
message=f"准备克隆到: {target_path}",
|
message=f"准备克隆到: {target_path}",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id
|
plugin_id=request.plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 克隆仓库(这里会自动推送 20%-80% 的进度)
|
# 3. 克隆仓库(这里会自动推送 20%-80% 的进度)
|
||||||
service = get_git_mirror_service()
|
service = get_git_mirror_service()
|
||||||
|
|
||||||
# 如果是 GitHub 仓库,使用镜像源
|
# 如果是 GitHub 仓库,使用镜像源
|
||||||
if 'github.com' in repo_url:
|
if "github.com" in repo_url:
|
||||||
result = await service.clone_repository(
|
result = await service.clone_repository(
|
||||||
owner=owner,
|
owner=owner,
|
||||||
repo=repo,
|
repo=repo,
|
||||||
target_path=target_path,
|
target_path=target_path,
|
||||||
branch=request.branch,
|
branch=request.branch,
|
||||||
mirror_id=request.mirror_id,
|
mirror_id=request.mirror_id,
|
||||||
depth=1 # 浅克隆,节省时间和空间
|
depth=1, # 浅克隆,节省时间和空间
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 自定义仓库,直接使用 URL
|
# 自定义仓库,直接使用 URL
|
||||||
result = await service.clone_repository(
|
result = await service.clone_repository(
|
||||||
owner=owner,
|
owner=owner, repo=repo, target_path=target_path, branch=request.branch, custom_url=repo_url, depth=1
|
||||||
repo=repo,
|
|
||||||
target_path=target_path,
|
|
||||||
branch=request.branch,
|
|
||||||
custom_url=repo_url,
|
|
||||||
depth=1
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not result.get("success"):
|
if not result.get("success"):
|
||||||
@@ -595,23 +565,20 @@ async def install_plugin(
|
|||||||
message="克隆仓库失败",
|
message="克隆仓库失败",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=request.plugin_id,
|
||||||
error=error_msg
|
error=error_msg,
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=500, detail=error_msg)
|
raise HTTPException(status_code=500, detail=error_msg)
|
||||||
|
|
||||||
# 4. 验证插件完整性
|
# 4. 验证插件完整性
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="loading",
|
stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=request.plugin_id
|
||||||
progress=85,
|
|
||||||
message="验证插件文件...",
|
|
||||||
operation="install",
|
|
||||||
plugin_id=request.plugin_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
manifest_path = target_path / "_manifest.json"
|
manifest_path = target_path / "_manifest.json"
|
||||||
if not manifest_path.exists():
|
if not manifest_path.exists():
|
||||||
# 清理失败的安装
|
# 清理失败的安装
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
shutil.rmtree(target_path, ignore_errors=True)
|
shutil.rmtree(target_path, ignore_errors=True)
|
||||||
|
|
||||||
await update_progress(
|
await update_progress(
|
||||||
@@ -620,26 +587,23 @@ async def install_plugin(
|
|||||||
message="插件缺少 _manifest.json",
|
message="插件缺少 _manifest.json",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=request.plugin_id,
|
||||||
error="无效的插件格式"
|
error="无效的插件格式",
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
||||||
|
|
||||||
# 5. 读取并验证 manifest
|
# 5. 读取并验证 manifest
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="loading",
|
stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=request.plugin_id
|
||||||
progress=90,
|
|
||||||
message="读取插件配置...",
|
|
||||||
operation="install",
|
|
||||||
plugin_id=request.plugin_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import json as json_module
|
import json as json_module
|
||||||
with open(manifest_path, 'r', encoding='utf-8') as f:
|
|
||||||
|
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||||
manifest = json_module.load(f)
|
manifest = json_module.load(f)
|
||||||
|
|
||||||
# 基本验证
|
# 基本验证
|
||||||
required_fields = ['manifest_version', 'name', 'version', 'author']
|
required_fields = ["manifest_version", "name", "version", "author"]
|
||||||
for field in required_fields:
|
for field in required_fields:
|
||||||
if field not in manifest:
|
if field not in manifest:
|
||||||
raise ValueError(f"缺少必需字段: {field}")
|
raise ValueError(f"缺少必需字段: {field}")
|
||||||
@@ -647,6 +611,7 @@ async def install_plugin(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 清理失败的安装
|
# 清理失败的安装
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
shutil.rmtree(target_path, ignore_errors=True)
|
shutil.rmtree(target_path, ignore_errors=True)
|
||||||
|
|
||||||
await update_progress(
|
await update_progress(
|
||||||
@@ -655,7 +620,7 @@ async def install_plugin(
|
|||||||
message="_manifest.json 无效",
|
message="_manifest.json 无效",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=request.plugin_id,
|
||||||
error=str(e)
|
error=str(e),
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
||||||
|
|
||||||
@@ -665,16 +630,16 @@ async def install_plugin(
|
|||||||
progress=100,
|
progress=100,
|
||||||
message=f"成功安装插件: {manifest['name']} v{manifest['version']}",
|
message=f"成功安装插件: {manifest['name']} v{manifest['version']}",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id
|
plugin_id=request.plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"message": "插件安装成功",
|
"message": "插件安装成功",
|
||||||
"plugin_id": request.plugin_id,
|
"plugin_id": request.plugin_id,
|
||||||
"plugin_name": manifest['name'],
|
"plugin_name": manifest["name"],
|
||||||
"version": manifest['version'],
|
"version": manifest["version"],
|
||||||
"path": str(target_path)
|
"path": str(target_path),
|
||||||
}
|
}
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
@@ -688,7 +653,7 @@ async def install_plugin(
|
|||||||
message="安装失败",
|
message="安装失败",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=request.plugin_id,
|
||||||
error=str(e)
|
error=str(e),
|
||||||
)
|
)
|
||||||
|
|
||||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||||
@@ -696,8 +661,7 @@ async def install_plugin(
|
|||||||
|
|
||||||
@router.post("/uninstall")
|
@router.post("/uninstall")
|
||||||
async def uninstall_plugin(
|
async def uninstall_plugin(
|
||||||
request: UninstallPluginRequest,
|
request: UninstallPluginRequest, authorization: Optional[str] = Header(None)
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
卸载插件
|
卸载插件
|
||||||
@@ -719,7 +683,7 @@ async def uninstall_plugin(
|
|||||||
progress=10,
|
progress=10,
|
||||||
message=f"开始卸载插件: {request.plugin_id}",
|
message=f"开始卸载插件: {request.plugin_id}",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id
|
plugin_id=request.plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. 检查插件是否存在
|
# 1. 检查插件是否存在
|
||||||
@@ -733,7 +697,7 @@ async def uninstall_plugin(
|
|||||||
message="插件不存在",
|
message="插件不存在",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=request.plugin_id,
|
||||||
error="插件未安装或已被删除"
|
error="插件未安装或已被删除",
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=404, detail="插件未安装")
|
raise HTTPException(status_code=404, detail="插件未安装")
|
||||||
|
|
||||||
@@ -742,7 +706,7 @@ async def uninstall_plugin(
|
|||||||
progress=30,
|
progress=30,
|
||||||
message=f"正在删除插件文件: {plugin_path}",
|
message=f"正在删除插件文件: {plugin_path}",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id
|
plugin_id=request.plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. 读取插件信息(用于日志)
|
# 2. 读取插件信息(用于日志)
|
||||||
@@ -752,7 +716,8 @@ async def uninstall_plugin(
|
|||||||
if manifest_path.exists():
|
if manifest_path.exists():
|
||||||
try:
|
try:
|
||||||
import json as json_module
|
import json as json_module
|
||||||
with open(manifest_path, 'r', encoding='utf-8') as f:
|
|
||||||
|
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||||
manifest = json_module.load(f)
|
manifest = json_module.load(f)
|
||||||
plugin_name = manifest.get("name", request.plugin_id)
|
plugin_name = manifest.get("name", request.plugin_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -763,7 +728,7 @@ async def uninstall_plugin(
|
|||||||
progress=50,
|
progress=50,
|
||||||
message=f"正在删除 {plugin_name}...",
|
message=f"正在删除 {plugin_name}...",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id
|
plugin_id=request.plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 删除插件目录
|
# 3. 删除插件目录
|
||||||
@@ -773,6 +738,7 @@ async def uninstall_plugin(
|
|||||||
def remove_readonly(func, path, _):
|
def remove_readonly(func, path, _):
|
||||||
"""清除只读属性并删除文件"""
|
"""清除只读属性并删除文件"""
|
||||||
import os
|
import os
|
||||||
|
|
||||||
os.chmod(path, stat.S_IWRITE)
|
os.chmod(path, stat.S_IWRITE)
|
||||||
func(path)
|
func(path)
|
||||||
|
|
||||||
@@ -786,15 +752,10 @@ async def uninstall_plugin(
|
|||||||
progress=100,
|
progress=100,
|
||||||
message=f"成功卸载插件: {plugin_name}",
|
message=f"成功卸载插件: {plugin_name}",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id
|
plugin_id=request.plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {"success": True, "message": "插件卸载成功", "plugin_id": request.plugin_id, "plugin_name": plugin_name}
|
||||||
"success": True,
|
|
||||||
"message": "插件卸载成功",
|
|
||||||
"plugin_id": request.plugin_id,
|
|
||||||
"plugin_name": plugin_name
|
|
||||||
}
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -807,7 +768,7 @@ async def uninstall_plugin(
|
|||||||
message="卸载失败",
|
message="卸载失败",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=request.plugin_id,
|
||||||
error="权限不足,无法删除插件文件"
|
error="权限不足,无法删除插件文件",
|
||||||
)
|
)
|
||||||
|
|
||||||
raise HTTPException(status_code=500, detail="权限不足,无法删除插件文件") from e
|
raise HTTPException(status_code=500, detail="权限不足,无法删除插件文件") from e
|
||||||
@@ -820,17 +781,14 @@ async def uninstall_plugin(
|
|||||||
message="卸载失败",
|
message="卸载失败",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=request.plugin_id,
|
||||||
error=str(e)
|
error=str(e),
|
||||||
)
|
)
|
||||||
|
|
||||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.post("/update")
|
@router.post("/update")
|
||||||
async def update_plugin(
|
async def update_plugin(request: UpdatePluginRequest, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||||
request: UpdatePluginRequest,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
"""
|
||||||
更新插件
|
更新插件
|
||||||
|
|
||||||
@@ -851,7 +809,7 @@ async def update_plugin(
|
|||||||
progress=5,
|
progress=5,
|
||||||
message=f"开始更新插件: {request.plugin_id}",
|
message=f"开始更新插件: {request.plugin_id}",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id
|
plugin_id=request.plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. 检查插件是否已安装
|
# 1. 检查插件是否已安装
|
||||||
@@ -865,7 +823,7 @@ async def update_plugin(
|
|||||||
message="插件不存在",
|
message="插件不存在",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=request.plugin_id,
|
||||||
error="插件未安装,请先安装"
|
error="插件未安装,请先安装",
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=404, detail="插件未安装")
|
raise HTTPException(status_code=404, detail="插件未安装")
|
||||||
|
|
||||||
@@ -877,10 +835,11 @@ async def update_plugin(
|
|||||||
if manifest_path.exists():
|
if manifest_path.exists():
|
||||||
try:
|
try:
|
||||||
import json as json_module
|
import json as json_module
|
||||||
with open(manifest_path, 'r', encoding='utf-8') as f:
|
|
||||||
|
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||||
manifest = json_module.load(f)
|
manifest = json_module.load(f)
|
||||||
old_version = manifest.get("version", "unknown")
|
old_version = manifest.get("version", "unknown")
|
||||||
plugin_name = manifest.get("name", request.plugin_id)
|
_plugin_name = manifest.get("name", request.plugin_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -889,16 +848,12 @@ async def update_plugin(
|
|||||||
progress=10,
|
progress=10,
|
||||||
message=f"当前版本: {old_version},准备更新...",
|
message=f"当前版本: {old_version},准备更新...",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id
|
plugin_id=request.plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 删除旧版本
|
# 3. 删除旧版本
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="loading",
|
stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=request.plugin_id
|
||||||
progress=20,
|
|
||||||
message="正在删除旧版本...",
|
|
||||||
operation="update",
|
|
||||||
plugin_id=request.plugin_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import shutil
|
import shutil
|
||||||
@@ -907,6 +862,7 @@ async def update_plugin(
|
|||||||
def remove_readonly(func, path, _):
|
def remove_readonly(func, path, _):
|
||||||
"""清除只读属性并删除文件"""
|
"""清除只读属性并删除文件"""
|
||||||
import os
|
import os
|
||||||
|
|
||||||
os.chmod(path, stat.S_IWRITE)
|
os.chmod(path, stat.S_IWRITE)
|
||||||
func(path)
|
func(path)
|
||||||
|
|
||||||
@@ -920,14 +876,14 @@ async def update_plugin(
|
|||||||
progress=30,
|
progress=30,
|
||||||
message="正在准备下载新版本...",
|
message="正在准备下载新版本...",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id
|
plugin_id=request.plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
repo_url = request.repository_url.rstrip('/')
|
repo_url = request.repository_url.rstrip("/")
|
||||||
if repo_url.endswith('.git'):
|
if repo_url.endswith(".git"):
|
||||||
repo_url = repo_url[:-4]
|
repo_url = repo_url[:-4]
|
||||||
|
|
||||||
parts = repo_url.split('/')
|
parts = repo_url.split("/")
|
||||||
if len(parts) < 2:
|
if len(parts) < 2:
|
||||||
raise HTTPException(status_code=400, detail="无效的仓库 URL")
|
raise HTTPException(status_code=400, detail="无效的仓库 URL")
|
||||||
|
|
||||||
@@ -937,23 +893,18 @@ async def update_plugin(
|
|||||||
# 5. 克隆新版本(这里会推送 35%-85% 的进度)
|
# 5. 克隆新版本(这里会推送 35%-85% 的进度)
|
||||||
service = get_git_mirror_service()
|
service = get_git_mirror_service()
|
||||||
|
|
||||||
if 'github.com' in repo_url:
|
if "github.com" in repo_url:
|
||||||
result = await service.clone_repository(
|
result = await service.clone_repository(
|
||||||
owner=owner,
|
owner=owner,
|
||||||
repo=repo,
|
repo=repo,
|
||||||
target_path=plugin_path,
|
target_path=plugin_path,
|
||||||
branch=request.branch,
|
branch=request.branch,
|
||||||
mirror_id=request.mirror_id,
|
mirror_id=request.mirror_id,
|
||||||
depth=1
|
depth=1,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = await service.clone_repository(
|
result = await service.clone_repository(
|
||||||
owner=owner,
|
owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, custom_url=repo_url, depth=1
|
||||||
repo=repo,
|
|
||||||
target_path=plugin_path,
|
|
||||||
branch=request.branch,
|
|
||||||
custom_url=repo_url,
|
|
||||||
depth=1
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not result.get("success"):
|
if not result.get("success"):
|
||||||
@@ -964,17 +915,13 @@ async def update_plugin(
|
|||||||
message="下载新版本失败",
|
message="下载新版本失败",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=request.plugin_id,
|
||||||
error=error_msg
|
error=error_msg,
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=500, detail=error_msg)
|
raise HTTPException(status_code=500, detail=error_msg)
|
||||||
|
|
||||||
# 6. 验证新版本
|
# 6. 验证新版本
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="loading",
|
stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=request.plugin_id
|
||||||
progress=90,
|
|
||||||
message="验证新版本...",
|
|
||||||
operation="update",
|
|
||||||
plugin_id=request.plugin_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
new_manifest_path = plugin_path / "_manifest.json"
|
new_manifest_path = plugin_path / "_manifest.json"
|
||||||
@@ -983,6 +930,7 @@ async def update_plugin(
|
|||||||
def remove_readonly(func, path, _):
|
def remove_readonly(func, path, _):
|
||||||
"""清除只读属性并删除文件"""
|
"""清除只读属性并删除文件"""
|
||||||
import os
|
import os
|
||||||
|
|
||||||
os.chmod(path, stat.S_IWRITE)
|
os.chmod(path, stat.S_IWRITE)
|
||||||
func(path)
|
func(path)
|
||||||
|
|
||||||
@@ -994,13 +942,13 @@ async def update_plugin(
|
|||||||
message="新版本缺少 _manifest.json",
|
message="新版本缺少 _manifest.json",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=request.plugin_id,
|
||||||
error="无效的插件格式"
|
error="无效的插件格式",
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
||||||
|
|
||||||
# 7. 读取新版本信息
|
# 7. 读取新版本信息
|
||||||
try:
|
try:
|
||||||
with open(new_manifest_path, 'r', encoding='utf-8') as f:
|
with open(new_manifest_path, "r", encoding="utf-8") as f:
|
||||||
new_manifest = json_module.load(f)
|
new_manifest = json_module.load(f)
|
||||||
|
|
||||||
new_version = new_manifest.get("version", "unknown")
|
new_version = new_manifest.get("version", "unknown")
|
||||||
@@ -1014,7 +962,7 @@ async def update_plugin(
|
|||||||
progress=100,
|
progress=100,
|
||||||
message=f"成功更新 {new_name}: {old_version} → {new_version}",
|
message=f"成功更新 {new_name}: {old_version} → {new_version}",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id
|
plugin_id=request.plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -1023,7 +971,7 @@ async def update_plugin(
|
|||||||
"plugin_id": request.plugin_id,
|
"plugin_id": request.plugin_id,
|
||||||
"plugin_name": new_name,
|
"plugin_name": new_name,
|
||||||
"old_version": old_version,
|
"old_version": old_version,
|
||||||
"new_version": new_version
|
"new_version": new_version,
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1036,7 +984,7 @@ async def update_plugin(
|
|||||||
message="_manifest.json 无效",
|
message="_manifest.json 无效",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=request.plugin_id,
|
||||||
error=str(e)
|
error=str(e),
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
||||||
|
|
||||||
@@ -1046,21 +994,14 @@ async def update_plugin(
|
|||||||
logger.error(f"更新插件失败: {e}", exc_info=True)
|
logger.error(f"更新插件失败: {e}", exc_info=True)
|
||||||
|
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="error",
|
stage="error", progress=0, message="更新失败", operation="update", plugin_id=request.plugin_id, error=str(e)
|
||||||
progress=0,
|
|
||||||
message="更新失败",
|
|
||||||
operation="update",
|
|
||||||
plugin_id=request.plugin_id,
|
|
||||||
error=str(e)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.get("/installed")
|
@router.get("/installed")
|
||||||
async def get_installed_plugins(
|
async def get_installed_plugins(authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
"""
|
||||||
获取已安装的插件列表
|
获取已安装的插件列表
|
||||||
|
|
||||||
@@ -1081,10 +1022,7 @@ async def get_installed_plugins(
|
|||||||
if not plugins_dir.exists():
|
if not plugins_dir.exists():
|
||||||
logger.info("插件目录不存在,创建目录")
|
logger.info("插件目录不存在,创建目录")
|
||||||
plugins_dir.mkdir(exist_ok=True)
|
plugins_dir.mkdir(exist_ok=True)
|
||||||
return {
|
return {"success": True, "plugins": []}
|
||||||
"success": True,
|
|
||||||
"plugins": []
|
|
||||||
}
|
|
||||||
|
|
||||||
installed_plugins = []
|
installed_plugins = []
|
||||||
|
|
||||||
@@ -1098,7 +1036,7 @@ async def get_installed_plugins(
|
|||||||
plugin_id = plugin_path.name
|
plugin_id = plugin_path.name
|
||||||
|
|
||||||
# 跳过隐藏目录和特殊目录
|
# 跳过隐藏目录和特殊目录
|
||||||
if plugin_id.startswith('.') or plugin_id.startswith('__'):
|
if plugin_id.startswith(".") or plugin_id.startswith("__"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 读取 _manifest.json
|
# 读取 _manifest.json
|
||||||
@@ -1110,20 +1048,23 @@ async def get_installed_plugins(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import json as json_module
|
import json as json_module
|
||||||
with open(manifest_path, 'r', encoding='utf-8') as f:
|
|
||||||
|
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||||
manifest = json_module.load(f)
|
manifest = json_module.load(f)
|
||||||
|
|
||||||
# 基本验证
|
# 基本验证
|
||||||
if 'name' not in manifest or 'version' not in manifest:
|
if "name" not in manifest or "version" not in manifest:
|
||||||
logger.warning(f"插件 {plugin_id} 的 _manifest.json 格式无效,跳过")
|
logger.warning(f"插件 {plugin_id} 的 _manifest.json 格式无效,跳过")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 添加到已安装列表(返回完整的 manifest 信息)
|
# 添加到已安装列表(返回完整的 manifest 信息)
|
||||||
installed_plugins.append({
|
installed_plugins.append(
|
||||||
|
{
|
||||||
"id": plugin_id,
|
"id": plugin_id,
|
||||||
"manifest": manifest, # 返回完整的 manifest 对象
|
"manifest": manifest, # 返回完整的 manifest 对象
|
||||||
"path": str(plugin_path.absolute())
|
"path": str(plugin_path.absolute()),
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.warning(f"插件 {plugin_id} 的 _manifest.json 解析失败: {e}")
|
logger.warning(f"插件 {plugin_id} 的 _manifest.json 解析失败: {e}")
|
||||||
@@ -1134,11 +1075,7 @@ async def get_installed_plugins(
|
|||||||
|
|
||||||
logger.info(f"找到 {len(installed_plugins)} 个已安装插件")
|
logger.info(f"找到 {len(installed_plugins)} 个已安装插件")
|
||||||
|
|
||||||
return {
|
return {"success": True, "plugins": installed_plugins, "total": len(installed_plugins)}
|
||||||
"success": True,
|
|
||||||
"plugins": installed_plugins,
|
|
||||||
"total": len(installed_plugins)
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取已安装插件列表失败: {e}", exc_info=True)
|
logger.error(f"获取已安装插件列表失败: {e}", exc_info=True)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
提供系统重启、状态查询等功能
|
提供系统重启、状态查询等功能
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
@@ -19,12 +20,14 @@ _start_time = time.time()
|
|||||||
|
|
||||||
class RestartResponse(BaseModel):
|
class RestartResponse(BaseModel):
|
||||||
"""重启响应"""
|
"""重启响应"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
class StatusResponse(BaseModel):
|
class StatusResponse(BaseModel):
|
||||||
"""状态响应"""
|
"""状态响应"""
|
||||||
|
|
||||||
running: bool
|
running: bool
|
||||||
uptime: float
|
uptime: float
|
||||||
version: str
|
version: str
|
||||||
@@ -52,15 +55,9 @@ async def restart_maibot():
|
|||||||
# 但我们仍然返回它以保持 API 一致性
|
# 但我们仍然返回它以保持 API 一致性
|
||||||
os.execv(python, args)
|
os.execv(python, args)
|
||||||
|
|
||||||
return RestartResponse(
|
return RestartResponse(success=True, message="麦麦正在重启中...")
|
||||||
success=True,
|
|
||||||
message="麦麦正在重启中..."
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=500, detail=f"重启失败: {str(e)}") from e
|
||||||
status_code=500,
|
|
||||||
detail=f"重启失败: {str(e)}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/status", response_model=StatusResponse)
|
@router.get("/status", response_model=StatusResponse)
|
||||||
@@ -77,20 +74,15 @@ async def get_maibot_status():
|
|||||||
version = MMC_VERSION # 可以从配置或常量中读取
|
version = MMC_VERSION # 可以从配置或常量中读取
|
||||||
|
|
||||||
return StatusResponse(
|
return StatusResponse(
|
||||||
running=True,
|
running=True, uptime=uptime, version=version, start_time=datetime.fromtimestamp(_start_time).isoformat()
|
||||||
uptime=uptime,
|
|
||||||
version=version,
|
|
||||||
start_time=datetime.fromtimestamp(_start_time).isoformat()
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}") from e
|
||||||
status_code=500,
|
|
||||||
detail=f"获取状态失败: {str(e)}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
# 可选:添加更多系统控制功能
|
# 可选:添加更多系统控制功能
|
||||||
|
|
||||||
|
|
||||||
@router.post("/reload-config")
|
@router.post("/reload-config")
|
||||||
async def reload_config():
|
async def reload_config():
|
||||||
"""
|
"""
|
||||||
@@ -102,7 +94,4 @@ async def reload_config():
|
|||||||
# 这里需要调用主程序的配置重载函数
|
# 这里需要调用主程序的配置重载函数
|
||||||
# 示例:await app_instance.reload_config()
|
# 示例:await app_instance.reload_config()
|
||||||
|
|
||||||
return {
|
return {"success": True, "message": "配置重载功能待实现"}
|
||||||
"success": True,
|
|
||||||
"message": "配置重载功能待实现"
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""WebUI API 路由"""
|
"""WebUI API 路由"""
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Header
|
from fastapi import APIRouter, HTTPException, Header
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -38,28 +39,33 @@ router.include_router(system_router)
|
|||||||
|
|
||||||
class TokenVerifyRequest(BaseModel):
|
class TokenVerifyRequest(BaseModel):
|
||||||
"""Token 验证请求"""
|
"""Token 验证请求"""
|
||||||
|
|
||||||
token: str = Field(..., description="访问令牌")
|
token: str = Field(..., description="访问令牌")
|
||||||
|
|
||||||
|
|
||||||
class TokenVerifyResponse(BaseModel):
|
class TokenVerifyResponse(BaseModel):
|
||||||
"""Token 验证响应"""
|
"""Token 验证响应"""
|
||||||
|
|
||||||
valid: bool = Field(..., description="Token 是否有效")
|
valid: bool = Field(..., description="Token 是否有效")
|
||||||
message: str = Field(..., description="验证结果消息")
|
message: str = Field(..., description="验证结果消息")
|
||||||
|
|
||||||
|
|
||||||
class TokenUpdateRequest(BaseModel):
|
class TokenUpdateRequest(BaseModel):
|
||||||
"""Token 更新请求"""
|
"""Token 更新请求"""
|
||||||
|
|
||||||
new_token: str = Field(..., description="新的访问令牌", min_length=10)
|
new_token: str = Field(..., description="新的访问令牌", min_length=10)
|
||||||
|
|
||||||
|
|
||||||
class TokenUpdateResponse(BaseModel):
|
class TokenUpdateResponse(BaseModel):
|
||||||
"""Token 更新响应"""
|
"""Token 更新响应"""
|
||||||
|
|
||||||
success: bool = Field(..., description="是否更新成功")
|
success: bool = Field(..., description="是否更新成功")
|
||||||
message: str = Field(..., description="更新结果消息")
|
message: str = Field(..., description="更新结果消息")
|
||||||
|
|
||||||
|
|
||||||
class TokenRegenerateResponse(BaseModel):
|
class TokenRegenerateResponse(BaseModel):
|
||||||
"""Token 重新生成响应"""
|
"""Token 重新生成响应"""
|
||||||
|
|
||||||
success: bool = Field(..., description="是否生成成功")
|
success: bool = Field(..., description="是否生成成功")
|
||||||
token: str = Field(..., description="新生成的令牌")
|
token: str = Field(..., description="新生成的令牌")
|
||||||
message: str = Field(..., description="生成结果消息")
|
message: str = Field(..., description="生成结果消息")
|
||||||
@@ -67,18 +73,21 @@ class TokenRegenerateResponse(BaseModel):
|
|||||||
|
|
||||||
class FirstSetupStatusResponse(BaseModel):
|
class FirstSetupStatusResponse(BaseModel):
|
||||||
"""首次配置状态响应"""
|
"""首次配置状态响应"""
|
||||||
|
|
||||||
is_first_setup: bool = Field(..., description="是否为首次配置")
|
is_first_setup: bool = Field(..., description="是否为首次配置")
|
||||||
message: str = Field(..., description="状态消息")
|
message: str = Field(..., description="状态消息")
|
||||||
|
|
||||||
|
|
||||||
class CompleteSetupResponse(BaseModel):
|
class CompleteSetupResponse(BaseModel):
|
||||||
"""完成配置响应"""
|
"""完成配置响应"""
|
||||||
|
|
||||||
success: bool = Field(..., description="是否成功")
|
success: bool = Field(..., description="是否成功")
|
||||||
message: str = Field(..., description="结果消息")
|
message: str = Field(..., description="结果消息")
|
||||||
|
|
||||||
|
|
||||||
class ResetSetupResponse(BaseModel):
|
class ResetSetupResponse(BaseModel):
|
||||||
"""重置配置响应"""
|
"""重置配置响应"""
|
||||||
|
|
||||||
success: bool = Field(..., description="是否成功")
|
success: bool = Field(..., description="是否成功")
|
||||||
message: str = Field(..., description="结果消息")
|
message: str = Field(..., description="结果消息")
|
||||||
|
|
||||||
@@ -105,25 +114,16 @@ async def verify_token(request: TokenVerifyRequest):
|
|||||||
is_valid = token_manager.verify_token(request.token)
|
is_valid = token_manager.verify_token(request.token)
|
||||||
|
|
||||||
if is_valid:
|
if is_valid:
|
||||||
return TokenVerifyResponse(
|
return TokenVerifyResponse(valid=True, message="Token 验证成功")
|
||||||
valid=True,
|
|
||||||
message="Token 验证成功"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return TokenVerifyResponse(
|
return TokenVerifyResponse(valid=False, message="Token 无效或已过期")
|
||||||
valid=False,
|
|
||||||
message="Token 无效或已过期"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Token 验证失败: {e}")
|
logger.error(f"Token 验证失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail="Token 验证失败") from e
|
raise HTTPException(status_code=500, detail="Token 验证失败") from e
|
||||||
|
|
||||||
|
|
||||||
@router.post("/auth/update", response_model=TokenUpdateResponse)
|
@router.post("/auth/update", response_model=TokenUpdateResponse)
|
||||||
async def update_token(
|
async def update_token(request: TokenUpdateRequest, authorization: Optional[str] = Header(None)):
|
||||||
request: TokenUpdateRequest,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
更新访问令牌(需要当前有效的 token)
|
更新访问令牌(需要当前有效的 token)
|
||||||
|
|
||||||
@@ -148,10 +148,7 @@ async def update_token(
|
|||||||
# 更新 token
|
# 更新 token
|
||||||
success, message = token_manager.update_token(request.new_token)
|
success, message = token_manager.update_token(request.new_token)
|
||||||
|
|
||||||
return TokenUpdateResponse(
|
return TokenUpdateResponse(success=success, message=message)
|
||||||
success=success,
|
|
||||||
message=message
|
|
||||||
)
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -184,11 +181,7 @@ async def regenerate_token(authorization: Optional[str] = Header(None)):
|
|||||||
# 重新生成 token
|
# 重新生成 token
|
||||||
new_token = token_manager.regenerate_token()
|
new_token = token_manager.regenerate_token()
|
||||||
|
|
||||||
return TokenRegenerateResponse(
|
return TokenRegenerateResponse(success=True, token=new_token, message="Token 已重新生成")
|
||||||
success=True,
|
|
||||||
token=new_token,
|
|
||||||
message="Token 已重新生成"
|
|
||||||
)
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -221,10 +214,7 @@ async def get_setup_status(authorization: Optional[str] = Header(None)):
|
|||||||
# 检查是否为首次配置
|
# 检查是否为首次配置
|
||||||
is_first = token_manager.is_first_setup()
|
is_first = token_manager.is_first_setup()
|
||||||
|
|
||||||
return FirstSetupStatusResponse(
|
return FirstSetupStatusResponse(is_first_setup=is_first, message="首次配置" if is_first else "已完成配置")
|
||||||
is_first_setup=is_first,
|
|
||||||
message="首次配置" if is_first else "已完成配置"
|
|
||||||
)
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -257,10 +247,7 @@ async def complete_setup(authorization: Optional[str] = Header(None)):
|
|||||||
# 标记配置完成
|
# 标记配置完成
|
||||||
success = token_manager.mark_setup_completed()
|
success = token_manager.mark_setup_completed()
|
||||||
|
|
||||||
return CompleteSetupResponse(
|
return CompleteSetupResponse(success=success, message="配置已完成" if success else "标记失败")
|
||||||
success=success,
|
|
||||||
message="配置已完成" if success else "标记失败"
|
|
||||||
)
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -293,10 +280,7 @@ async def reset_setup(authorization: Optional[str] = Header(None)):
|
|||||||
# 重置配置状态
|
# 重置配置状态
|
||||||
success = token_manager.reset_setup_status()
|
success = token_manager.reset_setup_status()
|
||||||
|
|
||||||
return ResetSetupResponse(
|
return ResetSetupResponse(success=success, message="配置状态已重置" if success else "重置失败")
|
||||||
success=success,
|
|
||||||
message="配置状态已重置" if success else "重置失败"
|
|
||||||
)
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""统计数据 API 路由"""
|
"""统计数据 API 路由"""
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Dict, Any, List
|
from typing import Dict, Any, List
|
||||||
@@ -15,6 +16,7 @@ router = APIRouter(prefix="/statistics", tags=["statistics"])
|
|||||||
|
|
||||||
class StatisticsSummary(BaseModel):
|
class StatisticsSummary(BaseModel):
|
||||||
"""统计数据摘要"""
|
"""统计数据摘要"""
|
||||||
|
|
||||||
total_requests: int = Field(0, description="总请求数")
|
total_requests: int = Field(0, description="总请求数")
|
||||||
total_cost: float = Field(0.0, description="总花费")
|
total_cost: float = Field(0.0, description="总花费")
|
||||||
total_tokens: int = Field(0, description="总token数")
|
total_tokens: int = Field(0, description="总token数")
|
||||||
@@ -28,6 +30,7 @@ class StatisticsSummary(BaseModel):
|
|||||||
|
|
||||||
class ModelStatistics(BaseModel):
|
class ModelStatistics(BaseModel):
|
||||||
"""模型统计"""
|
"""模型统计"""
|
||||||
|
|
||||||
model_name: str
|
model_name: str
|
||||||
request_count: int
|
request_count: int
|
||||||
total_cost: float
|
total_cost: float
|
||||||
@@ -37,6 +40,7 @@ class ModelStatistics(BaseModel):
|
|||||||
|
|
||||||
class TimeSeriesData(BaseModel):
|
class TimeSeriesData(BaseModel):
|
||||||
"""时间序列数据"""
|
"""时间序列数据"""
|
||||||
|
|
||||||
timestamp: str
|
timestamp: str
|
||||||
requests: int = 0
|
requests: int = 0
|
||||||
cost: float = 0.0
|
cost: float = 0.0
|
||||||
@@ -45,6 +49,7 @@ class TimeSeriesData(BaseModel):
|
|||||||
|
|
||||||
class DashboardData(BaseModel):
|
class DashboardData(BaseModel):
|
||||||
"""仪表盘数据"""
|
"""仪表盘数据"""
|
||||||
|
|
||||||
summary: StatisticsSummary
|
summary: StatisticsSummary
|
||||||
model_stats: List[ModelStatistics]
|
model_stats: List[ModelStatistics]
|
||||||
hourly_data: List[TimeSeriesData]
|
hourly_data: List[TimeSeriesData]
|
||||||
@@ -88,7 +93,7 @@ async def get_dashboard_data(hours: int = 24):
|
|||||||
model_stats=model_stats,
|
model_stats=model_stats,
|
||||||
hourly_data=hourly_data,
|
hourly_data=hourly_data,
|
||||||
daily_data=daily_data,
|
daily_data=daily_data,
|
||||||
recent_activity=recent_activity
|
recent_activity=recent_activity,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取仪表盘数据失败: {e}")
|
logger.error(f"获取仪表盘数据失败: {e}")
|
||||||
@@ -100,11 +105,7 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S
|
|||||||
summary = StatisticsSummary()
|
summary = StatisticsSummary()
|
||||||
|
|
||||||
# 查询 LLM 使用记录
|
# 查询 LLM 使用记录
|
||||||
llm_records = list(
|
llm_records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time).where(LLMUsage.timestamp <= end_time))
|
||||||
LLMUsage.select()
|
|
||||||
.where(LLMUsage.timestamp >= start_time)
|
|
||||||
.where(LLMUsage.timestamp <= end_time)
|
|
||||||
)
|
|
||||||
|
|
||||||
total_time_cost = 0.0
|
total_time_cost = 0.0
|
||||||
time_cost_count = 0
|
time_cost_count = 0
|
||||||
@@ -124,11 +125,7 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S
|
|||||||
|
|
||||||
# 查询在线时间
|
# 查询在线时间
|
||||||
online_records = list(
|
online_records = list(
|
||||||
OnlineTime.select()
|
OnlineTime.select().where((OnlineTime.start_timestamp >= start_time) | (OnlineTime.end_timestamp >= start_time))
|
||||||
.where(
|
|
||||||
(OnlineTime.start_timestamp >= start_time) |
|
|
||||||
(OnlineTime.end_timestamp >= start_time)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for record in online_records:
|
for record in online_records:
|
||||||
@@ -139,9 +136,7 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S
|
|||||||
|
|
||||||
# 查询消息数量
|
# 查询消息数量
|
||||||
messages = list(
|
messages = list(
|
||||||
Messages.select()
|
Messages.select().where(Messages.time >= start_time.timestamp()).where(Messages.time <= end_time.timestamp())
|
||||||
.where(Messages.time >= start_time.timestamp())
|
|
||||||
.where(Messages.time <= end_time.timestamp())
|
|
||||||
)
|
)
|
||||||
|
|
||||||
summary.total_messages = len(messages)
|
summary.total_messages = len(messages)
|
||||||
@@ -159,38 +154,32 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S
|
|||||||
|
|
||||||
async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
|
async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
|
||||||
"""获取模型统计数据"""
|
"""获取模型统计数据"""
|
||||||
model_data = defaultdict(lambda: {
|
model_data = defaultdict(lambda: {"request_count": 0, "total_cost": 0.0, "total_tokens": 0, "time_costs": []})
|
||||||
'request_count': 0,
|
|
||||||
'total_cost': 0.0,
|
|
||||||
'total_tokens': 0,
|
|
||||||
'time_costs': []
|
|
||||||
})
|
|
||||||
|
|
||||||
records = list(
|
records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time))
|
||||||
LLMUsage.select()
|
|
||||||
.where(LLMUsage.timestamp >= start_time)
|
|
||||||
)
|
|
||||||
|
|
||||||
for record in records:
|
for record in records:
|
||||||
model_name = record.model_assign_name or record.model_name or "unknown"
|
model_name = record.model_assign_name or record.model_name or "unknown"
|
||||||
model_data[model_name]['request_count'] += 1
|
model_data[model_name]["request_count"] += 1
|
||||||
model_data[model_name]['total_cost'] += record.cost or 0.0
|
model_data[model_name]["total_cost"] += record.cost or 0.0
|
||||||
model_data[model_name]['total_tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
model_data[model_name]["total_tokens"] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
||||||
|
|
||||||
if record.time_cost and record.time_cost > 0:
|
if record.time_cost and record.time_cost > 0:
|
||||||
model_data[model_name]['time_costs'].append(record.time_cost)
|
model_data[model_name]["time_costs"].append(record.time_cost)
|
||||||
|
|
||||||
# 转换为列表并排序
|
# 转换为列表并排序
|
||||||
result = []
|
result = []
|
||||||
for model_name, data in model_data.items():
|
for model_name, data in model_data.items():
|
||||||
avg_time = sum(data['time_costs']) / len(data['time_costs']) if data['time_costs'] else 0.0
|
avg_time = sum(data["time_costs"]) / len(data["time_costs"]) if data["time_costs"] else 0.0
|
||||||
result.append(ModelStatistics(
|
result.append(
|
||||||
|
ModelStatistics(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
request_count=data['request_count'],
|
request_count=data["request_count"],
|
||||||
total_cost=data['total_cost'],
|
total_cost=data["total_cost"],
|
||||||
total_tokens=data['total_tokens'],
|
total_tokens=data["total_tokens"],
|
||||||
avg_response_time=avg_time
|
avg_response_time=avg_time,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# 按请求数排序
|
# 按请求数排序
|
||||||
result.sort(key=lambda x: x.request_count, reverse=True)
|
result.sort(key=lambda x: x.request_count, reverse=True)
|
||||||
@@ -200,35 +189,28 @@ async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
|
|||||||
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
||||||
"""获取小时级统计数据"""
|
"""获取小时级统计数据"""
|
||||||
# 创建小时桶
|
# 创建小时桶
|
||||||
hourly_buckets = defaultdict(lambda: {'requests': 0, 'cost': 0.0, 'tokens': 0})
|
hourly_buckets = defaultdict(lambda: {"requests": 0, "cost": 0.0, "tokens": 0})
|
||||||
|
|
||||||
records = list(
|
records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time).where(LLMUsage.timestamp <= end_time))
|
||||||
LLMUsage.select()
|
|
||||||
.where(LLMUsage.timestamp >= start_time)
|
|
||||||
.where(LLMUsage.timestamp <= end_time)
|
|
||||||
)
|
|
||||||
|
|
||||||
for record in records:
|
for record in records:
|
||||||
# 获取小时键(去掉分钟和秒)
|
# 获取小时键(去掉分钟和秒)
|
||||||
hour_key = record.timestamp.replace(minute=0, second=0, microsecond=0)
|
hour_key = record.timestamp.replace(minute=0, second=0, microsecond=0)
|
||||||
hour_str = hour_key.isoformat()
|
hour_str = hour_key.isoformat()
|
||||||
|
|
||||||
hourly_buckets[hour_str]['requests'] += 1
|
hourly_buckets[hour_str]["requests"] += 1
|
||||||
hourly_buckets[hour_str]['cost'] += record.cost or 0.0
|
hourly_buckets[hour_str]["cost"] += record.cost or 0.0
|
||||||
hourly_buckets[hour_str]['tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
hourly_buckets[hour_str]["tokens"] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
||||||
|
|
||||||
# 填充所有小时(包括没有数据的)
|
# 填充所有小时(包括没有数据的)
|
||||||
result = []
|
result = []
|
||||||
current = start_time.replace(minute=0, second=0, microsecond=0)
|
current = start_time.replace(minute=0, second=0, microsecond=0)
|
||||||
while current <= end_time:
|
while current <= end_time:
|
||||||
hour_str = current.isoformat()
|
hour_str = current.isoformat()
|
||||||
data = hourly_buckets.get(hour_str, {'requests': 0, 'cost': 0.0, 'tokens': 0})
|
data = hourly_buckets.get(hour_str, {"requests": 0, "cost": 0.0, "tokens": 0})
|
||||||
result.append(TimeSeriesData(
|
result.append(
|
||||||
timestamp=hour_str,
|
TimeSeriesData(timestamp=hour_str, requests=data["requests"], cost=data["cost"], tokens=data["tokens"])
|
||||||
requests=data['requests'],
|
)
|
||||||
cost=data['cost'],
|
|
||||||
tokens=data['tokens']
|
|
||||||
))
|
|
||||||
current += timedelta(hours=1)
|
current += timedelta(hours=1)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -236,35 +218,28 @@ async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> Li
|
|||||||
|
|
||||||
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
||||||
"""获取日级统计数据"""
|
"""获取日级统计数据"""
|
||||||
daily_buckets = defaultdict(lambda: {'requests': 0, 'cost': 0.0, 'tokens': 0})
|
daily_buckets = defaultdict(lambda: {"requests": 0, "cost": 0.0, "tokens": 0})
|
||||||
|
|
||||||
records = list(
|
records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time).where(LLMUsage.timestamp <= end_time))
|
||||||
LLMUsage.select()
|
|
||||||
.where(LLMUsage.timestamp >= start_time)
|
|
||||||
.where(LLMUsage.timestamp <= end_time)
|
|
||||||
)
|
|
||||||
|
|
||||||
for record in records:
|
for record in records:
|
||||||
# 获取日期键
|
# 获取日期键
|
||||||
day_key = record.timestamp.replace(hour=0, minute=0, second=0, microsecond=0)
|
day_key = record.timestamp.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
day_str = day_key.isoformat()
|
day_str = day_key.isoformat()
|
||||||
|
|
||||||
daily_buckets[day_str]['requests'] += 1
|
daily_buckets[day_str]["requests"] += 1
|
||||||
daily_buckets[day_str]['cost'] += record.cost or 0.0
|
daily_buckets[day_str]["cost"] += record.cost or 0.0
|
||||||
daily_buckets[day_str]['tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
daily_buckets[day_str]["tokens"] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
||||||
|
|
||||||
# 填充所有天
|
# 填充所有天
|
||||||
result = []
|
result = []
|
||||||
current = start_time.replace(hour=0, minute=0, second=0, microsecond=0)
|
current = start_time.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
while current <= end_time:
|
while current <= end_time:
|
||||||
day_str = current.isoformat()
|
day_str = current.isoformat()
|
||||||
data = daily_buckets.get(day_str, {'requests': 0, 'cost': 0.0, 'tokens': 0})
|
data = daily_buckets.get(day_str, {"requests": 0, "cost": 0.0, "tokens": 0})
|
||||||
result.append(TimeSeriesData(
|
result.append(
|
||||||
timestamp=day_str,
|
TimeSeriesData(timestamp=day_str, requests=data["requests"], cost=data["cost"], tokens=data["tokens"])
|
||||||
requests=data['requests'],
|
)
|
||||||
cost=data['cost'],
|
|
||||||
tokens=data['tokens']
|
|
||||||
))
|
|
||||||
current += timedelta(days=1)
|
current += timedelta(days=1)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -272,23 +247,21 @@ async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> Lis
|
|||||||
|
|
||||||
async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
|
async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
|
||||||
"""获取最近活动"""
|
"""获取最近活动"""
|
||||||
records = list(
|
records = list(LLMUsage.select().order_by(LLMUsage.timestamp.desc()).limit(limit))
|
||||||
LLMUsage.select()
|
|
||||||
.order_by(LLMUsage.timestamp.desc())
|
|
||||||
.limit(limit)
|
|
||||||
)
|
|
||||||
|
|
||||||
activities = []
|
activities = []
|
||||||
for record in records:
|
for record in records:
|
||||||
activities.append({
|
activities.append(
|
||||||
'timestamp': record.timestamp.isoformat(),
|
{
|
||||||
'model': record.model_assign_name or record.model_name,
|
"timestamp": record.timestamp.isoformat(),
|
||||||
'request_type': record.request_type,
|
"model": record.model_assign_name or record.model_name,
|
||||||
'tokens': (record.prompt_tokens or 0) + (record.completion_tokens or 0),
|
"request_type": record.request_type,
|
||||||
'cost': record.cost or 0.0,
|
"tokens": (record.prompt_tokens or 0) + (record.completion_tokens or 0),
|
||||||
'time_cost': record.time_cost or 0.0,
|
"cost": record.cost or 0.0,
|
||||||
'status': record.status
|
"time_cost": record.time_cost or 0.0,
|
||||||
})
|
"status": record.status,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return activities
|
return activities
|
||||||
|
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class TokenManager:
|
|||||||
"access_token": token,
|
"access_token": token,
|
||||||
"created_at": self._get_current_timestamp(),
|
"created_at": self._get_current_timestamp(),
|
||||||
"updated_at": self._get_current_timestamp(),
|
"updated_at": self._get_current_timestamp(),
|
||||||
"first_setup_completed": False # 标记首次配置未完成
|
"first_setup_completed": False, # 标记首次配置未完成
|
||||||
}
|
}
|
||||||
|
|
||||||
self._save_config(config)
|
self._save_config(config)
|
||||||
@@ -91,6 +91,7 @@ class TokenManager:
|
|||||||
def _get_current_timestamp(self) -> str:
|
def _get_current_timestamp(self) -> str:
|
||||||
"""获取当前时间戳字符串"""
|
"""获取当前时间戳字符串"""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
return datetime.now().isoformat()
|
return datetime.now().isoformat()
|
||||||
|
|
||||||
def get_token(self) -> str:
|
def get_token(self) -> str:
|
||||||
|
|||||||
Reference in New Issue
Block a user