feat:表达方式更新,现在会训练朴素贝叶斯模型来预测使用什么表达
This commit is contained in:
60
src/express/expressor_model/online_nb.py
Normal file
60
src/express/expressor_model/online_nb.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import math
|
||||
from typing import Dict, List
|
||||
from collections import defaultdict, Counter
|
||||
|
||||
class OnlineNaiveBayes:
|
||||
def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000):
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.gamma = gamma
|
||||
self.V = vocab_size
|
||||
|
||||
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
|
||||
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float)) # cid -> term -> count
|
||||
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
|
||||
|
||||
def _invalidate(self, cid: str):
|
||||
if cid in self._logZ:
|
||||
del self._logZ[cid]
|
||||
|
||||
def _logZ_c(self, cid: str) -> float:
|
||||
if cid not in self._logZ:
|
||||
Z = self.cls_counts[cid] + self.V * self.alpha
|
||||
self._logZ[cid] = math.log(max(Z, 1e-12))
|
||||
return self._logZ[cid]
|
||||
|
||||
def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]:
|
||||
total_cls = sum(self.cls_counts.values())
|
||||
n_cls = max(1, len(self.cls_counts))
|
||||
denom_prior = math.log(total_cls + self.beta * n_cls)
|
||||
|
||||
out: Dict[str, float] = {}
|
||||
for cid in cids:
|
||||
prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior
|
||||
s = prior
|
||||
logZ = self._logZ_c(cid)
|
||||
tc = self.token_counts[cid]
|
||||
for term, qtf in tf.items():
|
||||
num = tc.get(term, 0.0) + self.alpha
|
||||
s += qtf * (math.log(num) - logZ)
|
||||
out[cid] = s
|
||||
return out
|
||||
|
||||
def update_positive(self, tf: Counter, cid: str):
|
||||
inc = 0.0
|
||||
tc = self.token_counts[cid]
|
||||
for term, c in tf.items():
|
||||
tc[term] += float(c)
|
||||
inc += float(c)
|
||||
self.cls_counts[cid] += inc
|
||||
self._invalidate(cid)
|
||||
|
||||
def decay(self, factor: float = None):
|
||||
g = self.gamma if factor is None else factor
|
||||
if g >= 1.0:
|
||||
return
|
||||
for cid in list(self.cls_counts.keys()):
|
||||
self.cls_counts[cid] *= g
|
||||
for term in list(self.token_counts[cid].keys()):
|
||||
self.token_counts[cid][term] *= g
|
||||
self._invalidate(cid)
|
||||
Reference in New Issue
Block a user