diff --git a/pybasicbayes/distributions/multinomial.py b/pybasicbayes/distributions/multinomial.py index f779c84ae757fc74bf01a29da14d55646cc6a1de..e13624fe83aee65bcdbd89a3fd3ab6a8600d7512 100644 --- a/pybasicbayes/distributions/multinomial.py +++ b/pybasicbayes/distributions/multinomial.py @@ -111,13 +111,13 @@ class Categorical(GibbsSampling, MeanField, MeanFieldSVI, MaxLikelihood, MAP): self.weights = np.random.dirichlet(self.alphav_0 + counts) except ZeroDivisionError as e: # print("ZeroDivisionError {}".format(e)) - self.weights = np.random.dirichlet(self.alphav_0 + 0.03 + counts) + self.weights = np.random.dirichlet(self.alphav_0 + 0.01 + counts) except ValueError as e: # print("ValueError {}".format(e)) - self.weights = np.random.dirichlet(self.alphav_0 + 0.03 + counts) + self.weights = np.random.dirichlet(self.alphav_0 + 0.01 + counts) if np.isnan(self.weights).any(): - self.weights = np.random.dirichlet(self.alphav_0 + 0.03 + counts) - np.clip(self.weights, np.spacing(1.), np.inf, out=self.weights) + self.weights = np.random.dirichlet(self.alphav_0 + 0.01 + counts) + np.clip(self.weights, np.spacing(1.), 1-np.spacing(1.), out=self.weights) # NOTE: next line is so we can use Gibbs sampling to initialize mean field self._alpha_mf = self.weights * self.alphav_0.sum() assert (self._alpha_mf >= 0.).all()