From 0f72e7e5ef4d2fd0d00d792c4d0b0913e634034c Mon Sep 17 00:00:00 2001 From: SebastianBruijns <> Date: Sat, 11 Feb 2023 20:37:50 +0100 Subject: [PATCH] code updates, commenting work, clean up of testing code --- .../mcmc_chain_analysis.cpython-37.pyc | Bin 4481 -> 5313 bytes __pycache__/simplex_plot.cpython-37.pyc | Bin 2949 -> 2989 bytes dyn_glm_chain_analysis.py | 284 ++++++++++++++---- dynamic_GLMiHMM_consistency.py | 11 +- dynamic_GLMiHMM_fit.py | 12 +- mcmc_chain_analysis.py | 29 +- simplex_plot.py | 7 +- test_codes/pymc_compare/call.py | 16 +- .../pymc_compare/dynglm_optimisation_test.py | 2 +- test_codes/pymc_compare/gibbs_sample.py | 4 +- 10 files changed, 279 insertions(+), 86 deletions(-) diff --git a/__pycache__/mcmc_chain_analysis.cpython-37.pyc b/__pycache__/mcmc_chain_analysis.cpython-37.pyc index bdea663ae5d36ec0de32219c100df06016b9244f..9f093ef55cf22048cd989c21378e8d782850879c 100644 GIT binary patch delta 1620 zcmbVMOKclO7~a`kd;N|})z*ao0|AO@NhRokhbj_6DG_Q)T`HhLpvlgxy~*y5nHeXJ z<r1O5jUJ+r5Zq!aaYtOM3OJy0LPBZ}IlvJRTsVL@^3U2%6Ivw1zGvoteE;`<v+qrP zRj)i>E*C5~uK)Iv|JQ-dN)3IqdF0U|rWu-jA6@k51DhA$n0S!hNAHK>erlh$>QkFv ze!t%PyunmQBoY%GbP`{MJi)3(6c(d-)na%G^8hOmk_hai<bFsLqgZh;WjIkS91+oE zQemlJ1&0ZQlO%3Otbm7yPZGSyAcp`Kfn}TdVGy1eM&U`ch9l0GF(e)>`J({x7$o76 z83+W&oHDJ^O;piK_I0#RKk2-V9Q~m)zux=$KbMlClz6oj`YqfdD-7ZaHm(rLH~8C3 z_$*PR2@5U3WQ#{o^1s%0>bBzkS8X?)F1lNvbq8ot|L)Ef4?<UE&h*yNCo?UC^=Fwy zG^O)dg&xuGX4{@i-4AT)>OZp2d6{Kfd`z?YS?}-#uQxNt<90_ei5qX3mO^VyWkQfO z!)Cw;w4MRM0gocS8YWGAnnf|#w5RbCPn`%v2J+UC$inLv8y*X~>0K$-;Y%7|maPxu zUY&5L0|=bJK5W%2K+?_4lT7d2g`^{agB2#EQ5A?{NVw#&p%l_V3mEY34`EVFT8z*X z+Kv=C$N>7I!kA6(hO?l*%dHikHb_cCB<J&Vhl^|Zd&?tUM(q5Ye{CkyJ33-yOh|y^ zPNYK6=?)u4f;{md8-}=};?v9vx>4vB=S<4GhAD3p2J5{~#+(sQ3~rEQ`SFB{m_%WZ znC+C(OQG6hz&-_;`#Vd|;91}$ZVSEwL^a+vO$xX$dyGkG`r-2g_-ir%WOz9<cV=$9 zqXc9MlgTDt4PgTUNy13g%%(9^$fRB^ev4kwv!!=l?#=Hh{fJDBBvW#)6b!`S*2gjF zhH)px#1~wG{YWm;`ioLfeACd7{&Q^oQ27!H^rz(;>&s>~RVyl`r#J~S+g@9+#gItf zF}%Uy!tKGGl37ERrK~NaY>Kv=KI+>8o7x-p09|tW&X%h()Y(AP9oWd~J4@-<{Maso z>iZV;E<4Cl*$oR>`qxVF=pFnwK$W|^jW5LOTl&Y!q2ssPdo;HvAB=D1Gd?fe$tV90 zpZqPK-599_|K7qv?=gLK;$i*6M0e5^r-6!XPj(iUn6I|8preq}Z5MW~UaC&b2w0XF z9+q80Ra8S+xGU&j`W}zXTbu)%x-EcE>%}^POPqpF{Y~{adP)CToyi&lQPP9LetmY| XJ%>t#Le5FgYOatg=SuL*<h;KDksQv> delta 796 zcmZWm&uddb5Z>AM`$t|v@<<~!ZM2m74kDrm9>lgnN(AGfikd^kW*b8Epf9gb<pq>f z5Tw*F9z5pJe?iZp2T>2Af*$H0AbQrjvq@Ca?y@sGyWjWC?Ch)H!Hl<<O4*FYqbHy0 zFS0$a08e{!N1SP)34If4hA8!F*HT3>waIEZ!IxQZsQ3CaRB?%a1sA83w~)r~%2fcI zS1WXWLEC^lzR?mWU{hakd=hV)G6mpsy$J`gXsp6C-ZvsRjK7WB1{3Td7YuKiXHTd% zdH76d%PfMv8k?~($r9*Tu@##QnOv`i8c78Z>g0ZKGFY;9oLSnm^JFN?2kTOLq}kvH zSb+%zkJ}Yb9pmCR`>u0R(%d2Kx+DIT^4=+?MiQgqemPfiBv2ykDEXieeC%#GXQWDY zsN$s(YQlN%9aPXy-=Y1-bi=tU5raMhE@V2ef-f^)nss?PY>f+YhEl1nNF^H5%U7PG z)y@E85$i+51@x7eCm?v2Y=_0HuSS|sx**hq1I4XDEPrr|3pEDZ&Q1jT7+p4omuONL zqL1!H8^RpZr+8d{aX-V_!?6A{gt4Uni}Wr7YBT+-UA;GlKeMx#&8_>&t{Sbiqp(a5 z-!<CntF5&<?&M}pw`3<uWEjsw1`41<9z0so|Ba;&FA^r;AvKMa11O{JAA==4<<IN# cCUo(Jk;83&a>BMvQyI*(X_+of%hXKc5AaZ+-2eap diff --git a/__pycache__/simplex_plot.cpython-37.pyc b/__pycache__/simplex_plot.cpython-37.pyc index 1c26491f573cb6355e901c949147f3d921c87a30..8fa16cdde21c4a7d3d6daa3797043fc2e294e842 100644 GIT binary patch delta 539 zcmZn_Un?%>#LLUY00bFl9wp!9Wng#=;=lmsM7g!R0_n^t!YLvt!s*QEtP?MY@{6X3 zrHJ=3N6Doy1v6+$Z2Tj{$oOG0Khq~}fg1K^Mn;AbhAhF!dCW4C6`182g(f>N%kvAT z2&AwtWNK!tWrxZt*9+A!EZ|QOSjf1LF-1g#Aw{B>m64%_aRJ*xh6N%iJSmb3nc#d; z7{7#Jfmn*vLWnKmDbgu2DYCsFbu1GYi=F^wB^EL;GURapS&|@D3M0^sa#i!V6o4Qg zC%;6&ufkZNAU`v2vj9s8Bcs&hR@PcJ)`FtUypqX`Y$-aoxQY|YQsYZ9OL9_+w16RT zi!Hl6u_(Rx7E_L46lXzEepYI7NovYUh9bks6WP2O^(NnClNIA&v|=`3v|$!t6ky_F z;$viE<YFo^-pt1?&cvuQS&dU!#1Lp!5tuLnu@r%X&E!~46Gqm_lQ>NT1)~l049yi3 z4D^hRL8R&ABb-(ShrepX^dCBEpDy3}I&g>FRr_>^$W{9yHjpWxxYK01#h46ss;2y8 zRjyNv#*;sA>8Y3lrEW3jq~;Ylf>i2)2!tCT4w>xEtu3s@A<MzU!pg(Q0fu}`JWK%E Cvw+wD delta 468 zcmZ20-YPEV#LLUY00eK>o=<M(Wng#=;=lmMM7gzwLg~!uEDIT<q*A3*Wl|-Y8JiiS zWK%^`rBXyvL{r3inWN-VSr^DJWME`SVG3r@6yNwqh>`KdWPYYk%nVrqlX;k}83iYM zFw64`r3j?3FJx+FtYt5O$jES{h^2^^uq+UqJd0V{RXBwwMRFk%BSQ`20=9(=3q)Z2 z5{3n$DN+j=85w{Y#ZshGWKv{%LF!m0Fc#eb%8G-m&*K2HBsTwIp2Wx~Ik}g$cCsK_ zvQ89Nabj6&d`V_WPHK@R&}+BYvda^T(u;2~<rqeB78K=Yr6!l8rmSQr(x1G5&6`nY z@*6f;5i4c`MjK`UMgb-+CO$?sMlPlz!_D&S;!KR%lg&AmMf8Ct6@du@5K9qASWnL5 zG+|_!yqMEu@+(dYb@|rUfji``+NU4>su9zF=&JooNsv-DAfd@r#0{iwF(zv&Ob+5Y z#b`KLmRnEL6ePx+lbToL0Av;EfC!M2Ar>K7IXRzOTUe7rmV=3fg@=&?2>F<JnE03g D8J%re diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py index 3f1e31cf..71e06373 100644 --- a/dyn_glm_chain_analysis.py +++ b/dyn_glm_chain_analysis.py @@ -574,10 +574,10 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur noise = np.zeros(len(c_n_a))# np.random.rand(len(c_n_a)) * 0.4 - 0.2 mask = c_n_a[:, -1] == 0 - plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - c_n_a[mask, 0] + c_n_a[mask, 1]), 'o', c='b', ms=ms, label='Leftward') + plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - c_n_a[mask, 0] + c_n_a[mask, 1]), 'o', c='b', ms=ms, label='Leftward', alpha=0.6) mask = c_n_a[:, -1] == 1 - plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - c_n_a[mask, 0] + c_n_a[mask, 1]), 'o', c='r', ms=ms, label='Rightward') + plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - c_n_a[mask, 0] + c_n_a[mask, 1]), 'o', c='r', ms=ms, label='Rightward', alpha=0.6) plt.title("session #{} / {}".format(1+seq_num, test.results[0].n_sessions), size=26) # plt.yticks(*self.cont_ticks, size=22-2) @@ -852,7 +852,7 @@ def state_development_single_sample(test, indices, save=True, save_append='', sh return states_by_session, all_pmfs -def state_development(test, state_sets, indices, save=True, save_append='', show=True, dpi='figure', separate_pmf=False): +def state_development(test, state_sets, indices, save=True, save_append='', show=True, dpi='figure', separate_pmf=False, type_coloring=True): # Now also returns durs of state types and state type summary array state_sets = [np.array(s) for s in state_sets] @@ -1038,7 +1038,8 @@ def state_development(test, state_sets, indices, save=True, save_append='', show introductions_by_stage = np.zeros(3) covered_states = [] for i, d in enumerate(durs): - ax0.fill_between(range(dur_counter, 1 + dur_counter + d), 0.5, -0.5, color=type_colours[i], zorder=0, alpha=0.3) + if type_coloring: + ax0.fill_between(range(dur_counter, 1 + dur_counter + d), 0.5, -0.5, color=type_colours[i], zorder=0, alpha=0.3) dur_counter += d # find out during which state type which contrast was introduced @@ -1098,7 +1099,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show else: plt.close() - return states_by_session, all_pmfs, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs)) + return states_by_session, all_pmfs, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs)), introductions_by_stage def smart_divide(a, b): @@ -1154,14 +1155,13 @@ if __name__ == "__main__": r_hats = [] + # R^hat tests # test = MCMC_result_list([fake_result(100) for i in range(8)]) # test.r_hat_and_ess(return_ascending, False) # test.r_hat_and_ess(return_ascending_shuffled, False) # quit() - good = [] - bad = [] - check_r_hats = False + check_r_hats = True if check_r_hats: subjects = list(loading_info.keys()) subjects = ['KS014'] @@ -1188,11 +1188,6 @@ if __name__ == "__main__": r_hats.append((subject, final_r_hat)) loading_info[subject]['ignore'] = sol - if final_r_hat < 1.05: - good.append(subject) - else: - bad.append(subject) - print(r_hats) if fit_type == 'bias': json.dump(loading_info, open("canonical_infos_bias.json", 'w')) @@ -1221,10 +1216,15 @@ def state_type_durs(states, pmfs): pmf_counter += 1 state_types[pmf_type(pmf[1][pmf_counter][pmf[0]]), i] += s[i] # indexing horror - durs = (np.where(state_types[1] > 0.5)[0][0], - np.where(state_types[2] > 0.5)[0][0] - np.where(state_types[1] > 0.5)[0][0], - states.shape[1] - np.where(state_types[2] > 0.5)[0][0]) - if np.where(state_types[2] > 0.5)[0][0] < np.where(state_types[1] > 0.5)[0][0]: + if np.any(state_types[1] > 0.5): + durs = (np.where(state_types[1] > 0.5)[0][0], + np.where(state_types[2] > 0.5)[0][0] - np.where(state_types[1] > 0.5)[0][0], + states.shape[1] - np.where(state_types[2] > 0.5)[0][0]) + if np.where(state_types[2] > 0.5)[0][0] < np.where(state_types[1] > 0.5)[0][0]: + durs = (np.where(state_types[2] > 0.5)[0][0], + 0, + states.shape[1] - np.where(state_types[2] > 0.5)[0][0]) + else: durs = (np.where(state_types[2] > 0.5)[0][0], 0, states.shape[1] - np.where(state_types[2] > 0.5)[0][0]) @@ -1251,6 +1251,24 @@ def state_cluster_interpolation(states, pmfs): return state_types, state_trans, pmf_examples +def get_first_pmfs(states, pmfs): + # get the first pmf of every type, also where they are defined, and whether they are the first pmf of that state + earliest_sessions = [1000, 1000, 1000] + first_pmfs = [0, 0, 0, 0, 0, 0, 0, 0, 0] + changing_pmfs = [[0, 0], [0, 0]] + for state, pmf in zip(states, pmfs): + sessions = np.where(state)[0] + for i, (sess_pmf, sess) in enumerate(zip(pmf[1], sessions)): + if earliest_sessions[pmf_type(sess_pmf[pmf[0]])] > sess: + earliest_sessions[pmf_type(sess_pmf[pmf[0]])] = sess + first_pmfs[3 * pmf_type(sess_pmf[pmf[0]])] = sess_pmf + first_pmfs[1 + 3 * pmf_type(sess_pmf[pmf[0]])] = pmf[0] + first_pmfs[2 + 3 * pmf_type(sess_pmf[pmf[0]])] = i + if i != 0: + changing_pmfs[pmf_type(sess_pmf[pmf[0]]) - 1] = [pmf[0], pmf[1]] + return first_pmfs, changing_pmfs + + def plot_pmf_types(pmf_types, subject, fit_type, save=True, show=False): # Plot the different types of PMFs, all split up by their different types for i, pmfs in enumerate(pmf_types): @@ -1267,15 +1285,155 @@ def plot_pmf_types(pmf_types, subject, fit_type, save=True, show=False): else: plt.close() + def pmf_type(pmf): - if pmf[-1] - pmf[0] <= 0.15: + if pmf[-1] - pmf[0] < 0.2: return 0 - elif pmf[-1] - pmf[0] < 0.6 and np.abs(pmf[0] + pmf[-1] - 1) > 0.1: + elif pmf[-1] - pmf[0] < 0.6:# and np.abs(pmf[0] + pmf[-1] - 1) > 0.1: return 1 else: return 2 +type2color = {0: 'green', 1: 'blue', 2: 'red'} + +if False: + + all_changing_pmfs = pickle.load(open("changing_pmfs.p", 'rb')) + plt.figure(figsize=(16, 9)) + for i, pmf in enumerate(all_changing_pmfs): + plt.subplot(4, 7, i + 1) + for p in pmf[1]: + plt.plot(np.where(pmf[0])[0], p[pmf[0]], color=type2color[pmf_type(p)]) + plt.ylim(0, 1) + + sns.despine() + if i+1 != 22: + plt.gca().set_xticks([]) + plt.gca().set_yticks([]) + else: + plt.xlabel("Contrasts", size=22) + plt.ylabel("P(rightwards)", size=22) + plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=16) + plt.gca().set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=16) + + if i + 1 == 30: + break + + plt.tight_layout() + plt.savefig("changing pmfs") + plt.show() + quit() + + type_2_assyms = [] + tick_size = 14 + label_size = 26 + all_first_pmfs = pickle.load(open("pmfs_temp.p", 'rb')) + plt.figure(figsize=(16, 9)) + plt.subplot(1, 3, 1) + counter = [[0, 0], [0, 0]] + save_title = "all types" if False else "KS014 types" + if save_title == "KS014 types": + all_first_pmfs = {'KS014': all_first_pmfs['KS014']} + + for key in all_first_pmfs: + x = all_first_pmfs[key] + if type(x[0]) == int: + continue + linestyle = '-' if x[2] == 0 else '--' + plt.plot(np.where(x[1])[0], x[0][x[1]], linestyle=linestyle, c='g') + plt.ylim(0, 1) + plt.gca().set_xticks(np.arange(11), all_conts, size=tick_size) + plt.gca().set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size) + plt.gca().spines[['right', 'top']].set_visible(False) + plt.xlim(0, 10) + plt.xticks(rotation=45) + plt.gca().set_ylabel("P(rightwards)", size=label_size) + + plt.subplot(1, 3, 2) + for key in all_first_pmfs: + x = all_first_pmfs[key] + if type(x[3]) == int: + continue + type_2_assyms.append(np.abs(x[3][0] + x[3][-1] - 1)) + linestyle = '-' if x[5] == 0 else '--' + counter[0][0 if x[5] == 0 else 1] += 1 + if linestyle == '--': + continue + plt.plot(np.where(x[4])[0], x[3][x[4]], linestyle=linestyle, c='b') + plt.gca().set_yticks([]) + plt.ylim(0, 1) + plt.gca().set_xticks(np.arange(11), all_conts, size=tick_size) + plt.gca().spines[['right', 'top']].set_visible(False) + plt.xticks(rotation=45) + plt.xlim(0, 10) + plt.gca().set_xlabel("Contrasts", size=label_size) + + plt.subplot(1, 3, 3) + for key in all_first_pmfs: + x = all_first_pmfs[key] + if type(x[6]) == int: + continue + linestyle = '-' if x[8] == 0 else '--' + counter[1][0 if x[8] == 0 else 1] += 1 + if linestyle == '--': + continue + plt.plot(np.where(x[7])[0], x[6][x[7]], linestyle=linestyle, c='r') + plt.gca().set_yticks([]) + plt.ylim(0, 1) + plt.gca().set_xticks(np.arange(11), all_conts, size=tick_size) + plt.gca().spines[['right', 'top']].set_visible(False) + plt.xlim(0, 10) + plt.xticks(rotation=45) + + print(counter) + plt.tight_layout() + plt.savefig(save_title) + plt.show() + if save_title == "KS014 types": + quit() + + counter = 0 + fig, ax = plt.subplots(1, 3, figsize=(16, 9)) + for key in all_first_pmfs: + x = all_first_pmfs[key] + if type(x[3]) == int: + continue + linestyle = '-' if x[5] == 0 else '--' + if linestyle == '--': + continue + if np.abs(x[3][0] + x[3][-1] - 1) <= 0.1: + counter += 1 + use_ax = 2 + else: + use_ax = int(x[3][0] > 1 - x[3][-1]) + + ax[use_ax].plot(np.where(x[4])[0], x[3][x[4]], linestyle=linestyle, c='b') + ax[0].set_ylim(0, 1) + ax[0].set_xlim(0, 10) + ax[0].spines[['right', 'top']].set_visible(False) + ax[0].set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size) + ax[0].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) + ax[0].set_ylabel("P(rightwards)", size=label_size) + + ax[1].set_ylim(0, 1) + ax[1].set_xlim(0, 10) + ax[1].set_yticks([]) + ax[1].spines[['right', 'top']].set_visible(False) + ax[1].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) + ax[1].set_xlabel("Contrasts", size=label_size) + + ax[2].set_ylim(0, 1) + ax[2].set_xlim(0, 10) + ax[2].set_yticks([]) + ax[2].spines[['right', 'top']].set_visible(False) + ax[2].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) + print(counter) + plt.tight_layout() + plt.savefig("differentiate type 2") + plt.show() + quit() + if __name__ == "__main__": # visualise pmf types @@ -1301,40 +1459,11 @@ if __name__ == "__main__": loading_info = json.load(open("canonical_infos.json", 'r')) r_hats = json.load(open("canonical_info_r_hats.json", 'r')) no_good_pcas = ['NYU-06', 'SWC_023'] # no good rhat: 'ibl_witten_13' - subjects = list(loading_info.keys()) + subjects = ['KS014'] # list(loading_info.keys()) print(subjects) fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0] dur = 'yes' - fist_good_pmf = {'CSHL051': (6, 'left'), - 'CSHL059': (3, 'left'), - 'CSHL061': (4, 'right'), - 'CSHL062': (1, 'left'), - 'CSHL_007': (2, 'skip'), - 'CSHL_014': (4, 'left'), - 'CSHL_015': (2, 'right'), - 'CSHL_018': (3, 'skip'), - 'CSHL_020': (5, 'left'), # tough case - 'CSH_ZAD_001': (4, 'right'), - 'CSH_ZAD_011': (0, 'right'), - 'CSH_ZAD_022': (4, 'right'), - 'CSH_ZAD_025': (2, 'skip'), # eieiei - 'CSH_ZAD_026': (6, 'skip'), # gradual - 'ibl_witten_14': (3, 'left'), - 'ibl_witten_16': (3, 'right'), - # 'ibl_witten_18': (3, 'very weird'), # probably shouldn't be analysed - 'ibl_witten_19': (4, 'right'), - 'KS014': (4, 'right'), - 'KS015': (4, 'left'), - 'KS016': (3, 'skip'), # non-trivial - 'KS017': (3, 'right'), - 'KS021': (5, 'left'), - 'KS022': (4, 'right'), - 'KS023': (3, 'left'), - 'NYU-06': (9, 'right'), - 'SWC_023': (9, 'skip'), # gradual - 'ZM_1897': (5, 'right'), - 'ZM_3003': (1, 'skip')} # fig, ax = plt.subplots(1, 3, sharey=True, figsize=(16, 9)) thinning = 25 @@ -1356,6 +1485,12 @@ if __name__ == "__main__": not_yet = True abs_state_durs = [] + all_first_pmfs = {} + all_pmf_diffs = [] + all_pmf_asymms = [] + all_pmfs = [] + all_changing_pmfs = [] + all_intros = [] for subject in subjects: @@ -1366,30 +1501,44 @@ if __name__ == "__main__": results = [] try: - - print('loading canonical') test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'rb')) print('loaded canoncial result') mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'rb')) + quit() state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb')) # lapse differential # lapse_sides(test, [s for s in state_sets if len(s) > 40], mode_indices) # training overview - states, pmfs, durs, _, contrast_intro_type, intros_by_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=1, separate_pmf=True) + states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True) + all_intros.append(undiv_intros) intros_by_type_sum += intros_by_type - continue + first_pmfs, changing_pmfs = get_first_pmfs(states, pmfs) + for pmf in changing_pmfs: + if type(pmf[0]) == int: + continue + all_changing_pmfs.append(pmf) + all_first_pmfs[subject] = first_pmfs + for pmf in pmfs: + for p in pmf[1]: + all_pmf_diffs.append(p[-1] - p[0]) + all_pmf_asymms.append(np.abs(p[0] + p[-1] - 1)) + all_pmfs.append(p) contrast_intro_types.append(contrast_intro_type) # state_development_single_sample(test, [mode_indices[0]], show=True, separate_pmf=True, save=False) # session overview - # consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb')) - # consistencies /= consistencies[0, 0] - # contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True, consistencies=consistencies) + consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb')) + consistencies /= consistencies[0, 0] + contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True, consistencies=consistencies) # duration of different state types (and also percentage of type activities) abs_state_durs.append(durs) + simplex_durs = np.array(durs).reshape(1, 3) + print(simplex_durs / np.sum(simplex_durs)) + from simplex_plot import projectSimplex + print(projectSimplex(simplex_durs / simplex_durs.sum(1)[:, None])) continue # compute state type proportions and split the pmfs accordingly @@ -1567,6 +1716,7 @@ if __name__ == "__main__": plt.show() except FileNotFoundError as e: + continue print(e) r_hat = 1.5 for r in r_hats: @@ -1630,11 +1780,33 @@ if __name__ == "__main__": # state_appear.append(b[a == i][0] / (test.results[0].n_sessions - 1)) # state_dur.append(b[a == i].shape[0]) + # pickle.dump(all_first_pmfs, open("pmfs_temp.p", 'wb')) + # pickle.dump(all_changing_pmfs, open("changing_pmfs.p", 'wb')) + # + # a = [x for x, y in zip(all_pmf_asymms, all_pmf_diffs) if y >= 0.2] + # b = [y for x, y in zip(all_pmf_asymms, all_pmf_diffs) if y >= 0.2] + # plt.hist2d(a, b, bins=40) + # plt.show() + + if True: abs_state_durs = np.array(abs_state_durs) # pickle.dump(abs_state_durs, open("multi_chain_saves/abs_state_durs.p", 'wb')) abs_state_durs = pickle.load(open("multi_chain_saves/abs_state_durs.p", 'rb')) + print("Correlations") + from scipy.stats import pearsonr + print(pearsonr(abs_state_durs[:, 0], abs_state_durs[:, 1])) + print(pearsonr(abs_state_durs[:, 2], abs_state_durs[:, 1])) + print(pearsonr(abs_state_durs[:, 0], abs_state_durs[:, 2])) + + print(pearsonr(abs_state_durs.sum(1), abs_state_durs[:, 0])) + # (0.7338297529946006, 2.6332570579118393e-06) + print(pearsonr(abs_state_durs.sum(1), abs_state_durs[:, 1])) + # (0.35094585023228597, 0.052897046343413114) + print(pearsonr(abs_state_durs.sum(1), abs_state_durs[:, 2])) + # (0.7210260323745921, 4.747833912452452e-06) + from simplex_plot import plotSimplex plotSimplex(np.array(abs_state_durs), c='k', show=True) diff --git a/dynamic_GLMiHMM_consistency.py b/dynamic_GLMiHMM_consistency.py index a8b2cf93..8bf1924c 100644 --- a/dynamic_GLMiHMM_consistency.py +++ b/dynamic_GLMiHMM_consistency.py @@ -42,10 +42,10 @@ for subject in subjects: from_session = info_dict['bias_start'] if fit_type == 'bias' else 0 models = [] - n_inputs = 5 + n_regressors = 5 T = till_session - from_session + (fit_type != 'prebias') - obs_hypparams = {'n_inputs': n_inputs, 'T': T, 'prior_mean': np.zeros(n_inputs), - 'P_0': 2 * np.eye(n_inputs), 'Q': fit_variance * np.tile(np.eye(n_inputs), (T, 1, 1))} + obs_hypparams = {'n_regressors': n_regressors, 'T': T, 'prior_mean': np.zeros(n_regressors), 'jumplimit': 3, + 'P_0': 2 * np.eye(n_regressors), 'Q': fit_variance * np.tile(np.eye(n_regressors), (T, 1, 1))} dur_hypparams = dict(r_support=np.array([1, 2, 3, 5, 7, 10, 15, 21, 28, 36, 45, 55, 150]), r_probs=np.ones(13)/13., alpha_0=1, beta_0=1) @@ -78,7 +78,7 @@ for subject in subjects: bad_trials = data[:, 1] == 1 bad_trials[0] = True - mega_data = np.empty((np.sum(~bad_trials), n_inputs + 1)) + mega_data = np.empty((np.sum(~bad_trials), n_regressors + 1)) mega_data[:, 0] = np.maximum(data[~bad_trials, 0], 0) mega_data[:, 1] = np.abs(np.minimum(data[~bad_trials, 0], 0)) @@ -123,7 +123,10 @@ for subject in subjects: prev_res = pickle.load(open(save_title, 'rb')) +counter = 0 for p, m in zip(prev_res, models): + print(counter) + counter += 1 for od, nd in zip(p.obs_distns, m.obs_distns): assert np.allclose(od.weights, nd.weights) prof._prof.print_stats() diff --git a/dynamic_GLMiHMM_fit.py b/dynamic_GLMiHMM_fit.py index b05ad929..dbbe447a 100644 --- a/dynamic_GLMiHMM_fit.py +++ b/dynamic_GLMiHMM_fit.py @@ -204,14 +204,14 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)): models = [] if params['obs_dur'] == 'glm': - n_inputs = len(params['regressors']) + n_regressors = len(params['regressors']) T = till_session - from_session + (params['fit_type'] != 'prebias') - obs_hypparams = {'n_inputs': n_inputs, 'T': T, 'jumplimit': params['jumplimit'], 'prior_mean': params['init_mean'], - 'P_0': params['init_var'] * np.eye(n_inputs), 'Q': params['fit_variance'] * np.tile(np.eye(n_inputs), (T, 1, 1))} + obs_hypparams = {'n_regressors': n_regressors, 'T': T, 'jumplimit': params['jumplimit'], 'prior_mean': params['init_mean'], + 'P_0': params['init_var'] * np.eye(n_regressors), 'Q': params['fit_variance'] * np.tile(np.eye(n_regressors), (T, 1, 1))} obs_distns = [distributions.Dynamic_GLM(**obs_hypparams) for state in range(params['n_states'])] else: - n_inputs = 9 if params['fit_type'] == 'bias' else 11 - obs_hypparams = {'n_inputs': n_inputs * (1 + (params['conditioned_on'] != 'nothing')), 'n_outputs': 2, 'T': till_session - from_session + (params['fit_type'] != 'prebias'), + n_regressors = 9 if params['fit_type'] == 'bias' else 11 + obs_hypparams = {'n_regressors': n_regressors * (1 + (params['conditioned_on'] != 'nothing')), 'n_outputs': 2, 'T': till_session - from_session + (params['fit_type'] != 'prebias'), 'jumplimit': params['jumplimit'], 'sigmasq_states': params['fit_variance']} obs_distns = [Dynamic_Input_Categorical(**obs_hypparams) for state in range(params['n_states'])] @@ -280,7 +280,7 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)): mask[0] = False if params['fit_type'] == 'zoe_style': mask[90:] = False - mega_data = np.empty((np.sum(mask), n_inputs + 1)) + mega_data = np.empty((np.sum(mask), n_regressors + 1)) for i, reg in enumerate(params['regressors']): # positive numbers are contrast on the right diff --git a/mcmc_chain_analysis.py b/mcmc_chain_analysis.py index 2a99a676..c8de1259 100644 --- a/mcmc_chain_analysis.py +++ b/mcmc_chain_analysis.py @@ -8,6 +8,8 @@ import pickle def state_size_helper(n=0, mode_specific=False): + """Returns a function that returns the # of trials associated to the nth largest state in a sample + can be further specified to only look at specific samples, those of a mode""" if not mode_specific: def nth_largest_state_func(x): return np.partition(x.assign_counts, -1 - n, axis=1)[:, -1 - n] @@ -18,6 +20,8 @@ def state_size_helper(n=0, mode_specific=False): def state_num_helper(t, mode_specific=False): + """Returns a function that returns the # of states which have more trials than a percentage threshold t in a sample + can be further specified to only look at specific samples, those of a mode""" if not mode_specific: def state_num_func(x): return ((x.assign_counts / x.n_datapoints) > t).sum(1) else: @@ -35,14 +39,15 @@ def ll_func(x): return x.sample_lls[-x.n_samples:] def r_hat_array_comp(chains): - m, n = chains.shape + """Computes R^hat on an array of features, following Gelman p. 284f""" + m, n = chains.shape # number of chains, length of chains psi_dot_j = np.mean(chains, axis=1) psi_dot_dot = np.mean(psi_dot_j) B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot) ** 2) s_j_squared = np.sum((chains - psi_dot_j[:, None]) ** 2, axis=1) / (n - 1) W = np.mean(s_j_squared) var_hat_plus = (n - 1) / n * W + B / n - if W == 0: + if W == 0: # sometimes a feature has 0 variance # print("all the same value") return 1, 0 r_hat = np.sqrt(var_hat_plus / W) @@ -50,6 +55,7 @@ def r_hat_array_comp(chains): def eval_amortized_r_hat(chains, psi_dot_j, s_j_squared, m, n): + """Unused version in which some things were computed ahead of function to save time.""" psi_dot_dot = np.mean(psi_dot_j, axis=1) B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot[:, None]) ** 2, axis=1) W = np.mean(s_j_squared, axis=1) @@ -59,6 +65,7 @@ def eval_amortized_r_hat(chains, psi_dot_j, s_j_squared, m, n): def r_hat_array_comp_mult(chains): + """Compute R^hat of multiple features at once.""" _, m, n = chains.shape psi_dot_j = np.mean(chains, axis=2) psi_dot_dot = np.mean(psi_dot_j, axis=1) @@ -71,8 +78,8 @@ def r_hat_array_comp_mult(chains): def rank_inv_normal_transform(chains): - # Gelman paper Rank-normalization, folding, and localization: An improved R_hat for assessing convergence of MCMC - # ranking with average rank for ties + """Gelman paper Rank-normalization, folding, and localization: An improved R_hat for assessing convergence of MCMC + ranking with average rank for ties""" folded_chains = np.abs(chains - np.median(chains)) ranked = rankdata(chains).reshape(chains.shape) folded_ranked = rankdata(folded_chains).reshape(folded_chains.shape) @@ -83,6 +90,8 @@ def rank_inv_normal_transform(chains): def eval_r_hat(chains): + """Compute entire set of R^hat's for list of feature arrays, and return maximum across features. + Computes all R^hat versions, as opposed to eval_simple_r_hat""" r_hats = [] for chain in chains: rank_normalised, folded_rank_normalised, _, _ = rank_inv_normal_transform(chain) @@ -92,6 +101,8 @@ def eval_r_hat(chains): def eval_simple_r_hat(chains): + """Compute just simple R^hat's for list of feature arrays, and return maximum across features. + Computes only the simple type of R^hat, no folding or rank normalising, making it much faster""" r_hats, _ = r_hat_array_comp_mult(chains) return max(r_hats) @@ -103,21 +114,21 @@ def comp_multi_r_hat(chains, rank_normalised, folded_rank_normalised): return max(lame_r_hat, rank_normalised_r_hat, folded_rank_normalised_r_hat) -def sample_statistics(test, mode_indices, subject): +def sample_statistics(mode_indices, subject, period='prebias'): # prints out r_hats and sample sizes for given sample - test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb')) + test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, period), 'rb')) test.r_hat_and_ess(state_size_helper(1), False) test.r_hat_and_ess(state_size_helper(1, mode_specific=True), False, mode_indices=mode_indices) print() - test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb')) + test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, period), 'rb')) test.r_hat_and_ess(state_size_helper(), False) test.r_hat_and_ess(state_size_helper(mode_specific=True), False, mode_indices=mode_indices) print() - test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb')) + test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, period), 'rb')) test.r_hat_and_ess(state_num_helper(0.05), False) test.r_hat_and_ess(state_num_helper(0.05, mode_specific=True), False, mode_indices=mode_indices) print() - test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb')) + test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, period), 'rb')) test.r_hat_and_ess(state_num_helper(0.02), False) test.r_hat_and_ess(state_num_helper(0.02, mode_specific=True), False, mode_indices=mode_indices) print() diff --git a/simplex_plot.py b/simplex_plot.py index bdaad788..dd8ba57d 100644 --- a/simplex_plot.py +++ b/simplex_plot.py @@ -37,6 +37,7 @@ def plotSimplex(points, fig=None, fig.gca().text(0.43, np.sqrt(3) / 2 + 0.025, vertexlabels[2], size=24) # Project and draw the actual points projected = projectSimplex(points / points.sum(1)[:, None]) + print(projected) P.scatter(projected[:, 0], projected[:, 1], s=points.sum(1) * 3.5, **kwargs) # plot center with average size @@ -90,14 +91,14 @@ if __name__ == '__main__': labels = ('[0.1 0.1 0.8]', '[0.8 0.1 0.1]', '[0.5 0.4 0.1]', + '[0.17 0.33 0.5]', '[0.33 0.34 0.33]') testpoints = np.array([[0.1, 0.1, 0.8], [0.8, 0.1, 0.1], [0.5, 0.4, 0.1], + [0.17, 0.33, 0.5], [0.33, 0.34, 0.33]]) # Define different colors for each label c = range(len(labels)) # Do scatter plot - fig = plotSimplex(testpoints, s=25, c='k') - - P.show() + fig = plotSimplex(testpoints, c='k', show=1) diff --git a/test_codes/pymc_compare/call.py b/test_codes/pymc_compare/call.py index aeeedd8a..bc521538 100644 --- a/test_codes/pymc_compare/call.py +++ b/test_codes/pymc_compare/call.py @@ -1,14 +1,19 @@ -import pymc, bayes_fit # load the model file +"""Perform a pymc sampling of the test data.""" +import pymc, bayes_fit # load the model file import numpy as np import pickle +# Data params +T = 14 +n_inputs = 3 + +# Sampling params n_samples = 400000 -R = pymc.MCMC(bayes_fit) # build the model -R.sample(n_samples) # populate and run it +R = pymc.MCMC(bayes_fit) # build the model +R.sample(n_samples) # populate and run it -T = 14 -n_inputs = 3 +# Extract weights weights = np.zeros((T, n_samples, n_inputs)) for t in range(T): try: @@ -16,4 +21,5 @@ for t in range(T): except KeyError: weights[t] = R.trace('ws'.format(t)) +# Save everything pickle.dump(weights, open('pymc_posterior', 'wb')) diff --git a/test_codes/pymc_compare/dynglm_optimisation_test.py b/test_codes/pymc_compare/dynglm_optimisation_test.py index 05df6e0c..e17222ea 100644 --- a/test_codes/pymc_compare/dynglm_optimisation_test.py +++ b/test_codes/pymc_compare/dynglm_optimisation_test.py @@ -2,7 +2,7 @@ Need to find out whether loglikelihood is computed correctly. Or whether a bug here allows states to invade each other more easily. -We'll test this by maximising the likelihood directly. +We'll test this by comparing to pymc results. """ import numpy as np import pyhsmm.basic.distributions as distributions diff --git a/test_codes/pymc_compare/gibbs_sample.py b/test_codes/pymc_compare/gibbs_sample.py index 31ed5919..5b91c1ab 100644 --- a/test_codes/pymc_compare/gibbs_sample.py +++ b/test_codes/pymc_compare/gibbs_sample.py @@ -8,7 +8,7 @@ import pyhsmm.basic.distributions as distributions from scipy.optimize import minimize import pickle -# Data Params +# Data params T = 16 n_inputs = 3 step_size = 0.2 @@ -41,5 +41,5 @@ LL_weights = np.zeros((T, n_inputs)) for t in range(T): LL_weights[t] = minimize(lambda w: wrapper(w, t), np.zeros(n_inputs)).x -# save everything +# Save everything pickle.dump((samples, LL_weights), open('gibbs_posterior', 'wb')) -- GitLab