diff --git a/examples/factor_analysis.py b/examples/factor_analysis.py index 5e3ac803e03d3a667cd10f62a62a682c35a144c0..bbc93d9094b260d18523fa169360f168810e74dc 100644 --- a/examples/factor_analysis.py +++ b/examples/factor_analysis.py @@ -6,7 +6,6 @@ import matplotlib.pyplot as plt from matplotlib.cm import get_cmap import pybasicbayes.models.factor_analysis -reload(pybasicbayes.models.factor_analysis) from pybasicbayes.models.factor_analysis import FactorAnalysis N = 2000 @@ -29,8 +28,8 @@ def generate_synth_data(): # Create a true model and sample from it mask = np.random.rand(N,D_obs) < 0.9 true_model = FactorAnalysis(D_obs, D_latent) - true_data = true_model.generate(N=N, mask=mask, keep=True) - return true_model, true_data + X, Z_true = true_model.generate(N=N, mask=mask, keep=True) + return true_model, X, Z_true, mask def plot_results(lls, angles, Ztrue, Zinf): @@ -69,15 +68,14 @@ def plot_results(lls, angles, Ztrue, Zinf): plt.show() -def gibbs_example(true_model, true_data): - X, mask = true_data.X, true_data.mask - +def gibbs_example(true_model, X, Z_true, mask): # Fit a test model model = FactorAnalysis( D_obs, D_latent, # W=true_model.W, sigmasq=true_model.sigmasq ) inf_data = model.add_data(X, mask=mask) + model.set_empirical_mean() lps = [] angles = [] @@ -87,17 +85,16 @@ def gibbs_example(true_model, true_data): lps.append(model.log_likelihood()) angles.append(principal_angle(true_model.W, model.W)) - plot_results(lps, angles, true_data.Z, inf_data.Z) - -def em_example(true_model, true_data): - X, mask = true_data.X, true_data.mask + plot_results(lps, angles, Z_true, inf_data.Z) +def em_example(true_model, X, Z_true, mask): # Fit a test model model = FactorAnalysis( D_obs, D_latent, # W=true_model.W, sigmasq=true_model.sigmasq ) inf_data = model.add_data(X, mask=mask) + model.set_empirical_mean() lps = [] angles = [] @@ -107,17 +104,16 @@ def em_example(true_model, true_data): lps.append(model.log_likelihood()) angles.append(principal_angle(true_model.W, model.W)) - plot_results(lps, angles, true_data.Z, inf_data.E_Z) - -def meanfield_example(true_model, true_data): - X, mask = true_data.X, true_data.mask + plot_results(lps, angles, Z_true, inf_data.E_Z) +def meanfield_example(true_model, X, Z_true, mask): # Fit a test model model = FactorAnalysis( D_obs, D_latent, # W=true_model.W, sigmasq=true_model.sigmasq ) inf_data = model.add_data(X, mask=mask) + model.set_empirical_mean() lps = [] angles = [] @@ -128,11 +124,9 @@ def meanfield_example(true_model, true_data): E_W, _, _, _ = model.regression.mf_expectations angles.append(principal_angle(true_model.W, E_W)) - plot_results(lps, angles, true_data.Z, inf_data.Z) - -def svi_example(true_model, true_data): - X, mask = true_data.X, true_data.mask + plot_results(lps, angles, Z_true, inf_data.Z) +def svi_example(true_model, X, Z_true, mask): # Fit a test model model = FactorAnalysis( D_obs, D_latent, @@ -158,16 +152,15 @@ def svi_example(true_model, true_data): angles.append(principal_angle(true_model.W, E_W)) # Compute the expected states for the first minibatch of data - model.add_data(X[:minibatchsize], mask[:minibatchsize]) + model.add_data(X, mask) statesobj = model.data_list.pop() statesobj.meanfieldupdate() Z_inf = statesobj.E_Z - Z_true = true_data.Z[:minibatchsize] plot_results(lps, angles, Z_true, Z_inf) if __name__ == "__main__": - true_model, true_data = generate_synth_data() - gibbs_example(true_model, true_data) - em_example(true_model, true_data) - meanfield_example(true_model, true_data) - svi_example(true_model, true_data) + true_model, X, Z_true, mask = generate_synth_data() + gibbs_example(true_model, X, Z_true, mask) + em_example(true_model, X, Z_true, mask) + meanfield_example(true_model, X, Z_true, mask) + svi_example(true_model, X, Z_true, mask) diff --git a/pybasicbayes/models/factor_analysis.py b/pybasicbayes/models/factor_analysis.py index 23aa648611b11b7e1a6a0a63c175eff98e01d1eb..6a8eba488ea5a7a3b921357da1d93d5f35be1f6e 100644 --- a/pybasicbayes/models/factor_analysis.py +++ b/pybasicbayes/models/factor_analysis.py @@ -48,6 +48,10 @@ class FactorAnalysisStates(object): def W(self): return self.model.W + @property + def mean(self): + return self.model.mean + @property def sigmasq(self): return self.model.sigmasq @@ -56,20 +60,31 @@ class FactorAnalysisStates(object): def regression(self): return self.model.regression - def log_likelihood(self): # mu = np.dot(self.Z, self.W.T) # return -0.5 * np.sum(((self.X - mu) * self.mask) ** 2 / self.sigmasq) # Compute the marginal likelihood, integrating out z - mu_x = np.zeros(self.D_obs) + mu_x = self.mean Sigma_x = self.W.dot(self.W.T) + np.diag(self.sigmasq) + from scipy.stats import multivariate_normal if not np.all(self.mask): - raise Exception("Need to implement this!") + # Find the patterns of missing dta + missing_patterns = np.unique(self.mask, axis=0) + + # Evaluate the likelihood for each missing pattern + lls = np.zeros(self.N) + for pat in missing_patterns: + inds = np.all(self.mask == pat, axis=1) + lls[inds] = \ + multivariate_normal(mu_x[pat], Sigma_x[np.ix_(pat, pat)])\ + .logpdf(self.X[np.ix_(inds, pat)]) + else: - from scipy.stats import multivariate_normal - return multivariate_normal(mu_x, Sigma_x).logpdf(self.X) + lls = multivariate_normal(mu_x, Sigma_x).logpdf(self.X) + + return lls ## Gibbs def resample(self): @@ -81,7 +96,7 @@ class FactorAnalysisStates(object): for n in range(self.N): Jobs = self.mask[n] / sigmasq Jpost = J0 + (W * Jobs[:, None]).T.dot(W) - hpost = h0 + (self.X[n] * Jobs).dot(W) + hpost = h0 + ((self.X[n] - self.mean) * Jobs).dot(W) self.Z[n] = sample_gaussian(J=Jpost, h=hpost) ## Mean field @@ -98,6 +113,9 @@ class FactorAnalysisStates(object): E_W, E_WWT, E_sigmasq_inv, _ = self.regression.mf_expectations self._meanfieldupdate(E_W, E_WWT, E_sigmasq_inv) + # Copy over the expected states to Z + self.Z = self.E_Z + def _meanfieldupdate(self, E_W, E_WWT, E_sigmasq_inv): N, D_obs, D_lat = self.N, self.D_obs, self.D_latent E_WWT_vec = E_WWT.reshape(D_obs, -1) @@ -113,7 +131,7 @@ class FactorAnalysisStates(object): Jobs = self.mask[n] * E_sigmasq_inv # Faster than Jpost = J0 + np.sum(E_WWT * Jobs[:,None,None], axis=0) Jpost = J0 + (np.dot(Jobs, E_WWT_vec)).reshape((D_lat, D_lat)) - hpost = h0 + (self.X[n] * Jobs).dot(E_W) + hpost = h0 + ((self.X[n] - self.mean) * Jobs).dot(E_W) # Get the expectations for this set of indices Sigma_post = np.linalg.inv(Jpost) @@ -124,8 +142,9 @@ class FactorAnalysisStates(object): def _set_expected_stats(self): D_lat = self.D_latent - E_Xsq = np.sum(self.X**2 * self.mask, axis=0) - E_XZT = (self.X * self.mask).T.dot(self.E_Z) + Xc = self.X - self.mean + E_Xsq = np.sum(Xc**2 * self.mask, axis=0) + E_XZT = (Xc * self.mask).T.dot(self.E_Z) E_ZZT_vec = self.E_ZZT.reshape((self.E_ZZT.shape[0], D_lat ** 2)) E_ZZT = np.array([np.dot(self.mask[:, d], E_ZZT_vec).reshape((D_lat, D_lat)) for d in range(self.D_obs)]) @@ -170,6 +189,9 @@ class _FactorAnalysisBase(Model): alpha_0=alpha_0, beta_0=beta_0, A=W, sigmasq=sigmasq) + # Handle the mean separately since DiagonalRegression doesn't support affine :-/ + self.mean = np.zeros(D_obs) + self.data_list = [] @property @@ -180,6 +202,11 @@ class _FactorAnalysisBase(Model): def sigmasq(self): return self.regression.sigmasq_flat + def set_empirical_mean(self): + self.mean = np.zeros(self.D_obs) + for n in range(self.D_obs): + self.mean[n] = np.concatenate([d.X[d.mask[:,n] == 1, n] for d in self.data_list]).mean() + def add_data(self, data, mask=None, **kwargs): self.data_list.append(self._states_class(self, data, mask=mask, **kwargs)) return self.data_list[-1] @@ -188,7 +215,7 @@ class _FactorAnalysisBase(Model): # Sample from the factor analysis model W, sigmasq = self.W, self.sigmasq Z = np.random.randn(N, self.D_latent) - X = np.dot(Z, W.T) + np.sqrt(sigmasq) * np.random.randn(N, self.D_obs) + X = self.mean + np.dot(Z, W.T) + np.sqrt(sigmasq) * np.random.randn(N, self.D_obs) data = self._states_class(self, X, mask=mask, **kwargs) data.Z = Z @@ -204,7 +231,6 @@ class _FactorAnalysisBase(Model): def log_likelihood(self): return sum([d.log_likelihood().sum() for d in self.data_list]) - def log_probability(self): lp = 0 @@ -216,7 +242,7 @@ class _FactorAnalysisBase(Model): return lp -class _FactorAnalysisGibbs(ModelGibbsSampling, _FactorAnalysisBase): +class _FactorAnalysisGibbs(_FactorAnalysisBase, ModelGibbsSampling): __metaclass__ = abc.ABCMeta def resample_model(self): @@ -229,7 +255,7 @@ class _FactorAnalysisGibbs(ModelGibbsSampling, _FactorAnalysisBase): self.regression.resample((Zs, Xs), mask=mask) -class _FactorAnalysisEM(ModelEM, _FactorAnalysisBase): +class _FactorAnalysisEM(_FactorAnalysisBase, ModelEM): def _null_stats(self): return objarray( @@ -238,19 +264,16 @@ class _FactorAnalysisEM(ModelEM, _FactorAnalysisBase): np.zeros((self.D_obs, self.D_latent, self.D_latent)), np.zeros(self.D_obs)]) - def log_likelihood(self): - # TODO: Fix inheritance issues... - return np.sum([d.log_likelihood() for d in self.data_list]) - def EM_step(self): for data in self.data_list: data.E_step() stats = self._null_stats() + sum([d.E_emission_stats for d in self.data_list]) self.regression.max_likelihood(data=None, weights=None, stats=stats) + assert np.all(np.isfinite(self.sigmasq )) -class _FactorAnalysisMeanField(ModelMeanField, ModelMeanFieldSVI, _FactorAnalysisBase): +class _FactorAnalysisMeanField(_FactorAnalysisBase, ModelMeanField, ModelMeanFieldSVI): __metaclass__ = abc.ABCMeta def _null_stats(self):