This commit is contained in:
SengokuCola
2025-09-28 02:04:49 +08:00
30 changed files with 203 additions and 188 deletions

View File

@@ -10,11 +10,14 @@ from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages, build_bare_messages
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp_with_chat_inclusive,
build_anonymous_messages,
build_bare_messages,
)
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from json_repair import repair_json
from src.chat.utils.utils import get_embedding
MAX_EXPRESSION_COUNT = 300
@@ -99,7 +102,9 @@ class ExpressionLearner:
self.last_learning_time: float = time.time()
# 学习参数
_, self.enable_learning, self.learning_intensity = global_config.expression.get_expression_config_for_chat(self.chat_id)
_, self.enable_learning, self.learning_intensity = global_config.expression.get_expression_config_for_chat(
self.chat_id
)
self.min_messages_for_learning = 15 / self.learning_intensity # 触发学习所需的最少消息数
self.min_learning_interval = 150 / self.learning_intensity
@@ -237,17 +242,42 @@ class ExpressionLearner:
return []
learnt_expressions = res
learnt_expressions_str = ""
for _chat_id, situation, style, context, context_words, full_context, full_context_embedding in learnt_expressions:
for (
_chat_id,
situation,
style,
_context,
_context_words,
_full_context,
_full_context_embedding,
) in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n"
logger.info(f"{self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
# 按chat_id分组
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
for chat_id, situation, style, context, context_words, full_context, full_context_embedding in learnt_expressions:
for (
chat_id,
situation,
style,
context,
context_words,
full_context,
full_context_embedding,
) in learnt_expressions:
if chat_id not in chat_dict:
chat_dict[chat_id] = []
chat_dict[chat_id].append({"situation": situation, "style": style, "context": context, "context_words": context_words, "full_context": full_context, "full_context_embedding": full_context_embedding})
chat_dict[chat_id].append(
{
"situation": situation,
"style": style,
"context": context,
"context_words": context_words,
"full_context": full_context,
"full_context_embedding": full_context_embedding,
}
)
current_time = time.time()
@@ -300,11 +330,13 @@ class ExpressionLearner:
expr.delete_instance()
return learnt_expressions
async def match_expression_context(self, expression_pairs: List[Tuple[str, str]], random_msg_match_str: str) -> List[Tuple[str, str, str]]:
async def match_expression_context(
self, expression_pairs: List[Tuple[str, str]], random_msg_match_str: str
) -> List[Tuple[str, str, str]]:
# 为expression_pairs逐个条目赋予编号并构建成字符串
numbered_pairs = []
for i, (situation, style) in enumerate(expression_pairs, 1):
numbered_pairs.append(f"{i}. 当\"{situation}\"时,使用\"{style}\"")
numbered_pairs.append(f'{i}. 当"{situation}"时,使用"{style}"')
expression_pairs_str = "\n".join(numbered_pairs)
@@ -319,20 +351,20 @@ class ExpressionLearner:
print(f"match_expression_context_prompt: {prompt}")
print(f"random_msg_match_str: {response}")
# 解析JSON响应
match_responses = []
try:
response = response.strip()
# 检查是否已经是标准JSON数组格式
if response.startswith('[') and response.endswith(']'):
if response.startswith("[") and response.endswith("]"):
match_responses = json.loads(response)
else:
# 尝试直接解析多个JSON对象
try:
# 如果是多个JSON对象用逗号分隔包装成数组
if response.startswith('{') and not response.startswith('['):
response = '[' + response + ']'
if response.startswith("{") and not response.startswith("["):
response = "[" + response + "]"
match_responses = json.loads(response)
else:
# 使用repair_json处理响应
@@ -394,7 +426,9 @@ class ExpressionLearner:
return matched_expressions
async def learn_expression(self, num: int = 10) -> Optional[List[Tuple[str, str, str, List[str], str, List[float]]]]:
async def learn_expression(
self, num: int = 10
) -> Optional[List[Tuple[str, str, str, List[str], str, List[float]]]]:
"""从指定聊天流学习表达方式
Args:
@@ -416,11 +450,10 @@ class ExpressionLearner:
if not random_msg or random_msg == []:
return None
# 转化成str
chat_id: str = random_msg[0].chat_id
_chat_id: str = random_msg[0].chat_id
# random_msg_str: str = build_readable_messages(random_msg, timestamp_mode="normal")
random_msg_str: str = await build_anonymous_messages(random_msg)
random_msg_match_str: str = await build_bare_messages(random_msg)
prompt: str = await global_prompt_manager.format_prompt(
prompt,
@@ -440,24 +473,31 @@ class ExpressionLearner:
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(expressions, random_msg_match_str)
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
expressions, random_msg_match_str
)
split_matched_expressions: List[Tuple[str, str, str, List[str]]] = self.split_expression_context(
matched_expressions
)
split_matched_expressions: List[Tuple[str, str, str, List[str]]] = self.split_expression_context(matched_expressions)
split_matched_expressions_w_emb = []
full_context_embedding: List[float] = await self.get_full_context_embedding(random_msg_match_str)
for situation, style, context, context_words in split_matched_expressions:
split_matched_expressions_w_emb.append((self.chat_id, situation, style, context, context_words, random_msg_match_str,full_context_embedding))
for situation, style, context, context_words in split_matched_expressions:
split_matched_expressions_w_emb.append(
(self.chat_id, situation, style, context, context_words, random_msg_match_str, full_context_embedding)
)
return split_matched_expressions_w_emb
async def get_full_context_embedding(self, context: str) -> List[float]:
embedding, _ = await self.embedding_model.get_embedding(context)
return embedding
def split_expression_context(self, matched_expressions: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str, List[str]]]:
def split_expression_context(
self, matched_expressions: List[Tuple[str, str, str]]
) -> List[Tuple[str, str, str, List[str]]]:
"""
对matched_expressions中的context部分进行jieba分词