fix
This commit is contained in:
@@ -65,24 +65,20 @@ class ExpressionLearner:
|
|||||||
self.chat_id = chat_id
|
self.chat_id = chat_id
|
||||||
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
|
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
|
||||||
|
|
||||||
|
|
||||||
# 维护每个chat的上次学习时间
|
# 维护每个chat的上次学习时间
|
||||||
self.last_learning_time: float = time.time()
|
self.last_learning_time: float = time.time()
|
||||||
|
|
||||||
# 学习参数
|
# 学习参数
|
||||||
self.min_messages_for_learning = 25 # 触发学习所需的最少消息数
|
self.min_messages_for_learning = 25 # 触发学习所需的最少消息数
|
||||||
self.min_learning_interval = 300 # 最短学习时间间隔(秒)
|
self.min_learning_interval = 300 # 最短学习时间间隔(秒)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def can_learn_for_chat(self) -> bool:
|
def can_learn_for_chat(self) -> bool:
|
||||||
"""
|
"""
|
||||||
检查指定聊天流是否允许学习表达
|
检查指定聊天流是否允许学习表达
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chat_id: 聊天流ID
|
chat_id: 聊天流ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: 是否允许学习
|
bool: 是否允许学习
|
||||||
"""
|
"""
|
||||||
@@ -96,10 +92,10 @@ class ExpressionLearner:
|
|||||||
def should_trigger_learning(self) -> bool:
|
def should_trigger_learning(self) -> bool:
|
||||||
"""
|
"""
|
||||||
检查是否应该触发学习
|
检查是否应该触发学习
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chat_id: 聊天流ID
|
chat_id: 聊天流ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: 是否应该触发学习
|
bool: 是否应该触发学习
|
||||||
"""
|
"""
|
||||||
@@ -107,23 +103,25 @@ class ExpressionLearner:
|
|||||||
|
|
||||||
# 获取该聊天流的学习强度
|
# 获取该聊天流的学习强度
|
||||||
try:
|
try:
|
||||||
_, enable_learning, learning_intensity = global_config.expression.get_expression_config_for_chat(self.chat_id)
|
_, enable_learning, learning_intensity = global_config.expression.get_expression_config_for_chat(
|
||||||
|
self.chat_id
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取聊天流 {self.chat_id} 的学习配置失败: {e}")
|
logger.error(f"获取聊天流 {self.chat_id} 的学习配置失败: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 检查是否允许学习
|
# 检查是否允许学习
|
||||||
if not enable_learning:
|
if not enable_learning:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 根据学习强度计算最短学习时间间隔
|
# 根据学习强度计算最短学习时间间隔
|
||||||
min_interval = self.min_learning_interval / learning_intensity
|
min_interval = self.min_learning_interval / learning_intensity
|
||||||
|
|
||||||
# 检查时间间隔
|
# 检查时间间隔
|
||||||
time_diff = current_time - self.last_learning_time
|
time_diff = current_time - self.last_learning_time
|
||||||
if time_diff < min_interval:
|
if time_diff < min_interval:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 检查消息数量(只检查指定聊天流的消息)
|
# 检查消息数量(只检查指定聊天流的消息)
|
||||||
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||||
chat_id=self.chat_id,
|
chat_id=self.chat_id,
|
||||||
@@ -133,69 +131,42 @@ class ExpressionLearner:
|
|||||||
|
|
||||||
if not recent_messages or len(recent_messages) < self.min_messages_for_learning:
|
if not recent_messages or len(recent_messages) < self.min_messages_for_learning:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def trigger_learning_for_chat(self) -> bool:
|
async def trigger_learning_for_chat(self) -> bool:
|
||||||
"""
|
"""
|
||||||
为指定聊天流触发学习
|
为指定聊天流触发学习
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chat_id: 聊天流ID
|
chat_id: 聊天流ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: 是否成功触发学习
|
bool: 是否成功触发学习
|
||||||
"""
|
"""
|
||||||
if not self.should_trigger_learning():
|
if not self.should_trigger_learning():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"为聊天流 {self.chat_name} 触发表达学习")
|
logger.info(f"为聊天流 {self.chat_name} 触发表达学习")
|
||||||
|
|
||||||
# 学习语言风格
|
# 学习语言风格
|
||||||
learnt_style = await self.learn_and_store(num=25)
|
learnt_style = await self.learn_and_store(num=25)
|
||||||
|
|
||||||
# 更新学习时间
|
# 更新学习时间
|
||||||
self.last_learning_time = time.time()
|
self.last_learning_time = time.time()
|
||||||
|
|
||||||
if learnt_style:
|
if learnt_style:
|
||||||
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
|
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
|
logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
|
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
|
|
||||||
# """
|
|
||||||
# 获取指定chat_id的style表达方式(已禁用grammar的获取)
|
|
||||||
# 返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
|
||||||
# """
|
|
||||||
# learnt_style_expressions = []
|
|
||||||
|
|
||||||
# # 直接从数据库查询
|
|
||||||
# style_query = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.type == "style"))
|
|
||||||
# for expr in style_query:
|
|
||||||
# # 确保create_date存在,如果不存在则使用last_active_time
|
|
||||||
# create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
|
||||||
# learnt_style_expressions.append(
|
|
||||||
# {
|
|
||||||
# "situation": expr.situation,
|
|
||||||
# "style": expr.style,
|
|
||||||
# "count": expr.count,
|
|
||||||
# "last_active_time": expr.last_active_time,
|
|
||||||
# "source_id": self.chat_id,
|
|
||||||
# "type": "style",
|
|
||||||
# "create_date": create_date,
|
|
||||||
# }
|
|
||||||
# )
|
|
||||||
# return learnt_style_expressions
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_global_decay_to_database(self, current_time: float) -> None:
|
def _apply_global_decay_to_database(self, current_time: float) -> None:
|
||||||
"""
|
"""
|
||||||
对数据库中的所有表达方式应用全局衰减
|
对数据库中的所有表达方式应用全局衰减
|
||||||
@@ -345,7 +316,7 @@ class ExpressionLearner:
|
|||||||
prompt = "learn_style_prompt"
|
prompt = "learn_style_prompt"
|
||||||
|
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
# 获取上次学习时间
|
# 获取上次学习时间
|
||||||
random_msg = get_raw_msg_by_timestamp_with_chat_inclusive(
|
random_msg = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||||
chat_id=self.chat_id,
|
chat_id=self.chat_id,
|
||||||
@@ -414,19 +385,20 @@ class ExpressionLearner:
|
|||||||
|
|
||||||
init_prompt()
|
init_prompt()
|
||||||
|
|
||||||
|
|
||||||
class ExpressionLearnerManager:
|
class ExpressionLearnerManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.expression_learners = {}
|
self.expression_learners = {}
|
||||||
|
|
||||||
self._ensure_expression_directories()
|
self._ensure_expression_directories()
|
||||||
self._auto_migrate_json_to_db()
|
self._auto_migrate_json_to_db()
|
||||||
self._migrate_old_data_create_date()
|
self._migrate_old_data_create_date()
|
||||||
|
|
||||||
def get_expression_learner(self, chat_id: str) -> ExpressionLearner:
|
def get_expression_learner(self, chat_id: str) -> ExpressionLearner:
|
||||||
if chat_id not in self.expression_learners:
|
if chat_id not in self.expression_learners:
|
||||||
self.expression_learners[chat_id] = ExpressionLearner(chat_id)
|
self.expression_learners[chat_id] = ExpressionLearner(chat_id)
|
||||||
return self.expression_learners[chat_id]
|
return self.expression_learners[chat_id]
|
||||||
|
|
||||||
def _ensure_expression_directories(self):
|
def _ensure_expression_directories(self):
|
||||||
"""
|
"""
|
||||||
确保表达方式相关的目录结构存在
|
确保表达方式相关的目录结构存在
|
||||||
@@ -445,7 +417,6 @@ class ExpressionLearnerManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"创建目录失败 {directory}: {e}")
|
logger.error(f"创建目录失败 {directory}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def _auto_migrate_json_to_db(self):
|
def _auto_migrate_json_to_db(self):
|
||||||
"""
|
"""
|
||||||
自动将/data/expression/learnt_style 和 learnt_grammar 下所有expressions.json迁移到数据库。
|
自动将/data/expression/learnt_style 和 learnt_grammar 下所有expressions.json迁移到数据库。
|
||||||
@@ -564,7 +535,7 @@ class ExpressionLearnerManager:
|
|||||||
try:
|
try:
|
||||||
deleted_count = self.delete_all_grammar_expressions()
|
deleted_count = self.delete_all_grammar_expressions()
|
||||||
logger.info(f"grammar表达删除完成,共删除 {deleted_count} 个表达")
|
logger.info(f"grammar表达删除完成,共删除 {deleted_count} 个表达")
|
||||||
|
|
||||||
# 创建done.done2标记文件
|
# 创建done.done2标记文件
|
||||||
with open(done_flag2, "w", encoding="utf-8") as f:
|
with open(done_flag2, "w", encoding="utf-8") as f:
|
||||||
f.write("done\n")
|
f.write("done\n")
|
||||||
@@ -598,7 +569,7 @@ class ExpressionLearnerManager:
|
|||||||
def delete_all_grammar_expressions(self) -> int:
|
def delete_all_grammar_expressions(self) -> int:
|
||||||
"""
|
"""
|
||||||
检查expression库中所有type为"grammar"的表达并全部删除
|
检查expression库中所有type为"grammar"的表达并全部删除
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int: 删除的grammar表达数量
|
int: 删除的grammar表达数量
|
||||||
"""
|
"""
|
||||||
@@ -606,13 +577,13 @@ class ExpressionLearnerManager:
|
|||||||
# 查询所有type为"grammar"的表达
|
# 查询所有type为"grammar"的表达
|
||||||
grammar_expressions = Expression.select().where(Expression.type == "grammar")
|
grammar_expressions = Expression.select().where(Expression.type == "grammar")
|
||||||
grammar_count = grammar_expressions.count()
|
grammar_count = grammar_expressions.count()
|
||||||
|
|
||||||
if grammar_count == 0:
|
if grammar_count == 0:
|
||||||
logger.info("expression库中没有找到grammar类型的表达")
|
logger.info("expression库中没有找到grammar类型的表达")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
logger.info(f"找到 {grammar_count} 个grammar类型的表达,开始删除...")
|
logger.info(f"找到 {grammar_count} 个grammar类型的表达,开始删除...")
|
||||||
|
|
||||||
# 删除所有grammar类型的表达
|
# 删除所有grammar类型的表达
|
||||||
deleted_count = 0
|
deleted_count = 0
|
||||||
for expr in grammar_expressions:
|
for expr in grammar_expressions:
|
||||||
@@ -622,10 +593,10 @@ class ExpressionLearnerManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"删除grammar表达失败: {e}")
|
logger.error(f"删除grammar表达失败: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.info(f"成功删除 {deleted_count} 个grammar类型的表达")
|
logger.info(f"成功删除 {deleted_count} 个grammar类型的表达")
|
||||||
return deleted_count
|
return deleted_count
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"删除grammar表达过程中发生错误: {e}")
|
logger.error(f"删除grammar表达过程中发生错误: {e}")
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
@@ -4,9 +4,10 @@ import os
|
|||||||
import pickle
|
import pickle
|
||||||
import random
|
import random
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Dict, Any, TYPE_CHECKING
|
from typing import List, Dict, Any
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.person_info.relationship_manager import get_relationship_manager
|
from src.person_info.relationship_manager import get_relationship_manager
|
||||||
from src.person_info.person_info import Person, get_person_id
|
from src.person_info.person_info import Person, get_person_id
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
@@ -17,8 +18,6 @@ from src.chat.utils.chat_message_builder import (
|
|||||||
num_new_messages_since,
|
num_new_messages_since,
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
|
||||||
|
|
||||||
logger = get_logger("relationship_builder")
|
logger = get_logger("relationship_builder")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user