Files
mai-bot/src/express/expressor_model/model.py

131 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from typing import Dict, Optional, Tuple, List
from collections import Counter, defaultdict
import pickle
import os
from .tokenizer import Tokenizer
from .online_nb import OnlineNaiveBayes
class ExpressorModel:
"""
直接使用朴素贝叶斯精排(可在线学习)
支持存储situation字段不参与计算仅与style对应
"""
def __init__(self,
alpha: float = 0.5,
beta: float = 0.5,
gamma: float = 1.0,
vocab_size: int = 200000,
use_jieba: bool = True):
self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba)
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
self._candidates: Dict[str, str] = {} # cid -> text (style)
self._situations: Dict[str, str] = {} # cid -> situation (不参与计算)
def add_candidate(self, cid: str, text: str, situation: str = None):
"""添加候选文本和对应的situation"""
self._candidates[cid] = text
if situation is not None:
self._situations[cid] = situation
# 确保在nb模型中初始化该候选的计数
if cid not in self.nb.cls_counts:
self.nb.cls_counts[cid] = 0.0
if cid not in self.nb.token_counts:
self.nb.token_counts[cid] = defaultdict(float)
def add_candidates_bulk(self, items: List[Tuple[str, str]], situations: List[str] = None):
"""批量添加候选文本和对应的situations"""
for i, (cid, text) in enumerate(items):
situation = situations[i] if situations and i < len(situations) else None
self.add_candidate(cid, text, situation)
def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]:
"""直接对所有候选进行朴素贝叶斯评分"""
toks = self.tokenizer.tokenize(text)
if not toks:
return None, {}
if not self._candidates:
return None, {}
# 对所有候选进行评分
tf = Counter(toks)
all_cids = list(self._candidates.keys())
scores = self.nb.score_batch(tf, all_cids)
# 取最高分
if not scores:
return None, {}
best = max(scores.items(), key=lambda x: x[1])[0]
return best, scores
def update_positive(self, text: str, cid: str):
"""更新正反馈学习"""
toks = self.tokenizer.tokenize(text)
if not toks:
return
tf = Counter(toks)
self.nb.update_positive(tf, cid)
def decay(self, factor: float):
self.nb.decay(factor=factor)
def get_situation(self, cid: str) -> Optional[str]:
"""获取候选对应的situation"""
return self._situations.get(cid)
def get_style(self, cid: str) -> Optional[str]:
"""获取候选对应的style"""
return self._candidates.get(cid)
def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]:
"""获取候选的style和situation信息"""
return self._candidates.get(cid), self._situations.get(cid)
def get_all_candidates(self) -> Dict[str, Tuple[str, Optional[str]]]:
"""获取所有候选的style和situation信息"""
return {cid: (style, self._situations.get(cid))
for cid, style in self._candidates.items()}
def save(self, path: str):
"""保存模型"""
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "wb") as f:
pickle.dump({
"candidates": self._candidates,
"situations": self._situations,
"nb": {
"cls_counts": dict(self.nb.cls_counts),
"token_counts": {cid: dict(tc) for cid, tc in self.nb.token_counts.items()},
"alpha": self.nb.alpha,
"beta": self.nb.beta,
"gamma": self.nb.gamma,
"V": self.nb.V,
}
}, f)
def load(self, path: str):
"""加载模型"""
with open(path, "rb") as f:
obj = pickle.load(f)
# 还原候选文本
self._candidates = obj["candidates"]
# 还原situations兼容旧版本
self._situations = obj.get("situations", {})
# 还原朴素贝叶斯模型
self.nb.cls_counts = obj["nb"]["cls_counts"]
self.nb.token_counts = defaultdict_dict(obj["nb"]["token_counts"])
self.nb.alpha = obj["nb"]["alpha"]
self.nb.beta = obj["nb"]["beta"]
self.nb.gamma = obj["nb"]["gamma"]
self.nb.V = obj["nb"]["V"]
self.nb._logZ.clear()
def defaultdict_dict(d: Dict[str, Dict[str, float]]):
from collections import defaultdict
outer = defaultdict(lambda: defaultdict(float))
for k, inner in d.items():
outer[k].update(inner)
return outer