diff --git a/.gitignore b/.gitignore
index b84a8f82fc35d730fb072a10736910aaca487db3..e0d406ba7e6131d35c436488d857b2167bbaaacf 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,4 @@
-
+summary_figures/
 old_ana_code/
 figures/
 dynamic_figures/
diff --git a/__pycache__/analysis_pmf.cpython-37.pyc b/__pycache__/analysis_pmf.cpython-37.pyc
index f9fd33bf94c4befd1c5937c83c151a084e749d91..899a197bee3c2ffcdbb2f048e7d0088d934a594e 100644
Binary files a/__pycache__/analysis_pmf.cpython-37.pyc and b/__pycache__/analysis_pmf.cpython-37.pyc differ
diff --git a/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc b/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc
index 797f6ea4ca71054cc38cb50347dc59cdea7fef0b..06557f92116c59895d3c014e28181f1e3c4ae7f7 100644
Binary files a/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc and b/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc differ
diff --git a/__pycache__/simplex_plot.cpython-37.pyc b/__pycache__/simplex_plot.cpython-37.pyc
index a8537b6f9da6be3317317d69450bbe8640da3809..d1a178aa1d78d9cd467d1b86bd96b0a477756845 100644
Binary files a/__pycache__/simplex_plot.cpython-37.pyc and b/__pycache__/simplex_plot.cpython-37.pyc differ
diff --git a/analysis_pmf.py b/analysis_pmf.py
index 03e55c1d17a94fbdfd7a0b8f2c8e239d4a076e4f..8be3eecf2101ada750f1414224d62ddd5b99491b 100644
--- a/analysis_pmf.py
+++ b/analysis_pmf.py
@@ -50,32 +50,32 @@ if __name__ == "__main__":
     state_types_interpolation = state_types_interpolation / state_types_interpolation.max() * 100
 
     fs = 18
-    # plt.plot(np.linspace(0, 1, 150), state_types_interpolation[0], color=type2color[0])
-    # plt.ylabel("% of type across population", size=fs)
-    # plt.xlabel("Interpolated session time", size=fs)
-    # plt.ylim(0, 100)
-    # sns.despine()
-    # plt.tight_layout()
-    # plt.savefig("type hist 1")
-    # plt.show()
-    #
-    # plt.plot(np.linspace(0, 1, 150), state_types_interpolation[1], color=type2color[1])
-    # plt.ylabel("% of type across population", size=fs)
-    # plt.xlabel("Interpolated session time", size=fs)
-    # plt.ylim(0, 100)
-    # sns.despine()
-    # plt.tight_layout()
-    # plt.savefig("type hist 2")
-    # plt.show()
-    #
-    # plt.plot(np.linspace(0, 1, 150), state_types_interpolation[2], color=type2color[2])
-    # plt.ylabel("% of type across population", size=fs)
-    # plt.xlabel("Interpolated session time", size=fs)
-    # plt.ylim(0, 100)
-    # sns.despine()
-    # plt.tight_layout()
-    # plt.savefig("type hist 3")
-    # plt.show()
+    plt.plot(np.linspace(0, 1, 150), state_types_interpolation[0], color=type2color[0])
+    plt.ylabel("% of type across population", size=fs)
+    plt.xlabel("Interpolated session time", size=fs)
+    plt.ylim(0, 100)
+    sns.despine()
+    plt.tight_layout()
+    plt.savefig("./summary_figures/type hist 1")
+    plt.close()
+    
+    plt.plot(np.linspace(0, 1, 150), state_types_interpolation[1], color=type2color[1])
+    plt.ylabel("% of type across population", size=fs)
+    plt.xlabel("Interpolated session time", size=fs)
+    plt.ylim(0, 100)
+    sns.despine()
+    plt.tight_layout()
+    plt.savefig("./summary_figures/type hist 2")
+    plt.close()
+    
+    plt.plot(np.linspace(0, 1, 150), state_types_interpolation[2], color=type2color[2])
+    plt.ylabel("% of type across population", size=fs)
+    plt.xlabel("Interpolated session time", size=fs)
+    plt.ylim(0, 100)
+    sns.despine()
+    plt.tight_layout()
+    plt.savefig("./summary_figures/type hist 3")
+    plt.close()
 
     all_first_pmfs_typeless = pickle.load(open("all_first_pmfs_typeless.p", 'rb'))
     all_pmfs = pickle.load(open("all_pmfs.p", 'rb'))
@@ -111,27 +111,27 @@ if __name__ == "__main__":
     # plt.xlabel("Bias flips")
     # sns.despine()
     # plt.tight_layout()
-    # plt.savefig("./meeting_figures/bias_flips.png")
+    # plt.savefig("./summary_figures/bias_flips.png")
     # plt.show()
 
-    # fewer_states_side = []
-    # for key in all_first_pmfs_typeless:
-    #     animal_biases = np.zeros(2)
-    #     for defined_points, pmf in all_first_pmfs_typeless[key]:
-    #         bias = np.mean(pmf[defined_points])
-    #         if bias > 0.55:
-    #             animal_biases[0] += 1
-    #         elif bias < 0.45:
-    #             animal_biases[1] += 1
-    #     fewer_states_side.append(np.min(animal_biases / animal_biases.sum()))
-    # plt.hist(fewer_states_side)
-    # plt.title("Mixed biases")
-    # plt.ylabel("# of mice")
-    # plt.xlabel("min(% left biased states, % right biased states)")
-    # sns.despine()
-    # plt.tight_layout()
-    # plt.savefig("./meeting_figures/proportion_other_bias")
-    # plt.show()
+    fewer_states_side = []
+    for key in all_first_pmfs_typeless:
+        animal_biases = np.zeros(2)
+        for defined_points, pmf in all_first_pmfs_typeless[key]:
+            bias = np.mean(pmf[defined_points])
+            if bias > 0.55:
+                animal_biases[0] += 1
+            elif bias < 0.45:
+                animal_biases[1] += 1
+        fewer_states_side.append(np.min(animal_biases / animal_biases.sum()))
+    plt.hist(fewer_states_side)
+    plt.title("Mixed biases")
+    plt.ylabel("# of mice")
+    plt.xlabel("min(% left biased states, % right biased states)")
+    sns.despine()
+    plt.tight_layout()
+    plt.savefig("./summary_figures/proportion_other_bias")
+    plt.close()
 
     # total_counter = 0
     # bias_counter = 0
@@ -287,7 +287,7 @@ if __name__ == "__main__":
     plt.gca().spines['bottom'].set_linewidth(4)
     plt.tight_layout()
     plt.savefig("single exam 5")
-    plt.show()
+    plt.close()
 
 
 
@@ -301,14 +301,14 @@ if __name__ == "__main__":
 
     plt.ylim(0, 1)
     plt.xlim(0, 10)
-    plt.ylabel("P(rightwards)", size=32)
-    plt.xlabel("Contrast", size=32)
+    # plt.ylabel("P(rightwards)", size=32)
+    # plt.xlabel("Contrast", size=32)
     plt.yticks([0, 1], size=27)
     plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=27)
     sns.despine()
     plt.tight_layout()
-    plt.savefig("example type 1")
-    plt.show()
+    plt.savefig("example type 1", transparent=True)
+    plt.close()
 
     state_num = 1
     defined_points, pmf = all_first_pmfs_typeless['CSHL_018'][state_num][0], all_first_pmfs_typeless['CSHL_018'][state_num][1]
@@ -319,14 +319,14 @@ if __name__ == "__main__":
 
     plt.ylim(0, 1)
     plt.xlim(0, 10)
-    plt.ylabel("P(rightwards)", size=32)
-    plt.xlabel("Contrast", size=32)
+    # plt.ylabel("P(rightwards)", size=32)
+    # plt.xlabel("Contrast", size=32)
     plt.yticks([0, 1], size=27)
     plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=27)
     sns.despine()
     plt.tight_layout()
-    plt.savefig("example type 2")
-    plt.show()
+    plt.savefig("example type 2", transparent=True)
+    plt.close()
 
     state_num = 4
     defined_points, pmf = all_first_pmfs_typeless['ibl_witten_17'][state_num][0], all_first_pmfs_typeless['ibl_witten_17'][state_num][1]
@@ -334,14 +334,14 @@ if __name__ == "__main__":
 
     plt.ylim(0, 1)
     plt.xlim(0, 10)
-    plt.ylabel("P(rightwards)", size=32)
-    plt.xlabel("Contrast", size=32)
+    # plt.ylabel("P(rightwards)", size=32)
+    # plt.xlabel("Contrast", size=32)
     plt.yticks([0, 1], size=27)
     plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=27)
     sns.despine()
     plt.tight_layout()
-    plt.savefig("example type 3")
-    plt.show()
+    plt.savefig("example type 3", transparent=True)
+    plt.close()
 
 
     n_rows, n_cols = 5, 6
@@ -364,8 +364,8 @@ if __name__ == "__main__":
                 a.set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
                 a.set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size)
     plt.tight_layout()
-    plt.savefig("animals 1")
-    plt.show()
+    plt.savefig("./summary_figures/animals 1")
+    plt.close()
 
     n_rows, n_cols = 5, 6
     _, axs = plt.subplots(n_rows, n_cols, figsize=(16, 9))
@@ -393,8 +393,8 @@ if __name__ == "__main__":
                 a.set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
                 a.set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size)
     plt.tight_layout()
-    plt.savefig("animals 2")
-    plt.show()
+    plt.savefig("./summary_figures/animals 2")
+    plt.close()
 
     n_rows, n_cols = 5, 6
     _, axs = plt.subplots(n_rows, n_cols, figsize=(16, 9))
@@ -422,8 +422,8 @@ if __name__ == "__main__":
                 a.set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
                 a.set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size)
     plt.tight_layout()
-    plt.savefig("animals 3")
-    plt.show()
+    plt.savefig("./summary_figures/animals 3")
+    plt.close()
 
     n_rows, n_cols = 5, 6
     _, axs = plt.subplots(n_rows, n_cols, figsize=(16, 9))
@@ -451,8 +451,8 @@ if __name__ == "__main__":
                 a.set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
                 a.set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size)
     plt.tight_layout()
-    plt.savefig("animals 4")
-    plt.show()
+    plt.savefig("./summary_figures/animals 4")
+    plt.close()
 
     # Collection of PMFs which change state type
     all_changing_pmfs = pickle.load(open("changing_pmfs.p", 'rb'))
@@ -477,11 +477,49 @@ if __name__ == "__main__":
             plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=16)
             plt.gca().set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=16)
 
-        if i + 1 == 30:
+        if i + 1 == 35:
             break
 
     plt.tight_layout()
-    plt.savefig("changing pmfs")
+    plt.savefig("./summary_figures/changing pmfs")
+    plt.close()
+
+    # The biases of all first PMFs as a histogram
+    biases = []
+    mouse_counter = 0
+    consistent_bias = 0
+    hm = 0
+    for key in all_first_pmfs_typeless:
+        left_1, right_1, left_2, right_2 = False, False, False, False
+        x = all_first_pmfs_typeless[key]
+        for pmf in x:
+            defined_points, pmf = pmf
+            if pmf_type(pmf) == 0:
+                biases.append(np.mean(pmf[[0, 1, -2, -1]]))
+                if np.mean(pmf[[0, 1, -2, -1]]) >= 0.5:
+                    right_1 = True
+                if np.mean(pmf[[0, 1, -2, -1]]) <= 0.5:
+                    left_1 = True
+            if pmf_type(pmf) == 1:
+                if np.mean(pmf[[0, 1, -2, -1]]) >= 0.5:
+                    right_2 = True
+                if np.mean(pmf[[0, 1, -2, -1]]) <= 0.5:
+                    left_2 = True
+        mouse_counter += 1
+        consistent_bias += (left_1 and left_2) or (right_1 and right_2)
+        hm += left_1 and left_2 and right_1 and right_2
+
+
+    print("number of mice is {}".format(mouse_counter))
+    print("of which {} have a previously expressed bias in type 2".format(consistent_bias))
+    print(hm)
+    # number of mice is 113
+    # of which 77 have a previously expressed bias in type 2
+    # 17
+
+    plt.hist(biases)
+    plt.axvline(np.mean(biases))
+    plt.xlim(0, 1)
     plt.show()
 
     # All first PMFs
@@ -497,7 +535,13 @@ if __name__ == "__main__":
     for key in all_first_pmfs_typeless:
         x = all_first_pmfs_typeless[key]
         for pmf in x:
-            axs[pmf_type(pmf[1])].plot(np.where(pmf[0])[0], pmf[1][pmf[0]], c=type2color[pmf_type(pmf[1])])
+
+            defined_points, pmf = pmf
+            pmf_min = min(pmf[0], pmf[1])
+            pmf_max = max(pmf[-2], pmf[-1])
+            defined_points = np.logical_and(np.logical_and(defined_points, ~ (pmf > pmf_max)), ~ (pmf < pmf_min))
+            axs[pmf_type(pmf)].plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)])
+            # axs[pmf_type(pmf[1])].plot(np.where(pmf[0])[0], pmf[1][pmf[0]], c=type2color[pmf_type(pmf[1])])
     axs[0].set_ylim(0, 1)
     axs[1].set_ylim(0, 1)
     axs[2].set_ylim(0, 1)
@@ -517,13 +561,14 @@ if __name__ == "__main__":
     axs[0].set_xlabel("Contrasts", size=label_size)
 
     plt.tight_layout()
-    plt.savefig(save_title)
+    plt.savefig("./summary_figures/" + save_title)
     plt.show()
     if save_title == "KS014 types":
         quit()
 
     # Type 1 PMF specifications
     counter = 0
+    counter_l, counter_r = 0, 0
     fig, ax = plt.subplots(1, 3, figsize=(16, 9))
     for key in all_first_pmfs_typeless:
         for defined_points, pmf in all_first_pmfs_typeless[key]:
@@ -537,6 +582,8 @@ if __name__ == "__main__":
                 use_ax = 2
             else:
                 use_ax = int(pmf[0] > 1 - pmf[-1])
+                counter_l += use_ax == 0
+                counter_r += use_ax == 1
 
             pmf_min = min(pmf[0], pmf[1])
             pmf_max = max(pmf[-2], pmf[-1])
@@ -561,9 +608,9 @@ if __name__ == "__main__":
     ax[2].set_yticks([])
     ax[2].spines[['right', 'top']].set_visible(False)
     ax[2].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
-    print(counter)
+    print("left biased pmfs {}, right biased pmfs {}, neutral pmfs {}".format(counter_l, counter_r, counter))
     plt.tight_layout()
-    plt.savefig("differentiate type 2")
+    plt.savefig("./summary_figures/differentiate type 2")
     plt.show()
 
 
diff --git a/pmf_weight_analysis.py b/analysis_pmf_weights.py
similarity index 61%
rename from pmf_weight_analysis.py
rename to analysis_pmf_weights.py
index 059851501e92abf00fe662b44d35c296ea153c09..a2b0619646b3433f268a7c531bf7dc8a8b94f5b2 100644
--- a/pmf_weight_analysis.py
+++ b/analysis_pmf_weights.py
@@ -32,6 +32,115 @@ def pmf_type_rew(weights):
     else:
         return 2
 
+all_weight_trajectories = pickle.load(open("multi_chain_saves/all_weight_trajectories.p", 'rb'))
+
+
+# for weight_traj in all_weight_trajectories:
+#     if len(weight_traj) == 1:
+#         continue
+#     state_type = pmf_type(weights_to_pmf(weight_traj[0]))
+
+#     if state_type == 2 and weight_traj[0][0] > 0.8:
+#         plt.plot(weights_to_pmf(weight_traj[0]))
+#         plt.ylim(0, 1)
+#         plt.show()
+
+n_weights = all_weight_trajectories[0][0].shape[0]
+ylabels = ["Cont left", "Cont right", "Persevere", "Bias left", "Bias right"]
+
+for min_dur in [2, 3, 4, 5, 7, 9, 11, 15]:
+    f, axs = plt.subplots(n_weights + 1, 3, figsize=(12, 9)) # add another space to split bias into stating left versus starting right
+    average = np.zeros((n_weights + 1, 3, 2))
+    counter = np.zeros((n_weights + 1, 3))
+    all_datapoints = [[[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]], [[[], []], [[], []], [[], []]]]
+    for weight_traj in all_weight_trajectories:
+
+        if len(weight_traj) < min_dur:
+            continue
+
+        state_type = pmf_type(weights_to_pmf(weight_traj[0]))
+
+        for i in range(n_weights):
+            if i == n_weights - 1:
+                axs[i + (weight_traj[0][i] > 0), state_type].plot([0, 1], [weight_traj[0][i], weight_traj[-1][i]], marker="o") # plot how the weight evolves from first to last appearance
+                average[i + (weight_traj[0][i] > 0), state_type] += np.array([weight_traj[0][i], weight_traj[-1][i]])
+                counter[i + (weight_traj[0][i] > 0), state_type] += 1
+                all_datapoints[i + (weight_traj[0][i] > 0)][state_type][0].append(weight_traj[0][i])
+                all_datapoints[i + (weight_traj[0][i] > 0)][state_type][1].append(weight_traj[-1][i])
+            else:
+                axs[i, state_type].plot([0, 1], [weight_traj[0][i], weight_traj[-1][i]], marker="o") # plot how the weight evolves from first to last appearance
+                average[i, state_type] += np.array([weight_traj[0][i], weight_traj[-1][i]])
+                counter[i, state_type] += 1
+                all_datapoints[i][state_type][0].append(weight_traj[0][i])
+                all_datapoints[i][state_type][1].append(weight_traj[-1][i])
+    for i in range(3):
+        for j in range(n_weights + 1):
+                axs[j, i].set_ylim(-9, 9)
+                axs[j, i].spines['top'].set_visible(False)
+                axs[j, i].spines['right'].set_visible(False)
+                axs[j, i].set_xticks([])
+                if i == 0:
+                    axs[j, i].set_ylabel(ylabels[j])
+                else:
+                    axs[j, i].set_yticks([])
+                if j == 0:
+                    axs[j, i].set_title("Type {}".format(i + 1))
+                if j == n_weights:
+                    axs[j, i].set_xlabel("Lifetime weight change")
+
+    plt.tight_layout()
+    plt.savefig("./summary_figures/weight_changes/all weight changes min dur {}".format(min_dur))
+    plt.close()
+
+    f, axs = plt.subplots(n_weights + 1, 3, figsize=(12, 9))
+    for i in range(3):
+        for j in range(n_weights + 1):
+            axs[j, i].plot([0, 1], average[j, i] / counter[j, i], marker="o")
+            if j < 2:
+                axs[j, i].set_ylim(-4.5, 4.5)
+            else:
+                axs[j, i].set_ylim(-2, 2)
+            axs[j, i].spines['top'].set_visible(False)
+            axs[j, i].spines['right'].set_visible(False)
+            axs[j, i].set_xticks([])
+            if i == 0:
+                axs[j, i].set_ylabel(ylabels[j])
+            else:
+                axs[j, i].set_yticks([])
+            if j == 0:
+                axs[j, i].set_title("Type {}".format(i + 1))
+            if j == n_weights:
+                axs[j, i].set_xlabel("Lifetime weight change")
+    plt.savefig("./summary_figures/weight_changes/weight changes min dur {}".format(min_dur))
+    plt.close()
+
+    f, axs = plt.subplots(n_weights + 1, 3, figsize=(12, 9))
+    for i in range(3):
+        for j in range(n_weights + 1):
+            axs[j, i].plot([0, 1], average[j, i] / counter[j, i])
+            axs[j, i].boxplot(all_datapoints[j][i], positions=[0, 1])
+            if j < 2:
+                axs[j, i].set_ylim(-8, 8)
+            else:
+                axs[j, i].set_ylim(-4, 4)
+            axs[j, i].spines['top'].set_visible(False)
+            axs[j, i].spines['right'].set_visible(False)
+            axs[j, i].set_xticks([])
+            if i == 0:
+                axs[j, i].set_ylabel(ylabels[j])
+            else:
+                axs[j, i].set_yticks([])
+            if j == 0:
+                axs[j, i].set_title("Type {}".format(i + 1))
+            if j == n_weights:
+                axs[j, i].set_xlabel("Lifetime weight change")
+    plt.savefig("./summary_figures/weight_changes/weight change boxplot min dur {}".format(min_dur))
+    plt.close()
+
+quit()
+
+
+
 # all pmf weights
 apw = np.array(pickle.load(open("all_pmf_weights.p", 'rb')))
 
diff --git a/analysis_regression.py b/analysis_regression.py
index 575f234c91dbf83c6f0082d3ac1aa05e743d765e..8e09fa87c2a109eebd100520dc655a59f25b6590 100644
--- a/analysis_regression.py
+++ b/analysis_regression.py
@@ -11,11 +11,13 @@ if __name__ == "__main__":
     regressions = np.array(pickle.load(open("regressions.p", 'rb')))
     regression_diffs = np.array(pickle.load(open("regression_diffs.p", 'rb')))
 
-    assert (regressions[:, 0] == np.sum(regressions[:, 2:], 1)).all()
+    assert (regressions[:, 0] == np.sum(regressions[:, 2:], 1)).all()  # total # of regressions must be sum of # of regressions per type
 
     print(pearsonr(regressions[:, 0], regressions[:, 1]))
+    # (0.6202414016960471, 2.3666819600330215e-13)
 
     offset = 0.125
+    plt.figure(figsize=(16 * 0.9, 9 * 0.9))
     # which x values exist
     for x in np.unique(regressions[:, 0]):
         # which and how many ys are associated with this x
@@ -28,7 +30,7 @@ if __name__ == "__main__":
     sns.despine()
 
     plt.tight_layout()
-    plt.savefig("Regression vs session length")
+    plt.savefig("./summary_figures/Regression vs session length")
     plt.show()
 
     # histogram of regressions per type
@@ -38,7 +40,7 @@ if __name__ == "__main__":
 
     sns.despine()
     plt.tight_layout()
-    plt.savefig("# of regressions per type")
+    plt.savefig("./summary_figures/# of regressions per type")
     plt.show()
 
     # histogram of number of mice with regressions in the different types
@@ -49,7 +51,7 @@ if __name__ == "__main__":
 
     sns.despine()
     plt.tight_layout()
-    plt.savefig("# of mice with regressions per type")
+    plt.savefig("./summary_figures/# of mice with regressions per type")
     plt.show()
 
     # histogram of regression diffs
@@ -59,5 +61,5 @@ if __name__ == "__main__":
 
     sns.despine()
     plt.tight_layout()
-    plt.savefig("Regression diffs")
+    plt.savefig("./summary_figures/Regression diffs")
     plt.show()
diff --git a/analysis_state_intros.py b/analysis_state_intros.py
index 91b3250a7fee932274c8510534aa0e591912e9b7..b8435d95ce708bd5727133656af66ab64fb8fe48 100644
--- a/analysis_state_intros.py
+++ b/analysis_state_intros.py
@@ -44,14 +44,14 @@ def type_hist(data, title=''):
     plt.ylabel("Type 3", size=fontsize)
     ax2.set_yticks([])
 
-    plt.savefig(title)
+    plt.savefig("./summary_figures/" + title)
     plt.show()
 
 
 if __name__ == "__main__":
-    all_intros = pickle.load(open("all_intros.p", 'rb'))
-    all_intros_div = pickle.load(open("all_intros_div.p", 'rb'))
-    all_states_per_type = pickle.load(open("all_states_per_type.p", 'rb'))
+    all_intros = np.array(pickle.load(open("all_intros.p", 'rb')))
+    all_intros_div = np.array(pickle.load(open("all_intros_div.p", 'rb')))
+    all_states_per_type = np.array(pickle.load(open("all_states_per_type.p", 'rb')))
 
     # There are 5 mice with 0 type 2 intros, but only 3 mice with no type 2 stats.
     # That is because they come up, but don't explain the necessary 50% to start phase 2.
diff --git a/behavioral_state_data_easier.py b/behavioral_state_data_easier.py
index 18a9de78e8aea1c241ff6a474e44e58a204d84c7..aed7c5dcc445fadffa0b0ea1f319ba90aafa67f4 100644
--- a/behavioral_state_data_easier.py
+++ b/behavioral_state_data_easier.py
@@ -5,6 +5,8 @@ import pandas as pd
 import seaborn as sns
 import pickle
 import json
+import os
+import re
 
 one = ONE()
 
@@ -34,20 +36,48 @@ misses = []
 to_introduce = [2, 3, 4, 5]
 
 amiss = ['UCLA034', 'UCLA036', 'UCLA037', 'PL015', 'PL016', 'PL017', 'PL024', 'NR_0017', 'NR_0019', 'NR_0020', 'NR_0021', 'NR_0027']
-subjects = ['ZFM-04019', 'ZFM-05236']
 fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0]
 if fit_type == 'bias':
     loading_info = json.load(open("canonical_infos_bias.json", 'r'))
 elif fit_type == 'prebias':
     loading_info = json.load(open("canonical_infos.json", 'r'))
+bwm = ['NYU-11', 'NYU-12', 'NYU-21', 'NYU-27', 'NYU-30', 'NYU-37',
+    'NYU-39', 'NYU-40', 'NYU-45', 'NYU-46', 'NYU-47', 'NYU-48',
+    'CSHL045', 'CSHL047', 'CSHL049', 'CSHL051', 'CSHL052', 'CSHL053',
+    'CSHL054', 'CSHL055', 'CSHL058', 'CSHL059', 'CSHL060', 'UCLA005',
+    'UCLA006', 'UCLA011', 'UCLA012', 'UCLA014', 'UCLA015', 'UCLA017',
+    'UCLA033', 'UCLA034', 'UCLA035', 'UCLA036', 'UCLA037', 'KS014',
+    'KS016', 'KS022', 'KS023', 'KS042', 'KS043', 'KS044', 'KS045',
+    'KS046', 'KS051', 'KS052', 'KS055', 'KS084', 'KS086', 'KS091',
+    'KS094', 'KS096', 'DY_008', 'DY_009', 'DY_010', 'DY_011', 'DY_013',
+    'DY_014', 'DY_016', 'DY_018', 'DY_020', 'PL015', 'PL016', 'PL017',
+    'PL024', 'SWC_042', 'SWC_043', 'SWC_060', 'SWC_061', 'SWC_066',
+    'ZFM-01576', 'ZFM-01577', 'ZFM-01592', 'ZFM-01935', 'ZFM-01936',
+    'ZFM-01937', 'ZFM-02368', 'ZFM-02369', 'ZFM-02370', 'ZFM-02372',
+    'ZFM-02373', 'ZM_1897', 'ZM_1898', 'ZM_2240', 'ZM_2241', 'ZM_2245',
+    'ZM_3003', 'SWC_038', 'SWC_039', 'SWC_052', 'SWC_053', 'SWC_054',
+    'SWC_058', 'SWC_065', 'NR_0017', 'NR_0019', 'NR_0020', 'NR_0021',
+    'NR_0027', 'ibl_witten_13', 'ibl_witten_17', 'ibl_witten_18',
+    'ibl_witten_19', 'ibl_witten_20', 'ibl_witten_25', 'ibl_witten_26',
+    'ibl_witten_27', 'ibl_witten_29', 'CSH_ZAD_001', 'CSH_ZAD_011',
+    'CSH_ZAD_019', 'CSH_ZAD_022', 'CSH_ZAD_024', 'CSH_ZAD_025',
+    'CSH_ZAD_026', 'CSH_ZAD_029']
+regexp = re.compile(r'canonical_result_((\w|-)+)_prebias.p')
+subjects = []
+for filename in os.listdir("./multi_chain_saves/"):
+    if not (filename.startswith('canonical_result_') and filename.endswith('.p')):
+        continue
+    result = regexp.search(filename)
+    if result is None:
+        continue
+    subject = result.group(1)
+    subjects.append(subject)
 already_fit = list(loading_info.keys())
 
-remaining_subs = [s for s in subjects if s not in amiss and s not in already_fit]
-print(remaining_subs)
-
+# remaining_subs = [s for s in subjects if s not in amiss and s not in already_fit]
+# print(remaining_subs)
 
-
-data_folder = 'session_data_test'
+data_folder = 'session_data'
 
 old_style = False
 if old_style:
@@ -64,22 +94,30 @@ names = []
 
 pre_bias = []
 entire_training = []
+training_status_reached = []
+actually_existing = []
 for subject in subjects:
+    if subject in bwm:
+        continue
     print('_____________________')
     print(subject)
 
     # if subject in already_fit or subject in amiss:
     #     continue
-
-    trials = one.load_aggregate('subjects', subject, '_ibl_subjectTrials.table')
+    try:
+        trials = one.load_aggregate('subjects', subject, '_ibl_subjectTrials.table')
 
     # Load training status and join to trials table
-    training = one.load_aggregate('subjects', subject, '_ibl_subjectTraining.table')
-    quit()
-    trials = (trials
-              .set_index('session')
-              .join(training.set_index('session'))
-              .sort_values(by='session_start_time', kind='stable'))
+    
+        training = one.load_aggregate('subjects', subject, '_ibl_subjectTraining.table')
+
+        trials = (trials
+                  .set_index('session')
+                  .join(training.set_index('session'))
+                  .sort_values(by='session_start_time', kind='stable'))
+        actually_existing.append(subject)
+    except:
+        continue
 
     start_times, indices = np.unique(trials.session_start_time, return_index=True)
     start_times = [trials.session_start_time[index] for index in sorted(indices)]
@@ -99,16 +137,19 @@ for subject in subjects:
     easy_per = np.zeros(len(eids))
     hard_per = np.zeros(len(eids))
     bias_start = 0
+    ephys_start = 0
 
     info_dict = {'subject': subject, 'dates': [st.to_pydatetime() for st in start_times], 'eids': eids}
     contrast_set = {0, 1, 9, 10}
 
     rel_count = -1
-    quit()
+
     for i, start_time in enumerate(start_times):
 
         rel_count += 1
 
+        assert rel_count == i
+
         df = trials[trials.session_start_time == start_time]
         df.loc[:, 'contrastRight'] = df.loc[:, 'contrastRight'].fillna(0)
         df.loc[:, 'contrastLeft'] = df.loc[:, 'contrastLeft'].fillna(0)
@@ -137,9 +178,15 @@ for subject in subjects:
             bias_start = i
             print("bias start {}".format(rel_count))
             info_dict['bias_start'] = rel_count
+            training_status_reached.append(set(df.training_status))
             if bias_start < 33:
                 short_subjs.append(subject)
 
+        if ephys_start == 0 and df.task_protocol[0].startswith('_iblrig_tasks_ephysChoiceWorld'):
+            ephys_start = i
+            print("ephys start {}".format(rel_count))
+            info_dict['ephys_start'] = rel_count
+
         pickle.dump(df, open("./{}/{}_df_{}.p".format(data_folder, subject, rel_count), "wb"))
 
         side_info = np.zeros((len(df), 2))
diff --git a/canonical_infos.json b/canonical_infos.json
index 99ae626eb5a4f9444331e2cd47bdbe9b3c2131ce..9e26dfeeb6e641a33dae4961196235bdb965b21b 100644
--- a/canonical_infos.json
+++ b/canonical_infos.json
@@ -1 +1 @@
-{"NYU-45": {"seeds": ["513", "503", "500", "506", "502", "501", "512", "509", "507", "515", "510", "505", "504", "514", "511", "508"], "fit_nums": ["768", "301", "96", "731", "879", "989", "915", "512", "295", "48", "157", "631", "666", "334", "682", "714"], "chain_num": 14}, "UCLA035": {"seeds": ["512", "509", "500", "510", "501", "503", "506", "513", "505", "504", "507", "502", "514", "508", "511", "515"], "fit_nums": ["715", "834", "996", "656", "242", "883", "870", "959", "483", "94", "864", "588", "390", "173", "967", "871"], "chain_num": 14}, "NYU-30": {"seeds": ["505", "508", "515", "507", "504", "513", "512", "503", "509", "500", "510", "514", "501", "502", "511", "506"], "fit_nums": ["885", "637", "318", "98", "209", "171", "472", "823", "956", "89", "762", "260", "76", "319", "139", "785"], "chain_num": 14}, "CSHL047": {"seeds": ["509", "510", "503", "502", "501", "508", "514", "505", "507", "511", "515", "504", "500", "506", "513", "512"], "fit_nums": ["60", "589", "537", "3", "178", "99", "877", "381", "462", "527", "6", "683", "771", "950", "294", "252"], "chain_num": 14}, "NYU-39": {"seeds": ["515", "502", "513", "508", "514", "503", "510", "506", "509", "504", "511", "500", "512", "501", "507", "505"], "fit_nums": ["722", "12", "207", "378", "698", "928", "15", "180", "650", "334", "388", "528", "608", "593", "988", "479"], "chain_num": 14}, "NYU-37": {"seeds": ["508", "509", "506", "512", "507", "503", "515", "504", "500", "505", "513", "501", "510", "511", "502", "514"], "fit_nums": ["94", "97", "793", "876", "483", "878", "886", "222", "66", "59", "601", "994", "526", "694", "304", "615"], "chain_num": 14}, "KS045": {"seeds": ["501", "514", "500", "502", "503", "505", "510", "515", "508", "507", "509", "512", "511", "506", "513", "504"], "fit_nums": ["731", "667", "181", "609", "489", "555", "995", "19", "738", "1", "267", "653", "750", "332", "218", "170"], "chain_num": 14}, "UCLA006": {"seeds": ["509", "505", "513", "515", "511", "500", "503", "507", "508", "514", "504", "510", "502", "501", "506", "512"], "fit_nums": ["807", "849", "196", "850", "293", "874", "216", "542", "400", "632", "781", "219", "331", "730", "740", "32"], "chain_num": 14}, "UCLA033": {"seeds": ["507", "505", "501", "512", "502", "510", "513", "514", "506", "515", "509", "504", "508", "500", "503", "511"], "fit_nums": ["366", "18", "281", "43", "423", "877", "673", "146", "921", "21", "353", "267", "674", "113", "905", "252"], "chain_num": 14}, "NYU-40": {"seeds": ["513", "500", "503", "507", "515", "504", "510", "508", "505", "514", "502", "512", "501", "509", "511", "506"], "fit_nums": ["656", "826", "657", "634", "861", "347", "334", "227", "747", "834", "460", "191", "489", "458", "24", "346"], "chain_num": 14}, "NYU-46": {"seeds": ["507", "509", "512", "508", "503", "500", "515", "501", "511", "506", "502", "505", "510", "513", "514", "504"], "fit_nums": ["503", "523", "5", "819", "190", "917", "707", "609", "145", "416", "376", "603", "655", "271", "223", "149"], "chain_num": 14}, "KS044": {"seeds": ["513", "503", "507", "504", "502", "515", "501", "511", "510", "500", "512", "508", "509", "506", "514", "505"], "fit_nums": ["367", "656", "73", "877", "115", "627", "610", "772", "558", "581", "398", "267", "353", "779", "393", "473"], "chain_num": 14}, "NYU-48": {"seeds": ["502", "507", "503", "515", "513", "505", "512", "501", "510", "504", "508", "511", "506", "514", "500", "509"], "fit_nums": ["854", "52", "218", "963", "249", "901", "248", "322", "566", "768", "256", "101", "303", "485", "577", "141"], "chain_num": 14}, "UCLA012": {"seeds": ["508", "506", "504", "513", "512", "502", "507", "505", "510", "503", "515", "509", "511", "501", "500", "514"], "fit_nums": ["492", "519", "577", "417", "717", "60", "130", "186", "725", "83", "841", "65", "441", "534", "856", "735"], "chain_num": 14}, "KS084": {"seeds": ["515", "507", "512", "503", "506", "508", "510", "502", "509", "505", "500", "511", "504", "501", "513", "514"], "fit_nums": ["816", "374", "140", "955", "399", "417", "733", "149", "300", "642", "644", "248", "324", "830", "889", "286"], "chain_num": 14}, "CSHL052": {"seeds": ["502", "508", "509", "514", "507", "515", "506", "500", "501", "505", "504", "511", "513", "512", "503", "510"], "fit_nums": ["572", "630", "784", "813", "501", "738", "517", "461", "690", "203", "40", "202", "412", "755", "837", "917"], "chain_num": 14}, "NYU-11": {"seeds": ["502", "510", "501", "513", "512", "500", "503", "515", "506", "514", "511", "505", "509", "504", "508", "507"], "fit_nums": ["650", "771", "390", "185", "523", "901", "387", "597", "57", "624", "12", "833", "433", "58", "276", "248"], "chain_num": 14}, "KS051": {"seeds": ["502", "505", "508", "515", "512", "511", "501", "509", "500", "506", "513", "503", "504", "507", "514", "510"], "fit_nums": ["620", "548", "765", "352", "402", "699", "370", "445", "159", "746", "449", "342", "642", "204", "726", "605"], "chain_num": 14}, "NYU-27": {"seeds": ["507", "513", "501", "505", "511", "504", "500", "514", "502", "509", "503", "515", "512", "510", "508", "506"], "fit_nums": ["485", "520", "641", "480", "454", "913", "526", "705", "138", "151", "962", "24", "21", "743", "119", "699"], "chain_num": 14}, "UCLA011": {"seeds": ["513", "503", "508", "501", "510", "505", "506", "511", "507", "514", "515", "500", "509", "502", "512", "504"], "fit_nums": ["957", "295", "743", "795", "643", "629", "142", "174", "164", "21", "835", "338", "368", "341", "209", "68"], "chain_num": 14}, "NYU-47": {"seeds": ["515", "504", "510", "503", "505", "506", "502", "501", "507", "500", "514", "513", "508", "509", "511", "512"], "fit_nums": ["169", "941", "329", "51", "788", "654", "224", "434", "385", "46", "712", "84", "930", "571", "273", "312"], "chain_num": 14}, "CSHL045": {"seeds": ["514", "502", "510", "507", "512", "501", "511", "506", "508", "509", "503", "513", "500", "504", "515", "505"], "fit_nums": ["862", "97", "888", "470", "620", "765", "874", "421", "104", "909", "924", "874", "158", "992", "25", "40"], "chain_num": 14}, "UCLA017": {"seeds": ["510", "502", "500", "515", "503", "507", "514", "501", "511", "505", "513", "504", "508", "512", "506", "509"], "fit_nums": ["875", "684", "841", "510", "209", "207", "806", "700", "989", "899", "812", "971", "526", "887", "160", "249"], "chain_num": 14}, "CSHL055": {"seeds": ["509", "503", "505", "512", "504", "508", "510", "511", "507", "501", "500", "514", "513", "515", "506", "502"], "fit_nums": ["957", "21", "710", "174", "689", "796", "449", "183", "193", "209", "437", "827", "990", "705", "540", "835"], "chain_num": 14}, "UCLA005": {"seeds": ["507", "503", "512", "515", "505", "500", "504", "513", "509", "506", "501", "511", "514", "502", "510", "508"], "fit_nums": ["636", "845", "712", "60", "733", "789", "990", "230", "335", "337", "307", "404", "297", "608", "428", "108"], "chain_num": 14}, "CSHL060": {"seeds": ["510", "515", "513", "503", "514", "501", "504", "502", "511", "500", "508", "509", "505", "507", "506", "512"], "fit_nums": ["626", "953", "497", "886", "585", "293", "580", "867", "113", "734", "88", "55", "949", "443", "210", "555"], "chain_num": 14}, "UCLA015": {"seeds": ["503", "501", "502", "500"], "fit_nums": ["877", "773", "109", "747"], "chain_num": 14}, "KS055": {"seeds": ["510", "502", "511", "501", "509", "506", "500", "515", "503", "512", "504", "505", "513", "508", "514", "507"], "fit_nums": ["526", "102", "189", "216", "673", "477", "293", "981", "960", "883", "899", "95", "31", "244", "385", "631"], "chain_num": 14}, "UCLA014": {"seeds": ["513", "509", "508", "510", "500", "504", "515", "511", "501", "514", "507", "502", "505", "512", "506", "503"], "fit_nums": ["628", "950", "418", "91", "25", "722", "792", "225", "287", "272", "23", "168", "821", "934", "194", "481"], "chain_num": 14}, "CSHL053": {"seeds": ["505", "513", "501", "508", "502", "515", "510", "507", "506", "509", "514", "511", "500", "503", "504", "512"], "fit_nums": ["56", "674", "979", "0", "221", "577", "25", "679", "612", "185", "464", "751", "648", "715", "344", "348"], "chain_num": 14}, "NYU-12": {"seeds": ["508", "506", "509", "515", "511", "510", "501", "504", "513", "503", "507", "512", "500", "514", "505", "502"], "fit_nums": ["853", "819", "692", "668", "730", "213", "846", "596", "644", "829", "976", "895", "974", "824", "179", "769"], "chain_num": 14}, "KS043": {"seeds": ["505", "504", "507", "500", "502", "503", "508", "511", "512", "509", "515", "513", "510", "514", "501", "506"], "fit_nums": ["997", "179", "26", "741", "476", "502", "597", "477", "511", "181", "233", "330", "299", "939", "542", "113"], "chain_num": 14}, "CSHL058": {"seeds": ["513", "505", "514", "506", "504", "500", "511", "503", "508", "509", "501", "510", "502", "515", "507", "512"], "fit_nums": ["752", "532", "949", "442", "400", "315", "106", "419", "903", "198", "553", "158", "674", "249", "723", "941"], "chain_num": 14}, "KS042": {"seeds": ["503", "502", "506", "511", "501", "505", "512", "509", "513", "508", "507", "500", "510", "504", "515", "514"], "fit_nums": ["241", "895", "503", "880", "283", "267", "944", "204", "921", "514", "392", "241", "28", "905", "334", "894"], "chain_num": 14}}
\ No newline at end of file
+{}
\ No newline at end of file
diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py
index 1b8062228711583bf6024e9b68935547ab36d6ce..37bb5dadb5fc41fd30fa438a8db40c146b2595ca 100644
--- a/dyn_glm_chain_analysis.py
+++ b/dyn_glm_chain_analysis.py
@@ -22,7 +22,24 @@ import multiprocessing as mp
 from mcmc_chain_analysis import state_num_helper, ll_func, r_hat_array_comp, rank_inv_normal_transform
 import pandas as pd
 from analysis_pmf import pmf_type, type2color
+import re
 
+performance_points = np.array([-1, -1, 0, 0])
+
+def pmf_to_perf(pmf):
+    # determine performance of a pmf, but only on the omnipresent strongest contrasts
+    return np.mean(np.abs(performance_points + pmf[[0, 1, -2, -1]]))
+
+# def pmf_type_temp(pmf):
+#     rew = pmf_to_perf(pmf)
+#     if rew < 0.6:
+#         return 0
+#     elif rew < 0.7827:
+#         if np.abs(pmf[0] + pmf[-1] - 1) <= 0.1:
+#             return 3
+#         return 1
+#     else:
+#         return 2
 
 colors = np.genfromtxt('colors.csv', delimiter=',')
 
@@ -45,15 +62,6 @@ def weights_to_pmf(weights, with_bias=1):
     psi = weights[0] * contrasts_R + weights[1] * contrasts_L + with_bias * weights[-1]
     return 1 / (1 + np.exp(psi))  # we somehow got the answers twisted, so we drop the minus here to get the opposite response probability for plotting
 
-performance_points = np.array([-1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0])
-def pmf_to_perf(pmf, def_points):
-    # determine performance of a pmf
-    # we use this to determine regressions in the behaviour of animals
-    # therefore, we exclude 0 as performance on it is 0.5 regardless of PMF, but it might
-    # overall lower performance. The removal of 0.5 later might also be a problem, let's see
-    relevant_points = def_points
-    relevant_points[5] = False
-    return np.mean(np.abs(performance_points[relevant_points] + pmf[relevant_points]))
 
 
 class MCMC_result_list:
@@ -305,13 +313,13 @@ class MCMC_result_list:
 
         state_appear = []
         state_durs = []
-        state_appear_dist = np.zeros(10)
+        state_appear_dist = np.zeros(11)
         for res in self.results:
             for j in range(take_n):
                 sa, sd = state_appear_and_dur(res.models[self.m // take_n * j], self)
                 state_appear += sa
                 state_durs += sd
-                state_appear_dist[[int(s * 10) for s in sa]] += 1
+                state_appear_dist[[int(s * 11) for s in sa]] += 1
         return state_appear, state_durs, state_appear_dist / self.m / take_n
 
 
@@ -392,7 +400,6 @@ class MCMC_result:
                     state_start_save[20 * index // len(seq)] += 1
                     state_starts[session_bounds[i] + index] += 1
                     observed_states.append(s)
-        print('observed_states {}'.format(observed_states))
         plt.plot(state_starts / self.n_samples)
         return session_bounds, state_start_save / self.n_samples
 
@@ -625,6 +632,7 @@ def bias_flips(states_by_session, pmfs, durs):
 
 def pmf_regressions(states_by_session, pmfs, durs):
     # find out whether pmfs regressed
+    # return: [total # of regressions, # of sessions, # of regressions during type 1, type 2, type 3], the reward differences
     state_perfs = {}
     state_counter = {}
     current_best_state = -1
@@ -1077,7 +1085,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
                     ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmf[defined_points] - 0.5 + test.state_mapping[state], color=cmap(0.2 + 0.8 * j / test.results[0].n_sessions))
                     ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmf[defined_points] - 0.5 + test.state_mapping[state], ls='', ms=7, marker='*', color=cmap(j / test.results[0].n_sessions))
             all_pmfs.append((defined_points, pmfs))
-            all_pmf_weights += pmf_weights
+            all_pmf_weights.append(pmf_weights)
         else:
             temp = np.percentile(pmfs, [2.5, 97.5], axis=0)
             if not test.state_mapping[state] in dont_plot:
@@ -1382,7 +1390,7 @@ def state_type_durs(states, pmfs):
     # Takes states and pmfs, first creates an array of when which type is how active, then computes the number of sessions each type lasts.
     # A type lasts until a more advanced type takes up more than 50% of a session (and cannot return)
     # Returns the durations for all the different state types, and an array which holds the state percentages
-    state_types = np.zeros((3, states.shape[1]))
+    state_types = np.zeros((4, states.shape[1]))
     for s, pmf in zip(states, pmfs):
         pmf_counter = -1
         for i in range(states.shape[1]):
@@ -1427,9 +1435,9 @@ def state_cluster_interpolation(states, pmfs):
 
 def get_first_pmfs(states, pmfs):
     # get the first pmf of every type, also where they are defined, and whether they are the first pmf of that state
-    earliest_sessions = [1000, 1000, 1000]
-    first_pmfs = [0, 0, 0, 0, 0, 0, 0, 0, 0]
-    changing_pmfs = [[0, 0], [0, 0]]
+    earliest_sessions = [1000, 1000, 1000]  # high values to compare with min
+    first_pmfs = [0, 0, 0, 0, 0, 0, 0, 0, 0]  # for every type (3), we save: the pmf, the defined points, the session at which they appear
+    changing_pmfs = [[0, 0], [0, 0]]  # if the new type appears first through a slow transition, we remember its defined points (pmf[1]) and its pmf trajectory pmf[1]
     for state, pmf in zip(states, pmfs):
         sessions = np.where(state)[0]
         for i, (sess_pmf, sess) in enumerate(zip(pmf[1], sessions)):
@@ -1532,26 +1540,25 @@ if __name__ == "__main__":
         loading_info = json.load(open("canonical_infos_bias.json", 'r'))
     elif fit_type == 'prebias':
         loading_info = json.load(open("canonical_infos.json", 'r'))
-    no_good_pcas = ['NYU-06', 'SWC_023']
-    subjects = list(loading_info.keys())
-
-    new_subs = ['KS044', 'NYU-11', 'DY_016', 'SWC_061', 'ZFM-05245', 'CSH_ZAD_029', 'SWC_021', 'CSHL058', 'DY_014', 'DY_009', 'KS094', 'DY_018', 'KS043', 'UCLA014', 'SWC_038', 'SWC_022', 'UCLA012', 'UCLA011', 'CSHL055', 'ZFM-04019', 'NYU-45', 'ZFM-02370', 'ZFM-02373', 'ZFM-02369', 'NYU-40', 'CSHL060', 'NYU-30', 'CSH_ZAD_019', 'UCLA017', 'KS052', 'ibl_witten_25', 'ZFM-02368', 'CSHL045', 'UCLA005', 'SWC_058', 'CSH_ZAD_024', 'SWC_042', 'DY_008', 'ibl_witten_13', 'SWC_043', 'KS046', 'DY_010', 'CSHL053', 'ZM_1898', 'UCLA033', 'NYU-47', 'DY_011', 'CSHL047', 'SWC_054', 'ibl_witten_19', 'ibl_witten_27', 'KS091', 'KS055', 'CSH_ZAD_017', 'UCLA035', 'SWC_060', 'DY_020', 'ZFM-01577', 'ZM_2240', 'ibl_witten_29', 'KS096', 'SWC_066', 'DY_013', 'ZFM-01592', 'GLM_Sim_17', 'NYU-48', 'UCLA006', 'NYU-39', 'KS051', 'NYU-27', 'NYU-46', 'ZFM-01936', 'ZFM-02372', 'ZFM-01935', 'ibl_witten_26', 'ZFM-05236', 'ZM_2241', 'NYU-37', 'KS086', 'KS084', 'ZFM-01576', 'KS042']
-
-    miss_count = 0
-    for s in new_subs:
-        if s not in subjects:
-            print(s)
-            miss_count += 1
-    quit()
+    subjects = []
+    regexp = re.compile(r'canonical_result_((\w|-)+)_prebias.p')
+    for filename in os.listdir("./multi_chain_saves/"):
+        if not (filename.startswith('canonical_result_') and filename.endswith('.p')):
+            continue
+        result = regexp.search(filename)
+        if result is None:
+            continue
+        subject = result.group(1)
+        subjects.append(subject)
 
-    print(subjects)
+    print(len(subjects))
     fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0]
     dur = 'yes'
 
     # fig, ax = plt.subplots(1, 3, sharey=True, figsize=(16, 9))
 
     pop_state_starts = np.zeros(20)
-    state_appear_dist = np.zeros(10)
+    state_appear_dist = np.zeros(11)
     state_appear_mode = []
     num_states = []
     num_sessions = []
@@ -1581,6 +1588,9 @@ if __name__ == "__main__":
     regression_diffs = []
     all_bias_flips = []
     all_pmf_weights = []
+    all_weight_trajectories = []
+    bias_sessions = []
+    first_and_last_pmf = []
 
     new_counter, transform_counter = 0, 0
     state_types_interpolation = np.zeros((3, 150))
@@ -1589,17 +1599,29 @@ if __name__ == "__main__":
     state_nums_5 = []
     state_nums_10 = []
 
-    for subject in subjects:
+    ultimate_counter = 0
 
-        # NYU-11 is quite weird, super errrativ behaviour, has all contrasts introduced at once, no good session at end
-        if subject.startswith('GLM_Sim_'):
+    for subject in subjects:
+        if subject.startswith('GLM_Sim_') or subject in ['SWC_065', 'ZFM-05245', 'ZFM-04019', 'ibl_witten_18']:
+            # ibl_witten_18 is a weird one, super good session in the middle, ending phase 1, never to re-appear, bad at the end
+            # ZFM-05245 is neuromodulator mouse, never reaches ephys it seems... same for ZFM-04019
+            # SWC_065 never reaches type 3
             continue
 
         print()
         print(subject)
+        continue
 
         test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'rb'))
 
+        # all_state_starts = test.state_appearance_posterior(subject)
+        # pop_state_starts += all_state_starts
+
+        # a, b, c = test.state_start_and_dur()
+        # state_appear += a
+        # state_dur += b
+        # state_appear_dist += c
+
         mode_specifier = 'first'
         try:
             mode_indices = pickle.load(open("multi_chain_saves/{}_mode_indices_{}_{}.p".format(mode_specifier, subject, fit_type), 'rb'))
@@ -1609,6 +1631,9 @@ if __name__ == "__main__":
                 mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'rb'))
                 state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb'))
             except:
+                print("____________________________________")
+                print("Something quite wrong with {}".format(subject))
+                print("____________________________________")
                 continue
 
         # lapse differential
@@ -1619,18 +1644,37 @@ if __name__ == "__main__":
         # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 1', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(7)), plot_until=2)
         # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 2', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(6)), plot_until=7)
         # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 3', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(4)), plot_until=13)
-        states, pmfs, pmf_weights, durs, state_types, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True)
-        quit()
+        states, pmfs, pmf_weights, durs, state_types, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save=False, show=0, separate_pmf=1, type_coloring=True)
+        
+        first_and_last_pmf.append((pmf_weights[np.argmax(states[:, 0])][0], pmf_weights[np.argmax(states[:, -1])][-1]))
+
+        continue
+        # all_weight_trajectories += pmf_weights
+        abs_state_durs.append(durs)
+        ultimate_counter += 1
+
+        info_dict = pickle.load(open("./{}/{}_info_dict.p".format('session_data', subject), "rb"))
+        if 'ephys_start' in info_dict:
+            bias_sessions.append(info_dict['ephys_start'] - info_dict['bias_start'])
+        else:
+            bias_sessions.append(info_dict['n_sessions'] - info_dict['bias_start'])
+            print(subject, info_dict['n_sessions'], info_dict['bias_start'])
+
+        continue
+
+        state_types_interpolation[0] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[0])
+        state_types_interpolation[1] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[1])
+        state_types_interpolation[2] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[2])
         # state_nums_5.append((states > 0.05).sum(0))
         # state_nums_10.append((states > 0.1).sum(0))
         # continue
         # a = compare_params(test, 26, [s for s in state_sets if len(s) > 40], mode_indices, [3, 5])
         # compare_pmfs(test, [3, 2, 4], states, pmfs, title="{} convergence pmf".format(subject))
-        consistencies = pickle.load(open("multi_chain_saves/first_mode_consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
-        consistencies /= consistencies[0, 0]
-        temp = contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True, consistencies=consistencies, CMF=False)
-        continue
-        quit()
+        # consistencies = pickle.load(open("multi_chain_saves/first_mode_consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
+        # consistencies /= consistencies[0, 0]
+        # temp = contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True, consistencies=consistencies, CMF=False)
+        # quit()
+
         new = type_2_appearance(states, pmfs)
 
         if new == 2:
@@ -1644,60 +1688,50 @@ if __name__ == "__main__":
         print(new_counter, transform_counter)
 
 
-        state_dict = write_results(test, [s for s in state_sets if len(s) > 40], mode_indices)
-        pickle.dump(state_dict, open("state_dict_{}".format(subject), 'wb'))
-        quit()
-        # abs_state_durs.append(durs)
-        # continue
-        # all_pmf_weights += pmf_weights
-        # all_state_types.append(state_types)
-        #
-        # state_types_interpolation[0] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[0])
-        # state_types_interpolation[1] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[1])
-        # state_types_interpolation[2] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[2])
-        #
-        # all_first_pmfs_typeless[subject] = []
-        # for pmf in pmfs:
-        #     all_first_pmfs_typeless[subject].append((pmf[0], pmf[1][0]))
-        #     all_pmfs.append(pmf)
-        #
-        # first_pmfs, changing_pmfs = get_first_pmfs(states, pmfs)
-        # for pmf in changing_pmfs:
-        #     if type(pmf[0]) == int:
-        #         continue
-        #     all_changing_pmf_names.append(subject)
-        #     all_changing_pmfs.append(pmf)
-        #
-        # all_first_pmfs[subject] = first_pmfs
-        # continue
-        #
-        # consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
-        # consistencies /= consistencies[0, 0]
-        # contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=False, consistencies=consistencies, CMF=False)
+        # state_dict = write_results(test, [s for s in state_sets if len(s) > 40], mode_indices)
+        # pickle.dump(state_dict, open("state_dict_{}".format(subject), 'wb'))
+
+        all_pmf_weights += [item for sublist in pmf_weights for item in sublist]
+        all_state_types.append(state_types)
+        
+        all_first_pmfs_typeless[subject] = []
+        for pmf in pmfs:
+            all_first_pmfs_typeless[subject].append((pmf[0], pmf[1][0]))
+            all_pmfs.append(pmf)
+        
+        first_pmfs, changing_pmfs = get_first_pmfs(states, pmfs)
+        for pmf in changing_pmfs:
+            if type(pmf[0]) == int:
+                continue
+            all_changing_pmf_names.append(subject)
+            all_changing_pmfs.append(pmf)
+        
+        all_first_pmfs[subject] = first_pmfs
 
+        quit()
 
         # b_flips = bias_flips(states, pmfs, durs)
         # all_bias_flips.append(b_flips)
 
-        # regression, diffs = pmf_regressions(states, pmfs, durs)
-        # regression_diffs += diffs
-        # regressions.append(regression)
-        # continue
+        regression, diffs = pmf_regressions(states, pmfs, durs)
+        regression_diffs += diffs
+        regressions.append(regression)
+
         # compare_pmfs(test, [s for s in state_sets if len(s) > 40], mode_indices, [4, 5], states, pmfs, title="{} convergence pmf".format(subject))
         # compare_weights(test, [s for s in state_sets if len(s) > 40], mode_indices, [4, 5], states, title="{} convergence weights".format(subject))
-        # quit()
-        # all_intros.append(undiv_intros)
-        # all_intros_div.append(intros_by_type)
-        # if states_per_type != []:
-        #     all_states_per_type.append(states_per_type)
-        #
-        # intros_by_type_sum += intros_by_type
-        # for pmf in pmfs:
-        #     all_pmfs.append(pmf)
-        #     for p in pmf[1]:
-        #         all_pmf_diffs.append(p[-1] - p[0])
-        #         all_pmf_asymms.append(np.abs(p[0] + p[-1] - 1))
-        # contrast_intro_types.append(contrast_intro_type)
+
+        all_intros.append(undiv_intros)
+        all_intros_div.append(intros_by_type)
+        if states_per_type != []:
+            all_states_per_type.append(states_per_type)
+        
+        intros_by_type_sum += intros_by_type
+        for pmf in pmfs:
+            all_pmfs.append(pmf)
+            for p in pmf[1]:
+                all_pmf_diffs.append(p[-1] - p[0])
+                all_pmf_asymms.append(np.abs(p[0] + p[-1] - 1))
+        contrast_intro_types.append(contrast_intro_type)
 
         # num_states.append(np.mean(test.state_num_dist()))
         # num_sessions.append(test.results[0].n_sessions)
@@ -1745,13 +1779,7 @@ if __name__ == "__main__":
         #         quit()
         # quit()
 
-        # all_state_starts = test.state_appearance_posterior(subject)
-        # pop_state_starts += all_state_starts
         #
-        # a, b, c = test.state_start_and_dur()
-        # state_appear += a
-        # state_dur += b
-        # state_appear_dist += c
 
 
         # all_state_starts = test.state_appearance_posterior(subject)
@@ -1759,7 +1787,6 @@ if __name__ == "__main__":
         # plt.savefig("temp")
         # plt.close()
 
-
     # pickle.dump(all_first_pmfs, open("all_first_pmfs.p", 'wb'))
     # pickle.dump(all_changing_pmfs, open("changing_pmfs.p", 'wb'))
     # pickle.dump(all_changing_pmf_names, open("changing_pmf_names.p", 'wb'))
@@ -1774,24 +1801,50 @@ if __name__ == "__main__":
     # pickle.dump(all_state_types, open("all_state_types.p", 'wb'))
     # pickle.dump(all_pmf_weights, open("all_pmf_weights.p", 'wb'))
     # pickle.dump(state_types_interpolation, open("state_types_interpolation.p", 'wb'))
+    # pickle.dump(state_types_interpolation, open("state_types_interpolation_4_states.p", 'wb'))  # special version, might not want to use
     # abs_state_durs = np.array(abs_state_durs)
     # pickle.dump(abs_state_durs, open("multi_chain_saves/abs_state_durs.p", 'wb'))
+    # pickle.dump(pop_state_starts, open("multi_chain_saves/pop_state_starts.p", 'wb'))
+    # pickle.dump(state_appear, open("multi_chain_saves/state_appear.p", 'wb'))
+    # pickle.dump(state_dur, open("multi_chain_saves/state_dur.p", 'wb'))
+    # pickle.dump(state_appear_dist, open("multi_chain_saves/state_appear_dist.p", 'wb'))
+    # pickle.dump(all_weight_trajectories, open("multi_chain_saves/all_weight_trajectories.p", 'wb'))
+    # pickle.dump(bias_sessions, open("multi_chain_saves/bias_sessions.p", 'wb'))
+    # pickle.dump(first_and_last_pmf, open("multi_chain_saves/first_and_last_pmf.p", 'wb'))
+    
+    abs_state_durs = pickle.load(open("multi_chain_saves/abs_state_durs.p", 'rb'))
+    bias_sessions = pickle.load(open("multi_chain_saves/bias_sessions.p", 'rb'))
+
+    quit()
+    print("Ultimate count is {}".format(ultimate_counter))
 
-    if True:
+    if False:
         abs_state_durs = pickle.load(open("multi_chain_saves/abs_state_durs.p", 'rb'))
 
         print("Correlations")
         from scipy.stats import pearsonr
         print(pearsonr(abs_state_durs[:, 0], abs_state_durs[:, 1]))
+        # (0.30202072153268694, 0.0011496023265844494)
         print(pearsonr(abs_state_durs[:, 0], abs_state_durs[:, 2]))
+        # (-0.008129156810930915, 0.9318998938147801)
         print(pearsonr(abs_state_durs[:, 1], abs_state_durs[:, 2]))
+        # (0.20293276001562754, 0.03110687851587209)
 
         print(pearsonr(abs_state_durs.sum(1), abs_state_durs[:, 0]))
-        # (0.7338297529946006, 2.6332570579118393e-06)
+        # (0.3930315541710777, 1.6612856471614082e-05)
         print(pearsonr(abs_state_durs.sum(1), abs_state_durs[:, 1]))
-        # (0.35094585023228597, 0.052897046343413114)
+        # (0.49183970714755426, 3.16067121387928e-08)
         print(pearsonr(abs_state_durs.sum(1), abs_state_durs[:, 2]))
-        # (0.7210260323745921, 4.747833912452452e-06)
+        # (0.8936241158982767, 2.0220064236645623e-40)
+
+        print(pearsonr(abs_state_durs[:, 0], bias_sessions))
+        # (-0.048942358403448086, 0.6099754179426583)
+        print(pearsonr(abs_state_durs[:, 1], bias_sessions))
+        # (-0.07904742327893216, 0.4095531301939973)
+        print(pearsonr(abs_state_durs[:, 2], bias_sessions))
+        # (-0.0697415785143183, 0.4670166281153171)
+        print(pearsonr(abs_state_durs.sum(1), bias_sessions))
+        # (-0.09311057488563208, 0.3310557265430209)
 
         from simplex_plot import plotSimplex
 
@@ -1804,9 +1857,11 @@ if __name__ == "__main__":
         plt.ylabel("# of mice", size=40)
         plt.xlabel('# of sessions', size=40)
         plt.tight_layout()
-        plt.savefig("session_num_hist.png", dpi=300, transparent=True)
+        plt.savefig("./summary_figures/session_num_hist.png", dpi=300, transparent=True)
         plt.show()
 
+        quit()
+
     if False:
         ax[0].set_ylim(0, 1)
         ax[1].set_ylim(0, 1)
@@ -1850,11 +1905,12 @@ if __name__ == "__main__":
         # ax[1].legend()
         # ax[2].legend()
         plt.tight_layout()
-        plt.savefig("first_knowledge_state", dpi=300)
+        plt.savefig("./summary_figures/first_knowledge_state", dpi=300)
         plt.show()
         quit()
 
     if False:
+        pop_state_starts = pickle.load(open("multi_chain_saves/pop_state_starts.p", 'rb'))
         f, (ax, ax2) = plt.subplots(2, 1, sharex=True, figsize=(9, 9))
 
         # plot the same data on both axes
@@ -1864,14 +1920,14 @@ if __name__ == "__main__":
         ax2.bar(np.linspace(0, 1, 20), pop_state_starts, align='edge', width=1/19, color='grey')
 
         # zoom-in / limit the view to different portions of the data
-        ax.set_ylim(24.75, 28)
-        ax2.set_ylim(0, 3.25)
+        ax.set_ylim(245, 300)
+        ax2.set_ylim(0, 25)
 
-        ax.set_yticks([25, 27])
-        ax.set_yticklabels([25, 27])
-        ax2.set_yticks([0, 1, 2, 3])
+        ax.set_yticks([250, 275])
+        ax.set_yticklabels([250, 275])
+        ax2.set_yticks([0, 8, 16, 24])
         ax2.set_xticks([0, .25, .5, .75, 1])
-        ax2.set_yticklabels([0, 1, 2, 3])
+        ax2.set_yticklabels([0, 8, 16, 24])
         ax2.set_xticklabels([0, .25, .5, .75, 1])
 
         # hide the spines between ax and ax2
@@ -1896,13 +1952,13 @@ if __name__ == "__main__":
         ax2.plot((-d, +d), (1 - d, 1 + d), **kwargs)  # bottom-left diagonal
 
         plt.tight_layout()
-        plt.savefig('states within session', dpi=300)
+        plt.savefig('./summary_figures/states within session', dpi=300)
         plt.close()
 
     if False:
         f, (a1, a2) = plt.subplots(1, 2, figsize=(18, 9), gridspec_kw={'width_ratios': [1.4, 1]})
-        # a1.bar(np.linspace(0, 1, 11), state_appear_dist, align='edge', width=1/10, color='grey')
-        a1.hist(state_appear_mode, color='grey')
+        a1.bar(np.linspace(0, 1, 11), state_appear_dist, align='edge', width=1/10, color='grey')
+        # a1.hist(state_appear_mode, color='grey')
         a1.set_xlim(left=0)
         # plt.title('First appearence of ', fontsize=22)
         a1.set_ylabel('# of states', fontsize=35)
@@ -1932,5 +1988,5 @@ if __name__ == "__main__":
         a2.tick_params(axis='both', labelsize=20)
         sns.despine()
         plt.tight_layout()
-        plt.savefig('states in sessions', dpi=300)
+        plt.savefig('./summary_figures/states in sessions', dpi=300)
         plt.close()
diff --git a/index_mice.py b/index_mice.py
index 96d35de5b11c000cdb713ffea432e4c8b74473e5..6c42c4d28994da619fd3a5cc63535ef9706ddb23 100644
--- a/index_mice.py
+++ b/index_mice.py
@@ -9,14 +9,14 @@ test = ['CSHL059_fittype_prebias_var_0.03_303_45_1.p', 'ibl_witten_16_fittype_pr
 prebias_subinfo = {}
 bias_subinfo = {}
 
-for filename in os.listdir("./dynamic_GLMiHMM_crossvals/"):
+for filename in test:#os.listdir("./dynamic_GLMiHMM_crossvals/"):
     if not filename.endswith('.p'):
         continue
     regexp = re.compile(r'((\w|-)+)_fittype_(\w+)_var_0.03_(\d+)_(\d+)_(\d+)')
     result = regexp.search(filename)
     subject = result.group(1)
-    if subject == 'ibl_witten_26':
-        print('here')
+    print(subject)
+    continue
     fit_type = result.group(3)
     seed = result.group(4)
     fit_num = result.group(5)
diff --git a/simplex_animation.py b/simplex_animation.py
index ddf6df7925ac502d5319f6b4a6986eb27adb6fff..dce03864bee596ead466388d2827318280b647fa 100644
--- a/simplex_animation.py
+++ b/simplex_animation.py
@@ -5,6 +5,7 @@ from analysis_pmf import type2color
 import math
 import matplotlib.pyplot as plt
 import copy
+import string
 
 
 all_state_types = pickle.load(open("all_state_types.p", 'rb'))
@@ -39,7 +40,15 @@ assert (test_count == 1).all()
 # quit()
 
 session_counter = -1
-alph = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'za', 'zb', 'zc', 'zd', 'ze', 'zf', 'zg', 'zh', 'zi', 'zj']
+string_list = []
+alphabet = string.ascii_lowercase
+
+for char1 in alphabet:
+    for char2 in alphabet:
+        string_list.append(char1 + char2)
+    if len(string_list) == 100:
+        break
+
 # do as many sessions as it takes
 while True:
 
@@ -58,7 +67,7 @@ while True:
         else:
             temp_counter = -1
 
-        assert np.sum(sts[:, temp_counter]) <= 1
+        assert np.sum(sts[:, temp_counter]) <= 1.000000000000001
 
         if (sts[:, temp_counter] == 1).any():
             x_offset[i] = offsets[i][0] * 0.01
@@ -66,8 +75,8 @@ while True:
 
         type_proportions[:, i] = sts[:, temp_counter]
 
-    plotSimplex(type_proportions.T, x_offset=x_offset, y_offset=y_offset, c=np.arange(len(all_state_types)), show=False, title="Session {}".format(1 + session_counter),
-                vertexcolors=[type2color[i] for i in range(3)], vertexlabels=['Type 1', 'Type 2', 'Type 3'], save_title="simplex_{}.png".format(alph[session_counter]))
+    plotSimplex(type_proportions.T * 10, x_offset=x_offset, y_offset=y_offset, c=np.arange(len(all_state_types)), show=False, title="Session {}".format(1 + session_counter),
+                vertexcolors=[type2color[i] for i in range(3)], vertexlabels=['Type 1', 'Type 2', 'Type 3'], save_title="simplex_{}.png".format(string_list[session_counter]))
 
     if not_ended == 0:
         break
diff --git a/simplex_plot.py b/simplex_plot.py
index 420dd26c661da32a2e9f89e2fbb937edea411f25..f9f8a0703b2fddf4610826d5d09604008906a060 100644
--- a/simplex_plot.py
+++ b/simplex_plot.py
@@ -15,7 +15,7 @@ import matplotlib.patches as PA
 
 def plotSimplex(points, fig=None,
                 vertexlabels=['1: initial flat PMFs', '2: intermediate unilateral PMFs', '3: final bilateral PMFs'], title='',
-                save_title="dur_simplex.png", show=False, vertexcolors=['k', 'k', 'k'], x_offset=0, y_offset=0, **kwargs):
+                save_title="./summary_figures/dur_simplex.png", show=False, vertexcolors=['k', 'k', 'k'], x_offset=0, y_offset=0, **kwargs):
     """
     Plot Nx3 points array on the 3-simplex
     (with optionally labeled vertices)
@@ -47,12 +47,12 @@ def plotSimplex(points, fig=None,
     P.scatter(projected[:, 0], projected[:, 1], marker='*', color='r', s=np.mean(points.sum(1)) * 3.5)#s=50
 
     # Leave some buffer around the triangle for vertex labels
-    fig.gca().set_xlim(-0.05, 1.05)
-    fig.gca().set_ylim(-0.05, 1.05)
+    fig.gca().set_xlim(-0.08, 1.08)
+    fig.gca().set_ylim(-0.08, 1.08)
 
     P.axis('off')
     if title != '':
-        P.annotate(title, (0.395, np.sqrt(3) / 2 + 0.025), size=24)
+        P.annotate(title, (0.395, np.sqrt(3) / 2 + 0.075), size=24)
 
     P.tight_layout()
     P.savefig(save_title, bbox_inches='tight', dpi=300, transparent=True)