feat:添加ReAct记忆提取系统

This commit is contained in:
SengokuCola
2025-11-09 14:02:29 +08:00
parent d761d42dd7
commit 7a3f260cc3
12 changed files with 1463 additions and 45 deletions

View File

@@ -1,13 +1,14 @@
import time
import json
import asyncio
from typing import List
from typing import List, Dict, Optional
from json_repair import repair_json
from peewee import fn
from src.common.logger import get_logger
from src.common.database.database_model import Jargon
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.config.config import model_config, global_config
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import (
build_anonymous_messages,
@@ -21,28 +22,27 @@ logger = get_logger("jargon")
def _init_prompt() -> None:
prompt_str = """
**聊天内容**
**聊天内容其中的SELF是你自己的发言**
{chat_str}
请从上面这段聊天内容中提取"可能是黑话"的候选项(黑话/俚语/网络缩写/口头禅)。
- 必须为对话中真实出现过的短词或短语
- 必须是你无法理解含义的词语,或者出现频率较高的词语
- 必须是你无法理解含义的词语,没有明确含义的词语
- 请不要选择有明确含义,或者含义清晰的词语
- 必须是这几种类别之一:英文或中文缩写、中文拼音短语、字母数字混合
- 排除:人名、@、明显的表情/图片占位、纯标点、常规功能词(如的、了、呢、啊等)
- 必须是这几种类别之一:英文或中文缩写、中文拼音短语
- 排除:人名、@、表情/图片中的内容、纯标点、常规功能词(如的、了、呢、啊等)
- 每个词条长度建议 2-8 个字符(不强制),尽量短小
- 合并重复项,去重
分类规则,type必须根据规则填写
- p拼音缩写由字母或字母和汉字构成的,用汉语拼音简写词,或汉语拼音首字母的简写词例如nb、yyds、xswl
- c中文缩写中文词语的缩写用几个汉字概括一个词汇或含义例如社死、内卷
- p拼音缩写由字母构成的,汉语拼音首字母的简写词例如nb、yyds、xswl
- e英文缩写英文词语的缩写用英文字母概括一个词汇或含义例如CPU、GPU、API
- x谐音梗谐音梗用谐音词概括一个词汇或含义,例如:好似,难崩
- c中文缩写中文词语的缩写用几个汉字概括一个词汇或含义,例如:社死、内卷
以 JSON 数组输出,元素为对象(严格按以下结构):
[
{{"content": "词条", "raw_content": "包含该词条的完整对话原文", "type": "p"}},
{{"content": "词条2", "raw_content": "包含该词条的完整对话原文", "type": "c"}}
{{"content": "词条", "raw_content": "包含该词条的完整对话上下文原文", "type": "p"}},
{{"content": "词条2", "raw_content": "包含该词条的完整对话上下文原文", "type": "c"}}
]
现在请输出:
@@ -57,7 +57,7 @@ def _init_inference_prompts() -> None:
**词条内容**
{content}
**词条出现的上下文raw_content**
**词条出现的上下文raw_content其中的SELF是你自己的发言**
{raw_content_list}
请根据以上词条内容和上下文,推断这个词条的含义。
@@ -66,8 +66,8 @@ def _init_inference_prompts() -> None:
以 JSON 格式输出:
{{
"meaning": "含义说明",
"translation": "翻译或解释"
"meaning": "详细含义说明(包含使用场景、来源、具体解释等)",
"translation": "原文(用一个词语写明这个词的实际含义)"
}}
"""
Prompt(prompt1_str, "jargon_inference_with_context_prompt")
@@ -83,8 +83,8 @@ def _init_inference_prompts() -> None:
以 JSON 格式输出:
{{
"meaning": "含义说明",
"translation": "翻译或解释"
"meaning": "详细含义说明(包含使用场景、来源、具体解释等)",
"translation": "原文(用一个词语写明这个词的实际含义)"
}}
"""
Prompt(prompt2_str, "jargon_inference_content_only_prompt")
@@ -117,7 +117,7 @@ _init_inference_prompts()
def _should_infer_meaning(jargon_obj: Jargon) -> bool:
"""
判断是否需要进行含义推断
在 count 达到 5, 10, 20, 40, 60, 100 时进行推断
在 count 达到 3,6, 10, 20, 40, 60, 100 时进行推断
并且count必须大于last_inference_count避免重启后重复判定
如果is_complete为True不再进行推断
"""
@@ -128,8 +128,8 @@ def _should_infer_meaning(jargon_obj: Jargon) -> bool:
count = jargon_obj.count or 0
last_inference = jargon_obj.last_inference_count or 0
# 阈值列表:5, 10, 20, 40, 60, 100
thresholds = [5, 10, 20, 40, 60, 100]
# 阈值列表:3,6, 10, 20, 40, 60, 100
thresholds = [3,6, 10, 20, 40, 60, 100]
if count < thresholds[0]:
return False
@@ -165,6 +165,11 @@ class JargonMiner:
model_set=model_config.model_task_config.utils,
request_type="jargon.extract",
)
# 初始化stream_name作为类属性避免重复提取
chat_manager = get_chat_manager()
stream_name = chat_manager.get_stream_name(self.chat_id)
self.stream_name = stream_name if stream_name else self.chat_id
async def _infer_meaning_by_id(self, jargon_id: int) -> None:
"""通过ID加载对象并推断"""
@@ -255,12 +260,14 @@ class JargonMiner:
except Exception as e:
logger.error(f"jargon {content} 推断2解析失败: {e}")
return
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
logger.info(f"jargon {content} 推断2结果: {response2}")
# logger.info(f"jargon {content} 推断2结果: {inference2}")
logger.info(f"jargon {content} 推断1提示词: {prompt1}")
logger.info(f"jargon {content} 推断1结果: {response1}")
# logger.info(f"jargon {content} 推断1结果: {inference1}")
if global_config.debug.show_jargon_prompt:
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
logger.info(f"jargon {content} 推断2结果: {response2}")
# logger.info(f"jargon {content} 推断2结果: {inference2}")
logger.info(f"jargon {content} 推断1提示词: {prompt1}")
logger.info(f"jargon {content} 推断1结果: {response1}")
# logger.info(f"jargon {content} 推断1结果: {inference1}")
# 步骤3: 比较两个推断结果
prompt3 = await global_prompt_manager.format_prompt(
@@ -269,7 +276,8 @@ class JargonMiner:
inference2=json.dumps(inference2, ensure_ascii=False),
)
logger.info(f"jargon {content} 比较提示词: {prompt3}")
if global_config.debug.show_jargon_prompt:
logger.info(f"jargon {content} 比较提示词: {prompt3}")
response3, _ = await self.llm.generate_response_async(prompt3, temperature=0.3)
if not response3:
@@ -317,6 +325,20 @@ class JargonMiner:
jargon_obj.save()
logger.info(f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}")
# 固定输出推断结果,格式化为可读形式
if is_jargon:
# 是黑话,输出格式:[聊天名]xxx (translation)的含义是 xxxxxxxxxxx
translation = jargon_obj.translation or "未知"
meaning = jargon_obj.meaning or "无详细说明"
is_global = jargon_obj.is_global
if is_global:
logger.info(f"[通用黑话]{content} ({translation})的含义是 {meaning}")
else:
logger.info(f"[{self.stream_name}]{content} ({translation})的含义是 {meaning}")
else:
# 不是黑话,输出格式:[聊天名]xxx 不是黑话
logger.info(f"[{self.stream_name}]{content} 不是黑话")
except Exception as e:
logger.error(f"jargon推断失败: {e}")
import traceback
@@ -371,8 +393,9 @@ class JargonMiner:
if not response:
return
logger.info(f"jargon提取提示词: {prompt}")
logger.info(f"jargon提取结果: {response}")
if global_config.debug.show_jargon_prompt:
logger.info(f"jargon提取提示词: {prompt}")
logger.info(f"jargon提取结果: {response}")
# 解析为JSON
entries: List[dict] = []
@@ -404,6 +427,8 @@ class JargonMiner:
raw_content_list = []
if isinstance(raw_content_value, list):
raw_content_list = [str(rc).strip() for rc in raw_content_value if str(rc).strip()]
# 去重
raw_content_list = list(dict.fromkeys(raw_content_list))
elif isinstance(raw_content_value, str):
raw_content_str = raw_content_value.strip()
if raw_content_str:
@@ -585,10 +610,20 @@ class JargonMiner:
logger.error(f"保存jargon失败: chat_id={self.chat_id}, content={content}, err={e}")
continue
if saved or updated or merged:
logger.info(f"jargon写入: 新增 {saved} 条,更新 {updated}合并为global {merged}chat_id={self.chat_id}")
# 固定输出提取的jargon结果格式化为可读形式只要有提取结果就输出
if uniq_entries:
# 收集所有提取的jargon内容
jargon_list = [entry["content"] for entry in uniq_entries]
jargon_str = ",".join(jargon_list)
# 输出格式化的结果使用logger.info会自动应用jargon模块的颜色
logger.info(f"[{self.stream_name}]疑似黑话: {jargon_str}")
# 更新为本次提取的结束时间,确保不会重复提取相同的消息窗口
self.last_learning_time = extraction_end_time
if saved or updated or merged:
logger.info(f"jargon写入: 新增 {saved} 条,更新 {updated}合并为global {merged}chat_id={self.chat_id}")
except Exception as e:
logger.error(f"JargonMiner 运行失败: {e}")
@@ -611,3 +646,88 @@ async def extract_and_store_jargon(chat_id: str) -> None:
await miner.run_once()
def search_jargon(
keyword: str,
chat_id: Optional[str] = None,
limit: int = 10,
case_sensitive: bool = False,
fuzzy: bool = True
) -> List[Dict[str, str]]:
"""
搜索jargon支持大小写不敏感和模糊搜索
Args:
keyword: 搜索关键词
chat_id: 可选的聊天ID如果提供则优先搜索该聊天或global的jargon
limit: 返回结果数量限制默认10
case_sensitive: 是否大小写敏感默认False不敏感
fuzzy: 是否模糊搜索默认True使用LIKE匹配
Returns:
List[Dict[str, str]]: 包含content, translation, meaning的字典列表
"""
if not keyword or not keyword.strip():
return []
keyword = keyword.strip()
# 构建查询
query = Jargon.select(
Jargon.content,
Jargon.translation,
Jargon.meaning
)
# 构建搜索条件
if case_sensitive:
# 大小写敏感
if fuzzy:
# 模糊搜索
search_condition = Jargon.content.contains(keyword)
else:
# 精确匹配
search_condition = (Jargon.content == keyword)
else:
# 大小写不敏感
if fuzzy:
# 模糊搜索使用LOWER函数
search_condition = fn.LOWER(Jargon.content).contains(keyword.lower())
else:
# 精确匹配使用LOWER函数
search_condition = (fn.LOWER(Jargon.content) == keyword.lower())
query = query.where(search_condition)
# 如果提供了chat_id优先搜索该聊天或global的jargon
if chat_id:
query = query.where(
(Jargon.chat_id == chat_id) | Jargon.is_global
)
# 只返回有translation或meaning的记录
query = query.where(
(
(Jargon.translation.is_null(False)) & (Jargon.translation != "")
) | (
(Jargon.meaning.is_null(False)) & (Jargon.meaning != "")
)
)
# 按count降序排序优先返回出现频率高的
query = query.order_by(Jargon.count.desc())
# 限制结果数量
query = query.limit(limit)
# 执行查询并返回结果
results = []
for jargon in query:
results.append({
"content": jargon.content or "",
"translation": jargon.translation or "",
"meaning": jargon.meaning or ""
})
return results