Ruff format

This commit is contained in:
墨梓柒
2025-12-13 17:14:09 +08:00
parent ef377bb0cd
commit e680a4d1f5
60 changed files with 1546 additions and 1532 deletions

View File

@@ -176,19 +176,19 @@ class BrainChatting:
# 如果有新消息,更新 last_read_time
if len(recent_messages_list) >= 1:
self.last_read_time = time.time()
# 总是执行一次思考迭代(不管有没有新消息)
# wait 动作会在其内部等待,不需要在这里处理
should_continue = await self._observe(recent_messages_list=recent_messages_list)
if not should_continue:
# 选择了 complete_talk返回 False 表示需要等待新消息
return False
# 继续下一次迭代(除非选择了 complete_talk
# 短暂等待后再继续,避免过于频繁的循环
await asyncio.sleep(0.1)
return True
async def _send_and_store_reply(
@@ -328,9 +328,7 @@ class BrainChatting:
)
# 检查是否有 complete_talk 动作(会停止后续迭代)
has_complete_talk = any(
action.action_type == "complete_talk" for action in action_to_use_info
)
has_complete_talk = any(action.action_type == "complete_talk" for action in action_to_use_info)
# 并行执行所有动作
action_tasks = [
@@ -430,12 +428,12 @@ class BrainChatting:
await asyncio.sleep(3)
self._loop_task = asyncio.create_task(self._main_chat_loop())
logger.error(f"{self.log_prefix} 结束了当前聊天循环")
async def _wait_for_new_message(self):
"""等待新消息到达"""
last_check_time = self.last_read_time
check_interval = 1.0 # 每秒检查一次
while self.running:
# 检查是否有新消息
recent_messages_list = message_api.get_messages_by_time_in_chat(
@@ -448,13 +446,13 @@ class BrainChatting:
filter_command=False,
filter_intercept_message_level=1,
)
# 如果有新消息,更新 last_read_time 并返回
if len(recent_messages_list) >= 1:
self.last_read_time = time.time()
logger.info(f"{self.log_prefix} 检测到新消息,恢复循环")
return
# 等待一段时间后再次检查
await asyncio.sleep(check_interval)
@@ -660,9 +658,9 @@ class BrainChatting:
except (ValueError, TypeError):
logger.warning(f"{self.log_prefix} wait_seconds 参数格式错误,使用默认值 5 秒")
wait_seconds = 5
logger.info(f"{self.log_prefix} 执行 wait 动作,等待 {wait_seconds}")
# 记录动作信息
await database_api.store_action_info(
chat_stream=self.chat_stream,
@@ -673,12 +671,12 @@ class BrainChatting:
action_data={"reason": reason, "wait_seconds": wait_seconds},
action_name="wait",
)
# 等待指定时间
await asyncio.sleep(wait_seconds)
logger.info(f"{self.log_prefix} wait 动作完成,继续下一次思考")
# 这些动作本身不产生文本回复
self._last_successful_reply = False
return {
@@ -693,9 +691,9 @@ class BrainChatting:
logger.debug(f"{self.log_prefix} 检测到 listening 动作,已合并到 wait自动转换")
# 使用默认等待时间
wait_seconds = 3
logger.info(f"{self.log_prefix} 执行 listening转换为 wait动作等待 {wait_seconds}")
# 记录动作信息
await database_api.store_action_info(
chat_stream=self.chat_stream,
@@ -706,12 +704,12 @@ class BrainChatting:
action_data={"reason": reason, "wait_seconds": wait_seconds},
action_name="listening",
)
# 等待指定时间
await asyncio.sleep(wait_seconds)
logger.info(f"{self.log_prefix} listening 动作完成,继续下一次思考")
# 这些动作本身不产生文本回复
self._last_successful_reply = False
return {

View File

@@ -147,7 +147,7 @@ class BrainPlanner:
) # 用于动作规划
self.last_obs_time_mark = 0.0
# 计划日志记录
self.plan_log: List[Tuple[str, float, List[ActionPlannerInfo]]] = []
@@ -203,9 +203,11 @@ class BrainPlanner:
# 内部保留动作(不依赖插件系统)
# 注意listening 已合并到 wait 中,如果遇到 listening 则转换为 wait
internal_action_names = ["complete_talk", "reply", "wait_time", "wait", "listening"]
logger.debug(f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}")
logger.debug(
f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}"
)
# 将 listening 转换为 wait向后兼容
if action == "listening":
logger.debug(f"{self.log_prefix}检测到 listening 动作,已合并到 wait自动转换")
@@ -521,7 +523,7 @@ class BrainPlanner:
if json_objects:
logger.info(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
for i, json_obj in enumerate(json_objects):
logger.info(f"{self.log_prefix}解析第{i+1}个JSON对象: {json_obj}")
logger.info(f"{self.log_prefix}解析第{i + 1}个JSON对象: {json_obj}")
filtered_actions_list = list(filtered_actions.items())
for json_obj in json_objects:
parsed_actions = self._parse_single_action(json_obj, message_id_list, filtered_actions_list)
@@ -553,7 +555,9 @@ class BrainPlanner:
return extracted_reasoning, actions
def _create_complete_talk(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
def _create_complete_talk(
self, reasoning: str, available_actions: Dict[str, ActionInfo]
) -> List[ActionPlannerInfo]:
"""创建complete_talk"""
return [
ActionPlannerInfo(
@@ -564,7 +568,7 @@ class BrainPlanner:
available_actions=available_actions,
)
]
def add_plan_log(self, reasoning: str, actions: List[ActionPlannerInfo]):
"""添加计划日志"""
self.plan_log.append((reasoning, time.time(), actions))

View File

@@ -271,7 +271,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
emoji.description = emoji_data.description
# Deserialize emotion string from DB to list
emoji.emotion = emoji_data.emotion.replace("",",").split(",") if emoji_data.emotion else []
emoji.emotion = emoji_data.emotion.replace("", ",").split(",") if emoji_data.emotion else []
emoji.usage_count = emoji_data.usage_count
db_last_used_time = emoji_data.last_used_time
@@ -732,7 +732,7 @@ class EmojiManager:
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
if emoji_record and emoji_record.emotion:
logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...")
return emoji_record.emotion.replace("",",").split(",")
return emoji_record.emotion.replace("", ",").split(",")
except Exception as e:
logger.error(f"从数据库查询表情包情感标签时出错: {e}")
@@ -993,7 +993,7 @@ class EmojiManager:
)
# 处理情感列表
emotions = [e.strip() for e in emotions_text.replace("",",").split(",") if e.strip()]
emotions = [e.strip() for e in emotions_text.replace("", ",").split(",") if e.strip()]
# 根据情感标签数量随机选择 - 超过5个选3个超过2个选2个
if len(emotions) > 5:

View File

@@ -619,13 +619,13 @@ class HeartFChatting:
think_level = 0
# 使用 action_reasoningplanner 的整体思考理由)作为 reply_reason
planner_reasoning = action_planner_info.action_reasoning or reason
record_replyer_action_temp(
chat_id=self.stream_id,
reason=reason,
think_level=think_level,
)
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,

View File

@@ -123,7 +123,11 @@ class ChatBot:
logger.warning(f"命令执行失败: {command_class.__name__} - {response}")
# 根据命令的拦截设置决定是否继续处理消息
return True, response, not bool(intercept_message_level) # 找到命令根据intercept_message决定是否继续
return (
True,
response,
not bool(intercept_message_level),
) # 找到命令根据intercept_message决定是否继续
except Exception as e:
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")

View File

@@ -263,7 +263,7 @@ class MessageRecv(Message):
desc = segment.data.get("desc", "") # 内容描述
source_url = segment.data.get("source_url", "") # 原始链接
url = segment.data.get("url", "") # 小程序链接
text = f"[小程序分享"
text = "[小程序分享"
if title:
text += f" - {title}"
text += "]"

View File

@@ -42,22 +42,21 @@ def is_webui_virtual_group(group_id: str) -> bool:
def parse_message_segments(segment) -> list:
"""解析消息段,转换为 WebUI 可用的格式
参考 NapCat 适配器的消息解析逻辑
Args:
segment: Seg 消息段对象
Returns:
list: 消息段列表,每个元素为 {"type": "...", "data": ...}
"""
from maim_message import Seg
result = []
if segment is None:
return result
if segment.type == "seglist":
# 处理消息段列表
if segment.data:
@@ -112,15 +111,19 @@ def parse_message_segments(segment) -> list:
forward_items = []
if segment.data:
for item in segment.data:
forward_items.append({
"content": parse_message_segments(item.get("message_segment", {})) if isinstance(item, dict) else []
})
forward_items.append(
{
"content": parse_message_segments(item.get("message_segment", {}))
if isinstance(item, dict)
else []
}
)
result.append({"type": "forward", "data": forward_items})
else:
# 未知类型,尝试作为文本处理
if segment.data:
result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)})
return result
@@ -134,7 +137,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
chat_manager, webui_platform = get_webui_chat_broadcaster()
is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id)
if is_webui_message and chat_manager is not None:
# WebUI 聊天室消息(包括虚拟身份模式),通过 WebSocket 广播
import time
@@ -142,7 +145,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
# 解析消息段,获取富文本内容
message_segments = parse_message_segments(message.message_segment)
# 判断消息类型
# 如果只有一个文本段,使用简单的 text 类型
# 否则使用 rich 类型,包含完整的消息段

View File

@@ -77,8 +77,7 @@ target_message_id为必填表示触发消息的id
```""",
"planner_prompt",
)
Prompt(
"""
{action_name}

View File

@@ -250,7 +250,12 @@ class DefaultReplyer:
# 使用从处理器传来的选中表达方式
# 使用模型预测选择表达方式
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason, think_level=think_level
self.chat_stream.stream_id,
chat_history,
max_num=8,
target_message=target,
reply_reason=reply_reason,
think_level=think_level,
)
if selected_expressions:
@@ -273,7 +278,6 @@ class DefaultReplyer:
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
"""构建工具信息块
@@ -788,7 +792,8 @@ class DefaultReplyer:
# 并行执行八个构建任务(包括黑话解释)
task_results = await asyncio.gather(
self._time_and_run_task(
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason, think_level=think_level), "expression_habits"
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason, think_level=think_level),
"expression_habits",
),
self._time_and_run_task(
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
@@ -980,7 +985,6 @@ class DefaultReplyer:
else:
reply_target_block = ""
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")

View File

@@ -287,7 +287,6 @@ class PrivateReplyer:
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
"""构建工具信息块
@@ -907,16 +906,11 @@ class PrivateReplyer:
else:
reply_target_block = ""
chat_target_name = "对方"
if self.chat_target_info:
chat_target_name = self.chat_target_info.person_name or self.chat_target_info.user_nickname or "对方"
chat_target_1 = await global_prompt_manager.format_prompt(
"chat_target_private1", sender_name=chat_target_name
)
chat_target_2 = await global_prompt_manager.format_prompt(
"chat_target_private2", sender_name=chat_target_name
)
chat_target_1 = await global_prompt_manager.format_prompt("chat_target_private1", sender_name=chat_target_name)
chat_target_2 = await global_prompt_manager.format_prompt("chat_target_private2", sender_name=chat_target_name)
template_name = "default_expressor_prompt"

View File

@@ -1,8 +1,9 @@
from src.chat.utils.prompt_builder import Prompt
def init_replyer_private_prompt():
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval}{jargon_explanation}
你正在和{sender_name}聊天,这是你们之前聊的内容:
@@ -17,9 +18,9 @@ def init_replyer_private_prompt():
{reply_style}
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
{moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。""",
"private_replyer_prompt",
)
"private_replyer_prompt",
)
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval}{jargon_explanation}
@@ -37,4 +38,4 @@ def init_replyer_private_prompt():
{moderation_prompt}不要输出多余内容(包括冒号和引号括号表情包at或 @等 )。
""",
"private_replyer_self_prompt",
)
)

View File

@@ -23,7 +23,7 @@ def init_replyer_prompt():
现在,你说:""",
"replyer_prompt_0",
)
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval}{jargon_explanation}
@@ -44,4 +44,3 @@ def init_replyer_prompt():
现在,你说:""",
"replyer_prompt",
)

View File

@@ -311,7 +311,10 @@ def get_raw_msg_before_timestamp_with_chat(
filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
sort_order = [("time", 1)]
return find_messages(
message_filter=filter_query, sort=sort_order, limit=limit, filter_intercept_message_level=filter_intercept_message_level
message_filter=filter_query,
sort=sort_order,
limit=limit,
filter_intercept_message_level=filter_intercept_message_level,
)

View File

@@ -746,7 +746,7 @@ class StatisticOutputTask(AsyncTask):
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
output = [
"按模型分类统计:",
" 模型名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
@@ -759,11 +759,11 @@ class StatisticOutputTask(AsyncTask):
cost = stats[COST_BY_MODEL][model_name]
avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name]
std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name]
# 计算每次回复平均值
avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
# 格式化大数字
formatted_count = _format_large_number(count)
formatted_in_tokens = _format_large_number(in_tokens)
@@ -771,7 +771,7 @@ class StatisticOutputTask(AsyncTask):
formatted_tokens = _format_large_number(tokens)
formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
output.append(
data_fmt.format(
name,
@@ -800,7 +800,7 @@ class StatisticOutputTask(AsyncTask):
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
output = [
"按模块分类统计:",
" 模块名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
@@ -813,11 +813,11 @@ class StatisticOutputTask(AsyncTask):
cost = stats[COST_BY_MODULE][module_name]
avg_time_cost = stats[AVG_TIME_COST_BY_MODULE][module_name]
std_time_cost = stats[STD_TIME_COST_BY_MODULE][module_name]
# 计算每次回复平均值
avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
# 格式化大数字
formatted_count = _format_large_number(count)
formatted_in_tokens = _format_large_number(in_tokens)
@@ -825,7 +825,7 @@ class StatisticOutputTask(AsyncTask):
formatted_tokens = _format_large_number(tokens)
formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
output.append(
data_fmt.format(
name,

View File

@@ -646,7 +646,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP
def record_replyer_action_temp(chat_id: str, reason: str, think_level: int) -> None:
"""
临时记录replyer动作被选择的信息仅群聊
Args:
chat_id: 聊天ID
reason: 选择理由
@@ -656,7 +656,7 @@ def record_replyer_action_temp(chat_id: str, reason: str, think_level: int) -> N
# 确保data/temp目录存在
temp_dir = "data/temp"
os.makedirs(temp_dir, exist_ok=True)
# 创建记录数据
record_data = {
"chat_id": chat_id,
@@ -664,16 +664,16 @@ def record_replyer_action_temp(chat_id: str, reason: str, think_level: int) -> N
"think_level": think_level,
"timestamp": datetime.now().isoformat(),
}
# 生成文件名(使用时间戳避免冲突)
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"replyer_action_{timestamp_str}.json"
filepath = os.path.join(temp_dir, filename)
# 写入文件
with open(filepath, "w", encoding="utf-8") as f:
json.dump(record_data, f, ensure_ascii=False, indent=2)
logger.debug(f"已记录replyer动作选择: chat_id={chat_id}, think_level={think_level}")
except Exception as e:
logger.warning(f"记录replyer动作选择失败: {e}")

View File

@@ -130,12 +130,10 @@ class ImageManager:
try:
# 清理Images表中type为emoji的记录
deleted_images = Images.delete().where(Images.type == "emoji").execute()
# 清理ImageDescriptions表中type为emoji的记录
deleted_descriptions = (
ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
)
deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
total_deleted = deleted_images + deleted_descriptions
if total_deleted > 0:
logger.info(
@@ -166,7 +164,7 @@ class ImageManager:
async def _save_emoji_file_if_needed(self, image_base64: str, image_hash: str, image_format: str) -> None:
"""如果启用了steal_emoji且表情包未注册保存文件到data/emoji目录
Args:
image_base64: 图片的base64编码
image_hash: 图片的MD5哈希值
@@ -174,7 +172,7 @@ class ImageManager:
"""
if not global_config.emoji.steal_emoji:
return
try:
from src.chat.emoji_system.emoji_manager import EMOJI_DIR
from src.chat.emoji_system.emoji_manager import get_emoji_manager
@@ -236,12 +234,16 @@ class ImageManager:
# 优先使用情感标签,如果没有则使用详细描述
result_text = ""
if cache_record.emotion_tags:
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}...")
logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..."
)
result_text = f"[表情包:{cache_record.emotion_tags}]"
elif cache_record.description:
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}...")
logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..."
)
result_text = f"[表情包:{cache_record.description}]"
# 即使缓存命中如果启用了steal_emoji也检查是否需要保存文件
if result_text:
await self._save_emoji_file_if_needed(image_base64, image_hash, image_format)