diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py
index bcae85556847df87a88bf37d6af5159780ae8123..3389c9473732e38b5e8adbdf3fa277eeb275f7f0 100644
--- a/dyn_glm_chain_analysis.py
+++ b/dyn_glm_chain_analysis.py
@@ -441,9 +441,6 @@ class MCMC_result:
         self.n_contrasts = 11
         self.cont_ticks = all_cont_ticks
 
-        self.state_to_color = {}
-        self.state_to_ls = {}
-
         self.session_contrasts = [np.unique(cont_mapping(d[:, 0] - d[:, 1])) for d in self.data]
 
         self.count_assigns()
@@ -617,9 +614,11 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur
             if consistencies is None:
                 active_trials[relevant_trials - trial_counter] = 1
             else:
-                active_trials[relevant_trials - trial_counter] = np.mean(consistencies[tuple(np.meshgrid(relevant_trials, trials))], axis=0)
-                print("fix this by taking the whole array, multiply by n, subtract n, divide by n-1")
-                input()
+                active_trials[relevant_trials - trial_counter] = np.sum(consistencies[tuple(np.meshgrid(relevant_trials, trials))], axis=0)
+                active_trials[relevant_trials - trial_counter] -= 1
+                active_trials[relevant_trials - trial_counter] = active_trials[relevant_trials - trial_counter] / (trials.shape[0] - 1)
+                # print("fix this by taking the whole array, multiply by n, subtract n, divide by n-1")
+                # input()
 
             label = "State {}".format(state) if np.sum(relevant_trials) > 0.02 * len(test.results[0].models[0].stateseqs[seq_num]) else None
 
@@ -747,7 +746,6 @@ def lapse_sides(test, state_sets, indices):
 
 
 def state_development_single_sample(test, indices, save=True, save_append='', show=True, dpi='figure', separate_pmf=False):
-    session_contrasts = [np.unique(cont_mapping(d[:, 0] - d[:, 1])) for d in test.results[0].data]
 
     for i, m in enumerate([item for sublist in test.results for item in sublist.models]):
         if i not in indices:
@@ -809,7 +807,7 @@ def state_development_single_sample(test, indices, save=True, save_append='', sh
                 session_max = i
 
         defined_points = np.zeros(test.results[0].n_contrasts, dtype=bool)
-        defined_points[session_contrasts[session_max]] = True
+        defined_points[test.results[0].session_contrasts[session_max]] = True
 
         n_points = 150
         points = np.linspace(1, test.results[0].n_sessions, n_points)
@@ -913,7 +911,6 @@ def state_development_single_sample(test, indices, save=True, save_append='', sh
 
 def state_development(test, state_sets, indices, save=True, save_append='', show=True, dpi='figure', separate_pmf=False):
     state_sets = [np.array(s) for s in state_sets]
-    session_contrasts = [np.unique(cont_mapping(d[:, 0] - d[:, 1])) for d in test.results[0].data]
 
     if test.results[0].name.startswith('GLM_Sim_'):
         print("./glm sim mice/truth_{}.p".format(test.results[0].name))
@@ -1019,7 +1016,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
             trial_counter += len(state_seq)
 
         defined_points = np.zeros(test.results[0].n_contrasts, dtype=bool)
-        defined_points[session_contrasts[session_max]] = True
+        defined_points[test.results[0].session_contrasts[session_max]] = True
         if not separate_pmf:
             temp = np.sum(pmfs[:, defined_points]) / (np.sum(defined_points))
             state_color = colors[int(temp * 101 - 1)]
@@ -1195,41 +1192,6 @@ def find_good_chains_unsplit(chains1, chains2, chains3, chains4, reduce_to=8, si
     return sol, r_hat_min
 
 
-def find_good_chains_unsplit_fast(chains1, chains2, chains3, chains4, reduce_to=8):
-    delete_n = - reduce_to + chains1.shape[0]
-    mins = np.zeros(delete_n + 1)
-    n_chains = chains1.shape[0]
-    chains = np.stack([chains1, chains2, chains3, chains4])
-
-    print("Without removals: {}".format(eval_simple_r_hat(chains)))
-    r_hat = eval_simple_r_hat(chains)
-
-    mins[0] = r_hat
-
-    l, m, n = chains.shape
-    psi_dot_j = np.mean(chains, axis=2)
-    s_j_squared = np.sum((chains - psi_dot_j[:, :, None]) ** 2, axis=2) / (n - 1)
-
-    r_hat_min = 10
-    sol = 0
-    for x in combinations(range(n_chains), n_chains - delete_n):
-        temp1 = chains[:, x]
-        temp2 = psi_dot_j[:, x]
-        temp3 = s_j_squared[:, x]
-        r_hat = eval_amortized_r_hat(temp1, temp2, temp3, l, m - delete_n, n)
-        if r_hat < r_hat_min:
-            sol = x
-        r_hat_min = min(r_hat, r_hat_min)
-    print("Minimum is {} (removed {})".format(r_hat_min, delete_n))
-    sol = [i for i in range(n_chains) if i not in sol]
-    print("Removed: {}".format(sol))
-
-    r_hat_local = eval_r_hat(np.delete(chains1, sol, axis=0), np.delete(chains2, sol, axis=0), np.delete(chains3, sol, axis=0), np.delete(chains4, sol, axis=0))
-    print("Minimum over everything is {} (removed {})".format(r_hat_local, delete_n))
-
-    return sol, r_hat_min
-
-
 def find_good_chains_unsplit_greedy(chains1, chains2, chains3, chains4, reduce_to=8, simple=False):
     delete_n = - reduce_to + chains1.shape[0]
     mins = np.zeros(delete_n + 1)
@@ -1320,12 +1282,6 @@ if __name__ == "__main__":
             chains3 = chains3[:, 160:]
             chains4 = chains4[:, 160:]
 
-            # chains1 = chains1[:16]
-            # chains2 = chains2[:16]
-            # chains3 = chains3[:16]
-            # chains4 = chains4[:16]
-
-
             # mins = find_good_chains(chains[:, :-1].reshape(32, chains.shape[1] // 2))
             sol, final_r_hat = find_good_chains_unsplit_greedy(chains1, chains2, chains3, chains4, reduce_to=chains1.shape[0] // 2)
             r_hats.append((subject, final_r_hat))
@@ -1352,22 +1308,6 @@ def dist_helper(dist_matrix, state_hists, inds):
     return dist_matrix
 
 
-def _get_leaves_color_list(R):
-    # copied from latest scipy version
-    leaves_color_list = [None] * len(R['leaves'])
-    for link_x, link_y, link_color in zip(R['icoord'],
-                                          R['dcoord'],
-                                          R['color_list']):
-        for (xi, yi) in zip(link_x, link_y):
-            if yi == 0.0:  # if yi is 0.0, the point is a leaf
-                # xi of leaves are      5, 15, 25, 35, ... (see `iv_ticks`)
-                # index of leaves are   0,  1,  2,  3, ... as below
-                leaf_index = (int(xi) - 5) // 10
-                # each leaf has a same color of its link.
-                leaves_color_list[leaf_index] = link_color
-    return leaves_color_list
-
-
 def state_size_helper(n=0, mode_specific=False):
     if not mode_specific:
         def nth_largest_state_func(x):
@@ -1496,7 +1436,7 @@ if __name__ == "__main__":
         loading_info = json.load(open("canonical_infos.json", 'r'))
         r_hats = json.load(open("canonical_info_r_hats.json", 'r'))
     no_good_pcas = ['NYU-06', 'SWC_023']  # no good rhat: 'ibl_witten_13'
-    subjects = list(loading_info.keys())
+    subjects = ['KS014']  # list(loading_info.keys())
     print(subjects)
     fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0]
     dur = 'yes'
@@ -1531,7 +1471,7 @@ if __name__ == "__main__":
                      'SWC_023': (9, 'skip'),  # gradual
                      'ZM_1897': (5, 'right'),
                      'ZM_3003': (1, 'skip')}
-    fig, ax = plt.subplots(1, 3, sharey=True, figsize=(16, 9))
+    # fig, ax = plt.subplots(1, 3, sharey=True, figsize=(16, 9))
 
     thinning = 25
     summary_info = {"thinning": thinning, "contains": [], "seeds": [], "fit_nums": []}
@@ -1569,22 +1509,19 @@ if __name__ == "__main__":
             test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'rb'))
             print('loaded canoncial result')
 
-            mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}.p".format(subject), 'rb'))
-
-
-            state_sets = pickle.load(open("multi_chain_saves/state_sets_{}.p".format(subject), 'rb'))
+            # state_sets = pickle.load(open("multi_chain_saves/state_sets_{}.p".format(subject), 'rb'))
             # lapse_sides(test, [s for s in state_sets if len(s) > 40], mode_indices)
             # continue
             #
-            # mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'rb'))
-            # state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb'))
-            # states, pmfs = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=True, separate_pmf=True)
+            mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'rb'))
+            state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb'))
+            states, pmfs = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=True, separate_pmf=True)
             # state_development_single_sample(test, [mode_indices[0]], show=True, separate_pmf=True, save=False)
-            # quit()
-            # consistencies = pickle.load(open("multi_chain_saves/consistencies_{}.p".format(subject), 'rb'))
-            # consistencies /= consistencies[0, 0]
-            # contrasts_plot(test, [s for s in state_sets if len(s) > 40], subject=subject, save=True, show=True, consistencies=consistencies)
-            # quit()
+            consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
+            consistencies /= consistencies[0, 0]
+            contrasts_plot(test, [s for s in state_sets if len(s) > 40], subject=subject, save=True, show=True, consistencies=consistencies)
+            quit()
+
 
             # state_types = np.zeros((3, states.shape[1]))
             # for s, pmf in zip(states, pmfs):
@@ -1802,7 +1739,6 @@ if __name__ == "__main__":
                     contrasts_plot(test, [s for s in state_sets if len(s) > 40], subject=subject, save_append='_{}{}'.format(string_prefix, criterion), save=True, show=True)
 
                 # R = hc.dendrogram(linkage, no_plot=True, get_leaves=True, color_threshold=color_threshold)
-                # R["leaves_color_list"] = _get_leaves_color_list(R)
 
                 # leaves_color_list = np.array(R["leaves_color_list"])
                 # leaves = np.array(R["leaves"])
diff --git a/dyn_glm_chain_analysis_unused_funcs.py b/dyn_glm_chain_analysis_unused_funcs.py
index 38015be96f29c92060dfad32e07934207ef9adee..d1ae35cc0c469ea18daf07886ff472923eb05840 100644
--- a/dyn_glm_chain_analysis_unused_funcs.py
+++ b/dyn_glm_chain_analysis_unused_funcs.py
@@ -64,6 +64,41 @@ class MCMC_result:
         return consistency_mat
 
 
+def find_good_chains_unsplit_fast(chains1, chains2, chains3, chains4, reduce_to=8):
+    delete_n = - reduce_to + chains1.shape[0]
+    mins = np.zeros(delete_n + 1)
+    n_chains = chains1.shape[0]
+    chains = np.stack([chains1, chains2, chains3, chains4])
+
+    print("Without removals: {}".format(eval_simple_r_hat(chains)))
+    r_hat = eval_simple_r_hat(chains)
+
+    mins[0] = r_hat
+
+    l, m, n = chains.shape
+    psi_dot_j = np.mean(chains, axis=2)
+    s_j_squared = np.sum((chains - psi_dot_j[:, :, None]) ** 2, axis=2) / (n - 1)
+
+    r_hat_min = 10
+    sol = 0
+    for x in combinations(range(n_chains), n_chains - delete_n):
+        temp1 = chains[:, x]
+        temp2 = psi_dot_j[:, x]
+        temp3 = s_j_squared[:, x]
+        r_hat = eval_amortized_r_hat(temp1, temp2, temp3, l, m - delete_n, n)
+        if r_hat < r_hat_min:
+            sol = x
+        r_hat_min = min(r_hat, r_hat_min)
+    print("Minimum is {} (removed {})".format(r_hat_min, delete_n))
+    sol = [i for i in range(n_chains) if i not in sol]
+    print("Removed: {}".format(sol))
+
+    r_hat_local = eval_r_hat(np.delete(chains1, sol, axis=0), np.delete(chains2, sol, axis=0), np.delete(chains3, sol, axis=0), np.delete(chains4, sol, axis=0))
+    print("Minimum over everything is {} (removed {})".format(r_hat_local, delete_n))
+
+    return sol, r_hat_min
+
+
 if __name__ == "__main__":
 
     subjects = list(loading_info.keys())