diff --git a/analysis_pmf_weights.py b/analysis_pmf_weights.py index cbcb0bdd52535fc80d81c1946b153c44ce783873..807e94f05084eb6da22b723f56f21e02366982a1 100644 --- a/analysis_pmf_weights.py +++ b/analysis_pmf_weights.py @@ -7,15 +7,47 @@ from mpl_toolkits import mplot3d from matplotlib.patches import ConnectionPatch +show_weight_augmentations = True + +all_weight_trajectories = pickle.load(open("multi_chain_saves/all_weight_trajectories.p", 'rb')) +first_and_last_pmf = np.array(pickle.load(open("multi_chain_saves/first_and_last_pmf.p", 'rb'))) +all_sudden_changes = pickle.load(open("multi_chain_saves/all_sudden_changes.p", 'rb')) +all_sudden_transition_changes = pickle.load(open("multi_chain_saves/all_sudden_transition_changes.p", 'rb')) + +if show_weight_augmentations: + aug_all_sudden_changes = pickle.load(open("multi_chain_saves/aug_all_sudden_changes.p", 'rb')) + aug_all_sudden_transition_changes = pickle.load(open("multi_chain_saves/aug_all_sudden_transition_changes.p", 'rb')) + performance_points = np.array([-1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0]) reduced_points = np.array([1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1], dtype=bool) weight_colours = ['blue', 'red', 'green', 'goldenrod', 'darkorange'] -weight_colours_aug = ['blue', 'red', 'purple', 'green', 'goldenrod', 'darkorange'] +weight_colours_aug = ['blue', 'red', 'green', 'goldenrod', 'darkorange', 'purple'] ylabels = ["Cont left", "Cont right", "Persevere", "Bias left", "Bias right"] -ylabels_aug = ["Cont left", "Cont right", "Cont diff", "Persevere", "Bias left", "Bias right"] +ylabels_aug = ["Cont left", "Cont right", "Persevere", "Bias left", "Bias right", "PMF span"] folder = "./reward_analysis/" +local_ylabels = ylabels_aug if show_weight_augmentations else ylabels +local_weight_colours = weight_colours_aug if show_weight_augmentations else weight_colours + +n_weights = all_weight_trajectories[0][0].shape[0] +n_types = 3 + + +def create_nested_list(list_of_ns): + """ + Create a complex nested list, according to the specifications of list_of_ns. + E.g. [3, 2, 4] will return a list with 3 sublists, each containing 2 sublists again, each containing 4 sublists yet again. + """ + ret = [] + if len(list_of_ns) == 0: + return ret + + for _ in range(list_of_ns[0]): + ret.append(create_nested_list(list_of_ns=list_of_ns[1:])) + + return ret + def pmf_to_perf(pmf): # determine performance of a pmf, but only on the omnipresent strongest contrasts @@ -41,331 +73,216 @@ def pmf_type_rew(weights): return 2 -if True: - all_weight_trajectories = pickle.load(open("multi_chain_saves/all_weight_trajectories.p", 'rb')) - first_and_last_pmf = np.array(pickle.load(open("multi_chain_saves/first_and_last_pmf.p", 'rb'))) - all_sudden_changes = pickle.load(open("multi_chain_saves/all_sudden_changes.p", 'rb')) - - n_weights = all_weight_trajectories[0][0].shape[0] - - all_sudden_transition_changes = pickle.load(open("multi_chain_saves/all_sudden_transition_changes.p", 'rb')) +def plot_traces_and_collate_data(data, augmented_data=None, title=""): + """Plot all the individual weight change traces, ann collect the data we need for the other plots""" + show_weight_augmentations = augmented_data is not None + sudden_changes = len(data) == n_types - 1 + average = np.zeros((n_weights + 1 + show_weight_augmentations, n_types - sudden_changes, 2)) + counter = np.zeros((n_weights + 1 + show_weight_augmentations, n_types - sudden_changes)) + all_datapoints = create_nested_list([n_weights + 1 + show_weight_augmentations, n_types - sudden_changes, 2]) - average = np.zeros((n_weights + 1, 3, 2)) - counter = np.zeros((n_weights + 1, 3)) - all_datapoints = [[[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]]] - f, axs = plt.subplots(n_weights + 1, 3, figsize=(12, 9)) # We add another weight slot to split the bias - for state_type, weight_gaps in enumerate(all_sudden_transition_changes): + f, axs = plt.subplots(n_weights + 1 + show_weight_augmentations, n_types - sudden_changes, figsize=(4 * (3 - sudden_changes), 9)) # We add another weight slot to split the bias + for state_type, weight_gaps in enumerate(data): for gap in weight_gaps: for i in range(n_weights): if i == n_weights - 1: - axs[i + (gap[1][i] > 0), state_type + 1].plot([0, 1], [gap[0][i], gap[-1][i]], marker="o") # plot how the weight evolves from first to last appearance - average[i + (gap[1][i] > 0), state_type + 1] += np.array([gap[0][i], gap[-1][i]]) - counter[i + (gap[1][i] > 0), state_type + 1] += 1 - all_datapoints[i + (gap[1][i] > 0)][state_type + 1][0].append(gap[0][i]) - all_datapoints[i + (gap[1][i] > 0)][state_type + 1][1].append(gap[-1][i]) - else: - axs[i, state_type + 1].plot([0, 1], [gap[0][i], gap[-1][i]], marker="o") # plot how the weight evolves from first to last appearance - average[i, state_type + 1] += np.array([gap[0][i], gap[-1][i]]) - counter[i, state_type + 1] += 1 - all_datapoints[i][state_type + 1][0].append(gap[0][i]) - all_datapoints[i][state_type + 1][1].append(gap[-1][i]) - - plt.tight_layout() - plt.savefig("./summary_figures/weight_changes/sudden_weight_change at transitions") - plt.show() - - - f, axs = plt.subplots(1, 3, figsize=(12, 6)) - for i in range(1, 3): - for j in range(n_weights + 1): - axs[i].plot([0, 1], average[j, i] / counter[j, i], marker="o", color=weight_colours[j], label=ylabels[j]) - axs[i].set_ylim(-3.5, 3.5) - axs[i].spines['top'].set_visible(False) - axs[i].spines['right'].set_visible(False) - axs[i].set_xticks([]) - if i == 0: - pass - # axs[i].set_ylabel("Weights", size=24) - # if j < n_weights - 1: - # axs[i].plot([0.1], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', color=weight_colours[j]) # also plot weights of very first state average - # if j == n_weights - 1: - # mask = first_and_last_pmf[:, 0, -1] < 0 - # axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=weight_colours[j]) # separete biases again - # if j == n_weights: - # mask = first_and_last_pmf[:, 0, -1] > 0 - # axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=weight_colours[j]) - else: - axs[i].yaxis.set_ticklabels([]) - if i == 2: - pass - # if j < n_weights - 1: - # axs[i].plot([0.9], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', color=weight_colours[j]) # also plot weights of very last state average - # if j == n_weights - 1: - # mask = first_and_last_pmf[:, 1, -1] < 0 - # axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=weight_colours[j]) # separete biases again - # if j == n_weights: - # mask = first_and_last_pmf[:, 1, -1] > 0 - # axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=weight_colours[j]) - if j == 0: - axs[i].set_title("Type {}".format(i + 1), size=26) - if j == n_weights and i == 1: - axs[i].set_xlabel("Lifetime weight change", size=24) - axs[0].legend(frameon=False, fontsize=14) - plt.tight_layout() - plt.savefig("./summary_figures/weight_changes/compact sudden_weight_changes at transitions") - plt.close() - - average = np.zeros((n_weights + 1, 3, 2)) - counter = np.zeros((n_weights + 1, 3)) - all_datapoints = [[[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]]] - f, axs = plt.subplots(n_weights + 1, 3, figsize=(12, 9)) # I added a third pseudo weight for distance between the two sensitivities, and do another extra to split the bias - for state_type, weight_gaps in enumerate(all_sudden_changes): - - for gap in weight_gaps: - for i in range(n_weights): - if i == n_weights - 1: - axs[i + (gap[1][i] > 0), state_type].plot([0, 1], [gap[0][i], gap[-1][i]], marker="o") # plot how the weight evolves from first to last appearance + axs[i + (gap[1][i] > 0), state_type].plot([0, 1], [gap[0][i], gap[-1][i]], marker="o") # plot how the weight of the new state differs from the previous closest weight average[i + (gap[1][i] > 0), state_type] += np.array([gap[0][i], gap[-1][i]]) counter[i + (gap[1][i] > 0), state_type] += 1 all_datapoints[i + (gap[1][i] > 0)][state_type][0].append(gap[0][i]) all_datapoints[i + (gap[1][i] > 0)][state_type][1].append(gap[-1][i]) else: - axs[i, state_type].plot([0, 1], [gap[0][i], gap[-1][i]], marker="o") # plot how the weight evolves from first to last appearance + axs[i, state_type].plot([0, 1], [gap[0][i], gap[-1][i]], marker="o") # plot how the weight of the new state differs from the previous closest weight average[i, state_type] += np.array([gap[0][i], gap[-1][i]]) counter[i, state_type] += 1 all_datapoints[i][state_type][0].append(gap[0][i]) all_datapoints[i][state_type][1].append(gap[-1][i]) + if show_weight_augmentations: + for state_type, weight_gaps in enumerate(augmented_data): + for gap in weight_gaps: + axs[n_weights + 1, state_type].plot([0, 1], [gap[0], gap[-1]], marker="o") # plot how the weight of the new state differs from the previous closest weight + average[n_weights + 1, state_type] += np.array([gap[0], gap[-1]]) + counter[n_weights + 1, state_type] += 1 + all_datapoints[n_weights + 1][state_type][0].append(gap[0]) + all_datapoints[n_weights + 1][state_type][1].append(gap[-1]) + plt.tight_layout() - plt.savefig("./summary_figures/weight_changes/sudden_weight_change") + plt.savefig("./summary_figures/weight_changes/" + title + " augmented" * show_weight_augmentations) plt.close() - f, axs = plt.subplots(1, 3, figsize=(12, 6)) - for i in range(3): - for j in range(n_weights + 1): - axs[i].plot([0, 1], average[j, i] / counter[j, i], marker="o", color=weight_colours[j], label=ylabels[j]) + return average, counter, all_datapoints + + +def plot_compact(average, counter, title, show_first_and_last=False, show_weight_augmentations=False): + """Plot the means of the weight change traces, split by the different weight types, possibly with augmentations""" + sudden_changes = average.shape[1] == n_types - 1 + f, axs = plt.subplots(1, n_types - sudden_changes, figsize=(4 * (3 - sudden_changes), 6)) + for i in range(n_types - sudden_changes): + for j in range(n_weights + 1 + show_weight_augmentations): + axs[i].plot([0, 1], average[j, i] / counter[j, i], marker="o", color=local_weight_colours[j], label=local_ylabels[j]) axs[i].set_ylim(-3.5, 3.5) axs[i].spines['top'].set_visible(False) axs[i].spines['right'].set_visible(False) axs[i].set_xticks([]) - if i == 0: - pass - # axs[i].set_ylabel("Weights", size=24) - # if j < n_weights - 1: - # axs[i].plot([0.1], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', color=weight_colours[j]) # also plot weights of very first state average - # if j == n_weights - 1: - # mask = first_and_last_pmf[:, 0, -1] < 0 - # axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=weight_colours[j]) # separete biases again - # if j == n_weights: - # mask = first_and_last_pmf[:, 0, -1] > 0 - # axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=weight_colours[j]) + if i == 0 and show_first_and_last: + axs[i].set_ylabel("Weights", size=24) + if j < n_weights - 1: + axs[i].plot([0.1], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', color=local_weight_colours[j]) # also plot weights of very first state average + if j == n_weights - 1: + mask = first_and_last_pmf[:, 0, -1] < 0 + axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=local_weight_colours[j]) # separete biases again + if j == n_weights: + mask = first_and_last_pmf[:, 0, -1] > 0 + axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=local_weight_colours[j]) else: axs[i].yaxis.set_ticklabels([]) - if i == 2: - pass - # if j < n_weights - 1: - # axs[i].plot([0.9], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', color=weight_colours[j]) # also plot weights of very last state average - # if j == n_weights - 1: - # mask = first_and_last_pmf[:, 1, -1] < 0 - # axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=weight_colours[j]) # separete biases again - # if j == n_weights: - # mask = first_and_last_pmf[:, 1, -1] > 0 - # axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=weight_colours[j]) + if i == 2 and show_first_and_last: + if j < n_weights - 1: + axs[i].plot([0.9], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', color=local_weight_colours[j]) # also plot weights of very last state average + if j == n_weights - 1: + mask = first_and_last_pmf[:, 1, -1] < 0 + axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=local_weight_colours[j]) # separete biases again + if j == n_weights: + mask = first_and_last_pmf[:, 1, -1] > 0 + axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=local_weight_colours[j]) if j == 0: - axs[i].set_title("Type {}".format(i + 1), size=26) + axs[i].set_title("Type {}".format(i + 1 + sudden_changes), size=26) if j == n_weights and i == 1: axs[i].set_xlabel("Lifetime weight change", size=24) axs[0].legend(frameon=False, fontsize=14) plt.tight_layout() - plt.savefig("./summary_figures/weight_changes/compact sudden_weight_changes") + plt.savefig("./summary_figures/weight_changes/" + title + " augmented" * show_weight_augmentations) plt.close() - bin_sets = [np.linspace(-6.5, 6.5, 30), np.linspace(-1, 2, 30), np.linspace(-3.5, 3.5, 30)] - if True: - x_lim_used_full, x_lim_used_half = 56, 28 - f, axs = plt.subplots(n_weights + 1, 3 * 2, figsize=(12, 9)) - for i in range(3): - for j in range(n_weights + 1): - # axs[j, i].plot([0, 1], average[j, i] / counter[j, i], marker="o") - if j < 2: - bins = bin_sets[0] - elif j == 2: - bins = bin_sets[1] - else: - bins = bin_sets[2] - axs[j, i * 2].hist(all_datapoints[j][i][0], orientation='horizontal', bins=bins, color='grey', alpha=0.5) + +def plot_histogram_diffs(all_datapoints, x_lim_used_normal, x_lim_used_bias, bin_sets, title, x_lim_used_augment=0, show_deltas=True, show_first_and_last=False, show_weight_augmentations=False): + """Plot histograms over the weights, and the mean changes connecting them. + Might have to mess quite a bit with the y-axis""" + sudden_changes = len(all_datapoints[0]) == 2 + f, axs = plt.subplots(n_weights + 1 + show_weight_augmentations, (n_types - sudden_changes) * 2, figsize=(4 * (3 - sudden_changes), 9)) + for i in range(n_types - sudden_changes): + for j in range(n_weights + 1 + show_weight_augmentations): + if j < 2: + bins = bin_sets[0] + elif j == 2: + bins = bin_sets[1] + elif j in [3, 4]: + bins = bin_sets[2] + else: + bins = bin_sets[3] + axs[j, i * 2].hist(all_datapoints[j][i][0], orientation='horizontal', bins=bins, color='grey', alpha=0.5) + if show_deltas: + axs[j, i * 2 + 1].hist(np.array(all_datapoints[j][i][1]) - np.array(all_datapoints[j][i][0]), orientation='horizontal', bins=bins, color='red', alpha=0.5) + else: axs[j, i * 2 + 1].hist(all_datapoints[j][i][1], orientation='horizontal', bins=bins, color='grey', alpha=0.5) - if j < 2: - axs[j, i * 2].set_ylim(-6.5, 6.5) - axs[j, i * 2 + 1].set_ylim(-6.5, 6.5) - elif j == 2: - axs[j, i * 2].set_ylim(-1, 2) - axs[j, i * 2 + 1].set_ylim(-1, 2) - else: - axs[j, i * 2].set_ylim(-3.5, 3.5) - axs[j, i * 2 + 1].set_ylim(-3.5, 3.5) - axs[j, i * 2].spines['top'].set_visible(False) - axs[j, i * 2].spines['right'].set_visible(False) - axs[j, i * 2].set_xticks([]) - axs[j, i * 2 + 1].spines['top'].set_visible(False) - axs[j, i * 2 + 1].spines['right'].set_visible(False) - axs[j, i * 2 + 1].set_xticks([]) - - axs[j, i * 2].annotate("Var {:.2f}".format(np.var(all_datapoints[j][i][0])), xy=(0.65, 0.8), xycoords='axes fraction') + axs[j, i * 2].set_ylim(bins[0], bins[-1]) + axs[j, i * 2 + 1].set_ylim(bins[0], bins[-1]) + + axs[j, i * 2].spines['top'].set_visible(False) + axs[j, i * 2].spines['right'].set_visible(False) + axs[j, i * 2].set_xticks([]) + axs[j, i * 2 + 1].spines['top'].set_visible(False) + axs[j, i * 2 + 1].spines['right'].set_visible(False) + axs[j, i * 2 + 1].set_xticks([]) + + axs[j, i * 2].annotate("Var {:.2f}".format(np.var(all_datapoints[j][i][0])), xy=(0.65, 0.8), xycoords='axes fraction') + if show_deltas: + axs[j, i * 2 + 1].annotate("Var {:.2f}".format(np.var(np.array(all_datapoints[j][i][1]) - np.array(all_datapoints[j][i][0]))), xy=(0.65, 0.8), xycoords='axes fraction') + else: axs[j, i * 2 + 1].annotate("Var {:.2f}".format(np.var(all_datapoints[j][i][1])), xy=(0.65, 0.8), xycoords='axes fraction') + if j < n_weights - 1: + assert x_lim_used_normal > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_normal, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1])) + axs[j, i * 2].set_xlim(0, x_lim_used_normal) + axs[j, i * 2 + 1].set_xlim(0, x_lim_used_normal) + means = average[j, i] / counter[j, i] + con = ConnectionPatch(xyA=(x_lim_used_normal / 12, means[0]), xyB=(0, means[1]), coordsA="data", coordsB="data", + axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue") + axs[j, i * 2 + 1].add_artist(con) + elif j < n_weights + 1: + assert x_lim_used_bias > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_bias, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1])) + axs[j, i * 2].set_xlim(0, x_lim_used_bias) + axs[j, i * 2 + 1].set_xlim(0, x_lim_used_bias) + means = average[j, i] / counter[j, i] + con = ConnectionPatch(xyA=(x_lim_used_bias / 12, means[0]), xyB=(0, means[1]), coordsA="data", coordsB="data", + axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue") + axs[j, i * 2 + 1].add_artist(con) + else: + assert x_lim_used_augment > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_augment, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1])) + axs[j, i * 2].set_xlim(0, x_lim_used_augment) + axs[j, i * 2 + 1].set_xlim(0, x_lim_used_augment) + means = average[j, i] / counter[j, i] + con = ConnectionPatch(xyA=(x_lim_used_augment / 12, means[0]), xyB=(0, means[1]), coordsA="data", coordsB="data", + axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue") + axs[j, i * 2 + 1].add_artist(con) + + if i == 0: + axs[j, i].set_ylabel(local_ylabels[j]) + axs[j, i * 2 + 1].yaxis.set_ticklabels([]) + if show_first_and_last: + if j < n_weights - 1: + axs[j, i * 2].plot([x_lim_used_normal / 8], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', c='red') # also plot weights of very first state average + if j == n_weights - 1: + mask = first_and_last_pmf[:, 0, -1] < 0 + axs[j, i * 2].plot([x_lim_used_bias / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red') # separete biases again + if j == n_weights: + mask = first_and_last_pmf[:, 0, -1] > 0 + axs[j, i * 2].plot([x_lim_used_augment / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red') + else: + axs[j, i * 2].yaxis.set_ticklabels([]) + axs[j, i * 2 + 1].yaxis.set_ticklabels([]) + if i == n_types - 1 and show_first_and_last: if j < n_weights - 1: - assert x_lim_used_full > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_full, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1])) - axs[j, i * 2].set_xlim(0, x_lim_used_full) - axs[j, i * 2 + 1].set_xlim(0, x_lim_used_full) - means = average[j, i] / counter[j, i] - con = ConnectionPatch(xyA=(x_lim_used_full / 8, means[0]), xyB=(x_lim_used_full / 8, means[1]), coordsA="data", coordsB="data", - axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue") - axs[j, i * 2 + 1].add_artist(con) + axs[j, i * 2 + 1].plot([x_lim_used_normal / 8], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', c='red') # also plot weights of very last state average + if j == n_weights - 1: + mask = first_and_last_pmf[:, 1, -1] < 0 + axs[j, i * 2 + 1].plot([x_lim_used_bias / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red') # separete biases again + if j == n_weights: + mask = first_and_last_pmf[:, 1, -1] > 0 + axs[j, i * 2 + 1].plot([x_lim_used_augment / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red') + if j == 0: + axs[j, i * 2].set_title("Type {}".format(i + 1), loc='right') + if j == n_weights + show_weight_augmentations and i == 0: + axs[j, 0].set_xlabel("Initial distribution") + if show_deltas: + axs[j, 1].set_xlabel("Weight change") else: - assert x_lim_used_half > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_half, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1])) - axs[j, i * 2].set_xlim(0, x_lim_used_half) - axs[j, i * 2 + 1].set_xlim(0, x_lim_used_half) - means = average[j, i] / counter[j, i] - con = ConnectionPatch(xyA=(x_lim_used_half / 8, means[0]), xyB=(x_lim_used_half / 8, means[1]), coordsA="data", coordsB="data", - axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue") - axs[j, i * 2 + 1].add_artist(con) + axs[j, 1].set_xlabel("Changed distribution") - if i == 0: - axs[j, i].set_ylabel(ylabels[j]) - axs[j, i * 2 + 1].yaxis.set_ticklabels([]) - # if j < n_weights - 1: - # axs[j, i * 2].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', c='red') # also plot weights of very first state average - # if j == n_weights - 1: - # mask = first_and_last_pmf[:, 0, -1] < 0 - # axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red') # separete biases again - # if j == n_weights: - # mask = first_and_last_pmf[:, 0, -1] > 0 - # axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red') - else: - axs[j, i * 2].yaxis.set_ticklabels([]) - axs[j, i * 2 + 1].yaxis.set_ticklabels([]) - if i == 2: - pass - # if j < n_weights - 1: - # axs[j, i * 2 + 1].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', c='red') # also plot weights of very last state average - # if j == n_weights - 1: - # mask = first_and_last_pmf[:, 1, -1] < 0 - # axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red') # separete biases again - # if j == n_weights: - # mask = first_and_last_pmf[:, 1, -1] > 0 - # axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red') - if j == 0: - axs[j, i * 2].set_title("Type {}".format(i + 1), loc='right') - if j == n_weights: - axs[j, i * 2].set_xlabel("Lifetime weight change", loc='right') + plt.savefig("./summary_figures/weight_changes/" + title) + plt.close() - plt.savefig("./summary_figures/weight_changes/weight changes sudden hists") - plt.close() - x_lim_used_full, x_lim_used_half = 120, 60 - bin_sets = [np.linspace(-6.5, 6.5, 30), np.linspace(-1, 2, 30), np.linspace(-3.5, 3.5, 30)] - f, axs = plt.subplots(n_weights + 1, 3 * 2, figsize=(12, 9)) - for i in range(3): - for j in range(n_weights + 1): - # axs[j, i].plot([0, 1], average[j, i] / counter[j, i], marker="o") - if j < 2: - bins = bin_sets[0] - elif j == 2: - bins = bin_sets[1] - else: - bins = bin_sets[2] - axs[j, i * 2].hist(all_datapoints[j][i][0], orientation='horizontal', bins=bins, color='grey', alpha=0.5) - axs[j, i * 2 + 1].hist(np.array(all_datapoints[j][i][1]) - np.array(all_datapoints[j][i][0]), orientation='horizontal', bins=bins, color='grey', alpha=0.5) +if True: - if j < 2: - axs[j, i * 2].set_ylim(-6.5, 6.5) - axs[j, i * 2 + 1].set_ylim(-6.5, 6.5) - elif j == 2: - axs[j, i * 2].set_ylim(-1, 2) - axs[j, i * 2 + 1].set_ylim(-1, 2) - else: - axs[j, i * 2].set_ylim(-3.5, 3.5) - axs[j, i * 2 + 1].set_ylim(-3.5, 3.5) - axs[j, i * 2].spines['top'].set_visible(False) - axs[j, i * 2].spines['right'].set_visible(False) - axs[j, i * 2].set_xticks([]) - axs[j, i * 2 + 1].spines['top'].set_visible(False) - axs[j, i * 2 + 1].spines['right'].set_visible(False) - axs[j, i * 2 + 1].set_xticks([]) - - axs[j, i * 2].annotate("Var {:.2f}".format(np.var(all_datapoints[j][i][0])), xy=(0.65, 0.8), xycoords='axes fraction') - axs[j, i * 2 + 1].annotate("Var {:.2f}".format(np.var(np.array(all_datapoints[j][i][1]) - np.array(all_datapoints[j][i][0]))), xy=(0.65, 0.8), xycoords='axes fraction') + bin_sets = [np.linspace(-6.5, 6.5, 30), np.linspace(-1, 2, 30), np.linspace(-3.5, 3.5, 30), np.linspace(-0.2, 1, 30)] # TODO: make show_augmentation robuse - if j < n_weights - 1: - assert x_lim_used_full > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_full, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1])) - axs[j, i * 2].set_xlim(0, x_lim_used_full) - axs[j, i * 2 + 1].set_xlim(0, x_lim_used_full) - means = average[j, i] / counter[j, i] - con = ConnectionPatch(xyA=(x_lim_used_full / 8, means[0]), xyB=(x_lim_used_full / 8, means[1]), coordsA="data", coordsB="data", - axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue") - axs[j, i * 2 + 1].add_artist(con) - else: - assert x_lim_used_half > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_half, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1])) - axs[j, i * 2].set_xlim(0, x_lim_used_half) - axs[j, i * 2 + 1].set_xlim(0, x_lim_used_half) - means = average[j, i] / counter[j, i] - con = ConnectionPatch(xyA=(x_lim_used_half / 8, means[0]), xyB=(x_lim_used_half / 8, means[1]), coordsA="data", coordsB="data", - axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue") - axs[j, i * 2 + 1].add_artist(con) + average, counter, all_datapoints = plot_traces_and_collate_data(data=all_sudden_transition_changes, augmented_data=aug_all_sudden_transition_changes, title="sudden_weight_change at transitions") - if i == 0: - axs[j, i].set_ylabel(ylabels[j]) - axs[j, i * 2 + 1].yaxis.set_ticklabels([]) - # if j < n_weights - 1: - # axs[j, i * 2].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', c='red') # also plot weights of very first state average - # if j == n_weights - 1: - # mask = first_and_last_pmf[:, 0, -1] < 0 - # axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red') # separete biases again - # if j == n_weights: - # mask = first_and_last_pmf[:, 0, -1] > 0 - # axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red') - else: - axs[j, i * 2].yaxis.set_ticklabels([]) - axs[j, i * 2 + 1].yaxis.set_ticklabels([]) - if i == 2: - pass - # if j < n_weights - 1: - # axs[j, i * 2 + 1].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', c='red') # also plot weights of very last state average - # if j == n_weights - 1: - # mask = first_and_last_pmf[:, 1, -1] < 0 - # axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red') # separete biases again - # if j == n_weights: - # mask = first_and_last_pmf[:, 1, -1] > 0 - # axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red') - if j == 0: - axs[j, i * 2].set_title("Type {}".format(i + 1), loc='right') - if j == n_weights: - axs[j, i * 2].set_xlabel("Lifetime weight change", loc='right') + plot_compact(average, counter, title="compact sudden_weight_change at transitions", show_weight_augmentations=show_weight_augmentations) - plt.savefig("./summary_figures/weight_changes/weight changes sudden delta hists") - plt.close() + plot_histogram_diffs(all_datapoints, x_lim_used_normal=22, x_lim_used_bias=11, x_lim_used_augment=40, bin_sets=bin_sets, title="weight changes sudden at transitions hists", show_deltas=False, show_weight_augmentations=show_weight_augmentations) + plot_histogram_diffs(all_datapoints, x_lim_used_normal=22, x_lim_used_bias=11, x_lim_used_augment=40, bin_sets=bin_sets, title="weight changes sudden at transitions delta hists", show_deltas=True, show_weight_augmentations=show_weight_augmentations) - # for weight_traj in all_weight_trajectories: - # if len(weight_traj) == 1: - # continue - # state_type = pmf_type(weights_to_pmf(weight_traj[0])) + average, counter, all_datapoints = plot_traces_and_collate_data(data=all_sudden_changes, augmented_data=aug_all_sudden_changes, title="sudden_weight_change") - # if state_type == 2 and weight_traj[0][0] > 0.8: - # plt.plot(weights_to_pmf(weight_traj[0])) - # plt.ylim(0, 1) - # plt.show() + plot_compact(average, counter, title="compact sudden_weight_change", show_weight_augmentations=show_weight_augmentations) - dur_lims = [(52, 25), (46, 23), (38, 19), (35, 17), (30, 15), (25, 12), (20, 10), (18, 9)] + plot_histogram_diffs(all_datapoints, x_lim_used_normal=56, x_lim_used_bias=28, x_lim_used_augment=140, bin_sets=bin_sets, title="weight changes sudden hists", show_deltas=False, show_weight_augmentations=show_weight_augmentations) + plot_histogram_diffs(all_datapoints, x_lim_used_normal=120, x_lim_used_bias=60, x_lim_used_augment=180, bin_sets=bin_sets, title="weight changes sudden delta hists", show_deltas=True, show_weight_augmentations=show_weight_augmentations) + + dur_lims = [(52, 25), (46, 23), (38, 19), (35, 17), (30, 15), (25, 12), (20, 10), (18, 9)] for min_dur_counter, min_dur in enumerate([2, 3, 4, 5, 7, 9, 11, 15]): - f, axs = plt.subplots(n_weights + 1, 3, figsize=(12, 9)) # add another space to split bias into stating left versus starting right - average = np.zeros((n_weights + 1, 3, 2)) - counter = np.zeros((n_weights + 1, 3)) - all_datapoints = [[[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]]] + average = np.zeros((n_weights + 1, n_types, 2)) + counter = np.zeros((n_weights + 1, n_types)) + all_datapoints = create_nested_list([n_weights + 1, n_types, 2]) + + f, axs = plt.subplots(n_weights + 1, n_types, figsize=(12, 9)) # add another space to split bias into stating left versus starting right for weight_traj in all_weight_trajectories: if len(weight_traj) < min_dur or len(weight_traj) > 15: # take out too short trajectories, and too long ones @@ -388,61 +305,26 @@ if True: all_datapoints[i][state_type][1].append(weight_traj[-1][i]) for i in range(3): for j in range(n_weights + 1): - axs[j, i].set_ylim(-9, 9) - axs[j, i].spines['top'].set_visible(False) - axs[j, i].spines['right'].set_visible(False) - axs[j, i].set_xticks([]) - if i == 0: - axs[j, i].set_ylabel(ylabels[j]) - else: - axs[j, i].yaxis.set_ticklabels([]) - if j == 0: - axs[j, i].set_title("Type {}".format(i + 1)) - if j == n_weights: - axs[j, i].set_xlabel("Lifetime weight change") - - plt.tight_layout() - plt.savefig("./summary_figures/weight_changes/all weight changes min dur {}".format(min_dur)) - plt.close() - - f, axs = plt.subplots(1, 3, figsize=(12, 6)) - for i in range(3): - for j in range(n_weights + 1): - axs[i].plot([0, 1], average[j, i] / counter[j, i], marker="o", color=weight_colours[j], label=ylabels[j]) - axs[i].set_ylim(-3.5, 3.5) - axs[i].spines['top'].set_visible(False) - axs[i].spines['right'].set_visible(False) - axs[i].set_xticks([]) + axs[j, i].set_ylim(-9, 9) + axs[j, i].spines['top'].set_visible(False) + axs[j, i].spines['right'].set_visible(False) + axs[j, i].set_xticks([]) if i == 0: - axs[i].set_ylabel("Weights", size=24) - if j < n_weights - 1: - axs[i].plot([0.1], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', color=weight_colours[j]) # also plot weights of very first state average - if j == n_weights - 1: - mask = first_and_last_pmf[:, 0, -1] < 0 - axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=weight_colours[j]) # separete biases again - if j == n_weights: - mask = first_and_last_pmf[:, 0, -1] > 0 - axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=weight_colours[j]) + axs[j, i].set_ylabel(ylabels[j]) else: - axs[i].yaxis.set_ticklabels([]) - if i == 2: - if j < n_weights - 1: - axs[i].plot([0.9], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', color=weight_colours[j]) # also plot weights of very last state average - if j == n_weights - 1: - mask = first_and_last_pmf[:, 1, -1] < 0 - axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=weight_colours[j]) # separete biases again - if j == n_weights: - mask = first_and_last_pmf[:, 1, -1] > 0 - axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=weight_colours[j]) + axs[j, i].yaxis.set_ticklabels([]) if j == 0: - axs[i].set_title("Type {}".format(i + 1), size=26) - if j == n_weights and i == 1: - axs[i].set_xlabel("Lifetime weight change", size=24) - axs[0].legend(frameon=False, fontsize=14) + axs[j, i].set_title("Type {}".format(i + 1)) + if j == n_weights: + axs[j, i].set_xlabel("Lifetime weight change") + plt.tight_layout() - plt.savefig("./summary_figures/weight_changes/compact changes min dur {}".format(min_dur)) + plt.savefig("./summary_figures/weight_changes/all weight changes min dur {}".format(min_dur)) plt.close() + plot_compact(average, counter, title="compact changes min dur {}".format(min_dur), show_first_and_last=True) + + # means all split up f, axs = plt.subplots(n_weights + 1, 3, figsize=(12, 9)) for i in range(3): for j in range(n_weights + 1): @@ -482,168 +364,14 @@ if True: plt.savefig("./summary_figures/weight_changes/weight changes min dur {}".format(min_dur)) plt.close() - x_lim_used_full, x_lim_used_half = dur_lims[min_dur_counter] - - # also histograms - f, axs = plt.subplots(n_weights + 1, 3 * 2, figsize=(12, 9)) - for i in range(3): - for j in range(n_weights + 1): - if j < 2: - bins = bin_sets[0] - elif j == 2: - bins = bin_sets[1] - else: - bins = bin_sets[2] - axs[j, i * 2].hist(all_datapoints[j][i][0], orientation='horizontal', bins=bins, color='grey', alpha=0.5) - axs[j, i * 2 + 1].hist(all_datapoints[j][i][1], orientation='horizontal', bins=bins, color='grey', alpha=0.5) - - if j < 2: - axs[j, i * 2].set_ylim(-6.5, 6.5) - axs[j, i * 2 + 1].set_ylim(-6.5, 6.5) - elif j == 2: - axs[j, i * 2].set_ylim(-1, 2) - axs[j, i * 2 + 1].set_ylim(-1, 2) - else: - axs[j, i * 2].set_ylim(-3.5, 3.5) - axs[j, i * 2 + 1].set_ylim(-3.5, 3.5) - axs[j, i * 2].spines['top'].set_visible(False) - axs[j, i * 2].spines['right'].set_visible(False) - axs[j, i * 2].set_xticks([]) - axs[j, i * 2 + 1].spines['top'].set_visible(False) - axs[j, i * 2 + 1].spines['right'].set_visible(False) - axs[j, i * 2 + 1].set_xticks([]) - - axs[j, i * 2].annotate("Var {:.2f}".format(np.var(all_datapoints[j][i][0])), xy=(0.65, 0.8), xycoords='axes fraction') - axs[j, i * 2 + 1].annotate("Var {:.2f}".format(np.var(all_datapoints[j][i][1])), xy=(0.65, 0.8), xycoords='axes fraction') - - if j < n_weights - 1: - assert x_lim_used_full > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_full, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1])) - axs[j, i * 2].set_xlim(0, x_lim_used_full) - axs[j, i * 2 + 1].set_xlim(0, x_lim_used_full) - means = average[j, i] / counter[j, i] - con = ConnectionPatch(xyA=(x_lim_used_full / 8, means[0]), xyB=(x_lim_used_full / 8, means[1]), coordsA="data", coordsB="data", - axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue") - axs[j, i * 2 + 1].add_artist(con) - else: - assert x_lim_used_half > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_half, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1])) - axs[j, i * 2].set_xlim(0, x_lim_used_half) - axs[j, i * 2 + 1].set_xlim(0, x_lim_used_half) - means = average[j, i] / counter[j, i] - con = ConnectionPatch(xyA=(x_lim_used_half / 8, means[0]), xyB=(x_lim_used_half / 8, means[1]), coordsA="data", coordsB="data", - axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue") - axs[j, i * 2 + 1].add_artist(con) - - if i == 0: - axs[j, i].set_ylabel(ylabels[j]) - axs[j, i * 2 + 1].yaxis.set_ticklabels([]) - if j < n_weights - 1: - axs[j, i * 2].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', c='red') # also plot weights of very first state average - if j == n_weights - 1: - mask = first_and_last_pmf[:, 0, -1] < 0 - axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red') # separete biases again - if j == n_weights: - mask = first_and_last_pmf[:, 0, -1] > 0 - axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red') - else: - axs[j, i * 2].yaxis.set_ticklabels([]) - axs[j, i * 2 + 1].yaxis.set_ticklabels([]) - if i == 2: - if j < n_weights - 1: - axs[j, i * 2 + 1].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', c='red') # also plot weights of very last state average - if j == n_weights - 1: - mask = first_and_last_pmf[:, 1, -1] < 0 - axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red') # separete biases again - if j == n_weights: - mask = first_and_last_pmf[:, 1, -1] > 0 - axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red') - if j == 0: - axs[j, i * 2].set_title("Type {}".format(i + 1), loc='right') - if j == n_weights: - axs[j, i * 2].set_xlabel("Lifetime weight change", loc='right') - - plt.savefig("./summary_figures/weight_changes/weight changes min dur {} hists".format(min_dur)) - plt.close() - - # also histograms of deltas - x_lim_used_full, x_lim_used_half = x_lim_used_full * 2.5, x_lim_used_half * 2.5 - f, axs = plt.subplots(n_weights + 1, 3 * 2, figsize=(12, 9)) - for i in range(3): - for j in range(n_weights + 1): - if j < 2: - bins = bin_sets[0] - elif j == 2: - bins = bin_sets[1] - else: - bins = bin_sets[2] - axs[j, i * 2].hist(all_datapoints[j][i][0], orientation='horizontal', bins=bins, color='grey', alpha=0.5) - axs[j, i * 2 + 1].hist(np.array(all_datapoints[j][i][1]) - np.array(all_datapoints[j][i][0]), orientation='horizontal', bins=bins, color='grey', alpha=0.5) - - if j < 2: - axs[j, i * 2].set_ylim(-6.5, 6.5) - axs[j, i * 2 + 1].set_ylim(-6.5, 6.5) - elif j == 2: - axs[j, i * 2].set_ylim(-1, 2) - axs[j, i * 2 + 1].set_ylim(-1, 2) - else: - axs[j, i * 2].set_ylim(-3.5, 3.5) - axs[j, i * 2 + 1].set_ylim(-3.5, 3.5) - axs[j, i * 2].spines['top'].set_visible(False) - axs[j, i * 2].spines['right'].set_visible(False) - axs[j, i * 2].set_xticks([]) - axs[j, i * 2 + 1].spines['top'].set_visible(False) - axs[j, i * 2 + 1].spines['right'].set_visible(False) - axs[j, i * 2 + 1].set_xticks([]) - - axs[j, i * 2].annotate("Var {:.2f}".format(np.var(all_datapoints[j][i][0])), xy=(0.65, 0.8), xycoords='axes fraction') - axs[j, i * 2 + 1].annotate("Var {:.2f}".format(np.var(np.array(all_datapoints[j][i][1]) - np.array(all_datapoints[j][i][0]))), xy=(0.65, 0.8), xycoords='axes fraction') + x_lim_used_normal, x_lim_used_bias = dur_lims[min_dur_counter] + plot_histogram_diffs(all_datapoints=all_datapoints, x_lim_used_normal=x_lim_used_normal, x_lim_used_bias=x_lim_used_bias, bin_sets=bin_sets, + title="weight changes min dur {} hists".format(min_dur), show_deltas=False, show_first_and_last=True) - if j < n_weights - 1: - assert x_lim_used_full > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_full, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1])) - axs[j, i * 2].set_xlim(0, x_lim_used_full) - axs[j, i * 2 + 1].set_xlim(0, x_lim_used_full) - means = average[j, i] / counter[j, i] - con = ConnectionPatch(xyA=(x_lim_used_full / 8, means[0]), xyB=(x_lim_used_full / 8, means[1]), coordsA="data", coordsB="data", - axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue") - axs[j, i * 2 + 1].add_artist(con) - else: - assert x_lim_used_half > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_half, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1])) - axs[j, i * 2].set_xlim(0, x_lim_used_half) - axs[j, i * 2 + 1].set_xlim(0, x_lim_used_half) - means = average[j, i] / counter[j, i] - con = ConnectionPatch(xyA=(x_lim_used_half / 8, means[0]), xyB=(x_lim_used_half / 8, means[1]), coordsA="data", coordsB="data", - axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue") - axs[j, i * 2 + 1].add_artist(con) + x_lim_used_normal, x_lim_used_bias = x_lim_used_normal * 2.5, x_lim_used_bias * 2.5 + plot_histogram_diffs(all_datapoints=all_datapoints, x_lim_used_normal=x_lim_used_normal, x_lim_used_bias=x_lim_used_bias, bin_sets=bin_sets, + title="weight changes min dur {} delta hists".format(min_dur), show_deltas=True, show_first_and_last=True) - if i == 0: - axs[j, i].set_ylabel(ylabels[j]) - axs[j, i * 2 + 1].yaxis.set_ticklabels([]) - if j < n_weights - 1: - axs[j, i * 2].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', c='red') # also plot weights of very first state average - if j == n_weights - 1: - mask = first_and_last_pmf[:, 0, -1] < 0 - axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red') # separete biases again - if j == n_weights: - mask = first_and_last_pmf[:, 0, -1] > 0 - axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red') - else: - axs[j, i * 2].yaxis.set_ticklabels([]) - axs[j, i * 2 + 1].yaxis.set_ticklabels([]) - if i == 2: - if j < n_weights - 1: - axs[j, i * 2 + 1].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', c='red') # also plot weights of very last state average - if j == n_weights - 1: - mask = first_and_last_pmf[:, 1, -1] < 0 - axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red') # separete biases again - if j == n_weights: - mask = first_and_last_pmf[:, 1, -1] > 0 - axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red') - if j == 0: - axs[j, i * 2].set_title("Type {}".format(i + 1), loc='right') - if j == n_weights: - axs[j, i * 2].set_xlabel("Lifetime weight change", loc='right') - - plt.savefig("./summary_figures/weight_changes/weight changes min dur {} delta hists".format(min_dur)) - plt.close() quit() # all pmf weights