fix:优化记忆提取和聊天压缩
This commit is contained in:
@@ -9,6 +9,7 @@ from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.common.database.database_model import ChatHistory
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.utils.utils import parse_keywords_string
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
from .tool_utils import parse_datetime_to_timestamp, parse_time_range
|
||||
|
||||
@@ -18,51 +19,46 @@ logger = get_logger("memory_retrieval_tools")
|
||||
async def query_chat_history(
|
||||
chat_id: str,
|
||||
keyword: Optional[str] = None,
|
||||
time_point: Optional[str] = None,
|
||||
time_range: Optional[str] = None
|
||||
) -> str:
|
||||
"""根据时间或关键词在chat_history表中查询聊天记录概述
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
keyword: 关键词(可选)
|
||||
time_point: 时间点,格式:YYYY-MM-DD HH:MM:SS(可选)
|
||||
time_range: 时间范围,格式:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"(可选)
|
||||
keyword: 关键词(可选,支持多个关键词,可用空格、逗号等分隔)
|
||||
time_range: 时间范围或时间点,格式:
|
||||
- 时间范围:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
|
||||
- 时间点:"YYYY-MM-DD HH:MM:SS"(查询包含该时间点的记录)
|
||||
|
||||
Returns:
|
||||
str: 查询结果
|
||||
"""
|
||||
try:
|
||||
# 检查参数
|
||||
if not keyword and not time_point and not time_range:
|
||||
return "未指定查询参数(需要提供keyword、time_point或time_range之一)"
|
||||
if not keyword and not time_range:
|
||||
return "未指定查询参数(需要提供keyword或time_range之一)"
|
||||
|
||||
# 构建查询条件
|
||||
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
|
||||
|
||||
# 时间过滤条件
|
||||
time_conditions = []
|
||||
if time_point:
|
||||
# 时间点:查询包含该时间点的记录(start_time <= time_point <= end_time)
|
||||
target_timestamp = parse_datetime_to_timestamp(time_point)
|
||||
time_conditions.append(
|
||||
(ChatHistory.start_time <= target_timestamp) &
|
||||
(ChatHistory.end_time >= target_timestamp)
|
||||
)
|
||||
elif time_range:
|
||||
# 时间范围:查询与时间范围有交集的记录
|
||||
start_timestamp, end_timestamp = parse_time_range(time_range)
|
||||
# 交集条件:start_time < end_timestamp AND end_time > start_timestamp
|
||||
time_conditions.append(
|
||||
(ChatHistory.start_time < end_timestamp) &
|
||||
(ChatHistory.end_time > start_timestamp)
|
||||
)
|
||||
|
||||
if time_conditions:
|
||||
# 合并所有时间条件(OR关系)
|
||||
time_filter = time_conditions[0]
|
||||
for condition in time_conditions[1:]:
|
||||
time_filter = time_filter | condition
|
||||
if time_range:
|
||||
# 判断是时间点还是时间范围
|
||||
if " - " in time_range:
|
||||
# 时间范围:查询与时间范围有交集的记录
|
||||
start_timestamp, end_timestamp = parse_time_range(time_range)
|
||||
# 交集条件:start_time < end_timestamp AND end_time > start_timestamp
|
||||
time_filter = (
|
||||
(ChatHistory.start_time < end_timestamp) &
|
||||
(ChatHistory.end_time > start_timestamp)
|
||||
)
|
||||
else:
|
||||
# 时间点:查询包含该时间点的记录(start_time <= time_point <= end_time)
|
||||
target_timestamp = parse_datetime_to_timestamp(time_range)
|
||||
time_filter = (
|
||||
(ChatHistory.start_time <= target_timestamp) &
|
||||
(ChatHistory.end_time >= target_timestamp)
|
||||
)
|
||||
query = query.where(time_filter)
|
||||
|
||||
# 执行查询
|
||||
@@ -73,7 +69,17 @@ async def query_chat_history(
|
||||
|
||||
# 如果有关键词,进一步过滤
|
||||
if keyword:
|
||||
keyword_lower = keyword.lower()
|
||||
# 解析多个关键词(支持空格、逗号等分隔符)
|
||||
keywords_list = parse_keywords_string(keyword)
|
||||
if not keywords_list:
|
||||
keywords_list = [keyword.strip()] if keyword.strip() else []
|
||||
|
||||
# 转换为小写以便匹配
|
||||
keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()]
|
||||
|
||||
if not keywords_lower:
|
||||
return "关键词为空"
|
||||
|
||||
filtered_records = []
|
||||
|
||||
for record in records:
|
||||
@@ -82,25 +88,32 @@ async def query_chat_history(
|
||||
summary = (record.summary or "").lower()
|
||||
original_text = (record.original_text or "").lower()
|
||||
|
||||
# 解析keywords JSON
|
||||
keywords_list = []
|
||||
# 解析record中的keywords JSON
|
||||
record_keywords_list = []
|
||||
if record.keywords:
|
||||
try:
|
||||
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||
if isinstance(keywords_data, list):
|
||||
keywords_list = [str(k).lower() for k in keywords_data]
|
||||
record_keywords_list = [str(k).lower() for k in keywords_data]
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
# 检查是否包含关键词
|
||||
if (keyword_lower in theme or
|
||||
keyword_lower in summary or
|
||||
keyword_lower in original_text or
|
||||
any(keyword_lower in k for k in keywords_list)):
|
||||
# 检查是否包含任意一个关键词(OR关系)
|
||||
matched = False
|
||||
for kw in keywords_lower:
|
||||
if (kw in theme or
|
||||
kw in summary or
|
||||
kw in original_text or
|
||||
any(kw in k for k in record_keywords_list)):
|
||||
matched = True
|
||||
break
|
||||
|
||||
if matched:
|
||||
filtered_records.append(record)
|
||||
|
||||
if not filtered_records:
|
||||
return f"未找到包含关键词'{keyword}'的聊天记录概述"
|
||||
keywords_str = "、".join(keywords_list)
|
||||
return f"未找到包含关键词'{keywords_str}'的聊天记录概述"
|
||||
|
||||
records = filtered_records
|
||||
|
||||
@@ -146,11 +159,18 @@ async def query_chat_history(
|
||||
|
||||
query_desc = []
|
||||
if keyword:
|
||||
query_desc.append(f"关键词:{keyword}")
|
||||
if time_point:
|
||||
query_desc.append(f"时间点:{time_point}")
|
||||
# 解析关键词列表用于显示
|
||||
keywords_list = parse_keywords_string(keyword)
|
||||
if keywords_list:
|
||||
keywords_str = "、".join(keywords_list)
|
||||
query_desc.append(f"关键词:{keywords_str}")
|
||||
else:
|
||||
query_desc.append(f"关键词:{keyword}")
|
||||
if time_range:
|
||||
query_desc.append(f"时间范围:{time_range}")
|
||||
if " - " in time_range:
|
||||
query_desc.append(f"时间范围:{time_range}")
|
||||
else:
|
||||
query_desc.append(f"时间点:{time_range}")
|
||||
|
||||
query_info = ",".join(query_desc) if query_desc else "聊天记录概述"
|
||||
|
||||
@@ -201,19 +221,13 @@ def register_tool():
|
||||
{
|
||||
"name": "keyword",
|
||||
"type": "string",
|
||||
"description": "关键词(可选,用于在主题、关键词、概括、原文中搜索)",
|
||||
"required": False
|
||||
},
|
||||
{
|
||||
"name": "time_point",
|
||||
"type": "string",
|
||||
"description": "时间点,格式:YYYY-MM-DD HH:MM:SS(可选,与time_range二选一)。用于查询包含该时间点的聊天记录概述",
|
||||
"description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘' 或 '麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索,只要包含任意一个关键词即匹配)",
|
||||
"required": False
|
||||
},
|
||||
{
|
||||
"name": "time_range",
|
||||
"type": "string",
|
||||
"description": "时间范围,格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(可选,与time_point二选一)。用于查询与时间范围有交集的聊天记录概述",
|
||||
"description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)",
|
||||
"required": False
|
||||
}
|
||||
],
|
||||
|
||||
@@ -12,8 +12,7 @@ logger = get_logger("memory_retrieval_tools")
|
||||
async def query_jargon(
|
||||
keyword: str,
|
||||
chat_id: str,
|
||||
fuzzy: bool = False,
|
||||
search_all: bool = False
|
||||
fuzzy: bool = False
|
||||
) -> str:
|
||||
"""根据关键词在jargon库中查询
|
||||
|
||||
@@ -21,7 +20,6 @@ async def query_jargon(
|
||||
keyword: 关键词(黑话/俚语/缩写)
|
||||
chat_id: 聊天ID
|
||||
fuzzy: 是否使用模糊搜索,默认False(精确匹配)
|
||||
search_all: 是否搜索全库(不限chat_id),默认False(仅搜索当前会话或global)
|
||||
|
||||
Returns:
|
||||
str: 查询结果
|
||||
@@ -31,11 +29,10 @@ async def query_jargon(
|
||||
if not content:
|
||||
return "关键词为空"
|
||||
|
||||
# 根据参数执行搜索
|
||||
search_chat_id = None if search_all else chat_id
|
||||
# 执行搜索(仅搜索当前会话或全局)
|
||||
results = search_jargon(
|
||||
keyword=content,
|
||||
chat_id=search_chat_id,
|
||||
chat_id=chat_id,
|
||||
limit=1,
|
||||
case_sensitive=False,
|
||||
fuzzy=fuzzy
|
||||
@@ -46,15 +43,13 @@ async def query_jargon(
|
||||
translation = result.get("translation", "").strip()
|
||||
meaning = result.get("meaning", "").strip()
|
||||
search_type = "模糊搜索" if fuzzy else "精确匹配"
|
||||
search_scope = "全库" if search_all else "当前会话或全局"
|
||||
output = f"“{content}可能为黑话或者网络简写,翻译为:{translation},含义为:{meaning}”"
|
||||
logger.info(f"在jargon库中找到匹配({search_scope},{search_type}): {content}")
|
||||
output = f'"{content}可能为黑话或者网络简写,翻译为:{translation},含义为:{meaning}"'
|
||||
logger.info(f"在jargon库中找到匹配(当前会话或全局,{search_type}): {content}")
|
||||
return output
|
||||
|
||||
# 未命中
|
||||
search_type = "模糊搜索" if fuzzy else "精确匹配"
|
||||
search_scope = "全库" if search_all else "当前会话或全局"
|
||||
logger.info(f"在jargon库中未找到匹配({search_scope},{search_type}): {content}")
|
||||
logger.info(f"在jargon库中未找到匹配(当前会话或全局,{search_type}): {content}")
|
||||
return f"未在jargon库中找到'{content}'的解释"
|
||||
|
||||
except Exception as e:
|
||||
@@ -66,7 +61,7 @@ def register_tool():
|
||||
"""注册工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="query_jargon",
|
||||
description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索和模糊搜索。默认优先搜索当前会话或全局jargon,可以设置为搜索全库。",
|
||||
description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索和模糊搜索。仅搜索当前会话或全局jargon。",
|
||||
parameters=[
|
||||
{
|
||||
"name": "keyword",
|
||||
@@ -79,12 +74,6 @@ def register_tool():
|
||||
"type": "boolean",
|
||||
"description": "是否使用模糊搜索(部分匹配),默认False(精确匹配)。当精确匹配找不到时,可以尝试使用模糊搜索。",
|
||||
"required": False
|
||||
},
|
||||
{
|
||||
"name": "search_all",
|
||||
"type": "boolean",
|
||||
"description": "是否搜索全库(不限chat_id),默认False(仅搜索当前会话或global的jargon)。当在当前会话中找不到时,可以尝试搜索全库。",
|
||||
"required": False
|
||||
}
|
||||
],
|
||||
execute_func=query_jargon
|
||||
|
||||
Reference in New Issue
Block a user