Skip to content
Snippets Groups Projects
Commit 73b8223f authored by Scott Linderman's avatar Scott Linderman
Browse files

add color code to state usage

parent 2bdecbea
No related branches found
No related tags found
No related merge requests found
...@@ -26,9 +26,9 @@ class FactorAnalysisStates(object): ...@@ -26,9 +26,9 @@ class FactorAnalysisStates(object):
def __init__(self, model, data, mask=None, **kwargs): def __init__(self, model, data, mask=None, **kwargs):
self.model = model self.model = model
self.X = data self.X = data
self.mask = mask
if mask is None: if mask is None:
mask = np.ones_like(data, dtype=bool) mask = np.ones_like(data, dtype=bool)
self.mask = mask
assert data.shape == mask.shape and mask.dtype == bool assert data.shape == mask.shape and mask.dtype == bool
assert self.X.shape[1] == self.D_obs assert self.X.shape[1] == self.D_obs
...@@ -58,8 +58,18 @@ class FactorAnalysisStates(object): ...@@ -58,8 +58,18 @@ class FactorAnalysisStates(object):
def log_likelihood(self): def log_likelihood(self):
mu = np.dot(self.Z, self.W.T) # mu = np.dot(self.Z, self.W.T)
return -0.5 * np.sum(((self.X - mu) * self.mask) ** 2 / self.sigmasq) # return -0.5 * np.sum(((self.X - mu) * self.mask) ** 2 / self.sigmasq)
# Compute the marginal likelihood, integrating out z
mu_x = np.zeros(self.D_obs)
Sigma_x = self.W.dot(self.W.T) + np.diag(self.sigmasq)
if not np.all(self.mask):
raise Exception("Need to implement this!")
else:
from scipy.stats import multivariate_normal
return multivariate_normal(mu_x, Sigma_x).logpdf(self.X)
## Gibbs ## Gibbs
def resample(self): def resample(self):
...@@ -184,10 +194,16 @@ class _FactorAnalysisBase(Model): ...@@ -184,10 +194,16 @@ class _FactorAnalysisBase(Model):
data.Z = Z data.Z = Z
if keep: if keep:
self.data_list.append(data) self.data_list.append(data)
return data return data.X, data.Z
def _log_likelihoods(self, x, mask=None, **kwargs):
self.add_data(x, mask=mask, **kwargs)
states = self.data_list.pop()
return states.log_likelihood()
def log_likelihood(self): def log_likelihood(self):
return np.sum([d.log_likelihood() for d in self.data_list]) return sum([d.log_likelihood().sum() for d in self.data_list])
def log_probability(self): def log_probability(self):
lp = 0 lp = 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment