Merge remote-tracking branch 'upstream/main-fix' into refactor

This commit is contained in:
tcmofashi
2025-03-28 10:56:47 +08:00
48 changed files with 4258 additions and 3149 deletions

View File

@@ -1,4 +1,3 @@
import math
import random
import time
import re
@@ -11,7 +10,7 @@ from src.common.logger import get_module_logger
from ..models.utils_model import LLM_request
from ..utils.typo_generator import ChineseTypoGenerator
from .config import global_config
from ..config.config import global_config
from .message import MessageRecv, Message
from ..message.message_base import UserInfo
from .chat_stream import ChatStream
@@ -59,61 +58,6 @@ async def get_embedding(text, request_type="embedding"):
return await llm.get_embedding(text)
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
char_count = Counter(text)
total_chars = len(text)
entropy = 0
for count in char_count.values():
probability = count / total_chars
entropy -= probability * math.log2(probability)
return entropy
def get_closest_chat_from_db(length: int, timestamp: str):
# print(f"获取最接近指定时间戳的聊天记录,长度: {length}, 时间戳: {timestamp}")
# print(f"当前时间: {timestamp},转换后时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(timestamp))}")
chat_records = []
closest_record = db.messages.find_one({"time": {"$lte": timestamp}}, sort=[("time", -1)])
# print(f"最接近的记录: {closest_record}")
if closest_record:
closest_time = closest_record["time"]
chat_id = closest_record["chat_id"] # 获取chat_id
# 获取该时间戳之后的length条消息保持相同的chat_id
chat_records = list(
db.messages.find(
{
"time": {"$gt": closest_time},
"chat_id": chat_id, # 添加chat_id过滤
}
)
.sort("time", 1)
.limit(length)
)
# print(f"获取到的记录: {chat_records}")
length = len(chat_records)
# print(f"获取到的记录长度: {length}")
# 转换记录格式
formatted_records = []
for record in chat_records:
# 兼容行为,前向兼容老数据
formatted_records.append(
{
"_id": record["_id"],
"time": record["time"],
"chat_id": record["chat_id"],
"detailed_plain_text": record.get("detailed_plain_text", ""), # 添加文本内容
"memorized_times": record.get("memorized_times", 0), # 添加记忆次数
}
)
return formatted_records
return []
async def get_recent_group_messages(chat_id: str, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录
@@ -241,21 +185,17 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]:
List[str]: 分割后的句子列表
"""
len_text = len(text)
if len_text < 5:
if len_text < 4:
if random.random() < 0.01:
return list(text) # 如果文本很短且触发随机条件,直接按字符分割
else:
return [text]
if len_text < 12:
split_strength = 0.3
split_strength = 0.2
elif len_text < 32:
split_strength = 0.7
split_strength = 0.6
else:
split_strength = 0.9
# 先移除换行符
# print(f"split_strength: {split_strength}")
# print(f"处理前的文本: {text}")
split_strength = 0.7
# 检查是否为西文字符段落
if not is_western_paragraph(text):
@@ -345,7 +285,7 @@ def random_remove_punctuation(text: str) -> str:
for i, char in enumerate(text):
if char == "" and i == text_len - 1: # 结尾的句号
if random.random() > 0.4: # 80%概率删除结尾句号
if random.random() > 0.1: # 90%概率删除结尾句号
continue
elif char == "":
rand = random.random()
@@ -361,7 +301,9 @@ def random_remove_punctuation(text: str) -> str:
def process_llm_response(text: str) -> List[str]:
# processed_response = process_text_with_typos(content)
# 对西文字符段落的回复长度设置为汉字字符的两倍
if len(text) > 100 and not is_western_paragraph(text):
max_length = global_config.response_max_length
max_sentence_num = global_config.response_max_sentence_num
if len(text) > max_length and not is_western_paragraph(text):
logger.warning(f"回复过长 ({len(text)} 字符),返回默认回复")
return ["懒得说"]
elif len(text) > 200:
@@ -374,7 +316,10 @@ def process_llm_response(text: str) -> List[str]:
tone_error_rate=global_config.chinese_typo_tone_error_rate,
word_replace_rate=global_config.chinese_typo_word_replace_rate,
)
split_sentences = split_into_sentences_w_remove_punctuation(text)
if global_config.enable_response_spliter:
split_sentences = split_into_sentences_w_remove_punctuation(text)
else:
split_sentences = [text]
sentences = []
for sentence in split_sentences:
if global_config.chinese_typo_enable:
@@ -386,14 +331,14 @@ def process_llm_response(text: str) -> List[str]:
sentences.append(sentence)
# 检查分割后的消息数量是否过多超过3条
if len(sentences) > 3:
if len(sentences) > max_sentence_num:
logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
return [f"{global_config.BOT_NICKNAME}不知道哦"]
return sentences
def calculate_typing_time(input_string: str, chinese_time: float = 0.4, english_time: float = 0.2) -> float:
def calculate_typing_time(input_string: str, chinese_time: float = 0.2, english_time: float = 0.1) -> float:
"""
计算输入字符串所需的时间,中文和英文字符有不同的输入时间
input_string (str): 输入的字符串