diff --git a/pybasicbayes/models/factor_analysis.py b/pybasicbayes/models/factor_analysis.py
index 69a0624742be8dfb58e215cb3949c3decf8a3cae..23aa648611b11b7e1a6a0a63c175eff98e01d1eb 100644
--- a/pybasicbayes/models/factor_analysis.py
+++ b/pybasicbayes/models/factor_analysis.py
@@ -26,9 +26,9 @@ class FactorAnalysisStates(object):
     def __init__(self, model, data, mask=None, **kwargs):
         self.model = model
         self.X = data
-        self.mask = mask
         if mask is None:
             mask = np.ones_like(data, dtype=bool)
+        self.mask = mask
         assert data.shape == mask.shape and mask.dtype == bool
         assert self.X.shape[1] == self.D_obs
 
@@ -58,8 +58,18 @@ class FactorAnalysisStates(object):
 
 
     def log_likelihood(self):
-        mu = np.dot(self.Z, self.W.T)
-        return -0.5 * np.sum(((self.X - mu) * self.mask) ** 2 / self.sigmasq)
+        # mu = np.dot(self.Z, self.W.T)
+        # return -0.5 * np.sum(((self.X - mu) * self.mask) ** 2 / self.sigmasq)
+
+        # Compute the marginal likelihood, integrating out z
+        mu_x = np.zeros(self.D_obs)
+        Sigma_x = self.W.dot(self.W.T) + np.diag(self.sigmasq)
+
+        if not np.all(self.mask):
+            raise Exception("Need to implement this!")
+        else:
+            from scipy.stats import multivariate_normal
+            return multivariate_normal(mu_x, Sigma_x).logpdf(self.X)
 
     ## Gibbs
     def resample(self):
@@ -184,10 +194,16 @@ class _FactorAnalysisBase(Model):
         data.Z = Z
         if keep:
             self.data_list.append(data)
-        return data
+        return data.X, data.Z
+
+    def _log_likelihoods(self, x, mask=None, **kwargs):
+        self.add_data(x, mask=mask, **kwargs)
+        states = self.data_list.pop()
+        return states.log_likelihood()
 
     def log_likelihood(self):
-        return np.sum([d.log_likelihood() for d in self.data_list])
+        return sum([d.log_likelihood().sum() for d in self.data_list])
+
 
     def log_probability(self):
         lp = 0