This commit is contained in:
SengokuCola
2025-09-28 02:04:49 +08:00
30 changed files with 203 additions and 188 deletions

View File

@@ -16,7 +16,6 @@ from src.chat.brain_chat.brain_planner import BrainPlanner
from src.chat.planner_actions.action_modifier import ActionModifier
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.heart_flow.hfc_utils import CycleDetail
from src.chat.heart_flow.hfc_utils import send_typing, stop_typing
from src.chat.express.expression_learner import expression_learner_manager
from src.person_info.person_info import Person
from src.plugin_system.base.component_types import EventType, ActionInfo
@@ -96,7 +95,6 @@ class BrainChatting:
self.last_read_time = time.time() - 2
self.more_plan = False
async def start(self):
"""检查是否需要启动主循环,如果未激活则启动。"""
@@ -171,10 +169,8 @@ class BrainChatting:
if len(recent_messages_list) >= 1:
self.last_read_time = time.time()
await self._observe(
recent_messages_list=recent_messages_list
)
await self._observe(recent_messages_list=recent_messages_list)
else:
# Normal模式消息数量不足等待
await asyncio.sleep(0.2)
@@ -233,11 +229,11 @@ class BrainChatting:
async def _observe(
self, # interest_value: float = 0.0,
recent_messages_list: Optional[List["DatabaseMessages"]] = None
recent_messages_list: Optional[List["DatabaseMessages"]] = None,
) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
if recent_messages_list is None:
recent_messages_list = []
reply_text = "" # 初始化reply_text变量避免UnboundLocalError
_reply_text = "" # 初始化reply_text变量避免UnboundLocalError
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
await self.expression_learner.trigger_learning_for_chat()
@@ -334,7 +330,7 @@ class BrainChatting:
"taken_time": time.time(),
}
)
reply_text = reply_text_from_reply
_reply_text = reply_text_from_reply
else:
# 没有回复信息构建纯动作的loop_info
loop_info = {
@@ -347,7 +343,7 @@ class BrainChatting:
"taken_time": time.time(),
},
}
reply_text = action_reply_text
_reply_text = action_reply_text
self.end_cycle(loop_info, cycle_timers)
self.print_cycle_info(cycle_timers)
@@ -484,7 +480,6 @@ class BrainChatting:
"""执行单个动作的通用函数"""
try:
with Timer(f"动作{action_planner_info.action_type}", cycle_timers):
if action_planner_info.action_type == "no_reply":
# 直接处理no_action逻辑不再通过动作系统
reason = action_planner_info.reasoning or "选择不回复"
@@ -517,7 +512,9 @@ class BrainChatting:
if not success or not llm_response or not llm_response.reply_set:
if action_planner_info.action_message:
logger.info(f"{action_planner_info.action_message.processed_plain_text} 的回复生成失败")
logger.info(
f"{action_planner_info.action_message.processed_plain_text} 的回复生成失败"
)
else:
logger.info("回复生成失败")
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}

View File

@@ -307,7 +307,9 @@ class BrainPlanner:
if chat_target_info:
# 构建聊天上下文描述
chat_context_description = f"你正在和 {chat_target_info.person_name or chat_target_info.user_nickname or '对方'} 聊天中"
chat_context_description = (
f"你正在和 {chat_target_info.person_name or chat_target_info.user_nickname or '对方'} 聊天中"
)
# 构建动作选项块
action_options_block = await self._build_action_options_block(current_available_actions)

View File

@@ -10,11 +10,14 @@ from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages, build_bare_messages
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp_with_chat_inclusive,
build_anonymous_messages,
build_bare_messages,
)
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from json_repair import repair_json
from src.chat.utils.utils import get_embedding
MAX_EXPRESSION_COUNT = 300
@@ -99,7 +102,9 @@ class ExpressionLearner:
self.last_learning_time: float = time.time()
# 学习参数
_, self.enable_learning, self.learning_intensity = global_config.expression.get_expression_config_for_chat(self.chat_id)
_, self.enable_learning, self.learning_intensity = global_config.expression.get_expression_config_for_chat(
self.chat_id
)
self.min_messages_for_learning = 15 / self.learning_intensity # 触发学习所需的最少消息数
self.min_learning_interval = 150 / self.learning_intensity
@@ -237,17 +242,42 @@ class ExpressionLearner:
return []
learnt_expressions = res
learnt_expressions_str = ""
for _chat_id, situation, style, context, context_words, full_context, full_context_embedding in learnt_expressions:
for (
_chat_id,
situation,
style,
_context,
_context_words,
_full_context,
_full_context_embedding,
) in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n"
logger.info(f"{self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
# 按chat_id分组
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
for chat_id, situation, style, context, context_words, full_context, full_context_embedding in learnt_expressions:
for (
chat_id,
situation,
style,
context,
context_words,
full_context,
full_context_embedding,
) in learnt_expressions:
if chat_id not in chat_dict:
chat_dict[chat_id] = []
chat_dict[chat_id].append({"situation": situation, "style": style, "context": context, "context_words": context_words, "full_context": full_context, "full_context_embedding": full_context_embedding})
chat_dict[chat_id].append(
{
"situation": situation,
"style": style,
"context": context,
"context_words": context_words,
"full_context": full_context,
"full_context_embedding": full_context_embedding,
}
)
current_time = time.time()
@@ -300,11 +330,13 @@ class ExpressionLearner:
expr.delete_instance()
return learnt_expressions
async def match_expression_context(self, expression_pairs: List[Tuple[str, str]], random_msg_match_str: str) -> List[Tuple[str, str, str]]:
async def match_expression_context(
self, expression_pairs: List[Tuple[str, str]], random_msg_match_str: str
) -> List[Tuple[str, str, str]]:
# 为expression_pairs逐个条目赋予编号并构建成字符串
numbered_pairs = []
for i, (situation, style) in enumerate(expression_pairs, 1):
numbered_pairs.append(f"{i}. 当\"{situation}\"时,使用\"{style}\"")
numbered_pairs.append(f'{i}. 当"{situation}"时,使用"{style}"')
expression_pairs_str = "\n".join(numbered_pairs)
@@ -319,20 +351,20 @@ class ExpressionLearner:
print(f"match_expression_context_prompt: {prompt}")
print(f"random_msg_match_str: {response}")
# 解析JSON响应
match_responses = []
try:
response = response.strip()
# 检查是否已经是标准JSON数组格式
if response.startswith('[') and response.endswith(']'):
if response.startswith("[") and response.endswith("]"):
match_responses = json.loads(response)
else:
# 尝试直接解析多个JSON对象
try:
# 如果是多个JSON对象用逗号分隔包装成数组
if response.startswith('{') and not response.startswith('['):
response = '[' + response + ']'
if response.startswith("{") and not response.startswith("["):
response = "[" + response + "]"
match_responses = json.loads(response)
else:
# 使用repair_json处理响应
@@ -394,7 +426,9 @@ class ExpressionLearner:
return matched_expressions
async def learn_expression(self, num: int = 10) -> Optional[List[Tuple[str, str, str, List[str], str, List[float]]]]:
async def learn_expression(
self, num: int = 10
) -> Optional[List[Tuple[str, str, str, List[str], str, List[float]]]]:
"""从指定聊天流学习表达方式
Args:
@@ -416,11 +450,10 @@ class ExpressionLearner:
if not random_msg or random_msg == []:
return None
# 转化成str
chat_id: str = random_msg[0].chat_id
_chat_id: str = random_msg[0].chat_id
# random_msg_str: str = build_readable_messages(random_msg, timestamp_mode="normal")
random_msg_str: str = await build_anonymous_messages(random_msg)
random_msg_match_str: str = await build_bare_messages(random_msg)
prompt: str = await global_prompt_manager.format_prompt(
prompt,
@@ -440,24 +473,31 @@ class ExpressionLearner:
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(expressions, random_msg_match_str)
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
expressions, random_msg_match_str
)
split_matched_expressions: List[Tuple[str, str, str, List[str]]] = self.split_expression_context(
matched_expressions
)
split_matched_expressions: List[Tuple[str, str, str, List[str]]] = self.split_expression_context(matched_expressions)
split_matched_expressions_w_emb = []
full_context_embedding: List[float] = await self.get_full_context_embedding(random_msg_match_str)
for situation, style, context, context_words in split_matched_expressions:
split_matched_expressions_w_emb.append((self.chat_id, situation, style, context, context_words, random_msg_match_str,full_context_embedding))
for situation, style, context, context_words in split_matched_expressions:
split_matched_expressions_w_emb.append(
(self.chat_id, situation, style, context, context_words, random_msg_match_str, full_context_embedding)
)
return split_matched_expressions_w_emb
async def get_full_context_embedding(self, context: str) -> List[float]:
embedding, _ = await self.embedding_model.get_embedding(context)
return embedding
def split_expression_context(self, matched_expressions: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str, List[str]]]:
def split_expression_context(
self, matched_expressions: List[Tuple[str, str, str]]
) -> List[Tuple[str, str, str, List[str]]]:
"""
对matched_expressions中的context部分进行jieba分词

View File

@@ -114,10 +114,10 @@ class ExpressionSelector:
def get_related_chat_ids(self, chat_id: str) -> List[str]:
"""根据expression_groups配置获取与当前chat_id相关的所有chat_id包括自身"""
groups = global_config.expression.expression_groups
# 检查是否存在全局共享组(包含"*"的组)
global_group_exists = any("*" in group for group in groups)
if global_group_exists:
# 如果存在全局共享组则返回所有可用的chat_id
all_chat_ids = set()
@@ -126,7 +126,7 @@ class ExpressionSelector:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
all_chat_ids.add(chat_id_candidate)
return list(all_chat_ids) if all_chat_ids else [chat_id]
# 否则使用现有的组逻辑
for group in groups:
group_chat_ids = []

View File

@@ -43,4 +43,4 @@ class FrequencyControlManager:
# 创建全局实例
frequency_control_manager = FrequencyControlManager()
frequency_control_manager = FrequencyControlManager()

View File

@@ -208,7 +208,11 @@ class HeartFChatting:
# *控制频率用
if mentioned_message:
await self._observe(recent_messages_list=recent_messages_list, force_reply_message=mentioned_message)
elif random.random() < global_config.chat.talk_value * frequency_control_manager.get_or_create_frequency_control(self.stream_id).get_talk_frequency_adjust():
elif (
random.random()
< global_config.chat.talk_value
* frequency_control_manager.get_or_create_frequency_control(self.stream_id).get_talk_frequency_adjust()
):
await self._observe(recent_messages_list=recent_messages_list)
else:
# 没有提到继续保持沉默等待5秒防止频繁触发
@@ -278,9 +282,8 @@ class HeartFChatting:
recent_messages_list = []
reply_text = "" # 初始化reply_text变量避免UnboundLocalError
start_time = time.time()
if s4u_config.enable_s4u:
await send_typing()
@@ -356,7 +359,7 @@ class HeartFChatting:
available_actions=available_actions,
)
)
logger.info(
f"{self.log_prefix} 决定执行{len(action_to_use_info)}个动作: {' '.join([a.action_type for a in action_to_use_info])}"
)
@@ -418,7 +421,7 @@ class HeartFChatting:
},
}
reply_text = action_reply_text
self.end_cycle(loop_info, cycle_timers)
self.print_cycle_info(cycle_timers)
@@ -429,11 +432,6 @@ class HeartFChatting:
else:
await asyncio.sleep(0.1)
"""S4U内容暂时保留"""
if s4u_config.enable_s4u:
await stop_typing()

View File

@@ -1,10 +1,8 @@
import asyncio
import re
import traceback
from typing import Tuple, TYPE_CHECKING
from src.config.config import global_config
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.storage import MessageStorage
from src.chat.heart_flow.heartflow import heartflow
@@ -74,7 +72,7 @@ class HeartFCMessageReceiver:
await self.storage.store_message(message, chat)
heartflow_chat: HeartFChatting = await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
_heartflow_chat: HeartFChatting = await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
# 3. 日志记录
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
@@ -102,7 +100,7 @@ class HeartFCMessageReceiver:
replace_bot_name=True,
)
# if not processed_plain_text:
# print(message)
# print(message)
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") # type: ignore

View File

@@ -8,7 +8,7 @@ from maim_message import UserInfo, Seg, GroupInfo
from src.common.logger import get_logger
from src.config.config import global_config
from src.mood.mood_manager import mood_manager # 导入情绪管理器
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
from src.chat.message_receive.storage import MessageStorage
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver

View File

@@ -343,7 +343,6 @@ class ActionPlanner:
interest=interest,
plan_style=global_config.personality.plan_style,
)
return prompt, message_id_list
except Exception as e:
@@ -508,9 +507,7 @@ class ActionPlanner:
action.action_data = action.action_data or {}
action.action_data["loop_start_time"] = loop_start_time
logger.debug(
f"{self.log_prefix}规划器选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
)
logger.debug(f"{self.log_prefix}规划器选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}")
return actions

View File

@@ -28,7 +28,7 @@ from src.chat.utils.chat_message_builder import (
from src.chat.express.expression_selector import expression_selector
# from src.chat.memory_system.memory_activator import MemoryActivator
from src.person_info.person_info import Person, is_person_known
from src.person_info.person_info import Person
from src.plugin_system.base.component_types import ActionInfo, EventType
from src.plugin_system.apis import llm_api
@@ -43,6 +43,7 @@ init_rewrite_prompt()
logger = get_logger("replyer")
class DefaultReplyer:
def __init__(
self,
@@ -216,7 +217,7 @@ class DefaultReplyer:
traceback.print_exc()
return False, llm_response
#移动到 relation插件中构建
# 移动到 relation插件中构建
# async def build_relation_info(self, chat_content: str, sender: str, person_list: List[Person]):
# if not global_config.relationship.enable_relationship:
# return ""
@@ -278,9 +279,7 @@ class DefaultReplyer:
expression_habits_block = ""
expression_habits_title = ""
if style_habits_str.strip():
expression_habits_title = (
"在回复时,你可以参考以下的语言习惯,不要生硬使用:"
)
expression_habits_title = "在回复时,你可以参考以下的语言习惯,不要生硬使用:"
expression_habits_block += f"{style_habits_str}\n"
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
@@ -510,7 +509,6 @@ class DefaultReplyer:
--------------------------------
"""
# 构建背景对话 prompt
all_dialogue_prompt = ""
if message_list_before_now:
@@ -536,7 +534,6 @@ class DefaultReplyer:
time_block: str,
chat_target_1: str,
chat_target_2: str,
identity_block: str,
sender: str,
target: str,
@@ -774,13 +771,9 @@ class DefaultReplyer:
if sender:
if is_group_chat:
reply_target_block = (
f"现在{sender}说的:{target}。引起了你的注意"
)
reply_target_block = f"现在{sender}说的:{target}。引起了你的注意"
else: # private chat
reply_target_block = (
f"现在{sender}说的:{target}。引起了你的注意"
)
reply_target_block = f"现在{sender}说的:{target}。引起了你的注意"
else:
reply_target_block = ""
@@ -1061,6 +1054,3 @@ def weighted_sample_no_replacement(items, weights, k) -> list:
pool.pop(idx)
break
return selected

View File

@@ -1,9 +1,7 @@
from src.chat.utils.prompt_builder import Prompt
# from src.chat.memory_system.memory_activator import MemoryActivator
def init_lpmm_prompt():
Prompt(
"""
@@ -20,5 +18,3 @@ If you need to use the search tool, please directly call the function "lpmm_sear
""",
name="lpmm_get_knowledge_prompt",
)

View File

@@ -1,9 +1,7 @@
from src.chat.utils.prompt_builder import Prompt
# from src.chat.memory_system.memory_activator import MemoryActivator
def init_rewrite_prompt():
Prompt("你正在qq群里聊天下面是群里正在聊的内容:", "chat_target_group1")
Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1")
@@ -31,4 +29,4 @@ def init_rewrite_prompt():
现在,你说:
""",
"default_expressor_prompt",
)
)

View File

@@ -859,7 +859,6 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
# 处理图片ID
content = process_pic_ids(content)
anon_name = get_anon_name(platform, user_id)
# print(f"anon_name:{anon_name}")
@@ -945,11 +944,12 @@ async def build_bare_messages(messages: List[DatabaseMessages]) -> str:
# 获取纯文本内容
content = msg.processed_plain_text or ""
# 处理图片ID
pic_pattern = r"\[picid:[^\]]+\]"
def replace_pic_id(match):
return "[图片]"
content = re.sub(pic_pattern, replace_pic_id, content)
# 处理用户引用格式,移除回复和@标记