Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • sbruijns/ihmm_behav_states
1 result
Show changes
Commits on Source (3)
Showing
with 868 additions and 701 deletions
No preview for this file type
File added
No preview for this file type
......@@ -73,7 +73,7 @@ to_introduce = [2, 3, 4, 5]
# "ibl_witten_06", "ibl_witten_07", "ibl_witten_12", "ibl_witten_13", "ibl_witten_14", "ibl_witten_15",
# "ibl_witten_16", "KS003", "KS005", "KS019", "NYU-01", "NYU-02", "NYU-04", "NYU-06", "ZM_1367", "ZM_1369",
# "ZM_1371", "ZM_1372", "ZM_1743", "ZM_1745", "ZM_1746"] # zoe's subjects
subjects = ['ibl_witten_14']
subjects = ['ZFM-05236']
data_folder = 'session_data'
# why does CSHL058 not work?
......@@ -149,7 +149,8 @@ for subject in subjects:
contrast_set = {0, 1, 9, 10}
rel_count = -1
for i, (eid, extra_eids) in enumerate(zip(fixed_eids, additional_eids)):
quit()
for i, (eid, extra_eids, date) in enumerate(zip(fixed_eids, additional_eids, fixed_dates)):
try:
trials = one.load_object(eid, 'trials')
......@@ -182,6 +183,8 @@ for subject in subjects:
df = pd.concat([df, df2], ignore_index=1)
print('new size: {}'.format(len(df)))
pickle.dump(df, open("./sofiya_data/{}_df_{}_{}.p".format(subject, rel_count, date), "wb"))
current_contrasts = set(df['signed_contrast'])
diff = current_contrasts.difference(contrast_set)
for c in to_introduce:
......
......@@ -34,14 +34,12 @@ misses = []
to_introduce = [2, 3, 4, 5]
amiss = ['UCLA034', 'UCLA036', 'UCLA037', 'PL015', 'PL016', 'PL017', 'PL024', 'NR_0017', 'NR_0019', 'NR_0020', 'NR_0021', 'NR_0027']
subjects = ['NYU-21']
subjects = ['ZFM-04019', 'ZFM-05236']
fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0]
if fit_type == 'bias':
loading_info = json.load(open("canonical_infos_bias.json", 'r'))
r_hats = json.load(open("canonical_info_r_hats_bias.json", 'r'))
elif fit_type == 'prebias':
loading_info = json.load(open("canonical_infos.json", 'r'))
r_hats = json.load(open("canonical_info_r_hats.json", 'r'))
already_fit = list(loading_info.keys())
remaining_subs = [s for s in subjects if s not in amiss and s not in already_fit]
......@@ -70,13 +68,14 @@ for subject in subjects:
print('_____________________')
print(subject)
if subject in already_fit or subject in amiss:
continue
# if subject in already_fit or subject in amiss:
# continue
trials = one.load_aggregate('subjects', subject, '_ibl_subjectTrials.table')
# Load training status and join to trials table
training = one.load_aggregate('subjects', subject, '_ibl_subjectTraining.table')
quit()
trials = (trials
.set_index('session')
.join(training.set_index('session'))
......
This diff is collapsed.
"""
Script for combining local canonical_infos.json and the one from the cluster
"""
import json
dist_info = json.load(open("canonical_infos.json", 'r'))
local_info = json.load(open("canonical_infos_local.json", 'r'))
cluster_subs = ['KS045', 'KS043', 'KS051', 'DY_008', 'KS084', 'KS052', 'KS046', 'KS096', 'KS086', 'UCLA033', 'UCLA005', 'NYU-21', 'KS055', 'KS091']
for key in cluster_subs:
print('ignore' in dist_info[key])
quit()
for key in cluster_subs:
if key not in local_info:
print("Adding all of {} to local info".format(key))
local_info[key] = dist_info[key]
continue
else:
for sub_key in dist_info[key]:
if sub_key not in local_info[key]:
print("Adding {} into local info for {}".format(key))
local_info[key][sub_key] = dist_info[key][sub_key]
else:
if local_info[key][sub_key] == dist_info[key][sub_key]:
continue
else:
assert len(dist_info[key][sub_key]) == 16
for x in dist_info[key][sub_key]:
assert x in local_info[key][sub_key]
local_info[key][sub_key] = dist_info[key][sub_key]
This diff is collapsed.
......@@ -87,6 +87,7 @@ for subject in subjects:
mega_data[:, 4] = 1
mega_data[:, 5] = data[~bad_trials, 1] - 1
mega_data[:, 5] = (mega_data[:, 5] + 1) / 2
print(mega_data.sum(0))
posteriormodel.add_data(mega_data)
import pyhsmm.util.profiling as prof
......
......@@ -19,15 +19,6 @@ import json
import sys
def crp_expec(n, theta):
"""
Return expected number of tables after n customers, given concentration theta.
From Wikipedia
"""
return theta * (digamma(theta + n) - digamma(theta))
def eleven2nine(x):
"""Map from 11 possible contrasts to 9, for the non-training phases.
......@@ -87,13 +78,7 @@ subjects = ['ibl_witten_15', 'ibl_witten_17', 'ibl_witten_18', 'ibl_witten_19',
# test subjects:
subjects = ['KS014']
# subjects = ['KS021', 'KS016', 'ibl_witten_16', 'SWC_022', 'KS003', 'CSHL054', 'ZM_3003', 'KS015', 'ibl_witten_13', 'CSHL059', 'CSH_ZAD_022', 'CSHL_007', 'CSHL062', 'NYU-06', 'KS014', 'ibl_witten_14', 'SWC_023']
# subjects = [['GLM_Sim_15', 'GLM_Sim_14', 'GLM_Sim_13', 'GLM_Sim_11', 'GLM_Sim_10', 'GLM_Sim_09', 'GLM_Sim_12'][2]]
# (0.03, 0.3, 5, 'contR', 'contL', 'prevA', 'bias', 1, 0.1):
cv_nums = [15]
# conda activate hdp_pg_env
# python dynamic_GLMiHMM_fit.py
cv_nums = [200 + int(sys.argv[1]) % 16]
subjects = [subjects[int(sys.argv[1]) // 16]]
......@@ -267,12 +252,6 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
print("meh, skipped session")
continue
# if j == 15:
# import matplotlib.pyplot as plt
# for i in [0, 2, 3,4,5,6,7,8,10]:
# plt.plot(i, data[data[:, 0] == i, 1].mean(), 'ko')
# plt.show()
if params['obs_dur'] == 'glm':
for i in range(data.shape[0]):
data[i, 0] = num_to_contrast[data[i, 0]]
......@@ -291,10 +270,7 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
elif reg == 'cont':
mega_data[:, i] = data[mask, 0]
elif reg == 'prevA':
# prev_ans = data[:, 1].copy()
new_prev_ans = data[:, 1].copy()
# prev_ans[1:] = prev_ans[:-1]
# prev_ans -= 1
new_prev_ans -= 1
new_prev_ans = np.convolve(np.append(0, new_prev_ans), params['exp_filter'])[:-(params['exp_filter'].shape[0])]
mega_data[:, i] = new_prev_ans[mask]
......@@ -334,9 +310,6 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
posteriormodel.add_data(mega_data)
# for d in posteriormodel.datas:
# print(d.shape)
# if not os.path.isfile('./{}/data_save_{}.p'.format(data_folder, params['subject'])):
pickle.dump(data_save, open('./{}/data_save_{}.p'.format(data_folder, params['subject']), 'wb'))
quit()
......
......@@ -15,6 +15,8 @@ for filename in os.listdir("./dynamic_GLMiHMM_crossvals/"):
regexp = re.compile(r'((\w|-)+)_fittype_(\w+)_var_0.03_(\d+)_(\d+)_(\d+)')
result = regexp.search(filename)
subject = result.group(1)
if subject == 'ibl_witten_26':
print('here')
fit_type = result.group(3)
seed = result.group(4)
fit_num = result.group(5)
......@@ -43,23 +45,14 @@ for s in prebias_subinfo.keys():
new_seeds.append(seed)
prebias_subinfo[s]["fit_nums"] = new_fit_nums
prebias_subinfo[s]["seeds"] = new_seeds
# if s == 'ibl_witten_13':
# new_fit_nums = []
# new_seeds = []
# for fit_num, seed in zip(prebias_subinfo[s]["fit_nums"], prebias_subinfo[s]["seeds"]):
# if int(seed) < 316:
# new_fit_nums.append(fit_num)
# new_seeds.append(seed)
# prebias_subinfo[s]["fit_nums"] = new_fit_nums
# prebias_subinfo[s]["seeds"] = new_seeds
big = []
non_big = []
sim_subjects = []
for s in prebias_subinfo.keys():
assert len(prebias_subinfo[s]["seeds"]) in [16, 32]
assert len(prebias_subinfo[s]["fit_nums"]) in [16, 32]
# assert len(prebias_subinfo[s]["seeds"]) in [16, 32], s + " " + str(len(prebias_subinfo[s]["seeds"]))
# assert len(prebias_subinfo[s]["fit_nums"]) in [16, 32]
print(s, len(prebias_subinfo[s]["fit_nums"]))
if len(prebias_subinfo[s]["fit_nums"]) == 32:
big.append(s)
......@@ -67,8 +60,6 @@ for s in prebias_subinfo.keys():
non_big.append(s)
if s.startswith("GLM_Sim_"):
sim_subjects.append(s)
else:
non_big.append(s)
print(non_big)
print()
......
......@@ -134,3 +134,48 @@ def sample_statistics(mode_indices, subject, period='prebias'):
test.r_hat_and_ess(state_num_helper(0.02), False)
test.r_hat_and_ess(state_num_helper(0.02, mode_specific=True), False, mode_indices=mode_indices)
print()
def find_good_chains_unsplit_greedy(chains1, chains2, chains3, chains4, reduce_to=8, simple=False):
delete_n = - reduce_to + chains1.shape[0]
mins = np.zeros(delete_n + 1)
n_chains = chains1.shape[0]
chains = np.stack([chains1, chains2, chains3, chains4])
r_hat = eval_r_hat([chains1, chains2, chains3, chains4])
print("Without removals: {}".format(r_hat))
mins[0] = r_hat
r_hats = []
solutions = []
to_del = []
for i in range(delete_n):
r_hat_min = 50
sol = 0
for x in range(n_chains):
if x in to_del:
continue
if not simple:
r_hat = eval_r_hat([np.delete(chains1, to_del + [x], axis=0), np.delete(chains2, to_del + [x], axis=0), np.delete(chains3, to_del + [x], axis=0), np.delete(chains4, to_del + [x], axis=0)])
else:
r_hat = eval_simple_r_hat(np.delete(chains, to_del + [x], axis=1))
if r_hat < r_hat_min:
sol = x
r_hat_min = min(r_hat, r_hat_min)
to_del.append(sol)
print("Minimum is {} (removed {})".format(r_hat_min, i + 1))
print("Removed: {}".format(to_del))
mins[i + 1] = r_hat_min
r_hats.append(r_hat_min)
solutions.append(to_del.copy())
if simple:
r_hat_local = eval_r_hat([np.delete(chains1, to_del, axis=0), np.delete(chains2, to_del, axis=0), np.delete(chains3, to_del, axis=0), np.delete(chains4, to_del, axis=0)])
print("Minimum over everything is {} (removed {})".format(r_hat_local, i + 1))
best = np.argmin(r_hats)
return solutions[best], r_hats[best]
\ No newline at end of file
old_ana_code/
figures/
dynamic_figures/
iHMM_fits/
dynamic_iHMM_fits/
overview_figures/
beliefs/
session_data/
WAIC/
*.png
*.p
*.npz
*.pdf
*.csv
*.zip
*.json
peter_fiugres/
consistency_data/
dynamic_GLM_figures/
dynamic_GLMiHMM_fits2/
glm sim mice/
dynamic_GLMiHMM_crossvals/
import os
i = 0
for filename in os.listdir("./"):
if not filename.endswith('.p'):
continue
if 'bias' in filename:
continue
if not filename.endswith('bias.p'):
i += 1
print("Rename {} into {}".format(filename, filename[:-2] + '_prebias.p'))
os.rename(filename, filename[:-2] + '_prebias.p')
print(i)
"""
Script for getting data from fits into a state to be analysed
Includes old process_many_chains.py:
Script for taking a list of subects and extracting statistics from the chains
which can be used to assess which chains have converged to the same regions
This cannot be run in parallel (because the loading_info dict gets changed and dumped)
"""
import numpy as np
import pyhsmm
import pickle
import json
from dyn_glm_chain_analysis import MCMC_result
import matplotlib.pyplot as plt
import time
from mcmc_chain_analysis import state_size_helper, state_num_helper
from mcmc_chain_analysis import state_size_helper, state_num_helper, find_good_chains_unsplit_greedy, gamma_func, alpha_func
import index_mice # executes function for creating dict of available fits
from dyn_glm_chain_analysis import MCMC_result_list
fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0]
if fit_type == 'bias':
loading_info = json.load(open("canonical_infos_bias.json", 'r'))
r_hats = json.load(open("canonical_info_r_hats_bias.json", 'r'))
elif fit_type == 'prebias':
loading_info = json.load(open("canonical_infos.json", 'r'))
subjects = ['ZFM-04019', 'ZFM-05236'] # list(loading_info.keys())
r_hats = json.load(open("canonical_info_r_hats.json", 'r'))
# done: 'NYU-45', 'UCLA035', 'NYU-30', 'CSHL047', 'NYU-39', 'NYU-37',
# can't: 'UCLA006'
subjects = ['NYU-40', 'NYU-46', 'KS044', 'NYU-48']
# 'UCLA012', 'CSHL052', 'NYU-11', 'UCLA011', 'NYU-47', 'CSHL045', 'UCLA017', 'CSHL055', 'UCLA005', 'CSHL060', 'UCLA015', 'UCLA014', 'CSHL053', 'NYU-12', 'CSHL058', 'KS042']
fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0]
thinning = 50
fit_variance = 0.03
func1 = state_num_helper(0.2)
func2 = state_num_helper(0.1)
func3 = state_size_helper()
func4 = state_size_helper(1)
func5 = state_num_helper(0.01)
func6 = state_size_helper(2)
func7 = state_size_helper(3)
func8 = state_size_helper(4)
func9 = state_size_helper(5)
func10 = state_size_helper(6)
func11 = state_size_helper(7)
dur = 'yes'
m = 16
for subject in subjects:
results = []
summary_info = {"thinning": thinning, "contains": [], "seeds": [], "fit_nums": []}
m = len(loading_info[subject]["fit_nums"])
assert m == 16
print(subject)
n_runs = -1
counter = -1
n = (loading_info[subject]['chain_num'] + 1) * 4000 // 25
n = (loading_info[subject]['chain_num']) * 4000 // thinning
chains1 = np.zeros((m, n))
chains2 = np.zeros((m, n))
chains3 = np.zeros((m, n))
......@@ -48,7 +60,8 @@ for subject in subjects:
print(seed)
info_dict = pickle.load(open("./session_data/{}_info_dict.p".format(subject), "rb"))
samples = []
mini_counter = 0
mini_counter = 1 # start at 1, discard first 4000 as burnin
while True:
try:
file = "./dynamic_GLMiHMM_crossvals/{}_fittype_{}_var_{}_{}_{}{}.p".format(subject, fit_type, fit_variance, seed, fit_num, '_{}'.format(mini_counter))
......@@ -58,6 +71,7 @@ for subject in subjects:
mini_counter += 1
except Exception:
break
if n_runs == -1:
n_runs = mini_counter
else:
......@@ -65,19 +79,17 @@ for subject in subjects:
print("Problem")
print(n_runs, mini_counter)
quit()
save_id = "{}_fittype_{}_var_{}_{}_{}.p".format(subject, fit_type, fit_variance, seed, fit_num).replace('.', '_')
# for file in os.listdir("./dynamic_GLMiHMM_crossvals/infos_new/"):
# if file.startswith("{}_".format(subject)) and file.endswith("_{}_{}_{}_{}.json".format(fit_type, fit_variance, seed, fit_num)):
# print("Taking fit infos from {}".format(file))
# fit_infos = json.load(open("./dynamic_GLMiHMM_crossvals/infos_new/" + file, 'r'))
# sample_lls = fit_infos['ll']
print("loaded seed {}".format(seed))
result = MCMC_result(samples[::25],
result = MCMC_result(samples[::thinning],
infos=info_dict, data=samples[0].datas,
sessions=fit_type, fit_variance=fit_variance,
dur=dur, save_id=save_id) # , sample_lls=sample_lls)
dur=dur, save_id=save_id)
results.append(result)
print("Making result {} took {:.4}".format(counter, time.time() - lala))
res = func1(result)
......@@ -88,22 +100,55 @@ for subject in subjects:
chains3[counter] = res
res = func4(result)
chains4[counter] = res
# res = func5(result)
# chains5[j] = res
# res = func6(result)
# chains6[j] = res
# res = func7(result)
# chains7[j] = res
# res = func8(result)
# chains8[j] = res
# res = func9(result)
# chains9[j] = res
# res = func10(result)
# chains10[j] = res
# res = func11(result)
# chains11[j] = res
pickle.dump(chains1, open("multi_chain_saves/{}_state_num_0_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb'))
pickle.dump(chains2, open("multi_chain_saves/{}_state_num_1_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb'))
pickle.dump(chains3, open("multi_chain_saves/{}_largest_state_0_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb'))
pickle.dump(chains4, open("multi_chain_saves/{}_largest_state_1_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb'))
# R^hat tests
# test = MCMC_result_list([fake_result(100) for i in range(8)])
# test.r_hat_and_ess(return_ascending, False)
# test.r_hat_and_ess(return_ascending_shuffled, False)
# quit()
if subject.startswith('GLM_Sim_07') or subject.startswith('GLM_Sim_11'):
continue
print()
print("Checking R^hat, finding best subset of chains")
# mins = find_good_chains(chains[:, :-1].reshape(32, chains.shape[1] // 2))
sol, final_r_hat = find_good_chains_unsplit_greedy(chains1, chains2, chains3, chains4, reduce_to=chains1.shape[0] // 2)
r_hats[subject] = final_r_hat
loading_info[subject]['ignore'] = sol
print(r_hats[subject])
if fit_type == 'bias':
json.dump(loading_info, open("canonical_infos_bias.json", 'w'))
json.dump(r_hats, open("canonical_info_r_hats_bias.json", 'w'))
elif fit_type == 'prebias':
json.dump(loading_info, open("canonical_infos.json", 'w'))
json.dump(r_hats, open("canonical_info_r_hats.json", 'w'))
if r_hats[subject] >= 1.05:
print("Skipping canonical result")
continue
else:
print("Making canonical result")
# subset data
summary_info['contains'] = [i for i in range(m) if i not in sol]
summary_info['seeds'] = [loading_info[subject]['seeds'][i] for i in summary_info['contains']]
summary_info['fit_nums'] = [loading_info[subject]['fit_nums'] for i in summary_info['contains']]
results = [results[i] for i in summary_info['contains']]
test = MCMC_result_list(results, summary_info)
pickle.dump(test, open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'wb'))
test.r_hat_and_ess(state_num_helper(0.2), False)
test.r_hat_and_ess(state_num_helper(0.1), False)
test.r_hat_and_ess(state_num_helper(0.05), False)
test.r_hat_and_ess(state_size_helper(), False)
test.r_hat_and_ess(state_size_helper(1), False)
test.r_hat_and_ess(gamma_func, True)
test.r_hat_and_ess(alpha_func, True)
import json
import numpy as np
from dyn_glm_chain_analysis import MCMC_result_list
import pickle
import matplotlib.pyplot as plt
import os
def create_mode_indices(test, subject, fit_type):
dim = 3
try:
xy, z = pickle.load(open("multi_chain_saves/xyz_{}_{}.p".format(subject, fit_type), 'rb'))
except Exception:
return
print('Doing PCA')
ev, eig, projection_matrix, dimreduc = test.state_pca(subject, pca_type='dists', dim=dim)
xy = np.vstack([dimreduc[i] for i in range(dim)])
from scipy.stats import gaussian_kde
z = gaussian_kde(xy)(xy)
pickle.dump((xy, z), open("multi_chain_saves/xyz_{}_{}.p".format(subject, fit_type), 'wb'))
print("Mode indices of " + subject)
threshold_search(xy, z, test, 'first_', subject, fit_type)
print("Find another mode?")
if input() not in ['yes', 'y']:
return
threshold_search(xy, z, test, 'second_', subject, fit_type)
print("Find another mode?")
if input() not in ['yes', 'y']:
return
threshold_search(xy, z, test, 'third_', subject, fit_type)
return
def threshold_search(xy, z, test, mode_prefix, subject, fit_type):
happy = False
conds = [0, None, None, None, None, None, None]
x_min, x_max, y_min, y_max, z_min, z_max = None, None, None, None, None, None
while not happy:
print()
print("Pick level")
prob_level = input()
if prob_level == 'cond':
print("x > ?")
resp = input()
if resp not in ['n', 'no']:
x_min = float(resp)
else:
x_min = None
print("x < ?")
resp = input()
if resp not in ['n', 'no']:
x_max = float(resp)
else:
x_max = None
print("y > ?")
resp = input()
if resp not in ['n', 'no']:
y_min = float(resp)
else:
y_min = None
print("y < ?")
resp = input()
if resp not in ['n', 'no']:
y_max = float(resp)
else:
y_max = None
print("z > ?")
resp = input()
if resp not in ['n', 'no']:
z_min = float(resp)
else:
z_min = None
print("z < ?")
resp = input()
if resp not in ['n', 'no']:
z_max = float(resp)
else:
z_max = None
print("Prob level")
prob_level = float(input())
conds = [prob_level, x_min, x_max, y_min, y_max, z_min, z_max]
print("Condtions are {}".format(conds))
else:
try:
prob_level = float(prob_level)
except:
print('mistake')
prob_level = float(input)
conds[0] = prob_level
print("Level is {}".format(prob_level))
mode = conditions_fulfilled(z, xy, conds)
print("# of samples: {}".format(mode.sum()))
mode_indices = np.where(mode)[0]
if mode.sum() > 0:
print(xy[0][mode_indices].min(), xy[0][mode_indices].max(), xy[1][mode_indices].min(), xy[1][mode_indices].max(), xy[2][mode_indices].min(), xy[2][mode_indices].max())
print("Happy?")
happy = 'yes' == input()
print("Subset by factor?")
if input() == 'yes':
print("Factor?")
print(mode_indices.shape)
factor = int(input())
mode_indices = mode_indices[::factor]
print(mode_indices.shape)
if subject not in loading_info:
loading_info[subject] = {}
loading_info[subject]['mode prob level'] = prob_level
pickle.dump(mode_indices, open("multi_chain_saves/{}mode_indices_{}_{}.p".format(mode_prefix, subject, fit_type), 'wb'))
# consistencies = test.consistency_rsa(indices=mode_indices) # do this on the cluster from now on
# pickle.dump(consistencies, open("multi_chain_saves/{}mode_consistencies_{}_{}.p".format(mode_prefix, subject, fit_type), 'wb', protocol=4))
def conditions_fulfilled(z, xy, conds):
works = z > conds[0]
if conds[1]:
works = np.logical_and(works, xy[0] > conds[1])
if conds[2]:
works = np.logical_and(works, xy[0] < conds[2])
if conds[3]:
works = np.logical_and(works, xy[1] > conds[3])
if conds[4]:
works = np.logical_and(works, xy[1] < conds[4])
if conds[5]:
works = np.logical_and(works, xy[2] > conds[5])
if conds[6]:
works = np.logical_and(works, xy[2] < conds[6])
return works
def state_set_and_plot(test, mode_prefix, subject, fit_type):
mode_indices = pickle.load(open("multi_chain_saves/{}mode_indices_{}_{}.p".format(mode_prefix, subject, fit_type), 'rb'))
consistencies = pickle.load(open("multi_chain_saves/{}mode_consistencies_{}_{}.p".format(mode_prefix, subject, fit_type), 'rb'))
session_bounds = list(np.cumsum([len(s) for s in test.results[0].models[-1].stateseqs]))
import scipy.cluster.hierarchy as hc
consistencies /= consistencies[0, 0]
linkage = hc.linkage(consistencies[0, 0] - consistencies[np.triu_indices(consistencies.shape[0], k=1)], method='complete')
# R = hc.dendrogram(linkage, truncate_mode='lastp', p=150, no_labels=True)
# plt.savefig("peter figures/{}tree_{}_{}".format(mode_prefix, subject, 'complete'))
# plt.close()
session_bounds = list(np.cumsum([len(s) for s in test.results[0].models[-1].stateseqs]))
plot_criterion = 0.95
a = hc.fcluster(linkage, plot_criterion, criterion='distance')
b, c = np.unique(a, return_counts=1)
state_sets = []
for x, y in zip(b, c):
state_sets.append(np.where(a == x)[0])
print("dumping state set")
pickle.dump(state_sets, open("multi_chain_saves/{}state_sets_{}_{}.p".format(mode_prefix, subject, fit_type), 'wb'))
state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='_{}{}'.format(mode_prefix, plot_criterion), show=True, separate_pmf=True, type_coloring=True)
fig, ax = plt.subplots(ncols=5, sharey=True, gridspec_kw={'width_ratios': [10, 1, 1, 1, 1]}, figsize=(13, 8))
from matplotlib.pyplot import cm
for j, criterion in enumerate([0.95, 0.8, 0.5, 0.2]):
clustering_colors = np.zeros((consistencies.shape[0], 100, 4))
a = hc.fcluster(linkage, criterion, criterion='distance')
b, c = np.unique(a, return_counts=1)
print(b.shape)
print(np.sort(c))
cmap = cm.rainbow(np.linspace(0, 1, 17))
rank_to_color_place = dict(zip(range(17), [0, 16, 8, 4, 12, 2, 6, 10, 14, 1, 3, 5, 7, 9, 11, 13, 15])) # handcrafted to maximise color distance, I think
i = -1
b = [x for _, x in sorted(zip(c, b))][::-1]
c = [x for x, _ in sorted(zip(c, b))][::-1]
plot_above = 50
while len([y for y in c if y > plot_above]) > 17:
plot_above += 1
for x, y in zip(b, c):
if y > plot_above:
i += 1
clustering_colors[a == x] = cmap[rank_to_color_place[i]]
# clustering_colors[a == x] = cm.rainbow(np.mean(np.where(a == x)[0]) / a.shape[0])
ax[j+1].imshow(clustering_colors, aspect='auto', origin='upper')
for sb in session_bounds:
ax[j+1].axhline(sb, color='k')
ax[j+1].set_xticks([])
ax[j+1].set_yticks([])
ax[j+1].set_title("{}%".format(int(criterion * 100)), size=20)
ax[0].imshow(consistencies, aspect='auto', origin='upper')
for sb in session_bounds:
ax[0].axhline(sb, color='k')
ax[0].set_xticks([])
ax[0].set_yticks(session_bounds[::-1])
ax[0].set_yticklabels(session_bounds[::-1], size=18)
ax[0].set_ylim(session_bounds[-1], 0)
ax[0].set_ylabel("Trials", size=28)
plt.yticks(rotation=45)
plt.tight_layout()
plt.savefig("peter figures/{}clustered_trials_{}_{}".format(mode_prefix, subject, 'criteria comp').replace('.', '_'))
plt.close()
fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0]
if fit_type == 'bias':
loading_info = json.load(open("canonical_infos_bias.json", 'r'))
elif fit_type == 'prebias':
loading_info = json.load(open("canonical_infos.json", 'r'))
subjects = list(loading_info.keys())
# error: KS043, KS045, 'NYU-12', ibl_witten_15, NYU-21, CSHL052, KS003
# done: NYU-46, NYU-39, ibl_witten_19, NYU-48
subjects = ['DY_013', 'ZFM-01592', 'NYU-39', 'NYU-27', 'NYU-46', 'ZFM-01936', 'ZFM-02372', 'ZFM-01935', 'ibl_witten_26', 'ZM_2241', 'KS084', 'ZFM-01576']
for subject in subjects:
test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'rb'))
if os.path.isfile("multi_chain_saves/{}mode_indices_{}_{}.p".format('first_', subject, fit_type)):
print("It has been done")
continue
print('Computing sub result')
create_mode_indices(test, subject, fit_type)
# state_set_and_plot(test, 'first_', subject, fit_type)
# print("second mode?")
# if input() in ['y', 'yes']:
# state_set_and_plot(test, 'second_', subject, fit_type)
import numpy as np
import matplotlib.pyplot as plt
import pyhsmm
import pyhsmm.basic.distributions as distributions
import copy
import warnings
import pickle
import time
from scipy.special import digamma
import faulthandler
faulthandler.enable()
np.set_printoptions(suppress=True)
def crp_expec(n, theta):
"""
Return expected number of tables after n customers, given concentration theta.
From Wikipedia
"""
return theta * (digamma(theta + n) - digamma(theta))
def fill_timeouts(xs, skip=1):
"""Just iterate over array, replacing invalid elements."""
if xs[0] == skip:
print('sucks')
xs[0] = np.random.rand() > 0.5
curr = xs[0]
ret = np.copy(xs)
for i, x in enumerate(xs):
if x == skip:
ret[i] = curr
curr = ret[i]
return ret
subject = "ibl_witten_17"
fit_type = ['pre_bias', 'bias', 'all'][0]
with_time = False
conditioned_on = ['nothing', 'reward', 'truth', 'answer'][0]
Nmax = 15
seeds = [7123115]
info_dict = pickle.load(open("./session_data/{}_info_dict.p".format(subject), "rb"))
n = 2000
relevant_states = np.zeros((len(seeds), n))
likes = np.zeros((len(seeds), n))
models = []
till_session = info_dict['bias_start'] if fit_type == 'pre_bias' else info_dict['n_sessions']
from_session = info_dict['bias_start'] if fit_type == 'bias' else 0
print(subject)
print(fit_type)
print(conditioned_on)
for i, seed in enumerate(seeds):
np.random.seed(seed)
print(i+1)
obs_hypparams = {'n_inputs': 11 + 11 * (conditioned_on != 'nothing'), 'n_outputs': 3}
dur_hypparams = dict(r_support=np.array([1, 2, 3, 5, 7, 10, 15, 21, 28, 36, 45, 55]), r_probs=np.ones(12)/12., alpha_0=1, beta_0=1)
obs_distns = [distributions.Input_Categorical(**obs_hypparams) for state in range(Nmax)]
dur_distns = [distributions.NegativeBinomialIntegerR2Duration(**dur_hypparams) for state in range(Nmax)]
posteriormodel = pyhsmm.models.WeakLimitHDPHSMM(
# https://math.stackexchange.com/questions/449234/vague-gamma-prior
alpha_a_0=.5, alpha_b_0=20, # TODO: gamma vs alpha? gamma steers state number
gamma_a_0=1, gamma_b_0=1,
init_state_concentration=6.,
obs_distns=obs_distns,
dur_distns=dur_distns)
# !!!
for j in range(from_session, till_session):
try:
data = pickle.load(open("./session_data/{}_recovery_info_{}.p".format(subject, j), "rb"))
#print(data.shape)
except FileNotFoundError:
continue
if data.shape[0] == 0:
continue
if conditioned_on == 'answer':
prev_ans = fill_timeouts(data[:, 1])
prev_ans[1:] = prev_ans[:-1]
data[:, 0] += (prev_ans == 2) * 11
elif conditioned_on == 'reward':
side_info = pickle.load(open("./session_data/{}_side_info_{}.p".format(subject, j), "rb"))
prev_reward = side_info[:, 1]
prev_reward[1:] = prev_reward[:-1]
prev_reward[0] = np.random.rand() > 0.5
data[:, 0] += (prev_reward == 1) * 11
data = data[:, [0, 1]]
data = data.astype(int)
posteriormodel.add_data(data, trunc=100)
# import pyhsmm.util.profiling as prof
# from pybasicbayes.util.stats import sample_crp_tablecounts
# prof_func = prof._prof(posteriormodel.resample_model)
# prof._prof.add_function(sample_crp_tablecounts)
# later
# !!! call prof_func, not the original func
# later
# prof._prof.print_stats()
time_save = time.time()
with warnings.catch_warnings(): # ignore the scipy warning
warnings.simplefilter("ignore")
for j in range(n):
if j % 25 == 0:
print(j)
posteriormodel.resample_model()
likes[i, j] = posteriormodel.log_likelihood()
model_save = copy.deepcopy(posteriormodel)
if j != n - 1:
# To save on memory:
model_save.delete_data()
models.append(model_save)
# save something in case of crash
if j % 100 == 0:
pickle.dump(models, open("./iHMM_fits/recovery_{}_{}_withtime_{}_condition_{}.p".format(subject, fit_type, with_time, conditioned_on), 'wb'))
print(time.time() - time_save)
pickle.dump(models, open("./iHMM_fits/recovery_{}_{}_withtime_{}_condition_{}.p".format(subject, fit_type, with_time, conditioned_on), 'wb'))
plt.plot(likes.T)
plt.savefig("likelihoods")
plt.show()
import os
import re
memory = {}
name_saves = {}
seed_saves = {}
do_it = True
for filename in os.listdir("./dynamic_GLMiHMM_crossvals/"):
if not filename.endswith('.p'):
continue
regexp = re.compile(r'((\w|-)+)_fittype_(\w+)_var_0.03_(\d+)_(\d+)_(\d+)')
result = regexp.search(filename)
subject = result.group(1)
fit_type = result.group(3)
seed = result.group(4)
fit_num = result.group(5)
chain_num = result.group(6)
if fit_type == 'prebias':
if subject not in seed_saves:
seed_saves[subject] = [seed]
else:
if seed not in seed_saves[subject]:
seed_saves[subject].append(seed)
if (subject, seed) not in name_saves:
name_saves[(subject, seed)] = []
if fit_num not in name_saves[(subject, seed)]:
name_saves[(subject, seed)].append(fit_num)
if (subject, seed, fit_num) not in memory:
memory[(subject, seed, fit_num)] = {"chain_num": int(chain_num), "counter": 1}
else: # if this is the first file of that chain, save some info
memory[(subject, seed, fit_num)]["chain_num"] = max(memory[(subject, seed, fit_num)]["chain_num"], int(chain_num))
memory[(subject, seed, fit_num)]["counter"] += 1
total_move = 0
nyu_11_move = 0
dicts_removed = 0
moved = []
completed = []
incompleted = []
for key in name_saves:
subject = key[0]
seed = key[1]
complete = False
save_fit_num = -1
for fit_num in name_saves[key]:
if memory[(subject, seed, fit_num)]['chain_num'] == 14 and memory[(subject, seed, fit_num)]['counter'] == 15:
save_fit_num = fit_num
complete = True
if len(seed_saves[subject]) == 16:
if subject not in completed:
completed.append(subject)
else:
if subject not in incompleted:
incompleted.append(subject)
if complete and len(name_saves[key]) > 1:
assert save_fit_num != -1
for fit_num in name_saves[key]:
if fit_num != save_fit_num:
for i in range(15):
if do_it:
if os.path.exists("./dynamic_GLMiHMM_crossvals/{}_fittype_prebias_var_0.03_{}_{}_{}.p".format(subject, seed, fit_num, i)):
os.rename("./dynamic_GLMiHMM_crossvals/{}_fittype_prebias_var_0.03_{}_{}_{}.p".format(subject, seed, fit_num, i),
"./del_test/{}_fittype_prebias_var_0.03_{}_{}_{}.p".format(subject, seed, fit_num, i))
if subject not in moved:
moved.append(subject)
total_move += 1
else:
if os.path.exists("./dynamic_GLMiHMM_crossvals/{}_fittype_prebias_var_0.03_{}_{}_{}.p".format(subject, seed, fit_num, i)):
print("I would move ")
print("./dynamic_GLMiHMM_crossvals/{}_fittype_prebias_var_0.03_{}_{}_{}.p".format(subject, seed, fit_num, i))
print(" to ")
print("./del_test/{}_fittype_prebias_var_0.03_{}_{}_{}.p".format(subject, seed, fit_num, i))
total_move += 1
nyu_11_move += subject == "NYU-11"
print(moved)
print(completed)
print(incompleted)
print("Would move {} in total, and {} of NYU-11".format(total_move, nyu_11_move))
"""
Studying how much easier it is for small changes in the regressors to affect the pmf around 0 versus further from zero
"""
import numpy as np
import matplotlib.pyplot as plt
weights = list(np.zeros((17, 3)))
for i, weight in enumerate(weights):
weight[-1] = i * 0.2 - 1.6
contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])
contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1]
bias = np.ones(11)
predictors = np.vstack((contrasts_L, contrasts_R, bias)).T
for weight in weights:
plt.plot(1 / (1 + np.exp(- np.sum(weight * predictors, axis=1))))
plt.ylim(0, 1)
plt.show()
weights = list(np.zeros((17, 3)))
for i, weight in enumerate(weights):
weight[-1] = i * 0.2 - 1.6
weight[0] = 2
for weight in weights:
plt.plot(1 / (1 + np.exp(- np.sum(weight * predictors, axis=1))))
plt.ylim(0, 1)
plt.show()
\ No newline at end of file