diff --git a/examples/factor_analysis.py b/examples/factor_analysis.py
index 5e3ac803e03d3a667cd10f62a62a682c35a144c0..bbc93d9094b260d18523fa169360f168810e74dc 100644
--- a/examples/factor_analysis.py
+++ b/examples/factor_analysis.py
@@ -6,7 +6,6 @@ import matplotlib.pyplot as plt
 from matplotlib.cm import get_cmap
 
 import pybasicbayes.models.factor_analysis
-reload(pybasicbayes.models.factor_analysis)
 from pybasicbayes.models.factor_analysis import FactorAnalysis
 
 N = 2000
@@ -29,8 +28,8 @@ def generate_synth_data():
     # Create a true model and sample from it
     mask = np.random.rand(N,D_obs) < 0.9
     true_model = FactorAnalysis(D_obs, D_latent)
-    true_data = true_model.generate(N=N, mask=mask, keep=True)
-    return true_model, true_data
+    X, Z_true = true_model.generate(N=N, mask=mask, keep=True)
+    return true_model, X, Z_true, mask
 
 
 def plot_results(lls, angles, Ztrue, Zinf):
@@ -69,15 +68,14 @@ def plot_results(lls, angles, Ztrue, Zinf):
 
     plt.show()
 
-def gibbs_example(true_model, true_data):
-    X, mask = true_data.X, true_data.mask
-
+def gibbs_example(true_model, X, Z_true, mask):   
     # Fit a test model
     model = FactorAnalysis(
         D_obs, D_latent,
         # W=true_model.W, sigmasq=true_model.sigmasq
         )
     inf_data = model.add_data(X, mask=mask)
+    model.set_empirical_mean()
 
     lps = []
     angles = []
@@ -87,17 +85,16 @@ def gibbs_example(true_model, true_data):
         lps.append(model.log_likelihood())
         angles.append(principal_angle(true_model.W, model.W))
 
-    plot_results(lps, angles, true_data.Z, inf_data.Z)
-
-def em_example(true_model, true_data):
-    X, mask = true_data.X, true_data.mask
+    plot_results(lps, angles, Z_true, inf_data.Z)
 
+def em_example(true_model, X, Z_true, mask):
     # Fit a test model
     model = FactorAnalysis(
         D_obs, D_latent,
         # W=true_model.W, sigmasq=true_model.sigmasq
         )
     inf_data = model.add_data(X, mask=mask)
+    model.set_empirical_mean()
 
     lps = []
     angles = []
@@ -107,17 +104,16 @@ def em_example(true_model, true_data):
         lps.append(model.log_likelihood())
         angles.append(principal_angle(true_model.W, model.W))
 
-    plot_results(lps, angles, true_data.Z, inf_data.E_Z)
-
-def meanfield_example(true_model, true_data):
-    X, mask = true_data.X, true_data.mask
+    plot_results(lps, angles, Z_true, inf_data.E_Z)
 
+def meanfield_example(true_model, X, Z_true, mask):
     # Fit a test model
     model = FactorAnalysis(
         D_obs, D_latent,
         # W=true_model.W, sigmasq=true_model.sigmasq
         )
     inf_data = model.add_data(X, mask=mask)
+    model.set_empirical_mean()
 
     lps = []
     angles = []
@@ -128,11 +124,9 @@ def meanfield_example(true_model, true_data):
         E_W, _, _, _ = model.regression.mf_expectations
         angles.append(principal_angle(true_model.W, E_W))
 
-    plot_results(lps, angles, true_data.Z, inf_data.Z)
-
-def svi_example(true_model, true_data):
-    X, mask = true_data.X, true_data.mask
+    plot_results(lps, angles, Z_true, inf_data.Z)
 
+def svi_example(true_model, X, Z_true, mask):
     # Fit a test model
     model = FactorAnalysis(
         D_obs, D_latent,
@@ -158,16 +152,15 @@ def svi_example(true_model, true_data):
         angles.append(principal_angle(true_model.W, E_W))
 
     # Compute the expected states for the first minibatch of data
-    model.add_data(X[:minibatchsize], mask[:minibatchsize])
+    model.add_data(X, mask)
     statesobj = model.data_list.pop()
     statesobj.meanfieldupdate()
     Z_inf = statesobj.E_Z
-    Z_true = true_data.Z[:minibatchsize]
     plot_results(lps, angles, Z_true, Z_inf)
 
 if __name__ == "__main__":
-    true_model, true_data = generate_synth_data()
-    gibbs_example(true_model, true_data)
-    em_example(true_model, true_data)
-    meanfield_example(true_model, true_data)
-    svi_example(true_model, true_data)
+    true_model, X, Z_true, mask = generate_synth_data()
+    gibbs_example(true_model, X, Z_true, mask)
+    em_example(true_model, X, Z_true, mask)
+    meanfield_example(true_model, X, Z_true, mask)
+    svi_example(true_model, X, Z_true, mask)
diff --git a/pybasicbayes/models/factor_analysis.py b/pybasicbayes/models/factor_analysis.py
index 23aa648611b11b7e1a6a0a63c175eff98e01d1eb..6a8eba488ea5a7a3b921357da1d93d5f35be1f6e 100644
--- a/pybasicbayes/models/factor_analysis.py
+++ b/pybasicbayes/models/factor_analysis.py
@@ -48,6 +48,10 @@ class FactorAnalysisStates(object):
     def W(self):
         return self.model.W
 
+    @property
+    def mean(self):
+        return self.model.mean
+
     @property
     def sigmasq(self):
         return self.model.sigmasq
@@ -56,20 +60,31 @@ class FactorAnalysisStates(object):
     def regression(self):
         return self.model.regression
 
-
     def log_likelihood(self):
         # mu = np.dot(self.Z, self.W.T)
         # return -0.5 * np.sum(((self.X - mu) * self.mask) ** 2 / self.sigmasq)
 
         # Compute the marginal likelihood, integrating out z
-        mu_x = np.zeros(self.D_obs)
+        mu_x = self.mean
         Sigma_x = self.W.dot(self.W.T) + np.diag(self.sigmasq)
 
+        from scipy.stats import multivariate_normal
         if not np.all(self.mask):
-            raise Exception("Need to implement this!")
+            # Find the patterns of missing dta
+            missing_patterns = np.unique(self.mask, axis=0)
+
+            # Evaluate the likelihood for each missing pattern
+            lls = np.zeros(self.N)
+            for pat in missing_patterns:
+                inds = np.all(self.mask == pat, axis=1)
+                lls[inds] = \
+                    multivariate_normal(mu_x[pat], Sigma_x[np.ix_(pat, pat)])\
+                    .logpdf(self.X[np.ix_(inds, pat)])
+
         else:
-            from scipy.stats import multivariate_normal
-            return multivariate_normal(mu_x, Sigma_x).logpdf(self.X)
+            lls = multivariate_normal(mu_x, Sigma_x).logpdf(self.X)
+
+        return lls
 
     ## Gibbs
     def resample(self):
@@ -81,7 +96,7 @@ class FactorAnalysisStates(object):
         for n in range(self.N):
             Jobs = self.mask[n] / sigmasq
             Jpost = J0 + (W * Jobs[:, None]).T.dot(W)
-            hpost = h0 + (self.X[n] * Jobs).dot(W)
+            hpost = h0 + ((self.X[n] - self.mean) * Jobs).dot(W)
             self.Z[n] = sample_gaussian(J=Jpost, h=hpost)
 
     ## Mean field
@@ -98,6 +113,9 @@ class FactorAnalysisStates(object):
         E_W, E_WWT, E_sigmasq_inv, _ = self.regression.mf_expectations
         self._meanfieldupdate(E_W, E_WWT, E_sigmasq_inv)
 
+        # Copy over the expected states to Z
+        self.Z = self.E_Z
+
     def _meanfieldupdate(self, E_W, E_WWT, E_sigmasq_inv):
         N, D_obs, D_lat = self.N, self.D_obs, self.D_latent
         E_WWT_vec = E_WWT.reshape(D_obs, -1)
@@ -113,7 +131,7 @@ class FactorAnalysisStates(object):
             Jobs = self.mask[n] * E_sigmasq_inv
             # Faster than Jpost = J0 + np.sum(E_WWT * Jobs[:,None,None], axis=0)
             Jpost = J0 + (np.dot(Jobs, E_WWT_vec)).reshape((D_lat, D_lat))
-            hpost = h0 + (self.X[n] * Jobs).dot(E_W)
+            hpost = h0 + ((self.X[n] - self.mean) * Jobs).dot(E_W)
 
             # Get the expectations for this set of indices
             Sigma_post = np.linalg.inv(Jpost)
@@ -124,8 +142,9 @@ class FactorAnalysisStates(object):
 
     def _set_expected_stats(self):
         D_lat = self.D_latent
-        E_Xsq = np.sum(self.X**2 * self.mask, axis=0)
-        E_XZT = (self.X * self.mask).T.dot(self.E_Z)
+        Xc = self.X - self.mean
+        E_Xsq = np.sum(Xc**2 * self.mask, axis=0)
+        E_XZT = (Xc * self.mask).T.dot(self.E_Z)
         E_ZZT_vec = self.E_ZZT.reshape((self.E_ZZT.shape[0], D_lat ** 2))
         E_ZZT = np.array([np.dot(self.mask[:, d], E_ZZT_vec).reshape((D_lat, D_lat))
                           for d in range(self.D_obs)])
@@ -170,6 +189,9 @@ class _FactorAnalysisBase(Model):
                 alpha_0=alpha_0, beta_0=beta_0,
                 A=W, sigmasq=sigmasq)
 
+        # Handle the mean separately since DiagonalRegression doesn't support affine :-/
+        self.mean = np.zeros(D_obs)
+
         self.data_list = []
 
     @property
@@ -180,6 +202,11 @@ class _FactorAnalysisBase(Model):
     def sigmasq(self):
         return self.regression.sigmasq_flat
 
+    def set_empirical_mean(self):
+        self.mean = np.zeros(self.D_obs)
+        for n in range(self.D_obs):
+            self.mean[n] = np.concatenate([d.X[d.mask[:,n] == 1, n] for d in self.data_list]).mean()
+
     def add_data(self, data, mask=None, **kwargs):
         self.data_list.append(self._states_class(self, data, mask=mask, **kwargs))
         return self.data_list[-1]
@@ -188,7 +215,7 @@ class _FactorAnalysisBase(Model):
         # Sample from the factor analysis model
         W, sigmasq = self.W, self.sigmasq
         Z = np.random.randn(N, self.D_latent)
-        X = np.dot(Z, W.T) + np.sqrt(sigmasq) * np.random.randn(N, self.D_obs)
+        X = self.mean + np.dot(Z, W.T) + np.sqrt(sigmasq) * np.random.randn(N, self.D_obs)
 
         data = self._states_class(self, X, mask=mask, **kwargs)
         data.Z = Z
@@ -204,7 +231,6 @@ class _FactorAnalysisBase(Model):
     def log_likelihood(self):
         return sum([d.log_likelihood().sum() for d in self.data_list])
 
-
     def log_probability(self):
         lp = 0
 
@@ -216,7 +242,7 @@ class _FactorAnalysisBase(Model):
         return lp
 
 
-class _FactorAnalysisGibbs(ModelGibbsSampling, _FactorAnalysisBase):
+class _FactorAnalysisGibbs(_FactorAnalysisBase, ModelGibbsSampling):
     __metaclass__ = abc.ABCMeta
 
     def resample_model(self):
@@ -229,7 +255,7 @@ class _FactorAnalysisGibbs(ModelGibbsSampling, _FactorAnalysisBase):
         self.regression.resample((Zs, Xs), mask=mask)
 
 
-class _FactorAnalysisEM(ModelEM, _FactorAnalysisBase):
+class _FactorAnalysisEM(_FactorAnalysisBase, ModelEM):
 
     def _null_stats(self):
         return objarray(
@@ -238,19 +264,16 @@ class _FactorAnalysisEM(ModelEM, _FactorAnalysisBase):
              np.zeros((self.D_obs, self.D_latent, self.D_latent)),
              np.zeros(self.D_obs)])
 
-    def log_likelihood(self):
-        # TODO: Fix inheritance issues...
-        return np.sum([d.log_likelihood() for d in self.data_list])
-
     def EM_step(self):
         for data in self.data_list:
             data.E_step()
 
         stats = self._null_stats() + sum([d.E_emission_stats for d in self.data_list])
         self.regression.max_likelihood(data=None, weights=None, stats=stats)
+        assert np.all(np.isfinite(self.sigmasq ))
 
 
-class _FactorAnalysisMeanField(ModelMeanField, ModelMeanFieldSVI, _FactorAnalysisBase):
+class _FactorAnalysisMeanField(_FactorAnalysisBase, ModelMeanField, ModelMeanFieldSVI):
     __metaclass__ = abc.ABCMeta
 
     def _null_stats(self):