Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • sbruijns/ihmm_behav_states
1 result
Show changes
Commits on Source (5)
......@@ -7,15 +7,47 @@ from mpl_toolkits import mplot3d
from matplotlib.patches import ConnectionPatch
show_weight_augmentations = True
all_weight_trajectories = pickle.load(open("multi_chain_saves/all_weight_trajectories.p", 'rb'))
first_and_last_pmf = np.array(pickle.load(open("multi_chain_saves/first_and_last_pmf.p", 'rb')))
all_sudden_changes = pickle.load(open("multi_chain_saves/all_sudden_changes.p", 'rb'))
all_sudden_transition_changes = pickle.load(open("multi_chain_saves/all_sudden_transition_changes.p", 'rb'))
if show_weight_augmentations:
aug_all_sudden_changes = pickle.load(open("multi_chain_saves/aug_all_sudden_changes.p", 'rb'))
aug_all_sudden_transition_changes = pickle.load(open("multi_chain_saves/aug_all_sudden_transition_changes.p", 'rb'))
performance_points = np.array([-1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0])
reduced_points = np.array([1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1], dtype=bool)
weight_colours = ['blue', 'red', 'green', 'goldenrod', 'darkorange']
weight_colours_aug = ['blue', 'red', 'purple', 'green', 'goldenrod', 'darkorange']
weight_colours_aug = ['blue', 'red', 'green', 'goldenrod', 'darkorange', 'purple']
ylabels = ["Cont left", "Cont right", "Persevere", "Bias left", "Bias right"]
ylabels_aug = ["Cont left", "Cont right", "Cont diff", "Persevere", "Bias left", "Bias right"]
ylabels_aug = ["Cont left", "Cont right", "Persevere", "Bias left", "Bias right", "PMF span"]
folder = "./reward_analysis/"
local_ylabels = ylabels_aug if show_weight_augmentations else ylabels
local_weight_colours = weight_colours_aug if show_weight_augmentations else weight_colours
n_weights = all_weight_trajectories[0][0].shape[0]
n_types = 3
def create_nested_list(list_of_ns):
"""
Create a complex nested list, according to the specifications of list_of_ns.
E.g. [3, 2, 4] will return a list with 3 sublists, each containing 2 sublists again, each containing 4 sublists yet again.
"""
ret = []
if len(list_of_ns) == 0:
return ret
for _ in range(list_of_ns[0]):
ret.append(create_nested_list(list_of_ns=list_of_ns[1:]))
return ret
def pmf_to_perf(pmf):
# determine performance of a pmf, but only on the omnipresent strongest contrasts
......@@ -41,331 +73,216 @@ def pmf_type_rew(weights):
return 2
if True:
all_weight_trajectories = pickle.load(open("multi_chain_saves/all_weight_trajectories.p", 'rb'))
first_and_last_pmf = np.array(pickle.load(open("multi_chain_saves/first_and_last_pmf.p", 'rb')))
all_sudden_changes = pickle.load(open("multi_chain_saves/all_sudden_changes.p", 'rb'))
n_weights = all_weight_trajectories[0][0].shape[0]
all_sudden_transition_changes = pickle.load(open("multi_chain_saves/all_sudden_transition_changes.p", 'rb'))
average = np.zeros((n_weights + 1, 3, 2))
counter = np.zeros((n_weights + 1, 3))
all_datapoints = [[[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]]]
f, axs = plt.subplots(n_weights + 1, 3, figsize=(12, 9)) # We add another weight slot to split the bias
for state_type, weight_gaps in enumerate(all_sudden_transition_changes):
for gap in weight_gaps:
for i in range(n_weights):
if i == n_weights - 1:
axs[i + (gap[1][i] > 0), state_type + 1].plot([0, 1], [gap[0][i], gap[-1][i]], marker="o") # plot how the weight evolves from first to last appearance
average[i + (gap[1][i] > 0), state_type + 1] += np.array([gap[0][i], gap[-1][i]])
counter[i + (gap[1][i] > 0), state_type + 1] += 1
all_datapoints[i + (gap[1][i] > 0)][state_type + 1][0].append(gap[0][i])
all_datapoints[i + (gap[1][i] > 0)][state_type + 1][1].append(gap[-1][i])
else:
axs[i, state_type + 1].plot([0, 1], [gap[0][i], gap[-1][i]], marker="o") # plot how the weight evolves from first to last appearance
average[i, state_type + 1] += np.array([gap[0][i], gap[-1][i]])
counter[i, state_type + 1] += 1
all_datapoints[i][state_type + 1][0].append(gap[0][i])
all_datapoints[i][state_type + 1][1].append(gap[-1][i])
plt.tight_layout()
plt.savefig("./summary_figures/weight_changes/sudden_weight_change at transitions")
plt.show()
f, axs = plt.subplots(1, 3, figsize=(12, 6))
for i in range(1, 3):
for j in range(n_weights + 1):
axs[i].plot([0, 1], average[j, i] / counter[j, i], marker="o", color=weight_colours[j], label=ylabels[j])
axs[i].set_ylim(-3.5, 3.5)
axs[i].spines['top'].set_visible(False)
axs[i].spines['right'].set_visible(False)
axs[i].set_xticks([])
if i == 0:
pass
# axs[i].set_ylabel("Weights", size=24)
# if j < n_weights - 1:
# axs[i].plot([0.1], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', color=weight_colours[j]) # also plot weights of very first state average
# if j == n_weights - 1:
# mask = first_and_last_pmf[:, 0, -1] < 0
# axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=weight_colours[j]) # separete biases again
# if j == n_weights:
# mask = first_and_last_pmf[:, 0, -1] > 0
# axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=weight_colours[j])
else:
axs[i].yaxis.set_ticklabels([])
if i == 2:
pass
# if j < n_weights - 1:
# axs[i].plot([0.9], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', color=weight_colours[j]) # also plot weights of very last state average
# if j == n_weights - 1:
# mask = first_and_last_pmf[:, 1, -1] < 0
# axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=weight_colours[j]) # separete biases again
# if j == n_weights:
# mask = first_and_last_pmf[:, 1, -1] > 0
# axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=weight_colours[j])
if j == 0:
axs[i].set_title("Type {}".format(i + 1), size=26)
if j == n_weights and i == 1:
axs[i].set_xlabel("Lifetime weight change", size=24)
axs[0].legend(frameon=False, fontsize=14)
plt.tight_layout()
plt.savefig("./summary_figures/weight_changes/compact sudden_weight_changes at transitions")
plt.show()
def plot_traces_and_collate_data(data, augmented_data=None, title=""):
"""Plot all the individual weight change traces, ann collect the data we need for the other plots"""
show_weight_augmentations = augmented_data is not None
sudden_changes = len(data) == n_types - 1
average = np.zeros((n_weights + 1 + show_weight_augmentations, n_types - sudden_changes, 2))
counter = np.zeros((n_weights + 1 + show_weight_augmentations, n_types - sudden_changes))
all_datapoints = create_nested_list([n_weights + 1 + show_weight_augmentations, n_types - sudden_changes, 2])
average = np.zeros((n_weights + 1, 3, 2))
counter = np.zeros((n_weights + 1, 3))
all_datapoints = [[[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]]]
f, axs = plt.subplots(n_weights + 1, 3, figsize=(12, 9)) # I added a third pseudo weight for distance between the two sensitivities, and do another extra to split the bias
for state_type, weight_gaps in enumerate(all_sudden_changes):
f, axs = plt.subplots(n_weights + 1 + show_weight_augmentations, n_types - sudden_changes, figsize=(4 * (3 - sudden_changes), 9)) # We add another weight slot to split the bias
for state_type, weight_gaps in enumerate(data):
for gap in weight_gaps:
for i in range(n_weights):
if i == n_weights - 1:
axs[i + (gap[1][i] > 0), state_type].plot([0, 1], [gap[0][i], gap[-1][i]], marker="o") # plot how the weight evolves from first to last appearance
axs[i + (gap[1][i] > 0), state_type].plot([0, 1], [gap[0][i], gap[-1][i]], marker="o") # plot how the weight of the new state differs from the previous closest weight
average[i + (gap[1][i] > 0), state_type] += np.array([gap[0][i], gap[-1][i]])
counter[i + (gap[1][i] > 0), state_type] += 1
all_datapoints[i + (gap[1][i] > 0)][state_type][0].append(gap[0][i])
all_datapoints[i + (gap[1][i] > 0)][state_type][1].append(gap[-1][i])
else:
axs[i, state_type].plot([0, 1], [gap[0][i], gap[-1][i]], marker="o") # plot how the weight evolves from first to last appearance
axs[i, state_type].plot([0, 1], [gap[0][i], gap[-1][i]], marker="o") # plot how the weight of the new state differs from the previous closest weight
average[i, state_type] += np.array([gap[0][i], gap[-1][i]])
counter[i, state_type] += 1
all_datapoints[i][state_type][0].append(gap[0][i])
all_datapoints[i][state_type][1].append(gap[-1][i])
if show_weight_augmentations:
for state_type, weight_gaps in enumerate(augmented_data):
for gap in weight_gaps:
axs[n_weights + 1, state_type].plot([0, 1], [gap[0], gap[-1]], marker="o") # plot how the weight of the new state differs from the previous closest weight
average[n_weights + 1, state_type] += np.array([gap[0], gap[-1]])
counter[n_weights + 1, state_type] += 1
all_datapoints[n_weights + 1][state_type][0].append(gap[0])
all_datapoints[n_weights + 1][state_type][1].append(gap[-1])
plt.tight_layout()
plt.savefig("./summary_figures/weight_changes/sudden_weight_change")
plt.show()
plt.savefig("./summary_figures/weight_changes/" + title + " augmented" * show_weight_augmentations)
plt.close()
f, axs = plt.subplots(1, 3, figsize=(12, 6))
for i in range(3):
for j in range(n_weights + 1):
axs[i].plot([0, 1], average[j, i] / counter[j, i], marker="o", color=weight_colours[j], label=ylabels[j])
return average, counter, all_datapoints
def plot_compact(average, counter, title, show_first_and_last=False, show_weight_augmentations=False):
"""Plot the means of the weight change traces, split by the different weight types, possibly with augmentations"""
sudden_changes = average.shape[1] == n_types - 1
f, axs = plt.subplots(1, n_types - sudden_changes, figsize=(4 * (3 - sudden_changes), 6))
for i in range(n_types - sudden_changes):
for j in range(n_weights + 1 + show_weight_augmentations):
axs[i].plot([0, 1], average[j, i] / counter[j, i], marker="o", color=local_weight_colours[j], label=local_ylabels[j])
axs[i].set_ylim(-3.5, 3.5)
axs[i].spines['top'].set_visible(False)
axs[i].spines['right'].set_visible(False)
axs[i].set_xticks([])
if i == 0:
pass
# axs[i].set_ylabel("Weights", size=24)
# if j < n_weights - 1:
# axs[i].plot([0.1], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', color=weight_colours[j]) # also plot weights of very first state average
# if j == n_weights - 1:
# mask = first_and_last_pmf[:, 0, -1] < 0
# axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=weight_colours[j]) # separete biases again
# if j == n_weights:
# mask = first_and_last_pmf[:, 0, -1] > 0
# axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=weight_colours[j])
if i == 0 and show_first_and_last:
axs[i].set_ylabel("Weights", size=24)
if j < n_weights - 1:
axs[i].plot([0.1], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', color=local_weight_colours[j]) # also plot weights of very first state average
if j == n_weights - 1:
mask = first_and_last_pmf[:, 0, -1] < 0
axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=local_weight_colours[j]) # separete biases again
if j == n_weights:
mask = first_and_last_pmf[:, 0, -1] > 0
axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=local_weight_colours[j])
else:
axs[i].yaxis.set_ticklabels([])
if i == 2:
pass
# if j < n_weights - 1:
# axs[i].plot([0.9], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', color=weight_colours[j]) # also plot weights of very last state average
# if j == n_weights - 1:
# mask = first_and_last_pmf[:, 1, -1] < 0
# axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=weight_colours[j]) # separete biases again
# if j == n_weights:
# mask = first_and_last_pmf[:, 1, -1] > 0
# axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=weight_colours[j])
if i == 2 and show_first_and_last:
if j < n_weights - 1:
axs[i].plot([0.9], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', color=local_weight_colours[j]) # also plot weights of very last state average
if j == n_weights - 1:
mask = first_and_last_pmf[:, 1, -1] < 0
axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=local_weight_colours[j]) # separete biases again
if j == n_weights:
mask = first_and_last_pmf[:, 1, -1] > 0
axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=local_weight_colours[j])
if j == 0:
axs[i].set_title("Type {}".format(i + 1), size=26)
axs[i].set_title("Type {}".format(i + 1 + sudden_changes), size=26)
if j == n_weights and i == 1:
axs[i].set_xlabel("Lifetime weight change", size=24)
axs[0].legend(frameon=False, fontsize=14)
plt.tight_layout()
plt.savefig("./summary_figures/weight_changes/compact sudden_weight_changes")
plt.show()
bin_sets = [np.linspace(-6.5, 6.5, 30), np.linspace(-1, 2, 30), np.linspace(-3.5, 3.5, 30)]
if True:
x_lim_used_full, x_lim_used_half = 56, 28
f, axs = plt.subplots(n_weights + 1, 3 * 2, figsize=(12, 9))
for i in range(3):
for j in range(n_weights + 1):
# axs[j, i].plot([0, 1], average[j, i] / counter[j, i], marker="o")
if j < 2:
bins = bin_sets[0]
elif j == 2:
bins = bin_sets[1]
else:
bins = bin_sets[2]
axs[j, i * 2].hist(all_datapoints[j][i][0], orientation='horizontal', bins=bins, color='grey', alpha=0.5)
plt.savefig("./summary_figures/weight_changes/" + title + " augmented" * show_weight_augmentations)
plt.close()
def plot_histogram_diffs(all_datapoints, x_lim_used_normal, x_lim_used_bias, bin_sets, title, x_lim_used_augment=0, show_deltas=True, show_first_and_last=False, show_weight_augmentations=False):
"""Plot histograms over the weights, and the mean changes connecting them.
Might have to mess quite a bit with the y-axis"""
sudden_changes = len(all_datapoints[0]) == 2
f, axs = plt.subplots(n_weights + 1 + show_weight_augmentations, (n_types - sudden_changes) * 2, figsize=(4 * (3 - sudden_changes), 9))
for i in range(n_types - sudden_changes):
for j in range(n_weights + 1 + show_weight_augmentations):
if j < 2:
bins = bin_sets[0]
elif j == 2:
bins = bin_sets[1]
elif j in [3, 4]:
bins = bin_sets[2]
else:
bins = bin_sets[3]
axs[j, i * 2].hist(all_datapoints[j][i][0], orientation='horizontal', bins=bins, color='grey', alpha=0.5)
if show_deltas:
axs[j, i * 2 + 1].hist(np.array(all_datapoints[j][i][1]) - np.array(all_datapoints[j][i][0]), orientation='horizontal', bins=bins, color='red', alpha=0.5)
else:
axs[j, i * 2 + 1].hist(all_datapoints[j][i][1], orientation='horizontal', bins=bins, color='grey', alpha=0.5)
if j < 2:
axs[j, i * 2].set_ylim(-6.5, 6.5)
axs[j, i * 2 + 1].set_ylim(-6.5, 6.5)
elif j == 2:
axs[j, i * 2].set_ylim(-1, 2)
axs[j, i * 2 + 1].set_ylim(-1, 2)
else:
axs[j, i * 2].set_ylim(-3.5, 3.5)
axs[j, i * 2 + 1].set_ylim(-3.5, 3.5)
axs[j, i * 2].spines['top'].set_visible(False)
axs[j, i * 2].spines['right'].set_visible(False)
axs[j, i * 2].set_xticks([])
axs[j, i * 2 + 1].spines['top'].set_visible(False)
axs[j, i * 2 + 1].spines['right'].set_visible(False)
axs[j, i * 2 + 1].set_xticks([])
axs[j, i * 2].annotate("Var {:.2f}".format(np.var(all_datapoints[j][i][0])), xy=(0.65, 0.8), xycoords='axes fraction')
axs[j, i * 2].set_ylim(bins[0], bins[-1])
axs[j, i * 2 + 1].set_ylim(bins[0], bins[-1])
axs[j, i * 2].spines['top'].set_visible(False)
axs[j, i * 2].spines['right'].set_visible(False)
axs[j, i * 2].set_xticks([])
axs[j, i * 2 + 1].spines['top'].set_visible(False)
axs[j, i * 2 + 1].spines['right'].set_visible(False)
axs[j, i * 2 + 1].set_xticks([])
axs[j, i * 2].annotate("Var {:.2f}".format(np.var(all_datapoints[j][i][0])), xy=(0.65, 0.8), xycoords='axes fraction')
if show_deltas:
axs[j, i * 2 + 1].annotate("Var {:.2f}".format(np.var(np.array(all_datapoints[j][i][1]) - np.array(all_datapoints[j][i][0]))), xy=(0.65, 0.8), xycoords='axes fraction')
else:
axs[j, i * 2 + 1].annotate("Var {:.2f}".format(np.var(all_datapoints[j][i][1])), xy=(0.65, 0.8), xycoords='axes fraction')
if j < n_weights - 1:
assert x_lim_used_normal > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_normal, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]))
axs[j, i * 2].set_xlim(0, x_lim_used_normal)
axs[j, i * 2 + 1].set_xlim(0, x_lim_used_normal)
means = average[j, i] / counter[j, i]
con = ConnectionPatch(xyA=(x_lim_used_normal / 12, means[0]), xyB=(0, means[1]), coordsA="data", coordsB="data",
axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue")
axs[j, i * 2 + 1].add_artist(con)
elif j < n_weights + 1:
assert x_lim_used_bias > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_bias, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]))
axs[j, i * 2].set_xlim(0, x_lim_used_bias)
axs[j, i * 2 + 1].set_xlim(0, x_lim_used_bias)
means = average[j, i] / counter[j, i]
con = ConnectionPatch(xyA=(x_lim_used_bias / 12, means[0]), xyB=(0, means[1]), coordsA="data", coordsB="data",
axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue")
axs[j, i * 2 + 1].add_artist(con)
else:
assert x_lim_used_augment > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_augment, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]))
axs[j, i * 2].set_xlim(0, x_lim_used_augment)
axs[j, i * 2 + 1].set_xlim(0, x_lim_used_augment)
means = average[j, i] / counter[j, i]
con = ConnectionPatch(xyA=(x_lim_used_augment / 12, means[0]), xyB=(0, means[1]), coordsA="data", coordsB="data",
axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue")
axs[j, i * 2 + 1].add_artist(con)
if i == 0:
axs[j, i].set_ylabel(local_ylabels[j])
axs[j, i * 2 + 1].yaxis.set_ticklabels([])
if show_first_and_last:
if j < n_weights - 1:
axs[j, i * 2].plot([x_lim_used_normal / 8], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', c='red') # also plot weights of very first state average
if j == n_weights - 1:
mask = first_and_last_pmf[:, 0, -1] < 0
axs[j, i * 2].plot([x_lim_used_bias / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red') # separete biases again
if j == n_weights:
mask = first_and_last_pmf[:, 0, -1] > 0
axs[j, i * 2].plot([x_lim_used_augment / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red')
else:
axs[j, i * 2].yaxis.set_ticklabels([])
axs[j, i * 2 + 1].yaxis.set_ticklabels([])
if i == n_types - 1 and show_first_and_last:
if j < n_weights - 1:
assert x_lim_used_full > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_full, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]))
axs[j, i * 2].set_xlim(0, x_lim_used_full)
axs[j, i * 2 + 1].set_xlim(0, x_lim_used_full)
means = average[j, i] / counter[j, i]
con = ConnectionPatch(xyA=(x_lim_used_full / 8, means[0]), xyB=(x_lim_used_full / 8, means[1]), coordsA="data", coordsB="data",
axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue")
axs[j, i * 2 + 1].add_artist(con)
axs[j, i * 2 + 1].plot([x_lim_used_normal / 8], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', c='red') # also plot weights of very last state average
if j == n_weights - 1:
mask = first_and_last_pmf[:, 1, -1] < 0
axs[j, i * 2 + 1].plot([x_lim_used_bias / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red') # separete biases again
if j == n_weights:
mask = first_and_last_pmf[:, 1, -1] > 0
axs[j, i * 2 + 1].plot([x_lim_used_augment / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red')
if j == 0:
axs[j, i * 2].set_title("Type {}".format(i + 1), loc='right')
if j == n_weights + show_weight_augmentations and i == 0:
axs[j, 0].set_xlabel("Initial distribution")
if show_deltas:
axs[j, 1].set_xlabel("Weight change")
else:
assert x_lim_used_half > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_half, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]))
axs[j, i * 2].set_xlim(0, x_lim_used_half)
axs[j, i * 2 + 1].set_xlim(0, x_lim_used_half)
means = average[j, i] / counter[j, i]
con = ConnectionPatch(xyA=(x_lim_used_half / 8, means[0]), xyB=(x_lim_used_half / 8, means[1]), coordsA="data", coordsB="data",
axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue")
axs[j, i * 2 + 1].add_artist(con)
axs[j, 1].set_xlabel("Changed distribution")
if i == 0:
axs[j, i].set_ylabel(ylabels[j])
axs[j, i * 2 + 1].yaxis.set_ticklabels([])
# if j < n_weights - 1:
# axs[j, i * 2].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', c='red') # also plot weights of very first state average
# if j == n_weights - 1:
# mask = first_and_last_pmf[:, 0, -1] < 0
# axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red') # separete biases again
# if j == n_weights:
# mask = first_and_last_pmf[:, 0, -1] > 0
# axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red')
else:
axs[j, i * 2].yaxis.set_ticklabels([])
axs[j, i * 2 + 1].yaxis.set_ticklabels([])
if i == 2:
pass
# if j < n_weights - 1:
# axs[j, i * 2 + 1].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', c='red') # also plot weights of very last state average
# if j == n_weights - 1:
# mask = first_and_last_pmf[:, 1, -1] < 0
# axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red') # separete biases again
# if j == n_weights:
# mask = first_and_last_pmf[:, 1, -1] > 0
# axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red')
if j == 0:
axs[j, i * 2].set_title("Type {}".format(i + 1), loc='right')
if j == n_weights:
axs[j, i * 2].set_xlabel("Lifetime weight change", loc='right')
plt.savefig("./summary_figures/weight_changes/" + title)
plt.close()
plt.savefig("./summary_figures/weight_changes/weight changes sudden hists")
plt.show()
x_lim_used_full, x_lim_used_half = 120, 60
bin_sets = [np.linspace(-6.5, 6.5, 30), np.linspace(-1, 2, 30), np.linspace(-3.5, 3.5, 30)]
f, axs = plt.subplots(n_weights + 1, 3 * 2, figsize=(12, 9))
for i in range(3):
for j in range(n_weights + 1):
# axs[j, i].plot([0, 1], average[j, i] / counter[j, i], marker="o")
if j < 2:
bins = bin_sets[0]
elif j == 2:
bins = bin_sets[1]
else:
bins = bin_sets[2]
axs[j, i * 2].hist(all_datapoints[j][i][0], orientation='horizontal', bins=bins, color='grey', alpha=0.5)
axs[j, i * 2 + 1].hist(np.array(all_datapoints[j][i][1]) - np.array(all_datapoints[j][i][0]), orientation='horizontal', bins=bins, color='grey', alpha=0.5)
if True:
if j < 2:
axs[j, i * 2].set_ylim(-6.5, 6.5)
axs[j, i * 2 + 1].set_ylim(-6.5, 6.5)
elif j == 2:
axs[j, i * 2].set_ylim(-1, 2)
axs[j, i * 2 + 1].set_ylim(-1, 2)
else:
axs[j, i * 2].set_ylim(-3.5, 3.5)
axs[j, i * 2 + 1].set_ylim(-3.5, 3.5)
axs[j, i * 2].spines['top'].set_visible(False)
axs[j, i * 2].spines['right'].set_visible(False)
axs[j, i * 2].set_xticks([])
axs[j, i * 2 + 1].spines['top'].set_visible(False)
axs[j, i * 2 + 1].spines['right'].set_visible(False)
axs[j, i * 2 + 1].set_xticks([])
axs[j, i * 2].annotate("Var {:.2f}".format(np.var(all_datapoints[j][i][0])), xy=(0.65, 0.8), xycoords='axes fraction')
axs[j, i * 2 + 1].annotate("Var {:.2f}".format(np.var(np.array(all_datapoints[j][i][1]) - np.array(all_datapoints[j][i][0]))), xy=(0.65, 0.8), xycoords='axes fraction')
bin_sets = [np.linspace(-6.5, 6.5, 30), np.linspace(-1, 2, 30), np.linspace(-3.5, 3.5, 30), np.linspace(-0.2, 1, 30)] # TODO: make show_augmentation robuse
if j < n_weights - 1:
assert x_lim_used_full > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_full, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]))
axs[j, i * 2].set_xlim(0, x_lim_used_full)
axs[j, i * 2 + 1].set_xlim(0, x_lim_used_full)
means = average[j, i] / counter[j, i]
con = ConnectionPatch(xyA=(x_lim_used_full / 8, means[0]), xyB=(x_lim_used_full / 8, means[1]), coordsA="data", coordsB="data",
axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue")
axs[j, i * 2 + 1].add_artist(con)
else:
assert x_lim_used_half > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_half, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]))
axs[j, i * 2].set_xlim(0, x_lim_used_half)
axs[j, i * 2 + 1].set_xlim(0, x_lim_used_half)
means = average[j, i] / counter[j, i]
con = ConnectionPatch(xyA=(x_lim_used_half / 8, means[0]), xyB=(x_lim_used_half / 8, means[1]), coordsA="data", coordsB="data",
axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue")
axs[j, i * 2 + 1].add_artist(con)
average, counter, all_datapoints = plot_traces_and_collate_data(data=all_sudden_transition_changes, augmented_data=aug_all_sudden_transition_changes, title="sudden_weight_change at transitions")
if i == 0:
axs[j, i].set_ylabel(ylabels[j])
axs[j, i * 2 + 1].yaxis.set_ticklabels([])
# if j < n_weights - 1:
# axs[j, i * 2].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', c='red') # also plot weights of very first state average
# if j == n_weights - 1:
# mask = first_and_last_pmf[:, 0, -1] < 0
# axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red') # separete biases again
# if j == n_weights:
# mask = first_and_last_pmf[:, 0, -1] > 0
# axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red')
else:
axs[j, i * 2].yaxis.set_ticklabels([])
axs[j, i * 2 + 1].yaxis.set_ticklabels([])
if i == 2:
pass
# if j < n_weights - 1:
# axs[j, i * 2 + 1].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', c='red') # also plot weights of very last state average
# if j == n_weights - 1:
# mask = first_and_last_pmf[:, 1, -1] < 0
# axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red') # separete biases again
# if j == n_weights:
# mask = first_and_last_pmf[:, 1, -1] > 0
# axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red')
if j == 0:
axs[j, i * 2].set_title("Type {}".format(i + 1), loc='right')
if j == n_weights:
axs[j, i * 2].set_xlabel("Lifetime weight change", loc='right')
plot_compact(average, counter, title="compact sudden_weight_change at transitions", show_weight_augmentations=show_weight_augmentations)
plt.savefig("./summary_figures/weight_changes/weight changes sudden delta hists")
plt.show()
plot_histogram_diffs(all_datapoints, x_lim_used_normal=22, x_lim_used_bias=11, x_lim_used_augment=40, bin_sets=bin_sets, title="weight changes sudden at transitions hists", show_deltas=False, show_weight_augmentations=show_weight_augmentations)
plot_histogram_diffs(all_datapoints, x_lim_used_normal=22, x_lim_used_bias=11, x_lim_used_augment=40, bin_sets=bin_sets, title="weight changes sudden at transitions delta hists", show_deltas=True, show_weight_augmentations=show_weight_augmentations)
# for weight_traj in all_weight_trajectories:
# if len(weight_traj) == 1:
# continue
# state_type = pmf_type(weights_to_pmf(weight_traj[0]))
average, counter, all_datapoints = plot_traces_and_collate_data(data=all_sudden_changes, augmented_data=aug_all_sudden_changes, title="sudden_weight_change")
# if state_type == 2 and weight_traj[0][0] > 0.8:
# plt.plot(weights_to_pmf(weight_traj[0]))
# plt.ylim(0, 1)
# plt.show()
plot_compact(average, counter, title="compact sudden_weight_change", show_weight_augmentations=show_weight_augmentations)
dur_lims = [(52, 25), (46, 23), (38, 19), (35, 17), (30, 15), (25, 12), (20, 10), (18, 9)]
plot_histogram_diffs(all_datapoints, x_lim_used_normal=56, x_lim_used_bias=28, x_lim_used_augment=140, bin_sets=bin_sets, title="weight changes sudden hists", show_deltas=False, show_weight_augmentations=show_weight_augmentations)
plot_histogram_diffs(all_datapoints, x_lim_used_normal=120, x_lim_used_bias=60, x_lim_used_augment=180, bin_sets=bin_sets, title="weight changes sudden delta hists", show_deltas=True, show_weight_augmentations=show_weight_augmentations)
dur_lims = [(52, 25), (46, 23), (38, 19), (35, 17), (30, 15), (25, 12), (20, 10), (18, 9)]
for min_dur_counter, min_dur in enumerate([2, 3, 4, 5, 7, 9, 11, 15]):
f, axs = plt.subplots(n_weights + 1, 3, figsize=(12, 9)) # add another space to split bias into stating left versus starting right
average = np.zeros((n_weights + 1, 3, 2))
counter = np.zeros((n_weights + 1, 3))
all_datapoints = [[[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]]]
average = np.zeros((n_weights + 1, n_types, 2))
counter = np.zeros((n_weights + 1, n_types))
all_datapoints = create_nested_list([n_weights + 1, n_types, 2])
f, axs = plt.subplots(n_weights + 1, n_types, figsize=(12, 9)) # add another space to split bias into stating left versus starting right
for weight_traj in all_weight_trajectories:
if len(weight_traj) < min_dur or len(weight_traj) > 15: # take out too short trajectories, and too long ones
......@@ -388,61 +305,26 @@ if True:
all_datapoints[i][state_type][1].append(weight_traj[-1][i])
for i in range(3):
for j in range(n_weights + 1):
axs[j, i].set_ylim(-9, 9)
axs[j, i].spines['top'].set_visible(False)
axs[j, i].spines['right'].set_visible(False)
axs[j, i].set_xticks([])
if i == 0:
axs[j, i].set_ylabel(ylabels[j])
else:
axs[j, i].yaxis.set_ticklabels([])
if j == 0:
axs[j, i].set_title("Type {}".format(i + 1))
if j == n_weights:
axs[j, i].set_xlabel("Lifetime weight change")
plt.tight_layout()
plt.savefig("./summary_figures/weight_changes/all weight changes min dur {}".format(min_dur))
plt.close()
f, axs = plt.subplots(1, 3, figsize=(12, 6))
for i in range(3):
for j in range(n_weights + 1):
axs[i].plot([0, 1], average[j, i] / counter[j, i], marker="o", color=weight_colours[j], label=ylabels[j])
axs[i].set_ylim(-3.5, 3.5)
axs[i].spines['top'].set_visible(False)
axs[i].spines['right'].set_visible(False)
axs[i].set_xticks([])
axs[j, i].set_ylim(-9, 9)
axs[j, i].spines['top'].set_visible(False)
axs[j, i].spines['right'].set_visible(False)
axs[j, i].set_xticks([])
if i == 0:
axs[i].set_ylabel("Weights", size=24)
if j < n_weights - 1:
axs[i].plot([0.1], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', color=weight_colours[j]) # also plot weights of very first state average
if j == n_weights - 1:
mask = first_and_last_pmf[:, 0, -1] < 0
axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=weight_colours[j]) # separete biases again
if j == n_weights:
mask = first_and_last_pmf[:, 0, -1] > 0
axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=weight_colours[j])
axs[j, i].set_ylabel(ylabels[j])
else:
axs[i].yaxis.set_ticklabels([])
if i == 2:
if j < n_weights - 1:
axs[i].plot([0.9], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', color=weight_colours[j]) # also plot weights of very last state average
if j == n_weights - 1:
mask = first_and_last_pmf[:, 1, -1] < 0
axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=weight_colours[j]) # separete biases again
if j == n_weights:
mask = first_and_last_pmf[:, 1, -1] > 0
axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=weight_colours[j])
axs[j, i].yaxis.set_ticklabels([])
if j == 0:
axs[i].set_title("Type {}".format(i + 1), size=26)
if j == n_weights and i == 1:
axs[i].set_xlabel("Lifetime weight change", size=24)
axs[0].legend(frameon=False, fontsize=14)
axs[j, i].set_title("Type {}".format(i + 1))
if j == n_weights:
axs[j, i].set_xlabel("Lifetime weight change")
plt.tight_layout()
plt.savefig("./summary_figures/weight_changes/compact changes min dur {}".format(min_dur))
plt.savefig("./summary_figures/weight_changes/all weight changes min dur {}".format(min_dur))
plt.close()
plot_compact(average, counter, title="compact changes min dur {}".format(min_dur), show_first_and_last=True)
# means all split up
f, axs = plt.subplots(n_weights + 1, 3, figsize=(12, 9))
for i in range(3):
for j in range(n_weights + 1):
......@@ -482,168 +364,14 @@ if True:
plt.savefig("./summary_figures/weight_changes/weight changes min dur {}".format(min_dur))
plt.close()
x_lim_used_full, x_lim_used_half = dur_lims[min_dur_counter]
# also histograms
f, axs = plt.subplots(n_weights + 1, 3 * 2, figsize=(12, 9))
for i in range(3):
for j in range(n_weights + 1):
if j < 2:
bins = bin_sets[0]
elif j == 2:
bins = bin_sets[1]
else:
bins = bin_sets[2]
axs[j, i * 2].hist(all_datapoints[j][i][0], orientation='horizontal', bins=bins, color='grey', alpha=0.5)
axs[j, i * 2 + 1].hist(all_datapoints[j][i][1], orientation='horizontal', bins=bins, color='grey', alpha=0.5)
if j < 2:
axs[j, i * 2].set_ylim(-6.5, 6.5)
axs[j, i * 2 + 1].set_ylim(-6.5, 6.5)
elif j == 2:
axs[j, i * 2].set_ylim(-1, 2)
axs[j, i * 2 + 1].set_ylim(-1, 2)
else:
axs[j, i * 2].set_ylim(-3.5, 3.5)
axs[j, i * 2 + 1].set_ylim(-3.5, 3.5)
axs[j, i * 2].spines['top'].set_visible(False)
axs[j, i * 2].spines['right'].set_visible(False)
axs[j, i * 2].set_xticks([])
axs[j, i * 2 + 1].spines['top'].set_visible(False)
axs[j, i * 2 + 1].spines['right'].set_visible(False)
axs[j, i * 2 + 1].set_xticks([])
axs[j, i * 2].annotate("Var {:.2f}".format(np.var(all_datapoints[j][i][0])), xy=(0.65, 0.8), xycoords='axes fraction')
axs[j, i * 2 + 1].annotate("Var {:.2f}".format(np.var(all_datapoints[j][i][1])), xy=(0.65, 0.8), xycoords='axes fraction')
if j < n_weights - 1:
assert x_lim_used_full > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_full, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]))
axs[j, i * 2].set_xlim(0, x_lim_used_full)
axs[j, i * 2 + 1].set_xlim(0, x_lim_used_full)
means = average[j, i] / counter[j, i]
con = ConnectionPatch(xyA=(x_lim_used_full / 8, means[0]), xyB=(x_lim_used_full / 8, means[1]), coordsA="data", coordsB="data",
axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue")
axs[j, i * 2 + 1].add_artist(con)
else:
assert x_lim_used_half > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_half, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]))
axs[j, i * 2].set_xlim(0, x_lim_used_half)
axs[j, i * 2 + 1].set_xlim(0, x_lim_used_half)
means = average[j, i] / counter[j, i]
con = ConnectionPatch(xyA=(x_lim_used_half / 8, means[0]), xyB=(x_lim_used_half / 8, means[1]), coordsA="data", coordsB="data",
axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue")
axs[j, i * 2 + 1].add_artist(con)
if i == 0:
axs[j, i].set_ylabel(ylabels[j])
axs[j, i * 2 + 1].yaxis.set_ticklabels([])
if j < n_weights - 1:
axs[j, i * 2].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', c='red') # also plot weights of very first state average
if j == n_weights - 1:
mask = first_and_last_pmf[:, 0, -1] < 0
axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red') # separete biases again
if j == n_weights:
mask = first_and_last_pmf[:, 0, -1] > 0
axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red')
else:
axs[j, i * 2].yaxis.set_ticklabels([])
axs[j, i * 2 + 1].yaxis.set_ticklabels([])
if i == 2:
if j < n_weights - 1:
axs[j, i * 2 + 1].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', c='red') # also plot weights of very last state average
if j == n_weights - 1:
mask = first_and_last_pmf[:, 1, -1] < 0
axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red') # separete biases again
if j == n_weights:
mask = first_and_last_pmf[:, 1, -1] > 0
axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red')
if j == 0:
axs[j, i * 2].set_title("Type {}".format(i + 1), loc='right')
if j == n_weights:
axs[j, i * 2].set_xlabel("Lifetime weight change", loc='right')
plt.savefig("./summary_figures/weight_changes/weight changes min dur {} hists".format(min_dur))
plt.close()
# also histograms of deltas
x_lim_used_full, x_lim_used_half = x_lim_used_full * 2.5, x_lim_used_half * 2.5
f, axs = plt.subplots(n_weights + 1, 3 * 2, figsize=(12, 9))
for i in range(3):
for j in range(n_weights + 1):
if j < 2:
bins = bin_sets[0]
elif j == 2:
bins = bin_sets[1]
else:
bins = bin_sets[2]
axs[j, i * 2].hist(all_datapoints[j][i][0], orientation='horizontal', bins=bins, color='grey', alpha=0.5)
axs[j, i * 2 + 1].hist(np.array(all_datapoints[j][i][1]) - np.array(all_datapoints[j][i][0]), orientation='horizontal', bins=bins, color='grey', alpha=0.5)
x_lim_used_normal, x_lim_used_bias = dur_lims[min_dur_counter]
plot_histogram_diffs(all_datapoints=all_datapoints, x_lim_used_normal=x_lim_used_normal, x_lim_used_bias=x_lim_used_bias, bin_sets=bin_sets,
title="weight changes min dur {} hists".format(min_dur), show_deltas=False, show_first_and_last=True)
if j < 2:
axs[j, i * 2].set_ylim(-6.5, 6.5)
axs[j, i * 2 + 1].set_ylim(-6.5, 6.5)
elif j == 2:
axs[j, i * 2].set_ylim(-1, 2)
axs[j, i * 2 + 1].set_ylim(-1, 2)
else:
axs[j, i * 2].set_ylim(-3.5, 3.5)
axs[j, i * 2 + 1].set_ylim(-3.5, 3.5)
axs[j, i * 2].spines['top'].set_visible(False)
axs[j, i * 2].spines['right'].set_visible(False)
axs[j, i * 2].set_xticks([])
axs[j, i * 2 + 1].spines['top'].set_visible(False)
axs[j, i * 2 + 1].spines['right'].set_visible(False)
axs[j, i * 2 + 1].set_xticks([])
axs[j, i * 2].annotate("Var {:.2f}".format(np.var(all_datapoints[j][i][0])), xy=(0.65, 0.8), xycoords='axes fraction')
axs[j, i * 2 + 1].annotate("Var {:.2f}".format(np.var(np.array(all_datapoints[j][i][1]) - np.array(all_datapoints[j][i][0]))), xy=(0.65, 0.8), xycoords='axes fraction')
if j < n_weights - 1:
assert x_lim_used_full > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_full, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]))
axs[j, i * 2].set_xlim(0, x_lim_used_full)
axs[j, i * 2 + 1].set_xlim(0, x_lim_used_full)
means = average[j, i] / counter[j, i]
con = ConnectionPatch(xyA=(x_lim_used_full / 8, means[0]), xyB=(x_lim_used_full / 8, means[1]), coordsA="data", coordsB="data",
axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue")
axs[j, i * 2 + 1].add_artist(con)
else:
assert x_lim_used_half > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_half, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]))
axs[j, i * 2].set_xlim(0, x_lim_used_half)
axs[j, i * 2 + 1].set_xlim(0, x_lim_used_half)
means = average[j, i] / counter[j, i]
con = ConnectionPatch(xyA=(x_lim_used_half / 8, means[0]), xyB=(x_lim_used_half / 8, means[1]), coordsA="data", coordsB="data",
axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue")
axs[j, i * 2 + 1].add_artist(con)
if i == 0:
axs[j, i].set_ylabel(ylabels[j])
axs[j, i * 2 + 1].yaxis.set_ticklabels([])
if j < n_weights - 1:
axs[j, i * 2].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', c='red') # also plot weights of very first state average
if j == n_weights - 1:
mask = first_and_last_pmf[:, 0, -1] < 0
axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red') # separete biases again
if j == n_weights:
mask = first_and_last_pmf[:, 0, -1] > 0
axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red')
else:
axs[j, i * 2].yaxis.set_ticklabels([])
axs[j, i * 2 + 1].yaxis.set_ticklabels([])
if i == 2:
if j < n_weights - 1:
axs[j, i * 2 + 1].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', c='red') # also plot weights of very last state average
if j == n_weights - 1:
mask = first_and_last_pmf[:, 1, -1] < 0
axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red') # separete biases again
if j == n_weights:
mask = first_and_last_pmf[:, 1, -1] > 0
axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red')
if j == 0:
axs[j, i * 2].set_title("Type {}".format(i + 1), loc='right')
if j == n_weights:
axs[j, i * 2].set_xlabel("Lifetime weight change", loc='right')
x_lim_used_normal, x_lim_used_bias = x_lim_used_normal * 2.5, x_lim_used_bias * 2.5
plot_histogram_diffs(all_datapoints=all_datapoints, x_lim_used_normal=x_lim_used_normal, x_lim_used_bias=x_lim_used_bias, bin_sets=bin_sets,
title="weight changes min dur {} delta hists".format(min_dur), show_deltas=True, show_first_and_last=True)
plt.savefig("./summary_figures/weight_changes/weight changes min dur {} delta hists".format(min_dur))
plt.close()
quit()
# all pmf weights
......