更改generator的返回值为一个数据模型稳定api

This commit is contained in:
UnCLAS-Prommer
2025-08-22 23:40:24 +08:00
parent 2d4fd08ac5
commit 1eeabe76ba
6 changed files with 90 additions and 89 deletions

View File

@@ -21,6 +21,7 @@ from src.plugin_system.base.component_types import ActionInfo
if TYPE_CHECKING:
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.llm_data_model import LLMGenerationDataModel
install(extra_lines=3)
@@ -85,11 +86,9 @@ async def generate_reply(
enable_tool: bool = False,
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
return_prompt: bool = False,
request_type: str = "generator_api",
from_plugin: bool = True,
return_expressions: bool = False,
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str], Optional[List[int]]]:
) -> Tuple[bool, Optional["LLMGenerationDataModel"]]:
"""生成回复
Args:
@@ -117,7 +116,7 @@ async def generate_reply(
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return False, [], None, None
return False, None
if not extra_info and action_data:
extra_info = action_data.get("extra_info", "")
@@ -126,7 +125,7 @@ async def generate_reply(
reply_reason = action_data.get("reason", "")
# 调用回复器生成回复
success, llm_response_dict, prompt, selected_expressions = await replyer.generate_reply_with_context(
success, llm_response = await replyer.generate_reply_with_context(
extra_info=extra_info,
available_actions=available_actions,
chosen_actions=chosen_actions,
@@ -138,43 +137,27 @@ async def generate_reply(
)
if not success:
logger.warning("[GeneratorAPI] 回复生成失败")
return False, [], None, None
assert llm_response_dict is not None, "llm_response_dict不应为None" # 虽然说不会出现llm_response为空的情况
if content := llm_response_dict.get("content", ""):
return False, None
if content := llm_response.content:
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
else:
reply_set = []
llm_response.reply_set = reply_set
logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项")
# if return_prompt:
# if return_expressions:
# return success, reply_set, prompt, selected_expressions
# else:
# return success, reply_set, prompt, None
# else:
# if return_expressions:
# return success, reply_set, (None, selected_expressions)
# else:
# return success, reply_set, None
return (
success,
reply_set,
prompt if return_prompt else None,
selected_expressions if return_expressions else None,
)
return success, llm_response
except ValueError as ve:
raise ve
except UserWarning as uw:
logger.warning(f"[GeneratorAPI] 中断了生成: {uw}")
return False, [], None, None
return False, None
except Exception as e:
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}")
logger.error(traceback.format_exc())
return False, [], None, None
return False, None
async def rewrite_reply(
chat_stream: Optional[ChatStream] = None,
@@ -185,9 +168,8 @@ async def rewrite_reply(
raw_reply: str = "",
reason: str = "",
reply_to: str = "",
return_prompt: bool = False,
request_type: str = "generator_api",
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
) -> Tuple[bool, Optional["LLMGenerationDataModel"]]:
"""重写回复
Args:
@@ -210,7 +192,7 @@ async def rewrite_reply(
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return False, [], None
return False, None
logger.info("[GeneratorAPI] 开始重写回复")
@@ -221,29 +203,28 @@ async def rewrite_reply(
reply_to = reply_to or reply_data.get("reply_to", "")
# 调用回复器重写回复
success, content, prompt = await replyer.rewrite_reply_with_context(
success, llm_response = await replyer.rewrite_reply_with_context(
raw_reply=raw_reply,
reason=reason,
reply_to=reply_to,
return_prompt=return_prompt,
)
reply_set = []
if content:
if success and llm_response and (content := llm_response.content):
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
llm_response.reply_set = reply_set
if success:
logger.info(f"[GeneratorAPI] 重写回复成功,生成了 {len(reply_set)} 个回复项")
else:
logger.warning("[GeneratorAPI] 重写回复失败")
return success, reply_set, prompt if return_prompt else None
return success, llm_response
except ValueError as ve:
raise ve
except Exception as e:
logger.error(f"[GeneratorAPI] 重写回复时出错: {e}")
return False, [], None
return False, None
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> List[Tuple[str, Any]]: