调整对应的调用

This commit is contained in:
UnCLAS-Prommer
2025-07-30 17:07:55 +08:00
parent 3c40ceda4c
commit 6c0edd0ad7
40 changed files with 580 additions and 1236 deletions

View File

@@ -12,6 +12,7 @@ import traceback
from typing import Tuple, Any, Dict, List, Optional
from rich.traceback import install
from src.common.logger import get_logger
from src.config.api_ada_configs import TaskConfig
from src.chat.replyer.default_generator import DefaultReplyer
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.utils.utils import process_llm_response
@@ -31,7 +32,7 @@ logger = get_logger("generator_api")
def get_replyer(
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
model_configs: Optional[List[Dict[str, Any]]] = None,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
request_type: str = "replyer",
) -> Optional[DefaultReplyer]:
"""获取回复器对象
@@ -58,7 +59,7 @@ def get_replyer(
return replyer_manager.get_replyer(
chat_stream=chat_stream,
chat_id=chat_id,
model_configs=model_configs,
model_set_with_weight=model_set_with_weight,
request_type=request_type,
)
except Exception as e:
@@ -83,7 +84,7 @@ async def generate_reply(
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
return_prompt: bool = False,
model_configs: Optional[List[Dict[str, Any]]] = None,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
request_type: str = "generator_api",
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
"""生成回复
@@ -106,7 +107,7 @@ async def generate_reply(
"""
try:
# 获取回复器
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs, request_type=request_type)
replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight, request_type=request_type)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return False, [], None
@@ -154,7 +155,7 @@ async def rewrite_reply(
chat_id: Optional[str] = None,
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
model_configs: Optional[List[Dict[str, Any]]] = None,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
raw_reply: str = "",
reason: str = "",
reply_to: str = "",
@@ -179,7 +180,7 @@ async def rewrite_reply(
"""
try:
# 获取回复器
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs)
replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return False, [], None
@@ -245,17 +246,17 @@ async def process_human_text(content: str, enable_splitter: bool, enable_chinese
async def generate_response_custom(
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
model_configs: Optional[List[Dict[str, Any]]] = None,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
prompt: str = "",
) -> Optional[str]:
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs)
replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return None
try:
logger.debug("[GeneratorAPI] 开始生成自定义回复")
response = await replyer.llm_generate_content(prompt)
response, _, _, _ = await replyer.llm_generate_content(prompt)
if response:
logger.debug("[GeneratorAPI] 自定义回复生成成功")
return response