Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • sbruijns/ihmm_behav_states
1 result
Show changes
Commits on Source (17)
Showing
with 1375 additions and 1999 deletions
......@@ -20,5 +20,7 @@ peter_fiugres/
consistency_data/
dynamic_GLM_figures/
dynamic_GLMiHMM_fits2/
glm sim mice/
dynamic_GLMiHMM_crossvals/
__pycache__/
fibre_data/
sofiya_data/
No preview for this file type
No preview for this file type
No preview for this file type
This diff is collapsed.
"""
Code for creating figure 6, 12, A16, and A17 (and many slighly variations on them)
"""
import numpy as np
import matplotlib.pyplot as plt
import pickle
......@@ -6,11 +9,19 @@ from analysis_pmf import pmf_type, type2color
from mpl_toolkits import mplot3d
from matplotlib.patches import ConnectionPatch
contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])
contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1]
def weights_to_pmf(weights, with_bias=1):
psi = weights[0] * contrasts_R + weights[1] * contrasts_L + with_bias * weights[-1]
return 1 / (1 + np.exp(psi))
# this will show an additional metric inferred from the weights, the range of the PMF (distance between min and max of PMF)
show_weight_augmentations = False
all_weight_trajectories = pickle.load(open("multi_chain_saves/all_weight_trajectories.p", 'rb'))
first_and_last_pmf = np.array(pickle.load(open("multi_chain_saves/first_and_last_pmf.p", 'rb')))
all_sudden_changes = pickle.load(open("multi_chain_saves/all_sudden_changes.p", 'rb'))
all_sudden_transition_changes = pickle.load(open("multi_chain_saves/all_sudden_transition_changes.p", 'rb'))
......@@ -53,15 +64,6 @@ def pmf_to_perf(pmf):
return np.mean(np.abs(performance_points[reduced_points] + pmf[reduced_points]))
contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])
contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1]
def weights_to_pmf(weights, with_bias=1):
psi = weights[0] * contrasts_R + weights[1] * contrasts_L + with_bias * weights[-1]
return 1 / (1 + np.exp(psi))
def pmf_type_rew(weights):
rew = pmf_to_perf(weights_to_pmf(weights))
if rew < 0.6:
......@@ -158,7 +160,7 @@ def plot_compact(average, counter, title, show_first_and_last=False, show_weight
def plot_compact_all(average_slow, counter_slow, average_sudden, counter_sudden, title, all_data_sudden=[], all_data_slow=[]):
"""Plot a whole bunch of changes"""
titles = ["Type 1", r'Type $1 \rightarrow 2$', "Type 2", r'Type $2 \rightarrow 3$', "Type 3"]
f, axs = plt.subplots(1, 5, figsize=(4 * 5, 8))
f, axs = plt.subplots(1, 5, width_ratios=[1.1, 1, 1, 1, 1.1], figsize=(4 * 5, 8))
average = np.zeros((average_slow.shape[0], average_slow.shape[1] + average_sudden.shape[1], 2))
counter = np.zeros((counter_slow.shape[0], counter_slow.shape[1] + counter_sudden.shape[1]))
all_data = create_nested_list([n_weights + 1, average_slow.shape[1] + average_sudden.shape[1], 2])
......@@ -178,24 +180,24 @@ def plot_compact_all(average_slow, counter_slow, average_sudden, counter_sudden,
if i == 0:
axs[i].set_ylabel("Weights", size=38)
if j < n_weights - 1:
axs[i].plot([0.1], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', color=local_weight_colours[j]) # also plot weights of very first state average
axs[i].plot([-0.1], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', color=local_weight_colours[j]) # also plot weights of very first state average
if j == n_weights - 1:
mask = first_and_last_pmf[:, 0, -1] < 0
axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=local_weight_colours[j]) # separete biases again
axs[i].plot([-0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=local_weight_colours[j]) # separete biases again
if j == n_weights:
mask = first_and_last_pmf[:, 0, -1] > 0
axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=local_weight_colours[j])
axs[i].plot([-0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=local_weight_colours[j])
else:
axs[i].yaxis.set_ticklabels([])
if i == 4:
if j < n_weights - 1:
axs[i].plot([0.9], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', color=local_weight_colours[j]) # also plot weights of very last state average
axs[i].plot([1.1], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', color=local_weight_colours[j]) # also plot weights of very last state average
if j == n_weights - 1:
mask = first_and_last_pmf[:, 1, -1] < 0
axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=local_weight_colours[j]) # separete biases again
axs[i].plot([1.1], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=local_weight_colours[j]) # separete biases again
if j == n_weights:
mask = first_and_last_pmf[:, 1, -1] > 0
axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=local_weight_colours[j])
axs[i].plot([1.1], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=local_weight_colours[j])
axs[i].annotate("n={}".format(int(counter[0, i])), (0.06, 0.025), xycoords='axes fraction', size=26)
if j == 0:
axs[i].set_title(titles[i], size=38)
......@@ -258,6 +260,8 @@ def plot_histogram_diffs(all_datapoints, average, counter, x_lim_used_normal, x_
Might have to mess quite a bit with the y-axis"""
sudden_changes = len(all_datapoints[0]) == 2
x_steps = x_lim_used_bias - x_lim_used_bias % 5
delta_dists = {} # save the deltas, to see whether bias moves more in sudden changes
f, axs = plt.subplots(n_weights + 1 + show_weight_augmentations, (n_types - sudden_changes) * 2, figsize=(4 * (3 - sudden_changes), 9))
for i in range(n_types - sudden_changes):
for j in range(n_weights + 1 + show_weight_augmentations):
......@@ -273,6 +277,7 @@ def plot_histogram_diffs(all_datapoints, average, counter, x_lim_used_normal, x_
means = average[j, i] / counter[j, i]
axs[j, i * 2].hist(all_datapoints[j][i][0], orientation='horizontal', bins=bins, color='grey', alpha=0.5)
if show_deltas:
delta_dists[(i, j)] = np.array(all_datapoints[j][i][1]) - np.array(all_datapoints[j][i][0])
axs[j, i * 2 + 1].hist(np.array(all_datapoints[j][i][1]) - np.array(all_datapoints[j][i][0]), orientation='horizontal', bins=bins, color='red', alpha=0.5)
else:
axs[j, i * 2 + 1].hist(all_datapoints[j][i][1], orientation='horizontal', bins=bins, color='grey', alpha=0.5)
......@@ -315,24 +320,24 @@ def plot_histogram_diffs(all_datapoints, average, counter, x_lim_used_normal, x_
if show_first_and_last:
if j < n_weights - 1:
axs[j, i * 2].plot([x_lim_used_normal / 8], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', c='red') # also plot weights of very first state average
elif j < n_weights + 1:
elif j == n_weights - 1:
mask = first_and_last_pmf[:, 0, -1] < 0
axs[j, i * 2].plot([x_lim_used_bias / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red') # separete biases again
else:
elif j == n_weights:
mask = first_and_last_pmf[:, 0, -1] > 0
axs[j, i * 2].plot([x_lim_used_augment / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red')
axs[j, i * 2].plot([x_lim_used_bias / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red')
# else:
# axs[j, i * 2].yaxis.set_ticklabels([])
# axs[j, i * 2 + 1].yaxis.set_ticklabels([])
if i == n_types - 1 and show_first_and_last:
if j < n_weights - 1:
axs[j, i * 2 + 1].plot([x_lim_used_normal / 8], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', c='red') # also plot weights of very last state average
elif j < n_weights + 1:
elif j == n_weights - 1:
mask = first_and_last_pmf[:, 1, -1] < 0
axs[j, i * 2 + 1].plot([x_lim_used_bias / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red') # separete biases again
else:
elif j == n_weights:
mask = first_and_last_pmf[:, 1, -1] > 0
axs[j, i * 2 + 1].plot([x_lim_used_augment / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red')
axs[j, i * 2 + 1].plot([x_lim_used_bias / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red')
if j == 0:
if 'sudden' in title:
axs[j, i * 2].set_title(r'Type ${} \rightarrow {}$'.format(i + 1, i + 2), loc='right', size=16, position=(1.45, 1))
......@@ -375,6 +380,8 @@ def plot_histogram_diffs(all_datapoints, average, counter, x_lim_used_normal, x_
plt.savefig("./summary_figures/weight_changes/" + title, dpi=300)
plt.close()
return delta_dists
if True:
......@@ -430,7 +437,7 @@ if True:
plot_histogram_diffs(all_datapoints, average, counter, x_lim_used_normal=23, x_lim_used_bias=11, x_lim_used_augment=40, bin_sets=bin_sets, title="weight changes sudden at transitions hists", show_deltas=False, show_weight_augmentations=show_weight_augmentations)
plot_histogram_diffs(all_datapoints, average, counter, x_lim_used_normal=23, x_lim_used_bias=11, x_lim_used_augment=40, bin_sets=bin_sets, title="weight changes sudden at transitions delta hists", show_deltas=True, show_weight_augmentations=show_weight_augmentations)
sudden_dists = plot_histogram_diffs(all_datapoints, average, counter, x_lim_used_normal=23, x_lim_used_bias=11, x_lim_used_augment=40, bin_sets=bin_sets, title="weight changes sudden at transitions delta hists", show_deltas=True, show_weight_augmentations=show_weight_augmentations)
average, counter, all_datapoints = plot_traces_and_collate_data(data=all_sudden_changes, augmented_data=aug_all_sudden_changes, title="sudden_weight_change")
......@@ -535,13 +542,26 @@ if True:
title="weight changes min dur {} hists".format(min_dur), show_deltas=False, show_first_and_last=True)
if min_dur == 5:
x_lim_used_normal, x_lim_used_bias = 25, 8
x_lim_used_normal, x_lim_used_bias = int(x_lim_used_normal * 2.5), int(x_lim_used_bias * 2.5)
plot_histogram_diffs(all_datapoints, average, counter, x_lim_used_normal=x_lim_used_normal, x_lim_used_bias=x_lim_used_bias, bin_sets=bin_sets,
slow_dists = plot_histogram_diffs(all_datapoints, average, counter, x_lim_used_normal=x_lim_used_normal, x_lim_used_bias=x_lim_used_bias, bin_sets=bin_sets,
title="weight changes min dur {} delta hists".format(min_dur), show_deltas=True, show_first_and_last=True)
quit()
if min_dur == 5:
import scipy
from itertools import product
sig_counter, insig_counter = 0, 0
for i, j, k in product([3, 4], [0, 1], [0, 1, 2]): # for weights 3 and 4, we compare over all combos of transitions
print()
print(j, i, k, i)
print(np.mean(np.abs(sudden_dists[(j, i)])), np.mean(np.abs(slow_dists[(k, i)])))
res = scipy.stats.mannwhitneyu(np.abs(sudden_dists[(j, i)]), np.abs(slow_dists[(k, i)]))
print(res)
if res.pvalue < 0.05:
sig_counter += 1
else:
insig_counter += 1
print(sig_counter, insig_counter)
quit()
# all pmf weights
......@@ -561,18 +581,6 @@ else:
colors_rew = [type2color[pmf_type_rew(x)] for x in apw]
if False:
for i, weights in enumerate(apw):
if pmf_type(weights_to_pmf(weights)) != pmf_type_rew(weights):
plt.plot(weights_to_pmf(weights))
plt.ylim(0, 1)
plt.title("Classic says {}, reward says {}, ({})".format(1 + pmf_type(weights_to_pmf(weights)), 1 + pmf_type_rew(weights), pmf_to_perf(weights_to_pmf(weights))))
print(weights_to_pmf(weights))
plt.tight_layout()
plt.savefig(folder + "divergent classification {}".format(i))
plt.close()
type1_rews = []
type2_rews = []
type3_rews = []
......@@ -607,67 +615,7 @@ else:
type1_rews, type2_rews, type3_rews = np.array(type1_rews), np.array(type2_rews), np.array(type3_rews)
bound1 = np.linspace(0.55, 0.65, 100)
bound2 = np.linspace(0.75, 0.85, 100)
opt_bound1, errors1 = 0, 100
for b in bound1:
if errors1 > (np.sum(type1_rews > b) + np.sum(type2_rews < b)):
opt_bound1 = b
errors1 = np.sum(type1_rews > b) + np.sum(type2_rews < b)
opt_bound2, errors2 = 0, 100
for b in bound2:
if errors2 > (np.sum(type2_rews > b) + np.sum(type3_rews < b)):
opt_bound2 = b
errors2 = np.sum(type2_rews > b) + np.sum(type3_rews < b)
print(opt_bound1, errors1, opt_bound2, errors2)
opt_bound2 = 0.780303
print("Optimal bound 2: {}, {}".format(np.sum(type2_rews > opt_bound2), np.sum(type3_rews < opt_bound2)))
man_bound = 0.7827
print("Manual bound 2: {}, {}".format(np.sum(type2_rews > man_bound), np.sum(type3_rews < man_bound)))
bins = np.linspace(0, 1, 30)
plt.hist([type1_rews, type2_rews, type3_rews], bins=bins, label=["Type 1", "Type 2", "Type 3"])
plt.axvline(0.6, c='k')
plt.axvline(0.7827, c='k')
plt.legend()
plt.savefig(folder + "hist 1" + " only first states" * only_first_states)
plt.show()
plt.hist([type1_rews, type2_rews, type3_rews], bins=bins, label=["Type 1", "Type 2", "Type 3"], stacked=True)
plt.axvline(0.6, c='k')
plt.axvline(0.7827, c='k')
plt.legend()
plt.savefig(folder + "hist 2" + " only first states" * only_first_states)
plt.show()
plt.subplot(1, 3, 1)
plt.title("Bins = 25")
plt.hist(all_rews, bins=25)
plt.axvline(0.6, c='k')
plt.axvline(0.7827, c='k')
plt.subplot(1, 3, 2)
plt.title("Bins = 40")
plt.hist(all_rews, bins=40)
plt.axvline(0.6, c='k')
plt.axvline(0.7827, c='k')
plt.subplot(1, 3, 3)
plt.title("Bins = 55")
plt.hist(all_rews, bins=55)
plt.axvline(0.6, c='k')
plt.axvline(0.7827, c='k')
plt.savefig(folder + "hists compare" + " only first states" * only_first_states)
plt.show()
# reward rate and boundaries, figure 12
fig = plt.figure(figsize=(13 * 3 / 5, 9 * 3 / 5))
plt.hist(all_rews, bins=40, color='grey')
plt.axvline(0.6, c='k')
......@@ -680,129 +628,3 @@ plt.gca().spines[['right', 'top']].set_visible(False)
plt.tight_layout()
plt.savefig(folder + "single hist" + " only first states" * only_first_states)
plt.show()
quit()
xy = np.vstack([apw[:, i] for i in range(4)])
z = gaussian_kde(xy)(xy)
plt.subplot(1, 3, 1)
plt.scatter(apw[:, 0], apw[:, 1], c=z)
plt.xlabel("Cont right")
plt.ylabel("Cont left")
plt.subplot(1, 3, 2)
plt.scatter(apw[:, 3], apw[:, 1], c=z)
plt.xlabel("Bias")
plt.ylabel("Cont left")
plt.subplot(1, 3, 3)
plt.scatter(apw[:, 3], apw[:, 0], c=z)
plt.xlabel("Bias")
plt.ylabel("Cont right")
plt.savefig(folder + "density scatter")
plt.show()
plt.subplot(1, 4, 1)
sc = plt.scatter(apw[:, 0], apw[:, 1], c=colors)
fig, ax1 = plt.gcf(), plt.gca()
plt.xlabel("Cont right")
plt.ylabel("Cont left")
annot1 = ax1.annotate("", xy=(0, 0), xytext=(20, 20), textcoords="offset points",
bbox=dict(boxstyle="round", fc="w"),
arrowprops=dict(arrowstyle="->"))
annot1.set_visible(False)
plt.subplot(1, 4, 2)
plt.scatter(apw[:, 3], apw[:, 1], c=colors)
ax2 = plt.gca()
plt.xlabel("Bias")
plt.ylabel("Cont left")
annot2 = ax2.annotate("", xy=(0, 0), xytext=(20, 20), textcoords="offset points",
bbox=dict(boxstyle="round", fc="w"),
arrowprops=dict(arrowstyle="->"))
annot2.set_visible(False)
plt.subplot(1, 4, 3)
plt.scatter(apw[:, 3], apw[:, 0], c=colors)
ax3 = plt.gca()
plt.xlabel("Bias")
plt.ylabel("Cont right")
annot3 = ax3.annotate("", xy=(0, 0), xytext=(20, 20), textcoords="offset points",
bbox=dict(boxstyle="round", fc="w"),
arrowprops=dict(arrowstyle="->"))
annot3.set_visible(False)
def update_annot(ind):
pos = sc.get_offsets()[ind["ind"][0]]
annot1.xy = pos
text = "{}".format(np.round(apw[ind["ind"][0]], 2))
annot1.set_text(text)
annot2.xy = apw[ind["ind"][0]][[3, 1]]
text = "{}".format(np.round(apw[ind["ind"][0]], 2))
annot2.set_text(text)
annot3.xy = apw[ind["ind"][0]][[3, 0]]
text = "{}".format(np.round(apw[ind["ind"][0]], 2))
annot3.set_text(text)
plt.subplot(1, 4, 4)
plt.cla()
plt.plot(weights_to_pmf(apw[ind["ind"][0]]))
plt.ylim(0, 1)
plt.ylabel("P(rightwards)")
plt.xlabel("Contrasts")
def hover(event):
vis = annot1.get_visible()
if event.inaxes == ax1:
cont, ind = sc.contains(event)
if cont:
update_annot(ind)
annot1.set_visible(True)
annot2.set_visible(True)
annot3.set_visible(True)
fig.canvas.draw_idle()
else:
if vis:
annot1.set_visible(False)
annot2.set_visible(False)
annot3.set_visible(False)
fig.canvas.draw_idle()
fig.canvas.mpl_connect("motion_notify_event", hover)
plt.savefig(folder + "type scatter")
plt.show()
fig = plt.figure(figsize=(16, 9))
ax = plt.axes(projection='3d')
ax.scatter3D(apw[:, 0], apw[:, 1], apw[:, 3], c=colors)
ax.view_init(27.5, -137)
plt.savefig(folder + "3d types")
plt.show()
fig = plt.figure(figsize=(16, 9))
ax = plt.axes(projection='3d')
ax.scatter3D(apw[:, 0], apw[:, 1], apw[:, 3], c=colors_rew)
ax.view_init(27.5, -137)
plt.savefig(folder + "3d types new")
plt.show()
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.scatter3D(apw[:, 0], apw[:, 1], apw[:, 3], c=z)
plt.savefig(folder + "3d density")
plt.show()
"""Cool function for offsetting points in a scatter."""
"""
Code for figure 8 and variations, as well as some side info.
Also: Cool function for offsetting points in a scatter.
"""
import pickle
import matplotlib.pyplot as plt
import numpy as np
......@@ -16,7 +19,9 @@ if __name__ == "__main__":
assert (regressions[:, 0] == np.sum(regressions[:, 2:], 1)).all() # total # of regressions must be sum of # of regressions per type
print(pearsonr(regressions[:, 0], regressions[:, 1]))
# (0.7025694973739075, 5.979216179591e-18)
print("Percentage of mice with regressions across population: {}".format(100 * np.sum(regressions[:, 0] != 0) / regressions.shape[0]))
# (0.7776378443719886, 2.4699096615011557e-25)
# Percentage of mice with regressions across population: 86.5546218487395
offset = 0.25
plt.figure(figsize=(16 * 0.9, 9 * 0.9))
......
"""
Unused figures (kinda weird)
"""
import numpy as np
import matplotlib.pyplot as plt
import pickle
......
"""
Unused figures on the connection between regressions and passed time to last session.
TODO: check wether dates are correct
"""
import numpy as np
import matplotlib.pyplot as plt
import pickle
regressed_or_not_list = pickle.load(open("multi_chain_saves/regressed_or_not_list.p", 'rb'))
regression_magnitude_list = pickle.load(open("multi_chain_saves/regression_magnitude_list.p", 'rb'))
dates_list = pickle.load(open("multi_chain_saves/dates_list.p", 'rb'))
i, j = 0, 0
dates_fixed = []
regressed_or_not_list_fixed = []
regression_magnitude_fixed = []
while True:
if i >= len(dates_list) or j >= len(regressed_or_not_list):
break
if len(dates_list[i]) - 1 == len(regressed_or_not_list[j]):
dates_fixed.append(dates_list[i])
regression_magnitude_fixed.append(regression_magnitude_list[i])
regressed_or_not_list_fixed.append(regressed_or_not_list[i])
i += 1
j += 1
else:
i += 1
j += 1
print(False, i - 1)
regressed_or_not_list = regressed_or_not_list_fixed
regression_magnitude_list = regression_magnitude_fixed
dates_diff = []
for sub_dates in dates_fixed:
temp = []
for i in range(len(sub_dates) - 1):
temp.append(sub_dates[i+1] - sub_dates[i])
dates_diff.append(temp)
regressed_or_not_list = [item for sublist in regressed_or_not_list for item in sublist]
regression_magnitude_list = [item for sublist in regression_magnitude_list for item in sublist]
dates_diff = [item for sublist in dates_diff for item in sublist]
plt.scatter([x.total_seconds() for x in dates_diff], regressed_or_not_list)
plt.show()
plt.scatter([x.total_seconds() for x in dates_diff], regression_magnitude_list)
plt.show()
from scipy.stats import pearsonr, mannwhitneyu
print(pearsonr([x.total_seconds() for x in dates_diff], np.array(regression_magnitude_list)))
a = np.array(regressed_or_not_list)
b = np.array([x.total_seconds() for x in dates_diff])
print(mannwhitneyu(b[a == 0], b[a == 1]))
captured_states = pickle.load(open("captured_states.p", 'rb'))
# captured states is: captured_states.append((len([item for sublist in state_sets for item in sublist if len(sublist) > 40]), test.results[0].n_datapoints, len([s for s in state_sets if len(s) > 40])))
......@@ -10,7 +63,8 @@ num_covered_trials = np.array([x for x, _, _, _ in captured_states])
num_states = np.array([x for _, _, x, _ in captured_states])
num_sessions = np.array([x for _, _, _, x in captured_states])
print(num_covered_trials / num_trials)
# print(num_covered_trials / num_trials)
print("Mean fraction of accounted trials: {}".format(np.mean(num_covered_trials / num_trials)))
print("Minimum fraction of accounted trials: {}".format(np.min(num_covered_trials / num_trials)))
print(np.unique(num_states, return_counts=True))
......@@ -24,7 +78,7 @@ counter = 0
all_prev_biases = []
all_biases = []
neutral_counter, symm_counter = 0, 0
boring_type_2, no_previous_bias, all_previous_biases = 0, 0, 0
boring_type_2 = 0
for i, pmf_lists in enumerate(type_1_to_2_save):
if pmf_lists == [[], []]:
continue
......@@ -49,34 +103,18 @@ for i, pmf_lists in enumerate(type_1_to_2_save):
all_prev_biases.append(expressed_biases)
all_biases.append(expressed_biases + [type_2_bias])
# print()
# if type_2_bias == 0:
# print("neutral bias")
# neutral_counter += 1
# if np.abs(pmf_lists[1][0][0] + pmf_lists[1][0][-1] - 1) <= 0.1:
# print("symmetric")
# symm_counter += 1
if type_2_bias == 0:
boring_type_2 += 1
# continue
# if expressed_biases == [0]:
# no_previous_bias += 1
# continue
# if -1 in expressed_biases and 1 in expressed_biases:
# all_previous_biases += 1
# continue
if type_2_bias in expressed_biases:
previously_expressed += 1
if type_2_bias == 0:
print('prev exp')
else:
not_expressed += 1
if type_2_bias == 0:
print('not exp')
print(boring_type_2, no_previous_bias, all_previous_biases)
print(boring_type_2)
print(previously_expressed, not_expressed)
from scipy.stats import binomtest
......@@ -102,4 +140,18 @@ plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.tight_layout()
plt.savefig("./summary_figures/state_num_regression.png", dpi=300)
plt.show()
session_time_at_sudden_changes = pickle.load(open("multi_chain_saves/session_time_at_sudden_changes.p", 'rb'))
f, axs = plt.subplots(2, 1, figsize=(16 * 0.75, 9 * 0.75), sharex=True, sharey=True)
axs[0].hist(session_time_at_sudden_changes[0])
axs[1].hist(session_time_at_sudden_changes[1])
axs[0].set_ylabel("First type 2 intro")
axs[1].set_ylabel("First type 3 intro")
axs[1].set_xlabel("Interpolated session time")
plt.tight_layout()
plt.savefig("./summary_figures/new type intro points.png", dpi=300)
plt.show()
\ No newline at end of file
......@@ -102,7 +102,6 @@ for subject in subjects:
eids, sess_info = one.search(subject=subject, date_range=['2015-01-01', '2025-01-01'], details=True)
start_times = [sess['date'] for sess in sess_info]
protocols = [sess['task_protocol'] for sess in sess_info]
nums = [sess['number'] for sess in sess_info]
print("original # of eids {}".format(len(eids)))
......@@ -113,17 +112,20 @@ for subject in subjects:
eids = [x for _, x in sorted(zip(start_times, eids))]
dates = [x for x, _ in sorted(zip(start_times, eids))]
nums = [x for _, x in sorted(zip(start_times, nums))]
start_times = sorted(start_times)
prev_date = None
prev_num = -1
fixed_dates = []
fixed_eids = []
fixed_start_times = []
additional_eids = []
for d, e, n in zip(dates, eids, nums):
for d, e, n, st in zip(dates, eids, nums, start_times):
if d != prev_date:
fixed_dates.append(d)
fixed_eids.append(e)
additional_eids.append([])
fixed_start_times.append(st)
else:
assert n > prev_num
if n == 1:
......@@ -133,8 +135,6 @@ for subject in subjects:
prev_date = d
prev_num = n
protocols = [x for _, x in sorted(zip(start_times, protocols))]
# in case you want it
if old_style:
fixed_eids = eids
......@@ -148,12 +148,13 @@ for subject in subjects:
bias_start = 0
info_dict = {'subject': subject, 'dates': fixed_dates, 'eids': fixed_eids}
info_dict = {'subject': subject, 'dates': [st for st in sorted(start_times)], 'eids': eids, 'date_and_session_num': {}}
info_dict = {'subject': subject, 'dates': [st for st in sorted(fixed_start_times)], 'eids': eids, 'date_and_session_num': {}}
contrast_set = {0, 1, 9, 10}
rel_count = -1
for i, (eid, extra_eids, start_time) in enumerate(zip(fixed_eids, additional_eids, sorted(start_times))):
assert len(fixed_eids) == len(additional_eids) == len(fixed_start_times)
for i, (eid, extra_eids, start_time) in enumerate(zip(fixed_eids, additional_eids, fixed_start_times)):
try:
trials = one.load_object(eid, 'trials')
......
"""
Script for downloading the mice of the paper "Dissecting the Complexities of Learning With Infinite Hidden Markov Models"
Download this using the IBL environment: https://github.com/int-brain-lab/iblenv
"""
from one.api import ONE
import matplotlib.pyplot as plt
import numpy as np
......@@ -5,153 +9,48 @@ import pandas as pd
import seaborn as sns
import pickle
import json
import os
import re
one = ONE()
one = ONE(base_url='https://openalyx.internationalbrainlab.org', password='*****')
regexp = re.compile(r'Subjects/\w*/((\w|-)+)/_ibl')
datasets = one.alyx.rest('datasets', 'list', tag='2023_Q4_Bruijns_et_al')
contrast_to_num = {-1.: 0, -0.5: 1, -0.25: 2, -0.125: 3, -0.0625: 4, 0: 5, 0.0625: 6, 0.125: 7, 0.25: 8, 0.5: 9, 1.: 10}
dataset_types = ['choice', 'contrastLeft', 'contrastRight',
'feedbackType', 'probabilityLeft', 'response_times',
'goCue_times']
def get_df(trials):
if np.all(None == trials['choice']) or np.all(None == trials['contrastLeft']) or np.all(None == trials['contrastRight']) or np.all(None == trials['feedbackType']) or np.all(None == trials['probabilityLeft']): # or np.all(None == data_dict['response_times']):
return None, None
d = {'response': trials['choice'], 'contrastL': trials['contrastLeft'], 'contrastR': trials['contrastRight'], 'feedback': trials['feedbackType']}
df = pd.DataFrame(data=d, index=range(len(trials['choice']))).fillna(0)
df['feedback'] = df['feedback'].replace(-1, 0)
df['signed_contrast'] = df['contrastR'] - df['contrastL']
df['signed_contrast'] = df['signed_contrast'].map(contrast_to_num)
df['response'] += 1 # this is coded most unintuitely, 0 is rightwards, and 1 is leftwards (which is why I not this variable in other programs)
df['block'] = trials['probabilityLeft']
# df['rt'] = data_dict['response_times'] - data_dict['goCue_times'] # RTODO
return df
misses = []
to_introduce = [2, 3, 4, 5]
amiss = ['UCLA034', 'UCLA036', 'UCLA037', 'PL015', 'PL016', 'PL017', 'PL024', 'NR_0017', 'NR_0019', 'NR_0020', 'NR_0021', 'NR_0027']
fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0]
if fit_type == 'bias':
loading_info = json.load(open("canonical_infos_bias.json", 'r'))
elif fit_type == 'prebias':
loading_info = json.load(open("canonical_infos.json", 'r'))
bwm = ['NYU-11', 'NYU-12', 'NYU-21', 'NYU-27', 'NYU-30', 'NYU-37',
'NYU-39', 'NYU-40', 'NYU-45', 'NYU-46', 'NYU-47', 'NYU-48',
'CSHL045', 'CSHL047', 'CSHL049', 'CSHL051', 'CSHL052', 'CSHL053',
'CSHL054', 'CSHL055', 'CSHL058', 'CSHL059', 'CSHL060', 'UCLA005',
'UCLA006', 'UCLA011', 'UCLA012', 'UCLA014', 'UCLA015', 'UCLA017',
'UCLA033', 'UCLA034', 'UCLA035', 'UCLA036', 'UCLA037', 'KS014',
'KS016', 'KS022', 'KS023', 'KS042', 'KS043', 'KS044', 'KS045',
'KS046', 'KS051', 'KS052', 'KS055', 'KS084', 'KS086', 'KS091',
'KS094', 'KS096', 'DY_008', 'DY_009', 'DY_010', 'DY_011', 'DY_013',
'DY_014', 'DY_016', 'DY_018', 'DY_020', 'PL015', 'PL016', 'PL017',
'PL024', 'SWC_042', 'SWC_043', 'SWC_060', 'SWC_061', 'SWC_066',
'ZFM-01576', 'ZFM-01577', 'ZFM-01592', 'ZFM-01935', 'ZFM-01936',
'ZFM-01937', 'ZFM-02368', 'ZFM-02369', 'ZFM-02370', 'ZFM-02372',
'ZFM-02373', 'ZM_1897', 'ZM_1898', 'ZM_2240', 'ZM_2241', 'ZM_2245',
'ZM_3003', 'SWC_038', 'SWC_039', 'SWC_052', 'SWC_053', 'SWC_054',
'SWC_058', 'SWC_065', 'NR_0017', 'NR_0019', 'NR_0020', 'NR_0021',
'NR_0027', 'ibl_witten_13', 'ibl_witten_17', 'ibl_witten_18',
'ibl_witten_19', 'ibl_witten_20', 'ibl_witten_25', 'ibl_witten_26',
'ibl_witten_27', 'ibl_witten_29', 'CSH_ZAD_001', 'CSH_ZAD_011',
'CSH_ZAD_019', 'CSH_ZAD_022', 'CSH_ZAD_024', 'CSH_ZAD_025',
'CSH_ZAD_026', 'CSH_ZAD_029']
regexp = re.compile(r'canonical_result_((\w|-)+)_prebias.p')
subjects = []
for filename in os.listdir("./multi_chain_saves/"):
if not (filename.startswith('canonical_result_') and filename.endswith('.p')):
continue
result = regexp.search(filename)
if result is None:
continue
subject = result.group(1)
subjects.append(subject)
already_fit = list(loading_info.keys())
subjects = ["KS055"]
# remaining_subs = [s for s in subjects if s not in amiss and s not in already_fit]
# print(remaining_subs)
# extract subject names
subjects = [regexp.search(ds['file_records'][0]['relative_path']).group(1) for ds in datasets]
# reduce to list of unique names
subjects = list(set(subjects))
data_folder = 'session_data'
contrast_to_num = {-1.: 0, -0.5: 1, -0.25: 2, -0.125: 3, -0.0625: 4, 0: 5, 0.0625: 6, 0.125: 7, 0.25: 8, 0.5: 9, 1.: 10}
old_style = False
if old_style:
print("Warning, data can have splits")
data_folder = 'session_data_old'
bias_eids = []
print("#########################################")
print("Waring, rt's removed, find with # RTODO")
print("#########################################")
short_subjs = []
names = []
pre_bias = []
entire_training = []
training_status_reached = []
actually_existing = []
for subject in subjects:
print('_____________________')
print(subject)
# if subject in already_fit or subject in amiss:
# continue
try:
trials = one.load_aggregate('subjects', subject, '_ibl_subjectTrials.table')
trials = one.load_aggregate('subjects', subject, '_ibl_subjectTrials.table')
# Load training status and join to trials table
training = one.load_aggregate('subjects', subject, '_ibl_subjectTraining.table')
trials = (trials
.set_index('session')
.join(training.set_index('session'))
.sort_values(by='session_start_time', kind='stable'))
actually_existing.append(subject)
except:
print("Not working {}".format(subject))
continue
training = one.load_aggregate('subjects', subject, '_ibl_subjectTraining.table')
trials = (trials
.set_index('session')
.join(training.set_index('session'))
.sort_values(by='session_start_time', kind='stable'))
# use np.unique to find unique start_times and eids, each session only has one. Then sort those.
start_times, indices = np.unique(trials.session_start_time, return_index=True)
start_times = [trials.session_start_time[index] for index in sorted(indices)]
task_protocol, indices = np.unique(trials.task_protocol, return_index=True)
task_protocol = [trials.task_protocol[index] for index in sorted(indices)]
nums, indices = np.unique(trials.session_number, return_index=True)
nums = [trials.session_number[index] for index in sorted(indices)]
eids, indices = np.unique(trials.index, return_index=True)
eids = [trials.index[index] for index in sorted(indices)]
print("original # of eids {}".format(len(eids)))
test = [(y, x) for y, x in sorted(zip(start_times, eids))]
pickle.dump(test, open("./{}/{}_session_names.p".format(data_folder, subject), "wb"))
performance = np.zeros(len(eids))
easy_per = np.zeros(len(eids))
hard_per = np.zeros(len(eids))
bias_start = 0
ephys_start = 0
info_dict = {'subject': subject, 'dates': [st.to_pydatetime() for st in start_times], 'eids': eids, 'date_and_session_num': {}}
contrast_set = {0, 1, 9, 10}
rel_count = -1
contrast_set = {0, 1, 9, 10} # starting contrasts
to_introduce = [2, 3, 4, 5] # these contrasts need to be introduced, keep track when that happens
for i, start_time in enumerate(start_times):
rel_count += 1
assert rel_count == i
df = trials[trials.session_start_time == start_time]
df.loc[:, 'contrastRight'] = df.loc[:, 'contrastRight'].fillna(0)
df.loc[:, 'contrastLeft'] = df.loc[:, 'contrastLeft'].fillna(0)
......@@ -163,66 +62,54 @@ for subject in subjects:
if any([df[x].isnull().any() for x in ['signed_contrast', 'choice', 'feedbackType', 'probabilityLeft']]):
quit()
assert len(np.unique(df['session_start_time'])) == 1
# check whether new contrasts got introduced
current_contrasts = set(df['signed_contrast'])
diff = current_contrasts.difference(contrast_set)
for c in to_introduce:
if c in diff:
info_dict[c] = rel_count
info_dict[c] = i
contrast_set.update(diff)
# document performance for plotting
performance[i] = np.mean(df['feedbackType'])
easy_per[i] = np.mean(df['feedbackType'][np.logical_or(df['signed_contrast'] == 0, df['signed_contrast'] == 10)])
hard_per[i] = np.mean(df['feedbackType'][df['signed_contrast'] == 5])
if bias_start == 0 and df.task_protocol[0].startswith('_iblrig_tasks_biasedChoiceWorld'):
bias_start = i
print("bias start {}".format(rel_count))
info_dict['bias_start'] = rel_count
training_status_reached.append(set(df.training_status))
if bias_start < 33:
short_subjs.append(subject)
if 'bias_start' not in info_dict and df.task_protocol[0].startswith('_iblrig_tasks_biasedChoiceWorld'):
info_dict['bias_start'] = i
if ephys_start == 0 and df.task_protocol[0].startswith('_iblrig_tasks_ephysChoiceWorld'):
ephys_start = i
print("ephys start {}".format(rel_count))
info_dict['ephys_start'] = rel_count
if 'ephys_start' not in info_dict and df.task_protocol[0].startswith('_iblrig_tasks_ephysChoiceWorld'):
info_dict['ephys_start'] = i
pickle.dump(df, open("./{}/{}_df_{}.p".format(data_folder, subject, rel_count), "wb"))
info_dict['date_and_session_num'][rel_count] = start_time
info_dict['date_and_session_num'][start_time] = rel_count
pickle.dump(df, open("./{}/{}_df_{}.p".format(data_folder, subject, i), "wb"))
info_dict['date_and_session_num'][i] = start_time
info_dict['date_and_session_num'][start_time] = i
side_info = np.zeros((len(df), 2))
side_info[:, 0] = df['probabilityLeft']
side_info[:, 1] = df['feedbackType']
pickle.dump(side_info, open("./{}/{}_side_info_{}.p".format(data_folder, subject, rel_count), "wb"))
pickle.dump(side_info, open("./{}/{}_side_info_{}.p".format(data_folder, subject, i), "wb"))
fit_info = np.zeros((len(df), 3))
fit_info = np.zeros((len(df), 2))
fit_info[:, 0] = df['signed_contrast']
fit_info[:, 1] = df['choice']
print(len(df))
# fit_info[:, 2] = df['rt'] # RTODO
pickle.dump(fit_info, open("./{}/{}_fit_info_{}.p".format(data_folder, subject, rel_count), "wb"))
pickle.dump(fit_info, open("./{}/{}_fit_info_{}.p".format(data_folder, subject, i), "wb"))
if rel_count == -1:
continue
info_dict['n_sessions'] = i
pickle.dump(info_dict, open("./{}/{}_info_dict.p".format(data_folder, subject), "wb"))
# optional plotting of the evolution of performance across sessions
plt.figure(figsize=(11, 8))
print(performance)
plt.plot(performance, label='Overall')
plt.plot(easy_per, label='100% contrasts')
plt.plot(hard_per, label='0% contrasts')
plt.axvline(bias_start - 0.5)
skip_count = 0
for p in performance:
if p == 0.:
skip_count += 1
else:
break
plt.axvline(info_dict['bias_start'] - 0.5)
for c in to_introduce:
plt.axvline(info_dict[c] + skip_count, ymax=0.85, c='grey')
plt.annotate('Pre-bias', (bias_start / 2, 1.), size=20, ha='center')
plt.annotate('Bias', (bias_start + (i - bias_start) / 2, 1.), size=20, ha='center')
plt.axvline(info_dict[c], ymax=0.85, c='grey')
plt.annotate('Pre-bias', (info_dict['bias_start'] / 2, 1.), size=20, ha='center')
plt.annotate('Bias', (info_dict['bias_start'] + (i - info_dict['bias_start']) / 2, 1.), size=20, ha='center')
plt.title(subject, size=22)
plt.ylabel('Performance', size=22)
plt.xlabel('Session', size=22)
......@@ -230,21 +117,7 @@ for subject in subjects:
plt.xticks(size=16)
plt.ylim(bottom=0)
plt.xlim(left=0)
sns.despine()
plt.tight_layout()
if not old_style:
plt.savefig('./figures/behavior/all_of_trainig_{}'.format(subject))
plt.close()
# print(bias_eids)
pre_bias.append(info_dict['bias_start'])
entire_training.append(rel_count + 1)
info_dict['n_sessions'] = rel_count
pickle.dump(info_dict, open("./{}/{}_info_dict.p".format(data_folder, subject), "wb"))
print(misses)
print(short_subjs)
print(pre_bias)
print(entire_training)
plt.savefig('./figures/behavior/all_of_trainig_{}'.format(subject))
plt.show()
import os
from one.api import ONE
import numpy as np
import pandas as pd
one = ONE()
for subject in ['fip_15']: # os.listdir("./fibre_data/"):
eids, sess_info = one.search(subject=subject, date_range=['2015-01-01', '2025-01-01'], details=True)
start_times = [sess['date'] for sess in sess_info]
protocols = [sess['task_protocol'] for sess in sess_info]
nums = [sess['number'] for sess in sess_info]
# print("original # of eids {}".format(len(eids)))
test = [(y, x) for y, x in sorted(zip(start_times, eids))]
eids = [x for _, x in sorted(zip(start_times, eids))]
dates = [x for x, _ in sorted(zip(start_times, eids))]
nums = [x for _, x in sorted(zip(start_times, nums))]
start_times = sorted(start_times)
prev_date = None
prev_num = -1
fixed_dates = []
fixed_eids = []
fixed_start_times = []
additional_eids = []
for d, e, n, st in zip(dates, eids, nums, start_times):
if d != prev_date:
fixed_dates.append(d)
fixed_eids.append(e)
additional_eids.append([])
fixed_start_times.append(st)
else:
assert n > prev_num
if n == 1:
additional_eids.append([e])
elif n > 1:
additional_eids[-1].append(e)
prev_date = d
prev_num = n
direct_dates = sorted(os.listdir("./fibre_data/" + subject))
if len(fixed_dates) != len(direct_dates):
print(len(fixed_dates), len(direct_dates))
rel_count = 0
effective_dates = []
for i, (eid, extra_eids, start_time) in enumerate(zip(fixed_eids, additional_eids, sorted(fixed_start_times))):
try:
trials = one.load_object(eid, 'trials')
if 'choice' not in trials:
continue
except Exception as e:
print(e, 'skipped session')
continue
if trials['probabilityLeft'] is None: # originally also "or df is None", if this gets reintroduced, probably need to download all data again
if rel_count != -1:
print('lost session, problem')
misses.append((subject, i))
continue
rel_count += 1
effective_dates.append(start_time)
print(start_time, "rel_count {}".format(rel_count), len(trials.choice), np.unique(trials.choice, return_counts=1))
else:
for d1, d2 in zip(fixed_dates, direct_dates):
assert str(d1) == d2
print(subject, 'good')
continue
for date in sorted(os.listdir("./fibre_data/" + subject)):
print(subject, date)
quit()
import pickle
import json
fails = ['KS022', 'ibl_witten_29', 'CSHL_015', 'UCLA017', 'KS091', 'KS023', 'KS055', 'CSHL_014', 'CSHL049', 'PL024', 'NYU-11']
for subject in fails:
info_dict = pickle.load(open("./{}/{}_info_dict.p".format('session_data', subject), "rb"))
info_dict['dates'] = [d.isoformat() for d in info_dict['dates']]
del info_dict['date_and_session_num']
json.dump(info_dict, open("./{}/{}_info_dict.json".format('session_data', subject), 'w'))
# datetime.datetime.fromisoformat(info_dict['dates'][0].isoformat())
\ No newline at end of file
from one.api import ONE
import numpy as np
import pandas as pd
import json
import os
one = ONE()
try:
os.mkdir("./fibre_data")
except OSError as error:
print(error)
subjects = ["fip_{}".format(i) for i in list(range(13, 17)) + list(range(26, 43))]
for subject in subjects:
print('_____________________')
print(subject)
try:
os.mkdir("./fibre_data/" + subject)
except OSError as error:
print(error)
eids, sess_info = one.search(subject=subject, date_range=['2015-01-01', '2025-01-01'], details=True)
for eid, sess in zip(eids, sess_info):
try:
fp = one.load_object(eid, 'fpData')
photometry = one.load_object(eid, 'photometry')
trials = one.load_object(eid, 'trials')
trials['intervals_1'] = trials['intervals'][:, 0]
trials['intervals_2'] = trials['intervals'][:, 1]
del trials['intervals']
except Exception as e:
print(e)
continue
try:
os.mkdir("./fibre_data/" + subject + '/' + str(sess['date']))
except OSError as error:
print(error)
fp.raw.to_parquet("./fibre_data/" + subject + '/' + str(sess['date']) + '/' + 'fpData_{}.pqt'.format(subject))
photometry.signal.to_parquet("./fibre_data/" + subject + '/' + str(sess['date']) + '/' + 'photometry_{}.pqt'.format(subject))
pd.DataFrame(trials).to_parquet("./fibre_data/" + subject + '/' + str(sess['date']) + '/' + 'trials_{}.pqt'.format(subject))
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
......@@ -18,7 +18,6 @@ from itertools import product
import json
import sys
save_things = False
def eleven2nine(x):
"""Map from 11 possible contrasts to 9, for the non-training phases.
......@@ -35,7 +34,7 @@ def eleven2nine(x):
def eval_cross_val(models, data, unmasked_data, n_all_states):
"""Eval cross_val."""
"""Eval cross_validation performance on held-out datapoints of an instantiated model"""
lls = np.zeros(len(models))
cross_val_n = 0
for sess_time, (d, full_data) in enumerate(zip(data, unmasked_data)):
......@@ -60,39 +59,27 @@ num_to_contrast = {v: k for k, v in contrast_to_num.items()}
cont_mapping = np.vectorize(num_to_contrast.get)
data_folder = 'session_data'
old_style = False
if old_style:
print("Warning, data can have splits")
print("Sure you want to use old data?")
temp = input()
if temp:
data_folder = 'session_data_old'
else:
quit()
# available:
# not great: ibl_witten_18
subjects = ['ibl_witten_15', 'ibl_witten_17', 'ibl_witten_18', 'ibl_witten_19',
'SWC_021', 'ZM_1897', 'KS017', 'KS019', 'NYU-06', 'SWC_023', 'ibl_witten_13',
'KS022', 'KS023', 'CSHL_014', 'CSHL_015', 'CSHL_018', 'CSHL_020', 'CSH_ZAD_001', 'CSH_ZAD_011',
'CSH_ZAD_017', 'CSH_ZAD_025', 'CSH_ZAD_026', 'CSHL049', 'CSHL051', 'CSHL061']
# test subjects:
subjects = ['fip_33']
cv_nums = [15]
cv_nums = [200 + int(sys.argv[1]) % 16]
subjects = [subjects[int(sys.argv[1]) // 16]]
subjects = ['KS014']
num_subjects = len(subjects)
subjects = [a for a in subjects for i in range(3)]
# seeds = [505, 506, 507, 505, 506, 508, 509, 506, 507, 508, 505, 506, 508, 509, 505, 506, 507, 508, 509, 506, 507, 508, 505, 506, 507, 508, 505, 506, 507, 508, 505, 506, 507, 506, 507, 508, 506, 507, 508, 506, 507, 508]
seeds = list(range(506, 509))
seeds = seeds * num_subjects
# lst = list(range(10))
# cv_nums = [a for a in lst for i in range(8)]
cv_nums = [0] * 3 * num_subjects
seeds = [seeds[int(sys.argv[1])]]
cv_nums = [cv_nums[int(sys.argv[1])]]
subjects = [subjects[int(sys.argv[1])]]
print(cv_nums)
print(subjects)
for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
# if loop_count_i > 8:
# if loop_count_i <= 8 or loop_count_i > 16:
# if loop_count_i <= 16:
# continue
params = {}
for loop_count_i, (s, cv_num, seed) in enumerate(zip(subjects, cv_nums, seeds)):
params = {} # save parameters in a dictionary to save later
params['subject'] = s
params['cross_val_num'] = cv_num
params['fit_variance'] = 0.03
......@@ -101,8 +88,7 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
params['regressors'] = [all_regressors[i] for i in [0, 1, 3, 6]]
# default (non-iteration) settings:
params['fit_type'] = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][2]
# params['fit_variance'] = [0.0005, 0.002, 0.008, 0.02, 0.06, 0.1, 0.3, 0.6, 1., 2.4, 10, 16, 30, 'uniform'][6]
params['fit_type'] = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0]
if 'prevA' in params['regressors'] or 'weighted_prevA' in params['regressors']:
params['exp_decay'], params['exp_length'] = 0.3, 5
params['exp_filter'] = np.exp(- params['exp_decay'] * np.arange(params['exp_length']))
......@@ -110,6 +96,7 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
print(params['exp_filter'])
params['dur'] = 'yes'
params['obs_dur'] = ['glm', 'cat'][0]
# more obscure params:
params['gamma'] = None # 0.005
params['alpha'] = None # 1
......@@ -121,40 +108,37 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
params['gamma_b_0'] = 1000
params['init_var'] = 8
params['init_mean'] = np.zeros(len(params['regressors']))
# normal:
r_support = np.cumsum(np.arange(5, 100, 5))
r_support = np.arange(5, 705, 4)
params['dur_params'] = dict(r_support=r_support,
r_probs=np.ones(len(r_support))/len(r_support), alpha_0=1, beta_0=1)
# params['dur_params'] = dict(r_support=np.array([1, 2, 3, 5, 7, 10, 15, 21, 28, 36, 45, 55, 150]),
# r_probs=np.ones(13)/13., alpha_0=1, beta_0=1)
# params['dur_params'] = dict(r_support=np.arange(1, 251),
# r_probs=np.ones(250)/250., alpha_0=1, beta_0=1)
params['alpha_a_0'] = 0.1
params['alpha_b_0'] = 10
# trying a smaller value here, should lower the appearance of ephemeral new states on session bounds
# hope this doesn't make real new states at session bound less likely, or hurt mixing...
params['init_state_concentration'] = 3
# cat params
# Parameter needed if one uses a Categorical distribution
params['conditioned_on'] = 'nothing'
params['cross_val'] = False
params['cross_val_fold'] = 10
params['CROSS_VAL_SEED'] = 4 # Do not change this, it's 4
params['seed'] = 100 + params['cross_val_num']
params['seed'] = seed
params['n_states'] = 15
params['n_samples'] = 20000 if params['obs_dur'] == 'glm' else 4000
params['n_samples'] = 60000 if params['obs_dur'] == 'glm' else 12000
if params['cross_val']:
params['n_samples'] = 4000
params['n_samples'] = 12000
if s.startswith("GLM_Sim"):
print("reduced sample size")
params['n_samples'] = 14000
params['n_samples'] = 48000
print(params['n_samples'])
# now actual fit:
# new start names: uniform_start_, bias_fraction_, small_gamma_, high_init_, non_semi_, non_semi_normal_init_, correct_sol_, correct_sol_semi_
# find a unique identifier to save this fit
while True:
folder = "./dynamic_GLMiHMM_crossvals/"
rand_id = np.random.randint(1000)
......@@ -167,8 +151,7 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
if not os.path.isfile(folder + id + '_0.p'):
break
# create placeholder dataset for rand_id purposes
if save_things:
pickle.dump(params, open(folder + id + '_0.p', 'wb'))
pickle.dump(params, open(folder + id + '_0.p', 'wb'))
if params['obs_dur'] == 'glm':
print(params['regressors'])
else:
......@@ -191,53 +174,53 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
models = []
if params['obs_dur'] == 'glm':
n_regressors = len(params['regressors'])
n_inputs = len(params['regressors'])
T = till_session - from_session + (params['fit_type'] != 'prebias')
obs_hypparams = {'n_regressors': n_regressors, 'T': T, 'jumplimit': params['jumplimit'], 'prior_mean': params['init_mean'],
'P_0': params['init_var'] * np.eye(n_regressors), 'Q': params['fit_variance'] * np.tile(np.eye(n_regressors), (T, 1, 1))}
obs_hypparams = {'n_inputs': n_inputs, 'T': T, 'jumplimit': params['jumplimit'], 'prior_mean': params['init_mean'],
'P_0': params['init_var'] * np.eye(n_inputs), 'Q': params['fit_variance'] * np.tile(np.eye(n_inputs), (T, 1, 1))}
obs_distns = [distributions.Dynamic_GLM(**obs_hypparams) for state in range(params['n_states'])]
else:
n_regressors = 9 if params['fit_type'] == 'bias' else 11
obs_hypparams = {'n_regressors': n_regressors * (1 + (params['conditioned_on'] != 'nothing')), 'n_outputs': 2, 'T': till_session - from_session + (params['fit_type'] != 'prebias'),
n_inputs = 9 if params['fit_type'] == 'bias' else 11
obs_hypparams = {'n_inputs': n_inputs * (1 + (params['conditioned_on'] != 'nothing')), 'n_outputs': 2, 'T': till_session - from_session + (params['fit_type'] != 'prebias'),
'jumplimit': params['jumplimit'], 'sigmasq_states': params['fit_variance']}
obs_distns = [Dynamic_Input_Categorical(**obs_hypparams) for state in range(params['n_states'])]
dur_distns = [distributions.NegativeBinomialIntegerR2Duration(**params['dur_params']) for state in range(params['n_states'])]
# Select correct model
if params['dur'] == 'yes':
if params['gamma'] is None:
posteriormodel = pyhsmm.models.WeakLimitHDPHSMM(
# https://math.stackexchange.com/questions/449234/vague-gamma-prior
alpha_a_0=params['alpha_a_0'], alpha_b_0=params['alpha_b_0'], # TODO: gamma vs alpha? gamma steers state number
alpha_a_0=params['alpha_a_0'], alpha_b_0=params['alpha_b_0'], # gamma steers state number
gamma_a_0=params['gamma_a_0'], gamma_b_0=params['gamma_b_0'],
init_state_concentration=params['init_state_concentration'],
obs_distns=obs_distns,
dur_distns=dur_distns,
var_prior=params['fit_variance']) # TODO: I don't think this does anything
var_prior=params['fit_variance'])
else:
posteriormodel = pyhsmm.models.WeakLimitHDPHSMM(
# https://math.stackexchange.com/questions/449234/vague-gamma-prior
alpha=params['alpha'], # TODO: gamma vs alpha? gamma steers state number
alpha=params['alpha'],
gamma=params['gamma'],
init_state_concentration=params['init_state_concentration'],
obs_distns=obs_distns,
dur_distns=dur_distns,
var_prior=params['fit_variance']) # TODO: I don't think this does anything
var_prior=params['fit_variance'])
else:
if params['gamma'] is None:
posteriormodel = pyhsmm.models.WeakLimitHDPHMM(
alpha_a_0=params['alpha_a_0'], alpha_b_0=params['alpha_b_0'], # TODO: gamma vs alpha? gamma steers state number
alpha_a_0=params['alpha_a_0'], alpha_b_0=params['alpha_b_0'],
gamma_a_0=params['gamma_a_0'], gamma_b_0=params['gamma_b_0'],
init_state_concentration=params['init_state_concentration'],
obs_distns=obs_distns,
var_prior=params['fit_variance']) # TODO: I don't think this does anything
var_prior=params['fit_variance'])
else:
posteriormodel = pyhsmm.models.WeakLimitHDPHMM(
alpha=params['alpha'], # TODO: gamma vs alpha? gamma steers state number
alpha=params['alpha'],
gamma=params['gamma'],
init_state_concentration=params['init_state_concentration'],
obs_distns=obs_distns,
var_prior=params['fit_variance']) # TODO: I don't think this does anything
var_prior=params['fit_variance'])
print(from_session, till_session + (params['fit_type'] != 'prebias'))
......@@ -261,7 +244,7 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
mask[0] = False
if params['fit_type'] == 'zoe_style':
mask[90:] = False
mega_data = np.empty((np.sum(mask), n_regressors + 1))
mega_data = np.empty((np.sum(mask), n_inputs + 1))
for i, reg in enumerate(params['regressors']):
# positive numbers are contrast on the right
......@@ -295,7 +278,7 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
# bias is now always active
mega_data[:, i] = 1
mega_data[:, -1] = data[mask, 1] / 2 # 0 is rightwards, and 1 is leftwards, because the original data is weird
mega_data[:, -1] = data[mask, 1] / 2
elif params['obs_dur'] == 'cat':
mask = data[:, 1] != 1
mask[0] = False
......@@ -312,10 +295,9 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
posteriormodel.add_data(mega_data)
# if not os.path.isfile('./{}/data_save_{}.p'.format(data_folder, params['subject'])):
if save_things:
if not os.path.isfile('./{}/data_save_{}.p'.format(data_folder, params['subject'])):
pickle.dump(data_save, open('./{}/data_save_{}.p'.format(data_folder, params['subject']), 'wb'))
# states_solution = pickle.load(open("states_{}_{}_condition_{}_{}.p".format('DY_013', 'all', 'nothing', '0_01'), 'rb')) # todo: remove!
time_save = time.time()
likes = np.zeros(params['n_samples'])
with warnings.catch_warnings(): # ignore the scipy warning
......@@ -336,14 +318,12 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
model_save.delete_dur_data()
models.append(model_save)
# save something in case of crash
# save unfinished results
if j % 400 == 0 and j > 0:
if params['n_samples'] <= 4000:
if save_things:
pickle.dump(models, open(folder + id + '.p', 'wb'))
pickle.dump(models, open(folder + id + '.p', 'wb'))
else:
if save_things:
pickle.dump(models, open(folder + id + '_{}.p'.format(j // 4001), 'wb'))
pickle.dump(models, open(folder + id + '_{}.p'.format(j // 4001), 'wb'))
if j % 4000 == 0:
models = []
print(time.time() - time_save)
......@@ -361,12 +341,9 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
params['ll'] = likes.tolist()
params['init_mean'] = params['init_mean'].tolist()
if params['cross_val']:
if save_things:
json.dump(params, open(folder + "infos_new/" + '{}_{}_cvll_{}_{}_{}_{}_{}.json'.format(params['subject'], params['cross_val_num'], str(np.round(lls_mean, 3)).replace('.', '_'),
json.dump(params, open(folder + "infos_new/" + '{}_{}_cvll_{}_{}_{}_{}_{}.json'.format(params['subject'], params['cross_val_num'], str(np.round(lls_mean, 3)).replace('.', '_'),
params['fit_type'], params['fit_variance'], params['seed'], rand_id), 'w'))
else:
if save_things:
json.dump(params, open(folder + "infos_new/" + '{}_{}_{}_{}_{}.json'.format(params['subject'], params['fit_type'],
json.dump(params, open(folder + "infos_new/" + '{}_{}_{}_{}_{}.json'.format(params['subject'], params['fit_type'],
params['fit_variance'], params['seed'], rand_id), 'w'))
if save_things:
pickle.dump(models, open(folder + id + '_{}.p'.format(j // 4001), 'wb'))
pickle.dump(models, open(folder + id + '_{}.p'.format(j // 4001), 'wb'))
"""
Generate data from a simulated mouse-GLM.
This mouse uses 2 states throughout
They start a certain distance apart, states alternate during the session, duration is negative binomial
"""
import numpy as np
import matplotlib.pyplot as plt
import pickle
from scipy.stats import nbinom
print('BIAS NOT UPDATED')
quit()
subject = 'CSHL059'
new_name = 'GLM_Sim_03_5'
seed = 4
print(new_name)
np.random.seed(seed)
info_dict = pickle.load(open("../session_data/{}_info_dict.p".format(subject), "rb"))
assert info_dict['subject'] == subject
info_dict['subject'] = new_name
pickle.dump(info_dict, open("../session_data/{}_info_dict.p".format(new_name), "wb"))
till_session = info_dict['n_sessions']
from_session = info_dict['bias_start']
GLM_weights = [np.array([-4.5, 4.3, 0, 0., 1.2, -0.7]),
np.array([-4.5, 3.2, 0, 1.3, 0.3, -1.5])]
contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0.302: 6, 0.555: 7, 0.848: 8, 0.987: 9, 1.: 10}
num_to_contrast = {v: k for k, v in contrast_to_num.items()}
state_posterior = np.zeros((till_session + 1 - from_session, 2))
for k, j in enumerate(range(from_session, till_session + 1)):
data = pickle.load(open("../session_data/{}_fit_info_{}.p".format(subject, j), "rb"))
side_info = pickle.load(open("../session_data/{}_side_info_{}.p".format(subject, j), "rb"))
contrasts = np.vectorize(num_to_contrast.get)(data[:, 0])
predictors = np.zeros(5)
state_plot = np.zeros(2)
count = 0
curr_state = int(np.random.rand() > 0.5)
dur = nbinom.rvs(21 + curr_state * 15, 0.3)
prev_choice = 2 * int(np.random.rand() > 0.5) - 1
if (contrasts[0] < 0) == (prev_choice > 0):
prev_reward = 1.
elif contrasts[0] > 0:
prev_reward = 0.
else:
prev_reward = int(np.random.rand() > 0.5)
prev_side = - prev_choice if prev_reward == 0. else prev_choice
data[0, 1] = prev_choice + 1
side_info[0, 1] = prev_reward
for i, c in enumerate(contrasts[1:]):
predictors[0] = max(c, 0)
predictors[1] = abs(min(c, 0))
predictors[2] = prev_choice
predictors[3] = prev_reward
predictors[4] = prev_side
data[i+1, 1] = 2 * (np.random.rand() < 1 / (1 + np.exp(- np.sum(GLM_weights[curr_state][:-1] * predictors) - GLM_weights[curr_state][-1])))
state_plot[curr_state] += 1
if dur == 0:
curr_state = (curr_state + 1) % 2
dur = nbinom.rvs(21 + curr_state * 15, 0.3)
else:
dur -= 1
prev_choice = data[i+1, 1] - 1
if (c < 0) == (prev_choice > 0):
prev_reward = 1.
elif c > 0:
prev_reward = 0.
else:
prev_reward = int(np.random.rand() > side_info[i+1, 0])
prev_side = - prev_choice if prev_reward == 0. else prev_choice
side_info[i+1, 1] = prev_reward
state_posterior[k] = state_plot / len(data[:, 0])
pickle.dump(data, open("../session_data/{}_fit_info_{}.p".format(new_name, j), "wb"))
pickle.dump(side_info, open("../session_data/{}_side_info_{}.p".format(new_name, j), "wb"))
plt.figure(figsize=(16, 9))
for s in range(2):
plt.fill_between(range(till_session + 1 - from_session), s - state_posterior[:, s] / 2, s + state_posterior[:, s] / 2)
plt.show()
"""
Generate data from a simulated mouse-GLM.
This mouse uses 4 states throughout
They start a certain distance apart, states alternate during the session, duration is negative binomial
"""
import numpy as np
import matplotlib.pyplot as plt
import pickle
from scipy.stats import nbinom
subject = 'CSHL059'
new_name = 'GLM_Sim_04'
seed = 5
print('BIAS NOT UPDATED')
quit()
print(new_name)
np.random.seed(seed)
info_dict = pickle.load(open("../session_data/{}_info_dict.p".format(subject), "rb"))
assert info_dict['subject'] == subject
info_dict['subject'] = new_name
pickle.dump(info_dict, open("../session_data/{}_info_dict.p".format(new_name), "wb"))
till_session = info_dict['n_sessions']
from_session = info_dict['bias_start']
GLM_weights = [np.array([-4.5, 4.3, 0, 0., 1.2, -0.7]),
np.array([-4.5, 3.2, 1.3, 0., 0.3, -1.5]),
np.array([-1, 1.2, 2.1, 0.5, 0.1, 1]),
np.array([-0.1, 0.2, 0, 0.1, 3.9, -1.])]
contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0.302: 6, 0.555: 7, 0.848: 8, 0.987: 9, 1.: 10}
num_to_contrast = {v: k for k, v in contrast_to_num.items()}
state_posterior = np.zeros((till_session + 1 - from_session, 4))
for k, j in enumerate(range(from_session, till_session + 1)):
data = pickle.load(open("../session_data/{}_fit_info_{}.p".format(subject, j), "rb"))
side_info = pickle.load(open("../session_data/{}_side_info_{}.p".format(subject, j), "rb"))
contrasts = np.vectorize(num_to_contrast.get)(data[:, 0])
predictors = np.zeros(5)
state_plot = np.zeros(4)
count = 0
curr_state = np.random.choice(4)
dur = nbinom.rvs(21 + (curr_state % 2) * 15, 0.3)
prev_choice = 2 * int(np.random.rand() > 0.5) - 1
if (contrasts[0] < 0) == (prev_choice > 0):
prev_reward = 1.
elif contrasts[0] > 0:
prev_reward = 0.
else:
prev_reward = int(np.random.rand() > 0.5)
prev_side = - prev_choice if prev_reward == 0. else prev_choice
data[0, 1] = prev_choice + 1
side_info[0, 1] = prev_reward
for i, c in enumerate(contrasts[1:]):
predictors[0] = max(c, 0)
predictors[1] = abs(min(c, 0))
predictors[2] = prev_choice
predictors[3] = prev_reward
predictors[4] = prev_side
data[i+1, 1] = 2 * (np.random.rand() < 1 / (1 + np.exp(- np.sum(GLM_weights[curr_state][:-1] * predictors) - GLM_weights[curr_state][-1])))
state_plot[curr_state] += 1
if dur == 0:
while True:
next_state = np.random.choice(4)
if next_state != curr_state:
curr_state = next_state
break
dur = nbinom.rvs(21 + (curr_state % 2) * 15, 0.3)
else:
dur -= 1
prev_choice = data[i+1, 1] - 1
if (c < 0) == (prev_choice > 0):
prev_reward = 1.
elif c > 0:
prev_reward = 0.
else:
prev_reward = int(np.random.rand() > side_info[i+1, 0])
prev_side = - prev_choice if prev_reward == 0. else prev_choice
side_info[i+1, 1] = prev_reward
state_posterior[k] = state_plot / len(data[:, 0])
pickle.dump(data, open("../session_data/{}_fit_info_{}.p".format(new_name, j), "wb"))
pickle.dump(side_info, open("../session_data/{}_side_info_{}.p".format(new_name, j), "wb"))
plt.figure(figsize=(16, 9))
for s in range(4):
plt.fill_between(range(till_session + 1 - from_session), s - state_posterior[:, s] / 2, s + state_posterior[:, s] / 2)
plt.show()
"""
Generate data from a simulated mouse-GLM.
This mouse uses 4 states throughout
They start a certain distance apart, states alternate during the session, duration is negative binomial
"""
import numpy as np
import matplotlib.pyplot as plt
import pickle
from scipy.stats import nbinom
subject = 'CSHL059'
new_name = 'GLM_Sim_05'
seed = 6
print(new_name)
np.random.seed(seed)
info_dict = pickle.load(open("../session_data/{}_info_dict.p".format(subject), "rb"))
assert info_dict['subject'] == subject
info_dict['subject'] = new_name
pickle.dump(info_dict, open("../session_data/{}_info_dict.p".format(new_name), "wb"))
till_session = info_dict['bias_start']
from_session = 0
GLM_weights = [np.array([-4.5, 4.3, 0., 1.2, -0.7]),
np.array([-4.5, 3.2, 1.3, 0.3, -1.5]),
np.array([-1, 1.2, 2.1, 0.1, 1]),
np.array([-0.1, 0.2, 0., 3.9, -1.])]
GLM_weights = list(reversed(GLM_weights))
contrast_to_num = {-1.: 0, -0.987: 1, -0.848: 2, -0.555: 3, -0.302: 4, 0.: 5, 0.302: 6, 0.555: 7, 0.848: 8, 0.987: 9, 1.: 10}
num_to_contrast = {v: k for k, v in contrast_to_num.items()}
state_posterior = np.zeros((till_session + 1 - from_session, 4))
for k, j in enumerate(range(from_session, till_session + 1)):
data = pickle.load(open("../session_data/{}_fit_info_{}.p".format(subject, j), "rb"))
side_info = pickle.load(open("../session_data/{}_side_info_{}.p".format(subject, j), "rb"))
contrasts = np.vectorize(num_to_contrast.get)(data[:, 0])
predictors = np.zeros(5)
state_plot = np.zeros(4)
count = 0
curr_state = j // 5
prev_choice = 2 * int(np.random.rand() > 0.5) - 1
if (contrasts[0] < 0) == (prev_choice > 0):
prev_reward = 1.
elif contrasts[0] > 0:
prev_reward = 0.
else:
prev_reward = int(np.random.rand() > 0.5)
prev_side = - prev_choice if prev_reward == 0. else prev_choice
data[0, 1] = prev_choice + 1
side_info[0, 1] = prev_reward
for i, c in enumerate(contrasts[1:]):
predictors[0] = max(c, 0)
predictors[1] = abs(min(c, 0))
predictors[2] = prev_choice
predictors[3] = prev_side
predictors[4] = 1
data[i+1, 1] = 2 * (np.random.rand() < 1 / (1 + np.exp(- np.sum(GLM_weights[curr_state] * predictors))))
state_plot[curr_state] += 1
prev_choice = data[i+1, 1] - 1
if (c < 0) == (prev_choice > 0):
prev_reward = 1.
elif c > 0:
prev_reward = 0.
else:
prev_reward = int(np.random.rand() > side_info[i+1, 0])
prev_side = - prev_choice if prev_reward == 0. else prev_choice
side_info[i+1, 1] = prev_reward
state_posterior[k] = state_plot / len(data[:, 0])
pickle.dump(data, open("../session_data/{}_fit_info_{}.p".format(new_name, j), "wb"))
pickle.dump(side_info, open("../session_data/{}_side_info_{}.p".format(new_name, j), "wb"))
plt.figure(figsize=(16, 9))
for s in range(4):
plt.fill_between(range(till_session + 1 - from_session), s - state_posterior[:, s] / 2, s + state_posterior[:, s] / 2)
plt.savefig('states_05')
plt.show()