feat:重构maisaka的消息类型,添加打断功能
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import Any, Callable, Coroutine, Generic, Tuple, TypeVar, cast
|
||||
|
||||
import asyncio
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.model_configs import ModelInfo
|
||||
|
||||
from .base_client import (
|
||||
@@ -33,12 +34,14 @@ ProviderStreamResponseHandler = Callable[
|
||||
ProviderResponseParser = Callable[[RawResponseT], Tuple[APIResponse, UsageTuple | None]]
|
||||
"""Provider 专用非流式响应解析函数类型。"""
|
||||
|
||||
logger = get_logger("llm_adapter_base")
|
||||
|
||||
|
||||
async def await_task_with_interrupt(
|
||||
task: asyncio.Task[TaskResultT],
|
||||
interrupt_flag: asyncio.Event | None,
|
||||
*,
|
||||
interval_seconds: float = 0.1,
|
||||
interval_seconds: float = 0.02,
|
||||
) -> TaskResultT:
|
||||
"""在支持外部中断的前提下等待异步任务完成。
|
||||
|
||||
@@ -55,8 +58,11 @@ async def await_task_with_interrupt(
|
||||
"""
|
||||
from src.llm_models.exceptions import ReqAbortException
|
||||
|
||||
started_at = asyncio.get_running_loop().time()
|
||||
while not task.done():
|
||||
if interrupt_flag and interrupt_flag.is_set():
|
||||
elapsed = asyncio.get_running_loop().time() - started_at
|
||||
logger.info(f"LLM 请求检测到中断信号,准备取消底层任务,elapsed={elapsed:.3f}s")
|
||||
task.cancel()
|
||||
raise ReqAbortException("请求被外部信号中断")
|
||||
await asyncio.sleep(interval_seconds)
|
||||
|
||||
@@ -22,6 +22,7 @@ from src.llm_models.exceptions import (
|
||||
EmptyResponseException,
|
||||
ModelAttemptFailed,
|
||||
NetworkConnectionError,
|
||||
ReqAbortException,
|
||||
RespNotOkException,
|
||||
RespParseException,
|
||||
)
|
||||
@@ -326,16 +327,7 @@ class LLMOrchestrator:
|
||||
del raise_when_empty
|
||||
self._refresh_task_config()
|
||||
start_time = time.time()
|
||||
if self.request_type.startswith("maisaka_"):
|
||||
logger.info(
|
||||
f"LLMOrchestrator[{self.request_type}] 开始执行 generate_response_with_message_async "
|
||||
f"(temperature={temperature}, max_tokens={max_tokens}, tools={len(tools or [])})"
|
||||
)
|
||||
|
||||
if self.request_type.startswith("maisaka_"):
|
||||
logger.info(
|
||||
f"LLMOrchestrator[{self.request_type}] 正在根据 {len(tools or [])} 个工具构建内部工具选项"
|
||||
)
|
||||
tool_built = self._build_tool_options(tools)
|
||||
if self.request_type.startswith("maisaka_"):
|
||||
logger.info(f"LLMOrchestrator[{self.request_type}] 已构建 {len(tool_built or [])} 个内部工具选项")
|
||||
@@ -777,6 +769,9 @@ class LLMOrchestrator:
|
||||
)
|
||||
await asyncio.sleep(api_provider.retry_interval)
|
||||
|
||||
except ReqAbortException:
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
@@ -881,6 +876,15 @@ class LLMOrchestrator:
|
||||
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1)
|
||||
return LLMExecutionResult(api_response=response, model_info=model_info)
|
||||
|
||||
except ReqAbortException as e:
|
||||
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1)
|
||||
if self.request_type.startswith("maisaka_"):
|
||||
logger.info(
|
||||
f"LLMOrchestrator[{self.request_type}] 模型 model={model_info.name} 的请求已被外部信号中断"
|
||||
)
|
||||
raise e
|
||||
|
||||
except ModelAttemptFailed as e:
|
||||
last_exception = e.original_exception or e
|
||||
logger.warning(f"模型 '{model_info.name}' 尝试失败,切换到下一个模型。原因: {e}")
|
||||
|
||||
Reference in New Issue
Block a user