feat:将记忆配置项添加到配置文件

This commit is contained in:
SengokuCola
2025-10-08 18:45:06 +08:00
parent e0a5cd5922
commit 16ae212adc
6 changed files with 253 additions and 282 deletions

View File

@@ -1,6 +1,7 @@
import argparse
import json
import random
import re
import sys
import os
from datetime import datetime
@@ -13,12 +14,33 @@ if PROJECT_ROOT not in sys.path:
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.message_repository import find_messages
from src.chat.utils.chat_message_builder import build_readable_messages, build_readable_messages_anonymized
from src.chat.utils.chat_message_builder import build_readable_messages
SECONDS_5_MINUTES = 5 * 60
def clean_output_text(text: str) -> str:
"""
清理输出文本,移除表情包和回复内容
- 移除 [表情包:...] 格式的内容
- 移除 [回复...] 格式的内容
"""
if not text:
return text
# 移除表情包内容:[表情包:...]
text = re.sub(r'\[表情包:[^\]]*\]', '', text)
# 移除回复内容:[回复...],说:... 的完整模式
text = re.sub(r'\[回复[^\]]*\],说:[^@]*@[^:]*:', '', text)
# 清理多余的空格和换行
text = re.sub(r'\s+', ' ', text).strip()
return text
def parse_datetime_to_timestamp(value: str) -> float:
"""
接受多种常见格式并转换为时间戳(秒)
@@ -162,37 +184,70 @@ def merge_adjacent_same_user(messages: List[DatabaseMessages]) -> List[DatabaseM
def build_pairs_for_chat(
original_messages: List[DatabaseMessages],
merged_messages: List[DatabaseMessages],
min_ctx: int,
max_ctx: int,
target_user_id: Optional[str] = None,
) -> List[Tuple[str, str, str]]:
"""
对每条消息作为 output从其前面取 20-30 条(可配置)的消息作为 input。
input 使用 chat_message_builder.build_readable_messages 构建为字符串
output 使用消息的 processed_plain_text。
对每条合并后的消息作为 output从其前面取 20-30 条(可配置)的原始消息作为 input。
input 使用原始未合并的消息构建上下文
output 使用合并后消息的 processed_plain_text。
如果指定了 target_user_id则只处理该用户的消息作为 output。
"""
pairs: List[Tuple[str, str, str]] = []
n = len(merged_messages)
if n == 0:
n_merged = len(merged_messages)
n_original = len(original_messages)
if n_merged == 0 or n_original == 0:
return pairs
for i in range(n):
# 为每个合并后的消息找到对应的原始消息位置
merged_to_original_map = {}
original_idx = 0
for merged_idx, merged_msg in enumerate(merged_messages):
# 找到这个合并消息对应的第一个原始消息
while (original_idx < n_original and
original_messages[original_idx].time < merged_msg.time):
original_idx += 1
# 如果找到了时间匹配的原始消息,建立映射
if (original_idx < n_original and
original_messages[original_idx].time == merged_msg.time):
merged_to_original_map[merged_idx] = original_idx
for merged_idx in range(n_merged):
merged_msg = merged_messages[merged_idx]
# 如果指定了 target_user_id只处理该用户的消息作为 output
if target_user_id and merged_msg.user_info.user_id != target_user_id:
continue
# 找到对应的原始消息位置
if merged_idx not in merged_to_original_map:
continue
original_idx = merged_to_original_map[merged_idx]
# 选择上下文窗口大小
window = random.randint(min_ctx, max_ctx) if max_ctx > min_ctx else min_ctx
start = max(0, i - window)
context_msgs = merged_messages[start:i]
start = max(0, original_idx - window)
context_msgs = original_messages[start:original_idx]
# 使用匿名化构建 input,并拿到原始显示名 -> 匿名名的映射
input_str, name_mapping = build_readable_messages_anonymized(
# 使用原始未合并消息构建 input
input_str = build_readable_messages(
messages=context_msgs,
timestamp_mode="relative",
timestamp_mode="normal_no_YMD",
show_actions=False,
show_pic=True,
)
# 输出取 processed_plain_text(不再额外替换)
output_text = merged_messages[i].processed_plain_text or ""
output_id = merged_messages[i].message_id or ""
# 输出取合并后消息的 processed_plain_text 并清理表情包和回复内容
output_text = merged_msg.processed_plain_text or ""
output_text = clean_output_text(output_text)
output_id = merged_msg.message_id or ""
pairs.append((input_str, output_text, output_id))
return pairs
@@ -202,16 +257,20 @@ def build_pairs(
start_ts: float,
end_ts: float,
platform: Optional[str],
user_id: Optional[str],
min_ctx: int,
max_ctx: int,
) -> List[Tuple[str, str, str]]:
# 获取所有消息不按user_id过滤这样input上下文可以包含所有用户的消息
messages = fetch_messages_between(start_ts, end_ts, platform)
groups = group_by_chat(messages)
all_pairs: List[Tuple[str, str, str]] = []
for chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用
# 对消息进行合并用于output
merged = merge_adjacent_same_user(msgs)
pairs = build_pairs_for_chat(merged, min_ctx, max_ctx)
# 传递原始消息和合并后消息input使用原始消息output使用合并后消息
pairs = build_pairs_for_chat(msgs, merged, min_ctx, max_ctx, user_id)
all_pairs.extend(pairs)
return all_pairs
@@ -225,10 +284,11 @@ def main(argv: Optional[List[str]] = None) -> int:
if len(argv) == 0:
return run_interactive()
parser = argparse.ArgumentParser(description="构建 (input_str, output_str, message_id) 列表")
parser = argparse.ArgumentParser(description="构建 (input_str, output_str, message_id) 列表支持按用户ID筛选消息")
parser.add_argument("start", help="起始时间,如 2025-09-28 00:00:00")
parser.add_argument("end", help="结束时间,如 2025-09-29 00:00:00")
parser.add_argument("--platform", default=None, help="仅选择 chat_info_platform 为该值的消息")
parser.add_argument("--user_id", default=None, help="仅选择指定 user_id 的消息")
parser.add_argument("--min_ctx", type=int, default=20, help="输入上下文的最少条数默认20")
parser.add_argument("--max_ctx", type=int, default=30, help="输入上下文的最多条数默认30")
parser.add_argument(
@@ -247,7 +307,7 @@ def main(argv: Optional[List[str]] = None) -> int:
if args.max_ctx < args.min_ctx:
raise ValueError("max_ctx 不能小于 min_ctx")
pairs = build_pairs(start_ts, end_ts, args.platform, args.min_ctx, args.max_ctx)
pairs = build_pairs(start_ts, end_ts, args.platform, args.user_id, args.min_ctx, args.max_ctx)
if args.output:
# 保存为 JSONL每行一个 {input, output, message_id}
@@ -277,6 +337,7 @@ def run_interactive() -> int:
start_str = _prompt_with_default("请输入起始时间", None)
end_str = _prompt_with_default("请输入结束时间", None)
platform = _prompt_with_default("平台(可留空表示不限)", "")
user_id = _prompt_with_default("用户ID可留空表示不限", "")
try:
min_ctx = int(_prompt_with_default("输入上下文最少条数", "20"))
max_ctx = int(_prompt_with_default("输入上下文最多条数", "30"))
@@ -305,7 +366,8 @@ def run_interactive() -> int:
return 2
platform_val = platform if platform != "" else None
pairs = build_pairs(start_ts, end_ts, platform_val, min_ctx, max_ctx)
user_id_val = user_id if user_id != "" else None
pairs = build_pairs(start_ts, end_ts, platform_val, user_id_val, min_ctx, max_ctx)
if output_path:
with open(output_path, "w", encoding="utf-8") as f: