From d06a0b171b6c57a074e5e6d4cf862896140a5e08 Mon Sep 17 00:00:00 2001
From: SebastianBruijns <>
Date: Sat, 11 Feb 2023 22:08:44 +0100
Subject: [PATCH] code cleanup and commenting

---
 .../mcmc_chain_analysis.cpython-37.pyc        | Bin 5313 -> 5771 bytes
 dyn_glm_chain_analysis.py                     | 111 +++++-------------
 dyn_glm_chain_analysis_unused_funcs.py        |  40 +++++++
 mcmc_chain_analysis.py                        |   4 +-
 4 files changed, 72 insertions(+), 83 deletions(-)

diff --git a/__pycache__/mcmc_chain_analysis.cpython-37.pyc b/__pycache__/mcmc_chain_analysis.cpython-37.pyc
index 9f093ef55cf22048cd989c21378e8d782850879c..dfcc898f4df0897c89643b35287844e38e54193b 100644
GIT binary patch
delta 613
zcma)(zlsz=5XPO|1N(NI{euTaLPZ|RXsQT@|1LPpJTdf^Vz%a{y`Ju&yJwf(#LUF7
zA3!b>g9F(IIE}>QF!B|Q#8<Fpc2^J+D`}|i{_6LA_4(%8+x^E$(!0cU>(`I$`?Zh#
za}m7{7vf>jgH;`pudRl|qXIoh??~k-pq{`KZPPodRP9oj6*4P8I?zNDnm}$W_=14)
zoH8$`1P;qe@gWxk8LF&S=%vx2zkMEcmlxY_qUiqeMSK=T`^&9l(21Ufoq@e~`oxq~
z?O92CY56~SF!_pS*R^w$bZsvuWL5&S7U$u^;Ku&O#9G4vDz&yUCF^)RqoV^;RmL&3
zH$YQVV<#Kdu^qF)@NJOCFF7LD2GtFZhI@AxpEg*pY979N|9?>Q50uer-iEc7@$-uQ
zcuYYXcBv*(9|MPKQJ-+jmdLibmqLjr&7Jg6*1UrpohKV!4<E!uL%r3kw;R3=pFdq3
zY^d{^gZMkj6pndKx&^_ivZk^UJvC@N)A2$}PADr-YIi4GAI)OZ$csk4vwSz&?L@Qi
S<?7F5BzhtdS41o>3-KFuWytLS

delta 185
zcmeCyJ*dg&#LLUY00bxAJWu{1x{+@`2cyH}a~uwn%Q*KlGHo{FYGY)y+`N~29V4Uj
z<{VyrCdPotHT*>?p3eEX1*Ija3PEuhi6!dA3TgR83dN}<#R~ap3Tdf{C8b5Flhp*&
zCQlGZ<BkDZPz<()X)>drEu-HgA$#sLkRS`lZpO)8LcxrFlb;Kfb7zA@Il!WuGlkWe
X7>g!P6n)Dpz`@1A%fSOASvlANxmhnp

diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py
index 71e06373..04907a5a 100644
--- a/dyn_glm_chain_analysis.py
+++ b/dyn_glm_chain_analysis.py
@@ -72,52 +72,11 @@ class MCMC_result_list:
             state_nums[i] = state_num_helper(0.05)(self.results[i])
         return state_nums
 
-    def return_chains(self, func, model_for_loop, rank_norm=True, mode_indices=None):
-        # Following Gelman page 284
-
-        if mode_indices is None:
-            if func.__name__ == "temp_state_glm_func":
-                chains = np.zeros((self.m, self.n, 4))
-            else:
-                chains = np.zeros((self.m, self.n))
-            if model_for_loop:
-                for j, result in enumerate(self.results):
-                    for i in range(self.n):
-                        chains[j, i] = func(result.models[i])
-            else:
-                for i in range(self.m):
-                    chains[i] = func(self.results[i])
-        else:
-            inds = []
-            for lims in [range(i * self.n, (i + 1) * self.n) for i in range(self.m)]:
-                inds.append([ind - lims[0] for ind in mode_indices if ind in lims])
-            lens = [len(ind) for ind in inds]
-            min_len = min([l for l in lens if l != 0])
-            n_remaining_chains = len([l for l in lens if l != 0])
-            print("{} chains left with a len of {}".format(n_remaining_chains, min_len))
-            # print(lens)
-            chains = np.zeros((n_remaining_chains, min_len))
-            print(func.__name__)
-            if func.__name__ == "temp_state_glm_func":
-                chains = np.zeros((n_remaining_chains, min_len, 4))
-            else:
-                chains = np.zeros((n_remaining_chains, min_len))
-            counter = -1
-            for i, ind in enumerate(inds):
-                if len(ind) == 0:
-                    continue
-                counter += 1
-                step = len(ind) // min_len
-                up_to = min_len * step
-                # print([j + i * 640 for j in ind[:up_to:step]])
-                chains[counter] = func(self.results[i], ind[:up_to:step])
-        return chains
-
     def r_hat_and_ess(self, func, model_for_loop, rank_norm=True, mode_indices=None):
-        # Following Gelman page 284
+        """Compute all kinds of R^hat's and the effective sample size with the intermediate steps
+           Following Gelman page 284f"""
 
         if mode_indices is None:
-
             chains = np.zeros((self.m, self.n))
             if model_for_loop:
                 for j, result in enumerate(self.results):
@@ -130,17 +89,11 @@ class MCMC_result_list:
             inds = []
             for lims in [range(i * self.n, (i + 1) * self.n) for i in range(self.m)]:
                 inds.append([ind - lims[0] for ind in mode_indices if ind in lims])
-            # inds = [[] for i in range(self.m)]
-            # for i, m in enumerate([item for sublist in self.results for item in sublist.models]):
-            #     if i not in mode_indices:
-            #         continue
-            #     inds[i // self.n].append(i % self.n)
             lens = [len(ind) for ind in inds]
-            min_len = min([l for l in lens if l != 0])
-            n_remaining_chains = len([l for l in lens if l != 0])
+            min_len = min([li for li in lens if li != 0])
+            n_remaining_chains = len([li for li in lens if li != 0])
             self.m, self.n = n_remaining_chains, min_len
             print("{} chains left with a len of {}".format(n_remaining_chains, min_len))
-            # print(lens)
             chains = np.zeros((n_remaining_chains, min_len))
             counter = -1
             for i, ind in enumerate(inds):
@@ -152,59 +105,54 @@ class MCMC_result_list:
                 chains[counter] = func(self.results[i], ind[:up_to:step])
 
         self.chains = chains
-
         self.rank_normalised, self.folded_rank_normalised, self.ranked, self.folded_ranked = rank_inv_normal_transform(self.chains)
 
+        # Compute all R^hats, use the worst
         self.lame_r_hat, self.lame_var_hat_plus = r_hat_array_comp(self.chains)
-        self.rank_normalised_r_hat, self.var_hat_plus = r_hat_array_comp(self.rank_normalised)
+        self.rank_normalised_r_hat, self.rank_normed_var_hat_plus = r_hat_array_comp(self.rank_normalised)
         self.folded_rank_normalised_r_hat, self.folded_rank_normalised_var_hat_plus = r_hat_array_comp(self.folded_rank_normalised)
         self.r_hat = max(self.lame_r_hat, self.rank_normalised_r_hat, self.folded_rank_normalised_r_hat)
         print("r_hat is {:.4f} (max of normal ({:.4f}), rank normed ({:.4f}), folded rank normed ({:.4f}))".format(self.r_hat, self.lame_r_hat, self.rank_normalised_r_hat, self.folded_rank_normalised_r_hat))
 
         t = 1
         rhos = []
+        # use chains and var_hat_plus as desired to compute effective sample size
         local_chains = self.rank_normalised if rank_norm else self.chains
+        local_var_hat_plus = self.rank_normed_var_hat_plus if rank_norm else self.lame_var_hat_plus
+
+        # Estimate sample auto-correlation (could be done with Fourier, but Gelman doesn't elaborate)
         while True:
             V_t = np.sum((local_chains[:, t:] - local_chains[:, :-t]) ** 2)
-            rho_t = 1 - (V_t / (self.m * (self.n - t))) / (2 * self.var_hat_plus)
+            rho_t = 1 - (V_t / (self.m * (self.n - t))) / (2 * local_var_hat_plus)
             t += 1
             rhos.append(rho_t)
             if (t > 2 and t % 2 == 1 and rhos[-1] + rhos[-2] < 0) or t == self.n:
                 break
         self.n_eff = self.m * self.n / (1 + 2 * sum(rhos[:-2]))
         print("Effective number of samples is {}".format(self.n_eff))
-        return chains
 
-    def rank_histos(self):
-        count_max = 0
-        for i in range(self.m):
-            counts, _ = np.histogram(self.ranked[i])
-            count_max = max(count_max, counts.max())
-        for i in range(self.m):
-            plt.subplot(self.m / 2, 2, i+1)
-            plt.hist(self.ranked[i])
-            plt.ylim(top=count_max)
-        plt.show()
-
-    def folded_rank_histos(self):
-        count_max = 0
-        for i in range(self.m):
-            counts, _ = np.histogram(self.ranked[i])
-            count_max = max(count_max, counts.max())
-        for i in range(self.m):
-            plt.subplot(self.m / 2, 2, i+1)
-            plt.hist(self.folded_ranked[i])
-            plt.ylim(top=count_max)
-        plt.show()
+        return chains
 
-    def histos(self):
+    def histos(self, type='normal'):
+        """Plot histograms of different way to represent the features
+           TODO: It would also be interesting to look at all mode_indices, not just the ones left after reduction
+           Using 'ranked' gives Gelman's beloved rank histograms
+           Need to use r_hat_and_ess before this to initialise chain variables with a feature
+        """
+        if type == 'normal':
+            local_chain = self.chains
+        elif type == 'folded_rank':
+            local_chain = self.folded_ranked
+        elif type == 'ranked':
+            local_chain = self.ranked
         count_max = 0
         for i in range(self.m):
-            counts, _ = np.histogram(self.ranked[i])
+            counts, _ = np.histogram(local_chain[i])
             count_max = max(count_max, counts.max())
+        _, bins = np.histogram(local_chain)
         for i in range(self.m):
-            plt.subplot(self.m / 2, 2, i+1)
-            plt.hist(self.chains[i])
+            plt.subplot(self.m // 2, 2, i+1)
+            plt.hist(local_chain[i], bins=bins)
             plt.ylim(top=count_max)
         plt.show()
 
@@ -1161,14 +1109,13 @@ if __name__ == "__main__":
     # test.r_hat_and_ess(return_ascending_shuffled, False)
     # quit()
 
-    check_r_hats = True
+    check_r_hats = False
     if check_r_hats:
         subjects = list(loading_info.keys())
         subjects = ['KS014']
         for subject in subjects:
             # if subject.startswith('GLM'):
             #     continue
-            print()
             print("_________________________________")
             print(subject)
             fit_num = loading_info[subject]['fit_nums'][-1]
diff --git a/dyn_glm_chain_analysis_unused_funcs.py b/dyn_glm_chain_analysis_unused_funcs.py
index abd755e3..ac1a9df7 100644
--- a/dyn_glm_chain_analysis_unused_funcs.py
+++ b/dyn_glm_chain_analysis_unused_funcs.py
@@ -36,6 +36,46 @@ class MCMC_result_list:
             plt.plot(res.assign_counts.max(axis=1))
         plt.show()
 
+    def return_chains(self, func, model_for_loop, rank_norm=True, mode_indices=None):
+        # Following Gelman page 284
+
+        if mode_indices is None:
+            if func.__name__ == "temp_state_glm_func":
+                chains = np.zeros((self.m, self.n, 4))
+            else:
+                chains = np.zeros((self.m, self.n))
+            if model_for_loop:
+                for j, result in enumerate(self.results):
+                    for i in range(self.n):
+                        chains[j, i] = func(result.models[i])
+            else:
+                for i in range(self.m):
+                    chains[i] = func(self.results[i])
+        else:
+            # this should probably be functioned up, exists elsewhere
+            inds = []
+            for lims in [range(i * self.n, (i + 1) * self.n) for i in range(self.m)]:
+                inds.append([ind - lims[0] for ind in mode_indices if ind in lims])
+            lens = [len(ind) for ind in inds]
+            min_len = min([l for l in lens if l != 0])
+            n_remaining_chains = len([l for l in lens if l != 0])
+            print("{} chains left with a len of {}".format(n_remaining_chains, min_len))
+            chains = np.zeros((n_remaining_chains, min_len))
+            print(func.__name__)
+            if func.__name__ == "temp_state_glm_func":
+                chains = np.zeros((n_remaining_chains, min_len, 4))
+            else:
+                chains = np.zeros((n_remaining_chains, min_len))
+            counter = -1
+            for i, ind in enumerate(inds):
+                if len(ind) == 0:
+                    continue
+                counter += 1
+                step = len(ind) // min_len
+                up_to = min_len * step
+                chains[counter] = func(self.results[i], ind[:up_to:step])
+        return chains
+
 
 class MCMC_result:
 
diff --git a/mcmc_chain_analysis.py b/mcmc_chain_analysis.py
index c8de1259..7dc376b6 100644
--- a/mcmc_chain_analysis.py
+++ b/mcmc_chain_analysis.py
@@ -39,7 +39,8 @@ def ll_func(x): return x.sample_lls[-x.n_samples:]
 
 
 def r_hat_array_comp(chains):
-    """Computes R^hat on an array of features, following Gelman p. 284f"""
+    """Computes R^hat on an array of features, following Gelman p. 284f
+       Return R^hat itself, and var^hat^plus, which is needed for the effective sample size calculation"""
     m, n = chains.shape  # number of chains, length of chains
     psi_dot_j = np.mean(chains, axis=1)
     psi_dot_dot = np.mean(psi_dot_j)
@@ -108,6 +109,7 @@ def eval_simple_r_hat(chains):
 
 
 def comp_multi_r_hat(chains, rank_normalised, folded_rank_normalised):
+    """Compute full set of R^hat's, given the appropriately transformed chains."""
     lame_r_hat, _ = r_hat_array_comp(chains)
     rank_normalised_r_hat, _ = r_hat_array_comp(rank_normalised)
     folded_rank_normalised_r_hat, _ = r_hat_array_comp(folded_rank_normalised)
-- 
GitLab