Skip to content
Snippets Groups Projects
Commit 73b8223f authored by Scott Linderman's avatar Scott Linderman
Browse files

add color code to state usage

parent 2bdecbea
No related branches found
No related tags found
No related merge requests found
......@@ -26,9 +26,9 @@ class FactorAnalysisStates(object):
def __init__(self, model, data, mask=None, **kwargs):
self.model = model
self.X = data
self.mask = mask
if mask is None:
mask = np.ones_like(data, dtype=bool)
self.mask = mask
assert data.shape == mask.shape and mask.dtype == bool
assert self.X.shape[1] == self.D_obs
......@@ -58,8 +58,18 @@ class FactorAnalysisStates(object):
def log_likelihood(self):
mu = np.dot(self.Z, self.W.T)
return -0.5 * np.sum(((self.X - mu) * self.mask) ** 2 / self.sigmasq)
# mu = np.dot(self.Z, self.W.T)
# return -0.5 * np.sum(((self.X - mu) * self.mask) ** 2 / self.sigmasq)
# Compute the marginal likelihood, integrating out z
mu_x = np.zeros(self.D_obs)
Sigma_x = self.W.dot(self.W.T) + np.diag(self.sigmasq)
if not np.all(self.mask):
raise Exception("Need to implement this!")
else:
from scipy.stats import multivariate_normal
return multivariate_normal(mu_x, Sigma_x).logpdf(self.X)
## Gibbs
def resample(self):
......@@ -184,10 +194,16 @@ class _FactorAnalysisBase(Model):
data.Z = Z
if keep:
self.data_list.append(data)
return data
return data.X, data.Z
def _log_likelihoods(self, x, mask=None, **kwargs):
self.add_data(x, mask=mask, **kwargs)
states = self.data_list.pop()
return states.log_likelihood()
def log_likelihood(self):
return np.sum([d.log_likelihood() for d in self.data_list])
return sum([d.log_likelihood().sum() for d in self.data_list])
def log_probability(self):
lp = 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment