feat:表达方式更新,现在会训练朴素贝叶斯模型来预测使用什么表达
This commit is contained in:
131
src/express/expressor_model/model.py
Normal file
131
src/express/expressor_model/model.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from typing import Dict, Optional, Tuple, List
|
||||
from collections import Counter, defaultdict
|
||||
import pickle
|
||||
import os
|
||||
|
||||
from .tokenizer import Tokenizer
|
||||
from .online_nb import OnlineNaiveBayes
|
||||
|
||||
class ExpressorModel:
|
||||
"""
|
||||
直接使用朴素贝叶斯精排(可在线学习)
|
||||
支持存储situation字段,不参与计算,仅与style对应
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
alpha: float = 0.5,
|
||||
beta: float = 0.5,
|
||||
gamma: float = 1.0,
|
||||
vocab_size: int = 200000,
|
||||
use_jieba: bool = True):
|
||||
self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba)
|
||||
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
|
||||
self._candidates: Dict[str, str] = {} # cid -> text (style)
|
||||
self._situations: Dict[str, str] = {} # cid -> situation (不参与计算)
|
||||
|
||||
def add_candidate(self, cid: str, text: str, situation: str = None):
|
||||
"""添加候选文本和对应的situation"""
|
||||
self._candidates[cid] = text
|
||||
if situation is not None:
|
||||
self._situations[cid] = situation
|
||||
|
||||
# 确保在nb模型中初始化该候选的计数
|
||||
if cid not in self.nb.cls_counts:
|
||||
self.nb.cls_counts[cid] = 0.0
|
||||
if cid not in self.nb.token_counts:
|
||||
self.nb.token_counts[cid] = defaultdict(float)
|
||||
|
||||
def add_candidates_bulk(self, items: List[Tuple[str, str]], situations: List[str] = None):
|
||||
"""批量添加候选文本和对应的situations"""
|
||||
for i, (cid, text) in enumerate(items):
|
||||
situation = situations[i] if situations and i < len(situations) else None
|
||||
self.add_candidate(cid, text, situation)
|
||||
|
||||
def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]:
|
||||
"""直接对所有候选进行朴素贝叶斯评分"""
|
||||
toks = self.tokenizer.tokenize(text)
|
||||
if not toks:
|
||||
return None, {}
|
||||
|
||||
if not self._candidates:
|
||||
return None, {}
|
||||
|
||||
# 对所有候选进行评分
|
||||
tf = Counter(toks)
|
||||
all_cids = list(self._candidates.keys())
|
||||
scores = self.nb.score_batch(tf, all_cids)
|
||||
|
||||
# 取最高分
|
||||
if not scores:
|
||||
return None, {}
|
||||
best = max(scores.items(), key=lambda x: x[1])[0]
|
||||
return best, scores
|
||||
|
||||
def update_positive(self, text: str, cid: str):
|
||||
"""更新正反馈学习"""
|
||||
toks = self.tokenizer.tokenize(text)
|
||||
if not toks:
|
||||
return
|
||||
tf = Counter(toks)
|
||||
self.nb.update_positive(tf, cid)
|
||||
|
||||
def decay(self, factor: float):
|
||||
self.nb.decay(factor=factor)
|
||||
|
||||
def get_situation(self, cid: str) -> Optional[str]:
|
||||
"""获取候选对应的situation"""
|
||||
return self._situations.get(cid)
|
||||
|
||||
def get_style(self, cid: str) -> Optional[str]:
|
||||
"""获取候选对应的style"""
|
||||
return self._candidates.get(cid)
|
||||
|
||||
def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""获取候选的style和situation信息"""
|
||||
return self._candidates.get(cid), self._situations.get(cid)
|
||||
|
||||
def get_all_candidates(self) -> Dict[str, Tuple[str, Optional[str]]]:
|
||||
"""获取所有候选的style和situation信息"""
|
||||
return {cid: (style, self._situations.get(cid))
|
||||
for cid, style in self._candidates.items()}
|
||||
|
||||
def save(self, path: str):
|
||||
"""保存模型"""
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump({
|
||||
"candidates": self._candidates,
|
||||
"situations": self._situations,
|
||||
"nb": {
|
||||
"cls_counts": dict(self.nb.cls_counts),
|
||||
"token_counts": {cid: dict(tc) for cid, tc in self.nb.token_counts.items()},
|
||||
"alpha": self.nb.alpha,
|
||||
"beta": self.nb.beta,
|
||||
"gamma": self.nb.gamma,
|
||||
"V": self.nb.V,
|
||||
}
|
||||
}, f)
|
||||
|
||||
def load(self, path: str):
|
||||
"""加载模型"""
|
||||
with open(path, "rb") as f:
|
||||
obj = pickle.load(f)
|
||||
# 还原候选文本
|
||||
self._candidates = obj["candidates"]
|
||||
# 还原situations(兼容旧版本)
|
||||
self._situations = obj.get("situations", {})
|
||||
# 还原朴素贝叶斯模型
|
||||
self.nb.cls_counts = obj["nb"]["cls_counts"]
|
||||
self.nb.token_counts = defaultdict_dict(obj["nb"]["token_counts"])
|
||||
self.nb.alpha = obj["nb"]["alpha"]
|
||||
self.nb.beta = obj["nb"]["beta"]
|
||||
self.nb.gamma = obj["nb"]["gamma"]
|
||||
self.nb.V = obj["nb"]["V"]
|
||||
self.nb._logZ.clear()
|
||||
|
||||
def defaultdict_dict(d: Dict[str, Dict[str, float]]):
|
||||
from collections import defaultdict
|
||||
outer = defaultdict(lambda: defaultdict(float))
|
||||
for k, inner in d.items():
|
||||
outer[k].update(inner)
|
||||
return outer
|
||||
Reference in New Issue
Block a user