Files
mai-bot/src/chat/express/expression_selector.py
2025-06-25 12:39:39 +00:00

257 lines
9.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from .exprssion_learner import get_expression_learner
import random
from typing import List, Dict, Tuple
from json_repair import repair_json
import json
import os
import time
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
logger = get_logger("expression_selector")
def init_prompt():
expression_evaluation_prompt = """
你的名字是{bot_name}
以下是正在进行的聊天内容:
{chat_observe_info}
以下是可选的表达情境:
{all_situations}
请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的{min_num}-{max_num}个情境。
考虑因素包括:
1. 聊天的情绪氛围(轻松、严肃、幽默等)
2. 话题类型(日常、技术、游戏、情感等)
3. 情境与当前语境的匹配度
请以JSON格式输出只需要输出选中的情境编号
例如:
{{
"selected_situations": [2, 3, 5, 7, 19, 22, 25, 38, 39, 45, 48 , 64]
}}
例如:
{{
"selected_situations": [1, 4, 7, 9, 23, 38, 44]
}}
请严格按照JSON格式输出不要包含其他内容
"""
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
def weighted_sample(population: List[Dict], weights: List[float], k: int) -> List[Dict]:
"""按权重随机抽样"""
if not population or not weights or k <= 0:
return []
if len(population) <= k:
return population.copy()
# 使用累积权重的方法进行加权抽样
selected = []
population_copy = population.copy()
weights_copy = weights.copy()
for _ in range(k):
if not population_copy:
break
# 选择一个元素
chosen_idx = random.choices(range(len(population_copy)), weights=weights_copy)[0]
selected.append(population_copy.pop(chosen_idx))
weights_copy.pop(chosen_idx)
return selected
class ExpressionSelector:
def __init__(self):
self.expression_learner = get_expression_learner()
# TODO: API-Adapter修改标记
self.llm_model = LLMRequest(
model=global_config.model.utils_small,
request_type="expression.selector",
)
def get_random_expressions(
self, chat_id: str, style_num: int, grammar_num: int, personality_num: int
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
(
learnt_style_expressions,
learnt_grammar_expressions,
personality_expressions,
) = self.expression_learner.get_expression_by_chat_id(chat_id)
# 按权重抽样使用count作为权重
if learnt_style_expressions:
style_weights = [expr.get("count", 1) for expr in learnt_style_expressions]
selected_style = weighted_sample(learnt_style_expressions, style_weights, style_num)
else:
selected_style = []
if learnt_grammar_expressions:
grammar_weights = [expr.get("count", 1) for expr in learnt_grammar_expressions]
selected_grammar = weighted_sample(learnt_grammar_expressions, grammar_weights, grammar_num)
else:
selected_grammar = []
if personality_expressions:
personality_weights = [expr.get("count", 1) for expr in personality_expressions]
selected_personality = weighted_sample(personality_expressions, personality_weights, personality_num)
else:
selected_personality = []
return selected_style, selected_grammar, selected_personality
def update_expression_count(self, chat_id: str, expression: Dict[str, str], increment: float = 0.1):
"""更新表达方式的count值
Args:
chat_id: 聊天ID
expression: 表达方式字典
increment: 增量值默认0.1
"""
if expression.get("type") == "style_personality":
# personality表达方式存储在全局文件中
file_path = os.path.join("data", "expression", "personality", "expressions.json")
else:
# style和grammar表达方式存储在对应chat_id目录中
expr_type = expression.get("type", "style")
if expr_type == "style":
file_path = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
elif expr_type == "grammar":
file_path = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json")
else:
return
if not os.path.exists(file_path):
return
try:
with open(file_path, "r", encoding="utf-8") as f:
expressions = json.load(f)
# 找到匹配的表达方式并更新count
for expr in expressions:
if expr.get("situation") == expression.get("situation") and expr.get("style") == expression.get(
"style"
):
current_count = expr.get("count", 1)
# 简单加0.1但限制最高为5
new_count = min(current_count + increment, 5.0)
expr["count"] = new_count
expr["last_active_time"] = time.time()
logger.info(f"表达方式激活: 原count={current_count:.2f}, 增量={increment}, 新count={new_count:.2f}")
break
# 保存更新后的文件
with open(file_path, "w", encoding="utf-8") as f:
json.dump(expressions, f, ensure_ascii=False, indent=2)
except Exception as e:
logger.error(f"更新表达方式count失败: {e}")
async def select_suitable_expressions_llm(
self, chat_id: str, chat_info: str, max_num: int = 10, min_num: int = 5
) -> List[Dict[str, str]]:
"""使用LLM选择适合的表达方式"""
# 1. 获取35个随机表达方式现在按权重抽取
style_exprs, grammar_exprs, personality_exprs = self.get_random_expressions(chat_id, 25, 25, 10)
# 2. 构建所有表达方式的索引和情境列表
all_expressions = []
all_situations = []
# 添加style表达方式
for expr in style_exprs:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
expr_with_type = expr.copy()
expr_with_type["type"] = "style"
all_expressions.append(expr_with_type)
all_situations.append(f"{len(all_expressions)}.{expr['situation']}")
# 添加grammar表达方式
for expr in grammar_exprs:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
expr_with_type = expr.copy()
expr_with_type["type"] = "grammar"
all_expressions.append(expr_with_type)
all_situations.append(f"{len(all_expressions)}.{expr['situation']}")
# 添加personality表达方式
for expr in personality_exprs:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
expr_with_type = expr.copy()
expr_with_type["type"] = "style_personality"
all_expressions.append(expr_with_type)
all_situations.append(f"{len(all_expressions)}.{expr['situation']}")
if not all_expressions:
logger.warning("没有找到可用的表达方式")
return []
all_situations_str = "\n".join(all_situations)
# 3. 构建prompt只包含情境不包含完整的表达方式
prompt = (await global_prompt_manager.get_prompt_async("expression_evaluation_prompt")).format(
bot_name=global_config.bot.nickname,
chat_observe_info=chat_info,
all_situations=all_situations_str,
min_num=min_num,
max_num=max_num,
)
# 4. 调用LLM
try:
content, (_, _) = await self.llm_model.generate_response_async(prompt=prompt)
# logger.info(f"{self.log_prefix} LLM返回结果: {content}")
if not content:
logger.warning("LLM返回空结果")
return []
# 5. 解析结果
result = repair_json(content)
if isinstance(result, str):
result = json.loads(result)
if not isinstance(result, dict) or "selected_situations" not in result:
logger.error("LLM返回格式错误")
return []
selected_indices = result["selected_situations"]
# 根据索引获取完整的表达方式
valid_expressions = []
for idx in selected_indices:
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
expression = all_expressions[idx - 1] # 索引从1开始
valid_expressions.append(expression)
# 对选中的表达方式count数+0.1
self.update_expression_count(chat_id, expression, 0.1)
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
return valid_expressions
except Exception as e:
logger.error(f"LLM处理表达方式选择时出错: {e}")
return []
init_prompt()
try:
expression_selector = ExpressionSelector()
except Exception as e:
print(f"ExpressionSelector初始化失败: {e}")