diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py
index 9b977369a2029101d52d48de0845217e14b6bdf6..fc7be25aa6fab9564c2d1753081bc62797bfad4c 100644
--- a/dyn_glm_chain_analysis.py
+++ b/dyn_glm_chain_analysis.py
@@ -719,7 +719,7 @@ def lapse_sides(test, state_sets, indices):
     """Compute and plot a lapse differential across sessions.
 
     Takes a single mouse and plots (1 - lapse_left) - lapse_right across sessions, with sessions boundaries shown."""
-    
+
     def func_init(): return {'lapse_side': np.zeros(test.results[0].n_datapoints) + 10, 'session_bounds': []}
 
     def first_for(test, results):
@@ -1384,24 +1384,46 @@ len_to_bools = {
 }
 
 
+def state_type_durs(states, pmfs):
+    # Takes states and pmfs, first creates an array of when which type is how active, then computes the number of sessions each type lasts.
+    # A type lasts until a more advanced type takes up more than 50% of a session (and cannot return)
+    # Returns the durations for all the different state types, and an array which holds the state percentages
+    state_types = np.zeros((3, states.shape[1]))
+    for s, pmf in zip(states, pmfs):
+        pmf_counter = -1
+        for i in range(states.shape[1]):
+            if s[i]:
+                pmf_counter += 1
+                state_types[pmf_type(pmf[1][pmf_counter][pmf[0]]), i] += s[i]  # indexing horror
+
+    durs = (np.where(state_types[1] > 0.5)[0][0],
+            np.where(state_types[2] > 0.5)[0][0] - np.where(state_types[1] > 0.5)[0][0],
+            states.shape[1] - np.where(state_types[2] > 0.5)[0][0])
+    if np.where(state_types[2] > 0.5)[0][0] < np.where(state_types[1] > 0.5)[0][0]:
+        durs = (np.where(state_types[2] > 0.5)[0][0],
+                0,
+                states.shape[1] - np.where(state_types[2] > 0.5)[0][0])
+    return durs, state_types
+
+
 def state_cluster_interpolation(states, pmfs):
+    #
+    # Used to contain a first_type_count variable, which seemed to just count the number of states?
     pmf_examples = [[], [], []]
     state_trans = np.zeros((3, 3))
     state_types = np.zeros((3, states.shape[1]))
-    first_type_count = 0
     for state, pmf in zip(states, pmfs):
         sessions = np.where(state)[0]
         for i, sess_pmf in enumerate(pmf[1]):
             if i == 0:
                 state_type = pmf_type(sess_pmf[pmf[0]])
-                first_type_count += state_type == 0
             if i > 0 and state_type != pmf_type(sess_pmf[pmf[0]]):
                 state_trans[state_type, pmf_type(sess_pmf[pmf[0]])] += 1
                 state_type = pmf_type(sess_pmf[pmf[0]])
             pmf_examples[pmf_type(sess_pmf[pmf[0]])].append(sess_pmf[pmf[0]])
             state_types[pmf_type(sess_pmf[pmf[0]]), sessions[i]] += state[sessions[i]]
 
-    return state_types, state_trans, first_type_count, pmf_examples
+    return state_types, state_trans, pmf_examples
 
 
 def pmf_type(pmf):
@@ -1477,7 +1499,6 @@ if __name__ == "__main__":
     n_points = 150
     state_trajs = np.zeros((3, n_points))
     state_trans = np.zeros((3, 3))
-    first_type_count = 0
 
     not_yet = True
 
@@ -1496,69 +1517,48 @@ if __name__ == "__main__":
 
             mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'rb'))
             state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb'))
-            lapse_sides(test, [s for s in state_sets if len(s) > 40], mode_indices)
-            continue
-            #
-            state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb'))
-            states, pmfs = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=True, separate_pmf=True)
+
+            # lapse differential
+            # lapse_sides(test, [s for s in state_sets if len(s) > 40], mode_indices)
+
+            # training overview
+            # states, pmfs = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=True, separate_pmf=True)
             # state_development_single_sample(test, [mode_indices[0]], show=True, separate_pmf=True, save=False)
-            consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
-            consistencies /= consistencies[0, 0]
-            contrasts_plot(test, [s for s in state_sets if len(s) > 40], subject=subject, save=True, show=True, consistencies=consistencies)
-            quit()
 
+            # session overview
+            # consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
+            # consistencies /= consistencies[0, 0]
+            # contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True, consistencies=consistencies)
 
-            # state_types = np.zeros((3, states.shape[1]))
-            # for s, pmf in zip(states, pmfs):
-            #     pmf_counter = -1
-            #     for i in range(states.shape[1]):
-            #         if s[i]:
-            #             pmf_counter += 1
-            #             state_types[pmf_type(pmf[1][pmf_counter][pmf[0]]), i] += s[i]  # indexing horror
-            #
-            # durs = (np.where(state_types[1] > 0.5)[0][0],
-            #         np.where(state_types[2] > 0.5)[0][0] - np.where(state_types[1] > 0.5)[0][0],
-            #         states.shape[1] - np.where(state_types[2] > 0.5)[0][0])
-            # if np.where(state_types[2] > 0.5)[0][0] < np.where(state_types[1] > 0.5)[0][0]:
-            #     durs = (np.where(state_types[2] > 0.5)[0][0],
-            #             0,
-            #             states.shape[1] - np.where(state_types[2] > 0.5)[0][0])
-            # if durs[0] < 0 or durs[1] < 0 or durs[2] < 0:
-            #     quit()
+            # duration of different state types (and also percentage of type activities)
+            # durs, _ = state_type_durs(states, pmfs)
             # abs_state_durs.append(durs)
-            # if durs[1] < 0:
-            #     state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=False, separate_pmf=True)
             # rel_state_durs.append((durs[0] / states.shape[1], durs[1] / states.shape[1], durs[2] / states.shape[1]))
-            # continue
-            # contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True)
-            # continue
-            # ret, trans, count, pmf_types = state_cluster_interpolation(states, pmfs)
-            # state_trans += trans
-            # first_type_count += count
 
-            # points = np.linspace(1, test.results[0].n_sessions, n_points)
-            #
-            # # plt.plot(np.arange(1, 1 + test.results[0].n_sessions), ret.T)
-            # for i, r in enumerate(ret):
-            #     state_trajs[i] += np.interp(points, np.arange(1, 1 + test.results[0].n_sessions), r)
-            #     # plt.plot(points, np.interp(points, np.arange(1, 1 + test.results[0].n_sessions), r))
-            # # plt.show()
-            # continue
 
+            ret, trans, count, pmf_types = state_cluster_interpolation(states, pmfs)
+            state_trans += trans
+            points = np.linspace(1, test.results[0].n_sessions, n_points)
+            # plt.plot(np.arange(1, 1 + test.results[0].n_sessions), ret.T)
+            for i, r in enumerate(ret):
+                state_trajs[i] += np.interp(points, np.arange(1, 1 + test.results[0].n_sessions), r)
+                plt.plot(points, np.interp(points, np.arange(1, 1 + test.results[0].n_sessions), r))
+            plt.show()
+            continue
 
-            # plt.plot(ret.T, label=[0, 1, 2])
-            # plt.legend()
-            # plt.show()
-            # continue
+            plt.plot(ret.T, label=[0, 1, 2])
+            plt.legend()
+            plt.show()
+            continue
 
-            # for i, pmfs in enumerate(pmf_types):
-            #     plt.subplot(1, 3, i + 1)
-            #     plt.plot([0, 11], [0.5, 0.5], 'grey', alpha=1/3)
-            #     for pmf in pmfs:
-            #         plt.plot(len_to_bools[len(pmf)], pmf)
-            #     plt.ylim(0, 1)
-            # plt.show()
-            # continue
+            for i, pmfs in enumerate(pmf_types):
+                plt.subplot(1, 3, i + 1)
+                plt.plot([0, 11], [0.5, 0.5], 'grey', alpha=1/3)
+                for pmf in pmfs:
+                    plt.plot(len_to_bools[len(pmf)], pmf)
+                plt.ylim(0, 1)
+            plt.show()
+            continue
 
             # quit()
             # for pmf in pmfs: