Files
mai-bot/src/learners/jargon_explainer.py
2026-03-24 11:36:26 +08:00

87 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from typing import Optional, Dict, List
from sqlmodel import select, func as fn
import json
from src.common.database.database import get_db_session
from src.common.database.database_model import Jargon
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("jargon_explainer")
def search_jargon(
keyword: str,
chat_id: Optional[str] = None,
limit: int = 10,
case_sensitive: bool = False,
fuzzy: bool = True,
) -> List[Dict[str, str]]:
"""
搜索 jargon支持大小写不敏感和模糊搜索
Args:
keyword: 搜索关键词
chat_id: 可选的聊天 IDsession_id
- 如果开启了 all_global此参数被忽略查询所有 is_global=True 的记录
- 如果关闭了 all_global如果提供则优先搜索该聊天或 global 的 jargon
limit: 返回结果数量限制,默认 10
case_sensitive: 是否大小写敏感,默认 False不敏感
fuzzy: 是否模糊搜索,默认 True使用 LIKE 匹配)
Returns:
List[Dict[str, str]]: 包含 content, meaning 的字典列表
"""
if not keyword or not keyword.strip():
return []
keyword = keyword.strip()
# 构建搜索条件
if case_sensitive: # 大小写敏感
search_condition = Jargon.content.contains(keyword) if fuzzy else Jargon.content == keyword # type: ignore
else:
keyword_lower = keyword.lower()
search_condition = (
fn.LOWER(Jargon.content).contains(keyword_lower) if fuzzy else fn.LOWER(Jargon.content) == keyword_lower
)
# 根据 all_global 配置决定查询逻辑同时,限制结果数量(先多取一些,因为后面可能过滤)
if global_config.expression.all_global_jargon:
# 开启 all_global所有记录都是全局的查询所有 is_global=True 的记录(无视 chat_id
query = select(Jargon).where(search_condition, Jargon.is_global).order_by(Jargon.count.desc()).limit(limit * 2) # type: ignore
else:
# 关闭 all_global查询所有记录chat_id 过滤在 Python 层面进行
query = select(Jargon).where(search_condition).order_by(Jargon.count.desc()).limit(limit * 2) # type: ignore
# 执行查询并返回结果
results: List[Dict[str, str]] = []
with get_db_session() as session:
jargons = session.exec(query).all()
for jargon in jargons:
# 如果提供了 chat_id 且 all_global=False需要检查 session_id_dict 是否包含目标 chat_id
if chat_id and not global_config.expression.all_global_jargon and not jargon.is_global:
try: # 解析 session_id_dict
session_id_dict = json.loads(jargon.session_id_dict) if jargon.session_id_dict else {}
except (json.JSONDecodeError, TypeError):
session_id_dict = {}
logger.warning(
f"解析 session_id_dict 失败jargon_id={jargon.id},原始数据:{jargon.session_id_dict}"
)
# 检查是否包含目标 chat_id
if chat_id not in session_id_dict:
continue
# 只返回有 meaning 的记录
if not jargon.meaning.strip():
continue
results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""})
# 达到限制数量后停止
if len(results) >= limit:
break
return results