更改generator的返回值为一个数据模型稳定api
This commit is contained in:
@@ -21,6 +21,7 @@ from src.plugin_system.base.component_types import ActionInfo
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -85,11 +86,9 @@ async def generate_reply(
|
||||
enable_tool: bool = False,
|
||||
enable_splitter: bool = True,
|
||||
enable_chinese_typo: bool = True,
|
||||
return_prompt: bool = False,
|
||||
request_type: str = "generator_api",
|
||||
from_plugin: bool = True,
|
||||
return_expressions: bool = False,
|
||||
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str], Optional[List[int]]]:
|
||||
) -> Tuple[bool, Optional["LLMGenerationDataModel"]]:
|
||||
"""生成回复
|
||||
|
||||
Args:
|
||||
@@ -117,7 +116,7 @@ async def generate_reply(
|
||||
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
return False, [], None, None
|
||||
return False, None
|
||||
|
||||
if not extra_info and action_data:
|
||||
extra_info = action_data.get("extra_info", "")
|
||||
@@ -126,7 +125,7 @@ async def generate_reply(
|
||||
reply_reason = action_data.get("reason", "")
|
||||
|
||||
# 调用回复器生成回复
|
||||
success, llm_response_dict, prompt, selected_expressions = await replyer.generate_reply_with_context(
|
||||
success, llm_response = await replyer.generate_reply_with_context(
|
||||
extra_info=extra_info,
|
||||
available_actions=available_actions,
|
||||
chosen_actions=chosen_actions,
|
||||
@@ -138,43 +137,27 @@ async def generate_reply(
|
||||
)
|
||||
if not success:
|
||||
logger.warning("[GeneratorAPI] 回复生成失败")
|
||||
return False, [], None, None
|
||||
assert llm_response_dict is not None, "llm_response_dict不应为None" # 虽然说不会出现llm_response为空的情况
|
||||
if content := llm_response_dict.get("content", ""):
|
||||
return False, None
|
||||
if content := llm_response.content:
|
||||
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
|
||||
else:
|
||||
reply_set = []
|
||||
llm_response.reply_set = reply_set
|
||||
logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项")
|
||||
|
||||
# if return_prompt:
|
||||
# if return_expressions:
|
||||
# return success, reply_set, prompt, selected_expressions
|
||||
# else:
|
||||
# return success, reply_set, prompt, None
|
||||
# else:
|
||||
# if return_expressions:
|
||||
# return success, reply_set, (None, selected_expressions)
|
||||
# else:
|
||||
# return success, reply_set, None
|
||||
return (
|
||||
success,
|
||||
reply_set,
|
||||
prompt if return_prompt else None,
|
||||
selected_expressions if return_expressions else None,
|
||||
)
|
||||
return success, llm_response
|
||||
|
||||
except ValueError as ve:
|
||||
raise ve
|
||||
|
||||
except UserWarning as uw:
|
||||
logger.warning(f"[GeneratorAPI] 中断了生成: {uw}")
|
||||
return False, [], None, None
|
||||
return False, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False, [], None, None
|
||||
|
||||
return False, None
|
||||
|
||||
async def rewrite_reply(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
@@ -185,9 +168,8 @@ async def rewrite_reply(
|
||||
raw_reply: str = "",
|
||||
reason: str = "",
|
||||
reply_to: str = "",
|
||||
return_prompt: bool = False,
|
||||
request_type: str = "generator_api",
|
||||
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
|
||||
) -> Tuple[bool, Optional["LLMGenerationDataModel"]]:
|
||||
"""重写回复
|
||||
|
||||
Args:
|
||||
@@ -210,7 +192,7 @@ async def rewrite_reply(
|
||||
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
return False, [], None
|
||||
return False, None
|
||||
|
||||
logger.info("[GeneratorAPI] 开始重写回复")
|
||||
|
||||
@@ -221,29 +203,28 @@ async def rewrite_reply(
|
||||
reply_to = reply_to or reply_data.get("reply_to", "")
|
||||
|
||||
# 调用回复器重写回复
|
||||
success, content, prompt = await replyer.rewrite_reply_with_context(
|
||||
success, llm_response = await replyer.rewrite_reply_with_context(
|
||||
raw_reply=raw_reply,
|
||||
reason=reason,
|
||||
reply_to=reply_to,
|
||||
return_prompt=return_prompt,
|
||||
)
|
||||
reply_set = []
|
||||
if content:
|
||||
if success and llm_response and (content := llm_response.content):
|
||||
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
|
||||
|
||||
llm_response.reply_set = reply_set
|
||||
if success:
|
||||
logger.info(f"[GeneratorAPI] 重写回复成功,生成了 {len(reply_set)} 个回复项")
|
||||
else:
|
||||
logger.warning("[GeneratorAPI] 重写回复失败")
|
||||
|
||||
return success, reply_set, prompt if return_prompt else None
|
||||
return success, llm_response
|
||||
|
||||
except ValueError as ve:
|
||||
raise ve
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorAPI] 重写回复时出错: {e}")
|
||||
return False, [], None
|
||||
return False, None
|
||||
|
||||
|
||||
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> List[Tuple[str, Any]]:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
from typing import List, Dict, Optional, Type, Tuple, Any, Coroutine
|
||||
from typing import List, Dict, Optional, Type, Tuple, Any, TYPE_CHECKING
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
@@ -9,6 +9,9 @@ from src.plugin_system.base.component_types import EventType, EventHandlerInfo,
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
from .global_announcement_manager import global_announcement_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||
|
||||
logger = get_logger("events_manager")
|
||||
|
||||
|
||||
@@ -47,7 +50,7 @@ class EventsManager:
|
||||
event_type: EventType,
|
||||
message: Optional[MessageRecv] = None,
|
||||
llm_prompt: Optional[str] = None,
|
||||
llm_response: Optional[Dict[str, Any]] = None,
|
||||
llm_response: Optional["LLMGenerationDataModel"] = None,
|
||||
stream_id: Optional[str] = None,
|
||||
action_usage: Optional[List[str]] = None,
|
||||
) -> Optional[MaiMessages]:
|
||||
@@ -97,7 +100,7 @@ class EventsManager:
|
||||
event_type: EventType,
|
||||
message: Optional[MessageRecv] = None,
|
||||
llm_prompt: Optional[str] = None,
|
||||
llm_response: Optional[Dict[str, Any]] = None,
|
||||
llm_response: Optional["LLMGenerationDataModel"] = None,
|
||||
stream_id: Optional[str] = None,
|
||||
action_usage: Optional[List[str]] = None,
|
||||
) -> bool:
|
||||
@@ -175,16 +178,16 @@ class EventsManager:
|
||||
return False
|
||||
|
||||
def _transform_event_message(
|
||||
self, message: MessageRecv, llm_prompt: Optional[str] = None, llm_response: Optional[Dict[str, Any]] = None
|
||||
self, message: MessageRecv, llm_prompt: Optional[str] = None, llm_response: Optional["LLMGenerationDataModel"] = None
|
||||
) -> MaiMessages:
|
||||
"""转换事件消息格式"""
|
||||
# 直接赋值部分内容
|
||||
transformed_message = MaiMessages(
|
||||
llm_prompt=llm_prompt,
|
||||
llm_response_content=llm_response.get("content") if llm_response else None,
|
||||
llm_response_reasoning=llm_response.get("reasoning") if llm_response else None,
|
||||
llm_response_model=llm_response.get("model") if llm_response else None,
|
||||
llm_response_tool_call=llm_response.get("tool_calls") if llm_response else None,
|
||||
llm_response_content=llm_response.content if llm_response else None,
|
||||
llm_response_reasoning=llm_response.reasoning if llm_response else None,
|
||||
llm_response_model=llm_response.model if llm_response else None,
|
||||
llm_response_tool_call=llm_response.tool_calls if llm_response else None,
|
||||
raw_message=message.raw_message,
|
||||
additional_data=message.message_info.additional_config or {},
|
||||
)
|
||||
@@ -228,7 +231,7 @@ class EventsManager:
|
||||
return transformed_message
|
||||
|
||||
def _build_message_from_stream(
|
||||
self, stream_id: str, llm_prompt: Optional[str] = None, llm_response: Optional[Dict[str, Any]] = None
|
||||
self, stream_id: str, llm_prompt: Optional[str] = None, llm_response: Optional["LLMGenerationDataModel"] = None
|
||||
) -> MaiMessages:
|
||||
"""从流ID构建消息"""
|
||||
chat_stream = get_chat_manager().get_stream(stream_id)
|
||||
@@ -240,7 +243,7 @@ class EventsManager:
|
||||
self,
|
||||
stream_id: str,
|
||||
llm_prompt: Optional[str] = None,
|
||||
llm_response: Optional[Dict[str, Any]] = None,
|
||||
llm_response: Optional["LLMGenerationDataModel"] = None,
|
||||
action_usage: Optional[List[str]] = None,
|
||||
) -> MaiMessages:
|
||||
"""没有message对象时进行转换"""
|
||||
@@ -249,10 +252,10 @@ class EventsManager:
|
||||
return MaiMessages(
|
||||
stream_id=stream_id,
|
||||
llm_prompt=llm_prompt,
|
||||
llm_response_content=(llm_response.get("content") if llm_response else None),
|
||||
llm_response_reasoning=(llm_response.get("reasoning") if llm_response else None),
|
||||
llm_response_model=llm_response.get("model") if llm_response else None,
|
||||
llm_response_tool_call=(llm_response.get("tool_calls") if llm_response else None),
|
||||
llm_response_content=(llm_response.content if llm_response else None),
|
||||
llm_response_reasoning=(llm_response.reasoning if llm_response else None),
|
||||
llm_response_model=(llm_response.model if llm_response else None),
|
||||
llm_response_tool_call=(llm_response.tool_calls if llm_response else None),
|
||||
is_group_message=(not (not chat_stream.group_info)),
|
||||
is_private_message=(not chat_stream.group_info),
|
||||
action_usage=action_usage,
|
||||
|
||||
Reference in New Issue
Block a user