feat:表达方式更新,现在会训练朴素贝叶斯模型来预测使用什么表达

This commit is contained in:
SengokuCola
2025-10-11 02:03:03 +08:00
parent 400296ade1
commit 958d6e04ee
20 changed files with 2372 additions and 443 deletions

View File

@@ -0,0 +1,60 @@
import math
from typing import Dict, List
from collections import defaultdict, Counter
class OnlineNaiveBayes:
def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000):
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.V = vocab_size
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float)) # cid -> term -> count
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
def _invalidate(self, cid: str):
if cid in self._logZ:
del self._logZ[cid]
def _logZ_c(self, cid: str) -> float:
if cid not in self._logZ:
Z = self.cls_counts[cid] + self.V * self.alpha
self._logZ[cid] = math.log(max(Z, 1e-12))
return self._logZ[cid]
def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]:
total_cls = sum(self.cls_counts.values())
n_cls = max(1, len(self.cls_counts))
denom_prior = math.log(total_cls + self.beta * n_cls)
out: Dict[str, float] = {}
for cid in cids:
prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior
s = prior
logZ = self._logZ_c(cid)
tc = self.token_counts[cid]
for term, qtf in tf.items():
num = tc.get(term, 0.0) + self.alpha
s += qtf * (math.log(num) - logZ)
out[cid] = s
return out
def update_positive(self, tf: Counter, cid: str):
inc = 0.0
tc = self.token_counts[cid]
for term, c in tf.items():
tc[term] += float(c)
inc += float(c)
self.cls_counts[cid] += inc
self._invalidate(cid)
def decay(self, factor: float = None):
g = self.gamma if factor is None else factor
if g >= 1.0:
return
for cid in list(self.cls_counts.keys()):
self.cls_counts[cid] *= g
for term in list(self.token_counts[cid].keys()):
self.token_counts[cid][term] *= g
self._invalidate(cid)