diff --git a/__pycache__/mcmc_chain_analysis.cpython-37.pyc b/__pycache__/mcmc_chain_analysis.cpython-37.pyc
index bdea663ae5d36ec0de32219c100df06016b9244f..9f093ef55cf22048cd989c21378e8d782850879c 100644
Binary files a/__pycache__/mcmc_chain_analysis.cpython-37.pyc and b/__pycache__/mcmc_chain_analysis.cpython-37.pyc differ
diff --git a/__pycache__/simplex_plot.cpython-37.pyc b/__pycache__/simplex_plot.cpython-37.pyc
index 1c26491f573cb6355e901c949147f3d921c87a30..8fa16cdde21c4a7d3d6daa3797043fc2e294e842 100644
Binary files a/__pycache__/simplex_plot.cpython-37.pyc and b/__pycache__/simplex_plot.cpython-37.pyc differ
diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py
index 3f1e31cffdccbaeedd50667e8840170d2e8499b0..71e06373a2a4dc0a114fb5407e8de9f2a9479899 100644
--- a/dyn_glm_chain_analysis.py
+++ b/dyn_glm_chain_analysis.py
@@ -574,10 +574,10 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur
         noise = np.zeros(len(c_n_a))# np.random.rand(len(c_n_a)) * 0.4 - 0.2
 
         mask = c_n_a[:, -1] == 0
-        plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - c_n_a[mask, 0] + c_n_a[mask, 1]), 'o', c='b', ms=ms, label='Leftward')
+        plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - c_n_a[mask, 0] + c_n_a[mask, 1]), 'o', c='b', ms=ms, label='Leftward', alpha=0.6)
 
         mask = c_n_a[:, -1] == 1
-        plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - c_n_a[mask, 0] + c_n_a[mask, 1]), 'o', c='r', ms=ms, label='Rightward')
+        plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - c_n_a[mask, 0] + c_n_a[mask, 1]), 'o', c='r', ms=ms, label='Rightward', alpha=0.6)
 
         plt.title("session #{} / {}".format(1+seq_num, test.results[0].n_sessions), size=26)
         # plt.yticks(*self.cont_ticks, size=22-2)
@@ -852,7 +852,7 @@ def state_development_single_sample(test, indices, save=True, save_append='', sh
     return states_by_session, all_pmfs
 
 
-def state_development(test, state_sets, indices, save=True, save_append='', show=True, dpi='figure', separate_pmf=False):
+def state_development(test, state_sets, indices, save=True, save_append='', show=True, dpi='figure', separate_pmf=False, type_coloring=True):
     # Now also returns durs of state types and state type summary array
     state_sets = [np.array(s) for s in state_sets]
 
@@ -1038,7 +1038,8 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
     introductions_by_stage = np.zeros(3)
     covered_states = []
     for i, d in enumerate(durs):
-        ax0.fill_between(range(dur_counter, 1 + dur_counter + d), 0.5, -0.5, color=type_colours[i], zorder=0, alpha=0.3)
+        if type_coloring:
+            ax0.fill_between(range(dur_counter, 1 + dur_counter + d), 0.5, -0.5, color=type_colours[i], zorder=0, alpha=0.3)
         dur_counter += d
 
         # find out during which state type which contrast was introduced
@@ -1098,7 +1099,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
     else:
         plt.close()
 
-    return states_by_session, all_pmfs, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs))
+    return states_by_session, all_pmfs, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs)), introductions_by_stage
 
 
 def smart_divide(a, b):
@@ -1154,14 +1155,13 @@ if __name__ == "__main__":
 
     r_hats = []
 
+    # R^hat tests
     # test = MCMC_result_list([fake_result(100) for i in range(8)])
     # test.r_hat_and_ess(return_ascending, False)
     # test.r_hat_and_ess(return_ascending_shuffled, False)
     # quit()
-    good = []
-    bad = []
 
-    check_r_hats = False
+    check_r_hats = True
     if check_r_hats:
         subjects = list(loading_info.keys())
         subjects = ['KS014']
@@ -1188,11 +1188,6 @@ if __name__ == "__main__":
             r_hats.append((subject, final_r_hat))
             loading_info[subject]['ignore'] = sol
 
-            if final_r_hat < 1.05:
-                good.append(subject)
-            else:
-                bad.append(subject)
-
         print(r_hats)
         if fit_type == 'bias':
             json.dump(loading_info, open("canonical_infos_bias.json", 'w'))
@@ -1221,10 +1216,15 @@ def state_type_durs(states, pmfs):
                 pmf_counter += 1
                 state_types[pmf_type(pmf[1][pmf_counter][pmf[0]]), i] += s[i]  # indexing horror
 
-    durs = (np.where(state_types[1] > 0.5)[0][0],
-            np.where(state_types[2] > 0.5)[0][0] - np.where(state_types[1] > 0.5)[0][0],
-            states.shape[1] - np.where(state_types[2] > 0.5)[0][0])
-    if np.where(state_types[2] > 0.5)[0][0] < np.where(state_types[1] > 0.5)[0][0]:
+    if np.any(state_types[1] > 0.5):
+        durs = (np.where(state_types[1] > 0.5)[0][0],
+                np.where(state_types[2] > 0.5)[0][0] - np.where(state_types[1] > 0.5)[0][0],
+                states.shape[1] - np.where(state_types[2] > 0.5)[0][0])
+        if np.where(state_types[2] > 0.5)[0][0] < np.where(state_types[1] > 0.5)[0][0]:
+            durs = (np.where(state_types[2] > 0.5)[0][0],
+                    0,
+                    states.shape[1] - np.where(state_types[2] > 0.5)[0][0])
+    else:
         durs = (np.where(state_types[2] > 0.5)[0][0],
                 0,
                 states.shape[1] - np.where(state_types[2] > 0.5)[0][0])
@@ -1251,6 +1251,24 @@ def state_cluster_interpolation(states, pmfs):
     return state_types, state_trans, pmf_examples
 
 
+def get_first_pmfs(states, pmfs):
+    # get the first pmf of every type, also where they are defined, and whether they are the first pmf of that state
+    earliest_sessions = [1000, 1000, 1000]
+    first_pmfs = [0, 0, 0, 0, 0, 0, 0, 0, 0]
+    changing_pmfs = [[0, 0], [0, 0]]
+    for state, pmf in zip(states, pmfs):
+        sessions = np.where(state)[0]
+        for i, (sess_pmf, sess) in enumerate(zip(pmf[1], sessions)):
+            if earliest_sessions[pmf_type(sess_pmf[pmf[0]])] > sess:
+                earliest_sessions[pmf_type(sess_pmf[pmf[0]])] = sess
+                first_pmfs[3 * pmf_type(sess_pmf[pmf[0]])] = sess_pmf
+                first_pmfs[1 + 3 * pmf_type(sess_pmf[pmf[0]])] = pmf[0]
+                first_pmfs[2 + 3 * pmf_type(sess_pmf[pmf[0]])] = i
+                if i != 0:
+                    changing_pmfs[pmf_type(sess_pmf[pmf[0]]) - 1] = [pmf[0], pmf[1]]
+    return first_pmfs, changing_pmfs
+
+
 def plot_pmf_types(pmf_types, subject, fit_type, save=True, show=False):
     # Plot the different types of PMFs, all split up by their different types
     for i, pmfs in enumerate(pmf_types):
@@ -1267,15 +1285,155 @@ def plot_pmf_types(pmf_types, subject, fit_type, save=True, show=False):
     else:
         plt.close()
 
+
 def pmf_type(pmf):
-    if pmf[-1] - pmf[0] <= 0.15:
+    if pmf[-1] - pmf[0] < 0.2:
         return 0
-    elif pmf[-1] - pmf[0] < 0.6 and np.abs(pmf[0] + pmf[-1] - 1) > 0.1:
+    elif pmf[-1] - pmf[0] < 0.6:# and np.abs(pmf[0] + pmf[-1] - 1) > 0.1:
         return 1
     else:
         return 2
 
 
+type2color = {0: 'green', 1: 'blue', 2: 'red'}
+
+if False:
+
+    all_changing_pmfs = pickle.load(open("changing_pmfs.p", 'rb'))
+    plt.figure(figsize=(16, 9))
+    for i, pmf in enumerate(all_changing_pmfs):
+        plt.subplot(4, 7, i + 1)
+        for p in pmf[1]:
+            plt.plot(np.where(pmf[0])[0], p[pmf[0]], color=type2color[pmf_type(p)])
+        plt.ylim(0, 1)
+
+        sns.despine()
+        if i+1 != 22:
+            plt.gca().set_xticks([])
+            plt.gca().set_yticks([])
+        else:
+            plt.xlabel("Contrasts", size=22)
+            plt.ylabel("P(rightwards)", size=22)
+            plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=16)
+            plt.gca().set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=16)
+
+        if i + 1 == 30:
+            break
+
+    plt.tight_layout()
+    plt.savefig("changing pmfs")
+    plt.show()
+    quit()
+
+    type_2_assyms = []
+    tick_size = 14
+    label_size = 26
+    all_first_pmfs = pickle.load(open("pmfs_temp.p", 'rb'))
+    plt.figure(figsize=(16, 9))
+    plt.subplot(1, 3, 1)
+    counter = [[0, 0], [0, 0]]
+    save_title = "all types" if False else "KS014 types"
+    if save_title == "KS014 types":
+        all_first_pmfs = {'KS014': all_first_pmfs['KS014']}
+
+    for key in all_first_pmfs:
+        x = all_first_pmfs[key]
+        if type(x[0]) == int:
+            continue
+        linestyle = '-' if x[2] == 0 else '--'
+        plt.plot(np.where(x[1])[0], x[0][x[1]], linestyle=linestyle, c='g')
+    plt.ylim(0, 1)
+    plt.gca().set_xticks(np.arange(11), all_conts, size=tick_size)
+    plt.gca().set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size)
+    plt.gca().spines[['right', 'top']].set_visible(False)
+    plt.xlim(0, 10)
+    plt.xticks(rotation=45)
+    plt.gca().set_ylabel("P(rightwards)", size=label_size)
+
+    plt.subplot(1, 3, 2)
+    for key in all_first_pmfs:
+        x = all_first_pmfs[key]
+        if type(x[3]) == int:
+            continue
+        type_2_assyms.append(np.abs(x[3][0] + x[3][-1] - 1))
+        linestyle = '-' if x[5] == 0 else '--'
+        counter[0][0 if x[5] == 0 else 1] += 1
+        if linestyle == '--':
+            continue
+        plt.plot(np.where(x[4])[0], x[3][x[4]], linestyle=linestyle, c='b')
+    plt.gca().set_yticks([])
+    plt.ylim(0, 1)
+    plt.gca().set_xticks(np.arange(11), all_conts, size=tick_size)
+    plt.gca().spines[['right', 'top']].set_visible(False)
+    plt.xticks(rotation=45)
+    plt.xlim(0, 10)
+    plt.gca().set_xlabel("Contrasts", size=label_size)
+
+    plt.subplot(1, 3, 3)
+    for key in all_first_pmfs:
+        x = all_first_pmfs[key]
+        if type(x[6]) == int:
+            continue
+        linestyle = '-' if x[8] == 0 else '--'
+        counter[1][0 if x[8] == 0 else 1] += 1
+        if linestyle == '--':
+            continue
+        plt.plot(np.where(x[7])[0], x[6][x[7]], linestyle=linestyle, c='r')
+    plt.gca().set_yticks([])
+    plt.ylim(0, 1)
+    plt.gca().set_xticks(np.arange(11), all_conts, size=tick_size)
+    plt.gca().spines[['right', 'top']].set_visible(False)
+    plt.xlim(0, 10)
+    plt.xticks(rotation=45)
+
+    print(counter)
+    plt.tight_layout()
+    plt.savefig(save_title)
+    plt.show()
+    if save_title == "KS014 types":
+        quit()
+
+    counter = 0
+    fig, ax = plt.subplots(1, 3, figsize=(16, 9))
+    for key in all_first_pmfs:
+        x = all_first_pmfs[key]
+        if type(x[3]) == int:
+            continue
+        linestyle = '-' if x[5] == 0 else '--'
+        if linestyle == '--':
+            continue
+        if np.abs(x[3][0] + x[3][-1] - 1) <= 0.1:
+            counter += 1
+            use_ax = 2
+        else:
+            use_ax = int(x[3][0] > 1 - x[3][-1])
+
+        ax[use_ax].plot(np.where(x[4])[0], x[3][x[4]], linestyle=linestyle, c='b')
+    ax[0].set_ylim(0, 1)
+    ax[0].set_xlim(0, 10)
+    ax[0].spines[['right', 'top']].set_visible(False)
+    ax[0].set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size)
+    ax[0].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
+    ax[0].set_ylabel("P(rightwards)", size=label_size)
+
+    ax[1].set_ylim(0, 1)
+    ax[1].set_xlim(0, 10)
+    ax[1].set_yticks([])
+    ax[1].spines[['right', 'top']].set_visible(False)
+    ax[1].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
+    ax[1].set_xlabel("Contrasts", size=label_size)
+
+    ax[2].set_ylim(0, 1)
+    ax[2].set_xlim(0, 10)
+    ax[2].set_yticks([])
+    ax[2].spines[['right', 'top']].set_visible(False)
+    ax[2].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
+    print(counter)
+    plt.tight_layout()
+    plt.savefig("differentiate type 2")
+    plt.show()
+    quit()
+
 if __name__ == "__main__":
 
     # visualise pmf types
@@ -1301,40 +1459,11 @@ if __name__ == "__main__":
         loading_info = json.load(open("canonical_infos.json", 'r'))
         r_hats = json.load(open("canonical_info_r_hats.json", 'r'))
     no_good_pcas = ['NYU-06', 'SWC_023']  # no good rhat: 'ibl_witten_13'
-    subjects = list(loading_info.keys())
+    subjects = ['KS014']  # list(loading_info.keys())
     print(subjects)
     fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0]
     dur = 'yes'
 
-    fist_good_pmf = {'CSHL051': (6, 'left'),
-                     'CSHL059': (3, 'left'),
-                     'CSHL061': (4, 'right'),
-                     'CSHL062': (1, 'left'),
-                     'CSHL_007': (2, 'skip'),
-                     'CSHL_014': (4, 'left'),
-                     'CSHL_015': (2, 'right'),
-                     'CSHL_018': (3, 'skip'),
-                     'CSHL_020': (5, 'left'),  # tough case
-                     'CSH_ZAD_001': (4, 'right'),
-                     'CSH_ZAD_011': (0, 'right'),
-                     'CSH_ZAD_022': (4, 'right'),
-                     'CSH_ZAD_025': (2, 'skip'),  # eieiei
-                     'CSH_ZAD_026': (6, 'skip'),  # gradual
-                     'ibl_witten_14': (3, 'left'),
-                     'ibl_witten_16': (3, 'right'),
-                     # 'ibl_witten_18': (3, 'very weird'),  # probably shouldn't be analysed
-                     'ibl_witten_19': (4, 'right'),
-                     'KS014': (4, 'right'),
-                     'KS015': (4, 'left'),
-                     'KS016': (3, 'skip'),  # non-trivial
-                     'KS017': (3, 'right'),
-                     'KS021': (5, 'left'),
-                     'KS022': (4, 'right'),
-                     'KS023': (3, 'left'),
-                     'NYU-06': (9, 'right'),
-                     'SWC_023': (9, 'skip'),  # gradual
-                     'ZM_1897': (5, 'right'),
-                     'ZM_3003': (1, 'skip')}
     # fig, ax = plt.subplots(1, 3, sharey=True, figsize=(16, 9))
 
     thinning = 25
@@ -1356,6 +1485,12 @@ if __name__ == "__main__":
     not_yet = True
 
     abs_state_durs = []
+    all_first_pmfs = {}
+    all_pmf_diffs = []
+    all_pmf_asymms = []
+    all_pmfs = []
+    all_changing_pmfs = []
+    all_intros = []
 
     for subject in subjects:
 
@@ -1366,30 +1501,44 @@ if __name__ == "__main__":
         results = []
 
         try:
-
-            print('loading canonical')
             test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'rb'))
             print('loaded canoncial result')
 
             mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'rb'))
+            quit()
             state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb'))
             # lapse differential
             # lapse_sides(test, [s for s in state_sets if len(s) > 40], mode_indices)
 
             # training overview
-            states, pmfs, durs, _, contrast_intro_type, intros_by_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=1, separate_pmf=True)
+            states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True)
+            all_intros.append(undiv_intros)
             intros_by_type_sum += intros_by_type
-            continue
+            first_pmfs, changing_pmfs = get_first_pmfs(states, pmfs)
+            for pmf in changing_pmfs:
+                if type(pmf[0]) == int:
+                    continue
+                all_changing_pmfs.append(pmf)
+            all_first_pmfs[subject] = first_pmfs
+            for pmf in pmfs:
+                for p in pmf[1]:
+                    all_pmf_diffs.append(p[-1] - p[0])
+                    all_pmf_asymms.append(np.abs(p[0] + p[-1] - 1))
+                    all_pmfs.append(p)
             contrast_intro_types.append(contrast_intro_type)
             # state_development_single_sample(test, [mode_indices[0]], show=True, separate_pmf=True, save=False)
 
             # session overview
-            # consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
-            # consistencies /= consistencies[0, 0]
-            # contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True, consistencies=consistencies)
+            consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
+            consistencies /= consistencies[0, 0]
+            contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True, consistencies=consistencies)
 
             # duration of different state types (and also percentage of type activities)
             abs_state_durs.append(durs)
+            simplex_durs = np.array(durs).reshape(1, 3)
+            print(simplex_durs / np.sum(simplex_durs))
+            from simplex_plot import projectSimplex
+            print(projectSimplex(simplex_durs / simplex_durs.sum(1)[:, None]))
             continue
 
             # compute state type proportions and split the pmfs accordingly
@@ -1567,6 +1716,7 @@ if __name__ == "__main__":
             plt.show()
 
         except FileNotFoundError as e:
+            continue
             print(e)
             r_hat = 1.5
             for r in r_hats:
@@ -1630,11 +1780,33 @@ if __name__ == "__main__":
         #     state_appear.append(b[a == i][0] / (test.results[0].n_sessions - 1))
         #     state_dur.append(b[a == i].shape[0])
 
+    # pickle.dump(all_first_pmfs, open("pmfs_temp.p", 'wb'))
+    # pickle.dump(all_changing_pmfs, open("changing_pmfs.p", 'wb'))
+    #
+    # a = [x for x, y in zip(all_pmf_asymms, all_pmf_diffs) if y >= 0.2]
+    # b = [y for x, y in zip(all_pmf_asymms, all_pmf_diffs) if y >= 0.2]
+    # plt.hist2d(a, b, bins=40)
+    # plt.show()
+
+
     if True:
         abs_state_durs = np.array(abs_state_durs)
         # pickle.dump(abs_state_durs, open("multi_chain_saves/abs_state_durs.p", 'wb'))
         abs_state_durs = pickle.load(open("multi_chain_saves/abs_state_durs.p", 'rb'))
 
+        print("Correlations")
+        from scipy.stats import pearsonr
+        print(pearsonr(abs_state_durs[:, 0], abs_state_durs[:, 1]))
+        print(pearsonr(abs_state_durs[:, 2], abs_state_durs[:, 1]))
+        print(pearsonr(abs_state_durs[:, 0], abs_state_durs[:, 2]))
+
+        print(pearsonr(abs_state_durs.sum(1), abs_state_durs[:, 0]))
+        # (0.7338297529946006, 2.6332570579118393e-06)
+        print(pearsonr(abs_state_durs.sum(1), abs_state_durs[:, 1]))
+        # (0.35094585023228597, 0.052897046343413114)
+        print(pearsonr(abs_state_durs.sum(1), abs_state_durs[:, 2]))
+        # (0.7210260323745921, 4.747833912452452e-06)
+
         from simplex_plot import plotSimplex
 
         plotSimplex(np.array(abs_state_durs), c='k', show=True)
diff --git a/dynamic_GLMiHMM_consistency.py b/dynamic_GLMiHMM_consistency.py
index a8b2cf936d5caca6bbc9dd302c0553ecad1edd26..8bf1924c1a08c92b50d8db433853faa7376093cb 100644
--- a/dynamic_GLMiHMM_consistency.py
+++ b/dynamic_GLMiHMM_consistency.py
@@ -42,10 +42,10 @@ for subject in subjects:
     from_session = info_dict['bias_start'] if fit_type == 'bias' else 0
 
     models = []
-    n_inputs = 5
+    n_regressors = 5
     T = till_session - from_session + (fit_type != 'prebias')
-    obs_hypparams = {'n_inputs': n_inputs, 'T': T, 'prior_mean': np.zeros(n_inputs),
-                     'P_0': 2 * np.eye(n_inputs), 'Q': fit_variance * np.tile(np.eye(n_inputs), (T, 1, 1))}
+    obs_hypparams = {'n_regressors': n_regressors, 'T': T, 'prior_mean': np.zeros(n_regressors), 'jumplimit': 3,
+                     'P_0': 2 * np.eye(n_regressors), 'Q': fit_variance * np.tile(np.eye(n_regressors), (T, 1, 1))}
     dur_hypparams = dict(r_support=np.array([1, 2, 3, 5, 7, 10, 15, 21, 28, 36, 45, 55, 150]),
                          r_probs=np.ones(13)/13., alpha_0=1, beta_0=1)
 
@@ -78,7 +78,7 @@ for subject in subjects:
 
         bad_trials = data[:, 1] == 1
         bad_trials[0] = True
-        mega_data = np.empty((np.sum(~bad_trials), n_inputs + 1))
+        mega_data = np.empty((np.sum(~bad_trials), n_regressors + 1))
 
         mega_data[:, 0] = np.maximum(data[~bad_trials, 0], 0)
         mega_data[:, 1] = np.abs(np.minimum(data[~bad_trials, 0], 0))
@@ -123,7 +123,10 @@ for subject in subjects:
 
 prev_res = pickle.load(open(save_title, 'rb'))
 
+counter = 0
 for p, m in zip(prev_res, models):
+    print(counter)
+    counter += 1
     for od, nd in zip(p.obs_distns, m.obs_distns):
         assert np.allclose(od.weights, nd.weights)
 prof._prof.print_stats()
diff --git a/dynamic_GLMiHMM_fit.py b/dynamic_GLMiHMM_fit.py
index b05ad929cf8ee0cc6558c0b6948875564ded52e6..dbbe447a55e51e03d20aef3e840ccd600d76f653 100644
--- a/dynamic_GLMiHMM_fit.py
+++ b/dynamic_GLMiHMM_fit.py
@@ -204,14 +204,14 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
     models = []
 
     if params['obs_dur'] == 'glm':
-        n_inputs = len(params['regressors'])
+        n_regressors = len(params['regressors'])
         T = till_session - from_session + (params['fit_type'] != 'prebias')
-        obs_hypparams = {'n_inputs': n_inputs, 'T': T, 'jumplimit': params['jumplimit'], 'prior_mean': params['init_mean'],
-                         'P_0': params['init_var'] * np.eye(n_inputs), 'Q': params['fit_variance'] * np.tile(np.eye(n_inputs), (T, 1, 1))}
+        obs_hypparams = {'n_regressors': n_regressors, 'T': T, 'jumplimit': params['jumplimit'], 'prior_mean': params['init_mean'],
+                         'P_0': params['init_var'] * np.eye(n_regressors), 'Q': params['fit_variance'] * np.tile(np.eye(n_regressors), (T, 1, 1))}
         obs_distns = [distributions.Dynamic_GLM(**obs_hypparams) for state in range(params['n_states'])]
     else:
-        n_inputs = 9 if params['fit_type'] == 'bias' else 11
-        obs_hypparams = {'n_inputs': n_inputs * (1 + (params['conditioned_on'] != 'nothing')), 'n_outputs': 2, 'T': till_session - from_session + (params['fit_type'] != 'prebias'),
+        n_regressors = 9 if params['fit_type'] == 'bias' else 11
+        obs_hypparams = {'n_regressors': n_regressors * (1 + (params['conditioned_on'] != 'nothing')), 'n_outputs': 2, 'T': till_session - from_session + (params['fit_type'] != 'prebias'),
                          'jumplimit': params['jumplimit'], 'sigmasq_states': params['fit_variance']}
         obs_distns = [Dynamic_Input_Categorical(**obs_hypparams) for state in range(params['n_states'])]
 
@@ -280,7 +280,7 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
             mask[0] = False
             if params['fit_type'] == 'zoe_style':
                 mask[90:] = False
-            mega_data = np.empty((np.sum(mask), n_inputs + 1))
+            mega_data = np.empty((np.sum(mask), n_regressors + 1))
 
             for i, reg in enumerate(params['regressors']):
                 # positive numbers are contrast on the right
diff --git a/mcmc_chain_analysis.py b/mcmc_chain_analysis.py
index 2a99a67602392e961761ef6ead4b9f40650e0d93..c8de12596afbb348ca65a77d28db0726031a52f5 100644
--- a/mcmc_chain_analysis.py
+++ b/mcmc_chain_analysis.py
@@ -8,6 +8,8 @@ import pickle
 
 
 def state_size_helper(n=0, mode_specific=False):
+    """Returns a function that returns the # of trials associated to the nth largest state in a sample
+       can be further specified to only look at specific samples, those of a mode"""
     if not mode_specific:
         def nth_largest_state_func(x):
             return np.partition(x.assign_counts, -1 - n, axis=1)[:, -1 - n]
@@ -18,6 +20,8 @@ def state_size_helper(n=0, mode_specific=False):
 
 
 def state_num_helper(t, mode_specific=False):
+    """Returns a function that returns the # of states which have more trials than a percentage threshold t in a sample
+       can be further specified to only look at specific samples, those of a mode"""
     if not mode_specific:
         def state_num_func(x): return ((x.assign_counts / x.n_datapoints) > t).sum(1)
     else:
@@ -35,14 +39,15 @@ def ll_func(x): return x.sample_lls[-x.n_samples:]
 
 
 def r_hat_array_comp(chains):
-    m, n = chains.shape
+    """Computes R^hat on an array of features, following Gelman p. 284f"""
+    m, n = chains.shape  # number of chains, length of chains
     psi_dot_j = np.mean(chains, axis=1)
     psi_dot_dot = np.mean(psi_dot_j)
     B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot) ** 2)
     s_j_squared = np.sum((chains - psi_dot_j[:, None]) ** 2, axis=1) / (n - 1)
     W = np.mean(s_j_squared)
     var_hat_plus = (n - 1) / n * W + B / n
-    if W == 0:
+    if W == 0:  # sometimes a feature has 0 variance
         # print("all the same value")
         return 1, 0
     r_hat = np.sqrt(var_hat_plus / W)
@@ -50,6 +55,7 @@ def r_hat_array_comp(chains):
 
 
 def eval_amortized_r_hat(chains, psi_dot_j, s_j_squared, m, n):
+    """Unused version in which some things were computed ahead of function to save time."""
     psi_dot_dot = np.mean(psi_dot_j, axis=1)
     B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot[:, None]) ** 2, axis=1)
     W = np.mean(s_j_squared, axis=1)
@@ -59,6 +65,7 @@ def eval_amortized_r_hat(chains, psi_dot_j, s_j_squared, m, n):
 
 
 def r_hat_array_comp_mult(chains):
+    """Compute R^hat of multiple features at once."""
     _, m, n = chains.shape
     psi_dot_j = np.mean(chains, axis=2)
     psi_dot_dot = np.mean(psi_dot_j, axis=1)
@@ -71,8 +78,8 @@ def r_hat_array_comp_mult(chains):
 
 
 def rank_inv_normal_transform(chains):
-    # Gelman paper Rank-normalization, folding, and localization: An improved R_hat for assessing convergence of MCMC
-    # ranking with average rank for ties
+    """Gelman paper Rank-normalization, folding, and localization: An improved R_hat for assessing convergence of MCMC
+       ranking with average rank for ties"""
     folded_chains = np.abs(chains - np.median(chains))
     ranked = rankdata(chains).reshape(chains.shape)
     folded_ranked = rankdata(folded_chains).reshape(folded_chains.shape)
@@ -83,6 +90,8 @@ def rank_inv_normal_transform(chains):
 
 
 def eval_r_hat(chains):
+    """Compute entire set of R^hat's for list of feature arrays, and return maximum across features.
+       Computes all R^hat versions, as opposed to eval_simple_r_hat"""
     r_hats = []
     for chain in chains:
         rank_normalised, folded_rank_normalised, _, _ = rank_inv_normal_transform(chain)
@@ -92,6 +101,8 @@ def eval_r_hat(chains):
 
 
 def eval_simple_r_hat(chains):
+    """Compute just simple R^hat's for list of feature arrays, and return maximum across features.
+       Computes only the simple type of R^hat, no folding or rank normalising, making it much faster"""
     r_hats, _ = r_hat_array_comp_mult(chains)
     return max(r_hats)
 
@@ -103,21 +114,21 @@ def comp_multi_r_hat(chains, rank_normalised, folded_rank_normalised):
     return max(lame_r_hat, rank_normalised_r_hat, folded_rank_normalised_r_hat)
 
 
-def sample_statistics(test, mode_indices, subject):
+def sample_statistics(mode_indices, subject, period='prebias'):
     # prints out r_hats and sample sizes for given sample
-    test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb'))
+    test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, period), 'rb'))
     test.r_hat_and_ess(state_size_helper(1), False)
     test.r_hat_and_ess(state_size_helper(1, mode_specific=True), False, mode_indices=mode_indices)
     print()
-    test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb'))
+    test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, period), 'rb'))
     test.r_hat_and_ess(state_size_helper(), False)
     test.r_hat_and_ess(state_size_helper(mode_specific=True), False, mode_indices=mode_indices)
     print()
-    test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb'))
+    test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, period), 'rb'))
     test.r_hat_and_ess(state_num_helper(0.05), False)
     test.r_hat_and_ess(state_num_helper(0.05, mode_specific=True), False, mode_indices=mode_indices)
     print()
-    test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb'))
+    test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, period), 'rb'))
     test.r_hat_and_ess(state_num_helper(0.02), False)
     test.r_hat_and_ess(state_num_helper(0.02, mode_specific=True), False, mode_indices=mode_indices)
     print()
diff --git a/simplex_plot.py b/simplex_plot.py
index bdaad7885fa71bbbddcdd58b45447f4acf442ae3..dd8ba57d9f3d356cda8e04524404b4833a627a60 100644
--- a/simplex_plot.py
+++ b/simplex_plot.py
@@ -37,6 +37,7 @@ def plotSimplex(points, fig=None,
     fig.gca().text(0.43, np.sqrt(3) / 2 + 0.025, vertexlabels[2], size=24)
     # Project and draw the actual points
     projected = projectSimplex(points / points.sum(1)[:, None])
+    print(projected)
     P.scatter(projected[:, 0], projected[:, 1], s=points.sum(1) * 3.5, **kwargs)
 
     # plot center with average size
@@ -90,14 +91,14 @@ if __name__ == '__main__':
     labels = ('[0.1  0.1  0.8]',
               '[0.8  0.1  0.1]',
               '[0.5  0.4  0.1]',
+              '[0.17  0.33  0.5]',
               '[0.33  0.34  0.33]')
     testpoints = np.array([[0.1, 0.1, 0.8],
                            [0.8, 0.1, 0.1],
                            [0.5, 0.4, 0.1],
+                           [0.17, 0.33, 0.5],
                            [0.33, 0.34, 0.33]])
     # Define different colors for each label
     c = range(len(labels))
     # Do scatter plot
-    fig = plotSimplex(testpoints, s=25, c='k')
-
-    P.show()
+    fig = plotSimplex(testpoints, c='k', show=1)
diff --git a/test_codes/pymc_compare/call.py b/test_codes/pymc_compare/call.py
index aeeedd8a5becdffd6824fc457a4578f8b2729a05..bc52153835ea5302a9483829c3a459a6c13240f3 100644
--- a/test_codes/pymc_compare/call.py
+++ b/test_codes/pymc_compare/call.py
@@ -1,14 +1,19 @@
-import pymc, bayes_fit              # load the model file
+"""Perform a pymc sampling of the test data."""
+import pymc, bayes_fit  # load the model file
 import numpy as np
 import pickle
 
+# Data params
+T = 14
+n_inputs = 3
+
+# Sampling params
 n_samples = 400000
 
-R = pymc.MCMC(bayes_fit)    #  build the model
-R.sample(n_samples)              # populate and run it
+R = pymc.MCMC(bayes_fit)  # build the model
+R.sample(n_samples)  # populate and run it
 
-T = 14
-n_inputs = 3
+# Extract weights
 weights = np.zeros((T, n_samples, n_inputs))
 for t in range(T):
     try:
@@ -16,4 +21,5 @@ for t in range(T):
     except KeyError:
         weights[t] = R.trace('ws'.format(t))
 
+# Save everything
 pickle.dump(weights, open('pymc_posterior', 'wb'))
diff --git a/test_codes/pymc_compare/dynglm_optimisation_test.py b/test_codes/pymc_compare/dynglm_optimisation_test.py
index 05df6e0cd7e1e02aa462861346bcc1e721d82e8f..e17222eae9ed15efd967844f79fa70adcc34cd99 100644
--- a/test_codes/pymc_compare/dynglm_optimisation_test.py
+++ b/test_codes/pymc_compare/dynglm_optimisation_test.py
@@ -2,7 +2,7 @@
 Need to find out whether loglikelihood is computed correctly.
 
 Or whether a bug here allows states to invade each other more easily.
-We'll test this by maximising the likelihood directly.
+We'll test this by comparing to pymc results.
 """
 import numpy as np
 import pyhsmm.basic.distributions as distributions
diff --git a/test_codes/pymc_compare/gibbs_sample.py b/test_codes/pymc_compare/gibbs_sample.py
index 31ed5919143fdc6e74eb79107117c2b0d603285b..5b91c1ab0a72c404452b14d725e24225be2942c9 100644
--- a/test_codes/pymc_compare/gibbs_sample.py
+++ b/test_codes/pymc_compare/gibbs_sample.py
@@ -8,7 +8,7 @@ import pyhsmm.basic.distributions as distributions
 from scipy.optimize import minimize
 import pickle
 
-# Data Params
+# Data params
 T = 16
 n_inputs = 3
 step_size = 0.2
@@ -41,5 +41,5 @@ LL_weights = np.zeros((T, n_inputs))
 for t in range(T):
     LL_weights[t] = minimize(lambda w: wrapper(w, t), np.zeros(n_inputs)).x
 
-# save everything
+# Save everything
 pickle.dump((samples, LL_weights), open('gibbs_posterior', 'wb'))