fix typing of utils_model.py

This commit is contained in:
UnCLAS-Prommer
2025-09-17 15:59:02 +08:00
parent 91e716a24c
commit 1260a11b78
2 changed files with 40 additions and 18 deletions

View File

@@ -14,6 +14,7 @@ from src.plugin_system import (
MaiMessages, MaiMessages,
ToolParamType, ToolParamType,
ReplyContentType, ReplyContentType,
emoji_api,
) )
from src.config.config import global_config from src.config.config import global_config
@@ -181,7 +182,26 @@ class ForwardMessages(BaseEventHandler):
raise ValueError("转发消息失败") raise ValueError("转发消息失败")
self.messages = [] self.messages = []
return True, True, None, None, None return True, True, None, None, None
class RandomEmojis(BaseCommand):
command_name = "random_emojis"
command_description = "发送多张随机表情包"
command_pattern = r"^/random_emojis$"
async def execute(self):
emojis = await emoji_api.get_random(5)
if not emojis:
return False, "未找到表情包", False
emoji_base64_list = []
for emoji in emojis:
emoji_base64_list.append(emoji[0])
return await self.forward_images(emoji_base64_list)
async def forward_images(self, images: List[str]):
"""
把多张图片用合并转发的方式发给用户
"""
success = await self.send_forward([("0", "神秘用户", [(ReplyContentType.IMAGE, img)]) for img in images])
return (True, "已发送随机表情包", True) if success else (False, "发送随机表情包失败", False)
# ===== 插件注册 ===== # ===== 插件注册 =====
@@ -225,6 +245,7 @@ class HelloWorldPlugin(BasePlugin):
(TimeCommand.get_command_info(), TimeCommand), (TimeCommand.get_command_info(), TimeCommand),
(PrintMessage.get_handler_info(), PrintMessage), (PrintMessage.get_handler_info(), PrintMessage),
(ForwardMessages.get_handler_info(), ForwardMessages), (ForwardMessages.get_handler_info(), ForwardMessages),
(RandomEmojis.get_command_info(), RandomEmojis),
] ]

View File

@@ -4,7 +4,7 @@ import time
from enum import Enum from enum import Enum
from rich.traceback import install from rich.traceback import install
from typing import Tuple, List, Dict, Optional, Callable, Any from typing import Tuple, List, Dict, Optional, Callable, Any, Set
import traceback import traceback
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -82,9 +82,7 @@ class LLMRequest:
message_builder = MessageBuilder() message_builder = MessageBuilder()
message_builder.add_text_content(prompt) message_builder.add_text_content(prompt)
message_builder.add_image_content( message_builder.add_image_content(
image_base64=image_base64, image_base64=image_base64, image_format=image_format, support_formats=client.get_support_image_formats()
image_format=image_format,
support_formats=client.get_support_image_formats()
) )
return [message_builder.build()] return [message_builder.build()]
@@ -145,7 +143,7 @@ class LLMRequest:
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
""" """
start_time = time.time() start_time = time.time()
def message_factory(client: BaseClient) -> List[Message]: def message_factory(client: BaseClient) -> List[Message]:
message_builder = MessageBuilder() message_builder = MessageBuilder()
message_builder.add_text_content(prompt) message_builder.add_text_content(prompt)
@@ -177,7 +175,7 @@ class LLMRequest:
endpoint="/chat/completions", endpoint="/chat/completions",
time_cost=time.time() - start_time, time_cost=time.time() - start_time,
) )
return content, (reasoning_content, model_info.name, tool_calls) return content or "", (reasoning_content, model_info.name, tool_calls)
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]: async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
""" """
@@ -206,7 +204,7 @@ class LLMRequest:
raise RuntimeError("获取embedding失败") raise RuntimeError("获取embedding失败")
return embedding, model_info.name return embedding, model_info.name
def _select_model(self, exclude_models: set = None) -> Tuple[ModelInfo, APIProvider, BaseClient]: def _select_model(self, exclude_models: Optional[Set[str]] = None) -> Tuple[ModelInfo, APIProvider, BaseClient]:
""" """
根据总tokens和惩罚值选择的模型 根据总tokens和惩罚值选择的模型
""" """
@@ -224,7 +222,7 @@ class LLMRequest:
) )
model_info = model_config.get_model_info(least_used_model_name) model_info = model_config.get_model_info(least_used_model_name)
api_provider = model_config.get_provider(model_info.api_provider) api_provider = model_config.get_provider(model_info.api_provider)
force_new_client = (self.request_type == "embedding") force_new_client = self.request_type == "embedding"
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
logger.debug(f"选择请求模型: {model_info.name}") logger.debug(f"选择请求模型: {model_info.name}")
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
@@ -246,13 +244,13 @@ class LLMRequest:
max_tokens: Optional[int], max_tokens: Optional[int],
embedding_input: str | None, embedding_input: str | None,
audio_base64: str | None, audio_base64: str | None,
compressed_messages: Optional[List[Message]] = None,
) -> APIResponse: ) -> APIResponse:
""" """
在单个模型上执行请求,包含针对临时错误的重试逻辑。 在单个模型上执行请求,包含针对临时错误的重试逻辑。
如果成功返回APIResponse。如果失败重试耗尽或硬错误则抛出ModelAttemptFailed异常。 如果成功返回APIResponse。如果失败重试耗尽或硬错误则抛出ModelAttemptFailed异常。
""" """
retry_remain = api_provider.max_retry retry_remain = api_provider.max_retry
compressed_messages: Optional[List[Message]] = None
while retry_remain > 0: while retry_remain > 0:
try: try:
@@ -299,7 +297,9 @@ class LLMRequest:
logger.error(f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。") logger.error(f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。")
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning(f"模型 '{model_info.name}' 遇到可重试的HTTP错误: {str(e)}。剩余重试次数: {retry_remain}") logger.warning(
f"模型 '{model_info.name}' 遇到可重试的HTTP错误: {str(e)}。剩余重试次数: {retry_remain}"
)
await asyncio.sleep(api_provider.retry_interval) await asyncio.sleep(api_provider.retry_interval)
continue continue
@@ -315,8 +315,8 @@ class LLMRequest:
raise ModelAttemptFailed(f"模型 '{model_info.name}' 遇到硬错误", original_exception=e) from e raise ModelAttemptFailed(f"模型 '{model_info.name}' 遇到硬错误", original_exception=e) from e
except Exception as e: except Exception as e:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
logger.warning(f"模型 '{model_info.name}' 遇到未知的不可重试错误: {str(e)}") logger.warning(f"模型 '{model_info.name}' 遇到未知的不可重试错误: {str(e)}")
raise ModelAttemptFailed(f"模型 '{model_info.name}' 遇到硬错误", original_exception=e) from e raise ModelAttemptFailed(f"模型 '{model_info.name}' 遇到硬错误", original_exception=e) from e
@@ -338,12 +338,11 @@ class LLMRequest:
""" """
调度器函数,负责模型选择、故障切换。 调度器函数,负责模型选择、故障切换。
""" """
failed_models_this_request = set() failed_models_this_request: Set[str] = set()
max_attempts = len(self.model_for_task.model_list) max_attempts = len(self.model_for_task.model_list)
last_exception: Optional[Exception] = None last_exception: Optional[Exception] = None
compressed_messages: Optional[List[Message]] = None
for _attempt in range(max_attempts): for _ in range(max_attempts):
model_info, api_provider, client = self._select_model(exclude_models=failed_models_this_request) model_info, api_provider, client = self._select_model(exclude_models=failed_models_this_request)
message_list = [] message_list = []
@@ -352,7 +351,10 @@ class LLMRequest:
try: try:
response = await self._attempt_request_on_model( response = await self._attempt_request_on_model(
model_info, api_provider, client, request_type, model_info,
api_provider,
client,
request_type,
message_list=message_list, message_list=message_list,
tool_options=tool_options, tool_options=tool_options,
response_format=response_format, response_format=response_format,
@@ -362,7 +364,6 @@ class LLMRequest:
max_tokens=max_tokens, max_tokens=max_tokens,
embedding_input=embedding_input, embedding_input=embedding_input,
audio_base64=audio_base64, audio_base64=audio_base64,
compressed_messages=compressed_messages,
) )
return response, model_info return response, model_info
@@ -430,4 +431,4 @@ class LLMRequest:
match = re.search(r"(?:<think>)?(.*?)</think>", content, re.DOTALL) match = re.search(r"(?:<think>)?(.*?)</think>", content, re.DOTALL)
content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip() content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
reasoning = match[1].strip() if match else "" reasoning = match[1].strip() if match else ""
return content, reasoning return content, reasoning