diff --git a/__pycache__/mcmc_chain_analysis.cpython-37.pyc b/__pycache__/mcmc_chain_analysis.cpython-37.pyc
index 9f093ef55cf22048cd989c21378e8d782850879c..dfcc898f4df0897c89643b35287844e38e54193b 100644
Binary files a/__pycache__/mcmc_chain_analysis.cpython-37.pyc and b/__pycache__/mcmc_chain_analysis.cpython-37.pyc differ
diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py
index 71e06373a2a4dc0a114fb5407e8de9f2a9479899..04907a5a632d01e90a49247b007a1febd4f1fe6e 100644
--- a/dyn_glm_chain_analysis.py
+++ b/dyn_glm_chain_analysis.py
@@ -72,52 +72,11 @@ class MCMC_result_list:
             state_nums[i] = state_num_helper(0.05)(self.results[i])
         return state_nums
 
-    def return_chains(self, func, model_for_loop, rank_norm=True, mode_indices=None):
-        # Following Gelman page 284
-
-        if mode_indices is None:
-            if func.__name__ == "temp_state_glm_func":
-                chains = np.zeros((self.m, self.n, 4))
-            else:
-                chains = np.zeros((self.m, self.n))
-            if model_for_loop:
-                for j, result in enumerate(self.results):
-                    for i in range(self.n):
-                        chains[j, i] = func(result.models[i])
-            else:
-                for i in range(self.m):
-                    chains[i] = func(self.results[i])
-        else:
-            inds = []
-            for lims in [range(i * self.n, (i + 1) * self.n) for i in range(self.m)]:
-                inds.append([ind - lims[0] for ind in mode_indices if ind in lims])
-            lens = [len(ind) for ind in inds]
-            min_len = min([l for l in lens if l != 0])
-            n_remaining_chains = len([l for l in lens if l != 0])
-            print("{} chains left with a len of {}".format(n_remaining_chains, min_len))
-            # print(lens)
-            chains = np.zeros((n_remaining_chains, min_len))
-            print(func.__name__)
-            if func.__name__ == "temp_state_glm_func":
-                chains = np.zeros((n_remaining_chains, min_len, 4))
-            else:
-                chains = np.zeros((n_remaining_chains, min_len))
-            counter = -1
-            for i, ind in enumerate(inds):
-                if len(ind) == 0:
-                    continue
-                counter += 1
-                step = len(ind) // min_len
-                up_to = min_len * step
-                # print([j + i * 640 for j in ind[:up_to:step]])
-                chains[counter] = func(self.results[i], ind[:up_to:step])
-        return chains
-
     def r_hat_and_ess(self, func, model_for_loop, rank_norm=True, mode_indices=None):
-        # Following Gelman page 284
+        """Compute all kinds of R^hat's and the effective sample size with the intermediate steps
+           Following Gelman page 284f"""
 
         if mode_indices is None:
-
             chains = np.zeros((self.m, self.n))
             if model_for_loop:
                 for j, result in enumerate(self.results):
@@ -130,17 +89,11 @@ class MCMC_result_list:
             inds = []
             for lims in [range(i * self.n, (i + 1) * self.n) for i in range(self.m)]:
                 inds.append([ind - lims[0] for ind in mode_indices if ind in lims])
-            # inds = [[] for i in range(self.m)]
-            # for i, m in enumerate([item for sublist in self.results for item in sublist.models]):
-            #     if i not in mode_indices:
-            #         continue
-            #     inds[i // self.n].append(i % self.n)
             lens = [len(ind) for ind in inds]
-            min_len = min([l for l in lens if l != 0])
-            n_remaining_chains = len([l for l in lens if l != 0])
+            min_len = min([li for li in lens if li != 0])
+            n_remaining_chains = len([li for li in lens if li != 0])
             self.m, self.n = n_remaining_chains, min_len
             print("{} chains left with a len of {}".format(n_remaining_chains, min_len))
-            # print(lens)
             chains = np.zeros((n_remaining_chains, min_len))
             counter = -1
             for i, ind in enumerate(inds):
@@ -152,59 +105,54 @@ class MCMC_result_list:
                 chains[counter] = func(self.results[i], ind[:up_to:step])
 
         self.chains = chains
-
         self.rank_normalised, self.folded_rank_normalised, self.ranked, self.folded_ranked = rank_inv_normal_transform(self.chains)
 
+        # Compute all R^hats, use the worst
         self.lame_r_hat, self.lame_var_hat_plus = r_hat_array_comp(self.chains)
-        self.rank_normalised_r_hat, self.var_hat_plus = r_hat_array_comp(self.rank_normalised)
+        self.rank_normalised_r_hat, self.rank_normed_var_hat_plus = r_hat_array_comp(self.rank_normalised)
         self.folded_rank_normalised_r_hat, self.folded_rank_normalised_var_hat_plus = r_hat_array_comp(self.folded_rank_normalised)
         self.r_hat = max(self.lame_r_hat, self.rank_normalised_r_hat, self.folded_rank_normalised_r_hat)
         print("r_hat is {:.4f} (max of normal ({:.4f}), rank normed ({:.4f}), folded rank normed ({:.4f}))".format(self.r_hat, self.lame_r_hat, self.rank_normalised_r_hat, self.folded_rank_normalised_r_hat))
 
         t = 1
         rhos = []
+        # use chains and var_hat_plus as desired to compute effective sample size
         local_chains = self.rank_normalised if rank_norm else self.chains
+        local_var_hat_plus = self.rank_normed_var_hat_plus if rank_norm else self.lame_var_hat_plus
+
+        # Estimate sample auto-correlation (could be done with Fourier, but Gelman doesn't elaborate)
         while True:
             V_t = np.sum((local_chains[:, t:] - local_chains[:, :-t]) ** 2)
-            rho_t = 1 - (V_t / (self.m * (self.n - t))) / (2 * self.var_hat_plus)
+            rho_t = 1 - (V_t / (self.m * (self.n - t))) / (2 * local_var_hat_plus)
             t += 1
             rhos.append(rho_t)
             if (t > 2 and t % 2 == 1 and rhos[-1] + rhos[-2] < 0) or t == self.n:
                 break
         self.n_eff = self.m * self.n / (1 + 2 * sum(rhos[:-2]))
         print("Effective number of samples is {}".format(self.n_eff))
-        return chains
 
-    def rank_histos(self):
-        count_max = 0
-        for i in range(self.m):
-            counts, _ = np.histogram(self.ranked[i])
-            count_max = max(count_max, counts.max())
-        for i in range(self.m):
-            plt.subplot(self.m / 2, 2, i+1)
-            plt.hist(self.ranked[i])
-            plt.ylim(top=count_max)
-        plt.show()
-
-    def folded_rank_histos(self):
-        count_max = 0
-        for i in range(self.m):
-            counts, _ = np.histogram(self.ranked[i])
-            count_max = max(count_max, counts.max())
-        for i in range(self.m):
-            plt.subplot(self.m / 2, 2, i+1)
-            plt.hist(self.folded_ranked[i])
-            plt.ylim(top=count_max)
-        plt.show()
+        return chains
 
-    def histos(self):
+    def histos(self, type='normal'):
+        """Plot histograms of different way to represent the features
+           TODO: It would also be interesting to look at all mode_indices, not just the ones left after reduction
+           Using 'ranked' gives Gelman's beloved rank histograms
+           Need to use r_hat_and_ess before this to initialise chain variables with a feature
+        """
+        if type == 'normal':
+            local_chain = self.chains
+        elif type == 'folded_rank':
+            local_chain = self.folded_ranked
+        elif type == 'ranked':
+            local_chain = self.ranked
         count_max = 0
         for i in range(self.m):
-            counts, _ = np.histogram(self.ranked[i])
+            counts, _ = np.histogram(local_chain[i])
             count_max = max(count_max, counts.max())
+        _, bins = np.histogram(local_chain)
         for i in range(self.m):
-            plt.subplot(self.m / 2, 2, i+1)
-            plt.hist(self.chains[i])
+            plt.subplot(self.m // 2, 2, i+1)
+            plt.hist(local_chain[i], bins=bins)
             plt.ylim(top=count_max)
         plt.show()
 
@@ -1161,14 +1109,13 @@ if __name__ == "__main__":
     # test.r_hat_and_ess(return_ascending_shuffled, False)
     # quit()
 
-    check_r_hats = True
+    check_r_hats = False
     if check_r_hats:
         subjects = list(loading_info.keys())
         subjects = ['KS014']
         for subject in subjects:
             # if subject.startswith('GLM'):
             #     continue
-            print()
             print("_________________________________")
             print(subject)
             fit_num = loading_info[subject]['fit_nums'][-1]
diff --git a/dyn_glm_chain_analysis_unused_funcs.py b/dyn_glm_chain_analysis_unused_funcs.py
index abd755e3cf7cfbf9d2e95ade4bca9738f05f084d..ac1a9df76aca055744a19ef471485421707c027e 100644
--- a/dyn_glm_chain_analysis_unused_funcs.py
+++ b/dyn_glm_chain_analysis_unused_funcs.py
@@ -36,6 +36,46 @@ class MCMC_result_list:
             plt.plot(res.assign_counts.max(axis=1))
         plt.show()
 
+    def return_chains(self, func, model_for_loop, rank_norm=True, mode_indices=None):
+        # Following Gelman page 284
+
+        if mode_indices is None:
+            if func.__name__ == "temp_state_glm_func":
+                chains = np.zeros((self.m, self.n, 4))
+            else:
+                chains = np.zeros((self.m, self.n))
+            if model_for_loop:
+                for j, result in enumerate(self.results):
+                    for i in range(self.n):
+                        chains[j, i] = func(result.models[i])
+            else:
+                for i in range(self.m):
+                    chains[i] = func(self.results[i])
+        else:
+            # this should probably be functioned up, exists elsewhere
+            inds = []
+            for lims in [range(i * self.n, (i + 1) * self.n) for i in range(self.m)]:
+                inds.append([ind - lims[0] for ind in mode_indices if ind in lims])
+            lens = [len(ind) for ind in inds]
+            min_len = min([l for l in lens if l != 0])
+            n_remaining_chains = len([l for l in lens if l != 0])
+            print("{} chains left with a len of {}".format(n_remaining_chains, min_len))
+            chains = np.zeros((n_remaining_chains, min_len))
+            print(func.__name__)
+            if func.__name__ == "temp_state_glm_func":
+                chains = np.zeros((n_remaining_chains, min_len, 4))
+            else:
+                chains = np.zeros((n_remaining_chains, min_len))
+            counter = -1
+            for i, ind in enumerate(inds):
+                if len(ind) == 0:
+                    continue
+                counter += 1
+                step = len(ind) // min_len
+                up_to = min_len * step
+                chains[counter] = func(self.results[i], ind[:up_to:step])
+        return chains
+
 
 class MCMC_result:
 
diff --git a/mcmc_chain_analysis.py b/mcmc_chain_analysis.py
index c8de12596afbb348ca65a77d28db0726031a52f5..7dc376b63e494f5ca262931f05d2ea558b08b70e 100644
--- a/mcmc_chain_analysis.py
+++ b/mcmc_chain_analysis.py
@@ -39,7 +39,8 @@ def ll_func(x): return x.sample_lls[-x.n_samples:]
 
 
 def r_hat_array_comp(chains):
-    """Computes R^hat on an array of features, following Gelman p. 284f"""
+    """Computes R^hat on an array of features, following Gelman p. 284f
+       Return R^hat itself, and var^hat^plus, which is needed for the effective sample size calculation"""
     m, n = chains.shape  # number of chains, length of chains
     psi_dot_j = np.mean(chains, axis=1)
     psi_dot_dot = np.mean(psi_dot_j)
@@ -108,6 +109,7 @@ def eval_simple_r_hat(chains):
 
 
 def comp_multi_r_hat(chains, rank_normalised, folded_rank_normalised):
+    """Compute full set of R^hat's, given the appropriately transformed chains."""
     lame_r_hat, _ = r_hat_array_comp(chains)
     rank_normalised_r_hat, _ = r_hat_array_comp(rank_normalised)
     folded_rank_normalised_r_hat, _ = r_hat_array_comp(folded_rank_normalised)