diff --git a/analysis_pmf_weights.py b/analysis_pmf_weights.py
index cbcb0bdd52535fc80d81c1946b153c44ce783873..807e94f05084eb6da22b723f56f21e02366982a1 100644
--- a/analysis_pmf_weights.py
+++ b/analysis_pmf_weights.py
@@ -7,15 +7,47 @@ from mpl_toolkits import mplot3d
 from matplotlib.patches import ConnectionPatch
 
 
+show_weight_augmentations = True
+
+all_weight_trajectories = pickle.load(open("multi_chain_saves/all_weight_trajectories.p", 'rb'))
+first_and_last_pmf = np.array(pickle.load(open("multi_chain_saves/first_and_last_pmf.p", 'rb')))
+all_sudden_changes = pickle.load(open("multi_chain_saves/all_sudden_changes.p", 'rb'))
+all_sudden_transition_changes = pickle.load(open("multi_chain_saves/all_sudden_transition_changes.p", 'rb'))
+
+if show_weight_augmentations:
+    aug_all_sudden_changes = pickle.load(open("multi_chain_saves/aug_all_sudden_changes.p", 'rb'))
+    aug_all_sudden_transition_changes = pickle.load(open("multi_chain_saves/aug_all_sudden_transition_changes.p", 'rb'))
+
 performance_points = np.array([-1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0])
 reduced_points = np.array([1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1], dtype=bool)
 weight_colours = ['blue', 'red', 'green', 'goldenrod', 'darkorange']
-weight_colours_aug = ['blue', 'red', 'purple', 'green', 'goldenrod', 'darkorange']
+weight_colours_aug = ['blue', 'red', 'green', 'goldenrod', 'darkorange', 'purple']
 ylabels = ["Cont left", "Cont right", "Persevere", "Bias left", "Bias right"]
-ylabels_aug = ["Cont left", "Cont right", "Cont diff", "Persevere", "Bias left", "Bias right"]
+ylabels_aug = ["Cont left", "Cont right", "Persevere", "Bias left", "Bias right", "PMF span"]
 
 folder = "./reward_analysis/"
 
+local_ylabels = ylabels_aug if show_weight_augmentations else ylabels
+local_weight_colours = weight_colours_aug if show_weight_augmentations else weight_colours
+
+n_weights = all_weight_trajectories[0][0].shape[0]
+n_types = 3
+
+
+def create_nested_list(list_of_ns):
+    """
+    Create a complex nested list, according to the specifications of list_of_ns.
+    E.g. [3, 2, 4] will return a list with 3 sublists, each containing 2 sublists again, each containing 4 sublists yet again.
+    """
+    ret = []
+    if len(list_of_ns) == 0:
+        return ret
+
+    for _ in range(list_of_ns[0]):
+        ret.append(create_nested_list(list_of_ns=list_of_ns[1:]))
+
+    return ret
+
 
 def pmf_to_perf(pmf):
     # determine performance of a pmf, but only on the omnipresent strongest contrasts
@@ -41,331 +73,216 @@ def pmf_type_rew(weights):
         return 2
 
 
-if True:
-    all_weight_trajectories = pickle.load(open("multi_chain_saves/all_weight_trajectories.p", 'rb'))
-    first_and_last_pmf = np.array(pickle.load(open("multi_chain_saves/first_and_last_pmf.p", 'rb')))
-    all_sudden_changes = pickle.load(open("multi_chain_saves/all_sudden_changes.p", 'rb'))
-
-    n_weights = all_weight_trajectories[0][0].shape[0]
-
-    all_sudden_transition_changes = pickle.load(open("multi_chain_saves/all_sudden_transition_changes.p", 'rb'))
+def plot_traces_and_collate_data(data, augmented_data=None, title=""):
+    """Plot all the individual weight change traces, ann collect the data we need for the other plots"""
+    show_weight_augmentations = augmented_data is not None
+    sudden_changes = len(data) == n_types - 1
+    average = np.zeros((n_weights + 1 + show_weight_augmentations, n_types - sudden_changes, 2))
+    counter = np.zeros((n_weights + 1 + show_weight_augmentations, n_types - sudden_changes))
+    all_datapoints = create_nested_list([n_weights + 1 + show_weight_augmentations, n_types - sudden_changes, 2])
 
-    average = np.zeros((n_weights + 1, 3, 2))
-    counter = np.zeros((n_weights + 1, 3))
-    all_datapoints = [[[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]]]
-    f, axs = plt.subplots(n_weights + 1, 3, figsize=(12, 9))  # We add another weight slot to split the bias
-    for state_type, weight_gaps in enumerate(all_sudden_transition_changes):
+    f, axs = plt.subplots(n_weights + 1 + show_weight_augmentations, n_types - sudden_changes, figsize=(4 * (3 - sudden_changes), 9))  # We add another weight slot to split the bias
+    for state_type, weight_gaps in enumerate(data):
 
         for gap in weight_gaps:
             for i in range(n_weights):
                 if i == n_weights - 1:
-                    axs[i + (gap[1][i] > 0), state_type + 1].plot([0, 1], [gap[0][i], gap[-1][i]], marker="o") # plot how the weight evolves from first to last appearance
-                    average[i + (gap[1][i] > 0), state_type + 1] += np.array([gap[0][i], gap[-1][i]])
-                    counter[i + (gap[1][i] > 0), state_type + 1] += 1
-                    all_datapoints[i + (gap[1][i] > 0)][state_type + 1][0].append(gap[0][i])
-                    all_datapoints[i + (gap[1][i] > 0)][state_type + 1][1].append(gap[-1][i])
-                else:
-                    axs[i, state_type + 1].plot([0, 1], [gap[0][i], gap[-1][i]], marker="o") # plot how the weight evolves from first to last appearance
-                    average[i, state_type + 1] += np.array([gap[0][i], gap[-1][i]])
-                    counter[i, state_type + 1] += 1
-                    all_datapoints[i][state_type + 1][0].append(gap[0][i])
-                    all_datapoints[i][state_type + 1][1].append(gap[-1][i])
-
-    plt.tight_layout()
-    plt.savefig("./summary_figures/weight_changes/sudden_weight_change at transitions")
-    plt.show()
-
-
-    f, axs = plt.subplots(1, 3, figsize=(12, 6))
-    for i in range(1, 3):
-        for j in range(n_weights + 1):
-            axs[i].plot([0, 1], average[j, i] / counter[j, i], marker="o", color=weight_colours[j], label=ylabels[j])
-            axs[i].set_ylim(-3.5, 3.5)
-            axs[i].spines['top'].set_visible(False)
-            axs[i].spines['right'].set_visible(False)
-            axs[i].set_xticks([])
-            if i == 0:
-                pass
-            #     axs[i].set_ylabel("Weights", size=24)
-            #     if j < n_weights - 1:
-            #         axs[i].plot([0.1], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', color=weight_colours[j])  # also plot weights of very first state average
-            #     if j == n_weights - 1:
-            #         mask = first_and_last_pmf[:, 0, -1] < 0
-            #         axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=weight_colours[j])  # separete biases again
-            #     if j == n_weights:
-            #         mask = first_and_last_pmf[:, 0, -1] > 0
-            #         axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=weight_colours[j])
-            else:
-                axs[i].yaxis.set_ticklabels([])
-            if i == 2:
-                pass
-            #     if j < n_weights - 1:
-            #         axs[i].plot([0.9], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', color=weight_colours[j])  # also plot weights of very last state average
-            #     if j == n_weights - 1:
-            #         mask = first_and_last_pmf[:, 1, -1] < 0
-            #         axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=weight_colours[j])  # separete biases again
-            #     if j == n_weights:
-            #         mask = first_and_last_pmf[:, 1, -1] > 0
-            #         axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=weight_colours[j])
-            if j == 0:
-                axs[i].set_title("Type {}".format(i + 1), size=26)
-            if j == n_weights and i == 1:
-                axs[i].set_xlabel("Lifetime weight change", size=24)
-    axs[0].legend(frameon=False, fontsize=14)
-    plt.tight_layout()
-    plt.savefig("./summary_figures/weight_changes/compact sudden_weight_changes at transitions")
-    plt.close()
-
-    average = np.zeros((n_weights + 1, 3, 2))
-    counter = np.zeros((n_weights + 1, 3))
-    all_datapoints = [[[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]]]
-    f, axs = plt.subplots(n_weights + 1, 3, figsize=(12, 9))  # I added a third pseudo weight for distance between the two sensitivities, and do another extra to split the bias
-    for state_type, weight_gaps in enumerate(all_sudden_changes):
-
-        for gap in weight_gaps:
-            for i in range(n_weights):
-                if i == n_weights - 1:
-                    axs[i + (gap[1][i] > 0), state_type].plot([0, 1], [gap[0][i], gap[-1][i]], marker="o") # plot how the weight evolves from first to last appearance
+                    axs[i + (gap[1][i] > 0), state_type].plot([0, 1], [gap[0][i], gap[-1][i]], marker="o")  # plot how the weight of the new state differs from the previous closest weight
                     average[i + (gap[1][i] > 0), state_type] += np.array([gap[0][i], gap[-1][i]])
                     counter[i + (gap[1][i] > 0), state_type] += 1
                     all_datapoints[i + (gap[1][i] > 0)][state_type][0].append(gap[0][i])
                     all_datapoints[i + (gap[1][i] > 0)][state_type][1].append(gap[-1][i])
                 else:
-                    axs[i, state_type].plot([0, 1], [gap[0][i], gap[-1][i]], marker="o") # plot how the weight evolves from first to last appearance
+                    axs[i, state_type].plot([0, 1], [gap[0][i], gap[-1][i]], marker="o")  # plot how the weight of the new state differs from the previous closest weight
                     average[i, state_type] += np.array([gap[0][i], gap[-1][i]])
                     counter[i, state_type] += 1
                     all_datapoints[i][state_type][0].append(gap[0][i])
                     all_datapoints[i][state_type][1].append(gap[-1][i])
 
+        if show_weight_augmentations:
+            for state_type, weight_gaps in enumerate(augmented_data):
+                for gap in weight_gaps:
+                    axs[n_weights + 1, state_type].plot([0, 1], [gap[0], gap[-1]], marker="o")  # plot how the weight of the new state differs from the previous closest weight
+                    average[n_weights + 1, state_type] += np.array([gap[0], gap[-1]])
+                    counter[n_weights + 1, state_type] += 1
+                    all_datapoints[n_weights + 1][state_type][0].append(gap[0])
+                    all_datapoints[n_weights + 1][state_type][1].append(gap[-1])
+
     plt.tight_layout()
-    plt.savefig("./summary_figures/weight_changes/sudden_weight_change")
+    plt.savefig("./summary_figures/weight_changes/" + title + " augmented" * show_weight_augmentations)
     plt.close()
 
-    f, axs = plt.subplots(1, 3, figsize=(12, 6))
-    for i in range(3):
-        for j in range(n_weights + 1):
-            axs[i].plot([0, 1], average[j, i] / counter[j, i], marker="o", color=weight_colours[j], label=ylabels[j])
+    return average, counter, all_datapoints
+
+
+def plot_compact(average, counter, title, show_first_and_last=False, show_weight_augmentations=False):
+    """Plot the means of the weight change traces, split by the different weight types, possibly with augmentations"""
+    sudden_changes = average.shape[1] == n_types - 1
+    f, axs = plt.subplots(1, n_types - sudden_changes, figsize=(4 * (3 - sudden_changes), 6))
+    for i in range(n_types - sudden_changes):
+        for j in range(n_weights + 1 + show_weight_augmentations):
+            axs[i].plot([0, 1], average[j, i] / counter[j, i], marker="o", color=local_weight_colours[j], label=local_ylabels[j])
             axs[i].set_ylim(-3.5, 3.5)
             axs[i].spines['top'].set_visible(False)
             axs[i].spines['right'].set_visible(False)
             axs[i].set_xticks([])
-            if i == 0:
-                pass
-            #     axs[i].set_ylabel("Weights", size=24)
-            #     if j < n_weights - 1:
-            #         axs[i].plot([0.1], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', color=weight_colours[j])  # also plot weights of very first state average
-            #     if j == n_weights - 1:
-            #         mask = first_and_last_pmf[:, 0, -1] < 0
-            #         axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=weight_colours[j])  # separete biases again
-            #     if j == n_weights:
-            #         mask = first_and_last_pmf[:, 0, -1] > 0
-            #         axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=weight_colours[j])
+            if i == 0 and show_first_and_last:
+                axs[i].set_ylabel("Weights", size=24)
+                if j < n_weights - 1:
+                    axs[i].plot([0.1], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', color=local_weight_colours[j])  # also plot weights of very first state average
+                if j == n_weights - 1:
+                    mask = first_and_last_pmf[:, 0, -1] < 0
+                    axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=local_weight_colours[j])  # separete biases again
+                if j == n_weights:
+                    mask = first_and_last_pmf[:, 0, -1] > 0
+                    axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=local_weight_colours[j])
             else:
                 axs[i].yaxis.set_ticklabels([])
-            if i == 2:
-                pass
-            #     if j < n_weights - 1:
-            #         axs[i].plot([0.9], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', color=weight_colours[j])  # also plot weights of very last state average
-            #     if j == n_weights - 1:
-            #         mask = first_and_last_pmf[:, 1, -1] < 0
-            #         axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=weight_colours[j])  # separete biases again
-            #     if j == n_weights:
-            #         mask = first_and_last_pmf[:, 1, -1] > 0
-            #         axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=weight_colours[j])
+            if i == 2 and show_first_and_last:
+                if j < n_weights - 1:
+                    axs[i].plot([0.9], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', color=local_weight_colours[j])  # also plot weights of very last state average
+                if j == n_weights - 1:
+                    mask = first_and_last_pmf[:, 1, -1] < 0
+                    axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=local_weight_colours[j])  # separete biases again
+                if j == n_weights:
+                    mask = first_and_last_pmf[:, 1, -1] > 0
+                    axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=local_weight_colours[j])
             if j == 0:
-                axs[i].set_title("Type {}".format(i + 1), size=26)
+                axs[i].set_title("Type {}".format(i + 1 + sudden_changes), size=26)
             if j == n_weights and i == 1:
                 axs[i].set_xlabel("Lifetime weight change", size=24)
     axs[0].legend(frameon=False, fontsize=14)
     plt.tight_layout()
-    plt.savefig("./summary_figures/weight_changes/compact sudden_weight_changes")
+    plt.savefig("./summary_figures/weight_changes/" + title + " augmented" * show_weight_augmentations)
     plt.close()
 
-    bin_sets = [np.linspace(-6.5, 6.5, 30), np.linspace(-1, 2, 30), np.linspace(-3.5, 3.5, 30)]
-    if True:
-        x_lim_used_full, x_lim_used_half = 56, 28
-        f, axs = plt.subplots(n_weights + 1, 3 * 2, figsize=(12, 9))
-        for i in range(3):
-            for j in range(n_weights + 1):
-                # axs[j, i].plot([0, 1], average[j, i] / counter[j, i], marker="o")
-                if j < 2:
-                    bins = bin_sets[0]
-                elif j == 2:
-                    bins = bin_sets[1]
-                else:
-                    bins = bin_sets[2]
-                axs[j, i * 2].hist(all_datapoints[j][i][0], orientation='horizontal', bins=bins, color='grey', alpha=0.5)
+
+def plot_histogram_diffs(all_datapoints, x_lim_used_normal, x_lim_used_bias, bin_sets, title, x_lim_used_augment=0, show_deltas=True, show_first_and_last=False, show_weight_augmentations=False):
+    """Plot histograms over the weights, and the mean changes connecting them.
+    Might have to mess quite a bit with the y-axis"""
+    sudden_changes = len(all_datapoints[0]) == 2
+    f, axs = plt.subplots(n_weights + 1 + show_weight_augmentations, (n_types - sudden_changes) * 2, figsize=(4 * (3 - sudden_changes), 9))
+    for i in range(n_types - sudden_changes):
+        for j in range(n_weights + 1 + show_weight_augmentations):
+            if j < 2:
+                bins = bin_sets[0]
+            elif j == 2:
+                bins = bin_sets[1]
+            elif j in [3, 4]:
+                bins = bin_sets[2]
+            else:
+                bins = bin_sets[3]
+            axs[j, i * 2].hist(all_datapoints[j][i][0], orientation='horizontal', bins=bins, color='grey', alpha=0.5)
+            if show_deltas:
+                axs[j, i * 2 + 1].hist(np.array(all_datapoints[j][i][1]) - np.array(all_datapoints[j][i][0]), orientation='horizontal', bins=bins, color='red', alpha=0.5)
+            else:
                 axs[j, i * 2 + 1].hist(all_datapoints[j][i][1], orientation='horizontal', bins=bins, color='grey', alpha=0.5)
 
-                if j < 2:
-                    axs[j, i * 2].set_ylim(-6.5, 6.5)
-                    axs[j, i * 2 + 1].set_ylim(-6.5, 6.5)
-                elif j == 2:
-                    axs[j, i * 2].set_ylim(-1, 2)
-                    axs[j, i * 2 + 1].set_ylim(-1, 2)
-                else:
-                    axs[j, i * 2].set_ylim(-3.5, 3.5)
-                    axs[j, i * 2 + 1].set_ylim(-3.5, 3.5)
-                axs[j, i * 2].spines['top'].set_visible(False)
-                axs[j, i * 2].spines['right'].set_visible(False)
-                axs[j, i * 2].set_xticks([])
-                axs[j, i * 2 + 1].spines['top'].set_visible(False)
-                axs[j, i * 2 + 1].spines['right'].set_visible(False)
-                axs[j, i * 2 + 1].set_xticks([])
-
-                axs[j, i * 2].annotate("Var {:.2f}".format(np.var(all_datapoints[j][i][0])), xy=(0.65, 0.8), xycoords='axes fraction')
+            axs[j, i * 2].set_ylim(bins[0], bins[-1])
+            axs[j, i * 2 + 1].set_ylim(bins[0], bins[-1])
+  
+            axs[j, i * 2].spines['top'].set_visible(False)
+            axs[j, i * 2].spines['right'].set_visible(False)
+            axs[j, i * 2].set_xticks([])
+            axs[j, i * 2 + 1].spines['top'].set_visible(False)
+            axs[j, i * 2 + 1].spines['right'].set_visible(False)
+            axs[j, i * 2 + 1].set_xticks([])
+
+            axs[j, i * 2].annotate("Var {:.2f}".format(np.var(all_datapoints[j][i][0])), xy=(0.65, 0.8), xycoords='axes fraction')
+            if show_deltas:
+                axs[j, i * 2 + 1].annotate("Var {:.2f}".format(np.var(np.array(all_datapoints[j][i][1]) - np.array(all_datapoints[j][i][0]))), xy=(0.65, 0.8), xycoords='axes fraction')
+            else:
                 axs[j, i * 2 + 1].annotate("Var {:.2f}".format(np.var(all_datapoints[j][i][1])), xy=(0.65, 0.8), xycoords='axes fraction')
 
+            if j < n_weights - 1:
+                assert x_lim_used_normal > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_normal, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]))
+                axs[j, i * 2].set_xlim(0, x_lim_used_normal)
+                axs[j, i * 2 + 1].set_xlim(0, x_lim_used_normal)
+                means = average[j, i] / counter[j, i]
+                con = ConnectionPatch(xyA=(x_lim_used_normal / 12, means[0]), xyB=(0, means[1]), coordsA="data", coordsB="data",
+                                      axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue")
+                axs[j, i * 2 + 1].add_artist(con)
+            elif j < n_weights + 1:
+                assert x_lim_used_bias > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_bias, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]))
+                axs[j, i * 2].set_xlim(0, x_lim_used_bias)
+                axs[j, i * 2 + 1].set_xlim(0, x_lim_used_bias)
+                means = average[j, i] / counter[j, i]
+                con = ConnectionPatch(xyA=(x_lim_used_bias / 12, means[0]), xyB=(0, means[1]), coordsA="data", coordsB="data",
+                                      axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue")
+                axs[j, i * 2 + 1].add_artist(con)
+            else:
+                assert x_lim_used_augment > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_augment, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]))
+                axs[j, i * 2].set_xlim(0, x_lim_used_augment)
+                axs[j, i * 2 + 1].set_xlim(0, x_lim_used_augment)
+                means = average[j, i] / counter[j, i]
+                con = ConnectionPatch(xyA=(x_lim_used_augment / 12, means[0]), xyB=(0, means[1]), coordsA="data", coordsB="data",
+                                      axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue")
+                axs[j, i * 2 + 1].add_artist(con)
+
+            if i == 0:
+                axs[j, i].set_ylabel(local_ylabels[j])
+                axs[j, i * 2 + 1].yaxis.set_ticklabels([])
+                if show_first_and_last:
+                    if j < n_weights - 1:
+                        axs[j, i * 2].plot([x_lim_used_normal / 8], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', c='red')  # also plot weights of very first state average
+                    if j == n_weights - 1:
+                        mask = first_and_last_pmf[:, 0, -1] < 0
+                        axs[j, i * 2].plot([x_lim_used_bias / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red')  # separete biases again
+                    if j == n_weights:
+                        mask = first_and_last_pmf[:, 0, -1] > 0
+                        axs[j, i * 2].plot([x_lim_used_augment / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red')
+            else:
+                axs[j, i * 2].yaxis.set_ticklabels([])
+                axs[j, i * 2 + 1].yaxis.set_ticklabels([])
+            if i == n_types - 1 and show_first_and_last:
                 if j < n_weights - 1:
-                    assert x_lim_used_full > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_full, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]))
-                    axs[j, i * 2].set_xlim(0, x_lim_used_full)
-                    axs[j, i * 2 + 1].set_xlim(0, x_lim_used_full)
-                    means = average[j, i] / counter[j, i]
-                    con = ConnectionPatch(xyA=(x_lim_used_full / 8, means[0]), xyB=(x_lim_used_full / 8, means[1]), coordsA="data", coordsB="data",
-                                          axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue")
-                    axs[j, i * 2 + 1].add_artist(con)
+                    axs[j, i * 2 + 1].plot([x_lim_used_normal / 8], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', c='red')  # also plot weights of very last state average
+                if j == n_weights - 1:
+                    mask = first_and_last_pmf[:, 1, -1] < 0
+                    axs[j, i * 2 + 1].plot([x_lim_used_bias / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red')  # separete biases again
+                if j == n_weights:
+                    mask = first_and_last_pmf[:, 1, -1] > 0
+                    axs[j, i * 2 + 1].plot([x_lim_used_augment / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red')
+            if j == 0:
+                axs[j, i * 2].set_title("Type {}".format(i + 1), loc='right')
+            if j == n_weights + show_weight_augmentations and i == 0:
+                axs[j, 0].set_xlabel("Initial distribution")
+                if show_deltas:
+                    axs[j, 1].set_xlabel("Weight change")
                 else:
-                    assert x_lim_used_half > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_half, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]))
-                    axs[j, i * 2].set_xlim(0, x_lim_used_half)
-                    axs[j, i * 2 + 1].set_xlim(0, x_lim_used_half)
-                    means = average[j, i] / counter[j, i]
-                    con = ConnectionPatch(xyA=(x_lim_used_half / 8, means[0]), xyB=(x_lim_used_half / 8, means[1]), coordsA="data", coordsB="data",
-                                          axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue")
-                    axs[j, i * 2 + 1].add_artist(con)
+                    axs[j, 1].set_xlabel("Changed distribution")
 
-                if i == 0:
-                    axs[j, i].set_ylabel(ylabels[j])
-                    axs[j, i * 2 + 1].yaxis.set_ticklabels([])
-                    # if j < n_weights - 1:
-                    #     axs[j, i * 2].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', c='red')  # also plot weights of very first state average
-                    # if j == n_weights - 1:
-                    #     mask = first_and_last_pmf[:, 0, -1] < 0
-                    #     axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red')  # separete biases again
-                    # if j == n_weights:
-                    #     mask = first_and_last_pmf[:, 0, -1] > 0
-                    #     axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red')
-                else:
-                    axs[j, i * 2].yaxis.set_ticklabels([])
-                    axs[j, i * 2 + 1].yaxis.set_ticklabels([])
-                if i == 2:
-                    pass
-                    # if j < n_weights - 1:
-                    #     axs[j, i * 2 + 1].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', c='red')  # also plot weights of very last state average
-                    # if j == n_weights - 1:
-                    #     mask = first_and_last_pmf[:, 1, -1] < 0
-                    #     axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red')  # separete biases again
-                    # if j == n_weights:
-                    #     mask = first_and_last_pmf[:, 1, -1] > 0
-                    #     axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red')
-                if j == 0:
-                    axs[j, i * 2].set_title("Type {}".format(i + 1), loc='right')
-                if j == n_weights:
-                    axs[j, i * 2].set_xlabel("Lifetime weight change", loc='right')
+    plt.savefig("./summary_figures/weight_changes/" + title)
+    plt.close()
 
-        plt.savefig("./summary_figures/weight_changes/weight changes sudden hists")
-        plt.close()
 
-        x_lim_used_full, x_lim_used_half = 120, 60
-        bin_sets = [np.linspace(-6.5, 6.5, 30), np.linspace(-1, 2, 30), np.linspace(-3.5, 3.5, 30)]
-        f, axs = plt.subplots(n_weights + 1, 3 * 2, figsize=(12, 9))
-        for i in range(3):
-            for j in range(n_weights + 1):
-                # axs[j, i].plot([0, 1], average[j, i] / counter[j, i], marker="o")
-                if j < 2:
-                    bins = bin_sets[0]
-                elif j == 2:
-                    bins = bin_sets[1]
-                else:
-                    bins = bin_sets[2]
-                axs[j, i * 2].hist(all_datapoints[j][i][0], orientation='horizontal', bins=bins, color='grey', alpha=0.5)
-                axs[j, i * 2 + 1].hist(np.array(all_datapoints[j][i][1]) - np.array(all_datapoints[j][i][0]), orientation='horizontal', bins=bins, color='grey', alpha=0.5)
+if True:
 
-                if j < 2:
-                    axs[j, i * 2].set_ylim(-6.5, 6.5)
-                    axs[j, i * 2 + 1].set_ylim(-6.5, 6.5)
-                elif j == 2:
-                    axs[j, i * 2].set_ylim(-1, 2)
-                    axs[j, i * 2 + 1].set_ylim(-1, 2)
-                else:
-                    axs[j, i * 2].set_ylim(-3.5, 3.5)
-                    axs[j, i * 2 + 1].set_ylim(-3.5, 3.5)
-                axs[j, i * 2].spines['top'].set_visible(False)
-                axs[j, i * 2].spines['right'].set_visible(False)
-                axs[j, i * 2].set_xticks([])
-                axs[j, i * 2 + 1].spines['top'].set_visible(False)
-                axs[j, i * 2 + 1].spines['right'].set_visible(False)
-                axs[j, i * 2 + 1].set_xticks([])
-
-                axs[j, i * 2].annotate("Var {:.2f}".format(np.var(all_datapoints[j][i][0])), xy=(0.65, 0.8), xycoords='axes fraction')
-                axs[j, i * 2 + 1].annotate("Var {:.2f}".format(np.var(np.array(all_datapoints[j][i][1]) - np.array(all_datapoints[j][i][0]))), xy=(0.65, 0.8), xycoords='axes fraction')
+    bin_sets = [np.linspace(-6.5, 6.5, 30), np.linspace(-1, 2, 30), np.linspace(-3.5, 3.5, 30), np.linspace(-0.2, 1, 30)]  # TODO: make show_augmentation robuse
 
-                if j < n_weights - 1:
-                    assert x_lim_used_full > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_full, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]))
-                    axs[j, i * 2].set_xlim(0, x_lim_used_full)
-                    axs[j, i * 2 + 1].set_xlim(0, x_lim_used_full)
-                    means = average[j, i] / counter[j, i]
-                    con = ConnectionPatch(xyA=(x_lim_used_full / 8, means[0]), xyB=(x_lim_used_full / 8, means[1]), coordsA="data", coordsB="data",
-                                          axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue")
-                    axs[j, i * 2 + 1].add_artist(con)
-                else:
-                    assert x_lim_used_half > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_half, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]))
-                    axs[j, i * 2].set_xlim(0, x_lim_used_half)
-                    axs[j, i * 2 + 1].set_xlim(0, x_lim_used_half)
-                    means = average[j, i] / counter[j, i]
-                    con = ConnectionPatch(xyA=(x_lim_used_half / 8, means[0]), xyB=(x_lim_used_half / 8, means[1]), coordsA="data", coordsB="data",
-                                          axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue")
-                    axs[j, i * 2 + 1].add_artist(con)
+    average, counter, all_datapoints = plot_traces_and_collate_data(data=all_sudden_transition_changes, augmented_data=aug_all_sudden_transition_changes, title="sudden_weight_change at transitions")
 
-                if i == 0:
-                    axs[j, i].set_ylabel(ylabels[j])
-                    axs[j, i * 2 + 1].yaxis.set_ticklabels([])
-                    # if j < n_weights - 1:
-                    #     axs[j, i * 2].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', c='red')  # also plot weights of very first state average
-                    # if j == n_weights - 1:
-                    #     mask = first_and_last_pmf[:, 0, -1] < 0
-                    #     axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red')  # separete biases again
-                    # if j == n_weights:
-                    #     mask = first_and_last_pmf[:, 0, -1] > 0
-                    #     axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red')
-                else:
-                    axs[j, i * 2].yaxis.set_ticklabels([])
-                    axs[j, i * 2 + 1].yaxis.set_ticklabels([])
-                if i == 2:
-                    pass
-                    # if j < n_weights - 1:
-                    #     axs[j, i * 2 + 1].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', c='red')  # also plot weights of very last state average
-                    # if j == n_weights - 1:
-                    #     mask = first_and_last_pmf[:, 1, -1] < 0
-                    #     axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red')  # separete biases again
-                    # if j == n_weights:
-                    #     mask = first_and_last_pmf[:, 1, -1] > 0
-                    #     axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red')
-                if j == 0:
-                    axs[j, i * 2].set_title("Type {}".format(i + 1), loc='right')
-                if j == n_weights:
-                    axs[j, i * 2].set_xlabel("Lifetime weight change", loc='right')
+    plot_compact(average, counter, title="compact sudden_weight_change at transitions", show_weight_augmentations=show_weight_augmentations)
 
-        plt.savefig("./summary_figures/weight_changes/weight changes sudden delta hists")
-        plt.close()
+    plot_histogram_diffs(all_datapoints, x_lim_used_normal=22, x_lim_used_bias=11, x_lim_used_augment=40, bin_sets=bin_sets, title="weight changes sudden at transitions hists", show_deltas=False, show_weight_augmentations=show_weight_augmentations)
 
+    plot_histogram_diffs(all_datapoints, x_lim_used_normal=22, x_lim_used_bias=11, x_lim_used_augment=40, bin_sets=bin_sets, title="weight changes sudden at transitions delta hists", show_deltas=True, show_weight_augmentations=show_weight_augmentations)
 
-    # for weight_traj in all_weight_trajectories:
-    #     if len(weight_traj) == 1:
-    #         continue
-    #     state_type = pmf_type(weights_to_pmf(weight_traj[0]))
+    average, counter, all_datapoints = plot_traces_and_collate_data(data=all_sudden_changes, augmented_data=aug_all_sudden_changes, title="sudden_weight_change")
 
-    #     if state_type == 2 and weight_traj[0][0] > 0.8:
-    #         plt.plot(weights_to_pmf(weight_traj[0]))
-    #         plt.ylim(0, 1)
-    #         plt.show()
+    plot_compact(average, counter, title="compact sudden_weight_change", show_weight_augmentations=show_weight_augmentations)
 
-    dur_lims = [(52, 25), (46, 23), (38, 19), (35, 17), (30, 15), (25, 12), (20, 10), (18, 9)]
+    plot_histogram_diffs(all_datapoints, x_lim_used_normal=56, x_lim_used_bias=28, x_lim_used_augment=140, bin_sets=bin_sets, title="weight changes sudden hists", show_deltas=False, show_weight_augmentations=show_weight_augmentations)
 
+    plot_histogram_diffs(all_datapoints, x_lim_used_normal=120, x_lim_used_bias=60, x_lim_used_augment=180, bin_sets=bin_sets, title="weight changes sudden delta hists", show_deltas=True, show_weight_augmentations=show_weight_augmentations)
+
+    dur_lims = [(52, 25), (46, 23), (38, 19), (35, 17), (30, 15), (25, 12), (20, 10), (18, 9)]
     for min_dur_counter, min_dur in enumerate([2, 3, 4, 5, 7, 9, 11, 15]):
-        f, axs = plt.subplots(n_weights + 1, 3, figsize=(12, 9)) # add another space to split bias into stating left versus starting right
-        average = np.zeros((n_weights + 1, 3, 2))
-        counter = np.zeros((n_weights + 1, 3))
-        all_datapoints = [[[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]]]
+        average = np.zeros((n_weights + 1, n_types, 2))
+        counter = np.zeros((n_weights + 1, n_types))
+        all_datapoints = create_nested_list([n_weights + 1, n_types, 2])
+
+        f, axs = plt.subplots(n_weights + 1, n_types, figsize=(12, 9)) # add another space to split bias into stating left versus starting right
         for weight_traj in all_weight_trajectories:
 
             if len(weight_traj) < min_dur or len(weight_traj) > 15: # take out too short trajectories, and too long ones
@@ -388,61 +305,26 @@ if True:
                     all_datapoints[i][state_type][1].append(weight_traj[-1][i])
         for i in range(3):
             for j in range(n_weights + 1):
-                    axs[j, i].set_ylim(-9, 9)
-                    axs[j, i].spines['top'].set_visible(False)
-                    axs[j, i].spines['right'].set_visible(False)
-                    axs[j, i].set_xticks([])
-                    if i == 0:
-                        axs[j, i].set_ylabel(ylabels[j])
-                    else:
-                        axs[j, i].yaxis.set_ticklabels([])
-                    if j == 0:
-                        axs[j, i].set_title("Type {}".format(i + 1))
-                    if j == n_weights:
-                        axs[j, i].set_xlabel("Lifetime weight change")
-
-        plt.tight_layout()
-        plt.savefig("./summary_figures/weight_changes/all weight changes min dur {}".format(min_dur))
-        plt.close()
-
-        f, axs = plt.subplots(1, 3, figsize=(12, 6))
-        for i in range(3):
-            for j in range(n_weights + 1):
-                axs[i].plot([0, 1], average[j, i] / counter[j, i], marker="o", color=weight_colours[j], label=ylabels[j])
-                axs[i].set_ylim(-3.5, 3.5)
-                axs[i].spines['top'].set_visible(False)
-                axs[i].spines['right'].set_visible(False)
-                axs[i].set_xticks([])
+                axs[j, i].set_ylim(-9, 9)
+                axs[j, i].spines['top'].set_visible(False)
+                axs[j, i].spines['right'].set_visible(False)
+                axs[j, i].set_xticks([])
                 if i == 0:
-                    axs[i].set_ylabel("Weights", size=24)
-                    if j < n_weights - 1:
-                        axs[i].plot([0.1], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', color=weight_colours[j])  # also plot weights of very first state average
-                    if j == n_weights - 1:
-                        mask = first_and_last_pmf[:, 0, -1] < 0
-                        axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=weight_colours[j])  # separete biases again
-                    if j == n_weights:
-                        mask = first_and_last_pmf[:, 0, -1] > 0
-                        axs[i].plot([0.1], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', color=weight_colours[j])
+                    axs[j, i].set_ylabel(ylabels[j])
                 else:
-                    axs[i].yaxis.set_ticklabels([])
-                if i == 2:
-                    if j < n_weights - 1:
-                        axs[i].plot([0.9], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', color=weight_colours[j])  # also plot weights of very last state average
-                    if j == n_weights - 1:
-                        mask = first_and_last_pmf[:, 1, -1] < 0
-                        axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=weight_colours[j])  # separete biases again
-                    if j == n_weights:
-                        mask = first_and_last_pmf[:, 1, -1] > 0
-                        axs[i].plot([0.9], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', color=weight_colours[j])
+                    axs[j, i].yaxis.set_ticklabels([])
                 if j == 0:
-                    axs[i].set_title("Type {}".format(i + 1), size=26)
-                if j == n_weights and i == 1:
-                    axs[i].set_xlabel("Lifetime weight change", size=24)
-        axs[0].legend(frameon=False, fontsize=14)
+                    axs[j, i].set_title("Type {}".format(i + 1))
+                if j == n_weights:
+                    axs[j, i].set_xlabel("Lifetime weight change")
+
         plt.tight_layout()
-        plt.savefig("./summary_figures/weight_changes/compact changes min dur {}".format(min_dur))
+        plt.savefig("./summary_figures/weight_changes/all weight changes min dur {}".format(min_dur))
         plt.close()
 
+        plot_compact(average, counter, title="compact changes min dur {}".format(min_dur), show_first_and_last=True)
+
+        # means all split up
         f, axs = plt.subplots(n_weights + 1, 3, figsize=(12, 9))
         for i in range(3):
             for j in range(n_weights + 1):
@@ -482,168 +364,14 @@ if True:
         plt.savefig("./summary_figures/weight_changes/weight changes min dur {}".format(min_dur))
         plt.close()
 
-        x_lim_used_full, x_lim_used_half = dur_lims[min_dur_counter]
-
-        # also histograms
-        f, axs = plt.subplots(n_weights + 1, 3 * 2, figsize=(12, 9))
-        for i in range(3):
-            for j in range(n_weights + 1):
-                if j < 2:
-                    bins = bin_sets[0]
-                elif j == 2:
-                    bins = bin_sets[1]
-                else:
-                    bins = bin_sets[2]
-                axs[j, i * 2].hist(all_datapoints[j][i][0], orientation='horizontal', bins=bins, color='grey', alpha=0.5)
-                axs[j, i * 2 + 1].hist(all_datapoints[j][i][1], orientation='horizontal', bins=bins, color='grey', alpha=0.5)
-
-                if j < 2:
-                    axs[j, i * 2].set_ylim(-6.5, 6.5)
-                    axs[j, i * 2 + 1].set_ylim(-6.5, 6.5)
-                elif j == 2:
-                    axs[j, i * 2].set_ylim(-1, 2)
-                    axs[j, i * 2 + 1].set_ylim(-1, 2)
-                else:
-                    axs[j, i * 2].set_ylim(-3.5, 3.5)
-                    axs[j, i * 2 + 1].set_ylim(-3.5, 3.5)
-                axs[j, i * 2].spines['top'].set_visible(False)
-                axs[j, i * 2].spines['right'].set_visible(False)
-                axs[j, i * 2].set_xticks([])
-                axs[j, i * 2 + 1].spines['top'].set_visible(False)
-                axs[j, i * 2 + 1].spines['right'].set_visible(False)
-                axs[j, i * 2 + 1].set_xticks([])
-
-                axs[j, i * 2].annotate("Var {:.2f}".format(np.var(all_datapoints[j][i][0])), xy=(0.65, 0.8), xycoords='axes fraction')
-                axs[j, i * 2 + 1].annotate("Var {:.2f}".format(np.var(all_datapoints[j][i][1])), xy=(0.65, 0.8), xycoords='axes fraction')
-
-                if j < n_weights - 1:
-                    assert x_lim_used_full > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_full, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]))
-                    axs[j, i * 2].set_xlim(0, x_lim_used_full)
-                    axs[j, i * 2 + 1].set_xlim(0, x_lim_used_full)
-                    means = average[j, i] / counter[j, i]
-                    con = ConnectionPatch(xyA=(x_lim_used_full / 8, means[0]), xyB=(x_lim_used_full / 8, means[1]), coordsA="data", coordsB="data",
-                        axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue")
-                    axs[j, i * 2 + 1].add_artist(con)
-                else:
-                    assert x_lim_used_half > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_half, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]))
-                    axs[j, i * 2].set_xlim(0, x_lim_used_half)
-                    axs[j, i * 2 + 1].set_xlim(0, x_lim_used_half)
-                    means = average[j, i] / counter[j, i]
-                    con = ConnectionPatch(xyA=(x_lim_used_half / 8, means[0]), xyB=(x_lim_used_half / 8, means[1]), coordsA="data", coordsB="data",
-                        axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue")
-                    axs[j, i * 2 + 1].add_artist(con)
-
-                if i == 0:
-                    axs[j, i].set_ylabel(ylabels[j])
-                    axs[j, i * 2 + 1].yaxis.set_ticklabels([])
-                    if j < n_weights - 1:
-                        axs[j, i * 2].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', c='red')  # also plot weights of very first state average
-                    if j == n_weights - 1:
-                        mask = first_and_last_pmf[:, 0, -1] < 0
-                        axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red')  # separete biases again
-                    if j == n_weights:
-                        mask = first_and_last_pmf[:, 0, -1] > 0
-                        axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red')
-                else:
-                    axs[j, i * 2].yaxis.set_ticklabels([])
-                    axs[j, i * 2 + 1].yaxis.set_ticklabels([])
-                if i == 2:
-                    if j < n_weights - 1:
-                        axs[j, i * 2 + 1].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', c='red')  # also plot weights of very last state average
-                    if j == n_weights - 1:
-                        mask = first_and_last_pmf[:, 1, -1] < 0
-                        axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red')  # separete biases again
-                    if j == n_weights:
-                        mask = first_and_last_pmf[:, 1, -1] > 0
-                        axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red')
-                if j == 0:
-                    axs[j, i * 2].set_title("Type {}".format(i + 1), loc='right')
-                if j == n_weights:
-                    axs[j, i * 2].set_xlabel("Lifetime weight change", loc='right')
-
-        plt.savefig("./summary_figures/weight_changes/weight changes min dur {} hists".format(min_dur))
-        plt.close()
-
-        # also histograms of deltas
-        x_lim_used_full, x_lim_used_half = x_lim_used_full * 2.5, x_lim_used_half * 2.5
-        f, axs = plt.subplots(n_weights + 1, 3 * 2, figsize=(12, 9))
-        for i in range(3):
-            for j in range(n_weights + 1):
-                if j < 2:
-                    bins = bin_sets[0]
-                elif j == 2:
-                    bins = bin_sets[1]
-                else:
-                    bins = bin_sets[2]
-                axs[j, i * 2].hist(all_datapoints[j][i][0], orientation='horizontal', bins=bins, color='grey', alpha=0.5)
-                axs[j, i * 2 + 1].hist(np.array(all_datapoints[j][i][1]) - np.array(all_datapoints[j][i][0]), orientation='horizontal', bins=bins, color='grey', alpha=0.5)
-
-                if j < 2:
-                    axs[j, i * 2].set_ylim(-6.5, 6.5)
-                    axs[j, i * 2 + 1].set_ylim(-6.5, 6.5)
-                elif j == 2:
-                    axs[j, i * 2].set_ylim(-1, 2)
-                    axs[j, i * 2 + 1].set_ylim(-1, 2)
-                else:
-                    axs[j, i * 2].set_ylim(-3.5, 3.5)
-                    axs[j, i * 2 + 1].set_ylim(-3.5, 3.5)
-                axs[j, i * 2].spines['top'].set_visible(False)
-                axs[j, i * 2].spines['right'].set_visible(False)
-                axs[j, i * 2].set_xticks([])
-                axs[j, i * 2 + 1].spines['top'].set_visible(False)
-                axs[j, i * 2 + 1].spines['right'].set_visible(False)
-                axs[j, i * 2 + 1].set_xticks([])
-
-                axs[j, i * 2].annotate("Var {:.2f}".format(np.var(all_datapoints[j][i][0])), xy=(0.65, 0.8), xycoords='axes fraction')
-                axs[j, i * 2 + 1].annotate("Var {:.2f}".format(np.var(np.array(all_datapoints[j][i][1]) - np.array(all_datapoints[j][i][0]))), xy=(0.65, 0.8), xycoords='axes fraction')
+        x_lim_used_normal, x_lim_used_bias = dur_lims[min_dur_counter]
+        plot_histogram_diffs(all_datapoints=all_datapoints, x_lim_used_normal=x_lim_used_normal, x_lim_used_bias=x_lim_used_bias, bin_sets=bin_sets,
+                             title="weight changes min dur {} hists".format(min_dur), show_deltas=False, show_first_and_last=True)
 
-                if j < n_weights - 1:
-                    assert x_lim_used_full > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_full, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]))
-                    axs[j, i * 2].set_xlim(0, x_lim_used_full)
-                    axs[j, i * 2 + 1].set_xlim(0, x_lim_used_full)
-                    means = average[j, i] / counter[j, i]
-                    con = ConnectionPatch(xyA=(x_lim_used_full / 8, means[0]), xyB=(x_lim_used_full / 8, means[1]), coordsA="data", coordsB="data",
-                        axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue")
-                    axs[j, i * 2 + 1].add_artist(con)
-                else:
-                    assert x_lim_used_half > max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]), "Hists are cut off ({} vs {})".format(x_lim_used_half, max(axs[j, i * 2].set_xlim()[0], axs[j, i * 2 + 1].set_xlim()[1]))
-                    axs[j, i * 2].set_xlim(0, x_lim_used_half)
-                    axs[j, i * 2 + 1].set_xlim(0, x_lim_used_half)
-                    means = average[j, i] / counter[j, i]
-                    con = ConnectionPatch(xyA=(x_lim_used_half / 8, means[0]), xyB=(x_lim_used_half / 8, means[1]), coordsA="data", coordsB="data",
-                        axesA=axs[j, i * 2], axesB=axs[j, i * 2 + 1], color="blue")
-                    axs[j, i * 2 + 1].add_artist(con)
+        x_lim_used_normal, x_lim_used_bias = x_lim_used_normal * 2.5, x_lim_used_bias * 2.5
+        plot_histogram_diffs(all_datapoints=all_datapoints, x_lim_used_normal=x_lim_used_normal, x_lim_used_bias=x_lim_used_bias, bin_sets=bin_sets,
+                             title="weight changes min dur {} delta hists".format(min_dur), show_deltas=True, show_first_and_last=True)
 
-                if i == 0:
-                    axs[j, i].set_ylabel(ylabels[j])
-                    axs[j, i * 2 + 1].yaxis.set_ticklabels([])
-                    if j < n_weights - 1:
-                        axs[j, i * 2].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', c='red')  # also plot weights of very first state average
-                    if j == n_weights - 1:
-                        mask = first_and_last_pmf[:, 0, -1] < 0
-                        axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red')  # separete biases again
-                    if j == n_weights:
-                        mask = first_and_last_pmf[:, 0, -1] > 0
-                        axs[j, i * 2].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red')
-                else:
-                    axs[j, i * 2].yaxis.set_ticklabels([])
-                    axs[j, i * 2 + 1].yaxis.set_ticklabels([])
-                if i == 2:
-                    if j < n_weights - 1:
-                        axs[j, i * 2 + 1].plot([x_lim_used_full / 8], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', c='red')  # also plot weights of very last state average
-                    if j == n_weights - 1:
-                        mask = first_and_last_pmf[:, 1, -1] < 0
-                        axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red')  # separete biases again
-                    if j == n_weights:
-                        mask = first_and_last_pmf[:, 1, -1] > 0
-                        axs[j, i * 2 + 1].plot([x_lim_used_half / 8], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red')
-                if j == 0:
-                    axs[j, i * 2].set_title("Type {}".format(i + 1), loc='right')
-                if j == n_weights:
-                    axs[j, i * 2].set_xlabel("Lifetime weight change", loc='right')
-
-        plt.savefig("./summary_figures/weight_changes/weight changes min dur {} delta hists".format(min_dur))
-        plt.close()
     quit()
 
 # all pmf weights