Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • sbruijns/ihmm_behav_states
1 result
Show changes
Commits on Source (4)
......@@ -594,9 +594,9 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur
n = test.results[0].n_sessions
trial_counter = 0
for seq_num in range(n):
if seq_num + 1 != 12:
trial_counter += len(test.results[0].models[0].stateseqs[seq_num])
continue
# if seq_num + 1 != 12:
# trial_counter += len(test.results[0].models[0].stateseqs[seq_num])
# continue
c_n_a = test.results[0].data[seq_num]
plt.figure(figsize=(18, 9))
......@@ -661,7 +661,7 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur
plt.xlabel('Trial', size=28)
sns.despine()
plt.xlim(left=250, right=450)
# plt.xlim(left=250, right=450)
plt.legend(frameon=False, fontsize=22, bbox_to_anchor=(0.8, 0.5))
plt.tight_layout()
if save:
......@@ -716,6 +716,10 @@ def state_pmfs(test, trials, indices):
def lapse_sides(test, state_sets, indices):
"""Compute and plot a lapse differential across sessions.
Takes a single mouse and plots (1 - lapse_left) - lapse_right across sessions, with sessions boundaries shown."""
def func_init(): return {'lapse_side': np.zeros(test.results[0].n_datapoints) + 10, 'session_bounds': []}
def first_for(test, results):
......@@ -1204,7 +1208,6 @@ def find_good_chains_unsplit_greedy(chains1, chains2, chains3, chains4, reduce_t
to_del = []
for i in range(delete_n):
# print()
r_hat_min = 50
sol = 0
for x in range(n_chains):
......@@ -1218,7 +1221,7 @@ def find_good_chains_unsplit_greedy(chains1, chains2, chains3, chains4, reduce_t
sol = x
r_hat_min = min(r_hat, r_hat_min)
to_del.append(sol)
# if i == delete_n - 1:
print("Minimum is {} (removed {})".format(r_hat_min, i + 1))
print("Removed: {}".format(to_del))
mins[i + 1] = r_hat_min
......@@ -1240,6 +1243,7 @@ def rank_inv_normal_transform(chains):
folded_rank_normalised = norm.ppf((folded_ranked - 3/8) / (folded_chains.size + 1/4))
return rank_normalised, folded_rank_normalised
if __name__ == "__main__":
fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0]
if fit_type == 'bias':
......@@ -1380,42 +1384,46 @@ len_to_bools = {
}
def params_to_pmf(params):
return params[2] + (1 - params[2] - params[3]) / (1 + np.exp(- params[0] * (all_conts - params[1])))
def four_param_loss(params, pmf, offset=False):
# if params[2] + params[3] > 1:
# return 10
contrast_bools = len_to_bools[len(pmf)]
fit = params_to_pmf(params)[contrast_bools]
return np.sum((pmf - fit) ** 2) + offset * (np.abs(params[1]) * 0.0001 + np.clip(np.abs(params[0] - 1), 0, 0.35) * 0.0002)
def four_param_pmf(pmf, s1=1, s2=0, offset=False):
res = minimize(lambda x: four_param_loss(x, pmf, offset=offset), method='Nelder-Mead', x0=np.array([s1, s2, pmf[0], pmf[-1]]))
# can give [s1, s2, pmf.min(), 1 - pmf.max()] as starting points, but gives weird error
return res
def state_type_durs(states, pmfs):
# Takes states and pmfs, first creates an array of when which type is how active, then computes the number of sessions each type lasts.
# A type lasts until a more advanced type takes up more than 50% of a session (and cannot return)
# Returns the durations for all the different state types, and an array which holds the state percentages
state_types = np.zeros((3, states.shape[1]))
for s, pmf in zip(states, pmfs):
pmf_counter = -1
for i in range(states.shape[1]):
if s[i]:
pmf_counter += 1
state_types[pmf_type(pmf[1][pmf_counter][pmf[0]]), i] += s[i] # indexing horror
durs = (np.where(state_types[1] > 0.5)[0][0],
np.where(state_types[2] > 0.5)[0][0] - np.where(state_types[1] > 0.5)[0][0],
states.shape[1] - np.where(state_types[2] > 0.5)[0][0])
if np.where(state_types[2] > 0.5)[0][0] < np.where(state_types[1] > 0.5)[0][0]:
durs = (np.where(state_types[2] > 0.5)[0][0],
0,
states.shape[1] - np.where(state_types[2] > 0.5)[0][0])
return durs, state_types
def state_cluster_interpolation(states, pmfs):
#
# Used to contain a first_type_count variable, which seemed to just count the number of states?
pmf_examples = [[], [], []]
state_trans = np.zeros((3, 3))
state_types = np.zeros((3, states.shape[1]))
first_type_count = 0
for state, pmf in zip(states, pmfs):
sessions = np.where(state)[0]
for i, sess_pmf in enumerate(pmf[1]):
if i == 0:
state_type = pmf_type(sess_pmf[pmf[0]])
first_type_count += state_type == 0
if i > 0 and state_type != pmf_type(sess_pmf[pmf[0]]):
state_trans[state_type, pmf_type(sess_pmf[pmf[0]])] += 1
state_type = pmf_type(sess_pmf[pmf[0]])
pmf_examples[pmf_type(sess_pmf[pmf[0]])].append(sess_pmf[pmf[0]])
state_types[pmf_type(sess_pmf[pmf[0]]), sessions[i]] += state[sessions[i]]
return state_types, state_trans, first_type_count, pmf_examples
return state_types, state_trans, pmf_examples
def pmf_type(pmf):
......@@ -1436,11 +1444,10 @@ if __name__ == "__main__":
loading_info = json.load(open("canonical_infos.json", 'r'))
r_hats = json.load(open("canonical_info_r_hats.json", 'r'))
no_good_pcas = ['NYU-06', 'SWC_023'] # no good rhat: 'ibl_witten_13'
subjects = ['KS014'] # list(loading_info.keys())
subjects = list(loading_info.keys())
print(subjects)
fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0]
dur = 'yes'
chain_num = ''
fist_good_pmf = {'CSHL051': (6, 'left'),
'CSHL059': (3, 'left'),
......@@ -1492,7 +1499,6 @@ if __name__ == "__main__":
n_points = 150
state_trajs = np.zeros((3, n_points))
state_trans = np.zeros((3, 3))
first_type_count = 0
not_yet = True
......@@ -1509,71 +1515,50 @@ if __name__ == "__main__":
test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'rb'))
print('loaded canoncial result')
# state_sets = pickle.load(open("multi_chain_saves/state_sets_{}.p".format(subject), 'rb'))
# lapse_sides(test, [s for s in state_sets if len(s) > 40], mode_indices)
# continue
#
mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'rb'))
state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb'))
states, pmfs = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=True, separate_pmf=True)
# lapse differential
# lapse_sides(test, [s for s in state_sets if len(s) > 40], mode_indices)
# training overview
# states, pmfs = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=True, separate_pmf=True)
# state_development_single_sample(test, [mode_indices[0]], show=True, separate_pmf=True, save=False)
consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
consistencies /= consistencies[0, 0]
contrasts_plot(test, [s for s in state_sets if len(s) > 40], subject=subject, save=True, show=True, consistencies=consistencies)
quit()
# session overview
# consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
# consistencies /= consistencies[0, 0]
# contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True, consistencies=consistencies)
# state_types = np.zeros((3, states.shape[1]))
# for s, pmf in zip(states, pmfs):
# pmf_counter = -1
# for i in range(states.shape[1]):
# if s[i]:
# pmf_counter += 1
# state_types[pmf_type(pmf[1][pmf_counter][pmf[0]]), i] += s[i] # indexing horror
#
# durs = (np.where(state_types[1] > 0.5)[0][0],
# np.where(state_types[2] > 0.5)[0][0] - np.where(state_types[1] > 0.5)[0][0],
# states.shape[1] - np.where(state_types[2] > 0.5)[0][0])
# if np.where(state_types[2] > 0.5)[0][0] < np.where(state_types[1] > 0.5)[0][0]:
# durs = (np.where(state_types[2] > 0.5)[0][0],
# 0,
# states.shape[1] - np.where(state_types[2] > 0.5)[0][0])
# if durs[0] < 0 or durs[1] < 0 or durs[2] < 0:
# quit()
# duration of different state types (and also percentage of type activities)
# durs, _ = state_type_durs(states, pmfs)
# abs_state_durs.append(durs)
# if durs[1] < 0:
# state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=False, separate_pmf=True)
# rel_state_durs.append((durs[0] / states.shape[1], durs[1] / states.shape[1], durs[2] / states.shape[1]))
# continue
# contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True)
# continue
# ret, trans, count, pmf_types = state_cluster_interpolation(states, pmfs)
# state_trans += trans
# first_type_count += count
# points = np.linspace(1, test.results[0].n_sessions, n_points)
#
# # plt.plot(np.arange(1, 1 + test.results[0].n_sessions), ret.T)
# for i, r in enumerate(ret):
# state_trajs[i] += np.interp(points, np.arange(1, 1 + test.results[0].n_sessions), r)
# # plt.plot(points, np.interp(points, np.arange(1, 1 + test.results[0].n_sessions), r))
# # plt.show()
# continue
ret, trans, count, pmf_types = state_cluster_interpolation(states, pmfs)
state_trans += trans
points = np.linspace(1, test.results[0].n_sessions, n_points)
# plt.plot(np.arange(1, 1 + test.results[0].n_sessions), ret.T)
for i, r in enumerate(ret):
state_trajs[i] += np.interp(points, np.arange(1, 1 + test.results[0].n_sessions), r)
plt.plot(points, np.interp(points, np.arange(1, 1 + test.results[0].n_sessions), r))
plt.show()
continue
# plt.plot(ret.T, label=[0, 1, 2])
# plt.legend()
# plt.show()
# continue
plt.plot(ret.T, label=[0, 1, 2])
plt.legend()
plt.show()
continue
# for i, pmfs in enumerate(pmf_types):
# plt.subplot(1, 3, i + 1)
# plt.plot([0, 11], [0.5, 0.5], 'grey', alpha=1/3)
# for pmf in pmfs:
# plt.plot(len_to_bools[len(pmf)], pmf)
# plt.ylim(0, 1)
# plt.show()
# continue
for i, pmfs in enumerate(pmf_types):
plt.subplot(1, 3, i + 1)
plt.plot([0, 11], [0.5, 0.5], 'grey', alpha=1/3)
for pmf in pmfs:
plt.plot(len_to_bools[len(pmf)], pmf)
plt.ylim(0, 1)
plt.show()
continue
# quit()
# for pmf in pmfs:
......@@ -1585,23 +1570,6 @@ if __name__ == "__main__":
# sec_mode = pickle.load(open("multi_chain_saves/second_mode_indices_{}.p".format(subject), 'rb'))
# consistencies = pickle.load(open("multi_chain_saves/second_mode_consistencies_{}.p".format(subject), 'rb'))
#
# import scipy.cluster.hierarchy as hc
# consistencies /= consistencies[0, 0]
# linkage = hc.linkage(consistencies[0, 0] - consistencies[np.triu_indices(consistencies.shape[0], k=1)], method='complete')
#
# a = hc.fcluster(linkage, 0.95, criterion='distance')
# b, c = np.unique(a, return_counts=1)
# print(b.shape)
# print(np.sort(c))
#
# state_sets = []
# for x, y in zip(b, c):
# state_sets.append(np.where(a == x)[0])
#
# states, pmfs = state_development(test, [s for s in state_sets if len(s) > 40], sec_mode, save_append='_{}{}'.format('second_mode_', 0.95), show=True, separate_pmf=True)
# contrasts_plot(test, [s for s in state_sets if len(s) > 40], subject=subject, save_append='_{}{}'.format('second_mode_', 0.95), save=True, show=True)
# quit()
#
# single_sample = [np.argmax(z)]
# for position, index in enumerate(np.where(np.logical_and(z > 2.7e-7, xy[0] < -500))[0]):
# print(position, index, index // test.n, index % test.n)
......@@ -1679,7 +1647,6 @@ if __name__ == "__main__":
mode_indices = mode_indices[::factor]
print(mode_indices.shape)
loading_info[subject]['mode prob level'] = prob_level
json.dump([subject, prob_level], open("canonical_infos_prob_lev{}_{}.json".format(subject, fit_type), 'w'))
pickle.dump(mode_indices, open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'wb'))
consistencies = test.consistency_rsa(indices=mode_indices)
......@@ -1689,9 +1656,6 @@ if __name__ == "__main__":
mode_indices = pickle.load(open("multi_chain_saves/{}mode_indices_{}_{}.p".format(string_prefix, subject, fit_type), 'rb'))
consistencies = pickle.load(open("multi_chain_saves/{}consistencies_{}_{}.p".format(string_prefix, subject, fit_type), 'rb'))
# xy, z = pickle.load(open("multi_chain_saves/xyz_{}.p".format(subject), 'rb'))
# quit()
# consistencies = pickle.load(open("multi_chain_saves/general_consistencies_{}.p".format(subject), 'rb'))
import scipy.cluster.hierarchy as hc
consistencies /= consistencies[0, 0]
......@@ -1704,20 +1668,6 @@ if __name__ == "__main__":
session_bounds = list(np.cumsum([len(s) for s in test.results[0].models[-1].stateseqs]))
# I think this is about finding out where states start and how long they last
# a = hc.fcluster(linkage, 0.9, criterion='distance')
# b, c = np.unique(a, return_counts=1)
# state_sets = []
# for x, y in zip(b, c):
# state_sets.append(np.where(a == x)[0])
# states, pmfs = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=True, separate_pmf=True)
#
# state_id, session_appears = np.where(states)
# for s in np.unique(state_id):
# state_appear_mode.append(session_appears[state_id == s][0] / (test.results[0].n_sessions - 1))
# continue
fig, ax = plt.subplots(ncols=5, sharey=True, gridspec_kw={'width_ratios': [10, 1, 1, 1, 1]}, figsize=(13, 8))
from matplotlib import colors
from matplotlib.pyplot import cm
......@@ -1738,12 +1688,10 @@ if __name__ == "__main__":
states, pmfs = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='_{}{}'.format(string_prefix, criterion), show=True, separate_pmf=True)
contrasts_plot(test, [s for s in state_sets if len(s) > 40], subject=subject, save_append='_{}{}'.format(string_prefix, criterion), save=True, show=True)
# R = hc.dendrogram(linkage, no_plot=True, get_leaves=True, color_threshold=color_threshold)
# leaves_color_list = np.array(R["leaves_color_list"])
# leaves = np.array(R["leaves"])
# for c in np.unique(R["leaves_color_list"]):
# clustering_colors[:, leaves[leaves_color_list == c]] = colors.to_rgb(c)
# I think this is about finding out where states start and how long they last
# state_id, session_appears = np.where(states)
# for s in np.unique(state_id):
# state_appear_mode.append(session_appears[state_id == s][0] / (test.results[0].n_sessions - 1))
cmap = cm.rainbow(np.linspace(0, 1, 17))#len([x for x, y in zip(b, c) if y > 50])))
rank_to_color_place = dict(zip(range(17), [0, 16, 8, 4, 12, 2, 6, 10, 14, 1, 3, 5, 7, 9, 11, 13, 15])) # handcrafted to maximise color distance, I think
......@@ -1777,23 +1725,8 @@ if __name__ == "__main__":
plt.savefig("peter figures/{}clustered_trials_{}_{}".format(string_prefix, subject, 'criteria comp').replace('.', '_'))
plt.show()
quit()
# plt.imshow(consistencies)
# plt.savefig("peter figures/mode_consistencies_{}".format(subject))
# plt.close()
continue
quit()
consistencies = test.consistency_rsa(indices=list(range(0, test.m * test.n, 8)))
pickle.dump(consistencies, open("multi_chain_saves/general_consistencies_{}.p".format(subject), 'wb'))
plt.imshow(consistencies)
plt.savefig("dynamic_GLM_figures/general_consistencies_{}".format(subject))
plt.close()
# state_sets = fuse_states(test, consistencies > consistencies[0, 0] * 0.95)
except FileNotFoundError as e:
print(e)
print('here')
r_hat = 1.5
for r in r_hats:
if r[0] == subject:
......
......@@ -99,6 +99,24 @@ def find_good_chains_unsplit_fast(chains1, chains2, chains3, chains4, reduce_to=
return sol, r_hat_min
def params_to_pmf(params):
return params[2] + (1 - params[2] - params[3]) / (1 + np.exp(- params[0] * (all_conts - params[1])))
def four_param_loss(params, pmf, offset=False):
# if params[2] + params[3] > 1:
# return 10
contrast_bools = len_to_bools[len(pmf)]
fit = params_to_pmf(params)[contrast_bools]
return np.sum((pmf - fit) ** 2) + offset * (np.abs(params[1]) * 0.0001 + np.clip(np.abs(params[0] - 1), 0, 0.35) * 0.0002)
def four_param_pmf(pmf, s1=1, s2=0, offset=False):
res = minimize(lambda x: four_param_loss(x, pmf, offset=offset), method='Nelder-Mead', x0=np.array([s1, s2, pmf[0], pmf[-1]]))
# can give [s1, s2, pmf.min(), 1 - pmf.max()] as starting points, but gives weird error
return res
if __name__ == "__main__":
subjects = list(loading_info.keys())
......
import os
i = 0
for filename in os.listdir("./"):
if not filename.endswith('.p'):
continue
if 'bias' in filename:
continue
if not filename.endswith('bias.p'):
i += 1
print("Rename {} into {}".format(filename, filename[:-2] + '_prebias.p'))
os.rename(filename, filename[:-2] + '_prebias.p')
print(i)