diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py index bcae85556847df87a88bf37d6af5159780ae8123..3389c9473732e38b5e8adbdf3fa277eeb275f7f0 100644 --- a/dyn_glm_chain_analysis.py +++ b/dyn_glm_chain_analysis.py @@ -441,9 +441,6 @@ class MCMC_result: self.n_contrasts = 11 self.cont_ticks = all_cont_ticks - self.state_to_color = {} - self.state_to_ls = {} - self.session_contrasts = [np.unique(cont_mapping(d[:, 0] - d[:, 1])) for d in self.data] self.count_assigns() @@ -617,9 +614,11 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur if consistencies is None: active_trials[relevant_trials - trial_counter] = 1 else: - active_trials[relevant_trials - trial_counter] = np.mean(consistencies[tuple(np.meshgrid(relevant_trials, trials))], axis=0) - print("fix this by taking the whole array, multiply by n, subtract n, divide by n-1") - input() + active_trials[relevant_trials - trial_counter] = np.sum(consistencies[tuple(np.meshgrid(relevant_trials, trials))], axis=0) + active_trials[relevant_trials - trial_counter] -= 1 + active_trials[relevant_trials - trial_counter] = active_trials[relevant_trials - trial_counter] / (trials.shape[0] - 1) + # print("fix this by taking the whole array, multiply by n, subtract n, divide by n-1") + # input() label = "State {}".format(state) if np.sum(relevant_trials) > 0.02 * len(test.results[0].models[0].stateseqs[seq_num]) else None @@ -747,7 +746,6 @@ def lapse_sides(test, state_sets, indices): def state_development_single_sample(test, indices, save=True, save_append='', show=True, dpi='figure', separate_pmf=False): - session_contrasts = [np.unique(cont_mapping(d[:, 0] - d[:, 1])) for d in test.results[0].data] for i, m in enumerate([item for sublist in test.results for item in sublist.models]): if i not in indices: @@ -809,7 +807,7 @@ def state_development_single_sample(test, indices, save=True, save_append='', sh session_max = i defined_points = np.zeros(test.results[0].n_contrasts, dtype=bool) - defined_points[session_contrasts[session_max]] = True + defined_points[test.results[0].session_contrasts[session_max]] = True n_points = 150 points = np.linspace(1, test.results[0].n_sessions, n_points) @@ -913,7 +911,6 @@ def state_development_single_sample(test, indices, save=True, save_append='', sh def state_development(test, state_sets, indices, save=True, save_append='', show=True, dpi='figure', separate_pmf=False): state_sets = [np.array(s) for s in state_sets] - session_contrasts = [np.unique(cont_mapping(d[:, 0] - d[:, 1])) for d in test.results[0].data] if test.results[0].name.startswith('GLM_Sim_'): print("./glm sim mice/truth_{}.p".format(test.results[0].name)) @@ -1019,7 +1016,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show trial_counter += len(state_seq) defined_points = np.zeros(test.results[0].n_contrasts, dtype=bool) - defined_points[session_contrasts[session_max]] = True + defined_points[test.results[0].session_contrasts[session_max]] = True if not separate_pmf: temp = np.sum(pmfs[:, defined_points]) / (np.sum(defined_points)) state_color = colors[int(temp * 101 - 1)] @@ -1195,41 +1192,6 @@ def find_good_chains_unsplit(chains1, chains2, chains3, chains4, reduce_to=8, si return sol, r_hat_min -def find_good_chains_unsplit_fast(chains1, chains2, chains3, chains4, reduce_to=8): - delete_n = - reduce_to + chains1.shape[0] - mins = np.zeros(delete_n + 1) - n_chains = chains1.shape[0] - chains = np.stack([chains1, chains2, chains3, chains4]) - - print("Without removals: {}".format(eval_simple_r_hat(chains))) - r_hat = eval_simple_r_hat(chains) - - mins[0] = r_hat - - l, m, n = chains.shape - psi_dot_j = np.mean(chains, axis=2) - s_j_squared = np.sum((chains - psi_dot_j[:, :, None]) ** 2, axis=2) / (n - 1) - - r_hat_min = 10 - sol = 0 - for x in combinations(range(n_chains), n_chains - delete_n): - temp1 = chains[:, x] - temp2 = psi_dot_j[:, x] - temp3 = s_j_squared[:, x] - r_hat = eval_amortized_r_hat(temp1, temp2, temp3, l, m - delete_n, n) - if r_hat < r_hat_min: - sol = x - r_hat_min = min(r_hat, r_hat_min) - print("Minimum is {} (removed {})".format(r_hat_min, delete_n)) - sol = [i for i in range(n_chains) if i not in sol] - print("Removed: {}".format(sol)) - - r_hat_local = eval_r_hat(np.delete(chains1, sol, axis=0), np.delete(chains2, sol, axis=0), np.delete(chains3, sol, axis=0), np.delete(chains4, sol, axis=0)) - print("Minimum over everything is {} (removed {})".format(r_hat_local, delete_n)) - - return sol, r_hat_min - - def find_good_chains_unsplit_greedy(chains1, chains2, chains3, chains4, reduce_to=8, simple=False): delete_n = - reduce_to + chains1.shape[0] mins = np.zeros(delete_n + 1) @@ -1320,12 +1282,6 @@ if __name__ == "__main__": chains3 = chains3[:, 160:] chains4 = chains4[:, 160:] - # chains1 = chains1[:16] - # chains2 = chains2[:16] - # chains3 = chains3[:16] - # chains4 = chains4[:16] - - # mins = find_good_chains(chains[:, :-1].reshape(32, chains.shape[1] // 2)) sol, final_r_hat = find_good_chains_unsplit_greedy(chains1, chains2, chains3, chains4, reduce_to=chains1.shape[0] // 2) r_hats.append((subject, final_r_hat)) @@ -1352,22 +1308,6 @@ def dist_helper(dist_matrix, state_hists, inds): return dist_matrix -def _get_leaves_color_list(R): - # copied from latest scipy version - leaves_color_list = [None] * len(R['leaves']) - for link_x, link_y, link_color in zip(R['icoord'], - R['dcoord'], - R['color_list']): - for (xi, yi) in zip(link_x, link_y): - if yi == 0.0: # if yi is 0.0, the point is a leaf - # xi of leaves are 5, 15, 25, 35, ... (see `iv_ticks`) - # index of leaves are 0, 1, 2, 3, ... as below - leaf_index = (int(xi) - 5) // 10 - # each leaf has a same color of its link. - leaves_color_list[leaf_index] = link_color - return leaves_color_list - - def state_size_helper(n=0, mode_specific=False): if not mode_specific: def nth_largest_state_func(x): @@ -1496,7 +1436,7 @@ if __name__ == "__main__": loading_info = json.load(open("canonical_infos.json", 'r')) r_hats = json.load(open("canonical_info_r_hats.json", 'r')) no_good_pcas = ['NYU-06', 'SWC_023'] # no good rhat: 'ibl_witten_13' - subjects = list(loading_info.keys()) + subjects = ['KS014'] # list(loading_info.keys()) print(subjects) fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0] dur = 'yes' @@ -1531,7 +1471,7 @@ if __name__ == "__main__": 'SWC_023': (9, 'skip'), # gradual 'ZM_1897': (5, 'right'), 'ZM_3003': (1, 'skip')} - fig, ax = plt.subplots(1, 3, sharey=True, figsize=(16, 9)) + # fig, ax = plt.subplots(1, 3, sharey=True, figsize=(16, 9)) thinning = 25 summary_info = {"thinning": thinning, "contains": [], "seeds": [], "fit_nums": []} @@ -1569,22 +1509,19 @@ if __name__ == "__main__": test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'rb')) print('loaded canoncial result') - mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}.p".format(subject), 'rb')) - - - state_sets = pickle.load(open("multi_chain_saves/state_sets_{}.p".format(subject), 'rb')) + # state_sets = pickle.load(open("multi_chain_saves/state_sets_{}.p".format(subject), 'rb')) # lapse_sides(test, [s for s in state_sets if len(s) > 40], mode_indices) # continue # - # mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'rb')) - # state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb')) - # states, pmfs = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=True, separate_pmf=True) + mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'rb')) + state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb')) + states, pmfs = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=True, separate_pmf=True) # state_development_single_sample(test, [mode_indices[0]], show=True, separate_pmf=True, save=False) - # quit() - # consistencies = pickle.load(open("multi_chain_saves/consistencies_{}.p".format(subject), 'rb')) - # consistencies /= consistencies[0, 0] - # contrasts_plot(test, [s for s in state_sets if len(s) > 40], subject=subject, save=True, show=True, consistencies=consistencies) - # quit() + consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb')) + consistencies /= consistencies[0, 0] + contrasts_plot(test, [s for s in state_sets if len(s) > 40], subject=subject, save=True, show=True, consistencies=consistencies) + quit() + # state_types = np.zeros((3, states.shape[1])) # for s, pmf in zip(states, pmfs): @@ -1802,7 +1739,6 @@ if __name__ == "__main__": contrasts_plot(test, [s for s in state_sets if len(s) > 40], subject=subject, save_append='_{}{}'.format(string_prefix, criterion), save=True, show=True) # R = hc.dendrogram(linkage, no_plot=True, get_leaves=True, color_threshold=color_threshold) - # R["leaves_color_list"] = _get_leaves_color_list(R) # leaves_color_list = np.array(R["leaves_color_list"]) # leaves = np.array(R["leaves"]) diff --git a/dyn_glm_chain_analysis_unused_funcs.py b/dyn_glm_chain_analysis_unused_funcs.py index 38015be96f29c92060dfad32e07934207ef9adee..d1ae35cc0c469ea18daf07886ff472923eb05840 100644 --- a/dyn_glm_chain_analysis_unused_funcs.py +++ b/dyn_glm_chain_analysis_unused_funcs.py @@ -64,6 +64,41 @@ class MCMC_result: return consistency_mat +def find_good_chains_unsplit_fast(chains1, chains2, chains3, chains4, reduce_to=8): + delete_n = - reduce_to + chains1.shape[0] + mins = np.zeros(delete_n + 1) + n_chains = chains1.shape[0] + chains = np.stack([chains1, chains2, chains3, chains4]) + + print("Without removals: {}".format(eval_simple_r_hat(chains))) + r_hat = eval_simple_r_hat(chains) + + mins[0] = r_hat + + l, m, n = chains.shape + psi_dot_j = np.mean(chains, axis=2) + s_j_squared = np.sum((chains - psi_dot_j[:, :, None]) ** 2, axis=2) / (n - 1) + + r_hat_min = 10 + sol = 0 + for x in combinations(range(n_chains), n_chains - delete_n): + temp1 = chains[:, x] + temp2 = psi_dot_j[:, x] + temp3 = s_j_squared[:, x] + r_hat = eval_amortized_r_hat(temp1, temp2, temp3, l, m - delete_n, n) + if r_hat < r_hat_min: + sol = x + r_hat_min = min(r_hat, r_hat_min) + print("Minimum is {} (removed {})".format(r_hat_min, delete_n)) + sol = [i for i in range(n_chains) if i not in sol] + print("Removed: {}".format(sol)) + + r_hat_local = eval_r_hat(np.delete(chains1, sol, axis=0), np.delete(chains2, sol, axis=0), np.delete(chains3, sol, axis=0), np.delete(chains4, sol, axis=0)) + print("Minimum over everything is {} (removed {})".format(r_hat_local, delete_n)) + + return sol, r_hat_min + + if __name__ == "__main__": subjects = list(loading_info.keys())