diff --git a/pybasicbayes/distributions/regression.py b/pybasicbayes/distributions/regression.py index 9e540272ab1e863718929a5966a3281698277bcf..5ae35abadf178c13227fdd7187b9e1c09100b46a 100644 --- a/pybasicbayes/distributions/regression.py +++ b/pybasicbayes/distributions/regression.py @@ -564,6 +564,10 @@ class DiagonalRegression(Regression, MeanFieldSVI): # Cache the standard parameters for A as well self._mf_A_cache = {} + # Store the natural hypparams. These correspond to the suff. stats + # (y^2, yxT, xxT, n) + # self.natural_hypparam = (2 * self.beta_0, self.h_0, self.J_0, 1.0) + @property def D_out(self): return self._D_out @@ -761,7 +765,6 @@ class DiagonalRegression(Regression, MeanFieldSVI): ysq, yxT, xxT, n = stats - assert np.all(n > 0), "Cannot perform max likelihood with zero data points!" self.A = np.array([ np.linalg.solve(self.J_0 + xxTd, self.h_0 + yxTd) for xxTd, yxTd in zip(xxT, yxT) @@ -787,6 +790,11 @@ class DiagonalRegression(Regression, MeanFieldSVI): self._meanfieldupdate_A(stats) self._meanfieldupdate_sigma(stats) + # Update A and sigmasq_flat + A, _, sigmasq_inv, _ = self.mf_expectations + self.A = A.copy() + self.sigmasq_flat = 1. / sigmasq_inv + def _meanfieldupdate_A(self, stats, prob=1.0, stepsize=1.0): E_sigmasq_inv = self.mf_alpha / self.mf_beta _, E_yxT, E_xxT, _ = stats / prob @@ -890,6 +898,8 @@ class DiagonalRegression(Regression, MeanFieldSVI): self._meanfieldupdate_sigma(stats, prob=prob, stepsize=stepsize) + + class _ARMixin(object): @property def nlags(self): diff --git a/pybasicbayes/models/factor_analysis.py b/pybasicbayes/models/factor_analysis.py index 0fa11857b16f01d0f4f12123e80115260686cc93..69a0624742be8dfb58e215cb3949c3decf8a3cae 100644 --- a/pybasicbayes/models/factor_analysis.py +++ b/pybasicbayes/models/factor_analysis.py @@ -23,7 +23,7 @@ class FactorAnalysisStates(object): """ Wrapper for the latent states of a factor analysis model """ - def __init__(self, model, data, mask=None): + def __init__(self, model, data, mask=None, **kwargs): self.model = model self.X = data self.mask = mask @@ -142,6 +142,7 @@ class FactorAnalysisStates(object): class _FactorAnalysisBase(Model): __metaclass__ = abc.ABCMeta + _states_class = FactorAnalysisStates def __init__(self, D_obs, D_latent, W=None, sigmasq=None, @@ -169,8 +170,8 @@ class _FactorAnalysisBase(Model): def sigmasq(self): return self.regression.sigmasq_flat - def add_data(self, data, mask=None): - self.data_list.append(FactorAnalysisStates(self, data, mask=mask)) + def add_data(self, data, mask=None, **kwargs): + self.data_list.append(self._states_class(self, data, mask=mask, **kwargs)) return self.data_list[-1] def generate(self, keep=True, N=1, mask=None, **kwargs): @@ -179,7 +180,7 @@ class _FactorAnalysisBase(Model): Z = np.random.randn(N, self.D_latent) X = np.dot(Z, W.T) + np.sqrt(sigmasq) * np.random.randn(N, self.D_obs) - data = FactorAnalysisStates(self, X, mask=mask) + data = self._states_class(self, X, mask=mask, **kwargs) data.Z = Z if keep: self.data_list.append(data) diff --git a/pybasicbayes/models/mixture.py b/pybasicbayes/models/mixture.py index 1675f01299e1d024cd4f2228d7d3dfdb4d82b7ca..43461c4a7c78ae06921d73ac8c2d4f188700311d 100644 --- a/pybasicbayes/models/mixture.py +++ b/pybasicbayes/models/mixture.py @@ -477,7 +477,7 @@ class Mixture(ModelGibbsSampling, ModelMeanField, ModelEM, ModelParallelTemperin [l.expectations[:,idx] for l in self.labels_list]) # mixture weights - self.weights.max_likelihood(np.arange(len(self.components)), + self.weights.max_likelihood(None, [l.expectations for l in self.labels_list]) @property