Merge branch 'dev'
This commit is contained in:
794
src/bw_learner/expression_learner.py
Normal file
794
src/bw_learner/expression_learner.py
Normal file
@@ -0,0 +1,794 @@
|
||||
import time
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import asyncio
|
||||
from typing import List, Optional, Tuple, Any, Dict
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_anonymous_messages,
|
||||
)
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.bw_learner.learner_utils import (
|
||||
filter_message_content,
|
||||
is_bot_message,
|
||||
build_context_paragraph,
|
||||
contains_bot_self_name,
|
||||
calculate_style_similarity,
|
||||
)
|
||||
from src.bw_learner.jargon_miner import miner_manager
|
||||
from json_repair import repair_json
|
||||
|
||||
|
||||
# MAX_EXPRESSION_COUNT = 300
|
||||
|
||||
logger = get_logger("expressor")
|
||||
|
||||
|
||||
def init_prompt() -> None:
|
||||
learn_style_prompt = """{chat_str}
|
||||
你的名字是{bot_name},现在请你完成两个提取任务
|
||||
任务1:请从上面这段群聊中用户的语言风格和说话方式
|
||||
1. 只考虑文字,不要考虑表情包和图片
|
||||
2. 不要总结SELF的发言,因为这是你自己的发言,不要重复学习你自己的发言
|
||||
3. 不要涉及具体的人名,也不要涉及具体名词
|
||||
4. 思考有没有特殊的梗,一并总结成语言风格
|
||||
5. 例子仅供参考,请严格根据群聊内容总结!!!
|
||||
注意:总结成如下格式的规律,总结的内容要详细,但具有概括性:
|
||||
例如:当"AAAAA"时,可以"BBBBB", AAAAA代表某个场景,不超过20个字。BBBBB代表对应的语言风格,特定句式或表达方式,不超过20个字。
|
||||
表达方式在3-5个左右,不要超过10个
|
||||
|
||||
|
||||
任务2:请从上面这段聊天内容中提取"可能是黑话"的候选项(黑话/俚语/网络缩写/口头禅)。
|
||||
- 必须为对话中真实出现过的短词或短语
|
||||
- 必须是你无法理解含义的词语,没有明确含义的词语,请不要选择有明确含义,或者含义清晰的词语
|
||||
- 排除:人名、@、表情包/图片中的内容、纯标点、常规功能词(如的、了、呢、啊等)
|
||||
- 每个词条长度建议 2-8 个字符(不强制),尽量短小
|
||||
- 请你提取出可能的黑话,最多30个黑话,请尽量提取所有
|
||||
|
||||
黑话必须为以下几种类型:
|
||||
- 由字母构成的,汉语拼音首字母的简写词,例如:nb、yyds、xswl
|
||||
- 英文词语的缩写,用英文字母概括一个词汇或含义,例如:CPU、GPU、API
|
||||
- 中文词语的缩写,用几个汉字概括一个词汇或含义,例如:社死、内卷
|
||||
|
||||
输出要求:
|
||||
将表达方式,语言风格和黑话以 JSON 数组输出,每个元素为一个对象,结构如下(注意字段名):
|
||||
注意请不要输出重复内容,请对表达方式和黑话进行去重。
|
||||
|
||||
[
|
||||
{{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}},
|
||||
{{"situation": "CCCC", "style": "DDDD", "source_id": "7"}}
|
||||
{{"situation": "对某件事表示十分惊叹", "style": "使用 我嘞个xxxx", "source_id": "[消息编号]"}},
|
||||
{{"situation": "表示讽刺的赞同,不讲道理", "style": "对对对", "source_id": "[消息编号]"}},
|
||||
{{"situation": "当涉及游戏相关时,夸赞,略带戏谑意味", "style": "使用 这么强!", "source_id": "[消息编号]"}},
|
||||
{{"content": "词条", "source_id": "12"}},
|
||||
{{"content": "词条2", "source_id": "5"}}
|
||||
]
|
||||
|
||||
其中:
|
||||
表达方式条目:
|
||||
- situation:表示“在什么情境下”的简短概括(不超过20个字)
|
||||
- style:表示对应的语言风格或常用表达(不超过20个字)
|
||||
- source_id:该表达方式对应的“来源行编号”,即上方聊天记录中方括号里的数字(例如 [3]),请只输出数字本身,不要包含方括号
|
||||
黑话jargon条目:
|
||||
- content:表示黑话的内容
|
||||
- source_id:该黑话对应的“来源行编号”,即上方聊天记录中方括号里的数字(例如 [3]),请只输出数字本身,不要包含方括号
|
||||
|
||||
现在请你输出 JSON:
|
||||
"""
|
||||
Prompt(learn_style_prompt, "learn_style_prompt")
|
||||
|
||||
|
||||
class ExpressionLearner:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.express_learn_model: LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils, request_type="expression.learner"
|
||||
)
|
||||
self.summary_model: LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="expression.summary"
|
||||
)
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
|
||||
|
||||
# 学习锁,防止并发执行学习任务
|
||||
self._learning_lock = asyncio.Lock()
|
||||
|
||||
async def learn_and_store(
|
||||
self,
|
||||
messages: List[Any],
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
学习并存储表达方式
|
||||
|
||||
Args:
|
||||
messages: 外部传入的消息列表(必需)
|
||||
num: 学习数量
|
||||
timestamp_start: 学习开始的时间戳,如果为None则使用self.last_learning_time
|
||||
"""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
random_msg = messages
|
||||
|
||||
# 学习用(开启行编号,便于溯源)
|
||||
random_msg_str: str = await build_anonymous_messages(random_msg, show_ids=True)
|
||||
|
||||
prompt: str = await global_prompt_manager.format_prompt(
|
||||
"learn_style_prompt",
|
||||
bot_name=global_config.bot.nickname,
|
||||
chat_str=random_msg_str,
|
||||
)
|
||||
|
||||
# print(f"random_msg_str:{random_msg_str}")
|
||||
# logger.info(f"学习{type_str}的prompt: {prompt}")
|
||||
|
||||
try:
|
||||
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
|
||||
except Exception as e:
|
||||
logger.error(f"学习表达方式失败,模型生成出错: {e}")
|
||||
return None
|
||||
|
||||
# 解析 LLM 返回的表达方式列表和黑话列表(包含来源行编号)
|
||||
expressions: List[Tuple[str, str, str]]
|
||||
jargon_entries: List[Tuple[str, str]] # (content, source_id)
|
||||
expressions, jargon_entries = self.parse_expression_response(response)
|
||||
expressions = self._filter_self_reference_styles(expressions)
|
||||
|
||||
# 检查表达方式数量,如果超过10个则放弃本次表达学习
|
||||
if len(expressions) > 10:
|
||||
logger.info(f"表达方式提取数量超过10个(实际{len(expressions)}个),放弃本次表达学习")
|
||||
expressions = []
|
||||
|
||||
# 检查黑话数量,如果超过30个则放弃本次黑话学习
|
||||
if len(jargon_entries) > 30:
|
||||
logger.info(f"黑话提取数量超过30个(实际{len(jargon_entries)}个),放弃本次黑话学习")
|
||||
jargon_entries = []
|
||||
|
||||
# 处理黑话条目,路由到 jargon_miner(即使没有表达方式也要处理黑话)
|
||||
if jargon_entries:
|
||||
await self._process_jargon_entries(jargon_entries, random_msg)
|
||||
|
||||
# 如果没有表达方式,直接返回
|
||||
if not expressions:
|
||||
logger.info("过滤后没有可用的表达方式(style 与机器人名称重复)")
|
||||
return []
|
||||
|
||||
logger.info(f"学习的prompt: {prompt}")
|
||||
logger.info(f"学习的expressions: {expressions}")
|
||||
logger.info(f"学习的jargon_entries: {jargon_entries}")
|
||||
logger.info(f"学习的response: {response}")
|
||||
|
||||
# 直接根据 source_id 在 random_msg 中溯源,获取 context
|
||||
filtered_expressions: List[Tuple[str, str, str]] = [] # (situation, style, context)
|
||||
|
||||
for situation, style, source_id in expressions:
|
||||
source_id_str = (source_id or "").strip()
|
||||
if not source_id_str.isdigit():
|
||||
# 无效的来源行编号,跳过
|
||||
continue
|
||||
|
||||
line_index = int(source_id_str) - 1 # build_anonymous_messages 的编号从 1 开始
|
||||
if line_index < 0 or line_index >= len(random_msg):
|
||||
# 超出范围,跳过
|
||||
continue
|
||||
|
||||
# 当前行的原始内容
|
||||
current_msg = random_msg[line_index]
|
||||
|
||||
# 过滤掉从bot自己发言中提取到的表达方式
|
||||
if is_bot_message(current_msg):
|
||||
continue
|
||||
|
||||
context = filter_message_content(current_msg.processed_plain_text or "")
|
||||
if not context:
|
||||
continue
|
||||
|
||||
# 过滤掉包含 SELF 的内容(不学习)
|
||||
if "SELF" in (situation or "") or "SELF" in (style or "") or "SELF" in context:
|
||||
logger.info(
|
||||
f"跳过包含 SELF 的表达方式: situation={situation}, style={style}, source_id={source_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
filtered_expressions.append((situation, style, context))
|
||||
|
||||
learnt_expressions = filtered_expressions
|
||||
|
||||
if learnt_expressions is None:
|
||||
logger.info("没有学习到表达风格")
|
||||
return []
|
||||
|
||||
# 展示学到的表达方式
|
||||
learnt_expressions_str = ""
|
||||
for (
|
||||
situation,
|
||||
style,
|
||||
_context,
|
||||
) in learnt_expressions:
|
||||
learnt_expressions_str += f"{situation}->{style}\n"
|
||||
logger.info(f"在 {self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到数据库 Expression 表
|
||||
for (
|
||||
situation,
|
||||
style,
|
||||
context,
|
||||
) in learnt_expressions:
|
||||
await self._upsert_expression_record(
|
||||
situation=situation,
|
||||
style=style,
|
||||
context=context,
|
||||
current_time=current_time,
|
||||
)
|
||||
|
||||
return learnt_expressions
|
||||
|
||||
def parse_expression_response(self, response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
|
||||
"""
|
||||
解析 LLM 返回的表达风格总结和黑话 JSON,提取两个列表。
|
||||
|
||||
期望的 JSON 结构:
|
||||
[
|
||||
{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式
|
||||
{"content": "词条", "source_id": "12"}, // 黑话
|
||||
...
|
||||
]
|
||||
|
||||
Returns:
|
||||
Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
|
||||
第一个列表是表达方式 (situation, style, source_id)
|
||||
第二个列表是黑话 (content, source_id)
|
||||
"""
|
||||
if not response:
|
||||
return [], []
|
||||
|
||||
raw = response.strip()
|
||||
|
||||
# 尝试提取 ```json 代码块
|
||||
json_block_pattern = r"```json\s*(.*?)\s*```"
|
||||
match = re.search(json_block_pattern, raw, re.DOTALL)
|
||||
if match:
|
||||
raw = match.group(1).strip()
|
||||
else:
|
||||
# 去掉可能存在的通用 ``` 包裹
|
||||
raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
|
||||
raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
|
||||
raw = raw.strip()
|
||||
|
||||
parsed = None
|
||||
expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id)
|
||||
jargon_entries: List[Tuple[str, str]] = [] # (content, source_id)
|
||||
|
||||
try:
|
||||
# 优先尝试直接解析
|
||||
if raw.startswith("[") and raw.endswith("]"):
|
||||
parsed = json.loads(raw)
|
||||
else:
|
||||
repaired = repair_json(raw)
|
||||
if isinstance(repaired, str):
|
||||
parsed = json.loads(repaired)
|
||||
else:
|
||||
parsed = repaired
|
||||
except Exception as parse_error:
|
||||
# 如果解析失败,尝试修复中文引号问题
|
||||
# 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
|
||||
try:
|
||||
|
||||
def fix_chinese_quotes_in_json(text):
|
||||
"""使用状态机修复 JSON 字符串值中的中文引号"""
|
||||
result = []
|
||||
i = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
|
||||
while i < len(text):
|
||||
char = text[i]
|
||||
|
||||
if escape_next:
|
||||
# 当前字符是转义字符后的字符,直接添加
|
||||
result.append(char)
|
||||
escape_next = False
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if char == "\\":
|
||||
# 转义字符
|
||||
result.append(char)
|
||||
escape_next = True
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if char == '"' and not escape_next:
|
||||
# 遇到英文引号,切换字符串状态
|
||||
in_string = not in_string
|
||||
result.append(char)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if in_string:
|
||||
# 在字符串值内部,将中文引号替换为转义的英文引号
|
||||
if char == '"': # 中文左引号 U+201C
|
||||
result.append('\\"')
|
||||
elif char == '"': # 中文右引号 U+201D
|
||||
result.append('\\"')
|
||||
else:
|
||||
result.append(char)
|
||||
else:
|
||||
# 不在字符串内,直接添加
|
||||
result.append(char)
|
||||
|
||||
i += 1
|
||||
|
||||
return "".join(result)
|
||||
|
||||
fixed_raw = fix_chinese_quotes_in_json(raw)
|
||||
|
||||
# 再次尝试解析
|
||||
if fixed_raw.startswith("[") and fixed_raw.endswith("]"):
|
||||
parsed = json.loads(fixed_raw)
|
||||
else:
|
||||
repaired = repair_json(fixed_raw)
|
||||
if isinstance(repaired, str):
|
||||
parsed = json.loads(repaired)
|
||||
else:
|
||||
parsed = repaired
|
||||
except Exception as fix_error:
|
||||
logger.error(f"解析表达风格 JSON 失败,初始错误: {type(parse_error).__name__}: {str(parse_error)}")
|
||||
logger.error(f"修复中文引号后仍失败,错误: {type(fix_error).__name__}: {str(fix_error)}")
|
||||
logger.error(f"解析表达风格 JSON 失败,原始响应:{response}")
|
||||
logger.error(f"处理后的 JSON 字符串(前500字符):{raw[:500]}")
|
||||
return [], []
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
parsed_list = [parsed]
|
||||
elif isinstance(parsed, list):
|
||||
parsed_list = parsed
|
||||
else:
|
||||
logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}")
|
||||
return [], []
|
||||
|
||||
for item in parsed_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
# 检查是否是表达方式条目(有 situation 和 style)
|
||||
situation = str(item.get("situation", "")).strip()
|
||||
style = str(item.get("style", "")).strip()
|
||||
source_id = str(item.get("source_id", "")).strip()
|
||||
|
||||
if situation and style and source_id:
|
||||
# 表达方式条目
|
||||
expressions.append((situation, style, source_id))
|
||||
elif item.get("content"):
|
||||
# 黑话条目(有 content 字段)
|
||||
content = str(item.get("content", "")).strip()
|
||||
source_id = str(item.get("source_id", "")).strip()
|
||||
if content and source_id:
|
||||
jargon_entries.append((content, source_id))
|
||||
|
||||
return expressions, jargon_entries
|
||||
|
||||
def _filter_self_reference_styles(self, expressions: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
过滤掉style与机器人名称/昵称重复的表达
|
||||
"""
|
||||
banned_names = set()
|
||||
bot_nickname = (global_config.bot.nickname or "").strip()
|
||||
if bot_nickname:
|
||||
banned_names.add(bot_nickname)
|
||||
|
||||
alias_names = global_config.bot.alias_names or []
|
||||
for alias in alias_names:
|
||||
alias = alias.strip()
|
||||
if alias:
|
||||
banned_names.add(alias)
|
||||
|
||||
banned_casefold = {name.casefold() for name in banned_names if name}
|
||||
|
||||
filtered: List[Tuple[str, str, str]] = []
|
||||
removed_count = 0
|
||||
for situation, style, source_id in expressions:
|
||||
normalized_style = (style or "").strip()
|
||||
if normalized_style and normalized_style.casefold() not in banned_casefold:
|
||||
filtered.append((situation, style, source_id))
|
||||
else:
|
||||
removed_count += 1
|
||||
|
||||
if removed_count:
|
||||
logger.debug(f"已过滤 {removed_count} 条style与机器人名称重复的表达方式")
|
||||
|
||||
return filtered
|
||||
|
||||
async def _upsert_expression_record(
|
||||
self,
|
||||
situation: str,
|
||||
style: str,
|
||||
context: str,
|
||||
current_time: float,
|
||||
) -> None:
|
||||
# 第一层:检查是否有完全一致的 style(检查 style 字段和 style_list)
|
||||
expr_obj = await self._find_exact_style_match(style)
|
||||
|
||||
if expr_obj:
|
||||
# 找到完全匹配的 style,合并到现有记录(不使用 LLM 总结)
|
||||
await self._update_existing_expression(
|
||||
expr_obj=expr_obj,
|
||||
situation=situation,
|
||||
style=style,
|
||||
context=context,
|
||||
current_time=current_time,
|
||||
use_llm_summary=False,
|
||||
)
|
||||
return
|
||||
|
||||
# 第二层:检查是否有相似的 style(相似度 >= 0.75,检查 style 字段和 style_list)
|
||||
similar_expr_obj = await self._find_similar_style_expression(style, similarity_threshold=0.75)
|
||||
|
||||
if similar_expr_obj:
|
||||
# 找到相似的 style,合并到现有记录(使用 LLM 总结)
|
||||
await self._update_existing_expression(
|
||||
expr_obj=similar_expr_obj,
|
||||
situation=situation,
|
||||
style=style,
|
||||
context=context,
|
||||
current_time=current_time,
|
||||
use_llm_summary=True,
|
||||
)
|
||||
return
|
||||
|
||||
# 没有找到匹配的记录,创建新记录
|
||||
await self._create_expression_record(
|
||||
situation=situation,
|
||||
style=style,
|
||||
context=context,
|
||||
current_time=current_time,
|
||||
)
|
||||
|
||||
async def _create_expression_record(
|
||||
self,
|
||||
situation: str,
|
||||
style: str,
|
||||
context: str,
|
||||
current_time: float,
|
||||
) -> None:
|
||||
content_list = [situation]
|
||||
# 创建新记录时,直接使用原始的 situation,不进行总结
|
||||
formatted_situation = situation
|
||||
|
||||
Expression.create(
|
||||
situation=formatted_situation,
|
||||
style=style,
|
||||
content_list=json.dumps(content_list, ensure_ascii=False),
|
||||
style_list=None, # 新记录初始时 style_list 为空
|
||||
count=1,
|
||||
last_active_time=current_time,
|
||||
chat_id=self.chat_id,
|
||||
create_date=current_time,
|
||||
context=context,
|
||||
)
|
||||
|
||||
async def _update_existing_expression(
|
||||
self,
|
||||
expr_obj: Expression,
|
||||
situation: str,
|
||||
style: str,
|
||||
context: str,
|
||||
current_time: float,
|
||||
use_llm_summary: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
更新现有 Expression 记录(style 完全匹配或相似的情况)
|
||||
将新的 situation 添加到 content_list,将新的 style 添加到 style_list(如果不同)
|
||||
|
||||
Args:
|
||||
use_llm_summary: 是否使用 LLM 进行总结,完全匹配时为 False,相似匹配时为 True
|
||||
"""
|
||||
# 更新 content_list(添加新的 situation)
|
||||
content_list = self._parse_content_list(expr_obj.content_list)
|
||||
content_list.append(situation)
|
||||
expr_obj.content_list = json.dumps(content_list, ensure_ascii=False)
|
||||
|
||||
# 更新 style_list(如果 style 不同,添加到 style_list)
|
||||
style_list = self._parse_style_list(expr_obj.style_list)
|
||||
# 将原有的 style 也加入 style_list(如果还没有的话)
|
||||
if expr_obj.style and expr_obj.style not in style_list:
|
||||
style_list.append(expr_obj.style)
|
||||
# 如果新的 style 不在 style_list 中,添加它
|
||||
if style not in style_list:
|
||||
style_list.append(style)
|
||||
expr_obj.style_list = json.dumps(style_list, ensure_ascii=False)
|
||||
|
||||
# 更新其他字段
|
||||
expr_obj.count = (expr_obj.count or 0) + 1
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.context = context
|
||||
|
||||
if use_llm_summary:
|
||||
# 相似匹配时,使用 LLM 重新组合 situation 和 style
|
||||
new_situation = await self._compose_situation_text(
|
||||
content_list=content_list,
|
||||
count=expr_obj.count,
|
||||
fallback=expr_obj.situation,
|
||||
)
|
||||
expr_obj.situation = new_situation
|
||||
|
||||
new_style = await self._compose_style_text(
|
||||
style_list=style_list,
|
||||
count=expr_obj.count,
|
||||
fallback=expr_obj.style or style,
|
||||
)
|
||||
expr_obj.style = new_style
|
||||
else:
|
||||
# 完全匹配时,不进行 LLM 总结,保持原有的 situation 和 style 不变
|
||||
# 只更新 content_list 和 style_list
|
||||
pass
|
||||
|
||||
expr_obj.save()
|
||||
|
||||
def _parse_content_list(self, stored_list: Optional[str]) -> List[str]:
|
||||
if not stored_list:
|
||||
return []
|
||||
try:
|
||||
data = json.loads(stored_list)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
|
||||
|
||||
def _parse_style_list(self, stored_list: Optional[str]) -> List[str]:
|
||||
"""解析 style_list JSON 字符串为列表,逻辑与 _parse_content_list 相同"""
|
||||
if not stored_list:
|
||||
return []
|
||||
try:
|
||||
data = json.loads(stored_list)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
|
||||
|
||||
async def _find_exact_style_match(self, style: str) -> Optional[Expression]:
|
||||
"""
|
||||
查找具有完全匹配 style 的 Expression 记录
|
||||
只检查 style_list 中的每一项(不检查 style 字段,因为 style 可能是总结后的概括性描述)
|
||||
|
||||
Args:
|
||||
style: 要查找的 style
|
||||
|
||||
Returns:
|
||||
找到的 Expression 对象,如果没有找到则返回 None
|
||||
"""
|
||||
# 查询同一 chat_id 的所有记录
|
||||
all_expressions = Expression.select().where(Expression.chat_id == self.chat_id)
|
||||
|
||||
for expr in all_expressions:
|
||||
# 只检查 style_list 中的每一项
|
||||
style_list = self._parse_style_list(expr.style_list)
|
||||
if style in style_list:
|
||||
return expr
|
||||
|
||||
return None
|
||||
|
||||
async def _find_similar_style_expression(self, style: str, similarity_threshold: float = 0.75) -> Optional[Expression]:
|
||||
"""
|
||||
查找具有相似 style 的 Expression 记录
|
||||
只检查 style_list 中的每一项(不检查 style 字段,因为 style 可能是总结后的概括性描述)
|
||||
|
||||
Args:
|
||||
style: 要查找的 style
|
||||
similarity_threshold: 相似度阈值,默认 0.75
|
||||
|
||||
Returns:
|
||||
找到的最相似的 Expression 对象,如果没有找到则返回 None
|
||||
"""
|
||||
# 查询同一 chat_id 的所有记录
|
||||
all_expressions = Expression.select().where(Expression.chat_id == self.chat_id)
|
||||
|
||||
best_match = None
|
||||
best_similarity = 0.0
|
||||
|
||||
for expr in all_expressions:
|
||||
# 只检查 style_list 中的每一项
|
||||
style_list = self._parse_style_list(expr.style_list)
|
||||
for existing_style in style_list:
|
||||
similarity = calculate_style_similarity(style, existing_style)
|
||||
if similarity >= similarity_threshold and similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = expr
|
||||
|
||||
if best_match:
|
||||
logger.debug(f"找到相似的 style: 相似度={best_similarity:.3f}, 现有='{best_match.style}', 新='{style}'")
|
||||
|
||||
return best_match
|
||||
|
||||
async def _compose_situation_text(self, content_list: List[str], count: int, fallback: str = "") -> str:
|
||||
sanitized = [c.strip() for c in content_list if c.strip()]
|
||||
summary = await self._summarize_situations(sanitized)
|
||||
if summary:
|
||||
return summary
|
||||
return "/".join(sanitized) if sanitized else fallback
|
||||
|
||||
async def _compose_style_text(self, style_list: List[str], count: int, fallback: str = "") -> str:
|
||||
"""
|
||||
组合 style 文本,如果 style_list 有多个元素则尝试总结
|
||||
"""
|
||||
sanitized = [s.strip() for s in style_list if s.strip()]
|
||||
if len(sanitized) > 1:
|
||||
# 只有当有多个 style 时才尝试总结
|
||||
summary = await self._summarize_styles(sanitized)
|
||||
if summary:
|
||||
return summary
|
||||
# 如果只有一个或总结失败,返回第一个或 fallback
|
||||
return sanitized[0] if sanitized else fallback
|
||||
|
||||
async def _summarize_styles(self, styles: List[str]) -> Optional[str]:
|
||||
"""总结多个 style,生成一个概括性的 style 描述"""
|
||||
if not styles or len(styles) <= 1:
|
||||
return None
|
||||
|
||||
# 计算输入列表中最长项目的长度
|
||||
max_input_length = max(len(s) for s in styles) if styles else 0
|
||||
max_summary_length = max_input_length * 2
|
||||
|
||||
# 最多重试3次
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
|
||||
while retry_count < max_retries:
|
||||
# 如果是重试,在 prompt 中强调要更简洁
|
||||
length_hint = f"长度不超过{max_summary_length}个字符," if retry_count > 0 else "长度不超过20个字,"
|
||||
|
||||
prompt = (
|
||||
"请阅读以下多个语言风格/表达方式,对其进行总结。"
|
||||
"不要对其进行语义概括,而是尽可能找出其中不变的部分或共同表达,尽量使用原文"
|
||||
f"{length_hint}保留共同特点:\n"
|
||||
f"{chr(10).join(f'- {s}' for s in styles[-10:])}\n只输出概括内容。不要输出其他内容"
|
||||
)
|
||||
|
||||
try:
|
||||
summary, _ = await self.summary_model.generate_response_async(prompt, temperature=0.2)
|
||||
summary = summary.strip()
|
||||
if summary:
|
||||
# 检查总结长度是否超过限制
|
||||
if len(summary) <= max_summary_length:
|
||||
return summary
|
||||
else:
|
||||
retry_count += 1
|
||||
logger.debug(
|
||||
f"总结长度 {len(summary)} 超过限制 {max_summary_length} "
|
||||
f"(输入最长项长度: {max_input_length}),重试第 {retry_count} 次"
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"概括表达风格失败: {e}")
|
||||
return None
|
||||
|
||||
# 如果重试多次后仍然超过长度,返回 None(不进行总结)
|
||||
logger.warning(
|
||||
f"总结多次后仍超过长度限制,放弃总结。"
|
||||
f"输入最长项长度: {max_input_length}, 最大允许长度: {max_summary_length}"
|
||||
)
|
||||
return None
|
||||
|
||||
async def _summarize_situations(self, situations: List[str]) -> Optional[str]:
|
||||
if not situations:
|
||||
return None
|
||||
|
||||
prompt = (
|
||||
"请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,"
|
||||
"长度不超过20个字,保留共同特点:\n"
|
||||
f"{chr(10).join(f'- {s}' for s in situations[-10:])}\n只输出概括内容。"
|
||||
)
|
||||
|
||||
try:
|
||||
summary, _ = await self.summary_model.generate_response_async(prompt, temperature=0.2)
|
||||
summary = summary.strip()
|
||||
if summary:
|
||||
return summary
|
||||
except Exception as e:
|
||||
logger.error(f"概括表达情境失败: {e}")
|
||||
return None
|
||||
|
||||
async def _process_jargon_entries(self, jargon_entries: List[Tuple[str, str]], messages: List[Any]) -> None:
|
||||
"""
|
||||
处理从 expression learner 提取的黑话条目,路由到 jargon_miner
|
||||
|
||||
Args:
|
||||
jargon_entries: 黑话条目列表,每个元素是 (content, source_id)
|
||||
messages: 消息列表,用于构建上下文
|
||||
"""
|
||||
if not jargon_entries or not messages:
|
||||
return
|
||||
|
||||
# 获取 jargon_miner 实例
|
||||
jargon_miner = miner_manager.get_miner(self.chat_id)
|
||||
|
||||
# 构建黑话条目格式,与 jargon_miner.run_once 中的格式一致
|
||||
entries: List[Dict[str, List[str]]] = []
|
||||
|
||||
for content, source_id in jargon_entries:
|
||||
content = content.strip()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
# 过滤掉包含 SELF 的黑话,不学习
|
||||
if "SELF" in content:
|
||||
logger.info(f"跳过包含 SELF 的黑话: {content}")
|
||||
continue
|
||||
|
||||
# 检查是否包含机器人名称
|
||||
if contains_bot_self_name(content):
|
||||
logger.info(f"跳过包含机器人昵称/别名的黑话: {content}")
|
||||
continue
|
||||
|
||||
# 解析 source_id
|
||||
source_id_str = (source_id or "").strip()
|
||||
if not source_id_str.isdigit():
|
||||
logger.warning(f"黑话条目 source_id 无效: content={content}, source_id={source_id_str}")
|
||||
continue
|
||||
|
||||
# build_anonymous_messages 的编号从 1 开始
|
||||
line_index = int(source_id_str) - 1
|
||||
if line_index < 0 or line_index >= len(messages):
|
||||
logger.warning(f"黑话条目 source_id 超出范围: content={content}, source_id={source_id_str}")
|
||||
continue
|
||||
|
||||
# 检查是否是机器人自己的消息
|
||||
target_msg = messages[line_index]
|
||||
if is_bot_message(target_msg):
|
||||
logger.info(f"跳过引用机器人自身消息的黑话: content={content}, source_id={source_id_str}")
|
||||
continue
|
||||
|
||||
# 构建上下文段落
|
||||
context_paragraph = build_context_paragraph(messages, line_index)
|
||||
if not context_paragraph:
|
||||
logger.warning(f"黑话条目上下文为空: content={content}, source_id={source_id_str}")
|
||||
continue
|
||||
|
||||
entries.append({"content": content, "raw_content": [context_paragraph]})
|
||||
|
||||
if not entries:
|
||||
return
|
||||
|
||||
# 调用 jargon_miner 处理这些条目
|
||||
await jargon_miner.process_extracted_entries(entries)
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
|
||||
class ExpressionLearnerManager:
|
||||
def __init__(self):
|
||||
self.expression_learners = {}
|
||||
|
||||
self._ensure_expression_directories()
|
||||
|
||||
def get_expression_learner(self, chat_id: str) -> ExpressionLearner:
|
||||
if chat_id not in self.expression_learners:
|
||||
self.expression_learners[chat_id] = ExpressionLearner(chat_id)
|
||||
return self.expression_learners[chat_id]
|
||||
|
||||
def _ensure_expression_directories(self):
|
||||
"""
|
||||
确保表达方式相关的目录结构存在
|
||||
"""
|
||||
base_dir = os.path.join("data", "expression")
|
||||
directories_to_create = [
|
||||
base_dir,
|
||||
os.path.join(base_dir, "learnt_style"),
|
||||
os.path.join(base_dir, "learnt_grammar"),
|
||||
]
|
||||
|
||||
for directory in directories_to_create:
|
||||
try:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
logger.debug(f"确保目录存在: {directory}")
|
||||
except Exception as e:
|
||||
logger.error(f"创建目录失败 {directory}: {e}")
|
||||
|
||||
|
||||
expression_learner_manager = ExpressionLearnerManager()
|
||||
@@ -82,9 +82,7 @@ class ExpressionReflector:
|
||||
# 获取未检查的表达
|
||||
try:
|
||||
logger.info("[Expression Reflection] 查询未检查且未拒绝的表达")
|
||||
expressions = (
|
||||
Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
|
||||
)
|
||||
expressions = Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
|
||||
|
||||
expr_list = list(expressions)
|
||||
logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达")
|
||||
@@ -147,7 +145,7 @@ expression_reflector_manager = ExpressionReflectorManager()
|
||||
|
||||
async def _check_tracker_exists(operator_config: str) -> bool:
|
||||
"""检查指定 Operator 是否已有活跃的 Tracker"""
|
||||
from src.express.reflect_tracker import reflect_tracker_manager
|
||||
from src.bw_learner.reflect_tracker import reflect_tracker_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = None
|
||||
@@ -242,7 +240,7 @@ async def _send_to_operator(operator_config: str, text: str, expr: Expression):
|
||||
stream_id = chat_stream.stream_id
|
||||
|
||||
# 注册 Tracker
|
||||
from src.express.reflect_tracker import ReflectTracker, reflect_tracker_manager
|
||||
from src.bw_learner.reflect_tracker import ReflectTracker, reflect_tracker_manager
|
||||
|
||||
tracker = ReflectTracker(chat_stream=chat_stream, expression=expr, created_time=time.time())
|
||||
reflect_tracker_manager.add_tracker(stream_id, tracker)
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import time
|
||||
import hashlib
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from json_repair import repair_json
|
||||
@@ -10,7 +9,8 @@ from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.express.express_utils import weighted_sample
|
||||
from src.bw_learner.learner_utils import weighted_sample
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
@@ -67,7 +67,7 @@ class ExpressionSelector:
|
||||
|
||||
@staticmethod
|
||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
||||
"""解析'platform:id:type'为chat_id(与get_stream_id一致)"""
|
||||
"""解析'platform:id:type'为chat_id,直接使用 ChatManager 提供的接口"""
|
||||
try:
|
||||
parts = stream_config_str.split(":")
|
||||
if len(parts) != 3:
|
||||
@@ -76,12 +76,8 @@ class ExpressionSelector:
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
is_group = stream_type == "group"
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
# 统一通过 chat_manager 生成 stream_id,避免各处自行实现哈希逻辑
|
||||
return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@@ -111,6 +107,85 @@ class ExpressionSelector:
|
||||
return group_chat_ids
|
||||
return [chat_id]
|
||||
|
||||
def _select_expressions_simple(self, chat_id: str, max_num: int) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
简单模式:只选择 count > 1 的项目,要求至少有10个才进行选择,随机选5个,不进行LLM选择
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
max_num: 最大选择数量(此参数在此模式下不使用,固定选择5个)
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
try:
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 查询所有相关chat_id的表达方式,排除 rejected=1 的,且只选择 count > 1 的
|
||||
style_query = Expression.select().where(
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) & (Expression.count > 1)
|
||||
)
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
"id": expr.id,
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
"count": expr.count if getattr(expr, "count", None) is not None else 1,
|
||||
"checked": expr.checked if getattr(expr, "checked", None) is not None else False,
|
||||
}
|
||||
for expr in style_query
|
||||
]
|
||||
|
||||
# 要求至少有一定数量的 count > 1 的表达方式才进行“完整简单模式”选择
|
||||
min_required = 8
|
||||
if len(style_exprs) < min_required:
|
||||
# 高 count 样本不足:如果还有候选,就降级为随机选 3 个;如果一个都没有,则直接返回空
|
||||
if not style_exprs:
|
||||
logger.info(
|
||||
f"聊天流 {chat_id} 没有满足 count > 1 且未被拒绝的表达方式,简单模式不进行选择"
|
||||
)
|
||||
# 完全没有高 count 样本时,退化为全量随机抽样(不进入LLM流程)
|
||||
fallback_num = min(3, max_num) if max_num > 0 else 3
|
||||
fallback_selected = self._random_expressions(chat_id, fallback_num)
|
||||
if fallback_selected:
|
||||
self.update_expressions_last_active_time(fallback_selected)
|
||||
selected_ids = [expr["id"] for expr in fallback_selected]
|
||||
logger.info(
|
||||
f"聊天流 {chat_id} 使用简单模式降级随机抽选 {len(fallback_selected)} 个表达(无 count>1 样本)"
|
||||
)
|
||||
return fallback_selected, selected_ids
|
||||
return [], []
|
||||
logger.info(
|
||||
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_required} 个(实际 {len(style_exprs)} 个),"
|
||||
f"简单模式降级为随机选择 3 个"
|
||||
)
|
||||
select_count = min(3, len(style_exprs))
|
||||
else:
|
||||
# 高 count 数量达标时,固定选择 5 个
|
||||
select_count = 5
|
||||
import random
|
||||
|
||||
selected_style = random.sample(style_exprs, select_count)
|
||||
|
||||
# 更新last_active_time
|
||||
if selected_style:
|
||||
self.update_expressions_last_active_time(selected_style)
|
||||
|
||||
selected_ids = [expr["id"] for expr in selected_style]
|
||||
logger.debug(
|
||||
f"think_level=0: 从 {len(style_exprs)} 个 count>1 的表达方式中随机选择了 {len(selected_style)} 个"
|
||||
)
|
||||
return selected_style, selected_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"简单模式选择表达方式失败: {e}")
|
||||
return [], []
|
||||
|
||||
def _random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
随机选择表达方式
|
||||
@@ -127,9 +202,7 @@ class ExpressionSelector:
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 优化:一次性查询所有相关chat_id的表达方式,排除 rejected=1 的表达
|
||||
style_query = Expression.select().where(
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
|
||||
)
|
||||
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
@@ -164,6 +237,7 @@ class ExpressionSelector:
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
reply_reason: Optional[str] = None,
|
||||
think_level: int = 1,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
选择适合的表达方式(使用classic模式:随机选择+LLM选择)
|
||||
@@ -174,6 +248,7 @@ class ExpressionSelector:
|
||||
max_num: 最大选择数量
|
||||
target_message: 目标消息内容
|
||||
reply_reason: planner给出的回复理由
|
||||
think_level: 思考级别,0/1
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
@@ -184,8 +259,10 @@ class ExpressionSelector:
|
||||
return [], []
|
||||
|
||||
# 使用classic模式(随机选择+LLM选择)
|
||||
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式")
|
||||
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message, reply_reason)
|
||||
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式,think_level={think_level}")
|
||||
return await self._select_expressions_classic(
|
||||
chat_id, chat_info, max_num, target_message, reply_reason, think_level
|
||||
)
|
||||
|
||||
async def _select_expressions_classic(
|
||||
self,
|
||||
@@ -194,6 +271,7 @@ class ExpressionSelector:
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
reply_reason: Optional[str] = None,
|
||||
think_level: int = 1,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
classic模式:随机选择+LLM选择
|
||||
@@ -204,24 +282,91 @@ class ExpressionSelector:
|
||||
max_num: 最大选择数量
|
||||
target_message: 目标消息内容
|
||||
reply_reason: planner给出的回复理由
|
||||
think_level: 思考级别,0/1
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
try:
|
||||
# 1. 使用随机抽样选择表达方式
|
||||
style_exprs = self._random_expressions(chat_id, 20)
|
||||
# think_level == 0: 只选择 count > 1 的项目,随机选10个,不进行LLM选择
|
||||
if think_level == 0:
|
||||
return self._select_expressions_simple(chat_id, max_num)
|
||||
|
||||
if len(style_exprs) < 10:
|
||||
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
|
||||
# think_level == 1: 先选高count,再从所有表达方式中随机抽样
|
||||
# 1. 获取所有表达方式并分离 count > 1 和 count <= 1 的
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
|
||||
|
||||
all_style_exprs = [
|
||||
{
|
||||
"id": expr.id,
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
"count": expr.count if getattr(expr, "count", None) is not None else 1,
|
||||
"checked": expr.checked if getattr(expr, "checked", None) is not None else False,
|
||||
}
|
||||
for expr in style_query
|
||||
]
|
||||
|
||||
# 分离 count > 1 和 count <= 1 的表达方式
|
||||
high_count_exprs = [expr for expr in all_style_exprs if (expr.get("count", 1) or 1) > 1]
|
||||
|
||||
# 根据 think_level 设置要求(仅支持 0/1,0 已在上方返回)
|
||||
min_high_count = 10
|
||||
min_total_count = 10
|
||||
select_high_count = 5
|
||||
select_random_count = 5
|
||||
|
||||
# 检查数量要求
|
||||
# 对于高 count 表达:如果数量不足,不再直接停止,而是仅跳过“高 count 优先选择”
|
||||
if len(high_count_exprs) < min_high_count:
|
||||
logger.info(
|
||||
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_high_count} 个(实际 {len(high_count_exprs)} 个),"
|
||||
f"将跳过高 count 优先选择,仅从全部表达中随机抽样"
|
||||
)
|
||||
high_count_valid = False
|
||||
else:
|
||||
high_count_valid = True
|
||||
|
||||
# 总量不足仍然直接返回,避免样本过少导致选择质量过低
|
||||
if len(all_style_exprs) < min_total_count:
|
||||
logger.info(
|
||||
f"聊天流 {chat_id} 总表达方式不足 {min_total_count} 个(实际 {len(all_style_exprs)} 个),不进行选择"
|
||||
)
|
||||
return [], []
|
||||
|
||||
# 先选取高count的表达方式(如果数量达标)
|
||||
if high_count_valid:
|
||||
selected_high = weighted_sample(high_count_exprs, min(len(high_count_exprs), select_high_count))
|
||||
else:
|
||||
selected_high = []
|
||||
|
||||
# 然后从所有表达方式中随机抽样(使用加权抽样)
|
||||
remaining_num = select_random_count
|
||||
selected_random = weighted_sample(all_style_exprs, min(len(all_style_exprs), remaining_num))
|
||||
|
||||
# 合并候选池(去重,避免重复)
|
||||
candidate_exprs = selected_high.copy()
|
||||
candidate_ids = {expr["id"] for expr in candidate_exprs}
|
||||
for expr in selected_random:
|
||||
if expr["id"] not in candidate_ids:
|
||||
candidate_exprs.append(expr)
|
||||
candidate_ids.add(expr["id"])
|
||||
|
||||
# 打乱顺序,避免高count的都在前面
|
||||
import random
|
||||
|
||||
random.shuffle(candidate_exprs)
|
||||
|
||||
# 2. 构建所有表达方式的索引和情境列表
|
||||
all_expressions: List[Dict[str, Any]] = []
|
||||
all_situations: List[str] = []
|
||||
|
||||
# 添加style表达方式
|
||||
for expr in style_exprs:
|
||||
for expr in candidate_exprs:
|
||||
expr = expr.copy()
|
||||
all_expressions.append(expr)
|
||||
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
|
||||
@@ -233,7 +378,7 @@ class ExpressionSelector:
|
||||
all_situations_str = "\n".join(all_situations)
|
||||
|
||||
if target_message:
|
||||
target_message_str = f",现在你想要对这条消息进行回复:“{target_message}”"
|
||||
target_message_str = f',现在你想要对这条消息进行回复:"{target_message}"'
|
||||
target_message_extra_block = "4.考虑你要回复的目标消息"
|
||||
else:
|
||||
target_message_str = ""
|
||||
@@ -262,7 +407,8 @@ class ExpressionSelector:
|
||||
# 4. 调用LLM
|
||||
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
# print(prompt)
|
||||
print(prompt)
|
||||
print(content)
|
||||
|
||||
if not content:
|
||||
logger.warning("LLM返回空结果")
|
||||
@@ -7,8 +7,13 @@ from src.common.database.database_model import Jargon
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.jargon.jargon_miner import search_jargon
|
||||
from src.jargon.jargon_utils import is_bot_message, contains_bot_self_name, parse_chat_id_list, chat_id_list_contains
|
||||
from src.bw_learner.jargon_miner import search_jargon
|
||||
from src.bw_learner.learner_utils import (
|
||||
is_bot_message,
|
||||
contains_bot_self_name,
|
||||
parse_chat_id_list,
|
||||
chat_id_list_contains,
|
||||
)
|
||||
|
||||
logger = get_logger("jargon")
|
||||
|
||||
@@ -82,7 +87,7 @@ class JargonExplainer:
|
||||
query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
|
||||
|
||||
# 根据all_global配置决定查询逻辑
|
||||
if global_config.jargon.all_global:
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启all_global:只查询is_global=True的记录
|
||||
query = query.where(Jargon.is_global)
|
||||
else:
|
||||
@@ -107,7 +112,7 @@ class JargonExplainer:
|
||||
continue
|
||||
|
||||
# 检查chat_id(如果all_global=False)
|
||||
if not global_config.jargon.all_global:
|
||||
if not global_config.expression.all_global_jargon:
|
||||
if jargon.is_global:
|
||||
# 全局黑话,包含
|
||||
pass
|
||||
@@ -181,7 +186,7 @@ class JargonExplainer:
|
||||
content = entry["content"]
|
||||
|
||||
# 根据是否开启全局黑话,决定查询方式
|
||||
if global_config.jargon.all_global:
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启全局黑话:查询所有is_global=True的记录
|
||||
results = search_jargon(
|
||||
keyword=content,
|
||||
@@ -265,7 +270,7 @@ def match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]:
|
||||
return []
|
||||
|
||||
query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
|
||||
if global_config.jargon.all_global:
|
||||
if global_config.expression.all_global_jargon:
|
||||
query = query.where(Jargon.is_global)
|
||||
|
||||
query = query.order_by(Jargon.count.desc())
|
||||
@@ -277,7 +282,7 @@ def match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]:
|
||||
if not content:
|
||||
continue
|
||||
|
||||
if not global_config.jargon.all_global and not jargon.is_global:
|
||||
if not global_config.expression.all_global_jargon and not jargon.is_global:
|
||||
chat_id_list = parse_chat_id_list(jargon.chat_id)
|
||||
if not chat_id_list_contains(chat_id_list, chat_id):
|
||||
continue
|
||||
@@ -357,4 +362,4 @@ async def retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> st
|
||||
|
||||
if results:
|
||||
return "【概念检索结果】\n" + "\n".join(results) + "\n"
|
||||
return ""
|
||||
return ""
|
||||
@@ -1,8 +1,8 @@
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from typing import List, Dict, Optional, Any
|
||||
from typing import List, Dict, Optional, Any, Callable
|
||||
from json_repair import repair_json
|
||||
from peewee import fn
|
||||
|
||||
@@ -13,10 +13,9 @@ from src.config.config import model_config, global_config
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||
)
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.jargon.jargon_utils import (
|
||||
from src.bw_learner.learner_utils import (
|
||||
is_bot_message,
|
||||
build_context_paragraph,
|
||||
contains_bot_self_name,
|
||||
@@ -29,6 +28,29 @@ from src.jargon.jargon_utils import (
|
||||
logger = get_logger("jargon")
|
||||
|
||||
|
||||
def _is_single_char_jargon(content: str) -> bool:
|
||||
"""
|
||||
判断是否是单字黑话(单个汉字、英文或数字)
|
||||
|
||||
Args:
|
||||
content: 词条内容
|
||||
|
||||
Returns:
|
||||
bool: 如果是单字黑话返回True,否则返回False
|
||||
"""
|
||||
if not content or len(content) != 1:
|
||||
return False
|
||||
|
||||
char = content[0]
|
||||
# 判断是否是单个汉字、单个英文字母或单个数字
|
||||
return (
|
||||
"\u4e00" <= char <= "\u9fff" # 汉字
|
||||
or "a" <= char <= "z" # 小写字母
|
||||
or "A" <= char <= "Z" # 大写字母
|
||||
or "0" <= char <= "9" # 数字
|
||||
)
|
||||
|
||||
|
||||
def _init_prompt() -> None:
|
||||
prompt_str = """
|
||||
**聊天内容,其中的{bot_name}的发言内容是你自己的发言,[msg_id] 是消息ID**
|
||||
@@ -36,11 +58,9 @@ def _init_prompt() -> None:
|
||||
|
||||
请从上面这段聊天内容中提取"可能是黑话"的候选项(黑话/俚语/网络缩写/口头禅)。
|
||||
- 必须为对话中真实出现过的短词或短语
|
||||
- 必须是你无法理解含义的词语,没有明确含义的词语
|
||||
- 请不要选择有明确含义,或者含义清晰的词语
|
||||
- 必须是你无法理解含义的词语,没有明确含义的词语,请不要选择有明确含义,或者含义清晰的词语
|
||||
- 排除:人名、@、表情包/图片中的内容、纯标点、常规功能词(如的、了、呢、啊等)
|
||||
- 每个词条长度建议 2-8 个字符(不强制),尽量短小
|
||||
- 合并重复项,去重
|
||||
|
||||
黑话必须为以下几种类型:
|
||||
- 由字母构成的,汉语拼音首字母的简写词,例如:nb、yyds、xswl
|
||||
@@ -48,7 +68,7 @@ def _init_prompt() -> None:
|
||||
- 中文词语的缩写,用几个汉字概括一个词汇或含义,例如:社死、内卷
|
||||
|
||||
以 JSON 数组输出,元素为对象(严格按以下结构):
|
||||
请你提取出可能的黑话,最多10
|
||||
请你提取出可能的黑话,最多30个黑话,请尽量提取所有
|
||||
[
|
||||
{{"content": "词条", "msg_id": "m12"}}, // msg_id 必须与上方聊天中展示的ID完全一致
|
||||
{{"content": "词条2", "msg_id": "m15"}}
|
||||
@@ -67,12 +87,14 @@ def _init_inference_prompts() -> None:
|
||||
{content}
|
||||
**词条出现的上下文。其中的{bot_name}的发言内容是你自己的发言**
|
||||
{raw_content_list}
|
||||
{previous_meaning_section}
|
||||
|
||||
请根据上下文,推断"{content}"这个词条的含义。
|
||||
- 如果这是一个黑话、俚语或网络用语,请推断其含义
|
||||
- 如果含义明确(常规词汇),也请说明
|
||||
- {bot_name} 的发言内容可能包含错误,请不要参考其发言内容
|
||||
- 如果上下文信息不足,无法推断含义,请设置 no_info 为 true
|
||||
{previous_meaning_instruction}
|
||||
|
||||
以 JSON 格式输出:
|
||||
{{
|
||||
@@ -166,23 +188,24 @@ def _should_infer_meaning(jargon_obj: Jargon) -> bool:
|
||||
class JargonMiner:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.chat_id = chat_id
|
||||
self.last_learning_time: float = time.time()
|
||||
# 频率控制,可按需调整
|
||||
self.min_messages_for_learning: int = 10
|
||||
self.min_learning_interval: float = 20
|
||||
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="jargon.extract",
|
||||
)
|
||||
|
||||
self.llm_inference = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="jargon.inference",
|
||||
)
|
||||
|
||||
# 初始化stream_name作为类属性,避免重复提取
|
||||
chat_manager = get_chat_manager()
|
||||
stream_name = chat_manager.get_stream_name(self.chat_id)
|
||||
self.stream_name = stream_name if stream_name else self.chat_id
|
||||
self.cache_limit = 100
|
||||
self.cache_limit = 50
|
||||
self.cache: OrderedDict[str, None] = OrderedDict()
|
||||
|
||||
|
||||
# 黑话提取锁,防止并发执行
|
||||
self._extraction_lock = asyncio.Lock()
|
||||
|
||||
@@ -195,6 +218,10 @@ class JargonMiner:
|
||||
if not key:
|
||||
return
|
||||
|
||||
# 单字黑话(单个汉字、英文或数字)不记录到缓存
|
||||
if _is_single_char_jargon(key):
|
||||
return
|
||||
|
||||
if key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
else:
|
||||
@@ -267,16 +294,44 @@ class JargonMiner:
|
||||
logger.warning(f"jargon {content} 没有raw_content,跳过推断")
|
||||
return
|
||||
|
||||
# 获取当前count和上一次的meaning
|
||||
current_count = jargon_obj.count or 0
|
||||
previous_meaning = jargon_obj.meaning or ""
|
||||
|
||||
# 当count为24, 60时,随机移除一半的raw_content项目
|
||||
if current_count in [24, 60] and len(raw_content_list) > 1:
|
||||
# 计算要保留的数量(至少保留1个)
|
||||
keep_count = max(1, len(raw_content_list) // 2)
|
||||
raw_content_list = random.sample(raw_content_list, keep_count)
|
||||
logger.info(
|
||||
f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目"
|
||||
)
|
||||
|
||||
# 步骤1: 基于raw_content和content推断
|
||||
raw_content_text = "\n".join(raw_content_list)
|
||||
|
||||
# 当count为24, 60, 100时,在prompt中放入上一次推断出的meaning作为参考
|
||||
previous_meaning_section = ""
|
||||
previous_meaning_instruction = ""
|
||||
if current_count in [24, 60, 100] and previous_meaning:
|
||||
previous_meaning_section = f"""
|
||||
**上一次推断的含义(仅供参考)**
|
||||
{previous_meaning}
|
||||
"""
|
||||
previous_meaning_instruction = (
|
||||
"- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果"
|
||||
)
|
||||
|
||||
prompt1 = await global_prompt_manager.format_prompt(
|
||||
"jargon_inference_with_context_prompt",
|
||||
content=content,
|
||||
bot_name=global_config.bot.nickname,
|
||||
raw_content_list=raw_content_text,
|
||||
previous_meaning_section=previous_meaning_section,
|
||||
previous_meaning_instruction=previous_meaning_instruction,
|
||||
)
|
||||
|
||||
response1, _ = await self.llm.generate_response_async(prompt1, temperature=0.3)
|
||||
response1, _ = await self.llm_inference.generate_response_async(prompt1, temperature=0.3)
|
||||
if not response1:
|
||||
logger.warning(f"jargon {content} 推断1失败:无响应")
|
||||
return
|
||||
@@ -313,7 +368,7 @@ class JargonMiner:
|
||||
content=content,
|
||||
)
|
||||
|
||||
response2, _ = await self.llm.generate_response_async(prompt2, temperature=0.3)
|
||||
response2, _ = await self.llm_inference.generate_response_async(prompt2, temperature=0.3)
|
||||
if not response2:
|
||||
logger.warning(f"jargon {content} 推断2失败:无响应")
|
||||
return
|
||||
@@ -360,7 +415,7 @@ class JargonMiner:
|
||||
if global_config.debug.show_jargon_prompt:
|
||||
logger.info(f"jargon {content} 比较提示词: {prompt3}")
|
||||
|
||||
response3, _ = await self.llm.generate_response_async(prompt3, temperature=0.3)
|
||||
response3, _ = await self.llm_inference.generate_response_async(prompt3, temperature=0.3)
|
||||
if not response3:
|
||||
logger.warning(f"jargon {content} 比较失败:无响应")
|
||||
return
|
||||
@@ -425,45 +480,21 @@ class JargonMiner:
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
def should_trigger(self) -> bool:
|
||||
# 冷却时间检查
|
||||
if time.time() - self.last_learning_time < self.min_learning_interval:
|
||||
return False
|
||||
async def run_once(
|
||||
self,
|
||||
messages: List[Any],
|
||||
person_name_filter: Optional[Callable[[str], bool]] = None
|
||||
) -> None:
|
||||
"""
|
||||
运行一次黑话提取
|
||||
|
||||
# 拉取最近消息数量是否足够
|
||||
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_learning_time,
|
||||
timestamp_end=time.time(),
|
||||
)
|
||||
return bool(recent_messages and len(recent_messages) >= self.min_messages_for_learning)
|
||||
|
||||
async def run_once(self) -> None:
|
||||
Args:
|
||||
messages: 外部传入的消息列表(必需)
|
||||
person_name_filter: 可选的过滤函数,用于检查内容是否包含人物名称
|
||||
"""
|
||||
# 使用异步锁防止并发执行
|
||||
async with self._extraction_lock:
|
||||
try:
|
||||
# 在锁内检查,避免并发触发
|
||||
if not self.should_trigger():
|
||||
return
|
||||
|
||||
chat_stream = get_chat_manager().get_stream(self.chat_id)
|
||||
if not chat_stream:
|
||||
return
|
||||
|
||||
# 记录本次提取的时间窗口,避免重复提取
|
||||
extraction_start_time = self.last_learning_time
|
||||
extraction_end_time = time.time()
|
||||
|
||||
# 立即更新学习时间,防止并发触发
|
||||
self.last_learning_time = extraction_end_time
|
||||
|
||||
# 拉取学习窗口内的消息
|
||||
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=extraction_start_time,
|
||||
timestamp_end=extraction_end_time,
|
||||
limit=20,
|
||||
)
|
||||
if not messages:
|
||||
return
|
||||
|
||||
@@ -538,6 +569,11 @@ class JargonMiner:
|
||||
if contains_bot_self_name(content):
|
||||
logger.info(f"解析阶段跳过包含机器人昵称/别名的词条: {content}")
|
||||
continue
|
||||
|
||||
# 检查是否包含人物名称
|
||||
if person_name_filter and person_name_filter(content):
|
||||
logger.info(f"解析阶段跳过包含人物名称的词条: {content}")
|
||||
continue
|
||||
|
||||
msg_id_str = str(msg_id_value or "").strip()
|
||||
if not msg_id_str:
|
||||
@@ -603,7 +639,7 @@ class JargonMiner:
|
||||
# 查找匹配的记录
|
||||
matched_obj = None
|
||||
for obj in query:
|
||||
if global_config.jargon.all_global:
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启all_global:所有content匹配的记录都可以
|
||||
matched_obj = obj
|
||||
break
|
||||
@@ -626,7 +662,9 @@ class JargonMiner:
|
||||
if obj.raw_content:
|
||||
try:
|
||||
existing_raw_content = (
|
||||
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||
json.loads(obj.raw_content)
|
||||
if isinstance(obj.raw_content, str)
|
||||
else obj.raw_content
|
||||
)
|
||||
if not isinstance(existing_raw_content, list):
|
||||
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
||||
@@ -643,7 +681,7 @@ class JargonMiner:
|
||||
obj.chat_id = json.dumps(updated_chat_id_list, ensure_ascii=False)
|
||||
|
||||
# 开启all_global时,确保记录标记为is_global=True
|
||||
if global_config.jargon.all_global:
|
||||
if global_config.expression.all_global_jargon:
|
||||
obj.is_global = True
|
||||
# 关闭all_global时,保持原有is_global不变(不修改)
|
||||
|
||||
@@ -659,7 +697,7 @@ class JargonMiner:
|
||||
updated += 1
|
||||
else:
|
||||
# 没找到匹配记录,创建新记录
|
||||
if global_config.jargon.all_global:
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启all_global:新记录默认为is_global=True
|
||||
is_global_new = True
|
||||
else:
|
||||
@@ -699,6 +737,158 @@ class JargonMiner:
|
||||
logger.error(f"JargonMiner 运行失败: {e}")
|
||||
# 即使失败也保持时间戳更新,避免频繁重试
|
||||
|
||||
async def process_extracted_entries(
|
||||
self,
|
||||
entries: List[Dict[str, List[str]]],
|
||||
person_name_filter: Optional[Callable[[str], bool]] = None
|
||||
) -> None:
|
||||
"""
|
||||
处理已提取的黑话条目(从 expression_learner 路由过来的)
|
||||
|
||||
Args:
|
||||
entries: 黑话条目列表,每个元素格式为 {"content": "...", "raw_content": [...]}
|
||||
person_name_filter: 可选的过滤函数,用于检查内容是否包含人物名称
|
||||
"""
|
||||
if not entries:
|
||||
return
|
||||
|
||||
try:
|
||||
# 去重并合并raw_content(按 content 聚合)
|
||||
merged_entries: OrderedDict[str, Dict[str, List[str]]] = OrderedDict()
|
||||
for entry in entries:
|
||||
content_key = entry["content"]
|
||||
|
||||
# 检查是否包含人物名称
|
||||
# logger.info(f"process_extracted_entries 检查是否包含人物名称: {content_key}")
|
||||
# logger.info(f"person_name_filter: {person_name_filter}")
|
||||
if person_name_filter and person_name_filter(content_key):
|
||||
logger.info(f"process_extracted_entries 跳过包含人物名称的黑话: {content_key}")
|
||||
continue
|
||||
|
||||
raw_list = entry.get("raw_content", []) or []
|
||||
if content_key in merged_entries:
|
||||
merged_entries[content_key]["raw_content"].extend(raw_list)
|
||||
else:
|
||||
merged_entries[content_key] = {
|
||||
"content": content_key,
|
||||
"raw_content": list(raw_list),
|
||||
}
|
||||
|
||||
uniq_entries = []
|
||||
for merged_entry in merged_entries.values():
|
||||
raw_content_list = merged_entry["raw_content"]
|
||||
if raw_content_list:
|
||||
merged_entry["raw_content"] = list(dict.fromkeys(raw_content_list))
|
||||
uniq_entries.append(merged_entry)
|
||||
|
||||
saved = 0
|
||||
updated = 0
|
||||
for entry in uniq_entries:
|
||||
content = entry["content"]
|
||||
raw_content_list = entry["raw_content"] # 已经是列表
|
||||
|
||||
try:
|
||||
# 查询所有content匹配的记录
|
||||
query = Jargon.select().where(Jargon.content == content)
|
||||
|
||||
# 查找匹配的记录
|
||||
matched_obj = None
|
||||
for obj in query:
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启all_global:所有content匹配的记录都可以
|
||||
matched_obj = obj
|
||||
break
|
||||
else:
|
||||
# 关闭all_global:需要检查chat_id列表是否包含目标chat_id
|
||||
chat_id_list = parse_chat_id_list(obj.chat_id)
|
||||
if chat_id_list_contains(chat_id_list, self.chat_id):
|
||||
matched_obj = obj
|
||||
break
|
||||
|
||||
if matched_obj:
|
||||
obj = matched_obj
|
||||
try:
|
||||
obj.count = (obj.count or 0) + 1
|
||||
except Exception:
|
||||
obj.count = 1
|
||||
|
||||
# 合并raw_content列表:读取现有列表,追加新值,去重
|
||||
existing_raw_content = []
|
||||
if obj.raw_content:
|
||||
try:
|
||||
existing_raw_content = (
|
||||
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||
)
|
||||
if not isinstance(existing_raw_content, list):
|
||||
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
existing_raw_content = [obj.raw_content] if obj.raw_content else []
|
||||
|
||||
# 合并并去重
|
||||
merged_list = list(dict.fromkeys(existing_raw_content + raw_content_list))
|
||||
obj.raw_content = json.dumps(merged_list, ensure_ascii=False)
|
||||
|
||||
# 更新chat_id列表:增加当前chat_id的计数
|
||||
chat_id_list = parse_chat_id_list(obj.chat_id)
|
||||
updated_chat_id_list = update_chat_id_list(chat_id_list, self.chat_id, increment=1)
|
||||
obj.chat_id = json.dumps(updated_chat_id_list, ensure_ascii=False)
|
||||
|
||||
# 开启all_global时,确保记录标记为is_global=True
|
||||
if global_config.expression.all_global_jargon:
|
||||
obj.is_global = True
|
||||
# 关闭all_global时,保持原有is_global不变(不修改)
|
||||
|
||||
obj.save()
|
||||
|
||||
# 检查是否需要推断(达到阈值且超过上次判定值)
|
||||
if _should_infer_meaning(obj):
|
||||
# 异步触发推断,不阻塞主流程
|
||||
# 重新加载对象以确保数据最新
|
||||
jargon_id = obj.id
|
||||
asyncio.create_task(self._infer_meaning_by_id(jargon_id))
|
||||
|
||||
updated += 1
|
||||
else:
|
||||
# 没找到匹配记录,创建新记录
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启all_global:新记录默认为is_global=True
|
||||
is_global_new = True
|
||||
else:
|
||||
# 关闭all_global:新记录is_global=False
|
||||
is_global_new = False
|
||||
|
||||
# 使用新格式创建chat_id列表:[[chat_id, count]]
|
||||
chat_id_list = [[self.chat_id, 1]]
|
||||
chat_id_json = json.dumps(chat_id_list, ensure_ascii=False)
|
||||
|
||||
Jargon.create(
|
||||
content=content,
|
||||
raw_content=json.dumps(raw_content_list, ensure_ascii=False),
|
||||
chat_id=chat_id_json,
|
||||
is_global=is_global_new,
|
||||
count=1,
|
||||
)
|
||||
saved += 1
|
||||
except Exception as e:
|
||||
logger.error(f"保存jargon失败: chat_id={self.chat_id}, content={content}, err={e}")
|
||||
continue
|
||||
finally:
|
||||
self._add_to_cache(content)
|
||||
|
||||
# 固定输出提取的jargon结果,格式化为可读形式(只要有提取结果就输出)
|
||||
if uniq_entries:
|
||||
# 收集所有提取的jargon内容
|
||||
jargon_list = [entry["content"] for entry in uniq_entries]
|
||||
jargon_str = ",".join(jargon_list)
|
||||
|
||||
# 输出格式化的结果(使用logger.info会自动应用jargon模块的颜色)
|
||||
logger.info(f"[{self.stream_name}]疑似黑话: {jargon_str}")
|
||||
|
||||
if saved or updated:
|
||||
logger.debug(f"jargon写入: 新增 {saved} 条,更新 {updated} 条,chat_id={self.chat_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理已提取的黑话条目失败: {e}")
|
||||
|
||||
|
||||
class JargonMinerManager:
|
||||
def __init__(self) -> None:
|
||||
@@ -713,11 +903,6 @@ class JargonMinerManager:
|
||||
miner_manager = JargonMinerManager()
|
||||
|
||||
|
||||
async def extract_and_store_jargon(chat_id: str) -> None:
|
||||
miner = miner_manager.get_miner(chat_id)
|
||||
await miner.run_once()
|
||||
|
||||
|
||||
def search_jargon(
|
||||
keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True
|
||||
) -> List[Dict[str, str]]:
|
||||
@@ -765,7 +950,7 @@ def search_jargon(
|
||||
query = query.where(search_condition)
|
||||
|
||||
# 根据all_global配置决定查询逻辑
|
||||
if global_config.jargon.all_global:
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启all_global:所有记录都是全局的,查询所有is_global=True的记录(无视chat_id)
|
||||
query = query.where(Jargon.is_global)
|
||||
# 注意:对于all_global=False的情况,chat_id过滤在Python层面进行,以便兼容新旧格式
|
||||
@@ -782,7 +967,7 @@ def search_jargon(
|
||||
results = []
|
||||
for jargon in query:
|
||||
# 如果提供了chat_id且all_global=False,需要检查chat_id列表是否包含目标chat_id
|
||||
if chat_id and not global_config.jargon.all_global:
|
||||
if chat_id and not global_config.expression.all_global_jargon:
|
||||
chat_id_list = parse_chat_id_list(jargon.chat_id)
|
||||
# 如果记录是is_global=True,或者chat_id列表包含目标chat_id,则包含
|
||||
if not jargon.is_global and not chat_id_list_contains(chat_id_list, chat_id):
|
||||
380
src/bw_learner/learner_utils.py
Normal file
380
src/bw_learner/learner_utils.py
Normal file
@@ -0,0 +1,380 @@
|
||||
import re
|
||||
import difflib
|
||||
import random
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages,
|
||||
)
|
||||
from src.chat.utils.utils import parse_platform_accounts
|
||||
|
||||
|
||||
logger = get_logger("learner_utils")
|
||||
|
||||
|
||||
def filter_message_content(content: Optional[str]) -> str:
|
||||
"""
|
||||
过滤消息内容,移除回复、@、图片等格式
|
||||
|
||||
Args:
|
||||
content: 原始消息内容
|
||||
|
||||
Returns:
|
||||
str: 过滤后的内容
|
||||
"""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
|
||||
content = re.sub(r"\[回复.*?\],说:\s*", "", content)
|
||||
# 移除@<...>格式的内容
|
||||
content = re.sub(r"@<[^>]*>", "", content)
|
||||
# 移除[picid:...]格式的图片ID
|
||||
content = re.sub(r"\[picid:[^\]]*\]", "", content)
|
||||
# 移除[表情包:...]格式的内容
|
||||
content = re.sub(r"\[表情包:[^\]]*\]", "", content)
|
||||
|
||||
return content.strip()
|
||||
|
||||
|
||||
def calculate_similarity(text1: str, text2: str) -> float:
|
||||
"""
|
||||
计算两个文本的相似度,返回0-1之间的值
|
||||
使用SequenceMatcher计算相似度
|
||||
|
||||
Args:
|
||||
text1: 第一个文本
|
||||
text2: 第二个文本
|
||||
|
||||
Returns:
|
||||
float: 相似度值,范围0-1
|
||||
"""
|
||||
return difflib.SequenceMatcher(None, text1, text2).ratio()
|
||||
|
||||
|
||||
def calculate_style_similarity(style1: str, style2: str) -> float:
|
||||
"""
|
||||
计算两个 style 的相似度,返回0-1之间的值
|
||||
在计算前会移除"使用"和"句式"这两个词(参考 expression_similarity_analysis.py)
|
||||
|
||||
Args:
|
||||
style1: 第一个 style
|
||||
style2: 第二个 style
|
||||
|
||||
Returns:
|
||||
float: 相似度值,范围0-1
|
||||
"""
|
||||
if not style1 or not style2:
|
||||
return 0.0
|
||||
|
||||
# 移除"使用"和"句式"这两个词
|
||||
def remove_ignored_words(text: str) -> str:
|
||||
"""移除需要忽略的词"""
|
||||
text = text.replace("使用", "")
|
||||
text = text.replace("句式", "")
|
||||
return text.strip()
|
||||
|
||||
cleaned_style1 = remove_ignored_words(style1)
|
||||
cleaned_style2 = remove_ignored_words(style2)
|
||||
|
||||
# 如果清理后文本为空,返回0
|
||||
if not cleaned_style1 or not cleaned_style2:
|
||||
return 0.0
|
||||
|
||||
return difflib.SequenceMatcher(None, cleaned_style1, cleaned_style2).ratio()
|
||||
|
||||
|
||||
def format_create_date(timestamp: float) -> str:
|
||||
"""
|
||||
将时间戳格式化为可读的日期字符串
|
||||
|
||||
Args:
|
||||
timestamp: 时间戳
|
||||
|
||||
Returns:
|
||||
str: 格式化后的日期字符串
|
||||
"""
|
||||
try:
|
||||
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
return "未知时间"
|
||||
|
||||
|
||||
def _compute_weights(population: List[Dict]) -> List[float]:
|
||||
"""
|
||||
根据表达的count计算权重,范围限定在1~5之间。
|
||||
count越高,权重越高,但最多为基础权重的5倍。
|
||||
如果表达已checked,权重会再乘以3倍。
|
||||
"""
|
||||
if not population:
|
||||
return []
|
||||
|
||||
counts = []
|
||||
checked_flags = []
|
||||
for item in population:
|
||||
count = item.get("count", 1)
|
||||
try:
|
||||
count_value = float(count)
|
||||
except (TypeError, ValueError):
|
||||
count_value = 1.0
|
||||
counts.append(max(count_value, 0.0))
|
||||
# 获取checked状态
|
||||
checked = item.get("checked", False)
|
||||
checked_flags.append(bool(checked))
|
||||
|
||||
min_count = min(counts)
|
||||
max_count = max(counts)
|
||||
|
||||
if max_count == min_count:
|
||||
base_weights = [1.0 for _ in counts]
|
||||
else:
|
||||
base_weights = []
|
||||
for count_value in counts:
|
||||
# 线性映射到[1,5]区间
|
||||
normalized = (count_value - min_count) / (max_count - min_count)
|
||||
base_weights.append(1.0 + normalized * 4.0) # 1~5
|
||||
|
||||
# 如果checked,权重乘以3
|
||||
weights = []
|
||||
for base_weight, checked in zip(base_weights, checked_flags, strict=False):
|
||||
if checked:
|
||||
weights.append(base_weight * 3.0)
|
||||
else:
|
||||
weights.append(base_weight)
|
||||
return weights
|
||||
|
||||
|
||||
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
|
||||
"""
|
||||
随机抽样函数
|
||||
|
||||
Args:
|
||||
population: 总体数据列表
|
||||
k: 需要抽取的数量
|
||||
|
||||
Returns:
|
||||
List[Dict]: 抽取的数据列表
|
||||
"""
|
||||
if not population or k <= 0:
|
||||
return []
|
||||
|
||||
if len(population) <= k:
|
||||
return population.copy()
|
||||
|
||||
selected: List[Dict] = []
|
||||
population_copy = population.copy()
|
||||
|
||||
for _ in range(min(k, len(population_copy))):
|
||||
weights = _compute_weights(population_copy)
|
||||
total_weight = sum(weights)
|
||||
if total_weight <= 0:
|
||||
# 回退到均匀随机
|
||||
idx = random.randint(0, len(population_copy) - 1)
|
||||
selected.append(population_copy.pop(idx))
|
||||
continue
|
||||
|
||||
threshold = random.uniform(0, total_weight)
|
||||
cumulative = 0.0
|
||||
for idx, weight in enumerate(weights):
|
||||
cumulative += weight
|
||||
if threshold <= cumulative:
|
||||
selected.append(population_copy.pop(idx))
|
||||
break
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]:
|
||||
"""
|
||||
解析chat_id字段,兼容旧格式(字符串)和新格式(JSON列表)
|
||||
|
||||
Args:
|
||||
chat_id_value: 可能是字符串(旧格式)或JSON字符串(新格式)
|
||||
|
||||
Returns:
|
||||
List[List[Any]]: 格式为 [[chat_id, count], ...] 的列表
|
||||
"""
|
||||
if not chat_id_value:
|
||||
return []
|
||||
|
||||
# 如果是字符串,尝试解析为JSON
|
||||
if isinstance(chat_id_value, str):
|
||||
# 尝试解析JSON
|
||||
try:
|
||||
parsed = json.loads(chat_id_value)
|
||||
if isinstance(parsed, list):
|
||||
# 新格式:已经是列表
|
||||
return parsed
|
||||
elif isinstance(parsed, str):
|
||||
# 解析后还是字符串,说明是旧格式
|
||||
return [[parsed, 1]]
|
||||
else:
|
||||
# 其他类型,当作旧格式处理
|
||||
return [[str(chat_id_value), 1]]
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# 解析失败,当作旧格式(纯字符串)
|
||||
return [[str(chat_id_value), 1]]
|
||||
elif isinstance(chat_id_value, list):
|
||||
# 已经是列表格式
|
||||
return chat_id_value
|
||||
else:
|
||||
# 其他类型,转换为旧格式
|
||||
return [[str(chat_id_value), 1]]
|
||||
|
||||
|
||||
def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, increment: int = 1) -> List[List[Any]]:
|
||||
"""
|
||||
更新chat_id列表,如果target_chat_id已存在则增加计数,否则添加新条目
|
||||
|
||||
Args:
|
||||
chat_id_list: 当前的chat_id列表,格式为 [[chat_id, count], ...]
|
||||
target_chat_id: 要更新或添加的chat_id
|
||||
increment: 增加的计数,默认为1
|
||||
|
||||
Returns:
|
||||
List[List[Any]]: 更新后的chat_id列表
|
||||
"""
|
||||
item = _find_chat_id_item(chat_id_list, target_chat_id)
|
||||
if item is not None:
|
||||
# 找到匹配的chat_id,增加计数
|
||||
if len(item) >= 2:
|
||||
item[1] = (item[1] if isinstance(item[1], (int, float)) else 0) + increment
|
||||
else:
|
||||
item.append(increment)
|
||||
else:
|
||||
# 未找到,添加新条目
|
||||
chat_id_list.append([target_chat_id, increment])
|
||||
|
||||
return chat_id_list
|
||||
|
||||
|
||||
def _find_chat_id_item(chat_id_list: List[List[Any]], target_chat_id: str) -> Optional[List[Any]]:
|
||||
"""
|
||||
在chat_id列表中查找匹配的项(辅助函数)
|
||||
|
||||
Args:
|
||||
chat_id_list: chat_id列表,格式为 [[chat_id, count], ...]
|
||||
target_chat_id: 要查找的chat_id
|
||||
|
||||
Returns:
|
||||
如果找到则返回匹配的项,否则返回None
|
||||
"""
|
||||
for item in chat_id_list:
|
||||
if isinstance(item, list) and len(item) >= 1 and str(item[0]) == str(target_chat_id):
|
||||
return item
|
||||
return None
|
||||
|
||||
|
||||
def chat_id_list_contains(chat_id_list: List[List[Any]], target_chat_id: str) -> bool:
|
||||
"""
|
||||
检查chat_id列表中是否包含指定的chat_id
|
||||
|
||||
Args:
|
||||
chat_id_list: chat_id列表,格式为 [[chat_id, count], ...]
|
||||
target_chat_id: 要查找的chat_id
|
||||
|
||||
Returns:
|
||||
bool: 如果包含则返回True
|
||||
"""
|
||||
return _find_chat_id_item(chat_id_list, target_chat_id) is not None
|
||||
|
||||
|
||||
def contains_bot_self_name(content: str) -> bool:
|
||||
"""
|
||||
判断词条是否包含机器人的昵称或别名
|
||||
"""
|
||||
if not content:
|
||||
return False
|
||||
|
||||
bot_config = getattr(global_config, "bot", None)
|
||||
if not bot_config:
|
||||
return False
|
||||
|
||||
target = content.strip().lower()
|
||||
nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower()
|
||||
alias_names = [str(alias or "").strip().lower() for alias in getattr(bot_config, "alias_names", []) or []]
|
||||
|
||||
candidates = [name for name in [nickname, *alias_names] if name]
|
||||
|
||||
return any(name in target for name in candidates)
|
||||
|
||||
|
||||
def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]:
|
||||
"""
|
||||
构建包含中心消息上下文的段落(前3条+后3条),使用标准的 readable builder 输出
|
||||
"""
|
||||
if not messages or center_index < 0 or center_index >= len(messages):
|
||||
return None
|
||||
|
||||
context_start = max(0, center_index - 3)
|
||||
context_end = min(len(messages), center_index + 1 + 3)
|
||||
context_messages = messages[context_start:context_end]
|
||||
|
||||
if not context_messages:
|
||||
return None
|
||||
|
||||
try:
|
||||
paragraph = build_readable_messages(
|
||||
messages=context_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
show_pic=True,
|
||||
message_id_list=None,
|
||||
remove_emoji_stickers=False,
|
||||
pic_single=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"构建上下文段落失败: {e}")
|
||||
return None
|
||||
|
||||
paragraph = paragraph.strip()
|
||||
return paragraph or None
|
||||
|
||||
|
||||
def is_bot_message(msg: Any) -> bool:
|
||||
"""判断消息是否来自机器人自身"""
|
||||
if msg is None:
|
||||
return False
|
||||
|
||||
bot_config = getattr(global_config, "bot", None)
|
||||
if not bot_config:
|
||||
return False
|
||||
|
||||
platform = (
|
||||
str(getattr(msg, "user_platform", "") or getattr(getattr(msg, "user_info", None), "platform", "") or "")
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip()
|
||||
|
||||
if not platform or not user_id:
|
||||
return False
|
||||
|
||||
platform_accounts = {}
|
||||
try:
|
||||
platform_accounts = parse_platform_accounts(getattr(bot_config, "platforms", []) or [])
|
||||
except Exception:
|
||||
platform_accounts = {}
|
||||
|
||||
bot_accounts: Dict[str, str] = {}
|
||||
qq_account = str(getattr(bot_config, "qq_account", "") or "").strip()
|
||||
if qq_account:
|
||||
bot_accounts["qq"] = qq_account
|
||||
|
||||
telegram_account = str(getattr(bot_config, "telegram_account", "") or "").strip()
|
||||
if telegram_account:
|
||||
bot_accounts["telegram"] = telegram_account
|
||||
|
||||
for plat, account in platform_accounts.items():
|
||||
if account and plat not in bot_accounts:
|
||||
bot_accounts[plat] = account
|
||||
|
||||
bot_account = bot_accounts.get(platform)
|
||||
return bool(bot_account and user_id == bot_account)
|
||||
212
src/bw_learner/message_recorder.py
Normal file
212
src/bw_learner/message_recorder.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import time
|
||||
import asyncio
|
||||
from typing import List, Any
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.bw_learner.expression_learner import expression_learner_manager
|
||||
from src.bw_learner.jargon_miner import miner_manager
|
||||
|
||||
logger = get_logger("bw_learner")
|
||||
|
||||
|
||||
class MessageRecorder:
|
||||
"""
|
||||
统一的消息记录器,负责管理时间窗口和消息提取,并将消息分发给 expression_learner 和 jargon_miner
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
|
||||
|
||||
# 维护每个chat的上次提取时间
|
||||
self.last_extraction_time: float = time.time()
|
||||
|
||||
# 提取锁,防止并发执行
|
||||
self._extraction_lock = asyncio.Lock()
|
||||
|
||||
# 获取 expression 和 jargon 的配置参数
|
||||
self._init_parameters()
|
||||
|
||||
# 获取 expression_learner 和 jargon_miner 实例
|
||||
self.expression_learner = expression_learner_manager.get_expression_learner(chat_id)
|
||||
self.jargon_miner = miner_manager.get_miner(chat_id)
|
||||
|
||||
def _init_parameters(self) -> None:
|
||||
"""初始化提取参数"""
|
||||
# 获取 expression 配置
|
||||
_, self.enable_expression_learning, self.enable_jargon_learning = (
|
||||
global_config.expression.get_expression_config_for_chat(self.chat_id)
|
||||
)
|
||||
self.min_messages_for_extraction = 30
|
||||
self.min_extraction_interval = 60
|
||||
|
||||
logger.debug(
|
||||
f"MessageRecorder 初始化: chat_id={self.chat_id}, "
|
||||
f"min_messages={self.min_messages_for_extraction}, "
|
||||
f"min_interval={self.min_extraction_interval}"
|
||||
)
|
||||
|
||||
def should_trigger_extraction(self) -> bool:
|
||||
"""
|
||||
检查是否应该触发消息提取
|
||||
|
||||
Returns:
|
||||
bool: 是否应该触发提取
|
||||
"""
|
||||
# 检查时间间隔
|
||||
time_diff = time.time() - self.last_extraction_time
|
||||
if time_diff < self.min_extraction_interval:
|
||||
return False
|
||||
|
||||
# 检查消息数量
|
||||
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_extraction_time,
|
||||
timestamp_end=time.time(),
|
||||
)
|
||||
|
||||
if not recent_messages or len(recent_messages) < self.min_messages_for_extraction:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def extract_and_distribute(self) -> None:
|
||||
"""
|
||||
提取消息并分发给 expression_learner 和 jargon_miner
|
||||
"""
|
||||
# 使用异步锁防止并发执行
|
||||
async with self._extraction_lock:
|
||||
# 在锁内检查,避免并发触发
|
||||
if not self.should_trigger_extraction():
|
||||
return
|
||||
|
||||
# 检查 chat_stream 是否存在
|
||||
if not self.chat_stream:
|
||||
return
|
||||
|
||||
# 记录本次提取的时间窗口,避免重复提取
|
||||
extraction_start_time = self.last_extraction_time
|
||||
extraction_end_time = time.time()
|
||||
|
||||
# 立即更新提取时间,防止并发触发
|
||||
self.last_extraction_time = extraction_end_time
|
||||
|
||||
try:
|
||||
# logger.info(f"在聊天流 {self.chat_name} 开始统一消息提取和分发")
|
||||
|
||||
# 拉取提取窗口内的消息
|
||||
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=extraction_start_time,
|
||||
timestamp_end=extraction_end_time,
|
||||
)
|
||||
|
||||
if not messages:
|
||||
logger.debug(f"聊天流 {self.chat_name} 没有新消息,跳过提取")
|
||||
return
|
||||
|
||||
# 按时间排序,确保顺序一致
|
||||
messages = sorted(messages, key=lambda msg: msg.time or 0)
|
||||
|
||||
logger.info(
|
||||
f"聊天流 {self.chat_name} 提取到 {len(messages)} 条消息,"
|
||||
f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}"
|
||||
)
|
||||
|
||||
# 分别触发 expression_learner 和 jargon_miner 的处理
|
||||
# 传递提取的消息,避免它们重复获取
|
||||
# 触发 expression 学习(如果启用)
|
||||
if self.enable_expression_learning:
|
||||
asyncio.create_task(
|
||||
self._trigger_expression_learning(extraction_start_time, extraction_end_time, messages)
|
||||
)
|
||||
|
||||
# 触发 jargon 提取(如果启用),传递消息
|
||||
# if self.enable_jargon_learning:
|
||||
# asyncio.create_task(
|
||||
# self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages)
|
||||
# )
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
# 即使失败也保持时间戳更新,避免频繁重试
|
||||
|
||||
async def _trigger_expression_learning(
|
||||
self, timestamp_start: float, timestamp_end: float, messages: List[Any]
|
||||
) -> None:
|
||||
"""
|
||||
触发 expression 学习,使用指定的消息列表
|
||||
|
||||
Args:
|
||||
timestamp_start: 开始时间戳
|
||||
timestamp_end: 结束时间戳
|
||||
messages: 消息列表
|
||||
"""
|
||||
try:
|
||||
# 传递消息给 ExpressionLearner(必需参数)
|
||||
learnt_style = await self.expression_learner.learn_and_store(messages=messages)
|
||||
|
||||
if learnt_style:
|
||||
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
|
||||
else:
|
||||
logger.debug(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发表达学习失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
async def _trigger_jargon_extraction(
|
||||
self, timestamp_start: float, timestamp_end: float, messages: List[Any]
|
||||
) -> None:
|
||||
"""
|
||||
触发 jargon 提取,使用指定的消息列表
|
||||
|
||||
Args:
|
||||
timestamp_start: 开始时间戳
|
||||
timestamp_end: 结束时间戳
|
||||
messages: 消息列表
|
||||
"""
|
||||
try:
|
||||
# 传递消息给 JargonMiner,避免它重复获取
|
||||
await self.jargon_miner.run_once(messages=messages)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发黑话提取失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
class MessageRecorderManager:
|
||||
"""MessageRecorder 管理器"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._recorders: dict[str, MessageRecorder] = {}
|
||||
|
||||
def get_recorder(self, chat_id: str) -> MessageRecorder:
|
||||
"""获取或创建指定 chat_id 的 MessageRecorder"""
|
||||
if chat_id not in self._recorders:
|
||||
self._recorders[chat_id] = MessageRecorder(chat_id)
|
||||
return self._recorders[chat_id]
|
||||
|
||||
|
||||
# 全局管理器实例
|
||||
recorder_manager = MessageRecorderManager()
|
||||
|
||||
|
||||
async def extract_and_distribute_messages(chat_id: str) -> None:
|
||||
"""
|
||||
统一的消息提取和分发入口函数
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
"""
|
||||
recorder = recorder_manager.get_recorder(chat_id)
|
||||
await recorder.extract_and_distribute()
|
||||
491
src/chat/brain_chat/PFC/action_planner.py
Normal file
491
src/chat/brain_chat/PFC/action_planner.py
Normal file
@@ -0,0 +1,491 @@
|
||||
import time
|
||||
from typing import Tuple, Optional # 增加了 Optional
|
||||
from src.common.logger_manager import get_logger
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ...config.config import global_config
|
||||
from .chat_observer import ChatObserver
|
||||
from .pfc_utils import get_items_from_json
|
||||
from src.individuality.individuality import Individuality
|
||||
from .observation_info import ObservationInfo
|
||||
from .conversation_info import ConversationInfo
|
||||
from src.plugins.utils.chat_message_builder import build_readable_messages
|
||||
|
||||
|
||||
logger = get_logger("pfc_action_planner")
|
||||
|
||||
|
||||
# --- 定义 Prompt 模板 ---
|
||||
|
||||
# Prompt(1): 首次回复或非连续回复时的决策 Prompt
|
||||
PROMPT_INITIAL_REPLY = """{persona_text}。现在你在参与一场QQ私聊,请根据以下【所有信息】审慎且灵活的决策下一步行动,可以回复,可以倾听,可以调取知识,甚至可以屏蔽对方:
|
||||
|
||||
【当前对话目标】
|
||||
{goals_str}
|
||||
{knowledge_info_str}
|
||||
|
||||
【最近行动历史概要】
|
||||
{action_history_summary}
|
||||
【上一次行动的详细情况和结果】
|
||||
{last_action_context}
|
||||
【时间和超时提示】
|
||||
{time_since_last_bot_message_info}{timeout_context}
|
||||
【最近的对话记录】(包括你已成功发送的消息 和 新收到的消息)
|
||||
{chat_history_text}
|
||||
|
||||
------
|
||||
可选行动类型以及解释:
|
||||
fetch_knowledge: 需要调取知识或记忆,当需要专业知识或特定信息时选择,对方若提到你不太认识的人名或实体也可以尝试选择
|
||||
listening: 倾听对方发言,当你认为对方话才说到一半,发言明显未结束时选择
|
||||
direct_reply: 直接回复对方
|
||||
rethink_goal: 思考一个对话目标,当你觉得目前对话需要目标,或当前目标不再适用,或话题卡住时选择。注意私聊的环境是灵活的,有可能需要经常选择
|
||||
end_conversation: 结束对话,对方长时间没回复或者当你觉得对话告一段落时可以选择
|
||||
block_and_ignore: 更加极端的结束对话方式,直接结束对话并在一段时间内无视对方所有发言(屏蔽),当对话让你感到十分不适,或你遭到各类骚扰时选择
|
||||
|
||||
请以JSON格式输出你的决策:
|
||||
{{
|
||||
"action": "选择的行动类型 (必须是上面列表中的一个)",
|
||||
"reason": "选择该行动的详细原因 (必须有解释你是如何根据“上一次行动结果”、“对话记录”和自身设定人设做出合理判断的)"
|
||||
}}
|
||||
|
||||
注意:请严格按照JSON格式输出,不要包含任何其他内容。"""
|
||||
|
||||
# Prompt(2): 上一次成功回复后,决定继续发言时的决策 Prompt
|
||||
PROMPT_FOLLOW_UP = """{persona_text}。现在你在参与一场QQ私聊,刚刚你已经回复了对方,请根据以下【所有信息】审慎且灵活的决策下一步行动,可以继续发送新消息,可以等待,可以倾听,可以调取知识,甚至可以屏蔽对方:
|
||||
|
||||
【当前对话目标】
|
||||
{goals_str}
|
||||
{knowledge_info_str}
|
||||
|
||||
【最近行动历史概要】
|
||||
{action_history_summary}
|
||||
【上一次行动的详细情况和结果】
|
||||
{last_action_context}
|
||||
【时间和超时提示】
|
||||
{time_since_last_bot_message_info}{timeout_context}
|
||||
【最近的对话记录】(包括你已成功发送的消息 和 新收到的消息)
|
||||
{chat_history_text}
|
||||
|
||||
------
|
||||
可选行动类型以及解释:
|
||||
fetch_knowledge: 需要调取知识,当需要专业知识或特定信息时选择,对方若提到你不太认识的人名或实体也可以尝试选择
|
||||
wait: 暂时不说话,留给对方交互空间,等待对方回复(尤其是在你刚发言后、或上次发言因重复、发言过多被拒时、或不确定做什么时,这是不错的选择)
|
||||
listening: 倾听对方发言(虽然你刚发过言,但如果对方立刻回复且明显话没说完,可以选择这个)
|
||||
send_new_message: 发送一条新消息继续对话,允许适当的追问、补充、深入话题,或开启相关新话题。**但是避免在因重复被拒后立即使用,也不要在对方没有回复的情况下过多的“消息轰炸”或重复发言**
|
||||
rethink_goal: 思考一个对话目标,当你觉得目前对话需要目标,或当前目标不再适用,或话题卡住时选择。注意私聊的环境是灵活的,有可能需要经常选择
|
||||
end_conversation: 结束对话,对方长时间没回复或者当你觉得对话告一段落时可以选择
|
||||
block_and_ignore: 更加极端的结束对话方式,直接结束对话并在一段时间内无视对方所有发言(屏蔽),当对话让你感到十分不适,或你遭到各类骚扰时选择
|
||||
|
||||
请以JSON格式输出你的决策:
|
||||
{{
|
||||
"action": "选择的行动类型 (必须是上面列表中的一个)",
|
||||
"reason": "选择该行动的详细原因 (必须有解释你是如何根据“上一次行动结果”、“对话记录”和自身设定人设做出合理判断的。请说明你为什么选择继续发言而不是等待,以及打算发送什么类型的新消息连续发言,必须记录已经发言了几次)"
|
||||
}}
|
||||
|
||||
注意:请严格按照JSON格式输出,不要包含任何其他内容。"""
|
||||
|
||||
# 新增:Prompt(3): 决定是否在结束对话前发送告别语
|
||||
PROMPT_END_DECISION = """{persona_text}。刚刚你决定结束一场 QQ 私聊。
|
||||
|
||||
【你们之前的聊天记录】
|
||||
{chat_history_text}
|
||||
|
||||
你觉得你们的对话已经完整结束了吗?有时候,在对话自然结束后再说点什么可能会有点奇怪,但有时也可能需要一条简短的消息来圆满结束。
|
||||
如果觉得确实有必要再发一条简短、自然、符合你人设的告别消息(比如 "好,下次再聊~" 或 "嗯,先这样吧"),就输出 "yes"。
|
||||
如果觉得当前状态下直接结束对话更好,没有必要再发消息,就输出 "no"。
|
||||
|
||||
请以 JSON 格式输出你的选择:
|
||||
{{
|
||||
"say_bye": "yes/no",
|
||||
"reason": "选择 yes 或 no 的原因和内心想法 (简要说明)"
|
||||
}}
|
||||
|
||||
注意:请严格按照 JSON 格式输出,不要包含任何其他内容。"""
|
||||
|
||||
|
||||
# ActionPlanner 类定义,顶格
|
||||
class ActionPlanner:
|
||||
"""行动规划器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model=global_config.llm_PFC_action_planner,
|
||||
temperature=global_config.llm_PFC_action_planner["temp"],
|
||||
max_tokens=1500,
|
||||
request_type="action_planning",
|
||||
)
|
||||
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
|
||||
self.name = global_config.BOT_NICKNAME
|
||||
self.private_name = private_name
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
# self.action_planner_info = ActionPlannerInfo() # 移除未使用的变量
|
||||
|
||||
# 修改 plan 方法签名,增加 last_successful_reply_action 参数
|
||||
async def plan(
|
||||
self,
|
||||
observation_info: ObservationInfo,
|
||||
conversation_info: ConversationInfo,
|
||||
last_successful_reply_action: Optional[str],
|
||||
) -> Tuple[str, str]:
|
||||
"""规划下一步行动
|
||||
|
||||
Args:
|
||||
observation_info: 决策信息
|
||||
conversation_info: 对话信息
|
||||
last_successful_reply_action: 上一次成功的回复动作类型 ('direct_reply' 或 'send_new_message' 或 None)
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (行动类型, 行动原因)
|
||||
"""
|
||||
# --- 获取 Bot 上次发言时间信息 ---
|
||||
# (这部分逻辑不变)
|
||||
time_since_last_bot_message_info = ""
|
||||
try:
|
||||
bot_id = str(global_config.BOT_QQ)
|
||||
if hasattr(observation_info, "chat_history") and observation_info.chat_history:
|
||||
for i in range(len(observation_info.chat_history) - 1, -1, -1):
|
||||
msg = observation_info.chat_history[i]
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
sender_info = msg.get("user_info", {})
|
||||
sender_id = str(sender_info.get("user_id")) if isinstance(sender_info, dict) else None
|
||||
msg_time = msg.get("time")
|
||||
if sender_id == bot_id and msg_time:
|
||||
time_diff = time.time() - msg_time
|
||||
if time_diff < 60.0:
|
||||
time_since_last_bot_message_info = (
|
||||
f"提示:你上一条成功发送的消息是在 {time_diff:.1f} 秒前。\n"
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]Observation info chat history is empty or not available for bot time check."
|
||||
)
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo object might not have chat_history attribute yet for bot time check."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[私聊][{self.private_name}]获取 Bot 上次发言时间时出错: {e}")
|
||||
|
||||
# --- 获取超时提示信息 ---
|
||||
# (这部分逻辑不变)
|
||||
timeout_context = ""
|
||||
try:
|
||||
if hasattr(conversation_info, "goal_list") and conversation_info.goal_list:
|
||||
last_goal_dict = conversation_info.goal_list[-1]
|
||||
if isinstance(last_goal_dict, dict) and "goal" in last_goal_dict:
|
||||
last_goal_text = last_goal_dict["goal"]
|
||||
if isinstance(last_goal_text, str) and "分钟,思考接下来要做什么" in last_goal_text:
|
||||
try:
|
||||
timeout_minutes_text = last_goal_text.split(",")[0].replace("你等待了", "")
|
||||
timeout_context = f"重要提示:对方已经长时间({timeout_minutes_text})没有回复你的消息了(这可能代表对方繁忙/不想回复/没注意到你的消息等情况,或在对方看来本次聊天已告一段落),请基于此情况规划下一步。\n"
|
||||
except Exception:
|
||||
timeout_context = "重要提示:对方已经长时间没有回复你的消息了(这可能代表对方繁忙/不想回复/没注意到你的消息等情况,或在对方看来本次聊天已告一段落),请基于此情况规划下一步。\n"
|
||||
else:
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]Conversation info goal_list is empty or not available for timeout check."
|
||||
)
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ConversationInfo object might not have goal_list attribute yet for timeout check."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[私聊][{self.private_name}]检查超时目标时出错: {e}")
|
||||
|
||||
# --- 构建通用 Prompt 参数 ---
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]开始规划行动:当前目标: {getattr(conversation_info, 'goal_list', '不可用')}"
|
||||
)
|
||||
|
||||
# 构建对话目标 (goals_str)
|
||||
goals_str = ""
|
||||
try:
|
||||
if hasattr(conversation_info, "goal_list") and conversation_info.goal_list:
|
||||
for goal_reason in conversation_info.goal_list:
|
||||
if isinstance(goal_reason, dict):
|
||||
goal = goal_reason.get("goal", "目标内容缺失")
|
||||
reasoning = goal_reason.get("reasoning", "没有明确原因")
|
||||
else:
|
||||
goal = str(goal_reason)
|
||||
reasoning = "没有明确原因"
|
||||
|
||||
goal = str(goal) if goal is not None else "目标内容缺失"
|
||||
reasoning = str(reasoning) if reasoning is not None else "没有明确原因"
|
||||
goals_str += f"- 目标:{goal}\n 原因:{reasoning}\n"
|
||||
|
||||
if not goals_str:
|
||||
goals_str = "- 目前没有明确对话目标,请考虑设定一个。\n"
|
||||
else:
|
||||
goals_str = "- 目前没有明确对话目标,请考虑设定一个。\n"
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ConversationInfo object might not have goal_list attribute yet."
|
||||
)
|
||||
goals_str = "- 获取对话目标时出错。\n"
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]构建对话目标字符串时出错: {e}")
|
||||
goals_str = "- 构建对话目标时出错。\n"
|
||||
|
||||
# --- 知识信息字符串构建开始 ---
|
||||
knowledge_info_str = "【已获取的相关知识和记忆】\n"
|
||||
try:
|
||||
# 检查 conversation_info 是否有 knowledge_list 并且不为空
|
||||
if hasattr(conversation_info, "knowledge_list") and conversation_info.knowledge_list:
|
||||
# 最多只显示最近的 5 条知识,防止 Prompt 过长
|
||||
recent_knowledge = conversation_info.knowledge_list[-5:]
|
||||
for i, knowledge_item in enumerate(recent_knowledge):
|
||||
if isinstance(knowledge_item, dict):
|
||||
query = knowledge_item.get("query", "未知查询")
|
||||
knowledge = knowledge_item.get("knowledge", "无知识内容")
|
||||
source = knowledge_item.get("source", "未知来源")
|
||||
# 只取知识内容的前 2000 个字,避免太长
|
||||
knowledge_snippet = knowledge[:2000] + "..." if len(knowledge) > 2000 else knowledge
|
||||
knowledge_info_str += (
|
||||
f"{i + 1}. 关于 '{query}' 的知识 (来源: {source}):\n {knowledge_snippet}\n"
|
||||
)
|
||||
else:
|
||||
# 处理列表里不是字典的异常情况
|
||||
knowledge_info_str += f"{i + 1}. 发现一条格式不正确的知识记录。\n"
|
||||
|
||||
if not recent_knowledge: # 如果 knowledge_list 存在但为空
|
||||
knowledge_info_str += "- 暂无相关知识和记忆。\n"
|
||||
|
||||
else:
|
||||
# 如果 conversation_info 没有 knowledge_list 属性,或者列表为空
|
||||
knowledge_info_str += "- 暂无相关知识记忆。\n"
|
||||
except AttributeError:
|
||||
logger.warning(f"[私聊][{self.private_name}]ConversationInfo 对象可能缺少 knowledge_list 属性。")
|
||||
knowledge_info_str += "- 获取知识列表时出错。\n"
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]构建知识信息字符串时出错: {e}")
|
||||
knowledge_info_str += "- 处理知识列表时出错。\n"
|
||||
# --- 知识信息字符串构建结束 ---
|
||||
|
||||
# 获取聊天历史记录 (chat_history_text)
|
||||
try:
|
||||
if hasattr(observation_info, "chat_history") and observation_info.chat_history:
|
||||
chat_history_text = observation_info.chat_history_str
|
||||
if not chat_history_text:
|
||||
chat_history_text = "还没有聊天记录。\n"
|
||||
else:
|
||||
chat_history_text = "还没有聊天记录。\n"
|
||||
|
||||
if hasattr(observation_info, "new_messages_count") and observation_info.new_messages_count > 0:
|
||||
if hasattr(observation_info, "unprocessed_messages") and observation_info.unprocessed_messages:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
new_messages_str = await build_readable_messages(
|
||||
new_messages_list,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
chat_history_text += (
|
||||
f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo has new_messages_count > 0 but unprocessed_messages is empty or missing."
|
||||
)
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo object might be missing expected attributes for chat history."
|
||||
)
|
||||
chat_history_text = "获取聊天记录时出错。\n"
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]处理聊天记录时发生未知错误: {e}")
|
||||
chat_history_text = "处理聊天记录时出错。\n"
|
||||
|
||||
# 构建 Persona 文本 (persona_text)
|
||||
persona_text = f"你的名字是{self.name},{self.personality_info}。"
|
||||
|
||||
# 构建行动历史和上一次行动结果 (action_history_summary, last_action_context)
|
||||
# (这部分逻辑不变)
|
||||
action_history_summary = "你最近执行的行动历史:\n"
|
||||
last_action_context = "关于你【上一次尝试】的行动:\n"
|
||||
action_history_list = []
|
||||
try:
|
||||
if hasattr(conversation_info, "done_action") and conversation_info.done_action:
|
||||
action_history_list = conversation_info.done_action[-5:]
|
||||
else:
|
||||
logger.debug(f"[私聊][{self.private_name}]Conversation info done_action is empty or not available.")
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ConversationInfo object might not have done_action attribute yet."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]访问行动历史时出错: {e}")
|
||||
|
||||
if not action_history_list:
|
||||
action_history_summary += "- 还没有执行过行动。\n"
|
||||
last_action_context += "- 这是你规划的第一个行动。\n"
|
||||
else:
|
||||
for i, action_data in enumerate(action_history_list):
|
||||
action_type = "未知"
|
||||
plan_reason = "未知"
|
||||
status = "未知"
|
||||
final_reason = ""
|
||||
action_time = ""
|
||||
|
||||
if isinstance(action_data, dict):
|
||||
action_type = action_data.get("action", "未知")
|
||||
plan_reason = action_data.get("plan_reason", "未知规划原因")
|
||||
status = action_data.get("status", "未知")
|
||||
final_reason = action_data.get("final_reason", "")
|
||||
action_time = action_data.get("time", "")
|
||||
elif isinstance(action_data, tuple):
|
||||
# 假设旧格式兼容
|
||||
if len(action_data) > 0:
|
||||
action_type = action_data[0]
|
||||
if len(action_data) > 1:
|
||||
plan_reason = action_data[1] # 可能是规划原因或最终原因
|
||||
if len(action_data) > 2:
|
||||
status = action_data[2]
|
||||
if status == "recall" and len(action_data) > 3:
|
||||
final_reason = action_data[3]
|
||||
elif status == "done" and action_type in ["direct_reply", "send_new_message"]:
|
||||
plan_reason = "成功发送" # 简化显示
|
||||
|
||||
reason_text = f", 失败/取消原因: {final_reason}" if final_reason else ""
|
||||
summary_line = f"- 时间:{action_time}, 尝试行动:'{action_type}', 状态:{status}{reason_text}"
|
||||
action_history_summary += summary_line + "\n"
|
||||
|
||||
if i == len(action_history_list) - 1:
|
||||
last_action_context += f"- 上次【规划】的行动是: '{action_type}'\n"
|
||||
last_action_context += f"- 当时规划的【原因】是: {plan_reason}\n"
|
||||
if status == "done":
|
||||
last_action_context += "- 该行动已【成功执行】。\n"
|
||||
# 记录这次成功的行动类型,供下次决策
|
||||
# self.last_successful_action_type = action_type # 不在这里记录,由 conversation 控制
|
||||
elif status == "recall":
|
||||
last_action_context += "- 但该行动最终【未能执行/被取消】。\n"
|
||||
if final_reason:
|
||||
last_action_context += f"- 【重要】失败/取消的具体原因是: “{final_reason}”\n"
|
||||
else:
|
||||
last_action_context += "- 【重要】失败/取消原因未明确记录。\n"
|
||||
# self.last_successful_action_type = None # 行动失败,清除记录
|
||||
else:
|
||||
last_action_context += f"- 该行动当前状态: {status}\n"
|
||||
# self.last_successful_action_type = None # 非完成状态,清除记录
|
||||
|
||||
# --- 选择 Prompt ---
|
||||
if last_successful_reply_action in ["direct_reply", "send_new_message"]:
|
||||
prompt_template = PROMPT_FOLLOW_UP
|
||||
logger.debug(f"[私聊][{self.private_name}]使用 PROMPT_FOLLOW_UP (追问决策)")
|
||||
else:
|
||||
prompt_template = PROMPT_INITIAL_REPLY
|
||||
logger.debug(f"[私聊][{self.private_name}]使用 PROMPT_INITIAL_REPLY (首次/非连续回复决策)")
|
||||
|
||||
# --- 格式化最终的 Prompt ---
|
||||
prompt = prompt_template.format(
|
||||
persona_text=persona_text,
|
||||
goals_str=goals_str if goals_str.strip() else "- 目前没有明确对话目标,请考虑设定一个。",
|
||||
action_history_summary=action_history_summary,
|
||||
last_action_context=last_action_context,
|
||||
time_since_last_bot_message_info=time_since_last_bot_message_info,
|
||||
timeout_context=timeout_context,
|
||||
chat_history_text=chat_history_text if chat_history_text.strip() else "还没有聊天记录。",
|
||||
knowledge_info_str=knowledge_info_str,
|
||||
)
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]发送到LLM的最终提示词:\n------\n{prompt}\n------")
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM (行动规划) 原始返回内容: {content}")
|
||||
|
||||
# --- 初始行动规划解析 ---
|
||||
success, initial_result = get_items_from_json(
|
||||
content,
|
||||
self.private_name,
|
||||
"action",
|
||||
"reason",
|
||||
default_values={"action": "wait", "reason": "LLM返回格式错误或未提供原因,默认等待"},
|
||||
)
|
||||
|
||||
initial_action = initial_result.get("action", "wait")
|
||||
initial_reason = initial_result.get("reason", "LLM未提供原因,默认等待")
|
||||
|
||||
# 检查是否需要进行结束对话决策 ---
|
||||
if initial_action == "end_conversation":
|
||||
logger.info(f"[私聊][{self.private_name}]初步规划结束对话,进入告别决策...")
|
||||
|
||||
# 使用新的 PROMPT_END_DECISION
|
||||
end_decision_prompt = PROMPT_END_DECISION.format(
|
||||
persona_text=persona_text, # 复用之前的 persona_text
|
||||
chat_history_text=chat_history_text, # 复用之前的 chat_history_text
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]发送到LLM的结束决策提示词:\n------\n{end_decision_prompt}\n------"
|
||||
)
|
||||
try:
|
||||
end_content, _ = await self.llm.generate_response_async(end_decision_prompt) # 再次调用LLM
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM (结束决策) 原始返回内容: {end_content}")
|
||||
|
||||
# 解析结束决策的JSON
|
||||
end_success, end_result = get_items_from_json(
|
||||
end_content,
|
||||
self.private_name,
|
||||
"say_bye",
|
||||
"reason",
|
||||
default_values={"say_bye": "no", "reason": "结束决策LLM返回格式错误,默认不告别"},
|
||||
required_types={"say_bye": str, "reason": str}, # 明确类型
|
||||
)
|
||||
|
||||
say_bye_decision = end_result.get("say_bye", "no").lower() # 转小写方便比较
|
||||
end_decision_reason = end_result.get("reason", "未提供原因")
|
||||
|
||||
if end_success and say_bye_decision == "yes":
|
||||
# 决定要告别,返回新的 'say_goodbye' 动作
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]结束决策: yes, 准备生成告别语. 原因: {end_decision_reason}"
|
||||
)
|
||||
# 注意:这里的 reason 可以考虑拼接初始原因和结束决策原因,或者只用结束决策原因
|
||||
final_action = "say_goodbye"
|
||||
final_reason = f"决定发送告别语。决策原因: {end_decision_reason} (原结束理由: {initial_reason})"
|
||||
return final_action, final_reason
|
||||
else:
|
||||
# 决定不告别 (包括解析失败或明确说no)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]结束决策: no, 直接结束对话. 原因: {end_decision_reason}"
|
||||
)
|
||||
# 返回原始的 'end_conversation' 动作
|
||||
final_action = "end_conversation"
|
||||
final_reason = initial_reason # 保持原始的结束理由
|
||||
return final_action, final_reason
|
||||
|
||||
except Exception as end_e:
|
||||
logger.error(f"[私聊][{self.private_name}]调用结束决策LLM或处理结果时出错: {str(end_e)}")
|
||||
# 出错时,默认执行原始的结束对话
|
||||
logger.warning(f"[私聊][{self.private_name}]结束决策出错,将按原计划执行 end_conversation")
|
||||
return "end_conversation", initial_reason # 返回原始动作和原因
|
||||
|
||||
else:
|
||||
action = initial_action
|
||||
reason = initial_reason
|
||||
|
||||
# 验证action类型 (保持不变)
|
||||
valid_actions = [
|
||||
"direct_reply",
|
||||
"send_new_message",
|
||||
"fetch_knowledge",
|
||||
"wait",
|
||||
"listening",
|
||||
"rethink_goal",
|
||||
"end_conversation", # 仍然需要验证,因为可能从上面决策后返回
|
||||
"block_and_ignore",
|
||||
"say_goodbye", # 也要验证这个新动作
|
||||
]
|
||||
if action not in valid_actions:
|
||||
logger.warning(f"[私聊][{self.private_name}]LLM返回了未知的行动类型: '{action}',强制改为 wait")
|
||||
reason = f"(原始行动'{action}'无效,已强制改为wait) {reason}"
|
||||
action = "wait"
|
||||
|
||||
logger.info(f"[私聊][{self.private_name}]规划的行动: {action}")
|
||||
logger.info(f"[私聊][{self.private_name}]行动原因: {reason}")
|
||||
return action, reason
|
||||
|
||||
except Exception as e:
|
||||
# 外层异常处理保持不变
|
||||
logger.error(f"[私聊][{self.private_name}]规划行动时调用 LLM 或处理结果出错: {str(e)}")
|
||||
return "wait", f"行动规划处理中发生错误,暂时等待: {str(e)}"
|
||||
379
src/chat/brain_chat/PFC/chat_observer.py
Normal file
379
src/chat/brain_chat/PFC/chat_observer.py
Normal file
@@ -0,0 +1,379 @@
|
||||
import time
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import Optional, Dict, Any, List
|
||||
from src.common.logger import get_module_logger
|
||||
from maim_message import UserInfo
|
||||
from ...config.config import global_config
|
||||
from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification
|
||||
from .message_storage import MongoDBMessageStorage
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_module_logger("chat_observer")
|
||||
|
||||
|
||||
class ChatObserver:
|
||||
"""聊天状态观察器"""
|
||||
|
||||
# 类级别的实例管理
|
||||
_instances: Dict[str, "ChatObserver"] = {}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, stream_id: str, private_name: str) -> "ChatObserver":
|
||||
"""获取或创建观察器实例
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
private_name: 私聊名称
|
||||
|
||||
Returns:
|
||||
ChatObserver: 观察器实例
|
||||
"""
|
||||
if stream_id not in cls._instances:
|
||||
cls._instances[stream_id] = cls(stream_id, private_name)
|
||||
return cls._instances[stream_id]
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
"""初始化观察器
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
"""
|
||||
self.last_check_time = None
|
||||
self.last_bot_speak_time = None
|
||||
self.last_user_speak_time = None
|
||||
if stream_id in self._instances:
|
||||
raise RuntimeError(f"ChatObserver for {stream_id} already exists. Use get_instance() instead.")
|
||||
|
||||
self.stream_id = stream_id
|
||||
self.private_name = private_name
|
||||
self.message_storage = MongoDBMessageStorage()
|
||||
|
||||
# self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
|
||||
# self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
|
||||
# self.last_check_time: float = time.time() # 上次查看聊天记录时间
|
||||
self.last_message_read: Optional[Dict[str, Any]] = None # 最后读取的消息ID
|
||||
self.last_message_time: float = time.time()
|
||||
|
||||
self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间
|
||||
|
||||
# 运行状态
|
||||
self._running: bool = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._update_event = asyncio.Event() # 触发更新的事件
|
||||
self._update_complete = asyncio.Event() # 更新完成的事件
|
||||
|
||||
# 通知管理器
|
||||
self.notification_manager = NotificationManager()
|
||||
|
||||
# 冷场检查配置
|
||||
self.cold_chat_threshold: float = 60.0 # 60秒无消息判定为冷场
|
||||
self.last_cold_chat_check: float = time.time()
|
||||
self.is_cold_chat_state: bool = False
|
||||
|
||||
self.update_event = asyncio.Event()
|
||||
self.update_interval = 2 # 更新间隔(秒)
|
||||
self.message_cache = []
|
||||
self.update_running = False
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""检查距离上一次观察之后是否有了新消息
|
||||
|
||||
Returns:
|
||||
bool: 是否有新消息
|
||||
"""
|
||||
logger.debug(f"[私聊][{self.private_name}]检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
|
||||
|
||||
new_message_exists = await self.message_storage.has_new_messages(self.stream_id, self.last_check_time)
|
||||
|
||||
if new_message_exists:
|
||||
logger.debug(f"[私聊][{self.private_name}]发现新消息")
|
||||
self.last_check_time = time.time()
|
||||
|
||||
return new_message_exists
|
||||
|
||||
async def _add_message_to_history(self, message: Dict[str, Any]):
|
||||
"""添加消息到历史记录并发送通知
|
||||
|
||||
Args:
|
||||
message: 消息数据
|
||||
"""
|
||||
try:
|
||||
# 发送新消息通知
|
||||
notification = create_new_message_notification(
|
||||
sender="chat_observer", target="observation_info", message=message
|
||||
)
|
||||
# print(self.notification_manager)
|
||||
await self.notification_manager.send_notification(notification)
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]添加消息到历史记录时出错: {e}")
|
||||
print(traceback.format_exc())
|
||||
|
||||
# 检查并更新冷场状态
|
||||
await self._check_cold_chat()
|
||||
|
||||
async def _check_cold_chat(self):
|
||||
"""检查是否处于冷场状态并发送通知"""
|
||||
current_time = time.time()
|
||||
|
||||
# 每10秒检查一次冷场状态
|
||||
if current_time - self.last_cold_chat_check < 10:
|
||||
return
|
||||
|
||||
self.last_cold_chat_check = current_time
|
||||
|
||||
# 判断是否冷场
|
||||
is_cold = (
|
||||
True
|
||||
if self.last_message_time is None
|
||||
else (current_time - self.last_message_time) > self.cold_chat_threshold
|
||||
)
|
||||
|
||||
# 如果冷场状态发生变化,发送通知
|
||||
if is_cold != self.is_cold_chat_state:
|
||||
self.is_cold_chat_state = is_cold
|
||||
notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold)
|
||||
await self.notification_manager.send_notification(notification)
|
||||
|
||||
def new_message_after(self, time_point: float) -> bool:
|
||||
"""判断是否在指定时间点后有新消息
|
||||
|
||||
Args:
|
||||
time_point: 时间戳
|
||||
|
||||
Returns:
|
||||
bool: 是否有新消息
|
||||
"""
|
||||
|
||||
if self.last_message_time is None:
|
||||
logger.debug(f"[私聊][{self.private_name}]没有最后消息时间,返回 False")
|
||||
return False
|
||||
|
||||
has_new = self.last_message_time > time_point
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]判断是否在指定时间点后有新消息: {self.last_message_time} > {time_point} = {has_new}"
|
||||
)
|
||||
return has_new
|
||||
|
||||
def get_message_history(
|
||||
self,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
limit: Optional[int] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取消息历史
|
||||
|
||||
Args:
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
limit: 限制返回消息数量
|
||||
user_id: 指定用户ID
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
"""
|
||||
filtered_messages = self.message_history
|
||||
|
||||
if start_time is not None:
|
||||
filtered_messages = [m for m in filtered_messages if m["time"] >= start_time]
|
||||
|
||||
if end_time is not None:
|
||||
filtered_messages = [m for m in filtered_messages if m["time"] <= end_time]
|
||||
|
||||
if user_id is not None:
|
||||
filtered_messages = [
|
||||
m for m in filtered_messages if UserInfo.from_dict(m.get("user_info", {})).user_id == user_id
|
||||
]
|
||||
|
||||
if limit is not None:
|
||||
filtered_messages = filtered_messages[-limit:]
|
||||
|
||||
return filtered_messages
|
||||
|
||||
async def _fetch_new_messages(self) -> List[Dict[str, Any]]:
|
||||
"""获取新消息
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 新消息列表
|
||||
"""
|
||||
new_messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_time)
|
||||
|
||||
if new_messages:
|
||||
self.last_message_read = new_messages[-1]
|
||||
self.last_message_time = new_messages[-1]["time"]
|
||||
|
||||
# print(f"获取数据库中找到的新消息: {new_messages}")
|
||||
|
||||
return new_messages
|
||||
|
||||
async def _fetch_new_messages_before(self, time_point: float) -> List[Dict[str, Any]]:
|
||||
"""获取指定时间点之前的消息
|
||||
|
||||
Args:
|
||||
time_point: 时间戳
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 最多5条消息
|
||||
"""
|
||||
new_messages = await self.message_storage.get_messages_before(self.stream_id, time_point)
|
||||
|
||||
if new_messages:
|
||||
self.last_message_read = new_messages[-1]["message_id"]
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]获取指定时间点111之前的消息: {new_messages}")
|
||||
|
||||
return new_messages
|
||||
|
||||
"""主要观察循环"""
|
||||
|
||||
async def _update_loop(self):
|
||||
"""更新循环"""
|
||||
# try:
|
||||
# start_time = time.time()
|
||||
# messages = await self._fetch_new_messages_before(start_time)
|
||||
# for message in messages:
|
||||
# await self._add_message_to_history(message)
|
||||
# logger.debug(f"[私聊][{self.private_name}]缓冲消息: {messages}")
|
||||
# except Exception as e:
|
||||
# logger.error(f"[私聊][{self.private_name}]缓冲消息出错: {e}")
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# 等待事件或超时(1秒)
|
||||
try:
|
||||
# print("等待事件")
|
||||
await asyncio.wait_for(self._update_event.wait(), timeout=1)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# print("超时")
|
||||
pass # 超时后也执行一次检查
|
||||
|
||||
self._update_event.clear() # 重置触发事件
|
||||
self._update_complete.clear() # 重置完成事件
|
||||
|
||||
# 获取新消息
|
||||
new_messages = await self._fetch_new_messages()
|
||||
|
||||
if new_messages:
|
||||
# 处理新消息
|
||||
for message in new_messages:
|
||||
await self._add_message_to_history(message)
|
||||
|
||||
# 设置完成事件
|
||||
self._update_complete.set()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]更新循环出错: {e}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
self._update_complete.set() # 即使出错也要设置完成事件
|
||||
|
||||
def trigger_update(self):
|
||||
"""触发一次立即更新"""
|
||||
self._update_event.set()
|
||||
|
||||
async def wait_for_update(self, timeout: float = 5.0) -> bool:
|
||||
"""等待更新完成
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功完成更新(False表示超时)
|
||||
"""
|
||||
try:
|
||||
await asyncio.wait_for(self._update_complete.wait(), timeout=timeout)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[私聊][{self.private_name}]等待更新完成超时({timeout}秒)")
|
||||
return False
|
||||
|
||||
def start(self):
|
||||
"""启动观察器"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._task = asyncio.create_task(self._update_loop())
|
||||
logger.debug(f"[私聊][{self.private_name}]ChatObserver for {self.stream_id} started")
|
||||
|
||||
def stop(self):
|
||||
"""停止观察器"""
|
||||
self._running = False
|
||||
self._update_event.set() # 设置事件以解除等待
|
||||
self._update_complete.set() # 设置完成事件以解除等待
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
logger.debug(f"[私聊][{self.private_name}]ChatObserver for {self.stream_id} stopped")
|
||||
|
||||
async def process_chat_history(self, messages: list):
|
||||
"""处理聊天历史
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
"""
|
||||
self.update_check_time()
|
||||
|
||||
for msg in messages:
|
||||
try:
|
||||
user_info = UserInfo.from_dict(msg.get("user_info", {}))
|
||||
if user_info.user_id == global_config.BOT_QQ:
|
||||
self.update_bot_speak_time(msg["time"])
|
||||
else:
|
||||
self.update_user_speak_time(msg["time"])
|
||||
except Exception as e:
|
||||
logger.warning(f"[私聊][{self.private_name}]处理消息时间时出错: {e}")
|
||||
continue
|
||||
|
||||
def update_check_time(self):
|
||||
"""更新查看时间"""
|
||||
self.last_check_time = time.time()
|
||||
|
||||
def update_bot_speak_time(self, speak_time: Optional[float] = None):
|
||||
"""更新机器人说话时间"""
|
||||
self.last_bot_speak_time = speak_time or time.time()
|
||||
|
||||
def update_user_speak_time(self, speak_time: Optional[float] = None):
|
||||
"""更新用户说话时间"""
|
||||
self.last_user_speak_time = speak_time or time.time()
|
||||
|
||||
def get_time_info(self) -> str:
|
||||
"""获取时间信息文本"""
|
||||
current_time = time.time()
|
||||
time_info = ""
|
||||
|
||||
if self.last_bot_speak_time:
|
||||
bot_speak_ago = current_time - self.last_bot_speak_time
|
||||
time_info += f"\n距离你上次发言已经过去了{int(bot_speak_ago)}秒"
|
||||
|
||||
if self.last_user_speak_time:
|
||||
user_speak_ago = current_time - self.last_user_speak_time
|
||||
time_info += f"\n距离对方上次发言已经过去了{int(user_speak_ago)}秒"
|
||||
|
||||
return time_info
|
||||
|
||||
def get_cached_messages(self, limit: int = 50) -> List[Dict[str, Any]]:
|
||||
"""获取缓存的消息历史
|
||||
|
||||
Args:
|
||||
limit: 获取的最大消息数量,默认50
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 缓存的消息历史列表
|
||||
"""
|
||||
return self.message_cache[-limit:]
|
||||
|
||||
def get_last_message(self) -> Optional[Dict[str, Any]]:
|
||||
"""获取最后一条消息
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 最后一条消息,如果没有则返回None
|
||||
"""
|
||||
if not self.message_cache:
|
||||
return None
|
||||
return self.message_cache[-1]
|
||||
|
||||
def __str__(self):
|
||||
return f"ChatObserver for {self.stream_id}"
|
||||
290
src/chat/brain_chat/PFC/chat_states.py
Normal file
290
src/chat/brain_chat/PFC/chat_states.py
Normal file
@@ -0,0 +1,290 @@
|
||||
from enum import Enum, auto
|
||||
from typing import Optional, Dict, Any, List, Set
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ChatState(Enum):
|
||||
"""聊天状态枚举"""
|
||||
|
||||
NORMAL = auto() # 正常状态
|
||||
NEW_MESSAGE = auto() # 有新消息
|
||||
COLD_CHAT = auto() # 冷场状态
|
||||
ACTIVE_CHAT = auto() # 活跃状态
|
||||
BOT_SPEAKING = auto() # 机器人正在说话
|
||||
USER_SPEAKING = auto() # 用户正在说话
|
||||
SILENT = auto() # 沉默状态
|
||||
ERROR = auto() # 错误状态
|
||||
|
||||
|
||||
class NotificationType(Enum):
|
||||
"""通知类型枚举"""
|
||||
|
||||
NEW_MESSAGE = auto() # 新消息通知
|
||||
COLD_CHAT = auto() # 冷场通知
|
||||
ACTIVE_CHAT = auto() # 活跃通知
|
||||
BOT_SPEAKING = auto() # 机器人说话通知
|
||||
USER_SPEAKING = auto() # 用户说话通知
|
||||
MESSAGE_DELETED = auto() # 消息删除通知
|
||||
USER_JOINED = auto() # 用户加入通知
|
||||
USER_LEFT = auto() # 用户离开通知
|
||||
ERROR = auto() # 错误通知
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatStateInfo:
|
||||
"""聊天状态信息"""
|
||||
|
||||
state: ChatState
|
||||
last_message_time: Optional[float] = None
|
||||
last_message_content: Optional[str] = None
|
||||
last_speaker: Optional[str] = None
|
||||
message_count: int = 0
|
||||
cold_duration: float = 0.0 # 冷场持续时间(秒)
|
||||
active_duration: float = 0.0 # 活跃持续时间(秒)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Notification:
|
||||
"""通知基类"""
|
||||
|
||||
type: NotificationType
|
||||
timestamp: float
|
||||
sender: str # 发送者标识
|
||||
target: str # 接收者标识
|
||||
data: Dict[str, Any]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {"type": self.type.name, "timestamp": self.timestamp, "data": self.data}
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateNotification(Notification):
|
||||
"""持续状态通知"""
|
||||
|
||||
is_active: bool = True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
base_dict = super().to_dict()
|
||||
base_dict["is_active"] = self.is_active
|
||||
return base_dict
|
||||
|
||||
|
||||
class NotificationHandler(ABC):
|
||||
"""通知处理器接口"""
|
||||
|
||||
@abstractmethod
|
||||
async def handle_notification(self, notification: Notification):
|
||||
"""处理通知"""
|
||||
pass
|
||||
|
||||
|
||||
class NotificationManager:
|
||||
"""通知管理器"""
|
||||
|
||||
def __init__(self):
|
||||
# 按接收者和通知类型存储处理器
|
||||
self._handlers: Dict[str, Dict[NotificationType, List[NotificationHandler]]] = {}
|
||||
self._active_states: Set[NotificationType] = set()
|
||||
self._notification_history: List[Notification] = []
|
||||
|
||||
def register_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler):
|
||||
"""注册通知处理器
|
||||
|
||||
Args:
|
||||
target: 接收者标识(例如:"pfc")
|
||||
notification_type: 要处理的通知类型
|
||||
handler: 处理器实例
|
||||
"""
|
||||
if target not in self._handlers:
|
||||
self._handlers[target] = {}
|
||||
if notification_type not in self._handlers[target]:
|
||||
self._handlers[target][notification_type] = []
|
||||
# print(self._handlers[target][notification_type])
|
||||
self._handlers[target][notification_type].append(handler)
|
||||
# print(self._handlers[target][notification_type])
|
||||
|
||||
def unregister_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler):
|
||||
"""注销通知处理器
|
||||
|
||||
Args:
|
||||
target: 接收者标识
|
||||
notification_type: 通知类型
|
||||
handler: 要注销的处理器实例
|
||||
"""
|
||||
if target in self._handlers and notification_type in self._handlers[target]:
|
||||
handlers = self._handlers[target][notification_type]
|
||||
if handler in handlers:
|
||||
handlers.remove(handler)
|
||||
# 如果该类型的处理器列表为空,删除该类型
|
||||
if not handlers:
|
||||
del self._handlers[target][notification_type]
|
||||
# 如果该目标没有任何处理器,删除该目标
|
||||
if not self._handlers[target]:
|
||||
del self._handlers[target]
|
||||
|
||||
async def send_notification(self, notification: Notification):
|
||||
"""发送通知"""
|
||||
self._notification_history.append(notification)
|
||||
|
||||
# 如果是状态通知,更新活跃状态
|
||||
if isinstance(notification, StateNotification):
|
||||
if notification.is_active:
|
||||
self._active_states.add(notification.type)
|
||||
else:
|
||||
self._active_states.discard(notification.type)
|
||||
|
||||
# 调用目标接收者的处理器
|
||||
target = notification.target
|
||||
if target in self._handlers:
|
||||
handlers = self._handlers[target].get(notification.type, [])
|
||||
# print(handlers)
|
||||
for handler in handlers:
|
||||
# print(f"调用处理器: {handler}")
|
||||
await handler.handle_notification(notification)
|
||||
|
||||
def get_active_states(self) -> Set[NotificationType]:
|
||||
"""获取当前活跃的状态"""
|
||||
return self._active_states.copy()
|
||||
|
||||
def is_state_active(self, state_type: NotificationType) -> bool:
|
||||
"""检查特定状态是否活跃"""
|
||||
return state_type in self._active_states
|
||||
|
||||
def get_notification_history(
|
||||
self, sender: Optional[str] = None, target: Optional[str] = None, limit: Optional[int] = None
|
||||
) -> List[Notification]:
|
||||
"""获取通知历史
|
||||
|
||||
Args:
|
||||
sender: 过滤特定发送者的通知
|
||||
target: 过滤特定接收者的通知
|
||||
limit: 限制返回数量
|
||||
"""
|
||||
history = self._notification_history
|
||||
|
||||
if sender:
|
||||
history = [n for n in history if n.sender == sender]
|
||||
if target:
|
||||
history = [n for n in history if n.target == target]
|
||||
|
||||
if limit is not None:
|
||||
history = history[-limit:]
|
||||
|
||||
return history
|
||||
|
||||
def __str__(self):
|
||||
str = ""
|
||||
for target, handlers in self._handlers.items():
|
||||
for notification_type, handler_list in handlers.items():
|
||||
str += f"NotificationManager for {target} {notification_type} {handler_list}"
|
||||
return str
|
||||
|
||||
|
||||
# 一些常用的通知创建函数
|
||||
def create_new_message_notification(sender: str, target: str, message: Dict[str, Any]) -> Notification:
|
||||
"""创建新消息通知"""
|
||||
return Notification(
|
||||
type=NotificationType.NEW_MESSAGE,
|
||||
timestamp=datetime.now().timestamp(),
|
||||
sender=sender,
|
||||
target=target,
|
||||
data={
|
||||
"message_id": message.get("message_id"),
|
||||
"processed_plain_text": message.get("processed_plain_text"),
|
||||
"detailed_plain_text": message.get("detailed_plain_text"),
|
||||
"user_info": message.get("user_info"),
|
||||
"time": message.get("time"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> StateNotification:
|
||||
"""创建冷场状态通知"""
|
||||
return StateNotification(
|
||||
type=NotificationType.COLD_CHAT,
|
||||
timestamp=datetime.now().timestamp(),
|
||||
sender=sender,
|
||||
target=target,
|
||||
data={"is_cold": is_cold},
|
||||
is_active=is_cold,
|
||||
)
|
||||
|
||||
|
||||
def create_active_chat_notification(sender: str, target: str, is_active: bool) -> StateNotification:
|
||||
"""创建活跃状态通知"""
|
||||
return StateNotification(
|
||||
type=NotificationType.ACTIVE_CHAT,
|
||||
timestamp=datetime.now().timestamp(),
|
||||
sender=sender,
|
||||
target=target,
|
||||
data={"is_active": is_active},
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
|
||||
class ChatStateManager:
|
||||
"""聊天状态管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.current_state = ChatState.NORMAL
|
||||
self.state_info = ChatStateInfo(state=ChatState.NORMAL)
|
||||
self.state_history: list[ChatStateInfo] = []
|
||||
|
||||
def update_state(self, new_state: ChatState, **kwargs):
|
||||
"""更新聊天状态
|
||||
|
||||
Args:
|
||||
new_state: 新的状态
|
||||
**kwargs: 其他状态信息
|
||||
"""
|
||||
self.current_state = new_state
|
||||
self.state_info.state = new_state
|
||||
|
||||
# 更新其他状态信息
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self.state_info, key):
|
||||
setattr(self.state_info, key, value)
|
||||
|
||||
# 记录状态历史
|
||||
self.state_history.append(self.state_info)
|
||||
|
||||
def get_current_state_info(self) -> ChatStateInfo:
|
||||
"""获取当前状态信息"""
|
||||
return self.state_info
|
||||
|
||||
def get_state_history(self) -> list[ChatStateInfo]:
|
||||
"""获取状态历史"""
|
||||
return self.state_history
|
||||
|
||||
def is_cold_chat(self, threshold: float = 60.0) -> bool:
|
||||
"""判断是否处于冷场状态
|
||||
|
||||
Args:
|
||||
threshold: 冷场阈值(秒)
|
||||
|
||||
Returns:
|
||||
bool: 是否冷场
|
||||
"""
|
||||
if not self.state_info.last_message_time:
|
||||
return True
|
||||
|
||||
current_time = datetime.now().timestamp()
|
||||
return (current_time - self.state_info.last_message_time) > threshold
|
||||
|
||||
def is_active_chat(self, threshold: float = 5.0) -> bool:
|
||||
"""判断是否处于活跃状态
|
||||
|
||||
Args:
|
||||
threshold: 活跃阈值(秒)
|
||||
|
||||
Returns:
|
||||
bool: 是否活跃
|
||||
"""
|
||||
if not self.state_info.last_message_time:
|
||||
return False
|
||||
|
||||
current_time = datetime.now().timestamp()
|
||||
return (current_time - self.state_info.last_message_time) <= threshold
|
||||
701
src/chat/brain_chat/PFC/conversation.py
Normal file
701
src/chat/brain_chat/PFC/conversation.py
Normal file
@@ -0,0 +1,701 @@
|
||||
import time
|
||||
import asyncio
|
||||
import datetime
|
||||
|
||||
# from .message_storage import MongoDBMessageStorage
|
||||
from src.plugins.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
|
||||
# from ...config.config import global_config
|
||||
from typing import Dict, Any, Optional
|
||||
from ..chat.message import Message
|
||||
from .pfc_types import ConversationState
|
||||
from .pfc import ChatObserver, GoalAnalyzer
|
||||
from .message_sender import DirectMessageSender
|
||||
from src.common.logger_manager import get_logger
|
||||
from .action_planner import ActionPlanner
|
||||
from .observation_info import ObservationInfo
|
||||
from .conversation_info import ConversationInfo # 确保导入 ConversationInfo
|
||||
from .reply_generator import ReplyGenerator
|
||||
from ..chat.chat_stream import ChatStream
|
||||
from maim_message import UserInfo
|
||||
from src.plugins.chat.chat_stream import chat_manager
|
||||
from .pfc_KnowledgeFetcher import KnowledgeFetcher
|
||||
from .waiter import Waiter
|
||||
|
||||
import traceback
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("pfc")
|
||||
|
||||
|
||||
class Conversation:
|
||||
"""对话类,负责管理单个对话的状态和行为"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
"""初始化对话实例
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
"""
|
||||
self.stream_id = stream_id
|
||||
self.private_name = private_name
|
||||
self.state = ConversationState.INIT
|
||||
self.should_continue = False
|
||||
self.ignore_until_timestamp: Optional[float] = None
|
||||
|
||||
# 回复相关
|
||||
self.generated_reply = ""
|
||||
|
||||
async def _initialize(self):
|
||||
"""初始化实例,注册所有组件"""
|
||||
|
||||
try:
|
||||
self.action_planner = ActionPlanner(self.stream_id, self.private_name)
|
||||
self.goal_analyzer = GoalAnalyzer(self.stream_id, self.private_name)
|
||||
self.reply_generator = ReplyGenerator(self.stream_id, self.private_name)
|
||||
self.knowledge_fetcher = KnowledgeFetcher(self.private_name)
|
||||
self.waiter = Waiter(self.stream_id, self.private_name)
|
||||
self.direct_sender = DirectMessageSender(self.private_name)
|
||||
|
||||
# 获取聊天流信息
|
||||
self.chat_stream = chat_manager.get_stream(self.stream_id)
|
||||
|
||||
self.stop_action_planner = False
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]初始化对话实例:注册运行组件失败: {e}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
try:
|
||||
# 决策所需要的信息,包括自身自信和观察信息两部分
|
||||
# 注册观察器和观测信息
|
||||
self.chat_observer = ChatObserver.get_instance(self.stream_id, self.private_name)
|
||||
self.chat_observer.start()
|
||||
self.observation_info = ObservationInfo(self.private_name)
|
||||
self.observation_info.bind_to_chat_observer(self.chat_observer)
|
||||
# print(self.chat_observer.get_cached_messages(limit=)
|
||||
|
||||
self.conversation_info = ConversationInfo()
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]初始化对话实例:注册信息组件失败: {e}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
raise
|
||||
try:
|
||||
logger.info(f"[私聊][{self.private_name}]为 {self.stream_id} 加载初始聊天记录...")
|
||||
initial_messages = get_raw_msg_before_timestamp_with_chat( #
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=30, # 加载最近30条作为初始上下文,可以调整
|
||||
)
|
||||
chat_talking_prompt = await build_readable_messages(
|
||||
initial_messages,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
if initial_messages:
|
||||
# 将加载的消息填充到 ObservationInfo 的 chat_history
|
||||
self.observation_info.chat_history = initial_messages
|
||||
self.observation_info.chat_history_str = chat_talking_prompt + "\n"
|
||||
self.observation_info.chat_history_count = len(initial_messages)
|
||||
|
||||
# 更新 ObservationInfo 中的时间戳等信息
|
||||
last_msg = initial_messages[-1]
|
||||
self.observation_info.last_message_time = last_msg.get("time")
|
||||
last_user_info = UserInfo.from_dict(last_msg.get("user_info", {}))
|
||||
self.observation_info.last_message_sender = last_user_info.user_id
|
||||
self.observation_info.last_message_content = last_msg.get("processed_plain_text", "")
|
||||
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]成功加载 {len(initial_messages)} 条初始聊天记录。最后一条消息时间: {self.observation_info.last_message_time}"
|
||||
)
|
||||
|
||||
# 让 ChatObserver 从加载的最后一条消息之后开始同步
|
||||
self.chat_observer.last_message_time = self.observation_info.last_message_time
|
||||
self.chat_observer.last_message_read = last_msg # 更新 observer 的最后读取记录
|
||||
else:
|
||||
logger.info(f"[私聊][{self.private_name}]没有找到初始聊天记录。")
|
||||
|
||||
except Exception as load_err:
|
||||
logger.error(f"[私聊][{self.private_name}]加载初始聊天记录时出错: {load_err}")
|
||||
# 出错也要继续,只是没有历史记录而已
|
||||
# 组件准备完成,启动该论对话
|
||||
self.should_continue = True
|
||||
asyncio.create_task(self.start())
|
||||
|
||||
async def start(self):
|
||||
"""开始对话流程"""
|
||||
try:
|
||||
logger.info(f"[私聊][{self.private_name}]对话系统启动中...")
|
||||
asyncio.create_task(self._plan_and_action_loop())
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]启动对话系统失败: {e}")
|
||||
raise
|
||||
|
||||
async def _plan_and_action_loop(self):
|
||||
"""思考步,PFC核心循环模块"""
|
||||
while self.should_continue:
|
||||
# 忽略逻辑
|
||||
if self.ignore_until_timestamp and time.time() < self.ignore_until_timestamp:
|
||||
await asyncio.sleep(30)
|
||||
continue
|
||||
elif self.ignore_until_timestamp and time.time() >= self.ignore_until_timestamp:
|
||||
logger.info(f"[私聊][{self.private_name}]忽略时间已到 {self.stream_id},准备结束对话。")
|
||||
self.ignore_until_timestamp = None
|
||||
self.should_continue = False
|
||||
continue
|
||||
try:
|
||||
# --- 在规划前记录当前新消息数量 ---
|
||||
initial_new_message_count = 0
|
||||
if hasattr(self.observation_info, "new_messages_count"):
|
||||
initial_new_message_count = self.observation_info.new_messages_count + 1 # 算上麦麦自己发的那一条
|
||||
else:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo missing 'new_messages_count' before planning."
|
||||
)
|
||||
|
||||
# --- 调用 Action Planner ---
|
||||
# 传递 self.conversation_info.last_successful_reply_action
|
||||
action, reason = await self.action_planner.plan(
|
||||
self.observation_info, self.conversation_info, self.conversation_info.last_successful_reply_action
|
||||
)
|
||||
|
||||
# --- 规划后检查是否有 *更多* 新消息到达 ---
|
||||
current_new_message_count = 0
|
||||
if hasattr(self.observation_info, "new_messages_count"):
|
||||
current_new_message_count = self.observation_info.new_messages_count
|
||||
else:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo missing 'new_messages_count' after planning."
|
||||
)
|
||||
|
||||
if current_new_message_count > initial_new_message_count + 2:
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]规划期间发现新增消息 ({initial_new_message_count} -> {current_new_message_count}),跳过本次行动,重新规划"
|
||||
)
|
||||
# 如果规划期间有新消息,也应该重置上次回复状态,因为现在要响应新消息了
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
# 包含 send_new_message
|
||||
if initial_new_message_count > 0 and action in ["direct_reply", "send_new_message"]:
|
||||
if hasattr(self.observation_info, "clear_unprocessed_messages"):
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]准备执行 {action},清理 {initial_new_message_count} 条规划时已知的新消息。"
|
||||
)
|
||||
await self.observation_info.clear_unprocessed_messages()
|
||||
if hasattr(self.observation_info, "new_messages_count"):
|
||||
self.observation_info.new_messages_count = 0
|
||||
else:
|
||||
logger.error(
|
||||
f"[私聊][{self.private_name}]无法清理未处理消息: ObservationInfo 缺少 clear_unprocessed_messages 方法!"
|
||||
)
|
||||
|
||||
await self._handle_action(action, reason, self.observation_info, self.conversation_info)
|
||||
|
||||
# 检查是否需要结束对话 (逻辑不变)
|
||||
goal_ended = False
|
||||
if hasattr(self.conversation_info, "goal_list") and self.conversation_info.goal_list:
|
||||
for goal_item in self.conversation_info.goal_list:
|
||||
if isinstance(goal_item, dict):
|
||||
current_goal = goal_item.get("goal")
|
||||
|
||||
if current_goal == "结束对话":
|
||||
goal_ended = True
|
||||
break
|
||||
|
||||
if goal_ended:
|
||||
self.should_continue = False
|
||||
logger.info(f"[私聊][{self.private_name}]检测到'结束对话'目标,停止循环。")
|
||||
|
||||
except Exception as loop_err:
|
||||
logger.error(f"[私聊][{self.private_name}]PFC主循环出错: {loop_err}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if self.should_continue:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
logger.info(f"[私聊][{self.private_name}]PFC 循环结束 for stream_id: {self.stream_id}")
|
||||
|
||||
def _check_new_messages_after_planning(self):
|
||||
"""检查在规划后是否有新消息"""
|
||||
# 检查 ObservationInfo 是否已初始化并且有 new_messages_count 属性
|
||||
if not hasattr(self, "observation_info") or not hasattr(self.observation_info, "new_messages_count"):
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo 未初始化或缺少 'new_messages_count' 属性,无法检查新消息。"
|
||||
)
|
||||
return False # 或者根据需要抛出错误
|
||||
|
||||
if self.observation_info.new_messages_count > 2:
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]生成/执行动作期间收到 {self.observation_info.new_messages_count} 条新消息,取消当前动作并重新规划"
|
||||
)
|
||||
# 如果有新消息,也应该重置上次回复状态
|
||||
if hasattr(self, "conversation_info"): # 确保 conversation_info 已初始化
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
else:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ConversationInfo 未初始化,无法重置 last_successful_reply_action。"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message:
|
||||
"""将消息字典转换为Message对象"""
|
||||
try:
|
||||
# 尝试从 msg_dict 直接获取 chat_stream,如果失败则从全局 chat_manager 获取
|
||||
chat_info = msg_dict.get("chat_info")
|
||||
if chat_info and isinstance(chat_info, dict):
|
||||
chat_stream = ChatStream.from_dict(chat_info)
|
||||
elif self.chat_stream: # 使用实例变量中的 chat_stream
|
||||
chat_stream = self.chat_stream
|
||||
else: # Fallback: 尝试从 manager 获取 (可能需要 stream_id)
|
||||
chat_stream = chat_manager.get_stream(self.stream_id)
|
||||
if not chat_stream:
|
||||
raise ValueError(f"无法确定 ChatStream for stream_id {self.stream_id}")
|
||||
|
||||
user_info = UserInfo.from_dict(msg_dict.get("user_info", {}))
|
||||
|
||||
return Message(
|
||||
message_id=msg_dict.get("message_id", f"gen_{time.time()}"), # 提供默认 ID
|
||||
chat_stream=chat_stream, # 使用确定的 chat_stream
|
||||
time=msg_dict.get("time", time.time()), # 提供默认时间
|
||||
user_info=user_info,
|
||||
processed_plain_text=msg_dict.get("processed_plain_text", ""),
|
||||
detailed_plain_text=msg_dict.get("detailed_plain_text", ""),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[私聊][{self.private_name}]转换消息时出错: {e}")
|
||||
# 可以选择返回 None 或重新抛出异常,这里选择重新抛出以指示问题
|
||||
raise ValueError(f"无法将字典转换为 Message 对象: {e}") from e
|
||||
|
||||
async def _handle_action(
|
||||
self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo
|
||||
):
|
||||
"""处理规划的行动"""
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]执行行动: {action}, 原因: {reason}")
|
||||
|
||||
# 记录action历史 (逻辑不变)
|
||||
current_action_record = {
|
||||
"action": action,
|
||||
"plan_reason": reason,
|
||||
"status": "start",
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
"final_reason": None,
|
||||
}
|
||||
# 确保 done_action 列表存在
|
||||
if not hasattr(conversation_info, "done_action"):
|
||||
conversation_info.done_action = []
|
||||
conversation_info.done_action.append(current_action_record)
|
||||
action_index = len(conversation_info.done_action) - 1
|
||||
|
||||
action_successful = False # 用于标记动作是否成功完成
|
||||
|
||||
# --- 根据不同的 action 执行 ---
|
||||
|
||||
# send_new_message 失败后执行 wait
|
||||
if action == "send_new_message":
|
||||
max_reply_attempts = 3
|
||||
reply_attempt_count = 0
|
||||
is_suitable = False
|
||||
need_replan = False
|
||||
check_reason = "未进行尝试"
|
||||
final_reply_to_send = ""
|
||||
|
||||
while reply_attempt_count < max_reply_attempts and not is_suitable:
|
||||
reply_attempt_count += 1
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]尝试生成追问回复 (第 {reply_attempt_count}/{max_reply_attempts} 次)..."
|
||||
)
|
||||
self.state = ConversationState.GENERATING
|
||||
|
||||
# 1. 生成回复 (调用 generate 时传入 action_type)
|
||||
self.generated_reply = await self.reply_generator.generate(
|
||||
observation_info, conversation_info, action_type="send_new_message"
|
||||
)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次生成的追问回复: {self.generated_reply}"
|
||||
)
|
||||
|
||||
# 2. 检查回复 (逻辑不变)
|
||||
self.state = ConversationState.CHECKING
|
||||
try:
|
||||
current_goal_str = conversation_info.goal_list[0]["goal"] if conversation_info.goal_list else ""
|
||||
is_suitable, check_reason, need_replan = await self.reply_generator.check_reply(
|
||||
reply=self.generated_reply,
|
||||
goal=current_goal_str,
|
||||
chat_history=observation_info.chat_history,
|
||||
chat_history_str=observation_info.chat_history_str,
|
||||
retry_count=reply_attempt_count - 1,
|
||||
)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次追问检查结果: 合适={is_suitable}, 原因='{check_reason}', 需重新规划={need_replan}"
|
||||
)
|
||||
if is_suitable:
|
||||
final_reply_to_send = self.generated_reply
|
||||
break
|
||||
elif need_replan:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次追问检查建议重新规划,停止尝试。原因: {check_reason}"
|
||||
)
|
||||
break
|
||||
except Exception as check_err:
|
||||
logger.error(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次调用 ReplyChecker (追问) 时出错: {check_err}"
|
||||
)
|
||||
check_reason = f"第 {reply_attempt_count} 次检查过程出错: {check_err}"
|
||||
break
|
||||
|
||||
# 循环结束,处理最终结果
|
||||
if is_suitable:
|
||||
# 检查是否有新消息
|
||||
if self._check_new_messages_after_planning():
|
||||
logger.info(f"[私聊][{self.private_name}]生成追问回复期间收到新消息,取消发送,重新规划行动")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"有新消息,取消发送追问: {final_reply_to_send}"}
|
||||
)
|
||||
return # 直接返回,重新规划
|
||||
|
||||
# 发送合适的回复
|
||||
self.generated_reply = final_reply_to_send
|
||||
# --- 在这里调用 _send_reply ---
|
||||
await self._send_reply() # <--- 调用恢复后的函数
|
||||
|
||||
# 更新状态: 标记上次成功是 send_new_message
|
||||
self.conversation_info.last_successful_reply_action = "send_new_message"
|
||||
action_successful = True # 标记动作成功
|
||||
|
||||
elif need_replan:
|
||||
# 打回动作决策
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,追问回复决定打回动作决策。打回原因: {check_reason}"
|
||||
)
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"追问尝试{reply_attempt_count}次后打回: {check_reason}"}
|
||||
)
|
||||
|
||||
else:
|
||||
# 追问失败
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,未能生成合适的追问回复。最终原因: {check_reason}"
|
||||
)
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"追问尝试{reply_attempt_count}次后失败: {check_reason}"}
|
||||
)
|
||||
# 重置状态: 追问失败,下次用初始 prompt
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
|
||||
# 执行 Wait 操作
|
||||
logger.info(f"[私聊][{self.private_name}]由于无法生成合适追问回复,执行 'wait' 操作...")
|
||||
self.state = ConversationState.WAITING
|
||||
await self.waiter.wait(self.conversation_info)
|
||||
wait_action_record = {
|
||||
"action": "wait",
|
||||
"plan_reason": "因 send_new_message 多次尝试失败而执行的后备等待",
|
||||
"status": "done",
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
"final_reason": None,
|
||||
}
|
||||
conversation_info.done_action.append(wait_action_record)
|
||||
|
||||
elif action == "direct_reply":
|
||||
max_reply_attempts = 3
|
||||
reply_attempt_count = 0
|
||||
is_suitable = False
|
||||
need_replan = False
|
||||
check_reason = "未进行尝试"
|
||||
final_reply_to_send = ""
|
||||
|
||||
while reply_attempt_count < max_reply_attempts and not is_suitable:
|
||||
reply_attempt_count += 1
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]尝试生成首次回复 (第 {reply_attempt_count}/{max_reply_attempts} 次)..."
|
||||
)
|
||||
self.state = ConversationState.GENERATING
|
||||
|
||||
# 1. 生成回复
|
||||
self.generated_reply = await self.reply_generator.generate(
|
||||
observation_info, conversation_info, action_type="direct_reply"
|
||||
)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次生成的首次回复: {self.generated_reply}"
|
||||
)
|
||||
|
||||
# 2. 检查回复
|
||||
self.state = ConversationState.CHECKING
|
||||
try:
|
||||
current_goal_str = conversation_info.goal_list[0]["goal"] if conversation_info.goal_list else ""
|
||||
is_suitable, check_reason, need_replan = await self.reply_generator.check_reply(
|
||||
reply=self.generated_reply,
|
||||
goal=current_goal_str,
|
||||
chat_history=observation_info.chat_history,
|
||||
chat_history_str=observation_info.chat_history_str,
|
||||
retry_count=reply_attempt_count - 1,
|
||||
)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次首次回复检查结果: 合适={is_suitable}, 原因='{check_reason}', 需重新规划={need_replan}"
|
||||
)
|
||||
if is_suitable:
|
||||
final_reply_to_send = self.generated_reply
|
||||
break
|
||||
elif need_replan:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次首次回复检查建议重新规划,停止尝试。原因: {check_reason}"
|
||||
)
|
||||
break
|
||||
except Exception as check_err:
|
||||
logger.error(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次调用 ReplyChecker (首次回复) 时出错: {check_err}"
|
||||
)
|
||||
check_reason = f"第 {reply_attempt_count} 次检查过程出错: {check_err}"
|
||||
break
|
||||
|
||||
# 循环结束,处理最终结果
|
||||
if is_suitable:
|
||||
# 检查是否有新消息
|
||||
if self._check_new_messages_after_planning():
|
||||
logger.info(f"[私聊][{self.private_name}]生成首次回复期间收到新消息,取消发送,重新规划行动")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"有新消息,取消发送首次回复: {final_reply_to_send}"}
|
||||
)
|
||||
return # 直接返回,重新规划
|
||||
|
||||
# 发送合适的回复
|
||||
self.generated_reply = final_reply_to_send
|
||||
# --- 在这里调用 _send_reply ---
|
||||
await self._send_reply() # <--- 调用恢复后的函数
|
||||
|
||||
# 更新状态: 标记上次成功是 direct_reply
|
||||
self.conversation_info.last_successful_reply_action = "direct_reply"
|
||||
action_successful = True # 标记动作成功
|
||||
|
||||
elif need_replan:
|
||||
# 打回动作决策
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,首次回复决定打回动作决策。打回原因: {check_reason}"
|
||||
)
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"首次回复尝试{reply_attempt_count}次后打回: {check_reason}"}
|
||||
)
|
||||
|
||||
else:
|
||||
# 首次回复失败
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,未能生成合适的首次回复。最终原因: {check_reason}"
|
||||
)
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"首次回复尝试{reply_attempt_count}次后失败: {check_reason}"}
|
||||
)
|
||||
# 重置状态: 首次回复失败,下次还是用初始 prompt
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
|
||||
# 执行 Wait 操作 (保持原有逻辑)
|
||||
logger.info(f"[私聊][{self.private_name}]由于无法生成合适首次回复,执行 'wait' 操作...")
|
||||
self.state = ConversationState.WAITING
|
||||
await self.waiter.wait(self.conversation_info)
|
||||
wait_action_record = {
|
||||
"action": "wait",
|
||||
"plan_reason": "因 direct_reply 多次尝试失败而执行的后备等待",
|
||||
"status": "done",
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
"final_reason": None,
|
||||
}
|
||||
conversation_info.done_action.append(wait_action_record)
|
||||
|
||||
elif action == "fetch_knowledge":
|
||||
self.state = ConversationState.FETCHING
|
||||
knowledge_query = reason
|
||||
try:
|
||||
# 检查 knowledge_fetcher 是否存在
|
||||
if not hasattr(self, "knowledge_fetcher"):
|
||||
logger.error(f"[私聊][{self.private_name}]KnowledgeFetcher 未初始化,无法获取知识。")
|
||||
raise AttributeError("KnowledgeFetcher not initialized")
|
||||
|
||||
knowledge, source = await self.knowledge_fetcher.fetch(knowledge_query, observation_info.chat_history)
|
||||
logger.info(f"[私聊][{self.private_name}]获取到知识: {knowledge[:100]}..., 来源: {source}")
|
||||
if knowledge:
|
||||
# 确保 knowledge_list 存在
|
||||
if not hasattr(conversation_info, "knowledge_list"):
|
||||
conversation_info.knowledge_list = []
|
||||
conversation_info.knowledge_list.append(
|
||||
{"query": knowledge_query, "knowledge": knowledge, "source": source}
|
||||
)
|
||||
action_successful = True
|
||||
except Exception as fetch_err:
|
||||
logger.error(f"[私聊][{self.private_name}]获取知识时出错: {str(fetch_err)}")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"获取知识失败: {str(fetch_err)}"}
|
||||
)
|
||||
self.conversation_info.last_successful_reply_action = None # 重置状态
|
||||
|
||||
elif action == "rethink_goal":
|
||||
self.state = ConversationState.RETHINKING
|
||||
try:
|
||||
# 检查 goal_analyzer 是否存在
|
||||
if not hasattr(self, "goal_analyzer"):
|
||||
logger.error(f"[私聊][{self.private_name}]GoalAnalyzer 未初始化,无法重新思考目标。")
|
||||
raise AttributeError("GoalAnalyzer not initialized")
|
||||
await self.goal_analyzer.analyze_goal(conversation_info, observation_info)
|
||||
action_successful = True
|
||||
except Exception as rethink_err:
|
||||
logger.error(f"[私聊][{self.private_name}]重新思考目标时出错: {rethink_err}")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"重新思考目标失败: {rethink_err}"}
|
||||
)
|
||||
self.conversation_info.last_successful_reply_action = None # 重置状态
|
||||
|
||||
elif action == "listening":
|
||||
self.state = ConversationState.LISTENING
|
||||
logger.info(f"[私聊][{self.private_name}]倾听对方发言...")
|
||||
try:
|
||||
# 检查 waiter 是否存在
|
||||
if not hasattr(self, "waiter"):
|
||||
logger.error(f"[私聊][{self.private_name}]Waiter 未初始化,无法倾听。")
|
||||
raise AttributeError("Waiter not initialized")
|
||||
await self.waiter.wait_listening(conversation_info)
|
||||
action_successful = True # Listening 完成就算成功
|
||||
except Exception as listen_err:
|
||||
logger.error(f"[私聊][{self.private_name}]倾听时出错: {listen_err}")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"倾听失败: {listen_err}"}
|
||||
)
|
||||
self.conversation_info.last_successful_reply_action = None # 重置状态
|
||||
|
||||
elif action == "say_goodbye":
|
||||
self.state = ConversationState.GENERATING # 也可以定义一个新的状态,如 ENDING
|
||||
logger.info(f"[私聊][{self.private_name}]执行行动: 生成并发送告别语...")
|
||||
try:
|
||||
# 1. 生成告别语 (使用 'say_goodbye' action_type)
|
||||
self.generated_reply = await self.reply_generator.generate(
|
||||
observation_info, conversation_info, action_type="say_goodbye"
|
||||
)
|
||||
logger.info(f"[私聊][{self.private_name}]生成的告别语: {self.generated_reply}")
|
||||
|
||||
# 2. 直接发送告别语 (不经过检查)
|
||||
if self.generated_reply: # 确保生成了内容
|
||||
await self._send_reply() # 调用发送方法
|
||||
# 发送成功后,标记动作成功
|
||||
action_successful = True
|
||||
logger.info(f"[私聊][{self.private_name}]告别语已发送。")
|
||||
else:
|
||||
logger.warning(f"[私聊][{self.private_name}]未能生成告别语内容,无法发送。")
|
||||
action_successful = False # 标记动作失败
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": "未能生成告别语内容"}
|
||||
)
|
||||
|
||||
# 3. 无论是否发送成功,都准备结束对话
|
||||
self.should_continue = False
|
||||
logger.info(f"[私聊][{self.private_name}]发送告别语流程结束,即将停止对话实例。")
|
||||
|
||||
except Exception as goodbye_err:
|
||||
logger.error(f"[私聊][{self.private_name}]生成或发送告别语时出错: {goodbye_err}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
# 即使出错,也结束对话
|
||||
self.should_continue = False
|
||||
action_successful = False # 标记动作失败
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"生成或发送告别语时出错: {goodbye_err}"}
|
||||
)
|
||||
|
||||
elif action == "end_conversation":
|
||||
# 这个分支现在只会在 action_planner 最终决定不告别时被调用
|
||||
self.should_continue = False
|
||||
logger.info(f"[私聊][{self.private_name}]收到最终结束指令,停止对话...")
|
||||
action_successful = True # 标记这个指令本身是成功的
|
||||
|
||||
elif action == "block_and_ignore":
|
||||
logger.info(f"[私聊][{self.private_name}]不想再理你了...")
|
||||
ignore_duration_seconds = 10 * 60
|
||||
self.ignore_until_timestamp = time.time() + ignore_duration_seconds
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]将忽略此对话直到: {datetime.datetime.fromtimestamp(self.ignore_until_timestamp)}"
|
||||
)
|
||||
self.state = ConversationState.IGNORED
|
||||
action_successful = True # 标记动作成功
|
||||
|
||||
else: # 对应 'wait' 动作
|
||||
self.state = ConversationState.WAITING
|
||||
logger.info(f"[私聊][{self.private_name}]等待更多信息...")
|
||||
try:
|
||||
# 检查 waiter 是否存在
|
||||
if not hasattr(self, "waiter"):
|
||||
logger.error(f"[私聊][{self.private_name}]Waiter 未初始化,无法等待。")
|
||||
raise AttributeError("Waiter not initialized")
|
||||
_timeout_occurred = await self.waiter.wait(self.conversation_info)
|
||||
action_successful = True # Wait 完成就算成功
|
||||
except Exception as wait_err:
|
||||
logger.error(f"[私聊][{self.private_name}]等待时出错: {wait_err}")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"等待失败: {wait_err}"}
|
||||
)
|
||||
self.conversation_info.last_successful_reply_action = None # 重置状态
|
||||
|
||||
# --- 更新 Action History 状态 ---
|
||||
# 只有当动作本身成功时,才更新状态为 done
|
||||
if action_successful:
|
||||
conversation_info.done_action[action_index].update(
|
||||
{
|
||||
"status": "done",
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
}
|
||||
)
|
||||
# 重置状态: 对于非回复类动作的成功,清除上次回复状态
|
||||
if action not in ["direct_reply", "send_new_message"]:
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
logger.debug(f"[私聊][{self.private_name}]动作 {action} 成功完成,重置 last_successful_reply_action")
|
||||
# 如果动作是 recall 状态,在各自的处理逻辑中已经更新了 done_action
|
||||
|
||||
async def _send_reply(self):
|
||||
"""发送回复"""
|
||||
if not self.generated_reply:
|
||||
logger.warning(f"[私聊][{self.private_name}]没有生成回复内容,无法发送。")
|
||||
return
|
||||
|
||||
try:
|
||||
_current_time = time.time()
|
||||
reply_content = self.generated_reply
|
||||
|
||||
# 发送消息 (确保 direct_sender 和 chat_stream 有效)
|
||||
if not hasattr(self, "direct_sender") or not self.direct_sender:
|
||||
logger.error(f"[私聊][{self.private_name}]DirectMessageSender 未初始化,无法发送回复。")
|
||||
return
|
||||
if not self.chat_stream:
|
||||
logger.error(f"[私聊][{self.private_name}]ChatStream 未初始化,无法发送回复。")
|
||||
return
|
||||
|
||||
await self.direct_sender.send_message(chat_stream=self.chat_stream, content=reply_content)
|
||||
|
||||
# 发送成功后,手动触发 observer 更新可能导致重复处理自己发送的消息
|
||||
# 更好的做法是依赖 observer 的自动轮询或数据库触发器(如果支持)
|
||||
# 暂时注释掉,观察是否影响 ObservationInfo 的更新
|
||||
# self.chat_observer.trigger_update()
|
||||
# if not await self.chat_observer.wait_for_update():
|
||||
# logger.warning(f"[私聊][{self.private_name}]等待 ChatObserver 更新完成超时")
|
||||
|
||||
self.state = ConversationState.ANALYZING # 更新状态
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]发送消息或更新状态时失败: {str(e)}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
self.state = ConversationState.ANALYZING
|
||||
|
||||
async def _send_timeout_message(self):
|
||||
"""发送超时结束消息"""
|
||||
try:
|
||||
messages = self.chat_observer.get_cached_messages(limit=1)
|
||||
if not messages:
|
||||
return
|
||||
|
||||
latest_message = self._convert_to_message(messages[0])
|
||||
await self.direct_sender.send_message(
|
||||
chat_stream=self.chat_stream, content="TODO:超时消息", reply_to_message=latest_message
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]发送超时消息失败: {str(e)}")
|
||||
10
src/chat/brain_chat/PFC/conversation_info.py
Normal file
10
src/chat/brain_chat/PFC/conversation_info.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ConversationInfo:
|
||||
def __init__(self):
|
||||
self.done_action = []
|
||||
self.goal_list = []
|
||||
self.knowledge_list = []
|
||||
self.memory_list = []
|
||||
self.last_successful_reply_action: Optional[str] = None
|
||||
81
src/chat/brain_chat/PFC/message_sender.py
Normal file
81
src/chat/brain_chat/PFC/message_sender.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
from src.common.logger import get_module_logger
|
||||
from ..chat.chat_stream import ChatStream
|
||||
from ..chat.message import Message
|
||||
from maim_message import UserInfo, Seg
|
||||
from src.plugins.chat.message import MessageSending, MessageSet
|
||||
from src.plugins.chat.message_sender import message_manager
|
||||
from ..storage.storage import MessageStorage
|
||||
from ...config.config import global_config
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
logger = get_module_logger("message_sender")
|
||||
|
||||
|
||||
class DirectMessageSender:
|
||||
"""直接消息发送器"""
|
||||
|
||||
def __init__(self, private_name: str):
|
||||
self.private_name = private_name
|
||||
self.storage = MessageStorage()
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
chat_stream: ChatStream,
|
||||
content: str,
|
||||
reply_to_message: Optional[Message] = None,
|
||||
) -> None:
|
||||
"""发送消息到聊天流
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流
|
||||
content: 消息内容
|
||||
reply_to_message: 要回复的消息(可选)
|
||||
"""
|
||||
try:
|
||||
# 创建消息内容
|
||||
segments = Seg(type="seglist", data=[Seg(type="text", data=content)])
|
||||
|
||||
# 获取麦麦的信息
|
||||
bot_user_info = UserInfo(
|
||||
user_id=global_config.BOT_QQ,
|
||||
user_nickname=global_config.BOT_NICKNAME,
|
||||
platform=chat_stream.platform,
|
||||
)
|
||||
|
||||
# 用当前时间作为message_id,和之前那套sender一样
|
||||
message_id = f"dm{round(time.time(), 2)}"
|
||||
|
||||
# 构建消息对象
|
||||
message = MessageSending(
|
||||
message_id=message_id,
|
||||
chat_stream=chat_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
sender_info=reply_to_message.message_info.user_info if reply_to_message else None,
|
||||
message_segment=segments,
|
||||
reply=reply_to_message,
|
||||
is_head=True,
|
||||
is_emoji=False,
|
||||
thinking_start_time=time.time(),
|
||||
)
|
||||
|
||||
# 处理消息
|
||||
await message.process()
|
||||
|
||||
# 不知道有什么用,先留下来了,和之前那套sender一样
|
||||
_message_json = message.to_dict()
|
||||
|
||||
# 发送消息
|
||||
message_set = MessageSet(chat_stream, message_id)
|
||||
message_set.add_message(message)
|
||||
await message_manager.add_message(message_set)
|
||||
await self.storage.store_message(message, chat_stream)
|
||||
logger.info(f"[私聊][{self.private_name}]PFC消息已发送: {content}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {str(e)}")
|
||||
raise
|
||||
119
src/chat/brain_chat/PFC/message_storage.py
Normal file
119
src/chat/brain_chat/PFC/message_storage.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any
|
||||
from src.common.database import db
|
||||
|
||||
|
||||
class MessageStorage(ABC):
|
||||
"""消息存储接口"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_messages_after(self, chat_id: str, message: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""获取指定消息ID之后的所有消息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
message: 消息
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
|
||||
"""获取指定时间点之前的消息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
time_point: 时间戳
|
||||
limit: 最大消息数量
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
|
||||
"""检查是否有新消息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
after_time: 时间戳
|
||||
|
||||
Returns:
|
||||
bool: 是否有新消息
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class MongoDBMessageStorage(MessageStorage):
|
||||
"""MongoDB消息存储实现"""
|
||||
|
||||
async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]:
|
||||
query = {"chat_id": chat_id, "time": {"$gt": message_time}}
|
||||
# print(f"storage_check_message: {message_time}")
|
||||
|
||||
return list(db.messages.find(query).sort("time", 1))
|
||||
|
||||
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
|
||||
query = {"chat_id": chat_id, "time": {"$lt": time_point}}
|
||||
|
||||
messages = list(db.messages.find(query).sort("time", -1).limit(limit))
|
||||
|
||||
# 将消息按时间正序排列
|
||||
messages.reverse()
|
||||
return messages
|
||||
|
||||
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
|
||||
query = {"chat_id": chat_id, "time": {"$gt": after_time}}
|
||||
|
||||
return db.messages.find_one(query) is not None
|
||||
|
||||
|
||||
# # 创建一个内存消息存储实现,用于测试
|
||||
# class InMemoryMessageStorage(MessageStorage):
|
||||
# """内存消息存储实现,主要用于测试"""
|
||||
|
||||
# def __init__(self):
|
||||
# self.messages: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
# async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
# if chat_id not in self.messages:
|
||||
# return []
|
||||
|
||||
# messages = self.messages[chat_id]
|
||||
# if not message_id:
|
||||
# return messages
|
||||
|
||||
# # 找到message_id的索引
|
||||
# try:
|
||||
# index = next(i for i, m in enumerate(messages) if m["message_id"] == message_id)
|
||||
# return messages[index + 1:]
|
||||
# except StopIteration:
|
||||
# return []
|
||||
|
||||
# async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
|
||||
# if chat_id not in self.messages:
|
||||
# return []
|
||||
|
||||
# messages = [
|
||||
# m for m in self.messages[chat_id]
|
||||
# if m["time"] < time_point
|
||||
# ]
|
||||
|
||||
# return messages[-limit:]
|
||||
|
||||
# async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
|
||||
# if chat_id not in self.messages:
|
||||
# return False
|
||||
|
||||
# return any(m["time"] > after_time for m in self.messages[chat_id])
|
||||
|
||||
# # 测试辅助方法
|
||||
# def add_message(self, chat_id: str, message: Dict[str, Any]):
|
||||
# """添加测试消息"""
|
||||
# if chat_id not in self.messages:
|
||||
# self.messages[chat_id] = []
|
||||
# self.messages[chat_id].append(message)
|
||||
# self.messages[chat_id].sort(key=lambda m: m["time"])
|
||||
389
src/chat/brain_chat/PFC/observation_info.py
Normal file
389
src/chat/brain_chat/PFC/observation_info.py
Normal file
@@ -0,0 +1,389 @@
|
||||
from typing import List, Optional, Dict, Any, Set
|
||||
from maim_message import UserInfo
|
||||
import time
|
||||
from src.common.logger import get_module_logger
|
||||
from .chat_observer import ChatObserver
|
||||
from .chat_states import NotificationHandler, NotificationType, Notification
|
||||
from src.plugins.utils.chat_message_builder import build_readable_messages
|
||||
import traceback # 导入 traceback 用于调试
|
||||
|
||||
logger = get_module_logger("observation_info")
|
||||
|
||||
|
||||
class ObservationInfoHandler(NotificationHandler):
|
||||
"""ObservationInfo的通知处理器"""
|
||||
|
||||
def __init__(self, observation_info: "ObservationInfo", private_name: str):
|
||||
"""初始化处理器
|
||||
|
||||
Args:
|
||||
observation_info: 要更新的ObservationInfo实例
|
||||
private_name: 私聊对象的名称,用于日志记录
|
||||
"""
|
||||
self.observation_info = observation_info
|
||||
# 将 private_name 存储在 handler 实例中
|
||||
self.private_name = private_name
|
||||
|
||||
async def handle_notification(self, notification: Notification): # 添加类型提示
|
||||
# 获取通知类型和数据
|
||||
notification_type = notification.type
|
||||
data = notification.data
|
||||
|
||||
try: # 添加错误处理块
|
||||
if notification_type == NotificationType.NEW_MESSAGE:
|
||||
# 处理新消息通知
|
||||
# logger.debug(f"[私聊][{self.private_name}]收到新消息通知data: {data}") # 可以在需要时取消注释
|
||||
message_id = data.get("message_id")
|
||||
processed_plain_text = data.get("processed_plain_text")
|
||||
detailed_plain_text = data.get("detailed_plain_text")
|
||||
user_info_dict = data.get("user_info") # 先获取字典
|
||||
time_value = data.get("time")
|
||||
|
||||
# 确保 user_info 是字典类型再创建 UserInfo 对象
|
||||
user_info = None
|
||||
if isinstance(user_info_dict, dict):
|
||||
try:
|
||||
user_info = UserInfo.from_dict(user_info_dict)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[私聊][{self.private_name}]从字典创建 UserInfo 时出错: {e}, 字典内容: {user_info_dict}"
|
||||
)
|
||||
# 可以选择在这里返回或记录错误,避免后续代码出错
|
||||
return
|
||||
elif user_info_dict is not None:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]收到的 user_info 不是预期的字典类型: {type(user_info_dict)}"
|
||||
)
|
||||
# 根据需要处理非字典情况,这里暂时返回
|
||||
return
|
||||
|
||||
message = {
|
||||
"message_id": message_id,
|
||||
"processed_plain_text": processed_plain_text,
|
||||
"detailed_plain_text": detailed_plain_text,
|
||||
"user_info": user_info_dict, # 存储原始字典或 UserInfo 对象,取决于你的 update_from_message 如何处理
|
||||
"time": time_value,
|
||||
}
|
||||
# 传递 UserInfo 对象(如果成功创建)或原始字典
|
||||
await self.observation_info.update_from_message(message, user_info) # 修改:传递 user_info 对象
|
||||
|
||||
elif notification_type == NotificationType.COLD_CHAT:
|
||||
# 处理冷场通知
|
||||
is_cold = data.get("is_cold", False)
|
||||
await self.observation_info.update_cold_chat_status(is_cold, time.time()) # 修改:改为 await 调用
|
||||
|
||||
elif notification_type == NotificationType.ACTIVE_CHAT:
|
||||
# 处理活跃通知 (通常由 COLD_CHAT 的反向状态处理)
|
||||
is_active = data.get("is_active", False)
|
||||
self.observation_info.is_cold = not is_active
|
||||
|
||||
elif notification_type == NotificationType.BOT_SPEAKING:
|
||||
# 处理机器人说话通知 (按需实现)
|
||||
self.observation_info.is_typing = False
|
||||
self.observation_info.last_bot_speak_time = time.time()
|
||||
|
||||
elif notification_type == NotificationType.USER_SPEAKING:
|
||||
# 处理用户说话通知
|
||||
self.observation_info.is_typing = False
|
||||
self.observation_info.last_user_speak_time = time.time()
|
||||
|
||||
elif notification_type == NotificationType.MESSAGE_DELETED:
|
||||
# 处理消息删除通知
|
||||
message_id = data.get("message_id")
|
||||
# 从 unprocessed_messages 中移除被删除的消息
|
||||
original_count = len(self.observation_info.unprocessed_messages)
|
||||
self.observation_info.unprocessed_messages = [
|
||||
msg for msg in self.observation_info.unprocessed_messages if msg.get("message_id") != message_id
|
||||
]
|
||||
if len(self.observation_info.unprocessed_messages) < original_count:
|
||||
logger.info(f"[私聊][{self.private_name}]移除了未处理的消息 (ID: {message_id})")
|
||||
|
||||
elif notification_type == NotificationType.USER_JOINED:
|
||||
# 处理用户加入通知 (如果适用私聊场景)
|
||||
user_id = data.get("user_id")
|
||||
if user_id:
|
||||
self.observation_info.active_users.add(str(user_id)) # 确保是字符串
|
||||
|
||||
elif notification_type == NotificationType.USER_LEFT:
|
||||
# 处理用户离开通知 (如果适用私聊场景)
|
||||
user_id = data.get("user_id")
|
||||
if user_id:
|
||||
self.observation_info.active_users.discard(str(user_id)) # 确保是字符串
|
||||
|
||||
elif notification_type == NotificationType.ERROR:
|
||||
# 处理错误通知
|
||||
error_msg = data.get("error", "未提供错误信息")
|
||||
logger.error(f"[私聊][{self.private_name}]收到错误通知: {error_msg}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]处理通知时发生错误: {e}")
|
||||
logger.error(traceback.format_exc()) # 打印详细堆栈信息
|
||||
|
||||
|
||||
# @dataclass <-- 这个,不需要了(递黄瓜)
|
||||
class ObservationInfo:
|
||||
"""决策信息类,用于收集和管理来自chat_observer的通知信息 (手动实现 __init__)"""
|
||||
|
||||
# 类型提示保留,可用于文档和静态分析
|
||||
private_name: str
|
||||
chat_history: List[Dict[str, Any]]
|
||||
chat_history_str: str
|
||||
unprocessed_messages: List[Dict[str, Any]]
|
||||
active_users: Set[str]
|
||||
last_bot_speak_time: Optional[float]
|
||||
last_user_speak_time: Optional[float]
|
||||
last_message_time: Optional[float]
|
||||
last_message_id: Optional[str]
|
||||
last_message_content: str
|
||||
last_message_sender: Optional[str]
|
||||
bot_id: Optional[str]
|
||||
chat_history_count: int
|
||||
new_messages_count: int
|
||||
cold_chat_start_time: Optional[float]
|
||||
cold_chat_duration: float
|
||||
is_typing: bool
|
||||
is_cold_chat: bool
|
||||
changed: bool
|
||||
chat_observer: Optional[ChatObserver]
|
||||
handler: Optional[ObservationInfoHandler]
|
||||
|
||||
def __init__(self, private_name: str):
|
||||
"""
|
||||
手动初始化 ObservationInfo 的所有实例变量。
|
||||
"""
|
||||
|
||||
# 接收的参数
|
||||
self.private_name: str = private_name
|
||||
|
||||
# data_list
|
||||
self.chat_history: List[Dict[str, Any]] = []
|
||||
self.chat_history_str: str = ""
|
||||
self.unprocessed_messages: List[Dict[str, Any]] = []
|
||||
self.active_users: Set[str] = set()
|
||||
|
||||
# data
|
||||
self.last_bot_speak_time: Optional[float] = None
|
||||
self.last_user_speak_time: Optional[float] = None
|
||||
self.last_message_time: Optional[float] = None
|
||||
self.last_message_id: Optional[str] = None
|
||||
self.last_message_content: str = ""
|
||||
self.last_message_sender: Optional[str] = None
|
||||
self.bot_id: Optional[str] = None
|
||||
self.chat_history_count: int = 0
|
||||
self.new_messages_count: int = 0
|
||||
self.cold_chat_start_time: Optional[float] = None
|
||||
self.cold_chat_duration: float = 0.0
|
||||
|
||||
# state
|
||||
self.is_typing: bool = False
|
||||
self.is_cold_chat: bool = False
|
||||
self.changed: bool = False
|
||||
|
||||
# 关联对象
|
||||
self.chat_observer: Optional[ChatObserver] = None
|
||||
|
||||
self.handler: ObservationInfoHandler = ObservationInfoHandler(self, self.private_name)
|
||||
|
||||
def bind_to_chat_observer(self, chat_observer: ChatObserver):
|
||||
"""绑定到指定的chat_observer
|
||||
|
||||
Args:
|
||||
chat_observer: 要绑定的 ChatObserver 实例
|
||||
"""
|
||||
if self.chat_observer:
|
||||
logger.warning(f"[私聊][{self.private_name}]尝试重复绑定 ChatObserver")
|
||||
return
|
||||
|
||||
self.chat_observer = chat_observer
|
||||
try:
|
||||
if not self.handler: # 确保 handler 已经被创建
|
||||
logger.error(f"[私聊][{self.private_name}] 尝试绑定时 handler 未初始化!")
|
||||
self.chat_observer = None # 重置,防止后续错误
|
||||
return
|
||||
|
||||
# 注册关心的通知类型
|
||||
self.chat_observer.notification_manager.register_handler(
|
||||
target="observation_info", notification_type=NotificationType.NEW_MESSAGE, handler=self.handler
|
||||
)
|
||||
self.chat_observer.notification_manager.register_handler(
|
||||
target="observation_info", notification_type=NotificationType.COLD_CHAT, handler=self.handler
|
||||
)
|
||||
# 可以根据需要注册更多通知类型
|
||||
# self.chat_observer.notification_manager.register_handler(
|
||||
# target="observation_info", notification_type=NotificationType.MESSAGE_DELETED, handler=self.handler
|
||||
# )
|
||||
logger.info(f"[私聊][{self.private_name}]成功绑定到 ChatObserver")
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]绑定到 ChatObserver 时出错: {e}")
|
||||
self.chat_observer = None # 绑定失败,重置
|
||||
|
||||
def unbind_from_chat_observer(self):
|
||||
"""解除与chat_observer的绑定"""
|
||||
if (
|
||||
self.chat_observer and hasattr(self.chat_observer, "notification_manager") and self.handler
|
||||
): # 增加 handler 检查
|
||||
try:
|
||||
self.chat_observer.notification_manager.unregister_handler(
|
||||
target="observation_info", notification_type=NotificationType.NEW_MESSAGE, handler=self.handler
|
||||
)
|
||||
self.chat_observer.notification_manager.unregister_handler(
|
||||
target="observation_info", notification_type=NotificationType.COLD_CHAT, handler=self.handler
|
||||
)
|
||||
# 如果注册了其他类型,也要在这里注销
|
||||
# self.chat_observer.notification_manager.unregister_handler(
|
||||
# target="observation_info", notification_type=NotificationType.MESSAGE_DELETED, handler=self.handler
|
||||
# )
|
||||
logger.info(f"[私聊][{self.private_name}]成功从 ChatObserver 解绑")
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]从 ChatObserver 解绑时出错: {e}")
|
||||
finally: # 确保 chat_observer 被重置
|
||||
self.chat_observer = None
|
||||
else:
|
||||
logger.warning(f"[私聊][{self.private_name}]尝试解绑时 ChatObserver 不存在、无效或 handler 未设置")
|
||||
|
||||
# 修改:update_from_message 接收 UserInfo 对象
|
||||
async def update_from_message(self, message: Dict[str, Any], user_info: Optional[UserInfo]):
|
||||
"""从消息更新信息
|
||||
|
||||
Args:
|
||||
message: 消息数据字典
|
||||
user_info: 解析后的 UserInfo 对象 (可能为 None)
|
||||
"""
|
||||
message_time = message.get("time")
|
||||
message_id = message.get("message_id")
|
||||
processed_text = message.get("processed_plain_text", "")
|
||||
|
||||
# 只有在新消息到达时才更新 last_message 相关信息
|
||||
if message_time and message_time > (self.last_message_time or 0):
|
||||
self.last_message_time = message_time
|
||||
self.last_message_id = message_id
|
||||
self.last_message_content = processed_text
|
||||
# 重置冷场计时器
|
||||
self.is_cold_chat = False
|
||||
self.cold_chat_start_time = None
|
||||
self.cold_chat_duration = 0.0
|
||||
|
||||
if user_info:
|
||||
sender_id = str(user_info.user_id) # 确保是字符串
|
||||
self.last_message_sender = sender_id
|
||||
# 更新发言时间
|
||||
if sender_id == self.bot_id:
|
||||
self.last_bot_speak_time = message_time
|
||||
else:
|
||||
self.last_user_speak_time = message_time
|
||||
self.active_users.add(sender_id) # 用户发言则认为其活跃
|
||||
else:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]处理消息更新时缺少有效的 UserInfo 对象, message_id: {message_id}"
|
||||
)
|
||||
self.last_message_sender = None # 发送者未知
|
||||
|
||||
# 将原始消息字典添加到未处理列表
|
||||
self.unprocessed_messages.append(message)
|
||||
self.new_messages_count = len(self.unprocessed_messages) # 直接用列表长度
|
||||
|
||||
# logger.debug(f"[私聊][{self.private_name}]消息更新: last_time={self.last_message_time}, new_count={self.new_messages_count}")
|
||||
self.update_changed() # 标记状态已改变
|
||||
else:
|
||||
# 如果消息时间戳不是最新的,可能不需要处理,或者记录一个警告
|
||||
pass
|
||||
# logger.warning(f"[私聊][{self.private_name}]收到过时或无效时间戳的消息: ID={message_id}, time={message_time}")
|
||||
|
||||
def update_changed(self):
|
||||
"""标记状态已改变,并重置标记"""
|
||||
# logger.debug(f"[私聊][{self.private_name}]状态标记为已改变 (changed=True)")
|
||||
self.changed = True
|
||||
|
||||
async def update_cold_chat_status(self, is_cold: bool, current_time: float):
|
||||
"""更新冷场状态
|
||||
|
||||
Args:
|
||||
is_cold: 是否处于冷场状态
|
||||
current_time: 当前时间戳
|
||||
"""
|
||||
if is_cold != self.is_cold_chat: # 仅在状态变化时更新
|
||||
self.is_cold_chat = is_cold
|
||||
if is_cold:
|
||||
# 进入冷场状态
|
||||
self.cold_chat_start_time = (
|
||||
self.last_message_time or current_time
|
||||
) # 从最后消息时间开始算,或从当前时间开始
|
||||
logger.info(f"[私聊][{self.private_name}]进入冷场状态,开始时间: {self.cold_chat_start_time}")
|
||||
else:
|
||||
# 结束冷场状态
|
||||
if self.cold_chat_start_time:
|
||||
self.cold_chat_duration = current_time - self.cold_chat_start_time
|
||||
logger.info(f"[私聊][{self.private_name}]结束冷场状态,持续时间: {self.cold_chat_duration:.2f} 秒")
|
||||
self.cold_chat_start_time = None # 重置开始时间
|
||||
self.update_changed() # 状态变化,标记改变
|
||||
|
||||
# 即使状态没变,如果是冷场状态,也更新持续时间
|
||||
if self.is_cold_chat and self.cold_chat_start_time:
|
||||
self.cold_chat_duration = current_time - self.cold_chat_start_time
|
||||
|
||||
def get_active_duration(self) -> float:
|
||||
"""获取当前活跃时长 (距离最后一条消息的时间)
|
||||
|
||||
Returns:
|
||||
float: 最后一条消息到现在的时长(秒)
|
||||
"""
|
||||
if not self.last_message_time:
|
||||
return 0.0
|
||||
return time.time() - self.last_message_time
|
||||
|
||||
def get_user_response_time(self) -> Optional[float]:
|
||||
"""获取用户最后响应时间 (距离用户最后发言的时间)
|
||||
|
||||
Returns:
|
||||
Optional[float]: 用户最后发言到现在的时长(秒),如果没有用户发言则返回None
|
||||
"""
|
||||
if not self.last_user_speak_time:
|
||||
return None
|
||||
return time.time() - self.last_user_speak_time
|
||||
|
||||
def get_bot_response_time(self) -> Optional[float]:
|
||||
"""获取机器人最后响应时间 (距离机器人最后发言的时间)
|
||||
|
||||
Returns:
|
||||
Optional[float]: 机器人最后发言到现在的时长(秒),如果没有机器人发言则返回None
|
||||
"""
|
||||
if not self.last_bot_speak_time:
|
||||
return None
|
||||
return time.time() - self.last_bot_speak_time
|
||||
|
||||
async def clear_unprocessed_messages(self):
|
||||
"""将未处理消息移入历史记录,并更新相关状态"""
|
||||
if not self.unprocessed_messages:
|
||||
return # 没有未处理消息,直接返回
|
||||
|
||||
# logger.debug(f"[私聊][{self.private_name}]处理 {len(self.unprocessed_messages)} 条未处理消息...")
|
||||
# 将未处理消息添加到历史记录中 (确保历史记录有长度限制,避免无限增长)
|
||||
max_history_len = 100 # 示例:最多保留100条历史记录
|
||||
self.chat_history.extend(self.unprocessed_messages)
|
||||
if len(self.chat_history) > max_history_len:
|
||||
self.chat_history = self.chat_history[-max_history_len:]
|
||||
|
||||
# 更新历史记录字符串 (只使用最近一部分生成,例如20条)
|
||||
history_slice_for_str = self.chat_history[-20:]
|
||||
try:
|
||||
self.chat_history_str = await build_readable_messages(
|
||||
history_slice_for_str,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0, # read_mark 可能需要根据逻辑调整
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]构建聊天记录字符串时出错: {e}")
|
||||
self.chat_history_str = "[构建聊天记录出错]" # 提供错误提示
|
||||
|
||||
# 清空未处理消息列表和计数
|
||||
# cleared_count = len(self.unprocessed_messages)
|
||||
self.unprocessed_messages.clear()
|
||||
self.new_messages_count = 0
|
||||
# self.has_unread_messages = False # 这个状态可以通过 new_messages_count 判断
|
||||
|
||||
self.chat_history_count = len(self.chat_history) # 更新历史记录总数
|
||||
# logger.debug(f"[私聊][{self.private_name}]已处理 {cleared_count} 条消息,当前历史记录 {self.chat_history_count} 条。")
|
||||
|
||||
self.update_changed() # 状态改变
|
||||
345
src/chat/brain_chat/PFC/pfc.py
Normal file
345
src/chat/brain_chat/PFC/pfc.py
Normal file
@@ -0,0 +1,345 @@
|
||||
from typing import List, Tuple, TYPE_CHECKING
|
||||
from src.common.logger import get_module_logger
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ...config.config import global_config
|
||||
from .chat_observer import ChatObserver
|
||||
from .pfc_utils import get_items_from_json
|
||||
from src.individuality.individuality import Individuality
|
||||
from .conversation_info import ConversationInfo
|
||||
from .observation_info import ObservationInfo
|
||||
from src.plugins.utils.chat_message_builder import build_readable_messages
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = get_module_logger("pfc")
|
||||
|
||||
|
||||
def _calculate_similarity(goal1: str, goal2: str) -> float:
|
||||
"""简单计算两个目标之间的相似度
|
||||
|
||||
这里使用一个简单的实现,实际可以使用更复杂的文本相似度算法
|
||||
|
||||
Args:
|
||||
goal1: 第一个目标
|
||||
goal2: 第二个目标
|
||||
|
||||
Returns:
|
||||
float: 相似度得分 (0-1)
|
||||
"""
|
||||
# 简单实现:检查重叠字数比例
|
||||
words1 = set(goal1)
|
||||
words2 = set(goal2)
|
||||
overlap = len(words1.intersection(words2))
|
||||
total = len(words1.union(words2))
|
||||
return overlap / total if total > 0 else 0
|
||||
|
||||
|
||||
class GoalAnalyzer:
|
||||
"""对话目标分析器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal"
|
||||
)
|
||||
|
||||
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
|
||||
self.name = global_config.BOT_NICKNAME
|
||||
self.nick_name = global_config.BOT_ALIAS_NAMES
|
||||
self.private_name = private_name
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
|
||||
# 多目标存储结构
|
||||
self.goals = [] # 存储多个目标
|
||||
self.max_goals = 3 # 同时保持的最大目标数量
|
||||
self.current_goal_and_reason = None
|
||||
|
||||
async def analyze_goal(self, conversation_info: ConversationInfo, observation_info: ObservationInfo):
|
||||
"""分析对话历史并设定目标
|
||||
|
||||
Args:
|
||||
conversation_info: 对话信息
|
||||
observation_info: 观察信息
|
||||
|
||||
Returns:
|
||||
Tuple[str, str, str]: (目标, 方法, 原因)
|
||||
"""
|
||||
# 构建对话目标
|
||||
goals_str = ""
|
||||
if conversation_info.goal_list:
|
||||
for goal_reason in conversation_info.goal_list:
|
||||
if isinstance(goal_reason, dict):
|
||||
goal = goal_reason.get("goal", "目标内容缺失")
|
||||
reasoning = goal_reason.get("reasoning", "没有明确原因")
|
||||
else:
|
||||
goal = str(goal_reason)
|
||||
reasoning = "没有明确原因"
|
||||
|
||||
goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
|
||||
goals_str += goal_str
|
||||
else:
|
||||
goal = "目前没有明确对话目标"
|
||||
reasoning = "目前没有明确对话目标,最好思考一个对话目标"
|
||||
goals_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
|
||||
|
||||
# 获取聊天历史记录
|
||||
chat_history_text = observation_info.chat_history_str
|
||||
|
||||
if observation_info.new_messages_count > 0:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
new_messages_str = await build_readable_messages(
|
||||
new_messages_list,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
chat_history_text += f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}"
|
||||
|
||||
# await observation_info.clear_unprocessed_messages()
|
||||
|
||||
persona_text = f"你的名字是{self.name},{self.personality_info}。"
|
||||
# 构建action历史文本
|
||||
action_history_list = conversation_info.done_action
|
||||
action_history_text = "你之前做的事情是:"
|
||||
for action in action_history_list:
|
||||
action_history_text += f"{action}\n"
|
||||
|
||||
prompt = f"""{persona_text}。现在你在参与一场QQ聊天,请分析以下聊天记录,并根据你的性格特征确定多个明确的对话目标。
|
||||
这些目标应该反映出对话的不同方面和意图。
|
||||
|
||||
{action_history_text}
|
||||
当前对话目标:
|
||||
{goals_str}
|
||||
|
||||
聊天记录:
|
||||
{chat_history_text}
|
||||
|
||||
请分析当前对话并确定最适合的对话目标。你可以:
|
||||
1. 保持现有目标不变
|
||||
2. 修改现有目标
|
||||
3. 添加新目标
|
||||
4. 删除不再相关的目标
|
||||
5. 如果你想结束对话,请设置一个目标,目标goal为"结束对话",原因reasoning为你希望结束对话
|
||||
|
||||
请以JSON数组格式输出当前的所有对话目标,每个目标包含以下字段:
|
||||
1. goal: 对话目标(简短的一句话)
|
||||
2. reasoning: 对话原因,为什么设定这个目标(简要解释)
|
||||
|
||||
输出格式示例:
|
||||
[
|
||||
{{
|
||||
"goal": "回答用户关于Python编程的具体问题",
|
||||
"reasoning": "用户提出了关于Python的技术问题,需要专业且准确的解答"
|
||||
}},
|
||||
{{
|
||||
"goal": "回答用户关于python安装的具体问题",
|
||||
"reasoning": "用户提出了关于Python的技术问题,需要专业且准确的解答"
|
||||
}}
|
||||
]"""
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]发送到LLM的提示词: {prompt}")
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM原始返回内容: {content}")
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]分析对话目标时出错: {str(e)}")
|
||||
content = ""
|
||||
|
||||
# 使用改进后的get_items_from_json函数处理JSON数组
|
||||
success, result = get_items_from_json(
|
||||
content,
|
||||
self.private_name,
|
||||
"goal",
|
||||
"reasoning",
|
||||
required_types={"goal": str, "reasoning": str},
|
||||
allow_array=True,
|
||||
)
|
||||
|
||||
if success:
|
||||
# 判断结果是单个字典还是字典列表
|
||||
if isinstance(result, list):
|
||||
# 清空现有目标列表并添加新目标
|
||||
conversation_info.goal_list = []
|
||||
for item in result:
|
||||
conversation_info.goal_list.append(item)
|
||||
|
||||
# 返回第一个目标作为当前主要目标(如果有)
|
||||
if result:
|
||||
first_goal = result[0]
|
||||
return first_goal.get("goal", ""), "", first_goal.get("reasoning", "")
|
||||
else:
|
||||
# 单个目标的情况
|
||||
conversation_info.goal_list.append(result)
|
||||
return goal, "", reasoning
|
||||
|
||||
# 如果解析失败,返回默认值
|
||||
return "", "", ""
|
||||
|
||||
async def _update_goals(self, new_goal: str, method: str, reasoning: str):
|
||||
"""更新目标列表
|
||||
|
||||
Args:
|
||||
new_goal: 新的目标
|
||||
method: 实现目标的方法
|
||||
reasoning: 目标的原因
|
||||
"""
|
||||
# 检查新目标是否与现有目标相似
|
||||
for i, (existing_goal, _, _) in enumerate(self.goals):
|
||||
if _calculate_similarity(new_goal, existing_goal) > 0.7: # 相似度阈值
|
||||
# 更新现有目标
|
||||
self.goals[i] = (new_goal, method, reasoning)
|
||||
# 将此目标移到列表前面(最主要的位置)
|
||||
self.goals.insert(0, self.goals.pop(i))
|
||||
return
|
||||
|
||||
# 添加新目标到列表前面
|
||||
self.goals.insert(0, (new_goal, method, reasoning))
|
||||
|
||||
# 限制目标数量
|
||||
if len(self.goals) > self.max_goals:
|
||||
self.goals.pop() # 移除最老的目标
|
||||
|
||||
async def get_all_goals(self) -> List[Tuple[str, str, str]]:
|
||||
"""获取所有当前目标
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, str]]: 目标列表,每项为(目标, 方法, 原因)
|
||||
"""
|
||||
return self.goals.copy()
|
||||
|
||||
async def get_alternative_goals(self) -> List[Tuple[str, str, str]]:
|
||||
"""获取除了当前主要目标外的其他备选目标
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, str]]: 备选目标列表
|
||||
"""
|
||||
if len(self.goals) <= 1:
|
||||
return []
|
||||
return self.goals[1:].copy()
|
||||
|
||||
async def analyze_conversation(self, goal, reasoning):
|
||||
messages = self.chat_observer.get_cached_messages()
|
||||
chat_history_text = await build_readable_messages(
|
||||
messages,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
|
||||
persona_text = f"你的名字是{self.name},{self.personality_info}。"
|
||||
# ===> Persona 文本构建结束 <===
|
||||
|
||||
# --- 修改 Prompt 字符串,使用 persona_text ---
|
||||
prompt = f"""{persona_text}。现在你在参与一场QQ聊天,
|
||||
当前对话目标:{goal}
|
||||
产生该对话目标的原因:{reasoning}
|
||||
|
||||
请分析以下聊天记录,并根据你的性格特征评估该目标是否已经达到,或者你是否希望停止该次对话。
|
||||
聊天记录:
|
||||
{chat_history_text}
|
||||
请以JSON格式输出,包含以下字段:
|
||||
1. goal_achieved: 对话目标是否已经达到(true/false)
|
||||
2. stop_conversation: 是否希望停止该次对话(true/false)
|
||||
3. reason: 为什么希望停止该次对话(简要解释)
|
||||
|
||||
输出格式示例:
|
||||
{{
|
||||
"goal_achieved": true,
|
||||
"stop_conversation": false,
|
||||
"reason": "虽然目标已达成,但对话仍然有继续的价值"
|
||||
}}"""
|
||||
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM原始返回内容: {content}")
|
||||
|
||||
# 尝试解析JSON
|
||||
success, result = get_items_from_json(
|
||||
content,
|
||||
self.private_name,
|
||||
"goal_achieved",
|
||||
"stop_conversation",
|
||||
"reason",
|
||||
required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str},
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.error(f"[私聊][{self.private_name}]无法解析对话分析结果JSON")
|
||||
return False, False, "解析结果失败"
|
||||
|
||||
goal_achieved = result["goal_achieved"]
|
||||
stop_conversation = result["stop_conversation"]
|
||||
reason = result["reason"]
|
||||
|
||||
return goal_achieved, stop_conversation, reason
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]分析对话状态时出错: {str(e)}")
|
||||
return False, False, f"分析出错: {str(e)}"
|
||||
|
||||
|
||||
# 先注释掉,万一以后出问题了还能开回来(((
|
||||
# class DirectMessageSender:
|
||||
# """直接发送消息到平台的发送器"""
|
||||
|
||||
# def __init__(self, private_name: str):
|
||||
# self.logger = get_module_logger("direct_sender")
|
||||
# self.storage = MessageStorage()
|
||||
# self.private_name = private_name
|
||||
|
||||
# async def send_via_ws(self, message: MessageSending) -> None:
|
||||
# try:
|
||||
# await global_api.send_message(message)
|
||||
# except Exception as e:
|
||||
# raise ValueError(f"未找到平台:{message.message_info.platform} 的url配置,请检查配置文件") from e
|
||||
|
||||
# async def send_message(
|
||||
# self,
|
||||
# chat_stream: ChatStream,
|
||||
# content: str,
|
||||
# reply_to_message: Optional[Message] = None,
|
||||
# ) -> None:
|
||||
# """直接发送消息到平台
|
||||
|
||||
# Args:
|
||||
# chat_stream: 聊天流
|
||||
# content: 消息内容
|
||||
# reply_to_message: 要回复的消息
|
||||
# """
|
||||
# # 构建消息对象
|
||||
# message_segment = Seg(type="text", data=content)
|
||||
# bot_user_info = UserInfo(
|
||||
# user_id=global_config.BOT_QQ,
|
||||
# user_nickname=global_config.BOT_NICKNAME,
|
||||
# platform=chat_stream.platform,
|
||||
# )
|
||||
|
||||
# message = MessageSending(
|
||||
# message_id=f"dm{round(time.time(), 2)}",
|
||||
# chat_stream=chat_stream,
|
||||
# bot_user_info=bot_user_info,
|
||||
# sender_info=reply_to_message.message_info.user_info if reply_to_message else None,
|
||||
# message_segment=message_segment,
|
||||
# reply=reply_to_message,
|
||||
# is_head=True,
|
||||
# is_emoji=False,
|
||||
# thinking_start_time=time.time(),
|
||||
# )
|
||||
|
||||
# # 处理消息
|
||||
# await message.process()
|
||||
|
||||
# _message_json = message.to_dict()
|
||||
|
||||
# # 发送消息
|
||||
# try:
|
||||
# await self.send_via_ws(message)
|
||||
# await self.storage.store_message(message, chat_stream)
|
||||
# logger.success(f"[私聊][{self.private_name}]PFC消息已发送: {content}")
|
||||
# except Exception as e:
|
||||
# logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {str(e)}")
|
||||
85
src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py
Normal file
85
src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from typing import List, Tuple
|
||||
from src.common.logger import get_module_logger
|
||||
from src.plugins.memory_system.Hippocampus import HippocampusManager
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ...config.config import global_config
|
||||
from ..chat.message import Message
|
||||
from ..knowledge.knowledge_lib import qa_manager
|
||||
from ..utils.chat_message_builder import build_readable_messages
|
||||
|
||||
logger = get_module_logger("knowledge_fetcher")
|
||||
|
||||
|
||||
class KnowledgeFetcher:
|
||||
"""知识调取器"""
|
||||
|
||||
def __init__(self, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model=global_config.llm_normal,
|
||||
temperature=global_config.llm_normal["temp"],
|
||||
max_tokens=1000,
|
||||
request_type="knowledge_fetch",
|
||||
)
|
||||
self.private_name = private_name
|
||||
|
||||
def _lpmm_get_knowledge(self, query: str) -> str:
|
||||
"""获取相关知识
|
||||
|
||||
Args:
|
||||
query: 查询内容
|
||||
|
||||
Returns:
|
||||
str: 构造好的,带相关度的知识
|
||||
"""
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]正在从LPMM知识库中获取知识")
|
||||
try:
|
||||
knowledge_info = qa_manager.get_knowledge(query)
|
||||
logger.debug(f"[私聊][{self.private_name}]LPMM知识库查询结果: {knowledge_info:150}")
|
||||
return knowledge_info
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]LPMM知识库搜索工具执行失败: {str(e)}")
|
||||
return "未找到匹配的知识"
|
||||
|
||||
async def fetch(self, query: str, chat_history: List[Message]) -> Tuple[str, str]:
|
||||
"""获取相关知识
|
||||
|
||||
Args:
|
||||
query: 查询内容
|
||||
chat_history: 聊天历史
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (获取的知识, 知识来源)
|
||||
"""
|
||||
# 构建查询上下文
|
||||
chat_history_text = await build_readable_messages(
|
||||
chat_history,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
|
||||
# 从记忆中获取相关知识
|
||||
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||
text=f"{query}\n{chat_history_text}",
|
||||
max_memory_num=3,
|
||||
max_memory_length=2,
|
||||
max_depth=3,
|
||||
fast_retrieval=False,
|
||||
)
|
||||
knowledge_text = ""
|
||||
sources_text = "无记忆匹配" # 默认值
|
||||
if related_memory:
|
||||
sources = []
|
||||
for memory in related_memory:
|
||||
knowledge_text += memory[1] + "\n"
|
||||
sources.append(f"记忆片段{memory[0]}")
|
||||
knowledge_text = knowledge_text.strip()
|
||||
sources_text = ",".join(sources)
|
||||
|
||||
knowledge_text += "\n现在有以下**知识**可供参考:\n "
|
||||
knowledge_text += self._lpmm_get_knowledge(query)
|
||||
knowledge_text += "\n请记住这些**知识**,并根据**知识**回答问题。\n"
|
||||
|
||||
return knowledge_text or "未找到相关知识", sources_text or "无记忆匹配"
|
||||
115
src/chat/brain_chat/PFC/pfc_manager.py
Normal file
115
src/chat/brain_chat/PFC/pfc_manager.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
from src.common.logger import get_module_logger
|
||||
from .conversation import Conversation
|
||||
import traceback
|
||||
|
||||
logger = get_module_logger("pfc_manager")
|
||||
|
||||
|
||||
class PFCManager:
|
||||
"""PFC对话管理器,负责管理所有对话实例"""
|
||||
|
||||
# 单例模式
|
||||
_instance = None
|
||||
|
||||
# 会话实例管理
|
||||
_instances: Dict[str, Conversation] = {}
|
||||
_initializing: Dict[str, bool] = {}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "PFCManager":
|
||||
"""获取管理器单例
|
||||
|
||||
Returns:
|
||||
PFCManager: 管理器实例
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = PFCManager()
|
||||
return cls._instance
|
||||
|
||||
async def get_or_create_conversation(self, stream_id: str, private_name: str) -> Optional[Conversation]:
|
||||
"""获取或创建对话实例
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
private_name: 私聊名称
|
||||
|
||||
Returns:
|
||||
Optional[Conversation]: 对话实例,创建失败则返回None
|
||||
"""
|
||||
# 检查是否已经有实例
|
||||
if stream_id in self._initializing and self._initializing[stream_id]:
|
||||
logger.debug(f"[私聊][{private_name}]会话实例正在初始化中: {stream_id}")
|
||||
return None
|
||||
|
||||
if stream_id in self._instances and self._instances[stream_id].should_continue:
|
||||
logger.debug(f"[私聊][{private_name}]使用现有会话实例: {stream_id}")
|
||||
return self._instances[stream_id]
|
||||
if stream_id in self._instances:
|
||||
instance = self._instances[stream_id]
|
||||
if (
|
||||
hasattr(instance, "ignore_until_timestamp")
|
||||
and instance.ignore_until_timestamp
|
||||
and time.time() < instance.ignore_until_timestamp
|
||||
):
|
||||
logger.debug(f"[私聊][{private_name}]会话实例当前处于忽略状态: {stream_id}")
|
||||
# 返回 None 阻止交互。或者可以返回实例但标记它被忽略了喵?
|
||||
# 还是返回 None 吧喵。
|
||||
return None
|
||||
|
||||
# 检查 should_continue 状态
|
||||
if instance.should_continue:
|
||||
logger.debug(f"[私聊][{private_name}]使用现有会话实例: {stream_id}")
|
||||
return instance
|
||||
# else: 实例存在但不应继续
|
||||
try:
|
||||
# 创建新实例
|
||||
logger.info(f"[私聊][{private_name}]创建新的对话实例: {stream_id}")
|
||||
self._initializing[stream_id] = True
|
||||
# 创建实例
|
||||
conversation_instance = Conversation(stream_id, private_name)
|
||||
self._instances[stream_id] = conversation_instance
|
||||
|
||||
# 启动实例初始化
|
||||
await self._initialize_conversation(conversation_instance)
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{private_name}]创建会话实例失败: {stream_id}, 错误: {e}")
|
||||
return None
|
||||
|
||||
return conversation_instance
|
||||
|
||||
async def _initialize_conversation(self, conversation: Conversation):
|
||||
"""初始化会话实例
|
||||
|
||||
Args:
|
||||
conversation: 要初始化的会话实例
|
||||
"""
|
||||
stream_id = conversation.stream_id
|
||||
private_name = conversation.private_name
|
||||
|
||||
try:
|
||||
logger.info(f"[私聊][{private_name}]开始初始化会话实例: {stream_id}")
|
||||
# 启动初始化流程
|
||||
await conversation._initialize()
|
||||
|
||||
# 标记初始化完成
|
||||
self._initializing[stream_id] = False
|
||||
|
||||
logger.info(f"[私聊][{private_name}]会话实例 {stream_id} 初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{private_name}]管理器初始化会话实例失败: {stream_id}, 错误: {e}")
|
||||
logger.error(f"[私聊][{private_name}]{traceback.format_exc()}")
|
||||
# 清理失败的初始化
|
||||
|
||||
async def get_conversation(self, stream_id: str) -> Optional[Conversation]:
|
||||
"""获取已存在的会话实例
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
Optional[Conversation]: 会话实例,不存在则返回None
|
||||
"""
|
||||
return self._instances.get(stream_id)
|
||||
23
src/chat/brain_chat/PFC/pfc_types.py
Normal file
23
src/chat/brain_chat/PFC/pfc_types.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class ConversationState(Enum):
|
||||
"""对话状态"""
|
||||
|
||||
INIT = "初始化"
|
||||
RETHINKING = "重新思考"
|
||||
ANALYZING = "分析历史"
|
||||
PLANNING = "规划目标"
|
||||
GENERATING = "生成回复"
|
||||
CHECKING = "检查回复"
|
||||
SENDING = "发送消息"
|
||||
FETCHING = "获取知识"
|
||||
WAITING = "等待"
|
||||
LISTENING = "倾听"
|
||||
ENDED = "结束"
|
||||
JUDGING = "判断"
|
||||
IGNORED = "屏蔽"
|
||||
|
||||
|
||||
ActionType = Literal["direct_reply", "fetch_knowledge", "wait"]
|
||||
127
src/chat/brain_chat/PFC/pfc_utils.py
Normal file
127
src/chat/brain_chat/PFC/pfc_utils.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, Any, Optional, Tuple, List, Union
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("pfc_utils")
|
||||
|
||||
|
||||
def get_items_from_json(
|
||||
content: str,
|
||||
private_name: str,
|
||||
*items: str,
|
||||
default_values: Optional[Dict[str, Any]] = None,
|
||||
required_types: Optional[Dict[str, type]] = None,
|
||||
allow_array: bool = True,
|
||||
) -> Tuple[bool, Union[Dict[str, Any], List[Dict[str, Any]]]]:
|
||||
"""从文本中提取JSON内容并获取指定字段
|
||||
|
||||
Args:
|
||||
content: 包含JSON的文本
|
||||
private_name: 私聊名称
|
||||
*items: 要提取的字段名
|
||||
default_values: 字段的默认值,格式为 {字段名: 默认值}
|
||||
required_types: 字段的必需类型,格式为 {字段名: 类型}
|
||||
allow_array: 是否允许解析JSON数组
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Union[Dict[str, Any], List[Dict[str, Any]]]]: (是否成功, 提取的字段字典或字典列表)
|
||||
"""
|
||||
content = content.strip()
|
||||
result = {}
|
||||
|
||||
# 设置默认值
|
||||
if default_values:
|
||||
result.update(default_values)
|
||||
|
||||
# 首先尝试解析为JSON数组
|
||||
if allow_array:
|
||||
try:
|
||||
# 尝试找到文本中的JSON数组
|
||||
array_pattern = r"\[[\s\S]*\]"
|
||||
array_match = re.search(array_pattern, content)
|
||||
if array_match:
|
||||
array_content = array_match.group()
|
||||
json_array = json.loads(array_content)
|
||||
|
||||
# 确认是数组类型
|
||||
if isinstance(json_array, list):
|
||||
# 验证数组中的每个项目是否包含所有必需字段
|
||||
valid_items = []
|
||||
for item in json_array:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
# 检查是否有所有必需字段
|
||||
if all(field in item for field in items):
|
||||
# 验证字段类型
|
||||
if required_types:
|
||||
type_valid = True
|
||||
for field, expected_type in required_types.items():
|
||||
if field in item and not isinstance(item[field], expected_type):
|
||||
type_valid = False
|
||||
break
|
||||
|
||||
if not type_valid:
|
||||
continue
|
||||
|
||||
# 验证字符串字段不为空
|
||||
string_valid = True
|
||||
for field in items:
|
||||
if isinstance(item[field], str) and not item[field].strip():
|
||||
string_valid = False
|
||||
break
|
||||
|
||||
if not string_valid:
|
||||
continue
|
||||
|
||||
valid_items.append(item)
|
||||
|
||||
if valid_items:
|
||||
return True, valid_items
|
||||
except json.JSONDecodeError:
|
||||
logger.debug(f"[私聊][{private_name}]JSON数组解析失败,尝试解析单个JSON对象")
|
||||
except Exception as e:
|
||||
logger.debug(f"[私聊][{private_name}]尝试解析JSON数组时出错: {str(e)}")
|
||||
|
||||
# 尝试解析JSON对象
|
||||
try:
|
||||
json_data = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
# 如果直接解析失败,尝试查找和提取JSON部分
|
||||
json_pattern = r"\{[^{}]*\}"
|
||||
json_match = re.search(json_pattern, content)
|
||||
if json_match:
|
||||
try:
|
||||
json_data = json.loads(json_match.group())
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"[私聊][{private_name}]提取的JSON内容解析失败")
|
||||
return False, result
|
||||
else:
|
||||
logger.error(f"[私聊][{private_name}]无法在返回内容中找到有效的JSON")
|
||||
return False, result
|
||||
|
||||
# 提取字段
|
||||
for item in items:
|
||||
if item in json_data:
|
||||
result[item] = json_data[item]
|
||||
|
||||
# 验证必需字段
|
||||
if not all(item in result for item in items):
|
||||
logger.error(f"[私聊][{private_name}]JSON缺少必要字段,实际内容: {json_data}")
|
||||
return False, result
|
||||
|
||||
# 验证字段类型
|
||||
if required_types:
|
||||
for field, expected_type in required_types.items():
|
||||
if field in result and not isinstance(result[field], expected_type):
|
||||
logger.error(f"[私聊][{private_name}]{field} 必须是 {expected_type.__name__} 类型")
|
||||
return False, result
|
||||
|
||||
# 验证字符串字段不为空
|
||||
for field in items:
|
||||
if isinstance(result[field], str) and not result[field].strip():
|
||||
logger.error(f"[私聊][{private_name}]{field} 不能为空")
|
||||
return False, result
|
||||
|
||||
return True, result
|
||||
183
src/chat/brain_chat/PFC/reply_checker.py
Normal file
183
src/chat/brain_chat/PFC/reply_checker.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import json
|
||||
from typing import Tuple, List, Dict, Any
|
||||
from src.common.logger import get_module_logger
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ...config.config import global_config
|
||||
from .chat_observer import ChatObserver
|
||||
from maim_message import UserInfo
|
||||
|
||||
logger = get_module_logger("reply_checker")
|
||||
|
||||
|
||||
class ReplyChecker:
|
||||
"""回复检查器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model=global_config.llm_PFC_reply_checker, temperature=0.50, max_tokens=1000, request_type="reply_check"
|
||||
)
|
||||
self.name = global_config.BOT_NICKNAME
|
||||
self.private_name = private_name
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
self.max_retries = 3 # 最大重试次数
|
||||
|
||||
async def check(
|
||||
self, reply: str, goal: str, chat_history: List[Dict[str, Any]], chat_history_text: str, retry_count: int = 0
|
||||
) -> Tuple[bool, str, bool]:
|
||||
"""检查生成的回复是否合适
|
||||
|
||||
Args:
|
||||
reply: 生成的回复
|
||||
goal: 对话目标
|
||||
chat_history: 对话历史记录
|
||||
chat_history_text: 对话历史记录文本
|
||||
retry_count: 当前重试次数
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, bool]: (是否合适, 原因, 是否需要重新规划)
|
||||
"""
|
||||
# 不再从 observer 获取,直接使用传入的 chat_history
|
||||
# messages = self.chat_observer.get_cached_messages(limit=20)
|
||||
try:
|
||||
# 筛选出最近由 Bot 自己发送的消息
|
||||
bot_messages = []
|
||||
for msg in reversed(chat_history):
|
||||
user_info = UserInfo.from_dict(msg.get("user_info", {}))
|
||||
if str(user_info.user_id) == str(global_config.BOT_QQ): # 确保比较的是字符串
|
||||
bot_messages.append(msg.get("processed_plain_text", ""))
|
||||
if len(bot_messages) >= 2: # 只和最近的两条比较
|
||||
break
|
||||
# 进行比较
|
||||
if bot_messages:
|
||||
# 可以用简单比较,或者更复杂的相似度库 (如 difflib)
|
||||
# 简单比较:是否完全相同
|
||||
if reply == bot_messages[0]: # 和最近一条完全一样
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ReplyChecker 检测到回复与上一条 Bot 消息完全相同: '{reply}'"
|
||||
)
|
||||
return (
|
||||
False,
|
||||
"被逻辑检查拒绝:回复内容与你上一条发言完全相同,可以选择深入话题或寻找其它话题或等待",
|
||||
True,
|
||||
) # 不合适,需要返回至决策层
|
||||
# 2. 相似度检查 (如果精确匹配未通过)
|
||||
import difflib # 导入 difflib 库
|
||||
|
||||
# 计算编辑距离相似度,ratio() 返回 0 到 1 之间的浮点数
|
||||
similarity_ratio = difflib.SequenceMatcher(None, reply, bot_messages[0]).ratio()
|
||||
logger.debug(f"[私聊][{self.private_name}]ReplyChecker - 相似度: {similarity_ratio:.2f}")
|
||||
|
||||
# 设置一个相似度阈值
|
||||
similarity_threshold = 0.9
|
||||
if similarity_ratio > similarity_threshold:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ReplyChecker 检测到回复与上一条 Bot 消息高度相似 (相似度 {similarity_ratio:.2f}): '{reply}'"
|
||||
)
|
||||
return (
|
||||
False,
|
||||
f"被逻辑检查拒绝:回复内容与你上一条发言高度相似 (相似度 {similarity_ratio:.2f}),可以选择深入话题或寻找其它话题或等待。",
|
||||
True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
logger.error(f"[私聊][{self.private_name}]检查回复时出错: 类型={type(e)}, 值={e}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}") # 打印详细的回溯信息
|
||||
|
||||
prompt = f"""你是一个聊天逻辑检查器,请检查以下回复或消息是否合适:
|
||||
|
||||
当前对话目标:{goal}
|
||||
最新的对话记录:
|
||||
{chat_history_text}
|
||||
|
||||
待检查的消息:
|
||||
{reply}
|
||||
|
||||
请结合聊天记录检查以下几点:
|
||||
1. 这条消息是否依然符合当前对话目标和实现方式
|
||||
2. 这条消息是否与最新的对话记录保持一致性
|
||||
3. 是否存在重复发言,或重复表达同质内容(尤其是只是换一种方式表达了相同的含义)
|
||||
4. 这条消息是否包含违规内容(例如血腥暴力,政治敏感等)
|
||||
5. 这条消息是否以发送者的角度发言(不要让发送者自己回复自己的消息)
|
||||
6. 这条消息是否通俗易懂
|
||||
7. 这条消息是否有些多余,例如在对方没有回复的情况下,依然连续多次“消息轰炸”(尤其是已经连续发送3条信息的情况,这很可能不合理,需要着重判断)
|
||||
8. 这条消息是否使用了完全没必要的修辞
|
||||
9. 这条消息是否逻辑通顺
|
||||
10. 这条消息是否太过冗长了(通常私聊的每条消息长度在20字以内,除非特殊情况)
|
||||
11. 在连续多次发送消息的情况下,这条消息是否衔接自然,会不会显得奇怪(例如连续两条消息中部分内容重叠)
|
||||
|
||||
请以JSON格式输出,包含以下字段:
|
||||
1. suitable: 是否合适 (true/false)
|
||||
2. reason: 原因说明
|
||||
3. need_replan: 是否需要重新决策 (true/false),当你认为此时已经不适合发消息,需要规划其它行动时,设为true
|
||||
|
||||
输出格式示例:
|
||||
{{
|
||||
"suitable": true,
|
||||
"reason": "回复符合要求,虽然有可能略微偏离目标,但是整体内容流畅得体",
|
||||
"need_replan": false
|
||||
}}
|
||||
|
||||
注意:请严格按照JSON格式输出,不要包含任何其他内容。"""
|
||||
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
logger.debug(f"[私聊][{self.private_name}]检查回复的原始返回: {content}")
|
||||
|
||||
# 清理内容,尝试提取JSON部分
|
||||
content = content.strip()
|
||||
try:
|
||||
# 尝试直接解析
|
||||
result = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
# 如果直接解析失败,尝试查找和提取JSON部分
|
||||
import re
|
||||
|
||||
json_pattern = r"\{[^{}]*\}"
|
||||
json_match = re.search(json_pattern, content)
|
||||
if json_match:
|
||||
try:
|
||||
result = json.loads(json_match.group())
|
||||
except json.JSONDecodeError:
|
||||
# 如果JSON解析失败,尝试从文本中提取结果
|
||||
is_suitable = "不合适" not in content.lower() and "违规" not in content.lower()
|
||||
reason = content[:100] if content else "无法解析响应"
|
||||
need_replan = "重新规划" in content.lower() or "目标不适合" in content.lower()
|
||||
return is_suitable, reason, need_replan
|
||||
else:
|
||||
# 如果找不到JSON,从文本中判断
|
||||
is_suitable = "不合适" not in content.lower() and "违规" not in content.lower()
|
||||
reason = content[:100] if content else "无法解析响应"
|
||||
need_replan = "重新规划" in content.lower() or "目标不适合" in content.lower()
|
||||
return is_suitable, reason, need_replan
|
||||
|
||||
# 验证JSON字段
|
||||
suitable = result.get("suitable", None)
|
||||
reason = result.get("reason", "未提供原因")
|
||||
need_replan = result.get("need_replan", False)
|
||||
|
||||
# 如果suitable字段是字符串,转换为布尔值
|
||||
if isinstance(suitable, str):
|
||||
suitable = suitable.lower() == "true"
|
||||
|
||||
# 如果suitable字段不存在或不是布尔值,从reason中判断
|
||||
if suitable is None:
|
||||
suitable = "不合适" not in reason.lower() and "违规" not in reason.lower()
|
||||
|
||||
# 如果不合适且未达到最大重试次数,返回需要重试
|
||||
if not suitable and retry_count < self.max_retries:
|
||||
return False, reason, False
|
||||
|
||||
# 如果不合适且已达到最大重试次数,返回需要重新规划
|
||||
if not suitable and retry_count >= self.max_retries:
|
||||
return False, f"多次重试后仍不合适: {reason}", True
|
||||
|
||||
return suitable, reason, need_replan
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]检查回复时出错: {e}")
|
||||
# 如果出错且已达到最大重试次数,建议重新规划
|
||||
if retry_count >= self.max_retries:
|
||||
return False, "多次检查失败,建议重新规划", True
|
||||
return False, f"检查过程出错,建议重试: {str(e)}", False
|
||||
228
src/chat/brain_chat/PFC/reply_generator.py
Normal file
228
src/chat/brain_chat/PFC/reply_generator.py
Normal file
@@ -0,0 +1,228 @@
|
||||
from typing import Tuple, List, Dict, Any
|
||||
from src.common.logger import get_module_logger
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ...config.config import global_config
|
||||
from .chat_observer import ChatObserver
|
||||
from .reply_checker import ReplyChecker
|
||||
from src.individuality.individuality import Individuality
|
||||
from .observation_info import ObservationInfo
|
||||
from .conversation_info import ConversationInfo
|
||||
from src.plugins.utils.chat_message_builder import build_readable_messages
|
||||
|
||||
logger = get_module_logger("reply_generator")
|
||||
|
||||
# --- 定义 Prompt 模板 ---
|
||||
|
||||
# Prompt for direct_reply (首次回复)
|
||||
PROMPT_DIRECT_REPLY = """{persona_text}。现在你在参与一场QQ私聊,请根据以下信息生成一条回复:
|
||||
|
||||
当前对话目标:{goals_str}
|
||||
|
||||
{knowledge_info_str}
|
||||
|
||||
最近的聊天记录:
|
||||
{chat_history_text}
|
||||
|
||||
|
||||
请根据上述信息,结合聊天记录,回复对方。该回复应该:
|
||||
1. 符合对话目标,以"你"的角度发言(不要自己与自己对话!)
|
||||
2. 符合你的性格特征和身份细节
|
||||
3. 通俗易懂,自然流畅,像正常聊天一样,简短(通常20字以内,除非特殊情况)
|
||||
4. 可以适当利用相关知识,但不要生硬引用
|
||||
5. 自然、得体,结合聊天记录逻辑合理,且没有重复表达同质内容
|
||||
|
||||
请注意把握聊天内容,不要回复的太有条理,可以有个性。请分清"你"和对方说的话,不要把"你"说的话当做对方说的话,这是你自己说的话。
|
||||
可以回复得自然随意自然一些,就像真人一样,注意把握聊天内容,整体风格可以平和、简短,不要刻意突出自身学科背景,不要说你说过的话,可以简短,多简短都可以,但是避免冗长。
|
||||
请你注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。
|
||||
|
||||
请直接输出回复内容,不需要任何额外格式。"""
|
||||
|
||||
# Prompt for send_new_message (追问/补充)
|
||||
PROMPT_SEND_NEW_MESSAGE = """{persona_text}。现在你在参与一场QQ私聊,**刚刚你已经发送了一条或多条消息**,现在请根据以下信息再发一条新消息:
|
||||
|
||||
当前对话目标:{goals_str}
|
||||
|
||||
{knowledge_info_str}
|
||||
|
||||
最近的聊天记录:
|
||||
{chat_history_text}
|
||||
|
||||
|
||||
请根据上述信息,结合聊天记录,继续发一条新消息(例如对之前消息的补充,深入话题,或追问等等)。该消息应该:
|
||||
1. 符合对话目标,以"你"的角度发言(不要自己与自己对话!)
|
||||
2. 符合你的性格特征和身份细节
|
||||
3. 通俗易懂,自然流畅,像正常聊天一样,简短(通常20字以内,除非特殊情况)
|
||||
4. 可以适当利用相关知识,但不要生硬引用
|
||||
5. 跟之前你发的消息自然的衔接,逻辑合理,且没有重复表达同质内容或部分重叠内容
|
||||
|
||||
请注意把握聊天内容,不用太有条理,可以有个性。请分清"你"和对方说的话,不要把"你"说的话当做对方说的话,这是你自己说的话。
|
||||
这条消息可以自然随意自然一些,就像真人一样,注意把握聊天内容,整体风格可以平和、简短,不要刻意突出自身学科背景,不要说你说过的话,可以简短,多简短都可以,但是避免冗长。
|
||||
请你注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出消息内容。
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。
|
||||
|
||||
请直接输出回复内容,不需要任何额外格式。"""
|
||||
|
||||
# Prompt for say_goodbye (告别语生成)
|
||||
PROMPT_FAREWELL = """{persona_text}。你在参与一场 QQ 私聊,现在对话似乎已经结束,你决定再发一条最后的消息来圆满结束。
|
||||
|
||||
最近的聊天记录:
|
||||
{chat_history_text}
|
||||
|
||||
请根据上述信息,结合聊天记录,构思一条**简短、自然、符合你人设**的最后的消息。
|
||||
这条消息应该:
|
||||
1. 从你自己的角度发言。
|
||||
2. 符合你的性格特征和身份细节。
|
||||
3. 通俗易懂,自然流畅,通常很简短。
|
||||
4. 自然地为这场对话画上句号,避免开启新话题或显得冗长、刻意。
|
||||
|
||||
请像真人一样随意自然,**简洁是关键**。
|
||||
不要输出多余内容(包括前后缀、冒号、引号、括号、表情包、at或@等)。
|
||||
|
||||
请直接输出最终的告别消息内容,不需要任何额外格式。"""
|
||||
|
||||
|
||||
class ReplyGenerator:
|
||||
"""回复生成器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model=global_config.llm_PFC_chat,
|
||||
temperature=global_config.llm_PFC_chat["temp"],
|
||||
max_tokens=300,
|
||||
request_type="reply_generation",
|
||||
)
|
||||
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
|
||||
self.name = global_config.BOT_NICKNAME
|
||||
self.private_name = private_name
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
self.reply_checker = ReplyChecker(stream_id, private_name)
|
||||
|
||||
# 修改 generate 方法签名,增加 action_type 参数
|
||||
async def generate(
|
||||
self, observation_info: ObservationInfo, conversation_info: ConversationInfo, action_type: str
|
||||
) -> str:
|
||||
"""生成回复
|
||||
|
||||
Args:
|
||||
observation_info: 观察信息
|
||||
conversation_info: 对话信息
|
||||
action_type: 当前执行的动作类型 ('direct_reply' 或 'send_new_message')
|
||||
|
||||
Returns:
|
||||
str: 生成的回复
|
||||
"""
|
||||
# 构建提示词
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]开始生成回复 (动作类型: {action_type}):当前目标: {conversation_info.goal_list}"
|
||||
)
|
||||
|
||||
# --- 构建通用 Prompt 参数 ---
|
||||
# (这部分逻辑基本不变)
|
||||
|
||||
# 构建对话目标 (goals_str)
|
||||
goals_str = ""
|
||||
if conversation_info.goal_list:
|
||||
for goal_reason in conversation_info.goal_list:
|
||||
if isinstance(goal_reason, dict):
|
||||
goal = goal_reason.get("goal", "目标内容缺失")
|
||||
reasoning = goal_reason.get("reasoning", "没有明确原因")
|
||||
else:
|
||||
goal = str(goal_reason)
|
||||
reasoning = "没有明确原因"
|
||||
|
||||
goal = str(goal) if goal is not None else "目标内容缺失"
|
||||
reasoning = str(reasoning) if reasoning is not None else "没有明确原因"
|
||||
goals_str += f"- 目标:{goal}\n 原因:{reasoning}\n"
|
||||
else:
|
||||
goals_str = "- 目前没有明确对话目标\n" # 简化无目标情况
|
||||
|
||||
# --- 新增:构建知识信息字符串 ---
|
||||
knowledge_info_str = "【供参考的相关知识和记忆】\n" # 稍微改下标题,表明是供参考
|
||||
try:
|
||||
# 检查 conversation_info 是否有 knowledge_list 并且不为空
|
||||
if hasattr(conversation_info, "knowledge_list") and conversation_info.knowledge_list:
|
||||
# 最多只显示最近的 5 条知识
|
||||
recent_knowledge = conversation_info.knowledge_list[-5:]
|
||||
for i, knowledge_item in enumerate(recent_knowledge):
|
||||
if isinstance(knowledge_item, dict):
|
||||
query = knowledge_item.get("query", "未知查询")
|
||||
knowledge = knowledge_item.get("knowledge", "无知识内容")
|
||||
source = knowledge_item.get("source", "未知来源")
|
||||
# 只取知识内容的前 2000 个字
|
||||
knowledge_snippet = knowledge[:2000] + "..." if len(knowledge) > 2000 else knowledge
|
||||
knowledge_info_str += (
|
||||
f"{i + 1}. 关于 '{query}' (来源: {source}): {knowledge_snippet}\n" # 格式微调,更简洁
|
||||
)
|
||||
else:
|
||||
knowledge_info_str += f"{i + 1}. 发现一条格式不正确的知识记录。\n"
|
||||
|
||||
if not recent_knowledge:
|
||||
knowledge_info_str += "- 暂无。\n" # 更简洁的提示
|
||||
|
||||
else:
|
||||
knowledge_info_str += "- 暂无。\n"
|
||||
except AttributeError:
|
||||
logger.warning(f"[私聊][{self.private_name}]ConversationInfo 对象可能缺少 knowledge_list 属性。")
|
||||
knowledge_info_str += "- 获取知识列表时出错。\n"
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]构建知识信息字符串时出错: {e}")
|
||||
knowledge_info_str += "- 处理知识列表时出错。\n"
|
||||
|
||||
# 获取聊天历史记录 (chat_history_text)
|
||||
chat_history_text = observation_info.chat_history_str
|
||||
if observation_info.new_messages_count > 0 and observation_info.unprocessed_messages:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
new_messages_str = await build_readable_messages(
|
||||
new_messages_list,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
chat_history_text += f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}"
|
||||
elif not chat_history_text:
|
||||
chat_history_text = "还没有聊天记录。"
|
||||
|
||||
# 构建 Persona 文本 (persona_text)
|
||||
persona_text = f"你的名字是{self.name},{self.personality_info}。"
|
||||
|
||||
# --- 选择 Prompt ---
|
||||
if action_type == "send_new_message":
|
||||
prompt_template = PROMPT_SEND_NEW_MESSAGE
|
||||
logger.info(f"[私聊][{self.private_name}]使用 PROMPT_SEND_NEW_MESSAGE (追问生成)")
|
||||
elif action_type == "say_goodbye": # 处理告别动作
|
||||
prompt_template = PROMPT_FAREWELL
|
||||
logger.info(f"[私聊][{self.private_name}]使用 PROMPT_FAREWELL (告别语生成)")
|
||||
else: # 默认使用 direct_reply 的 prompt (包括 'direct_reply' 或其他未明确处理的类型)
|
||||
prompt_template = PROMPT_DIRECT_REPLY
|
||||
logger.info(f"[私聊][{self.private_name}]使用 PROMPT_DIRECT_REPLY (首次/非连续回复生成)")
|
||||
|
||||
# --- 格式化最终的 Prompt ---
|
||||
prompt = prompt_template.format(
|
||||
persona_text=persona_text,
|
||||
goals_str=goals_str,
|
||||
chat_history_text=chat_history_text,
|
||||
knowledge_info_str=knowledge_info_str,
|
||||
)
|
||||
|
||||
# --- 调用 LLM 生成 ---
|
||||
logger.debug(f"[私聊][{self.private_name}]发送到LLM的生成提示词:\n------\n{prompt}\n------")
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
logger.debug(f"[私聊][{self.private_name}]生成的回复: {content}")
|
||||
# 移除旧的检查新消息逻辑,这应该由 conversation 控制流处理
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]生成回复时出错: {e}")
|
||||
return "抱歉,我现在有点混乱,让我重新思考一下..."
|
||||
|
||||
# check_reply 方法保持不变
|
||||
async def check_reply(
|
||||
self, reply: str, goal: str, chat_history: List[Dict[str, Any]], chat_history_str: str, retry_count: int = 0
|
||||
) -> Tuple[bool, str, bool]:
|
||||
"""检查回复是否合适
|
||||
(此方法逻辑保持不变)
|
||||
"""
|
||||
return await self.reply_checker.check(reply, goal, chat_history, chat_history_str, retry_count)
|
||||
79
src/chat/brain_chat/PFC/waiter.py
Normal file
79
src/chat/brain_chat/PFC/waiter.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from src.common.logger import get_module_logger
|
||||
from .chat_observer import ChatObserver
|
||||
from .conversation_info import ConversationInfo
|
||||
|
||||
# from src.individuality.individuality import Individuality # 不再需要
|
||||
from ...config.config import global_config
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
logger = get_module_logger("waiter")
|
||||
|
||||
# --- 在这里设定你想要的超时时间(秒) ---
|
||||
# 例如: 120 秒 = 2 分钟
|
||||
DESIRED_TIMEOUT_SECONDS = 300
|
||||
|
||||
|
||||
class Waiter:
|
||||
"""等待处理类"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
self.name = global_config.BOT_NICKNAME
|
||||
self.private_name = private_name
|
||||
# self.wait_accumulated_time = 0 # 不再需要累加计时
|
||||
|
||||
async def wait(self, conversation_info: ConversationInfo) -> bool:
|
||||
"""等待用户新消息或超时"""
|
||||
wait_start_time = time.time()
|
||||
logger.info(f"[私聊][{self.private_name}]进入常规等待状态 (超时: {DESIRED_TIMEOUT_SECONDS} 秒)...")
|
||||
|
||||
while True:
|
||||
# 检查是否有新消息
|
||||
if self.chat_observer.new_message_after(wait_start_time):
|
||||
logger.info(f"[私聊][{self.private_name}]等待结束,收到新消息")
|
||||
return False # 返回 False 表示不是超时
|
||||
|
||||
# 检查是否超时
|
||||
elapsed_time = time.time() - wait_start_time
|
||||
if elapsed_time > DESIRED_TIMEOUT_SECONDS:
|
||||
logger.info(f"[私聊][{self.private_name}]等待超过 {DESIRED_TIMEOUT_SECONDS} 秒...添加思考目标。")
|
||||
wait_goal = {
|
||||
"goal": f"你等待了{elapsed_time / 60:.1f}分钟,注意可能在对方看来聊天已经结束,思考接下来要做什么",
|
||||
"reasoning": "对方很久没有回复你的消息了",
|
||||
}
|
||||
conversation_info.goal_list.append(wait_goal)
|
||||
logger.info(f"[私聊][{self.private_name}]添加目标: {wait_goal}")
|
||||
return True # 返回 True 表示超时
|
||||
|
||||
await asyncio.sleep(5) # 每 5 秒检查一次
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]等待中..."
|
||||
) # 可以考虑把这个频繁日志注释掉,只在超时或收到消息时输出
|
||||
|
||||
async def wait_listening(self, conversation_info: ConversationInfo) -> bool:
|
||||
"""倾听用户发言或超时"""
|
||||
wait_start_time = time.time()
|
||||
logger.info(f"[私聊][{self.private_name}]进入倾听等待状态 (超时: {DESIRED_TIMEOUT_SECONDS} 秒)...")
|
||||
|
||||
while True:
|
||||
# 检查是否有新消息
|
||||
if self.chat_observer.new_message_after(wait_start_time):
|
||||
logger.info(f"[私聊][{self.private_name}]倾听等待结束,收到新消息")
|
||||
return False # 返回 False 表示不是超时
|
||||
|
||||
# 检查是否超时
|
||||
elapsed_time = time.time() - wait_start_time
|
||||
if elapsed_time > DESIRED_TIMEOUT_SECONDS:
|
||||
logger.info(f"[私聊][{self.private_name}]倾听等待超过 {DESIRED_TIMEOUT_SECONDS} 秒...添加思考目标。")
|
||||
wait_goal = {
|
||||
# 保持 goal 文本一致
|
||||
"goal": f"你等待了{elapsed_time / 60:.1f}分钟,对方似乎话说一半突然消失了,可能忙去了?也可能忘记了回复?要问问吗?还是结束对话?或继续等待?思考接下来要做什么",
|
||||
"reasoning": "对方话说一半消失了,很久没有回复",
|
||||
}
|
||||
conversation_info.goal_list.append(wait_goal)
|
||||
logger.info(f"[私聊][{self.private_name}]添加目标: {wait_goal}")
|
||||
return True # 返回 True 表示超时
|
||||
|
||||
await asyncio.sleep(5) # 每 5 秒检查一次
|
||||
logger.debug(f"[私聊][{self.private_name}]倾听等待中...") # 同上,可以考虑注释掉
|
||||
@@ -16,7 +16,8 @@ from src.chat.brain_chat.brain_planner import BrainPlanner
|
||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.heart_flow.hfc_utils import CycleDetail
|
||||
from src.express.expression_learner import expression_learner_manager
|
||||
from src.bw_learner.expression_learner import expression_learner_manager
|
||||
from src.bw_learner.message_recorder import extract_and_distribute_messages
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.component_types import EventType, ActionInfo
|
||||
from src.plugin_system.core import events_manager
|
||||
@@ -96,6 +97,9 @@ class BrainChatting:
|
||||
|
||||
self.more_plan = False
|
||||
|
||||
# 最近一次是否成功进行了 reply,用于选择 BrainPlanner 的 Prompt
|
||||
self._last_successful_reply: bool = False
|
||||
|
||||
async def start(self):
|
||||
"""检查是否需要启动主循环,如果未激活则启动。"""
|
||||
|
||||
@@ -157,6 +161,7 @@ class BrainChatting:
|
||||
)
|
||||
|
||||
async def _loopbody(self): # sourcery skip: hoist-if-from-if
|
||||
# 获取最新消息(用于上下文,但不影响是否调用 observe)
|
||||
recent_messages_list = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.stream_id,
|
||||
start_time=self.last_read_time,
|
||||
@@ -165,17 +170,25 @@ class BrainChatting:
|
||||
limit_mode="latest",
|
||||
filter_mai=True,
|
||||
filter_command=False,
|
||||
filter_no_read_command=True,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
|
||||
# 如果有新消息,更新 last_read_time
|
||||
if len(recent_messages_list) >= 1:
|
||||
self.last_read_time = time.time()
|
||||
await self._observe(recent_messages_list=recent_messages_list)
|
||||
|
||||
else:
|
||||
# Normal模式:消息数量不足,等待
|
||||
await asyncio.sleep(0.2)
|
||||
return True
|
||||
# 总是执行一次思考迭代(不管有没有新消息)
|
||||
# wait 动作会在其内部等待,不需要在这里处理
|
||||
should_continue = await self._observe(recent_messages_list=recent_messages_list)
|
||||
|
||||
if not should_continue:
|
||||
# 选择了 complete_talk,返回 False 表示需要等待新消息
|
||||
return False
|
||||
|
||||
# 继续下一次迭代(除非选择了 complete_talk)
|
||||
# 短暂等待后再继续,避免过于频繁的循环
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
return True
|
||||
|
||||
async def _send_and_store_reply(
|
||||
@@ -240,7 +253,7 @@ class BrainChatting:
|
||||
# ReflectTracker Check
|
||||
# 在每次回复前检查一次上下文,看是否有反思问题得到了解答
|
||||
# -------------------------------------------------------------------------
|
||||
from src.express.reflect_tracker import reflect_tracker_manager
|
||||
from src.bw_learner.reflect_tracker import reflect_tracker_manager
|
||||
|
||||
tracker = reflect_tracker_manager.get_tracker(self.stream_id)
|
||||
if tracker:
|
||||
@@ -253,13 +266,15 @@ class BrainChatting:
|
||||
# Expression Reflection Check
|
||||
# 检查是否需要提问表达反思
|
||||
# -------------------------------------------------------------------------
|
||||
from src.express.expression_reflector import expression_reflector_manager
|
||||
from src.bw_learner.expression_reflector import expression_reflector_manager
|
||||
|
||||
reflector = expression_reflector_manager.get_or_create_reflector(self.stream_id)
|
||||
asyncio.create_task(reflector.check_and_ask())
|
||||
|
||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
|
||||
# 通过 MessageRecorder 统一提取消息并分发给 expression_learner 和 jargon_miner
|
||||
# 在 replyer 执行时触发,统一管理时间窗口,避免重复获取消息
|
||||
asyncio.create_task(extract_and_distribute_messages(self.stream_id))
|
||||
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
|
||||
@@ -272,14 +287,16 @@ class BrainChatting:
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
|
||||
|
||||
# 执行planner
|
||||
# 获取必要信息
|
||||
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
|
||||
|
||||
# 一次思考迭代:Think - Act - Observe
|
||||
# 获取聊天上下文
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
filter_no_read_command=True,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
@@ -290,12 +307,11 @@ class BrainChatting:
|
||||
)
|
||||
|
||||
prompt_info = await self.action_planner.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
chat_target_info=chat_target_info,
|
||||
current_available_actions=available_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
interest=global_config.personality.interest,
|
||||
prompt_key="brain_planner_prompt_react",
|
||||
)
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
|
||||
@@ -311,7 +327,10 @@ class BrainChatting:
|
||||
available_actions=available_actions,
|
||||
)
|
||||
|
||||
# 3. 并行执行所有动作
|
||||
# 检查是否有 complete_talk 动作(会停止后续迭代)
|
||||
has_complete_talk = any(action.action_type == "complete_talk" for action in action_to_use_info)
|
||||
|
||||
# 并行执行所有动作
|
||||
action_tasks = [
|
||||
asyncio.create_task(
|
||||
self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers)
|
||||
@@ -343,7 +362,14 @@ class BrainChatting:
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 回复动作执行失败")
|
||||
|
||||
# 构建最终的循环信息
|
||||
# 更新观察时间标记
|
||||
self.action_planner.last_obs_time_mark = time.time()
|
||||
|
||||
# 如果选择了 complete_talk,标记为完成,不再继续迭代
|
||||
if has_complete_talk:
|
||||
logger.info(f"{self.log_prefix} 检测到 complete_talk 动作,本次思考完成")
|
||||
|
||||
# 构建循环信息
|
||||
if reply_loop_info:
|
||||
# 如果有回复信息,使用回复的loop_info作为基础
|
||||
loop_info = reply_loop_info
|
||||
@@ -369,10 +395,16 @@ class BrainChatting:
|
||||
}
|
||||
_reply_text = action_reply_text
|
||||
|
||||
# 如果选择了 complete_talk,返回 False 以停止 _loopbody 的循环
|
||||
# 否则返回 True,让 _loopbody 继续下一次迭代
|
||||
should_continue = not has_complete_talk
|
||||
|
||||
self.end_cycle(loop_info, cycle_timers)
|
||||
self.print_cycle_info(cycle_timers)
|
||||
|
||||
return True
|
||||
# 如果选择了 complete_talk,返回 False 停止循环
|
||||
# 否则返回 True,继续下一次思考迭代
|
||||
return should_continue
|
||||
|
||||
async def _main_chat_loop(self):
|
||||
"""主循环,持续进行计划并可能回复消息,直到被外部取消。"""
|
||||
@@ -380,9 +412,13 @@ class BrainChatting:
|
||||
while self.running:
|
||||
# 主循环
|
||||
success = await self._loopbody()
|
||||
await asyncio.sleep(0.1)
|
||||
if not success:
|
||||
break
|
||||
# 选择了 complete,等待新消息
|
||||
logger.info(f"{self.log_prefix} 选择了 complete,等待新消息...")
|
||||
await self._wait_for_new_message()
|
||||
# 有新消息后继续循环
|
||||
continue
|
||||
await asyncio.sleep(0.1)
|
||||
except asyncio.CancelledError:
|
||||
# 设置了关闭标志位后被取消是正常流程
|
||||
logger.info(f"{self.log_prefix} 麦麦已关闭聊天")
|
||||
@@ -393,6 +429,33 @@ class BrainChatting:
|
||||
self._loop_task = asyncio.create_task(self._main_chat_loop())
|
||||
logger.error(f"{self.log_prefix} 结束了当前聊天循环")
|
||||
|
||||
async def _wait_for_new_message(self):
|
||||
"""等待新消息到达"""
|
||||
last_check_time = self.last_read_time
|
||||
check_interval = 1.0 # 每秒检查一次
|
||||
|
||||
while self.running:
|
||||
# 检查是否有新消息
|
||||
recent_messages_list = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.stream_id,
|
||||
start_time=last_check_time,
|
||||
end_time=time.time(),
|
||||
limit=20,
|
||||
limit_mode="latest",
|
||||
filter_mai=True,
|
||||
filter_command=False,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
|
||||
# 如果有新消息,更新 last_read_time 并返回
|
||||
if len(recent_messages_list) >= 1:
|
||||
self.last_read_time = time.time()
|
||||
logger.info(f"{self.log_prefix} 检测到新消息,恢复循环")
|
||||
return
|
||||
|
||||
# 等待一段时间后再次检查
|
||||
await asyncio.sleep(check_interval)
|
||||
|
||||
async def _handle_action(
|
||||
self,
|
||||
action: str,
|
||||
@@ -506,12 +569,12 @@ class BrainChatting:
|
||||
"""执行单个动作的通用函数"""
|
||||
try:
|
||||
with Timer(f"动作{action_planner_info.action_type}", cycle_timers):
|
||||
if action_planner_info.action_type == "no_reply":
|
||||
# 直接处理no_reply逻辑,不再通过动作系统
|
||||
reason = action_planner_info.reasoning or "选择不回复"
|
||||
# logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
|
||||
if action_planner_info.action_type == "complete_talk":
|
||||
# 直接处理complete_talk逻辑,不再通过动作系统
|
||||
reason = action_planner_info.reasoning or "选择完成对话"
|
||||
logger.info(f"{self.log_prefix} 选择完成对话,原因: {reason}")
|
||||
|
||||
# 存储no_reply信息到数据库
|
||||
# 存储complete_talk信息到数据库
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
@@ -519,18 +582,33 @@ class BrainChatting:
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reason": reason},
|
||||
action_name="no_reply",
|
||||
action_name="complete_talk",
|
||||
)
|
||||
return {"action_type": "no_reply", "success": True, "reply_text": "", "command": ""}
|
||||
return {"action_type": "complete_talk", "success": True, "reply_text": "", "command": ""}
|
||||
|
||||
elif action_planner_info.action_type == "reply":
|
||||
try:
|
||||
# 从 Planner 的 action_data 中提取未知词语列表(仅在 reply 时使用)
|
||||
unknown_words = None
|
||||
if isinstance(action_planner_info.action_data, dict):
|
||||
uw = action_planner_info.action_data.get("unknown_words")
|
||||
if isinstance(uw, list):
|
||||
cleaned_uw: List[str] = []
|
||||
for item in uw:
|
||||
if isinstance(item, str):
|
||||
s = item.strip()
|
||||
if s:
|
||||
cleaned_uw.append(s)
|
||||
if cleaned_uw:
|
||||
unknown_words = cleaned_uw
|
||||
|
||||
success, llm_response = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_message=action_planner_info.action_message,
|
||||
available_actions=available_actions,
|
||||
chosen_actions=chosen_action_plan_infos,
|
||||
reply_reason=action_planner_info.reasoning or "",
|
||||
unknown_words=unknown_words,
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type="replyer",
|
||||
from_plugin=False,
|
||||
@@ -543,11 +621,17 @@ class BrainChatting:
|
||||
)
|
||||
else:
|
||||
logger.info("回复生成失败")
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": False,
|
||||
"reply_text": "",
|
||||
"loop_info": None,
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
|
||||
response_set = llm_response.reply_set
|
||||
selected_expressions = llm_response.selected_expressions
|
||||
loop_info, reply_text, _ = await self._send_and_store_reply(
|
||||
@@ -558,6 +642,8 @@ class BrainChatting:
|
||||
actions=chosen_action_plan_infos,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
# 标记这次循环已经成功进行了回复
|
||||
self._last_successful_reply = True
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": True,
|
||||
@@ -567,7 +653,88 @@ class BrainChatting:
|
||||
|
||||
# 其他动作
|
||||
else:
|
||||
# 执行普通动作
|
||||
# 内建 wait / listening:不通过插件系统,直接在这里处理
|
||||
if action_planner_info.action_type in ["wait", "listening"]:
|
||||
reason = action_planner_info.reasoning or ""
|
||||
action_data = action_planner_info.action_data or {}
|
||||
|
||||
if action_planner_info.action_type == "wait":
|
||||
# 获取等待时间(必填)
|
||||
wait_seconds = action_data.get("wait_seconds")
|
||||
if wait_seconds is None:
|
||||
logger.warning(f"{self.log_prefix} wait 动作缺少 wait_seconds 参数,使用默认值 5 秒")
|
||||
wait_seconds = 5
|
||||
else:
|
||||
try:
|
||||
wait_seconds = float(wait_seconds)
|
||||
if wait_seconds < 0:
|
||||
logger.warning(f"{self.log_prefix} wait_seconds 不能为负数,使用默认值 5 秒")
|
||||
wait_seconds = 5
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"{self.log_prefix} wait_seconds 参数格式错误,使用默认值 5 秒")
|
||||
wait_seconds = 5
|
||||
|
||||
logger.info(f"{self.log_prefix} 执行 wait 动作,等待 {wait_seconds} 秒")
|
||||
|
||||
# 记录动作信息
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason or f"等待 {wait_seconds} 秒",
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reason": reason, "wait_seconds": wait_seconds},
|
||||
action_name="wait",
|
||||
)
|
||||
|
||||
# 等待指定时间
|
||||
await asyncio.sleep(wait_seconds)
|
||||
|
||||
logger.info(f"{self.log_prefix} wait 动作完成,继续下一次思考")
|
||||
|
||||
# 这些动作本身不产生文本回复
|
||||
self._last_successful_reply = False
|
||||
return {
|
||||
"action_type": "wait",
|
||||
"success": True,
|
||||
"reply_text": "",
|
||||
"command": "",
|
||||
}
|
||||
|
||||
# listening 已合并到 wait,如果遇到则转换为 wait(向后兼容)
|
||||
elif action_planner_info.action_type == "listening":
|
||||
logger.debug(f"{self.log_prefix} 检测到 listening 动作,已合并到 wait,自动转换")
|
||||
# 使用默认等待时间
|
||||
wait_seconds = 3
|
||||
|
||||
logger.info(f"{self.log_prefix} 执行 listening(转换为 wait)动作,等待 {wait_seconds} 秒")
|
||||
|
||||
# 记录动作信息
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason or f"倾听并等待 {wait_seconds} 秒",
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reason": reason, "wait_seconds": wait_seconds},
|
||||
action_name="listening",
|
||||
)
|
||||
|
||||
# 等待指定时间
|
||||
await asyncio.sleep(wait_seconds)
|
||||
|
||||
logger.info(f"{self.log_prefix} listening 动作完成,继续下一次思考")
|
||||
|
||||
# 这些动作本身不产生文本回复
|
||||
self._last_successful_reply = False
|
||||
return {
|
||||
"action_type": "listening",
|
||||
"success": True,
|
||||
"reply_text": "",
|
||||
"command": "",
|
||||
}
|
||||
|
||||
# 其余动作:走原有插件 Action 体系
|
||||
with Timer("动作执行", cycle_timers):
|
||||
success, reply_text, command = await self._handle_action(
|
||||
action_planner_info.action_type,
|
||||
@@ -577,6 +744,10 @@ class BrainChatting:
|
||||
thinking_id,
|
||||
action_planner_info.action_message,
|
||||
)
|
||||
# 非 reply 类动作执行成功时,清空最近成功回复标记,让下一轮回到 initial Prompt
|
||||
if success and action_planner_info.action_type != "reply":
|
||||
self._last_successful_reply = False
|
||||
|
||||
return {
|
||||
"action_type": action_planner_info.action_type,
|
||||
"success": success,
|
||||
|
||||
@@ -35,12 +35,13 @@ install(extra_lines=3)
|
||||
|
||||
|
||||
def init_prompt():
|
||||
# ReAct 形式的 Planner Prompt
|
||||
Prompt(
|
||||
"""
|
||||
{time_block}
|
||||
{name_block}
|
||||
你的兴趣是:{interest}
|
||||
{chat_context_description},以下是具体的聊天内容
|
||||
|
||||
**聊天内容**
|
||||
{chat_content_block}
|
||||
|
||||
@@ -57,11 +58,35 @@ reply
|
||||
"reason":"回复的原因"
|
||||
}}
|
||||
|
||||
no_reply
|
||||
wait
|
||||
动作描述:
|
||||
等待,保持沉默,等待对方发言
|
||||
暂时不再发言,等待指定时间。适用于以下情况:
|
||||
- 你已经表达清楚一轮,想给对方留出空间
|
||||
- 你感觉对方的话还没说完,或者自己刚刚发了好几条连续消息
|
||||
- 你想要等待一定时间来让对方把话说完,或者等待对方反应
|
||||
- 你想保持安静,专注"听"而不是马上回复
|
||||
请你根据上下文来判断要等待多久,请你灵活判断:
|
||||
- 如果你们交流间隔时间很短,聊的很频繁,不宜等待太久
|
||||
- 如果你们交流间隔时间很长,聊的很少,可以等待较长时间
|
||||
{{
|
||||
"action": "no_reply",
|
||||
"action": "wait",
|
||||
"target_message_id":"想要作为这次等待依据的消息id(通常是对方的最新消息)",
|
||||
"wait_seconds": 等待的秒数(必填,例如:5 表示等待5秒),
|
||||
"reason":"选择等待的原因"
|
||||
}}
|
||||
|
||||
complete_talk
|
||||
动作描述:
|
||||
当前聊天暂时结束了,对方离开,没有更多话题了
|
||||
你可以使用该动作来暂时休息,等待对方有新发言再继续:
|
||||
- 多次wait之后,对方迟迟不回复消息才用
|
||||
- 如果对方只是短暂不回复,应该使用wait而不是complete_talk
|
||||
- 聊天内容显示当前聊天已经结束或者没有新内容时候,选择complete_talk
|
||||
选择此动作后,将不再继续循环思考,直到收到对方的新消息
|
||||
{{
|
||||
"action": "complete_talk",
|
||||
"target_message_id":"触发完成对话的消息id(通常是对方的最新消息)",
|
||||
"reason":"选择完成对话的原因"
|
||||
}}
|
||||
|
||||
{action_options_text}
|
||||
@@ -92,7 +117,7 @@ no_reply
|
||||
```
|
||||
|
||||
""",
|
||||
"brain_planner_prompt",
|
||||
"brain_planner_prompt_react",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
@@ -123,6 +148,9 @@ class BrainPlanner:
|
||||
|
||||
self.last_obs_time_mark = 0.0
|
||||
|
||||
# 计划日志记录
|
||||
self.plan_log: List[Tuple[str, float, List[ActionPlannerInfo]]] = []
|
||||
|
||||
def find_message_by_id(
|
||||
self, message_id: str, message_id_list: List[Tuple[str, "DatabaseMessages"]]
|
||||
) -> Optional["DatabaseMessages"]:
|
||||
@@ -152,10 +180,11 @@ class BrainPlanner:
|
||||
action_planner_infos = []
|
||||
|
||||
try:
|
||||
action = action_json.get("action", "no_reply")
|
||||
action = action_json.get("action", "complete_talk")
|
||||
logger.debug(f"{self.log_prefix}解析动作JSON: action={action}, json={action_json}")
|
||||
reasoning = action_json.get("reason", "未提供原因")
|
||||
action_data = {key: value for key, value in action_json.items() if key not in ["action", "reason"]}
|
||||
# 非no_reply动作需要target_message_id
|
||||
# 非complete_talk动作需要target_message_id
|
||||
target_message = None
|
||||
|
||||
if target_message_id := action_json.get("target_message_id"):
|
||||
@@ -171,16 +200,28 @@ class BrainPlanner:
|
||||
|
||||
# 验证action是否可用
|
||||
available_action_names = [action_name for action_name, _ in current_available_actions]
|
||||
internal_action_names = ["no_reply", "reply", "wait_time"]
|
||||
# 内部保留动作(不依赖插件系统)
|
||||
# 注意:listening 已合并到 wait 中,如果遇到 listening 则转换为 wait
|
||||
internal_action_names = ["complete_talk", "reply", "wait_time", "wait", "listening"]
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}"
|
||||
)
|
||||
|
||||
# 将 listening 转换为 wait(向后兼容)
|
||||
if action == "listening":
|
||||
logger.debug(f"{self.log_prefix}检测到 listening 动作,已合并到 wait,自动转换")
|
||||
action = "wait"
|
||||
|
||||
if action not in internal_action_names and action not in available_action_names:
|
||||
logger.warning(
|
||||
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {available_action_names}),将强制使用 'no_reply'"
|
||||
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (内部动作: {internal_action_names}, 可用插件动作: {available_action_names}),将强制使用 'complete_talk'"
|
||||
)
|
||||
reasoning = (
|
||||
f"LLM 返回了当前不可用的动作 '{action}' (可用: {available_action_names})。原始理由: {reasoning}"
|
||||
)
|
||||
action = "no_reply"
|
||||
action = "complete_talk"
|
||||
logger.warning(f"{self.log_prefix}动作已转换为 complete_talk")
|
||||
|
||||
# 创建ActionPlannerInfo对象
|
||||
# 将列表转换为字典格式
|
||||
@@ -201,7 +242,7 @@ class BrainPlanner:
|
||||
available_actions_dict = dict(current_available_actions)
|
||||
action_planner_infos.append(
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
action_type="complete_talk",
|
||||
reasoning=f"解析单个action时出错: {e}",
|
||||
action_data={},
|
||||
action_message=None,
|
||||
@@ -218,7 +259,7 @@ class BrainPlanner:
|
||||
) -> List[ActionPlannerInfo]:
|
||||
# sourcery skip: use-named-expression
|
||||
"""
|
||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作(ReAct模式)。
|
||||
"""
|
||||
|
||||
# 获取聊天上下文
|
||||
@@ -226,7 +267,7 @@ class BrainPlanner:
|
||||
chat_id=self.chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
filter_no_read_command=True,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
message_id_list: list[Tuple[str, "DatabaseMessages"]] = []
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
@@ -257,18 +298,19 @@ class BrainPlanner:
|
||||
|
||||
logger.debug(f"{self.log_prefix}过滤后有{len(filtered_actions)}个可用动作")
|
||||
|
||||
# 构建包含所有动作的提示词
|
||||
# 构建包含所有动作的提示词:使用统一的 ReAct Prompt
|
||||
prompt_key = "brain_planner_prompt_react"
|
||||
# 这里不记录日志,避免重复打印,由调用方按需控制 log_prompt
|
||||
prompt, message_id_list = await self.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
chat_target_info=chat_target_info,
|
||||
current_available_actions=filtered_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
interest=global_config.personality.interest,
|
||||
prompt_key=prompt_key,
|
||||
)
|
||||
|
||||
# 调用LLM获取决策
|
||||
actions = await self._execute_main_planner(
|
||||
reasoning, actions = await self._execute_main_planner(
|
||||
prompt=prompt,
|
||||
message_id_list=message_id_list,
|
||||
filtered_actions=filtered_actions,
|
||||
@@ -276,16 +318,22 @@ class BrainPlanner:
|
||||
loop_start_time=loop_start_time,
|
||||
)
|
||||
|
||||
# 记录和展示计划日志
|
||||
logger.info(
|
||||
f"{self.log_prefix}Planner: {reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
|
||||
)
|
||||
self.add_plan_log(reasoning, actions)
|
||||
|
||||
return actions
|
||||
|
||||
async def build_planner_prompt(
|
||||
self,
|
||||
is_group_chat: bool,
|
||||
chat_target_info: Optional["TargetPersonInfo"],
|
||||
current_available_actions: Dict[str, ActionInfo],
|
||||
message_id_list: List[Tuple[str, "DatabaseMessages"]],
|
||||
chat_content_block: str = "",
|
||||
interest: str = "",
|
||||
prompt_key: str = "brain_planner_prompt_react",
|
||||
) -> tuple[str, List[Tuple[str, "DatabaseMessages"]]]:
|
||||
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
||||
try:
|
||||
@@ -321,7 +369,7 @@ class BrainPlanner:
|
||||
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
|
||||
|
||||
# 获取主规划器模板并填充
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("brain_planner_prompt")
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async(prompt_key)
|
||||
prompt = planner_prompt_template.format(
|
||||
time_block=time_block,
|
||||
chat_context_description=chat_context_description,
|
||||
@@ -431,17 +479,18 @@ class BrainPlanner:
|
||||
filtered_actions: Dict[str, ActionInfo],
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
loop_start_time: float,
|
||||
) -> List[ActionPlannerInfo]:
|
||||
) -> Tuple[str, List[ActionPlannerInfo]]:
|
||||
"""执行主规划器"""
|
||||
llm_content = None
|
||||
actions: List[ActionPlannerInfo] = []
|
||||
extracted_reasoning = ""
|
||||
|
||||
try:
|
||||
# 调用LLM
|
||||
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
# logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
# logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
|
||||
if global_config.debug.show_planner_prompt:
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
@@ -456,10 +505,11 @@ class BrainPlanner:
|
||||
|
||||
except Exception as req_e:
|
||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||
return [
|
||||
extracted_reasoning = f"LLM 请求失败,模型出现问题: {req_e}"
|
||||
return extracted_reasoning, [
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
|
||||
action_type="complete_talk",
|
||||
reasoning=extracted_reasoning,
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
@@ -469,24 +519,32 @@ class BrainPlanner:
|
||||
# 解析LLM响应
|
||||
if llm_content:
|
||||
try:
|
||||
if json_objects := self._extract_json_from_markdown(llm_content):
|
||||
logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
|
||||
json_objects, extracted_reasoning = self._extract_json_from_markdown(llm_content)
|
||||
if json_objects:
|
||||
logger.info(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
|
||||
for i, json_obj in enumerate(json_objects):
|
||||
logger.info(f"{self.log_prefix}解析第{i + 1}个JSON对象: {json_obj}")
|
||||
filtered_actions_list = list(filtered_actions.items())
|
||||
for json_obj in json_objects:
|
||||
actions.extend(self._parse_single_action(json_obj, message_id_list, filtered_actions_list))
|
||||
parsed_actions = self._parse_single_action(json_obj, message_id_list, filtered_actions_list)
|
||||
logger.info(f"{self.log_prefix}解析后的动作: {[a.action_type for a in parsed_actions]}")
|
||||
actions.extend(parsed_actions)
|
||||
else:
|
||||
# 尝试解析为直接的JSON
|
||||
logger.warning(f"{self.log_prefix}LLM没有返回可用动作: {llm_content}")
|
||||
actions = self._create_no_reply("LLM没有返回可用动作", available_actions)
|
||||
extracted_reasoning = extracted_reasoning or "LLM没有返回可用动作"
|
||||
actions = self._create_complete_talk(extracted_reasoning, available_actions)
|
||||
|
||||
except Exception as json_e:
|
||||
logger.warning(f"{self.log_prefix}解析LLM响应JSON失败 {json_e}. LLM原始输出: '{llm_content}'")
|
||||
actions = self._create_no_reply(f"解析LLM响应JSON失败: {json_e}", available_actions)
|
||||
extracted_reasoning = f"解析LLM响应JSON失败: {json_e}"
|
||||
actions = self._create_complete_talk(extracted_reasoning, available_actions)
|
||||
traceback.print_exc()
|
||||
else:
|
||||
actions = self._create_no_reply("规划器没有获得LLM响应", available_actions)
|
||||
extracted_reasoning = "规划器没有获得LLM响应"
|
||||
actions = self._create_complete_talk(extracted_reasoning, available_actions)
|
||||
|
||||
# 添加循环开始时间到所有非no_reply动作
|
||||
# 添加循环开始时间到所有动作
|
||||
for action in actions:
|
||||
action.action_data = action.action_data or {}
|
||||
action.action_data["loop_start_time"] = loop_start_time
|
||||
@@ -495,13 +553,15 @@ class BrainPlanner:
|
||||
f"{self.log_prefix}规划器决定执行{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
|
||||
)
|
||||
|
||||
return actions
|
||||
return extracted_reasoning, actions
|
||||
|
||||
def _create_no_reply(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
|
||||
"""创建no_reply"""
|
||||
def _create_complete_talk(
|
||||
self, reasoning: str, available_actions: Dict[str, ActionInfo]
|
||||
) -> List[ActionPlannerInfo]:
|
||||
"""创建complete_talk"""
|
||||
return [
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
action_type="complete_talk",
|
||||
reasoning=reasoning,
|
||||
action_data={},
|
||||
action_message=None,
|
||||
@@ -509,33 +569,122 @@ class BrainPlanner:
|
||||
)
|
||||
]
|
||||
|
||||
def _extract_json_from_markdown(self, content: str) -> List[dict]:
|
||||
def add_plan_log(self, reasoning: str, actions: List[ActionPlannerInfo]):
|
||||
"""添加计划日志"""
|
||||
self.plan_log.append((reasoning, time.time(), actions))
|
||||
if len(self.plan_log) > 20:
|
||||
self.plan_log.pop(0)
|
||||
|
||||
def _extract_json_from_markdown(self, content: str) -> Tuple[List[dict], str]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""从Markdown格式的内容中提取JSON对象"""
|
||||
"""从Markdown格式的内容中提取JSON对象和推理内容"""
|
||||
json_objects = []
|
||||
reasoning_content = ""
|
||||
|
||||
# 使用正则表达式查找```json包裹的JSON内容
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, content, re.DOTALL)
|
||||
markdown_matches = re.findall(json_pattern, content, re.DOTALL)
|
||||
|
||||
for match in matches:
|
||||
# 提取JSON之前的内容作为推理文本
|
||||
first_json_pos = len(content)
|
||||
if markdown_matches:
|
||||
# 找到第一个```json的位置
|
||||
first_json_pos = content.find("```json")
|
||||
if first_json_pos > 0:
|
||||
reasoning_content = content[:first_json_pos].strip()
|
||||
# 清理推理内容中的注释标记
|
||||
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
|
||||
reasoning_content = reasoning_content.strip()
|
||||
|
||||
# 处理```json包裹的JSON
|
||||
for match in markdown_matches:
|
||||
try:
|
||||
# 清理可能的注释和格式问题
|
||||
json_str = re.sub(r"//.*?\n", "\n", match) # 移除单行注释
|
||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释
|
||||
if json_str := json_str.strip():
|
||||
json_obj = json.loads(repair_json(json_str))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
# 先尝试将整个块作为一个JSON对象或数组(适用于多行JSON)
|
||||
try:
|
||||
json_obj = json.loads(repair_json(json_str))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
except json.JSONDecodeError:
|
||||
# 如果整个块解析失败,尝试按行分割(适用于多个单行JSON对象)
|
||||
lines = [line.strip() for line in json_str.split("\n") if line.strip()]
|
||||
for line in lines:
|
||||
try:
|
||||
# 尝试解析每一行作为独立的JSON对象
|
||||
json_obj = json.loads(repair_json(line))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
except json.JSONDecodeError:
|
||||
# 单行解析失败,继续下一行
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"解析JSON块失败: {e}, 块内容: {match[:100]}...")
|
||||
logger.warning(f"{self.log_prefix}解析JSON块失败: {e}, 块内容: {match[:100]}...")
|
||||
continue
|
||||
|
||||
return json_objects
|
||||
# 如果没有找到完整的```json```块,尝试查找不完整的代码块(缺少结尾```)
|
||||
if not json_objects:
|
||||
json_start_pos = content.find("```json")
|
||||
if json_start_pos != -1:
|
||||
# 找到```json之后的内容
|
||||
json_content_start = json_start_pos + 7 # ```json的长度
|
||||
# 提取从```json之后到内容结尾的所有内容
|
||||
incomplete_json_str = content[json_content_start:].strip()
|
||||
|
||||
# 提取JSON之前的内容作为推理文本
|
||||
if json_start_pos > 0:
|
||||
reasoning_content = content[:json_start_pos].strip()
|
||||
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
|
||||
reasoning_content = reasoning_content.strip()
|
||||
|
||||
if incomplete_json_str:
|
||||
try:
|
||||
# 清理可能的注释和格式问题
|
||||
json_str = re.sub(r"//.*?\n", "\n", incomplete_json_str)
|
||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL)
|
||||
json_str = json_str.strip()
|
||||
|
||||
if json_str:
|
||||
# 尝试按行分割,每行可能是一个JSON对象
|
||||
lines = [line.strip() for line in json_str.split("\n") if line.strip()]
|
||||
for line in lines:
|
||||
try:
|
||||
json_obj = json.loads(repair_json(line))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 如果按行解析没有成功,尝试将整个块作为一个JSON对象或数组
|
||||
if not json_objects:
|
||||
try:
|
||||
json_obj = json.loads(repair_json(json_str))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
except Exception as e:
|
||||
logger.debug(f"尝试解析不完整的JSON代码块失败: {e}")
|
||||
except Exception as e:
|
||||
logger.debug(f"处理不完整的JSON代码块时出错: {e}")
|
||||
|
||||
return json_objects, reasoning_content
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
@@ -271,7 +271,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
|
||||
|
||||
emoji.description = emoji_data.description
|
||||
# Deserialize emotion string from DB to list
|
||||
emoji.emotion = emoji_data.emotion.split(",") if emoji_data.emotion else []
|
||||
emoji.emotion = emoji_data.emotion.replace(",", ",").split(",") if emoji_data.emotion else []
|
||||
emoji.usage_count = emoji_data.usage_count
|
||||
|
||||
db_last_used_time = emoji_data.last_used_time
|
||||
@@ -732,7 +732,7 @@ class EmojiManager:
|
||||
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
|
||||
if emoji_record and emoji_record.emotion:
|
||||
logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...")
|
||||
return emoji_record.emotion.split(",")
|
||||
return emoji_record.emotion.replace(",", ",").split(",")
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库查询表情包情感标签时出错: {e}")
|
||||
|
||||
@@ -993,7 +993,7 @@ class EmojiManager:
|
||||
)
|
||||
|
||||
# 处理情感列表
|
||||
emotions = [e.strip() for e in emotions_text.split(",") if e.strip()]
|
||||
emotions = [e.strip() for e in emotions_text.replace(",", ",").split(",") if e.strip()]
|
||||
|
||||
# 根据情感标签数量随机选择 - 超过5个选3个,超过2个选2个
|
||||
if len(emotions) > 5:
|
||||
|
||||
@@ -1,163 +0,0 @@
|
||||
from datetime import datetime
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
build_readable_messages,
|
||||
)
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import frequency_api
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt(
|
||||
"""{name_block}
|
||||
{time_block}
|
||||
你现在正在聊天,请根据下面的聊天记录判断是否有用户觉得你的发言过于频繁或者发言过少
|
||||
{message_str}
|
||||
|
||||
如果用户觉得你的发言过于频繁,请输出"过于频繁",否则输出"正常"
|
||||
如果用户觉得你的发言过少,请输出"过少",否则输出"正常"
|
||||
**你只能输出以下三个词之一,不要输出任何其他文字、解释或标点:**
|
||||
- 正常
|
||||
- 过于频繁
|
||||
- 过少
|
||||
""",
|
||||
"frequency_adjust_prompt",
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger("frequency_control")
|
||||
|
||||
|
||||
class FrequencyControl:
|
||||
"""简化的频率控制类,仅管理不同chat_id的频率值"""
|
||||
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id = chat_id
|
||||
# 发言频率调整值
|
||||
self.talk_frequency_adjust: float = 1.0
|
||||
|
||||
self.last_frequency_adjust_time: float = 0.0
|
||||
self.frequency_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="frequency.adjust"
|
||||
)
|
||||
# 频率调整锁,防止并发执行
|
||||
self._adjust_lock = asyncio.Lock()
|
||||
|
||||
def get_talk_frequency_adjust(self) -> float:
|
||||
"""获取发言频率调整值"""
|
||||
return self.talk_frequency_adjust
|
||||
|
||||
def set_talk_frequency_adjust(self, value: float) -> None:
|
||||
"""设置发言频率调整值"""
|
||||
self.talk_frequency_adjust = max(0.1, min(5.0, value))
|
||||
|
||||
async def trigger_frequency_adjust(self) -> None:
|
||||
# 使用异步锁防止并发执行
|
||||
async with self._adjust_lock:
|
||||
# 在锁内检查,避免并发触发
|
||||
current_time = time.time()
|
||||
previous_adjust_time = self.last_frequency_adjust_time
|
||||
|
||||
msg_list = get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=previous_adjust_time,
|
||||
timestamp_end=current_time,
|
||||
)
|
||||
|
||||
if current_time - previous_adjust_time < 160 or len(msg_list) <= 20:
|
||||
return
|
||||
|
||||
# 立即更新调整时间,防止并发触发
|
||||
self.last_frequency_adjust_time = current_time
|
||||
|
||||
try:
|
||||
new_msg_list = get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=previous_adjust_time,
|
||||
timestamp_end=current_time,
|
||||
limit=20,
|
||||
limit_mode="latest",
|
||||
)
|
||||
|
||||
message_str = build_readable_messages(
|
||||
new_msg_list,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=False,
|
||||
)
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
bot_name = global_config.bot.nickname
|
||||
bot_nickname = (
|
||||
f",也有人叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else ""
|
||||
)
|
||||
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"frequency_adjust_prompt",
|
||||
name_block=name_block,
|
||||
time_block=time_block,
|
||||
message_str=message_str,
|
||||
)
|
||||
response, (reasoning_content, _, _) = await self.frequency_model.generate_response_async(
|
||||
prompt,
|
||||
)
|
||||
|
||||
# logger.info(f"频率调整 prompt: {prompt}")
|
||||
# logger.info(f"频率调整 response: {response}")
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"频率调整 prompt: {prompt}")
|
||||
logger.info(f"频率调整 response: {response}")
|
||||
logger.info(f"频率调整 reasoning_content: {reasoning_content}")
|
||||
|
||||
final_value_by_api = frequency_api.get_current_talk_value(self.chat_id)
|
||||
|
||||
# LLM依然输出过多内容时取消本次调整。合法最多4个字,但有的模型可能会输出一些markdown换行符等,需要长度宽限
|
||||
if len(response) < 20:
|
||||
if "过于频繁" in response:
|
||||
logger.info(f"频率调整: 过于频繁,调整值到{final_value_by_api}")
|
||||
self.talk_frequency_adjust = max(0.1, min(1.5, self.talk_frequency_adjust * 0.8))
|
||||
elif "过少" in response:
|
||||
logger.info(f"频率调整: 过少,调整值到{final_value_by_api}")
|
||||
self.talk_frequency_adjust = max(0.1, min(1.5, self.talk_frequency_adjust * 1.2))
|
||||
except Exception as e:
|
||||
logger.error(f"频率调整失败: {e}")
|
||||
# 即使失败也保持时间戳更新,避免频繁重试
|
||||
|
||||
|
||||
class FrequencyControlManager:
|
||||
"""频率控制管理器,管理多个聊天流的频率控制实例"""
|
||||
|
||||
def __init__(self):
|
||||
self.frequency_control_dict: Dict[str, FrequencyControl] = {}
|
||||
|
||||
def get_or_create_frequency_control(self, chat_id: str) -> FrequencyControl:
|
||||
"""获取或创建指定聊天流的频率控制实例"""
|
||||
if chat_id not in self.frequency_control_dict:
|
||||
self.frequency_control_dict[chat_id] = FrequencyControl(chat_id)
|
||||
return self.frequency_control_dict[chat_id]
|
||||
|
||||
def remove_frequency_control(self, chat_id: str) -> bool:
|
||||
"""移除指定聊天流的频率控制实例"""
|
||||
if chat_id in self.frequency_control_dict:
|
||||
del self.frequency_control_dict[chat_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_all_chat_ids(self) -> list[str]:
|
||||
"""获取所有有频率控制的聊天ID"""
|
||||
return list(self.frequency_control_dict.keys())
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
# 创建全局实例
|
||||
frequency_control_manager = FrequencyControlManager()
|
||||
50
src/chat/heart_flow/frequency_control.py
Normal file
50
src/chat/heart_flow/frequency_control.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("frequency_control")
|
||||
|
||||
|
||||
class FrequencyControl:
|
||||
"""简化的频率控制类,仅管理不同chat_id的频率值"""
|
||||
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id = chat_id
|
||||
# 发言频率调整值
|
||||
self.talk_frequency_adjust: float = 1.0
|
||||
|
||||
def get_talk_frequency_adjust(self) -> float:
|
||||
"""获取发言频率调整值"""
|
||||
return self.talk_frequency_adjust
|
||||
|
||||
def set_talk_frequency_adjust(self, value: float) -> None:
|
||||
"""设置发言频率调整值"""
|
||||
self.talk_frequency_adjust = max(0.1, min(5.0, value))
|
||||
|
||||
|
||||
class FrequencyControlManager:
|
||||
"""频率控制管理器,管理多个聊天流的频率控制实例"""
|
||||
|
||||
def __init__(self):
|
||||
self.frequency_control_dict: Dict[str, FrequencyControl] = {}
|
||||
|
||||
def get_or_create_frequency_control(self, chat_id: str) -> FrequencyControl:
|
||||
"""获取或创建指定聊天流的频率控制实例"""
|
||||
if chat_id not in self.frequency_control_dict:
|
||||
self.frequency_control_dict[chat_id] = FrequencyControl(chat_id)
|
||||
return self.frequency_control_dict[chat_id]
|
||||
|
||||
def remove_frequency_control(self, chat_id: str) -> bool:
|
||||
"""移除指定聊天流的频率控制实例"""
|
||||
if chat_id in self.frequency_control_dict:
|
||||
del self.frequency_control_dict[chat_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_all_chat_ids(self) -> list[str]:
|
||||
"""获取所有有频率控制的聊天ID"""
|
||||
return list(self.frequency_control_dict.keys())
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
frequency_control_manager = FrequencyControlManager()
|
||||
@@ -16,11 +16,11 @@ from src.chat.planner_actions.planner import ActionPlanner
|
||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.heart_flow.hfc_utils import CycleDetail
|
||||
from src.express.expression_learner import expression_learner_manager
|
||||
from src.chat.frequency_control.frequency_control import frequency_control_manager
|
||||
from src.express.reflect_tracker import reflect_tracker_manager
|
||||
from src.express.expression_reflector import expression_reflector_manager
|
||||
from src.jargon import extract_and_store_jargon
|
||||
from src.bw_learner.expression_learner import expression_learner_manager
|
||||
from src.chat.heart_flow.frequency_control import frequency_control_manager
|
||||
from src.bw_learner.reflect_tracker import reflect_tracker_manager
|
||||
from src.bw_learner.expression_reflector import expression_reflector_manager
|
||||
from src.bw_learner.message_recorder import extract_and_distribute_messages
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.component_types import EventType, ActionInfo
|
||||
from src.plugin_system.core import events_manager
|
||||
@@ -29,6 +29,7 @@ from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
)
|
||||
from src.chat.utils.utils import record_replyer_action_temp
|
||||
from src.hippo_memorizer.chat_history_summarizer import ChatHistorySummarizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -99,7 +100,6 @@ class HeartFChatting:
|
||||
self._current_cycle_detail: CycleDetail = None # type: ignore
|
||||
|
||||
self.last_read_time = time.time() - 2
|
||||
self.no_reply_until_call = False
|
||||
|
||||
self.is_mute = False
|
||||
|
||||
@@ -190,7 +190,7 @@ class HeartFChatting:
|
||||
limit_mode="latest",
|
||||
filter_mai=True,
|
||||
filter_command=False,
|
||||
filter_no_read_command=True,
|
||||
filter_intercept_message_level=0,
|
||||
)
|
||||
|
||||
# 根据连续 no_reply 次数动态调整阈值
|
||||
@@ -207,23 +207,6 @@ class HeartFChatting:
|
||||
if len(recent_messages_list) >= threshold:
|
||||
# for message in recent_messages_list:
|
||||
# print(message.processed_plain_text)
|
||||
# !处理no_reply_until_call逻辑
|
||||
if self.no_reply_until_call:
|
||||
for message in recent_messages_list:
|
||||
if (
|
||||
message.is_mentioned
|
||||
or message.is_at
|
||||
or len(recent_messages_list) >= 8
|
||||
or time.time() - self.last_read_time > 600
|
||||
):
|
||||
self.no_reply_until_call = False
|
||||
self.last_read_time = time.time()
|
||||
break
|
||||
# 没有提到,继续保持沉默
|
||||
if self.no_reply_until_call:
|
||||
# logger.info(f"{self.log_prefix} 没有提到,继续保持沉默")
|
||||
await asyncio.sleep(1)
|
||||
return True
|
||||
|
||||
self.last_read_time = time.time()
|
||||
|
||||
@@ -303,90 +286,6 @@ class HeartFChatting:
|
||||
|
||||
return loop_info, reply_text, cycle_timers
|
||||
|
||||
async def _run_planner_without_reply(
|
||||
self,
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
cycle_timers: Dict[str, float],
|
||||
) -> List[ActionPlannerInfo]:
|
||||
"""执行planner,但不包含reply动作(用于并行执行场景,提及时使用简化版提示词)"""
|
||||
try:
|
||||
with Timer("规划器", cycle_timers):
|
||||
action_to_use_info = await self.action_planner.plan(
|
||||
loop_start_time=self.last_read_time,
|
||||
available_actions=available_actions,
|
||||
is_mentioned=True, # 标记为提及时,使用简化版提示词
|
||||
)
|
||||
# 过滤掉reply动作(虽然提及时不应该有reply,但为了安全还是过滤一下)
|
||||
return [action for action in action_to_use_info if action.action_type != "reply"]
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} Planner执行失败: {e}")
|
||||
traceback.print_exc()
|
||||
return []
|
||||
|
||||
async def _generate_mentioned_reply(
|
||||
self,
|
||||
force_reply_message: "DatabaseMessages",
|
||||
thinking_id: str,
|
||||
cycle_timers: Dict[str, float],
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
) -> Dict[str, Any]:
|
||||
"""当被提及时,独立生成回复的任务"""
|
||||
try:
|
||||
self.questioned = False
|
||||
# 重置连续 no_reply 计数
|
||||
self.consecutive_no_reply_count = 0
|
||||
reason = ""
|
||||
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={},
|
||||
action_name="reply",
|
||||
action_reasoning=reason,
|
||||
)
|
||||
|
||||
with Timer("提及回复生成", cycle_timers):
|
||||
success, llm_response = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_message=force_reply_message,
|
||||
available_actions=available_actions,
|
||||
chosen_actions=[], # 独立回复,不依赖planner的动作
|
||||
reply_reason=reason,
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type="replyer",
|
||||
from_plugin=False,
|
||||
reply_time_point=self.last_read_time,
|
||||
)
|
||||
|
||||
if not success or not llm_response or not llm_response.reply_set:
|
||||
logger.warning(f"{self.log_prefix} 提及回复生成失败")
|
||||
return {"action_type": "reply", "success": False, "result": "提及回复生成失败", "loop_info": None}
|
||||
|
||||
response_set = llm_response.reply_set
|
||||
selected_expressions = llm_response.selected_expressions
|
||||
loop_info, reply_text, _ = await self._send_and_store_reply(
|
||||
response_set=response_set,
|
||||
action_message=force_reply_message,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
actions=[], # 独立回复,不依赖planner的动作
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
self.last_active_time = time.time()
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": True,
|
||||
"result": f"你回复内容{reply_text}",
|
||||
"loop_info": loop_info,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 提及回复生成异常: {e}")
|
||||
traceback.print_exc()
|
||||
return {"action_type": "reply", "success": False, "result": f"提及回复生成异常: {e}", "loop_info": None}
|
||||
|
||||
async def _observe(
|
||||
self, # interest_value: float = 0.0,
|
||||
recent_messages_list: Optional[List["DatabaseMessages"]] = None,
|
||||
@@ -412,15 +311,12 @@ class HeartFChatting:
|
||||
|
||||
start_time = time.time()
|
||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
|
||||
asyncio.create_task(
|
||||
frequency_control_manager.get_or_create_frequency_control(self.stream_id).trigger_frequency_adjust()
|
||||
)
|
||||
# 通过 MessageRecorder 统一提取消息并分发给 expression_learner 和 jargon_miner
|
||||
# 在 replyer 执行时触发,统一管理时间窗口,避免重复获取消息
|
||||
asyncio.create_task(extract_and_distribute_messages(self.stream_id))
|
||||
|
||||
# 添加curious检测任务 - 检测聊天记录中的矛盾、冲突或需要提问的内容
|
||||
# asyncio.create_task(check_and_make_question(self.stream_id))
|
||||
# 添加jargon提取任务 - 提取聊天中的黑话/俚语并入库(内部自行取消息并带冷却)
|
||||
asyncio.create_task(extract_and_store_jargon(self.stream_id))
|
||||
# 添加聊天内容概括任务 - 累积、打包和压缩聊天记录
|
||||
# 注意:后台循环已在start()中启动,这里作为额外触发点,在有思考时立即处理
|
||||
# asyncio.create_task(self.chat_history_summarizer.process())
|
||||
@@ -438,95 +334,50 @@ class HeartFChatting:
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
|
||||
|
||||
# 如果被提及,让回复生成和planner并行执行
|
||||
if force_reply_message:
|
||||
logger.info(f"{self.log_prefix} 检测到提及,回复生成与planner并行执行")
|
||||
# 执行planner
|
||||
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
|
||||
|
||||
# 并行执行planner和回复生成
|
||||
planner_task = asyncio.create_task(
|
||||
self._run_planner_without_reply(
|
||||
available_actions=available_actions,
|
||||
cycle_timers=cycle_timers,
|
||||
)
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=self.action_planner.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
prompt_info = await self.action_planner.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
chat_target_info=chat_target_info,
|
||||
current_available_actions=available_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
)
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
|
||||
)
|
||||
if not continue_flag:
|
||||
return False
|
||||
if modified_message and modified_message._modify_flags.modify_llm_prompt:
|
||||
prompt_info = (modified_message.llm_prompt, prompt_info[1])
|
||||
|
||||
with Timer("规划器", cycle_timers):
|
||||
action_to_use_info = await self.action_planner.plan(
|
||||
loop_start_time=self.last_read_time,
|
||||
available_actions=available_actions,
|
||||
force_reply_message=force_reply_message,
|
||||
)
|
||||
reply_task = asyncio.create_task(
|
||||
self._generate_mentioned_reply(
|
||||
force_reply_message=force_reply_message,
|
||||
thinking_id=thinking_id,
|
||||
cycle_timers=cycle_timers,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
)
|
||||
|
||||
# 等待两个任务完成
|
||||
planner_result, reply_result = await asyncio.gather(planner_task, reply_task, return_exceptions=True)
|
||||
|
||||
# 处理planner结果
|
||||
if isinstance(planner_result, BaseException):
|
||||
logger.error(f"{self.log_prefix} Planner执行异常: {planner_result}")
|
||||
action_to_use_info = []
|
||||
else:
|
||||
action_to_use_info = planner_result
|
||||
|
||||
# 处理回复结果
|
||||
if isinstance(reply_result, BaseException):
|
||||
logger.error(f"{self.log_prefix} 回复生成异常: {reply_result}")
|
||||
reply_result = {
|
||||
"action_type": "reply",
|
||||
"success": False,
|
||||
"result": "回复生成异常",
|
||||
"loop_info": None,
|
||||
}
|
||||
else:
|
||||
# 正常流程:只执行planner
|
||||
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
|
||||
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
filter_no_read_command=True,
|
||||
)
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=self.action_planner.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
prompt_info = await self.action_planner.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
chat_target_info=chat_target_info,
|
||||
current_available_actions=available_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
interest=global_config.personality.interest,
|
||||
)
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
|
||||
)
|
||||
if not continue_flag:
|
||||
return False
|
||||
if modified_message and modified_message._modify_flags.modify_llm_prompt:
|
||||
prompt_info = (modified_message.llm_prompt, prompt_info[1])
|
||||
|
||||
with Timer("规划器", cycle_timers):
|
||||
action_to_use_info = await self.action_planner.plan(
|
||||
loop_start_time=self.last_read_time,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
reply_result = None
|
||||
|
||||
# 只在提及情况下过滤掉planner返回的reply动作(提及时已有独立回复生成)
|
||||
if force_reply_message:
|
||||
action_to_use_info = [action for action in action_to_use_info if action.action_type != "reply"]
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 决定执行{len(action_to_use_info)}个动作: {' '.join([a.action_type for a in action_to_use_info])}"
|
||||
)
|
||||
|
||||
# 3. 并行执行所有动作(不包括reply,reply已经独立执行)
|
||||
# 3. 并行执行所有动作
|
||||
action_tasks = [
|
||||
asyncio.create_task(
|
||||
self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers)
|
||||
@@ -537,10 +388,6 @@ class HeartFChatting:
|
||||
# 并行执行所有任务
|
||||
results = await asyncio.gather(*action_tasks, return_exceptions=True)
|
||||
|
||||
# 如果有独立的回复结果,添加到结果列表中
|
||||
if reply_result:
|
||||
results = list(results) + [reply_result]
|
||||
|
||||
# 处理执行结果
|
||||
reply_loop_info = None
|
||||
reply_text_from_reply = ""
|
||||
@@ -751,31 +598,6 @@ class HeartFChatting:
|
||||
|
||||
return {"action_type": "no_reply", "success": True, "result": "选择不回复", "command": ""}
|
||||
|
||||
elif action_planner_info.action_type == "no_reply_until_call":
|
||||
# 直接当场执行no_reply_until_call逻辑
|
||||
logger.info(f"{self.log_prefix} 保持沉默,直到有人直接叫的名字")
|
||||
reason = action_planner_info.reasoning or "选择不回复"
|
||||
|
||||
# 增加连续 no_reply 计数
|
||||
self.consecutive_no_reply_count += 1
|
||||
self.no_reply_until_call = True
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={},
|
||||
action_name="no_reply_until_call",
|
||||
action_reasoning=reason,
|
||||
)
|
||||
return {
|
||||
"action_type": "no_reply_until_call",
|
||||
"success": True,
|
||||
"result": "保持沉默,直到有人直接叫的名字",
|
||||
"command": "",
|
||||
}
|
||||
|
||||
elif action_planner_info.action_type == "reply":
|
||||
# 直接当场执行reply逻辑
|
||||
self.questioned = False
|
||||
@@ -784,8 +606,27 @@ class HeartFChatting:
|
||||
self.consecutive_no_reply_count = 0
|
||||
|
||||
reason = action_planner_info.reasoning or ""
|
||||
# 根据 think_mode 配置决定 think_level 的值
|
||||
think_mode = global_config.chat.think_mode
|
||||
if think_mode == "default":
|
||||
think_level = 0
|
||||
elif think_mode == "deep":
|
||||
think_level = 1
|
||||
elif think_mode == "dynamic":
|
||||
# dynamic 模式:从 planner 返回的 action_data 中获取
|
||||
think_level = action_planner_info.action_data.get("think_level", 1)
|
||||
else:
|
||||
# 默认使用 default 模式
|
||||
think_level = 0
|
||||
# 使用 action_reasoning(planner 的整体思考理由)作为 reply_reason
|
||||
planner_reasoning = action_planner_info.action_reasoning or reason
|
||||
|
||||
record_replyer_action_temp(
|
||||
chat_id=self.stream_id,
|
||||
reason=reason,
|
||||
think_level=think_level,
|
||||
)
|
||||
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
@@ -797,16 +638,32 @@ class HeartFChatting:
|
||||
action_reasoning=reason,
|
||||
)
|
||||
|
||||
# 从 Planner 的 action_data 中提取未知词语列表(仅在 reply 时使用)
|
||||
unknown_words = None
|
||||
if isinstance(action_planner_info.action_data, dict):
|
||||
uw = action_planner_info.action_data.get("unknown_words")
|
||||
if isinstance(uw, list):
|
||||
cleaned_uw: List[str] = []
|
||||
for item in uw:
|
||||
if isinstance(item, str):
|
||||
s = item.strip()
|
||||
if s:
|
||||
cleaned_uw.append(s)
|
||||
if cleaned_uw:
|
||||
unknown_words = cleaned_uw
|
||||
|
||||
success, llm_response = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_message=action_planner_info.action_message,
|
||||
available_actions=available_actions,
|
||||
chosen_actions=chosen_action_plan_infos,
|
||||
reply_reason=planner_reasoning,
|
||||
unknown_words=unknown_words,
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type="replyer",
|
||||
from_plugin=False,
|
||||
reply_time_point=action_planner_info.action_data.get("loop_start_time", time.time()),
|
||||
think_level=think_level,
|
||||
)
|
||||
|
||||
if not success or not llm_response or not llm_response.reply_set:
|
||||
|
||||
@@ -42,7 +42,10 @@ def lpmm_start_up(): # sourcery skip: extract-duplicate-method
|
||||
logger.info("创建LLM客户端")
|
||||
|
||||
# 初始化Embedding库
|
||||
embed_manager = EmbeddingManager()
|
||||
embed_manager = EmbeddingManager(
|
||||
max_workers=global_config.lpmm_knowledge.max_embedding_workers,
|
||||
chunk_size=global_config.lpmm_knowledge.embedding_chunk_size,
|
||||
)
|
||||
logger.info("正在从文件加载Embedding库")
|
||||
try:
|
||||
embed_manager.load_from_file()
|
||||
|
||||
@@ -104,7 +104,9 @@ class EmbeddingStore:
|
||||
self.dir = dir_path
|
||||
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
|
||||
self.index_file_path = f"{dir_path}/{namespace}.index"
|
||||
self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json"
|
||||
self.idx2hash_file_path = f"{dir_path}/{namespace}_i2h.json"
|
||||
|
||||
self.dirty = False # 标记是否有新增数据需要重建索引
|
||||
|
||||
# 多线程配置参数验证和设置
|
||||
self.max_workers = max(MIN_WORKERS, min(MAX_WORKERS, max_workers))
|
||||
@@ -125,6 +127,11 @@ class EmbeddingStore:
|
||||
self.faiss_index = None
|
||||
self.idx2hash = None
|
||||
|
||||
@staticmethod
|
||||
def hash_texts(namespace: str, texts: List[str]) -> List[str]:
|
||||
"""将原文计算为带前缀的键"""
|
||||
return [f"{namespace}-{get_sha256(t)}" for t in texts]
|
||||
|
||||
def _get_embedding(self, s: str) -> List[float]:
|
||||
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
|
||||
# 创建新的事件循环并在完成后立即关闭
|
||||
@@ -412,6 +419,7 @@ class EmbeddingStore:
|
||||
item_hash = self.namespace + "-" + get_sha256(s)
|
||||
if embedding: # 只有成功获取到嵌入才存入
|
||||
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
|
||||
self.dirty = True
|
||||
else:
|
||||
logger.warning(f"跳过存储失败的嵌入: {s[:50]}...")
|
||||
|
||||
@@ -488,9 +496,17 @@ class EmbeddingStore:
|
||||
self.build_faiss_index()
|
||||
logger.info(f"{self.namespace}嵌入库的FaissIndex重建成功")
|
||||
self.save_to_file()
|
||||
self.dirty = False
|
||||
|
||||
def build_faiss_index(self) -> None:
|
||||
"""重新构建Faiss索引,以余弦相似度为度量"""
|
||||
# 空库直接跳过,清空索引映射
|
||||
if not self.store:
|
||||
self.idx2hash = {}
|
||||
self.faiss_index = None
|
||||
self.dirty = False
|
||||
return
|
||||
|
||||
# 获取所有的embedding
|
||||
array = []
|
||||
self.idx2hash = dict()
|
||||
@@ -498,11 +514,44 @@ class EmbeddingStore:
|
||||
array.append(self.store[key].embedding)
|
||||
self.idx2hash[str(len(array) - 1)] = key
|
||||
embeddings = np.array(array, dtype=np.float32)
|
||||
if embeddings.size == 0:
|
||||
self.idx2hash = {}
|
||||
self.faiss_index = None
|
||||
self.dirty = False
|
||||
return
|
||||
# L2归一化
|
||||
faiss.normalize_L2(embeddings)
|
||||
# 构建索引
|
||||
self.faiss_index = faiss.IndexFlatIP(global_config.lpmm_knowledge.embedding_dimension)
|
||||
self.faiss_index.add(embeddings)
|
||||
self.dirty = False
|
||||
|
||||
def delete_items(self, hashes: List[str]) -> Tuple[int, int]:
|
||||
"""删除指定键的嵌入并重建 idx2hash(不直接重建 faiss)
|
||||
|
||||
Args:
|
||||
hashes: 需要删除的完整键列表(如 paragraph-xxx)
|
||||
|
||||
Returns:
|
||||
(deleted, skipped)
|
||||
"""
|
||||
deleted = 0
|
||||
skipped = 0
|
||||
for h in hashes:
|
||||
if h in self.store:
|
||||
self.store.pop(h)
|
||||
deleted += 1
|
||||
else:
|
||||
skipped += 1
|
||||
|
||||
# 重新构建 idx2hash 映射
|
||||
self.idx2hash = {}
|
||||
for idx, key in enumerate(self.store.keys()):
|
||||
self.idx2hash[str(idx)] = key
|
||||
|
||||
# 删除后标记 dirty,faiss 重建由上层统一调用
|
||||
self.dirty = True
|
||||
return deleted, skipped
|
||||
|
||||
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:
|
||||
"""搜索最相似的k个项,以余弦相似度为度量
|
||||
@@ -536,7 +585,7 @@ class EmbeddingStore:
|
||||
|
||||
|
||||
class EmbeddingManager:
|
||||
def __init__(self, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
|
||||
def __init__(self, max_workers: int | None = None, chunk_size: int | None = None):
|
||||
"""
|
||||
初始化EmbeddingManager
|
||||
|
||||
@@ -544,6 +593,8 @@ class EmbeddingManager:
|
||||
max_workers: 最大线程数
|
||||
chunk_size: 每个线程处理的数据块大小
|
||||
"""
|
||||
max_workers = max_workers if max_workers is not None else global_config.lpmm_knowledge.max_embedding_workers
|
||||
chunk_size = chunk_size if chunk_size is not None else global_config.lpmm_knowledge.embedding_chunk_size
|
||||
self.paragraphs_embedding_store = EmbeddingStore(
|
||||
"paragraph", # type: ignore
|
||||
EMBEDDING_DATA_DIR_STR,
|
||||
@@ -617,7 +668,19 @@ class EmbeddingManager:
|
||||
self.relation_embedding_store.save_to_file()
|
||||
|
||||
def rebuild_faiss_index(self):
|
||||
"""重建Faiss索引(请在添加新数据后调用)"""
|
||||
self.paragraphs_embedding_store.build_faiss_index()
|
||||
self.entities_embedding_store.build_faiss_index()
|
||||
self.relation_embedding_store.build_faiss_index()
|
||||
"""重建Faiss索引,新增数据后调用,带跳过逻辑"""
|
||||
|
||||
def _rebuild_if_needed(store: EmbeddingStore):
|
||||
if (
|
||||
not store.dirty
|
||||
and store.faiss_index is not None
|
||||
and store.idx2hash is not None
|
||||
and getattr(store.faiss_index, "ntotal", 0) == len(store.idx2hash) == len(store.store)
|
||||
):
|
||||
logger.info(f"{store.namespace} FaissIndex 已是最新,跳过重建")
|
||||
return
|
||||
store.build_faiss_index()
|
||||
|
||||
_rebuild_if_needed(self.paragraphs_embedding_store)
|
||||
_rebuild_if_needed(self.entities_embedding_store)
|
||||
_rebuild_if_needed(self.relation_embedding_store)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Tuple, Set
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -98,6 +99,28 @@ class KGManager:
|
||||
# 加载KG
|
||||
self.graph = di_graph.load_from_file(self.graph_data_path)
|
||||
|
||||
def _rebuild_metadata_from_graph(self) -> None:
|
||||
"""根据当前图重建 stored_paragraph_hashes 与 ent_appear_cnt"""
|
||||
nodes = self.graph.get_node_list()
|
||||
edges = self.graph.get_edge_list()
|
||||
|
||||
# 段落 hash:paragraph-{hash}
|
||||
self.stored_paragraph_hashes = set()
|
||||
for node_id in nodes:
|
||||
if node_id.startswith("paragraph-"):
|
||||
self.stored_paragraph_hashes.add(node_id.split("paragraph-", 1)[1])
|
||||
|
||||
# 实体出现次数:基于 entity -> paragraph 的边权
|
||||
ent_appear_cnt: Dict[str, float] = {}
|
||||
for edge_tuple in edges:
|
||||
src, tgt = edge_tuple[0], edge_tuple[1]
|
||||
if src.startswith("entity") and tgt.startswith("paragraph"):
|
||||
edge_data = self.graph[src, tgt]
|
||||
weight = edge_data["weight"] if "weight" in edge_data else 1.0
|
||||
ent_appear_cnt[src] = ent_appear_cnt.get(src, 0.0) + float(weight)
|
||||
|
||||
self.ent_appear_cnt = ent_appear_cnt
|
||||
|
||||
def _build_edges_between_ent(
|
||||
self,
|
||||
node_to_node: Dict[Tuple[str, str], float],
|
||||
@@ -149,6 +172,13 @@ class KGManager:
|
||||
ent_hash_list.add("entity" + "-" + get_sha256(triple[0]))
|
||||
ent_hash_list.add("entity" + "-" + get_sha256(triple[2]))
|
||||
ent_hash_list = list(ent_hash_list)
|
||||
# 性能保护:限制同义连接的实体数量
|
||||
max_synonym_entities = global_config.lpmm_knowledge.max_synonym_entities
|
||||
if max_synonym_entities and len(ent_hash_list) > max_synonym_entities:
|
||||
logger.warning(
|
||||
f"同义连接实体数 {len(ent_hash_list)} 超过阈值 {max_synonym_entities},跳过同义边构建以保护性能"
|
||||
)
|
||||
return 0
|
||||
|
||||
synonym_hash_set = set()
|
||||
synonym_result = {}
|
||||
@@ -328,6 +358,10 @@ class KGManager:
|
||||
paragraph_search_result: ParagraphEmbedding的搜索结果(paragraph_hash, similarity)
|
||||
embed_manager: EmbeddingManager对象
|
||||
"""
|
||||
# 性能保护:关闭时直接返回向量检索结果
|
||||
if not global_config.lpmm_knowledge.enable_ppr:
|
||||
logger.info("PPR 已禁用,使用纯向量检索结果")
|
||||
return paragraph_search_result, None
|
||||
# 图中存在的节点总集
|
||||
existed_nodes = self.graph.get_node_list()
|
||||
|
||||
@@ -357,7 +391,15 @@ class KGManager:
|
||||
ent_mean_scores = {} # 记录实体的平均相似度
|
||||
for ent_hash, scores in ent_sim_scores.items():
|
||||
# 先对相似度进行累加,然后与实体计数相除获取最终权重
|
||||
ent_weights[ent_hash] = float(np.sum(scores)) / self.ent_appear_cnt[ent_hash]
|
||||
# 保护:有些实体在当前图中可能只有实体-实体关系,不会出现在 ent_appear_cnt 中
|
||||
appear_cnt = self.ent_appear_cnt.get(ent_hash)
|
||||
if not appear_cnt or appear_cnt <= 0:
|
||||
logger.debug(
|
||||
f"实体 {ent_hash} 在 ent_appear_cnt 中不存在或计数为 0,"
|
||||
f"将使用 1.0 作为默认出现次数参与权重计算"
|
||||
)
|
||||
appear_cnt = 1.0
|
||||
ent_weights[ent_hash] = float(np.sum(scores)) / float(appear_cnt)
|
||||
# 记录实体的平均相似度,用于后续的top_k筛选
|
||||
ent_mean_scores[ent_hash] = float(np.mean(scores))
|
||||
del ent_sim_scores
|
||||
@@ -434,3 +476,115 @@ class KGManager:
|
||||
passage_node_res = sorted(passage_node_res, key=lambda item: item[1], reverse=True)
|
||||
|
||||
return passage_node_res, ppr_node_weights
|
||||
|
||||
def delete_paragraphs(
|
||||
self,
|
||||
pg_hashes: List[str],
|
||||
ent_hashes: List[str] | None = None,
|
||||
remove_orphan_entities: bool = False,
|
||||
) -> Dict[str, int]:
|
||||
"""删除段落/实体节点及相关边(基于 GraphML),可选清理孤立实体,并重建元数据"""
|
||||
# 要删除的节点 ID
|
||||
nodes_to_delete: Set[str] = {f"paragraph-{h}" for h in pg_hashes}
|
||||
if ent_hashes:
|
||||
nodes_to_delete.update({f"entity-{h}" for h in ent_hashes})
|
||||
|
||||
if not os.path.exists(self.graph_data_path):
|
||||
raise FileNotFoundError(f"KG图文件{self.graph_data_path}不存在")
|
||||
|
||||
tree = ET.parse(self.graph_data_path)
|
||||
root = tree.getroot()
|
||||
|
||||
# GraphML 可能带命名空间,用尾缀判断
|
||||
def is_node(elem: ET.Element) -> bool:
|
||||
return elem.tag.endswith("node")
|
||||
|
||||
def is_edge(elem: ET.Element) -> bool:
|
||||
return elem.tag.endswith("edge")
|
||||
|
||||
graph_elem = None
|
||||
for child in root:
|
||||
if child.tag.endswith("graph"):
|
||||
graph_elem = child
|
||||
break
|
||||
if graph_elem is None:
|
||||
raise RuntimeError("GraphML 中未找到 <graph> 节点")
|
||||
|
||||
# 统计现有节点
|
||||
existing_nodes: Set[str] = set()
|
||||
for elem in graph_elem:
|
||||
if is_node(elem):
|
||||
node_id = elem.get("id")
|
||||
if node_id:
|
||||
existing_nodes.add(node_id)
|
||||
|
||||
deleted_nodes = len(nodes_to_delete & existing_nodes)
|
||||
skipped_nodes = len(nodes_to_delete - existing_nodes)
|
||||
|
||||
# 先删除指定节点及相关边
|
||||
# 删除节点
|
||||
for elem in list(graph_elem):
|
||||
if is_node(elem):
|
||||
node_id = elem.get("id")
|
||||
if node_id and node_id in nodes_to_delete:
|
||||
graph_elem.remove(elem)
|
||||
|
||||
# 删除 incident edges
|
||||
for elem in list(graph_elem):
|
||||
if is_edge(elem):
|
||||
src = elem.get("source")
|
||||
tgt = elem.get("target")
|
||||
if src in nodes_to_delete or tgt in nodes_to_delete:
|
||||
graph_elem.remove(elem)
|
||||
|
||||
orphan_removed = 0
|
||||
if remove_orphan_entities:
|
||||
# 计算仍然参与边的节点
|
||||
used_nodes: Set[str] = set()
|
||||
for elem in graph_elem:
|
||||
if is_edge(elem):
|
||||
src = elem.get("source")
|
||||
tgt = elem.get("target")
|
||||
if src:
|
||||
used_nodes.add(src)
|
||||
if tgt:
|
||||
used_nodes.add(tgt)
|
||||
|
||||
# 找出没有任何边的实体节点
|
||||
orphan_entities: Set[str] = set()
|
||||
for elem in graph_elem:
|
||||
if is_node(elem):
|
||||
node_id = elem.get("id")
|
||||
if node_id and node_id.startswith("entity") and node_id not in used_nodes:
|
||||
orphan_entities.add(node_id)
|
||||
|
||||
orphan_removed = len(orphan_entities)
|
||||
|
||||
if orphan_entities:
|
||||
# 删除孤立实体节点
|
||||
for elem in list(graph_elem):
|
||||
if is_node(elem):
|
||||
node_id = elem.get("id")
|
||||
if node_id in orphan_entities:
|
||||
graph_elem.remove(elem)
|
||||
|
||||
# 删除与孤立实体相关的边(理论上已无,但做一次防御性清理)
|
||||
for elem in list(graph_elem):
|
||||
if is_edge(elem):
|
||||
src = elem.get("source")
|
||||
tgt = elem.get("target")
|
||||
if src in orphan_entities or tgt in orphan_entities:
|
||||
graph_elem.remove(elem)
|
||||
|
||||
# 写回 GraphML
|
||||
tree.write(self.graph_data_path, encoding="utf-8", xml_declaration=True)
|
||||
|
||||
# 重新加载图并重建元数据
|
||||
self.graph = di_graph.load_from_file(self.graph_data_path)
|
||||
self._rebuild_metadata_from_graph()
|
||||
|
||||
return {
|
||||
"deleted": deleted_nodes,
|
||||
"skipped": skipped_nodes,
|
||||
"orphan_removed": orphan_removed,
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ from maim_message import UserInfo, Seg, GroupInfo
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
@@ -73,7 +72,6 @@ class ChatBot:
|
||||
def __init__(self):
|
||||
self.bot = None # bot 实例引用
|
||||
self._started = False
|
||||
self.mood_manager = mood_manager # 获取情绪管理器单例
|
||||
self.heartflow_message_receiver = HeartFCMessageReceiver() # 新增
|
||||
|
||||
async def _ensure_started(self):
|
||||
@@ -83,7 +81,7 @@ class ChatBot:
|
||||
|
||||
self._started = True
|
||||
|
||||
async def _process_commands_with_new_system(self, message: MessageRecv):
|
||||
async def _process_commands(self, message: MessageRecv):
|
||||
# sourcery skip: use-named-expression
|
||||
"""使用新插件系统处理命令"""
|
||||
try:
|
||||
@@ -115,17 +113,21 @@ class ChatBot:
|
||||
|
||||
try:
|
||||
# 执行命令
|
||||
success, response, intercept_message = await command_instance.execute()
|
||||
message.is_no_read_command = bool(intercept_message)
|
||||
success, response, intercept_message_level = await command_instance.execute()
|
||||
message.intercept_message_level = intercept_message_level
|
||||
|
||||
# 记录命令执行结果
|
||||
if success:
|
||||
logger.info(f"命令执行成功: {command_class.__name__} (拦截: {intercept_message})")
|
||||
logger.info(f"命令执行成功: {command_class.__name__} (拦截等级: {intercept_message_level})")
|
||||
else:
|
||||
logger.warning(f"命令执行失败: {command_class.__name__} - {response}")
|
||||
|
||||
# 根据命令的拦截设置决定是否继续处理消息
|
||||
return True, response, not intercept_message # 找到命令,根据intercept_message决定是否继续
|
||||
return (
|
||||
True,
|
||||
response,
|
||||
not bool(intercept_message_level),
|
||||
) # 找到命令,根据intercept_message决定是否继续
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")
|
||||
@@ -295,7 +297,7 @@ class ChatBot:
|
||||
# return
|
||||
|
||||
# 命令处理 - 使用新插件系统检查并处理命令
|
||||
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message)
|
||||
is_command, cmd_result, continue_process = await self._process_commands(message)
|
||||
|
||||
# 如果是命令且不需要继续处理,则直接返回
|
||||
if is_command and not continue_process:
|
||||
|
||||
@@ -122,7 +122,7 @@ class MessageRecv(Message):
|
||||
self.is_notify = False
|
||||
|
||||
self.is_command = False
|
||||
self.is_no_read_command = False
|
||||
self.intercept_message_level = 0
|
||||
|
||||
self.priority_mode = "interest"
|
||||
self.priority_info = None
|
||||
@@ -213,6 +213,68 @@ class MessageRecv(Message):
|
||||
}
|
||||
"""
|
||||
return ""
|
||||
elif segment.type == "video_card":
|
||||
# 处理视频卡片消息
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = False
|
||||
if isinstance(segment.data, dict):
|
||||
file_name = segment.data.get("file", "未知视频")
|
||||
file_size = segment.data.get("file_size", "")
|
||||
url = segment.data.get("url", "")
|
||||
text = f"[视频: {file_name}"
|
||||
if file_size:
|
||||
text += f", 大小: {file_size}字节"
|
||||
text += "]"
|
||||
if url:
|
||||
text += f" 链接: {url}"
|
||||
return text
|
||||
return "[视频]"
|
||||
elif segment.type == "music_card":
|
||||
# 处理音乐卡片消息
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = False
|
||||
if isinstance(segment.data, dict):
|
||||
title = segment.data.get("title", "未知歌曲")
|
||||
singer = segment.data.get("singer", "")
|
||||
tag = segment.data.get("tag", "") # 音乐来源,如"网易云音乐"
|
||||
jump_url = segment.data.get("jump_url", "")
|
||||
music_url = segment.data.get("music_url", "")
|
||||
text = f"[音乐: {title}"
|
||||
if singer:
|
||||
text += f" - {singer}"
|
||||
if tag:
|
||||
text += f" ({tag})"
|
||||
text += "]"
|
||||
if jump_url:
|
||||
text += f" 跳转链接: {jump_url}"
|
||||
if music_url:
|
||||
text += f" 音乐链接: {music_url}"
|
||||
return text
|
||||
return "[音乐]"
|
||||
elif segment.type == "miniapp_card":
|
||||
# 处理小程序分享卡片(如B站视频分享)
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = False
|
||||
if isinstance(segment.data, dict):
|
||||
title = segment.data.get("title", "") # 小程序名称
|
||||
desc = segment.data.get("desc", "") # 内容描述
|
||||
source_url = segment.data.get("source_url", "") # 原始链接
|
||||
url = segment.data.get("url", "") # 小程序链接
|
||||
text = "[小程序分享"
|
||||
if title:
|
||||
text += f" - {title}"
|
||||
text += "]"
|
||||
if desc:
|
||||
text += f" {desc}"
|
||||
if source_url:
|
||||
text += f" 链接: {source_url}"
|
||||
elif url:
|
||||
text += f" 链接: {url}"
|
||||
return text
|
||||
return "[小程序分享]"
|
||||
else:
|
||||
return ""
|
||||
except Exception as e:
|
||||
|
||||
@@ -72,7 +72,7 @@ class MessageStorage:
|
||||
key_words = ""
|
||||
key_words_lite = ""
|
||||
selected_expressions = message.selected_expressions
|
||||
is_no_read_command = False
|
||||
intercept_message_level = 0
|
||||
else:
|
||||
filtered_display_message = ""
|
||||
interest_value = message.interest_value
|
||||
@@ -86,7 +86,7 @@ class MessageStorage:
|
||||
is_picid = message.is_picid
|
||||
is_notify = message.is_notify
|
||||
is_command = message.is_command
|
||||
is_no_read_command = getattr(message, "is_no_read_command", False)
|
||||
intercept_message_level = getattr(message, "intercept_message_level", 0)
|
||||
# 序列化关键词列表为JSON字符串
|
||||
key_words = MessageStorage._serialize_keywords(message.key_words)
|
||||
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
|
||||
@@ -138,7 +138,7 @@ class MessageStorage:
|
||||
is_picid=is_picid,
|
||||
is_notify=is_notify,
|
||||
is_command=is_command,
|
||||
is_no_read_command=is_no_read_command,
|
||||
intercept_message_level=intercept_message_level,
|
||||
key_words=key_words,
|
||||
key_words_lite=key_words_lite,
|
||||
selected_expressions=selected_expressions,
|
||||
|
||||
@@ -40,6 +40,93 @@ def is_webui_virtual_group(group_id: str) -> bool:
|
||||
return group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX)
|
||||
|
||||
|
||||
def parse_message_segments(segment) -> list:
|
||||
"""解析消息段,转换为 WebUI 可用的格式
|
||||
|
||||
参考 NapCat 适配器的消息解析逻辑
|
||||
|
||||
Args:
|
||||
segment: Seg 消息段对象
|
||||
|
||||
Returns:
|
||||
list: 消息段列表,每个元素为 {"type": "...", "data": ...}
|
||||
"""
|
||||
|
||||
result = []
|
||||
|
||||
if segment is None:
|
||||
return result
|
||||
|
||||
if segment.type == "seglist":
|
||||
# 处理消息段列表
|
||||
if segment.data:
|
||||
for seg in segment.data:
|
||||
result.extend(parse_message_segments(seg))
|
||||
elif segment.type == "text":
|
||||
# 文本消息
|
||||
if segment.data:
|
||||
result.append({"type": "text", "data": segment.data})
|
||||
elif segment.type == "image":
|
||||
# 图片消息(base64)
|
||||
if segment.data:
|
||||
result.append({"type": "image", "data": f"data:image/png;base64,{segment.data}"})
|
||||
elif segment.type == "emoji":
|
||||
# 表情包消息(base64)
|
||||
if segment.data:
|
||||
result.append({"type": "emoji", "data": f"data:image/gif;base64,{segment.data}"})
|
||||
elif segment.type == "imageurl":
|
||||
# 图片链接消息
|
||||
if segment.data:
|
||||
result.append({"type": "image", "data": segment.data})
|
||||
elif segment.type == "face":
|
||||
# 原生表情
|
||||
result.append({"type": "face", "data": segment.data})
|
||||
elif segment.type == "voice":
|
||||
# 语音消息(base64)
|
||||
if segment.data:
|
||||
result.append({"type": "voice", "data": f"data:audio/wav;base64,{segment.data}"})
|
||||
elif segment.type == "voiceurl":
|
||||
# 语音链接
|
||||
if segment.data:
|
||||
result.append({"type": "voice", "data": segment.data})
|
||||
elif segment.type == "video":
|
||||
# 视频消息(base64)
|
||||
if segment.data:
|
||||
result.append({"type": "video", "data": f"data:video/mp4;base64,{segment.data}"})
|
||||
elif segment.type == "videourl":
|
||||
# 视频链接
|
||||
if segment.data:
|
||||
result.append({"type": "video", "data": segment.data})
|
||||
elif segment.type == "music":
|
||||
# 音乐消息
|
||||
result.append({"type": "music", "data": segment.data})
|
||||
elif segment.type == "file":
|
||||
# 文件消息
|
||||
result.append({"type": "file", "data": segment.data})
|
||||
elif segment.type == "reply":
|
||||
# 回复消息
|
||||
result.append({"type": "reply", "data": segment.data})
|
||||
elif segment.type == "forward":
|
||||
# 转发消息
|
||||
forward_items = []
|
||||
if segment.data:
|
||||
for item in segment.data:
|
||||
forward_items.append(
|
||||
{
|
||||
"content": parse_message_segments(item.get("message_segment", {}))
|
||||
if isinstance(item, dict)
|
||||
else []
|
||||
}
|
||||
)
|
||||
result.append({"type": "forward", "data": forward_items})
|
||||
else:
|
||||
# 未知类型,尝试作为文本处理
|
||||
if segment.data:
|
||||
result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
"""合并后的消息发送函数,包含WS发送和日志记录"""
|
||||
message_preview = truncate_message(message.processed_plain_text, max_length=200)
|
||||
@@ -50,17 +137,31 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
|
||||
chat_manager, webui_platform = get_webui_chat_broadcaster()
|
||||
is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id)
|
||||
|
||||
|
||||
if is_webui_message and chat_manager is not None:
|
||||
# WebUI 聊天室消息(包括虚拟身份模式),通过 WebSocket 广播
|
||||
import time
|
||||
from src.config.config import global_config
|
||||
|
||||
# 解析消息段,获取富文本内容
|
||||
message_segments = parse_message_segments(message.message_segment)
|
||||
|
||||
# 判断消息类型
|
||||
# 如果只有一个文本段,使用简单的 text 类型
|
||||
# 否则使用 rich 类型,包含完整的消息段
|
||||
if len(message_segments) == 1 and message_segments[0].get("type") == "text":
|
||||
message_type = "text"
|
||||
segments = None
|
||||
else:
|
||||
message_type = "rich"
|
||||
segments = message_segments
|
||||
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
"type": "bot_message",
|
||||
"content": message.processed_plain_text,
|
||||
"message_type": "text",
|
||||
"message_type": message_type,
|
||||
"segments": segments, # 富文本消息段
|
||||
"timestamp": time.time(),
|
||||
"group_id": group_id, # 包含群 ID 以便前端区分不同的聊天标签
|
||||
"sender": {
|
||||
@@ -81,11 +182,70 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
logger.info(f"已将消息 '{message_preview}' 发往 WebUI 聊天室")
|
||||
return True
|
||||
|
||||
# 直接调用API发送消息
|
||||
await get_global_api().send_message(message)
|
||||
if show_log:
|
||||
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'")
|
||||
return True
|
||||
# Fallback 逻辑: 尝试通过 API Server 发送
|
||||
async def send_with_new_api(legacy_exception=None):
|
||||
try:
|
||||
from src.config.config import global_config
|
||||
|
||||
# 如果未开启 API Server,直接跳过 Fallback
|
||||
if not global_config.maim_message.enable_api_server:
|
||||
if legacy_exception:
|
||||
raise legacy_exception
|
||||
return False
|
||||
|
||||
global_api = get_global_api()
|
||||
extra_server = getattr(global_api, "extra_server", None)
|
||||
|
||||
if extra_server and extra_server.is_running():
|
||||
# Fallback: 使用极其简单的 Platform -> API Key 映射
|
||||
# 只有收到过该平台的消息,我们才知道该平台的 API Key,才能回传消息
|
||||
platform_map = getattr(global_api, "platform_map", {})
|
||||
target_api_key = platform_map.get(platform)
|
||||
|
||||
if target_api_key:
|
||||
# 构造 APIMessageBase
|
||||
from maim_message.message import APIMessageBase, MessageDim
|
||||
|
||||
msg_dim = MessageDim(api_key=target_api_key, platform=platform)
|
||||
|
||||
api_message = APIMessageBase(
|
||||
message_info=message.message_info,
|
||||
message_segment=message.message_segment,
|
||||
message_dim=msg_dim,
|
||||
)
|
||||
|
||||
# 直接调用 Server 的 send_message 接口,它会自动处理路由
|
||||
results = await extra_server.send_message(api_message)
|
||||
|
||||
# 检查是否有任何连接发送成功
|
||||
if any(results.values()):
|
||||
if show_log:
|
||||
logger.info(
|
||||
f"已通过API Server Fallback将消息 '{message_preview}' 发往平台'{platform}' (key: {target_api_key})"
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 如果 Fallback 失败,且存在 legacy 异常,则抛出 legacy 异常
|
||||
if legacy_exception:
|
||||
raise legacy_exception
|
||||
return False
|
||||
|
||||
try:
|
||||
send_result = await get_global_api().send_message(message)
|
||||
# if send_result:
|
||||
if show_log:
|
||||
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'")
|
||||
return True
|
||||
|
||||
# Legacy API 返回 False (发送失败但未报错),尝试 Fallback
|
||||
# return await send_with_new_api()
|
||||
|
||||
except Exception as legacy_e:
|
||||
# Legacy API 抛出异常,尝试 Fallback
|
||||
# 如果 Fallback 也失败,将重新抛出 legacy_e
|
||||
return await send_with_new_api(legacy_exception=legacy_e)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {str(e)}")
|
||||
|
||||
@@ -69,7 +69,7 @@ class ActionModifier:
|
||||
chat_id=self.chat_stream.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=min(int(global_config.chat.max_context_size * 0.33), 10),
|
||||
filter_no_read_command=True,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
|
||||
chat_content = build_readable_messages(
|
||||
|
||||
@@ -15,12 +15,15 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
replace_user_references,
|
||||
)
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.apis.message_api import translate_pid_to_description
|
||||
from src.person_info.person_info import Person
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.info_data_model import TargetPersonInfo
|
||||
@@ -36,7 +39,6 @@ def init_prompt():
|
||||
"""
|
||||
{time_block}
|
||||
{name_block}
|
||||
你的兴趣是:{interest}
|
||||
{chat_context_description},以下是具体的聊天内容
|
||||
**聊天内容**
|
||||
{chat_content_block}
|
||||
@@ -46,9 +48,12 @@ reply
|
||||
动作描述:
|
||||
1.你可以选择呼叫了你的名字,但是你没有做出回应的消息进行回复
|
||||
2.你可以自然的顺着正在进行的聊天内容进行回复或自然的提出一个问题
|
||||
3.不要回复你自己发送的消息
|
||||
4.不要单独对表情包进行回复
|
||||
{{"action":"reply", "target_message_id":"消息id(m+数字)", "reason":"原因"}}
|
||||
3.最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。
|
||||
4.不要选择回复你自己发送的消息
|
||||
5.不要单独对表情包进行回复
|
||||
6.将上下文中所有含义不明的,疑似黑话的,缩写词均写入unknown_words中
|
||||
7.用一句简单的话来描述当前回复场景,不超过10个字
|
||||
{reply_action_example}
|
||||
|
||||
no_reply
|
||||
动作描述:
|
||||
@@ -56,75 +61,37 @@ no_reply
|
||||
控制聊天频率,不要太过频繁的发言
|
||||
{{"action":"no_reply"}}
|
||||
|
||||
{no_reply_until_call_block}
|
||||
|
||||
{action_options_text}
|
||||
|
||||
|
||||
**你之前的action执行和思考记录**
|
||||
{actions_before_now_block}
|
||||
|
||||
请选择**可选的**且符合使用条件的action,并说明触发action的消息id(消息id格式:m+数字)
|
||||
不要回复你自己发送的消息
|
||||
先输出你的简短的选择思考理由,再输出你选择的action,理由不要分点,精简。
|
||||
**动作选择要求**
|
||||
请你根据聊天内容,用户的最新消息和以下标准选择合适的动作:
|
||||
{plan_style}
|
||||
{moderation_prompt}
|
||||
|
||||
请选择所有符合使用要求的action,动作用json格式输出,用```json包裹,如果输出多个json,每个json都要单独一行放在同一个```json代码块内,你可以重复使用同一个动作或不同动作:
|
||||
target_message_id为必填,表示触发消息的id
|
||||
请选择所有符合使用要求的action,每个动作最多选择一次,但是可以选择多个动作;
|
||||
动作用json格式输出,用```json包裹,如果输出多个json,每个json都要单独一行放在同一个```json代码块内:
|
||||
**示例**
|
||||
// 理由文本(简短)
|
||||
```json
|
||||
{{"action":"动作名", "target_message_id":"m123", "reason":"原因"}}
|
||||
{{"action":"动作名", "target_message_id":"m456", "reason":"原因"}}
|
||||
{{"action":"动作名", "target_message_id":"m123", .....}}
|
||||
{{"action":"动作名", "target_message_id":"m456", .....}}
|
||||
```""",
|
||||
"planner_prompt",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""{time_block}
|
||||
{name_block}
|
||||
{chat_context_description},以下是具体的聊天内容
|
||||
**聊天内容**
|
||||
{chat_content_block}
|
||||
|
||||
**可选的action**
|
||||
no_reply
|
||||
动作描述:
|
||||
没有合适的可以使用的动作,不使用action
|
||||
{{"action":"no_reply"}}
|
||||
|
||||
{action_options_text}
|
||||
|
||||
**你之前的action执行和思考记录**
|
||||
{actions_before_now_block}
|
||||
|
||||
请选择**可选的**且符合使用条件的action,并说明触发action的消息id(消息id格式:m+数字)
|
||||
先输出你的简短的选择思考理由,再输出你选择的action,理由不要分点,精简。
|
||||
**动作选择要求**
|
||||
请你根据聊天内容,用户的最新消息和以下标准选择合适的动作:
|
||||
1.思考**所有**的可用的action中的**每个动作**是否符合当下条件,如果动作使用条件符合聊天内容就使用
|
||||
2.如果相同的内容已经被执行,请不要重复执行
|
||||
{moderation_prompt}
|
||||
|
||||
请选择所有符合使用要求的action,动作用json格式输出,用```json包裹,如果输出多个json,每个json都要单独一行放在同一个```json代码块内,你可以重复使用同一个动作或不同动作:
|
||||
**示例**
|
||||
// 理由文本(简短)
|
||||
```json
|
||||
{{"action":"动作名", "target_message_id":"m123", "reason":"原因"}}
|
||||
{{"action":"动作名", "target_message_id":"m456", "reason":"原因"}}
|
||||
```""",
|
||||
"planner_prompt_mentioned",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
{action_name}
|
||||
动作描述:{action_description}
|
||||
使用条件{parallel_text}:
|
||||
{action_require}
|
||||
{{"action":"{action_name}",{action_parameters}, "target_message_id":"消息id(m+数字)", "reason":"原因"}}
|
||||
{{"action":"{action_name}",{action_parameters}, "target_message_id":"消息id(m+数字)"}}
|
||||
""",
|
||||
"action_prompt",
|
||||
)
|
||||
@@ -195,11 +162,41 @@ class ActionPlanner:
|
||||
logger.warning(f"{self.log_prefix}planner理由引用 {msg_id} 未找到对应消息,保持原样")
|
||||
return msg_id
|
||||
|
||||
msg_text = (message.processed_plain_text or message.display_message or "").strip()
|
||||
msg_text = (message.processed_plain_text or "").strip()
|
||||
if not msg_text:
|
||||
logger.warning(f"{self.log_prefix}planner理由引用 {msg_id} 的消息内容为空,保持原样")
|
||||
return msg_id
|
||||
|
||||
# 替换 [picid:xxx] 为 [图片:描述]
|
||||
pic_pattern = r"\[picid:([^\]]+)\]"
|
||||
def replace_pic_id(pic_match: re.Match) -> str:
|
||||
pic_id = pic_match.group(1)
|
||||
description = translate_pid_to_description(pic_id)
|
||||
return f"[图片:{description}]"
|
||||
msg_text = re.sub(pic_pattern, replace_pic_id, msg_text)
|
||||
|
||||
# 替换用户引用格式:回复<aaa:bbb> 和 @<aaa:bbb>
|
||||
platform = getattr(message, "user_info", None) and message.user_info.platform or getattr(message, "chat_info", None) and message.chat_info.platform or "qq"
|
||||
msg_text = replace_user_references(msg_text, platform, replace_bot_name=True)
|
||||
|
||||
# 替换单独的 <用户名:用户ID> 格式(replace_user_references 已处理回复<和@<格式)
|
||||
# 匹配所有 <aaa:bbb> 格式,由于 replace_user_references 已经替换了回复<和@<格式,
|
||||
# 这里匹配到的应该都是单独的格式
|
||||
user_ref_pattern = r"<([^:<>]+):([^:<>]+)>"
|
||||
def replace_user_ref(user_match: re.Match) -> str:
|
||||
user_name = user_match.group(1)
|
||||
user_id = user_match.group(2)
|
||||
try:
|
||||
# 检查是否是机器人自己
|
||||
if user_id == global_config.bot.qq_account:
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
person = Person(platform=platform, user_id=user_id)
|
||||
return person.person_name or user_name
|
||||
except Exception:
|
||||
# 如果解析失败,使用原始昵称
|
||||
return user_name
|
||||
msg_text = re.sub(user_ref_pattern, replace_user_ref, msg_text)
|
||||
|
||||
preview = msg_text if len(msg_text) <= 100 else f"{msg_text[:97]}..."
|
||||
logger.info(f"{self.log_prefix}planner理由引用 {msg_id} -> 消息({preview})")
|
||||
return f"消息({msg_text})"
|
||||
@@ -218,11 +215,14 @@ class ActionPlanner:
|
||||
|
||||
try:
|
||||
action = action_json.get("action", "no_reply")
|
||||
original_reasoning = action_json.get("reason", "未提供原因")
|
||||
reasoning = self._replace_message_ids_with_text(original_reasoning, message_id_list)
|
||||
if reasoning is None:
|
||||
reasoning = original_reasoning
|
||||
action_data = {key: value for key, value in action_json.items() if key not in ["action", "reason"]}
|
||||
# 使用 extracted_reasoning(整体推理文本)作为 reasoning
|
||||
if extracted_reasoning:
|
||||
reasoning = self._replace_message_ids_with_text(extracted_reasoning, message_id_list)
|
||||
if reasoning is None:
|
||||
reasoning = extracted_reasoning
|
||||
else:
|
||||
reasoning = "未提供原因"
|
||||
action_data = {key: value for key, value in action_json.items() if key not in ["action"]}
|
||||
# 非no_reply动作需要target_message_id
|
||||
target_message = None
|
||||
|
||||
@@ -248,7 +248,7 @@ class ActionPlanner:
|
||||
|
||||
# 验证action是否可用
|
||||
available_action_names = [action_name for action_name, _ in current_available_actions]
|
||||
internal_action_names = ["no_reply", "reply", "wait_time", "no_reply_until_call"]
|
||||
internal_action_names = ["no_reply", "reply", "wait_time"]
|
||||
|
||||
if action not in internal_action_names and action not in available_action_names:
|
||||
logger.warning(
|
||||
@@ -304,7 +304,7 @@ class ActionPlanner:
|
||||
self,
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
loop_start_time: float = 0.0,
|
||||
is_mentioned: bool = False,
|
||||
force_reply_message: Optional["DatabaseMessages"] = None,
|
||||
) -> List[ActionPlannerInfo]:
|
||||
# sourcery skip: use-named-expression
|
||||
"""
|
||||
@@ -316,7 +316,7 @@ class ActionPlanner:
|
||||
chat_id=self.chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
filter_no_read_command=True,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
message_id_list: list[Tuple[str, "DatabaseMessages"]] = []
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
@@ -345,11 +345,6 @@ class ActionPlanner:
|
||||
|
||||
logger.debug(f"{self.log_prefix}过滤后有{len(filtered_actions)}个可用动作")
|
||||
|
||||
# 如果是提及时且没有可用动作,直接返回空列表,不调用LLM以节省token
|
||||
if is_mentioned and not filtered_actions:
|
||||
logger.info(f"{self.log_prefix}提及时没有可用动作,跳过plan调用")
|
||||
return []
|
||||
|
||||
# 构建包含所有动作的提示词
|
||||
prompt, message_id_list = await self.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
@@ -357,8 +352,6 @@ class ActionPlanner:
|
||||
current_available_actions=filtered_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
interest=global_config.personality.interest,
|
||||
is_mentioned=is_mentioned,
|
||||
)
|
||||
|
||||
# 调用LLM获取决策
|
||||
@@ -370,6 +363,34 @@ class ActionPlanner:
|
||||
loop_start_time=loop_start_time,
|
||||
)
|
||||
|
||||
# 如果有强制回复消息,确保回复该消息
|
||||
if force_reply_message:
|
||||
# 检查是否已经有回复该消息的 action
|
||||
has_reply_to_force_message = False
|
||||
for action in actions:
|
||||
if action.action_type == "reply" and action.action_message and action.action_message.message_id == force_reply_message.message_id:
|
||||
has_reply_to_force_message = True
|
||||
break
|
||||
|
||||
# 如果没有回复该消息,强制添加回复 action
|
||||
if not has_reply_to_force_message:
|
||||
# 移除所有 no_reply action(如果有)
|
||||
actions = [a for a in actions if a.action_type != "no_reply"]
|
||||
|
||||
# 创建强制回复 action
|
||||
available_actions_dict = dict(current_available_actions)
|
||||
force_reply_action = ActionPlannerInfo(
|
||||
action_type="reply",
|
||||
reasoning="用户提及了我,必须回复该消息",
|
||||
action_data={"loop_start_time": loop_start_time},
|
||||
action_message=force_reply_message,
|
||||
available_actions=available_actions_dict,
|
||||
action_reasoning=None,
|
||||
)
|
||||
# 将强制回复 action 放在最前面
|
||||
actions.insert(0, force_reply_action)
|
||||
logger.info(f"{self.log_prefix} 检测到强制回复消息,已添加回复动作")
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix}Planner:{reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
|
||||
)
|
||||
@@ -430,32 +451,6 @@ class ActionPlanner:
|
||||
|
||||
return plan_log_str
|
||||
|
||||
def _has_consecutive_no_reply(self, min_count: int = 3) -> bool:
|
||||
"""
|
||||
检查是否有连续min_count次以上的no_reply
|
||||
|
||||
Args:
|
||||
min_count: 需要连续的最少次数,默认3
|
||||
|
||||
Returns:
|
||||
如果有连续min_count次以上no_reply返回True,否则返回False
|
||||
"""
|
||||
consecutive_count = 0
|
||||
|
||||
# 从后往前遍历plan_log,检查最新的连续记录
|
||||
for _reasoning, _timestamp, content in reversed(self.plan_log):
|
||||
if isinstance(content, list) and all(isinstance(action, ActionPlannerInfo) for action in content):
|
||||
# 检查所有action是否都是no_reply
|
||||
if all(action.action_type == "no_reply" for action in content):
|
||||
consecutive_count += 1
|
||||
if consecutive_count >= min_count:
|
||||
return True
|
||||
else:
|
||||
# 如果遇到非no_reply的action,重置计数
|
||||
break
|
||||
|
||||
return False
|
||||
|
||||
async def build_planner_prompt(
|
||||
self,
|
||||
is_group_chat: bool,
|
||||
@@ -464,7 +459,6 @@ class ActionPlanner:
|
||||
message_id_list: List[Tuple[str, "DatabaseMessages"]],
|
||||
chat_content_block: str = "",
|
||||
interest: str = "",
|
||||
is_mentioned: bool = False,
|
||||
) -> tuple[str, List[Tuple[str, "DatabaseMessages"]]]:
|
||||
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
||||
try:
|
||||
@@ -485,48 +479,35 @@ class ActionPlanner:
|
||||
)
|
||||
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
|
||||
|
||||
# 根据是否是提及时选择不同的模板
|
||||
if is_mentioned:
|
||||
# 提及时使用简化版提示词,不需要reply、no_reply、no_reply_until_call
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt_mentioned")
|
||||
prompt = planner_prompt_template.format(
|
||||
time_block=time_block,
|
||||
chat_context_description=chat_context_description,
|
||||
chat_content_block=chat_content_block,
|
||||
actions_before_now_block=actions_before_now_block,
|
||||
action_options_text=action_options_block,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
name_block=name_block,
|
||||
interest=interest,
|
||||
plan_style=global_config.personality.plan_style,
|
||||
# 根据 think_mode 配置决定 reply action 的示例 JSON
|
||||
# 在 JSON 中直接作为 action 参数携带 unknown_words
|
||||
if global_config.chat.think_mode == "classic":
|
||||
reply_action_example = (
|
||||
'{{"action":"reply", "target_message_id":"消息id(m+数字)", '
|
||||
'"unknown_words":["词语1","词语2"]}}'
|
||||
)
|
||||
else:
|
||||
# 正常流程使用完整版提示词
|
||||
# 检查是否有连续3次以上no_reply,如果有则添加no_reply_until_call选项
|
||||
no_reply_until_call_block = ""
|
||||
if self._has_consecutive_no_reply(min_count=3):
|
||||
no_reply_until_call_block = """no_reply_until_call
|
||||
动作描述:
|
||||
保持沉默,直到有人直接叫你的名字
|
||||
当前话题不感兴趣时使用,或有人不喜欢你的发言时使用
|
||||
当你频繁选择no_reply时使用,表示话题暂时与你无关
|
||||
{{"action":"no_reply_until_call"}}
|
||||
"""
|
||||
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
||||
prompt = planner_prompt_template.format(
|
||||
time_block=time_block,
|
||||
chat_context_description=chat_context_description,
|
||||
chat_content_block=chat_content_block,
|
||||
actions_before_now_block=actions_before_now_block,
|
||||
action_options_text=action_options_block,
|
||||
no_reply_until_call_block=no_reply_until_call_block,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
name_block=name_block,
|
||||
interest=interest,
|
||||
plan_style=global_config.personality.plan_style,
|
||||
reply_action_example = (
|
||||
"5.think_level表示思考深度,0表示该回复不需要思考和回忆,1表示该回复需要进行回忆和思考\n"
|
||||
+ '{{"action":"reply", "think_level":数值等级(0或1), '
|
||||
'"target_message_id":"消息id(m+数字)", '
|
||||
'"unknown_words":["词语1","词语2"]}}'
|
||||
)
|
||||
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
||||
prompt = planner_prompt_template.format(
|
||||
time_block=time_block,
|
||||
chat_context_description=chat_context_description,
|
||||
chat_content_block=chat_content_block,
|
||||
actions_before_now_block=actions_before_now_block,
|
||||
action_options_text=action_options_block,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
name_block=name_block,
|
||||
interest=interest,
|
||||
plan_style=global_config.personality.plan_style,
|
||||
reply_action_example=reply_action_example,
|
||||
)
|
||||
|
||||
return prompt, message_id_list
|
||||
except Exception as e:
|
||||
logger.error(f"构建 Planner 提示词时出错: {e}")
|
||||
@@ -696,6 +677,12 @@ class ActionPlanner:
|
||||
action.action_data = action.action_data or {}
|
||||
action.action_data["loop_start_time"] = loop_start_time
|
||||
|
||||
# 去重:如果同一个动作被选择了多次,随机选择其中一个
|
||||
if actions:
|
||||
shuffled = actions.copy()
|
||||
random.shuffle(shuffled)
|
||||
actions = list({a.action_type: a for a in shuffled}.values())
|
||||
|
||||
logger.debug(f"{self.log_prefix}规划器选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}")
|
||||
|
||||
return extracted_reasoning, actions
|
||||
@@ -747,23 +734,27 @@ class ActionPlanner:
|
||||
# 尝试解析每一行作为独立的JSON对象
|
||||
json_obj = json.loads(repair_json(line))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
# 过滤掉空字典,避免单个 { 字符被错误修复为 {} 的情况
|
||||
if json_obj:
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
if isinstance(item, dict) and item:
|
||||
json_objects.append(item)
|
||||
except json.JSONDecodeError:
|
||||
# 如果单行解析失败,尝试将整个块作为一个JSON对象或数组
|
||||
pass
|
||||
|
||||
# 如果按行解析没有成功,尝试将整个块作为一个JSON对象或数组
|
||||
# 如果按行解析没有成功(或只得到空字典),尝试将整个块作为一个JSON对象或数组
|
||||
if not json_objects:
|
||||
json_obj = json.loads(repair_json(json_str))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
# 过滤掉空字典
|
||||
if json_obj:
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
if isinstance(item, dict) and item:
|
||||
json_objects.append(item)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析JSON块失败: {e}, 块内容: {match[:100]}...")
|
||||
@@ -798,23 +789,27 @@ class ActionPlanner:
|
||||
try:
|
||||
json_obj = json.loads(repair_json(line))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
# 过滤掉空字典,避免单个 { 字符被错误修复为 {} 的情况
|
||||
if json_obj:
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
if isinstance(item, dict) and item:
|
||||
json_objects.append(item)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 如果按行解析没有成功,尝试将整个块作为一个JSON对象或数组
|
||||
# 如果按行解析没有成功(或只得到空字典),尝试将整个块作为一个JSON对象或数组
|
||||
if not json_objects:
|
||||
try:
|
||||
json_obj = json.loads(repair_json(json_str))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
# 过滤掉空字典
|
||||
if json_obj:
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
if isinstance(item, dict) and item:
|
||||
json_objects.append(item)
|
||||
except Exception as e:
|
||||
logger.debug(f"尝试解析不完整的JSON代码块失败: {e}")
|
||||
|
||||
@@ -18,13 +18,12 @@ from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
replace_user_references,
|
||||
)
|
||||
from src.express.expression_selector import expression_selector
|
||||
from src.bw_learner.expression_selector import expression_selector
|
||||
from src.plugin_system.apis.message_api import translate_pid_to_description
|
||||
|
||||
# from src.memory_system.memory_activator import MemoryActivator
|
||||
@@ -36,7 +35,7 @@ from src.chat.replyer.prompt.lpmm_prompt import init_lpmm_prompt
|
||||
from src.chat.replyer.prompt.replyer_prompt import init_replyer_prompt
|
||||
from src.chat.replyer.prompt.rewrite_prompt import init_rewrite_prompt
|
||||
from src.memory_system.memory_retrieval import init_memory_retrieval_prompt, build_memory_retrieval_prompt
|
||||
from src.jargon.jargon_explainer import explain_jargon_in_context
|
||||
from src.bw_learner.jargon_explainer import explain_jargon_in_context, retrieve_concepts_with_jargon
|
||||
|
||||
init_lpmm_prompt()
|
||||
init_replyer_prompt()
|
||||
@@ -73,6 +72,8 @@ class DefaultReplyer:
|
||||
stream_id: Optional[str] = None,
|
||||
reply_message: Optional[DatabaseMessages] = None,
|
||||
reply_time_point: Optional[float] = time.time(),
|
||||
think_level: int = 1,
|
||||
unknown_words: Optional[List[str]] = None,
|
||||
) -> Tuple[bool, LLMGenerationDataModel]:
|
||||
# sourcery skip: merge-nested-ifs
|
||||
"""
|
||||
@@ -98,8 +99,10 @@ class DefaultReplyer:
|
||||
available_actions = {}
|
||||
try:
|
||||
# 3. 构建 Prompt
|
||||
timing_logs = []
|
||||
almost_zero_str = ""
|
||||
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
||||
prompt, selected_expressions = await self.build_prompt_reply_context(
|
||||
prompt, selected_expressions, timing_logs, almost_zero_str = await self.build_prompt_reply_context(
|
||||
extra_info=extra_info,
|
||||
available_actions=available_actions,
|
||||
chosen_actions=chosen_actions,
|
||||
@@ -107,6 +110,8 @@ class DefaultReplyer:
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason,
|
||||
reply_time_point=reply_time_point,
|
||||
think_level=think_level,
|
||||
unknown_words=unknown_words,
|
||||
)
|
||||
llm_response.prompt = prompt
|
||||
llm_response.selected_expressions = selected_expressions
|
||||
@@ -135,10 +140,22 @@ class DefaultReplyer:
|
||||
content, reasoning_content, model_name, tool_call = await self.llm_generate_content(prompt)
|
||||
# logger.debug(f"replyer生成内容: {content}")
|
||||
|
||||
logger.info(f"replyer生成内容: {content}")
|
||||
if global_config.debug.show_replyer_reasoning:
|
||||
logger.info(f"replyer生成推理:\n{reasoning_content}")
|
||||
logger.info(f"replyer生成模型: {model_name}")
|
||||
# 统一输出所有日志信息,使用try-except确保即使某个步骤出错也能输出
|
||||
try:
|
||||
# 1. 输出回复准备日志
|
||||
timing_log_str = f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.1s" if timing_logs or almost_zero_str else "回复准备: 无计时信息"
|
||||
logger.info(timing_log_str)
|
||||
# 2. 输出Prompt日志
|
||||
if global_config.debug.show_replyer_prompt:
|
||||
logger.info(f"\n{prompt}\n")
|
||||
else:
|
||||
logger.debug(f"\nreplyer_Prompt:{prompt}\n")
|
||||
# 3. 输出模型生成内容和推理日志
|
||||
logger.info(f"模型: [{model_name}][思考等级:{think_level}]生成内容: {content}")
|
||||
if global_config.debug.show_replyer_reasoning and reasoning_content:
|
||||
logger.info(f"模型: [{model_name}][思考等级:{think_level}]生成推理:\n{reasoning_content}")
|
||||
except Exception as e:
|
||||
logger.warning(f"输出日志时出错: {e}")
|
||||
|
||||
llm_response.content = content
|
||||
llm_response.reasoning = reasoning_content
|
||||
@@ -162,6 +179,21 @@ class DefaultReplyer:
|
||||
except Exception as llm_e:
|
||||
# 精简报错信息
|
||||
logger.error(f"LLM 生成失败: {llm_e}")
|
||||
# 即使LLM生成失败,也尝试输出已收集的日志信息
|
||||
try:
|
||||
# 1. 输出回复准备日志
|
||||
timing_log_str = f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.1s" if timing_logs or almost_zero_str else "回复准备: 无计时信息"
|
||||
logger.info(timing_log_str)
|
||||
# 2. 输出Prompt日志
|
||||
if global_config.debug.show_replyer_prompt:
|
||||
logger.info(f"\n{prompt}\n")
|
||||
else:
|
||||
logger.debug(f"\nreplyer_Prompt:{prompt}\n")
|
||||
# 3. 输出模型生成失败信息
|
||||
logger.info("模型生成失败,无法输出生成内容和推理")
|
||||
except Exception as log_e:
|
||||
logger.warning(f"输出日志时出错: {log_e}")
|
||||
|
||||
return False, llm_response # LLM 调用失败则无法生成回复
|
||||
|
||||
return True, llm_response
|
||||
@@ -228,7 +260,7 @@ class DefaultReplyer:
|
||||
return False, llm_response
|
||||
|
||||
async def build_expression_habits(
|
||||
self, chat_history: str, target: str, reply_reason: str = ""
|
||||
self, chat_history: str, target: str, reply_reason: str = "", think_level: int = 1
|
||||
) -> Tuple[str, List[int]]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""构建表达习惯块
|
||||
@@ -237,6 +269,7 @@ class DefaultReplyer:
|
||||
chat_history: 聊天历史记录
|
||||
target: 目标消息内容
|
||||
reply_reason: planner给出的回复理由
|
||||
think_level: 思考级别,0/1/2
|
||||
|
||||
Returns:
|
||||
str: 表达习惯信息字符串
|
||||
@@ -249,14 +282,19 @@ class DefaultReplyer:
|
||||
# 使用从处理器传来的选中表达方式
|
||||
# 使用模型预测选择表达方式
|
||||
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
|
||||
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason
|
||||
self.chat_stream.stream_id,
|
||||
chat_history,
|
||||
max_num=8,
|
||||
target_message=target,
|
||||
reply_reason=reply_reason,
|
||||
think_level=think_level,
|
||||
)
|
||||
|
||||
if selected_expressions:
|
||||
logger.debug(f"使用处理器选中的{len(selected_expressions)}个表达方式")
|
||||
for expr in selected_expressions:
|
||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
||||
style_habits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
style_habits.append(f"当{expr['situation']}时:{expr['style']}")
|
||||
else:
|
||||
logger.debug("没有从处理器获得表达方式,将使用空的表达方式")
|
||||
# 不再在replyer中进行随机选择,全部交给处理器处理
|
||||
@@ -272,13 +310,6 @@ class DefaultReplyer:
|
||||
|
||||
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
||||
|
||||
async def build_mood_state_prompt(self) -> str:
|
||||
"""构建情绪状态提示"""
|
||||
if not global_config.mood.enable_mood:
|
||||
return ""
|
||||
mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood()
|
||||
return f"你现在的心情是:{mood_state}"
|
||||
|
||||
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
||||
"""构建工具信息块
|
||||
|
||||
@@ -459,6 +490,57 @@ class DefaultReplyer:
|
||||
duration = end_time - start_time
|
||||
return name, result, duration
|
||||
|
||||
async def _build_disabled_jargon_explanation(self) -> str:
|
||||
"""当关闭黑话解释时使用的占位协程,避免额外的LLM调用"""
|
||||
return ""
|
||||
|
||||
async def _build_unknown_words_jargon(self, unknown_words: Optional[List[str]], chat_id: str) -> str:
|
||||
"""针对 Planner 提供的未知词语列表执行黑话检索"""
|
||||
if not unknown_words:
|
||||
return ""
|
||||
# 清洗未知词语列表,只保留非空字符串
|
||||
concepts: List[str] = []
|
||||
for item in unknown_words:
|
||||
if isinstance(item, str):
|
||||
s = item.strip()
|
||||
if s:
|
||||
concepts.append(s)
|
||||
if not concepts:
|
||||
return ""
|
||||
try:
|
||||
return await retrieve_concepts_with_jargon(concepts, chat_id)
|
||||
except Exception as e:
|
||||
logger.error(f"未知词语黑话检索失败: {e}")
|
||||
return ""
|
||||
|
||||
async def _build_jargon_explanation(
|
||||
self,
|
||||
chat_id: str,
|
||||
messages_short: List[DatabaseMessages],
|
||||
chat_talking_prompt_short: str,
|
||||
unknown_words: Optional[List[str]],
|
||||
) -> str:
|
||||
"""
|
||||
统一的黑话解释构建函数:
|
||||
- 根据 enable_jargon_explanation / jargon_mode 决定具体策略
|
||||
"""
|
||||
enable_jargon_explanation = getattr(global_config.expression, "enable_jargon_explanation", True)
|
||||
if not enable_jargon_explanation:
|
||||
return ""
|
||||
|
||||
jargon_mode = getattr(global_config.expression, "jargon_mode", "context")
|
||||
|
||||
# planner 模式:仅使用 Planner 的 unknown_words
|
||||
if jargon_mode == "planner":
|
||||
return await self._build_unknown_words_jargon(unknown_words, chat_id)
|
||||
|
||||
# 默认 / context 模式:使用上下文自动匹配黑话
|
||||
try:
|
||||
return await explain_jargon_in_context(chat_id, messages_short, chat_talking_prompt_short)
|
||||
except Exception as e:
|
||||
logger.error(f"上下文黑话解释失败: {e}")
|
||||
return ""
|
||||
|
||||
def build_chat_history_prompts(
|
||||
self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str
|
||||
) -> Tuple[str, str]:
|
||||
@@ -606,7 +688,7 @@ class DefaultReplyer:
|
||||
# 获取基础personality
|
||||
prompt_personality = global_config.personality.personality
|
||||
|
||||
# 检查是否需要随机替换为状态
|
||||
# 检查是否需要随机替换为状态(personality 本体)
|
||||
if (
|
||||
global_config.personality.states
|
||||
and global_config.personality.state_probability > 0
|
||||
@@ -643,16 +725,10 @@ class DefaultReplyer:
|
||||
# 判断是否为群聊
|
||||
is_group = stream_type == "group"
|
||||
|
||||
# 使用与 ChatStream.get_stream_id 相同的逻辑生成 chat_id
|
||||
import hashlib
|
||||
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
chat_id = hashlib.md5(key.encode()).hexdigest()
|
||||
# 使用 ChatManager 提供的接口生成 chat_id,避免在此重复实现逻辑
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
chat_id = get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
|
||||
return chat_id, prompt_content
|
||||
|
||||
except (ValueError, IndexError):
|
||||
@@ -705,7 +781,9 @@ class DefaultReplyer:
|
||||
chosen_actions: Optional[List[ActionPlannerInfo]] = None,
|
||||
enable_tool: bool = True,
|
||||
reply_time_point: Optional[float] = time.time(),
|
||||
) -> Tuple[str, List[int]]:
|
||||
think_level: int = 1,
|
||||
unknown_words: Optional[List[str]] = None,
|
||||
) -> Tuple[str, List[int], List[str], str]:
|
||||
"""
|
||||
构建回复器上下文
|
||||
|
||||
@@ -751,14 +829,14 @@ class DefaultReplyer:
|
||||
chat_id=chat_id,
|
||||
timestamp=reply_time_point,
|
||||
limit=global_config.chat.max_context_size * 1,
|
||||
filter_no_read_command=True,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
|
||||
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=reply_time_point,
|
||||
limit=int(global_config.chat.max_context_size * 0.33),
|
||||
filter_no_read_command=True,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
|
||||
person_list_short: List[Person] = []
|
||||
@@ -789,10 +867,16 @@ class DefaultReplyer:
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
# 并行执行八个构建任务(包括黑话解释)
|
||||
# 统一黑话解释构建:根据配置选择上下文或 Planner 模式
|
||||
jargon_coroutine = self._build_jargon_explanation(
|
||||
chat_id, message_list_before_short, chat_talking_prompt_short, unknown_words
|
||||
)
|
||||
|
||||
# 并行执行构建任务(包括黑话解释,可配置关闭)
|
||||
task_results = await asyncio.gather(
|
||||
self._time_and_run_task(
|
||||
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason), "expression_habits"
|
||||
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason, think_level=think_level),
|
||||
"expression_habits",
|
||||
),
|
||||
self._time_and_run_task(
|
||||
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
|
||||
@@ -800,17 +884,13 @@ class DefaultReplyer:
|
||||
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
|
||||
self._time_and_run_task(self.build_actions_prompt(available_actions, chosen_actions), "actions_info"),
|
||||
self._time_and_run_task(self.build_personality_prompt(), "personality_prompt"),
|
||||
self._time_and_run_task(self.build_mood_state_prompt(), "mood_state_prompt"),
|
||||
self._time_and_run_task(
|
||||
build_memory_retrieval_prompt(
|
||||
chat_talking_prompt_short, sender, target, self.chat_stream, self.tool_executor
|
||||
chat_talking_prompt_short, sender, target, self.chat_stream, think_level=think_level
|
||||
),
|
||||
"memory_retrieval",
|
||||
),
|
||||
self._time_and_run_task(
|
||||
explain_jargon_in_context(chat_id, message_list_before_short, chat_talking_prompt_short),
|
||||
"jargon_explanation",
|
||||
),
|
||||
self._time_and_run_task(jargon_coroutine, "jargon_explanation"),
|
||||
)
|
||||
|
||||
# 任务名称中英文映射
|
||||
@@ -821,7 +901,6 @@ class DefaultReplyer:
|
||||
"prompt_info": "获取知识",
|
||||
"actions_info": "动作信息",
|
||||
"personality_prompt": "人格信息",
|
||||
"mood_state_prompt": "情绪状态",
|
||||
"memory_retrieval": "记忆检索",
|
||||
"jargon_explanation": "黑话解释",
|
||||
}
|
||||
@@ -839,7 +918,8 @@ class DefaultReplyer:
|
||||
continue
|
||||
|
||||
timing_logs.append(f"{chinese_name}: {duration:.1f}s")
|
||||
logger.info(f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.1s")
|
||||
# 不再在这里输出日志,而是返回给调用者统一输出
|
||||
# logger.info(f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.1s")
|
||||
|
||||
expression_habits_block, selected_expressions = results_dict["expression_habits"]
|
||||
expression_habits_block: str
|
||||
@@ -851,14 +931,8 @@ class DefaultReplyer:
|
||||
personality_prompt: str = results_dict["personality_prompt"]
|
||||
memory_retrieval: str = results_dict["memory_retrieval"]
|
||||
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
|
||||
mood_state_prompt: str = results_dict["mood_state_prompt"]
|
||||
jargon_explanation: str = results_dict.get("jargon_explanation") or ""
|
||||
|
||||
# 从 chosen_actions 中提取 planner 的整体思考理由
|
||||
planner_reasoning = ""
|
||||
if global_config.chat.include_planner_reasoning and reply_reason:
|
||||
# 如果没有 chosen_actions,使用 reply_reason 作为备选
|
||||
planner_reasoning = f"你的想法是:{reply_reason}"
|
||||
planner_reasoning = f"你的想法是:{reply_reason}"
|
||||
|
||||
if extra_info:
|
||||
extra_info_block = f"以下是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策\n{extra_info}\n以上是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策"
|
||||
@@ -893,14 +967,31 @@ class DefaultReplyer:
|
||||
chat_prompt_content = self.get_chat_prompt_for_chat(chat_id)
|
||||
chat_prompt_block = f"{chat_prompt_content}\n" if chat_prompt_content else ""
|
||||
|
||||
# 固定使用群聊回复模板
|
||||
# 根据think_level选择不同的回复模板
|
||||
# think_level=0: 轻量回复(简短平淡)
|
||||
# think_level=1: 中等回复(日常口语化)
|
||||
if think_level == 0:
|
||||
prompt_name = "replyer_prompt_0"
|
||||
else: # think_level == 1 或默认
|
||||
prompt_name = "replyer_prompt"
|
||||
|
||||
# 根据配置构建最终的 reply_style:支持 multiple_reply_style 按概率随机替换
|
||||
reply_style = global_config.personality.reply_style
|
||||
multi_styles = getattr(global_config.personality, "multiple_reply_style", None) or []
|
||||
multi_prob = getattr(global_config.personality, "multiple_probability", 0.0) or 0.0
|
||||
if multi_styles and multi_prob > 0 and random.random() < multi_prob:
|
||||
try:
|
||||
reply_style = random.choice(list(multi_styles))
|
||||
except Exception:
|
||||
# 兜底:即使 multiple_reply_style 配置异常也不影响正常回复
|
||||
reply_style = global_config.personality.reply_style
|
||||
|
||||
return await global_prompt_manager.format_prompt(
|
||||
"replyer_prompt",
|
||||
prompt_name,
|
||||
expression_habits_block=expression_habits_block,
|
||||
tool_info_block=tool_info,
|
||||
bot_name=global_config.bot.nickname,
|
||||
knowledge_prompt=prompt_info,
|
||||
mood_state=mood_state_prompt,
|
||||
# relation_info_block=relation_info,
|
||||
extra_info_block=extra_info_block,
|
||||
jargon_explanation=jargon_explanation,
|
||||
@@ -910,13 +1001,13 @@ class DefaultReplyer:
|
||||
dialogue_prompt=dialogue_prompt,
|
||||
time_block=time_block,
|
||||
reply_target_block=reply_target_block,
|
||||
reply_style=global_config.personality.reply_style,
|
||||
reply_style=reply_style,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
memory_retrieval=memory_retrieval,
|
||||
chat_prompt=chat_prompt_block,
|
||||
planner_reasoning=planner_reasoning,
|
||||
), selected_expressions
|
||||
), selected_expressions, timing_logs, almost_zero_str
|
||||
|
||||
async def build_prompt_rewrite_context(
|
||||
self,
|
||||
@@ -926,8 +1017,6 @@ class DefaultReplyer:
|
||||
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
|
||||
sender, target = self._parse_reply_target(reply_to)
|
||||
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
|
||||
|
||||
@@ -941,7 +1030,7 @@ class DefaultReplyer:
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
|
||||
filter_no_read_command=True,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
chat_talking_prompt_half = build_readable_messages(
|
||||
message_list_before_now_half,
|
||||
@@ -967,61 +1056,42 @@ class DefaultReplyer:
|
||||
|
||||
if sender and target:
|
||||
# 使用预先分析的内容类型结果
|
||||
if is_group_chat:
|
||||
if sender:
|
||||
if has_only_pics and not has_text:
|
||||
# 只包含图片
|
||||
reply_target_block = (
|
||||
f"现在{sender}发送的图片:{pic_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
else:
|
||||
# 只包含文字
|
||||
reply_target_block = (
|
||||
f"现在{sender}说的:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
elif target:
|
||||
reply_target_block = f"现在{target}引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
if sender:
|
||||
if has_only_pics and not has_text:
|
||||
# 只包含图片
|
||||
reply_target_block = (
|
||||
f"现在{sender}发送的图片:{pic_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
else:
|
||||
reply_target_block = "现在,你想要在群里发言或者回复消息。"
|
||||
else: # private chat
|
||||
if sender:
|
||||
if has_only_pics and not has_text:
|
||||
# 只包含图片
|
||||
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = (
|
||||
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||
)
|
||||
else:
|
||||
# 只包含文字
|
||||
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||
elif target:
|
||||
reply_target_block = f"现在{target}引起了你的注意,针对这条消息回复。"
|
||||
else:
|
||||
reply_target_block = "现在,你想要回复。"
|
||||
# 只包含文字
|
||||
reply_target_block = (
|
||||
f"现在{sender}说的:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
elif target:
|
||||
reply_target_block = f"现在{target}引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
else:
|
||||
reply_target_block = "现在,你想要在群里发言或者回复消息。"
|
||||
else:
|
||||
reply_target_block = ""
|
||||
|
||||
if is_group_chat:
|
||||
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
|
||||
chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
|
||||
else:
|
||||
chat_target_name = "对方"
|
||||
if self.chat_target_info:
|
||||
chat_target_name = self.chat_target_info.person_name or self.chat_target_info.user_nickname or "对方"
|
||||
chat_target_1 = await global_prompt_manager.format_prompt(
|
||||
"chat_target_private1", sender_name=chat_target_name
|
||||
)
|
||||
chat_target_2 = await global_prompt_manager.format_prompt(
|
||||
"chat_target_private2", sender_name=chat_target_name
|
||||
)
|
||||
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
|
||||
chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
|
||||
|
||||
template_name = "default_expressor_prompt"
|
||||
|
||||
# 根据配置构建最终的 reply_style:支持 multiple_reply_style 按概率随机替换
|
||||
reply_style = global_config.personality.reply_style
|
||||
multi_styles = getattr(global_config.personality, "multiple_reply_style", None) or []
|
||||
multi_prob = getattr(global_config.personality, "multiple_probability", 0.0) or 0.0
|
||||
if multi_styles and multi_prob > 0 and random.random() < multi_prob:
|
||||
try:
|
||||
reply_style = random.choice(list(multi_styles))
|
||||
except Exception:
|
||||
reply_style = global_config.personality.reply_style
|
||||
|
||||
return await global_prompt_manager.format_prompt(
|
||||
template_name,
|
||||
expression_habits_block=expression_habits_block,
|
||||
@@ -1034,7 +1104,7 @@ class DefaultReplyer:
|
||||
reply_target_block=reply_target_block,
|
||||
raw_reply=raw_reply,
|
||||
reason=reason,
|
||||
reply_style=global_config.personality.reply_style,
|
||||
reply_style=reply_style,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
)
|
||||
@@ -1078,10 +1148,11 @@ class DefaultReplyer:
|
||||
# 直接使用已初始化的模型实例
|
||||
# logger.info(f"\n{prompt}\n")
|
||||
|
||||
if global_config.debug.show_replyer_prompt:
|
||||
logger.info(f"\n{prompt}\n")
|
||||
else:
|
||||
logger.debug(f"\nreplyer_Prompt:{prompt}\n")
|
||||
# 不再在这里输出日志,而是返回给调用者统一输出
|
||||
# if global_config.debug.show_replyer_prompt:
|
||||
# logger.info(f"\n{prompt}\n")
|
||||
# else:
|
||||
# logger.debug(f"\nreplyer_Prompt:{prompt}\n")
|
||||
|
||||
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
|
||||
prompt
|
||||
@@ -1090,7 +1161,7 @@ class DefaultReplyer:
|
||||
# 移除 content 前后的换行符和空格
|
||||
content = content.strip()
|
||||
|
||||
logger.info(f"使用 {model_name} 生成回复内容: {content}")
|
||||
# logger.info(f"使用 {model_name} 生成回复内容: {content}")
|
||||
return content, reasoning_content, model_name, tool_calls
|
||||
|
||||
async def get_prompt_info(self, message: str, sender: str, target: str):
|
||||
|
||||
@@ -23,9 +23,8 @@ from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
replace_user_references,
|
||||
)
|
||||
from src.express.expression_selector import expression_selector
|
||||
from src.bw_learner.expression_selector import expression_selector
|
||||
from src.plugin_system.apis.message_api import translate_pid_to_description
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
# from src.memory_system.memory_activator import MemoryActivator
|
||||
|
||||
@@ -34,13 +33,13 @@ from src.plugin_system.base.component_types import ActionInfo, EventType
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
from src.chat.replyer.prompt.lpmm_prompt import init_lpmm_prompt
|
||||
from src.chat.replyer.prompt.replyer_prompt import init_replyer_prompt
|
||||
from src.chat.replyer.prompt.replyer_private_prompt import init_replyer_private_prompt
|
||||
from src.chat.replyer.prompt.rewrite_prompt import init_rewrite_prompt
|
||||
from src.memory_system.memory_retrieval import init_memory_retrieval_prompt, build_memory_retrieval_prompt
|
||||
from src.jargon.jargon_explainer import explain_jargon_in_context
|
||||
from src.bw_learner.jargon_explainer import explain_jargon_in_context
|
||||
|
||||
init_lpmm_prompt()
|
||||
init_replyer_prompt()
|
||||
init_replyer_private_prompt()
|
||||
init_rewrite_prompt()
|
||||
init_memory_retrieval_prompt()
|
||||
|
||||
@@ -72,9 +71,11 @@ class PrivateReplyer:
|
||||
chosen_actions: Optional[List[ActionPlannerInfo]] = None,
|
||||
enable_tool: bool = True,
|
||||
from_plugin: bool = True,
|
||||
think_level: int = 1,
|
||||
stream_id: Optional[str] = None,
|
||||
reply_message: Optional[DatabaseMessages] = None,
|
||||
reply_time_point: Optional[float] = time.time(),
|
||||
unknown_words: Optional[List[str]] = None,
|
||||
) -> Tuple[bool, LLMGenerationDataModel]:
|
||||
# sourcery skip: merge-nested-ifs
|
||||
"""
|
||||
@@ -271,7 +272,7 @@ class PrivateReplyer:
|
||||
logger.debug(f"使用处理器选中的{len(selected_expressions)}个表达方式")
|
||||
for expr in selected_expressions:
|
||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
||||
style_habits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
style_habits.append(f"当{expr['situation']}时:{expr['style']}")
|
||||
else:
|
||||
logger.debug("没有从处理器获得表达方式,将使用空的表达方式")
|
||||
# 不再在replyer中进行随机选择,全部交给处理器处理
|
||||
@@ -287,13 +288,6 @@ class PrivateReplyer:
|
||||
|
||||
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
||||
|
||||
async def build_mood_state_prompt(self) -> str:
|
||||
"""构建情绪状态提示"""
|
||||
if not global_config.mood.enable_mood:
|
||||
return ""
|
||||
mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood()
|
||||
return f"你现在的心情是:{mood_state}"
|
||||
|
||||
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
||||
"""构建工具信息块
|
||||
|
||||
@@ -474,6 +468,10 @@ class PrivateReplyer:
|
||||
duration = end_time - start_time
|
||||
return name, result, duration
|
||||
|
||||
async def _build_disabled_jargon_explanation(self) -> str:
|
||||
"""当关闭黑话解释时使用的占位协程,避免额外的LLM调用"""
|
||||
return ""
|
||||
|
||||
async def build_actions_prompt(
|
||||
self, available_actions: Dict[str, ActionInfo], chosen_actions_info: Optional[List[ActionPlannerInfo]] = None
|
||||
) -> str:
|
||||
@@ -557,16 +555,10 @@ class PrivateReplyer:
|
||||
# 判断是否为群聊
|
||||
is_group = stream_type == "group"
|
||||
|
||||
# 使用与 ChatStream.get_stream_id 相同的逻辑生成 chat_id
|
||||
import hashlib
|
||||
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
chat_id = hashlib.md5(key.encode()).hexdigest()
|
||||
# 使用 ChatManager 提供的接口生成 chat_id,避免在此重复实现逻辑
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
chat_id = get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
|
||||
return chat_id, prompt_content
|
||||
|
||||
except (ValueError, IndexError):
|
||||
@@ -663,7 +655,7 @@ class PrivateReplyer:
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=global_config.chat.max_context_size,
|
||||
filter_no_read_command=True,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
|
||||
dialogue_prompt = build_readable_messages(
|
||||
@@ -678,7 +670,7 @@ class PrivateReplyer:
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.33),
|
||||
filter_no_read_command=True,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
|
||||
person_list_short: List[Person] = []
|
||||
@@ -709,7 +701,14 @@ class PrivateReplyer:
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
# 并行执行九个构建任务(包括黑话解释)
|
||||
# 根据配置决定是否启用黑话解释
|
||||
enable_jargon_explanation = getattr(global_config.expression, "enable_jargon_explanation", True)
|
||||
if enable_jargon_explanation:
|
||||
jargon_coroutine = explain_jargon_in_context(chat_id, message_list_before_short, chat_talking_prompt_short)
|
||||
else:
|
||||
jargon_coroutine = self._build_disabled_jargon_explanation()
|
||||
|
||||
# 并行执行九个构建任务(包括黑话解释,可配置关闭)
|
||||
task_results = await asyncio.gather(
|
||||
self._time_and_run_task(
|
||||
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason), "expression_habits"
|
||||
@@ -721,17 +720,13 @@ class PrivateReplyer:
|
||||
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
|
||||
self._time_and_run_task(self.build_actions_prompt(available_actions, chosen_actions), "actions_info"),
|
||||
self._time_and_run_task(self.build_personality_prompt(), "personality_prompt"),
|
||||
self._time_and_run_task(self.build_mood_state_prompt(), "mood_state_prompt"),
|
||||
self._time_and_run_task(
|
||||
build_memory_retrieval_prompt(
|
||||
chat_talking_prompt_short, sender, target, self.chat_stream, self.tool_executor
|
||||
),
|
||||
"memory_retrieval",
|
||||
),
|
||||
self._time_and_run_task(
|
||||
explain_jargon_in_context(chat_id, message_list_before_short, chat_talking_prompt_short),
|
||||
"jargon_explanation",
|
||||
),
|
||||
self._time_and_run_task(jargon_coroutine, "jargon_explanation"),
|
||||
)
|
||||
|
||||
# 任务名称中英文映射
|
||||
@@ -742,7 +737,6 @@ class PrivateReplyer:
|
||||
"prompt_info": "获取知识",
|
||||
"actions_info": "动作信息",
|
||||
"personality_prompt": "人格信息",
|
||||
"mood_state_prompt": "情绪状态",
|
||||
"memory_retrieval": "记忆检索",
|
||||
"jargon_explanation": "黑话解释",
|
||||
}
|
||||
@@ -770,16 +764,10 @@ class PrivateReplyer:
|
||||
prompt_info: str = results_dict["prompt_info"] # 直接使用格式化后的结果
|
||||
actions_info: str = results_dict["actions_info"]
|
||||
personality_prompt: str = results_dict["personality_prompt"]
|
||||
mood_state_prompt: str = results_dict["mood_state_prompt"]
|
||||
memory_retrieval: str = results_dict["memory_retrieval"]
|
||||
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
|
||||
jargon_explanation: str = results_dict.get("jargon_explanation") or ""
|
||||
|
||||
# 从 chosen_actions 中提取 planner 的整体思考理由
|
||||
planner_reasoning = ""
|
||||
if global_config.chat.include_planner_reasoning and reply_reason:
|
||||
# 如果没有 chosen_actions,使用 reply_reason 作为备选
|
||||
planner_reasoning = f"你的想法是:{reply_reason}"
|
||||
planner_reasoning = f"你的想法是:{reply_reason}"
|
||||
|
||||
if extra_info:
|
||||
extra_info_block = f"以下是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策\n{extra_info}\n以上是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策"
|
||||
@@ -814,7 +802,6 @@ class PrivateReplyer:
|
||||
expression_habits_block=expression_habits_block,
|
||||
tool_info_block=tool_info,
|
||||
knowledge_prompt=prompt_info,
|
||||
mood_state=mood_state_prompt,
|
||||
relation_info_block=relation_info,
|
||||
extra_info_block=extra_info_block,
|
||||
identity=personality_prompt,
|
||||
@@ -837,7 +824,6 @@ class PrivateReplyer:
|
||||
expression_habits_block=expression_habits_block,
|
||||
tool_info_block=tool_info,
|
||||
knowledge_prompt=prompt_info,
|
||||
mood_state=mood_state_prompt,
|
||||
relation_info_block=relation_info,
|
||||
extra_info_block=extra_info_block,
|
||||
identity=personality_prompt,
|
||||
@@ -878,7 +864,7 @@ class PrivateReplyer:
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
|
||||
filter_no_read_command=True,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
chat_talking_prompt_half = build_readable_messages(
|
||||
message_list_before_now_half,
|
||||
@@ -904,59 +890,30 @@ class PrivateReplyer:
|
||||
)
|
||||
|
||||
if sender and target:
|
||||
# 使用预先分析的内容类型结果
|
||||
if is_group_chat:
|
||||
if sender:
|
||||
if has_only_pics and not has_text:
|
||||
# 只包含图片
|
||||
reply_target_block = (
|
||||
f"现在{sender}发送的图片:{pic_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
else:
|
||||
# 只包含文字
|
||||
reply_target_block = (
|
||||
f"现在{sender}说的:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
elif target:
|
||||
reply_target_block = f"现在{target}引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
if sender:
|
||||
if has_only_pics and not has_text:
|
||||
# 只包含图片
|
||||
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = (
|
||||
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||
)
|
||||
else:
|
||||
reply_target_block = "现在,你想要在群里发言或者回复消息。"
|
||||
else: # private chat
|
||||
if sender:
|
||||
if has_only_pics and not has_text:
|
||||
# 只包含图片
|
||||
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = (
|
||||
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||
)
|
||||
else:
|
||||
# 只包含文字
|
||||
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||
elif target:
|
||||
reply_target_block = f"现在{target}引起了你的注意,针对这条消息回复。"
|
||||
else:
|
||||
reply_target_block = "现在,你想要回复。"
|
||||
# 只包含文字
|
||||
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||
elif target:
|
||||
reply_target_block = f"现在{target}引起了你的注意,针对这条消息回复。"
|
||||
else:
|
||||
reply_target_block = "现在,你想要回复。"
|
||||
else:
|
||||
reply_target_block = ""
|
||||
|
||||
if is_group_chat:
|
||||
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
|
||||
chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
|
||||
else:
|
||||
chat_target_name = "对方"
|
||||
if self.chat_target_info:
|
||||
chat_target_name = self.chat_target_info.person_name or self.chat_target_info.user_nickname or "对方"
|
||||
chat_target_1 = await global_prompt_manager.format_prompt(
|
||||
"chat_target_private1", sender_name=chat_target_name
|
||||
)
|
||||
chat_target_2 = await global_prompt_manager.format_prompt(
|
||||
"chat_target_private2", sender_name=chat_target_name
|
||||
)
|
||||
chat_target_name = "对方"
|
||||
if self.chat_target_info:
|
||||
chat_target_name = self.chat_target_info.person_name or self.chat_target_info.user_nickname or "对方"
|
||||
chat_target_1 = await global_prompt_manager.format_prompt("chat_target_private1", sender_name=chat_target_name)
|
||||
chat_target_2 = await global_prompt_manager.format_prompt("chat_target_private2", sender_name=chat_target_name)
|
||||
|
||||
template_name = "default_expressor_prompt"
|
||||
|
||||
|
||||
41
src/chat/replyer/prompt/replyer_private_prompt.py
Normal file
41
src/chat/replyer/prompt/replyer_private_prompt.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from src.chat.utils.prompt_builder import Prompt
|
||||
|
||||
|
||||
def init_replyer_private_prompt():
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}{memory_retrieval}{jargon_explanation}
|
||||
|
||||
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
||||
{time_block}
|
||||
{dialogue_prompt}
|
||||
|
||||
{reply_target_block}。
|
||||
{planner_reasoning}
|
||||
{identity}
|
||||
{chat_prompt}你正在和{sender_name}聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些,
|
||||
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""",
|
||||
"private_replyer_prompt",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}{memory_retrieval}{jargon_explanation}
|
||||
|
||||
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
||||
{time_block}
|
||||
{dialogue_prompt}
|
||||
|
||||
你现在想补充说明你刚刚自己的发言内容:{target},原因是{reason}
|
||||
请你根据聊天内容,组织一条新回复。注意,{target} 是刚刚你自己的发言,你要在这基础上进一步发言,请按照你自己的角度来继续进行回复。注意保持上下文的连贯性。
|
||||
{identity}
|
||||
{chat_prompt}尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||
{moderation_prompt}不要输出多余内容(包括冒号和引号,括号,表情包,at或 @等 )。
|
||||
""",
|
||||
"private_replyer_self_prompt",
|
||||
)
|
||||
@@ -3,8 +3,27 @@ from src.chat.utils.prompt_builder import Prompt
|
||||
|
||||
|
||||
def init_replyer_prompt():
|
||||
Prompt("正在群里聊天", "chat_target_group2")
|
||||
Prompt("和{sender_name}聊天", "chat_target_private2")
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}{memory_retrieval}{jargon_explanation}
|
||||
|
||||
你正在qq群里聊天,下面是群里正在聊的内容,其中包含聊天记录和聊天中的图片
|
||||
其中标注 {bot_name}(你) 的发言是你自己的发言,请注意区分:
|
||||
{time_block}
|
||||
{dialogue_prompt}
|
||||
|
||||
{reply_target_block}。
|
||||
{planner_reasoning}
|
||||
{identity}
|
||||
{chat_prompt}你正在群里聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复,
|
||||
尽量简短一些。{keywords_reaction_prompt}
|
||||
请注意把握聊天内容,不要回复的太有条理。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括不必要的前后缀,冒号,括号,表情包,at或 @等 ),只输出发言内容就好。
|
||||
最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。
|
||||
现在,你说:""",
|
||||
"replyer_prompt_0",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
@@ -18,49 +37,12 @@ def init_replyer_prompt():
|
||||
{reply_target_block}。
|
||||
{planner_reasoning}
|
||||
{identity}
|
||||
{chat_prompt}你正在群里聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些,{mood_state}
|
||||
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
|
||||
{chat_prompt}你正在群里聊天,现在请你读读之前的聊天记录,把握当前的话题,然后给出日常且简短的回复。
|
||||
最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。
|
||||
{keywords_reaction_prompt}
|
||||
请注意把握聊天内容。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出一句回复内容就好。
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。
|
||||
请注意不要输出多余内容(包括不必要的前后缀,冒号,括号,at或 @等 ),只输出发言内容就好。
|
||||
现在,你说:""",
|
||||
"replyer_prompt",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}{memory_retrieval}{jargon_explanation}
|
||||
|
||||
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
||||
{time_block}
|
||||
{dialogue_prompt}
|
||||
|
||||
{reply_target_block}。
|
||||
{planner_reasoning}
|
||||
{identity}
|
||||
{chat_prompt}你正在和{sender_name}聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些,{mood_state}
|
||||
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""",
|
||||
"private_replyer_prompt",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}{memory_retrieval}{jargon_explanation}
|
||||
|
||||
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
||||
{time_block}
|
||||
{dialogue_prompt}
|
||||
|
||||
你现在想补充说明你刚刚自己的发言内容:{target},原因是{reason}
|
||||
请你根据聊天内容,组织一条新回复。注意,{target} 是刚刚你自己的发言,你要在这基础上进一步发言,请按照你自己的角度来继续进行回复。注意保持上下文的连贯性。{mood_state}
|
||||
{identity}
|
||||
{chat_prompt}尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||
{moderation_prompt}不要输出多余内容(包括冒号和引号,括号,表情包,at或 @等 )。
|
||||
""",
|
||||
"private_replyer_self_prompt",
|
||||
)
|
||||
|
||||
@@ -120,7 +120,7 @@ def get_raw_msg_by_timestamp_with_chat(
|
||||
limit_mode: str = "latest",
|
||||
filter_bot=False,
|
||||
filter_command=False,
|
||||
filter_no_read_command=False,
|
||||
filter_intercept_message_level: Optional[int] = None,
|
||||
) -> List[DatabaseMessages]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
@@ -138,7 +138,7 @@ def get_raw_msg_by_timestamp_with_chat(
|
||||
limit_mode=limit_mode,
|
||||
filter_bot=filter_bot,
|
||||
filter_command=filter_command,
|
||||
filter_no_read_command=filter_no_read_command,
|
||||
filter_intercept_message_level=filter_intercept_message_level,
|
||||
)
|
||||
|
||||
|
||||
@@ -150,7 +150,7 @@ def get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
limit_mode: str = "latest",
|
||||
filter_bot=False,
|
||||
filter_command=False,
|
||||
filter_no_read_command=False,
|
||||
filter_intercept_message_level: Optional[int] = None,
|
||||
) -> List[DatabaseMessages]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的消息(包含边界),按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
@@ -167,7 +167,7 @@ def get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
limit_mode=limit_mode,
|
||||
filter_bot=filter_bot,
|
||||
filter_command=filter_command,
|
||||
filter_no_read_command=filter_no_read_command,
|
||||
filter_intercept_message_level=filter_intercept_message_level,
|
||||
)
|
||||
|
||||
|
||||
@@ -303,7 +303,7 @@ def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Datab
|
||||
|
||||
|
||||
def get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id: str, timestamp: float, limit: int = 0, filter_no_read_command: bool = False
|
||||
chat_id: str, timestamp: float, limit: int = 0, filter_intercept_message_level: Optional[int] = None
|
||||
) -> List[DatabaseMessages]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
@@ -311,7 +311,10 @@ def get_raw_msg_before_timestamp_with_chat(
|
||||
filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
|
||||
sort_order = [("time", 1)]
|
||||
return find_messages(
|
||||
message_filter=filter_query, sort=sort_order, limit=limit, filter_no_read_command=filter_no_read_command
|
||||
message_filter=filter_query,
|
||||
sort=sort_order,
|
||||
limit=limit,
|
||||
filter_intercept_message_level=filter_intercept_message_level,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Any, Dict, Tuple, List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import OnlineTime, LLMUsage, Messages
|
||||
from src.common.database.database_model import OnlineTime, LLMUsage, Messages, ActionRecords
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
from src.manager.local_store_manager import local_storage
|
||||
from src.config.config import global_config
|
||||
@@ -505,13 +505,6 @@ class StatisticOutputTask(AsyncTask):
|
||||
for period_key, _ in collect_period
|
||||
}
|
||||
|
||||
# 获取bot的QQ账号
|
||||
bot_qq_account = (
|
||||
str(global_config.bot.qq_account)
|
||||
if hasattr(global_config, "bot") and hasattr(global_config.bot, "qq_account")
|
||||
else ""
|
||||
)
|
||||
|
||||
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
|
||||
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
|
||||
message_time_ts = message.time # This is a float timestamp
|
||||
@@ -537,7 +530,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
if not chat_id: # Should not happen if above logic is correct
|
||||
continue
|
||||
|
||||
# Update name_mapping
|
||||
# Update name_mapping(仅用于展示聊天名称)
|
||||
try:
|
||||
if chat_id in self.name_mapping:
|
||||
if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
|
||||
@@ -549,19 +542,30 @@ class StatisticOutputTask(AsyncTask):
|
||||
# 重置为正确的格式
|
||||
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
||||
|
||||
# 检查是否是bot发送的消息(回复)
|
||||
is_bot_reply = False
|
||||
if bot_qq_account and message.user_id == bot_qq_account:
|
||||
is_bot_reply = True
|
||||
|
||||
for idx, (_, period_start_dt) in enumerate(collect_period):
|
||||
if message_time_ts >= period_start_dt.timestamp():
|
||||
for period_key, _ in collect_period[idx:]:
|
||||
stats[period_key][TOTAL_MSG_CNT] += 1
|
||||
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
|
||||
if is_bot_reply:
|
||||
stats[period_key][TOTAL_REPLY_CNT] += 1
|
||||
break
|
||||
|
||||
# 使用 ActionRecords 中的 reply 动作次数作为回复数基准
|
||||
try:
|
||||
action_query_start_timestamp = collect_period[-1][1].timestamp()
|
||||
for action in ActionRecords.select().where(ActionRecords.time >= action_query_start_timestamp): # type: ignore
|
||||
# 仅统计已完成的 reply 动作
|
||||
if action.action_name != "reply" or not action.action_done:
|
||||
continue
|
||||
|
||||
action_time_ts = action.time
|
||||
for idx, (_, period_start_dt) in enumerate(collect_period):
|
||||
if action_time_ts >= period_start_dt.timestamp():
|
||||
for period_key, _ in collect_period[idx:]:
|
||||
stats[period_key][TOTAL_REPLY_CNT] += 1
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"统计 reply 动作次数失败,将回复数视为 0,错误信息:{e}")
|
||||
|
||||
return stats
|
||||
|
||||
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
|
||||
@@ -742,7 +746,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
|
||||
|
||||
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
|
||||
|
||||
|
||||
output = [
|
||||
"按模型分类统计:",
|
||||
" 模型名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
|
||||
@@ -755,11 +759,11 @@ class StatisticOutputTask(AsyncTask):
|
||||
cost = stats[COST_BY_MODEL][model_name]
|
||||
avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name]
|
||||
std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name]
|
||||
|
||||
|
||||
# 计算每次回复平均值
|
||||
avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
|
||||
avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
|
||||
|
||||
|
||||
# 格式化大数字
|
||||
formatted_count = _format_large_number(count)
|
||||
formatted_in_tokens = _format_large_number(in_tokens)
|
||||
@@ -767,7 +771,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
formatted_tokens = _format_large_number(tokens)
|
||||
formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
|
||||
formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
|
||||
|
||||
|
||||
output.append(
|
||||
data_fmt.format(
|
||||
name,
|
||||
@@ -796,7 +800,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
|
||||
|
||||
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
|
||||
|
||||
|
||||
output = [
|
||||
"按模块分类统计:",
|
||||
" 模块名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
|
||||
@@ -809,11 +813,11 @@ class StatisticOutputTask(AsyncTask):
|
||||
cost = stats[COST_BY_MODULE][module_name]
|
||||
avg_time_cost = stats[AVG_TIME_COST_BY_MODULE][module_name]
|
||||
std_time_cost = stats[STD_TIME_COST_BY_MODULE][module_name]
|
||||
|
||||
|
||||
# 计算每次回复平均值
|
||||
avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
|
||||
avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
|
||||
|
||||
|
||||
# 格式化大数字
|
||||
formatted_count = _format_large_number(count)
|
||||
formatted_in_tokens = _format_large_number(in_tokens)
|
||||
@@ -821,7 +825,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
formatted_tokens = _format_large_number(tokens)
|
||||
formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
|
||||
formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
|
||||
|
||||
|
||||
output.append(
|
||||
data_fmt.format(
|
||||
name,
|
||||
|
||||
@@ -4,6 +4,8 @@ import time
|
||||
import jieba
|
||||
import json
|
||||
import ast
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from typing import Optional, Tuple, List, TYPE_CHECKING
|
||||
|
||||
@@ -196,21 +198,54 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
|
||||
List[str]: 分割和合并后的句子列表
|
||||
"""
|
||||
# 预处理:处理多余的换行符
|
||||
# 1. 将连续的换行符替换为单个换行符
|
||||
# 1. 将连续的换行符替换为单个换行符(保留换行符用于分割)
|
||||
text = re.sub(r"\n\s*\n+", "\n", text)
|
||||
# 2. 处理换行符和其他分隔符的组合
|
||||
text = re.sub(r"\n\s*([,,。;\s])", r"\1", text)
|
||||
text = re.sub(r"([,,。;\s])\s*\n", r"\1", text)
|
||||
# 2. 处理换行符和其他分隔符的组合(保留换行符,删除其他分隔符)
|
||||
text = re.sub(r"\n\s*([,,。;\s])", r"\n\1", text)
|
||||
text = re.sub(r"([,,。;\s])\s*\n", r"\1\n", text)
|
||||
|
||||
# 处理两个汉字中间的换行符
|
||||
text = re.sub(r"([\u4e00-\u9fff])\n([\u4e00-\u9fff])", r"\1。\2", text)
|
||||
# 处理两个汉字中间的换行符(保留换行符,不替换为句号,让换行符强制分割)
|
||||
# text = re.sub(r"([\u4e00-\u9fff])\n([\u4e00-\u9fff])", r"\1。\2", text) # 注释掉,保留换行符用于分割
|
||||
|
||||
len_text = len(text)
|
||||
if len_text < 3:
|
||||
return list(text) if random.random() < 0.01 else [text]
|
||||
|
||||
# 定义分隔符
|
||||
separators = {",", ",", " ", "。", ";"}
|
||||
# 先标记哪些位置位于成对引号内部,避免在引号内部进行句子分割
|
||||
# 支持的引号包括:中英文单/双引号和常见中文书名号/引号
|
||||
quote_chars = {
|
||||
'"',
|
||||
"'",
|
||||
"“",
|
||||
"”",
|
||||
"‘",
|
||||
"’",
|
||||
"「",
|
||||
"」",
|
||||
"『",
|
||||
"』",
|
||||
}
|
||||
inside_quote = [False] * len_text
|
||||
in_quote = False
|
||||
current_quote_char = ""
|
||||
for idx, ch in enumerate(text):
|
||||
if ch in quote_chars:
|
||||
# 遇到引号时切换状态(英文引号本身开闭相同,用同一个字符表示)
|
||||
if not in_quote:
|
||||
in_quote = True
|
||||
current_quote_char = ch
|
||||
inside_quote[idx] = False
|
||||
else:
|
||||
# 只有遇到同一类引号才视为关闭
|
||||
if ch == current_quote_char or ch in {'"', "'"} and current_quote_char in {'"', "'"}:
|
||||
in_quote = False
|
||||
current_quote_char = ""
|
||||
inside_quote[idx] = False
|
||||
else:
|
||||
inside_quote[idx] = in_quote
|
||||
|
||||
# 定义分隔符(包含换行符)
|
||||
separators = {",", ",", " ", "。", ";", "\n"}
|
||||
segments = []
|
||||
current_segment = ""
|
||||
|
||||
@@ -219,24 +254,42 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
|
||||
while i < len(text):
|
||||
char = text[i]
|
||||
if char in separators:
|
||||
# 检查分割条件:如果空格左右都是英文字母、数字,或数字和英文之间,则不分割(仅对空格应用此规则)
|
||||
can_split = True
|
||||
if 0 < i < len(text) - 1:
|
||||
prev_char = text[i - 1]
|
||||
next_char = text[i + 1]
|
||||
# 只对空格应用"不分割数字和数字、数字和英文、英文和数字、英文和英文之间的空格"规则
|
||||
if char == " ":
|
||||
prev_is_alnum = prev_char.isdigit() or is_english_letter(prev_char)
|
||||
next_is_alnum = next_char.isdigit() or is_english_letter(next_char)
|
||||
if prev_is_alnum and next_is_alnum:
|
||||
can_split = False
|
||||
# 引号内部一律不作为分割点(包括换行)
|
||||
if inside_quote[i]:
|
||||
can_split = False
|
||||
else:
|
||||
# 换行符在不在引号内时都强制分割
|
||||
if char == "\n":
|
||||
can_split = True
|
||||
else:
|
||||
# 检查分割条件
|
||||
can_split = True
|
||||
# 检查分隔符左右是否有冒号(中英文),如果有则不分割
|
||||
if i > 0:
|
||||
prev_char = text[i - 1]
|
||||
if prev_char in {":", ":"}:
|
||||
can_split = False
|
||||
if i < len(text) - 1:
|
||||
next_char = text[i + 1]
|
||||
if next_char in {":", ":"}:
|
||||
can_split = False
|
||||
|
||||
# 如果左右没有冒号,再检查空格的特殊情况
|
||||
if can_split and char == " " and i > 0 and i < len(text) - 1:
|
||||
prev_char = text[i - 1]
|
||||
next_char = text[i + 1]
|
||||
# 不分割数字和数字、数字和英文、英文和数字、英文和英文之间的空格
|
||||
prev_is_alnum = prev_char.isdigit() or is_english_letter(prev_char)
|
||||
next_is_alnum = next_char.isdigit() or is_english_letter(next_char)
|
||||
if prev_is_alnum and next_is_alnum:
|
||||
can_split = False
|
||||
|
||||
if can_split:
|
||||
# 只有当当前段不为空时才添加
|
||||
if current_segment:
|
||||
segments.append((current_segment, char))
|
||||
# 如果当前段为空,但分隔符是空格,则也添加一个空段(保留空格)
|
||||
elif char == " ":
|
||||
# 如果当前段为空,但分隔符是空格或换行符,则也添加一个空段(保留分隔符)
|
||||
elif char in {" ", "\n"}:
|
||||
segments.append(("", char))
|
||||
current_segment = ""
|
||||
else:
|
||||
@@ -641,6 +694,42 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP
|
||||
return is_group_chat, chat_target_info
|
||||
|
||||
|
||||
def record_replyer_action_temp(chat_id: str, reason: str, think_level: int) -> None:
|
||||
"""
|
||||
临时记录replyer动作被选择的信息(仅群聊)
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
reason: 选择理由
|
||||
think_level: 思考深度等级
|
||||
"""
|
||||
try:
|
||||
# 确保data/temp目录存在
|
||||
temp_dir = "data/temp"
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
# 创建记录数据
|
||||
record_data = {
|
||||
"chat_id": chat_id,
|
||||
"reason": reason,
|
||||
"think_level": think_level,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
# 生成文件名(使用时间戳避免冲突)
|
||||
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
filename = f"replyer_action_{timestamp_str}.json"
|
||||
filepath = os.path.join(temp_dir, filename)
|
||||
|
||||
# 写入文件
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(record_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.debug(f"已记录replyer动作选择: chat_id={chat_id}, think_level={think_level}")
|
||||
except Exception as e:
|
||||
logger.warning(f"记录replyer动作选择失败: {e}")
|
||||
|
||||
|
||||
def assign_message_ids(messages: List[DatabaseMessages]) -> List[Tuple[str, DatabaseMessages]]:
|
||||
"""
|
||||
为消息列表中的每个消息分配唯一的简短随机ID
|
||||
|
||||
@@ -130,12 +130,10 @@ class ImageManager:
|
||||
try:
|
||||
# 清理Images表中type为emoji的记录
|
||||
deleted_images = Images.delete().where(Images.type == "emoji").execute()
|
||||
|
||||
|
||||
# 清理ImageDescriptions表中type为emoji的记录
|
||||
deleted_descriptions = (
|
||||
ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
|
||||
)
|
||||
|
||||
deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
|
||||
|
||||
total_deleted = deleted_images + deleted_descriptions
|
||||
if total_deleted > 0:
|
||||
logger.info(
|
||||
@@ -166,7 +164,7 @@ class ImageManager:
|
||||
|
||||
async def _save_emoji_file_if_needed(self, image_base64: str, image_hash: str, image_format: str) -> None:
|
||||
"""如果启用了steal_emoji且表情包未注册,保存文件到data/emoji目录
|
||||
|
||||
|
||||
Args:
|
||||
image_base64: 图片的base64编码
|
||||
image_hash: 图片的MD5哈希值
|
||||
@@ -174,7 +172,7 @@ class ImageManager:
|
||||
"""
|
||||
if not global_config.emoji.steal_emoji:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
from src.chat.emoji_system.emoji_manager import EMOJI_DIR
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
@@ -236,12 +234,16 @@ class ImageManager:
|
||||
# 优先使用情感标签,如果没有则使用详细描述
|
||||
result_text = ""
|
||||
if cache_record.emotion_tags:
|
||||
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}...")
|
||||
logger.info(
|
||||
f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..."
|
||||
)
|
||||
result_text = f"[表情包:{cache_record.emotion_tags}]"
|
||||
elif cache_record.description:
|
||||
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}...")
|
||||
logger.info(
|
||||
f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..."
|
||||
)
|
||||
result_text = f"[表情包:{cache_record.description}]"
|
||||
|
||||
|
||||
# 即使缓存命中,如果启用了steal_emoji,也检查是否需要保存文件
|
||||
if result_text:
|
||||
await self._save_emoji_file_if_needed(image_base64, image_hash, image_format)
|
||||
|
||||
@@ -77,7 +77,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
is_emoji: bool = False,
|
||||
is_picid: bool = False,
|
||||
is_command: bool = False,
|
||||
is_no_read_command: bool = False,
|
||||
intercept_message_level: int = 0,
|
||||
is_notify: bool = False,
|
||||
selected_expressions: Optional[str] = None,
|
||||
user_id: str = "",
|
||||
@@ -120,7 +120,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
self.is_emoji = is_emoji
|
||||
self.is_picid = is_picid
|
||||
self.is_command = is_command
|
||||
self.is_no_read_command = is_no_read_command
|
||||
self.intercept_message_level = intercept_message_level
|
||||
self.is_notify = is_notify
|
||||
|
||||
self.selected_expressions = selected_expressions
|
||||
@@ -188,7 +188,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
"is_emoji": self.is_emoji,
|
||||
"is_picid": self.is_picid,
|
||||
"is_command": self.is_command,
|
||||
"is_no_read_command": self.is_no_read_command,
|
||||
"intercept_message_level": self.intercept_message_level,
|
||||
"is_notify": self.is_notify,
|
||||
"selected_expressions": self.selected_expressions,
|
||||
"user_id": self.user_info.user_id,
|
||||
|
||||
@@ -22,7 +22,7 @@ class MessageAndActionModel(BaseDataModel):
|
||||
is_action_record: bool = field(default=False)
|
||||
action_name: Optional[str] = None
|
||||
is_command: bool = field(default=False)
|
||||
is_no_read_command: bool = field(default=False)
|
||||
intercept_message_level: int = field(default=0)
|
||||
|
||||
@classmethod
|
||||
def from_DatabaseMessages(cls, message: "DatabaseMessages"):
|
||||
@@ -37,7 +37,7 @@ class MessageAndActionModel(BaseDataModel):
|
||||
display_message=message.display_message,
|
||||
chat_info_platform=message.chat_info.platform,
|
||||
is_command=message.is_command,
|
||||
is_no_read_command=getattr(message, "is_no_read_command", False),
|
||||
intercept_message_level=getattr(message, "intercept_message_level", 0),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -170,7 +170,7 @@ class Messages(BaseModel):
|
||||
is_emoji = BooleanField(default=False)
|
||||
is_picid = BooleanField(default=False)
|
||||
is_command = BooleanField(default=False)
|
||||
is_no_read_command = BooleanField(default=False)
|
||||
intercept_message_level = IntegerField(default=0)
|
||||
is_notify = BooleanField(default=False)
|
||||
|
||||
selected_expressions = TextField(null=True)
|
||||
@@ -324,9 +324,9 @@ class Expression(BaseModel):
|
||||
|
||||
# new mode fields
|
||||
context = TextField(null=True)
|
||||
up_content = TextField(null=True)
|
||||
|
||||
content_list = TextField(null=True)
|
||||
style_list = TextField(null=True) # 存储相似的 style,格式与 content_list 相同(JSON 数组)
|
||||
count = IntegerField(default=1)
|
||||
last_active_time = FloatField()
|
||||
chat_id = TextField(index=True)
|
||||
@@ -593,22 +593,41 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
|
||||
db.execute_sql(f"CREATE TABLE {backup_table} AS SELECT * FROM {table_name}")
|
||||
logger.info(f"已创建备份表 '{backup_table}'")
|
||||
|
||||
# 2. 删除原表
|
||||
# 2. 获取原始行数(在删除表之前)
|
||||
original_count = db.execute_sql(f"SELECT COUNT(*) FROM {backup_table}").fetchone()[0]
|
||||
logger.info(f"备份表 '{backup_table}' 包含 {original_count} 行数据")
|
||||
|
||||
# 3. 删除原表
|
||||
db.execute_sql(f"DROP TABLE {table_name}")
|
||||
logger.info(f"已删除原表 '{table_name}'")
|
||||
|
||||
# 3. 重新创建表(使用当前模型定义)
|
||||
# 4. 重新创建表(使用当前模型定义)
|
||||
db.create_tables([model])
|
||||
logger.info(f"已重新创建表 '{table_name}' 使用新的约束")
|
||||
|
||||
# 4. 从备份表恢复数据
|
||||
# 获取字段列表
|
||||
# 5. 从备份表恢复数据
|
||||
# 获取字段列表,排除主键字段(让数据库自动生成新的主键)
|
||||
fields = list(model._meta.fields.keys())
|
||||
fields_str = ", ".join(fields)
|
||||
# Peewee 默认使用 'id' 作为主键字段名
|
||||
# 尝试获取主键字段名,如果获取失败则默认使用 'id'
|
||||
primary_key_name = "id" # 默认值
|
||||
try:
|
||||
if hasattr(model._meta, "primary_key") and model._meta.primary_key:
|
||||
if hasattr(model._meta.primary_key, "name"):
|
||||
primary_key_name = model._meta.primary_key.name
|
||||
elif isinstance(model._meta.primary_key, str):
|
||||
primary_key_name = model._meta.primary_key
|
||||
except Exception:
|
||||
pass # 如果获取失败,使用默认值 'id'
|
||||
|
||||
# 对于需要从 NOT NULL 改为 NULL 的字段,直接复制数据
|
||||
# 对于需要从 NULL 改为 NOT NULL 的字段,需要处理 NULL 值
|
||||
insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {fields_str} FROM {backup_table}"
|
||||
# 如果字段列表包含主键,则排除它
|
||||
if primary_key_name in fields:
|
||||
fields_without_pk = [f for f in fields if f != primary_key_name]
|
||||
logger.info(f"排除主键字段 '{primary_key_name}',让数据库自动生成新的主键")
|
||||
else:
|
||||
fields_without_pk = fields
|
||||
|
||||
fields_str = ", ".join(fields_without_pk)
|
||||
|
||||
# 检查是否有字段需要从 NULL 改为 NOT NULL
|
||||
null_to_notnull_fields = [
|
||||
@@ -621,7 +640,7 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
|
||||
|
||||
# 构建更复杂的 SELECT 语句来处理 NULL 值
|
||||
select_fields = []
|
||||
for field_name in fields:
|
||||
for field_name in fields_without_pk:
|
||||
if field_name in null_to_notnull_fields:
|
||||
field_obj = model._meta.fields[field_name]
|
||||
# 根据字段类型设置默认值
|
||||
@@ -642,12 +661,13 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
|
||||
|
||||
select_str = ", ".join(select_fields)
|
||||
insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {select_str} FROM {backup_table}"
|
||||
else:
|
||||
# 没有需要处理 NULL 的字段,直接复制数据(排除主键)
|
||||
insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {fields_str} FROM {backup_table}"
|
||||
|
||||
db.execute_sql(insert_sql)
|
||||
logger.info(f"已从备份表恢复数据到 '{table_name}'")
|
||||
|
||||
# 5. 验证数据完整性
|
||||
original_count = db.execute_sql(f"SELECT COUNT(*) FROM {backup_table}").fetchone()[0]
|
||||
new_count = db.execute_sql(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0]
|
||||
|
||||
if original_count == new_count:
|
||||
|
||||
@@ -20,6 +20,9 @@ PROJECT_ROOT = logger_file.parent.parent.parent.resolve()
|
||||
_file_handler = None
|
||||
_console_handler = None
|
||||
_ws_handler = None
|
||||
# 全局标志,防止重复初始化
|
||||
_logging_initialized = False
|
||||
_cleanup_task_started = False
|
||||
|
||||
|
||||
def get_file_handler():
|
||||
@@ -869,29 +872,41 @@ def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger:
|
||||
return logger
|
||||
|
||||
|
||||
def initialize_logging():
|
||||
def initialize_logging(verbose: bool = True):
|
||||
"""手动初始化日志系统,确保所有logger都使用正确的配置
|
||||
|
||||
在应用程序的早期调用此函数,确保所有模块都使用统一的日志配置
|
||||
|
||||
Args:
|
||||
verbose: 是否输出详细的初始化信息。默认为 True。
|
||||
在 Runner 进程中可以设置为 False 以避免重复的初始化日志。
|
||||
"""
|
||||
global LOG_CONFIG
|
||||
global LOG_CONFIG, _logging_initialized
|
||||
|
||||
# 防止重复初始化(在同一进程内)
|
||||
if _logging_initialized:
|
||||
return
|
||||
|
||||
_logging_initialized = True
|
||||
|
||||
LOG_CONFIG = load_log_config()
|
||||
# print(LOG_CONFIG)
|
||||
configure_third_party_loggers()
|
||||
reconfigure_existing_loggers()
|
||||
|
||||
# 启动日志清理任务
|
||||
start_log_cleanup_task()
|
||||
start_log_cleanup_task(verbose=verbose)
|
||||
|
||||
# 输出初始化信息
|
||||
logger = get_logger("logger")
|
||||
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
||||
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
||||
# 只在 verbose=True 时输出详细的初始化信息
|
||||
if verbose:
|
||||
logger = get_logger("logger")
|
||||
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
||||
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
||||
|
||||
logger.info("日志系统已初始化:")
|
||||
logger.info(f" - 控制台级别: {console_level}")
|
||||
logger.info(f" - 文件级别: {file_level}")
|
||||
logger.info(" - 轮转份数: 30个文件|自动清理: 30天前的日志")
|
||||
logger.info("日志系统已初始化:")
|
||||
logger.info(f" - 控制台级别: {console_level}")
|
||||
logger.info(f" - 文件级别: {file_level}")
|
||||
logger.info(" - 轮转份数: 30个文件|自动清理: 30天前的日志")
|
||||
|
||||
|
||||
def cleanup_old_logs():
|
||||
@@ -924,8 +939,19 @@ def cleanup_old_logs():
|
||||
logger.error(f"清理旧日志文件时出错: {e}")
|
||||
|
||||
|
||||
def start_log_cleanup_task():
|
||||
"""启动日志清理任务"""
|
||||
def start_log_cleanup_task(verbose: bool = True):
|
||||
"""启动日志清理任务
|
||||
|
||||
Args:
|
||||
verbose: 是否输出启动信息。默认为 True。
|
||||
"""
|
||||
global _cleanup_task_started
|
||||
|
||||
# 防止重复启动清理任务
|
||||
if _cleanup_task_started:
|
||||
return
|
||||
|
||||
_cleanup_task_started = True
|
||||
|
||||
def cleanup_task():
|
||||
while True:
|
||||
@@ -935,8 +961,9 @@ def start_log_cleanup_task():
|
||||
cleanup_thread = threading.Thread(target=cleanup_task, daemon=True)
|
||||
cleanup_thread.start()
|
||||
|
||||
logger = get_logger("logger")
|
||||
logger.info("已启动日志清理任务,将自动清理30天前的日志文件(轮转份数限制: 30个文件)")
|
||||
if verbose:
|
||||
logger = get_logger("logger")
|
||||
logger.info("已启动日志清理任务,将自动清理30天前的日志文件(轮转份数限制: 30个文件)")
|
||||
|
||||
|
||||
def shutdown_logging():
|
||||
|
||||
@@ -15,14 +15,18 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
||||
# 检查maim_message版本
|
||||
try:
|
||||
maim_message_version = importlib.metadata.version("maim_message")
|
||||
version_compatible = [int(x) for x in maim_message_version.split(".")] >= [0, 3, 3]
|
||||
version_int = [int(x) for x in maim_message_version.split(".")]
|
||||
version_compatible = version_int >= [0, 3, 3]
|
||||
# Check for API Server feature (>= 0.6.0)
|
||||
has_api_server_feature = version_int >= [0, 6, 0]
|
||||
except (importlib.metadata.PackageNotFoundError, ValueError):
|
||||
version_compatible = False
|
||||
has_api_server_feature = False
|
||||
|
||||
# 读取配置项
|
||||
maim_message_config = global_config.maim_message
|
||||
|
||||
# 设置基本参数
|
||||
# 设置基本参数 (Legacy Server Mode)
|
||||
kwargs = {
|
||||
"host": os.environ["HOST"],
|
||||
"port": int(os.environ["PORT"]),
|
||||
@@ -39,21 +43,129 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
||||
if maim_message_config.auth_token and len(maim_message_config.auth_token) > 0:
|
||||
kwargs["enable_token"] = True
|
||||
|
||||
if maim_message_config.use_custom:
|
||||
# 添加WSS模式支持
|
||||
del kwargs["app"]
|
||||
kwargs["host"] = maim_message_config.host
|
||||
kwargs["port"] = maim_message_config.port
|
||||
kwargs["mode"] = maim_message_config.mode
|
||||
if maim_message_config.use_wss:
|
||||
if maim_message_config.cert_file:
|
||||
kwargs["ssl_certfile"] = maim_message_config.cert_file
|
||||
if maim_message_config.key_file:
|
||||
kwargs["ssl_keyfile"] = maim_message_config.key_file
|
||||
kwargs["enable_custom_uvicorn_logger"] = False
|
||||
# Removed legacy custom config block (use_custom) as requested.
|
||||
kwargs["enable_custom_uvicorn_logger"] = False
|
||||
|
||||
global_api = MessageServer(**kwargs)
|
||||
if version_compatible and maim_message_config.auth_token:
|
||||
for token in maim_message_config.auth_token:
|
||||
global_api.add_valid_token(token)
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Additional API Server Configuration (maim_message >= 6.0)
|
||||
# ---------------------------------------------------------------------
|
||||
enable_api_server = maim_message_config.enable_api_server
|
||||
|
||||
# 如果版本支持且启用了API Server,则初始化额外服务器
|
||||
if has_api_server_feature and enable_api_server:
|
||||
try:
|
||||
from maim_message.server import WebSocketServer, ServerConfig
|
||||
from maim_message.message import APIMessageBase
|
||||
|
||||
api_logger = get_logger("maim_message_api_server")
|
||||
|
||||
# 1. Prepare Config
|
||||
api_server_host = maim_message_config.api_server_host
|
||||
api_server_port = maim_message_config.api_server_port
|
||||
use_wss = maim_message_config.api_server_use_wss
|
||||
|
||||
server_config = ServerConfig(
|
||||
host=api_server_host,
|
||||
port=api_server_port,
|
||||
ssl_enabled=use_wss,
|
||||
ssl_certfile=maim_message_config.api_server_cert_file if use_wss else None,
|
||||
ssl_keyfile=maim_message_config.api_server_key_file if use_wss else None,
|
||||
)
|
||||
|
||||
# 2. Setup Auth Handler
|
||||
async def auth_handler(metadata: dict) -> bool:
|
||||
allowed_keys = maim_message_config.api_server_allowed_api_keys
|
||||
# If list is empty/None, allow all (default behavior of returning True)
|
||||
if not allowed_keys:
|
||||
return True
|
||||
|
||||
api_key = metadata.get("api_key")
|
||||
if api_key in allowed_keys:
|
||||
return True
|
||||
|
||||
api_logger.warning(f"Rejected connection with invalid API Key: {api_key}")
|
||||
return False
|
||||
|
||||
server_config.on_auth = auth_handler
|
||||
|
||||
# 3. Setup Message Bridge
|
||||
# Initialize refined route map if not exists
|
||||
if not hasattr(global_api, "platform_map"):
|
||||
global_api.platform_map = {}
|
||||
|
||||
async def bridge_message_handler(message: APIMessageBase, metadata: dict):
|
||||
# Bridge message to the main bot logic
|
||||
# We convert APIMessageBase to dict to be compatible with legacy handlers
|
||||
# that MainBot (ChatManager) expects.
|
||||
msg_dict = message.to_dict()
|
||||
|
||||
# Compatibility Layer: Flatten sender_info to top-level user_info/group_info
|
||||
# Legacy MessageBase expects message_info to have user_info and group_info directly.
|
||||
if "message_info" in msg_dict:
|
||||
msg_info = msg_dict["message_info"]
|
||||
sender_info = msg_info.get("sender_info")
|
||||
if sender_info:
|
||||
# If direct user_info/group_info are missing, populate them from sender_info
|
||||
if "user_info" not in msg_info and (ui := sender_info.get("user_info")):
|
||||
msg_info["user_info"] = ui
|
||||
|
||||
if "group_info" not in msg_info and (gi := sender_info.get("group_info")):
|
||||
msg_info["group_info"] = gi
|
||||
|
||||
# Route Caching Logic: Simply map platform to API Key
|
||||
# This allows us to send messages back to the correct API client for this platform
|
||||
try:
|
||||
api_key = metadata.get("api_key")
|
||||
if api_key:
|
||||
platform = msg_info.get("platform")
|
||||
if platform:
|
||||
global_api.platform_map[platform] = api_key
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Failed to update platform map: {e}")
|
||||
|
||||
# Compatibility Layer: Ensure raw_message exists (even if None) as it's part of MessageBase
|
||||
if "raw_message" not in msg_dict:
|
||||
msg_dict["raw_message"] = None
|
||||
|
||||
await global_api.process_message(msg_dict)
|
||||
|
||||
server_config.on_message = bridge_message_handler
|
||||
|
||||
# 4. Initialize Server
|
||||
extra_server = WebSocketServer(config=server_config)
|
||||
|
||||
# 5. Patch global_api lifecycle methods to manage both servers
|
||||
original_run = global_api.run
|
||||
original_stop = global_api.stop
|
||||
|
||||
async def patched_run():
|
||||
api_logger.info(f"Starting Additional API Server on {api_server_host}:{api_server_port} (WSS: {use_wss})")
|
||||
# Start the extra server (non-blocking start)
|
||||
await extra_server.start()
|
||||
# Run the original legacy server (this usually keeps running)
|
||||
await original_run()
|
||||
|
||||
async def patched_stop():
|
||||
api_logger.info("Stopping Additional API Server...")
|
||||
await extra_server.stop()
|
||||
await original_stop()
|
||||
|
||||
global_api.run = patched_run
|
||||
global_api.stop = patched_stop
|
||||
|
||||
# Attach for reference
|
||||
global_api.extra_server = extra_server
|
||||
|
||||
except ImportError:
|
||||
get_logger("maim_message").error("Cannot import maim_message.server components. Is maim_message >= 0.6.0 installed?")
|
||||
except Exception as e:
|
||||
get_logger("maim_message").error(f"Failed to initialize Additional API Server: {e}")
|
||||
import traceback
|
||||
get_logger("maim_message").debug(traceback.format_exc())
|
||||
|
||||
return global_api
|
||||
|
||||
@@ -25,7 +25,7 @@ def find_messages(
|
||||
limit_mode: str = "latest",
|
||||
filter_bot=False,
|
||||
filter_command=False,
|
||||
filter_no_read_command=False,
|
||||
filter_intercept_message_level: Optional[int] = None,
|
||||
) -> List[DatabaseMessages]:
|
||||
"""
|
||||
根据提供的过滤器、排序和限制条件查找消息。
|
||||
@@ -85,8 +85,9 @@ def find_messages(
|
||||
# 使用按位取反构造 Peewee 的 NOT 条件,避免直接与 False 比较
|
||||
query = query.where(~Messages.is_command)
|
||||
|
||||
if filter_no_read_command:
|
||||
query = query.where(~Messages.is_no_read_command)
|
||||
if filter_intercept_message_level is not None:
|
||||
# 过滤掉所有 intercept_message_level > filter_intercept_message_level 的消息
|
||||
query = query.where(Messages.intercept_message_level <= filter_intercept_message_level)
|
||||
|
||||
if limit > 0:
|
||||
if limit_mode == "earliest":
|
||||
|
||||
@@ -4,6 +4,7 @@ TOML 工具函数
|
||||
提供 TOML 文件的格式化保存功能,确保数组等元素以美观的多行格式输出。
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
import tomlkit
|
||||
from tomlkit.items import AoT, Table, Array
|
||||
@@ -33,7 +34,7 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
|
||||
return obj
|
||||
|
||||
# 决定是否多行:仅在顶层且长度超过阈值时
|
||||
should_multiline = (depth == 0 and len(obj) > threshold)
|
||||
should_multiline = depth == 0 and len(obj) > threshold
|
||||
|
||||
# 如果已经是 tomlkit Array,原地修改以保留注释
|
||||
if isinstance(obj, Array):
|
||||
@@ -45,7 +46,7 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
|
||||
# 普通 list:转换为 tomlkit 数组
|
||||
arr = tomlkit.array()
|
||||
arr.multiline(should_multiline)
|
||||
|
||||
|
||||
for item in obj:
|
||||
arr.append(_format_toml_value(item, threshold, depth + 1))
|
||||
return arr
|
||||
@@ -54,14 +55,71 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
|
||||
return obj
|
||||
|
||||
|
||||
def save_toml_with_format(data: Any, file_path: str, multiline_threshold: int = 1) -> None:
|
||||
"""格式化 TOML 数据并保存到文件"""
|
||||
def _update_toml_doc(target: Any, source: Any) -> None:
|
||||
"""
|
||||
递归合并字典,将 source 的值更新到 target 中,保留 target 的注释和格式。
|
||||
- 已存在的键:更新值(递归处理嵌套字典)
|
||||
- 新增的键:添加到 target
|
||||
- 跳过 version 字段
|
||||
"""
|
||||
if isinstance(source, list) or not isinstance(source, dict) or not isinstance(target, dict):
|
||||
return
|
||||
|
||||
for key, value in source.items():
|
||||
if key == "version":
|
||||
continue
|
||||
if key in target:
|
||||
# 已存在的键:递归更新或直接赋值
|
||||
target_value = target[key]
|
||||
if isinstance(value, dict) and isinstance(target_value, dict):
|
||||
_update_toml_doc(target_value, value)
|
||||
else:
|
||||
try:
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
target[key] = value
|
||||
else:
|
||||
# 新增的键:添加到 target
|
||||
try:
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
target[key] = value
|
||||
|
||||
|
||||
def save_toml_with_format(
|
||||
data: Any, file_path: str, multiline_threshold: int = 1, preserve_comments: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
格式化 TOML 数据并保存到文件。
|
||||
|
||||
Args:
|
||||
data: 要保存的数据(dict 或 tomlkit 文档)
|
||||
file_path: 保存路径
|
||||
multiline_threshold: 数组多行格式化阈值,-1 表示不格式化
|
||||
preserve_comments: 是否保留原文件的注释和格式(默认 True)
|
||||
若为 True 且文件已存在且 data 不是 tomlkit 文档,会先读取原文件,再将 data 合并进去
|
||||
"""
|
||||
import os
|
||||
from tomlkit import TOMLDocument
|
||||
|
||||
# 如果需要保留注释、文件存在、且 data 不是已有的 tomlkit 文档,先读取原文件再合并
|
||||
if preserve_comments and os.path.exists(file_path) and not isinstance(data, TOMLDocument):
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
doc = tomlkit.load(f)
|
||||
_update_toml_doc(doc, data)
|
||||
data = doc
|
||||
|
||||
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
|
||||
output = tomlkit.dumps(formatted)
|
||||
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
|
||||
output = re.sub(r"\n{3,}", "\n\n", output)
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(formatted, f)
|
||||
f.write(output)
|
||||
|
||||
|
||||
def format_toml_string(data: Any, multiline_threshold: int = 1) -> str:
|
||||
"""格式化 TOML 数据并返回字符串"""
|
||||
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
|
||||
return tomlkit.dumps(formatted)
|
||||
output = tomlkit.dumps(formatted)
|
||||
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
|
||||
return re.sub(r"\n{3,}", "\n\n", output)
|
||||
|
||||
@@ -60,6 +60,12 @@ class ModelInfo(ConfigBase):
|
||||
price_out: float = field(default=0.0)
|
||||
"""每M token输出价格"""
|
||||
|
||||
temperature: float | None = field(default=None)
|
||||
"""模型级别温度(可选),会覆盖任务配置中的温度"""
|
||||
|
||||
max_tokens: int | None = field(default=None)
|
||||
"""模型级别最大token数(可选),会覆盖任务配置中的max_tokens"""
|
||||
|
||||
force_stream_mode: bool = field(default=False)
|
||||
"""是否强制使用流式输出模式"""
|
||||
|
||||
|
||||
@@ -31,10 +31,10 @@ from src.config.official_configs import (
|
||||
RelationshipConfig,
|
||||
ToolConfig,
|
||||
VoiceConfig,
|
||||
MoodConfig,
|
||||
MemoryConfig,
|
||||
DebugConfig,
|
||||
JargonConfig,
|
||||
DreamConfig,
|
||||
WebUIConfig,
|
||||
)
|
||||
|
||||
from .api_ada_configs import (
|
||||
@@ -57,7 +57,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
|
||||
|
||||
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
||||
MMC_VERSION = "0.11.6"
|
||||
MMC_VERSION = "0.12.0"
|
||||
|
||||
|
||||
def get_key_comment(toml_table, key):
|
||||
@@ -348,15 +348,15 @@ class Config(ConfigBase):
|
||||
response_post_process: ResponsePostProcessConfig
|
||||
response_splitter: ResponseSplitterConfig
|
||||
telemetry: TelemetryConfig
|
||||
webui: WebUIConfig
|
||||
experimental: ExperimentalConfig
|
||||
maim_message: MaimMessageConfig
|
||||
lpmm_knowledge: LPMMKnowledgeConfig
|
||||
tool: ToolConfig
|
||||
memory: MemoryConfig
|
||||
debug: DebugConfig
|
||||
mood: MoodConfig
|
||||
voice: VoiceConfig
|
||||
jargon: JargonConfig
|
||||
dream: DreamConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -43,10 +43,13 @@ class PersonalityConfig(ConfigBase):
|
||||
"""人格"""
|
||||
|
||||
reply_style: str = ""
|
||||
"""表达风格"""
|
||||
"""默认表达风格"""
|
||||
|
||||
interest: str = ""
|
||||
"""兴趣"""
|
||||
multiple_reply_style: list[str] = field(default_factory=lambda: [])
|
||||
"""可选的多种表达风格列表,当配置不为空时可按概率随机替换 reply_style"""
|
||||
|
||||
multiple_probability: float = 0.0
|
||||
"""每次构建回复时,从 multiple_reply_style 中随机替换 reply_style 的概率(0.0-1.0)"""
|
||||
|
||||
plan_style: str = ""
|
||||
"""说话规则,行为风格"""
|
||||
@@ -79,12 +82,6 @@ class ChatConfig(ConfigBase):
|
||||
max_context_size: int = 18
|
||||
"""上下文长度"""
|
||||
|
||||
interest_rate_mode: Literal["fast", "accurate"] = "fast"
|
||||
"""兴趣值计算模式,fast为快速计算,accurate为精确计算"""
|
||||
|
||||
planner_size: float = 1.5
|
||||
"""副规划器大小,越小,麦麦的动作执行能力越精细,但是消耗更多token,调大可以缓解429类错误"""
|
||||
|
||||
mentioned_bot_reply: bool = True
|
||||
"""是否启用提及必回复"""
|
||||
|
||||
@@ -117,8 +114,13 @@ class ChatConfig(ConfigBase):
|
||||
时间区间支持跨夜,例如 "23:00-02:00"。
|
||||
"""
|
||||
|
||||
include_planner_reasoning: bool = False
|
||||
"""是否将planner推理加入replyer,默认关闭(不加入)"""
|
||||
think_mode: Literal["classic", "deep", "dynamic"] = "classic"
|
||||
"""
|
||||
思考模式配置
|
||||
- classic: 默认think_level为0(轻量回复,不需要思考和回忆)
|
||||
- deep: 默认think_level为1(深度回复,需要进行回忆和思考)
|
||||
- dynamic: think_level由planner动态给出(根据planner返回的think_level决定)
|
||||
"""
|
||||
|
||||
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
|
||||
"""与 ChatStream.get_stream_id 一致地从 "platform:id:type" 生成 chat_id。"""
|
||||
@@ -133,14 +135,9 @@ class ChatConfig(ConfigBase):
|
||||
|
||||
is_group = stream_type == "group"
|
||||
|
||||
import hashlib
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
|
||||
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
@@ -173,7 +170,11 @@ class ChatConfig(ConfigBase):
|
||||
def get_talk_value(self, chat_id: Optional[str]) -> float:
|
||||
"""根据规则返回当前 chat 的动态 talk_value,未匹配则回退到基础值。"""
|
||||
if not self.enable_talk_value_rules or not self.talk_value_rules:
|
||||
return self.talk_value
|
||||
result = self.talk_value
|
||||
# 防止返回0值,自动转换为0.0001
|
||||
if result == 0:
|
||||
return 0.0000001
|
||||
return result
|
||||
|
||||
now_min = self._now_minutes()
|
||||
|
||||
@@ -199,7 +200,11 @@ class ChatConfig(ConfigBase):
|
||||
start_min, end_min = parsed
|
||||
if self._in_range(now_min, start_min, end_min):
|
||||
try:
|
||||
return float(value)
|
||||
result = float(value)
|
||||
# 防止返回0值,自动转换为0.0001
|
||||
if result == 0:
|
||||
return 0.0000001
|
||||
return result
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
@@ -218,12 +223,20 @@ class ChatConfig(ConfigBase):
|
||||
start_min, end_min = parsed
|
||||
if self._in_range(now_min, start_min, end_min):
|
||||
try:
|
||||
return float(value)
|
||||
result = float(value)
|
||||
# 防止返回0值,自动转换为0.0001
|
||||
if result == 0:
|
||||
return 0.0000001
|
||||
return result
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 3) 未命中规则返回基础值
|
||||
return self.talk_value
|
||||
result = self.talk_value
|
||||
# 防止返回0值,自动转换为0.0001
|
||||
if result == 0:
|
||||
return 0.0000001
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -244,13 +257,21 @@ class MemoryConfig(ConfigBase):
|
||||
max_agent_iterations: int = 5
|
||||
"""Agent最多迭代轮数(最低为1)"""
|
||||
|
||||
agent_timeout_seconds: float = 120.0
|
||||
"""Agent超时时间(秒)"""
|
||||
|
||||
enable_jargon_detection: bool = True
|
||||
"""记忆检索过程中是否启用黑话识别"""
|
||||
|
||||
global_memory: bool = False
|
||||
"""是否允许记忆检索在聊天记录中进行全局查询(忽略当前chat_id,仅对 search_chat_history 等工具生效)"""
|
||||
|
||||
def __post_init__(self):
|
||||
"""验证配置值"""
|
||||
if self.max_agent_iterations < 1:
|
||||
raise ValueError(f"max_agent_iterations 必须至少为1,当前值: {self.max_agent_iterations}")
|
||||
if self.agent_timeout_seconds <= 0:
|
||||
raise ValueError(f"agent_timeout_seconds 必须大于0,当前值: {self.agent_timeout_seconds}")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -260,20 +281,20 @@ class ExpressionConfig(ConfigBase):
|
||||
learning_list: list[list] = field(default_factory=lambda: [])
|
||||
"""
|
||||
表达学习配置列表,支持按聊天流配置
|
||||
格式: [["chat_stream_id", "use_expression", "enable_learning", learning_intensity], ...]
|
||||
格式: [["chat_stream_id", "use_expression", "enable_learning", "enable_jargon_learning"], ...]
|
||||
|
||||
示例:
|
||||
[
|
||||
["", "enable", "enable", 1.0], # 全局配置:使用表达,启用学习,学习强度1.0
|
||||
["qq:1919810:private", "enable", "enable", 1.5], # 特定私聊配置:使用表达,启用学习,学习强度1.5
|
||||
["qq:114514:private", "enable", "disable", 0.5], # 特定私聊配置:使用表达,禁用学习,学习强度0.5
|
||||
["", "enable", "enable", "enable"], # 全局配置:使用表达,启用学习,启用jargon学习
|
||||
["qq:1919810:private", "enable", "enable", "enable"], # 特定私聊配置:使用表达,启用学习,启用jargon学习
|
||||
["qq:114514:private", "enable", "disable", "disable"], # 特定私聊配置:使用表达,禁用学习,禁用jargon学习
|
||||
]
|
||||
|
||||
说明:
|
||||
- 第一位: chat_stream_id,空字符串表示全局配置
|
||||
- 第二位: 是否使用学到的表达 ("enable"/"disable")
|
||||
- 第三位: 是否学习表达 ("enable"/"disable")
|
||||
- 第四位: 学习强度(浮点数),影响学习频率,最短学习时间间隔 = 300/学习强度(秒)
|
||||
- 第四位: 是否启用jargon学习 ("enable"/"disable")
|
||||
"""
|
||||
|
||||
expression_groups: list[list[str]] = field(default_factory=list)
|
||||
@@ -296,6 +317,19 @@ class ExpressionConfig(ConfigBase):
|
||||
如果列表为空,则所有聊天流都可以进行表达反思(前提是 reflect = true)
|
||||
"""
|
||||
|
||||
all_global_jargon: bool = False
|
||||
"""是否将所有新增的jargon项目默认为全局(is_global=True),chat_id记录第一次存储时的id。注意,此功能关闭后,已经记录的全局黑话不会改变,需要手动删除"""
|
||||
|
||||
enable_jargon_explanation: bool = True
|
||||
"""是否在回复前尝试对上下文中的黑话进行解释(关闭可减少一次LLM调用,仅影响回复前的黑话匹配与解释,不影响黑话学习)"""
|
||||
|
||||
jargon_mode: Literal["context", "planner"] = "context"
|
||||
"""
|
||||
黑话解释来源模式:
|
||||
- "context": 使用上下文自动匹配黑话并解释(原有模式)
|
||||
- "planner": 仅使用 Planner 在 reply 动作中给出的 unknown_words 列表进行黑话检索
|
||||
"""
|
||||
|
||||
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
|
||||
"""
|
||||
解析流配置字符串并生成对应的 chat_id
|
||||
@@ -318,20 +352,15 @@ class ExpressionConfig(ConfigBase):
|
||||
# 判断是否为群聊
|
||||
is_group = stream_type == "group"
|
||||
|
||||
# 使用与 ChatStream.get_stream_id 相同的逻辑生成 chat_id
|
||||
import hashlib
|
||||
# 使用 ChatManager 提供的接口生成 chat_id,避免在此重复实现逻辑
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
|
||||
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
def get_expression_config_for_chat(self, chat_stream_id: Optional[str] = None) -> tuple[bool, bool, int]:
|
||||
def get_expression_config_for_chat(self, chat_stream_id: Optional[str] = None) -> tuple[bool, bool, bool]:
|
||||
"""
|
||||
根据聊天流ID获取表达配置
|
||||
|
||||
@@ -339,11 +368,11 @@ class ExpressionConfig(ConfigBase):
|
||||
chat_stream_id: 聊天流ID,格式为哈希值
|
||||
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 学习间隔)
|
||||
tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习)
|
||||
"""
|
||||
if not self.learning_list:
|
||||
# 如果没有配置,使用默认值:启用表达,启用学习,300秒间隔
|
||||
return True, True, 300
|
||||
# 如果没有配置,使用默认值:启用表达,启用学习,启用jargon学习
|
||||
return True, True, True
|
||||
|
||||
# 优先检查聊天流特定的配置
|
||||
if chat_stream_id:
|
||||
@@ -356,10 +385,10 @@ class ExpressionConfig(ConfigBase):
|
||||
if global_expression_config is not None:
|
||||
return global_expression_config
|
||||
|
||||
# 如果都没有匹配,返回默认值
|
||||
return True, True, 300
|
||||
# 如果都没有匹配,返回默认值:启用表达,启用学习,启用jargon学习
|
||||
return True, True, True
|
||||
|
||||
def _get_stream_specific_config(self, chat_stream_id: str) -> Optional[tuple[bool, bool, int]]:
|
||||
def _get_stream_specific_config(self, chat_stream_id: str) -> Optional[tuple[bool, bool, bool]]:
|
||||
"""
|
||||
获取特定聊天流的表达配置
|
||||
|
||||
@@ -367,7 +396,7 @@ class ExpressionConfig(ConfigBase):
|
||||
chat_stream_id: 聊天流ID(哈希值)
|
||||
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None
|
||||
tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习),如果没有配置则返回 None
|
||||
"""
|
||||
for config_item in self.learning_list:
|
||||
if not config_item or len(config_item) < 4:
|
||||
@@ -392,19 +421,19 @@ class ExpressionConfig(ConfigBase):
|
||||
try:
|
||||
use_expression: bool = config_item[1].lower() == "enable"
|
||||
enable_learning: bool = config_item[2].lower() == "enable"
|
||||
learning_intensity: float = float(config_item[3])
|
||||
return use_expression, enable_learning, learning_intensity # type: ignore
|
||||
enable_jargon_learning: bool = config_item[3].lower() == "enable"
|
||||
return use_expression, enable_learning, enable_jargon_learning # type: ignore
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def _get_global_config(self) -> Optional[tuple[bool, bool, int]]:
|
||||
def _get_global_config(self) -> Optional[tuple[bool, bool, bool]]:
|
||||
"""
|
||||
获取全局表达配置
|
||||
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None
|
||||
tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习),如果没有配置则返回 None
|
||||
"""
|
||||
for config_item in self.learning_list:
|
||||
if not config_item or len(config_item) < 4:
|
||||
@@ -415,8 +444,8 @@ class ExpressionConfig(ConfigBase):
|
||||
try:
|
||||
use_expression: bool = config_item[1].lower() == "enable"
|
||||
enable_learning: bool = config_item[2].lower() == "enable"
|
||||
learning_intensity = float(config_item[3])
|
||||
return use_expression, enable_learning, learning_intensity # type: ignore
|
||||
enable_jargon_learning: bool = config_item[3].lower() == "enable"
|
||||
return use_expression, enable_learning, enable_jargon_learning # type: ignore
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
@@ -431,20 +460,6 @@ class ToolConfig(ConfigBase):
|
||||
"""是否在聊天中启用工具"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MoodConfig(ConfigBase):
|
||||
"""情绪配置类"""
|
||||
|
||||
enable_mood: bool = True
|
||||
"""是否启用情绪系统"""
|
||||
|
||||
mood_update_threshold: float = 1
|
||||
"""情绪更新阈值,越高,更新越慢"""
|
||||
|
||||
emotion_style: str = "情绪较为稳定,但遭遇特定事件的时候起伏较大"
|
||||
"""情感特征,影响情绪的变化情况"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoiceConfig(ConfigBase):
|
||||
"""语音识别配置类"""
|
||||
@@ -582,6 +597,35 @@ class TelemetryConfig(ConfigBase):
|
||||
"""是否启用遥测"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebUIConfig(ConfigBase):
|
||||
"""WebUI配置类
|
||||
|
||||
注意: host 和 port 配置已移至环境变量 WEBUI_HOST 和 WEBUI_PORT
|
||||
"""
|
||||
|
||||
enabled: bool = True
|
||||
"""是否启用WebUI"""
|
||||
|
||||
mode: Literal["development", "production"] = "production"
|
||||
"""运行模式:development(开发) 或 production(生产)"""
|
||||
|
||||
anti_crawler_mode: Literal["false", "strict", "loose", "basic"] = "basic"
|
||||
"""防爬虫模式:false(禁用) / strict(严格) / loose(宽松) / basic(基础-只记录不阻止)"""
|
||||
|
||||
allowed_ips: str = "127.0.0.1"
|
||||
"""IP白名单(逗号分隔,支持精确IP、CIDR格式和通配符)"""
|
||||
|
||||
trusted_proxies: str = ""
|
||||
"""信任的代理IP列表(逗号分隔),只有来自这些IP的X-Forwarded-For才被信任"""
|
||||
|
||||
trust_xff: bool = False
|
||||
"""是否启用X-Forwarded-For代理解析(默认false)"""
|
||||
|
||||
secure_cookie: bool = False
|
||||
"""是否启用安全Cookie(仅通过HTTPS传输,默认false)"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DebugConfig(ConfigBase):
|
||||
"""调试配置类"""
|
||||
@@ -639,29 +683,29 @@ class ExperimentalConfig(ConfigBase):
|
||||
class MaimMessageConfig(ConfigBase):
|
||||
"""maim_message配置类"""
|
||||
|
||||
use_custom: bool = False
|
||||
"""是否使用自定义的maim_message配置"""
|
||||
|
||||
host: str = "127.0.0.1"
|
||||
"""主机地址"""
|
||||
|
||||
port: int = 8090
|
||||
""""端口号"""
|
||||
|
||||
mode: Literal["ws", "tcp"] = "ws"
|
||||
"""连接模式,支持ws和tcp"""
|
||||
|
||||
use_wss: bool = False
|
||||
"""是否使用WSS安全连接"""
|
||||
|
||||
cert_file: str = ""
|
||||
"""SSL证书文件路径,仅在use_wss=True时有效"""
|
||||
|
||||
key_file: str = ""
|
||||
"""SSL密钥文件路径,仅在use_wss=True时有效"""
|
||||
|
||||
auth_token: list[str] = field(default_factory=lambda: [])
|
||||
"""认证令牌,用于API验证,为空则不启用验证"""
|
||||
"""认证令牌,用于旧版API验证,为空则不启用验证"""
|
||||
|
||||
enable_api_server: bool = False
|
||||
"""是否启用额外的新版API Server"""
|
||||
|
||||
api_server_host: str = "0.0.0.0"
|
||||
"""新版API Server主机地址"""
|
||||
|
||||
api_server_port: int = 8090
|
||||
"""新版API Server端口号"""
|
||||
|
||||
api_server_use_wss: bool = False
|
||||
"""新版API Server是否启用WSS"""
|
||||
|
||||
api_server_cert_file: str = ""
|
||||
"""新版API Server SSL证书文件路径"""
|
||||
|
||||
api_server_key_file: str = ""
|
||||
"""新版API Server SSL密钥文件路径"""
|
||||
|
||||
api_server_allowed_api_keys: list[str] = field(default_factory=lambda: [])
|
||||
"""新版API Server允许的API Key列表,为空则允许所有连接"""
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -707,10 +751,107 @@ class LPMMKnowledgeConfig(ConfigBase):
|
||||
embedding_dimension: int = 1024
|
||||
"""嵌入向量维度,应该与模型的输出维度一致"""
|
||||
|
||||
max_embedding_workers: int = 3
|
||||
"""嵌入/抽取并发线程数"""
|
||||
|
||||
embedding_chunk_size: int = 4
|
||||
"""每批嵌入的条数"""
|
||||
|
||||
max_synonym_entities: int = 2000
|
||||
"""同义边参与的实体数上限,超限则跳过"""
|
||||
|
||||
enable_ppr: bool = True
|
||||
"""是否启用PPR,低配机器可关闭"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class JargonConfig(ConfigBase):
|
||||
"""Jargon配置类"""
|
||||
class DreamConfig(ConfigBase):
|
||||
"""Dream配置类"""
|
||||
|
||||
all_global: bool = False
|
||||
"""是否将所有新增的jargon项目默认为全局(is_global=True),chat_id记录第一次存储时的id"""
|
||||
interval_minutes: int = 30
|
||||
"""做梦时间间隔(分钟),默认30分钟"""
|
||||
|
||||
max_iterations: int = 20
|
||||
"""做梦最大轮次,默认20轮"""
|
||||
|
||||
first_delay_seconds: int = 60
|
||||
"""程序启动后首次做梦前的延迟时间(秒),默认60秒"""
|
||||
|
||||
dream_send: str = ""
|
||||
"""
|
||||
做梦结果推送目标,格式为 "platform:user_id"
|
||||
例如: "qq:123456" 表示在做梦结束后,将梦境文本额外发送给该QQ私聊用户。
|
||||
为空字符串时不推送。
|
||||
"""
|
||||
|
||||
dream_time_ranges: list[str] = field(default_factory=lambda: [])
|
||||
"""
|
||||
做梦时间段配置列表,格式:["HH:MM-HH:MM", ...]
|
||||
如果列表为空,则表示全天允许做梦。
|
||||
如果配置了时间段,则只有在这些时间段内才会实际执行做梦流程。
|
||||
时间段外,调度器仍会按间隔检查,但不会进入做梦流程。
|
||||
|
||||
示例:
|
||||
[
|
||||
"09:00-22:00", # 白天允许做梦
|
||||
"23:00-02:00", # 跨夜时间段(23:00到次日02:00)
|
||||
]
|
||||
|
||||
支持跨夜区间,例如 "23:00-02:00" 表示从23:00到次日02:00。
|
||||
"""
|
||||
|
||||
def _now_minutes(self) -> int:
|
||||
"""返回本地时间的分钟数(0-1439)。"""
|
||||
lt = time.localtime()
|
||||
return lt.tm_hour * 60 + lt.tm_min
|
||||
|
||||
def _parse_range(self, range_str: str) -> Optional[tuple[int, int]]:
|
||||
"""解析 "HH:MM-HH:MM" 到 (start_min, end_min)。"""
|
||||
try:
|
||||
start_str, end_str = [s.strip() for s in range_str.split("-")]
|
||||
sh, sm = [int(x) for x in start_str.split(":")]
|
||||
eh, em = [int(x) for x in end_str.split(":")]
|
||||
return sh * 60 + sm, eh * 60 + em
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _in_range(self, now_min: int, start_min: int, end_min: int) -> bool:
|
||||
"""
|
||||
判断 now_min 是否在 [start_min, end_min] 区间内。
|
||||
支持跨夜:如果 start > end,则表示跨越午夜。
|
||||
"""
|
||||
if start_min <= end_min:
|
||||
return start_min <= now_min <= end_min
|
||||
# 跨夜:例如 23:00-02:00
|
||||
return now_min >= start_min or now_min <= end_min
|
||||
|
||||
def is_in_dream_time(self) -> bool:
|
||||
"""
|
||||
检查当前时间是否在允许做梦的时间段内。
|
||||
如果 dream_time_ranges 为空,则返回 True(全天允许)。
|
||||
"""
|
||||
if not self.dream_time_ranges:
|
||||
return True
|
||||
|
||||
now_min = self._now_minutes()
|
||||
|
||||
for time_range in self.dream_time_ranges:
|
||||
if not isinstance(time_range, str):
|
||||
continue
|
||||
parsed = self._parse_range(time_range)
|
||||
if not parsed:
|
||||
continue
|
||||
start_min, end_min = parsed
|
||||
if self._in_range(now_min, start_min, end_min):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def __post_init__(self):
|
||||
"""验证配置值"""
|
||||
if self.interval_minutes < 1:
|
||||
raise ValueError(f"interval_minutes 必须至少为1,当前值: {self.interval_minutes}")
|
||||
if self.max_iterations < 1:
|
||||
raise ValueError(f"max_iterations 必须至少为1,当前值: {self.max_iterations}")
|
||||
if self.first_delay_seconds < 0:
|
||||
raise ValueError(f"first_delay_seconds 不能为负数,当前值: {self.first_delay_seconds}")
|
||||
|
||||
580
src/dream/dream_agent.py
Normal file
580
src/dream/dream_agent.py
Normal file
@@ -0,0 +1,580 @@
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from peewee import fn
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.database.database_model import ChatHistory
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
|
||||
from src.plugin_system.apis import llm_api
|
||||
from src.dream.dream_generator import generate_dream_summary
|
||||
|
||||
# dream 工具工厂函数
|
||||
from src.dream.tools.search_chat_history_tool import make_search_chat_history
|
||||
from src.dream.tools.get_chat_history_detail_tool import make_get_chat_history_detail
|
||||
from src.dream.tools.delete_chat_history_tool import make_delete_chat_history
|
||||
from src.dream.tools.create_chat_history_tool import make_create_chat_history
|
||||
from src.dream.tools.update_chat_history_tool import make_update_chat_history
|
||||
from src.dream.tools.finish_maintenance_tool import make_finish_maintenance
|
||||
from src.dream.tools.search_jargon_tool import make_search_jargon
|
||||
from src.dream.tools.delete_jargon_tool import make_delete_jargon
|
||||
from src.dream.tools.update_jargon_tool import make_update_jargon
|
||||
|
||||
logger = get_logger("dream_agent")
|
||||
|
||||
|
||||
def init_dream_prompts() -> None:
|
||||
"""初始化 dream agent 的提示词"""
|
||||
Prompt(
|
||||
"""
|
||||
你的名字是{bot_name},你现在处于"梦境维护模式(dream agent)"。
|
||||
你可以自由地在 ChatHistory 库中探索、整理、创建和删改记录,以帮助自己在未来更好地回忆和理解对话历史。
|
||||
|
||||
本轮要维护的聊天ID:{chat_id}
|
||||
本轮随机选中的起始记忆 ID:{start_memory_id}
|
||||
请优先以这条起始记忆为切入点,先理解它的内容与上下文,再决定如何在其附近进行创建新概括、重写或删除等整理操作;如果起始记忆为空,则由你自行选择合适的切入点。
|
||||
|
||||
你可以使用的工具包括:
|
||||
**ChatHistory 维护工具:**
|
||||
- search_chat_history:根据关键词或参与人搜索该 chat_id 下的历史记忆概括列表
|
||||
- get_chat_history_detail:查看某条概括的详细内容
|
||||
- create_chat_history:根据整理后的理解创建一条新的 ChatHistory 概括记录(主题、概括、关键词、关键信息等)
|
||||
- update_chat_history:在不改变事实的前提下重写或精炼主题、概括、关键词、关键信息
|
||||
- delete_chat_history:删除明显冗余、噪声、错误或无意义的记录,或者非常有时效性的信息,或者无太多有用信息的日常互动。
|
||||
你也可以先用 create_chat_history 创建一条新的综合概括,再对旧的冗余记录执行多次 delete_chat_history 来完成“合并”效果。
|
||||
|
||||
**Jargon(黑话)维护工具(只读,禁止修改):**
|
||||
- search_jargon:根据一个或多个关键词搜索Jargon 记录,通常是含义不明确的词条或者特殊的缩写
|
||||
|
||||
**通用工具:**
|
||||
- finish_maintenance:当你认为当前维护工作已经完成,没有更多需要整理的内容时,调用此工具来结束本次运行
|
||||
|
||||
**工作目标**:
|
||||
- 发现冗余、重复或高度相似的记录,并进行合并或删除;
|
||||
- 发现主题/概括过于含糊、啰嗦或缺少关键信息的记录,进行重写和精简;
|
||||
- summary要尽可能保持有用的信息;
|
||||
- 尽量保持信息的真实与可用性,不要凭空捏造事实。
|
||||
|
||||
**合并准则**
|
||||
- 你可以新建一个记录,然后删除旧记录来实现合并。
|
||||
- 如果两个或多个记录的主题相似,内容是对主题不同方面的信息或讨论,且信息量较少,则可以合并为一条记录。
|
||||
- 如果两个记录冲突,可以根据逻辑保留一个或者进行整合,也可以采取更新的记录,删除旧的记录
|
||||
|
||||
**轮次信息**:
|
||||
- 本次维护最多执行 {max_iterations} 轮
|
||||
- 每轮开始时,系统会告知你当前是第几轮,还剩多少轮
|
||||
- 如果提前完成维护工作,可以调用 finish_maintenance 工具主动结束
|
||||
|
||||
**每一轮的执行方式(必须遵守):**
|
||||
- 第一步:先用一小段中文自然语言,写出你的「思考」和本轮计划(例如要查什么、准备怎么合并/修改);
|
||||
- 第二步:在这段思考之后,再通过工具调用来执行你的计划(可以调用 0~N 个工具);
|
||||
- 第三步:收到工具结果后,在下一轮继续先写出新的思考,再视情况继续调用工具。
|
||||
|
||||
请不要在没有先写出思考的情况下直接调用工具。
|
||||
只输出你的思考内容或工具调用结果,由系统负责真正执行工具调用。
|
||||
""",
|
||||
name="dream_react_head_prompt",
|
||||
)
|
||||
|
||||
|
||||
class DreamTool:
|
||||
"""dream 模块内部使用的简易工具封装"""
|
||||
|
||||
def __init__(self, name: str, description: str, parameters: List[Tuple], execute_func):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.parameters = parameters
|
||||
self.execute_func = execute_func
|
||||
|
||||
def get_tool_definition(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.parameters,
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs) -> str:
|
||||
return await self.execute_func(**kwargs)
|
||||
|
||||
|
||||
class DreamToolRegistry:
|
||||
def __init__(self) -> None:
|
||||
self.tools: Dict[str, DreamTool] = {}
|
||||
|
||||
def register_tool(self, tool: DreamTool) -> None:
|
||||
"""
|
||||
注册或更新 dream 工具。
|
||||
注意:dream agent 每个 chat_id 会重新初始化工具,这里允许覆盖已有同名工具。
|
||||
"""
|
||||
self.tools[tool.name] = tool
|
||||
logger.info(f"注册/更新 dream 工具: {tool.name}")
|
||||
|
||||
def get_tool(self, name: str) -> Optional[DreamTool]:
|
||||
return self.tools.get(name)
|
||||
|
||||
def get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||||
return [tool.get_tool_definition() for tool in self.tools.values()]
|
||||
|
||||
|
||||
_dream_tool_registry = DreamToolRegistry()
|
||||
|
||||
|
||||
def get_dream_tool_registry() -> DreamToolRegistry:
|
||||
return _dream_tool_registry
|
||||
|
||||
|
||||
def init_dream_tools(chat_id: str) -> None:
|
||||
"""注册 dream agent 可用的 ChatHistory / Jargon 相关工具(限定在当前 chat_id 作用域内)"""
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType
|
||||
|
||||
# 通过工厂函数生成绑定当前 chat_id 的工具实现
|
||||
search_chat_history = make_search_chat_history(chat_id)
|
||||
get_chat_history_detail = make_get_chat_history_detail(chat_id)
|
||||
delete_chat_history = make_delete_chat_history(chat_id)
|
||||
create_chat_history = make_create_chat_history(chat_id)
|
||||
update_chat_history = make_update_chat_history(chat_id)
|
||||
finish_maintenance = make_finish_maintenance(chat_id)
|
||||
|
||||
search_jargon = make_search_jargon(chat_id)
|
||||
delete_jargon = make_delete_jargon(chat_id)
|
||||
update_jargon = make_update_jargon(chat_id)
|
||||
|
||||
_dream_tool_registry.register_tool(
|
||||
DreamTool(
|
||||
"search_chat_history",
|
||||
"根据关键词或参与人查询当前 chat_id 下的 ChatHistory 概览,便于快速定位相关记忆。",
|
||||
[
|
||||
(
|
||||
"keyword",
|
||||
ToolParamType.STRING,
|
||||
"关键词(可选,支持多个关键词,可用空格、逗号等分隔)。",
|
||||
False,
|
||||
None,
|
||||
),
|
||||
("participant", ToolParamType.STRING, "参与人昵称(可选)。", False, None),
|
||||
],
|
||||
search_chat_history,
|
||||
)
|
||||
)
|
||||
|
||||
_dream_tool_registry.register_tool(
|
||||
DreamTool(
|
||||
"get_chat_history_detail",
|
||||
"根据 memory_id 获取单条 ChatHistory 的详细内容,包含主题、概括、关键词、关键信息等字段(不包含原文)。",
|
||||
[
|
||||
("memory_id", ToolParamType.INTEGER, "ChatHistory 主键 ID。", True, None),
|
||||
],
|
||||
get_chat_history_detail,
|
||||
)
|
||||
)
|
||||
|
||||
_dream_tool_registry.register_tool(
|
||||
DreamTool(
|
||||
"delete_chat_history",
|
||||
"根据 memory_id 删除一条 ChatHistory 记录(请谨慎使用)。",
|
||||
[
|
||||
("memory_id", ToolParamType.INTEGER, "需要删除的 ChatHistory 主键 ID。", True, None),
|
||||
],
|
||||
delete_chat_history,
|
||||
)
|
||||
)
|
||||
|
||||
_dream_tool_registry.register_tool(
|
||||
DreamTool(
|
||||
"update_chat_history",
|
||||
"按字段更新 ChatHistory 记录,可用于清理、重写或补充信息。",
|
||||
[
|
||||
("memory_id", ToolParamType.INTEGER, "需要更新的 ChatHistory 主键 ID。", True, None),
|
||||
("theme", ToolParamType.STRING, "新的主题标题,如果不需要修改可不填。", False, None),
|
||||
("summary", ToolParamType.STRING, "新的概括内容,如果不需要修改可不填。", False, None),
|
||||
("keywords", ToolParamType.STRING, "新的关键词 JSON 字符串,如 ['关键词1','关键词2']。", False, None),
|
||||
("key_point", ToolParamType.STRING, "新的关键信息 JSON 字符串,如 ['要点1','要点2']。", False, None),
|
||||
],
|
||||
update_chat_history,
|
||||
)
|
||||
)
|
||||
|
||||
_dream_tool_registry.register_tool(
|
||||
DreamTool(
|
||||
"create_chat_history",
|
||||
"根据整理后的理解创建一条新的 ChatHistory 概括记录(主题、概括、关键词、关键信息等)。",
|
||||
[
|
||||
("theme", ToolParamType.STRING, "新的主题标题(必填)。", True, None),
|
||||
("summary", ToolParamType.STRING, "新的概括内容(必填)。", True, None),
|
||||
(
|
||||
"keywords",
|
||||
ToolParamType.STRING,
|
||||
"新的关键词 JSON 字符串,如 ['关键词1','关键词2'](必填)。",
|
||||
True,
|
||||
None,
|
||||
),
|
||||
(
|
||||
"key_point",
|
||||
ToolParamType.STRING,
|
||||
"新的关键信息 JSON 字符串,如 ['要点1','要点2'](必填)。",
|
||||
True,
|
||||
None,
|
||||
),
|
||||
("start_time", ToolParamType.STRING, "起始时间戳(秒,Unix 时间,必填)。", True, None),
|
||||
("end_time", ToolParamType.STRING, "结束时间戳(秒,Unix 时间,必填)。", True, None),
|
||||
],
|
||||
create_chat_history,
|
||||
)
|
||||
)
|
||||
|
||||
_dream_tool_registry.register_tool(
|
||||
DreamTool(
|
||||
"finish_maintenance",
|
||||
"结束本次 dream 维护任务。当你认为当前 chat_id 下的维护工作已经完成,没有更多需要整理、合并或修改的内容时,调用此工具来主动结束本次运行。",
|
||||
[
|
||||
(
|
||||
"reason",
|
||||
ToolParamType.STRING,
|
||||
"结束维护的原因说明(可选),例如 '已完成所有记录的整理' 或 '当前记录质量良好,无需进一步维护'。",
|
||||
False,
|
||||
None,
|
||||
),
|
||||
],
|
||||
finish_maintenance,
|
||||
)
|
||||
)
|
||||
|
||||
# ==================== Jargon 维护工具 ====================
|
||||
# 注册 Jargon 工具
|
||||
_dream_tool_registry.register_tool(
|
||||
DreamTool(
|
||||
"search_jargon",
|
||||
"根据一个或多个关键词搜索当前 chat_id 相关的 Jargon 记录概览(只包含 is_jargon=True,含全局 Jargon),便于快速理解黑话库。",
|
||||
[
|
||||
("keyword", ToolParamType.STRING, "按一个或多个关键词搜索内容/含义/推断结果(必填)。", True, None),
|
||||
],
|
||||
search_jargon,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def run_dream_agent_once(
|
||||
chat_id: str,
|
||||
max_iterations: Optional[int] = None,
|
||||
start_memory_id: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
运行一次 dream agent,对指定 chat_id 的 ChatHistory 进行最多 max_iterations 轮的整理。
|
||||
如果 max_iterations 为 None,则使用配置文件中的默认值。
|
||||
"""
|
||||
if max_iterations is None:
|
||||
max_iterations = global_config.dream.max_iterations
|
||||
|
||||
start_ts = time.time()
|
||||
logger.info(f"[dream] 开始对 chat_id={chat_id} 进行 dream 维护,最多迭代 {max_iterations} 轮")
|
||||
|
||||
# 初始化工具(作用域限定在当前 chat_id)
|
||||
init_dream_tools(chat_id)
|
||||
|
||||
tool_registry = get_dream_tool_registry()
|
||||
tool_defs = tool_registry.get_tool_definitions()
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
head_prompt = await global_prompt_manager.format_prompt(
|
||||
"dream_react_head_prompt",
|
||||
bot_name=bot_name,
|
||||
time_now=time_now,
|
||||
chat_id=chat_id,
|
||||
start_memory_id=start_memory_id if start_memory_id is not None else "无(本轮由你自由选择切入点)",
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
conversation_messages: List[Message] = []
|
||||
|
||||
# 如果提供了起始记忆 ID,则在对话正式开始前,先把这条记忆的详细信息放入上下文,
|
||||
# 避免 LLM 还需要额外调用一次 get_chat_history_detail 才能看到起始记忆内容。
|
||||
if start_memory_id is not None:
|
||||
try:
|
||||
record = ChatHistory.get_or_none(ChatHistory.id == start_memory_id)
|
||||
if record:
|
||||
start_time_str = (
|
||||
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.start_time))
|
||||
if record.start_time
|
||||
else "未知"
|
||||
)
|
||||
end_time_str = (
|
||||
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) if record.end_time else "未知"
|
||||
)
|
||||
detail_text = (
|
||||
f"ID={record.id}\n"
|
||||
f"chat_id={record.chat_id}\n"
|
||||
f"时间范围={start_time_str} 至 {end_time_str}\n"
|
||||
f"主题={record.theme or '无'}\n"
|
||||
f"关键词={record.keywords or '无'}\n"
|
||||
f"参与者={record.participants or '无'}\n"
|
||||
f"概括={record.summary or '无'}\n"
|
||||
f"关键信息={record.key_point or '无'}"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"[dream] 预加载起始记忆详情 memory_id={start_memory_id},"
|
||||
f"预览: {detail_text[:200].replace(chr(10), ' ')}"
|
||||
)
|
||||
|
||||
start_detail_builder = MessageBuilder()
|
||||
start_detail_builder.set_role(RoleType.User)
|
||||
start_detail_builder.add_text_content(
|
||||
"【起始记忆详情】以下是本轮随机/指定的起始记忆的详细信息,供你在整理时优先参考:\n\n" + detail_text
|
||||
)
|
||||
conversation_messages.append(start_detail_builder.build())
|
||||
else:
|
||||
logger.warning(
|
||||
f"[dream] 提供的 start_memory_id={start_memory_id} 未找到对应 ChatHistory 记录,"
|
||||
"将不预加载起始记忆详情。"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[dream] 预加载起始记忆详情失败 start_memory_id={start_memory_id}: {e}")
|
||||
|
||||
# 注意:message_factory 必须是同步函数,返回消息列表(不能是 async/coroutine)
|
||||
def message_factory(
|
||||
_client,
|
||||
*,
|
||||
_head_prompt: str = head_prompt,
|
||||
_conversation_messages: List[Message] = conversation_messages,
|
||||
) -> List[Message]:
|
||||
messages: List[Message] = []
|
||||
system_builder = MessageBuilder()
|
||||
system_builder.set_role(RoleType.System)
|
||||
system_builder.add_text_content(_head_prompt)
|
||||
messages.append(system_builder.build())
|
||||
messages.extend(_conversation_messages)
|
||||
return messages
|
||||
|
||||
for iteration in range(1, max_iterations + 1):
|
||||
# 在每轮开始时,添加轮次信息到对话中
|
||||
remaining_rounds = max_iterations - iteration + 1
|
||||
round_info_builder = MessageBuilder()
|
||||
round_info_builder.set_role(RoleType.User)
|
||||
round_info_builder.add_text_content(
|
||||
f"【轮次信息】当前是第 {iteration}/{max_iterations} 轮,还剩 {remaining_rounds} 轮。"
|
||||
)
|
||||
conversation_messages.append(round_info_builder.build())
|
||||
|
||||
# 调用 LLM 让其决定是否要使用工具
|
||||
(
|
||||
success,
|
||||
response,
|
||||
reasoning_content,
|
||||
model_name,
|
||||
tool_calls,
|
||||
) = await llm_api.generate_with_model_with_tools_by_message_factory(
|
||||
message_factory,
|
||||
model_config=model_config.model_task_config.tool_use,
|
||||
tool_options=tool_defs,
|
||||
request_type="dream.react",
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.error(f"[dream] 第 {iteration} 轮 LLM 调用失败: {response}")
|
||||
break
|
||||
|
||||
# 先输出「思考」内容,再输出工具调用信息(思考文本较长,仅在 debug 下输出)
|
||||
thought_log = reasoning_content or (response[:300] if response else "")
|
||||
if thought_log:
|
||||
logger.debug(f"[dream] 第 {iteration} 轮思考内容: {thought_log}")
|
||||
|
||||
logger.info(
|
||||
f"[dream] 第 {iteration} 轮响应,模型={model_name},工具调用数={len(tool_calls) if tool_calls else 0}"
|
||||
)
|
||||
|
||||
assistant_msg: Optional[Message] = None
|
||||
if tool_calls:
|
||||
builder = MessageBuilder()
|
||||
builder.set_role(RoleType.Assistant)
|
||||
if response and response.strip():
|
||||
builder.add_text_content(response)
|
||||
builder.set_tool_calls(tool_calls)
|
||||
assistant_msg = builder.build()
|
||||
elif response and response.strip():
|
||||
builder = MessageBuilder()
|
||||
builder.set_role(RoleType.Assistant)
|
||||
builder.add_text_content(response)
|
||||
assistant_msg = builder.build()
|
||||
|
||||
if assistant_msg:
|
||||
conversation_messages.append(assistant_msg)
|
||||
|
||||
# 如果本轮没有工具调用,仅作为思考记录,继续下一轮
|
||||
if not tool_calls:
|
||||
logger.debug(f"[dream] 第 {iteration} 轮未调用任何工具,仅记录思考。")
|
||||
continue
|
||||
|
||||
# 执行所有工具调用
|
||||
tasks = []
|
||||
finish_maintenance_called = False
|
||||
for tc in tool_calls:
|
||||
tool = tool_registry.get_tool(tc.func_name)
|
||||
if not tool:
|
||||
logger.warning(f"[dream] 未知工具:{tc.func_name}")
|
||||
continue
|
||||
|
||||
# 检测是否调用了 finish_maintenance 工具
|
||||
if tc.func_name == "finish_maintenance":
|
||||
finish_maintenance_called = True
|
||||
|
||||
params = tc.args or {}
|
||||
|
||||
async def _run_single(t: DreamTool, p: Dict[str, Any], call_id: str, it: int):
|
||||
try:
|
||||
result = await t.execute(**p)
|
||||
logger.debug(f"[dream] 第 {it} 轮 工具 {t.name} 执行完成")
|
||||
return call_id, result
|
||||
except Exception as e:
|
||||
logger.error(f"[dream] 工具 {t.name} 执行失败: {e}")
|
||||
return call_id, f"工具 {t.name} 执行失败: {e}"
|
||||
|
||||
tasks.append(_run_single(tool, params, tc.call_id, iteration))
|
||||
|
||||
if not tasks:
|
||||
continue
|
||||
|
||||
tool_results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
|
||||
# 将工具结果作为 Tool 消息追加
|
||||
for call_id, obs in tool_results:
|
||||
tool_builder = MessageBuilder()
|
||||
tool_builder.set_role(RoleType.Tool)
|
||||
tool_builder.add_text_content(str(obs))
|
||||
tool_builder.add_tool_call(call_id)
|
||||
conversation_messages.append(tool_builder.build())
|
||||
|
||||
# 如果调用了 finish_maintenance 工具,提前结束本次运行
|
||||
if finish_maintenance_called:
|
||||
logger.info(f"[dream] 第 {iteration} 轮检测到 finish_maintenance 工具调用,提前结束本次维护。")
|
||||
break
|
||||
|
||||
cost = time.time() - start_ts
|
||||
logger.info(f"[dream] 对 chat_id={chat_id} 的 dream 维护结束,共迭代 {iteration} 轮,耗时 {cost:.1f} 秒")
|
||||
|
||||
# 生成梦境总结
|
||||
await generate_dream_summary(chat_id, conversation_messages, iteration, cost)
|
||||
|
||||
|
||||
def _pick_random_chat_id() -> Optional[str]:
|
||||
"""从 ChatHistory 中随机选择一个 chat_id,用于 dream agent 本次维护
|
||||
|
||||
规则:
|
||||
- 只在 chat_id 所属的 ChatHistory 记录数 >= 10 时才会参与随机选择;
|
||||
- 记录数不足 10 的 chat_id 将被跳过,不会触发做梦 react。
|
||||
"""
|
||||
try:
|
||||
# 统计每个 chat_id 的记录数,只保留记录数 >= 10 的 chat_id
|
||||
rows = (
|
||||
ChatHistory.select(ChatHistory.chat_id, fn.COUNT(ChatHistory.id).alias("cnt"))
|
||||
.group_by(ChatHistory.chat_id)
|
||||
.having(fn.COUNT(ChatHistory.id) >= 10)
|
||||
.order_by(ChatHistory.chat_id)
|
||||
.limit(200)
|
||||
)
|
||||
eligible_ids = [r.chat_id for r in rows]
|
||||
if not eligible_ids:
|
||||
logger.warning("[dream] ChatHistory 中暂无满足条件(记录数 >= 10)的 chat_id,本轮 dream 任务跳过。")
|
||||
return None
|
||||
chosen = random.choice(eligible_ids)
|
||||
logger.info(f"[dream] 从 {len(eligible_ids)} 个满足条件的 chat_id 中随机选择:{chosen}")
|
||||
return chosen
|
||||
except Exception as e:
|
||||
logger.error(f"[dream] 随机选择 chat_id 失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _pick_random_memory_for_chat(chat_id: str) -> Optional[int]:
|
||||
"""
|
||||
在给定 chat_id 下随机选择一条 ChatHistory 记录,作为本轮整理的起始记忆。
|
||||
"""
|
||||
try:
|
||||
rows = (
|
||||
ChatHistory.select(ChatHistory.id)
|
||||
.where(ChatHistory.chat_id == chat_id)
|
||||
.order_by(ChatHistory.start_time.asc())
|
||||
.limit(200)
|
||||
)
|
||||
ids = [r.id for r in rows]
|
||||
if not ids:
|
||||
logger.warning(f"[dream] chat_id={chat_id} 下暂无 ChatHistory 记录,无法选择起始记忆。")
|
||||
return None
|
||||
return random.choice(ids)
|
||||
except Exception as e:
|
||||
logger.error(f"[dream] 在 chat_id={chat_id} 下随机选择起始记忆失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def run_dream_cycle_once() -> None:
|
||||
"""
|
||||
单次 dream 周期:
|
||||
- 随机选择一个 chat_id
|
||||
- 在该 chat_id 下随机选择一条 ChatHistory 作为起始记忆
|
||||
- 以这条起始记忆为切入点,对该 chat_id 运行一次 dream agent(最多 15 轮)
|
||||
"""
|
||||
chat_id = _pick_random_chat_id()
|
||||
if not chat_id:
|
||||
return
|
||||
|
||||
start_memory_id = _pick_random_memory_for_chat(chat_id)
|
||||
await run_dream_agent_once(
|
||||
chat_id=chat_id,
|
||||
max_iterations=None, # 使用配置文件中的默认值
|
||||
start_memory_id=start_memory_id,
|
||||
)
|
||||
|
||||
|
||||
async def start_dream_scheduler(
|
||||
first_delay_seconds: Optional[int] = None,
|
||||
interval_seconds: Optional[int] = None,
|
||||
stop_event: Optional[asyncio.Event] = None,
|
||||
) -> None:
|
||||
"""
|
||||
dream 调度器:
|
||||
- 程序启动后先等待 first_delay_seconds(如果为 None,则使用配置文件中的值,默认 60s)
|
||||
- 然后每隔 interval_seconds(如果为 None,则使用配置文件中的值,默认 30 分钟)运行一次 dream agent 周期
|
||||
- 如果提供 stop_event,则在 stop_event 被 set() 后优雅退出循环
|
||||
"""
|
||||
if first_delay_seconds is None:
|
||||
first_delay_seconds = global_config.dream.first_delay_seconds
|
||||
|
||||
if interval_seconds is None:
|
||||
interval_seconds = global_config.dream.interval_minutes * 60
|
||||
|
||||
logger.info(
|
||||
f"[dream] dream 调度器启动:首次延迟 {first_delay_seconds}s,之后每隔 {interval_seconds}s ({interval_seconds // 60} 分钟) 运行一次 dream agent"
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.sleep(first_delay_seconds)
|
||||
while True:
|
||||
if stop_event is not None and stop_event.is_set():
|
||||
logger.info("[dream] 收到停止事件,结束 dream 调度器循环。")
|
||||
break
|
||||
|
||||
start_ts = time.time()
|
||||
# 检查当前时间是否在允许做梦的时间段内
|
||||
if not global_config.dream.is_in_dream_time():
|
||||
logger.debug("[dream] 当前时间不在允许做梦的时间段内,跳过本次执行")
|
||||
else:
|
||||
try:
|
||||
await run_dream_cycle_once()
|
||||
except Exception as e:
|
||||
logger.error(f"[dream] 单次 dream 周期执行异常: {e}")
|
||||
|
||||
elapsed = time.time() - start_ts
|
||||
# 保证两次执行之间至少间隔 interval_seconds
|
||||
to_sleep = max(0.0, interval_seconds - elapsed)
|
||||
await asyncio.sleep(to_sleep)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("[dream] dream 调度器任务被取消,准备退出。")
|
||||
raise
|
||||
|
||||
|
||||
# 初始化提示词
|
||||
init_dream_prompts()
|
||||
251
src/dream/dream_generator.py
Normal file
251
src/dream/dream_generator.py
Normal file
@@ -0,0 +1,251 @@
|
||||
import random
|
||||
from typing import List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt_builder import Prompt
|
||||
from src.llm_models.payload_content.message import RoleType, Message
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
logger = get_logger("dream_generator")
|
||||
|
||||
# 初始化 utils 模型用于生成梦境总结
|
||||
_dream_summary_model: Optional[LLMRequest] = None
|
||||
|
||||
# 梦境风格列表(21种)
|
||||
DREAM_STYLES = [
|
||||
"保持诗意和想象力,自由编写",
|
||||
"诗意朦胧,如薄雾笼罩的清晨",
|
||||
"奇幻冒险,充满未知与探索",
|
||||
"温暖怀旧,带着时光的痕迹",
|
||||
"神秘悬疑,暗藏深意",
|
||||
"浪漫唯美,如诗如画",
|
||||
"科幻未来,科技与想象交织",
|
||||
"自然清新,如山林间的微风",
|
||||
"深沉哲思,引人深思",
|
||||
"轻松幽默,充满趣味",
|
||||
"悲伤忧郁,带着淡淡哀愁",
|
||||
"激昂热烈,充满活力",
|
||||
"宁静平和,如湖面般平静",
|
||||
"荒诞离奇,打破常规",
|
||||
"细腻温柔,如春风拂面",
|
||||
"壮阔宏大,气势磅礴",
|
||||
"简约纯粹,返璞归真",
|
||||
"复杂多变,层次丰富",
|
||||
"梦幻迷离,虚实难辨",
|
||||
"现实写意,贴近生活",
|
||||
"抽象概念,超越具象",
|
||||
]
|
||||
|
||||
|
||||
def get_random_dream_styles(count: int = 2) -> List[str]:
|
||||
"""从梦境风格列表中随机选择指定数量的风格"""
|
||||
return random.sample(DREAM_STYLES, min(count, len(DREAM_STYLES)))
|
||||
|
||||
|
||||
def get_dream_summary_model() -> LLMRequest:
|
||||
"""获取用于生成梦境总结的 utils 模型实例"""
|
||||
global _dream_summary_model
|
||||
if _dream_summary_model is None:
|
||||
_dream_summary_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="dream.summary",
|
||||
)
|
||||
return _dream_summary_model
|
||||
|
||||
|
||||
def init_dream_summary_prompt() -> None:
|
||||
"""初始化梦境总结的提示词"""
|
||||
Prompt(
|
||||
"""
|
||||
你刚刚完成了一次对聊天记录的记忆整理工作。以下是整理过程的摘要:
|
||||
整理过程:
|
||||
{conversation_text}
|
||||
|
||||
请将这次整理涉及的相关信息改写为一个富有诗意和想象力的"梦境",请你仅使用具体的记忆的内容,而不是整理过程编写。
|
||||
要求:
|
||||
1. 使用第一人称视角
|
||||
2. 叙述直白,不要复杂修辞,口语化
|
||||
3. 长度控制在200-800字
|
||||
4. 用中文输出
|
||||
梦境风格:
|
||||
{dream_styles}
|
||||
请直接输出梦境内容,不要添加其他说明:
|
||||
""",
|
||||
name="dream_summary_prompt",
|
||||
)
|
||||
|
||||
|
||||
async def generate_dream_summary(
|
||||
chat_id: str,
|
||||
conversation_messages: List[Message],
|
||||
total_iterations: int,
|
||||
time_cost: float,
|
||||
) -> None:
|
||||
"""生成梦境总结,输出到日志,并根据配置可选地推送给指定用户"""
|
||||
try:
|
||||
import json
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
|
||||
# 第一步:建立工具调用结果映射 (call_id -> result)
|
||||
tool_results_map: dict[str, str] = {}
|
||||
for msg in conversation_messages:
|
||||
if msg.role == RoleType.Tool and msg.tool_call_id:
|
||||
content = ""
|
||||
if msg.content:
|
||||
if isinstance(msg.content, list) and msg.content:
|
||||
content = msg.content[0].text if hasattr(msg.content[0], "text") else str(msg.content[0])
|
||||
else:
|
||||
content = str(msg.content)
|
||||
tool_results_map[msg.tool_call_id] = content
|
||||
|
||||
# 第二步:详细记录所有工具调用操作和结果到日志
|
||||
tool_call_count = 0
|
||||
logger.info(f"[dream][工具调用详情] 开始记录 chat_id={chat_id} 的所有工具调用操作:")
|
||||
|
||||
for msg in conversation_messages:
|
||||
if msg.role == RoleType.Assistant and msg.tool_calls:
|
||||
tool_call_count += 1
|
||||
# 提取思考内容
|
||||
thought_content = ""
|
||||
if msg.content:
|
||||
if isinstance(msg.content, list) and msg.content:
|
||||
thought_content = (
|
||||
msg.content[0].text if hasattr(msg.content[0], "text") else str(msg.content[0])
|
||||
)
|
||||
else:
|
||||
thought_content = str(msg.content)
|
||||
|
||||
logger.info(f"[dream][工具调用详情] === 第 {tool_call_count} 组工具调用 ===")
|
||||
if thought_content:
|
||||
logger.info(
|
||||
f"[dream][工具调用详情] 思考内容:{thought_content[:500]}{'...' if len(thought_content) > 500 else ''}"
|
||||
)
|
||||
|
||||
# 记录每个工具调用的详细信息
|
||||
for idx, tool_call in enumerate(msg.tool_calls, 1):
|
||||
tool_name = tool_call.func_name
|
||||
tool_args = tool_call.args or {}
|
||||
tool_call_id = tool_call.call_id
|
||||
tool_result = tool_results_map.get(tool_call_id, "未找到执行结果")
|
||||
|
||||
# 格式化参数
|
||||
try:
|
||||
args_str = json.dumps(tool_args, ensure_ascii=False, indent=2) if tool_args else "无参数"
|
||||
except Exception:
|
||||
args_str = str(tool_args)
|
||||
|
||||
logger.info(f"[dream][工具调用详情] --- 工具 {idx}: {tool_name} ---")
|
||||
logger.info(f"[dream][工具调用详情] 调用参数:\n{args_str}")
|
||||
logger.info(f"[dream][工具调用详情] 执行结果:\n{tool_result}")
|
||||
logger.info(f"[dream][工具调用详情] {'-' * 60}")
|
||||
|
||||
logger.info(f"[dream][工具调用详情] 共记录了 {tool_call_count} 组工具调用操作")
|
||||
|
||||
# 第三步:构建对话历史摘要(用于生成梦境)
|
||||
conversation_summary = []
|
||||
for msg in conversation_messages:
|
||||
role = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
|
||||
content = ""
|
||||
if msg.content:
|
||||
content = msg.content[0].text if isinstance(msg.content, list) and msg.content else str(msg.content)
|
||||
|
||||
if role == "user" and "轮次信息" in content:
|
||||
# 跳过轮次信息消息
|
||||
continue
|
||||
|
||||
if role == "assistant":
|
||||
# 只保留思考内容,简化工具调用信息
|
||||
if content:
|
||||
# 截取前500字符,避免过长
|
||||
content_preview = content[:500] + ("..." if len(content) > 500 else "")
|
||||
conversation_summary.append(f"[{role}] {content_preview}")
|
||||
elif role == "tool":
|
||||
# 工具结果,只保留关键信息
|
||||
if content:
|
||||
# 截取前300字符
|
||||
content_preview = content[:300] + ("..." if len(content) > 300 else "")
|
||||
conversation_summary.append(f"[工具执行] {content_preview}")
|
||||
|
||||
conversation_text = "\n".join(conversation_summary[-20:]) # 只保留最后20条消息
|
||||
|
||||
# 随机选择2个梦境风格
|
||||
selected_styles = get_random_dream_styles(2)
|
||||
dream_styles_text = "\n".join([f"{i + 1}. {style}" for i, style in enumerate(selected_styles)])
|
||||
|
||||
# 使用 Prompt 管理器格式化梦境生成 prompt
|
||||
dream_prompt = await global_prompt_manager.format_prompt(
|
||||
"dream_summary_prompt",
|
||||
chat_id=chat_id,
|
||||
total_iterations=total_iterations,
|
||||
time_cost=time_cost,
|
||||
conversation_text=conversation_text,
|
||||
dream_styles=dream_styles_text,
|
||||
)
|
||||
|
||||
# 调用 utils 模型生成梦境
|
||||
summary_model = get_dream_summary_model()
|
||||
dream_content, (reasoning, model_name, _) = await summary_model.generate_response_async(
|
||||
dream_prompt,
|
||||
max_tokens=512,
|
||||
temperature=0.8,
|
||||
)
|
||||
|
||||
if dream_content:
|
||||
logger.info(f"[dream][梦境总结] 对 chat_id={chat_id} 的整理过程梦境:\n{dream_content}")
|
||||
|
||||
# 第五步:根据配置决定是否将梦境发送给指定用户
|
||||
try:
|
||||
dream_send_raw = getattr(global_config.dream, "dream_send", "") or ""
|
||||
dream_send = dream_send_raw.strip()
|
||||
if dream_send:
|
||||
parts = dream_send.split(":")
|
||||
if len(parts) != 2:
|
||||
logger.warning(
|
||||
f"[dream][梦境总结] dream_send 配置格式不正确,应为 'platform:user_id',当前值: {dream_send_raw!r}"
|
||||
)
|
||||
else:
|
||||
platform, user_id = parts[0].strip(), parts[1].strip()
|
||||
if not platform or not user_id:
|
||||
logger.warning(
|
||||
f"[dream][梦境总结] dream_send 平台或用户ID为空,当前值: {dream_send_raw!r}"
|
||||
)
|
||||
else:
|
||||
# 默认为私聊会话
|
||||
stream_id = get_chat_manager().get_stream_id(
|
||||
platform=platform,
|
||||
id=str(user_id),
|
||||
is_group=False,
|
||||
)
|
||||
if not stream_id:
|
||||
logger.error(
|
||||
f"[dream][梦境总结] 无法根据 dream_send 找到有效的聊天流,"
|
||||
f"platform={platform!r}, user_id={user_id!r}"
|
||||
)
|
||||
else:
|
||||
ok = await send_api.text_to_stream(
|
||||
dream_content,
|
||||
stream_id=stream_id,
|
||||
typing=False,
|
||||
storage_message=True,
|
||||
)
|
||||
if ok:
|
||||
logger.info(
|
||||
f"[dream][梦境总结] 已将梦境结果发送给配置的目标用户: {platform}:{user_id}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"[dream][梦境总结] 向 {platform}:{user_id} 发送梦境结果失败"
|
||||
)
|
||||
except Exception as send_exc:
|
||||
logger.error(f"[dream][梦境总结] 发送梦境结果到配置用户时出错: {send_exc}", exc_info=True)
|
||||
else:
|
||||
logger.warning("[dream][梦境总结] 未能生成梦境总结")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[dream][梦境总结] 生成梦境总结失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
init_dream_summary_prompt()
|
||||
7
src/dream/tools/__init__.py
Normal file
7
src/dream/tools/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
dream agent 工具实现模块。
|
||||
|
||||
每个工具的具体实现放在独立文件中,通过 make_xxx(chat_id) 工厂函数
|
||||
生成绑定到特定 chat_id 的协程函数,由 dream_agent.init_dream_tools 统一注册。
|
||||
"""
|
||||
|
||||
63
src/dream/tools/create_chat_history_tool.py
Normal file
63
src/dream/tools/create_chat_history_tool.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import time
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import ChatHistory
|
||||
|
||||
logger = get_logger("dream_agent")
|
||||
|
||||
|
||||
def make_create_chat_history(chat_id: str):
|
||||
async def create_chat_history(
|
||||
theme: str,
|
||||
summary: str,
|
||||
keywords: str,
|
||||
key_point: str,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
) -> str:
|
||||
"""创建一条新的 ChatHistory 概括记录(用于整理/合并后的新记忆)"""
|
||||
try:
|
||||
logger.info(
|
||||
f"[dream][tool] 调用 create_chat_history("
|
||||
f"theme={bool(theme)}, summary={bool(summary)}, "
|
||||
f"keywords={bool(keywords)}, key_point={bool(key_point)}, "
|
||||
f"start_time={start_time}, end_time={end_time}) (chat_id={chat_id})"
|
||||
)
|
||||
|
||||
now_ts = time.time()
|
||||
|
||||
# 将传入的 start_time/end_time(如果有)解析为时间戳;否则回退为当前时间
|
||||
def _parse_ts(value, default):
|
||||
if value is None:
|
||||
return default
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
start_ts = _parse_ts(start_time, now_ts)
|
||||
end_ts = _parse_ts(end_time, now_ts)
|
||||
|
||||
record = ChatHistory.create(
|
||||
chat_id=chat_id,
|
||||
theme=theme,
|
||||
summary=summary,
|
||||
keywords=keywords,
|
||||
key_point=key_point,
|
||||
# 对于由 dream 整理产生的新概括,时间范围优先使用工具提供的时间,否则使用当前时间占位
|
||||
start_time=start_ts,
|
||||
end_time=end_ts,
|
||||
)
|
||||
|
||||
msg = (
|
||||
f"已创建新的 ChatHistory 记录,ID={record.id},"
|
||||
f"theme={record.theme or '无'},summary={'有' if record.summary else '无'}。"
|
||||
)
|
||||
logger.info(f"[dream][tool] create_chat_history 完成: {msg}")
|
||||
return msg
|
||||
except Exception as e:
|
||||
logger.error(f"create_chat_history 失败: {e}")
|
||||
return f"create_chat_history 执行失败: {e}"
|
||||
|
||||
return create_chat_history
|
||||
|
||||
26
src/dream/tools/delete_chat_history_tool.py
Normal file
26
src/dream/tools/delete_chat_history_tool.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import ChatHistory
|
||||
|
||||
logger = get_logger("dream_agent")
|
||||
|
||||
|
||||
def make_delete_chat_history(chat_id: str): # chat_id 目前未直接使用,预留以备扩展
|
||||
async def delete_chat_history(memory_id: int) -> str:
|
||||
"""删除一条 chat_history 记录"""
|
||||
try:
|
||||
logger.info(f"[dream][tool] 调用 delete_chat_history(memory_id={memory_id})")
|
||||
record = ChatHistory.get_or_none(ChatHistory.id == memory_id)
|
||||
if not record:
|
||||
msg = f"未找到 ID={memory_id} 的 ChatHistory 记录,无法删除。"
|
||||
logger.info(f"[dream][tool] delete_chat_history 未找到记录: {msg}")
|
||||
return msg
|
||||
rows = ChatHistory.delete().where(ChatHistory.id == memory_id).execute()
|
||||
msg = f"已删除 ID={memory_id} 的 ChatHistory 记录,受影响行数={rows}。"
|
||||
logger.info(f"[dream][tool] delete_chat_history 完成: {msg}")
|
||||
return msg
|
||||
except Exception as e:
|
||||
logger.error(f"delete_chat_history 失败: {e}")
|
||||
return f"delete_chat_history 执行失败: {e}"
|
||||
|
||||
return delete_chat_history
|
||||
|
||||
26
src/dream/tools/delete_jargon_tool.py
Normal file
26
src/dream/tools/delete_jargon_tool.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Jargon
|
||||
|
||||
logger = get_logger("dream_agent")
|
||||
|
||||
|
||||
def make_delete_jargon(chat_id: str): # chat_id 目前未直接使用,预留以备扩展
|
||||
async def delete_jargon(jargon_id: int) -> str:
|
||||
"""删除一条 Jargon 记录"""
|
||||
try:
|
||||
logger.info(f"[dream][tool] 调用 delete_jargon(jargon_id={jargon_id})")
|
||||
record = Jargon.get_or_none(Jargon.id == jargon_id)
|
||||
if not record:
|
||||
msg = f"未找到 ID={jargon_id} 的 Jargon 记录,无法删除。"
|
||||
logger.info(f"[dream][tool] delete_jargon 未找到记录: {msg}")
|
||||
return msg
|
||||
rows = Jargon.delete().where(Jargon.id == jargon_id).execute()
|
||||
msg = f"已删除 ID={jargon_id} 的 Jargon 记录(内容:{record.content}),受影响行数={rows}。"
|
||||
logger.info(f"[dream][tool] delete_jargon 完成: {msg}")
|
||||
return msg
|
||||
except Exception as e:
|
||||
logger.error(f"delete_jargon 失败: {e}")
|
||||
return f"delete_jargon 执行失败: {e}"
|
||||
|
||||
return delete_jargon
|
||||
|
||||
17
src/dream/tools/finish_maintenance_tool.py
Normal file
17
src/dream/tools/finish_maintenance_tool.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("dream_agent")
|
||||
|
||||
|
||||
def make_finish_maintenance(chat_id: str): # chat_id 目前未直接使用,预留以备扩展
|
||||
async def finish_maintenance(reason: Optional[str] = None) -> str:
|
||||
"""结束本次 dream 维护任务。当你认为当前 chat_id 下的维护工作已经完成,没有更多需要整理的内容时,调用此工具来结束本次运行。"""
|
||||
reason_text = f",原因:{reason}" if reason else ""
|
||||
msg = f"DREAM_MAINTENANCE_COMPLETE{reason_text}"
|
||||
logger.info(f"[dream][tool] 调用 finish_maintenance,结束本次维护{reason_text}")
|
||||
return msg
|
||||
|
||||
return finish_maintenance
|
||||
|
||||
45
src/dream/tools/get_chat_history_detail_tool.py
Normal file
45
src/dream/tools/get_chat_history_detail_tool.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import time
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import ChatHistory
|
||||
|
||||
logger = get_logger("dream_agent")
|
||||
|
||||
|
||||
def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用,预留以备扩展
|
||||
async def get_chat_history_detail(memory_id: int) -> str:
|
||||
"""获取单条 chat_history 的完整内容"""
|
||||
try:
|
||||
logger.info(f"[dream][tool] 调用 get_chat_history_detail(memory_id={memory_id})")
|
||||
record = ChatHistory.get_or_none(ChatHistory.id == memory_id)
|
||||
if not record:
|
||||
msg = f"未找到 ID={memory_id} 的 ChatHistory 记录。"
|
||||
logger.info(f"[dream][tool] get_chat_history_detail 未找到记录: {msg}")
|
||||
return msg
|
||||
|
||||
# 将时间戳转换为可读时间格式
|
||||
start_time_str = (
|
||||
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.start_time)) if record.start_time else "未知"
|
||||
)
|
||||
end_time_str = (
|
||||
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) if record.end_time else "未知"
|
||||
)
|
||||
|
||||
result = (
|
||||
f"ID={record.id}\n"
|
||||
# f"chat_id={record.chat_id}\n"
|
||||
f"时间范围={start_time_str} 至 {end_time_str}\n"
|
||||
f"主题={record.theme or '无'}\n"
|
||||
f"关键词={record.keywords or '无'}\n"
|
||||
f"参与者={record.participants or '无'}\n"
|
||||
f"概括={record.summary or '无'}\n"
|
||||
f"关键信息={record.key_point or '无'}"
|
||||
)
|
||||
logger.debug(f"[dream][tool] get_chat_history_detail 成功,预览: {result[:200].replace(chr(10), ' ')}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"get_chat_history_detail 失败: {e}")
|
||||
return f"get_chat_history_detail 执行失败: {e}"
|
||||
|
||||
return get_chat_history_detail
|
||||
|
||||
215
src/dream/tools/search_chat_history_tool.py
Normal file
215
src/dream/tools/search_chat_history_tool.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import ChatHistory
|
||||
from src.chat.utils.utils import parse_keywords_string
|
||||
|
||||
logger = get_logger("dream_agent")
|
||||
|
||||
|
||||
def make_search_chat_history(chat_id: str):
|
||||
async def search_chat_history(
|
||||
keyword: Optional[str] = None,
|
||||
participant: Optional[str] = None,
|
||||
) -> str:
|
||||
"""根据关键词或参与人查询记忆,返回匹配的记忆id、记忆标题theme和关键词keywords(dream 维护专用版本)"""
|
||||
try:
|
||||
# 检查参数
|
||||
if not keyword and not participant:
|
||||
return "未指定查询参数(需要提供keyword或participant之一)"
|
||||
|
||||
logger.info(
|
||||
f"[dream][tool] 调用 search_chat_history(keyword={keyword}, participant={participant}) "
|
||||
f"(作用域 chat_id={chat_id})"
|
||||
)
|
||||
|
||||
# 构建查询条件
|
||||
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
|
||||
|
||||
# 执行查询(按时间倒序,最近的在前)
|
||||
records = list(query.order_by(ChatHistory.start_time.desc()).limit(50))
|
||||
|
||||
filtered_records: List[ChatHistory] = []
|
||||
|
||||
for record in records:
|
||||
participant_matched = True # 如果没有participant条件,默认为True
|
||||
keyword_matched = True # 如果没有keyword条件,默认为True
|
||||
|
||||
# 检查参与人匹配
|
||||
if participant:
|
||||
participant_matched = False
|
||||
participants_list: List[str] = []
|
||||
if record.participants:
|
||||
try:
|
||||
participants_data = (
|
||||
json.loads(record.participants)
|
||||
if isinstance(record.participants, str)
|
||||
else record.participants
|
||||
)
|
||||
if isinstance(participants_data, list):
|
||||
participants_list = [str(p).lower() for p in participants_data]
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
participant_lower = participant.lower().strip()
|
||||
if participant_lower and any(participant_lower in p for p in participants_list):
|
||||
participant_matched = True
|
||||
|
||||
# 检查关键词匹配
|
||||
if keyword:
|
||||
keyword_matched = False
|
||||
# 解析多个关键词(支持空格、逗号等分隔符)
|
||||
keywords_list = parse_keywords_string(keyword)
|
||||
if not keywords_list:
|
||||
keywords_list = [keyword.strip()] if keyword.strip() else []
|
||||
|
||||
# 转换为小写以便匹配
|
||||
keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()]
|
||||
|
||||
if keywords_lower:
|
||||
# 在theme、keywords、summary、original_text中搜索
|
||||
theme = (record.theme or "").lower()
|
||||
summary = (record.summary or "").lower()
|
||||
original_text = (record.original_text or "").lower()
|
||||
|
||||
# 解析record中的keywords JSON
|
||||
record_keywords_list: List[str] = []
|
||||
if record.keywords:
|
||||
try:
|
||||
keywords_data = (
|
||||
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||
)
|
||||
if isinstance(keywords_data, list):
|
||||
record_keywords_list = [str(k).lower() for k in keywords_data]
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
# 有容错的全匹配:如果关键词数量>2,允许n-1个关键词匹配;否则必须全部匹配
|
||||
matched_count = 0
|
||||
for kw in keywords_lower:
|
||||
kw_matched = (
|
||||
kw in theme
|
||||
or kw in summary
|
||||
or kw in original_text
|
||||
or any(kw in k for k in record_keywords_list)
|
||||
)
|
||||
if kw_matched:
|
||||
matched_count += 1
|
||||
|
||||
# 计算需要匹配的关键词数量
|
||||
total_keywords = len(keywords_lower)
|
||||
if total_keywords > 2:
|
||||
# 关键词数量>2,允许n-1个关键词匹配
|
||||
required_matches = total_keywords - 1
|
||||
else:
|
||||
# 关键词数量<=2,必须全部匹配
|
||||
required_matches = total_keywords
|
||||
|
||||
keyword_matched = matched_count >= required_matches
|
||||
|
||||
# 两者都匹配(如果同时有participant和keyword,需要两者都匹配;如果只有一个条件,只需要该条件匹配)
|
||||
matched = participant_matched and keyword_matched
|
||||
|
||||
if matched:
|
||||
filtered_records.append(record)
|
||||
|
||||
if not filtered_records:
|
||||
if keyword and participant:
|
||||
keywords_str = "、".join(parse_keywords_string(keyword) if keyword else [])
|
||||
return f"未找到包含关键词'{keywords_str}'且参与人包含'{participant}'的聊天记录"
|
||||
elif keyword:
|
||||
keywords_list = parse_keywords_string(keyword)
|
||||
keywords_str = "、".join(keywords_list)
|
||||
if len(keywords_list) > 2:
|
||||
required_count = len(keywords_list) - 1
|
||||
return f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
|
||||
else:
|
||||
return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
|
||||
elif participant:
|
||||
return f"未找到参与人包含'{participant}'的聊天记录"
|
||||
else:
|
||||
return "未找到相关聊天记录"
|
||||
|
||||
# 如果匹配结果超过20条,不返回具体记录,只返回提示和所有相关关键词
|
||||
if len(filtered_records) > 20:
|
||||
all_keywords_set = set()
|
||||
for record in filtered_records:
|
||||
if record.keywords:
|
||||
try:
|
||||
keywords_data = (
|
||||
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||
)
|
||||
if isinstance(keywords_data, list):
|
||||
for k in keywords_data:
|
||||
k_str = str(k).strip()
|
||||
if k_str:
|
||||
all_keywords_set.add(k_str)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
continue
|
||||
|
||||
search_label = keyword or participant or "当前条件"
|
||||
|
||||
if all_keywords_set:
|
||||
keywords_str = "、".join(sorted(all_keywords_set))
|
||||
response_text = (
|
||||
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
|
||||
f'有关"{search_label}"的关键词:\n'
|
||||
f"{keywords_str}"
|
||||
)
|
||||
else:
|
||||
response_text = (
|
||||
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
|
||||
f'有关"{search_label}"的关键词信息为空'
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[dream][tool] search_chat_history 匹配结果超过20条,返回关键词汇总提示,总数={len(filtered_records)}"
|
||||
)
|
||||
return response_text
|
||||
|
||||
# 构建结果文本,返回id、theme和keywords(最多20条)
|
||||
results: List[str] = []
|
||||
for record in filtered_records[:20]:
|
||||
result_parts: List[str] = []
|
||||
|
||||
# 记忆ID
|
||||
result_parts.append(f"记忆ID:{record.id}")
|
||||
|
||||
# 主题
|
||||
if record.theme:
|
||||
result_parts.append(f"主题:{record.theme}")
|
||||
else:
|
||||
result_parts.append("主题:(无)")
|
||||
|
||||
# 关键词
|
||||
if record.keywords:
|
||||
try:
|
||||
keywords_data = (
|
||||
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||
)
|
||||
if isinstance(keywords_data, list) and keywords_data:
|
||||
keywords_str = "、".join([str(k) for k in keywords_data])
|
||||
result_parts.append(f"关键词:{keywords_str}")
|
||||
else:
|
||||
result_parts.append("关键词:(无)")
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
result_parts.append("关键词:(无)")
|
||||
else:
|
||||
result_parts.append("关键词:(无)")
|
||||
|
||||
results.append("\n".join(result_parts))
|
||||
|
||||
if not results:
|
||||
return "未找到相关聊天记录"
|
||||
|
||||
response_text = "\n\n---\n\n".join(results)
|
||||
|
||||
logger.info(f"[dream][tool] search_chat_history 返回 {len(filtered_records)} 条匹配记录")
|
||||
return response_text
|
||||
except Exception as e:
|
||||
logger.error(f"search_chat_history 失败: {e}")
|
||||
return f"search_chat_history 执行失败: {e}"
|
||||
|
||||
return search_chat_history
|
||||
|
||||
102
src/dream/tools/search_jargon_tool.py
Normal file
102
src/dream/tools/search_jargon_tool.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from typing import List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Jargon
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.utils import parse_keywords_string
|
||||
from src.bw_learner.learner_utils import parse_chat_id_list, chat_id_list_contains
|
||||
|
||||
logger = get_logger("dream_agent")
|
||||
|
||||
|
||||
def make_search_jargon(chat_id: str):
|
||||
async def search_jargon(keyword: str) -> str:
|
||||
"""根据一个或多个关键词搜索当前 chat_id 相关的 Jargon 记录概览(只包含 is_jargon=True,是否跨 chat_id 由 all_global 决定)"""
|
||||
try:
|
||||
if not keyword or not keyword.strip():
|
||||
return "未指定查询关键词(参数 keyword 为必填,且不能为空)"
|
||||
|
||||
logger.info(f"[dream][tool] 调用 search_jargon(keyword={keyword}) (作用域 chat_id={chat_id})")
|
||||
|
||||
# 基础条件:只查 is_jargon=True 的记录
|
||||
query = Jargon.select().where(Jargon.is_jargon)
|
||||
|
||||
# 根据 all_global 配置决定 chat_id 作用域
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启全局黑话:只看 is_global=True 的记录,不区分 chat_id
|
||||
query = query.where(Jargon.is_global)
|
||||
else:
|
||||
# 关闭全局黑话:后续在 Python 层按 chat_id 列表过滤(包含 is_global=True)
|
||||
pass
|
||||
|
||||
# 先按使用次数排序取一批候选,做一个安全上限
|
||||
query = query.order_by(Jargon.count.desc()).limit(200)
|
||||
candidates = list(query)
|
||||
|
||||
if not candidates:
|
||||
msg = "未找到符合条件的 Jargon 记录。"
|
||||
logger.info(f"[dream][tool] search_jargon 无记录: {msg}")
|
||||
return msg
|
||||
|
||||
# 关键词为必填,因此此处必然执行关键词过滤(支持多个关键词,大小写不敏感)
|
||||
keywords_list = parse_keywords_string(keyword) or []
|
||||
if not keywords_list and keyword.strip():
|
||||
keywords_list = [keyword.strip()]
|
||||
keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()]
|
||||
|
||||
# 先按关键词过滤(仅对 content 字段进行匹配)
|
||||
filtered_keyword: List[Jargon] = []
|
||||
for r in candidates:
|
||||
content = (r.content or "").lower()
|
||||
|
||||
# 只要命中任意一个关键词即可视为匹配(OR 逻辑)
|
||||
any_matched = False
|
||||
for kw in keywords_lower:
|
||||
if not kw:
|
||||
continue
|
||||
if kw in content:
|
||||
any_matched = True
|
||||
break
|
||||
|
||||
if any_matched:
|
||||
filtered_keyword.append(r)
|
||||
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 全局黑话模式:不再做 chat_id 过滤,直接使用关键词过滤结果
|
||||
records = filtered_keyword
|
||||
else:
|
||||
# 非全局模式:仅保留全局黑话或 chat_id 列表中包含当前 chat_id 的记录
|
||||
records = []
|
||||
for r in filtered_keyword:
|
||||
if r.is_global:
|
||||
records.append(r)
|
||||
continue
|
||||
chat_id_list = parse_chat_id_list(r.chat_id)
|
||||
if chat_id_list_contains(chat_id_list, chat_id):
|
||||
records.append(r)
|
||||
|
||||
if not records:
|
||||
scope_note = (
|
||||
"(当前为全局黑话模式,仅统计 is_global=True 的条目)"
|
||||
if global_config.expression.all_global_jargon
|
||||
else "(当前为按 chat_id 作用域模式,仅统计全局黑话或与当前 chat_id 相关的条目)"
|
||||
)
|
||||
return f"未找到包含关键词'{keyword}'的 Jargon 记录{scope_note}"
|
||||
|
||||
lines: List[str] = []
|
||||
for r in records:
|
||||
is_jargon_str = "是" if r.is_jargon else "否" if r.is_jargon is False else "未判定"
|
||||
is_global_str = "全局" if r.is_global else "非全局"
|
||||
lines.append(
|
||||
f"ID={r.id} | 内容={r.content} | 含义={r.meaning or '无'} | "
|
||||
f"chat_id={r.chat_id} | {is_global_str} | 是否黑话={is_jargon_str}"
|
||||
)
|
||||
|
||||
result = "\n".join(lines)
|
||||
logger.info(f"[dream][tool] search_jargon 返回 {len(records)} 条记录")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"search_jargon 失败: {e}")
|
||||
return f"search_jargon 执行失败: {e}"
|
||||
|
||||
return search_jargon
|
||||
52
src/dream/tools/update_chat_history_tool.py
Normal file
52
src/dream/tools/update_chat_history_tool.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import ChatHistory
|
||||
from src.plugin_system.apis import database_api
|
||||
|
||||
logger = get_logger("dream_agent")
|
||||
|
||||
|
||||
def make_update_chat_history(chat_id: str): # chat_id 目前未直接使用,预留以备扩展
|
||||
async def update_chat_history(
|
||||
memory_id: int,
|
||||
theme: Optional[str] = None,
|
||||
summary: Optional[str] = None,
|
||||
keywords: Optional[str] = None,
|
||||
key_point: Optional[str] = None,
|
||||
) -> str:
|
||||
"""按字段更新 chat_history(字符串字段要求 JSON 的字段须传入已序列化的字符串)"""
|
||||
try:
|
||||
logger.info(
|
||||
f"[dream][tool] 调用 update_chat_history(memory_id={memory_id}, "
|
||||
f"theme={bool(theme)}, summary={bool(summary)}, keywords={bool(keywords)}, key_point={bool(key_point)})"
|
||||
)
|
||||
record = ChatHistory.get_or_none(ChatHistory.id == memory_id)
|
||||
if not record:
|
||||
msg = f"未找到 ID={memory_id} 的 ChatHistory 记录,无法更新。"
|
||||
logger.info(f"[dream][tool] update_chat_history 未找到记录: {msg}")
|
||||
return msg
|
||||
|
||||
data: Dict[str, Any] = {}
|
||||
if theme is not None:
|
||||
data["theme"] = theme
|
||||
if summary is not None:
|
||||
data["summary"] = summary
|
||||
if keywords is not None:
|
||||
data["keywords"] = keywords
|
||||
if key_point is not None:
|
||||
data["key_point"] = key_point
|
||||
|
||||
if not data:
|
||||
return "未提供任何需要更新的字段。"
|
||||
|
||||
await database_api.db_save(ChatHistory, data=data, key_field="id", key_value=memory_id)
|
||||
msg = f"已更新 ChatHistory 记录 ID={memory_id},更新字段={list(data.keys())}。"
|
||||
logger.info(f"[dream][tool] update_chat_history 完成: {msg}")
|
||||
return msg
|
||||
except Exception as e:
|
||||
logger.error(f"update_chat_history 失败: {e}")
|
||||
return f"update_chat_history 执行失败: {e}"
|
||||
|
||||
return update_chat_history
|
||||
|
||||
52
src/dream/tools/update_jargon_tool.py
Normal file
52
src/dream/tools/update_jargon_tool.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Jargon
|
||||
from src.plugin_system.apis import database_api
|
||||
|
||||
logger = get_logger("dream_agent")
|
||||
|
||||
|
||||
def make_update_jargon(chat_id: str): # chat_id 目前未直接使用,预留以备扩展
|
||||
async def update_jargon(
|
||||
jargon_id: int,
|
||||
meaning: Optional[str] = None,
|
||||
is_global: Optional[bool] = None,
|
||||
is_jargon: Optional[bool] = None,
|
||||
content: Optional[str] = None,
|
||||
) -> str:
|
||||
"""按字段更新 Jargon 记录,可用于修正含义、调整全局性、标记是否为黑话等"""
|
||||
try:
|
||||
logger.info(
|
||||
f"[dream][tool] 调用 update_jargon(jargon_id={jargon_id}, "
|
||||
f"meaning={bool(meaning)}, is_global={is_global}, is_jargon={is_jargon}, content={bool(content)})"
|
||||
)
|
||||
record = Jargon.get_or_none(Jargon.id == jargon_id)
|
||||
if not record:
|
||||
msg = f"未找到 ID={jargon_id} 的 Jargon 记录,无法更新。"
|
||||
logger.info(f"[dream][tool] update_jargon 未找到记录: {msg}")
|
||||
return msg
|
||||
|
||||
data: Dict[str, Any] = {}
|
||||
if meaning is not None:
|
||||
data["meaning"] = meaning
|
||||
if is_global is not None:
|
||||
data["is_global"] = is_global
|
||||
if is_jargon is not None:
|
||||
data["is_jargon"] = is_jargon
|
||||
if content is not None:
|
||||
data["content"] = content
|
||||
|
||||
if not data:
|
||||
return "未提供任何需要更新的字段。"
|
||||
|
||||
await database_api.db_save(Jargon, data=data, key_field="id", key_value=jargon_id)
|
||||
msg = f"已更新 Jargon 记录 ID={jargon_id},更新字段={list(data.keys())}。"
|
||||
logger.info(f"[dream][tool] update_jargon 完成: {msg}")
|
||||
return msg
|
||||
except Exception as e:
|
||||
logger.error(f"update_jargon 失败: {e}")
|
||||
return f"update_jargon 执行失败: {e}"
|
||||
|
||||
return update_jargon
|
||||
|
||||
@@ -1,145 +0,0 @@
|
||||
import re
|
||||
import difflib
|
||||
import random
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict
|
||||
|
||||
|
||||
def filter_message_content(content: Optional[str]) -> str:
|
||||
"""
|
||||
过滤消息内容,移除回复、@、图片等格式
|
||||
|
||||
Args:
|
||||
content: 原始消息内容
|
||||
|
||||
Returns:
|
||||
str: 过滤后的内容
|
||||
"""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
|
||||
content = re.sub(r"\[回复.*?\],说:\s*", "", content)
|
||||
# 移除@<...>格式的内容
|
||||
content = re.sub(r"@<[^>]*>", "", content)
|
||||
# 移除[picid:...]格式的图片ID
|
||||
content = re.sub(r"\[picid:[^\]]*\]", "", content)
|
||||
# 移除[表情包:...]格式的内容
|
||||
content = re.sub(r"\[表情包:[^\]]*\]", "", content)
|
||||
|
||||
return content.strip()
|
||||
|
||||
|
||||
def calculate_similarity(text1: str, text2: str) -> float:
|
||||
"""
|
||||
计算两个文本的相似度,返回0-1之间的值
|
||||
使用SequenceMatcher计算相似度
|
||||
|
||||
Args:
|
||||
text1: 第一个文本
|
||||
text2: 第二个文本
|
||||
|
||||
Returns:
|
||||
float: 相似度值,范围0-1
|
||||
"""
|
||||
return difflib.SequenceMatcher(None, text1, text2).ratio()
|
||||
|
||||
|
||||
def format_create_date(timestamp: float) -> str:
|
||||
"""
|
||||
将时间戳格式化为可读的日期字符串
|
||||
|
||||
Args:
|
||||
timestamp: 时间戳
|
||||
|
||||
Returns:
|
||||
str: 格式化后的日期字符串
|
||||
"""
|
||||
try:
|
||||
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
return "未知时间"
|
||||
|
||||
|
||||
def _compute_weights(population: List[Dict]) -> List[float]:
|
||||
"""
|
||||
根据表达的count计算权重,范围限定在1~5之间。
|
||||
count越高,权重越高,但最多为基础权重的5倍。
|
||||
如果表达已checked,权重会再乘以3倍。
|
||||
"""
|
||||
if not population:
|
||||
return []
|
||||
|
||||
counts = []
|
||||
checked_flags = []
|
||||
for item in population:
|
||||
count = item.get("count", 1)
|
||||
try:
|
||||
count_value = float(count)
|
||||
except (TypeError, ValueError):
|
||||
count_value = 1.0
|
||||
counts.append(max(count_value, 0.0))
|
||||
# 获取checked状态
|
||||
checked = item.get("checked", False)
|
||||
checked_flags.append(bool(checked))
|
||||
|
||||
min_count = min(counts)
|
||||
max_count = max(counts)
|
||||
|
||||
if max_count == min_count:
|
||||
base_weights = [1.0 for _ in counts]
|
||||
else:
|
||||
base_weights = []
|
||||
for count_value in counts:
|
||||
# 线性映射到[1,5]区间
|
||||
normalized = (count_value - min_count) / (max_count - min_count)
|
||||
base_weights.append(1.0 + normalized * 4.0) # 1~3
|
||||
|
||||
# 如果checked,权重乘以3
|
||||
weights = []
|
||||
for base_weight, checked in zip(base_weights, checked_flags, strict=False):
|
||||
if checked:
|
||||
weights.append(base_weight * 3.0)
|
||||
else:
|
||||
weights.append(base_weight)
|
||||
return weights
|
||||
|
||||
|
||||
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
|
||||
"""
|
||||
随机抽样函数
|
||||
|
||||
Args:
|
||||
population: 总体数据列表
|
||||
k: 需要抽取的数量
|
||||
|
||||
Returns:
|
||||
List[Dict]: 抽取的数据列表
|
||||
"""
|
||||
if not population or k <= 0:
|
||||
return []
|
||||
|
||||
if len(population) <= k:
|
||||
return population.copy()
|
||||
|
||||
selected: List[Dict] = []
|
||||
population_copy = population.copy()
|
||||
|
||||
for _ in range(min(k, len(population_copy))):
|
||||
weights = _compute_weights(population_copy)
|
||||
total_weight = sum(weights)
|
||||
if total_weight <= 0:
|
||||
# 回退到均匀随机
|
||||
idx = random.randint(0, len(population_copy) - 1)
|
||||
selected.append(population_copy.pop(idx))
|
||||
continue
|
||||
|
||||
threshold = random.uniform(0, total_weight)
|
||||
cumulative = 0.0
|
||||
for idx, weight in enumerate(weights):
|
||||
cumulative += weight
|
||||
if threshold <= cumulative:
|
||||
selected.append(population_copy.pop(idx))
|
||||
break
|
||||
|
||||
return selected
|
||||
@@ -1,708 +0,0 @@
|
||||
import time
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import asyncio
|
||||
from typing import List, Optional, Tuple
|
||||
import traceback
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||
build_anonymous_messages,
|
||||
build_bare_messages,
|
||||
)
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.express.express_utils import filter_message_content, calculate_similarity
|
||||
from json_repair import repair_json
|
||||
|
||||
|
||||
# MAX_EXPRESSION_COUNT = 300
|
||||
|
||||
logger = get_logger("expressor")
|
||||
|
||||
|
||||
def init_prompt() -> None:
|
||||
learn_style_prompt = """
|
||||
{chat_str}
|
||||
|
||||
请从上面这段群聊中概括除了人名为"SELF"之外的人的语言风格
|
||||
1. 只考虑文字,不要考虑表情包和图片
|
||||
2. 不要涉及具体的人名,但是可以涉及具体名词
|
||||
3. 思考有没有特殊的梗,一并总结成语言风格
|
||||
4. 例子仅供参考,请严格根据群聊内容总结!!!
|
||||
注意:总结成如下格式的规律,总结的内容要详细,但具有概括性:
|
||||
例如:当"AAAAA"时,可以"BBBBB", AAAAA代表某个具体的场景,不超过20个字。BBBBB代表对应的语言风格,特定句式或表达方式,不超过20个字。
|
||||
|
||||
例如:
|
||||
当"对某件事表示十分惊叹"时,使用"我嘞个xxxx"
|
||||
当"表示讽刺的赞同,不讲道理"时,使用"对对对"
|
||||
当"想说明某个具体的事实观点,但懒得明说,使用"懂的都懂"
|
||||
当"当涉及游戏相关时,夸赞,略带戏谑意味"时,使用"这么强!"
|
||||
|
||||
请注意:不要总结你自己(SELF)的发言,尽量保证总结内容的逻辑性
|
||||
现在请你概括
|
||||
"""
|
||||
Prompt(learn_style_prompt, "learn_style_prompt")
|
||||
|
||||
match_expression_context_prompt = """
|
||||
**聊天内容**
|
||||
{chat_str}
|
||||
|
||||
**从聊天内容总结的表达方式pairs**
|
||||
{expression_pairs}
|
||||
|
||||
请你为上面的每一条表达方式,找到该表达方式的原文句子,并输出匹配结果,expression_pair不能有重复,每个expression_pair仅输出一个最合适的context。
|
||||
如果找不到原句,就不输出该句的匹配结果。
|
||||
以json格式输出:
|
||||
格式如下:
|
||||
{{
|
||||
"expression_pair": "表达方式pair的序号(数字)",
|
||||
"context": "与表达方式对应的原文句子的原始内容,不要修改原文句子的内容",
|
||||
}},
|
||||
{{
|
||||
"expression_pair": "表达方式pair的序号(数字)",
|
||||
"context": "与表达方式对应的原文句子的原始内容,不要修改原文句子的内容",
|
||||
}},
|
||||
...
|
||||
|
||||
现在请你输出匹配结果:
|
||||
"""
|
||||
Prompt(match_expression_context_prompt, "match_expression_context_prompt")
|
||||
|
||||
|
||||
class ExpressionLearner:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.express_learn_model: LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils, request_type="expression.learner"
|
||||
)
|
||||
self.summary_model: LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="expression.summary"
|
||||
)
|
||||
self.embedding_model: LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.embedding, request_type="expression.embedding"
|
||||
)
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
|
||||
|
||||
# 维护每个chat的上次学习时间
|
||||
self.last_learning_time: float = time.time()
|
||||
|
||||
# 学习锁,防止并发执行学习任务
|
||||
self._learning_lock = asyncio.Lock()
|
||||
|
||||
# 学习参数
|
||||
_, self.enable_learning, self.learning_intensity = global_config.expression.get_expression_config_for_chat(
|
||||
self.chat_id
|
||||
)
|
||||
self.min_messages_for_learning = 15 / self.learning_intensity # 触发学习所需的最少消息数
|
||||
self.min_learning_interval = 120 / self.learning_intensity
|
||||
|
||||
def should_trigger_learning(self) -> bool:
|
||||
"""
|
||||
检查是否应该触发学习
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
bool: 是否应该触发学习
|
||||
"""
|
||||
# 检查是否允许学习
|
||||
if not self.enable_learning:
|
||||
return False
|
||||
|
||||
# 检查时间间隔
|
||||
time_diff = time.time() - self.last_learning_time
|
||||
if time_diff < self.min_learning_interval:
|
||||
return False
|
||||
|
||||
# 检查消息数量(只检查指定聊天流的消息)
|
||||
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_learning_time,
|
||||
timestamp_end=time.time(),
|
||||
)
|
||||
|
||||
if not recent_messages or len(recent_messages) < self.min_messages_for_learning:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def trigger_learning_for_chat(self):
|
||||
"""
|
||||
为指定聊天流触发学习
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
bool: 是否成功触发学习
|
||||
"""
|
||||
# 使用异步锁防止并发执行
|
||||
async with self._learning_lock:
|
||||
# 在锁内检查,避免并发触发
|
||||
# 如果锁被持有,其他协程会等待,但等待期间条件可能已变化,所以需要再次检查
|
||||
if not self.should_trigger_learning():
|
||||
return
|
||||
|
||||
# 保存学习开始前的时间戳,用于获取消息范围
|
||||
learning_start_timestamp = time.time()
|
||||
previous_learning_time = self.last_learning_time
|
||||
|
||||
# 立即更新学习时间,防止并发触发
|
||||
self.last_learning_time = learning_start_timestamp
|
||||
|
||||
try:
|
||||
logger.info(f"在聊天流 {self.chat_name} 学习表达方式")
|
||||
# 学习语言风格,传递学习开始前的时间戳
|
||||
learnt_style = await self.learn_and_store(num=25, timestamp_start=previous_learning_time)
|
||||
|
||||
if learnt_style:
|
||||
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
|
||||
else:
|
||||
logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
|
||||
traceback.print_exc()
|
||||
# 即使失败也保持时间戳更新,避免频繁重试
|
||||
return
|
||||
|
||||
async def learn_and_store(self, num: int = 10, timestamp_start: Optional[float] = None) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
学习并存储表达方式
|
||||
|
||||
Args:
|
||||
num: 学习数量
|
||||
timestamp_start: 学习开始的时间戳,如果为None则使用self.last_learning_time
|
||||
"""
|
||||
learnt_expressions = await self.learn_expression(num, timestamp_start=timestamp_start)
|
||||
|
||||
if learnt_expressions is None:
|
||||
logger.info("没有学习到表达风格")
|
||||
return []
|
||||
|
||||
# 展示学到的表达方式
|
||||
learnt_expressions_str = ""
|
||||
for (
|
||||
situation,
|
||||
style,
|
||||
_context,
|
||||
_up_content,
|
||||
) in learnt_expressions:
|
||||
learnt_expressions_str += f"{situation}->{style}\n"
|
||||
logger.info(f"在 {self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到数据库 Expression 表
|
||||
for (
|
||||
situation,
|
||||
style,
|
||||
context,
|
||||
up_content,
|
||||
) in learnt_expressions:
|
||||
await self._upsert_expression_record(
|
||||
situation=situation,
|
||||
style=style,
|
||||
context=context,
|
||||
up_content=up_content,
|
||||
current_time=current_time,
|
||||
)
|
||||
|
||||
return learnt_expressions
|
||||
|
||||
async def match_expression_context(
|
||||
self, expression_pairs: List[Tuple[str, str]], random_msg_match_str: str
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
# 为expression_pairs逐个条目赋予编号,并构建成字符串
|
||||
numbered_pairs = []
|
||||
for i, (situation, style) in enumerate(expression_pairs, 1):
|
||||
numbered_pairs.append(f'{i}. 当"{situation}"时,使用"{style}"')
|
||||
|
||||
expression_pairs_str = "\n".join(numbered_pairs)
|
||||
|
||||
prompt = "match_expression_context_prompt"
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
prompt,
|
||||
expression_pairs=expression_pairs_str,
|
||||
chat_str=random_msg_match_str,
|
||||
)
|
||||
|
||||
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
|
||||
|
||||
# print(f"match_expression_context_prompt: {prompt}")
|
||||
# print(f"{response}")
|
||||
|
||||
# 解析JSON响应
|
||||
match_responses = []
|
||||
try:
|
||||
response = response.strip()
|
||||
|
||||
# 尝试提取JSON代码块(如果存在)
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, response, re.DOTALL)
|
||||
if matches:
|
||||
response = matches[0].strip()
|
||||
|
||||
# 移除可能的markdown代码块标记(如果没有找到```json,但可能有```)
|
||||
if not matches:
|
||||
response = re.sub(r"^```\s*", "", response, flags=re.MULTILINE)
|
||||
response = re.sub(r"```\s*$", "", response, flags=re.MULTILINE)
|
||||
response = response.strip()
|
||||
|
||||
# 检查是否已经是标准JSON数组格式
|
||||
if response.startswith("[") and response.endswith("]"):
|
||||
match_responses = json.loads(response)
|
||||
else:
|
||||
# 尝试直接解析多个JSON对象
|
||||
try:
|
||||
# 如果是多个JSON对象用逗号分隔,包装成数组
|
||||
if response.startswith("{") and not response.startswith("["):
|
||||
response = "[" + response + "]"
|
||||
match_responses = json.loads(response)
|
||||
else:
|
||||
# 使用repair_json处理响应
|
||||
repaired_content = repair_json(response)
|
||||
|
||||
# 确保repaired_content是列表格式
|
||||
if isinstance(repaired_content, str):
|
||||
try:
|
||||
parsed_data = json.loads(repaired_content)
|
||||
if isinstance(parsed_data, dict):
|
||||
# 如果是字典,包装成列表
|
||||
match_responses = [parsed_data]
|
||||
elif isinstance(parsed_data, list):
|
||||
match_responses = parsed_data
|
||||
else:
|
||||
match_responses = []
|
||||
except json.JSONDecodeError:
|
||||
match_responses = []
|
||||
elif isinstance(repaired_content, dict):
|
||||
# 如果是字典,包装成列表
|
||||
match_responses = [repaired_content]
|
||||
elif isinstance(repaired_content, list):
|
||||
match_responses = repaired_content
|
||||
else:
|
||||
match_responses = []
|
||||
except json.JSONDecodeError:
|
||||
# 如果还是失败,尝试repair_json
|
||||
repaired_content = repair_json(response)
|
||||
if isinstance(repaired_content, str):
|
||||
parsed_data = json.loads(repaired_content)
|
||||
match_responses = parsed_data if isinstance(parsed_data, list) else [parsed_data]
|
||||
else:
|
||||
match_responses = repaired_content if isinstance(repaired_content, list) else [repaired_content]
|
||||
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
logger.error(f"解析匹配响应JSON失败: {e}, 响应内容: \n{response}")
|
||||
return []
|
||||
|
||||
# 确保 match_responses 是一个列表
|
||||
if not isinstance(match_responses, list):
|
||||
if isinstance(match_responses, dict):
|
||||
match_responses = [match_responses]
|
||||
else:
|
||||
logger.error(f"match_responses 不是列表或字典类型: {type(match_responses)}, 内容: {match_responses}")
|
||||
return []
|
||||
|
||||
# 清理和规范化 match_responses 中的元素
|
||||
normalized_responses = []
|
||||
for item in match_responses:
|
||||
if isinstance(item, dict):
|
||||
# 已经是字典,直接添加
|
||||
normalized_responses.append(item)
|
||||
elif isinstance(item, str):
|
||||
# 如果是字符串,尝试解析为 JSON
|
||||
try:
|
||||
parsed = json.loads(item)
|
||||
if isinstance(parsed, dict):
|
||||
normalized_responses.append(parsed)
|
||||
elif isinstance(parsed, list):
|
||||
# 如果是列表,递归处理
|
||||
for sub_item in parsed:
|
||||
if isinstance(sub_item, dict):
|
||||
normalized_responses.append(sub_item)
|
||||
else:
|
||||
logger.debug(f"跳过非字典类型的子元素: {type(sub_item)}, 内容: {sub_item}")
|
||||
else:
|
||||
logger.debug(f"跳过无法转换为字典的字符串元素: {item}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.debug(f"跳过无法解析为JSON的字符串元素: {item}")
|
||||
elif isinstance(item, list):
|
||||
# 如果是列表,展开并处理其中的字典
|
||||
for sub_item in item:
|
||||
if isinstance(sub_item, dict):
|
||||
normalized_responses.append(sub_item)
|
||||
elif isinstance(sub_item, str):
|
||||
# 尝试解析字符串
|
||||
try:
|
||||
parsed = json.loads(sub_item)
|
||||
if isinstance(parsed, dict):
|
||||
normalized_responses.append(parsed)
|
||||
else:
|
||||
logger.debug(f"跳过非字典类型的解析结果: {type(parsed)}, 内容: {parsed}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.debug(f"跳过无法解析为JSON的字符串子元素: {sub_item}")
|
||||
else:
|
||||
logger.debug(f"跳过非字典类型的列表元素: {type(sub_item)}, 内容: {sub_item}")
|
||||
else:
|
||||
logger.debug(f"跳过无法处理的元素类型: {type(item)}, 内容: {item}")
|
||||
|
||||
match_responses = normalized_responses
|
||||
|
||||
matched_expressions = []
|
||||
used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引
|
||||
|
||||
logger.debug(f"规范化后的 match_responses 类型: {type(match_responses)}, 长度: {len(match_responses)}")
|
||||
logger.debug(f"规范化后的 match_responses 内容: {match_responses}")
|
||||
|
||||
for match_response in match_responses:
|
||||
try:
|
||||
# 检查 match_response 的类型(此时应该都是字典)
|
||||
if not isinstance(match_response, dict):
|
||||
logger.error(f"match_response 不是字典类型: {type(match_response)}, 内容: {match_response}")
|
||||
continue
|
||||
|
||||
# 获取表达方式序号
|
||||
if "expression_pair" not in match_response:
|
||||
logger.error(f"match_response 缺少 'expression_pair' 字段: {match_response}")
|
||||
continue
|
||||
|
||||
pair_index = int(match_response["expression_pair"]) - 1 # 转换为0-based索引
|
||||
|
||||
# 检查索引是否有效且未被使用过
|
||||
if 0 <= pair_index < len(expression_pairs) and pair_index not in used_pair_indices:
|
||||
situation, style = expression_pairs[pair_index]
|
||||
context = match_response.get("context", "")
|
||||
matched_expressions.append((situation, style, context))
|
||||
used_pair_indices.add(pair_index) # 标记该索引已使用
|
||||
logger.debug(f"成功匹配表达方式 {pair_index + 1}: {situation} -> {style}")
|
||||
elif pair_index in used_pair_indices:
|
||||
logger.debug(f"跳过重复的表达方式 {pair_index + 1}")
|
||||
except (ValueError, KeyError, IndexError, TypeError) as e:
|
||||
logger.error(f"解析匹配条目失败: {e}, 条目: {match_response}")
|
||||
continue
|
||||
|
||||
return matched_expressions
|
||||
|
||||
async def learn_expression(self, num: int = 10, timestamp_start: Optional[float] = None) -> Optional[List[Tuple[str, str, str, str]]]:
|
||||
"""从指定聊天流学习表达方式
|
||||
|
||||
Args:
|
||||
num: 学习数量
|
||||
timestamp_start: 学习开始的时间戳,如果为None则使用self.last_learning_time
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# 使用传入的时间戳,如果没有则使用self.last_learning_time
|
||||
start_timestamp = timestamp_start if timestamp_start is not None else self.last_learning_time
|
||||
|
||||
# 获取上次学习之后的消息
|
||||
random_msg = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=start_timestamp,
|
||||
timestamp_end=current_time,
|
||||
limit=num,
|
||||
)
|
||||
# print(random_msg)
|
||||
if not random_msg or random_msg == []:
|
||||
return None
|
||||
|
||||
# 学习用
|
||||
random_msg_str: str = await build_anonymous_messages(random_msg)
|
||||
# 溯源用
|
||||
random_msg_match_str: str = await build_bare_messages(random_msg)
|
||||
|
||||
prompt: str = await global_prompt_manager.format_prompt(
|
||||
"learn_style_prompt",
|
||||
chat_str=random_msg_str,
|
||||
)
|
||||
|
||||
# print(f"random_msg_str:{random_msg_str}")
|
||||
# logger.info(f"学习{type_str}的prompt: {prompt}")
|
||||
|
||||
try:
|
||||
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
|
||||
except Exception as e:
|
||||
logger.error(f"学习表达方式失败,模型生成出错: {e}")
|
||||
return None
|
||||
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
|
||||
expressions = self._filter_self_reference_styles(expressions)
|
||||
if not expressions:
|
||||
logger.info("过滤后没有可用的表达方式(style 与机器人名称重复)")
|
||||
return None
|
||||
# logger.debug(f"学习{type_str}的response: {response}")
|
||||
|
||||
# 对表达方式溯源
|
||||
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
|
||||
expressions, random_msg_match_str
|
||||
)
|
||||
# 为每条消息构建精简文本列表,保留到原消息索引的映射
|
||||
bare_lines: List[Tuple[int, str]] = self._build_bare_lines(random_msg)
|
||||
# 将 matched_expressions 结合上一句 up_content(若不存在上一句则跳过)
|
||||
filtered_with_up: List[Tuple[str, str, str, str]] = [] # (situation, style, context, up_content)
|
||||
for situation, style, context in matched_expressions:
|
||||
# 在 bare_lines 中找到第一处相似度达到85%的行
|
||||
pos = None
|
||||
for i, (_, c) in enumerate(bare_lines):
|
||||
similarity = calculate_similarity(c, context)
|
||||
if similarity >= 0.85: # 85%相似度阈值
|
||||
pos = i
|
||||
break
|
||||
|
||||
if pos is None or pos == 0:
|
||||
# 没有匹配到目标句或没有上一句,跳过该表达
|
||||
continue
|
||||
|
||||
# 检查目标句是否为空
|
||||
target_content = bare_lines[pos][1]
|
||||
if not target_content:
|
||||
# 目标句为空,跳过该表达
|
||||
continue
|
||||
|
||||
prev_original_idx = bare_lines[pos - 1][0]
|
||||
up_content = filter_message_content(random_msg[prev_original_idx].processed_plain_text or "")
|
||||
if not up_content:
|
||||
# 上一句为空,跳过该表达
|
||||
continue
|
||||
filtered_with_up.append((situation, style, context, up_content))
|
||||
|
||||
if not filtered_with_up:
|
||||
return None
|
||||
|
||||
return filtered_with_up
|
||||
|
||||
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组
|
||||
"""
|
||||
expressions: List[Tuple[str, str, str]] = []
|
||||
for line in response.splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# 查找"当"和下一个引号
|
||||
idx_when = line.find('当"')
|
||||
if idx_when == -1:
|
||||
continue
|
||||
idx_quote1 = idx_when + 1
|
||||
idx_quote2 = line.find('"', idx_quote1 + 1)
|
||||
if idx_quote2 == -1:
|
||||
continue
|
||||
situation = line[idx_quote1 + 1 : idx_quote2]
|
||||
# 查找"使用"
|
||||
idx_use = line.find('使用"', idx_quote2)
|
||||
if idx_use == -1:
|
||||
continue
|
||||
idx_quote3 = idx_use + 2
|
||||
idx_quote4 = line.find('"', idx_quote3 + 1)
|
||||
if idx_quote4 == -1:
|
||||
continue
|
||||
style = line[idx_quote3 + 1 : idx_quote4]
|
||||
expressions.append((situation, style))
|
||||
return expressions
|
||||
|
||||
def _filter_self_reference_styles(self, expressions: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
过滤掉style与机器人名称/昵称重复的表达
|
||||
"""
|
||||
banned_names = set()
|
||||
bot_nickname = (global_config.bot.nickname or "").strip()
|
||||
if bot_nickname:
|
||||
banned_names.add(bot_nickname)
|
||||
|
||||
alias_names = global_config.bot.alias_names or []
|
||||
for alias in alias_names:
|
||||
alias = alias.strip()
|
||||
if alias:
|
||||
banned_names.add(alias)
|
||||
|
||||
banned_casefold = {name.casefold() for name in banned_names if name}
|
||||
|
||||
filtered: List[Tuple[str, str]] = []
|
||||
removed_count = 0
|
||||
for situation, style in expressions:
|
||||
normalized_style = (style or "").strip()
|
||||
if normalized_style and normalized_style.casefold() not in banned_casefold:
|
||||
filtered.append((situation, style))
|
||||
else:
|
||||
removed_count += 1
|
||||
|
||||
if removed_count:
|
||||
logger.debug(f"已过滤 {removed_count} 条style与机器人名称重复的表达方式")
|
||||
|
||||
return filtered
|
||||
|
||||
async def _upsert_expression_record(
|
||||
self,
|
||||
situation: str,
|
||||
style: str,
|
||||
context: str,
|
||||
up_content: str,
|
||||
current_time: float,
|
||||
) -> None:
|
||||
expr_obj = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.style == style)).first()
|
||||
|
||||
if expr_obj:
|
||||
await self._update_existing_expression(
|
||||
expr_obj=expr_obj,
|
||||
situation=situation,
|
||||
context=context,
|
||||
up_content=up_content,
|
||||
current_time=current_time,
|
||||
)
|
||||
return
|
||||
|
||||
await self._create_expression_record(
|
||||
situation=situation,
|
||||
style=style,
|
||||
context=context,
|
||||
up_content=up_content,
|
||||
current_time=current_time,
|
||||
)
|
||||
|
||||
async def _create_expression_record(
|
||||
self,
|
||||
situation: str,
|
||||
style: str,
|
||||
context: str,
|
||||
up_content: str,
|
||||
current_time: float,
|
||||
) -> None:
|
||||
content_list = [situation]
|
||||
formatted_situation = await self._compose_situation_text(content_list, 1, situation)
|
||||
|
||||
Expression.create(
|
||||
situation=formatted_situation,
|
||||
style=style,
|
||||
content_list=json.dumps(content_list, ensure_ascii=False),
|
||||
count=1,
|
||||
last_active_time=current_time,
|
||||
chat_id=self.chat_id,
|
||||
create_date=current_time,
|
||||
context=context,
|
||||
up_content=up_content,
|
||||
)
|
||||
|
||||
async def _update_existing_expression(
|
||||
self,
|
||||
expr_obj: Expression,
|
||||
situation: str,
|
||||
context: str,
|
||||
up_content: str,
|
||||
current_time: float,
|
||||
) -> None:
|
||||
content_list = self._parse_content_list(expr_obj.content_list)
|
||||
content_list.append(situation)
|
||||
|
||||
expr_obj.content_list = json.dumps(content_list, ensure_ascii=False)
|
||||
expr_obj.count = (expr_obj.count or 0) + 1
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.context = context
|
||||
expr_obj.up_content = up_content
|
||||
|
||||
new_situation = await self._compose_situation_text(
|
||||
content_list=content_list,
|
||||
count=expr_obj.count,
|
||||
fallback=expr_obj.situation,
|
||||
)
|
||||
expr_obj.situation = new_situation
|
||||
|
||||
expr_obj.save()
|
||||
|
||||
def _parse_content_list(self, stored_list: Optional[str]) -> List[str]:
|
||||
if not stored_list:
|
||||
return []
|
||||
try:
|
||||
data = json.loads(stored_list)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
|
||||
|
||||
async def _compose_situation_text(self, content_list: List[str], count: int, fallback: str = "") -> str:
|
||||
sanitized = [c.strip() for c in content_list if c.strip()]
|
||||
summary = await self._summarize_situations(sanitized)
|
||||
if summary:
|
||||
return summary
|
||||
return "/".join(sanitized) if sanitized else fallback
|
||||
|
||||
async def _summarize_situations(self, situations: List[str]) -> Optional[str]:
|
||||
if not situations:
|
||||
return None
|
||||
|
||||
prompt = (
|
||||
"请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,"
|
||||
"长度不超过20个字,保留共同特点:\n"
|
||||
f"{chr(10).join(f'- {s}' for s in situations[-10:])}\n只输出概括内容。"
|
||||
)
|
||||
|
||||
try:
|
||||
summary, _ = await self.summary_model.generate_response_async(prompt, temperature=0.2)
|
||||
summary = summary.strip()
|
||||
if summary:
|
||||
return summary
|
||||
except Exception as e:
|
||||
logger.error(f"概括表达情境失败: {e}")
|
||||
return None
|
||||
|
||||
def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]:
|
||||
"""
|
||||
为每条消息构建精简文本列表,保留到原消息索引的映射
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, str]]: (original_index, bare_content) 元组列表
|
||||
"""
|
||||
bare_lines: List[Tuple[int, str]] = []
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
content = msg.processed_plain_text or ""
|
||||
content = filter_message_content(content)
|
||||
# 即使content为空也要记录,防止错位
|
||||
bare_lines.append((idx, content))
|
||||
|
||||
return bare_lines
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
|
||||
class ExpressionLearnerManager:
|
||||
def __init__(self):
|
||||
self.expression_learners = {}
|
||||
|
||||
self._ensure_expression_directories()
|
||||
|
||||
def get_expression_learner(self, chat_id: str) -> ExpressionLearner:
|
||||
if chat_id not in self.expression_learners:
|
||||
self.expression_learners[chat_id] = ExpressionLearner(chat_id)
|
||||
return self.expression_learners[chat_id]
|
||||
|
||||
def _ensure_expression_directories(self):
|
||||
"""
|
||||
确保表达方式相关的目录结构存在
|
||||
"""
|
||||
base_dir = os.path.join("data", "expression")
|
||||
directories_to_create = [
|
||||
base_dir,
|
||||
os.path.join(base_dir, "learnt_style"),
|
||||
os.path.join(base_dir, "learnt_grammar"),
|
||||
]
|
||||
|
||||
for directory in directories_to_create:
|
||||
try:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
logger.debug(f"确保目录存在: {directory}")
|
||||
except Exception as e:
|
||||
logger.error(f"创建目录失败 {directory}: {e}")
|
||||
|
||||
|
||||
expression_learner_manager = ExpressionLearnerManager()
|
||||
@@ -7,6 +7,7 @@ import asyncio
|
||||
import json
|
||||
import time
|
||||
import re
|
||||
import difflib
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from dataclasses import dataclass, field
|
||||
@@ -30,16 +31,18 @@ HIPPO_CACHE_DIR = Path(__file__).resolve().parents[2] / "data" / "hippo_memorize
|
||||
def init_prompt():
|
||||
"""初始化提示词模板"""
|
||||
|
||||
topic_analysis_prompt = """
|
||||
【历史话题标题列表】(仅标题,不含具体内容):
|
||||
topic_analysis_prompt = """【历史话题标题列表】(仅标题,不含具体内容):
|
||||
{history_topics_block}
|
||||
【历史话题标题列表结束】
|
||||
|
||||
【本次聊天记录】(每条消息前有编号,用于后续引用):
|
||||
{messages_block}
|
||||
【本次聊天记录结束】
|
||||
|
||||
请完成以下任务:
|
||||
**识别话题**
|
||||
1. 识别【本次聊天记录】中正在进行的一个或多个话题;
|
||||
2. 【本次聊天记录】的中的消息可能与历史话题有关,也可能毫无关联。
|
||||
2. 判断【历史话题标题列表】中的话题是否在【本次聊天记录】中出现,如果出现,则直接使用该历史话题标题字符串;
|
||||
|
||||
**选取消息**
|
||||
@@ -316,7 +319,9 @@ class ChatHistorySummarizer:
|
||||
before_count = len(self.current_batch.messages)
|
||||
self.current_batch.messages.extend(new_messages)
|
||||
self.current_batch.end_time = current_time
|
||||
logger.info(f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息")
|
||||
logger.info(
|
||||
f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息"
|
||||
)
|
||||
# 更新批次后持久化
|
||||
self._persist_topic_cache()
|
||||
else:
|
||||
@@ -362,9 +367,7 @@ class ChatHistorySummarizer:
|
||||
else:
|
||||
time_str = f"{time_since_last_check / 3600:.1f}小时"
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}"
|
||||
)
|
||||
logger.debug(f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}")
|
||||
|
||||
# 检查“话题检查”触发条件
|
||||
should_check = False
|
||||
@@ -374,10 +377,10 @@ class ChatHistorySummarizer:
|
||||
should_check = True
|
||||
logger.info(f"{self.log_prefix} 触发检查条件: 消息数量达到 {message_count} 条(阈值: 100条)")
|
||||
|
||||
# 条件2: 距离上一次检查 > 3600 秒(1小时),触发一次检查
|
||||
elif time_since_last_check > 2400:
|
||||
# 条件2: 距离上一次检查 > 3600 * 8 秒(8小时)且消息数量 >= 20 条,触发一次检查
|
||||
elif time_since_last_check > 3600 * 8 and message_count >= 20:
|
||||
should_check = True
|
||||
logger.info(f"{self.log_prefix} 触发检查条件: 距上次检查 {time_str}(阈值: 1小时)")
|
||||
logger.info(f"{self.log_prefix} 触发检查条件: 距上次检查 {time_str}(阈值: 8小时)且消息数量达到 {message_count} 条(阈值: 20条)")
|
||||
|
||||
if should_check:
|
||||
await self._run_topic_check_and_update_cache(messages)
|
||||
@@ -414,7 +417,7 @@ class ChatHistorySummarizer:
|
||||
# 说明 bot 没有参与这段对话,不应该记录
|
||||
bot_user_id = str(global_config.bot.qq_account)
|
||||
has_bot_message = False
|
||||
|
||||
|
||||
for msg in messages:
|
||||
if msg.user_info.user_id == bot_user_id:
|
||||
has_bot_message = True
|
||||
@@ -427,20 +430,63 @@ class ChatHistorySummarizer:
|
||||
return
|
||||
|
||||
# 2. 构造编号后的消息字符串和参与者信息
|
||||
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = self._build_numbered_messages_for_llm(messages)
|
||||
|
||||
# 3. 调用 LLM 识别话题,并得到 topic -> indices
|
||||
existing_topics = list(self.topic_cache.keys())
|
||||
success, topic_to_indices = await self._analyze_topics_with_llm(
|
||||
numbered_lines=numbered_lines,
|
||||
existing_topics=existing_topics,
|
||||
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = (
|
||||
self._build_numbered_messages_for_llm(messages)
|
||||
)
|
||||
|
||||
# 3. 调用 LLM 识别话题,并得到 topic -> indices(失败时最多重试 3 次)
|
||||
existing_topics = list(self.topic_cache.keys())
|
||||
max_retries = 3
|
||||
attempt = 0
|
||||
success = False
|
||||
topic_to_indices: Dict[str, List[int]] = {}
|
||||
|
||||
while attempt < max_retries:
|
||||
attempt += 1
|
||||
success, topic_to_indices = await self._analyze_topics_with_llm(
|
||||
numbered_lines=numbered_lines,
|
||||
existing_topics=existing_topics,
|
||||
)
|
||||
|
||||
if success and topic_to_indices:
|
||||
if attempt > 1:
|
||||
logger.info(
|
||||
f"{self.log_prefix} 话题识别在第 {attempt} 次重试后成功 | 话题数: {len(topic_to_indices)}"
|
||||
)
|
||||
break
|
||||
|
||||
logger.warning(
|
||||
f"{self.log_prefix} 话题识别失败或无有效话题,第 {attempt} 次尝试失败"
|
||||
+ ("" if attempt >= max_retries else ",准备重试")
|
||||
)
|
||||
|
||||
if not success or not topic_to_indices:
|
||||
logger.warning(f"{self.log_prefix} 话题识别失败或无有效话题,本次检查忽略")
|
||||
# 即使识别失败,也认为是一次“检查”,但不更新 no_update_checks(保持原状)
|
||||
logger.error(f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃")
|
||||
# 即使识别失败,也认为是一次"检查",但不更新 no_update_checks(保持原状)
|
||||
return
|
||||
|
||||
# 3.5. 检查新话题是否与历史话题相似(相似度>=90%则使用历史标题)
|
||||
topic_mapping = self._build_topic_mapping(topic_to_indices, similarity_threshold=0.9)
|
||||
|
||||
# 应用话题映射:将相似的新话题标题替换为历史话题标题
|
||||
if topic_mapping:
|
||||
new_topic_to_indices: Dict[str, List[int]] = {}
|
||||
for new_topic, indices in topic_to_indices.items():
|
||||
# 如果这个新话题需要映射到历史话题
|
||||
if new_topic in topic_mapping:
|
||||
historical_topic = topic_mapping[new_topic]
|
||||
# 如果历史话题已经存在,合并消息索引
|
||||
if historical_topic in new_topic_to_indices:
|
||||
# 合并索引并去重
|
||||
combined_indices = list(set(new_topic_to_indices[historical_topic] + indices))
|
||||
new_topic_to_indices[historical_topic] = combined_indices
|
||||
else:
|
||||
new_topic_to_indices[historical_topic] = indices
|
||||
else:
|
||||
# 不需要映射,保持原样
|
||||
new_topic_to_indices[new_topic] = indices
|
||||
topic_to_indices = new_topic_to_indices
|
||||
|
||||
# 4. 统计哪些话题在本次检查中有新增内容
|
||||
updated_topics: Set[str] = set()
|
||||
|
||||
@@ -507,6 +553,71 @@ class ChatHistorySummarizer:
|
||||
# 无论成功与否,都从缓存中删除,避免重复
|
||||
self.topic_cache.pop(topic, None)
|
||||
|
||||
def _find_most_similar_topic(
|
||||
self, new_topic: str, existing_topics: List[str], similarity_threshold: float = 0.9
|
||||
) -> Optional[tuple[str, float]]:
|
||||
"""
|
||||
查找与给定新话题最相似的历史话题
|
||||
|
||||
Args:
|
||||
new_topic: 新话题标题
|
||||
existing_topics: 历史话题标题列表
|
||||
similarity_threshold: 相似度阈值,默认0.9(90%)
|
||||
|
||||
Returns:
|
||||
Optional[tuple[str, float]]: 如果找到相似度>=阈值的历史话题,返回(历史话题标题, 相似度),
|
||||
否则返回None
|
||||
"""
|
||||
if not existing_topics:
|
||||
return None
|
||||
|
||||
best_match = None
|
||||
best_similarity = 0.0
|
||||
|
||||
for existing_topic in existing_topics:
|
||||
similarity = difflib.SequenceMatcher(None, new_topic, existing_topic).ratio()
|
||||
if similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = existing_topic
|
||||
|
||||
# 如果相似度达到阈值,返回匹配结果
|
||||
if best_match and best_similarity >= similarity_threshold:
|
||||
return (best_match, best_similarity)
|
||||
|
||||
return None
|
||||
|
||||
def _build_topic_mapping(
|
||||
self, topic_to_indices: Dict[str, List[int]], similarity_threshold: float = 0.9
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
构建新话题到历史话题的映射(如果相似度>=阈值)
|
||||
|
||||
Args:
|
||||
topic_to_indices: 新话题到消息索引的映射
|
||||
similarity_threshold: 相似度阈值,默认0.9(90%)
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: 新话题 -> 历史话题的映射字典
|
||||
"""
|
||||
existing_topics_list = list(self.topic_cache.keys())
|
||||
topic_mapping: Dict[str, str] = {}
|
||||
|
||||
for new_topic in topic_to_indices.keys():
|
||||
# 如果新话题已经在历史话题中,不需要检查
|
||||
if new_topic in existing_topics_list:
|
||||
continue
|
||||
|
||||
# 查找最相似的历史话题
|
||||
result = self._find_most_similar_topic(new_topic, existing_topics_list, similarity_threshold)
|
||||
if result:
|
||||
historical_topic, similarity = result
|
||||
topic_mapping[new_topic] = historical_topic
|
||||
logger.info(
|
||||
f"{self.log_prefix} 话题相似度检查: '{new_topic}' 与历史话题 '{historical_topic}' 相似度 {similarity:.2%},使用历史标题"
|
||||
)
|
||||
|
||||
return topic_mapping
|
||||
|
||||
def _build_numbered_messages_for_llm(
|
||||
self, messages: List[DatabaseMessages]
|
||||
) -> tuple[List[str], Dict[int, str], Dict[int, str], Dict[int, Set[str]]]:
|
||||
@@ -589,9 +700,7 @@ class ChatHistorySummarizer:
|
||||
if not numbered_lines:
|
||||
return False, {}
|
||||
|
||||
history_topics_block = (
|
||||
"\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
|
||||
)
|
||||
history_topics_block = "\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
|
||||
messages_block = "\n".join(numbered_lines)
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
@@ -603,8 +712,7 @@ class ChatHistorySummarizer:
|
||||
try:
|
||||
response, _ = await self.summarizer_llm.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.2,
|
||||
max_tokens=800,
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
logger.info(f"{self.log_prefix} 话题识别LLM Prompt: {prompt}")
|
||||
@@ -614,17 +722,17 @@ class ChatHistorySummarizer:
|
||||
json_str = None
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, response, re.DOTALL)
|
||||
|
||||
|
||||
if matches:
|
||||
# 找到JSON代码块,使用第一个匹配
|
||||
json_str = matches[0].strip()
|
||||
else:
|
||||
# 如果没有找到代码块,尝试查找JSON数组的开始和结束位置
|
||||
# 查找第一个 [ 和最后一个 ]
|
||||
start_idx = response.find('[')
|
||||
end_idx = response.rfind(']')
|
||||
start_idx = response.find("[")
|
||||
end_idx = response.rfind("]")
|
||||
if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
|
||||
json_str = response[start_idx:end_idx + 1].strip()
|
||||
json_str = response[start_idx : end_idx + 1].strip()
|
||||
else:
|
||||
# 如果还是找不到,尝试直接使用整个响应(移除可能的markdown标记)
|
||||
json_str = response.strip()
|
||||
@@ -921,4 +1029,3 @@ class ChatHistorySummarizer:
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
|
||||
@@ -28,10 +28,10 @@ class MemoryForgetTask(AsyncTask):
|
||||
# logger.info("[记忆遗忘] 开始遗忘检查...")
|
||||
|
||||
# 执行4个阶段的遗忘检查
|
||||
await self._forget_stage_1(current_time)
|
||||
await self._forget_stage_2(current_time)
|
||||
await self._forget_stage_3(current_time)
|
||||
await self._forget_stage_4(current_time)
|
||||
# await self._forget_stage_1(current_time)
|
||||
# await self._forget_stage_2(current_time)
|
||||
# await self._forget_stage_3(current_time)
|
||||
# await self._forget_stage_4(current_time)
|
||||
|
||||
# logger.info("[记忆遗忘] 遗忘检查完成")
|
||||
except Exception as e:
|
||||
@@ -1,5 +0,0 @@
|
||||
from .jargon_miner import extract_and_store_jargon
|
||||
|
||||
__all__ = [
|
||||
"extract_and_store_jargon",
|
||||
]
|
||||
@@ -1,195 +0,0 @@
|
||||
import json
|
||||
from typing import List, Dict, Optional, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages,
|
||||
)
|
||||
from src.chat.utils.utils import parse_platform_accounts
|
||||
|
||||
|
||||
logger = get_logger("jargon")
|
||||
|
||||
|
||||
def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]:
|
||||
"""
|
||||
解析chat_id字段,兼容旧格式(字符串)和新格式(JSON列表)
|
||||
|
||||
Args:
|
||||
chat_id_value: 可能是字符串(旧格式)或JSON字符串(新格式)
|
||||
|
||||
Returns:
|
||||
List[List[Any]]: 格式为 [[chat_id, count], ...] 的列表
|
||||
"""
|
||||
if not chat_id_value:
|
||||
return []
|
||||
|
||||
# 如果是字符串,尝试解析为JSON
|
||||
if isinstance(chat_id_value, str):
|
||||
# 尝试解析JSON
|
||||
try:
|
||||
parsed = json.loads(chat_id_value)
|
||||
if isinstance(parsed, list):
|
||||
# 新格式:已经是列表
|
||||
return parsed
|
||||
elif isinstance(parsed, str):
|
||||
# 解析后还是字符串,说明是旧格式
|
||||
return [[parsed, 1]]
|
||||
else:
|
||||
# 其他类型,当作旧格式处理
|
||||
return [[str(chat_id_value), 1]]
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# 解析失败,当作旧格式(纯字符串)
|
||||
return [[str(chat_id_value), 1]]
|
||||
elif isinstance(chat_id_value, list):
|
||||
# 已经是列表格式
|
||||
return chat_id_value
|
||||
else:
|
||||
# 其他类型,转换为旧格式
|
||||
return [[str(chat_id_value), 1]]
|
||||
|
||||
|
||||
def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, increment: int = 1) -> List[List[Any]]:
|
||||
"""
|
||||
更新chat_id列表,如果target_chat_id已存在则增加计数,否则添加新条目
|
||||
|
||||
Args:
|
||||
chat_id_list: 当前的chat_id列表,格式为 [[chat_id, count], ...]
|
||||
target_chat_id: 要更新或添加的chat_id
|
||||
increment: 增加的计数,默认为1
|
||||
|
||||
Returns:
|
||||
List[List[Any]]: 更新后的chat_id列表
|
||||
"""
|
||||
# 查找是否已存在该chat_id
|
||||
found = False
|
||||
for item in chat_id_list:
|
||||
if isinstance(item, list) and len(item) >= 1 and str(item[0]) == str(target_chat_id):
|
||||
# 找到匹配的chat_id,增加计数
|
||||
if len(item) >= 2:
|
||||
item[1] = (item[1] if isinstance(item[1], (int, float)) else 0) + increment
|
||||
else:
|
||||
item.append(increment)
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
# 未找到,添加新条目
|
||||
chat_id_list.append([target_chat_id, increment])
|
||||
|
||||
return chat_id_list
|
||||
|
||||
|
||||
def chat_id_list_contains(chat_id_list: List[List[Any]], target_chat_id: str) -> bool:
|
||||
"""
|
||||
检查chat_id列表中是否包含指定的chat_id
|
||||
|
||||
Args:
|
||||
chat_id_list: chat_id列表,格式为 [[chat_id, count], ...]
|
||||
target_chat_id: 要查找的chat_id
|
||||
|
||||
Returns:
|
||||
bool: 如果包含则返回True
|
||||
"""
|
||||
for item in chat_id_list:
|
||||
if isinstance(item, list) and len(item) >= 1 and str(item[0]) == str(target_chat_id):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def contains_bot_self_name(content: str) -> bool:
|
||||
"""
|
||||
判断词条是否包含机器人的昵称或别名
|
||||
"""
|
||||
if not content:
|
||||
return False
|
||||
|
||||
bot_config = getattr(global_config, "bot", None)
|
||||
if not bot_config:
|
||||
return False
|
||||
|
||||
target = content.strip().lower()
|
||||
nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower()
|
||||
alias_names = [str(alias or "").strip().lower() for alias in getattr(bot_config, "alias_names", []) or []]
|
||||
|
||||
candidates = [name for name in [nickname, *alias_names] if name]
|
||||
|
||||
return any(name in target for name in candidates if target)
|
||||
|
||||
|
||||
def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]:
|
||||
"""
|
||||
构建包含中心消息上下文的段落(前3条+后3条),使用标准的 readable builder 输出
|
||||
"""
|
||||
if not messages or center_index < 0 or center_index >= len(messages):
|
||||
return None
|
||||
|
||||
context_start = max(0, center_index - 3)
|
||||
context_end = min(len(messages), center_index + 1 + 3)
|
||||
context_messages = messages[context_start:context_end]
|
||||
|
||||
if not context_messages:
|
||||
return None
|
||||
|
||||
try:
|
||||
paragraph = build_readable_messages(
|
||||
messages=context_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
show_pic=True,
|
||||
message_id_list=None,
|
||||
remove_emoji_stickers=False,
|
||||
pic_single=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"构建上下文段落失败: {e}")
|
||||
return None
|
||||
|
||||
paragraph = paragraph.strip()
|
||||
return paragraph or None
|
||||
|
||||
|
||||
def is_bot_message(msg: Any) -> bool:
|
||||
"""判断消息是否来自机器人自身"""
|
||||
if msg is None:
|
||||
return False
|
||||
|
||||
bot_config = getattr(global_config, "bot", None)
|
||||
if not bot_config:
|
||||
return False
|
||||
|
||||
platform = (
|
||||
str(getattr(msg, "user_platform", "") or getattr(getattr(msg, "user_info", None), "platform", "") or "")
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip()
|
||||
|
||||
if not platform or not user_id:
|
||||
return False
|
||||
|
||||
platform_accounts = {}
|
||||
try:
|
||||
platform_accounts = parse_platform_accounts(getattr(bot_config, "platforms", []) or [])
|
||||
except Exception:
|
||||
platform_accounts = {}
|
||||
|
||||
bot_accounts: Dict[str, str] = {}
|
||||
qq_account = str(getattr(bot_config, "qq_account", "") or "").strip()
|
||||
if qq_account:
|
||||
bot_accounts["qq"] = qq_account
|
||||
|
||||
telegram_account = str(getattr(bot_config, "telegram_account", "") or "").strip()
|
||||
if telegram_account:
|
||||
bot_accounts["telegram"] = telegram_account
|
||||
|
||||
for plat, account in platform_accounts.items():
|
||||
if account and plat not in bot_accounts:
|
||||
bot_accounts[plat] = account
|
||||
|
||||
bot_account = bot_accounts.get(platform)
|
||||
return bool(bot_account and user_id == bot_account)
|
||||
@@ -98,7 +98,10 @@ def _convert_messages(
|
||||
content: List[Part] = []
|
||||
for item in message.content:
|
||||
if isinstance(item, tuple):
|
||||
image_format = "jpeg" if item[0].lower() == "jpg" else item[0].lower()
|
||||
image_format = item[0].lower()
|
||||
# 规范 JPEG MIME 类型后缀,统一使用 image/jpeg
|
||||
if image_format in ("jpg", "jpeg"):
|
||||
image_format = "jpeg"
|
||||
content.append(Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{image_format}"))
|
||||
elif isinstance(item, str):
|
||||
content.append(Part.from_text(text=item))
|
||||
@@ -143,10 +146,14 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclar
|
||||
:param tool_option_param: 工具参数对象
|
||||
:return: 转换后的工具参数字典
|
||||
"""
|
||||
# JSON Schema要求使用"boolean"而不是"bool"
|
||||
# JSON Schema 类型名称修正:
|
||||
# - 布尔类型使用 "boolean" 而不是 "bool"
|
||||
# - 浮点数使用 "number" 而不是 "float"
|
||||
param_type_value = tool_option_param.param_type.value
|
||||
if param_type_value == "bool":
|
||||
param_type_value = "boolean"
|
||||
elif param_type_value == "float":
|
||||
param_type_value = "number"
|
||||
|
||||
return_dict: dict[str, Any] = {
|
||||
"type": param_type_value,
|
||||
|
||||
@@ -61,10 +61,16 @@ def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessagePara
|
||||
content = []
|
||||
for item in message.content:
|
||||
if isinstance(item, tuple):
|
||||
image_format = item[0].lower()
|
||||
# 规范 JPEG MIME 类型后缀,统一使用 image/jpeg
|
||||
if image_format in ("jpg", "jpeg"):
|
||||
mime_suffix = "jpeg"
|
||||
else:
|
||||
mime_suffix = image_format
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/{item[0].lower()};base64,{item[1]}"},
|
||||
"image_url": {"url": f"data:image/{mime_suffix};base64,{item[1]}"},
|
||||
}
|
||||
)
|
||||
elif isinstance(item, str):
|
||||
@@ -118,10 +124,14 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict[str, Any]
|
||||
:param tool_option_param: 工具参数对象
|
||||
:return: 转换后的工具参数字典
|
||||
"""
|
||||
# JSON Schema要求使用"boolean"而不是"bool"
|
||||
# JSON Schema 类型名称修正:
|
||||
# - 布尔类型使用 "boolean" 而不是 "bool"
|
||||
# - 浮点数使用 "number" 而不是 "float"
|
||||
param_type_value = tool_option_param.param_type.value
|
||||
if param_type_value == "bool":
|
||||
param_type_value = "boolean"
|
||||
elif param_type_value == "float":
|
||||
param_type_value = "number"
|
||||
|
||||
return_dict: dict[str, Any] = {
|
||||
"type": param_type_value,
|
||||
|
||||
@@ -49,7 +49,7 @@ class LLMRequest:
|
||||
|
||||
def _check_slow_request(self, time_cost: float, model_name: str) -> None:
|
||||
"""检查请求是否过慢并输出警告日志
|
||||
|
||||
|
||||
Args:
|
||||
time_cost: 请求耗时(秒)
|
||||
model_name: 使用的模型名称
|
||||
@@ -315,12 +315,30 @@ class LLMRequest:
|
||||
while retry_remain > 0:
|
||||
try:
|
||||
if request_type == RequestType.RESPONSE:
|
||||
# 温度优先级:参数传入 > 模型级别配置 > extra_params > 任务配置
|
||||
effective_temperature = temperature
|
||||
if effective_temperature is None:
|
||||
effective_temperature = model_info.temperature
|
||||
if effective_temperature is None:
|
||||
effective_temperature = (model_info.extra_params or {}).get("temperature")
|
||||
if effective_temperature is None:
|
||||
effective_temperature = self.model_for_task.temperature
|
||||
|
||||
# max_tokens 优先级:参数传入 > 模型级别配置 > extra_params > 任务配置
|
||||
effective_max_tokens = max_tokens
|
||||
if effective_max_tokens is None:
|
||||
effective_max_tokens = model_info.max_tokens
|
||||
if effective_max_tokens is None:
|
||||
effective_max_tokens = (model_info.extra_params or {}).get("max_tokens")
|
||||
if effective_max_tokens is None:
|
||||
effective_max_tokens = self.model_for_task.max_tokens
|
||||
|
||||
return await client.get_response(
|
||||
model_info=model_info,
|
||||
message_list=(compressed_messages or message_list),
|
||||
tool_options=tool_options,
|
||||
max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens,
|
||||
temperature=temperature if temperature is not None else (model_info.extra_params or {}).get("temperature", self.model_for_task.temperature),
|
||||
max_tokens=effective_max_tokens,
|
||||
temperature=effective_temperature,
|
||||
response_format=response_format,
|
||||
stream_response_handler=stream_response_handler,
|
||||
async_response_parser=async_response_parser,
|
||||
@@ -348,7 +366,9 @@ class LLMRequest:
|
||||
logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。{original_error_info}")
|
||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
||||
|
||||
logger.warning(f"模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}")
|
||||
logger.warning(
|
||||
f"模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}"
|
||||
)
|
||||
await asyncio.sleep(api_provider.retry_interval)
|
||||
|
||||
except NetworkConnectionError as e:
|
||||
@@ -376,7 +396,9 @@ class LLMRequest:
|
||||
if e.status_code == 429 or e.status_code >= 500:
|
||||
retry_remain -= 1
|
||||
if retry_remain <= 0:
|
||||
logger.error(f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}")
|
||||
logger.error(
|
||||
f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}"
|
||||
)
|
||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
||||
|
||||
logger.warning(
|
||||
@@ -522,7 +544,5 @@ class LLMRequest:
|
||||
if e.__cause__:
|
||||
original_error_type = type(e.__cause__).__name__
|
||||
original_error_msg = str(e.__cause__)
|
||||
return (
|
||||
f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
|
||||
)
|
||||
return f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
|
||||
return ""
|
||||
|
||||
28
src/main.py
28
src/main.py
@@ -13,9 +13,9 @@ from src.config.config import global_config
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
from src.common.logger import get_logger
|
||||
from src.common.server import get_global_server, Server
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.chat.knowledge import lpmm_start_up
|
||||
from rich.traceback import install
|
||||
|
||||
# from src.api.main import start_api_server
|
||||
|
||||
# 导入新的插件管理器
|
||||
@@ -23,6 +23,7 @@ from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
# 导入消息API和traceback模块
|
||||
from src.common.message import get_global_api
|
||||
from src.dream.dream_agent import start_dream_scheduler
|
||||
|
||||
# 插件系统现在使用统一的插件加载器
|
||||
|
||||
@@ -43,30 +44,17 @@ class MainSystem:
|
||||
|
||||
def _setup_webui_server(self):
|
||||
"""设置独立的 WebUI 服务器"""
|
||||
import os
|
||||
from src.config.config import global_config
|
||||
|
||||
webui_enabled = os.getenv("WEBUI_ENABLED", "false").lower() == "true"
|
||||
if not webui_enabled:
|
||||
if not global_config.webui.enabled:
|
||||
logger.info("WebUI 已禁用")
|
||||
return
|
||||
|
||||
webui_mode = os.getenv("WEBUI_MODE", "production").lower()
|
||||
|
||||
try:
|
||||
from src.webui.webui_server import get_webui_server
|
||||
|
||||
self.webui_server = get_webui_server()
|
||||
|
||||
if webui_mode == "development":
|
||||
logger.info("📝 WebUI 开发模式已启用")
|
||||
logger.info("🌐 后端 API 将运行在配置的地址(默认 http://127.0.0.1:8001)")
|
||||
logger.info("💡 请手动启动前端开发服务器: cd MaiBot-Dashboard && bun dev")
|
||||
logger.info("💡 前端将运行在 http://localhost:7999")
|
||||
else:
|
||||
logger.info("✅ WebUI 生产模式已启用")
|
||||
logger.info("🌐 WebUI 将运行在配置的地址(默认 http://127.0.0.1:8001)")
|
||||
logger.info("💡 请确保已构建前端: cd MaiBot-Dashboard && bun run build")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 初始化 WebUI 服务器失败: {e}")
|
||||
|
||||
@@ -106,7 +94,7 @@ class MainSystem:
|
||||
await async_task_manager.add_task(TelemetryHeartBeatTask())
|
||||
|
||||
# 添加记忆遗忘任务
|
||||
from src.chat.utils.memory_forget_task import MemoryForgetTask
|
||||
from src.hippo_memorizer.memory_forget_task import MemoryForgetTask
|
||||
|
||||
await async_task_manager.add_task(MemoryForgetTask())
|
||||
|
||||
@@ -124,11 +112,6 @@ class MainSystem:
|
||||
get_emoji_manager().initialize()
|
||||
logger.info("表情包管理器初始化成功")
|
||||
|
||||
# 启动情绪管理器
|
||||
if global_config.mood.enable_mood:
|
||||
await mood_manager.start()
|
||||
logger.info("情绪管理器初始化成功")
|
||||
|
||||
# 初始化聊天管理器
|
||||
await get_chat_manager()._initialize()
|
||||
asyncio.create_task(get_chat_manager()._auto_save_task())
|
||||
@@ -159,6 +142,7 @@ class MainSystem:
|
||||
try:
|
||||
tasks = [
|
||||
get_emoji_manager().start_periodic_check_register(),
|
||||
start_dream_scheduler(),
|
||||
self.app.run(),
|
||||
self.server.run(),
|
||||
]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
import re
|
||||
from typing import List, Dict, Any, Optional, Tuple, Set
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
@@ -10,7 +11,7 @@ from src.common.database.database_model import ThinkingBack
|
||||
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
|
||||
from src.memory_system.memory_utils import parse_questions_json
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
|
||||
from src.jargon.jargon_explainer import match_jargon_from_text, retrieve_concepts_with_jargon
|
||||
from src.bw_learner.jargon_explainer import match_jargon_from_text, retrieve_concepts_with_jargon
|
||||
|
||||
logger = get_logger("memory_retrieval")
|
||||
|
||||
@@ -77,20 +78,12 @@ def init_memory_retrieval_prompt():
|
||||
|
||||
问题要说明前因后果和上下文,使其全面且精准
|
||||
|
||||
输出格式示例(需要检索时):
|
||||
输出格式示例:
|
||||
```json
|
||||
{{
|
||||
"questions": ["张三在前几天干了什么"] #问题数组(字符串数组),如果不需要检索记忆则输出空数组[],如果需要检索则只输出包含一个问题的数组
|
||||
}}
|
||||
```
|
||||
|
||||
输出格式示例(不需要检索时):
|
||||
```json
|
||||
{{
|
||||
"questions": []
|
||||
}}
|
||||
```
|
||||
|
||||
请只输出JSON对象,不要输出其他内容:
|
||||
""",
|
||||
name="memory_retrieval_question_prompt",
|
||||
@@ -104,17 +97,16 @@ def init_memory_retrieval_prompt():
|
||||
已收集的信息:
|
||||
{collected_info}
|
||||
|
||||
**执行步骤:**
|
||||
**工具说明:**
|
||||
- 如果涉及过往事件,或者查询某个过去可能提到过的概念,或者某段时间发生的事件。可以使用聊天记录查询工具查询过往事件
|
||||
- 如果涉及人物,可以使用人物信息查询工具查询人物信息
|
||||
- 如果没有可靠信息,且查询时间充足,或者不确定查询类别,也可以使用lpmm知识库查询,作为辅助信息
|
||||
- **如果信息不足需要使用tool,说明需要查询什么,并输出为纯文本说明,然后调用相应工具查询(可并行调用多个工具)**
|
||||
- **如果当前已收集的信息足够回答问题,且能找到明确答案,调用found_answer工具标记已找到答案**
|
||||
|
||||
**思考**
|
||||
- 你可以对查询思路给出简短的思考:思考要简短,直接切入要点
|
||||
- 如果信息不足,你必须给出使用什么工具进行查询
|
||||
- 如果信息足够,你必须调用found_answer工具
|
||||
- 先思考当前信息是否足够回答问题
|
||||
- 如果信息不足,则需要使用tool查询信息,你必须给出使用什么工具进行查询
|
||||
- 如果当前已收集的信息足够或信息不足确定无法找到答案,你必须调用finish_search工具结束查询
|
||||
""",
|
||||
name="memory_retrieval_react_prompt_head",
|
||||
)
|
||||
@@ -128,14 +120,12 @@ def init_memory_retrieval_prompt():
|
||||
已收集的信息:
|
||||
{collected_info}
|
||||
|
||||
**执行步骤:**
|
||||
分析:
|
||||
- 当前信息是否足够回答问题?
|
||||
- **如果信息足够且能找到明确答案**,在思考中直接给出答案,格式为:found_answer(answer="你的答案内容")
|
||||
- **如果信息不足或无法找到答案**,在思考中给出:not_enough_info(reason="信息不足或无法找到答案的原因")
|
||||
|
||||
**重要规则:**
|
||||
- 你已经经过几轮查询,尝试了信息搜集,现在你需要总结信息,选择回答问题或判断问题无法回答
|
||||
- 必须严格使用检索到的信息回答问题,不要编造信息
|
||||
- 答案必须精简,不要过多解释
|
||||
- **只有在检索到明确、具体的答案时,才使用found_answer**
|
||||
@@ -146,8 +136,6 @@ def init_memory_retrieval_prompt():
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
def _log_conversation_messages(
|
||||
conversation_messages: List[Message],
|
||||
head_prompt: Optional[str] = None,
|
||||
@@ -167,7 +155,7 @@ def _log_conversation_messages(
|
||||
|
||||
# 如果有head_prompt,先添加为第一条消息
|
||||
if head_prompt:
|
||||
msg_info = "========================================\n[消息 1] 角色: System 内容类型: 文本\n-----------------------------"
|
||||
msg_info = "========================================\n[消息 1] 角色: System\n-----------------------------"
|
||||
msg_info += f"\n{head_prompt}"
|
||||
log_lines.append(msg_info)
|
||||
start_idx = 2
|
||||
@@ -180,32 +168,24 @@ def _log_conversation_messages(
|
||||
for idx, msg in enumerate(conversation_messages, start_idx):
|
||||
role_name = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
|
||||
|
||||
# 处理内容 - 显示完整内容,不截断
|
||||
if isinstance(msg.content, str):
|
||||
full_content = msg.content
|
||||
content_type = "文本"
|
||||
elif isinstance(msg.content, list):
|
||||
text_parts = [item for item in msg.content if isinstance(item, str)]
|
||||
image_count = len([item for item in msg.content if isinstance(item, tuple)])
|
||||
full_content = "".join(text_parts) if text_parts else ""
|
||||
content_type = f"混合({len(text_parts)}段文本, {image_count}张图片)"
|
||||
else:
|
||||
full_content = str(msg.content)
|
||||
content_type = "未知"
|
||||
|
||||
# 构建单条消息的日志信息
|
||||
msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n-----------------------------"
|
||||
# msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n-----------------------------"
|
||||
msg_info = (
|
||||
f"\n========================================\n[消息 {idx}] 角色: {role_name}\n-----------------------------"
|
||||
)
|
||||
|
||||
if full_content:
|
||||
msg_info += f"\n{full_content}"
|
||||
# if full_content:
|
||||
# msg_info += f"\n{full_content}"
|
||||
if msg.content:
|
||||
msg_info += f"\n{msg.content}"
|
||||
|
||||
if msg.tool_calls:
|
||||
msg_info += f"\n 工具调用: {len(msg.tool_calls)}个"
|
||||
for tool_call in msg.tool_calls:
|
||||
msg_info += f"\n - {tool_call}"
|
||||
msg_info += f"\n - {tool_call.func_name}: {json.dumps(tool_call.args, ensure_ascii=False)}"
|
||||
|
||||
if msg.tool_call_id:
|
||||
msg_info += f"\n 工具调用ID: {msg.tool_call_id}"
|
||||
# if msg.tool_call_id:
|
||||
# msg_info += f"\n 工具调用ID: {msg.tool_call_id}"
|
||||
|
||||
log_lines.append(msg_info)
|
||||
|
||||
@@ -251,6 +231,7 @@ async def _react_agent_solve_question(
|
||||
conversation_messages: List[Message] = []
|
||||
first_head_prompt: Optional[str] = None # 保存第一次使用的head_prompt(用于日志显示)
|
||||
|
||||
# 正常迭代:max_iterations 次(最终评估单独处理,不算在迭代中)
|
||||
for iteration in range(max_iterations):
|
||||
# 检查超时
|
||||
if time.time() - start_time > timeout:
|
||||
@@ -270,7 +251,6 @@ async def _react_agent_solve_question(
|
||||
# 计算剩余迭代次数
|
||||
current_iteration = iteration + 1
|
||||
remaining_iterations = max_iterations - current_iteration
|
||||
is_final_iteration = current_iteration >= max_iterations
|
||||
|
||||
# 提取函数调用中参数的值,支持单引号和双引号
|
||||
def extract_quoted_content(text, func_name, param_name):
|
||||
@@ -330,114 +310,10 @@ async def _react_agent_solve_question(
|
||||
|
||||
return None
|
||||
|
||||
# 如果是最后一次迭代,使用final_prompt进行总结
|
||||
if is_final_iteration:
|
||||
evaluation_prompt = await global_prompt_manager.format_prompt(
|
||||
"memory_retrieval_react_final_prompt",
|
||||
bot_name=bot_name,
|
||||
time_now=time_now,
|
||||
question=question,
|
||||
collected_info=collected_info if collected_info else "暂无信息",
|
||||
current_iteration=current_iteration,
|
||||
remaining_iterations=remaining_iterations,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
if global_config.debug.show_memory_prompt:
|
||||
logger.info(f"ReAct Agent 最终评估Prompt: {evaluation_prompt}")
|
||||
|
||||
eval_success, eval_response, eval_reasoning_content, eval_model_name, eval_tool_calls = await llm_api.generate_with_model_with_tools(
|
||||
evaluation_prompt,
|
||||
model_config=model_config.model_task_config.tool_use,
|
||||
tool_options=[], # 最终评估阶段不提供工具
|
||||
request_type="memory.react.final",
|
||||
)
|
||||
|
||||
if not eval_success:
|
||||
logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 最终评估阶段 LLM调用失败: {eval_response}")
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status="未找到答案:最终评估阶段LLM调用失败",
|
||||
)
|
||||
return False, "最终评估阶段LLM调用失败", thinking_steps, False
|
||||
|
||||
logger.info(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 最终评估响应: {eval_response}"
|
||||
)
|
||||
|
||||
# 从最终评估响应中提取found_answer或not_enough_info
|
||||
found_answer_content = None
|
||||
not_enough_info_reason = None
|
||||
|
||||
if eval_response:
|
||||
found_answer_content = extract_quoted_content(eval_response, "found_answer", "answer")
|
||||
if not found_answer_content:
|
||||
not_enough_info_reason = extract_quoted_content(eval_response, "not_enough_info", "reason")
|
||||
|
||||
# 如果找到答案,返回
|
||||
if found_answer_content:
|
||||
eval_step = {
|
||||
"iteration": iteration + 1,
|
||||
"thought": f"[最终评估] {eval_response}",
|
||||
"actions": [{"action_type": "found_answer", "action_params": {"answer": found_answer_content}}],
|
||||
"observations": ["最终评估阶段检测到found_answer"]
|
||||
}
|
||||
thinking_steps.append(eval_step)
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 最终评估阶段找到关于问题{question}的答案: {found_answer_content}")
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status=f"找到答案:{found_answer_content}",
|
||||
)
|
||||
|
||||
return True, found_answer_content, thinking_steps, False
|
||||
|
||||
# 如果评估为not_enough_info,返回空字符串(不返回任何信息)
|
||||
if not_enough_info_reason:
|
||||
eval_step = {
|
||||
"iteration": iteration + 1,
|
||||
"thought": f"[最终评估] {eval_response}",
|
||||
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}],
|
||||
"observations": ["最终评估阶段检测到not_enough_info"]
|
||||
}
|
||||
thinking_steps.append(eval_step)
|
||||
logger.info(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 最终评估阶段判断信息不足: {not_enough_info_reason}"
|
||||
)
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status=f"未找到答案:{not_enough_info_reason}",
|
||||
)
|
||||
|
||||
return False, "", thinking_steps, False
|
||||
|
||||
# 如果没有明确判断,视为not_enough_info,返回空字符串(不返回任何信息)
|
||||
eval_step = {
|
||||
"iteration": iteration + 1,
|
||||
"thought": f"[最终评估] {eval_response}",
|
||||
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": "已到达最后一次迭代,无法找到答案"}}],
|
||||
"observations": ["已到达最后一次迭代,无法找到答案"]
|
||||
}
|
||||
thinking_steps.append(eval_step)
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 已到达最后一次迭代,无法找到答案")
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status="未找到答案:已到达最后一次迭代,无法找到答案",
|
||||
)
|
||||
|
||||
return False, "", thinking_steps, False
|
||||
|
||||
# 前n-1次迭代,使用head_prompt决定调用哪些工具(包含found_answer工具)
|
||||
# 正常迭代:使用head_prompt决定调用哪些工具(包含finish_search工具)
|
||||
tool_definitions = tool_registry.get_tool_definitions()
|
||||
logger.info(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: {len(tool_definitions)}"
|
||||
)
|
||||
# tool_names = [tool_def["name"] for tool_def in tool_definitions]
|
||||
# logger.debug(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具: {', '.join(tool_names)} (共{len(tool_definitions)}个)")
|
||||
|
||||
# head_prompt应该只构建一次,使用初始的collected_info,后续迭代都复用同一个
|
||||
if first_head_prompt is None:
|
||||
@@ -453,7 +329,7 @@ async def _react_agent_solve_question(
|
||||
remaining_iterations=remaining_iterations,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
|
||||
# 后续迭代都复用第一次构建的head_prompt
|
||||
head_prompt = first_head_prompt
|
||||
|
||||
@@ -487,15 +363,15 @@ async def _react_agent_solve_question(
|
||||
request_type="memory.react",
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
|
||||
)
|
||||
# logger.info(
|
||||
# f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
|
||||
# )
|
||||
|
||||
if not success:
|
||||
logger.error(f"ReAct Agent LLM调用失败: {response}")
|
||||
break
|
||||
|
||||
# 注意:这里会检查found_answer工具调用,如果检测到found_answer工具,会直接返回答案
|
||||
# 注意:这里会检查finish_search工具调用,如果检测到finish_search工具,会根据found_answer参数决定返回答案或退出查询
|
||||
|
||||
assistant_message: Optional[Message] = None
|
||||
if tool_calls:
|
||||
@@ -525,8 +401,113 @@ async def _react_agent_solve_question(
|
||||
|
||||
# 处理工具调用
|
||||
if not tool_calls:
|
||||
# 如果没有工具调用,记录思考过程,继续下一轮迭代(下一轮会再次评估)
|
||||
# 如果没有工具调用,检查响应文本中是否包含finish_search函数调用格式
|
||||
if response and response.strip():
|
||||
# 尝试从文本中解析finish_search函数调用
|
||||
def parse_finish_search_from_text(text: str):
|
||||
"""从文本中解析finish_search函数调用,返回(found_answer, answer)元组,如果未找到则返回(None, None)"""
|
||||
if not text:
|
||||
return None, None
|
||||
|
||||
# 查找finish_search函数调用位置(不区分大小写)
|
||||
func_pattern = "finish_search"
|
||||
text_lower = text.lower()
|
||||
func_pos = text_lower.find(func_pattern)
|
||||
if func_pos == -1:
|
||||
return None, None
|
||||
|
||||
# 查找函数调用的开始和结束位置
|
||||
# 从func_pos开始向后查找左括号
|
||||
start_pos = text.find("(", func_pos)
|
||||
if start_pos == -1:
|
||||
return None, None
|
||||
|
||||
# 查找匹配的右括号(考虑嵌套)
|
||||
paren_count = 0
|
||||
end_pos = start_pos
|
||||
for i in range(start_pos, len(text)):
|
||||
if text[i] == "(":
|
||||
paren_count += 1
|
||||
elif text[i] == ")":
|
||||
paren_count -= 1
|
||||
if paren_count == 0:
|
||||
end_pos = i
|
||||
break
|
||||
else:
|
||||
# 没有找到匹配的右括号
|
||||
return None, None
|
||||
|
||||
# 提取函数参数部分
|
||||
params_text = text[start_pos + 1 : end_pos]
|
||||
|
||||
# 解析found_answer参数(布尔值,可能是true/false/True/False)
|
||||
found_answer = None
|
||||
found_answer_patterns = [
|
||||
r"found_answer\s*=\s*true",
|
||||
r"found_answer\s*=\s*True",
|
||||
r"found_answer\s*=\s*false",
|
||||
r"found_answer\s*=\s*False",
|
||||
]
|
||||
for pattern in found_answer_patterns:
|
||||
match = re.search(pattern, params_text, re.IGNORECASE)
|
||||
if match:
|
||||
found_answer = "true" in match.group(0).lower()
|
||||
break
|
||||
|
||||
# 解析answer参数(字符串,使用extract_quoted_content)
|
||||
answer = extract_quoted_content(text, "finish_search", "answer")
|
||||
|
||||
return found_answer, answer
|
||||
|
||||
parsed_found_answer, parsed_answer = parse_finish_search_from_text(response)
|
||||
|
||||
if parsed_found_answer is not None:
|
||||
# 检测到finish_search函数调用格式
|
||||
if parsed_found_answer:
|
||||
# 找到了答案
|
||||
if parsed_answer:
|
||||
step["actions"].append(
|
||||
{
|
||||
"action_type": "finish_search",
|
||||
"action_params": {"found_answer": True, "answer": parsed_answer},
|
||||
}
|
||||
)
|
||||
step["observations"] = ["检测到finish_search文本格式调用,找到答案"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}"
|
||||
)
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status=f"找到答案:{parsed_answer}",
|
||||
)
|
||||
|
||||
return True, parsed_answer, thinking_steps, False
|
||||
else:
|
||||
# found_answer为True但没有提供answer,视为错误,继续迭代
|
||||
logger.warning(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer"
|
||||
)
|
||||
else:
|
||||
# 未找到答案,直接退出查询
|
||||
step["actions"].append(
|
||||
{"action_type": "finish_search", "action_params": {"found_answer": False}}
|
||||
)
|
||||
step["observations"] = ["检测到finish_search文本格式调用,未找到答案"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式判断未找到答案")
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status="未找到答案:通过finish_search文本格式判断未找到答案",
|
||||
)
|
||||
|
||||
return False, "", thinking_steps, False
|
||||
|
||||
# 如果没有检测到finish_search格式,记录思考过程,继续下一轮迭代
|
||||
step["observations"] = [f"思考完成,但未调用工具。响应: {response}"]
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 思考完成但未调用工具: {response}")
|
||||
collected_info += f"思考: {response}"
|
||||
@@ -537,29 +518,60 @@ async def _react_agent_solve_question(
|
||||
continue
|
||||
|
||||
# 处理工具调用
|
||||
# 首先检查是否有found_answer工具调用,如果有则立即返回,不再处理其他工具
|
||||
found_answer_from_tool = None
|
||||
# 首先检查是否有finish_search工具调用,如果有则立即返回,不再处理其他工具
|
||||
finish_search_found = None
|
||||
finish_search_answer = None
|
||||
for tool_call in tool_calls:
|
||||
tool_name = tool_call.func_name
|
||||
tool_args = tool_call.args or {}
|
||||
|
||||
if tool_name == "found_answer":
|
||||
found_answer_from_tool = tool_args.get("answer", "")
|
||||
if found_answer_from_tool:
|
||||
step["actions"].append({"action_type": "found_answer", "action_params": {"answer": found_answer_from_tool}})
|
||||
step["observations"] = ["检测到found_answer工具调用"]
|
||||
|
||||
if tool_name == "finish_search":
|
||||
finish_search_found = tool_args.get("found_answer", False)
|
||||
finish_search_answer = tool_args.get("answer", "")
|
||||
|
||||
if finish_search_found:
|
||||
# 找到了答案
|
||||
if finish_search_answer:
|
||||
step["actions"].append(
|
||||
{
|
||||
"action_type": "finish_search",
|
||||
"action_params": {"found_answer": True, "answer": finish_search_answer},
|
||||
}
|
||||
)
|
||||
step["observations"] = ["检测到finish_search工具调用,找到答案"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}"
|
||||
)
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status=f"找到答案:{finish_search_answer}",
|
||||
)
|
||||
|
||||
return True, finish_search_answer, thinking_steps, False
|
||||
else:
|
||||
# found_answer为True但没有提供answer,视为错误
|
||||
logger.warning(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer"
|
||||
)
|
||||
else:
|
||||
# 未找到答案,直接退出查询
|
||||
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}})
|
||||
step["observations"] = ["检测到finish_search工具调用,未找到答案"]
|
||||
thinking_steps.append(step)
|
||||
logger.debug(f"ReAct Agent 第 {iteration + 1} 次迭代 通过found_answer工具找到关于问题{question}的答案: {found_answer_from_tool}")
|
||||
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具判断未找到答案")
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status=f"找到答案:{found_answer_from_tool}",
|
||||
final_status="未找到答案:通过finish_search工具判断未找到答案",
|
||||
)
|
||||
|
||||
return True, found_answer_from_tool, thinking_steps, False
|
||||
|
||||
# 如果没有found_answer工具调用,或者found_answer工具调用没有答案,继续处理其他工具
|
||||
|
||||
return False, "", thinking_steps, False
|
||||
|
||||
# 如果没有finish_search工具调用,继续处理其他工具
|
||||
tool_tasks = []
|
||||
for i, tool_call in enumerate(tool_calls):
|
||||
tool_name = tool_call.func_name
|
||||
@@ -569,8 +581,8 @@ async def _react_agent_solve_question(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 工具调用 {i + 1}/{len(tool_calls)}: {tool_name}({tool_args})"
|
||||
)
|
||||
|
||||
# 跳过found_answer工具调用(已经在上面处理过了)
|
||||
if tool_name == "found_answer":
|
||||
# 跳过finish_search工具调用(已经在上面处理过了)
|
||||
if tool_name == "finish_search":
|
||||
continue
|
||||
|
||||
# 普通工具调用
|
||||
@@ -634,7 +646,7 @@ async def _react_agent_solve_question(
|
||||
observation_text += f"\n\n{jargon_info}"
|
||||
collected_info += f"\n{jargon_info}\n"
|
||||
logger.info(f"工具输出触发黑话解析: {new_concepts}")
|
||||
|
||||
|
||||
tool_builder = MessageBuilder()
|
||||
tool_builder.set_role(RoleType.Tool)
|
||||
tool_builder.add_text_content(observation_text)
|
||||
@@ -643,26 +655,196 @@ async def _react_agent_solve_question(
|
||||
|
||||
thinking_steps.append(step)
|
||||
|
||||
# 达到最大迭代次数或超时,但Agent没有明确返回found_answer
|
||||
# 迭代超时应该直接视为not_enough_info,而不是使用已有信息
|
||||
# 只有Agent明确返回found_answer时,才认为找到了答案
|
||||
if collected_info:
|
||||
logger.warning(
|
||||
f"ReAct Agent达到最大迭代次数或超时,但未明确返回found_answer。已收集信息: {collected_info[:100]}..."
|
||||
)
|
||||
# 正常迭代结束后,如果达到最大迭代次数或超时,执行最终评估
|
||||
# 最终评估单独处理,不算在迭代中
|
||||
should_do_final_evaluation = False
|
||||
if is_timeout:
|
||||
logger.warning("ReAct Agent超时,直接视为not_enough_info")
|
||||
else:
|
||||
logger.warning("ReAct Agent达到最大迭代次数,直接视为not_enough_info")
|
||||
|
||||
# React完成时输出消息列表
|
||||
timeout_reason = "超时" if is_timeout else "达到最大迭代次数"
|
||||
should_do_final_evaluation = True
|
||||
logger.warning(f"ReAct Agent超时,已迭代{iteration + 1}次,进入最终评估")
|
||||
elif iteration + 1 >= max_iterations:
|
||||
should_do_final_evaluation = True
|
||||
logger.info(f"ReAct Agent达到最大迭代次数(已迭代{iteration + 1}次),进入最终评估")
|
||||
|
||||
if should_do_final_evaluation:
|
||||
# 获取必要变量用于最终评估
|
||||
tool_registry = get_tool_registry()
|
||||
bot_name = global_config.bot.nickname
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
current_iteration = iteration + 1
|
||||
remaining_iterations = 0
|
||||
|
||||
# 提取函数调用中参数的值,支持单引号和双引号
|
||||
def extract_quoted_content(text, func_name, param_name):
|
||||
"""从文本中提取函数调用中参数的值,支持单引号和双引号
|
||||
|
||||
Args:
|
||||
text: 要搜索的文本
|
||||
func_name: 函数名,如 'found_answer'
|
||||
param_name: 参数名,如 'answer'
|
||||
|
||||
Returns:
|
||||
提取的参数值,如果未找到则返回None
|
||||
"""
|
||||
if not text:
|
||||
return None
|
||||
|
||||
# 查找函数调用位置(不区分大小写)
|
||||
func_pattern = func_name.lower()
|
||||
text_lower = text.lower()
|
||||
func_pos = text_lower.find(func_pattern)
|
||||
if func_pos == -1:
|
||||
return None
|
||||
|
||||
# 查找参数名和等号
|
||||
param_pattern = f"{param_name}="
|
||||
param_pos = text_lower.find(param_pattern, func_pos)
|
||||
if param_pos == -1:
|
||||
return None
|
||||
|
||||
# 跳过参数名、等号和空白
|
||||
start_pos = param_pos + len(param_pattern)
|
||||
while start_pos < len(text) and text[start_pos] in " \t\n":
|
||||
start_pos += 1
|
||||
|
||||
if start_pos >= len(text):
|
||||
return None
|
||||
|
||||
# 确定引号类型
|
||||
quote_char = text[start_pos]
|
||||
if quote_char not in ['"', "'"]:
|
||||
return None
|
||||
|
||||
# 查找匹配的结束引号(考虑转义)
|
||||
end_pos = start_pos + 1
|
||||
while end_pos < len(text):
|
||||
if text[end_pos] == quote_char:
|
||||
# 检查是否是转义的引号
|
||||
if end_pos > start_pos + 1 and text[end_pos - 1] == "\\":
|
||||
end_pos += 1
|
||||
continue
|
||||
# 找到匹配的引号
|
||||
content = text[start_pos + 1 : end_pos]
|
||||
# 处理转义字符
|
||||
content = content.replace('\\"', '"').replace("\\'", "'").replace("\\\\", "\\")
|
||||
return content
|
||||
end_pos += 1
|
||||
|
||||
return None
|
||||
|
||||
# 执行最终评估
|
||||
evaluation_prompt = await global_prompt_manager.format_prompt(
|
||||
"memory_retrieval_react_final_prompt",
|
||||
bot_name=bot_name,
|
||||
time_now=time_now,
|
||||
question=question,
|
||||
collected_info=collected_info if collected_info else "暂无信息",
|
||||
current_iteration=current_iteration,
|
||||
remaining_iterations=remaining_iterations,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
(
|
||||
eval_success,
|
||||
eval_response,
|
||||
eval_reasoning_content,
|
||||
eval_model_name,
|
||||
eval_tool_calls,
|
||||
) = await llm_api.generate_with_model_with_tools(
|
||||
evaluation_prompt,
|
||||
model_config=model_config.model_task_config.tool_use,
|
||||
tool_options=[], # 最终评估阶段不提供工具
|
||||
request_type="memory.react.final",
|
||||
)
|
||||
|
||||
if not eval_success:
|
||||
logger.error(f"ReAct Agent 最终评估阶段 LLM调用失败: {eval_response}")
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status="未找到答案:最终评估阶段LLM调用失败",
|
||||
)
|
||||
return False, "最终评估阶段LLM调用失败", thinking_steps, is_timeout
|
||||
|
||||
if global_config.debug.show_memory_prompt:
|
||||
logger.info(f"ReAct Agent 最终评估Prompt: {evaluation_prompt}")
|
||||
logger.info(f"ReAct Agent 最终评估响应: {eval_response}")
|
||||
|
||||
# 从最终评估响应中提取found_answer或not_enough_info
|
||||
found_answer_content = None
|
||||
not_enough_info_reason = None
|
||||
|
||||
if eval_response:
|
||||
found_answer_content = extract_quoted_content(eval_response, "found_answer", "answer")
|
||||
if not found_answer_content:
|
||||
not_enough_info_reason = extract_quoted_content(eval_response, "not_enough_info", "reason")
|
||||
|
||||
# 如果找到答案,返回(找到答案时,无论是否超时,都视为成功完成)
|
||||
if found_answer_content:
|
||||
eval_step = {
|
||||
"iteration": current_iteration,
|
||||
"thought": f"[最终评估] {eval_response}",
|
||||
"actions": [{"action_type": "found_answer", "action_params": {"answer": found_answer_content}}],
|
||||
"observations": ["最终评估阶段检测到found_answer"],
|
||||
}
|
||||
thinking_steps.append(eval_step)
|
||||
logger.info(f"ReAct Agent 最终评估阶段找到关于问题{question}的答案: {found_answer_content}")
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status=f"找到答案:{found_answer_content}",
|
||||
)
|
||||
|
||||
return True, found_answer_content, thinking_steps, False
|
||||
|
||||
# 如果评估为not_enough_info,返回空字符串(不返回任何信息)
|
||||
if not_enough_info_reason:
|
||||
eval_step = {
|
||||
"iteration": current_iteration,
|
||||
"thought": f"[最终评估] {eval_response}",
|
||||
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}],
|
||||
"observations": ["最终评估阶段检测到not_enough_info"],
|
||||
}
|
||||
thinking_steps.append(eval_step)
|
||||
logger.info(f"ReAct Agent 最终评估阶段判断信息不足: {not_enough_info_reason}")
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status=f"未找到答案:{not_enough_info_reason}",
|
||||
)
|
||||
|
||||
return False, "", thinking_steps, is_timeout
|
||||
|
||||
# 如果没有明确判断,视为not_enough_info,返回空字符串(不返回任何信息)
|
||||
eval_step = {
|
||||
"iteration": current_iteration,
|
||||
"thought": f"[最终评估] {eval_response}",
|
||||
"actions": [
|
||||
{"action_type": "not_enough_info", "action_params": {"reason": "已到达最大迭代次数,无法找到答案"}}
|
||||
],
|
||||
"observations": ["已到达最大迭代次数,无法找到答案"],
|
||||
}
|
||||
thinking_steps.append(eval_step)
|
||||
logger.info("ReAct Agent 已到达最大迭代次数,无法找到答案")
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status="未找到答案:已到达最大迭代次数,无法找到答案",
|
||||
)
|
||||
|
||||
return False, "", thinking_steps, is_timeout
|
||||
|
||||
# 如果正常迭代过程中提前找到答案返回,不会到达这里
|
||||
# 如果正常迭代结束但没有触发最终评估(理论上不应该发生),直接返回
|
||||
logger.warning("ReAct Agent正常迭代结束,但未触发最终评估")
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status=f"未找到答案:{timeout_reason}",
|
||||
final_status="未找到答案:正常迭代结束",
|
||||
)
|
||||
|
||||
|
||||
return False, "", thinking_steps, is_timeout
|
||||
|
||||
|
||||
@@ -817,6 +999,7 @@ async def _process_single_question(
|
||||
context: str,
|
||||
initial_info: str = "",
|
||||
initial_jargon_concepts: Optional[List[str]] = None,
|
||||
max_iterations: Optional[int] = None,
|
||||
) -> Optional[str]:
|
||||
"""处理单个问题的查询
|
||||
|
||||
@@ -841,11 +1024,15 @@ async def _process_single_question(
|
||||
|
||||
jargon_concepts_for_agent = initial_jargon_concepts if global_config.memory.enable_jargon_detection else None
|
||||
|
||||
# 如果未指定max_iterations,使用配置的默认值
|
||||
if max_iterations is None:
|
||||
max_iterations = global_config.memory.max_agent_iterations
|
||||
|
||||
found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question(
|
||||
question=question,
|
||||
chat_id=chat_id,
|
||||
max_iterations=global_config.memory.max_agent_iterations,
|
||||
timeout=120.0,
|
||||
max_iterations=max_iterations,
|
||||
timeout=global_config.memory.agent_timeout_seconds,
|
||||
initial_info=question_initial_info,
|
||||
initial_jargon_concepts=jargon_concepts_for_agent,
|
||||
)
|
||||
@@ -874,7 +1061,7 @@ async def build_memory_retrieval_prompt(
|
||||
sender: str,
|
||||
target: str,
|
||||
chat_stream,
|
||||
tool_executor,
|
||||
think_level: int = 1,
|
||||
) -> str:
|
||||
"""构建记忆检索提示
|
||||
使用两段式查询:第一步生成问题,第二步使用ReAct Agent查询答案
|
||||
@@ -961,9 +1148,17 @@ async def build_memory_retrieval_prompt(
|
||||
logger.info(f"无当次查询,不返回任何结果,耗时: {(end_time - start_time):.3f}秒")
|
||||
return ""
|
||||
|
||||
# 第二步:并行处理所有问题(使用配置的最大迭代次数/120秒超时)
|
||||
max_iterations = global_config.memory.max_agent_iterations
|
||||
logger.debug(f"问题数量: {len(questions)},设置最大迭代次数: {max_iterations},超时时间: 120秒")
|
||||
# 第二步:并行处理所有问题(使用配置的最大迭代次数和超时时间)
|
||||
base_max_iterations = global_config.memory.max_agent_iterations
|
||||
# 根据think_level调整迭代次数:think_level=1时不变,think_level=0时减半
|
||||
if think_level == 0:
|
||||
max_iterations = max(1, base_max_iterations // 2) # 至少为1
|
||||
else:
|
||||
max_iterations = base_max_iterations
|
||||
timeout_seconds = global_config.memory.agent_timeout_seconds
|
||||
logger.debug(
|
||||
f"问题数量: {len(questions)},think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}秒"
|
||||
)
|
||||
|
||||
# 并行处理所有问题,将概念检索结果作为初始信息传递
|
||||
question_tasks = [
|
||||
@@ -973,6 +1168,7 @@ async def build_memory_retrieval_prompt(
|
||||
context=message,
|
||||
initial_info=initial_info,
|
||||
initial_jargon_concepts=concepts if enable_jargon_detection else None,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
for question in questions
|
||||
]
|
||||
@@ -990,10 +1186,10 @@ async def build_memory_retrieval_prompt(
|
||||
|
||||
# 获取最近10分钟内已找到答案的缓存记录
|
||||
cached_answers = _get_recent_found_answers(chat_id, time_window_seconds=600.0)
|
||||
|
||||
|
||||
# 合并当前查询结果和缓存答案(去重:如果当前查询的问题在缓存中已存在,优先使用当前结果)
|
||||
all_results = []
|
||||
|
||||
|
||||
# 先添加当前查询的结果
|
||||
current_questions = set()
|
||||
for result in question_results:
|
||||
@@ -1003,7 +1199,7 @@ async def build_memory_retrieval_prompt(
|
||||
if question_end != -1:
|
||||
current_questions.add(result[4:question_end])
|
||||
all_results.append(result)
|
||||
|
||||
|
||||
# 添加缓存答案(排除当前查询中已存在的问题)
|
||||
for cached_answer in cached_answers:
|
||||
if cached_answer.startswith("问题:"):
|
||||
@@ -1031,4 +1227,3 @@ async def build_memory_retrieval_prompt(
|
||||
except Exception as e:
|
||||
logger.error(f"记忆检索时发生异常: {str(e)}")
|
||||
return ""
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ from src.common.logger import get_logger
|
||||
logger = get_logger("memory_utils")
|
||||
|
||||
|
||||
|
||||
def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
|
||||
"""解析问题JSON,返回概念列表和问题列表
|
||||
|
||||
@@ -68,6 +67,7 @@ def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
|
||||
logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
|
||||
return [], []
|
||||
|
||||
|
||||
def parse_datetime_to_timestamp(value: str) -> float:
|
||||
"""
|
||||
接受多种常见格式并转换为时间戳(秒)
|
||||
|
||||
@@ -14,7 +14,7 @@ from .tool_registry import (
|
||||
from .query_chat_history import register_tool as register_query_chat_history
|
||||
from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
|
||||
from .query_person_info import register_tool as register_query_person_info
|
||||
from .found_answer import register_tool as register_found_answer
|
||||
from .found_answer import register_tool as register_finish_search
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ def init_all_tools():
|
||||
"""初始化并注册所有记忆检索工具"""
|
||||
register_query_chat_history()
|
||||
register_query_person_info()
|
||||
register_found_answer() # 注册found_answer工具
|
||||
register_finish_search() # 注册finish_search工具
|
||||
|
||||
if global_config.lpmm_knowledge.lpmm_mode == "agent":
|
||||
register_lpmm_knowledge()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
found_answer工具 - 用于在记忆检索过程中标记找到答案
|
||||
finish_search工具 - 用于在记忆检索过程中结束查询
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -8,33 +8,42 @@ from .tool_registry import register_memory_retrieval_tool
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def found_answer(answer: str) -> str:
|
||||
"""标记已找到问题的答案
|
||||
async def finish_search(found_answer: bool, answer: str = "") -> str:
|
||||
"""结束查询
|
||||
|
||||
Args:
|
||||
answer: 找到的答案内容
|
||||
found_answer: 是否找到了答案
|
||||
answer: 如果找到了答案,提供答案内容;如果未找到,可以为空
|
||||
|
||||
Returns:
|
||||
str: 确认信息
|
||||
"""
|
||||
# 这个工具主要用于标记,实际答案会通过返回值传递
|
||||
logger.info(f"找到答案: {answer}")
|
||||
return f"已确认找到答案: {answer}"
|
||||
if found_answer:
|
||||
logger.info(f"找到答案: {answer}")
|
||||
return f"已确认找到答案: {answer}"
|
||||
else:
|
||||
logger.info("未找到答案,结束查询")
|
||||
return "未找到答案,查询结束"
|
||||
|
||||
|
||||
def register_tool():
|
||||
"""注册found_answer工具"""
|
||||
"""注册finish_search工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="found_answer",
|
||||
description="当你在已收集的信息中找到了问题的明确答案时,调用此工具标记已找到答案。只有在检索到明确、具体的答案时才使用此工具,不要编造信息。",
|
||||
name="finish_search",
|
||||
description="当你决定结束查询时,调用此工具。如果找到了明确答案,设置found_answer为true并在answer中提供答案;如果未找到答案,设置found_answer为false。只有在检索到明确、具体的答案时才设置found_answer为true,不要编造信息。",
|
||||
parameters=[
|
||||
{
|
||||
"name": "found_answer",
|
||||
"type": "boolean",
|
||||
"description": "是否找到了答案",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "answer",
|
||||
"type": "string",
|
||||
"description": "找到的答案内容,必须基于已收集的信息,不要编造",
|
||||
"required": True,
|
||||
"description": "如果found_answer为true,提供找到的答案内容,必须基于已收集的信息,不要编造;如果found_answer为false,可以为空",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
execute_func=found_answer,
|
||||
execute_func=finish_search,
|
||||
)
|
||||
|
||||
|
||||
@@ -5,18 +5,18 @@
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import ChatHistory
|
||||
from src.chat.utils.utils import parse_keywords_string
|
||||
from src.config.config import global_config
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
from datetime import datetime
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def search_chat_history(
|
||||
chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None
|
||||
) -> str:
|
||||
async def search_chat_history(chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None) -> str:
|
||||
"""根据关键词或参与人查询记忆,返回匹配的记忆id、记忆标题theme和关键词keywords
|
||||
|
||||
Args:
|
||||
@@ -33,7 +33,18 @@ async def search_chat_history(
|
||||
return "未指定查询参数(需要提供keyword或participant之一)"
|
||||
|
||||
# 构建查询条件
|
||||
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
|
||||
# 根据配置决定是否限制在当前 chat_id 内查询
|
||||
use_global_search = global_config.memory.global_memory
|
||||
|
||||
if use_global_search:
|
||||
# 全局查询所有聊天记录
|
||||
query = ChatHistory.select()
|
||||
logger.debug(
|
||||
f"search_chat_history 启用全局查询模式,忽略 chat_id 过滤,keyword={keyword}, participant={participant}"
|
||||
)
|
||||
else:
|
||||
# 仅在当前聊天流内查询
|
||||
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
|
||||
|
||||
# 执行查询
|
||||
records = list(query.order_by(ChatHistory.start_time.desc()).limit(50))
|
||||
@@ -104,7 +115,7 @@ async def search_chat_history(
|
||||
)
|
||||
if kw_matched:
|
||||
matched_count += 1
|
||||
|
||||
|
||||
# 计算需要匹配的关键词数量
|
||||
total_keywords = len(keywords_lower)
|
||||
if total_keywords > 2:
|
||||
@@ -113,7 +124,7 @@ async def search_chat_history(
|
||||
else:
|
||||
# 关键词数量<=2,必须全部匹配
|
||||
required_matches = total_keywords
|
||||
|
||||
|
||||
keyword_matched = matched_count >= required_matches
|
||||
|
||||
# 两者都匹配(如果同时有participant和keyword,需要两者都匹配;如果只有一个条件,只需要该条件匹配)
|
||||
@@ -131,7 +142,9 @@ async def search_chat_history(
|
||||
keywords_list = parse_keywords_string(keyword)
|
||||
if len(keywords_list) > 2:
|
||||
required_count = len(keywords_list) - 1
|
||||
return f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
|
||||
return (
|
||||
f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
|
||||
)
|
||||
else:
|
||||
return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
|
||||
elif participant:
|
||||
@@ -139,9 +152,42 @@ async def search_chat_history(
|
||||
else:
|
||||
return "未找到相关聊天记录"
|
||||
|
||||
# 构建结果文本,返回id、theme和keywords
|
||||
# 如果匹配结果超过20条,不返回具体记录,只返回提示和所有相关关键词
|
||||
if len(filtered_records) > 15:
|
||||
# 统计所有记录上的关键词并去重
|
||||
all_keywords_set = set()
|
||||
for record in filtered_records:
|
||||
if record.keywords:
|
||||
try:
|
||||
keywords_data = (
|
||||
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||
)
|
||||
if isinstance(keywords_data, list):
|
||||
for k in keywords_data:
|
||||
k_str = str(k).strip()
|
||||
if k_str:
|
||||
all_keywords_set.add(k_str)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
continue
|
||||
|
||||
# xxx 使用用户原始查询词,优先 keyword,其次 participant,最后退化成“当前条件”
|
||||
search_label = keyword or participant or "当前条件"
|
||||
|
||||
if all_keywords_set:
|
||||
keywords_str = "、".join(sorted(all_keywords_set))
|
||||
return (
|
||||
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
|
||||
f'有关"{search_label}"的关键词:\n'
|
||||
f"{keywords_str}"
|
||||
)
|
||||
else:
|
||||
return (
|
||||
f'包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n有关"{search_label}"的关键词信息为空'
|
||||
)
|
||||
|
||||
# 构建结果文本,返回id、theme和keywords(最多20条)
|
||||
results = []
|
||||
for record in filtered_records[:20]: # 最多返回20条记录
|
||||
for record in filtered_records[:20]:
|
||||
result_parts = []
|
||||
|
||||
# 添加记忆ID
|
||||
@@ -173,9 +219,6 @@ async def search_chat_history(
|
||||
return "未找到相关聊天记录"
|
||||
|
||||
response_text = "\n\n---\n\n".join(results)
|
||||
if len(filtered_records) > 20:
|
||||
omitted_count = len(filtered_records) - 20
|
||||
response_text += f"\n\n(还有{omitted_count}条记录已省略,可使用记忆ID查询详细信息)"
|
||||
return response_text
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,230 +0,0 @@
|
||||
import time
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
|
||||
|
||||
logger = get_logger("mood")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是群里正在进行的聊天记录
|
||||
|
||||
{identity_block}
|
||||
你先前的情绪状态是:{mood_state}
|
||||
你的情绪特点是:{emotion_style}
|
||||
|
||||
现在,请你根据先前的情绪状态和现在的聊天内容,总结推断你现在的情绪状态,用简短的词句来描述情绪状态
|
||||
请只输出新的情绪状态,不要输出其他内容:
|
||||
""",
|
||||
"get_mood_prompt",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是群里最近的聊天记录
|
||||
|
||||
{identity_block}
|
||||
你之前的情绪状态是:{mood_state}
|
||||
|
||||
距离你上次关注群里消息已经过去了一段时间,你冷静了下来,请你输出一句话或几个词来描述你现在的情绪状态
|
||||
你的情绪特点是:{emotion_style}
|
||||
请只输出新的情绪状态,不要输出其他内容:
|
||||
""",
|
||||
"regress_mood_prompt",
|
||||
)
|
||||
|
||||
|
||||
class ChatMood:
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id: str = chat_id
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
self.chat_stream = chat_manager.get_stream(self.chat_id)
|
||||
|
||||
if not self.chat_stream:
|
||||
raise ValueError(f"Chat stream for chat_id {chat_id} not found")
|
||||
|
||||
self.log_prefix = f"[{self.chat_stream.group_info.group_name if self.chat_stream.group_info else self.chat_stream.user_info.user_nickname}]"
|
||||
|
||||
self.mood_state: str = "感觉很平静"
|
||||
|
||||
self.regression_count: int = 0
|
||||
|
||||
self.mood_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="mood")
|
||||
|
||||
self.last_change_time: float = 0
|
||||
|
||||
async def get_mood(self) -> str:
|
||||
self.regression_count = 0
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
logger.info(f"{self.log_prefix} 获取情绪状态")
|
||||
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_change_time,
|
||||
timestamp_end=current_time,
|
||||
limit=int(global_config.chat.max_context_size / 3),
|
||||
limit_mode="last",
|
||||
)
|
||||
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||
else:
|
||||
bot_nickname = ""
|
||||
|
||||
identity_block = f"你的名字是{bot_name}{bot_nickname}"
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"get_mood_prompt",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
identity_block=identity_block,
|
||||
mood_state=self.mood_state,
|
||||
emotion_style=global_config.mood.emotion_style,
|
||||
)
|
||||
|
||||
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"{self.log_prefix} prompt: {prompt}")
|
||||
logger.info(f"{self.log_prefix} response: {response}")
|
||||
logger.info(f"{self.log_prefix} reasoning_content: {reasoning_content}")
|
||||
|
||||
logger.info(f"{self.log_prefix} 情绪状态更新为: {response}")
|
||||
|
||||
self.mood_state = response
|
||||
|
||||
self.last_change_time = current_time
|
||||
|
||||
return response
|
||||
|
||||
async def regress_mood(self):
|
||||
message_time = time.time()
|
||||
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_change_time,
|
||||
timestamp_end=message_time,
|
||||
limit=15,
|
||||
limit_mode="last",
|
||||
)
|
||||
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||
else:
|
||||
bot_nickname = ""
|
||||
|
||||
identity_block = f"你的名字是{bot_name}{bot_nickname}"
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"regress_mood_prompt",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
identity_block=identity_block,
|
||||
mood_state=self.mood_state,
|
||||
emotion_style=global_config.mood.emotion_style,
|
||||
)
|
||||
|
||||
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"{self.log_prefix} prompt: {prompt}")
|
||||
logger.info(f"{self.log_prefix} response: {response}")
|
||||
logger.info(f"{self.log_prefix} reasoning_content: {reasoning_content}")
|
||||
|
||||
logger.info(f"{self.log_prefix} 情绪状态转变为: {response}")
|
||||
|
||||
self.mood_state = response
|
||||
|
||||
self.regression_count += 1
|
||||
|
||||
|
||||
class MoodRegressionTask(AsyncTask):
|
||||
def __init__(self, mood_manager: "MoodManager"):
|
||||
super().__init__(task_name="MoodRegressionTask", run_interval=45)
|
||||
self.mood_manager = mood_manager
|
||||
|
||||
async def run(self):
|
||||
logger.debug("开始情绪回归任务...")
|
||||
now = time.time()
|
||||
for mood in self.mood_manager.mood_list:
|
||||
if mood.last_change_time == 0:
|
||||
continue
|
||||
|
||||
if now - mood.last_change_time > 200:
|
||||
if mood.regression_count >= 2:
|
||||
continue
|
||||
|
||||
logger.debug(f"{mood.log_prefix} 开始情绪回归, 第 {mood.regression_count + 1} 次")
|
||||
await mood.regress_mood()
|
||||
|
||||
|
||||
class MoodManager:
|
||||
def __init__(self):
|
||||
self.mood_list: list[ChatMood] = []
|
||||
"""当前情绪状态"""
|
||||
self.task_started: bool = False
|
||||
|
||||
async def start(self):
|
||||
"""启动情绪回归后台任务"""
|
||||
if self.task_started:
|
||||
return
|
||||
|
||||
task = MoodRegressionTask(self)
|
||||
await async_task_manager.add_task(task)
|
||||
self.task_started = True
|
||||
logger.info("情绪回归任务已启动")
|
||||
|
||||
def get_mood_by_chat_id(self, chat_id: str) -> ChatMood:
|
||||
for mood in self.mood_list:
|
||||
if mood.chat_id == chat_id:
|
||||
return mood
|
||||
|
||||
new_mood = ChatMood(chat_id)
|
||||
self.mood_list.append(new_mood)
|
||||
return new_mood
|
||||
|
||||
def reset_mood_by_chat_id(self, chat_id: str):
|
||||
for mood in self.mood_list:
|
||||
if mood.chat_id == chat_id:
|
||||
mood.mood_state = "感觉很平静"
|
||||
mood.regression_count = 0
|
||||
return
|
||||
self.mood_list.append(ChatMood(chat_id))
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
mood_manager = MoodManager()
|
||||
"""全局情绪管理器"""
|
||||
@@ -11,6 +11,9 @@ from .base import (
|
||||
BaseCommand,
|
||||
BaseTool,
|
||||
ConfigField,
|
||||
ConfigSection,
|
||||
ConfigLayout,
|
||||
ConfigTab,
|
||||
ComponentType,
|
||||
ActionActivationType,
|
||||
ChatMode,
|
||||
@@ -116,6 +119,9 @@ __all__ = [
|
||||
# 装饰器
|
||||
"register_plugin",
|
||||
"ConfigField",
|
||||
"ConfigSection",
|
||||
"ConfigLayout",
|
||||
"ConfigTab",
|
||||
# 工具函数
|
||||
"ManifestValidator",
|
||||
"get_logger",
|
||||
|
||||
@@ -19,7 +19,6 @@ from src.plugin_system.apis import (
|
||||
send_api,
|
||||
tool_api,
|
||||
frequency_api,
|
||||
mood_api,
|
||||
auto_talk_api,
|
||||
)
|
||||
from .logging_api import get_logger
|
||||
@@ -42,6 +41,5 @@ __all__ = [
|
||||
"register_plugin",
|
||||
"tool_api",
|
||||
"frequency_api",
|
||||
"mood_api",
|
||||
"auto_talk_api",
|
||||
]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.frequency_control.frequency_control import frequency_control_manager
|
||||
from src.chat.heart_flow.frequency_control import frequency_control_manager
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("frequency_api")
|
||||
|
||||
@@ -81,10 +81,12 @@ async def generate_reply(
|
||||
chat_id: Optional[str] = None,
|
||||
action_data: Optional[Dict[str, Any]] = None,
|
||||
reply_message: Optional["DatabaseMessages"] = None,
|
||||
think_level: int = 1,
|
||||
extra_info: str = "",
|
||||
reply_reason: str = "",
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
chosen_actions: Optional[List["ActionPlannerInfo"]] = None,
|
||||
unknown_words: Optional[List[str]] = None,
|
||||
enable_tool: bool = False,
|
||||
enable_splitter: bool = True,
|
||||
enable_chinese_typo: bool = True,
|
||||
@@ -103,6 +105,7 @@ async def generate_reply(
|
||||
reply_reason: 回复原因
|
||||
available_actions: 可用动作
|
||||
chosen_actions: 已选动作
|
||||
unknown_words: Planner 在 reply 动作中给出的未知词语列表,用于黑话检索
|
||||
enable_tool: 是否启用工具调用
|
||||
enable_splitter: 是否启用消息分割器
|
||||
enable_chinese_typo: 是否启用错字生成器
|
||||
@@ -122,11 +125,24 @@ async def generate_reply(
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
return False, None
|
||||
|
||||
if not extra_info and action_data:
|
||||
extra_info = action_data.get("extra_info", "")
|
||||
|
||||
if not reply_reason and action_data:
|
||||
reply_reason = action_data.get("reason", "")
|
||||
if action_data:
|
||||
if not extra_info:
|
||||
extra_info = action_data.get("extra_info", "")
|
||||
if not reply_reason:
|
||||
reply_reason = action_data.get("reason", "")
|
||||
# 仅在 reply 场景下使用的未知词语解析(Planner JSON 中下发)
|
||||
if unknown_words is None:
|
||||
uw = action_data.get("unknown_words")
|
||||
if isinstance(uw, list):
|
||||
# 只保留非空字符串
|
||||
cleaned: List[str] = []
|
||||
for item in uw:
|
||||
if isinstance(item, str):
|
||||
s = item.strip()
|
||||
if s:
|
||||
cleaned.append(s)
|
||||
if cleaned:
|
||||
unknown_words = cleaned
|
||||
|
||||
# 调用回复器生成回复
|
||||
success, llm_response = await replyer.generate_reply_with_context(
|
||||
@@ -136,6 +152,8 @@ async def generate_reply(
|
||||
enable_tool=enable_tool,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason,
|
||||
unknown_words=unknown_words,
|
||||
think_level=think_level,
|
||||
from_plugin=from_plugin,
|
||||
stream_id=chat_stream.stream_id if chat_stream else chat_id,
|
||||
reply_time_point=reply_time_point,
|
||||
|
||||
@@ -108,8 +108,8 @@ async def generate_with_model_with_tools(
|
||||
"""
|
||||
try:
|
||||
model_name_list = model_config.model_list
|
||||
logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容")
|
||||
logger.debug(f"[LLMAPI] 完整提示词: {prompt}")
|
||||
logger.info(f"使用模型{model_name_list}生成内容")
|
||||
logger.debug(f"完整提示词: {prompt}")
|
||||
|
||||
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
|
||||
|
||||
@@ -147,7 +147,7 @@ async def generate_with_model_with_tools_by_message_factory(
|
||||
"""
|
||||
try:
|
||||
model_name_list = model_config.model_list
|
||||
logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容(消息工厂)")
|
||||
logger.info(f"使用模型 {model_name_list} 生成内容")
|
||||
|
||||
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ def get_messages_by_time_in_chat(
|
||||
limit_mode: str = "latest",
|
||||
filter_mai: bool = False,
|
||||
filter_command: bool = False,
|
||||
filter_no_read_command: bool = False,
|
||||
filter_intercept_message_level: Optional[int] = None,
|
||||
) -> List[DatabaseMessages]:
|
||||
"""
|
||||
获取指定聊天中指定时间范围内的消息
|
||||
@@ -111,7 +111,7 @@ def get_messages_by_time_in_chat(
|
||||
limit_mode=limit_mode,
|
||||
filter_bot=filter_mai,
|
||||
filter_command=filter_command,
|
||||
filter_no_read_command=filter_no_read_command,
|
||||
filter_intercept_message_level=filter_intercept_message_level,
|
||||
)
|
||||
|
||||
|
||||
@@ -123,7 +123,7 @@ def get_messages_by_time_in_chat_inclusive(
|
||||
limit_mode: str = "latest",
|
||||
filter_mai: bool = False,
|
||||
filter_command: bool = False,
|
||||
filter_no_read_command: bool = False,
|
||||
filter_intercept_message_level: Optional[int] = None,
|
||||
) -> List[DatabaseMessages]:
|
||||
"""
|
||||
获取指定聊天中指定时间范围内的消息(包含边界)
|
||||
@@ -158,7 +158,7 @@ def get_messages_by_time_in_chat_inclusive(
|
||||
limit_mode=limit_mode,
|
||||
filter_bot=filter_mai,
|
||||
filter_command=filter_command,
|
||||
filter_no_read_command=filter_no_read_command,
|
||||
filter_intercept_message_level=filter_intercept_message_level,
|
||||
)
|
||||
if filter_mai:
|
||||
return filter_mai_messages(messages)
|
||||
@@ -284,7 +284,7 @@ def get_messages_before_time_in_chat(
|
||||
timestamp: float,
|
||||
limit: int = 0,
|
||||
filter_mai: bool = False,
|
||||
filter_no_read_command: bool = False,
|
||||
filter_intercept_message_level: Optional[int] = None,
|
||||
) -> List[DatabaseMessages]:
|
||||
"""
|
||||
获取指定聊天中指定时间戳之前的消息
|
||||
@@ -313,7 +313,7 @@ def get_messages_before_time_in_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=timestamp,
|
||||
limit=limit,
|
||||
filter_no_read_command=filter_no_read_command,
|
||||
filter_intercept_message_level=filter_intercept_message_level,
|
||||
)
|
||||
if filter_mai:
|
||||
return filter_mai_messages(messages)
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
logger = get_logger("mood_api")
|
||||
|
||||
|
||||
async def get_mood_by_chat_id(chat_id: str) -> Optional[float]:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
|
||||
mood = asyncio.create_task(chat_mood.get_mood())
|
||||
return mood
|
||||
@@ -29,7 +29,7 @@ from .component_types import (
|
||||
ForwardNode,
|
||||
ReplySetModel,
|
||||
)
|
||||
from .config_types import ConfigField
|
||||
from .config_types import ConfigField, ConfigSection, ConfigLayout, ConfigTab
|
||||
|
||||
__all__ = [
|
||||
"BasePlugin",
|
||||
@@ -46,6 +46,9 @@ __all__ = [
|
||||
"PluginInfo",
|
||||
"PythonDependency",
|
||||
"ConfigField",
|
||||
"ConfigSection",
|
||||
"ConfigLayout",
|
||||
"ConfigTab",
|
||||
"EventHandlerInfo",
|
||||
"EventType",
|
||||
"BaseEventHandler",
|
||||
|
||||
@@ -55,11 +55,11 @@ class BaseCommand(ABC):
|
||||
self.matched_groups = groups
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self) -> Tuple[bool, Optional[str], bool]:
|
||||
async def execute(self) -> Tuple[bool, Optional[str], int]:
|
||||
"""执行Command的抽象方法,子类必须实现
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str], bool]: (是否执行成功, 可选的回复消息, 是否拦截消息 不进行 后续处理)
|
||||
Tuple[bool, Optional[str], int]: (是否执行成功, 可选的回复消息, 拦截消息力度,0代表不拦截,1代表仅不触发回复,replyer可见,2代表不触发回复,replyer不可见)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -70,6 +70,12 @@ class ConfigField:
|
||||
depends_on: Optional[str] = None # 依赖的字段路径,如 "section.field"
|
||||
depends_value: Any = None # 依赖字段需要的值(当依赖字段等于此值时显示)
|
||||
|
||||
# === 列表类型专用 ===
|
||||
item_type: Optional[str] = None # 数组元素类型: "string", "number", "object"
|
||||
item_fields: Optional[Dict[str, Any]] = None # 当 item_type="object" 时,定义对象的字段结构
|
||||
min_items: Optional[int] = None # 数组最小元素数量
|
||||
max_items: Optional[int] = None # 数组最大元素数量
|
||||
|
||||
def get_ui_type(self) -> str:
|
||||
"""
|
||||
获取 UI 控件类型
|
||||
@@ -132,6 +138,10 @@ class ConfigField:
|
||||
"group": self.group,
|
||||
"depends_on": self.depends_on,
|
||||
"depends_value": self.depends_value,
|
||||
"item_type": self.item_type,
|
||||
"item_fields": self.item_fields,
|
||||
"min_items": self.min_items,
|
||||
"max_items": self.max_items,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ class EmojiAction(BaseAction):
|
||||
try:
|
||||
# 1. 获取发送表情的原因
|
||||
# reason = self.action_data.get("reason", "表达当前情绪")
|
||||
reason = self.reasoning
|
||||
reason = self.action_reasoning
|
||||
|
||||
# 2. 随机获取20个表情包
|
||||
sampled_emojis = await emoji_api.get_random(30)
|
||||
|
||||
796
src/webui/anti_crawler.py
Normal file
796
src/webui/anti_crawler.py
Normal file
@@ -0,0 +1,796 @@
|
||||
"""
|
||||
WebUI 防爬虫模块
|
||||
提供爬虫检测和阻止功能,保护 WebUI 不被搜索引擎和恶意爬虫访问
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import ipaddress
|
||||
import re
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui.anti_crawler")
|
||||
|
||||
# 常见爬虫 User-Agent 列表(使用更精确的关键词,避免误报)
|
||||
CRAWLER_USER_AGENTS = {
|
||||
# 搜索引擎爬虫(精确匹配)
|
||||
"googlebot",
|
||||
"bingbot",
|
||||
"baiduspider",
|
||||
"yandexbot",
|
||||
"slurp", # Yahoo
|
||||
"duckduckbot",
|
||||
"sogou",
|
||||
"exabot",
|
||||
"facebot",
|
||||
"ia_archiver", # Internet Archive
|
||||
# 通用爬虫(移除过于宽泛的关键词)
|
||||
"crawler",
|
||||
"spider",
|
||||
"scraper",
|
||||
"wget", # 保留wget,因为通常用于自动化脚本
|
||||
"scrapy", # 保留scrapy,因为这是爬虫框架
|
||||
# 安全扫描工具(这些是明确的扫描工具)
|
||||
"masscan",
|
||||
"nmap",
|
||||
"nikto",
|
||||
"sqlmap",
|
||||
# 注意:移除了以下过于宽泛的关键词以避免误报:
|
||||
# - "bot" (会误匹配GitHub-Robot等)
|
||||
# - "curl" (正常工具)
|
||||
# - "python-requests" (正常库)
|
||||
# - "httpx" (正常库)
|
||||
# - "aiohttp" (正常库)
|
||||
}
|
||||
|
||||
# 资产测绘工具 User-Agent 标识
|
||||
ASSET_SCANNER_USER_AGENTS = {
|
||||
# 知名资产测绘平台
|
||||
"shodan",
|
||||
"censys",
|
||||
"zoomeye",
|
||||
"fofa",
|
||||
"quake",
|
||||
"hunter",
|
||||
"binaryedge",
|
||||
"onyphe",
|
||||
"securitytrails",
|
||||
"virustotal",
|
||||
"passivetotal",
|
||||
# 安全扫描工具
|
||||
"acunetix",
|
||||
"appscan",
|
||||
"burpsuite",
|
||||
"nessus",
|
||||
"openvas",
|
||||
"qualys",
|
||||
"rapid7",
|
||||
"tenable",
|
||||
"veracode",
|
||||
"zap",
|
||||
"awvs", # Acunetix Web Vulnerability Scanner
|
||||
"netsparker",
|
||||
"skipfish",
|
||||
"w3af",
|
||||
"arachni",
|
||||
# 其他扫描工具
|
||||
"masscan",
|
||||
"zmap",
|
||||
"nmap",
|
||||
"whatweb",
|
||||
"wpscan",
|
||||
"joomscan",
|
||||
"dnsenum",
|
||||
"subfinder",
|
||||
"amass",
|
||||
"sublist3r",
|
||||
"theharvester",
|
||||
}
|
||||
|
||||
# 资产测绘工具常用的HTTP头标识
|
||||
ASSET_SCANNER_HEADERS = {
|
||||
# 常见的扫描工具自定义头
|
||||
"x-scan": {"shodan", "censys", "zoomeye", "fofa"},
|
||||
"x-scanner": {"nmap", "masscan", "zmap"},
|
||||
"x-probe": {"masscan", "zmap"},
|
||||
# 其他可疑头(移除反向代理标准头)
|
||||
"x-originating-ip": set(),
|
||||
"x-remote-ip": set(),
|
||||
"x-remote-addr": set(),
|
||||
# 注意:移除了以下反向代理标准头以避免误报:
|
||||
# - "x-forwarded-proto" (反向代理标准头)
|
||||
# - "x-real-ip" (反向代理标准头,已在_get_client_ip中使用)
|
||||
}
|
||||
|
||||
# 仅检查特定HTTP头中的可疑模式(收紧匹配范围)
|
||||
# 只检查这些特定头,不检查所有头
|
||||
SCANNER_SPECIFIC_HEADERS = {
|
||||
"x-scan",
|
||||
"x-scanner",
|
||||
"x-probe",
|
||||
"x-originating-ip",
|
||||
"x-remote-ip",
|
||||
"x-remote-addr",
|
||||
}
|
||||
|
||||
# 防爬虫模式配置
|
||||
# false: 禁用
|
||||
# strict: 严格模式(更严格的检测,更低的频率限制)
|
||||
# loose: 宽松模式(较宽松的检测,较高的频率限制)
|
||||
# basic: 基础模式(只记录恶意访问,不阻止,不限制请求数,不跟踪IP)
|
||||
|
||||
# IP白名单配置(从配置文件读取,逗号分隔)
|
||||
# 支持格式:
|
||||
# - 精确IP:127.0.0.1, 192.168.1.100
|
||||
# - CIDR格式:192.168.1.0/24, 172.17.0.0/16 (适用于Docker网络)
|
||||
# - 通配符:192.168.*.*, 10.*.*.*, *.*.*.* (匹配所有)
|
||||
# - IPv6:::1, 2001:db8::/32
|
||||
def _parse_allowed_ips(ip_string: str) -> list:
|
||||
"""
|
||||
解析IP白名单字符串,支持精确IP、CIDR格式和通配符
|
||||
|
||||
Args:
|
||||
ip_string: 逗号分隔的IP字符串
|
||||
|
||||
Returns:
|
||||
IP白名单列表,每个元素可能是:
|
||||
- ipaddress.IPv4Network/IPv6Network对象(CIDR格式)
|
||||
- ipaddress.IPv4Address/IPv6Address对象(精确IP)
|
||||
- str(通配符模式,已转换为正则表达式)
|
||||
"""
|
||||
allowed = []
|
||||
if not ip_string:
|
||||
return allowed
|
||||
|
||||
for ip_entry in ip_string.split(","):
|
||||
ip_entry = ip_entry.strip() # 去除空格
|
||||
if not ip_entry:
|
||||
continue
|
||||
|
||||
# 跳过注释行(以#开头)
|
||||
if ip_entry.startswith("#"):
|
||||
continue
|
||||
|
||||
# 检查通配符格式(包含*)
|
||||
if "*" in ip_entry:
|
||||
# 处理通配符
|
||||
pattern = _convert_wildcard_to_regex(ip_entry)
|
||||
if pattern:
|
||||
allowed.append(pattern)
|
||||
else:
|
||||
logger.warning(f"无效的通配符IP格式,已忽略: {ip_entry}")
|
||||
continue
|
||||
|
||||
try:
|
||||
# 尝试解析为CIDR格式(包含/)
|
||||
if "/" in ip_entry:
|
||||
allowed.append(ipaddress.ip_network(ip_entry, strict=False))
|
||||
else:
|
||||
# 精确IP地址
|
||||
allowed.append(ipaddress.ip_address(ip_entry))
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.warning(f"无效的IP白名单条目,已忽略: {ip_entry} ({e})")
|
||||
|
||||
return allowed
|
||||
|
||||
|
||||
def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
|
||||
"""
|
||||
将通配符IP模式转换为正则表达式
|
||||
|
||||
支持的格式:
|
||||
- 192.168.*.* 或 192.168.*
|
||||
- 10.*.*.* 或 10.*
|
||||
- *.*.*.* 或 *
|
||||
|
||||
Args:
|
||||
wildcard_pattern: 通配符模式字符串
|
||||
|
||||
Returns:
|
||||
正则表达式字符串,如果格式无效则返回None
|
||||
"""
|
||||
# 去除空格
|
||||
pattern = wildcard_pattern.strip()
|
||||
|
||||
# 处理单个*(匹配所有)
|
||||
if pattern == "*":
|
||||
return r".*"
|
||||
|
||||
# 处理IPv4通配符格式
|
||||
# 支持:192.168.*.*, 192.168.*, 10.*.*.*, 10.* 等
|
||||
parts = pattern.split(".")
|
||||
|
||||
if len(parts) > 4:
|
||||
return None # IPv4最多4段
|
||||
|
||||
# 构建正则表达式
|
||||
regex_parts = []
|
||||
for part in parts:
|
||||
part = part.strip()
|
||||
if part == "*":
|
||||
regex_parts.append(r"\d+") # 匹配任意数字
|
||||
elif part.isdigit():
|
||||
# 验证数字范围(0-255)
|
||||
num = int(part)
|
||||
if 0 <= num <= 255:
|
||||
regex_parts.append(re.escape(part))
|
||||
else:
|
||||
return None # 无效的数字
|
||||
else:
|
||||
return None # 无效的格式
|
||||
|
||||
# 如果部分少于4段,补充.*
|
||||
while len(regex_parts) < 4:
|
||||
regex_parts.append(r"\d+")
|
||||
|
||||
# 组合成正则表达式
|
||||
regex = r"^" + r"\.".join(regex_parts) + r"$"
|
||||
return regex
|
||||
|
||||
|
||||
# 从配置读取防爬虫设置(延迟导入避免循环依赖)
|
||||
def _get_anti_crawler_config():
|
||||
"""获取防爬虫配置"""
|
||||
from src.config.config import global_config
|
||||
return {
|
||||
'mode': global_config.webui.anti_crawler_mode,
|
||||
'allowed_ips': _parse_allowed_ips(global_config.webui.allowed_ips),
|
||||
'trusted_proxies': _parse_allowed_ips(global_config.webui.trusted_proxies),
|
||||
'trust_xff': global_config.webui.trust_xff
|
||||
}
|
||||
|
||||
# 初始化配置(将在模块加载时执行)
|
||||
_config = _get_anti_crawler_config()
|
||||
ANTI_CRAWLER_MODE = _config['mode']
|
||||
ALLOWED_IPS = _config['allowed_ips']
|
||||
TRUSTED_PROXIES = _config['trusted_proxies']
|
||||
TRUST_XFF = _config['trust_xff']
|
||||
|
||||
|
||||
def _get_mode_config(mode: str) -> dict:
|
||||
"""
|
||||
根据模式获取配置参数
|
||||
|
||||
Args:
|
||||
mode: 防爬虫模式 (false/strict/loose/basic)
|
||||
|
||||
Returns:
|
||||
配置字典,包含所有相关参数
|
||||
"""
|
||||
mode = mode.lower()
|
||||
|
||||
if mode == "false":
|
||||
return {
|
||||
"enabled": False,
|
||||
"rate_limit_window": 60,
|
||||
"rate_limit_max_requests": 1000, # 禁用时设置很高的值
|
||||
"max_tracked_ips": 0,
|
||||
"check_user_agent": False,
|
||||
"check_asset_scanner": False,
|
||||
"check_rate_limit": False,
|
||||
"block_on_detect": False, # 不阻止
|
||||
}
|
||||
elif mode == "strict":
|
||||
return {
|
||||
"enabled": True,
|
||||
"rate_limit_window": 60,
|
||||
"rate_limit_max_requests": 15, # 严格模式:更低的请求数
|
||||
"max_tracked_ips": 20000,
|
||||
"check_user_agent": True,
|
||||
"check_asset_scanner": True,
|
||||
"check_rate_limit": True,
|
||||
"block_on_detect": True, # 阻止恶意访问
|
||||
}
|
||||
elif mode == "loose":
|
||||
return {
|
||||
"enabled": True,
|
||||
"rate_limit_window": 60,
|
||||
"rate_limit_max_requests": 60, # 宽松模式:更高的请求数
|
||||
"max_tracked_ips": 5000,
|
||||
"check_user_agent": True,
|
||||
"check_asset_scanner": True,
|
||||
"check_rate_limit": True,
|
||||
"block_on_detect": True, # 阻止恶意访问
|
||||
}
|
||||
else: # basic (默认模式)
|
||||
return {
|
||||
"enabled": True,
|
||||
"rate_limit_window": 60,
|
||||
"rate_limit_max_requests": 1000, # 不限制请求数
|
||||
"max_tracked_ips": 0, # 不跟踪IP
|
||||
"check_user_agent": True, # 检测但不阻止
|
||||
"check_asset_scanner": True, # 检测但不阻止
|
||||
"check_rate_limit": False, # 不限制请求频率
|
||||
"block_on_detect": False, # 只记录,不阻止
|
||||
}
|
||||
|
||||
|
||||
class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
||||
"""防爬虫中间件"""
|
||||
|
||||
def __init__(self, app, mode: str = "standard"):
|
||||
"""
|
||||
初始化防爬虫中间件
|
||||
|
||||
Args:
|
||||
app: FastAPI 应用实例
|
||||
mode: 防爬虫模式 (false/strict/loose/standard)
|
||||
"""
|
||||
super().__init__(app)
|
||||
self.mode = mode.lower()
|
||||
# 根据模式获取配置
|
||||
config = _get_mode_config(self.mode)
|
||||
self.enabled = config["enabled"]
|
||||
self.rate_limit_window = config["rate_limit_window"]
|
||||
self.rate_limit_max_requests = config["rate_limit_max_requests"]
|
||||
self.max_tracked_ips = config["max_tracked_ips"]
|
||||
self.check_user_agent = config["check_user_agent"]
|
||||
self.check_asset_scanner = config["check_asset_scanner"]
|
||||
self.check_rate_limit = config["check_rate_limit"]
|
||||
self.block_on_detect = config["block_on_detect"] # 是否阻止检测到的恶意访问
|
||||
|
||||
# 用于存储每个IP的请求时间戳(使用deque提高性能)
|
||||
self.request_times: dict[str, deque] = {}
|
||||
# 上次清理时间
|
||||
self.last_cleanup = time.time()
|
||||
# 将关键词列表转换为集合以提高查找性能
|
||||
self.crawler_keywords_set = set(CRAWLER_USER_AGENTS)
|
||||
self.scanner_keywords_set = set(ASSET_SCANNER_USER_AGENTS)
|
||||
|
||||
def _is_crawler_user_agent(self, user_agent: Optional[str]) -> bool:
|
||||
"""
|
||||
检测是否为爬虫 User-Agent
|
||||
|
||||
Args:
|
||||
user_agent: User-Agent 字符串
|
||||
|
||||
Returns:
|
||||
如果是爬虫则返回 True
|
||||
"""
|
||||
if not user_agent:
|
||||
# 没有 User-Agent 的请求记录日志但不直接阻止
|
||||
# 改为只记录,让频率限制来处理
|
||||
logger.debug("请求缺少User-Agent")
|
||||
return False # 不再直接阻止无User-Agent的请求
|
||||
|
||||
user_agent_lower = user_agent.lower()
|
||||
|
||||
# 使用集合查找提高性能(检查是否包含爬虫关键词)
|
||||
for crawler_keyword in self.crawler_keywords_set:
|
||||
if crawler_keyword in user_agent_lower:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_asset_scanner_header(self, request: Request) -> bool:
|
||||
"""
|
||||
检测是否为资产测绘工具的HTTP头(只检查特定头,收紧匹配)
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
|
||||
Returns:
|
||||
如果检测到资产测绘工具头则返回 True
|
||||
"""
|
||||
# 只检查特定的扫描工具头,不检查所有头
|
||||
for header_name, header_value in request.headers.items():
|
||||
header_name_lower = header_name.lower()
|
||||
header_value_lower = header_value.lower() if header_value else ""
|
||||
|
||||
# 检查已知的扫描工具头
|
||||
if header_name_lower in ASSET_SCANNER_HEADERS:
|
||||
# 如果该头有特定的工具集合,检查值是否匹配
|
||||
expected_tools = ASSET_SCANNER_HEADERS[header_name_lower]
|
||||
if expected_tools:
|
||||
for tool in expected_tools:
|
||||
if tool in header_value_lower:
|
||||
return True
|
||||
else:
|
||||
# 如果没有特定工具集合,只要存在该头就视为可疑
|
||||
if header_value_lower:
|
||||
return True
|
||||
|
||||
# 只检查特定头中的可疑模式(收紧匹配)
|
||||
if header_name_lower in SCANNER_SPECIFIC_HEADERS:
|
||||
# 检查头值中是否包含已知扫描工具名称
|
||||
for tool in self.scanner_keywords_set:
|
||||
if tool in header_value_lower:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _detect_asset_scanner(self, request: Request) -> tuple[bool, Optional[str]]:
|
||||
"""
|
||||
检测资产测绘工具
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
|
||||
Returns:
|
||||
(是否检测到, 检测到的工具名称)
|
||||
"""
|
||||
user_agent = request.headers.get("User-Agent")
|
||||
|
||||
# 检查 User-Agent(使用集合查找提高性能)
|
||||
if user_agent:
|
||||
user_agent_lower = user_agent.lower()
|
||||
for scanner_keyword in self.scanner_keywords_set:
|
||||
if scanner_keyword in user_agent_lower:
|
||||
return True, scanner_keyword
|
||||
|
||||
# 检查HTTP头
|
||||
if self._is_asset_scanner_header(request):
|
||||
# 尝试从User-Agent或头中提取工具名称
|
||||
detected_tool = None
|
||||
if user_agent:
|
||||
user_agent_lower = user_agent.lower()
|
||||
for tool in self.scanner_keywords_set:
|
||||
if tool in user_agent_lower:
|
||||
detected_tool = tool
|
||||
break
|
||||
|
||||
# 检查HTTP头中的工具标识(只检查特定头)
|
||||
if not detected_tool:
|
||||
for header_name, header_value in request.headers.items():
|
||||
header_name_lower = header_name.lower()
|
||||
if header_name_lower in SCANNER_SPECIFIC_HEADERS:
|
||||
header_value_lower = (header_value or "").lower()
|
||||
for tool in self.scanner_keywords_set:
|
||||
if tool in header_value_lower:
|
||||
detected_tool = tool
|
||||
break
|
||||
if detected_tool:
|
||||
break
|
||||
|
||||
return True, detected_tool or "unknown_scanner"
|
||||
|
||||
return False, None
|
||||
|
||||
def _check_rate_limit(self, client_ip: str) -> bool:
|
||||
"""
|
||||
检查请求频率限制
|
||||
|
||||
Args:
|
||||
client_ip: 客户端IP地址
|
||||
|
||||
Returns:
|
||||
如果超过限制则返回 True(需要阻止)
|
||||
"""
|
||||
# 检查IP白名单
|
||||
if self._is_ip_allowed(client_ip):
|
||||
return False
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 定期清理过期的请求记录(每5分钟清理一次)
|
||||
if current_time - self.last_cleanup > 300:
|
||||
self._cleanup_old_requests(current_time)
|
||||
self.last_cleanup = current_time
|
||||
|
||||
# 限制跟踪的IP数量,防止内存泄漏
|
||||
if self.max_tracked_ips > 0 and len(self.request_times) > self.max_tracked_ips:
|
||||
# 清理最旧的记录(删除最久未访问的IP)
|
||||
self._cleanup_oldest_ips()
|
||||
|
||||
# 获取或创建该IP的请求时间deque(不使用maxlen,避免限流变松)
|
||||
if client_ip not in self.request_times:
|
||||
self.request_times[client_ip] = deque()
|
||||
|
||||
request_times = self.request_times[client_ip]
|
||||
|
||||
# 移除时间窗口外的请求记录(从左侧弹出过期记录)
|
||||
while request_times and current_time - request_times[0] >= self.rate_limit_window:
|
||||
request_times.popleft()
|
||||
|
||||
# 检查是否超过限制
|
||||
if len(request_times) >= self.rate_limit_max_requests:
|
||||
return True
|
||||
|
||||
# 记录当前请求时间
|
||||
request_times.append(current_time)
|
||||
return False
|
||||
|
||||
def _cleanup_old_requests(self, current_time: float):
|
||||
"""清理过期的请求记录(只清理当前需要检查的IP,不全量遍历)"""
|
||||
# 这个方法现在主要用于定期清理,实际清理在_check_rate_limit中按需进行
|
||||
# 清理最久未访问的IP记录
|
||||
if len(self.request_times) > self.max_tracked_ips * 0.8:
|
||||
self._cleanup_oldest_ips()
|
||||
|
||||
def _cleanup_oldest_ips(self):
|
||||
"""清理最久未访问的IP记录(全量遍历找真正的oldest)"""
|
||||
if not self.request_times:
|
||||
return
|
||||
|
||||
# 先收集空deque的IP(优先删除)
|
||||
empty_ips = []
|
||||
# 找到最久未访问的IP(最旧时间戳)
|
||||
oldest_ip = None
|
||||
oldest_time = float("inf")
|
||||
|
||||
# 全量遍历找真正的oldest(超限时性能可接受)
|
||||
for ip, times in self.request_times.items():
|
||||
if not times:
|
||||
# 空deque,记录待删除
|
||||
empty_ips.append(ip)
|
||||
else:
|
||||
# 找到最旧的时间戳
|
||||
if times[0] < oldest_time:
|
||||
oldest_time = times[0]
|
||||
oldest_ip = ip
|
||||
|
||||
# 先删除空deque的IP
|
||||
for ip in empty_ips:
|
||||
del self.request_times[ip]
|
||||
|
||||
# 如果没有空deque可删除,且仍需要清理,删除最旧的一个IP
|
||||
if not empty_ips and oldest_ip:
|
||||
del self.request_times[oldest_ip]
|
||||
|
||||
def _is_trusted_proxy(self, ip: str) -> bool:
|
||||
"""
|
||||
检查IP是否在信任的代理列表中
|
||||
|
||||
Args:
|
||||
ip: IP地址字符串
|
||||
|
||||
Returns:
|
||||
如果是信任的代理则返回 True
|
||||
"""
|
||||
if not TRUSTED_PROXIES or ip == "unknown":
|
||||
return False
|
||||
|
||||
# 检查代理列表中的每个条目
|
||||
for trusted_entry in TRUSTED_PROXIES:
|
||||
# 通配符模式(字符串,正则表达式)
|
||||
if isinstance(trusted_entry, str):
|
||||
try:
|
||||
if re.match(trusted_entry, ip):
|
||||
return True
|
||||
except re.error:
|
||||
continue
|
||||
# CIDR格式(网络对象)
|
||||
elif isinstance(trusted_entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
|
||||
try:
|
||||
client_ip_obj = ipaddress.ip_address(ip)
|
||||
if client_ip_obj in trusted_entry:
|
||||
return True
|
||||
except (ValueError, AttributeError):
|
||||
continue
|
||||
# 精确IP(地址对象)
|
||||
elif isinstance(trusted_entry, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
|
||||
try:
|
||||
client_ip_obj = ipaddress.ip_address(ip)
|
||||
if client_ip_obj == trusted_entry:
|
||||
return True
|
||||
except (ValueError, AttributeError):
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""
|
||||
获取客户端真实IP地址(带基本验证和代理信任检查)
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
|
||||
Returns:
|
||||
客户端IP地址
|
||||
"""
|
||||
# 获取直接连接的客户端IP(用于验证代理)
|
||||
direct_client_ip = None
|
||||
if request.client:
|
||||
direct_client_ip = request.client.host
|
||||
|
||||
# 检查是否信任X-Forwarded-For头
|
||||
# TRUST_XFF 只表示"启用代理解析能力",但仍要求直连 IP 在 TRUSTED_PROXIES 中
|
||||
use_xff = False
|
||||
if TRUST_XFF and TRUSTED_PROXIES and direct_client_ip:
|
||||
# 只有在启用 TRUST_XFF 且直连 IP 在信任列表中时,才信任 XFF
|
||||
use_xff = self._is_trusted_proxy(direct_client_ip)
|
||||
|
||||
# 如果信任代理,优先从 X-Forwarded-For 获取
|
||||
if use_xff:
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
# X-Forwarded-For 可能包含多个IP,取第一个
|
||||
ip = forwarded_for.split(",")[0].strip()
|
||||
# 基本验证IP格式
|
||||
if self._validate_ip(ip):
|
||||
return ip
|
||||
|
||||
# 从 X-Real-IP 获取(如果信任代理)
|
||||
if use_xff:
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
ip = real_ip.strip()
|
||||
if self._validate_ip(ip):
|
||||
return ip
|
||||
|
||||
# 使用直接连接的客户端IP
|
||||
if direct_client_ip and self._validate_ip(direct_client_ip):
|
||||
return direct_client_ip
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _validate_ip(self, ip: str) -> bool:
|
||||
"""
|
||||
验证IP地址格式
|
||||
|
||||
Args:
|
||||
ip: IP地址字符串
|
||||
|
||||
Returns:
|
||||
如果格式有效则返回 True
|
||||
"""
|
||||
try:
|
||||
ipaddress.ip_address(ip)
|
||||
return True
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
|
||||
def _is_ip_allowed(self, ip: str) -> bool:
|
||||
"""
|
||||
检查IP是否在白名单中(支持精确IP、CIDR格式和通配符)
|
||||
|
||||
Args:
|
||||
ip: 客户端IP地址
|
||||
|
||||
Returns:
|
||||
如果IP在白名单中则返回 True
|
||||
"""
|
||||
if not ALLOWED_IPS or ip == "unknown":
|
||||
return False
|
||||
|
||||
# 检查白名单中的每个条目
|
||||
for allowed_entry in ALLOWED_IPS:
|
||||
# 通配符模式(字符串,正则表达式)
|
||||
if isinstance(allowed_entry, str):
|
||||
try:
|
||||
if re.match(allowed_entry, ip):
|
||||
return True
|
||||
except re.error:
|
||||
# 正则表达式错误,跳过
|
||||
continue
|
||||
# CIDR格式(网络对象)
|
||||
elif isinstance(allowed_entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
|
||||
try:
|
||||
client_ip_obj = ipaddress.ip_address(ip)
|
||||
if client_ip_obj in allowed_entry:
|
||||
return True
|
||||
except (ValueError, AttributeError):
|
||||
# IP格式无效,跳过
|
||||
continue
|
||||
# 精确IP(地址对象)
|
||||
elif isinstance(allowed_entry, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
|
||||
try:
|
||||
client_ip_obj = ipaddress.ip_address(ip)
|
||||
if client_ip_obj == allowed_entry:
|
||||
return True
|
||||
except (ValueError, AttributeError):
|
||||
# IP格式无效,跳过
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""
|
||||
处理请求
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
call_next: 下一个中间件或路由处理函数
|
||||
|
||||
Returns:
|
||||
响应对象
|
||||
"""
|
||||
# 如果未启用,直接通过
|
||||
if not self.enabled:
|
||||
return await call_next(request)
|
||||
|
||||
# 允许访问 robots.txt(由专门的路由处理)
|
||||
if request.url.path == "/robots.txt":
|
||||
return await call_next(request)
|
||||
|
||||
# 允许访问静态资源(CSS、JS、图片等)
|
||||
# 注意:.json 已移除,避免 API 路径绕过防护
|
||||
# 静态资源只在特定前缀下放行(/static/、/assets/、/dist/)
|
||||
static_extensions = {
|
||||
".css",
|
||||
".js",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".svg",
|
||||
".ico",
|
||||
".woff",
|
||||
".woff2",
|
||||
".ttf",
|
||||
".eot",
|
||||
}
|
||||
static_prefixes = {"/static/", "/assets/", "/dist/"}
|
||||
|
||||
# 检查是否是静态资源路径(特定前缀下的静态文件)
|
||||
path = request.url.path
|
||||
is_static_path = any(path.startswith(prefix) for prefix in static_prefixes) and any(
|
||||
path.endswith(ext) for ext in static_extensions
|
||||
)
|
||||
|
||||
# 也允许根路径下的静态文件(如 /favicon.ico)
|
||||
is_root_static = path.count("/") == 1 and any(path.endswith(ext) for ext in static_extensions)
|
||||
|
||||
if is_static_path or is_root_static:
|
||||
return await call_next(request)
|
||||
|
||||
# 获取客户端IP(只获取一次,避免重复调用)
|
||||
client_ip = self._get_client_ip(request)
|
||||
|
||||
# 检查IP白名单(优先检查,白名单IP直接通过)
|
||||
if self._is_ip_allowed(client_ip):
|
||||
return await call_next(request)
|
||||
|
||||
# 获取 User-Agent
|
||||
user_agent = request.headers.get("User-Agent")
|
||||
|
||||
# 检测资产测绘工具(优先检测,因为更危险)
|
||||
if self.check_asset_scanner:
|
||||
is_scanner, scanner_name = self._detect_asset_scanner(request)
|
||||
if is_scanner:
|
||||
logger.warning(
|
||||
f"🚫 检测到资产测绘工具请求 - IP: {client_ip}, 工具: {scanner_name}, "
|
||||
f"User-Agent: {user_agent}, Path: {request.url.path}"
|
||||
)
|
||||
# 根据配置决定是否阻止
|
||||
if self.block_on_detect:
|
||||
return PlainTextResponse(
|
||||
"Access Denied: Asset scanning tools are not allowed",
|
||||
status_code=403,
|
||||
)
|
||||
|
||||
# 检测爬虫 User-Agent
|
||||
if self.check_user_agent and self._is_crawler_user_agent(user_agent):
|
||||
logger.warning(f"🚫 检测到爬虫请求 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
|
||||
# 根据配置决定是否阻止
|
||||
if self.block_on_detect:
|
||||
return PlainTextResponse(
|
||||
"Access Denied: Crawlers are not allowed",
|
||||
status_code=403,
|
||||
)
|
||||
|
||||
# 检查请求频率限制
|
||||
if self.check_rate_limit and self._check_rate_limit(client_ip):
|
||||
logger.warning(f"🚫 请求频率过高 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
|
||||
return PlainTextResponse(
|
||||
"Too Many Requests: Rate limit exceeded",
|
||||
status_code=429,
|
||||
)
|
||||
|
||||
# 正常请求,继续处理
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
def create_robots_txt_response() -> PlainTextResponse:
|
||||
"""
|
||||
创建 robots.txt 响应
|
||||
|
||||
Returns:
|
||||
robots.txt 响应对象
|
||||
"""
|
||||
robots_content = """User-agent: *
|
||||
Disallow: /
|
||||
|
||||
# 禁止所有爬虫访问
|
||||
"""
|
||||
return PlainTextResponse(
|
||||
content=robots_content,
|
||||
media_type="text/plain",
|
||||
headers={"Cache-Control": "public, max-age=86400"}, # 缓存24小时
|
||||
)
|
||||
@@ -6,6 +6,7 @@ WebUI 认证模块
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException, Cookie, Header, Response, Request
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from .token_manager import get_token_manager
|
||||
|
||||
logger = get_logger("webui.auth")
|
||||
@@ -15,6 +16,28 @@ COOKIE_NAME = "maibot_session"
|
||||
COOKIE_MAX_AGE = 7 * 24 * 60 * 60 # 7天
|
||||
|
||||
|
||||
def _is_secure_environment() -> bool:
|
||||
"""
|
||||
检测是否应该启用安全 Cookie(HTTPS)
|
||||
|
||||
Returns:
|
||||
bool: 如果应该使用 secure cookie 则返回 True
|
||||
"""
|
||||
# 从配置读取
|
||||
if global_config.webui.secure_cookie:
|
||||
logger.info("配置中启用了 secure_cookie")
|
||||
return True
|
||||
|
||||
# 检查是否是生产环境
|
||||
if global_config.webui.mode == "production":
|
||||
logger.info("WebUI运行在生产模式,启用 secure cookie")
|
||||
return True
|
||||
|
||||
# 默认:开发环境不启用(因为通常是 HTTP)
|
||||
logger.debug("WebUI运行在开发模式,禁用 secure cookie")
|
||||
return False
|
||||
|
||||
|
||||
def get_current_token(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
@@ -22,69 +45,102 @@ def get_current_token(
|
||||
) -> str:
|
||||
"""
|
||||
获取当前请求的 token,优先从 Cookie 获取,其次从 Header 获取
|
||||
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization Header (Bearer token)
|
||||
|
||||
|
||||
Returns:
|
||||
验证通过的 token
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: 认证失败时抛出 401 错误
|
||||
"""
|
||||
token = None
|
||||
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
# 其次从 Header 获取(兼容旧版本)
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
|
||||
# 验证 token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def set_auth_cookie(response: Response, token: str) -> None:
|
||||
def set_auth_cookie(response: Response, token: str, request: Optional[Request] = None) -> None:
|
||||
"""
|
||||
设置认证 Cookie
|
||||
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
token: 要设置的 token
|
||||
request: FastAPI Request 对象(可选,用于检测协议)
|
||||
"""
|
||||
# 根据环境和实际请求协议决定安全设置
|
||||
is_secure = _is_secure_environment()
|
||||
|
||||
# 如果提供了 request,检测实际使用的协议
|
||||
if request:
|
||||
# 检查 X-Forwarded-Proto header(代理/负载均衡器)
|
||||
forwarded_proto = request.headers.get("x-forwarded-proto", "").lower()
|
||||
if forwarded_proto:
|
||||
is_https = forwarded_proto == "https"
|
||||
logger.debug(f"检测到 X-Forwarded-Proto: {forwarded_proto}, is_https={is_https}")
|
||||
else:
|
||||
# 检查 request.url.scheme
|
||||
is_https = request.url.scheme == "https"
|
||||
logger.debug(f"检测到 scheme: {request.url.scheme}, is_https={is_https}")
|
||||
|
||||
# 如果是 HTTP 连接,强制禁用 secure 标志
|
||||
if not is_https and is_secure:
|
||||
logger.warning("=" * 80)
|
||||
logger.warning("检测到 HTTP 连接但环境配置要求 HTTPS (secure cookie)")
|
||||
logger.warning("已自动禁用 secure 标志以允许登录,但建议修改配置:")
|
||||
logger.warning("1. 在配置文件中设置: webui.secure_cookie = false")
|
||||
logger.warning("2. 如果使用反向代理,请确保正确配置 X-Forwarded-Proto 头")
|
||||
logger.warning("=" * 80)
|
||||
is_secure = False
|
||||
|
||||
# 设置 Cookie
|
||||
response.set_cookie(
|
||||
key=COOKIE_NAME,
|
||||
value=token,
|
||||
max_age=COOKIE_MAX_AGE,
|
||||
httponly=True, # 防止 JS 读取
|
||||
samesite="lax", # 允许同站导航时发送 Cookie(兼容开发环境代理)
|
||||
secure=False, # 本地开发不强制 HTTPS,生产环境建议设为 True
|
||||
httponly=True, # 防止 JS 读取,阻止 XSS 窃取
|
||||
samesite="lax", # 使用 lax 以兼容更多场景(开发和生产)
|
||||
secure=is_secure, # 根据实际协议决定
|
||||
path="/", # 确保 Cookie 在所有路径下可用
|
||||
)
|
||||
logger.debug(f"已设置认证 Cookie: {token[:8]}...")
|
||||
|
||||
logger.info(f"已设置认证 Cookie: {token[:8]}... (secure={is_secure}, samesite=lax, httponly=True, path=/, max_age={COOKIE_MAX_AGE})")
|
||||
logger.debug(f"完整 token 前缀: {token[:20]}...")
|
||||
|
||||
|
||||
def clear_auth_cookie(response: Response) -> None:
|
||||
"""
|
||||
清除认证 Cookie
|
||||
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
"""
|
||||
# 保持与 set_auth_cookie 相同的安全设置
|
||||
is_secure = _is_secure_environment()
|
||||
|
||||
response.delete_cookie(
|
||||
key=COOKIE_NAME,
|
||||
httponly=True,
|
||||
samesite="lax",
|
||||
samesite="strict" if is_secure else "lax",
|
||||
secure=is_secure,
|
||||
path="/",
|
||||
)
|
||||
logger.debug("已清除认证 Cookie")
|
||||
@@ -96,32 +152,32 @@ def verify_auth_token_from_cookie_or_header(
|
||||
) -> bool:
|
||||
"""
|
||||
验证认证 Token,支持从 Cookie 或 Header 获取
|
||||
|
||||
|
||||
Args:
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
|
||||
Returns:
|
||||
验证成功返回 True
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: 认证失败时抛出 401 错误
|
||||
"""
|
||||
token = None
|
||||
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
# 其次从 Header 获取(兼容旧版本)
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
|
||||
# 验证 token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
|
||||
return True
|
||||
|
||||
@@ -8,18 +8,30 @@
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional, List
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends, Cookie, Header
|
||||
from pydantic import BaseModel
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Messages, PersonInfo
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
from src.webui.token_manager import get_token_manager
|
||||
from src.webui.ws_auth import verify_ws_token
|
||||
|
||||
logger = get_logger("webui.chat")
|
||||
|
||||
router = APIRouter(prefix="/api/chat", tags=["LocalChat"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
# WebUI 聊天的虚拟群组 ID
|
||||
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
|
||||
WEBUI_CHAT_PLATFORM = "webui"
|
||||
@@ -63,14 +75,14 @@ class ChatHistoryManager:
|
||||
|
||||
def _message_to_dict(self, msg: Messages, group_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""将数据库消息转换为前端格式
|
||||
|
||||
|
||||
Args:
|
||||
msg: 数据库消息对象
|
||||
group_id: 群 ID,用于判断是否是虚拟群
|
||||
"""
|
||||
# 判断是否是机器人消息
|
||||
user_id = msg.user_id or ""
|
||||
|
||||
|
||||
# 对于虚拟群,通过比较机器人 QQ 账号来判断
|
||||
# 对于普通 WebUI 群,检查 user_id 是否以 webui_ 开头
|
||||
if group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
|
||||
@@ -256,6 +268,7 @@ async def get_chat_history(
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
|
||||
group_id: Optional[str] = Query(default=None), # 可选:指定群 ID 获取历史
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""获取聊天历史记录
|
||||
|
||||
@@ -272,7 +285,7 @@ async def get_chat_history(
|
||||
|
||||
|
||||
@router.get("/platforms")
|
||||
async def get_available_platforms():
|
||||
async def get_available_platforms(_auth: bool = Depends(require_auth)):
|
||||
"""获取可用平台列表
|
||||
|
||||
从 PersonInfo 表中获取所有已知的平台
|
||||
@@ -303,6 +316,7 @@ async def get_persons_by_platform(
|
||||
platform: str = Query(..., description="平台名称"),
|
||||
search: Optional[str] = Query(default=None, description="搜索关键词"),
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""获取指定平台的用户列表
|
||||
|
||||
@@ -350,7 +364,7 @@ async def get_persons_by_platform(
|
||||
|
||||
|
||||
@router.delete("/history")
|
||||
async def clear_chat_history(group_id: Optional[str] = Query(default=None)):
|
||||
async def clear_chat_history(group_id: Optional[str] = Query(default=None), _auth: bool = Depends(require_auth)):
|
||||
"""清空聊天历史记录
|
||||
|
||||
Args:
|
||||
@@ -372,6 +386,7 @@ async def websocket_chat(
|
||||
person_id: Optional[str] = Query(default=None),
|
||||
group_name: Optional[str] = Query(default=None),
|
||||
group_id: Optional[str] = Query(default=None), # 前端传递的稳定 group_id
|
||||
token: Optional[str] = Query(default=None), # 认证 token
|
||||
):
|
||||
"""WebSocket 聊天端点
|
||||
|
||||
@@ -382,9 +397,45 @@ async def websocket_chat(
|
||||
person_id: 虚拟身份模式的用户 person_id(可选)
|
||||
group_name: 虚拟身份模式的群名(可选)
|
||||
group_id: 虚拟身份模式的群 ID(可选,由前端生成并持久化)
|
||||
token: 认证 token(可选,也可从 Cookie 获取)
|
||||
|
||||
虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置
|
||||
|
||||
支持三种认证方式(按优先级):
|
||||
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||
2. Cookie 中的 maibot_session
|
||||
3. 直接使用 session token(兼容)
|
||||
|
||||
示例:ws://host/api/chat/ws?token=xxx
|
||||
"""
|
||||
is_authenticated = False
|
||||
|
||||
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||
if token and verify_ws_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
|
||||
|
||||
# 方式 2: 尝试从 Cookie 获取 session token
|
||||
if not is_authenticated:
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
if cookie_token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
is_authenticated = True
|
||||
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
|
||||
|
||||
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||
if not is_authenticated and token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("聊天 WebSocket 使用 session token 认证成功")
|
||||
|
||||
if not is_authenticated:
|
||||
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
# 生成会话 ID(每次连接都是新的)
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
@@ -414,7 +465,9 @@ async def websocket_chat(
|
||||
group_id=virtual_group_id,
|
||||
group_name=group_name or "WebUI虚拟群聊",
|
||||
)
|
||||
logger.info(f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}")
|
||||
logger.info(
|
||||
f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
|
||||
|
||||
@@ -710,7 +763,7 @@ async def websocket_chat(
|
||||
|
||||
|
||||
@router.get("/info")
|
||||
async def get_chat_info():
|
||||
async def get_chat_info(_auth: bool = Depends(require_auth)):
|
||||
"""获取聊天室信息"""
|
||||
return {
|
||||
"bot_name": global_config.bot.nickname,
|
||||
|
||||
@@ -4,11 +4,12 @@
|
||||
|
||||
import os
|
||||
import tomlkit
|
||||
from fastapi import APIRouter, HTTPException, Body
|
||||
from typing import Any, Annotated
|
||||
from fastapi import APIRouter, HTTPException, Body, Depends, Cookie, Header
|
||||
from typing import Any, Annotated, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.toml_utils import save_toml_with_format
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
from src.common.toml_utils import save_toml_with_format, _update_toml_doc
|
||||
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
|
||||
from src.config.official_configs import (
|
||||
BotConfig,
|
||||
@@ -29,9 +30,7 @@ from src.config.official_configs import (
|
||||
ToolConfig,
|
||||
MemoryConfig,
|
||||
DebugConfig,
|
||||
MoodConfig,
|
||||
VoiceConfig,
|
||||
JargonConfig,
|
||||
)
|
||||
from src.config.api_ada_configs import (
|
||||
ModelTaskConfig,
|
||||
@@ -51,45 +50,19 @@ PathBody = Annotated[dict[str, str], Body()]
|
||||
router = APIRouter(prefix="/config", tags=["config"])
|
||||
|
||||
|
||||
# ===== 辅助函数 =====
|
||||
|
||||
|
||||
def _update_dict_preserve_comments(target: Any, source: Any) -> None:
|
||||
"""
|
||||
递归合并字典,保留 target 中的注释和格式
|
||||
将 source 的值更新到 target 中(仅更新已存在的键)
|
||||
|
||||
Args:
|
||||
target: 目标字典(tomlkit 对象,包含注释)
|
||||
source: 源字典(普通 dict 或 list)
|
||||
"""
|
||||
# 如果 source 是列表,直接替换(数组表没有注释保留的意义)
|
||||
if isinstance(source, list):
|
||||
return # 调用者需要直接赋值
|
||||
|
||||
# 如果都是字典,递归合并
|
||||
if isinstance(source, dict) and isinstance(target, dict):
|
||||
for key, value in source.items():
|
||||
if key == "version":
|
||||
continue # 跳过版本号
|
||||
if key in target:
|
||||
target_value = target[key]
|
||||
# 递归处理嵌套字典
|
||||
if isinstance(value, dict) and isinstance(target_value, dict):
|
||||
_update_dict_preserve_comments(target_value, value)
|
||||
else:
|
||||
# 使用 tomlkit.item 保持类型
|
||||
try:
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
target[key] = value
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
# ===== 架构获取接口 =====
|
||||
|
||||
|
||||
@router.get("/schema/bot")
|
||||
async def get_bot_config_schema():
|
||||
async def get_bot_config_schema(_auth: bool = Depends(require_auth)):
|
||||
"""获取麦麦主程序配置架构"""
|
||||
try:
|
||||
# Config 类包含所有子配置
|
||||
@@ -101,7 +74,7 @@ async def get_bot_config_schema():
|
||||
|
||||
|
||||
@router.get("/schema/model")
|
||||
async def get_model_config_schema():
|
||||
async def get_model_config_schema(_auth: bool = Depends(require_auth)):
|
||||
"""获取模型配置架构(包含提供商和模型任务配置)"""
|
||||
try:
|
||||
schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig)
|
||||
@@ -115,7 +88,7 @@ async def get_model_config_schema():
|
||||
|
||||
|
||||
@router.get("/schema/section/{section_name}")
|
||||
async def get_config_section_schema(section_name: str):
|
||||
async def get_config_section_schema(section_name: str, _auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
获取指定配置节的架构
|
||||
|
||||
@@ -138,7 +111,6 @@ async def get_config_section_schema(section_name: str):
|
||||
- tool: ToolConfig
|
||||
- memory: MemoryConfig
|
||||
- debug: DebugConfig
|
||||
- mood: MoodConfig
|
||||
- voice: VoiceConfig
|
||||
- jargon: JargonConfig
|
||||
- model_task_config: ModelTaskConfig
|
||||
@@ -164,9 +136,7 @@ async def get_config_section_schema(section_name: str):
|
||||
"tool": ToolConfig,
|
||||
"memory": MemoryConfig,
|
||||
"debug": DebugConfig,
|
||||
"mood": MoodConfig,
|
||||
"voice": VoiceConfig,
|
||||
"jargon": JargonConfig,
|
||||
"model_task_config": ModelTaskConfig,
|
||||
"api_provider": APIProvider,
|
||||
"model_info": ModelInfo,
|
||||
@@ -188,7 +158,7 @@ async def get_config_section_schema(section_name: str):
|
||||
|
||||
|
||||
@router.get("/bot")
|
||||
async def get_bot_config():
|
||||
async def get_bot_config(_auth: bool = Depends(require_auth)):
|
||||
"""获取麦麦主程序配置"""
|
||||
try:
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
@@ -207,7 +177,7 @@ async def get_bot_config():
|
||||
|
||||
|
||||
@router.get("/model")
|
||||
async def get_model_config():
|
||||
async def get_model_config(_auth: bool = Depends(require_auth)):
|
||||
"""获取模型配置(包含提供商和模型任务配置)"""
|
||||
try:
|
||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
@@ -229,7 +199,7 @@ async def get_model_config():
|
||||
|
||||
|
||||
@router.post("/bot")
|
||||
async def update_bot_config(config_data: ConfigBody):
|
||||
async def update_bot_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
|
||||
"""更新麦麦主程序配置"""
|
||||
try:
|
||||
# 验证配置数据
|
||||
@@ -238,7 +208,7 @@ async def update_bot_config(config_data: ConfigBody):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置文件(格式化数组为多行)
|
||||
# 保存配置文件(自动保留注释和格式)
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
save_toml_with_format(config_data, config_path)
|
||||
|
||||
@@ -252,7 +222,7 @@ async def update_bot_config(config_data: ConfigBody):
|
||||
|
||||
|
||||
@router.post("/model")
|
||||
async def update_model_config(config_data: ConfigBody):
|
||||
async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
|
||||
"""更新模型配置"""
|
||||
try:
|
||||
# 验证配置数据
|
||||
@@ -261,7 +231,7 @@ async def update_model_config(config_data: ConfigBody):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置文件(格式化数组为多行)
|
||||
# 保存配置文件(自动保留注释和格式)
|
||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
save_toml_with_format(config_data, config_path)
|
||||
|
||||
@@ -278,7 +248,7 @@ async def update_model_config(config_data: ConfigBody):
|
||||
|
||||
|
||||
@router.post("/bot/section/{section_name}")
|
||||
async def update_bot_config_section(section_name: str, section_data: SectionBody):
|
||||
async def update_bot_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)):
|
||||
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
|
||||
try:
|
||||
# 读取现有配置
|
||||
@@ -300,7 +270,7 @@ async def update_bot_config_section(section_name: str, section_data: SectionBody
|
||||
config_data[section_name] = section_data
|
||||
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
|
||||
# 字典递归合并
|
||||
_update_dict_preserve_comments(config_data[section_name], section_data)
|
||||
_update_toml_doc(config_data[section_name], section_data)
|
||||
else:
|
||||
# 其他类型直接替换
|
||||
config_data[section_name] = section_data
|
||||
@@ -327,7 +297,7 @@ async def update_bot_config_section(section_name: str, section_data: SectionBody
|
||||
|
||||
|
||||
@router.get("/bot/raw")
|
||||
async def get_bot_config_raw():
|
||||
async def get_bot_config_raw(_auth: bool = Depends(require_auth)):
|
||||
"""获取麦麦主程序配置的原始 TOML 内容"""
|
||||
try:
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
@@ -346,7 +316,7 @@ async def get_bot_config_raw():
|
||||
|
||||
|
||||
@router.post("/bot/raw")
|
||||
async def update_bot_config_raw(raw_content: RawContentBody):
|
||||
async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depends(require_auth)):
|
||||
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
|
||||
try:
|
||||
# 验证 TOML 格式
|
||||
@@ -376,7 +346,9 @@ async def update_bot_config_raw(raw_content: RawContentBody):
|
||||
|
||||
|
||||
@router.post("/model/section/{section_name}")
|
||||
async def update_model_config_section(section_name: str, section_data: SectionBody):
|
||||
async def update_model_config_section(
|
||||
section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)
|
||||
):
|
||||
"""更新模型配置的指定节(保留注释和格式)"""
|
||||
try:
|
||||
# 读取现有配置
|
||||
@@ -398,7 +370,7 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
|
||||
config_data[section_name] = section_data
|
||||
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
|
||||
# 字典递归合并
|
||||
_update_dict_preserve_comments(config_data[section_name], section_data)
|
||||
_update_toml_doc(config_data[section_name], section_data)
|
||||
else:
|
||||
# 其他类型直接替换
|
||||
config_data[section_name] = section_data
|
||||
@@ -407,6 +379,17 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
|
||||
try:
|
||||
APIAdapterConfig.from_dict(config_data)
|
||||
except Exception as e:
|
||||
logger.error(f"配置数据验证失败,详细错误: {str(e)}")
|
||||
# 特殊处理:如果是更新 api_providers,检查是否有模型引用了已删除的provider
|
||||
if section_name == "api_providers" and "api_provider" in str(e):
|
||||
provider_names = {p.get("name") for p in section_data if isinstance(p, dict)}
|
||||
models = config_data.get("models", [])
|
||||
orphaned_models = [
|
||||
m.get("name") for m in models if isinstance(m, dict) and m.get("api_provider") not in provider_names
|
||||
]
|
||||
if orphaned_models:
|
||||
error_msg = f"以下模型引用了已删除的提供商: {', '.join(orphaned_models)}。请先在模型管理页面删除这些模型,或重新分配它们的提供商。"
|
||||
raise HTTPException(status_code=400, detail=error_msg) from e
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置(格式化数组为多行,保留注释)
|
||||
@@ -457,7 +440,7 @@ def _to_relative_path(path: str) -> str:
|
||||
|
||||
|
||||
@router.get("/adapter-config/path")
|
||||
async def get_adapter_config_path():
|
||||
async def get_adapter_config_path(_auth: bool = Depends(require_auth)):
|
||||
"""获取保存的适配器配置文件路径"""
|
||||
try:
|
||||
# 从 data/webui.json 读取路径偏好
|
||||
@@ -496,7 +479,7 @@ async def get_adapter_config_path():
|
||||
|
||||
|
||||
@router.post("/adapter-config/path")
|
||||
async def save_adapter_config_path(data: PathBody):
|
||||
async def save_adapter_config_path(data: PathBody, _auth: bool = Depends(require_auth)):
|
||||
"""保存适配器配置文件路径偏好"""
|
||||
try:
|
||||
path = data.get("path")
|
||||
@@ -539,7 +522,7 @@ async def save_adapter_config_path(data: PathBody):
|
||||
|
||||
|
||||
@router.get("/adapter-config")
|
||||
async def get_adapter_config(path: str):
|
||||
async def get_adapter_config(path: str, _auth: bool = Depends(require_auth)):
|
||||
"""从指定路径读取适配器配置文件"""
|
||||
try:
|
||||
if not path:
|
||||
@@ -571,7 +554,7 @@ async def get_adapter_config(path: str):
|
||||
|
||||
|
||||
@router.post("/adapter-config")
|
||||
async def save_adapter_config(data: PathBody):
|
||||
async def save_adapter_config(data: PathBody, _auth: bool = Depends(require_auth)):
|
||||
"""保存适配器配置到指定路径"""
|
||||
try:
|
||||
path = data.get("path")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user