Skip to content
Snippets Groups Projects
Commit d95ae20c authored by SebastianBruijns's avatar SebastianBruijns
Browse files

commit for reproducing KS014 two sample binomial test

parent 70ed7f02
No related branches found
No related tags found
No related merge requests found
No preview for this file type
No preview for this file type
No preview for this file type
......@@ -8,7 +8,9 @@ type2color = {0: 'green', 1: 'blue', 2: 'red'}
all_conts = np.array([-1, -0.5, -.25, -.125, -.062, 0, .062, .125, .25, 0.5, 1])
performance_points = np.array([-1, -1, 0, 0])
np.random.seed(12)
seed = np.random.randint(10000)
print(seed)
np.random.seed(2645) # 14 maybe, 21, 24, 1602 ok, 2645 good, 6429 good
def pmf_to_perf(pmf):
# determine performance of a pmf, but only on the omnipresent strongest contrasts
......@@ -78,6 +80,17 @@ if __name__ == "__main__":
plt.savefig("./summary_figures/type hist 3")
plt.close()
plt.plot(np.linspace(0, 1, 150), state_types_interpolation[0], color=type2color[0])
plt.plot(np.linspace(0, 1, 150), state_types_interpolation[1], color=type2color[1])
plt.plot(np.linspace(0, 1, 150), state_types_interpolation[2], color=type2color[2])
plt.ylabel("% of type across population", size=fs)
plt.xlabel("Interpolated session time", size=fs)
plt.ylim(0, 100)
sns.despine()
plt.tight_layout()
plt.savefig("./summary_figures/state evos")
plt.close()
all_first_pmfs_typeless = pickle.load(open("all_first_pmfs_typeless.p", 'rb'))
type_2_counter = 0
for subject in all_first_pmfs_typeless.keys():
......
......@@ -204,7 +204,7 @@ def plot_compact_all(average_slow, counter_slow, average_sudden, counter_sudden,
axs[0].legend(frameon=False, fontsize=17)
plt.tight_layout()
plt.savefig("./summary_figures/weight_changes/" + title + " augmented" * show_weight_augmentations, dpi=300)
plt.close()
plt.show()
def plot_compact_split(all_datapoints, title, show_first_and_last=False, show_weight_augmentations=False, width_divisor=20):
......@@ -310,7 +310,7 @@ def plot_histogram_diffs(all_datapoints, average, counter, x_lim_used_normal, x_
axs[j, i * 2 + 1].add_artist(con)
if i == 0:
axs[j, i].set_ylabel(local_ylabels[j])
axs[j, i].set_ylabel(local_ylabels[j], size=15)
# axs[j, i * 2 + 1].yaxis.set_ticklabels([])
if show_first_and_last:
if j < n_weights - 1:
......@@ -334,13 +334,16 @@ def plot_histogram_diffs(all_datapoints, average, counter, x_lim_used_normal, x_
mask = first_and_last_pmf[:, 1, -1] > 0
axs[j, i * 2 + 1].plot([x_lim_used_augment / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red')
if j == 0:
axs[j, i * 2].set_title("Type {}".format(i + 1), loc='right')
if 'sudden' in title:
axs[j, i * 2].set_title(r'Type ${} \rightarrow {}$'.format(i + 1, i + 2), loc='right', size=16, position=(1.45, 1))
else:
axs[j, i * 2].set_title(r'Type ${}$'.format(i + 1), loc='right', size=16, position=(1.45, 1))
if j == n_weights + show_weight_augmentations and i == 0:
axs[j, 0].set_xlabel("Initial distribution")
axs[j, 0].set_xlabel("Pre", size=15)
if show_deltas:
axs[j, 1].set_xlabel("Weight change")
axs[j, 1].set_xlabel("Deltas", size=15)
else:
axs[j, 1].set_xlabel("Changed distribution")
axs[j, 1].set_xlabel("Post", size=15)
if j < n_weights - 1:
axs[j, i * 2].set_xticks(list(range(x_steps, x_lim_used_normal, x_steps)))
......@@ -368,7 +371,8 @@ def plot_histogram_diffs(all_datapoints, average, counter, x_lim_used_normal, x_
if show_deltas:
axs[j, i * 2 + 1].set_ylim(bins[0] - means[0], bins[-1] - means[0])
plt.savefig("./summary_figures/weight_changes/" + title)
# plt.tight_layout()
plt.savefig("./summary_figures/weight_changes/" + title, dpi=300)
plt.close()
......@@ -404,17 +408,18 @@ if True:
temp_sudden_counter, switch, nonswitch = 0, 0, 0
for temp_sudden_change in all_sudden_transition_changes[0]:
prev_relevant_points, post_relevant_points = weights_to_pmf(temp_sudden_change[0])[[0, 1, -2, -1]], weights_to_pmf(temp_sudden_change[1])[[0, 1, -2, -1]]
print((1 - post_relevant_points[0] + 1 - post_relevant_points[1] + post_relevant_points[-2] + post_relevant_points[-1]) / 4)
# print((1 - post_relevant_points[0] + 1 - post_relevant_points[1] + post_relevant_points[-2] + post_relevant_points[-1]) / 4)
if np.abs(np.mean(post_relevant_points) - 0.5) < 0.05 or np.abs(np.mean(prev_relevant_points) - 0.5) < 0.05:
print("skipped " + str(prev_relevant_points) + " " + str(post_relevant_points))
# print("skipped " + str(prev_relevant_points) + " " + str(post_relevant_points))
continue
print("considered " + str(prev_relevant_points) + " " + str(post_relevant_points))
# print("considered " + str(prev_relevant_points) + " " + str(post_relevant_points))
temp_sudden_counter += 1
if (np.mean(prev_relevant_points) > 0.5 and np.mean(post_relevant_points) < 0.5) or (np.mean(prev_relevant_points) < 0.5 and np.mean(post_relevant_points) > 0.5):
switch += 1
else:
nonswitch += 1
quit()
print(switch, nonswitch)
plot_compact_all(average_slow, counter_slow, average_sudden, counter_sudden, title="all weight changes combined", all_data_sudden=all_data_sudden, all_data_slow=all_data_slow)
average, counter, all_datapoints = plot_traces_and_collate_data(data=all_sudden_transition_changes, augmented_data=aug_all_sudden_transition_changes, title="sudden_weight_change at transitions")
......@@ -423,9 +428,9 @@ if True:
plot_compact_split(all_datapoints, title="compact sudden_weight_change at transitions", show_weight_augmentations=show_weight_augmentations, width_divisor=18)
plot_histogram_diffs(all_datapoints, average, counter, x_lim_used_normal=22, x_lim_used_bias=11, x_lim_used_augment=40, bin_sets=bin_sets, title="weight changes sudden at transitions hists", show_deltas=False, show_weight_augmentations=show_weight_augmentations)
plot_histogram_diffs(all_datapoints, average, counter, x_lim_used_normal=23, x_lim_used_bias=11, x_lim_used_augment=40, bin_sets=bin_sets, title="weight changes sudden at transitions hists", show_deltas=False, show_weight_augmentations=show_weight_augmentations)
plot_histogram_diffs(all_datapoints, average, counter, x_lim_used_normal=22, x_lim_used_bias=11, x_lim_used_augment=40, bin_sets=bin_sets, title="weight changes sudden at transitions delta hists", show_deltas=True, show_weight_augmentations=show_weight_augmentations)
plot_histogram_diffs(all_datapoints, average, counter, x_lim_used_normal=23, x_lim_used_bias=11, x_lim_used_augment=40, bin_sets=bin_sets, title="weight changes sudden at transitions delta hists", show_deltas=True, show_weight_augmentations=show_weight_augmentations)
average, counter, all_datapoints = plot_traces_and_collate_data(data=all_sudden_changes, augmented_data=aug_all_sudden_changes, title="sudden_weight_change")
......@@ -437,7 +442,7 @@ if True:
plot_histogram_diffs(all_datapoints, average, counter, x_lim_used_normal=120, x_lim_used_bias=40, x_lim_used_augment=180, bin_sets=bin_sets, title="weight changes sudden delta hists", show_deltas=True, show_weight_augmentations=show_weight_augmentations)
dur_lims = [(52, 25), (46, 23), (38, 19), (35, 17), (30, 15), (25, 12), (20, 10), (18, 9)]
dur_lims = [(56, 28), (48, 25), (35, 20), (35, 17), (30, 15), (25, 12), (20, 10), (18, 9)]
for min_dur_counter, min_dur in enumerate([2, 3, 4, 5, 7, 9, 11, 15]):
average = np.zeros((n_weights + 1, n_types, 2))
counter = np.zeros((n_weights + 1, n_types))
......@@ -529,6 +534,9 @@ if True:
plot_histogram_diffs(all_datapoints, average, counter, x_lim_used_normal=x_lim_used_normal, x_lim_used_bias=x_lim_used_bias, bin_sets=bin_sets,
title="weight changes min dur {} hists".format(min_dur), show_deltas=False, show_first_and_last=True)
if min_dur == 5:
x_lim_used_normal, x_lim_used_bias = 25, 8
x_lim_used_normal, x_lim_used_bias = int(x_lim_used_normal * 2.5), int(x_lim_used_bias * 2.5)
plot_histogram_diffs(all_datapoints, average, counter, x_lim_used_normal=x_lim_used_normal, x_lim_used_bias=x_lim_used_bias, bin_sets=bin_sets,
title="weight changes min dur {} delta hists".format(min_dur), show_deltas=True, show_first_and_last=True)
......
......@@ -78,7 +78,7 @@ if __name__ == "__main__":
plt.show()
# histogram of regression diffs
plt.hist(regression_diffs, color='grey', bins=20)
plt.hist([item for row in regression_diffs for item in row], color='grey', bins=20)
plt.ylabel("# regressions", size=fontsize)
plt.xlabel("Reward rate diff", size=fontsize)
......@@ -86,3 +86,22 @@ if __name__ == "__main__":
plt.tight_layout()
plt.savefig("./summary_figures/Regression diffs")
plt.show()
# figure out whether mice with long training use their states for longer
all_state_percentages = pickle.load(open("multi_chain_saves/all_state_percentages.p", 'rb'))
total_lengths = []
mean_state_appearances = []
for asp in all_state_percentages:
total_lengths.append(asp.shape[1])
mean_state_appearances.append(np.mean(np.sum(asp > 0.05, 1)))
plt.scatter(total_lengths, mean_state_appearances)
plt.xlabel("# of sessions")
plt.ylabel("Mean state sessions")
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.tight_layout()
plt.savefig("./summary_figures/training_time_versus_mean_state_sessions.png", dpi=300)
plt.show()
\ No newline at end of file
......@@ -59,18 +59,22 @@ for i, pmf_lists in enumerate(type_1_to_2_save):
if type_2_bias == 0:
boring_type_2 += 1
continue
if expressed_biases == [0]:
no_previous_bias += 1
continue
if -1 in expressed_biases and 1 in expressed_biases:
all_previous_biases += 1
continue
# continue
# if expressed_biases == [0]:
# no_previous_bias += 1
# continue
# if -1 in expressed_biases and 1 in expressed_biases:
# all_previous_biases += 1
# continue
if type_2_bias in expressed_biases:
previously_expressed += 1
if type_2_bias == 0:
print('prev exp')
else:
not_expressed += 1
if type_2_bias == 0:
print('not exp')
print(boring_type_2, no_previous_bias, all_previous_biases)
print(previously_expressed, not_expressed)
......@@ -80,7 +84,7 @@ print(binomtest(previously_expressed, previously_expressed + not_expressed, 0.5)
from scipy.stats import linregress
quantiles = np.linspace(0, 1, 7)[1:]
quantiles = np.linspace(0, 1, 2)[1:]
quant_sessions = np.quantile(num_trials, quantiles)
prev_session_bound = num_trials.min() - 1
......@@ -94,4 +98,8 @@ for quant_session in quant_sessions:
prev_session_bound = quant_session
plt.scatter(num_trials, num_states)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.tight_layout()
plt.savefig("./summary_figures/state_num_regression.png", dpi=300)
plt.show()
\ No newline at end of file
......@@ -74,7 +74,7 @@ for filename in os.listdir("./multi_chain_saves/"):
subjects.append(subject)
already_fit = list(loading_info.keys())
subjects = ["fip_{}".format(i) for i in list(range(13, 17)) + list(range(26, 43))]
subjects = ["KS055"]
# remaining_subs = [s for s in subjects if s not in amiss and s not in already_fit]
# print(remaining_subs)
......
This diff is collapsed.
import json
import numpy as np
from dyn_glm_chain_analysis import MCMC_result_list
from dyn_glm_chain_analysis import MCMC_result_list, MCMC_result
import pickle
import matplotlib.pyplot as plt
import os
......@@ -8,16 +8,16 @@ import os
fit_variance = 0.03
# subjects = list(loading_info.keys())
# error: KS043, KS045, 'NYU-12', ibl_witten_15, NYU-21, CSHL052, KS003
# done: NYU-46, NYU-39, ibl_witten_19, NYU-48
subjects = ['NR_0027', 'NR_0019', 'PL024', 'UCLA037', 'UCLA036', 'UCLA015', 'PL017', 'NR_0020', 'NR_0021', 'UCLA034']
# something wrong with 'GLM_Sim_11', GLM_Sim_13
subjects = ['GLM_Sim_07', 'GLM_Sim_10', 'GLM_Sim_16', 'GLM_Sim_15', 'GLM_Sim_12', 'GLM_Sim_14']
def create_mode_indices(test, subject, fit_type):
dim = 3
try:
xy, z = pickle.load(open("multi_chain_saves/xyz_{}_{}_var_{}.p".format(subject, fit_type, fit_variance), 'rb'))
except Exception:
except Exception as e:
print(e)
# this is cluster work
return
print('Doing PCA')
......@@ -129,7 +129,7 @@ def threshold_search(xy, z, test, mode_prefix, subject, fit_type):
loading_info[subject] = {}
loading_info[subject]['mode prob level'] = prob_level
# pickle.dump(mode_indices, open("multi_chain_saves/{}mode_indices_{}_{}_var_{}.p".format(mode_prefix, subject, fit_type, fit_variance), 'wb'))
pickle.dump(mode_indices, open("multi_chain_saves/{}mode_indices_{}_{}_var_{}.p".format(mode_prefix, subject, fit_type, fit_variance), 'wb'))
# we do this on the cluster now
# consistencies = test.consistency_rsa(indices=mode_indices) # do this on the cluster from now on
# pickle.dump(consistencies, open("multi_chain_saves/{}mode_consistencies_{}_{}.p".format(mode_prefix, subject, fit_type), 'wb', protocol=4))
......@@ -154,8 +154,8 @@ def conditions_fulfilled(z, xy, conds):
def state_set_and_plot(test, mode_prefix, subject, fit_type):
mode_indices = pickle.load(open("multi_chain_saves/{}mode_indices_{}_{}.p".format("first_", subject, fit_type), 'rb'))
consistencies = pickle.load(open("multi_chain_saves/{}consistencies_{}_{}.p".format(mode_prefix, subject, fit_type), 'rb'))
mode_indices = pickle.load(open("multi_chain_saves/{}mode_indices_{}_{}_var_{}.p".format("first_", subject, fit_type, fit_variance), 'rb'))
consistencies = pickle.load(open("multi_chain_saves/{}consistencies_{}_{}_var_{}.p".format(mode_prefix, subject, fit_type, fit_variance), 'rb'))
session_bounds = list(np.cumsum([len(s) for s in test.results[0].models[-1].stateseqs]))
import scipy.cluster.hierarchy as hc
......@@ -236,15 +236,16 @@ if fit_type == 'bias':
elif fit_type == 'prebias':
loading_info = json.load(open("canonical_infos_fitvar_{}.json".format(fit_variance), 'r'))
subjects = ['KS014']
# subjects = ['CSH_ZAD_022']
for subject in subjects:
test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'rb'))
# print("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type))
test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}_var_{}.p".format(subject, fit_type, fit_variance), 'rb'))
# if os.path.isfile("multi_chain_saves/{}mode_indices_{}_{}.p".format('first_', subject, fit_type)):
# print("It has been done")
# continue
print('Computing sub result')
# create_mode_indices(test, subject, fit_type)
state_set_and_plot(test, '', subject, fit_type)
create_mode_indices(test, subject, fit_type)
# state_set_and_plot(test, '', subject, fit_type)
# print("second mode?")
# if input() in ['y', 'yes']:
# state_set_and_plot(test, 'second_', subject, fit_type)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment