Skip to content
Snippets Groups Projects
Commit 5b02b404 authored by Scott Linderman's avatar Scott Linderman
Browse files

cleaning up factor analysis model and examples

parent c7dff355
No related branches found
No related tags found
No related merge requests found
...@@ -6,7 +6,6 @@ import matplotlib.pyplot as plt ...@@ -6,7 +6,6 @@ import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap from matplotlib.cm import get_cmap
import pybasicbayes.models.factor_analysis import pybasicbayes.models.factor_analysis
reload(pybasicbayes.models.factor_analysis)
from pybasicbayes.models.factor_analysis import FactorAnalysis from pybasicbayes.models.factor_analysis import FactorAnalysis
N = 2000 N = 2000
...@@ -29,8 +28,8 @@ def generate_synth_data(): ...@@ -29,8 +28,8 @@ def generate_synth_data():
# Create a true model and sample from it # Create a true model and sample from it
mask = np.random.rand(N,D_obs) < 0.9 mask = np.random.rand(N,D_obs) < 0.9
true_model = FactorAnalysis(D_obs, D_latent) true_model = FactorAnalysis(D_obs, D_latent)
true_data = true_model.generate(N=N, mask=mask, keep=True) X, Z_true = true_model.generate(N=N, mask=mask, keep=True)
return true_model, true_data return true_model, X, Z_true, mask
def plot_results(lls, angles, Ztrue, Zinf): def plot_results(lls, angles, Ztrue, Zinf):
...@@ -69,15 +68,14 @@ def plot_results(lls, angles, Ztrue, Zinf): ...@@ -69,15 +68,14 @@ def plot_results(lls, angles, Ztrue, Zinf):
plt.show() plt.show()
def gibbs_example(true_model, true_data): def gibbs_example(true_model, X, Z_true, mask):
X, mask = true_data.X, true_data.mask
# Fit a test model # Fit a test model
model = FactorAnalysis( model = FactorAnalysis(
D_obs, D_latent, D_obs, D_latent,
# W=true_model.W, sigmasq=true_model.sigmasq # W=true_model.W, sigmasq=true_model.sigmasq
) )
inf_data = model.add_data(X, mask=mask) inf_data = model.add_data(X, mask=mask)
model.set_empirical_mean()
lps = [] lps = []
angles = [] angles = []
...@@ -87,17 +85,16 @@ def gibbs_example(true_model, true_data): ...@@ -87,17 +85,16 @@ def gibbs_example(true_model, true_data):
lps.append(model.log_likelihood()) lps.append(model.log_likelihood())
angles.append(principal_angle(true_model.W, model.W)) angles.append(principal_angle(true_model.W, model.W))
plot_results(lps, angles, true_data.Z, inf_data.Z) plot_results(lps, angles, Z_true, inf_data.Z)
def em_example(true_model, true_data):
X, mask = true_data.X, true_data.mask
def em_example(true_model, X, Z_true, mask):
# Fit a test model # Fit a test model
model = FactorAnalysis( model = FactorAnalysis(
D_obs, D_latent, D_obs, D_latent,
# W=true_model.W, sigmasq=true_model.sigmasq # W=true_model.W, sigmasq=true_model.sigmasq
) )
inf_data = model.add_data(X, mask=mask) inf_data = model.add_data(X, mask=mask)
model.set_empirical_mean()
lps = [] lps = []
angles = [] angles = []
...@@ -107,17 +104,16 @@ def em_example(true_model, true_data): ...@@ -107,17 +104,16 @@ def em_example(true_model, true_data):
lps.append(model.log_likelihood()) lps.append(model.log_likelihood())
angles.append(principal_angle(true_model.W, model.W)) angles.append(principal_angle(true_model.W, model.W))
plot_results(lps, angles, true_data.Z, inf_data.E_Z) plot_results(lps, angles, Z_true, inf_data.E_Z)
def meanfield_example(true_model, true_data):
X, mask = true_data.X, true_data.mask
def meanfield_example(true_model, X, Z_true, mask):
# Fit a test model # Fit a test model
model = FactorAnalysis( model = FactorAnalysis(
D_obs, D_latent, D_obs, D_latent,
# W=true_model.W, sigmasq=true_model.sigmasq # W=true_model.W, sigmasq=true_model.sigmasq
) )
inf_data = model.add_data(X, mask=mask) inf_data = model.add_data(X, mask=mask)
model.set_empirical_mean()
lps = [] lps = []
angles = [] angles = []
...@@ -128,11 +124,9 @@ def meanfield_example(true_model, true_data): ...@@ -128,11 +124,9 @@ def meanfield_example(true_model, true_data):
E_W, _, _, _ = model.regression.mf_expectations E_W, _, _, _ = model.regression.mf_expectations
angles.append(principal_angle(true_model.W, E_W)) angles.append(principal_angle(true_model.W, E_W))
plot_results(lps, angles, true_data.Z, inf_data.Z) plot_results(lps, angles, Z_true, inf_data.Z)
def svi_example(true_model, true_data):
X, mask = true_data.X, true_data.mask
def svi_example(true_model, X, Z_true, mask):
# Fit a test model # Fit a test model
model = FactorAnalysis( model = FactorAnalysis(
D_obs, D_latent, D_obs, D_latent,
...@@ -158,16 +152,15 @@ def svi_example(true_model, true_data): ...@@ -158,16 +152,15 @@ def svi_example(true_model, true_data):
angles.append(principal_angle(true_model.W, E_W)) angles.append(principal_angle(true_model.W, E_W))
# Compute the expected states for the first minibatch of data # Compute the expected states for the first minibatch of data
model.add_data(X[:minibatchsize], mask[:minibatchsize]) model.add_data(X, mask)
statesobj = model.data_list.pop() statesobj = model.data_list.pop()
statesobj.meanfieldupdate() statesobj.meanfieldupdate()
Z_inf = statesobj.E_Z Z_inf = statesobj.E_Z
Z_true = true_data.Z[:minibatchsize]
plot_results(lps, angles, Z_true, Z_inf) plot_results(lps, angles, Z_true, Z_inf)
if __name__ == "__main__": if __name__ == "__main__":
true_model, true_data = generate_synth_data() true_model, X, Z_true, mask = generate_synth_data()
gibbs_example(true_model, true_data) gibbs_example(true_model, X, Z_true, mask)
em_example(true_model, true_data) em_example(true_model, X, Z_true, mask)
meanfield_example(true_model, true_data) meanfield_example(true_model, X, Z_true, mask)
svi_example(true_model, true_data) svi_example(true_model, X, Z_true, mask)
...@@ -48,6 +48,10 @@ class FactorAnalysisStates(object): ...@@ -48,6 +48,10 @@ class FactorAnalysisStates(object):
def W(self): def W(self):
return self.model.W return self.model.W
@property
def mean(self):
return self.model.mean
@property @property
def sigmasq(self): def sigmasq(self):
return self.model.sigmasq return self.model.sigmasq
...@@ -56,20 +60,31 @@ class FactorAnalysisStates(object): ...@@ -56,20 +60,31 @@ class FactorAnalysisStates(object):
def regression(self): def regression(self):
return self.model.regression return self.model.regression
def log_likelihood(self): def log_likelihood(self):
# mu = np.dot(self.Z, self.W.T) # mu = np.dot(self.Z, self.W.T)
# return -0.5 * np.sum(((self.X - mu) * self.mask) ** 2 / self.sigmasq) # return -0.5 * np.sum(((self.X - mu) * self.mask) ** 2 / self.sigmasq)
# Compute the marginal likelihood, integrating out z # Compute the marginal likelihood, integrating out z
mu_x = np.zeros(self.D_obs) mu_x = self.mean
Sigma_x = self.W.dot(self.W.T) + np.diag(self.sigmasq) Sigma_x = self.W.dot(self.W.T) + np.diag(self.sigmasq)
from scipy.stats import multivariate_normal
if not np.all(self.mask): if not np.all(self.mask):
raise Exception("Need to implement this!") # Find the patterns of missing dta
missing_patterns = np.unique(self.mask, axis=0)
# Evaluate the likelihood for each missing pattern
lls = np.zeros(self.N)
for pat in missing_patterns:
inds = np.all(self.mask == pat, axis=1)
lls[inds] = \
multivariate_normal(mu_x[pat], Sigma_x[np.ix_(pat, pat)])\
.logpdf(self.X[np.ix_(inds, pat)])
else: else:
from scipy.stats import multivariate_normal lls = multivariate_normal(mu_x, Sigma_x).logpdf(self.X)
return multivariate_normal(mu_x, Sigma_x).logpdf(self.X)
return lls
## Gibbs ## Gibbs
def resample(self): def resample(self):
...@@ -81,7 +96,7 @@ class FactorAnalysisStates(object): ...@@ -81,7 +96,7 @@ class FactorAnalysisStates(object):
for n in range(self.N): for n in range(self.N):
Jobs = self.mask[n] / sigmasq Jobs = self.mask[n] / sigmasq
Jpost = J0 + (W * Jobs[:, None]).T.dot(W) Jpost = J0 + (W * Jobs[:, None]).T.dot(W)
hpost = h0 + (self.X[n] * Jobs).dot(W) hpost = h0 + ((self.X[n] - self.mean) * Jobs).dot(W)
self.Z[n] = sample_gaussian(J=Jpost, h=hpost) self.Z[n] = sample_gaussian(J=Jpost, h=hpost)
## Mean field ## Mean field
...@@ -98,6 +113,9 @@ class FactorAnalysisStates(object): ...@@ -98,6 +113,9 @@ class FactorAnalysisStates(object):
E_W, E_WWT, E_sigmasq_inv, _ = self.regression.mf_expectations E_W, E_WWT, E_sigmasq_inv, _ = self.regression.mf_expectations
self._meanfieldupdate(E_W, E_WWT, E_sigmasq_inv) self._meanfieldupdate(E_W, E_WWT, E_sigmasq_inv)
# Copy over the expected states to Z
self.Z = self.E_Z
def _meanfieldupdate(self, E_W, E_WWT, E_sigmasq_inv): def _meanfieldupdate(self, E_W, E_WWT, E_sigmasq_inv):
N, D_obs, D_lat = self.N, self.D_obs, self.D_latent N, D_obs, D_lat = self.N, self.D_obs, self.D_latent
E_WWT_vec = E_WWT.reshape(D_obs, -1) E_WWT_vec = E_WWT.reshape(D_obs, -1)
...@@ -113,7 +131,7 @@ class FactorAnalysisStates(object): ...@@ -113,7 +131,7 @@ class FactorAnalysisStates(object):
Jobs = self.mask[n] * E_sigmasq_inv Jobs = self.mask[n] * E_sigmasq_inv
# Faster than Jpost = J0 + np.sum(E_WWT * Jobs[:,None,None], axis=0) # Faster than Jpost = J0 + np.sum(E_WWT * Jobs[:,None,None], axis=0)
Jpost = J0 + (np.dot(Jobs, E_WWT_vec)).reshape((D_lat, D_lat)) Jpost = J0 + (np.dot(Jobs, E_WWT_vec)).reshape((D_lat, D_lat))
hpost = h0 + (self.X[n] * Jobs).dot(E_W) hpost = h0 + ((self.X[n] - self.mean) * Jobs).dot(E_W)
# Get the expectations for this set of indices # Get the expectations for this set of indices
Sigma_post = np.linalg.inv(Jpost) Sigma_post = np.linalg.inv(Jpost)
...@@ -124,8 +142,9 @@ class FactorAnalysisStates(object): ...@@ -124,8 +142,9 @@ class FactorAnalysisStates(object):
def _set_expected_stats(self): def _set_expected_stats(self):
D_lat = self.D_latent D_lat = self.D_latent
E_Xsq = np.sum(self.X**2 * self.mask, axis=0) Xc = self.X - self.mean
E_XZT = (self.X * self.mask).T.dot(self.E_Z) E_Xsq = np.sum(Xc**2 * self.mask, axis=0)
E_XZT = (Xc * self.mask).T.dot(self.E_Z)
E_ZZT_vec = self.E_ZZT.reshape((self.E_ZZT.shape[0], D_lat ** 2)) E_ZZT_vec = self.E_ZZT.reshape((self.E_ZZT.shape[0], D_lat ** 2))
E_ZZT = np.array([np.dot(self.mask[:, d], E_ZZT_vec).reshape((D_lat, D_lat)) E_ZZT = np.array([np.dot(self.mask[:, d], E_ZZT_vec).reshape((D_lat, D_lat))
for d in range(self.D_obs)]) for d in range(self.D_obs)])
...@@ -170,6 +189,9 @@ class _FactorAnalysisBase(Model): ...@@ -170,6 +189,9 @@ class _FactorAnalysisBase(Model):
alpha_0=alpha_0, beta_0=beta_0, alpha_0=alpha_0, beta_0=beta_0,
A=W, sigmasq=sigmasq) A=W, sigmasq=sigmasq)
# Handle the mean separately since DiagonalRegression doesn't support affine :-/
self.mean = np.zeros(D_obs)
self.data_list = [] self.data_list = []
@property @property
...@@ -180,6 +202,11 @@ class _FactorAnalysisBase(Model): ...@@ -180,6 +202,11 @@ class _FactorAnalysisBase(Model):
def sigmasq(self): def sigmasq(self):
return self.regression.sigmasq_flat return self.regression.sigmasq_flat
def set_empirical_mean(self):
self.mean = np.zeros(self.D_obs)
for n in range(self.D_obs):
self.mean[n] = np.concatenate([d.X[d.mask[:,n] == 1, n] for d in self.data_list]).mean()
def add_data(self, data, mask=None, **kwargs): def add_data(self, data, mask=None, **kwargs):
self.data_list.append(self._states_class(self, data, mask=mask, **kwargs)) self.data_list.append(self._states_class(self, data, mask=mask, **kwargs))
return self.data_list[-1] return self.data_list[-1]
...@@ -188,7 +215,7 @@ class _FactorAnalysisBase(Model): ...@@ -188,7 +215,7 @@ class _FactorAnalysisBase(Model):
# Sample from the factor analysis model # Sample from the factor analysis model
W, sigmasq = self.W, self.sigmasq W, sigmasq = self.W, self.sigmasq
Z = np.random.randn(N, self.D_latent) Z = np.random.randn(N, self.D_latent)
X = np.dot(Z, W.T) + np.sqrt(sigmasq) * np.random.randn(N, self.D_obs) X = self.mean + np.dot(Z, W.T) + np.sqrt(sigmasq) * np.random.randn(N, self.D_obs)
data = self._states_class(self, X, mask=mask, **kwargs) data = self._states_class(self, X, mask=mask, **kwargs)
data.Z = Z data.Z = Z
...@@ -204,7 +231,6 @@ class _FactorAnalysisBase(Model): ...@@ -204,7 +231,6 @@ class _FactorAnalysisBase(Model):
def log_likelihood(self): def log_likelihood(self):
return sum([d.log_likelihood().sum() for d in self.data_list]) return sum([d.log_likelihood().sum() for d in self.data_list])
def log_probability(self): def log_probability(self):
lp = 0 lp = 0
...@@ -216,7 +242,7 @@ class _FactorAnalysisBase(Model): ...@@ -216,7 +242,7 @@ class _FactorAnalysisBase(Model):
return lp return lp
class _FactorAnalysisGibbs(ModelGibbsSampling, _FactorAnalysisBase): class _FactorAnalysisGibbs(_FactorAnalysisBase, ModelGibbsSampling):
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def resample_model(self): def resample_model(self):
...@@ -229,7 +255,7 @@ class _FactorAnalysisGibbs(ModelGibbsSampling, _FactorAnalysisBase): ...@@ -229,7 +255,7 @@ class _FactorAnalysisGibbs(ModelGibbsSampling, _FactorAnalysisBase):
self.regression.resample((Zs, Xs), mask=mask) self.regression.resample((Zs, Xs), mask=mask)
class _FactorAnalysisEM(ModelEM, _FactorAnalysisBase): class _FactorAnalysisEM(_FactorAnalysisBase, ModelEM):
def _null_stats(self): def _null_stats(self):
return objarray( return objarray(
...@@ -238,19 +264,16 @@ class _FactorAnalysisEM(ModelEM, _FactorAnalysisBase): ...@@ -238,19 +264,16 @@ class _FactorAnalysisEM(ModelEM, _FactorAnalysisBase):
np.zeros((self.D_obs, self.D_latent, self.D_latent)), np.zeros((self.D_obs, self.D_latent, self.D_latent)),
np.zeros(self.D_obs)]) np.zeros(self.D_obs)])
def log_likelihood(self):
# TODO: Fix inheritance issues...
return np.sum([d.log_likelihood() for d in self.data_list])
def EM_step(self): def EM_step(self):
for data in self.data_list: for data in self.data_list:
data.E_step() data.E_step()
stats = self._null_stats() + sum([d.E_emission_stats for d in self.data_list]) stats = self._null_stats() + sum([d.E_emission_stats for d in self.data_list])
self.regression.max_likelihood(data=None, weights=None, stats=stats) self.regression.max_likelihood(data=None, weights=None, stats=stats)
assert np.all(np.isfinite(self.sigmasq ))
class _FactorAnalysisMeanField(ModelMeanField, ModelMeanFieldSVI, _FactorAnalysisBase): class _FactorAnalysisMeanField(_FactorAnalysisBase, ModelMeanField, ModelMeanFieldSVI):
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def _null_stats(self): def _null_stats(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment