调整对应的调用
This commit is contained in:
@@ -12,6 +12,7 @@ import traceback
|
||||
from typing import Tuple, Any, Dict, List, Optional
|
||||
from rich.traceback import install
|
||||
from src.common.logger import get_logger
|
||||
from src.config.api_ada_configs import TaskConfig
|
||||
from src.chat.replyer.default_generator import DefaultReplyer
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.utils.utils import process_llm_response
|
||||
@@ -31,7 +32,7 @@ logger = get_logger("generator_api")
|
||||
def get_replyer(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
model_configs: Optional[List[Dict[str, Any]]] = None,
|
||||
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
|
||||
request_type: str = "replyer",
|
||||
) -> Optional[DefaultReplyer]:
|
||||
"""获取回复器对象
|
||||
@@ -58,7 +59,7 @@ def get_replyer(
|
||||
return replyer_manager.get_replyer(
|
||||
chat_stream=chat_stream,
|
||||
chat_id=chat_id,
|
||||
model_configs=model_configs,
|
||||
model_set_with_weight=model_set_with_weight,
|
||||
request_type=request_type,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -83,7 +84,7 @@ async def generate_reply(
|
||||
enable_splitter: bool = True,
|
||||
enable_chinese_typo: bool = True,
|
||||
return_prompt: bool = False,
|
||||
model_configs: Optional[List[Dict[str, Any]]] = None,
|
||||
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
|
||||
request_type: str = "generator_api",
|
||||
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
|
||||
"""生成回复
|
||||
@@ -106,7 +107,7 @@ async def generate_reply(
|
||||
"""
|
||||
try:
|
||||
# 获取回复器
|
||||
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs, request_type=request_type)
|
||||
replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight, request_type=request_type)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
return False, [], None
|
||||
@@ -154,7 +155,7 @@ async def rewrite_reply(
|
||||
chat_id: Optional[str] = None,
|
||||
enable_splitter: bool = True,
|
||||
enable_chinese_typo: bool = True,
|
||||
model_configs: Optional[List[Dict[str, Any]]] = None,
|
||||
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
|
||||
raw_reply: str = "",
|
||||
reason: str = "",
|
||||
reply_to: str = "",
|
||||
@@ -179,7 +180,7 @@ async def rewrite_reply(
|
||||
"""
|
||||
try:
|
||||
# 获取回复器
|
||||
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs)
|
||||
replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
return False, [], None
|
||||
@@ -245,17 +246,17 @@ async def process_human_text(content: str, enable_splitter: bool, enable_chinese
|
||||
async def generate_response_custom(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
model_configs: Optional[List[Dict[str, Any]]] = None,
|
||||
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
|
||||
prompt: str = "",
|
||||
) -> Optional[str]:
|
||||
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs)
|
||||
replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
return None
|
||||
|
||||
try:
|
||||
logger.debug("[GeneratorAPI] 开始生成自定义回复")
|
||||
response = await replyer.llm_generate_content(prompt)
|
||||
response, _, _, _ = await replyer.llm_generate_content(prompt)
|
||||
if response:
|
||||
logger.debug("[GeneratorAPI] 自定义回复生成成功")
|
||||
return response
|
||||
|
||||
Reference in New Issue
Block a user