Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision

Target

Select target project
  • sbruijns/ihmm_behav_states
1 result
Select Git revision
Show changes
Commits on Source (2)
No preview for this file type
No preview for this file type
......@@ -21,111 +21,197 @@ def pmf_type(pmf):
return 2
if __name__ == "__main__":
state_types_interpolation = pickle.load(open("state_types_interpolation.p", 'rb'))
state_types_interpolation = state_types_interpolation / state_types_interpolation.max() * 100
fs = 18
# plt.plot(np.linspace(0, 1, 150), state_types_interpolation[0], color=type2color[0])
# plt.ylabel("% of type across population", size=fs)
# plt.xlabel("Interpolated session time", size=fs)
# plt.ylim(0, 100)
# sns.despine()
# plt.tight_layout()
# plt.savefig("type hist 1")
# plt.show()
#
# plt.plot(np.linspace(0, 1, 150), state_types_interpolation[1], color=type2color[1])
# plt.ylabel("% of type across population", size=fs)
# plt.xlabel("Interpolated session time", size=fs)
# plt.ylim(0, 100)
# sns.despine()
# plt.tight_layout()
# plt.savefig("type hist 2")
# plt.show()
#
# plt.plot(np.linspace(0, 1, 150), state_types_interpolation[2], color=type2color[2])
# plt.ylabel("% of type across population", size=fs)
# plt.xlabel("Interpolated session time", size=fs)
# plt.ylim(0, 100)
# sns.despine()
# plt.tight_layout()
# plt.savefig("type hist 3")
# plt.show()
all_first_pmfs_typeless = pickle.load(open("all_first_pmfs_typeless.p", 'rb'))
all_pmfs = pickle.load(open("all_pmfs.p", 'rb'))
all_bias_flips = pickle.load(open("all_bias_flips.p", 'rb'))
# bias flips
plt.hist(all_bias_flips, bins=np.arange(0, max(all_bias_flips) + 1), color='grey', align='left')
plt.ylabel("# of mice")
plt.xlabel("Bias flips")
# plt.hist(all_bias_flips, bins=np.arange(0, max(all_bias_flips) + 1), color='grey', align='left')
# plt.ylabel("# of mice")
# plt.xlabel("Bias flips")
# sns.despine()
# plt.tight_layout()
# plt.savefig("./meeting_figures/bias_flips.png")
# plt.show()
# fewer_states_side = []
# for key in all_first_pmfs_typeless:
# animal_biases = np.zeros(2)
# for defined_points, pmf in all_first_pmfs_typeless[key]:
# bias = np.mean(pmf[defined_points])
# if bias > 0.55:
# animal_biases[0] += 1
# elif bias < 0.45:
# animal_biases[1] += 1
# fewer_states_side.append(np.min(animal_biases / animal_biases.sum()))
# plt.hist(fewer_states_side)
# plt.title("Mixed biases")
# plt.ylabel("# of mice")
# plt.xlabel("min(% left biased states, % right biased states)")
# sns.despine()
# plt.tight_layout()
# plt.savefig("./meeting_figures/proportion_other_bias")
# plt.show()
# total_counter = 0
# bias_counter = 0
# tendency_counter = 0
# lapse_counter = 0
# for pmf in all_pmfs:
# max_b, min_b = 0, 1
# max_tendency, min_tendency = 0, 1
# max_lapse_diff, min_lapse_diff = 0, 1
# for p in pmf[1]:
# if pmf[0][5]: # if this part of the pmf is defined, just take it
# bias = p[5]
# deviation = 0
# while True: # just take the closest thing
# deviation += 1
# if pmf[0][5 - deviation] and pmf[0][5 + deviation]:
# bias = (p[5 - deviation] + p[5 + deviation]) / 2
# break
# max_b = max(max_b, bias)
# min_b = min(min_b, bias)
# max_tendency = max(max_tendency, np.mean(p[pmf[0]]))
# min_tendency = min(min_tendency, np.mean(p[pmf[0]]))
# max_lapse_diff = max(max_lapse_diff, p[0] + p[-1] - 1)
# min_lapse_diff = min(min_lapse_diff, p[0] + p[-1] - 1)
# bias_changed = max_b > 0.55 and min_b < 0.45
# tendency_changed = max_tendency > 0.55 and min_tendency < 0.45
# lapse_changed = max_lapse_diff > 0.1 and min_lapse_diff < -0.1
# bias_counter += bias_changed
# tendency_counter += tendency_changed
# lapse_counter += lapse_changed
# if bias_changed or tendency_changed or lapse_changed:
# total_counter += 1
# for p in pmf[1]:
# plt.plot(np.where(pmf[0])[0], p[pmf[0]])
# plt.title("bias: {}, tendency: {}, lapse: {}".format(bias_changed, tendency_changed, lapse_changed))
# plt.ylim(0, 1)
# plt.axvline(5, color='grey')
# plt.axhline(0.5, color='grey')
# plt.savefig("./meeting_figures/bias_change_{}".format(total_counter))
# plt.close()
# print("Bias counters")
# print(total_counter)
# print(bias_counter)
# print(tendency_counter)
# print(lapse_counter)
# pmf_ranges = []
# for key in all_first_pmfs_typeless:
# for defined_points, pmf in all_first_pmfs_typeless[key]:
# if pmf_type(pmf) == 2:
# pmf_ranges.append(pmf[-1] - pmf[0])
# # if pmf_ranges[-1] < 0.6:
# # plt.plot(pmf)
# # plt.title(pmf_ranges[-1])
# # plt.ylim(0, 1)
# # plt.show()
# plt.hist(pmf_ranges, bins=40)
# plt.title("Type 2 ranges")
# plt.ylabel("# of type 2 states")
# plt.xlabel("PMF range")
# plt.show()
# lapses = []
# for key in all_first_pmfs_typeless:
# for defined_points, pmf in all_first_pmfs_typeless[key]:
# if pmf_type(pmf) != 0:
# lapses.append(max(pmf[0], 1 - pmf[-1]))
# plt.hist(lapses, bins=40)
# plt.title("Higher lapse rate of type != 0")
# plt.ylabel("# of non-type-0 states")
# plt.xlabel("Higher lapse rate")
# plt.show()
lw = 4
# Simplex example pmfs
state_num = 7
defined_points, pmf = all_first_pmfs_typeless['NYU-06'][state_num][0], all_first_pmfs_typeless['NYU-06'][state_num][1]
plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
state_num = 6
defined_points, pmf = all_first_pmfs_typeless['CSHL061'][state_num][0], all_first_pmfs_typeless['CSHL061'][state_num][1]
plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
plt.ylim(0, 1)
plt.xlim(0, 10)
plt.ylabel("P(rightwards)", size=32)
plt.xlabel("Contrast", size=32)
plt.yticks([0, 1], size=27)
plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=27)
sns.despine()
plt.tight_layout()
plt.savefig("./meeting_figures/bias_flips.png")
plt.savefig("example type 1")
plt.show()
quit()
fewer_states_side = []
for key in all_first_pmfs_typeless:
animal_biases = np.zeros(2)
for defined_points, pmf in all_first_pmfs_typeless[key]:
bias = np.mean(pmf[defined_points])
if bias > 0.55:
animal_biases[0] += 1
elif bias < 0.45:
animal_biases[1] += 1
fewer_states_side.append(np.min(animal_biases / animal_biases.sum()))
plt.hist(fewer_states_side)
plt.title("Mixed biases")
plt.ylabel("# of mice")
plt.xlabel("min(% left biased states, % right biased states)")
state_num = 1
defined_points, pmf = all_first_pmfs_typeless['CSHL_018'][state_num][0], all_first_pmfs_typeless['CSHL_018'][state_num][1]
plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
state_num = 3
defined_points, pmf = all_first_pmfs_typeless['ibl_witten_14'][state_num][0], all_first_pmfs_typeless['ibl_witten_14'][state_num][1]
plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
plt.ylim(0, 1)
plt.xlim(0, 10)
plt.ylabel("P(rightwards)", size=32)
plt.xlabel("Contrast", size=32)
plt.yticks([0, 1], size=27)
plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=27)
sns.despine()
plt.tight_layout()
plt.savefig("./meeting_figures/proportion_other_bias")
plt.savefig("example type 2")
plt.show()
total_counter = 0
bias_counter = 0
tendency_counter = 0
lapse_counter = 0
for pmf in all_pmfs:
max_b, min_b = 0, 1
max_tendency, min_tendency = 0, 1
max_lapse_diff, min_lapse_diff = 0, 1
for p in pmf[1]:
if pmf[0][5]: # if this part of the pmf is defined, just take it
bias = p[5]
deviation = 0
while True: # just take the closest thing
deviation += 1
if pmf[0][5 - deviation] and pmf[0][5 + deviation]:
bias = (p[5 - deviation] + p[5 + deviation]) / 2
break
max_b = max(max_b, bias)
min_b = min(min_b, bias)
max_tendency = max(max_tendency, np.mean(p[pmf[0]]))
min_tendency = min(min_tendency, np.mean(p[pmf[0]]))
max_lapse_diff = max(max_lapse_diff, p[0] + p[-1] - 1)
min_lapse_diff = min(min_lapse_diff, p[0] + p[-1] - 1)
bias_changed = max_b > 0.55 and min_b < 0.45
tendency_changed = max_tendency > 0.55 and min_tendency < 0.45
lapse_changed = max_lapse_diff > 0.1 and min_lapse_diff < -0.1
bias_counter += bias_changed
tendency_counter += tendency_changed
lapse_counter += lapse_changed
if bias_changed or tendency_changed or lapse_changed:
total_counter += 1
for p in pmf[1]:
plt.plot(np.where(pmf[0])[0], p[pmf[0]])
plt.title("bias: {}, tendency: {}, lapse: {}".format(bias_changed, tendency_changed, lapse_changed))
plt.ylim(0, 1)
plt.axvline(5, color='grey')
plt.axhline(0.5, color='grey')
plt.savefig("./meeting_figures/bias_change_{}".format(total_counter))
plt.close()
print("Bias counters")
print(total_counter)
print(bias_counter)
print(tendency_counter)
print(lapse_counter)
pmf_ranges = []
for key in all_first_pmfs_typeless:
for defined_points, pmf in all_first_pmfs_typeless[key]:
if pmf_type(pmf) == 2:
pmf_ranges.append(pmf[-1] - pmf[0])
# if pmf_ranges[-1] < 0.6:
# plt.plot(pmf)
# plt.title(pmf_ranges[-1])
# plt.ylim(0, 1)
# plt.show()
plt.hist(pmf_ranges, bins=40)
plt.title("Type 2 ranges")
plt.ylabel("# of type 2 states")
plt.xlabel("PMF range")
plt.show()
state_num = 4
defined_points, pmf = all_first_pmfs_typeless['ibl_witten_17'][state_num][0], all_first_pmfs_typeless['ibl_witten_17'][state_num][1]
plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
lapses = []
for key in all_first_pmfs_typeless:
for defined_points, pmf in all_first_pmfs_typeless[key]:
if pmf_type(pmf) != 0:
lapses.append(max(pmf[0], 1 - pmf[-1]))
plt.hist(lapses, bins=40)
plt.title("Higher lapse rate of type != 0")
plt.ylabel("# of non-type-0 states")
plt.xlabel("Higher lapse rate")
plt.ylim(0, 1)
plt.xlim(0, 10)
plt.ylabel("P(rightwards)", size=32)
plt.xlabel("Contrast", size=32)
plt.yticks([0, 1], size=27)
plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=27)
sns.despine()
plt.tight_layout()
plt.savefig("example type 3")
plt.show()
quit()
n_rows, n_cols = 5, 6
_, axs = plt.subplots(n_rows, n_cols, figsize=(16, 9))
......
This diff is collapsed.
import numpy as np
import matplotlib.pyplot as plt
import pickle
from scipy.stats import gaussian_kde
from analysis_pmf import pmf_type, type2color
# all pmf weights
apw = np.array(pickle.load(open("all_pmf_weights.p", 'rb')))
xy = np.vstack([apw[:, i] for i in range(4)])
z = gaussian_kde(xy)(xy)
plt.subplot(1, 3, 1)
plt.scatter(apw[:, 0], apw[:, 1], c=z)
plt.xlabel("Cont right")
plt.ylabel("Cont left")
plt.subplot(1, 3, 2)
plt.scatter(apw[:, 3], apw[:, 1], c=z)
plt.xlabel("Bias")
plt.ylabel("Cont left")
plt.subplot(1, 3, 3)
plt.scatter(apw[:, 3], apw[:, 0], c=z)
plt.xlabel("Bias")
plt.ylabel("Cont right")
plt.show()
contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])
contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1]
def weights_to_pmf(weights, with_bias=1):
psi = weights[0] * contrasts_R + weights[1] * contrasts_L + with_bias * weights[-1]
return 1 / (1 + np.exp(psi))
colors = [type2color[pmf_type(weights_to_pmf(x))] for x in apw]
plt.subplot(1, 3, 1)
sc = plt.scatter(apw[:, 0], apw[:, 1], c=colors)
fig, ax = plt.gcf(), plt.gca()
plt.xlabel("Cont right")
plt.ylabel("Cont left")
annot = ax.annotate("", xy=(0, 0), xytext=(20, 20), textcoords="offset points",
bbox=dict(boxstyle="round", fc="w"),
arrowprops=dict(arrowstyle="->"))
annot.set_visible(False)
def update_annot(ind):
print(ind)
pos = sc.get_offsets()[ind["ind"][0]]
annot.xy = pos
text = "{}".format(np.round(apw[ind["ind"][0]], 2))
annot.set_text(text)
def hover(event):
vis = annot.get_visible()
if event.inaxes == ax:
cont, ind = sc.contains(event)
if cont:
update_annot(ind)
annot.set_visible(True)
fig.canvas.draw_idle()
else:
if vis:
annot.set_visible(False)
fig.canvas.draw_idle()
fig.canvas.mpl_connect("motion_notify_event", hover)
plt.subplot(1, 3, 2)
plt.scatter(apw[:, 3], apw[:, 1], c=colors)
plt.xlabel("Bias")
plt.ylabel("Cont left")
plt.subplot(1, 3, 3)
plt.scatter(apw[:, 3], apw[:, 0], c=colors)
plt.xlabel("Bias")
plt.ylabel("Cont right")
plt.show()
from mpl_toolkits import mplot3d
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.scatter3D(apw[:, 0], apw[:, 1], apw[:, 3], c=colors)
plt.show()
from simplex_plot import plotSimplex
import numpy as np
import pickle
from analysis_pmf import type2color
import math
import matplotlib.pyplot as plt
import copy
all_state_types = pickle.load(open("all_state_types.p", 'rb'))
# create a list of fixed offsets, to apply to mice when they are in a corner
offsets = [(0, 0)]
hex_size = 0
angle_add = math.pi / 3
while len(offsets) < len(all_state_types):
hex_size += 1
corner = (hex_size, 0)
offsets.append(corner)
local_angle = 4 * math.pi / 3
for i in range(6):
next_corner = (corner[0] + hex_size * math.cos(local_angle), corner[1] + hex_size * math.sin(local_angle))
local_angle -= angle_add
for j in range(hex_size - 1):
offsets.append((corner[0] * (j + 1) / hex_size + next_corner[0] * (hex_size - j - 1) / hex_size,
corner[1] * (j + 1) / hex_size + next_corner[1] * (hex_size - j - 1) / hex_size))
corner = next_corner
if i != 5:
offsets.append(copy.copy(corner))
_, test_count = np.unique(offsets, return_counts=1, axis=0)
assert (test_count == 1).all()
# for i, o in enumerate(offsets):
# plt.scatter(o[0], o[1])
# plt.show()
# quit()
session_counter = 0
alph = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'za', 'zb', 'zc', 'zd', 'ze', 'zf', 'zg', 'zh', 'zi', 'zj']
# do as many sessions as it takes
while True:
session_counter += 1
not_ended = 0
type_proportions = np.zeros((3, len(all_state_types)))
x_offset, y_offset = np.zeros(len(all_state_types)), np.zeros(len(all_state_types))
# iterate through all mice, count how many are in each corner
for i, sts in enumerate(all_state_types):
# check whether this mouse still has sessions left
if sts.shape[1] > session_counter:
not_ended += 1
temp_counter = session_counter
else:
temp_counter = -1
assert np.sum(sts[:, temp_counter]) <= 1
if (sts[:, temp_counter] == 1).any():
x_offset[i] = offsets[i][0] * 0.01
y_offset[i] = offsets[i][1] * 0.01
type_proportions[:, i] = sts[:, temp_counter]
plotSimplex(type_proportions.T, x_offset=x_offset, y_offset=y_offset, c=np.arange(len(all_state_types)), show=False, title="Session {}".format(session_counter),
vertexcolors=[type2color[i] for i in range(3)], vertexlabels=['Type 1', 'Type 2', 'Type 3'], save_title="simplex_{}.png".format(alph[session_counter]))
plt.close()
if not_ended == 0:
break
......@@ -14,8 +14,8 @@ import matplotlib.patches as PA
def plotSimplex(points, fig=None,
vertexlabels=['Type 1', 'Type 2', 'Type 3'], save_title="test.png",
show=False, **kwargs):
vertexlabels=['1: initial flat PMFs', '2: intermediate unilateral PMFs', '3: final bilateral PMFs'],
save_title="test.png", show=False, vertexcolors=['k', 'k', 'k'], x_offset=0, y_offset=0, **kwargs):
"""
Plot Nx3 points array on the 3-simplex
(with optionally labeled vertices)
......@@ -32,9 +32,9 @@ def plotSimplex(points, fig=None,
fig.gca().xaxis.set_major_locator(MT.NullLocator())
fig.gca().yaxis.set_major_locator(MT.NullLocator())
# Draw vertex labels
fig.gca().text(-0.06, -0.05, vertexlabels[0], size=24)
fig.gca().text(0.95, -0.05, vertexlabels[1], size=24)
fig.gca().text(0.43, np.sqrt(3) / 2 + 0.025, vertexlabels[2], size=24)
# fig.gca().annotate(vertexlabels[0], (-0.35, -0.05), size=24, color=vertexcolors[0], annotation_clip=False)
# fig.gca().annotate(vertexlabels[1], (0.6, -0.05), size=24, color=vertexcolors[1], annotation_clip=False)
# fig.gca().annotate(vertexlabels[2], (0.1, np.sqrt(3) / 2 + 0.025), size=24, color=vertexcolors[2], annotation_clip=False)
# Project and draw the actual points
projected = projectSimplex(points / points.sum(1)[:, None])
print(projected)
......@@ -50,6 +50,7 @@ def plotSimplex(points, fig=None,
P.axis('off')
P.tight_layout()
P.savefig("dur_simplex.png", bbox_inches='tight', dpi=300, transparent=True)
if show:
P.show()
......