Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision

Target

Select target project
  • sbruijns/ihmm_behav_states
1 result
Select Git revision
Show changes
Commits on Source (2)
......@@ -441,13 +441,9 @@ class MCMC_result:
self.n_contrasts = 11
self.cont_ticks = all_cont_ticks
self.state_to_color = {}
self.state_to_ls = {}
self.session_contrasts = [np.unique(cont_mapping(d[:, 0] - d[:, 1])) for d in self.data]
self.count_assigns()
# self.calc_state_by_sess()
def count_assigns(self):
self.assign_counts = np.zeros((self.n_samples, self.n_all_states))
......@@ -457,14 +453,6 @@ class MCMC_result:
for s in range(self.n_all_states):
self.assign_counts[i, s] = np.sum(flat_list == s)
def calc_state_by_sess(self):
self.states_by_session = np.zeros((self.n_pstates, self.n_sessions))
for m in self.models:
for i, seq in enumerate(m.stateseqs):
for s in self.proto_states:
self.states_by_session[self.state_map[s], i] += np.sum(seq == s) / len(seq)
self.states_by_session /= self.n_samples
def state_appearance_posterior(self):
# posterior over when new states appear
state_starts = np.zeros(self.n_datapoints)
......@@ -626,9 +614,11 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur
if consistencies is None:
active_trials[relevant_trials - trial_counter] = 1
else:
active_trials[relevant_trials - trial_counter] = np.mean(consistencies[tuple(np.meshgrid(relevant_trials, trials))], axis=0)
print("fix this by taking the whole array, multiply by n, subtract n, divide by n-1")
input()
active_trials[relevant_trials - trial_counter] = np.sum(consistencies[tuple(np.meshgrid(relevant_trials, trials))], axis=0)
active_trials[relevant_trials - trial_counter] -= 1
active_trials[relevant_trials - trial_counter] = active_trials[relevant_trials - trial_counter] / (trials.shape[0] - 1)
# print("fix this by taking the whole array, multiply by n, subtract n, divide by n-1")
# input()
label = "State {}".format(state) if np.sum(relevant_trials) > 0.02 * len(test.results[0].models[0].stateseqs[seq_num]) else None
......@@ -756,7 +746,6 @@ def lapse_sides(test, state_sets, indices):
def state_development_single_sample(test, indices, save=True, save_append='', show=True, dpi='figure', separate_pmf=False):
session_contrasts = [np.unique(cont_mapping(d[:, 0] - d[:, 1])) for d in test.results[0].data]
for i, m in enumerate([item for sublist in test.results for item in sublist.models]):
if i not in indices:
......@@ -818,7 +807,7 @@ def state_development_single_sample(test, indices, save=True, save_append='', sh
session_max = i
defined_points = np.zeros(test.results[0].n_contrasts, dtype=bool)
defined_points[session_contrasts[session_max]] = True
defined_points[test.results[0].session_contrasts[session_max]] = True
n_points = 150
points = np.linspace(1, test.results[0].n_sessions, n_points)
......@@ -922,7 +911,6 @@ def state_development_single_sample(test, indices, save=True, save_append='', sh
def state_development(test, state_sets, indices, save=True, save_append='', show=True, dpi='figure', separate_pmf=False):
state_sets = [np.array(s) for s in state_sets]
session_contrasts = [np.unique(cont_mapping(d[:, 0] - d[:, 1])) for d in test.results[0].data]
if test.results[0].name.startswith('GLM_Sim_'):
print("./glm sim mice/truth_{}.p".format(test.results[0].name))
......@@ -1028,7 +1016,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
trial_counter += len(state_seq)
defined_points = np.zeros(test.results[0].n_contrasts, dtype=bool)
defined_points[session_contrasts[session_max]] = True
defined_points[test.results[0].session_contrasts[session_max]] = True
if not separate_pmf:
temp = np.sum(pmfs[:, defined_points]) / (np.sum(defined_points))
state_color = colors[int(temp * 101 - 1)]
......@@ -1204,41 +1192,6 @@ def find_good_chains_unsplit(chains1, chains2, chains3, chains4, reduce_to=8, si
return sol, r_hat_min
def find_good_chains_unsplit_fast(chains1, chains2, chains3, chains4, reduce_to=8):
delete_n = - reduce_to + chains1.shape[0]
mins = np.zeros(delete_n + 1)
n_chains = chains1.shape[0]
chains = np.stack([chains1, chains2, chains3, chains4])
print("Without removals: {}".format(eval_simple_r_hat(chains)))
r_hat = eval_simple_r_hat(chains)
mins[0] = r_hat
l, m, n = chains.shape
psi_dot_j = np.mean(chains, axis=2)
s_j_squared = np.sum((chains - psi_dot_j[:, :, None]) ** 2, axis=2) / (n - 1)
r_hat_min = 10
sol = 0
for x in combinations(range(n_chains), n_chains - delete_n):
temp1 = chains[:, x]
temp2 = psi_dot_j[:, x]
temp3 = s_j_squared[:, x]
r_hat = eval_amortized_r_hat(temp1, temp2, temp3, l, m - delete_n, n)
if r_hat < r_hat_min:
sol = x
r_hat_min = min(r_hat, r_hat_min)
print("Minimum is {} (removed {})".format(r_hat_min, delete_n))
sol = [i for i in range(n_chains) if i not in sol]
print("Removed: {}".format(sol))
r_hat_local = eval_r_hat(np.delete(chains1, sol, axis=0), np.delete(chains2, sol, axis=0), np.delete(chains3, sol, axis=0), np.delete(chains4, sol, axis=0))
print("Minimum over everything is {} (removed {})".format(r_hat_local, delete_n))
return sol, r_hat_min
def find_good_chains_unsplit_greedy(chains1, chains2, chains3, chains4, reduce_to=8, simple=False):
delete_n = - reduce_to + chains1.shape[0]
mins = np.zeros(delete_n + 1)
......@@ -1329,12 +1282,6 @@ if __name__ == "__main__":
chains3 = chains3[:, 160:]
chains4 = chains4[:, 160:]
# chains1 = chains1[:16]
# chains2 = chains2[:16]
# chains3 = chains3[:16]
# chains4 = chains4[:16]
# mins = find_good_chains(chains[:, :-1].reshape(32, chains.shape[1] // 2))
sol, final_r_hat = find_good_chains_unsplit_greedy(chains1, chains2, chains3, chains4, reduce_to=chains1.shape[0] // 2)
r_hats.append((subject, final_r_hat))
......@@ -1361,22 +1308,6 @@ def dist_helper(dist_matrix, state_hists, inds):
return dist_matrix
def _get_leaves_color_list(R):
# copied from latest scipy version
leaves_color_list = [None] * len(R['leaves'])
for link_x, link_y, link_color in zip(R['icoord'],
R['dcoord'],
R['color_list']):
for (xi, yi) in zip(link_x, link_y):
if yi == 0.0: # if yi is 0.0, the point is a leaf
# xi of leaves are 5, 15, 25, 35, ... (see `iv_ticks`)
# index of leaves are 0, 1, 2, 3, ... as below
leaf_index = (int(xi) - 5) // 10
# each leaf has a same color of its link.
leaves_color_list[leaf_index] = link_color
return leaves_color_list
def state_size_helper(n=0, mode_specific=False):
if not mode_specific:
def nth_largest_state_func(x):
......@@ -1505,7 +1436,7 @@ if __name__ == "__main__":
loading_info = json.load(open("canonical_infos.json", 'r'))
r_hats = json.load(open("canonical_info_r_hats.json", 'r'))
no_good_pcas = ['NYU-06', 'SWC_023'] # no good rhat: 'ibl_witten_13'
subjects = list(loading_info.keys())
subjects = ['KS014'] # list(loading_info.keys())
print(subjects)
fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0]
dur = 'yes'
......@@ -1540,7 +1471,7 @@ if __name__ == "__main__":
'SWC_023': (9, 'skip'), # gradual
'ZM_1897': (5, 'right'),
'ZM_3003': (1, 'skip')}
fig, ax = plt.subplots(1, 3, sharey=True, figsize=(16, 9))
# fig, ax = plt.subplots(1, 3, sharey=True, figsize=(16, 9))
thinning = 25
summary_info = {"thinning": thinning, "contains": [], "seeds": [], "fit_nums": []}
......@@ -1578,22 +1509,19 @@ if __name__ == "__main__":
test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'rb'))
print('loaded canoncial result')
mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}.p".format(subject), 'rb'))
state_sets = pickle.load(open("multi_chain_saves/state_sets_{}.p".format(subject), 'rb'))
# state_sets = pickle.load(open("multi_chain_saves/state_sets_{}.p".format(subject), 'rb'))
# lapse_sides(test, [s for s in state_sets if len(s) > 40], mode_indices)
# continue
#
# mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'rb'))
# state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb'))
# states, pmfs = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=True, separate_pmf=True)
mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'rb'))
state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb'))
states, pmfs = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=True, separate_pmf=True)
# state_development_single_sample(test, [mode_indices[0]], show=True, separate_pmf=True, save=False)
# quit()
# consistencies = pickle.load(open("multi_chain_saves/consistencies_{}.p".format(subject), 'rb'))
# consistencies /= consistencies[0, 0]
# contrasts_plot(test, [s for s in state_sets if len(s) > 40], subject=subject, save=True, show=True, consistencies=consistencies)
# quit()
consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
consistencies /= consistencies[0, 0]
contrasts_plot(test, [s for s in state_sets if len(s) > 40], subject=subject, save=True, show=True, consistencies=consistencies)
quit()
# state_types = np.zeros((3, states.shape[1]))
# for s, pmf in zip(states, pmfs):
......@@ -1811,7 +1739,6 @@ if __name__ == "__main__":
contrasts_plot(test, [s for s in state_sets if len(s) > 40], subject=subject, save_append='_{}{}'.format(string_prefix, criterion), save=True, show=True)
# R = hc.dendrogram(linkage, no_plot=True, get_leaves=True, color_threshold=color_threshold)
# R["leaves_color_list"] = _get_leaves_color_list(R)
# leaves_color_list = np.array(R["leaves_color_list"])
# leaves = np.array(R["leaves"])
......
......@@ -64,6 +64,41 @@ class MCMC_result:
return consistency_mat
def find_good_chains_unsplit_fast(chains1, chains2, chains3, chains4, reduce_to=8):
delete_n = - reduce_to + chains1.shape[0]
mins = np.zeros(delete_n + 1)
n_chains = chains1.shape[0]
chains = np.stack([chains1, chains2, chains3, chains4])
print("Without removals: {}".format(eval_simple_r_hat(chains)))
r_hat = eval_simple_r_hat(chains)
mins[0] = r_hat
l, m, n = chains.shape
psi_dot_j = np.mean(chains, axis=2)
s_j_squared = np.sum((chains - psi_dot_j[:, :, None]) ** 2, axis=2) / (n - 1)
r_hat_min = 10
sol = 0
for x in combinations(range(n_chains), n_chains - delete_n):
temp1 = chains[:, x]
temp2 = psi_dot_j[:, x]
temp3 = s_j_squared[:, x]
r_hat = eval_amortized_r_hat(temp1, temp2, temp3, l, m - delete_n, n)
if r_hat < r_hat_min:
sol = x
r_hat_min = min(r_hat, r_hat_min)
print("Minimum is {} (removed {})".format(r_hat_min, delete_n))
sol = [i for i in range(n_chains) if i not in sol]
print("Removed: {}".format(sol))
r_hat_local = eval_r_hat(np.delete(chains1, sol, axis=0), np.delete(chains2, sol, axis=0), np.delete(chains3, sol, axis=0), np.delete(chains4, sol, axis=0))
print("Minimum over everything is {} (removed {})".format(r_hat_local, delete_n))
return sol, r_hat_min
if __name__ == "__main__":
subjects = list(loading_info.keys())
......