diff --git a/__pycache__/analysis_pmf.cpython-37.pyc b/__pycache__/analysis_pmf.cpython-37.pyc
index 34a4499ac8d78ac0139d1f5675718e1ab52cae6d..e1ab5f64794b712fd3d7753bfc894d24535085b8 100644
Binary files a/__pycache__/analysis_pmf.cpython-37.pyc and b/__pycache__/analysis_pmf.cpython-37.pyc differ
diff --git a/__pycache__/simplex_plot.cpython-37.pyc b/__pycache__/simplex_plot.cpython-37.pyc
index 8417f61340255c4910e14c7ab0a3f0d48670b60d..d40c9990e4c7cc9a2781a81174008f4cb3e8b339 100644
Binary files a/__pycache__/simplex_plot.cpython-37.pyc and b/__pycache__/simplex_plot.cpython-37.pyc differ
diff --git a/analysis_pmf.py b/analysis_pmf.py
index 6729289c14616f8cbf4beb722a909c8f323c8601..637f06cd68c5139c19f7171c90f5b28c4a86c15a 100644
--- a/analysis_pmf.py
+++ b/analysis_pmf.py
@@ -28,32 +28,32 @@ if __name__ == "__main__":
     state_types_interpolation = state_types_interpolation / state_types_interpolation.max() * 100
 
     fs = 18
-    plt.plot(np.linspace(0, 1, 150), state_types_interpolation[0], color=type2color[0])
-    plt.ylabel("% of type across population", size=fs)
-    plt.xlabel("Interpolated session time", size=fs)
-    plt.ylim(0, 100)
-    sns.despine()
-    plt.tight_layout()
-    plt.savefig("type hist 1")
-    plt.show()
-
-    plt.plot(np.linspace(0, 1, 150), state_types_interpolation[1], color=type2color[1])
-    plt.ylabel("% of type across population", size=fs)
-    plt.xlabel("Interpolated session time", size=fs)
-    plt.ylim(0, 100)
-    sns.despine()
-    plt.tight_layout()
-    plt.savefig("type hist 2")
-    plt.show()
-
-    plt.plot(np.linspace(0, 1, 150), state_types_interpolation[2], color=type2color[2])
-    plt.ylabel("% of type across population", size=fs)
-    plt.xlabel("Interpolated session time", size=fs)
-    plt.ylim(0, 100)
-    sns.despine()
-    plt.tight_layout()
-    plt.savefig("type hist 3")
-    plt.show()
+    # plt.plot(np.linspace(0, 1, 150), state_types_interpolation[0], color=type2color[0])
+    # plt.ylabel("% of type across population", size=fs)
+    # plt.xlabel("Interpolated session time", size=fs)
+    # plt.ylim(0, 100)
+    # sns.despine()
+    # plt.tight_layout()
+    # plt.savefig("type hist 1")
+    # plt.show()
+    #
+    # plt.plot(np.linspace(0, 1, 150), state_types_interpolation[1], color=type2color[1])
+    # plt.ylabel("% of type across population", size=fs)
+    # plt.xlabel("Interpolated session time", size=fs)
+    # plt.ylim(0, 100)
+    # sns.despine()
+    # plt.tight_layout()
+    # plt.savefig("type hist 2")
+    # plt.show()
+    #
+    # plt.plot(np.linspace(0, 1, 150), state_types_interpolation[2], color=type2color[2])
+    # plt.ylabel("% of type across population", size=fs)
+    # plt.xlabel("Interpolated session time", size=fs)
+    # plt.ylim(0, 100)
+    # sns.despine()
+    # plt.tight_layout()
+    # plt.savefig("type hist 3")
+    # plt.show()
 
     all_first_pmfs_typeless = pickle.load(open("all_first_pmfs_typeless.p", 'rb'))
     all_pmfs = pickle.load(open("all_pmfs.p", 'rb'))
@@ -161,37 +161,41 @@ if __name__ == "__main__":
 
     lw = 4
     # Simplex example pmfs
-    # state_num = 7
-    # defined_points, pmf = all_first_pmfs_typeless['NYU-06'][state_num][0], all_first_pmfs_typeless['NYU-06'][state_num][1]
-    # plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
-    # state_num = 6
-    # defined_points, pmf = all_first_pmfs_typeless['CSHL061'][state_num][0], all_first_pmfs_typeless['CSHL061'][state_num][1]
-    # plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
-    #
-    # plt.ylim(0, 1)
-    # plt.xlim(0, 10)
-    # plt.yticks([])
-    # plt.xticks([])
-    # sns.despine()
-    # plt.tight_layout()
-    # plt.savefig("example type 1")
-    # plt.show()
-    #
-    # state_num = 1
-    # defined_points, pmf = all_first_pmfs_typeless['CSHL_018'][state_num][0], all_first_pmfs_typeless['CSHL_018'][state_num][1]
-    # plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
-    # state_num = 3
-    # defined_points, pmf = all_first_pmfs_typeless['ibl_witten_14'][state_num][0], all_first_pmfs_typeless['ibl_witten_14'][state_num][1]
-    # plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
-    #
-    # plt.ylim(0, 1)
-    # plt.xlim(0, 10)
-    # plt.yticks([])
-    # plt.xticks([])
-    # sns.despine()
-    # plt.tight_layout()
-    # plt.savefig("example type 2")
-    # plt.show()
+    state_num = 7
+    defined_points, pmf = all_first_pmfs_typeless['NYU-06'][state_num][0], all_first_pmfs_typeless['NYU-06'][state_num][1]
+    plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
+    state_num = 6
+    defined_points, pmf = all_first_pmfs_typeless['CSHL061'][state_num][0], all_first_pmfs_typeless['CSHL061'][state_num][1]
+    plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
+
+    plt.ylim(0, 1)
+    plt.xlim(0, 10)
+    plt.ylabel("P(rightwards)", size=32)
+    plt.xlabel("Contrast", size=32)
+    plt.yticks([0, 1], size=27)
+    plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=27)
+    sns.despine()
+    plt.tight_layout()
+    plt.savefig("example type 1")
+    plt.show()
+
+    state_num = 1
+    defined_points, pmf = all_first_pmfs_typeless['CSHL_018'][state_num][0], all_first_pmfs_typeless['CSHL_018'][state_num][1]
+    plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
+    state_num = 3
+    defined_points, pmf = all_first_pmfs_typeless['ibl_witten_14'][state_num][0], all_first_pmfs_typeless['ibl_witten_14'][state_num][1]
+    plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
+
+    plt.ylim(0, 1)
+    plt.xlim(0, 10)
+    plt.ylabel("P(rightwards)", size=32)
+    plt.xlabel("Contrast", size=32)
+    plt.yticks([0, 1], size=27)
+    plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=27)
+    sns.despine()
+    plt.tight_layout()
+    plt.savefig("example type 2")
+    plt.show()
 
     state_num = 4
     defined_points, pmf = all_first_pmfs_typeless['ibl_witten_17'][state_num][0], all_first_pmfs_typeless['ibl_witten_17'][state_num][1]
@@ -199,8 +203,10 @@ if __name__ == "__main__":
 
     plt.ylim(0, 1)
     plt.xlim(0, 10)
-    plt.yticks([])
-    plt.xticks([])
+    plt.ylabel("P(rightwards)", size=32)
+    plt.xlabel("Contrast", size=32)
+    plt.yticks([0, 1], size=27)
+    plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=27)
     sns.despine()
     plt.tight_layout()
     plt.savefig("example type 3")
diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py
index f99904a98921b921be4f1a181e106a051eab03da..c61d4c7c659c1360646febd6889a35d401ebdc37 100644
--- a/dyn_glm_chain_analysis.py
+++ b/dyn_glm_chain_analysis.py
@@ -526,7 +526,7 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur
                 # print("fix this by taking the whole array, multiply by n, subtract n, divide by n-1")
                 # input()
 
-            label = "State {}".format(state) if np.sum(relevant_trials) > 0.02 * len(test.results[0].models[0].stateseqs[seq_num]) else None
+            label = "State {}".format(len(state_sets) - test.state_mapping[state]) if np.sum(relevant_trials) > 0.02 * len(test.results[0].models[0].stateseqs[seq_num]) else None
 
             # state_c_n_a = c_n_a[relevant_trials - trial_counter]
             # print(state)
@@ -571,7 +571,7 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur
         plt.xlabel('Trial', size=28)
         sns.despine()
 
-        # plt.xlim(left=250, right=450)
+        plt.xlim(left=250, right=450)
         plt.legend(frameon=False, fontsize=22, bbox_to_anchor=(0.8, 0.5))
         plt.tight_layout()
         if save:
@@ -854,22 +854,25 @@ def state_set_and_plot(test, mode_prefix, subject, fit_type):
 
 
 def state_pmfs(test, trials, indices):
-    def func_init(): return {'pmfs': [], 'session_js': []}
+    def func_init(): return {'pmfs': [], 'session_js': [], 'pmf_weights': []}
 
     def first_for(test, results):
         results['pmf'] = np.zeros(test.results[0].n_contrasts)
+        results['pmf_weight'] = np.zeros(4)
 
     def second_for(m, j, session_trials, trial_counter, results):
         states, counts = np.unique(m.stateseqs[j][session_trials - trial_counter], return_counts=True)
         for sub_state, c in zip(states, counts):
             results['pmf'] += weights_to_pmf(m.obs_distns[sub_state].weights[j]) * c / session_trials.shape[0]
+            results['pmf_weight'] += m.obs_distns[sub_state].weights[j] * c / session_trials.shape[0]
 
     def end_first_for(results, indices, j, **kwargs):
         results['pmfs'].append(results['pmf'] / len(indices))
+        results['pmf_weights'].append(results['pmf_weight'] / len(indices))
         results['session_js'].append(j)
 
     results = control_flow(test, indices, trials, func_init, first_for, second_for, end_first_for)
-    return results['session_js'], results['pmfs']
+    return results['session_js'], results['pmfs'], results['pmf_weights']
 
 
 def state_weights(test, trials, indices):
@@ -1138,6 +1141,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
         ax0.annotate('Bias', (test.results[0].infos['bias_start'] + 1 - 0.5, 0.68), fontsize=22)
 
     all_pmfs = []
+    all_pmf_weights = []
     cmaps = ['Greys', 'Purples', 'Blues', 'Greens', 'Oranges', 'Reds', 'YlOrBr', 'YlOrRd', 'OrRd', 'PuRd', 'RdPu']
     np.random.seed(8)
     np.random.shuffle(cmaps)
@@ -1147,7 +1151,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
     for state, trials in enumerate(state_sets):
         if separate_pmf:
             n_trials = len(trials)
-            session_js, pmfs = state_pmfs(test, trials, indices)
+            session_js, pmfs, _ = state_pmfs(test, trials, indices)
         else:
             pmfs = np.zeros((len(indices), test.results[0].n_contrasts))
             n_trials = len(trials)
@@ -1173,9 +1177,10 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
 
         if separate_pmf:
             n_trials = len(trials)
-            session_js, pmfs = state_pmfs(test, trials, indices)
+            session_js, pmfs, pmf_weights = state_pmfs(test, trials, indices)
         else:
             pmfs = np.zeros((len(indices), test.results[0].n_contrasts))
+            pmf_weights = np.zeros((len(indices), test.results[0].obs_distns[0].weights.shape[0]))
             n_trials = len(trials)
             counter = 0
             for i, m in enumerate([item for sublist in test.results for item in sublist.models]):
@@ -1188,6 +1193,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
                         states, counts = np.unique(state_seq[session_trials - trial_counter], return_counts=True)
                         for sub_state, c in zip(states, counts):
                             pmfs[counter] += weights_to_pmf(m.obs_distns[sub_state].weights[j]) * c / n_trials
+                            pmf_weights[counter] += m.obs_distns[sub_state].weights[j] * c / n_trials
                     trial_counter += len(state_seq)
                 counter += 1
 
@@ -1218,7 +1224,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
                 if not test.state_mapping[state] in dont_plot:
                     ax1.fill_between([points[k], points[k+1]],
                                      test.state_mapping[state] - 0.5, [test.state_mapping[state] + interpolation[k] - 0.5, test.state_mapping[state] + interpolation[k+1] - 0.5], color=cmap(0.3 + 0.7 * k / n_points))
-        ax1.annotate(test.state_mapping[state] + 1, (test.results[0].n_sessions + 0.1, test.state_mapping[state] - 0.15), fontsize=22, annotation_clip=False)
+        ax1.annotate(len(state_sets) - test.state_mapping[state], (test.results[0].n_sessions + 0.1, test.state_mapping[state] - 0.15), fontsize=22, annotation_clip=False)
 
         if test.results[0].name.startswith('GLM_Sim_'):
             ax1.plot(range(1, 1 + test.results[0].n_sessions), truth['state_map'][test.state_mapping[state]] + truth['state_posterior'][:, state] - 0.5, color='r')
@@ -1235,11 +1241,12 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
         #         defined_points = np.zeros(test.results[0].n_contrasts, dtype=bool)
         #         defined_points[[0, 1, -2, -1]] = True
         if separate_pmf:
-            for j, pmf in zip(session_js, pmfs):
+            for j, pmf, pmf_weight in zip(session_js, pmfs, pmf_weights):
                 if not test.state_mapping[state] in dont_plot:
                     ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmf[defined_points] - 0.5 + test.state_mapping[state], color=cmap(0.2 + 0.8 * j / test.results[0].n_sessions))
                     ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmf[defined_points] - 0.5 + test.state_mapping[state], ls='', ms=7, marker='*', color=cmap(j / test.results[0].n_sessions))
             all_pmfs.append((defined_points, pmfs))
+            all_pmf_weights += pmf_weights
         else:
             temp = np.percentile(pmfs, [2.5, 97.5], axis=0)
             if not test.state_mapping[state] in dont_plot:
@@ -1313,7 +1320,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
     ax2.set_title('Psychometric\nfunction', size=16)
     ax1.set_ylabel('Proportion of trials', size=28, labelpad=-20)
     ax0.set_ylabel('% correct', size=18)
-    ax2.set_ylabel('Probability', size=26, labelpad=-20)
+    ax2.set_ylabel('P(rightwards answer)', size=26, labelpad=-20)
     ax1.set_xlabel('Session', size=28)
     ax2.set_xlabel('Contrast', size=26)
     ax1.set_xlim(left=1, right=test.results[0].n_sessions)
@@ -1352,7 +1359,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
     else:
         plt.close()
 
-    return states_by_session, all_pmfs, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs)), introductions_by_stage, states_per_type
+    return states_by_session, all_pmfs, all_pmf_weights, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs)), introductions_by_stage, states_per_type
 
 def compare_pmfs(test, state_sets, indices, states2compare, states_by_session, all_pmfs, title=""):
     """
@@ -1670,7 +1677,6 @@ if __name__ == "__main__":
     # subjects = ['SWC_021', 'ibl_witten_15', 'ibl_witten_13', 'KS003', 'ibl_witten_19', 'SWC_022', 'CSH_ZAD_017']
     # subjects = ['KS014']
 
-    # meh pmfs: KS021
     print(subjects)
     fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0]
     dur = 'yes'
@@ -1709,6 +1715,7 @@ if __name__ == "__main__":
     regressions = []
     regression_diffs = []
     all_bias_flips = []
+    all_pmf_weights = []
 
     temp_counter = 0
 
@@ -1744,15 +1751,20 @@ if __name__ == "__main__":
             # states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 1', show=1, separate_pmf=1, type_coloring=True, dont_plot=list(range(7)), plot_until=2)
             # states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 2', show=1, separate_pmf=1, type_coloring=True, dont_plot=list(range(6)), plot_until=7)
             # states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 3', show=1, separate_pmf=1, type_coloring=True, dont_plot=list(range(4)), plot_until=13)
-            states, pmfs, durs, state_types, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True)
-            all_state_types.append(state_types)
+            states, pmfs, pmf_weights, durs, state_types, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True)
+            all_pmf_weights += pmf_weights
             continue
+
+            consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
+            consistencies /= consistencies[0, 0]
+            contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=False, consistencies=consistencies, CMF=False)
+
+            all_state_types.append(state_types)
             # state_types_interpolation[0] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[0])
             # state_types_interpolation[1] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[1])
             # state_types_interpolation[2] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[2])
             # temp_counter += 1
 
-            continue
 
             # b_flips = bias_flips(states, pmfs, durs)
             # all_bias_flips.append(b_flips)
@@ -1915,6 +1927,8 @@ if __name__ == "__main__":
     # pickle.dump(regression_diffs, open("regression_diffs.p", 'wb'))
     # pickle.dump(all_bias_flips, open("all_bias_flips.p", 'wb'))
     # pickle.dump(all_state_types, open("all_state_types.p", 'wb'))
+    pickle.dump(all_pmf_weights, open("all_pmf_weights.p", 'wb'))
+    quit()
 
 
     if True:
diff --git a/pmf_weight_analysis.py b/pmf_weight_analysis.py
new file mode 100644
index 0000000000000000000000000000000000000000..00611bcc7e1200cbe4a76a79ebde4510408bea3b
--- /dev/null
+++ b/pmf_weight_analysis.py
@@ -0,0 +1,91 @@
+import numpy as np
+import matplotlib.pyplot as plt
+import pickle
+from scipy.stats import gaussian_kde
+from analysis_pmf import pmf_type, type2color
+
+# all pmf weights
+apw = np.array(pickle.load(open("all_pmf_weights.p", 'rb')))
+
+xy = np.vstack([apw[:, i] for i in range(4)])
+z = gaussian_kde(xy)(xy)
+
+plt.subplot(1, 3, 1)
+plt.scatter(apw[:, 0], apw[:, 1], c=z)
+plt.xlabel("Cont right")
+plt.ylabel("Cont left")
+
+plt.subplot(1, 3, 2)
+plt.scatter(apw[:, 3], apw[:, 1], c=z)
+plt.xlabel("Bias")
+plt.ylabel("Cont left")
+
+plt.subplot(1, 3, 3)
+plt.scatter(apw[:, 3], apw[:, 0], c=z)
+plt.xlabel("Bias")
+plt.ylabel("Cont right")
+
+plt.show()
+
+
+contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])
+contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1]
+
+def weights_to_pmf(weights, with_bias=1):
+    psi = weights[0] * contrasts_R + weights[1] * contrasts_L + with_bias * weights[-1]
+    return 1 / (1 + np.exp(psi))
+
+colors = [type2color[pmf_type(weights_to_pmf(x))] for x in apw]
+
+plt.subplot(1, 3, 1)
+sc = plt.scatter(apw[:, 0], apw[:, 1], c=colors)
+fig, ax = plt.gcf(), plt.gca()
+plt.xlabel("Cont right")
+plt.ylabel("Cont left")
+
+annot = ax.annotate("", xy=(0, 0), xytext=(20, 20), textcoords="offset points",
+                    bbox=dict(boxstyle="round", fc="w"),
+                    arrowprops=dict(arrowstyle="->"))
+annot.set_visible(False)
+
+def update_annot(ind):
+    print(ind)
+    pos = sc.get_offsets()[ind["ind"][0]]
+    annot.xy = pos
+    text = "{}".format(np.round(apw[ind["ind"][0]], 2))
+    annot.set_text(text)
+
+def hover(event):
+    vis = annot.get_visible()
+    if event.inaxes == ax:
+        cont, ind = sc.contains(event)
+        if cont:
+            update_annot(ind)
+            annot.set_visible(True)
+            fig.canvas.draw_idle()
+        else:
+            if vis:
+                annot.set_visible(False)
+                fig.canvas.draw_idle()
+
+fig.canvas.mpl_connect("motion_notify_event", hover)
+
+plt.subplot(1, 3, 2)
+plt.scatter(apw[:, 3], apw[:, 1], c=colors)
+plt.xlabel("Bias")
+plt.ylabel("Cont left")
+
+plt.subplot(1, 3, 3)
+plt.scatter(apw[:, 3], apw[:, 0], c=colors)
+plt.xlabel("Bias")
+plt.ylabel("Cont right")
+
+plt.show()
+
+
+from mpl_toolkits import mplot3d
+
+fig = plt.figure()
+ax = plt.axes(projection='3d')
+ax.scatter3D(apw[:, 0], apw[:, 1], apw[:, 3], c=colors)
+plt.show()
diff --git a/simplex_animation.py b/simplex_animation.py
index 2dd78ee1c91cd03961b826c39d3ba9604e06a08b..89dfe945c0994f472c74602d89541d12ac1d5fb1 100644
--- a/simplex_animation.py
+++ b/simplex_animation.py
@@ -38,6 +38,7 @@ assert (test_count == 1).all()
 # quit()
 
 session_counter = 0
+alph = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'za', 'zb', 'zc', 'zd', 'ze', 'zf', 'zg', 'zh', 'zi', 'zj']
 # do as many sessions as it takes
 while True:
 
@@ -64,9 +65,9 @@ while True:
 
         type_proportions[:, i] = sts[:, temp_counter]
 
-    plotSimplex(type_proportions.T, x_offset=x_offset, y_offset=y_offset, c=np.arange(len(all_state_types)), show=False, vertexcolors=[type2color[i] for i in range(3)], vertexlabels=['Type 1', 'Type 2', 'Type 3'], save_title=None)
-    plt.title(session_counter)
-    plt.show()
+    plotSimplex(type_proportions.T, x_offset=x_offset, y_offset=y_offset, c=np.arange(len(all_state_types)), show=False, title="Session {}".format(session_counter),
+                vertexcolors=[type2color[i] for i in range(3)], vertexlabels=['Type 1', 'Type 2', 'Type 3'], save_title="simplex_{}.png".format(alph[session_counter]))
+    plt.close()
 
     if not_ended == 0:
         break
diff --git a/simplex_plot.py b/simplex_plot.py
index 62bf36276da1b9f0f199687e968371d0d6673a46..635885d05417cb2c7810ec74e6d5cb5572e6f0b7 100644
--- a/simplex_plot.py
+++ b/simplex_plot.py
@@ -15,7 +15,7 @@ import matplotlib.patches as PA
 
 def plotSimplex(points, fig=None,
                 vertexlabels=['1: initial flat PMFs', '2: intermediate unilateral PMFs', '3: final bilateral PMFs'],
-                show=False, vertexcolors=['k', 'k', 'k'], x_offset=0, y_offset=0, save_title="dur_simplex.png", **kwargs):
+                save_title="test.png", show=False, vertexcolors=['k', 'k', 'k'], x_offset=0, y_offset=0, **kwargs):
     """
     Plot Nx3 points array on the 3-simplex
     (with optionally labeled vertices)
@@ -32,17 +32,17 @@ def plotSimplex(points, fig=None,
     fig.gca().xaxis.set_major_locator(MT.NullLocator())
     fig.gca().yaxis.set_major_locator(MT.NullLocator())
     # Draw vertex labels
-    fig.gca().annotate(vertexlabels[0], (-0.35, -0.05), size=24, color=vertexcolors[0], annotation_clip=False)
-    fig.gca().annotate(vertexlabels[1], (0.6, -0.05), size=24, color=vertexcolors[1], annotation_clip=False)
-    fig.gca().annotate(vertexlabels[2], (0.1, np.sqrt(3) / 2 + 0.025), size=24, color=vertexcolors[2], annotation_clip=False)
+    # fig.gca().annotate(vertexlabels[0], (-0.35, -0.05), size=24, color=vertexcolors[0], annotation_clip=False)
+    # fig.gca().annotate(vertexlabels[1], (0.6, -0.05), size=24, color=vertexcolors[1], annotation_clip=False)
+    # fig.gca().annotate(vertexlabels[2], (0.1, np.sqrt(3) / 2 + 0.025), size=24, color=vertexcolors[2], annotation_clip=False)
     # Project and draw the actual points
     projected = projectSimplex(points / points.sum(1)[:, None])
-    # print(projected)
-    P.scatter(projected[:, 0] + x_offset, projected[:, 1] + y_offset, s=35, **kwargs)#s=points.sum(1) * 3.5
+    print(projected)
+    P.scatter(projected[:, 0], projected[:, 1], s=points.sum(1) * 3.5, **kwargs)
 
     # plot center with average size
     projected = projectSimplex(np.mean(points / points.sum(1)[:, None], axis=0).reshape(1, 3))
-    P.scatter(projected[:, 0], projected[:, 1], marker='*', color='r', s=50)#np.mean(points.sum(1)) * 3.5)
+    P.scatter(projected[:, 0], projected[:, 1], marker='*', color='r', s=np.mean(points.sum(1)) * 3.5)
 
     # Leave some buffer around the triangle for vertex labels
     fig.gca().set_xlim(-0.05, 1.05)
@@ -51,10 +51,11 @@ def plotSimplex(points, fig=None,
     P.axis('off')
 
     P.tight_layout()
-    if save_title:
-        P.savefig(save_title, bbox_inches='tight', dpi=300, transparent=True)
+    P.savefig("dur_simplex.png", bbox_inches='tight', dpi=300, transparent=True)
     if show:
         P.show()
+    else:
+        P.close()
 
 
 def projectSimplex(points):