This commit is contained in:
墨梓柒
2025-11-13 13:24:55 +08:00
parent e78a070fbd
commit 7839acd25d
52 changed files with 1322 additions and 1408 deletions

View File

@@ -2,6 +2,7 @@ import math
from typing import Dict, List
from collections import defaultdict, Counter
class OnlineNaiveBayes:
def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000):
self.alpha = alpha
@@ -9,9 +10,9 @@ class OnlineNaiveBayes:
self.gamma = gamma
self.V = vocab_size
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float)) # cid -> term -> count
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
def _invalidate(self, cid: str):
if cid in self._logZ:
@@ -57,4 +58,4 @@ class OnlineNaiveBayes:
self.cls_counts[cid] *= g
for term in list(self.token_counts[cid].keys()):
self.token_counts[cid][term] *= g
self._invalidate(cid)
self._invalidate(cid)