fix typing of utils_model.py
This commit is contained in:
@@ -14,6 +14,7 @@ from src.plugin_system import (
|
|||||||
MaiMessages,
|
MaiMessages,
|
||||||
ToolParamType,
|
ToolParamType,
|
||||||
ReplyContentType,
|
ReplyContentType,
|
||||||
|
emoji_api,
|
||||||
)
|
)
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
@@ -181,7 +182,26 @@ class ForwardMessages(BaseEventHandler):
|
|||||||
raise ValueError("转发消息失败")
|
raise ValueError("转发消息失败")
|
||||||
self.messages = []
|
self.messages = []
|
||||||
return True, True, None, None, None
|
return True, True, None, None, None
|
||||||
|
class RandomEmojis(BaseCommand):
|
||||||
|
command_name = "random_emojis"
|
||||||
|
command_description = "发送多张随机表情包"
|
||||||
|
command_pattern = r"^/random_emojis$"
|
||||||
|
|
||||||
|
async def execute(self):
|
||||||
|
emojis = await emoji_api.get_random(5)
|
||||||
|
if not emojis:
|
||||||
|
return False, "未找到表情包", False
|
||||||
|
emoji_base64_list = []
|
||||||
|
for emoji in emojis:
|
||||||
|
emoji_base64_list.append(emoji[0])
|
||||||
|
return await self.forward_images(emoji_base64_list)
|
||||||
|
|
||||||
|
async def forward_images(self, images: List[str]):
|
||||||
|
"""
|
||||||
|
把多张图片用合并转发的方式发给用户
|
||||||
|
"""
|
||||||
|
success = await self.send_forward([("0", "神秘用户", [(ReplyContentType.IMAGE, img)]) for img in images])
|
||||||
|
return (True, "已发送随机表情包", True) if success else (False, "发送随机表情包失败", False)
|
||||||
|
|
||||||
# ===== 插件注册 =====
|
# ===== 插件注册 =====
|
||||||
|
|
||||||
@@ -225,6 +245,7 @@ class HelloWorldPlugin(BasePlugin):
|
|||||||
(TimeCommand.get_command_info(), TimeCommand),
|
(TimeCommand.get_command_info(), TimeCommand),
|
||||||
(PrintMessage.get_handler_info(), PrintMessage),
|
(PrintMessage.get_handler_info(), PrintMessage),
|
||||||
(ForwardMessages.get_handler_info(), ForwardMessages),
|
(ForwardMessages.get_handler_info(), ForwardMessages),
|
||||||
|
(RandomEmojis.get_command_info(), RandomEmojis),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import time
|
|||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from typing import Tuple, List, Dict, Optional, Callable, Any
|
from typing import Tuple, List, Dict, Optional, Callable, Any, Set
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -82,9 +82,7 @@ class LLMRequest:
|
|||||||
message_builder = MessageBuilder()
|
message_builder = MessageBuilder()
|
||||||
message_builder.add_text_content(prompt)
|
message_builder.add_text_content(prompt)
|
||||||
message_builder.add_image_content(
|
message_builder.add_image_content(
|
||||||
image_base64=image_base64,
|
image_base64=image_base64, image_format=image_format, support_formats=client.get_support_image_formats()
|
||||||
image_format=image_format,
|
|
||||||
support_formats=client.get_support_image_formats()
|
|
||||||
)
|
)
|
||||||
return [message_builder.build()]
|
return [message_builder.build()]
|
||||||
|
|
||||||
@@ -145,7 +143,7 @@ class LLMRequest:
|
|||||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
def message_factory(client: BaseClient) -> List[Message]:
|
def message_factory(client: BaseClient) -> List[Message]:
|
||||||
message_builder = MessageBuilder()
|
message_builder = MessageBuilder()
|
||||||
message_builder.add_text_content(prompt)
|
message_builder.add_text_content(prompt)
|
||||||
@@ -177,7 +175,7 @@ class LLMRequest:
|
|||||||
endpoint="/chat/completions",
|
endpoint="/chat/completions",
|
||||||
time_cost=time.time() - start_time,
|
time_cost=time.time() - start_time,
|
||||||
)
|
)
|
||||||
return content, (reasoning_content, model_info.name, tool_calls)
|
return content or "", (reasoning_content, model_info.name, tool_calls)
|
||||||
|
|
||||||
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
|
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
|
||||||
"""
|
"""
|
||||||
@@ -206,7 +204,7 @@ class LLMRequest:
|
|||||||
raise RuntimeError("获取embedding失败")
|
raise RuntimeError("获取embedding失败")
|
||||||
return embedding, model_info.name
|
return embedding, model_info.name
|
||||||
|
|
||||||
def _select_model(self, exclude_models: set = None) -> Tuple[ModelInfo, APIProvider, BaseClient]:
|
def _select_model(self, exclude_models: Optional[Set[str]] = None) -> Tuple[ModelInfo, APIProvider, BaseClient]:
|
||||||
"""
|
"""
|
||||||
根据总tokens和惩罚值选择的模型
|
根据总tokens和惩罚值选择的模型
|
||||||
"""
|
"""
|
||||||
@@ -224,7 +222,7 @@ class LLMRequest:
|
|||||||
)
|
)
|
||||||
model_info = model_config.get_model_info(least_used_model_name)
|
model_info = model_config.get_model_info(least_used_model_name)
|
||||||
api_provider = model_config.get_provider(model_info.api_provider)
|
api_provider = model_config.get_provider(model_info.api_provider)
|
||||||
force_new_client = (self.request_type == "embedding")
|
force_new_client = self.request_type == "embedding"
|
||||||
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
|
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
|
||||||
logger.debug(f"选择请求模型: {model_info.name}")
|
logger.debug(f"选择请求模型: {model_info.name}")
|
||||||
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||||
@@ -246,13 +244,13 @@ class LLMRequest:
|
|||||||
max_tokens: Optional[int],
|
max_tokens: Optional[int],
|
||||||
embedding_input: str | None,
|
embedding_input: str | None,
|
||||||
audio_base64: str | None,
|
audio_base64: str | None,
|
||||||
compressed_messages: Optional[List[Message]] = None,
|
|
||||||
) -> APIResponse:
|
) -> APIResponse:
|
||||||
"""
|
"""
|
||||||
在单个模型上执行请求,包含针对临时错误的重试逻辑。
|
在单个模型上执行请求,包含针对临时错误的重试逻辑。
|
||||||
如果成功,返回APIResponse。如果失败(重试耗尽或硬错误),则抛出ModelAttemptFailed异常。
|
如果成功,返回APIResponse。如果失败(重试耗尽或硬错误),则抛出ModelAttemptFailed异常。
|
||||||
"""
|
"""
|
||||||
retry_remain = api_provider.max_retry
|
retry_remain = api_provider.max_retry
|
||||||
|
compressed_messages: Optional[List[Message]] = None
|
||||||
|
|
||||||
while retry_remain > 0:
|
while retry_remain > 0:
|
||||||
try:
|
try:
|
||||||
@@ -299,7 +297,9 @@ class LLMRequest:
|
|||||||
logger.error(f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。")
|
logger.error(f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。")
|
||||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
||||||
|
|
||||||
logger.warning(f"模型 '{model_info.name}' 遇到可重试的HTTP错误: {str(e)}。剩余重试次数: {retry_remain}")
|
logger.warning(
|
||||||
|
f"模型 '{model_info.name}' 遇到可重试的HTTP错误: {str(e)}。剩余重试次数: {retry_remain}"
|
||||||
|
)
|
||||||
await asyncio.sleep(api_provider.retry_interval)
|
await asyncio.sleep(api_provider.retry_interval)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -315,8 +315,8 @@ class LLMRequest:
|
|||||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 遇到硬错误", original_exception=e) from e
|
raise ModelAttemptFailed(f"模型 '{model_info.name}' 遇到硬错误", original_exception=e) from e
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
logger.warning(f"模型 '{model_info.name}' 遇到未知的不可重试错误: {str(e)}")
|
logger.warning(f"模型 '{model_info.name}' 遇到未知的不可重试错误: {str(e)}")
|
||||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 遇到硬错误", original_exception=e) from e
|
raise ModelAttemptFailed(f"模型 '{model_info.name}' 遇到硬错误", original_exception=e) from e
|
||||||
|
|
||||||
@@ -338,12 +338,11 @@ class LLMRequest:
|
|||||||
"""
|
"""
|
||||||
调度器函数,负责模型选择、故障切换。
|
调度器函数,负责模型选择、故障切换。
|
||||||
"""
|
"""
|
||||||
failed_models_this_request = set()
|
failed_models_this_request: Set[str] = set()
|
||||||
max_attempts = len(self.model_for_task.model_list)
|
max_attempts = len(self.model_for_task.model_list)
|
||||||
last_exception: Optional[Exception] = None
|
last_exception: Optional[Exception] = None
|
||||||
compressed_messages: Optional[List[Message]] = None
|
|
||||||
|
|
||||||
for _attempt in range(max_attempts):
|
for _ in range(max_attempts):
|
||||||
model_info, api_provider, client = self._select_model(exclude_models=failed_models_this_request)
|
model_info, api_provider, client = self._select_model(exclude_models=failed_models_this_request)
|
||||||
|
|
||||||
message_list = []
|
message_list = []
|
||||||
@@ -352,7 +351,10 @@ class LLMRequest:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self._attempt_request_on_model(
|
response = await self._attempt_request_on_model(
|
||||||
model_info, api_provider, client, request_type,
|
model_info,
|
||||||
|
api_provider,
|
||||||
|
client,
|
||||||
|
request_type,
|
||||||
message_list=message_list,
|
message_list=message_list,
|
||||||
tool_options=tool_options,
|
tool_options=tool_options,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
@@ -362,7 +364,6 @@ class LLMRequest:
|
|||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
embedding_input=embedding_input,
|
embedding_input=embedding_input,
|
||||||
audio_base64=audio_base64,
|
audio_base64=audio_base64,
|
||||||
compressed_messages=compressed_messages,
|
|
||||||
)
|
)
|
||||||
return response, model_info
|
return response, model_info
|
||||||
|
|
||||||
@@ -430,4 +431,4 @@ class LLMRequest:
|
|||||||
match = re.search(r"(?:<think>)?(.*?)</think>", content, re.DOTALL)
|
match = re.search(r"(?:<think>)?(.*?)</think>", content, re.DOTALL)
|
||||||
content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
|
content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
|
||||||
reasoning = match[1].strip() if match else ""
|
reasoning = match[1].strip() if match else ""
|
||||||
return content, reasoning
|
return content, reasoning
|
||||||
|
|||||||
Reference in New Issue
Block a user