给Expression系统的一些准备文件
This commit is contained in:
@@ -1,22 +1,15 @@
|
|||||||
{chat_observe_info}
|
请评估以下表达方式或语言风格以及使用条件或使用情景是否合适:
|
||||||
|
使用条件或使用情景:{situation}
|
||||||
|
表达方式或言语风格:{style}
|
||||||
|
|
||||||
你的名字是{bot_name}{target_message}
|
请从以下方面进行评估:
|
||||||
{reply_reason_block}
|
{criteria_list}
|
||||||
|
|
||||||
以下是可选的表达情境:
|
请以JSON格式输出评估结果:
|
||||||
{all_situations}
|
|
||||||
|
|
||||||
请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的,最多{max_num}个情境。
|
|
||||||
考虑因素包括:
|
|
||||||
1.聊天的情绪氛围(轻松、严肃、幽默等)
|
|
||||||
2.话题类型(日常、技术、游戏、情感等)
|
|
||||||
3.情境与当前语境的匹配度
|
|
||||||
{target_message_extra_block}
|
|
||||||
|
|
||||||
请以JSON格式输出,只需要输出选中的情境编号:
|
|
||||||
例如:
|
|
||||||
{{
|
{{
|
||||||
"selected_situations": [2, 3, 5, 7, 19]
|
"suitable": true/false,
|
||||||
}}
|
"reason": "评估理由(如果不合适,请说明原因)"
|
||||||
|
|
||||||
请严格按照JSON格式输出,不要包含其他内容:
|
}}
|
||||||
|
如果合适,suitable设为true;如果不合适,suitable设为false,并在reason中说明原因。
|
||||||
|
请严格按照JSON格式输出,不要包含其他内容。
|
||||||
22
prompts/expression_select.prompt
Normal file
22
prompts/expression_select.prompt
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
{chat_observe_info}
|
||||||
|
|
||||||
|
你的名字是{bot_name}{target_message}
|
||||||
|
{reply_reason_block}
|
||||||
|
|
||||||
|
以下是可选的表达情境:
|
||||||
|
{all_situations}
|
||||||
|
|
||||||
|
请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的,最多{max_num}个情境。
|
||||||
|
考虑因素包括:
|
||||||
|
1.聊天的情绪氛围(轻松、严肃、幽默等)
|
||||||
|
2.话题类型(日常、技术、游戏、情感等)
|
||||||
|
3.情境与当前语境的匹配度
|
||||||
|
{target_message_extra_block}
|
||||||
|
|
||||||
|
请以JSON格式输出,只需要输出选中的情境编号:
|
||||||
|
例如:
|
||||||
|
{{
|
||||||
|
"selected_situations": [2, 3, 5, 7, 19]
|
||||||
|
}}
|
||||||
|
|
||||||
|
请严格按照JSON格式输出,不要包含其他内容:
|
||||||
@@ -121,7 +121,7 @@ def setup_mocks(monkeypatch):
|
|||||||
db_mod = _stub_module("src.common.database.database")
|
db_mod = _stub_module("src.common.database.database")
|
||||||
db_mod.get_db_session = get_db_session
|
db_mod.get_db_session = get_db_session
|
||||||
db_mod.get_manual_db_session = get_manual_db_session
|
db_mod.get_manual_db_session = get_manual_db_session
|
||||||
|
|
||||||
db_model_mod = _stub_module("src.common.database.database_model")
|
db_model_mod = _stub_module("src.common.database.database_model")
|
||||||
db_model_mod.Messages = None # 可以根据需要添加更多的属性或方法
|
db_model_mod.Messages = None # 可以根据需要添加更多的属性或方法
|
||||||
|
|
||||||
|
|||||||
@@ -378,7 +378,7 @@ class ExpressionSelector:
|
|||||||
reply_reason_block = ""
|
reply_reason_block = ""
|
||||||
|
|
||||||
# 3. 构建prompt(只包含情境,不包含完整的表达方式)
|
# 3. 构建prompt(只包含情境,不包含完整的表达方式)
|
||||||
prompt_template = prompt_manager.get_prompt("expression_evaluation")
|
prompt_template = prompt_manager.get_prompt("expression_select")
|
||||||
prompt_template.add_context("bot_name", global_config.bot.nickname)
|
prompt_template.add_context("bot_name", global_config.bot.nickname)
|
||||||
prompt_template.add_context("chat_observe_info", chat_context)
|
prompt_template.add_context("chat_observe_info", chat_context)
|
||||||
prompt_template.add_context("all_situations", all_situations_str)
|
prompt_template.add_context("all_situations", all_situations_str)
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ class MaiExpression(BaseDatabaseDataModel[Expression]):
|
|||||||
item_id: int,
|
item_id: int,
|
||||||
situation: str,
|
situation: str,
|
||||||
style: str,
|
style: str,
|
||||||
context: str,
|
# context: str,
|
||||||
# up_content: str,
|
# up_content: str,
|
||||||
content: List[str],
|
content: List[str],
|
||||||
count: int,
|
count: int,
|
||||||
@@ -31,8 +31,8 @@ class MaiExpression(BaseDatabaseDataModel[Expression]):
|
|||||||
"""表达方式使用情景"""
|
"""表达方式使用情景"""
|
||||||
self.style = style
|
self.style = style
|
||||||
"""表达方式风格"""
|
"""表达方式风格"""
|
||||||
self.context = context
|
# self.context = context
|
||||||
"""表达方式上下文"""
|
# """表达方式上下文"""
|
||||||
# self.up_content = up_content
|
# self.up_content = up_content
|
||||||
self.content: List[str] = content
|
self.content: List[str] = content
|
||||||
"""内容列表"""
|
"""内容列表"""
|
||||||
@@ -40,7 +40,7 @@ class MaiExpression(BaseDatabaseDataModel[Expression]):
|
|||||||
self.last_active_time: datetime = last_active_time or datetime.now()
|
self.last_active_time: datetime = last_active_time or datetime.now()
|
||||||
self.create_time: datetime = create_time or datetime.now()
|
self.create_time: datetime = create_time or datetime.now()
|
||||||
self.session_id: Optional[str] = session_id
|
self.session_id: Optional[str] = session_id
|
||||||
|
|
||||||
self.checked: bool = checked
|
self.checked: bool = checked
|
||||||
"""是否已经被检查过"""
|
"""是否已经被检查过"""
|
||||||
self.rejected: bool = rejected
|
self.rejected: bool = rejected
|
||||||
@@ -58,7 +58,7 @@ class MaiExpression(BaseDatabaseDataModel[Expression]):
|
|||||||
item_id=db_record.id, # type: ignore
|
item_id=db_record.id, # type: ignore
|
||||||
situation=db_record.situation,
|
situation=db_record.situation,
|
||||||
style=db_record.style,
|
style=db_record.style,
|
||||||
context=db_record.context,
|
# context=db_record.context,
|
||||||
content=content_list,
|
content=content_list,
|
||||||
count=db_record.count,
|
count=db_record.count,
|
||||||
last_active_time=db_record.last_active_time,
|
last_active_time=db_record.last_active_time,
|
||||||
@@ -77,7 +77,7 @@ class MaiExpression(BaseDatabaseDataModel[Expression]):
|
|||||||
id=self.item_id,
|
id=self.item_id,
|
||||||
situation=self.situation,
|
situation=self.situation,
|
||||||
style=self.style,
|
style=self.style,
|
||||||
context=self.context,
|
# context=self.context,
|
||||||
content_list=json.dumps(self.content),
|
content_list=json.dumps(self.content),
|
||||||
count=self.count,
|
count=self.count,
|
||||||
last_active_time=self.last_active_time,
|
last_active_time=self.last_active_time,
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from maim_message import Seg, UserInfo, MessageBase, BaseMessageInfo
|
from maim_message import Seg, UserInfo, MessageBase, BaseMessageInfo
|
||||||
|
from pathlib import Path
|
||||||
|
from sqlmodel import select
|
||||||
from typing import Optional, List, Union, Dict, Any
|
from typing import Optional, List, Union, Dict, Any
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -61,8 +63,21 @@ class ImageComponent(BaseMessageComponentModel, ByteComponent):
|
|||||||
return "image"
|
return "image"
|
||||||
|
|
||||||
async def load_image_binary(self):
|
async def load_image_binary(self):
|
||||||
if not self.binary_data:
|
if self.binary_data:
|
||||||
raise NotImplementedError
|
return
|
||||||
|
from src.common.database.database import get_db_session
|
||||||
|
from src.common.database.database_model import Images, ImageType
|
||||||
|
|
||||||
|
try:
|
||||||
|
with get_db_session() as db:
|
||||||
|
statement = select(Images).filter_by(image_hash=self.binary_hash, image_type=ImageType.IMAGE).limit(1)
|
||||||
|
if image_record := db.exec(statement).first():
|
||||||
|
image_path = Path(image_record.full_path)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"无法通过 image_hash 加载图片二进制数据: {self.binary_hash}")
|
||||||
|
self.binary_data = await asyncio.to_thread(image_path.read_bytes)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"通过 image_hash 加载图片二进制数据时发生错误: {e}") from e
|
||||||
|
|
||||||
async def to_seg(self) -> Seg:
|
async def to_seg(self) -> Seg:
|
||||||
if not self.binary_data:
|
if not self.binary_data:
|
||||||
@@ -85,18 +100,21 @@ class EmojiComponent(BaseMessageComponentModel, ByteComponent):
|
|||||||
ValueError: 如果 binary_data 为空且缺少 emoji_hash
|
ValueError: 如果 binary_data 为空且缺少 emoji_hash
|
||||||
ValueError: 如果无法通过 emoji_hash 加载表情二进制数据
|
ValueError: 如果无法通过 emoji_hash 加载表情二进制数据
|
||||||
"""
|
"""
|
||||||
if not self.binary_data:
|
if self.binary_data:
|
||||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
return
|
||||||
|
from src.common.database.database import get_db_session
|
||||||
|
from src.common.database.database_model import Images, ImageType
|
||||||
|
|
||||||
if not (
|
try:
|
||||||
emoji := emoji_manager.get_emoji_by_hash(self.binary_hash)
|
with get_db_session() as db:
|
||||||
or emoji_manager.get_emoji_by_hash_from_db(self.binary_hash)
|
statement = select(Images).filter_by(image_hash=self.binary_hash, image_type=ImageType.EMOJI).limit(1)
|
||||||
):
|
if image_record := db.exec(statement).first():
|
||||||
raise ValueError(f"无法通过 emoji_hash 加载表情二进制数据: {self.binary_hash}")
|
image_path = Path(image_record.full_path)
|
||||||
try:
|
else:
|
||||||
self.binary_data = await asyncio.to_thread(emoji.full_path.read_bytes)
|
raise ValueError(f"无法通过 emoji_hash 加载表情二进制数据: {self.binary_hash}")
|
||||||
except Exception as e:
|
self.binary_data = await asyncio.to_thread(image_path.read_bytes)
|
||||||
raise ValueError(f"通过 emoji_hash 加载表情二进制数据时发生错误: {e}") from e
|
except Exception as e:
|
||||||
|
raise ValueError(f"通过 emoji_hash 加载表情二进制数据时发生错误: {e}") from e
|
||||||
|
|
||||||
async def to_seg(self) -> Seg:
|
async def to_seg(self) -> Seg:
|
||||||
if not self.binary_data:
|
if not self.binary_data:
|
||||||
|
|||||||
@@ -170,7 +170,7 @@ class Expression(SQLModel, table=True):
|
|||||||
situation: str = Field(index=True, max_length=255, primary_key=True) # 情景
|
situation: str = Field(index=True, max_length=255, primary_key=True) # 情景
|
||||||
style: str = Field(index=True, max_length=255, primary_key=True) # 风格
|
style: str = Field(index=True, max_length=255, primary_key=True) # 风格
|
||||||
|
|
||||||
context: str # 上下文
|
# context: str # 上下文
|
||||||
# up_content: str
|
# up_content: str
|
||||||
|
|
||||||
content_list: str # 内容列表,JSON格式存储
|
content_list: str # 内容列表,JSON格式存储
|
||||||
|
|||||||
Reference in New Issue
Block a user