fix:优化记忆提取和聊天压缩

This commit is contained in:
SengokuCola
2025-11-10 12:27:54 +08:00
parent 10cd2474af
commit 71a2a4282b
5 changed files with 212 additions and 139 deletions

View File

@@ -9,6 +9,7 @@ from src.common.logger import get_logger
from src.config.config import model_config
from src.common.database.database_model import ChatHistory
from src.llm_models.utils_model import LLMRequest
from src.chat.utils.utils import parse_keywords_string
from .tool_registry import register_memory_retrieval_tool
from .tool_utils import parse_datetime_to_timestamp, parse_time_range
@@ -18,51 +19,46 @@ logger = get_logger("memory_retrieval_tools")
async def query_chat_history(
chat_id: str,
keyword: Optional[str] = None,
time_point: Optional[str] = None,
time_range: Optional[str] = None
) -> str:
"""根据时间或关键词在chat_history表中查询聊天记录概述
Args:
chat_id: 聊天ID
keyword: 关键词(可选)
time_point: 时间点格式YYYY-MM-DD HH:MM:SS可选
time_range: 时间范围,格式"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"(可选)
keyword: 关键词(可选,支持多个关键词,可用空格、逗号等分隔
time_range: 时间范围或时间点,格式:
- 时间范围:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
- 时间点:"YYYY-MM-DD HH:MM:SS"(查询包含该时间点的记录)
Returns:
str: 查询结果
"""
try:
# 检查参数
if not keyword and not time_point and not time_range:
return "未指定查询参数需要提供keyword、time_point或time_range之一"
if not keyword and not time_range:
return "未指定查询参数需要提供keyword或time_range之一"
# 构建查询条件
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
# 时间过滤条件
time_conditions = []
if time_point:
# 时间点查询包含该时间点的记录start_time <= time_point <= end_time
target_timestamp = parse_datetime_to_timestamp(time_point)
time_conditions.append(
(ChatHistory.start_time <= target_timestamp) &
(ChatHistory.end_time >= target_timestamp)
)
elif time_range:
# 时间范围:查询与时间范围有交集的记录
start_timestamp, end_timestamp = parse_time_range(time_range)
# 交集条件start_time < end_timestamp AND end_time > start_timestamp
time_conditions.append(
(ChatHistory.start_time < end_timestamp) &
(ChatHistory.end_time > start_timestamp)
)
if time_conditions:
# 合并所有时间条件OR关系
time_filter = time_conditions[0]
for condition in time_conditions[1:]:
time_filter = time_filter | condition
if time_range:
# 判断是时间点还是时间范围
if " - " in time_range:
# 时间范围:查询与时间范围有交集的记录
start_timestamp, end_timestamp = parse_time_range(time_range)
# 交集条件start_time < end_timestamp AND end_time > start_timestamp
time_filter = (
(ChatHistory.start_time < end_timestamp) &
(ChatHistory.end_time > start_timestamp)
)
else:
# 时间点查询包含该时间点的记录start_time <= time_point <= end_time
target_timestamp = parse_datetime_to_timestamp(time_range)
time_filter = (
(ChatHistory.start_time <= target_timestamp) &
(ChatHistory.end_time >= target_timestamp)
)
query = query.where(time_filter)
# 执行查询
@@ -73,7 +69,17 @@ async def query_chat_history(
# 如果有关键词,进一步过滤
if keyword:
keyword_lower = keyword.lower()
# 解析多个关键词(支持空格、逗号等分隔符)
keywords_list = parse_keywords_string(keyword)
if not keywords_list:
keywords_list = [keyword.strip()] if keyword.strip() else []
# 转换为小写以便匹配
keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()]
if not keywords_lower:
return "关键词为空"
filtered_records = []
for record in records:
@@ -82,25 +88,32 @@ async def query_chat_history(
summary = (record.summary or "").lower()
original_text = (record.original_text or "").lower()
# 解析keywords JSON
keywords_list = []
# 解析record中的keywords JSON
record_keywords_list = []
if record.keywords:
try:
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
if isinstance(keywords_data, list):
keywords_list = [str(k).lower() for k in keywords_data]
record_keywords_list = [str(k).lower() for k in keywords_data]
except (json.JSONDecodeError, TypeError, ValueError):
pass
# 检查是否包含关键词
if (keyword_lower in theme or
keyword_lower in summary or
keyword_lower in original_text or
any(keyword_lower in k for k in keywords_list)):
# 检查是否包含任意一个关键词OR关系
matched = False
for kw in keywords_lower:
if (kw in theme or
kw in summary or
kw in original_text or
any(kw in k for k in record_keywords_list)):
matched = True
break
if matched:
filtered_records.append(record)
if not filtered_records:
return f"未找到包含关键词'{keyword}'的聊天记录概述"
keywords_str = "".join(keywords_list)
return f"未找到包含关键词'{keywords_str}'的聊天记录概述"
records = filtered_records
@@ -146,11 +159,18 @@ async def query_chat_history(
query_desc = []
if keyword:
query_desc.append(f"关键词:{keyword}")
if time_point:
query_desc.append(f"时间点:{time_point}")
# 解析关键词列表用于显示
keywords_list = parse_keywords_string(keyword)
if keywords_list:
keywords_str = "".join(keywords_list)
query_desc.append(f"关键词:{keywords_str}")
else:
query_desc.append(f"关键词:{keyword}")
if time_range:
query_desc.append(f"时间范围:{time_range}")
if " - " in time_range:
query_desc.append(f"时间范围:{time_range}")
else:
query_desc.append(f"时间点:{time_range}")
query_info = "".join(query_desc) if query_desc else "聊天记录概述"
@@ -201,19 +221,13 @@ def register_tool():
{
"name": "keyword",
"type": "string",
"description": "关键词(可选,用于在主题、关键词、概括、原文中搜索)",
"required": False
},
{
"name": "time_point",
"type": "string",
"description": "时间点格式YYYY-MM-DD HH:MM:SS可选与time_range二选一。用于查询包含该时间点的聊天记录概述",
"description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘''麦麦,百度网盘'用于在主题、关键词、概括、原文中搜索,只要包含任意一个关键词即匹配",
"required": False
},
{
"name": "time_range",
"type": "string",
"description": "时间范围格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'可选与time_point二选一。用于查询与时间范围有交集的聊天记录概述",
"description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)",
"required": False
}
],

View File

@@ -12,8 +12,7 @@ logger = get_logger("memory_retrieval_tools")
async def query_jargon(
keyword: str,
chat_id: str,
fuzzy: bool = False,
search_all: bool = False
fuzzy: bool = False
) -> str:
"""根据关键词在jargon库中查询
@@ -21,7 +20,6 @@ async def query_jargon(
keyword: 关键词(黑话/俚语/缩写)
chat_id: 聊天ID
fuzzy: 是否使用模糊搜索默认False精确匹配
search_all: 是否搜索全库不限chat_id默认False仅搜索当前会话或global
Returns:
str: 查询结果
@@ -31,11 +29,10 @@ async def query_jargon(
if not content:
return "关键词为空"
# 根据参数执行搜索
search_chat_id = None if search_all else chat_id
# 执行搜索(仅搜索当前会话或全局)
results = search_jargon(
keyword=content,
chat_id=search_chat_id,
chat_id=chat_id,
limit=1,
case_sensitive=False,
fuzzy=fuzzy
@@ -46,15 +43,13 @@ async def query_jargon(
translation = result.get("translation", "").strip()
meaning = result.get("meaning", "").strip()
search_type = "模糊搜索" if fuzzy else "精确匹配"
search_scope = "全库" if search_all else "当前会话或全局"
output = f"{content}可能为黑话或者网络简写,翻译为:{translation},含义为:{meaning}"
logger.info(f"在jargon库中找到匹配{search_scope}{search_type}: {content}")
output = f'"{content}可能为黑话或者网络简写,翻译为:{translation},含义为:{meaning}"'
logger.info(f"在jargon库中找到匹配当前会话或全局{search_type}: {content}")
return output
# 未命中
search_type = "模糊搜索" if fuzzy else "精确匹配"
search_scope = "全库" if search_all else "当前会话或全局"
logger.info(f"在jargon库中未找到匹配{search_scope}{search_type}: {content}")
logger.info(f"在jargon库中未找到匹配当前会话或全局{search_type}: {content}")
return f"未在jargon库中找到'{content}'的解释"
except Exception as e:
@@ -66,7 +61,7 @@ def register_tool():
"""注册工具"""
register_memory_retrieval_tool(
name="query_jargon",
description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索和模糊搜索。默认优先搜索当前会话或全局jargon,可以设置为搜索全库",
description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索和模糊搜索。搜索当前会话或全局jargon。",
parameters=[
{
"name": "keyword",
@@ -79,12 +74,6 @@ def register_tool():
"type": "boolean",
"description": "是否使用模糊搜索部分匹配默认False精确匹配。当精确匹配找不到时可以尝试使用模糊搜索。",
"required": False
},
{
"name": "search_all",
"type": "boolean",
"description": "是否搜索全库不限chat_id默认False仅搜索当前会话或global的jargon。当在当前会话中找不到时可以尝试搜索全库。",
"required": False
}
],
execute_func=query_jargon