将PFC加回来,修复一大堆PFC的神秘报错
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
import json
|
||||
import random
|
||||
from typing import Tuple, List, Dict, Any
|
||||
from src.common.logger import get_module_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
from .chat_observer import ChatObserver
|
||||
from maim_message import UserInfo
|
||||
|
||||
logger = get_module_logger("reply_checker")
|
||||
logger = get_logger("reply_checker")
|
||||
|
||||
|
||||
class ReplyChecker:
|
||||
@@ -14,13 +15,30 @@ class ReplyChecker:
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model=global_config.llm_PFC_reply_checker, temperature=0.50, max_tokens=1000, request_type="reply_check"
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="reply_check"
|
||||
)
|
||||
self.name = global_config.BOT_NICKNAME
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
self.name = global_config.bot.nickname
|
||||
self.private_name = private_name
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
self.max_retries = 3 # 最大重试次数
|
||||
|
||||
def _get_personality_prompt(self) -> str:
|
||||
"""获取个性提示信息"""
|
||||
prompt_personality = global_config.personality.personality
|
||||
|
||||
# 检查是否需要随机替换为状态
|
||||
if (
|
||||
global_config.personality.states
|
||||
and global_config.personality.state_probability > 0
|
||||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
return f"你的名字是{bot_name},你{prompt_personality};"
|
||||
|
||||
async def check(
|
||||
self, reply: str, goal: str, chat_history: List[Dict[str, Any]], chat_history_text: str, retry_count: int = 0
|
||||
) -> Tuple[bool, str, bool]:
|
||||
@@ -43,7 +61,7 @@ class ReplyChecker:
|
||||
bot_messages = []
|
||||
for msg in reversed(chat_history):
|
||||
user_info = UserInfo.from_dict(msg.get("user_info", {}))
|
||||
if str(user_info.user_id) == str(global_config.BOT_QQ): # 确保比较的是字符串
|
||||
if str(user_info.user_id) == str(global_config.bot.qq_account):
|
||||
bot_messages.append(msg.get("processed_plain_text", ""))
|
||||
if len(bot_messages) >= 2: # 只和最近的两条比较
|
||||
break
|
||||
@@ -129,7 +147,7 @@ class ReplyChecker:
|
||||
content = content.strip()
|
||||
try:
|
||||
# 尝试直接解析
|
||||
result = json.loads(content)
|
||||
result: dict = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
# 如果直接解析失败,尝试查找和提取JSON部分
|
||||
import re
|
||||
@@ -138,7 +156,7 @@ class ReplyChecker:
|
||||
json_match = re.search(json_pattern, content)
|
||||
if json_match:
|
||||
try:
|
||||
result = json.loads(json_match.group())
|
||||
result: dict = json.loads(json_match.group())
|
||||
except json.JSONDecodeError:
|
||||
# 如果JSON解析失败,尝试从文本中提取结果
|
||||
is_suitable = "不合适" not in content.lower() and "违规" not in content.lower()
|
||||
|
||||
Reference in New Issue
Block a user