From 73b8223fd2c89c29c2dc84c2158bed9b33d2d95d Mon Sep 17 00:00:00 2001 From: Scott Linderman <scott.linderman@gmail.com> Date: Wed, 20 Dec 2017 14:32:10 -0500 Subject: [PATCH] add color code to state usage --- pybasicbayes/models/factor_analysis.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/pybasicbayes/models/factor_analysis.py b/pybasicbayes/models/factor_analysis.py index 69a0624..23aa648 100644 --- a/pybasicbayes/models/factor_analysis.py +++ b/pybasicbayes/models/factor_analysis.py @@ -26,9 +26,9 @@ class FactorAnalysisStates(object): def __init__(self, model, data, mask=None, **kwargs): self.model = model self.X = data - self.mask = mask if mask is None: mask = np.ones_like(data, dtype=bool) + self.mask = mask assert data.shape == mask.shape and mask.dtype == bool assert self.X.shape[1] == self.D_obs @@ -58,8 +58,18 @@ class FactorAnalysisStates(object): def log_likelihood(self): - mu = np.dot(self.Z, self.W.T) - return -0.5 * np.sum(((self.X - mu) * self.mask) ** 2 / self.sigmasq) + # mu = np.dot(self.Z, self.W.T) + # return -0.5 * np.sum(((self.X - mu) * self.mask) ** 2 / self.sigmasq) + + # Compute the marginal likelihood, integrating out z + mu_x = np.zeros(self.D_obs) + Sigma_x = self.W.dot(self.W.T) + np.diag(self.sigmasq) + + if not np.all(self.mask): + raise Exception("Need to implement this!") + else: + from scipy.stats import multivariate_normal + return multivariate_normal(mu_x, Sigma_x).logpdf(self.X) ## Gibbs def resample(self): @@ -184,10 +194,16 @@ class _FactorAnalysisBase(Model): data.Z = Z if keep: self.data_list.append(data) - return data + return data.X, data.Z + + def _log_likelihoods(self, x, mask=None, **kwargs): + self.add_data(x, mask=mask, **kwargs) + states = self.data_list.pop() + return states.log_likelihood() def log_likelihood(self): - return np.sum([d.log_likelihood() for d in self.data_list]) + return sum([d.log_likelihood().sum() for d in self.data_list]) + def log_probability(self): lp = 0 -- GitLab