Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • sbruijns/ihmm_behav_states
1 result
Show changes
No preview for this file type
File added
File added
File added
This diff is collapsed.
This diff is collapsed.
......@@ -99,6 +99,40 @@ def find_good_chains_unsplit_fast(chains1, chains2, chains3, chains4, reduce_to=
return sol, r_hat_min
def find_good_chains_unsplit(chains1, chains2, chains3, chains4, reduce_to=8, simple=False):
delete_n = - reduce_to + chains1.shape[0]
mins = np.zeros(delete_n + 1)
n_chains = chains1.shape[0]
chains = np.stack([chains1, chains2, chains3, chains4])
print("Without removals: {}".format(eval_simple_r_hat(chains)))
if simple:
r_hat = eval_simple_r_hat(chains)
else:
r_hat = eval_r_hat(chains1, chains2, chains3, chains4)
mins[0] = r_hat
for i in range(delete_n):
print()
r_hat_min = 10
sol = 0
for x in combinations(range(n_chains), n_chains - 1 - i):
if simple:
r_hat = eval_simple_r_hat(np.delete(chains, x, axis=1))
else:
r_hat = eval_r_hat(np.delete(chains1, x, axis=0), np.delete(chains2, x, axis=0), np.delete(chains3, x, axis=0), np.delete(chains4, x, axis=0))
if r_hat < r_hat_min:
sol = x
r_hat_min = min(r_hat, r_hat_min)
print("Minimum is {} (removed {})".format(r_hat_min, i + 1))
sol = [i for i in range(32) if i not in sol]
print("Removed: {}".format(sol))
mins[i + 1] = r_hat_min
return sol, r_hat_min
def params_to_pmf(params):
return params[2] + (1 - params[2] - params[3]) / (1 + np.exp(- params[0] * (all_conts - params[1])))
......@@ -214,7 +248,7 @@ if __name__ == "__main__":
plt.savefig("pmf fit scatter")
plt.show()
# New things
xy = np.vstack([short_pmfs[:, 0], function_range, short_pmfs[:, 1]])
z = gaussian_kde(xy)(xy)
plt.figure(figsize=(24, 24 / 3))
......
"""
Functions to extract statistics from a set of chains
Functions to compute R^hat from a set of statistic vectors
"""
import numpy as np
from scipy.stats import rankdata, norm
import pickle
def state_size_helper(n=0, mode_specific=False):
if not mode_specific:
def nth_largest_state_func(x):
return np.partition(x.assign_counts, -1 - n, axis=1)[:, -1 - n]
else:
def nth_largest_state_func(x, ind):
return np.partition(x.assign_counts[ind], -1 - n, axis=1)[:, -1 - n]
return nth_largest_state_func
def state_num_helper(t, mode_specific=False):
if not mode_specific:
def state_num_func(x): return ((x.assign_counts / x.n_datapoints) > t).sum(1)
else:
def state_num_func(x, ind): return ((x.assign_counts[ind] / x.n_datapoints) > t).sum(1)
return state_num_func
def gamma_func(x): return x.trans_distn.gamma
def alpha_func(x): return x.trans_distn.alpha
def ll_func(x): return x.sample_lls[-x.n_samples:]
def r_hat_array_comp(chains):
m, n = chains.shape
psi_dot_j = np.mean(chains, axis=1)
psi_dot_dot = np.mean(psi_dot_j)
B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot) ** 2)
s_j_squared = np.sum((chains - psi_dot_j[:, None]) ** 2, axis=1) / (n - 1)
W = np.mean(s_j_squared)
var_hat_plus = (n - 1) / n * W + B / n
if W == 0:
# print("all the same value")
return 1, 0
r_hat = np.sqrt(var_hat_plus / W)
return r_hat, var_hat_plus
def eval_amortized_r_hat(chains, psi_dot_j, s_j_squared, m, n):
psi_dot_dot = np.mean(psi_dot_j, axis=1)
B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot[:, None]) ** 2, axis=1)
W = np.mean(s_j_squared, axis=1)
var_hat_plus = (n - 1) / n * W + B / n
r_hat = np.sqrt(var_hat_plus / W)
return max(r_hat)
def r_hat_array_comp_mult(chains):
_, m, n = chains.shape
psi_dot_j = np.mean(chains, axis=2)
psi_dot_dot = np.mean(psi_dot_j, axis=1)
B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot[:, None]) ** 2, axis=1)
s_j_squared = np.sum((chains - psi_dot_j[:, :, None]) ** 2, axis=2) / (n - 1)
W = np.mean(s_j_squared, axis=1)
var_hat_plus = (n - 1) / n * W + B / n
r_hat = np.sqrt(var_hat_plus / W)
return r_hat, var_hat_plus
def rank_inv_normal_transform(chains):
# Gelman paper Rank-normalization, folding, and localization: An improved R_hat for assessing convergence of MCMC
# ranking with average rank for ties
folded_chains = np.abs(chains - np.median(chains))
ranked = rankdata(chains).reshape(chains.shape)
folded_ranked = rankdata(folded_chains).reshape(folded_chains.shape)
# inverse normal with fractional offset
rank_normalised = norm.ppf((ranked - 3/8) / (chains.size + 1/4))
folded_rank_normalised = norm.ppf((folded_ranked - 3/8) / (folded_chains.size + 1/4))
return rank_normalised, folded_rank_normalised, ranked, folded_ranked
def eval_r_hat(chains):
r_hats = []
for chain in chains:
rank_normalised, folded_rank_normalised, _, _ = rank_inv_normal_transform(chain)
r_hats.append(comp_multi_r_hat(chain, rank_normalised, folded_rank_normalised))
return max(r_hats)
def eval_simple_r_hat(chains):
r_hats, _ = r_hat_array_comp_mult(chains)
return max(r_hats)
def comp_multi_r_hat(chains, rank_normalised, folded_rank_normalised):
lame_r_hat, _ = r_hat_array_comp(chains)
rank_normalised_r_hat, _ = r_hat_array_comp(rank_normalised)
folded_rank_normalised_r_hat, _ = r_hat_array_comp(folded_rank_normalised)
return max(lame_r_hat, rank_normalised_r_hat, folded_rank_normalised_r_hat)
def sample_statistics(test, mode_indices, subject):
# prints out r_hats and sample sizes for given sample
test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb'))
test.r_hat_and_ess(state_size_helper(1), False)
test.r_hat_and_ess(state_size_helper(1, mode_specific=True), False, mode_indices=mode_indices)
print()
test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb'))
test.r_hat_and_ess(state_size_helper(), False)
test.r_hat_and_ess(state_size_helper(mode_specific=True), False, mode_indices=mode_indices)
print()
test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb'))
test.r_hat_and_ess(state_num_helper(0.05), False)
test.r_hat_and_ess(state_num_helper(0.05, mode_specific=True), False, mode_indices=mode_indices)
print()
test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb'))
test.r_hat_and_ess(state_num_helper(0.02), False)
test.r_hat_and_ess(state_num_helper(0.02, mode_specific=True), False, mode_indices=mode_indices)
print()
"""
Script for taking a list of subects and extracting statistics from the chains
which can be used to assess which chains have converged to the same regions
"""
import numpy as np
import pyhsmm
import pickle
......@@ -5,25 +9,7 @@ import json
from dyn_glm_chain_analysis import MCMC_result
import matplotlib.pyplot as plt
import time
def gamma_func(x): return x.trans_distn.gamma
def alpha_func(x): return x.trans_distn.alpha
def state_size_helper(n=0):
def nth_largest_state_func(x):
return np.partition(x.assign_counts, -1 - n, axis=1)[:, -1 - n]
return nth_largest_state_func
def state_num_helper(t):
def state_num_func(x): return ((x.assign_counts / x.n_datapoints) > t).sum(1)
return state_num_func
def ll_func(x): return x.sample_lls[-x.n_samples:]
from mcmc_chain_analysis import state_size_helper, state_num_helper, ll_helper
fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0]
......@@ -31,7 +17,7 @@ if fit_type == 'bias':
loading_info = json.load(open("canonical_infos_bias.json", 'r'))
elif fit_type == 'prebias':
loading_info = json.load(open("canonical_infos.json", 'r'))
subjects = list(loading_info.keys())
subjects = ['GLM_Sim_11_trick'] # list(loading_info.keys())
fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0]
func1 = state_num_helper(0.2)
......@@ -45,117 +31,79 @@ func8 = state_size_helper(4)
func9 = state_size_helper(5)
func10 = state_size_helper(6)
func11 = state_size_helper(7)
model_for_loop = False
dur = 'yes'
def temp():
m = 16
for subject in subjects:
print(subject)
n_runs = -1
counter = -1
n = (loading_info[subject]['chain_num'] + 1) * 4000 // 25
chains1 = np.zeros((m, n))
chains2 = np.zeros((m, n))
chains3 = np.zeros((m, n))
chains4 = np.zeros((m, n))
for j, (seed, fit_num) in enumerate(zip(loading_info[subject]['seeds'], loading_info[subject]['fit_nums'])):
counter += 1
print(seed)
info_dict = pickle.load(open("./session_data/{}_info_dict.p".format(subject), "rb"))
samples = []
mini_counter = 0
while True:
try:
file = "./dynamic_GLMiHMM_crossvals/{}_fittype_{}_var_{}_{}_{}{}.p".format(subject, fit_type, fit_variance, seed, fit_num, '_{}'.format(mini_counter))
lala = time.time()
samples += pickle.load(open(file, "rb"))
print("Loading {} took {:.4}".format(mini_counter, time.time() - lala))
mini_counter += 1
except Exception:
break
if n_runs == -1:
n_runs = mini_counter
else:
if n_runs != mini_counter:
print("Problem")
print(n_runs, mini_counter)
quit()
save_id = "{}_fittype_{}_var_{}_{}_{}.p".format(subject, fit_type, fit_variance, seed, fit_num).replace('.', '_')
# for file in os.listdir("./dynamic_GLMiHMM_crossvals/infos_new/"):
# if file.startswith("{}_".format(subject)) and file.endswith("_{}_{}_{}_{}.json".format(fit_type, fit_variance, seed, fit_num)):
# print("Taking fit infos from {}".format(file))
# fit_infos = json.load(open("./dynamic_GLMiHMM_crossvals/infos_new/" + file, 'r'))
# sample_lls = fit_infos['ll']
print("loaded seed {}".format(seed))
lala = time.time()
result = MCMC_result(samples[::25],
infos=info_dict, data=samples[0].datas,
sessions=fit_type, fit_variance=fit_variance,
dur=dur, save_id=save_id) # , sample_lls=sample_lls)
print("Making result {} took {:.4}".format(counter, time.time() - lala))
# if model_for_loop:
# for i in range(n):
# chains[j, i] = func(result.models[i])
res = func1(result)
chains1[counter] = res
res = func2(result)
chains2[counter] = res
res = func3(result)
chains3[counter] = res
res = func4(result)
chains4[counter] = res
# res = func5(result)
# chains5[j] = res
# res = func6(result)
# chains6[j] = res
# res = func7(result)
# chains7[j] = res
# res = func8(result)
# chains8[j] = res
# res = func9(result)
# chains9[j] = res
# res = func10(result)
# chains10[j] = res
# res = func11(result)
# chains11[j] = res
# func2 = state_size_helper()
# func5 = state_size_helper(1)
# func6 = state_size_helper(2)
# func7 = state_size_helper(3)
# func8 = state_size_helper(4)
# func9 = state_size_helper(5)
# func10 = state_size_helper(6)
# func11 = state_size_helper(7)
# plt.plot(chains2.flatten()[::10])
# plt.plot(chains5.flatten()[::10])
# plt.plot(chains6.flatten()[::10])
# plt.plot(chains7.flatten()[::10])
# plt.plot(chains8.flatten()[::10])
# plt.plot(chains9.flatten()[::10])
# plt.plot(chains10.flatten()[::10])
# plt.plot(chains11.flatten()[::10])
# plt.axhline(1643, color='r')
# plt.axhline(3438, color='r')
# plt.axhline(2216, color='r')
# plt.axhline(811, color='r')
# plt.axhline(2721, color='r')
# plt.axhline(743, color='r')
# plt.show()
pickle.dump(chains1, open("multi_chain_saves/{}_state_num_0_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb'))
pickle.dump(chains2, open("multi_chain_saves/{}_state_num_1_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb'))
pickle.dump(chains3, open("multi_chain_saves/{}_largest_state_0_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb'))
pickle.dump(chains4, open("multi_chain_saves/{}_largest_state_1_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb'))
temp()
# import pyhsmm.util.profiling as prof
# prof_func = prof._prof(temp)
# prof_func()
# prof._prof.print_stats()
m = 16
for subject in subjects:
print(subject)
n_runs = -1
counter = -1
n = (loading_info[subject]['chain_num'] + 1) * 4000 // 25
chains1 = np.zeros((m, n))
chains2 = np.zeros((m, n))
chains3 = np.zeros((m, n))
chains4 = np.zeros((m, n))
for j, (seed, fit_num) in enumerate(zip(loading_info[subject]['seeds'], loading_info[subject]['fit_nums'])):
counter += 1
print(seed)
info_dict = pickle.load(open("./session_data/{}_info_dict.p".format(subject), "rb"))
samples = []
mini_counter = 0
while True:
try:
file = "./dynamic_GLMiHMM_crossvals/{}_fittype_{}_var_{}_{}_{}{}.p".format(subject, fit_type, fit_variance, seed, fit_num, '_{}'.format(mini_counter))
lala = time.time()
samples += pickle.load(open(file, "rb"))
print("Loading {} took {:.4}".format(mini_counter, time.time() - lala))
mini_counter += 1
except Exception:
break
if n_runs == -1:
n_runs = mini_counter
else:
if n_runs != mini_counter:
print("Problem")
print(n_runs, mini_counter)
quit()
save_id = "{}_fittype_{}_var_{}_{}_{}.p".format(subject, fit_type, fit_variance, seed, fit_num).replace('.', '_')
# for file in os.listdir("./dynamic_GLMiHMM_crossvals/infos_new/"):
# if file.startswith("{}_".format(subject)) and file.endswith("_{}_{}_{}_{}.json".format(fit_type, fit_variance, seed, fit_num)):
# print("Taking fit infos from {}".format(file))
# fit_infos = json.load(open("./dynamic_GLMiHMM_crossvals/infos_new/" + file, 'r'))
# sample_lls = fit_infos['ll']
print("loaded seed {}".format(seed))
result = MCMC_result(samples[::25],
infos=info_dict, data=samples[0].datas,
sessions=fit_type, fit_variance=fit_variance,
dur=dur, save_id=save_id) # , sample_lls=sample_lls)
print("Making result {} took {:.4}".format(counter, time.time() - lala))
res = func1(result)
chains1[counter] = res
res = func2(result)
chains2[counter] = res
res = func3(result)
chains3[counter] = res
res = func4(result)
chains4[counter] = res
# res = func5(result)
# chains5[j] = res
# res = func6(result)
# chains6[j] = res
# res = func7(result)
# chains7[j] = res
# res = func8(result)
# chains8[j] = res
# res = func9(result)
# chains9[j] = res
# res = func10(result)
# chains10[j] = res
# res = func11(result)
# chains11[j] = res
pickle.dump(chains1, open("multi_chain_saves/{}_state_num_0_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb'))
pickle.dump(chains2, open("multi_chain_saves/{}_state_num_1_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb'))
pickle.dump(chains3, open("multi_chain_saves/{}_largest_state_0_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb'))
pickle.dump(chains4, open("multi_chain_saves/{}_largest_state_1_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb'))
"""
Visualize points on the 3-simplex (eg, the parameters of a
3-dimensional multinomial distributions) as a scatter plot
contained within a 2D triangle.
Adapted from David Andrzejewski (david.andrzej@gmail.com)
"""
import numpy as np
import matplotlib.pyplot as P
import matplotlib.ticker as MT
import matplotlib.lines as L
import matplotlib.cm as CM
import matplotlib.colors as C
import matplotlib.patches as PA
def plotSimplex(points, fig=None,
vertexlabels=['Type 1', 'Type 2', 'Type 3'], save_title="test.png",
show=False, **kwargs):
"""
Plot Nx3 points array on the 3-simplex
(with optionally labeled vertices)
kwargs will be passed along directly to matplotlib.pyplot.scatter
"""
if fig is None:
fig = P.figure(figsize=(9, 9))
# Draw the triangle
l1 = L.Line2D([0, 0.5, 1.0, 0], # xcoords
[0, np.sqrt(3) / 2, 0, 0], # ycoords
color='k')
fig.gca().add_line(l1)
fig.gca().xaxis.set_major_locator(MT.NullLocator())
fig.gca().yaxis.set_major_locator(MT.NullLocator())
# Draw vertex labels
fig.gca().text(-0.06, -0.05, vertexlabels[0], size=24)
fig.gca().text(0.95, -0.05, vertexlabels[1], size=24)
fig.gca().text(0.43, np.sqrt(3) / 2 + 0.025, vertexlabels[2], size=24)
# Project and draw the actual points
projected = projectSimplex(points / points.sum(1)[:, None])
P.scatter(projected[:, 0], projected[:, 1], s=points.sum(1) * 3.5, **kwargs)
# plot center with average size
projected = projectSimplex(np.mean(points / points.sum(1)[:, None], axis=0).reshape(1, 3))
P.scatter(projected[:, 0], projected[:, 1], marker='*', color='r', s=np.mean(points.sum(1)) * 3.5)
# Leave some buffer around the triangle for vertex labels
fig.gca().set_xlim(-0.05, 1.05)
fig.gca().set_ylim(-0.05, 1.05)
P.axis('off')
P.savefig("dur_simplex.png", bbox_inches='tight')
if show:
P.show()
else:
P.close()
def projectSimplex(points):
"""
Project probabilities on the 3-simplex to a 2D triangle
N points are given as N x 3 array
"""
# Convert points one at a time
tripts = np.zeros((points.shape[0], 2))
for idx in range(points.shape[0]):
# Init to triangle centroid
x = 1.0 / 2
y = 1.0 / (2 * np.sqrt(3))
# Vector 1 - bisect out of lower left vertex
p1 = points[idx, 0]
x = x - (1.0 / np.sqrt(3)) * p1 * np.cos(np.pi / 6)
y = y - (1.0 / np.sqrt(3)) * p1 * np.sin(np.pi / 6)
# Vector 2 - bisect out of lower right vertex
p2 = points[idx, 1]
x = x + (1.0 / np.sqrt(3)) * p2 * np.cos(np.pi / 6)
y = y - (1.0 / np.sqrt(3)) * p2 * np.sin(np.pi / 6)
# Vector 3 - bisect out of top vertex
p3 = points[idx, 2]
y = y + (1.0 / np.sqrt(3) * p3)
tripts[idx, :] = (x, y)
return tripts
if __name__ == '__main__':
# Define a synthetic test dataset
labels = ('[0.1 0.1 0.8]',
'[0.8 0.1 0.1]',
'[0.5 0.4 0.1]',
'[0.33 0.34 0.33]')
testpoints = np.array([[0.1, 0.1, 0.8],
[0.8, 0.1, 0.1],
[0.5, 0.4, 0.1],
[0.33, 0.34, 0.33]])
# Define different colors for each label
c = range(len(labels))
# Do scatter plot
fig = plotSimplex(testpoints, s=25, c='k')
P.show()