diff --git a/__pycache__/mcmc_chain_analysis.cpython-37.pyc b/__pycache__/mcmc_chain_analysis.cpython-37.pyc index bdea663ae5d36ec0de32219c100df06016b9244f..9f093ef55cf22048cd989c21378e8d782850879c 100644 Binary files a/__pycache__/mcmc_chain_analysis.cpython-37.pyc and b/__pycache__/mcmc_chain_analysis.cpython-37.pyc differ diff --git a/__pycache__/simplex_plot.cpython-37.pyc b/__pycache__/simplex_plot.cpython-37.pyc index 1c26491f573cb6355e901c949147f3d921c87a30..8fa16cdde21c4a7d3d6daa3797043fc2e294e842 100644 Binary files a/__pycache__/simplex_plot.cpython-37.pyc and b/__pycache__/simplex_plot.cpython-37.pyc differ diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py index 3f1e31cffdccbaeedd50667e8840170d2e8499b0..71e06373a2a4dc0a114fb5407e8de9f2a9479899 100644 --- a/dyn_glm_chain_analysis.py +++ b/dyn_glm_chain_analysis.py @@ -574,10 +574,10 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur noise = np.zeros(len(c_n_a))# np.random.rand(len(c_n_a)) * 0.4 - 0.2 mask = c_n_a[:, -1] == 0 - plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - c_n_a[mask, 0] + c_n_a[mask, 1]), 'o', c='b', ms=ms, label='Leftward') + plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - c_n_a[mask, 0] + c_n_a[mask, 1]), 'o', c='b', ms=ms, label='Leftward', alpha=0.6) mask = c_n_a[:, -1] == 1 - plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - c_n_a[mask, 0] + c_n_a[mask, 1]), 'o', c='r', ms=ms, label='Rightward') + plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - c_n_a[mask, 0] + c_n_a[mask, 1]), 'o', c='r', ms=ms, label='Rightward', alpha=0.6) plt.title("session #{} / {}".format(1+seq_num, test.results[0].n_sessions), size=26) # plt.yticks(*self.cont_ticks, size=22-2) @@ -852,7 +852,7 @@ def state_development_single_sample(test, indices, save=True, save_append='', sh return states_by_session, all_pmfs -def state_development(test, state_sets, indices, save=True, save_append='', show=True, dpi='figure', separate_pmf=False): +def state_development(test, state_sets, indices, save=True, save_append='', show=True, dpi='figure', separate_pmf=False, type_coloring=True): # Now also returns durs of state types and state type summary array state_sets = [np.array(s) for s in state_sets] @@ -1038,7 +1038,8 @@ def state_development(test, state_sets, indices, save=True, save_append='', show introductions_by_stage = np.zeros(3) covered_states = [] for i, d in enumerate(durs): - ax0.fill_between(range(dur_counter, 1 + dur_counter + d), 0.5, -0.5, color=type_colours[i], zorder=0, alpha=0.3) + if type_coloring: + ax0.fill_between(range(dur_counter, 1 + dur_counter + d), 0.5, -0.5, color=type_colours[i], zorder=0, alpha=0.3) dur_counter += d # find out during which state type which contrast was introduced @@ -1098,7 +1099,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show else: plt.close() - return states_by_session, all_pmfs, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs)) + return states_by_session, all_pmfs, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs)), introductions_by_stage def smart_divide(a, b): @@ -1154,14 +1155,13 @@ if __name__ == "__main__": r_hats = [] + # R^hat tests # test = MCMC_result_list([fake_result(100) for i in range(8)]) # test.r_hat_and_ess(return_ascending, False) # test.r_hat_and_ess(return_ascending_shuffled, False) # quit() - good = [] - bad = [] - check_r_hats = False + check_r_hats = True if check_r_hats: subjects = list(loading_info.keys()) subjects = ['KS014'] @@ -1188,11 +1188,6 @@ if __name__ == "__main__": r_hats.append((subject, final_r_hat)) loading_info[subject]['ignore'] = sol - if final_r_hat < 1.05: - good.append(subject) - else: - bad.append(subject) - print(r_hats) if fit_type == 'bias': json.dump(loading_info, open("canonical_infos_bias.json", 'w')) @@ -1221,10 +1216,15 @@ def state_type_durs(states, pmfs): pmf_counter += 1 state_types[pmf_type(pmf[1][pmf_counter][pmf[0]]), i] += s[i] # indexing horror - durs = (np.where(state_types[1] > 0.5)[0][0], - np.where(state_types[2] > 0.5)[0][0] - np.where(state_types[1] > 0.5)[0][0], - states.shape[1] - np.where(state_types[2] > 0.5)[0][0]) - if np.where(state_types[2] > 0.5)[0][0] < np.where(state_types[1] > 0.5)[0][0]: + if np.any(state_types[1] > 0.5): + durs = (np.where(state_types[1] > 0.5)[0][0], + np.where(state_types[2] > 0.5)[0][0] - np.where(state_types[1] > 0.5)[0][0], + states.shape[1] - np.where(state_types[2] > 0.5)[0][0]) + if np.where(state_types[2] > 0.5)[0][0] < np.where(state_types[1] > 0.5)[0][0]: + durs = (np.where(state_types[2] > 0.5)[0][0], + 0, + states.shape[1] - np.where(state_types[2] > 0.5)[0][0]) + else: durs = (np.where(state_types[2] > 0.5)[0][0], 0, states.shape[1] - np.where(state_types[2] > 0.5)[0][0]) @@ -1251,6 +1251,24 @@ def state_cluster_interpolation(states, pmfs): return state_types, state_trans, pmf_examples +def get_first_pmfs(states, pmfs): + # get the first pmf of every type, also where they are defined, and whether they are the first pmf of that state + earliest_sessions = [1000, 1000, 1000] + first_pmfs = [0, 0, 0, 0, 0, 0, 0, 0, 0] + changing_pmfs = [[0, 0], [0, 0]] + for state, pmf in zip(states, pmfs): + sessions = np.where(state)[0] + for i, (sess_pmf, sess) in enumerate(zip(pmf[1], sessions)): + if earliest_sessions[pmf_type(sess_pmf[pmf[0]])] > sess: + earliest_sessions[pmf_type(sess_pmf[pmf[0]])] = sess + first_pmfs[3 * pmf_type(sess_pmf[pmf[0]])] = sess_pmf + first_pmfs[1 + 3 * pmf_type(sess_pmf[pmf[0]])] = pmf[0] + first_pmfs[2 + 3 * pmf_type(sess_pmf[pmf[0]])] = i + if i != 0: + changing_pmfs[pmf_type(sess_pmf[pmf[0]]) - 1] = [pmf[0], pmf[1]] + return first_pmfs, changing_pmfs + + def plot_pmf_types(pmf_types, subject, fit_type, save=True, show=False): # Plot the different types of PMFs, all split up by their different types for i, pmfs in enumerate(pmf_types): @@ -1267,15 +1285,155 @@ def plot_pmf_types(pmf_types, subject, fit_type, save=True, show=False): else: plt.close() + def pmf_type(pmf): - if pmf[-1] - pmf[0] <= 0.15: + if pmf[-1] - pmf[0] < 0.2: return 0 - elif pmf[-1] - pmf[0] < 0.6 and np.abs(pmf[0] + pmf[-1] - 1) > 0.1: + elif pmf[-1] - pmf[0] < 0.6:# and np.abs(pmf[0] + pmf[-1] - 1) > 0.1: return 1 else: return 2 +type2color = {0: 'green', 1: 'blue', 2: 'red'} + +if False: + + all_changing_pmfs = pickle.load(open("changing_pmfs.p", 'rb')) + plt.figure(figsize=(16, 9)) + for i, pmf in enumerate(all_changing_pmfs): + plt.subplot(4, 7, i + 1) + for p in pmf[1]: + plt.plot(np.where(pmf[0])[0], p[pmf[0]], color=type2color[pmf_type(p)]) + plt.ylim(0, 1) + + sns.despine() + if i+1 != 22: + plt.gca().set_xticks([]) + plt.gca().set_yticks([]) + else: + plt.xlabel("Contrasts", size=22) + plt.ylabel("P(rightwards)", size=22) + plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=16) + plt.gca().set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=16) + + if i + 1 == 30: + break + + plt.tight_layout() + plt.savefig("changing pmfs") + plt.show() + quit() + + type_2_assyms = [] + tick_size = 14 + label_size = 26 + all_first_pmfs = pickle.load(open("pmfs_temp.p", 'rb')) + plt.figure(figsize=(16, 9)) + plt.subplot(1, 3, 1) + counter = [[0, 0], [0, 0]] + save_title = "all types" if False else "KS014 types" + if save_title == "KS014 types": + all_first_pmfs = {'KS014': all_first_pmfs['KS014']} + + for key in all_first_pmfs: + x = all_first_pmfs[key] + if type(x[0]) == int: + continue + linestyle = '-' if x[2] == 0 else '--' + plt.plot(np.where(x[1])[0], x[0][x[1]], linestyle=linestyle, c='g') + plt.ylim(0, 1) + plt.gca().set_xticks(np.arange(11), all_conts, size=tick_size) + plt.gca().set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size) + plt.gca().spines[['right', 'top']].set_visible(False) + plt.xlim(0, 10) + plt.xticks(rotation=45) + plt.gca().set_ylabel("P(rightwards)", size=label_size) + + plt.subplot(1, 3, 2) + for key in all_first_pmfs: + x = all_first_pmfs[key] + if type(x[3]) == int: + continue + type_2_assyms.append(np.abs(x[3][0] + x[3][-1] - 1)) + linestyle = '-' if x[5] == 0 else '--' + counter[0][0 if x[5] == 0 else 1] += 1 + if linestyle == '--': + continue + plt.plot(np.where(x[4])[0], x[3][x[4]], linestyle=linestyle, c='b') + plt.gca().set_yticks([]) + plt.ylim(0, 1) + plt.gca().set_xticks(np.arange(11), all_conts, size=tick_size) + plt.gca().spines[['right', 'top']].set_visible(False) + plt.xticks(rotation=45) + plt.xlim(0, 10) + plt.gca().set_xlabel("Contrasts", size=label_size) + + plt.subplot(1, 3, 3) + for key in all_first_pmfs: + x = all_first_pmfs[key] + if type(x[6]) == int: + continue + linestyle = '-' if x[8] == 0 else '--' + counter[1][0 if x[8] == 0 else 1] += 1 + if linestyle == '--': + continue + plt.plot(np.where(x[7])[0], x[6][x[7]], linestyle=linestyle, c='r') + plt.gca().set_yticks([]) + plt.ylim(0, 1) + plt.gca().set_xticks(np.arange(11), all_conts, size=tick_size) + plt.gca().spines[['right', 'top']].set_visible(False) + plt.xlim(0, 10) + plt.xticks(rotation=45) + + print(counter) + plt.tight_layout() + plt.savefig(save_title) + plt.show() + if save_title == "KS014 types": + quit() + + counter = 0 + fig, ax = plt.subplots(1, 3, figsize=(16, 9)) + for key in all_first_pmfs: + x = all_first_pmfs[key] + if type(x[3]) == int: + continue + linestyle = '-' if x[5] == 0 else '--' + if linestyle == '--': + continue + if np.abs(x[3][0] + x[3][-1] - 1) <= 0.1: + counter += 1 + use_ax = 2 + else: + use_ax = int(x[3][0] > 1 - x[3][-1]) + + ax[use_ax].plot(np.where(x[4])[0], x[3][x[4]], linestyle=linestyle, c='b') + ax[0].set_ylim(0, 1) + ax[0].set_xlim(0, 10) + ax[0].spines[['right', 'top']].set_visible(False) + ax[0].set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size) + ax[0].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) + ax[0].set_ylabel("P(rightwards)", size=label_size) + + ax[1].set_ylim(0, 1) + ax[1].set_xlim(0, 10) + ax[1].set_yticks([]) + ax[1].spines[['right', 'top']].set_visible(False) + ax[1].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) + ax[1].set_xlabel("Contrasts", size=label_size) + + ax[2].set_ylim(0, 1) + ax[2].set_xlim(0, 10) + ax[2].set_yticks([]) + ax[2].spines[['right', 'top']].set_visible(False) + ax[2].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) + print(counter) + plt.tight_layout() + plt.savefig("differentiate type 2") + plt.show() + quit() + if __name__ == "__main__": # visualise pmf types @@ -1301,40 +1459,11 @@ if __name__ == "__main__": loading_info = json.load(open("canonical_infos.json", 'r')) r_hats = json.load(open("canonical_info_r_hats.json", 'r')) no_good_pcas = ['NYU-06', 'SWC_023'] # no good rhat: 'ibl_witten_13' - subjects = list(loading_info.keys()) + subjects = ['KS014'] # list(loading_info.keys()) print(subjects) fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0] dur = 'yes' - fist_good_pmf = {'CSHL051': (6, 'left'), - 'CSHL059': (3, 'left'), - 'CSHL061': (4, 'right'), - 'CSHL062': (1, 'left'), - 'CSHL_007': (2, 'skip'), - 'CSHL_014': (4, 'left'), - 'CSHL_015': (2, 'right'), - 'CSHL_018': (3, 'skip'), - 'CSHL_020': (5, 'left'), # tough case - 'CSH_ZAD_001': (4, 'right'), - 'CSH_ZAD_011': (0, 'right'), - 'CSH_ZAD_022': (4, 'right'), - 'CSH_ZAD_025': (2, 'skip'), # eieiei - 'CSH_ZAD_026': (6, 'skip'), # gradual - 'ibl_witten_14': (3, 'left'), - 'ibl_witten_16': (3, 'right'), - # 'ibl_witten_18': (3, 'very weird'), # probably shouldn't be analysed - 'ibl_witten_19': (4, 'right'), - 'KS014': (4, 'right'), - 'KS015': (4, 'left'), - 'KS016': (3, 'skip'), # non-trivial - 'KS017': (3, 'right'), - 'KS021': (5, 'left'), - 'KS022': (4, 'right'), - 'KS023': (3, 'left'), - 'NYU-06': (9, 'right'), - 'SWC_023': (9, 'skip'), # gradual - 'ZM_1897': (5, 'right'), - 'ZM_3003': (1, 'skip')} # fig, ax = plt.subplots(1, 3, sharey=True, figsize=(16, 9)) thinning = 25 @@ -1356,6 +1485,12 @@ if __name__ == "__main__": not_yet = True abs_state_durs = [] + all_first_pmfs = {} + all_pmf_diffs = [] + all_pmf_asymms = [] + all_pmfs = [] + all_changing_pmfs = [] + all_intros = [] for subject in subjects: @@ -1366,30 +1501,44 @@ if __name__ == "__main__": results = [] try: - - print('loading canonical') test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'rb')) print('loaded canoncial result') mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'rb')) + quit() state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb')) # lapse differential # lapse_sides(test, [s for s in state_sets if len(s) > 40], mode_indices) # training overview - states, pmfs, durs, _, contrast_intro_type, intros_by_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=1, separate_pmf=True) + states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True) + all_intros.append(undiv_intros) intros_by_type_sum += intros_by_type - continue + first_pmfs, changing_pmfs = get_first_pmfs(states, pmfs) + for pmf in changing_pmfs: + if type(pmf[0]) == int: + continue + all_changing_pmfs.append(pmf) + all_first_pmfs[subject] = first_pmfs + for pmf in pmfs: + for p in pmf[1]: + all_pmf_diffs.append(p[-1] - p[0]) + all_pmf_asymms.append(np.abs(p[0] + p[-1] - 1)) + all_pmfs.append(p) contrast_intro_types.append(contrast_intro_type) # state_development_single_sample(test, [mode_indices[0]], show=True, separate_pmf=True, save=False) # session overview - # consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb')) - # consistencies /= consistencies[0, 0] - # contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True, consistencies=consistencies) + consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb')) + consistencies /= consistencies[0, 0] + contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True, consistencies=consistencies) # duration of different state types (and also percentage of type activities) abs_state_durs.append(durs) + simplex_durs = np.array(durs).reshape(1, 3) + print(simplex_durs / np.sum(simplex_durs)) + from simplex_plot import projectSimplex + print(projectSimplex(simplex_durs / simplex_durs.sum(1)[:, None])) continue # compute state type proportions and split the pmfs accordingly @@ -1567,6 +1716,7 @@ if __name__ == "__main__": plt.show() except FileNotFoundError as e: + continue print(e) r_hat = 1.5 for r in r_hats: @@ -1630,11 +1780,33 @@ if __name__ == "__main__": # state_appear.append(b[a == i][0] / (test.results[0].n_sessions - 1)) # state_dur.append(b[a == i].shape[0]) + # pickle.dump(all_first_pmfs, open("pmfs_temp.p", 'wb')) + # pickle.dump(all_changing_pmfs, open("changing_pmfs.p", 'wb')) + # + # a = [x for x, y in zip(all_pmf_asymms, all_pmf_diffs) if y >= 0.2] + # b = [y for x, y in zip(all_pmf_asymms, all_pmf_diffs) if y >= 0.2] + # plt.hist2d(a, b, bins=40) + # plt.show() + + if True: abs_state_durs = np.array(abs_state_durs) # pickle.dump(abs_state_durs, open("multi_chain_saves/abs_state_durs.p", 'wb')) abs_state_durs = pickle.load(open("multi_chain_saves/abs_state_durs.p", 'rb')) + print("Correlations") + from scipy.stats import pearsonr + print(pearsonr(abs_state_durs[:, 0], abs_state_durs[:, 1])) + print(pearsonr(abs_state_durs[:, 2], abs_state_durs[:, 1])) + print(pearsonr(abs_state_durs[:, 0], abs_state_durs[:, 2])) + + print(pearsonr(abs_state_durs.sum(1), abs_state_durs[:, 0])) + # (0.7338297529946006, 2.6332570579118393e-06) + print(pearsonr(abs_state_durs.sum(1), abs_state_durs[:, 1])) + # (0.35094585023228597, 0.052897046343413114) + print(pearsonr(abs_state_durs.sum(1), abs_state_durs[:, 2])) + # (0.7210260323745921, 4.747833912452452e-06) + from simplex_plot import plotSimplex plotSimplex(np.array(abs_state_durs), c='k', show=True) diff --git a/dynamic_GLMiHMM_consistency.py b/dynamic_GLMiHMM_consistency.py index a8b2cf936d5caca6bbc9dd302c0553ecad1edd26..8bf1924c1a08c92b50d8db433853faa7376093cb 100644 --- a/dynamic_GLMiHMM_consistency.py +++ b/dynamic_GLMiHMM_consistency.py @@ -42,10 +42,10 @@ for subject in subjects: from_session = info_dict['bias_start'] if fit_type == 'bias' else 0 models = [] - n_inputs = 5 + n_regressors = 5 T = till_session - from_session + (fit_type != 'prebias') - obs_hypparams = {'n_inputs': n_inputs, 'T': T, 'prior_mean': np.zeros(n_inputs), - 'P_0': 2 * np.eye(n_inputs), 'Q': fit_variance * np.tile(np.eye(n_inputs), (T, 1, 1))} + obs_hypparams = {'n_regressors': n_regressors, 'T': T, 'prior_mean': np.zeros(n_regressors), 'jumplimit': 3, + 'P_0': 2 * np.eye(n_regressors), 'Q': fit_variance * np.tile(np.eye(n_regressors), (T, 1, 1))} dur_hypparams = dict(r_support=np.array([1, 2, 3, 5, 7, 10, 15, 21, 28, 36, 45, 55, 150]), r_probs=np.ones(13)/13., alpha_0=1, beta_0=1) @@ -78,7 +78,7 @@ for subject in subjects: bad_trials = data[:, 1] == 1 bad_trials[0] = True - mega_data = np.empty((np.sum(~bad_trials), n_inputs + 1)) + mega_data = np.empty((np.sum(~bad_trials), n_regressors + 1)) mega_data[:, 0] = np.maximum(data[~bad_trials, 0], 0) mega_data[:, 1] = np.abs(np.minimum(data[~bad_trials, 0], 0)) @@ -123,7 +123,10 @@ for subject in subjects: prev_res = pickle.load(open(save_title, 'rb')) +counter = 0 for p, m in zip(prev_res, models): + print(counter) + counter += 1 for od, nd in zip(p.obs_distns, m.obs_distns): assert np.allclose(od.weights, nd.weights) prof._prof.print_stats() diff --git a/dynamic_GLMiHMM_fit.py b/dynamic_GLMiHMM_fit.py index b05ad929cf8ee0cc6558c0b6948875564ded52e6..dbbe447a55e51e03d20aef3e840ccd600d76f653 100644 --- a/dynamic_GLMiHMM_fit.py +++ b/dynamic_GLMiHMM_fit.py @@ -204,14 +204,14 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)): models = [] if params['obs_dur'] == 'glm': - n_inputs = len(params['regressors']) + n_regressors = len(params['regressors']) T = till_session - from_session + (params['fit_type'] != 'prebias') - obs_hypparams = {'n_inputs': n_inputs, 'T': T, 'jumplimit': params['jumplimit'], 'prior_mean': params['init_mean'], - 'P_0': params['init_var'] * np.eye(n_inputs), 'Q': params['fit_variance'] * np.tile(np.eye(n_inputs), (T, 1, 1))} + obs_hypparams = {'n_regressors': n_regressors, 'T': T, 'jumplimit': params['jumplimit'], 'prior_mean': params['init_mean'], + 'P_0': params['init_var'] * np.eye(n_regressors), 'Q': params['fit_variance'] * np.tile(np.eye(n_regressors), (T, 1, 1))} obs_distns = [distributions.Dynamic_GLM(**obs_hypparams) for state in range(params['n_states'])] else: - n_inputs = 9 if params['fit_type'] == 'bias' else 11 - obs_hypparams = {'n_inputs': n_inputs * (1 + (params['conditioned_on'] != 'nothing')), 'n_outputs': 2, 'T': till_session - from_session + (params['fit_type'] != 'prebias'), + n_regressors = 9 if params['fit_type'] == 'bias' else 11 + obs_hypparams = {'n_regressors': n_regressors * (1 + (params['conditioned_on'] != 'nothing')), 'n_outputs': 2, 'T': till_session - from_session + (params['fit_type'] != 'prebias'), 'jumplimit': params['jumplimit'], 'sigmasq_states': params['fit_variance']} obs_distns = [Dynamic_Input_Categorical(**obs_hypparams) for state in range(params['n_states'])] @@ -280,7 +280,7 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)): mask[0] = False if params['fit_type'] == 'zoe_style': mask[90:] = False - mega_data = np.empty((np.sum(mask), n_inputs + 1)) + mega_data = np.empty((np.sum(mask), n_regressors + 1)) for i, reg in enumerate(params['regressors']): # positive numbers are contrast on the right diff --git a/mcmc_chain_analysis.py b/mcmc_chain_analysis.py index 2a99a67602392e961761ef6ead4b9f40650e0d93..c8de12596afbb348ca65a77d28db0726031a52f5 100644 --- a/mcmc_chain_analysis.py +++ b/mcmc_chain_analysis.py @@ -8,6 +8,8 @@ import pickle def state_size_helper(n=0, mode_specific=False): + """Returns a function that returns the # of trials associated to the nth largest state in a sample + can be further specified to only look at specific samples, those of a mode""" if not mode_specific: def nth_largest_state_func(x): return np.partition(x.assign_counts, -1 - n, axis=1)[:, -1 - n] @@ -18,6 +20,8 @@ def state_size_helper(n=0, mode_specific=False): def state_num_helper(t, mode_specific=False): + """Returns a function that returns the # of states which have more trials than a percentage threshold t in a sample + can be further specified to only look at specific samples, those of a mode""" if not mode_specific: def state_num_func(x): return ((x.assign_counts / x.n_datapoints) > t).sum(1) else: @@ -35,14 +39,15 @@ def ll_func(x): return x.sample_lls[-x.n_samples:] def r_hat_array_comp(chains): - m, n = chains.shape + """Computes R^hat on an array of features, following Gelman p. 284f""" + m, n = chains.shape # number of chains, length of chains psi_dot_j = np.mean(chains, axis=1) psi_dot_dot = np.mean(psi_dot_j) B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot) ** 2) s_j_squared = np.sum((chains - psi_dot_j[:, None]) ** 2, axis=1) / (n - 1) W = np.mean(s_j_squared) var_hat_plus = (n - 1) / n * W + B / n - if W == 0: + if W == 0: # sometimes a feature has 0 variance # print("all the same value") return 1, 0 r_hat = np.sqrt(var_hat_plus / W) @@ -50,6 +55,7 @@ def r_hat_array_comp(chains): def eval_amortized_r_hat(chains, psi_dot_j, s_j_squared, m, n): + """Unused version in which some things were computed ahead of function to save time.""" psi_dot_dot = np.mean(psi_dot_j, axis=1) B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot[:, None]) ** 2, axis=1) W = np.mean(s_j_squared, axis=1) @@ -59,6 +65,7 @@ def eval_amortized_r_hat(chains, psi_dot_j, s_j_squared, m, n): def r_hat_array_comp_mult(chains): + """Compute R^hat of multiple features at once.""" _, m, n = chains.shape psi_dot_j = np.mean(chains, axis=2) psi_dot_dot = np.mean(psi_dot_j, axis=1) @@ -71,8 +78,8 @@ def r_hat_array_comp_mult(chains): def rank_inv_normal_transform(chains): - # Gelman paper Rank-normalization, folding, and localization: An improved R_hat for assessing convergence of MCMC - # ranking with average rank for ties + """Gelman paper Rank-normalization, folding, and localization: An improved R_hat for assessing convergence of MCMC + ranking with average rank for ties""" folded_chains = np.abs(chains - np.median(chains)) ranked = rankdata(chains).reshape(chains.shape) folded_ranked = rankdata(folded_chains).reshape(folded_chains.shape) @@ -83,6 +90,8 @@ def rank_inv_normal_transform(chains): def eval_r_hat(chains): + """Compute entire set of R^hat's for list of feature arrays, and return maximum across features. + Computes all R^hat versions, as opposed to eval_simple_r_hat""" r_hats = [] for chain in chains: rank_normalised, folded_rank_normalised, _, _ = rank_inv_normal_transform(chain) @@ -92,6 +101,8 @@ def eval_r_hat(chains): def eval_simple_r_hat(chains): + """Compute just simple R^hat's for list of feature arrays, and return maximum across features. + Computes only the simple type of R^hat, no folding or rank normalising, making it much faster""" r_hats, _ = r_hat_array_comp_mult(chains) return max(r_hats) @@ -103,21 +114,21 @@ def comp_multi_r_hat(chains, rank_normalised, folded_rank_normalised): return max(lame_r_hat, rank_normalised_r_hat, folded_rank_normalised_r_hat) -def sample_statistics(test, mode_indices, subject): +def sample_statistics(mode_indices, subject, period='prebias'): # prints out r_hats and sample sizes for given sample - test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb')) + test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, period), 'rb')) test.r_hat_and_ess(state_size_helper(1), False) test.r_hat_and_ess(state_size_helper(1, mode_specific=True), False, mode_indices=mode_indices) print() - test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb')) + test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, period), 'rb')) test.r_hat_and_ess(state_size_helper(), False) test.r_hat_and_ess(state_size_helper(mode_specific=True), False, mode_indices=mode_indices) print() - test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb')) + test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, period), 'rb')) test.r_hat_and_ess(state_num_helper(0.05), False) test.r_hat_and_ess(state_num_helper(0.05, mode_specific=True), False, mode_indices=mode_indices) print() - test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb')) + test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, period), 'rb')) test.r_hat_and_ess(state_num_helper(0.02), False) test.r_hat_and_ess(state_num_helper(0.02, mode_specific=True), False, mode_indices=mode_indices) print() diff --git a/simplex_plot.py b/simplex_plot.py index bdaad7885fa71bbbddcdd58b45447f4acf442ae3..dd8ba57d9f3d356cda8e04524404b4833a627a60 100644 --- a/simplex_plot.py +++ b/simplex_plot.py @@ -37,6 +37,7 @@ def plotSimplex(points, fig=None, fig.gca().text(0.43, np.sqrt(3) / 2 + 0.025, vertexlabels[2], size=24) # Project and draw the actual points projected = projectSimplex(points / points.sum(1)[:, None]) + print(projected) P.scatter(projected[:, 0], projected[:, 1], s=points.sum(1) * 3.5, **kwargs) # plot center with average size @@ -90,14 +91,14 @@ if __name__ == '__main__': labels = ('[0.1 0.1 0.8]', '[0.8 0.1 0.1]', '[0.5 0.4 0.1]', + '[0.17 0.33 0.5]', '[0.33 0.34 0.33]') testpoints = np.array([[0.1, 0.1, 0.8], [0.8, 0.1, 0.1], [0.5, 0.4, 0.1], + [0.17, 0.33, 0.5], [0.33, 0.34, 0.33]]) # Define different colors for each label c = range(len(labels)) # Do scatter plot - fig = plotSimplex(testpoints, s=25, c='k') - - P.show() + fig = plotSimplex(testpoints, c='k', show=1) diff --git a/test_codes/pymc_compare/call.py b/test_codes/pymc_compare/call.py index aeeedd8a5becdffd6824fc457a4578f8b2729a05..bc52153835ea5302a9483829c3a459a6c13240f3 100644 --- a/test_codes/pymc_compare/call.py +++ b/test_codes/pymc_compare/call.py @@ -1,14 +1,19 @@ -import pymc, bayes_fit # load the model file +"""Perform a pymc sampling of the test data.""" +import pymc, bayes_fit # load the model file import numpy as np import pickle +# Data params +T = 14 +n_inputs = 3 + +# Sampling params n_samples = 400000 -R = pymc.MCMC(bayes_fit) # build the model -R.sample(n_samples) # populate and run it +R = pymc.MCMC(bayes_fit) # build the model +R.sample(n_samples) # populate and run it -T = 14 -n_inputs = 3 +# Extract weights weights = np.zeros((T, n_samples, n_inputs)) for t in range(T): try: @@ -16,4 +21,5 @@ for t in range(T): except KeyError: weights[t] = R.trace('ws'.format(t)) +# Save everything pickle.dump(weights, open('pymc_posterior', 'wb')) diff --git a/test_codes/pymc_compare/dynglm_optimisation_test.py b/test_codes/pymc_compare/dynglm_optimisation_test.py index 05df6e0cd7e1e02aa462861346bcc1e721d82e8f..e17222eae9ed15efd967844f79fa70adcc34cd99 100644 --- a/test_codes/pymc_compare/dynglm_optimisation_test.py +++ b/test_codes/pymc_compare/dynglm_optimisation_test.py @@ -2,7 +2,7 @@ Need to find out whether loglikelihood is computed correctly. Or whether a bug here allows states to invade each other more easily. -We'll test this by maximising the likelihood directly. +We'll test this by comparing to pymc results. """ import numpy as np import pyhsmm.basic.distributions as distributions diff --git a/test_codes/pymc_compare/gibbs_sample.py b/test_codes/pymc_compare/gibbs_sample.py index 31ed5919143fdc6e74eb79107117c2b0d603285b..5b91c1ab0a72c404452b14d725e24225be2942c9 100644 --- a/test_codes/pymc_compare/gibbs_sample.py +++ b/test_codes/pymc_compare/gibbs_sample.py @@ -8,7 +8,7 @@ import pyhsmm.basic.distributions as distributions from scipy.optimize import minimize import pickle -# Data Params +# Data params T = 16 n_inputs = 3 step_size = 0.2 @@ -41,5 +41,5 @@ LL_weights = np.zeros((T, n_inputs)) for t in range(T): LL_weights[t] = minimize(lambda w: wrapper(w, t), np.zeros(n_inputs)).x -# save everything +# Save everything pickle.dump((samples, LL_weights), open('gibbs_posterior', 'wb'))