Ruff fix
This commit is contained in:
@@ -2,6 +2,7 @@ import math
|
||||
from typing import Dict, List
|
||||
from collections import defaultdict, Counter
|
||||
|
||||
|
||||
class OnlineNaiveBayes:
|
||||
def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000):
|
||||
self.alpha = alpha
|
||||
@@ -9,9 +10,9 @@ class OnlineNaiveBayes:
|
||||
self.gamma = gamma
|
||||
self.V = vocab_size
|
||||
|
||||
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
|
||||
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
|
||||
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float)) # cid -> term -> count
|
||||
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
|
||||
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
|
||||
|
||||
def _invalidate(self, cid: str):
|
||||
if cid in self._logZ:
|
||||
@@ -57,4 +58,4 @@ class OnlineNaiveBayes:
|
||||
self.cls_counts[cid] *= g
|
||||
for term in list(self.token_counts[cid].keys()):
|
||||
self.token_counts[cid][term] *= g
|
||||
self._invalidate(cid)
|
||||
self._invalidate(cid)
|
||||
|
||||
Reference in New Issue
Block a user