This commit is contained in:
UnCLAS-Prommer
2025-08-22 16:49:46 +08:00
parent d9bd8a10cb
commit 8862a50452
2 changed files with 34 additions and 64 deletions

View File

@@ -65,7 +65,6 @@ class ExpressionLearner:
self.chat_id = chat_id self.chat_id = chat_id
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
# 维护每个chat的上次学习时间 # 维护每个chat的上次学习时间
self.last_learning_time: float = time.time() self.last_learning_time: float = time.time()
@@ -73,9 +72,6 @@ class ExpressionLearner:
self.min_messages_for_learning = 25 # 触发学习所需的最少消息数 self.min_messages_for_learning = 25 # 触发学习所需的最少消息数
self.min_learning_interval = 300 # 最短学习时间间隔(秒) self.min_learning_interval = 300 # 最短学习时间间隔(秒)
def can_learn_for_chat(self) -> bool: def can_learn_for_chat(self) -> bool:
""" """
检查指定聊天流是否允许学习表达 检查指定聊天流是否允许学习表达
@@ -107,7 +103,9 @@ class ExpressionLearner:
# 获取该聊天流的学习强度 # 获取该聊天流的学习强度
try: try:
_, enable_learning, learning_intensity = global_config.expression.get_expression_config_for_chat(self.chat_id) _, enable_learning, learning_intensity = global_config.expression.get_expression_config_for_chat(
self.chat_id
)
except Exception as e: except Exception as e:
logger.error(f"获取聊天流 {self.chat_id} 的学习配置失败: {e}") logger.error(f"获取聊天流 {self.chat_id} 的学习配置失败: {e}")
return False return False
@@ -169,33 +167,6 @@ class ExpressionLearner:
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}") logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
return False return False
# def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
# """
# 获取指定chat_id的style表达方式已禁用grammar的获取
# 返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
# """
# learnt_style_expressions = []
# # 直接从数据库查询
# style_query = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.type == "style"))
# for expr in style_query:
# # 确保create_date存在如果不存在则使用last_active_time
# create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
# learnt_style_expressions.append(
# {
# "situation": expr.situation,
# "style": expr.style,
# "count": expr.count,
# "last_active_time": expr.last_active_time,
# "source_id": self.chat_id,
# "type": "style",
# "create_date": create_date,
# }
# )
# return learnt_style_expressions
def _apply_global_decay_to_database(self, current_time: float) -> None: def _apply_global_decay_to_database(self, current_time: float) -> None:
""" """
对数据库中的所有表达方式应用全局衰减 对数据库中的所有表达方式应用全局衰减
@@ -414,6 +385,7 @@ class ExpressionLearner:
init_prompt() init_prompt()
class ExpressionLearnerManager: class ExpressionLearnerManager:
def __init__(self): def __init__(self):
self.expression_learners = {} self.expression_learners = {}
@@ -445,7 +417,6 @@ class ExpressionLearnerManager:
except Exception as e: except Exception as e:
logger.error(f"创建目录失败 {directory}: {e}") logger.error(f"创建目录失败 {directory}: {e}")
def _auto_migrate_json_to_db(self): def _auto_migrate_json_to_db(self):
""" """
自动将/data/expression/learnt_style 和 learnt_grammar 下所有expressions.json迁移到数据库。 自动将/data/expression/learnt_style 和 learnt_grammar 下所有expressions.json迁移到数据库。

View File

@@ -4,9 +4,10 @@ import os
import pickle import pickle
import random import random
import asyncio import asyncio
from typing import List, Dict, Any, TYPE_CHECKING from typing import List, Dict, Any
from src.config.config import global_config from src.config.config import global_config
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.person_info.relationship_manager import get_relationship_manager from src.person_info.relationship_manager import get_relationship_manager
from src.person_info.person_info import Person, get_person_id from src.person_info.person_info import Person, get_person_id
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
@@ -17,8 +18,6 @@ from src.chat.utils.chat_message_builder import (
num_new_messages_since, num_new_messages_since,
) )
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger("relationship_builder") logger = get_logger("relationship_builder")