diff --git a/pybasicbayes/models/factor_analysis.py b/pybasicbayes/models/factor_analysis.py index 69a0624742be8dfb58e215cb3949c3decf8a3cae..23aa648611b11b7e1a6a0a63c175eff98e01d1eb 100644 --- a/pybasicbayes/models/factor_analysis.py +++ b/pybasicbayes/models/factor_analysis.py @@ -26,9 +26,9 @@ class FactorAnalysisStates(object): def __init__(self, model, data, mask=None, **kwargs): self.model = model self.X = data - self.mask = mask if mask is None: mask = np.ones_like(data, dtype=bool) + self.mask = mask assert data.shape == mask.shape and mask.dtype == bool assert self.X.shape[1] == self.D_obs @@ -58,8 +58,18 @@ class FactorAnalysisStates(object): def log_likelihood(self): - mu = np.dot(self.Z, self.W.T) - return -0.5 * np.sum(((self.X - mu) * self.mask) ** 2 / self.sigmasq) + # mu = np.dot(self.Z, self.W.T) + # return -0.5 * np.sum(((self.X - mu) * self.mask) ** 2 / self.sigmasq) + + # Compute the marginal likelihood, integrating out z + mu_x = np.zeros(self.D_obs) + Sigma_x = self.W.dot(self.W.T) + np.diag(self.sigmasq) + + if not np.all(self.mask): + raise Exception("Need to implement this!") + else: + from scipy.stats import multivariate_normal + return multivariate_normal(mu_x, Sigma_x).logpdf(self.X) ## Gibbs def resample(self): @@ -184,10 +194,16 @@ class _FactorAnalysisBase(Model): data.Z = Z if keep: self.data_list.append(data) - return data + return data.X, data.Z + + def _log_likelihoods(self, x, mask=None, **kwargs): + self.add_data(x, mask=mask, **kwargs) + states = self.data_list.pop() + return states.log_likelihood() def log_likelihood(self): - return np.sum([d.log_likelihood() for d in self.data_list]) + return sum([d.log_likelihood().sum() for d in self.data_list]) + def log_probability(self): lp = 0