131 lines
4.9 KiB
Python
131 lines
4.9 KiB
Python
from typing import Dict, Optional, Tuple, List
|
||
from collections import Counter, defaultdict
|
||
import pickle
|
||
import os
|
||
|
||
from .tokenizer import Tokenizer
|
||
from .online_nb import OnlineNaiveBayes
|
||
|
||
class ExpressorModel:
|
||
"""
|
||
直接使用朴素贝叶斯精排(可在线学习)
|
||
支持存储situation字段,不参与计算,仅与style对应
|
||
"""
|
||
|
||
def __init__(self,
|
||
alpha: float = 0.5,
|
||
beta: float = 0.5,
|
||
gamma: float = 1.0,
|
||
vocab_size: int = 200000,
|
||
use_jieba: bool = True):
|
||
self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba)
|
||
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
|
||
self._candidates: Dict[str, str] = {} # cid -> text (style)
|
||
self._situations: Dict[str, str] = {} # cid -> situation (不参与计算)
|
||
|
||
def add_candidate(self, cid: str, text: str, situation: str = None):
|
||
"""添加候选文本和对应的situation"""
|
||
self._candidates[cid] = text
|
||
if situation is not None:
|
||
self._situations[cid] = situation
|
||
|
||
# 确保在nb模型中初始化该候选的计数
|
||
if cid not in self.nb.cls_counts:
|
||
self.nb.cls_counts[cid] = 0.0
|
||
if cid not in self.nb.token_counts:
|
||
self.nb.token_counts[cid] = defaultdict(float)
|
||
|
||
def add_candidates_bulk(self, items: List[Tuple[str, str]], situations: List[str] = None):
|
||
"""批量添加候选文本和对应的situations"""
|
||
for i, (cid, text) in enumerate(items):
|
||
situation = situations[i] if situations and i < len(situations) else None
|
||
self.add_candidate(cid, text, situation)
|
||
|
||
def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]:
|
||
"""直接对所有候选进行朴素贝叶斯评分"""
|
||
toks = self.tokenizer.tokenize(text)
|
||
if not toks:
|
||
return None, {}
|
||
|
||
if not self._candidates:
|
||
return None, {}
|
||
|
||
# 对所有候选进行评分
|
||
tf = Counter(toks)
|
||
all_cids = list(self._candidates.keys())
|
||
scores = self.nb.score_batch(tf, all_cids)
|
||
|
||
# 取最高分
|
||
if not scores:
|
||
return None, {}
|
||
best = max(scores.items(), key=lambda x: x[1])[0]
|
||
return best, scores
|
||
|
||
def update_positive(self, text: str, cid: str):
|
||
"""更新正反馈学习"""
|
||
toks = self.tokenizer.tokenize(text)
|
||
if not toks:
|
||
return
|
||
tf = Counter(toks)
|
||
self.nb.update_positive(tf, cid)
|
||
|
||
def decay(self, factor: float):
|
||
self.nb.decay(factor=factor)
|
||
|
||
def get_situation(self, cid: str) -> Optional[str]:
|
||
"""获取候选对应的situation"""
|
||
return self._situations.get(cid)
|
||
|
||
def get_style(self, cid: str) -> Optional[str]:
|
||
"""获取候选对应的style"""
|
||
return self._candidates.get(cid)
|
||
|
||
def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]:
|
||
"""获取候选的style和situation信息"""
|
||
return self._candidates.get(cid), self._situations.get(cid)
|
||
|
||
def get_all_candidates(self) -> Dict[str, Tuple[str, Optional[str]]]:
|
||
"""获取所有候选的style和situation信息"""
|
||
return {cid: (style, self._situations.get(cid))
|
||
for cid, style in self._candidates.items()}
|
||
|
||
def save(self, path: str):
|
||
"""保存模型"""
|
||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||
with open(path, "wb") as f:
|
||
pickle.dump({
|
||
"candidates": self._candidates,
|
||
"situations": self._situations,
|
||
"nb": {
|
||
"cls_counts": dict(self.nb.cls_counts),
|
||
"token_counts": {cid: dict(tc) for cid, tc in self.nb.token_counts.items()},
|
||
"alpha": self.nb.alpha,
|
||
"beta": self.nb.beta,
|
||
"gamma": self.nb.gamma,
|
||
"V": self.nb.V,
|
||
}
|
||
}, f)
|
||
|
||
def load(self, path: str):
|
||
"""加载模型"""
|
||
with open(path, "rb") as f:
|
||
obj = pickle.load(f)
|
||
# 还原候选文本
|
||
self._candidates = obj["candidates"]
|
||
# 还原situations(兼容旧版本)
|
||
self._situations = obj.get("situations", {})
|
||
# 还原朴素贝叶斯模型
|
||
self.nb.cls_counts = obj["nb"]["cls_counts"]
|
||
self.nb.token_counts = defaultdict_dict(obj["nb"]["token_counts"])
|
||
self.nb.alpha = obj["nb"]["alpha"]
|
||
self.nb.beta = obj["nb"]["beta"]
|
||
self.nb.gamma = obj["nb"]["gamma"]
|
||
self.nb.V = obj["nb"]["V"]
|
||
self.nb._logZ.clear()
|
||
|
||
def defaultdict_dict(d: Dict[str, Dict[str, float]]):
|
||
from collections import defaultdict
|
||
outer = defaultdict(lambda: defaultdict(float))
|
||
for k, inner in d.items():
|
||
outer[k].update(inner)
|
||
return outer |