diff --git a/.gitignore b/.gitignore index b84a8f82fc35d730fb072a10736910aaca487db3..e0d406ba7e6131d35c436488d857b2167bbaaacf 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ - +summary_figures/ old_ana_code/ figures/ dynamic_figures/ diff --git a/__pycache__/analysis_pmf.cpython-37.pyc b/__pycache__/analysis_pmf.cpython-37.pyc index f9fd33bf94c4befd1c5937c83c151a084e749d91..899a197bee3c2ffcdbb2f048e7d0088d934a594e 100644 Binary files a/__pycache__/analysis_pmf.cpython-37.pyc and b/__pycache__/analysis_pmf.cpython-37.pyc differ diff --git a/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc b/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc index 797f6ea4ca71054cc38cb50347dc59cdea7fef0b..06557f92116c59895d3c014e28181f1e3c4ae7f7 100644 Binary files a/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc and b/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc differ diff --git a/__pycache__/simplex_plot.cpython-37.pyc b/__pycache__/simplex_plot.cpython-37.pyc index a8537b6f9da6be3317317d69450bbe8640da3809..d1a178aa1d78d9cd467d1b86bd96b0a477756845 100644 Binary files a/__pycache__/simplex_plot.cpython-37.pyc and b/__pycache__/simplex_plot.cpython-37.pyc differ diff --git a/analysis_pmf.py b/analysis_pmf.py index 03e55c1d17a94fbdfd7a0b8f2c8e239d4a076e4f..8be3eecf2101ada750f1414224d62ddd5b99491b 100644 --- a/analysis_pmf.py +++ b/analysis_pmf.py @@ -50,32 +50,32 @@ if __name__ == "__main__": state_types_interpolation = state_types_interpolation / state_types_interpolation.max() * 100 fs = 18 - # plt.plot(np.linspace(0, 1, 150), state_types_interpolation[0], color=type2color[0]) - # plt.ylabel("% of type across population", size=fs) - # plt.xlabel("Interpolated session time", size=fs) - # plt.ylim(0, 100) - # sns.despine() - # plt.tight_layout() - # plt.savefig("type hist 1") - # plt.show() - # - # plt.plot(np.linspace(0, 1, 150), state_types_interpolation[1], color=type2color[1]) - # plt.ylabel("% of type across population", size=fs) - # plt.xlabel("Interpolated session time", size=fs) - # plt.ylim(0, 100) - # sns.despine() - # plt.tight_layout() - # plt.savefig("type hist 2") - # plt.show() - # - # plt.plot(np.linspace(0, 1, 150), state_types_interpolation[2], color=type2color[2]) - # plt.ylabel("% of type across population", size=fs) - # plt.xlabel("Interpolated session time", size=fs) - # plt.ylim(0, 100) - # sns.despine() - # plt.tight_layout() - # plt.savefig("type hist 3") - # plt.show() + plt.plot(np.linspace(0, 1, 150), state_types_interpolation[0], color=type2color[0]) + plt.ylabel("% of type across population", size=fs) + plt.xlabel("Interpolated session time", size=fs) + plt.ylim(0, 100) + sns.despine() + plt.tight_layout() + plt.savefig("./summary_figures/type hist 1") + plt.close() + + plt.plot(np.linspace(0, 1, 150), state_types_interpolation[1], color=type2color[1]) + plt.ylabel("% of type across population", size=fs) + plt.xlabel("Interpolated session time", size=fs) + plt.ylim(0, 100) + sns.despine() + plt.tight_layout() + plt.savefig("./summary_figures/type hist 2") + plt.close() + + plt.plot(np.linspace(0, 1, 150), state_types_interpolation[2], color=type2color[2]) + plt.ylabel("% of type across population", size=fs) + plt.xlabel("Interpolated session time", size=fs) + plt.ylim(0, 100) + sns.despine() + plt.tight_layout() + plt.savefig("./summary_figures/type hist 3") + plt.close() all_first_pmfs_typeless = pickle.load(open("all_first_pmfs_typeless.p", 'rb')) all_pmfs = pickle.load(open("all_pmfs.p", 'rb')) @@ -111,27 +111,27 @@ if __name__ == "__main__": # plt.xlabel("Bias flips") # sns.despine() # plt.tight_layout() - # plt.savefig("./meeting_figures/bias_flips.png") + # plt.savefig("./summary_figures/bias_flips.png") # plt.show() - # fewer_states_side = [] - # for key in all_first_pmfs_typeless: - # animal_biases = np.zeros(2) - # for defined_points, pmf in all_first_pmfs_typeless[key]: - # bias = np.mean(pmf[defined_points]) - # if bias > 0.55: - # animal_biases[0] += 1 - # elif bias < 0.45: - # animal_biases[1] += 1 - # fewer_states_side.append(np.min(animal_biases / animal_biases.sum())) - # plt.hist(fewer_states_side) - # plt.title("Mixed biases") - # plt.ylabel("# of mice") - # plt.xlabel("min(% left biased states, % right biased states)") - # sns.despine() - # plt.tight_layout() - # plt.savefig("./meeting_figures/proportion_other_bias") - # plt.show() + fewer_states_side = [] + for key in all_first_pmfs_typeless: + animal_biases = np.zeros(2) + for defined_points, pmf in all_first_pmfs_typeless[key]: + bias = np.mean(pmf[defined_points]) + if bias > 0.55: + animal_biases[0] += 1 + elif bias < 0.45: + animal_biases[1] += 1 + fewer_states_side.append(np.min(animal_biases / animal_biases.sum())) + plt.hist(fewer_states_side) + plt.title("Mixed biases") + plt.ylabel("# of mice") + plt.xlabel("min(% left biased states, % right biased states)") + sns.despine() + plt.tight_layout() + plt.savefig("./summary_figures/proportion_other_bias") + plt.close() # total_counter = 0 # bias_counter = 0 @@ -287,7 +287,7 @@ if __name__ == "__main__": plt.gca().spines['bottom'].set_linewidth(4) plt.tight_layout() plt.savefig("single exam 5") - plt.show() + plt.close() @@ -301,14 +301,14 @@ if __name__ == "__main__": plt.ylim(0, 1) plt.xlim(0, 10) - plt.ylabel("P(rightwards)", size=32) - plt.xlabel("Contrast", size=32) + # plt.ylabel("P(rightwards)", size=32) + # plt.xlabel("Contrast", size=32) plt.yticks([0, 1], size=27) plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=27) sns.despine() plt.tight_layout() - plt.savefig("example type 1") - plt.show() + plt.savefig("example type 1", transparent=True) + plt.close() state_num = 1 defined_points, pmf = all_first_pmfs_typeless['CSHL_018'][state_num][0], all_first_pmfs_typeless['CSHL_018'][state_num][1] @@ -319,14 +319,14 @@ if __name__ == "__main__": plt.ylim(0, 1) plt.xlim(0, 10) - plt.ylabel("P(rightwards)", size=32) - plt.xlabel("Contrast", size=32) + # plt.ylabel("P(rightwards)", size=32) + # plt.xlabel("Contrast", size=32) plt.yticks([0, 1], size=27) plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=27) sns.despine() plt.tight_layout() - plt.savefig("example type 2") - plt.show() + plt.savefig("example type 2", transparent=True) + plt.close() state_num = 4 defined_points, pmf = all_first_pmfs_typeless['ibl_witten_17'][state_num][0], all_first_pmfs_typeless['ibl_witten_17'][state_num][1] @@ -334,14 +334,14 @@ if __name__ == "__main__": plt.ylim(0, 1) plt.xlim(0, 10) - plt.ylabel("P(rightwards)", size=32) - plt.xlabel("Contrast", size=32) + # plt.ylabel("P(rightwards)", size=32) + # plt.xlabel("Contrast", size=32) plt.yticks([0, 1], size=27) plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=27) sns.despine() plt.tight_layout() - plt.savefig("example type 3") - plt.show() + plt.savefig("example type 3", transparent=True) + plt.close() n_rows, n_cols = 5, 6 @@ -364,8 +364,8 @@ if __name__ == "__main__": a.set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) a.set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size) plt.tight_layout() - plt.savefig("animals 1") - plt.show() + plt.savefig("./summary_figures/animals 1") + plt.close() n_rows, n_cols = 5, 6 _, axs = plt.subplots(n_rows, n_cols, figsize=(16, 9)) @@ -393,8 +393,8 @@ if __name__ == "__main__": a.set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) a.set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size) plt.tight_layout() - plt.savefig("animals 2") - plt.show() + plt.savefig("./summary_figures/animals 2") + plt.close() n_rows, n_cols = 5, 6 _, axs = plt.subplots(n_rows, n_cols, figsize=(16, 9)) @@ -422,8 +422,8 @@ if __name__ == "__main__": a.set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) a.set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size) plt.tight_layout() - plt.savefig("animals 3") - plt.show() + plt.savefig("./summary_figures/animals 3") + plt.close() n_rows, n_cols = 5, 6 _, axs = plt.subplots(n_rows, n_cols, figsize=(16, 9)) @@ -451,8 +451,8 @@ if __name__ == "__main__": a.set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) a.set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size) plt.tight_layout() - plt.savefig("animals 4") - plt.show() + plt.savefig("./summary_figures/animals 4") + plt.close() # Collection of PMFs which change state type all_changing_pmfs = pickle.load(open("changing_pmfs.p", 'rb')) @@ -477,11 +477,49 @@ if __name__ == "__main__": plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=16) plt.gca().set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=16) - if i + 1 == 30: + if i + 1 == 35: break plt.tight_layout() - plt.savefig("changing pmfs") + plt.savefig("./summary_figures/changing pmfs") + plt.close() + + # The biases of all first PMFs as a histogram + biases = [] + mouse_counter = 0 + consistent_bias = 0 + hm = 0 + for key in all_first_pmfs_typeless: + left_1, right_1, left_2, right_2 = False, False, False, False + x = all_first_pmfs_typeless[key] + for pmf in x: + defined_points, pmf = pmf + if pmf_type(pmf) == 0: + biases.append(np.mean(pmf[[0, 1, -2, -1]])) + if np.mean(pmf[[0, 1, -2, -1]]) >= 0.5: + right_1 = True + if np.mean(pmf[[0, 1, -2, -1]]) <= 0.5: + left_1 = True + if pmf_type(pmf) == 1: + if np.mean(pmf[[0, 1, -2, -1]]) >= 0.5: + right_2 = True + if np.mean(pmf[[0, 1, -2, -1]]) <= 0.5: + left_2 = True + mouse_counter += 1 + consistent_bias += (left_1 and left_2) or (right_1 and right_2) + hm += left_1 and left_2 and right_1 and right_2 + + + print("number of mice is {}".format(mouse_counter)) + print("of which {} have a previously expressed bias in type 2".format(consistent_bias)) + print(hm) + # number of mice is 113 + # of which 77 have a previously expressed bias in type 2 + # 17 + + plt.hist(biases) + plt.axvline(np.mean(biases)) + plt.xlim(0, 1) plt.show() # All first PMFs @@ -497,7 +535,13 @@ if __name__ == "__main__": for key in all_first_pmfs_typeless: x = all_first_pmfs_typeless[key] for pmf in x: - axs[pmf_type(pmf[1])].plot(np.where(pmf[0])[0], pmf[1][pmf[0]], c=type2color[pmf_type(pmf[1])]) + + defined_points, pmf = pmf + pmf_min = min(pmf[0], pmf[1]) + pmf_max = max(pmf[-2], pmf[-1]) + defined_points = np.logical_and(np.logical_and(defined_points, ~ (pmf > pmf_max)), ~ (pmf < pmf_min)) + axs[pmf_type(pmf)].plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)]) + # axs[pmf_type(pmf[1])].plot(np.where(pmf[0])[0], pmf[1][pmf[0]], c=type2color[pmf_type(pmf[1])]) axs[0].set_ylim(0, 1) axs[1].set_ylim(0, 1) axs[2].set_ylim(0, 1) @@ -517,13 +561,14 @@ if __name__ == "__main__": axs[0].set_xlabel("Contrasts", size=label_size) plt.tight_layout() - plt.savefig(save_title) + plt.savefig("./summary_figures/" + save_title) plt.show() if save_title == "KS014 types": quit() # Type 1 PMF specifications counter = 0 + counter_l, counter_r = 0, 0 fig, ax = plt.subplots(1, 3, figsize=(16, 9)) for key in all_first_pmfs_typeless: for defined_points, pmf in all_first_pmfs_typeless[key]: @@ -537,6 +582,8 @@ if __name__ == "__main__": use_ax = 2 else: use_ax = int(pmf[0] > 1 - pmf[-1]) + counter_l += use_ax == 0 + counter_r += use_ax == 1 pmf_min = min(pmf[0], pmf[1]) pmf_max = max(pmf[-2], pmf[-1]) @@ -561,9 +608,9 @@ if __name__ == "__main__": ax[2].set_yticks([]) ax[2].spines[['right', 'top']].set_visible(False) ax[2].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) - print(counter) + print("left biased pmfs {}, right biased pmfs {}, neutral pmfs {}".format(counter_l, counter_r, counter)) plt.tight_layout() - plt.savefig("differentiate type 2") + plt.savefig("./summary_figures/differentiate type 2") plt.show() diff --git a/pmf_weight_analysis.py b/analysis_pmf_weights.py similarity index 61% rename from pmf_weight_analysis.py rename to analysis_pmf_weights.py index 059851501e92abf00fe662b44d35c296ea153c09..a2b0619646b3433f268a7c531bf7dc8a8b94f5b2 100644 --- a/pmf_weight_analysis.py +++ b/analysis_pmf_weights.py @@ -32,6 +32,115 @@ def pmf_type_rew(weights): else: return 2 +all_weight_trajectories = pickle.load(open("multi_chain_saves/all_weight_trajectories.p", 'rb')) + + +# for weight_traj in all_weight_trajectories: +# if len(weight_traj) == 1: +# continue +# state_type = pmf_type(weights_to_pmf(weight_traj[0])) + +# if state_type == 2 and weight_traj[0][0] > 0.8: +# plt.plot(weights_to_pmf(weight_traj[0])) +# plt.ylim(0, 1) +# plt.show() + +n_weights = all_weight_trajectories[0][0].shape[0] +ylabels = ["Cont left", "Cont right", "Persevere", "Bias left", "Bias right"] + +for min_dur in [2, 3, 4, 5, 7, 9, 11, 15]: + f, axs = plt.subplots(n_weights + 1, 3, figsize=(12, 9)) # add another space to split bias into stating left versus starting right + average = np.zeros((n_weights + 1, 3, 2)) + counter = np.zeros((n_weights + 1, 3)) + all_datapoints = [[[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]]] + for weight_traj in all_weight_trajectories: + + if len(weight_traj) < min_dur: + continue + + state_type = pmf_type(weights_to_pmf(weight_traj[0])) + + for i in range(n_weights): + if i == n_weights - 1: + axs[i + (weight_traj[0][i] > 0), state_type].plot([0, 1], [weight_traj[0][i], weight_traj[-1][i]], marker="o") # plot how the weight evolves from first to last appearance + average[i + (weight_traj[0][i] > 0), state_type] += np.array([weight_traj[0][i], weight_traj[-1][i]]) + counter[i + (weight_traj[0][i] > 0), state_type] += 1 + all_datapoints[i + (weight_traj[0][i] > 0)][state_type][0].append(weight_traj[0][i]) + all_datapoints[i + (weight_traj[0][i] > 0)][state_type][1].append(weight_traj[-1][i]) + else: + axs[i, state_type].plot([0, 1], [weight_traj[0][i], weight_traj[-1][i]], marker="o") # plot how the weight evolves from first to last appearance + average[i, state_type] += np.array([weight_traj[0][i], weight_traj[-1][i]]) + counter[i, state_type] += 1 + all_datapoints[i][state_type][0].append(weight_traj[0][i]) + all_datapoints[i][state_type][1].append(weight_traj[-1][i]) + for i in range(3): + for j in range(n_weights + 1): + axs[j, i].set_ylim(-9, 9) + axs[j, i].spines['top'].set_visible(False) + axs[j, i].spines['right'].set_visible(False) + axs[j, i].set_xticks([]) + if i == 0: + axs[j, i].set_ylabel(ylabels[j]) + else: + axs[j, i].set_yticks([]) + if j == 0: + axs[j, i].set_title("Type {}".format(i + 1)) + if j == n_weights: + axs[j, i].set_xlabel("Lifetime weight change") + + plt.tight_layout() + plt.savefig("./summary_figures/weight_changes/all weight changes min dur {}".format(min_dur)) + plt.close() + + f, axs = plt.subplots(n_weights + 1, 3, figsize=(12, 9)) + for i in range(3): + for j in range(n_weights + 1): + axs[j, i].plot([0, 1], average[j, i] / counter[j, i], marker="o") + if j < 2: + axs[j, i].set_ylim(-4.5, 4.5) + else: + axs[j, i].set_ylim(-2, 2) + axs[j, i].spines['top'].set_visible(False) + axs[j, i].spines['right'].set_visible(False) + axs[j, i].set_xticks([]) + if i == 0: + axs[j, i].set_ylabel(ylabels[j]) + else: + axs[j, i].set_yticks([]) + if j == 0: + axs[j, i].set_title("Type {}".format(i + 1)) + if j == n_weights: + axs[j, i].set_xlabel("Lifetime weight change") + plt.savefig("./summary_figures/weight_changes/weight changes min dur {}".format(min_dur)) + plt.close() + + f, axs = plt.subplots(n_weights + 1, 3, figsize=(12, 9)) + for i in range(3): + for j in range(n_weights + 1): + axs[j, i].plot([0, 1], average[j, i] / counter[j, i]) + axs[j, i].boxplot(all_datapoints[j][i], positions=[0, 1]) + if j < 2: + axs[j, i].set_ylim(-8, 8) + else: + axs[j, i].set_ylim(-4, 4) + axs[j, i].spines['top'].set_visible(False) + axs[j, i].spines['right'].set_visible(False) + axs[j, i].set_xticks([]) + if i == 0: + axs[j, i].set_ylabel(ylabels[j]) + else: + axs[j, i].set_yticks([]) + if j == 0: + axs[j, i].set_title("Type {}".format(i + 1)) + if j == n_weights: + axs[j, i].set_xlabel("Lifetime weight change") + plt.savefig("./summary_figures/weight_changes/weight change boxplot min dur {}".format(min_dur)) + plt.close() + +quit() + + + # all pmf weights apw = np.array(pickle.load(open("all_pmf_weights.p", 'rb'))) diff --git a/analysis_regression.py b/analysis_regression.py index 575f234c91dbf83c6f0082d3ac1aa05e743d765e..8e09fa87c2a109eebd100520dc655a59f25b6590 100644 --- a/analysis_regression.py +++ b/analysis_regression.py @@ -11,11 +11,13 @@ if __name__ == "__main__": regressions = np.array(pickle.load(open("regressions.p", 'rb'))) regression_diffs = np.array(pickle.load(open("regression_diffs.p", 'rb'))) - assert (regressions[:, 0] == np.sum(regressions[:, 2:], 1)).all() + assert (regressions[:, 0] == np.sum(regressions[:, 2:], 1)).all() # total # of regressions must be sum of # of regressions per type print(pearsonr(regressions[:, 0], regressions[:, 1])) + # (0.6202414016960471, 2.3666819600330215e-13) offset = 0.125 + plt.figure(figsize=(16 * 0.9, 9 * 0.9)) # which x values exist for x in np.unique(regressions[:, 0]): # which and how many ys are associated with this x @@ -28,7 +30,7 @@ if __name__ == "__main__": sns.despine() plt.tight_layout() - plt.savefig("Regression vs session length") + plt.savefig("./summary_figures/Regression vs session length") plt.show() # histogram of regressions per type @@ -38,7 +40,7 @@ if __name__ == "__main__": sns.despine() plt.tight_layout() - plt.savefig("# of regressions per type") + plt.savefig("./summary_figures/# of regressions per type") plt.show() # histogram of number of mice with regressions in the different types @@ -49,7 +51,7 @@ if __name__ == "__main__": sns.despine() plt.tight_layout() - plt.savefig("# of mice with regressions per type") + plt.savefig("./summary_figures/# of mice with regressions per type") plt.show() # histogram of regression diffs @@ -59,5 +61,5 @@ if __name__ == "__main__": sns.despine() plt.tight_layout() - plt.savefig("Regression diffs") + plt.savefig("./summary_figures/Regression diffs") plt.show() diff --git a/analysis_state_intros.py b/analysis_state_intros.py index 91b3250a7fee932274c8510534aa0e591912e9b7..b8435d95ce708bd5727133656af66ab64fb8fe48 100644 --- a/analysis_state_intros.py +++ b/analysis_state_intros.py @@ -44,14 +44,14 @@ def type_hist(data, title=''): plt.ylabel("Type 3", size=fontsize) ax2.set_yticks([]) - plt.savefig(title) + plt.savefig("./summary_figures/" + title) plt.show() if __name__ == "__main__": - all_intros = pickle.load(open("all_intros.p", 'rb')) - all_intros_div = pickle.load(open("all_intros_div.p", 'rb')) - all_states_per_type = pickle.load(open("all_states_per_type.p", 'rb')) + all_intros = np.array(pickle.load(open("all_intros.p", 'rb'))) + all_intros_div = np.array(pickle.load(open("all_intros_div.p", 'rb'))) + all_states_per_type = np.array(pickle.load(open("all_states_per_type.p", 'rb'))) # There are 5 mice with 0 type 2 intros, but only 3 mice with no type 2 stats. # That is because they come up, but don't explain the necessary 50% to start phase 2. diff --git a/behavioral_state_data_easier.py b/behavioral_state_data_easier.py index 18a9de78e8aea1c241ff6a474e44e58a204d84c7..aed7c5dcc445fadffa0b0ea1f319ba90aafa67f4 100644 --- a/behavioral_state_data_easier.py +++ b/behavioral_state_data_easier.py @@ -5,6 +5,8 @@ import pandas as pd import seaborn as sns import pickle import json +import os +import re one = ONE() @@ -34,20 +36,48 @@ misses = [] to_introduce = [2, 3, 4, 5] amiss = ['UCLA034', 'UCLA036', 'UCLA037', 'PL015', 'PL016', 'PL017', 'PL024', 'NR_0017', 'NR_0019', 'NR_0020', 'NR_0021', 'NR_0027'] -subjects = ['ZFM-04019', 'ZFM-05236'] fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0] if fit_type == 'bias': loading_info = json.load(open("canonical_infos_bias.json", 'r')) elif fit_type == 'prebias': loading_info = json.load(open("canonical_infos.json", 'r')) +bwm = ['NYU-11', 'NYU-12', 'NYU-21', 'NYU-27', 'NYU-30', 'NYU-37', + 'NYU-39', 'NYU-40', 'NYU-45', 'NYU-46', 'NYU-47', 'NYU-48', + 'CSHL045', 'CSHL047', 'CSHL049', 'CSHL051', 'CSHL052', 'CSHL053', + 'CSHL054', 'CSHL055', 'CSHL058', 'CSHL059', 'CSHL060', 'UCLA005', + 'UCLA006', 'UCLA011', 'UCLA012', 'UCLA014', 'UCLA015', 'UCLA017', + 'UCLA033', 'UCLA034', 'UCLA035', 'UCLA036', 'UCLA037', 'KS014', + 'KS016', 'KS022', 'KS023', 'KS042', 'KS043', 'KS044', 'KS045', + 'KS046', 'KS051', 'KS052', 'KS055', 'KS084', 'KS086', 'KS091', + 'KS094', 'KS096', 'DY_008', 'DY_009', 'DY_010', 'DY_011', 'DY_013', + 'DY_014', 'DY_016', 'DY_018', 'DY_020', 'PL015', 'PL016', 'PL017', + 'PL024', 'SWC_042', 'SWC_043', 'SWC_060', 'SWC_061', 'SWC_066', + 'ZFM-01576', 'ZFM-01577', 'ZFM-01592', 'ZFM-01935', 'ZFM-01936', + 'ZFM-01937', 'ZFM-02368', 'ZFM-02369', 'ZFM-02370', 'ZFM-02372', + 'ZFM-02373', 'ZM_1897', 'ZM_1898', 'ZM_2240', 'ZM_2241', 'ZM_2245', + 'ZM_3003', 'SWC_038', 'SWC_039', 'SWC_052', 'SWC_053', 'SWC_054', + 'SWC_058', 'SWC_065', 'NR_0017', 'NR_0019', 'NR_0020', 'NR_0021', + 'NR_0027', 'ibl_witten_13', 'ibl_witten_17', 'ibl_witten_18', + 'ibl_witten_19', 'ibl_witten_20', 'ibl_witten_25', 'ibl_witten_26', + 'ibl_witten_27', 'ibl_witten_29', 'CSH_ZAD_001', 'CSH_ZAD_011', + 'CSH_ZAD_019', 'CSH_ZAD_022', 'CSH_ZAD_024', 'CSH_ZAD_025', + 'CSH_ZAD_026', 'CSH_ZAD_029'] +regexp = re.compile(r'canonical_result_((\w|-)+)_prebias.p') +subjects = [] +for filename in os.listdir("./multi_chain_saves/"): + if not (filename.startswith('canonical_result_') and filename.endswith('.p')): + continue + result = regexp.search(filename) + if result is None: + continue + subject = result.group(1) + subjects.append(subject) already_fit = list(loading_info.keys()) -remaining_subs = [s for s in subjects if s not in amiss and s not in already_fit] -print(remaining_subs) - +# remaining_subs = [s for s in subjects if s not in amiss and s not in already_fit] +# print(remaining_subs) - -data_folder = 'session_data_test' +data_folder = 'session_data' old_style = False if old_style: @@ -64,22 +94,30 @@ names = [] pre_bias = [] entire_training = [] +training_status_reached = [] +actually_existing = [] for subject in subjects: + if subject in bwm: + continue print('_____________________') print(subject) # if subject in already_fit or subject in amiss: # continue - - trials = one.load_aggregate('subjects', subject, '_ibl_subjectTrials.table') + try: + trials = one.load_aggregate('subjects', subject, '_ibl_subjectTrials.table') # Load training status and join to trials table - training = one.load_aggregate('subjects', subject, '_ibl_subjectTraining.table') - quit() - trials = (trials - .set_index('session') - .join(training.set_index('session')) - .sort_values(by='session_start_time', kind='stable')) + + training = one.load_aggregate('subjects', subject, '_ibl_subjectTraining.table') + + trials = (trials + .set_index('session') + .join(training.set_index('session')) + .sort_values(by='session_start_time', kind='stable')) + actually_existing.append(subject) + except: + continue start_times, indices = np.unique(trials.session_start_time, return_index=True) start_times = [trials.session_start_time[index] for index in sorted(indices)] @@ -99,16 +137,19 @@ for subject in subjects: easy_per = np.zeros(len(eids)) hard_per = np.zeros(len(eids)) bias_start = 0 + ephys_start = 0 info_dict = {'subject': subject, 'dates': [st.to_pydatetime() for st in start_times], 'eids': eids} contrast_set = {0, 1, 9, 10} rel_count = -1 - quit() + for i, start_time in enumerate(start_times): rel_count += 1 + assert rel_count == i + df = trials[trials.session_start_time == start_time] df.loc[:, 'contrastRight'] = df.loc[:, 'contrastRight'].fillna(0) df.loc[:, 'contrastLeft'] = df.loc[:, 'contrastLeft'].fillna(0) @@ -137,9 +178,15 @@ for subject in subjects: bias_start = i print("bias start {}".format(rel_count)) info_dict['bias_start'] = rel_count + training_status_reached.append(set(df.training_status)) if bias_start < 33: short_subjs.append(subject) + if ephys_start == 0 and df.task_protocol[0].startswith('_iblrig_tasks_ephysChoiceWorld'): + ephys_start = i + print("ephys start {}".format(rel_count)) + info_dict['ephys_start'] = rel_count + pickle.dump(df, open("./{}/{}_df_{}.p".format(data_folder, subject, rel_count), "wb")) side_info = np.zeros((len(df), 2)) diff --git a/canonical_infos.json b/canonical_infos.json index 99ae626eb5a4f9444331e2cd47bdbe9b3c2131ce..9e26dfeeb6e641a33dae4961196235bdb965b21b 100644 --- a/canonical_infos.json +++ b/canonical_infos.json @@ -1 +1 @@ -{"NYU-45": {"seeds": ["513", "503", "500", "506", "502", "501", "512", "509", "507", "515", "510", "505", "504", "514", "511", "508"], "fit_nums": ["768", "301", "96", "731", "879", "989", "915", "512", "295", "48", "157", "631", "666", "334", "682", "714"], "chain_num": 14}, "UCLA035": {"seeds": ["512", "509", "500", "510", "501", "503", "506", "513", "505", "504", "507", "502", "514", "508", "511", "515"], "fit_nums": ["715", "834", "996", "656", "242", "883", "870", "959", "483", "94", "864", "588", "390", "173", "967", "871"], "chain_num": 14}, "NYU-30": {"seeds": ["505", "508", "515", "507", "504", "513", "512", "503", "509", "500", "510", "514", "501", "502", "511", "506"], "fit_nums": ["885", "637", "318", "98", "209", "171", "472", "823", "956", "89", "762", "260", "76", "319", "139", "785"], "chain_num": 14}, "CSHL047": {"seeds": ["509", "510", "503", "502", "501", "508", "514", "505", "507", "511", "515", "504", "500", "506", "513", "512"], "fit_nums": ["60", "589", "537", "3", "178", "99", "877", "381", "462", "527", "6", "683", "771", "950", "294", "252"], "chain_num": 14}, "NYU-39": {"seeds": ["515", "502", "513", "508", "514", "503", "510", "506", "509", "504", "511", "500", "512", "501", "507", "505"], "fit_nums": ["722", "12", "207", "378", "698", "928", "15", "180", "650", "334", "388", "528", "608", "593", "988", "479"], "chain_num": 14}, "NYU-37": {"seeds": ["508", "509", "506", "512", "507", "503", "515", "504", "500", "505", "513", "501", "510", "511", "502", "514"], "fit_nums": ["94", "97", "793", "876", "483", "878", "886", "222", "66", "59", "601", "994", "526", "694", "304", "615"], "chain_num": 14}, "KS045": {"seeds": ["501", "514", "500", "502", "503", "505", "510", "515", "508", "507", "509", "512", "511", "506", "513", "504"], "fit_nums": ["731", "667", "181", "609", "489", "555", "995", "19", "738", "1", "267", "653", "750", "332", "218", "170"], "chain_num": 14}, "UCLA006": {"seeds": ["509", "505", "513", "515", "511", "500", "503", "507", "508", "514", "504", "510", "502", "501", "506", "512"], "fit_nums": ["807", "849", "196", "850", "293", "874", "216", "542", "400", "632", "781", "219", "331", "730", "740", "32"], "chain_num": 14}, "UCLA033": {"seeds": ["507", "505", "501", "512", "502", "510", "513", "514", "506", "515", "509", "504", "508", "500", "503", "511"], "fit_nums": ["366", "18", "281", "43", "423", "877", "673", "146", "921", "21", "353", "267", "674", "113", "905", "252"], "chain_num": 14}, "NYU-40": {"seeds": ["513", "500", "503", "507", "515", "504", "510", "508", "505", "514", "502", "512", "501", "509", "511", "506"], "fit_nums": ["656", "826", "657", "634", "861", "347", "334", "227", "747", "834", "460", "191", "489", "458", "24", "346"], "chain_num": 14}, "NYU-46": {"seeds": ["507", "509", "512", "508", "503", "500", "515", "501", "511", "506", "502", "505", "510", "513", "514", "504"], "fit_nums": ["503", "523", "5", "819", "190", "917", "707", "609", "145", "416", "376", "603", "655", "271", "223", "149"], "chain_num": 14}, "KS044": {"seeds": ["513", "503", "507", "504", "502", "515", "501", "511", "510", "500", "512", "508", "509", "506", "514", "505"], "fit_nums": ["367", "656", "73", "877", "115", "627", "610", "772", "558", "581", "398", "267", "353", "779", "393", "473"], "chain_num": 14}, "NYU-48": {"seeds": ["502", "507", "503", "515", "513", "505", "512", "501", "510", "504", "508", "511", "506", "514", "500", "509"], "fit_nums": ["854", "52", "218", "963", "249", "901", "248", "322", "566", "768", "256", "101", "303", "485", "577", "141"], "chain_num": 14}, "UCLA012": {"seeds": ["508", "506", "504", "513", "512", "502", "507", "505", "510", "503", "515", "509", "511", "501", "500", "514"], "fit_nums": ["492", "519", "577", "417", "717", "60", "130", "186", "725", "83", "841", "65", "441", "534", "856", "735"], "chain_num": 14}, "KS084": {"seeds": ["515", "507", "512", "503", "506", "508", "510", "502", "509", "505", "500", "511", "504", "501", "513", "514"], "fit_nums": ["816", "374", "140", "955", "399", "417", "733", "149", "300", "642", "644", "248", "324", "830", "889", "286"], "chain_num": 14}, "CSHL052": {"seeds": ["502", "508", "509", "514", "507", "515", "506", "500", "501", "505", "504", "511", "513", "512", "503", "510"], "fit_nums": ["572", "630", "784", "813", "501", "738", "517", "461", "690", "203", "40", "202", "412", "755", "837", "917"], "chain_num": 14}, "NYU-11": {"seeds": ["502", "510", "501", "513", "512", "500", "503", "515", "506", "514", "511", "505", "509", "504", "508", "507"], "fit_nums": ["650", "771", "390", "185", "523", "901", "387", "597", "57", "624", "12", "833", "433", "58", "276", "248"], "chain_num": 14}, "KS051": {"seeds": ["502", "505", "508", "515", "512", "511", "501", "509", "500", "506", "513", "503", "504", "507", "514", "510"], "fit_nums": ["620", "548", "765", "352", "402", "699", "370", "445", "159", "746", "449", "342", "642", "204", "726", "605"], "chain_num": 14}, "NYU-27": {"seeds": ["507", "513", "501", "505", "511", "504", "500", "514", "502", "509", "503", "515", "512", "510", "508", "506"], "fit_nums": ["485", "520", "641", "480", "454", "913", "526", "705", "138", "151", "962", "24", "21", "743", "119", "699"], "chain_num": 14}, "UCLA011": {"seeds": ["513", "503", "508", "501", "510", "505", "506", "511", "507", "514", "515", "500", "509", "502", "512", "504"], "fit_nums": ["957", "295", "743", "795", "643", "629", "142", "174", "164", "21", "835", "338", "368", "341", "209", "68"], "chain_num": 14}, "NYU-47": {"seeds": ["515", "504", "510", "503", "505", "506", "502", "501", "507", "500", "514", "513", "508", "509", "511", "512"], "fit_nums": ["169", "941", "329", "51", "788", "654", "224", "434", "385", "46", "712", "84", "930", "571", "273", "312"], "chain_num": 14}, "CSHL045": {"seeds": ["514", "502", "510", "507", "512", "501", "511", "506", "508", "509", "503", "513", "500", "504", "515", "505"], "fit_nums": ["862", "97", "888", "470", "620", "765", "874", "421", "104", "909", "924", "874", "158", "992", "25", "40"], "chain_num": 14}, "UCLA017": {"seeds": ["510", "502", "500", "515", "503", "507", "514", "501", "511", "505", "513", "504", "508", "512", "506", "509"], "fit_nums": ["875", "684", "841", "510", "209", "207", "806", "700", "989", "899", "812", "971", "526", "887", "160", "249"], "chain_num": 14}, "CSHL055": {"seeds": ["509", "503", "505", "512", "504", "508", "510", "511", "507", "501", "500", "514", "513", "515", "506", "502"], "fit_nums": ["957", "21", "710", "174", "689", "796", "449", "183", "193", "209", "437", "827", "990", "705", "540", "835"], "chain_num": 14}, "UCLA005": {"seeds": ["507", "503", "512", "515", "505", "500", "504", "513", "509", "506", "501", "511", "514", "502", "510", "508"], "fit_nums": ["636", "845", "712", "60", "733", "789", "990", "230", "335", "337", "307", "404", "297", "608", "428", "108"], "chain_num": 14}, "CSHL060": {"seeds": ["510", "515", "513", "503", "514", "501", "504", "502", "511", "500", "508", "509", "505", "507", "506", "512"], "fit_nums": ["626", "953", "497", "886", "585", "293", "580", "867", "113", "734", "88", "55", "949", "443", "210", "555"], "chain_num": 14}, "UCLA015": {"seeds": ["503", "501", "502", "500"], "fit_nums": ["877", "773", "109", "747"], "chain_num": 14}, "KS055": {"seeds": ["510", "502", "511", "501", "509", "506", "500", "515", "503", "512", "504", "505", "513", "508", "514", "507"], "fit_nums": ["526", "102", "189", "216", "673", "477", "293", "981", "960", "883", "899", "95", "31", "244", "385", "631"], "chain_num": 14}, "UCLA014": {"seeds": ["513", "509", "508", "510", "500", "504", "515", "511", "501", "514", "507", "502", "505", "512", "506", "503"], "fit_nums": ["628", "950", "418", "91", "25", "722", "792", "225", "287", "272", "23", "168", "821", "934", "194", "481"], "chain_num": 14}, "CSHL053": {"seeds": ["505", "513", "501", "508", "502", "515", "510", "507", "506", "509", "514", "511", "500", "503", "504", "512"], "fit_nums": ["56", "674", "979", "0", "221", "577", "25", "679", "612", "185", "464", "751", "648", "715", "344", "348"], "chain_num": 14}, "NYU-12": {"seeds": ["508", "506", "509", "515", "511", "510", "501", "504", "513", "503", "507", "512", "500", "514", "505", "502"], "fit_nums": ["853", "819", "692", "668", "730", "213", "846", "596", "644", "829", "976", "895", "974", "824", "179", "769"], "chain_num": 14}, "KS043": {"seeds": ["505", "504", "507", "500", "502", "503", "508", "511", "512", "509", "515", "513", "510", "514", "501", "506"], "fit_nums": ["997", "179", "26", "741", "476", "502", "597", "477", "511", "181", "233", "330", "299", "939", "542", "113"], "chain_num": 14}, "CSHL058": {"seeds": ["513", "505", "514", "506", "504", "500", "511", "503", "508", "509", "501", "510", "502", "515", "507", "512"], "fit_nums": ["752", "532", "949", "442", "400", "315", "106", "419", "903", "198", "553", "158", "674", "249", "723", "941"], "chain_num": 14}, "KS042": {"seeds": ["503", "502", "506", "511", "501", "505", "512", "509", "513", "508", "507", "500", "510", "504", "515", "514"], "fit_nums": ["241", "895", "503", "880", "283", "267", "944", "204", "921", "514", "392", "241", "28", "905", "334", "894"], "chain_num": 14}} \ No newline at end of file +{} \ No newline at end of file diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py index 1b8062228711583bf6024e9b68935547ab36d6ce..37bb5dadb5fc41fd30fa438a8db40c146b2595ca 100644 --- a/dyn_glm_chain_analysis.py +++ b/dyn_glm_chain_analysis.py @@ -22,7 +22,24 @@ import multiprocessing as mp from mcmc_chain_analysis import state_num_helper, ll_func, r_hat_array_comp, rank_inv_normal_transform import pandas as pd from analysis_pmf import pmf_type, type2color +import re +performance_points = np.array([-1, -1, 0, 0]) + +def pmf_to_perf(pmf): + # determine performance of a pmf, but only on the omnipresent strongest contrasts + return np.mean(np.abs(performance_points + pmf[[0, 1, -2, -1]])) + +# def pmf_type_temp(pmf): +# rew = pmf_to_perf(pmf) +# if rew < 0.6: +# return 0 +# elif rew < 0.7827: +# if np.abs(pmf[0] + pmf[-1] - 1) <= 0.1: +# return 3 +# return 1 +# else: +# return 2 colors = np.genfromtxt('colors.csv', delimiter=',') @@ -45,15 +62,6 @@ def weights_to_pmf(weights, with_bias=1): psi = weights[0] * contrasts_R + weights[1] * contrasts_L + with_bias * weights[-1] return 1 / (1 + np.exp(psi)) # we somehow got the answers twisted, so we drop the minus here to get the opposite response probability for plotting -performance_points = np.array([-1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0]) -def pmf_to_perf(pmf, def_points): - # determine performance of a pmf - # we use this to determine regressions in the behaviour of animals - # therefore, we exclude 0 as performance on it is 0.5 regardless of PMF, but it might - # overall lower performance. The removal of 0.5 later might also be a problem, let's see - relevant_points = def_points - relevant_points[5] = False - return np.mean(np.abs(performance_points[relevant_points] + pmf[relevant_points])) class MCMC_result_list: @@ -305,13 +313,13 @@ class MCMC_result_list: state_appear = [] state_durs = [] - state_appear_dist = np.zeros(10) + state_appear_dist = np.zeros(11) for res in self.results: for j in range(take_n): sa, sd = state_appear_and_dur(res.models[self.m // take_n * j], self) state_appear += sa state_durs += sd - state_appear_dist[[int(s * 10) for s in sa]] += 1 + state_appear_dist[[int(s * 11) for s in sa]] += 1 return state_appear, state_durs, state_appear_dist / self.m / take_n @@ -392,7 +400,6 @@ class MCMC_result: state_start_save[20 * index // len(seq)] += 1 state_starts[session_bounds[i] + index] += 1 observed_states.append(s) - print('observed_states {}'.format(observed_states)) plt.plot(state_starts / self.n_samples) return session_bounds, state_start_save / self.n_samples @@ -625,6 +632,7 @@ def bias_flips(states_by_session, pmfs, durs): def pmf_regressions(states_by_session, pmfs, durs): # find out whether pmfs regressed + # return: [total # of regressions, # of sessions, # of regressions during type 1, type 2, type 3], the reward differences state_perfs = {} state_counter = {} current_best_state = -1 @@ -1077,7 +1085,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmf[defined_points] - 0.5 + test.state_mapping[state], color=cmap(0.2 + 0.8 * j / test.results[0].n_sessions)) ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmf[defined_points] - 0.5 + test.state_mapping[state], ls='', ms=7, marker='*', color=cmap(j / test.results[0].n_sessions)) all_pmfs.append((defined_points, pmfs)) - all_pmf_weights += pmf_weights + all_pmf_weights.append(pmf_weights) else: temp = np.percentile(pmfs, [2.5, 97.5], axis=0) if not test.state_mapping[state] in dont_plot: @@ -1382,7 +1390,7 @@ def state_type_durs(states, pmfs): # Takes states and pmfs, first creates an array of when which type is how active, then computes the number of sessions each type lasts. # A type lasts until a more advanced type takes up more than 50% of a session (and cannot return) # Returns the durations for all the different state types, and an array which holds the state percentages - state_types = np.zeros((3, states.shape[1])) + state_types = np.zeros((4, states.shape[1])) for s, pmf in zip(states, pmfs): pmf_counter = -1 for i in range(states.shape[1]): @@ -1427,9 +1435,9 @@ def state_cluster_interpolation(states, pmfs): def get_first_pmfs(states, pmfs): # get the first pmf of every type, also where they are defined, and whether they are the first pmf of that state - earliest_sessions = [1000, 1000, 1000] - first_pmfs = [0, 0, 0, 0, 0, 0, 0, 0, 0] - changing_pmfs = [[0, 0], [0, 0]] + earliest_sessions = [1000, 1000, 1000] # high values to compare with min + first_pmfs = [0, 0, 0, 0, 0, 0, 0, 0, 0] # for every type (3), we save: the pmf, the defined points, the session at which they appear + changing_pmfs = [[0, 0], [0, 0]] # if the new type appears first through a slow transition, we remember its defined points (pmf[1]) and its pmf trajectory pmf[1] for state, pmf in zip(states, pmfs): sessions = np.where(state)[0] for i, (sess_pmf, sess) in enumerate(zip(pmf[1], sessions)): @@ -1532,26 +1540,25 @@ if __name__ == "__main__": loading_info = json.load(open("canonical_infos_bias.json", 'r')) elif fit_type == 'prebias': loading_info = json.load(open("canonical_infos.json", 'r')) - no_good_pcas = ['NYU-06', 'SWC_023'] - subjects = list(loading_info.keys()) - - new_subs = ['KS044', 'NYU-11', 'DY_016', 'SWC_061', 'ZFM-05245', 'CSH_ZAD_029', 'SWC_021', 'CSHL058', 'DY_014', 'DY_009', 'KS094', 'DY_018', 'KS043', 'UCLA014', 'SWC_038', 'SWC_022', 'UCLA012', 'UCLA011', 'CSHL055', 'ZFM-04019', 'NYU-45', 'ZFM-02370', 'ZFM-02373', 'ZFM-02369', 'NYU-40', 'CSHL060', 'NYU-30', 'CSH_ZAD_019', 'UCLA017', 'KS052', 'ibl_witten_25', 'ZFM-02368', 'CSHL045', 'UCLA005', 'SWC_058', 'CSH_ZAD_024', 'SWC_042', 'DY_008', 'ibl_witten_13', 'SWC_043', 'KS046', 'DY_010', 'CSHL053', 'ZM_1898', 'UCLA033', 'NYU-47', 'DY_011', 'CSHL047', 'SWC_054', 'ibl_witten_19', 'ibl_witten_27', 'KS091', 'KS055', 'CSH_ZAD_017', 'UCLA035', 'SWC_060', 'DY_020', 'ZFM-01577', 'ZM_2240', 'ibl_witten_29', 'KS096', 'SWC_066', 'DY_013', 'ZFM-01592', 'GLM_Sim_17', 'NYU-48', 'UCLA006', 'NYU-39', 'KS051', 'NYU-27', 'NYU-46', 'ZFM-01936', 'ZFM-02372', 'ZFM-01935', 'ibl_witten_26', 'ZFM-05236', 'ZM_2241', 'NYU-37', 'KS086', 'KS084', 'ZFM-01576', 'KS042'] - - miss_count = 0 - for s in new_subs: - if s not in subjects: - print(s) - miss_count += 1 - quit() + subjects = [] + regexp = re.compile(r'canonical_result_((\w|-)+)_prebias.p') + for filename in os.listdir("./multi_chain_saves/"): + if not (filename.startswith('canonical_result_') and filename.endswith('.p')): + continue + result = regexp.search(filename) + if result is None: + continue + subject = result.group(1) + subjects.append(subject) - print(subjects) + print(len(subjects)) fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0] dur = 'yes' # fig, ax = plt.subplots(1, 3, sharey=True, figsize=(16, 9)) pop_state_starts = np.zeros(20) - state_appear_dist = np.zeros(10) + state_appear_dist = np.zeros(11) state_appear_mode = [] num_states = [] num_sessions = [] @@ -1581,6 +1588,9 @@ if __name__ == "__main__": regression_diffs = [] all_bias_flips = [] all_pmf_weights = [] + all_weight_trajectories = [] + bias_sessions = [] + first_and_last_pmf = [] new_counter, transform_counter = 0, 0 state_types_interpolation = np.zeros((3, 150)) @@ -1589,17 +1599,29 @@ if __name__ == "__main__": state_nums_5 = [] state_nums_10 = [] - for subject in subjects: + ultimate_counter = 0 - # NYU-11 is quite weird, super errrativ behaviour, has all contrasts introduced at once, no good session at end - if subject.startswith('GLM_Sim_'): + for subject in subjects: + if subject.startswith('GLM_Sim_') or subject in ['SWC_065', 'ZFM-05245', 'ZFM-04019', 'ibl_witten_18']: + # ibl_witten_18 is a weird one, super good session in the middle, ending phase 1, never to re-appear, bad at the end + # ZFM-05245 is neuromodulator mouse, never reaches ephys it seems... same for ZFM-04019 + # SWC_065 never reaches type 3 continue print() print(subject) + continue test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'rb')) + # all_state_starts = test.state_appearance_posterior(subject) + # pop_state_starts += all_state_starts + + # a, b, c = test.state_start_and_dur() + # state_appear += a + # state_dur += b + # state_appear_dist += c + mode_specifier = 'first' try: mode_indices = pickle.load(open("multi_chain_saves/{}_mode_indices_{}_{}.p".format(mode_specifier, subject, fit_type), 'rb')) @@ -1609,6 +1631,9 @@ if __name__ == "__main__": mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'rb')) state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb')) except: + print("____________________________________") + print("Something quite wrong with {}".format(subject)) + print("____________________________________") continue # lapse differential @@ -1619,18 +1644,37 @@ if __name__ == "__main__": # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 1', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(7)), plot_until=2) # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 2', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(6)), plot_until=7) # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 3', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(4)), plot_until=13) - states, pmfs, pmf_weights, durs, state_types, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True) - quit() + states, pmfs, pmf_weights, durs, state_types, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save=False, show=0, separate_pmf=1, type_coloring=True) + + first_and_last_pmf.append((pmf_weights[np.argmax(states[:, 0])][0], pmf_weights[np.argmax(states[:, -1])][-1])) + + continue + # all_weight_trajectories += pmf_weights + abs_state_durs.append(durs) + ultimate_counter += 1 + + info_dict = pickle.load(open("./{}/{}_info_dict.p".format('session_data', subject), "rb")) + if 'ephys_start' in info_dict: + bias_sessions.append(info_dict['ephys_start'] - info_dict['bias_start']) + else: + bias_sessions.append(info_dict['n_sessions'] - info_dict['bias_start']) + print(subject, info_dict['n_sessions'], info_dict['bias_start']) + + continue + + state_types_interpolation[0] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[0]) + state_types_interpolation[1] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[1]) + state_types_interpolation[2] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[2]) # state_nums_5.append((states > 0.05).sum(0)) # state_nums_10.append((states > 0.1).sum(0)) # continue # a = compare_params(test, 26, [s for s in state_sets if len(s) > 40], mode_indices, [3, 5]) # compare_pmfs(test, [3, 2, 4], states, pmfs, title="{} convergence pmf".format(subject)) - consistencies = pickle.load(open("multi_chain_saves/first_mode_consistencies_{}_{}.p".format(subject, fit_type), 'rb')) - consistencies /= consistencies[0, 0] - temp = contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True, consistencies=consistencies, CMF=False) - continue - quit() + # consistencies = pickle.load(open("multi_chain_saves/first_mode_consistencies_{}_{}.p".format(subject, fit_type), 'rb')) + # consistencies /= consistencies[0, 0] + # temp = contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True, consistencies=consistencies, CMF=False) + # quit() + new = type_2_appearance(states, pmfs) if new == 2: @@ -1644,60 +1688,50 @@ if __name__ == "__main__": print(new_counter, transform_counter) - state_dict = write_results(test, [s for s in state_sets if len(s) > 40], mode_indices) - pickle.dump(state_dict, open("state_dict_{}".format(subject), 'wb')) - quit() - # abs_state_durs.append(durs) - # continue - # all_pmf_weights += pmf_weights - # all_state_types.append(state_types) - # - # state_types_interpolation[0] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[0]) - # state_types_interpolation[1] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[1]) - # state_types_interpolation[2] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[2]) - # - # all_first_pmfs_typeless[subject] = [] - # for pmf in pmfs: - # all_first_pmfs_typeless[subject].append((pmf[0], pmf[1][0])) - # all_pmfs.append(pmf) - # - # first_pmfs, changing_pmfs = get_first_pmfs(states, pmfs) - # for pmf in changing_pmfs: - # if type(pmf[0]) == int: - # continue - # all_changing_pmf_names.append(subject) - # all_changing_pmfs.append(pmf) - # - # all_first_pmfs[subject] = first_pmfs - # continue - # - # consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb')) - # consistencies /= consistencies[0, 0] - # contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=False, consistencies=consistencies, CMF=False) + # state_dict = write_results(test, [s for s in state_sets if len(s) > 40], mode_indices) + # pickle.dump(state_dict, open("state_dict_{}".format(subject), 'wb')) + + all_pmf_weights += [item for sublist in pmf_weights for item in sublist] + all_state_types.append(state_types) + + all_first_pmfs_typeless[subject] = [] + for pmf in pmfs: + all_first_pmfs_typeless[subject].append((pmf[0], pmf[1][0])) + all_pmfs.append(pmf) + + first_pmfs, changing_pmfs = get_first_pmfs(states, pmfs) + for pmf in changing_pmfs: + if type(pmf[0]) == int: + continue + all_changing_pmf_names.append(subject) + all_changing_pmfs.append(pmf) + + all_first_pmfs[subject] = first_pmfs + quit() # b_flips = bias_flips(states, pmfs, durs) # all_bias_flips.append(b_flips) - # regression, diffs = pmf_regressions(states, pmfs, durs) - # regression_diffs += diffs - # regressions.append(regression) - # continue + regression, diffs = pmf_regressions(states, pmfs, durs) + regression_diffs += diffs + regressions.append(regression) + # compare_pmfs(test, [s for s in state_sets if len(s) > 40], mode_indices, [4, 5], states, pmfs, title="{} convergence pmf".format(subject)) # compare_weights(test, [s for s in state_sets if len(s) > 40], mode_indices, [4, 5], states, title="{} convergence weights".format(subject)) - # quit() - # all_intros.append(undiv_intros) - # all_intros_div.append(intros_by_type) - # if states_per_type != []: - # all_states_per_type.append(states_per_type) - # - # intros_by_type_sum += intros_by_type - # for pmf in pmfs: - # all_pmfs.append(pmf) - # for p in pmf[1]: - # all_pmf_diffs.append(p[-1] - p[0]) - # all_pmf_asymms.append(np.abs(p[0] + p[-1] - 1)) - # contrast_intro_types.append(contrast_intro_type) + + all_intros.append(undiv_intros) + all_intros_div.append(intros_by_type) + if states_per_type != []: + all_states_per_type.append(states_per_type) + + intros_by_type_sum += intros_by_type + for pmf in pmfs: + all_pmfs.append(pmf) + for p in pmf[1]: + all_pmf_diffs.append(p[-1] - p[0]) + all_pmf_asymms.append(np.abs(p[0] + p[-1] - 1)) + contrast_intro_types.append(contrast_intro_type) # num_states.append(np.mean(test.state_num_dist())) # num_sessions.append(test.results[0].n_sessions) @@ -1745,13 +1779,7 @@ if __name__ == "__main__": # quit() # quit() - # all_state_starts = test.state_appearance_posterior(subject) - # pop_state_starts += all_state_starts # - # a, b, c = test.state_start_and_dur() - # state_appear += a - # state_dur += b - # state_appear_dist += c # all_state_starts = test.state_appearance_posterior(subject) @@ -1759,7 +1787,6 @@ if __name__ == "__main__": # plt.savefig("temp") # plt.close() - # pickle.dump(all_first_pmfs, open("all_first_pmfs.p", 'wb')) # pickle.dump(all_changing_pmfs, open("changing_pmfs.p", 'wb')) # pickle.dump(all_changing_pmf_names, open("changing_pmf_names.p", 'wb')) @@ -1774,24 +1801,50 @@ if __name__ == "__main__": # pickle.dump(all_state_types, open("all_state_types.p", 'wb')) # pickle.dump(all_pmf_weights, open("all_pmf_weights.p", 'wb')) # pickle.dump(state_types_interpolation, open("state_types_interpolation.p", 'wb')) + # pickle.dump(state_types_interpolation, open("state_types_interpolation_4_states.p", 'wb')) # special version, might not want to use # abs_state_durs = np.array(abs_state_durs) # pickle.dump(abs_state_durs, open("multi_chain_saves/abs_state_durs.p", 'wb')) + # pickle.dump(pop_state_starts, open("multi_chain_saves/pop_state_starts.p", 'wb')) + # pickle.dump(state_appear, open("multi_chain_saves/state_appear.p", 'wb')) + # pickle.dump(state_dur, open("multi_chain_saves/state_dur.p", 'wb')) + # pickle.dump(state_appear_dist, open("multi_chain_saves/state_appear_dist.p", 'wb')) + # pickle.dump(all_weight_trajectories, open("multi_chain_saves/all_weight_trajectories.p", 'wb')) + # pickle.dump(bias_sessions, open("multi_chain_saves/bias_sessions.p", 'wb')) + # pickle.dump(first_and_last_pmf, open("multi_chain_saves/first_and_last_pmf.p", 'wb')) + + abs_state_durs = pickle.load(open("multi_chain_saves/abs_state_durs.p", 'rb')) + bias_sessions = pickle.load(open("multi_chain_saves/bias_sessions.p", 'rb')) + + quit() + print("Ultimate count is {}".format(ultimate_counter)) - if True: + if False: abs_state_durs = pickle.load(open("multi_chain_saves/abs_state_durs.p", 'rb')) print("Correlations") from scipy.stats import pearsonr print(pearsonr(abs_state_durs[:, 0], abs_state_durs[:, 1])) + # (0.30202072153268694, 0.0011496023265844494) print(pearsonr(abs_state_durs[:, 0], abs_state_durs[:, 2])) + # (-0.008129156810930915, 0.9318998938147801) print(pearsonr(abs_state_durs[:, 1], abs_state_durs[:, 2])) + # (0.20293276001562754, 0.03110687851587209) print(pearsonr(abs_state_durs.sum(1), abs_state_durs[:, 0])) - # (0.7338297529946006, 2.6332570579118393e-06) + # (0.3930315541710777, 1.6612856471614082e-05) print(pearsonr(abs_state_durs.sum(1), abs_state_durs[:, 1])) - # (0.35094585023228597, 0.052897046343413114) + # (0.49183970714755426, 3.16067121387928e-08) print(pearsonr(abs_state_durs.sum(1), abs_state_durs[:, 2])) - # (0.7210260323745921, 4.747833912452452e-06) + # (0.8936241158982767, 2.0220064236645623e-40) + + print(pearsonr(abs_state_durs[:, 0], bias_sessions)) + # (-0.048942358403448086, 0.6099754179426583) + print(pearsonr(abs_state_durs[:, 1], bias_sessions)) + # (-0.07904742327893216, 0.4095531301939973) + print(pearsonr(abs_state_durs[:, 2], bias_sessions)) + # (-0.0697415785143183, 0.4670166281153171) + print(pearsonr(abs_state_durs.sum(1), bias_sessions)) + # (-0.09311057488563208, 0.3310557265430209) from simplex_plot import plotSimplex @@ -1804,9 +1857,11 @@ if __name__ == "__main__": plt.ylabel("# of mice", size=40) plt.xlabel('# of sessions', size=40) plt.tight_layout() - plt.savefig("session_num_hist.png", dpi=300, transparent=True) + plt.savefig("./summary_figures/session_num_hist.png", dpi=300, transparent=True) plt.show() + quit() + if False: ax[0].set_ylim(0, 1) ax[1].set_ylim(0, 1) @@ -1850,11 +1905,12 @@ if __name__ == "__main__": # ax[1].legend() # ax[2].legend() plt.tight_layout() - plt.savefig("first_knowledge_state", dpi=300) + plt.savefig("./summary_figures/first_knowledge_state", dpi=300) plt.show() quit() if False: + pop_state_starts = pickle.load(open("multi_chain_saves/pop_state_starts.p", 'rb')) f, (ax, ax2) = plt.subplots(2, 1, sharex=True, figsize=(9, 9)) # plot the same data on both axes @@ -1864,14 +1920,14 @@ if __name__ == "__main__": ax2.bar(np.linspace(0, 1, 20), pop_state_starts, align='edge', width=1/19, color='grey') # zoom-in / limit the view to different portions of the data - ax.set_ylim(24.75, 28) - ax2.set_ylim(0, 3.25) + ax.set_ylim(245, 300) + ax2.set_ylim(0, 25) - ax.set_yticks([25, 27]) - ax.set_yticklabels([25, 27]) - ax2.set_yticks([0, 1, 2, 3]) + ax.set_yticks([250, 275]) + ax.set_yticklabels([250, 275]) + ax2.set_yticks([0, 8, 16, 24]) ax2.set_xticks([0, .25, .5, .75, 1]) - ax2.set_yticklabels([0, 1, 2, 3]) + ax2.set_yticklabels([0, 8, 16, 24]) ax2.set_xticklabels([0, .25, .5, .75, 1]) # hide the spines between ax and ax2 @@ -1896,13 +1952,13 @@ if __name__ == "__main__": ax2.plot((-d, +d), (1 - d, 1 + d), **kwargs) # bottom-left diagonal plt.tight_layout() - plt.savefig('states within session', dpi=300) + plt.savefig('./summary_figures/states within session', dpi=300) plt.close() if False: f, (a1, a2) = plt.subplots(1, 2, figsize=(18, 9), gridspec_kw={'width_ratios': [1.4, 1]}) - # a1.bar(np.linspace(0, 1, 11), state_appear_dist, align='edge', width=1/10, color='grey') - a1.hist(state_appear_mode, color='grey') + a1.bar(np.linspace(0, 1, 11), state_appear_dist, align='edge', width=1/10, color='grey') + # a1.hist(state_appear_mode, color='grey') a1.set_xlim(left=0) # plt.title('First appearence of ', fontsize=22) a1.set_ylabel('# of states', fontsize=35) @@ -1932,5 +1988,5 @@ if __name__ == "__main__": a2.tick_params(axis='both', labelsize=20) sns.despine() plt.tight_layout() - plt.savefig('states in sessions', dpi=300) + plt.savefig('./summary_figures/states in sessions', dpi=300) plt.close() diff --git a/index_mice.py b/index_mice.py index 96d35de5b11c000cdb713ffea432e4c8b74473e5..6c42c4d28994da619fd3a5cc63535ef9706ddb23 100644 --- a/index_mice.py +++ b/index_mice.py @@ -9,14 +9,14 @@ test = ['CSHL059_fittype_prebias_var_0.03_303_45_1.p', 'ibl_witten_16_fittype_pr prebias_subinfo = {} bias_subinfo = {} -for filename in os.listdir("./dynamic_GLMiHMM_crossvals/"): +for filename in test:#os.listdir("./dynamic_GLMiHMM_crossvals/"): if not filename.endswith('.p'): continue regexp = re.compile(r'((\w|-)+)_fittype_(\w+)_var_0.03_(\d+)_(\d+)_(\d+)') result = regexp.search(filename) subject = result.group(1) - if subject == 'ibl_witten_26': - print('here') + print(subject) + continue fit_type = result.group(3) seed = result.group(4) fit_num = result.group(5) diff --git a/simplex_animation.py b/simplex_animation.py index ddf6df7925ac502d5319f6b4a6986eb27adb6fff..dce03864bee596ead466388d2827318280b647fa 100644 --- a/simplex_animation.py +++ b/simplex_animation.py @@ -5,6 +5,7 @@ from analysis_pmf import type2color import math import matplotlib.pyplot as plt import copy +import string all_state_types = pickle.load(open("all_state_types.p", 'rb')) @@ -39,7 +40,15 @@ assert (test_count == 1).all() # quit() session_counter = -1 -alph = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'za', 'zb', 'zc', 'zd', 'ze', 'zf', 'zg', 'zh', 'zi', 'zj'] +string_list = [] +alphabet = string.ascii_lowercase + +for char1 in alphabet: + for char2 in alphabet: + string_list.append(char1 + char2) + if len(string_list) == 100: + break + # do as many sessions as it takes while True: @@ -58,7 +67,7 @@ while True: else: temp_counter = -1 - assert np.sum(sts[:, temp_counter]) <= 1 + assert np.sum(sts[:, temp_counter]) <= 1.000000000000001 if (sts[:, temp_counter] == 1).any(): x_offset[i] = offsets[i][0] * 0.01 @@ -66,8 +75,8 @@ while True: type_proportions[:, i] = sts[:, temp_counter] - plotSimplex(type_proportions.T, x_offset=x_offset, y_offset=y_offset, c=np.arange(len(all_state_types)), show=False, title="Session {}".format(1 + session_counter), - vertexcolors=[type2color[i] for i in range(3)], vertexlabels=['Type 1', 'Type 2', 'Type 3'], save_title="simplex_{}.png".format(alph[session_counter])) + plotSimplex(type_proportions.T * 10, x_offset=x_offset, y_offset=y_offset, c=np.arange(len(all_state_types)), show=False, title="Session {}".format(1 + session_counter), + vertexcolors=[type2color[i] for i in range(3)], vertexlabels=['Type 1', 'Type 2', 'Type 3'], save_title="simplex_{}.png".format(string_list[session_counter])) if not_ended == 0: break diff --git a/simplex_plot.py b/simplex_plot.py index 420dd26c661da32a2e9f89e2fbb937edea411f25..f9f8a0703b2fddf4610826d5d09604008906a060 100644 --- a/simplex_plot.py +++ b/simplex_plot.py @@ -15,7 +15,7 @@ import matplotlib.patches as PA def plotSimplex(points, fig=None, vertexlabels=['1: initial flat PMFs', '2: intermediate unilateral PMFs', '3: final bilateral PMFs'], title='', - save_title="dur_simplex.png", show=False, vertexcolors=['k', 'k', 'k'], x_offset=0, y_offset=0, **kwargs): + save_title="./summary_figures/dur_simplex.png", show=False, vertexcolors=['k', 'k', 'k'], x_offset=0, y_offset=0, **kwargs): """ Plot Nx3 points array on the 3-simplex (with optionally labeled vertices) @@ -47,12 +47,12 @@ def plotSimplex(points, fig=None, P.scatter(projected[:, 0], projected[:, 1], marker='*', color='r', s=np.mean(points.sum(1)) * 3.5)#s=50 # Leave some buffer around the triangle for vertex labels - fig.gca().set_xlim(-0.05, 1.05) - fig.gca().set_ylim(-0.05, 1.05) + fig.gca().set_xlim(-0.08, 1.08) + fig.gca().set_ylim(-0.08, 1.08) P.axis('off') if title != '': - P.annotate(title, (0.395, np.sqrt(3) / 2 + 0.025), size=24) + P.annotate(title, (0.395, np.sqrt(3) / 2 + 0.075), size=24) P.tight_layout() P.savefig(save_title, bbox_inches='tight', dpi=300, transparent=True)