Ruff fix
This commit is contained in:
@@ -6,15 +6,16 @@ import sys
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.message_repository import find_messages
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
|
||||
# 确保可从任意工作目录运行:将项目根目录加入 sys.path(scripts 的上一级)
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, PROJECT_ROOT)
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.message_repository import find_messages
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
|
||||
|
||||
|
||||
SECONDS_5_MINUTES = 5 * 60
|
||||
@@ -28,16 +29,16 @@ def clean_output_text(text: str) -> str:
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
|
||||
# 移除表情包内容:[表情包:...]
|
||||
text = re.sub(r'\[表情包:[^\]]*\]', '', text)
|
||||
|
||||
text = re.sub(r"\[表情包:[^\]]*\]", "", text)
|
||||
|
||||
# 移除回复内容:[回复...],说:... 的完整模式
|
||||
text = re.sub(r'\[回复[^\]]*\],说:[^@]*@[^:]*:', '', text)
|
||||
|
||||
text = re.sub(r"\[回复[^\]]*\],说:[^@]*@[^:]*:", "", text)
|
||||
|
||||
# 清理多余的空格和换行
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
@@ -89,7 +90,7 @@ def group_by_chat(messages: Iterable[DatabaseMessages]) -> Dict[str, List[Databa
|
||||
for msg in messages:
|
||||
groups.setdefault(msg.chat_id, []).append(msg)
|
||||
# 保证每个分组内按时间升序
|
||||
for chat_id, msgs in groups.items():
|
||||
for _chat_id, msgs in groups.items():
|
||||
msgs.sort(key=lambda m: m.time or 0)
|
||||
return groups
|
||||
|
||||
@@ -170,8 +171,8 @@ def merge_adjacent_same_user(messages: List[DatabaseMessages]) -> List[DatabaseM
|
||||
continue
|
||||
|
||||
last = bucket[-1]
|
||||
same_user = (msg.user_info.user_id == last.user_info.user_id)
|
||||
close_enough = ((msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES)
|
||||
same_user = msg.user_info.user_id == last.user_info.user_id
|
||||
close_enough = (msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES
|
||||
|
||||
if same_user and close_enough:
|
||||
bucket.append(msg)
|
||||
@@ -199,38 +200,36 @@ def build_pairs_for_chat(
|
||||
pairs: List[Tuple[str, str, str]] = []
|
||||
n_merged = len(merged_messages)
|
||||
n_original = len(original_messages)
|
||||
|
||||
|
||||
if n_merged == 0 or n_original == 0:
|
||||
return pairs
|
||||
|
||||
# 为每个合并后的消息找到对应的原始消息位置
|
||||
merged_to_original_map = {}
|
||||
original_idx = 0
|
||||
|
||||
|
||||
for merged_idx, merged_msg in enumerate(merged_messages):
|
||||
# 找到这个合并消息对应的第一个原始消息
|
||||
while (original_idx < n_original and
|
||||
original_messages[original_idx].time < merged_msg.time):
|
||||
while original_idx < n_original and original_messages[original_idx].time < merged_msg.time:
|
||||
original_idx += 1
|
||||
|
||||
|
||||
# 如果找到了时间匹配的原始消息,建立映射
|
||||
if (original_idx < n_original and
|
||||
original_messages[original_idx].time == merged_msg.time):
|
||||
if original_idx < n_original and original_messages[original_idx].time == merged_msg.time:
|
||||
merged_to_original_map[merged_idx] = original_idx
|
||||
|
||||
for merged_idx in range(n_merged):
|
||||
merged_msg = merged_messages[merged_idx]
|
||||
|
||||
|
||||
# 如果指定了 target_user_id,只处理该用户的消息作为 output
|
||||
if target_user_id and merged_msg.user_info.user_id != target_user_id:
|
||||
continue
|
||||
|
||||
|
||||
# 找到对应的原始消息位置
|
||||
if merged_idx not in merged_to_original_map:
|
||||
continue
|
||||
|
||||
|
||||
original_idx = merged_to_original_map[merged_idx]
|
||||
|
||||
|
||||
# 选择上下文窗口大小
|
||||
window = random.randint(min_ctx, max_ctx) if max_ctx > min_ctx else min_ctx
|
||||
start = max(0, original_idx - window)
|
||||
@@ -266,7 +265,7 @@ def build_pairs(
|
||||
groups = group_by_chat(messages)
|
||||
|
||||
all_pairs: List[Tuple[str, str, str]] = []
|
||||
for chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用
|
||||
for _chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用
|
||||
# 对消息进行合并,用于output
|
||||
merged = merge_adjacent_same_user(msgs)
|
||||
# 传递原始消息和合并后消息,input使用原始消息,output使用合并后消息
|
||||
@@ -385,5 +384,3 @@ def run_interactive() -> int:
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user