Merge remote-tracking branch 'upstream/main-fix' into refactor
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
import re
|
||||
@@ -11,7 +10,7 @@ from src.common.logger import get_module_logger
|
||||
|
||||
from ..models.utils_model import LLM_request
|
||||
from ..utils.typo_generator import ChineseTypoGenerator
|
||||
from .config import global_config
|
||||
from ..config.config import global_config
|
||||
from .message import MessageRecv, Message
|
||||
from ..message.message_base import UserInfo
|
||||
from .chat_stream import ChatStream
|
||||
@@ -59,61 +58,6 @@ async def get_embedding(text, request_type="embedding"):
|
||||
return await llm.get_embedding(text)
|
||||
|
||||
|
||||
def calculate_information_content(text):
|
||||
"""计算文本的信息量(熵)"""
|
||||
char_count = Counter(text)
|
||||
total_chars = len(text)
|
||||
|
||||
entropy = 0
|
||||
for count in char_count.values():
|
||||
probability = count / total_chars
|
||||
entropy -= probability * math.log2(probability)
|
||||
|
||||
return entropy
|
||||
|
||||
|
||||
def get_closest_chat_from_db(length: int, timestamp: str):
|
||||
# print(f"获取最接近指定时间戳的聊天记录,长度: {length}, 时间戳: {timestamp}")
|
||||
# print(f"当前时间: {timestamp},转换后时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(timestamp))}")
|
||||
chat_records = []
|
||||
closest_record = db.messages.find_one({"time": {"$lte": timestamp}}, sort=[("time", -1)])
|
||||
# print(f"最接近的记录: {closest_record}")
|
||||
if closest_record:
|
||||
closest_time = closest_record["time"]
|
||||
chat_id = closest_record["chat_id"] # 获取chat_id
|
||||
# 获取该时间戳之后的length条消息,保持相同的chat_id
|
||||
chat_records = list(
|
||||
db.messages.find(
|
||||
{
|
||||
"time": {"$gt": closest_time},
|
||||
"chat_id": chat_id, # 添加chat_id过滤
|
||||
}
|
||||
)
|
||||
.sort("time", 1)
|
||||
.limit(length)
|
||||
)
|
||||
# print(f"获取到的记录: {chat_records}")
|
||||
length = len(chat_records)
|
||||
# print(f"获取到的记录长度: {length}")
|
||||
# 转换记录格式
|
||||
formatted_records = []
|
||||
for record in chat_records:
|
||||
# 兼容行为,前向兼容老数据
|
||||
formatted_records.append(
|
||||
{
|
||||
"_id": record["_id"],
|
||||
"time": record["time"],
|
||||
"chat_id": record["chat_id"],
|
||||
"detailed_plain_text": record.get("detailed_plain_text", ""), # 添加文本内容
|
||||
"memorized_times": record.get("memorized_times", 0), # 添加记忆次数
|
||||
}
|
||||
)
|
||||
|
||||
return formatted_records
|
||||
|
||||
return []
|
||||
|
||||
|
||||
async def get_recent_group_messages(chat_id: str, limit: int = 12) -> list:
|
||||
"""从数据库获取群组最近的消息记录
|
||||
|
||||
@@ -241,21 +185,17 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]:
|
||||
List[str]: 分割后的句子列表
|
||||
"""
|
||||
len_text = len(text)
|
||||
if len_text < 5:
|
||||
if len_text < 4:
|
||||
if random.random() < 0.01:
|
||||
return list(text) # 如果文本很短且触发随机条件,直接按字符分割
|
||||
else:
|
||||
return [text]
|
||||
if len_text < 12:
|
||||
split_strength = 0.3
|
||||
split_strength = 0.2
|
||||
elif len_text < 32:
|
||||
split_strength = 0.7
|
||||
split_strength = 0.6
|
||||
else:
|
||||
split_strength = 0.9
|
||||
# 先移除换行符
|
||||
# print(f"split_strength: {split_strength}")
|
||||
|
||||
# print(f"处理前的文本: {text}")
|
||||
split_strength = 0.7
|
||||
|
||||
# 检查是否为西文字符段落
|
||||
if not is_western_paragraph(text):
|
||||
@@ -345,7 +285,7 @@ def random_remove_punctuation(text: str) -> str:
|
||||
|
||||
for i, char in enumerate(text):
|
||||
if char == "。" and i == text_len - 1: # 结尾的句号
|
||||
if random.random() > 0.4: # 80%概率删除结尾句号
|
||||
if random.random() > 0.1: # 90%概率删除结尾句号
|
||||
continue
|
||||
elif char == ",":
|
||||
rand = random.random()
|
||||
@@ -361,7 +301,9 @@ def random_remove_punctuation(text: str) -> str:
|
||||
def process_llm_response(text: str) -> List[str]:
|
||||
# processed_response = process_text_with_typos(content)
|
||||
# 对西文字符段落的回复长度设置为汉字字符的两倍
|
||||
if len(text) > 100 and not is_western_paragraph(text):
|
||||
max_length = global_config.response_max_length
|
||||
max_sentence_num = global_config.response_max_sentence_num
|
||||
if len(text) > max_length and not is_western_paragraph(text):
|
||||
logger.warning(f"回复过长 ({len(text)} 字符),返回默认回复")
|
||||
return ["懒得说"]
|
||||
elif len(text) > 200:
|
||||
@@ -374,7 +316,10 @@ def process_llm_response(text: str) -> List[str]:
|
||||
tone_error_rate=global_config.chinese_typo_tone_error_rate,
|
||||
word_replace_rate=global_config.chinese_typo_word_replace_rate,
|
||||
)
|
||||
split_sentences = split_into_sentences_w_remove_punctuation(text)
|
||||
if global_config.enable_response_spliter:
|
||||
split_sentences = split_into_sentences_w_remove_punctuation(text)
|
||||
else:
|
||||
split_sentences = [text]
|
||||
sentences = []
|
||||
for sentence in split_sentences:
|
||||
if global_config.chinese_typo_enable:
|
||||
@@ -386,14 +331,14 @@ def process_llm_response(text: str) -> List[str]:
|
||||
sentences.append(sentence)
|
||||
# 检查分割后的消息数量是否过多(超过3条)
|
||||
|
||||
if len(sentences) > 3:
|
||||
if len(sentences) > max_sentence_num:
|
||||
logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
|
||||
return [f"{global_config.BOT_NICKNAME}不知道哦"]
|
||||
|
||||
return sentences
|
||||
|
||||
|
||||
def calculate_typing_time(input_string: str, chinese_time: float = 0.4, english_time: float = 0.2) -> float:
|
||||
def calculate_typing_time(input_string: str, chinese_time: float = 0.2, english_time: float = 0.1) -> float:
|
||||
"""
|
||||
计算输入字符串所需的时间,中文和英文字符有不同的输入时间
|
||||
input_string (str): 输入的字符串
|
||||
|
||||
Reference in New Issue
Block a user