Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • sbruijns/ihmm_behav_states
1 result
Show changes
Commits on Source (5)
Showing
with 339 additions and 213 deletions
No preview for this file type
No preview for this file type
This diff is collapsed.
......@@ -132,6 +132,29 @@ def find_good_chains_unsplit(chains1, chains2, chains3, chains4, reduce_to=8, si
return sol, r_hat_min
def state_glm_func_helper(t, mode_specific=False):
if not mode_specific:
def temp_state_glm_func(x):
glms = np.zeros((x.n_samples, 4))
args = np.argsort(x.assign_counts, 1)[:, -1 - t]
for i, (m, s, n) in enumerate(zip(x.models, args, x.assign_counts[np.arange(x.assign_counts.shape[0]), args])):
for j, seq in enumerate(m.stateseqs):
if s in seq:
glms[i] += np.sum(s == seq) * m.obs_distns[s].weights[j]
glms[i] /= n
return glms
else:
def temp_state_glm_func(x, ind):
glms = np.zeros((x.n_samples, 4))
np.argsort(x.assign_counts)[:-1 - n]
for i, (m, s, n) in enumerate(zip(x.models, x.assign_counts.argmax(1), x.assign_counts.max(1))):
for j, seq in enumerate(m.stateseqs):
if s in seq:
glms[i] += np.sum(s == seq) * m.obs_distns[s].weights[j]
glms[i] /= n
return glms[ind]
return temp_state_glm_func
def params_to_pmf(params):
return params[2] + (1 - params[2] - params[3]) / (1 + np.exp(- params[0] * (all_conts - params[1])))
......
......@@ -42,10 +42,10 @@ for subject in subjects:
from_session = info_dict['bias_start'] if fit_type == 'bias' else 0
models = []
n_inputs = 5
n_regressors = 5
T = till_session - from_session + (fit_type != 'prebias')
obs_hypparams = {'n_inputs': n_inputs, 'T': T, 'prior_mean': np.zeros(n_inputs),
'P_0': 2 * np.eye(n_inputs), 'Q': fit_variance * np.tile(np.eye(n_inputs), (T, 1, 1))}
obs_hypparams = {'n_regressors': n_regressors, 'T': T, 'prior_mean': np.zeros(n_regressors), 'jumplimit': 3,
'P_0': 2 * np.eye(n_regressors), 'Q': fit_variance * np.tile(np.eye(n_regressors), (T, 1, 1))}
dur_hypparams = dict(r_support=np.array([1, 2, 3, 5, 7, 10, 15, 21, 28, 36, 45, 55, 150]),
r_probs=np.ones(13)/13., alpha_0=1, beta_0=1)
......@@ -78,7 +78,7 @@ for subject in subjects:
bad_trials = data[:, 1] == 1
bad_trials[0] = True
mega_data = np.empty((np.sum(~bad_trials), n_inputs + 1))
mega_data = np.empty((np.sum(~bad_trials), n_regressors + 1))
mega_data[:, 0] = np.maximum(data[~bad_trials, 0], 0)
mega_data[:, 1] = np.abs(np.minimum(data[~bad_trials, 0], 0))
......@@ -123,7 +123,10 @@ for subject in subjects:
prev_res = pickle.load(open(save_title, 'rb'))
counter = 0
for p, m in zip(prev_res, models):
print(counter)
counter += 1
for od, nd in zip(p.obs_distns, m.obs_distns):
assert np.allclose(od.weights, nd.weights)
prof._prof.print_stats()
......
......@@ -204,14 +204,14 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
models = []
if params['obs_dur'] == 'glm':
n_inputs = len(params['regressors'])
n_regressors = len(params['regressors'])
T = till_session - from_session + (params['fit_type'] != 'prebias')
obs_hypparams = {'n_inputs': n_inputs, 'T': T, 'jumplimit': params['jumplimit'], 'prior_mean': params['init_mean'],
'P_0': params['init_var'] * np.eye(n_inputs), 'Q': params['fit_variance'] * np.tile(np.eye(n_inputs), (T, 1, 1))}
obs_hypparams = {'n_regressors': n_regressors, 'T': T, 'jumplimit': params['jumplimit'], 'prior_mean': params['init_mean'],
'P_0': params['init_var'] * np.eye(n_regressors), 'Q': params['fit_variance'] * np.tile(np.eye(n_regressors), (T, 1, 1))}
obs_distns = [distributions.Dynamic_GLM(**obs_hypparams) for state in range(params['n_states'])]
else:
n_inputs = 9 if params['fit_type'] == 'bias' else 11
obs_hypparams = {'n_inputs': n_inputs * (1 + (params['conditioned_on'] != 'nothing')), 'n_outputs': 2, 'T': till_session - from_session + (params['fit_type'] != 'prebias'),
n_regressors = 9 if params['fit_type'] == 'bias' else 11
obs_hypparams = {'n_regressors': n_regressors * (1 + (params['conditioned_on'] != 'nothing')), 'n_outputs': 2, 'T': till_session - from_session + (params['fit_type'] != 'prebias'),
'jumplimit': params['jumplimit'], 'sigmasq_states': params['fit_variance']}
obs_distns = [Dynamic_Input_Categorical(**obs_hypparams) for state in range(params['n_states'])]
......@@ -280,7 +280,7 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
mask[0] = False
if params['fit_type'] == 'zoe_style':
mask[90:] = False
mega_data = np.empty((np.sum(mask), n_inputs + 1))
mega_data = np.empty((np.sum(mask), n_regressors + 1))
for i, reg in enumerate(params['regressors']):
# positive numbers are contrast on the right
......
......@@ -8,6 +8,8 @@ import pickle
def state_size_helper(n=0, mode_specific=False):
"""Returns a function that returns the # of trials associated to the nth largest state in a sample
can be further specified to only look at specific samples, those of a mode"""
if not mode_specific:
def nth_largest_state_func(x):
return np.partition(x.assign_counts, -1 - n, axis=1)[:, -1 - n]
......@@ -18,6 +20,8 @@ def state_size_helper(n=0, mode_specific=False):
def state_num_helper(t, mode_specific=False):
"""Returns a function that returns the # of states which have more trials than a percentage threshold t in a sample
can be further specified to only look at specific samples, those of a mode"""
if not mode_specific:
def state_num_func(x): return ((x.assign_counts / x.n_datapoints) > t).sum(1)
else:
......@@ -35,14 +39,15 @@ def ll_func(x): return x.sample_lls[-x.n_samples:]
def r_hat_array_comp(chains):
m, n = chains.shape
"""Computes R^hat on an array of features, following Gelman p. 284f"""
m, n = chains.shape # number of chains, length of chains
psi_dot_j = np.mean(chains, axis=1)
psi_dot_dot = np.mean(psi_dot_j)
B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot) ** 2)
s_j_squared = np.sum((chains - psi_dot_j[:, None]) ** 2, axis=1) / (n - 1)
W = np.mean(s_j_squared)
var_hat_plus = (n - 1) / n * W + B / n
if W == 0:
if W == 0: # sometimes a feature has 0 variance
# print("all the same value")
return 1, 0
r_hat = np.sqrt(var_hat_plus / W)
......@@ -50,6 +55,7 @@ def r_hat_array_comp(chains):
def eval_amortized_r_hat(chains, psi_dot_j, s_j_squared, m, n):
"""Unused version in which some things were computed ahead of function to save time."""
psi_dot_dot = np.mean(psi_dot_j, axis=1)
B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot[:, None]) ** 2, axis=1)
W = np.mean(s_j_squared, axis=1)
......@@ -59,6 +65,7 @@ def eval_amortized_r_hat(chains, psi_dot_j, s_j_squared, m, n):
def r_hat_array_comp_mult(chains):
"""Compute R^hat of multiple features at once."""
_, m, n = chains.shape
psi_dot_j = np.mean(chains, axis=2)
psi_dot_dot = np.mean(psi_dot_j, axis=1)
......@@ -71,8 +78,8 @@ def r_hat_array_comp_mult(chains):
def rank_inv_normal_transform(chains):
# Gelman paper Rank-normalization, folding, and localization: An improved R_hat for assessing convergence of MCMC
# ranking with average rank for ties
"""Gelman paper Rank-normalization, folding, and localization: An improved R_hat for assessing convergence of MCMC
ranking with average rank for ties"""
folded_chains = np.abs(chains - np.median(chains))
ranked = rankdata(chains).reshape(chains.shape)
folded_ranked = rankdata(folded_chains).reshape(folded_chains.shape)
......@@ -83,6 +90,8 @@ def rank_inv_normal_transform(chains):
def eval_r_hat(chains):
"""Compute entire set of R^hat's for list of feature arrays, and return maximum across features.
Computes all R^hat versions, as opposed to eval_simple_r_hat"""
r_hats = []
for chain in chains:
rank_normalised, folded_rank_normalised, _, _ = rank_inv_normal_transform(chain)
......@@ -92,6 +101,8 @@ def eval_r_hat(chains):
def eval_simple_r_hat(chains):
"""Compute just simple R^hat's for list of feature arrays, and return maximum across features.
Computes only the simple type of R^hat, no folding or rank normalising, making it much faster"""
r_hats, _ = r_hat_array_comp_mult(chains)
return max(r_hats)
......@@ -103,21 +114,21 @@ def comp_multi_r_hat(chains, rank_normalised, folded_rank_normalised):
return max(lame_r_hat, rank_normalised_r_hat, folded_rank_normalised_r_hat)
def sample_statistics(test, mode_indices, subject):
def sample_statistics(mode_indices, subject, period='prebias'):
# prints out r_hats and sample sizes for given sample
test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb'))
test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, period), 'rb'))
test.r_hat_and_ess(state_size_helper(1), False)
test.r_hat_and_ess(state_size_helper(1, mode_specific=True), False, mode_indices=mode_indices)
print()
test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb'))
test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, period), 'rb'))
test.r_hat_and_ess(state_size_helper(), False)
test.r_hat_and_ess(state_size_helper(mode_specific=True), False, mode_indices=mode_indices)
print()
test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb'))
test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, period), 'rb'))
test.r_hat_and_ess(state_num_helper(0.05), False)
test.r_hat_and_ess(state_num_helper(0.05, mode_specific=True), False, mode_indices=mode_indices)
print()
test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb'))
test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, period), 'rb'))
test.r_hat_and_ess(state_num_helper(0.02), False)
test.r_hat_and_ess(state_num_helper(0.02, mode_specific=True), False, mode_indices=mode_indices)
print()
......@@ -37,6 +37,7 @@ def plotSimplex(points, fig=None,
fig.gca().text(0.43, np.sqrt(3) / 2 + 0.025, vertexlabels[2], size=24)
# Project and draw the actual points
projected = projectSimplex(points / points.sum(1)[:, None])
print(projected)
P.scatter(projected[:, 0], projected[:, 1], s=points.sum(1) * 3.5, **kwargs)
# plot center with average size
......@@ -90,14 +91,14 @@ if __name__ == '__main__':
labels = ('[0.1 0.1 0.8]',
'[0.8 0.1 0.1]',
'[0.5 0.4 0.1]',
'[0.17 0.33 0.5]',
'[0.33 0.34 0.33]')
testpoints = np.array([[0.1, 0.1, 0.8],
[0.8, 0.1, 0.1],
[0.5, 0.4, 0.1],
[0.17, 0.33, 0.5],
[0.33, 0.34, 0.33]])
# Define different colors for each label
c = range(len(labels))
# Do scatter plot
fig = plotSimplex(testpoints, s=25, c='k')
P.show()
fig = plotSimplex(testpoints, c='k', show=1)
import matplotlib
#matplotlib.use('Agg')
import numpy as np
import pyhsmm.basic.distributions as distributions
import time
......
import pymc, bayes_fit # load the model file
"""Perform a pymc sampling of the test data."""
import pymc, bayes_fit # load the model file
import numpy as np
import pickle
# Data params
T = 14
n_inputs = 3
# Sampling params
n_samples = 400000
R = pymc.MCMC(bayes_fit) # build the model
R.sample(n_samples) # populate and run it
R = pymc.MCMC(bayes_fit) # build the model
R.sample(n_samples) # populate and run it
T = 14
n_inputs = 3
# Extract weights
weights = np.zeros((T, n_samples, n_inputs))
for t in range(T):
try:
......@@ -16,4 +21,5 @@ for t in range(T):
except KeyError:
weights[t] = R.trace('ws'.format(t))
# Save everything
pickle.dump(weights, open('pymc_posterior', 'wb'))
"""
Need to find out whether loglikelihood is computed correctly.
Or whether a bug here allows states to invade each other more easily.
We'll test this by maximising the likelihood directly.
"""
import numpy as np
import pyhsmm.basic.distributions as distributions
from scipy.optimize import minimize
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
# Testing of Dynamic_GLM implementation
np.set_printoptions(suppress=True)
seed = np.random.randint(10000) # 215
print(seed)
seed = 2268
np.random.seed(seed)
T = 12
n_inputs = 3
step_size = 0.03
Q = np.tile(np.eye(n_inputs), (T, 1, 1))
test = distributions.Dynamic_GLM(n_inputs=n_inputs, T=T, P_0=4 * np.eye(n_inputs), Q=Q * step_size, prior_mean=np.zeros(n_inputs))
w = np.zeros(n_inputs)
# w = np.array([4.23061493, 2.14425199, -2.1125851])
# test.weights = w.reshape(T, n_inputs)
test_points = [0]
predictors = []
a, b = np.zeros(1000), np.zeros(1000)
a[250:500] = 1
a[750:1000] = 1
b[500:] = 1
for _ in range(T):
t = 1000
pred = np.empty((t, n_inputs))
pred[:, 0] = a
pred[:, 1] = b
pred[:, 2] = 1
predictors.append(pred)
sample = test.rvs(predictors, list(range(T)))
pickle.dump(sample, open('test_data', 'wb'))
learn = distributions.Dynamic_GLM(n_inputs=n_inputs, T=T, P_0=4 * np.eye(n_inputs), Q=Q * step_size, prior_mean=np.zeros(n_inputs))
def wrapper(w, t):
learn.weights = np.tile(w, (T, 1))
return - np.sum(learn.log_likelihood(sample[t], t))
print(test.weights)
n_samples = 150000
samples = []
pseudo_samples = []
for _ in range(n_samples):
if _ % 100 == 0:
print(_)
temp = learn.resample(sample)
pseudo_samples.append(temp)
samples.append(learn.weights.copy())
samples = np.array(samples[30000:])
LL_weights = np.zeros((T, n_inputs))
for t in range(T):
f = lambda w: wrapper(w, t)
LL_weights[t] = minimize(f, np.zeros(n_inputs)).x
pickle.dump((samples, LL_weights), open('gibbs_posterior', 'wb'))
......@@ -2,7 +2,7 @@
Need to find out whether loglikelihood is computed correctly.
Or whether a bug here allows states to invade each other more easily.
We'll test this by maximising the likelihood directly.
We'll test this by comparing to pymc results.
"""
import numpy as np
import pyhsmm.basic.distributions as distributions
......
"""
Perform a Gibbs sampling using my function of the test data.
Also perform maximum likelihood estimation of the same weights.
"""
import numpy as np
import pyhsmm.basic.distributions as distributions
from scipy.optimize import minimize
import pickle
n_samples = 100000
# Data params
T = 16
n_inputs = 3
step_size = 0.2
# Sampling params
n_samples = 100000
# Setup
Q = np.tile(np.eye(n_inputs), (T, 1, 1))
sample = pickle.load(open('test_data', 'rb'))
learn = distributions.Dynamic_GLM(n_inputs=n_inputs, T=T, P_0=4 * np.eye(n_inputs), Q=Q * step_size, prior_mean=np.zeros(n_inputs))
def wrapper(w, t):
learn.weights = np.tile(w, (T, 1))
return - np.sum(learn.log_likelihood(sample[t], t))
# Draw samples
samples = []
pseudo_samples = []
for _ in range(n_samples):
if _ % 1000 == 0:
print(_)
learn.resample(sample)
samples.append(learn.weights.copy())
def wrapper(w, t):
"""Reshape weight vector w into the correct shape, then compute the max ll estimate for the desired time t."""
learn.weights = np.tile(w, (T, 1))
return - np.sum(learn.log_likelihood(sample[t], t))
# Compute max ll estimates
LL_weights = np.zeros((T, n_inputs))
for t in range(T):
f = lambda w: wrapper(w, t)
LL_weights[t] = minimize(f, np.zeros(n_inputs)).x
LL_weights[t] = minimize(lambda w: wrapper(w, t), np.zeros(n_inputs)).x
# Save everything
pickle.dump((samples, LL_weights), open('gibbs_posterior', 'wb'))
......@@ -16,6 +16,7 @@ samples = np.array(samples)[gibbs_burnin:]
test = pickle.load(open('truth', 'rb'))
pymc_weights = pickle.load(open('pymc_posterior', 'rb'))
def temp(x):
if x < 11:
return x
......@@ -24,6 +25,7 @@ def temp(x):
else:
return x - 2
m = np.zeros((T, n_inputs))
u = np.zeros((T, n_inputs))
low = np.zeros((T, n_inputs))
......@@ -35,22 +37,17 @@ for t in range(T):
low[t, i] = np.percentile(w[:, i], 2.5)
plt.figure(figsize=(16, 9))
for i in range(n_inputs):
plt.subplot(n_inputs, 1, i+1)
label = 'Truth' if i == 0 else None
plt.plot(np.arange(T), test.weights[:, i], label=label)
label = 'LL' if i == 0 else None
# plt.plot(np.arange(T), LL_weights[:, i], label=label)
plt.plot(np.arange(T), test.weights[:, i], label='Truth')
plt.plot(np.arange(T), LL_weights[:, i], label='LL')
sample_mean = np.mean(samples[:, :, i], axis=0)
label = 'Posterior mean' if i == 0 else None
credible_interval = np.percentile(samples[:, :, i], [2.5, 97.5], axis=0)
plt.plot(np.arange(T), sample_mean, label=label, c='g')
plt.plot(np.arange(T), sample_mean, label='Posterior mean', c='g')
plt.fill_between(np.arange(T), credible_interval[1], credible_interval[0], alpha=0.2, color='g')
label = 'pymc mean' if i == 0 else None
plt.plot(np.arange(T), m[:, i], label=label, c='r')
plt.plot(np.arange(T), m[:, i], label='pymc mean', c='r')
plt.fill_between(np.arange(T), u[:, i], low[:, i], alpha=0.2, color='r')
sns.despine()
......
File deleted