fix:修复引用回复逻辑
This commit is contained in:
@@ -6,13 +6,13 @@ from collections import Counter
|
||||
import jieba
|
||||
import numpy as np
|
||||
from src.common.logger import get_module_logger
|
||||
from pymongo.errors import PyMongoError
|
||||
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ..utils.typo_generator import ChineseTypoGenerator
|
||||
from ...config.config import global_config
|
||||
from .message import MessageRecv, Message
|
||||
from .message import MessageRecv
|
||||
from maim_message import UserInfo
|
||||
from .chat_stream import ChatStream
|
||||
from ..moods.moods import MoodManager
|
||||
from ...common.database import db
|
||||
|
||||
@@ -107,8 +107,6 @@ async def get_embedding(text, request_type="embedding"):
|
||||
return embedding
|
||||
|
||||
|
||||
|
||||
|
||||
def get_recent_group_detailed_plain_text(chat_stream_id: str, limit: int = 12, combine=False):
|
||||
recent_messages = list(
|
||||
db.messages.find(
|
||||
@@ -566,93 +564,45 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
|
||||
"""计算两个时间点之间的消息数量和文本总长度
|
||||
|
||||
Args:
|
||||
start_time (float): 起始时间戳
|
||||
end_time (float): 结束时间戳
|
||||
start_time (float): 起始时间戳 (不包含)
|
||||
end_time (float): 结束时间戳 (包含)
|
||||
stream_id (str): 聊天流ID
|
||||
|
||||
Returns:
|
||||
tuple[int, int]: (消息数量, 文本总长度)
|
||||
- 消息数量:包含起始时间的消息,不包含结束时间的消息
|
||||
- 文本总长度:所有消息的processed_plain_text长度之和
|
||||
"""
|
||||
count = 0
|
||||
total_length = 0
|
||||
|
||||
# 参数校验 (可选但推荐)
|
||||
if start_time >= end_time:
|
||||
# logger.debug(f"开始时间 {start_time} 大于或等于结束时间 {end_time},返回 0, 0")
|
||||
return 0, 0
|
||||
if not stream_id:
|
||||
logger.error("stream_id 不能为空")
|
||||
return 0, 0
|
||||
|
||||
# 直接查询时间范围内的消息
|
||||
# time > start_time AND time <= end_time
|
||||
query = {"chat_id": stream_id, "time": {"$gt": start_time, "$lte": end_time}}
|
||||
|
||||
try:
|
||||
# 获取开始时间之前最新的一条消息
|
||||
start_message = db.messages.find_one(
|
||||
{"chat_id": stream_id, "time": {"$lte": start_time}},
|
||||
sort=[("time", -1), ("_id", -1)], # 按时间倒序,_id倒序(最后插入的在前)
|
||||
)
|
||||
# 执行查询
|
||||
messages_cursor = db.messages.find(query)
|
||||
|
||||
# 获取结束时间最近的一条消息
|
||||
# 先找到结束时间点的所有消息
|
||||
end_time_messages = list(
|
||||
db.messages.find(
|
||||
{"chat_id": stream_id, "time": {"$lte": end_time}},
|
||||
sort=[("time", -1)], # 先按时间倒序
|
||||
).limit(10)
|
||||
) # 限制查询数量,避免性能问题
|
||||
|
||||
if not end_time_messages:
|
||||
logger.warning(f"未找到结束时间 {end_time} 之前的消息")
|
||||
return 0, 0
|
||||
|
||||
# 找到最大时间
|
||||
max_time = end_time_messages[0]["time"]
|
||||
# 在最大时间的消息中找最后插入的(_id最大的)
|
||||
end_message = max([msg for msg in end_time_messages if msg["time"] == max_time], key=lambda x: x["_id"])
|
||||
|
||||
if not start_message:
|
||||
logger.warning(f"未找到开始时间 {start_time} 之前的消息")
|
||||
return 0, 0
|
||||
|
||||
# 调试输出
|
||||
# print("\n=== 消息范围信息 ===")
|
||||
# print("Start message:", {
|
||||
# "message_id": start_message.get("message_id"),
|
||||
# "time": start_message.get("time"),
|
||||
# "text": start_message.get("processed_plain_text", ""),
|
||||
# "_id": str(start_message.get("_id"))
|
||||
# })
|
||||
# print("End message:", {
|
||||
# "message_id": end_message.get("message_id"),
|
||||
# "time": end_message.get("time"),
|
||||
# "text": end_message.get("processed_plain_text", ""),
|
||||
# "_id": str(end_message.get("_id"))
|
||||
# })
|
||||
# print("Stream ID:", stream_id)
|
||||
|
||||
# 如果结束消息的时间等于开始时间,返回0
|
||||
if end_message["time"] == start_message["time"]:
|
||||
return 0, 0
|
||||
|
||||
# 获取并打印这个时间范围内的所有消息
|
||||
# print("\n=== 时间范围内的所有消息 ===")
|
||||
all_messages = list(
|
||||
db.messages.find(
|
||||
{"chat_id": stream_id, "time": {"$gte": start_message["time"], "$lte": end_message["time"]}},
|
||||
sort=[("time", 1), ("_id", 1)], # 按时间正序,_id正序
|
||||
)
|
||||
)
|
||||
|
||||
count = 0
|
||||
total_length = 0
|
||||
for msg in all_messages:
|
||||
# 遍历结果计算数量和长度
|
||||
for msg in messages_cursor:
|
||||
count += 1
|
||||
text_length = len(msg.get("processed_plain_text", ""))
|
||||
total_length += text_length
|
||||
# print(f"\n消息 {count}:")
|
||||
# print({
|
||||
# "message_id": msg.get("message_id"),
|
||||
# "time": msg.get("time"),
|
||||
# "text": msg.get("processed_plain_text", ""),
|
||||
# "text_length": text_length,
|
||||
# "_id": str(msg.get("_id"))
|
||||
# })
|
||||
total_length += len(msg.get("processed_plain_text", ""))
|
||||
|
||||
# 如果时间不同,需要把end_message本身也计入
|
||||
return count - 1, total_length
|
||||
# logger.debug(f"查询范围 ({start_time}, {end_time}] 内找到 {count} 条消息,总长度 {total_length}")
|
||||
return count, total_length
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算消息数量时出错: {str(e)}")
|
||||
except PyMongoError as e:
|
||||
logger.error(f"查询 stream_id={stream_id} 在 ({start_time}, {end_time}] 范围内的消息时出错: {e}")
|
||||
return 0, 0
|
||||
except Exception as e: # 保留一个通用异常捕获以防万一
|
||||
logger.error(f"计算消息数量时发生意外错误: {e}")
|
||||
return 0, 0
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user