feat:聊天特定的额外prompt支持
This commit is contained in:
@@ -1075,6 +1075,24 @@ class ExperimentalConfig(ConfigBase):
|
||||
)
|
||||
"""_wrap_私聊说话规则,行为风格(实验性功能)"""
|
||||
|
||||
group_chat_prompt: str = Field(
|
||||
default="",
|
||||
json_schema_extra={
|
||||
"x-widget": "textarea",
|
||||
"x-icon": "users",
|
||||
},
|
||||
)
|
||||
"""_wrap_群聊通用注意事项(实验性功能)"""
|
||||
|
||||
private_chat_prompts: str = Field(
|
||||
default="",
|
||||
json_schema_extra={
|
||||
"x-widget": "textarea",
|
||||
"x-icon": "user",
|
||||
},
|
||||
)
|
||||
"""_wrap_私聊通用注意事项(实验性功能)"""
|
||||
|
||||
chat_prompts: list[ExtraPromptItem] = Field(
|
||||
default_factory=lambda: [],
|
||||
json_schema_extra={
|
||||
|
||||
@@ -18,6 +18,7 @@ from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.core.tooling import ToolRegistry, ToolSpec
|
||||
from src.know_u.knowledge import extract_category_ids_from_result
|
||||
@@ -63,6 +64,8 @@ class MaisakaChatLoopService:
|
||||
def __init__(
|
||||
self,
|
||||
chat_system_prompt: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
is_group_chat: Optional[bool] = None,
|
||||
temperature: float = 0.5,
|
||||
max_tokens: int = 2048,
|
||||
) -> None:
|
||||
@@ -70,12 +73,16 @@ class MaisakaChatLoopService:
|
||||
|
||||
Args:
|
||||
chat_system_prompt: 可选的系统提示词。
|
||||
session_id: 当前会话 ID,用于匹配会话级额外提示。
|
||||
is_group_chat: 当前会话是否为群聊。
|
||||
temperature: 规划器温度参数。
|
||||
max_tokens: 规划器最大输出长度。
|
||||
"""
|
||||
|
||||
self._temperature = temperature
|
||||
self._max_tokens = max_tokens
|
||||
self._is_group_chat = is_group_chat
|
||||
self._session_id = session_id or ""
|
||||
self._extra_tools: List[ToolOption] = []
|
||||
self._interrupt_flag: asyncio.Event | None = None
|
||||
self._tool_registry: ToolRegistry | None = None
|
||||
@@ -141,6 +148,7 @@ class MaisakaChatLoopService:
|
||||
"maisaka_chat",
|
||||
file_tools_section=tools_section,
|
||||
bot_name=global_config.bot.nickname,
|
||||
group_chat_attention_block=self._build_group_chat_attention_block(),
|
||||
identity=self._personality_prompt,
|
||||
)
|
||||
except Exception:
|
||||
@@ -148,6 +156,74 @@ class MaisakaChatLoopService:
|
||||
|
||||
self._prompts_loaded = True
|
||||
|
||||
def _build_group_chat_attention_block(self) -> str:
|
||||
"""构建当前聊天场景下的额外注意事项块。"""
|
||||
|
||||
prompt_lines: List[str] = []
|
||||
|
||||
if self._is_group_chat is True:
|
||||
if group_chat_prompt := str(global_config.experimental.group_chat_prompt or "").strip():
|
||||
prompt_lines.append(group_chat_prompt)
|
||||
elif self._is_group_chat is False:
|
||||
if private_chat_prompt := str(global_config.experimental.private_chat_prompts or "").strip():
|
||||
prompt_lines.append(private_chat_prompt)
|
||||
|
||||
if self._session_id:
|
||||
if chat_prompt := self._get_chat_prompt_for_chat(self._session_id, self._is_group_chat).strip():
|
||||
prompt_lines.append(chat_prompt)
|
||||
|
||||
if not prompt_lines:
|
||||
return ""
|
||||
|
||||
return f"在该聊天中的注意事项:\n" + "\n".join(prompt_lines) + "\n"
|
||||
|
||||
@staticmethod
|
||||
def _get_chat_prompt_for_chat(chat_id: str, is_group_chat: Optional[bool]) -> str:
|
||||
"""根据聊天流 ID 获取匹配的额外提示。"""
|
||||
|
||||
if not global_config.experimental.chat_prompts:
|
||||
return ""
|
||||
|
||||
for chat_prompt_item in global_config.experimental.chat_prompts:
|
||||
if hasattr(chat_prompt_item, "platform"):
|
||||
platform = str(chat_prompt_item.platform or "").strip()
|
||||
item_id = str(chat_prompt_item.item_id or "").strip()
|
||||
rule_type = str(chat_prompt_item.rule_type or "").strip()
|
||||
prompt_content = str(chat_prompt_item.prompt or "").strip()
|
||||
elif isinstance(chat_prompt_item, str):
|
||||
parts = chat_prompt_item.split(":", 3)
|
||||
if len(parts) != 4:
|
||||
continue
|
||||
|
||||
platform, item_id, rule_type, prompt_content = parts
|
||||
platform = platform.strip()
|
||||
item_id = item_id.strip()
|
||||
rule_type = rule_type.strip()
|
||||
prompt_content = prompt_content.strip()
|
||||
else:
|
||||
continue
|
||||
|
||||
if not platform or not item_id or not prompt_content:
|
||||
continue
|
||||
|
||||
if rule_type == "group":
|
||||
config_is_group = True
|
||||
config_chat_id = SessionUtils.calculate_session_id(platform, group_id=item_id)
|
||||
elif rule_type == "private":
|
||||
config_is_group = False
|
||||
config_chat_id = SessionUtils.calculate_session_id(platform, user_id=item_id)
|
||||
else:
|
||||
continue
|
||||
|
||||
if is_group_chat is not None and config_is_group != is_group_chat:
|
||||
continue
|
||||
|
||||
if config_chat_id == chat_id:
|
||||
logger.debug(f"匹配到 Maisaka 聊天额外提示,chat_id: {chat_id}, prompt: {prompt_content[:50]}...")
|
||||
return prompt_content
|
||||
|
||||
return ""
|
||||
|
||||
def set_extra_tools(self, tools: Sequence[ToolDefinitionInput]) -> None:
|
||||
"""设置额外工具定义。
|
||||
|
||||
|
||||
@@ -49,7 +49,10 @@ class MaisakaHeartFlowChatting:
|
||||
|
||||
session_name = chat_manager.get_session_name(session_id) or session_id
|
||||
self.log_prefix = f"[{session_name}]"
|
||||
self._chat_loop_service = MaisakaChatLoopService()
|
||||
self._chat_loop_service = MaisakaChatLoopService(
|
||||
session_id=session_id,
|
||||
is_group_chat=self.chat_stream.is_group_session,
|
||||
)
|
||||
self._chat_history: list[LLMContextMessage] = []
|
||||
self.history_loop: list[CycleDetail] = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user