From d06a0b171b6c57a074e5e6d4cf862896140a5e08 Mon Sep 17 00:00:00 2001 From: SebastianBruijns <> Date: Sat, 11 Feb 2023 22:08:44 +0100 Subject: [PATCH] code cleanup and commenting --- .../mcmc_chain_analysis.cpython-37.pyc | Bin 5313 -> 5771 bytes dyn_glm_chain_analysis.py | 111 +++++------------- dyn_glm_chain_analysis_unused_funcs.py | 40 +++++++ mcmc_chain_analysis.py | 4 +- 4 files changed, 72 insertions(+), 83 deletions(-) diff --git a/__pycache__/mcmc_chain_analysis.cpython-37.pyc b/__pycache__/mcmc_chain_analysis.cpython-37.pyc index 9f093ef55cf22048cd989c21378e8d782850879c..dfcc898f4df0897c89643b35287844e38e54193b 100644 GIT binary patch delta 613 zcma)(zlsz=5XPO|1N(NI{euTaLPZ|RXsQT@|1LPpJTdf^Vz%a{y`Ju&yJwf(#LUF7 zA3!b>g9F(IIE}>QF!B|Q#8<Fpc2^J+D`}|i{_6LA_4(%8+x^E$(!0cU>(`I$`?Zh# za}m7{7vf>jgH;`pudRl|qXIoh??~k-pq{`KZPPodRP9oj6*4P8I?zNDnm}$W_=14) zoH8$`1P;qe@gWxk8LF&S=%vx2zkMEcmlxY_qUiqeMSK=T`^&9l(21Ufoq@e~`oxq~ z?O92CY56~SF!_pS*R^w$bZsvuWL5&S7U$u^;Ku&O#9G4vDz&yUCF^)RqoV^;RmL&3 zH$YQVV<#Kdu^qF)@NJOCFF7LD2GtFZhI@AxpEg*pY979N|9?>Q50uer-iEc7@$-uQ zcuYYXcBv*(9|MPKQJ-+jmdLibmqLjr&7Jg6*1UrpohKV!4<E!uL%r3kw;R3=pFdq3 zY^d{^gZMkj6pndKx&^_ivZk^UJvC@N)A2$}PADr-YIi4GAI)OZ$csk4vwSz&?L@Qi S<?7F5BzhtdS41o>3-KFuWytLS delta 185 zcmeCyJ*dg&#LLUY00bxAJWu{1x{+@`2cyH}a~uwn%Q*KlGHo{FYGY)y+`N~29V4Uj z<{VyrCdPotHT*>?p3eEX1*Ija3PEuhi6!dA3TgR83dN}<#R~ap3Tdf{C8b5Flhp*& zCQlGZ<BkDZPz<()X)>drEu-HgA$#sLkRS`lZpO)8LcxrFlb;Kfb7zA@Il!WuGlkWe X7>g!P6n)Dpz`@1A%fSOASvlANxmhnp diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py index 71e06373..04907a5a 100644 --- a/dyn_glm_chain_analysis.py +++ b/dyn_glm_chain_analysis.py @@ -72,52 +72,11 @@ class MCMC_result_list: state_nums[i] = state_num_helper(0.05)(self.results[i]) return state_nums - def return_chains(self, func, model_for_loop, rank_norm=True, mode_indices=None): - # Following Gelman page 284 - - if mode_indices is None: - if func.__name__ == "temp_state_glm_func": - chains = np.zeros((self.m, self.n, 4)) - else: - chains = np.zeros((self.m, self.n)) - if model_for_loop: - for j, result in enumerate(self.results): - for i in range(self.n): - chains[j, i] = func(result.models[i]) - else: - for i in range(self.m): - chains[i] = func(self.results[i]) - else: - inds = [] - for lims in [range(i * self.n, (i + 1) * self.n) for i in range(self.m)]: - inds.append([ind - lims[0] for ind in mode_indices if ind in lims]) - lens = [len(ind) for ind in inds] - min_len = min([l for l in lens if l != 0]) - n_remaining_chains = len([l for l in lens if l != 0]) - print("{} chains left with a len of {}".format(n_remaining_chains, min_len)) - # print(lens) - chains = np.zeros((n_remaining_chains, min_len)) - print(func.__name__) - if func.__name__ == "temp_state_glm_func": - chains = np.zeros((n_remaining_chains, min_len, 4)) - else: - chains = np.zeros((n_remaining_chains, min_len)) - counter = -1 - for i, ind in enumerate(inds): - if len(ind) == 0: - continue - counter += 1 - step = len(ind) // min_len - up_to = min_len * step - # print([j + i * 640 for j in ind[:up_to:step]]) - chains[counter] = func(self.results[i], ind[:up_to:step]) - return chains - def r_hat_and_ess(self, func, model_for_loop, rank_norm=True, mode_indices=None): - # Following Gelman page 284 + """Compute all kinds of R^hat's and the effective sample size with the intermediate steps + Following Gelman page 284f""" if mode_indices is None: - chains = np.zeros((self.m, self.n)) if model_for_loop: for j, result in enumerate(self.results): @@ -130,17 +89,11 @@ class MCMC_result_list: inds = [] for lims in [range(i * self.n, (i + 1) * self.n) for i in range(self.m)]: inds.append([ind - lims[0] for ind in mode_indices if ind in lims]) - # inds = [[] for i in range(self.m)] - # for i, m in enumerate([item for sublist in self.results for item in sublist.models]): - # if i not in mode_indices: - # continue - # inds[i // self.n].append(i % self.n) lens = [len(ind) for ind in inds] - min_len = min([l for l in lens if l != 0]) - n_remaining_chains = len([l for l in lens if l != 0]) + min_len = min([li for li in lens if li != 0]) + n_remaining_chains = len([li for li in lens if li != 0]) self.m, self.n = n_remaining_chains, min_len print("{} chains left with a len of {}".format(n_remaining_chains, min_len)) - # print(lens) chains = np.zeros((n_remaining_chains, min_len)) counter = -1 for i, ind in enumerate(inds): @@ -152,59 +105,54 @@ class MCMC_result_list: chains[counter] = func(self.results[i], ind[:up_to:step]) self.chains = chains - self.rank_normalised, self.folded_rank_normalised, self.ranked, self.folded_ranked = rank_inv_normal_transform(self.chains) + # Compute all R^hats, use the worst self.lame_r_hat, self.lame_var_hat_plus = r_hat_array_comp(self.chains) - self.rank_normalised_r_hat, self.var_hat_plus = r_hat_array_comp(self.rank_normalised) + self.rank_normalised_r_hat, self.rank_normed_var_hat_plus = r_hat_array_comp(self.rank_normalised) self.folded_rank_normalised_r_hat, self.folded_rank_normalised_var_hat_plus = r_hat_array_comp(self.folded_rank_normalised) self.r_hat = max(self.lame_r_hat, self.rank_normalised_r_hat, self.folded_rank_normalised_r_hat) print("r_hat is {:.4f} (max of normal ({:.4f}), rank normed ({:.4f}), folded rank normed ({:.4f}))".format(self.r_hat, self.lame_r_hat, self.rank_normalised_r_hat, self.folded_rank_normalised_r_hat)) t = 1 rhos = [] + # use chains and var_hat_plus as desired to compute effective sample size local_chains = self.rank_normalised if rank_norm else self.chains + local_var_hat_plus = self.rank_normed_var_hat_plus if rank_norm else self.lame_var_hat_plus + + # Estimate sample auto-correlation (could be done with Fourier, but Gelman doesn't elaborate) while True: V_t = np.sum((local_chains[:, t:] - local_chains[:, :-t]) ** 2) - rho_t = 1 - (V_t / (self.m * (self.n - t))) / (2 * self.var_hat_plus) + rho_t = 1 - (V_t / (self.m * (self.n - t))) / (2 * local_var_hat_plus) t += 1 rhos.append(rho_t) if (t > 2 and t % 2 == 1 and rhos[-1] + rhos[-2] < 0) or t == self.n: break self.n_eff = self.m * self.n / (1 + 2 * sum(rhos[:-2])) print("Effective number of samples is {}".format(self.n_eff)) - return chains - def rank_histos(self): - count_max = 0 - for i in range(self.m): - counts, _ = np.histogram(self.ranked[i]) - count_max = max(count_max, counts.max()) - for i in range(self.m): - plt.subplot(self.m / 2, 2, i+1) - plt.hist(self.ranked[i]) - plt.ylim(top=count_max) - plt.show() - - def folded_rank_histos(self): - count_max = 0 - for i in range(self.m): - counts, _ = np.histogram(self.ranked[i]) - count_max = max(count_max, counts.max()) - for i in range(self.m): - plt.subplot(self.m / 2, 2, i+1) - plt.hist(self.folded_ranked[i]) - plt.ylim(top=count_max) - plt.show() + return chains - def histos(self): + def histos(self, type='normal'): + """Plot histograms of different way to represent the features + TODO: It would also be interesting to look at all mode_indices, not just the ones left after reduction + Using 'ranked' gives Gelman's beloved rank histograms + Need to use r_hat_and_ess before this to initialise chain variables with a feature + """ + if type == 'normal': + local_chain = self.chains + elif type == 'folded_rank': + local_chain = self.folded_ranked + elif type == 'ranked': + local_chain = self.ranked count_max = 0 for i in range(self.m): - counts, _ = np.histogram(self.ranked[i]) + counts, _ = np.histogram(local_chain[i]) count_max = max(count_max, counts.max()) + _, bins = np.histogram(local_chain) for i in range(self.m): - plt.subplot(self.m / 2, 2, i+1) - plt.hist(self.chains[i]) + plt.subplot(self.m // 2, 2, i+1) + plt.hist(local_chain[i], bins=bins) plt.ylim(top=count_max) plt.show() @@ -1161,14 +1109,13 @@ if __name__ == "__main__": # test.r_hat_and_ess(return_ascending_shuffled, False) # quit() - check_r_hats = True + check_r_hats = False if check_r_hats: subjects = list(loading_info.keys()) subjects = ['KS014'] for subject in subjects: # if subject.startswith('GLM'): # continue - print() print("_________________________________") print(subject) fit_num = loading_info[subject]['fit_nums'][-1] diff --git a/dyn_glm_chain_analysis_unused_funcs.py b/dyn_glm_chain_analysis_unused_funcs.py index abd755e3..ac1a9df7 100644 --- a/dyn_glm_chain_analysis_unused_funcs.py +++ b/dyn_glm_chain_analysis_unused_funcs.py @@ -36,6 +36,46 @@ class MCMC_result_list: plt.plot(res.assign_counts.max(axis=1)) plt.show() + def return_chains(self, func, model_for_loop, rank_norm=True, mode_indices=None): + # Following Gelman page 284 + + if mode_indices is None: + if func.__name__ == "temp_state_glm_func": + chains = np.zeros((self.m, self.n, 4)) + else: + chains = np.zeros((self.m, self.n)) + if model_for_loop: + for j, result in enumerate(self.results): + for i in range(self.n): + chains[j, i] = func(result.models[i]) + else: + for i in range(self.m): + chains[i] = func(self.results[i]) + else: + # this should probably be functioned up, exists elsewhere + inds = [] + for lims in [range(i * self.n, (i + 1) * self.n) for i in range(self.m)]: + inds.append([ind - lims[0] for ind in mode_indices if ind in lims]) + lens = [len(ind) for ind in inds] + min_len = min([l for l in lens if l != 0]) + n_remaining_chains = len([l for l in lens if l != 0]) + print("{} chains left with a len of {}".format(n_remaining_chains, min_len)) + chains = np.zeros((n_remaining_chains, min_len)) + print(func.__name__) + if func.__name__ == "temp_state_glm_func": + chains = np.zeros((n_remaining_chains, min_len, 4)) + else: + chains = np.zeros((n_remaining_chains, min_len)) + counter = -1 + for i, ind in enumerate(inds): + if len(ind) == 0: + continue + counter += 1 + step = len(ind) // min_len + up_to = min_len * step + chains[counter] = func(self.results[i], ind[:up_to:step]) + return chains + class MCMC_result: diff --git a/mcmc_chain_analysis.py b/mcmc_chain_analysis.py index c8de1259..7dc376b6 100644 --- a/mcmc_chain_analysis.py +++ b/mcmc_chain_analysis.py @@ -39,7 +39,8 @@ def ll_func(x): return x.sample_lls[-x.n_samples:] def r_hat_array_comp(chains): - """Computes R^hat on an array of features, following Gelman p. 284f""" + """Computes R^hat on an array of features, following Gelman p. 284f + Return R^hat itself, and var^hat^plus, which is needed for the effective sample size calculation""" m, n = chains.shape # number of chains, length of chains psi_dot_j = np.mean(chains, axis=1) psi_dot_dot = np.mean(psi_dot_j) @@ -108,6 +109,7 @@ def eval_simple_r_hat(chains): def comp_multi_r_hat(chains, rank_normalised, folded_rank_normalised): + """Compute full set of R^hat's, given the appropriately transformed chains.""" lame_r_hat, _ = r_hat_array_comp(chains) rank_normalised_r_hat, _ = r_hat_array_comp(rank_normalised) folded_rank_normalised_r_hat, _ = r_hat_array_comp(folded_rank_normalised) -- GitLab