feat:添加数据提取脚本
This commit is contained in:
327
scripts/build_io_pairs.py
Normal file
327
scripts/build_io_pairs.py
Normal file
@@ -0,0 +1,327 @@
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
# 确保可从任意工作目录运行:将项目根目录加入 sys.path(scripts 的上一级)
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, PROJECT_ROOT)
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.message_repository import find_messages
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, build_readable_messages_anonymized
|
||||
|
||||
|
||||
SECONDS_5_MINUTES = 5 * 60
|
||||
|
||||
|
||||
def parse_datetime_to_timestamp(value: str) -> float:
|
||||
"""
|
||||
接受多种常见格式并转换为时间戳(秒)
|
||||
支持示例:
|
||||
- 2025-09-29
|
||||
- 2025-09-29 00:00:00
|
||||
- 2025/09/29 00:00
|
||||
- 2025-09-29T00:00:00
|
||||
"""
|
||||
value = value.strip()
|
||||
fmts = [
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
"%Y-%m-%d %H:%M",
|
||||
"%Y/%m/%d %H:%M:%S",
|
||||
"%Y/%m/%d %H:%M",
|
||||
"%Y-%m-%d",
|
||||
"%Y/%m/%d",
|
||||
"%Y-%m-%dT%H:%M:%S",
|
||||
"%Y-%m-%dT%H:%M",
|
||||
]
|
||||
last_err: Optional[Exception] = None
|
||||
for fmt in fmts:
|
||||
try:
|
||||
dt = datetime.strptime(value, fmt)
|
||||
return dt.timestamp()
|
||||
except Exception as e: # noqa: BLE001
|
||||
last_err = e
|
||||
raise ValueError(f"无法解析时间: {value} ({last_err})")
|
||||
|
||||
|
||||
def fetch_messages_between(
|
||||
start_ts: float,
|
||||
end_ts: float,
|
||||
platform: Optional[str] = None,
|
||||
) -> List[DatabaseMessages]:
|
||||
"""使用 find_messages 获取指定区间的消息,可选按 chat_info_platform 过滤。按时间升序返回。"""
|
||||
filter_query: Dict[str, object] = {"time": {"$gt": start_ts, "$lt": end_ts}}
|
||||
if platform:
|
||||
filter_query["chat_info_platform"] = platform
|
||||
# 当 limit==0 时,sort 生效,这里按时间升序
|
||||
return find_messages(message_filter=filter_query, sort=[("time", 1)], limit=0)
|
||||
|
||||
|
||||
def group_by_chat(messages: Iterable[DatabaseMessages]) -> Dict[str, List[DatabaseMessages]]:
|
||||
groups: Dict[str, List[DatabaseMessages]] = {}
|
||||
for msg in messages:
|
||||
groups.setdefault(msg.chat_id, []).append(msg)
|
||||
# 保证每个分组内按时间升序
|
||||
for chat_id, msgs in groups.items():
|
||||
msgs.sort(key=lambda m: m.time or 0)
|
||||
return groups
|
||||
|
||||
|
||||
def _merge_bucket_to_message(bucket: List[DatabaseMessages]) -> DatabaseMessages:
|
||||
"""
|
||||
将相邻、同一 user_id 且 5 分钟内的消息 bucket 合并为一条。
|
||||
processed_plain_text 合并(以换行连接),其余字段取最新一条(时间最大)。
|
||||
"""
|
||||
if not bucket:
|
||||
raise ValueError("bucket 为空,无法合并")
|
||||
|
||||
latest = bucket[-1]
|
||||
merged_texts: List[str] = []
|
||||
for m in bucket:
|
||||
text = m.processed_plain_text or ""
|
||||
if text:
|
||||
merged_texts.append(text)
|
||||
|
||||
merged = DatabaseMessages(
|
||||
# 其他信息采用最新消息
|
||||
message_id=latest.message_id,
|
||||
time=latest.time,
|
||||
chat_id=latest.chat_id,
|
||||
reply_to=latest.reply_to,
|
||||
interest_value=latest.interest_value,
|
||||
key_words=latest.key_words,
|
||||
key_words_lite=latest.key_words_lite,
|
||||
is_mentioned=latest.is_mentioned,
|
||||
is_at=latest.is_at,
|
||||
reply_probability_boost=latest.reply_probability_boost,
|
||||
processed_plain_text="\n".join(merged_texts) if merged_texts else latest.processed_plain_text,
|
||||
display_message=latest.display_message,
|
||||
priority_mode=latest.priority_mode,
|
||||
priority_info=latest.priority_info,
|
||||
additional_config=latest.additional_config,
|
||||
is_emoji=latest.is_emoji,
|
||||
is_picid=latest.is_picid,
|
||||
is_command=latest.is_command,
|
||||
is_notify=latest.is_notify,
|
||||
selected_expressions=latest.selected_expressions,
|
||||
user_id=latest.user_info.user_id,
|
||||
user_nickname=latest.user_info.user_nickname,
|
||||
user_cardname=latest.user_info.user_cardname,
|
||||
user_platform=latest.user_info.platform,
|
||||
chat_info_group_id=(latest.group_info.group_id if latest.group_info else None),
|
||||
chat_info_group_name=(latest.group_info.group_name if latest.group_info else None),
|
||||
chat_info_group_platform=(latest.group_info.group_platform if latest.group_info else None),
|
||||
chat_info_user_id=latest.chat_info.user_info.user_id,
|
||||
chat_info_user_nickname=latest.chat_info.user_info.user_nickname,
|
||||
chat_info_user_cardname=latest.chat_info.user_info.user_cardname,
|
||||
chat_info_user_platform=latest.chat_info.user_info.platform,
|
||||
chat_info_stream_id=latest.chat_info.stream_id,
|
||||
chat_info_platform=latest.chat_info.platform,
|
||||
chat_info_create_time=latest.chat_info.create_time,
|
||||
chat_info_last_active_time=latest.chat_info.last_active_time,
|
||||
)
|
||||
return merged
|
||||
|
||||
|
||||
def merge_adjacent_same_user(messages: List[DatabaseMessages]) -> List[DatabaseMessages]:
|
||||
"""按 5 分钟窗口合并相邻同 user_id 的消息。输入需按时间升序。"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
merged: List[DatabaseMessages] = []
|
||||
bucket: List[DatabaseMessages] = []
|
||||
|
||||
def flush_bucket() -> None:
|
||||
nonlocal bucket
|
||||
if bucket:
|
||||
merged.append(_merge_bucket_to_message(bucket))
|
||||
bucket = []
|
||||
|
||||
for msg in messages:
|
||||
if not bucket:
|
||||
bucket = [msg]
|
||||
continue
|
||||
|
||||
last = bucket[-1]
|
||||
same_user = (msg.user_info.user_id == last.user_info.user_id)
|
||||
close_enough = ((msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES)
|
||||
|
||||
if same_user and close_enough:
|
||||
bucket.append(msg)
|
||||
else:
|
||||
flush_bucket()
|
||||
bucket = [msg]
|
||||
|
||||
flush_bucket()
|
||||
return merged
|
||||
|
||||
|
||||
def build_pairs_for_chat(
|
||||
merged_messages: List[DatabaseMessages],
|
||||
min_ctx: int,
|
||||
max_ctx: int,
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
对每条消息作为 output,从其前面取 20-30 条(可配置)的消息作为 input。
|
||||
input 使用 chat_message_builder.build_readable_messages 构建为字符串。
|
||||
output 使用该消息的 processed_plain_text。
|
||||
"""
|
||||
pairs: List[Tuple[str, str, str]] = []
|
||||
n = len(merged_messages)
|
||||
if n == 0:
|
||||
return pairs
|
||||
|
||||
for i in range(n):
|
||||
# 选择上下文窗口大小
|
||||
window = random.randint(min_ctx, max_ctx) if max_ctx > min_ctx else min_ctx
|
||||
start = max(0, i - window)
|
||||
context_msgs = merged_messages[start:i]
|
||||
|
||||
# 使用匿名化构建 input,并拿到原始显示名 -> 匿名名的映射
|
||||
input_str, name_mapping = build_readable_messages_anonymized(
|
||||
messages=context_msgs,
|
||||
timestamp_mode="relative",
|
||||
show_actions=False,
|
||||
show_pic=True,
|
||||
)
|
||||
|
||||
# 输出取 processed_plain_text(不再额外替换)
|
||||
output_text = merged_messages[i].processed_plain_text or ""
|
||||
output_id = merged_messages[i].message_id or ""
|
||||
pairs.append((input_str, output_text, output_id))
|
||||
|
||||
return pairs
|
||||
|
||||
|
||||
def build_pairs(
|
||||
start_ts: float,
|
||||
end_ts: float,
|
||||
platform: Optional[str],
|
||||
min_ctx: int,
|
||||
max_ctx: int,
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
messages = fetch_messages_between(start_ts, end_ts, platform)
|
||||
groups = group_by_chat(messages)
|
||||
|
||||
all_pairs: List[Tuple[str, str, str]] = []
|
||||
for chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用
|
||||
merged = merge_adjacent_same_user(msgs)
|
||||
pairs = build_pairs_for_chat(merged, min_ctx, max_ctx)
|
||||
all_pairs.extend(pairs)
|
||||
|
||||
return all_pairs
|
||||
|
||||
|
||||
def main(argv: Optional[List[str]] = None) -> int:
|
||||
# 若未提供参数,则进入交互模式
|
||||
if argv is None:
|
||||
argv = sys.argv[1:]
|
||||
|
||||
if len(argv) == 0:
|
||||
return run_interactive()
|
||||
|
||||
parser = argparse.ArgumentParser(description="构建 (input_str, output_str, message_id) 列表")
|
||||
parser.add_argument("start", help="起始时间,如 2025-09-28 00:00:00")
|
||||
parser.add_argument("end", help="结束时间,如 2025-09-29 00:00:00")
|
||||
parser.add_argument("--platform", default=None, help="仅选择 chat_info_platform 为该值的消息")
|
||||
parser.add_argument("--min_ctx", type=int, default=20, help="输入上下文的最少条数,默认20")
|
||||
parser.add_argument("--max_ctx", type=int, default=30, help="输入上下文的最多条数,默认30")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default=None,
|
||||
help="输出保存路径,支持 .jsonl(每行 {input, output}),若不指定则打印到stdout",
|
||||
)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
start_ts = parse_datetime_to_timestamp(args.start)
|
||||
end_ts = parse_datetime_to_timestamp(args.end)
|
||||
if end_ts <= start_ts:
|
||||
raise ValueError("结束时间必须大于起始时间")
|
||||
|
||||
if args.max_ctx < args.min_ctx:
|
||||
raise ValueError("max_ctx 不能小于 min_ctx")
|
||||
|
||||
pairs = build_pairs(start_ts, end_ts, args.platform, args.min_ctx, args.max_ctx)
|
||||
|
||||
if args.output:
|
||||
# 保存为 JSONL,每行一个 {input, output, message_id}
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
for input_str, output_str, message_id in pairs:
|
||||
obj = {"input": input_str, "output": output_str, "message_id": message_id}
|
||||
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
|
||||
print(f"已保存 {len(pairs)} 条到 {args.output}")
|
||||
else:
|
||||
# 打印到 stdout
|
||||
for input_str, output_str, message_id in pairs:
|
||||
print(json.dumps({"input": input_str, "output": output_str, "message_id": message_id}, ensure_ascii=False))
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def _prompt_with_default(prompt_text: str, default: Optional[str]) -> str:
|
||||
suffix = f"[{default}]" if default not in (None, "") else ""
|
||||
value = input(f"{prompt_text}{' ' + suffix if suffix else ''}: ").strip()
|
||||
if value == "" and default is not None:
|
||||
return default
|
||||
return value
|
||||
|
||||
|
||||
def run_interactive() -> int:
|
||||
print("进入交互模式(直接回车采用默认值)。时间格式例如:2025-09-28 00:00:00 或 2025-09-28")
|
||||
start_str = _prompt_with_default("请输入起始时间", None)
|
||||
end_str = _prompt_with_default("请输入结束时间", None)
|
||||
platform = _prompt_with_default("平台(可留空表示不限)", "")
|
||||
try:
|
||||
min_ctx = int(_prompt_with_default("输入上下文最少条数", "20"))
|
||||
max_ctx = int(_prompt_with_default("输入上下文最多条数", "30"))
|
||||
except Exception:
|
||||
print("上下文条数输入有误,使用默认 20/30")
|
||||
min_ctx, max_ctx = 20, 30
|
||||
output_path = _prompt_with_default("输出路径(.jsonl,可留空打印到控制台)", "")
|
||||
|
||||
if not start_str or not end_str:
|
||||
print("必须提供起始与结束时间。")
|
||||
return 2
|
||||
|
||||
try:
|
||||
start_ts = parse_datetime_to_timestamp(start_str)
|
||||
end_ts = parse_datetime_to_timestamp(end_str)
|
||||
except Exception as e: # noqa: BLE001
|
||||
print(f"时间解析失败:{e}")
|
||||
return 2
|
||||
|
||||
if end_ts <= start_ts:
|
||||
print("结束时间必须大于起始时间。")
|
||||
return 2
|
||||
|
||||
if max_ctx < min_ctx:
|
||||
print("最多条数不能小于最少条数。")
|
||||
return 2
|
||||
|
||||
platform_val = platform if platform != "" else None
|
||||
pairs = build_pairs(start_ts, end_ts, platform_val, min_ctx, max_ctx)
|
||||
|
||||
if output_path:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
for input_str, output_str, message_id in pairs:
|
||||
obj = {"input": input_str, "output": output_str, "message_id": message_id}
|
||||
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
|
||||
print(f"已保存 {len(pairs)} 条到 {output_path}")
|
||||
else:
|
||||
for input_str, output_str, message_id in pairs:
|
||||
print(json.dumps({"input": input_str, "output": output_str, "message_id": message_id}, ensure_ascii=False))
|
||||
print(f"总计 {len(pairs)} 条。")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user