From 0f72e7e5ef4d2fd0d00d792c4d0b0913e634034c Mon Sep 17 00:00:00 2001
From: SebastianBruijns <>
Date: Sat, 11 Feb 2023 20:37:50 +0100
Subject: [PATCH] code updates, commenting work, clean up of testing code

---
 .../mcmc_chain_analysis.cpython-37.pyc        | Bin 4481 -> 5313 bytes
 __pycache__/simplex_plot.cpython-37.pyc       | Bin 2949 -> 2989 bytes
 dyn_glm_chain_analysis.py                     | 284 ++++++++++++++----
 dynamic_GLMiHMM_consistency.py                |  11 +-
 dynamic_GLMiHMM_fit.py                        |  12 +-
 mcmc_chain_analysis.py                        |  29 +-
 simplex_plot.py                               |   7 +-
 test_codes/pymc_compare/call.py               |  16 +-
 .../pymc_compare/dynglm_optimisation_test.py  |   2 +-
 test_codes/pymc_compare/gibbs_sample.py       |   4 +-
 10 files changed, 279 insertions(+), 86 deletions(-)

diff --git a/__pycache__/mcmc_chain_analysis.cpython-37.pyc b/__pycache__/mcmc_chain_analysis.cpython-37.pyc
index bdea663ae5d36ec0de32219c100df06016b9244f..9f093ef55cf22048cd989c21378e8d782850879c 100644
GIT binary patch
delta 1620
zcmbVMOKclO7~a`kd;N|})z*ao0|AO@NhRokhbj_6DG_Q)T`HhLpvlgxy~*y5nHeXJ
z<r1O5jUJ+r5Zq!aaYtOM3OJy0LPBZ}IlvJRTsVL@^3U2%6Ivw1zGvoteE;`<v+qrP
zRj)i>E*C5~uK)Iv|JQ-dN)3IqdF0U|rWu-jA6@k51DhA$n0S!hNAHK>erlh$>QkFv
ze!t%PyunmQBoY%GbP`{MJi)3(6c(d-)na%G^8hOmk_hai<bFsLqgZh;WjIkS91+oE
zQemlJ1&0ZQlO%3Otbm7yPZGSyAcp`Kfn}TdVGy1eM&U`ch9l0GF(e)>`J({x7$o76
z83+W&oHDJ^O;piK_I0#RKk2-V9Q~m)zux=$KbMlClz6oj`YqfdD-7ZaHm(rLH~8C3
z_$*PR2@5U3WQ#{o^1s%0>bBzkS8X?)F1lNvbq8ot|L)Ef4?<UE&h*yNCo?UC^=Fwy
zG^O)dg&xuGX4{@i-4AT)>OZp2d6{Kfd`z?YS?}-#uQxNt<90_ei5qX3mO^VyWkQfO
z!)Cw;w4MRM0gocS8YWGAnnf|#w5RbCPn`%v2J+UC$inLv8y*X~>0K$-;Y%7|maPxu
zUY&5L0|=bJK5W%2K+?_4lT7d2g`^{agB2#EQ5A?{NVw#&p%l_V3mEY34`EVFT8z*X
z+Kv=C$N>7I!kA6(hO?l*%dHikHb_cCB<J&Vhl^|Zd&?tUM(q5Ye{CkyJ33-yOh|y^
zPNYK6=?)u4f;{md8-}=};?v9vx>4vB=S<4GhAD3p2J5{~#+(sQ3~rEQ`SFB{m_%WZ
znC+C(OQG6hz&-_;`#Vd|;91}$ZVSEwL^a+vO$xX$dyGkG`r-2g_-ir%WOz9<cV=$9
zqXc9MlgTDt4PgTUNy13g%%(9^$fRB^ev4kwv!!=l?#=Hh{fJDBBvW#)6b!`S*2gjF
zhH)px#1~wG{YWm;`ioLfeACd7{&Q^oQ27!H^rz(;>&s>~RVyl`r#J~S+g@9+#gItf
zF}%Uy!tKGGl37ERrK~NaY>Kv=KI+>8o7x-p09|tW&X%h()Y(AP9oWd~J4@-<{Maso
z>iZV;E<4Cl*$oR>`qxVF=pFnwK$W|^jW5LOTl&Y!q2ssPdo;HvAB=D1Gd?fe$tV90
zpZqPK-599_|K7qv?=gLK;$i*6M0e5^r-6!XPj(iUn6I|8preq}Z5MW~UaC&b2w0XF
z9+q80Ra8S+xGU&j`W}zXTbu)%x-EcE>%}^POPqpF{Y~{adP)CToyi&lQPP9LetmY|
XJ%>t#Le5FgYOatg=SuL*<h;KDksQv>

delta 796
zcmZWm&uddb5Z>AM`$t|v@<<~!ZM2m74kDrm9>lgnN(AGfikd^kW*b8Epf9gb<pq>f
z5Tw*F9z5pJe?iZp2T>2Af*$H0AbQrjvq@Ca?y@sGyWjWC?Ch)H!Hl<<O4*FYqbHy0
zFS0$a08e{!N1SP)34If4hA8!F*HT3>waIEZ!IxQZsQ3CaRB?%a1sA83w~)r~%2fcI
zS1WXWLEC^lzR?mWU{hakd=hV)G6mpsy$J`gXsp6C-ZvsRjK7WB1{3Td7YuKiXHTd%
zdH76d%PfMv8k?~($r9*Tu@##QnOv`i8c78Z>g0ZKGFY;9oLSnm^JFN?2kTOLq}kvH
zSb+%zkJ}Yb9pmCR`>u0R(%d2Kx+DIT^4=+?MiQgqemPfiBv2ykDEXieeC%#GXQWDY
zsN$s(YQlN%9aPXy-=Y1-bi=tU5raMhE@V2ef-f^)nss?PY>f+YhEl1nNF^H5%U7PG
z)y@E85$i+51@x7eCm?v2Y=_0HuSS|sx**hq1I4XDEPrr|3pEDZ&Q1jT7+p4omuONL
zqL1!H8^RpZr+8d{aX-V_!?6A{gt4Uni}Wr7YBT+-UA;GlKeMx#&8_>&t{Sbiqp(a5
z-!<CntF5&<?&M}pw`3<uWEjsw1`41<9z0so|Ba;&FA^r;AvKMa11O{JAA==4<<IN#
cCUo(Jk;83&a>BMvQyI*(X_+of%hXKc5AaZ+-2eap

diff --git a/__pycache__/simplex_plot.cpython-37.pyc b/__pycache__/simplex_plot.cpython-37.pyc
index 1c26491f573cb6355e901c949147f3d921c87a30..8fa16cdde21c4a7d3d6daa3797043fc2e294e842 100644
GIT binary patch
delta 539
zcmZn_Un?%>#LLUY00bFl9wp!9Wng#=;=lmsM7g!R0_n^t!YLvt!s*QEtP?MY@{6X3
zrHJ=3N6Doy1v6+$Z2Tj{$oOG0Khq~}fg1K^Mn;AbhAhF!dCW4C6`182g(f>N%kvAT
z2&AwtWNK!tWrxZt*9+A!EZ|QOSjf1LF-1g#Aw{B>m64%_aRJ*xh6N%iJSmb3nc#d;
z7{7#Jfmn*vLWnKmDbgu2DYCsFbu1GYi=F^wB^EL;GURapS&|@D3M0^sa#i!V6o4Qg
zC%;6&ufkZNAU`v2vj9s8Bcs&hR@PcJ)`FtUypqX`Y$-aoxQY|YQsYZ9OL9_+w16RT
zi!Hl6u_(Rx7E_L46lXzEepYI7NovYUh9bks6WP2O^(NnClNIA&v|=`3v|$!t6ky_F
z;$viE<YFo^-pt1?&cvuQS&dU!#1Lp!5tuLnu@r%X&E!~46Gqm_lQ>NT1)~l049yi3
z4D^hRL8R&ABb-(ShrepX^dCBEpDy3}I&g>FRr_>^$W{9yHjpWxxYK01#h46ss;2y8
zRjyNv#*;sA>8Y3lrEW3jq~;Ylf>i2)2!tCT4w>xEtu3s@A<MzU!pg(Q0fu}`JWK%E
Cvw+wD

delta 468
zcmZ20-YPEV#LLUY00eK>o=<M(Wng#=;=lmMM7gzwLg~!uEDIT<q*A3*Wl|-Y8JiiS
zWK%^`rBXyvL{r3inWN-VSr^DJWME`SVG3r@6yNwqh>`KdWPYYk%nVrqlX;k}83iYM
zFw64`r3j?3FJx+FtYt5O$jES{h^2^^uq+UqJd0V{RXBwwMRFk%BSQ`20=9(=3q)Z2
z5{3n$DN+j=85w{Y#ZshGWKv{%LF!m0Fc#eb%8G-m&*K2HBsTwIp2Wx~Ik}g$cCsK_
zvQ89Nabj6&d`V_WPHK@R&}+BYvda^T(u;2~<rqeB78K=Yr6!l8rmSQr(x1G5&6`nY
z@*6f;5i4c`MjK`UMgb-+CO$?sMlPlz!_D&S;!KR%lg&AmMf8Ct6@du@5K9qASWnL5
zG+|_!yqMEu@+(dYb@|rUfji``+NU4>su9zF=&JooNsv-DAfd@r#0{iwF(zv&Ob+5Y
z#b`KLmRnEL6ePx+lbToL0Av;EfC!M2Ar>K7IXRzOTUe7rmV=3fg@=&?2>F<JnE03g
D8J%re

diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py
index 3f1e31cf..71e06373 100644
--- a/dyn_glm_chain_analysis.py
+++ b/dyn_glm_chain_analysis.py
@@ -574,10 +574,10 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur
         noise = np.zeros(len(c_n_a))# np.random.rand(len(c_n_a)) * 0.4 - 0.2
 
         mask = c_n_a[:, -1] == 0
-        plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - c_n_a[mask, 0] + c_n_a[mask, 1]), 'o', c='b', ms=ms, label='Leftward')
+        plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - c_n_a[mask, 0] + c_n_a[mask, 1]), 'o', c='b', ms=ms, label='Leftward', alpha=0.6)
 
         mask = c_n_a[:, -1] == 1
-        plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - c_n_a[mask, 0] + c_n_a[mask, 1]), 'o', c='r', ms=ms, label='Rightward')
+        plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - c_n_a[mask, 0] + c_n_a[mask, 1]), 'o', c='r', ms=ms, label='Rightward', alpha=0.6)
 
         plt.title("session #{} / {}".format(1+seq_num, test.results[0].n_sessions), size=26)
         # plt.yticks(*self.cont_ticks, size=22-2)
@@ -852,7 +852,7 @@ def state_development_single_sample(test, indices, save=True, save_append='', sh
     return states_by_session, all_pmfs
 
 
-def state_development(test, state_sets, indices, save=True, save_append='', show=True, dpi='figure', separate_pmf=False):
+def state_development(test, state_sets, indices, save=True, save_append='', show=True, dpi='figure', separate_pmf=False, type_coloring=True):
     # Now also returns durs of state types and state type summary array
     state_sets = [np.array(s) for s in state_sets]
 
@@ -1038,7 +1038,8 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
     introductions_by_stage = np.zeros(3)
     covered_states = []
     for i, d in enumerate(durs):
-        ax0.fill_between(range(dur_counter, 1 + dur_counter + d), 0.5, -0.5, color=type_colours[i], zorder=0, alpha=0.3)
+        if type_coloring:
+            ax0.fill_between(range(dur_counter, 1 + dur_counter + d), 0.5, -0.5, color=type_colours[i], zorder=0, alpha=0.3)
         dur_counter += d
 
         # find out during which state type which contrast was introduced
@@ -1098,7 +1099,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
     else:
         plt.close()
 
-    return states_by_session, all_pmfs, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs))
+    return states_by_session, all_pmfs, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs)), introductions_by_stage
 
 
 def smart_divide(a, b):
@@ -1154,14 +1155,13 @@ if __name__ == "__main__":
 
     r_hats = []
 
+    # R^hat tests
     # test = MCMC_result_list([fake_result(100) for i in range(8)])
     # test.r_hat_and_ess(return_ascending, False)
     # test.r_hat_and_ess(return_ascending_shuffled, False)
     # quit()
-    good = []
-    bad = []
 
-    check_r_hats = False
+    check_r_hats = True
     if check_r_hats:
         subjects = list(loading_info.keys())
         subjects = ['KS014']
@@ -1188,11 +1188,6 @@ if __name__ == "__main__":
             r_hats.append((subject, final_r_hat))
             loading_info[subject]['ignore'] = sol
 
-            if final_r_hat < 1.05:
-                good.append(subject)
-            else:
-                bad.append(subject)
-
         print(r_hats)
         if fit_type == 'bias':
             json.dump(loading_info, open("canonical_infos_bias.json", 'w'))
@@ -1221,10 +1216,15 @@ def state_type_durs(states, pmfs):
                 pmf_counter += 1
                 state_types[pmf_type(pmf[1][pmf_counter][pmf[0]]), i] += s[i]  # indexing horror
 
-    durs = (np.where(state_types[1] > 0.5)[0][0],
-            np.where(state_types[2] > 0.5)[0][0] - np.where(state_types[1] > 0.5)[0][0],
-            states.shape[1] - np.where(state_types[2] > 0.5)[0][0])
-    if np.where(state_types[2] > 0.5)[0][0] < np.where(state_types[1] > 0.5)[0][0]:
+    if np.any(state_types[1] > 0.5):
+        durs = (np.where(state_types[1] > 0.5)[0][0],
+                np.where(state_types[2] > 0.5)[0][0] - np.where(state_types[1] > 0.5)[0][0],
+                states.shape[1] - np.where(state_types[2] > 0.5)[0][0])
+        if np.where(state_types[2] > 0.5)[0][0] < np.where(state_types[1] > 0.5)[0][0]:
+            durs = (np.where(state_types[2] > 0.5)[0][0],
+                    0,
+                    states.shape[1] - np.where(state_types[2] > 0.5)[0][0])
+    else:
         durs = (np.where(state_types[2] > 0.5)[0][0],
                 0,
                 states.shape[1] - np.where(state_types[2] > 0.5)[0][0])
@@ -1251,6 +1251,24 @@ def state_cluster_interpolation(states, pmfs):
     return state_types, state_trans, pmf_examples
 
 
+def get_first_pmfs(states, pmfs):
+    # get the first pmf of every type, also where they are defined, and whether they are the first pmf of that state
+    earliest_sessions = [1000, 1000, 1000]
+    first_pmfs = [0, 0, 0, 0, 0, 0, 0, 0, 0]
+    changing_pmfs = [[0, 0], [0, 0]]
+    for state, pmf in zip(states, pmfs):
+        sessions = np.where(state)[0]
+        for i, (sess_pmf, sess) in enumerate(zip(pmf[1], sessions)):
+            if earliest_sessions[pmf_type(sess_pmf[pmf[0]])] > sess:
+                earliest_sessions[pmf_type(sess_pmf[pmf[0]])] = sess
+                first_pmfs[3 * pmf_type(sess_pmf[pmf[0]])] = sess_pmf
+                first_pmfs[1 + 3 * pmf_type(sess_pmf[pmf[0]])] = pmf[0]
+                first_pmfs[2 + 3 * pmf_type(sess_pmf[pmf[0]])] = i
+                if i != 0:
+                    changing_pmfs[pmf_type(sess_pmf[pmf[0]]) - 1] = [pmf[0], pmf[1]]
+    return first_pmfs, changing_pmfs
+
+
 def plot_pmf_types(pmf_types, subject, fit_type, save=True, show=False):
     # Plot the different types of PMFs, all split up by their different types
     for i, pmfs in enumerate(pmf_types):
@@ -1267,15 +1285,155 @@ def plot_pmf_types(pmf_types, subject, fit_type, save=True, show=False):
     else:
         plt.close()
 
+
 def pmf_type(pmf):
-    if pmf[-1] - pmf[0] <= 0.15:
+    if pmf[-1] - pmf[0] < 0.2:
         return 0
-    elif pmf[-1] - pmf[0] < 0.6 and np.abs(pmf[0] + pmf[-1] - 1) > 0.1:
+    elif pmf[-1] - pmf[0] < 0.6:# and np.abs(pmf[0] + pmf[-1] - 1) > 0.1:
         return 1
     else:
         return 2
 
 
+type2color = {0: 'green', 1: 'blue', 2: 'red'}
+
+if False:
+
+    all_changing_pmfs = pickle.load(open("changing_pmfs.p", 'rb'))
+    plt.figure(figsize=(16, 9))
+    for i, pmf in enumerate(all_changing_pmfs):
+        plt.subplot(4, 7, i + 1)
+        for p in pmf[1]:
+            plt.plot(np.where(pmf[0])[0], p[pmf[0]], color=type2color[pmf_type(p)])
+        plt.ylim(0, 1)
+
+        sns.despine()
+        if i+1 != 22:
+            plt.gca().set_xticks([])
+            plt.gca().set_yticks([])
+        else:
+            plt.xlabel("Contrasts", size=22)
+            plt.ylabel("P(rightwards)", size=22)
+            plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=16)
+            plt.gca().set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=16)
+
+        if i + 1 == 30:
+            break
+
+    plt.tight_layout()
+    plt.savefig("changing pmfs")
+    plt.show()
+    quit()
+
+    type_2_assyms = []
+    tick_size = 14
+    label_size = 26
+    all_first_pmfs = pickle.load(open("pmfs_temp.p", 'rb'))
+    plt.figure(figsize=(16, 9))
+    plt.subplot(1, 3, 1)
+    counter = [[0, 0], [0, 0]]
+    save_title = "all types" if False else "KS014 types"
+    if save_title == "KS014 types":
+        all_first_pmfs = {'KS014': all_first_pmfs['KS014']}
+
+    for key in all_first_pmfs:
+        x = all_first_pmfs[key]
+        if type(x[0]) == int:
+            continue
+        linestyle = '-' if x[2] == 0 else '--'
+        plt.plot(np.where(x[1])[0], x[0][x[1]], linestyle=linestyle, c='g')
+    plt.ylim(0, 1)
+    plt.gca().set_xticks(np.arange(11), all_conts, size=tick_size)
+    plt.gca().set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size)
+    plt.gca().spines[['right', 'top']].set_visible(False)
+    plt.xlim(0, 10)
+    plt.xticks(rotation=45)
+    plt.gca().set_ylabel("P(rightwards)", size=label_size)
+
+    plt.subplot(1, 3, 2)
+    for key in all_first_pmfs:
+        x = all_first_pmfs[key]
+        if type(x[3]) == int:
+            continue
+        type_2_assyms.append(np.abs(x[3][0] + x[3][-1] - 1))
+        linestyle = '-' if x[5] == 0 else '--'
+        counter[0][0 if x[5] == 0 else 1] += 1
+        if linestyle == '--':
+            continue
+        plt.plot(np.where(x[4])[0], x[3][x[4]], linestyle=linestyle, c='b')
+    plt.gca().set_yticks([])
+    plt.ylim(0, 1)
+    plt.gca().set_xticks(np.arange(11), all_conts, size=tick_size)
+    plt.gca().spines[['right', 'top']].set_visible(False)
+    plt.xticks(rotation=45)
+    plt.xlim(0, 10)
+    plt.gca().set_xlabel("Contrasts", size=label_size)
+
+    plt.subplot(1, 3, 3)
+    for key in all_first_pmfs:
+        x = all_first_pmfs[key]
+        if type(x[6]) == int:
+            continue
+        linestyle = '-' if x[8] == 0 else '--'
+        counter[1][0 if x[8] == 0 else 1] += 1
+        if linestyle == '--':
+            continue
+        plt.plot(np.where(x[7])[0], x[6][x[7]], linestyle=linestyle, c='r')
+    plt.gca().set_yticks([])
+    plt.ylim(0, 1)
+    plt.gca().set_xticks(np.arange(11), all_conts, size=tick_size)
+    plt.gca().spines[['right', 'top']].set_visible(False)
+    plt.xlim(0, 10)
+    plt.xticks(rotation=45)
+
+    print(counter)
+    plt.tight_layout()
+    plt.savefig(save_title)
+    plt.show()
+    if save_title == "KS014 types":
+        quit()
+
+    counter = 0
+    fig, ax = plt.subplots(1, 3, figsize=(16, 9))
+    for key in all_first_pmfs:
+        x = all_first_pmfs[key]
+        if type(x[3]) == int:
+            continue
+        linestyle = '-' if x[5] == 0 else '--'
+        if linestyle == '--':
+            continue
+        if np.abs(x[3][0] + x[3][-1] - 1) <= 0.1:
+            counter += 1
+            use_ax = 2
+        else:
+            use_ax = int(x[3][0] > 1 - x[3][-1])
+
+        ax[use_ax].plot(np.where(x[4])[0], x[3][x[4]], linestyle=linestyle, c='b')
+    ax[0].set_ylim(0, 1)
+    ax[0].set_xlim(0, 10)
+    ax[0].spines[['right', 'top']].set_visible(False)
+    ax[0].set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size)
+    ax[0].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
+    ax[0].set_ylabel("P(rightwards)", size=label_size)
+
+    ax[1].set_ylim(0, 1)
+    ax[1].set_xlim(0, 10)
+    ax[1].set_yticks([])
+    ax[1].spines[['right', 'top']].set_visible(False)
+    ax[1].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
+    ax[1].set_xlabel("Contrasts", size=label_size)
+
+    ax[2].set_ylim(0, 1)
+    ax[2].set_xlim(0, 10)
+    ax[2].set_yticks([])
+    ax[2].spines[['right', 'top']].set_visible(False)
+    ax[2].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
+    print(counter)
+    plt.tight_layout()
+    plt.savefig("differentiate type 2")
+    plt.show()
+    quit()
+
 if __name__ == "__main__":
 
     # visualise pmf types
@@ -1301,40 +1459,11 @@ if __name__ == "__main__":
         loading_info = json.load(open("canonical_infos.json", 'r'))
         r_hats = json.load(open("canonical_info_r_hats.json", 'r'))
     no_good_pcas = ['NYU-06', 'SWC_023']  # no good rhat: 'ibl_witten_13'
-    subjects = list(loading_info.keys())
+    subjects = ['KS014']  # list(loading_info.keys())
     print(subjects)
     fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0]
     dur = 'yes'
 
-    fist_good_pmf = {'CSHL051': (6, 'left'),
-                     'CSHL059': (3, 'left'),
-                     'CSHL061': (4, 'right'),
-                     'CSHL062': (1, 'left'),
-                     'CSHL_007': (2, 'skip'),
-                     'CSHL_014': (4, 'left'),
-                     'CSHL_015': (2, 'right'),
-                     'CSHL_018': (3, 'skip'),
-                     'CSHL_020': (5, 'left'),  # tough case
-                     'CSH_ZAD_001': (4, 'right'),
-                     'CSH_ZAD_011': (0, 'right'),
-                     'CSH_ZAD_022': (4, 'right'),
-                     'CSH_ZAD_025': (2, 'skip'),  # eieiei
-                     'CSH_ZAD_026': (6, 'skip'),  # gradual
-                     'ibl_witten_14': (3, 'left'),
-                     'ibl_witten_16': (3, 'right'),
-                     # 'ibl_witten_18': (3, 'very weird'),  # probably shouldn't be analysed
-                     'ibl_witten_19': (4, 'right'),
-                     'KS014': (4, 'right'),
-                     'KS015': (4, 'left'),
-                     'KS016': (3, 'skip'),  # non-trivial
-                     'KS017': (3, 'right'),
-                     'KS021': (5, 'left'),
-                     'KS022': (4, 'right'),
-                     'KS023': (3, 'left'),
-                     'NYU-06': (9, 'right'),
-                     'SWC_023': (9, 'skip'),  # gradual
-                     'ZM_1897': (5, 'right'),
-                     'ZM_3003': (1, 'skip')}
     # fig, ax = plt.subplots(1, 3, sharey=True, figsize=(16, 9))
 
     thinning = 25
@@ -1356,6 +1485,12 @@ if __name__ == "__main__":
     not_yet = True
 
     abs_state_durs = []
+    all_first_pmfs = {}
+    all_pmf_diffs = []
+    all_pmf_asymms = []
+    all_pmfs = []
+    all_changing_pmfs = []
+    all_intros = []
 
     for subject in subjects:
 
@@ -1366,30 +1501,44 @@ if __name__ == "__main__":
         results = []
 
         try:
-
-            print('loading canonical')
             test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'rb'))
             print('loaded canoncial result')
 
             mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'rb'))
+            quit()
             state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb'))
             # lapse differential
             # lapse_sides(test, [s for s in state_sets if len(s) > 40], mode_indices)
 
             # training overview
-            states, pmfs, durs, _, contrast_intro_type, intros_by_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=1, separate_pmf=True)
+            states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True)
+            all_intros.append(undiv_intros)
             intros_by_type_sum += intros_by_type
-            continue
+            first_pmfs, changing_pmfs = get_first_pmfs(states, pmfs)
+            for pmf in changing_pmfs:
+                if type(pmf[0]) == int:
+                    continue
+                all_changing_pmfs.append(pmf)
+            all_first_pmfs[subject] = first_pmfs
+            for pmf in pmfs:
+                for p in pmf[1]:
+                    all_pmf_diffs.append(p[-1] - p[0])
+                    all_pmf_asymms.append(np.abs(p[0] + p[-1] - 1))
+                    all_pmfs.append(p)
             contrast_intro_types.append(contrast_intro_type)
             # state_development_single_sample(test, [mode_indices[0]], show=True, separate_pmf=True, save=False)
 
             # session overview
-            # consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
-            # consistencies /= consistencies[0, 0]
-            # contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True, consistencies=consistencies)
+            consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
+            consistencies /= consistencies[0, 0]
+            contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True, consistencies=consistencies)
 
             # duration of different state types (and also percentage of type activities)
             abs_state_durs.append(durs)
+            simplex_durs = np.array(durs).reshape(1, 3)
+            print(simplex_durs / np.sum(simplex_durs))
+            from simplex_plot import projectSimplex
+            print(projectSimplex(simplex_durs / simplex_durs.sum(1)[:, None]))
             continue
 
             # compute state type proportions and split the pmfs accordingly
@@ -1567,6 +1716,7 @@ if __name__ == "__main__":
             plt.show()
 
         except FileNotFoundError as e:
+            continue
             print(e)
             r_hat = 1.5
             for r in r_hats:
@@ -1630,11 +1780,33 @@ if __name__ == "__main__":
         #     state_appear.append(b[a == i][0] / (test.results[0].n_sessions - 1))
         #     state_dur.append(b[a == i].shape[0])
 
+    # pickle.dump(all_first_pmfs, open("pmfs_temp.p", 'wb'))
+    # pickle.dump(all_changing_pmfs, open("changing_pmfs.p", 'wb'))
+    #
+    # a = [x for x, y in zip(all_pmf_asymms, all_pmf_diffs) if y >= 0.2]
+    # b = [y for x, y in zip(all_pmf_asymms, all_pmf_diffs) if y >= 0.2]
+    # plt.hist2d(a, b, bins=40)
+    # plt.show()
+
+
     if True:
         abs_state_durs = np.array(abs_state_durs)
         # pickle.dump(abs_state_durs, open("multi_chain_saves/abs_state_durs.p", 'wb'))
         abs_state_durs = pickle.load(open("multi_chain_saves/abs_state_durs.p", 'rb'))
 
+        print("Correlations")
+        from scipy.stats import pearsonr
+        print(pearsonr(abs_state_durs[:, 0], abs_state_durs[:, 1]))
+        print(pearsonr(abs_state_durs[:, 2], abs_state_durs[:, 1]))
+        print(pearsonr(abs_state_durs[:, 0], abs_state_durs[:, 2]))
+
+        print(pearsonr(abs_state_durs.sum(1), abs_state_durs[:, 0]))
+        # (0.7338297529946006, 2.6332570579118393e-06)
+        print(pearsonr(abs_state_durs.sum(1), abs_state_durs[:, 1]))
+        # (0.35094585023228597, 0.052897046343413114)
+        print(pearsonr(abs_state_durs.sum(1), abs_state_durs[:, 2]))
+        # (0.7210260323745921, 4.747833912452452e-06)
+
         from simplex_plot import plotSimplex
 
         plotSimplex(np.array(abs_state_durs), c='k', show=True)
diff --git a/dynamic_GLMiHMM_consistency.py b/dynamic_GLMiHMM_consistency.py
index a8b2cf93..8bf1924c 100644
--- a/dynamic_GLMiHMM_consistency.py
+++ b/dynamic_GLMiHMM_consistency.py
@@ -42,10 +42,10 @@ for subject in subjects:
     from_session = info_dict['bias_start'] if fit_type == 'bias' else 0
 
     models = []
-    n_inputs = 5
+    n_regressors = 5
     T = till_session - from_session + (fit_type != 'prebias')
-    obs_hypparams = {'n_inputs': n_inputs, 'T': T, 'prior_mean': np.zeros(n_inputs),
-                     'P_0': 2 * np.eye(n_inputs), 'Q': fit_variance * np.tile(np.eye(n_inputs), (T, 1, 1))}
+    obs_hypparams = {'n_regressors': n_regressors, 'T': T, 'prior_mean': np.zeros(n_regressors), 'jumplimit': 3,
+                     'P_0': 2 * np.eye(n_regressors), 'Q': fit_variance * np.tile(np.eye(n_regressors), (T, 1, 1))}
     dur_hypparams = dict(r_support=np.array([1, 2, 3, 5, 7, 10, 15, 21, 28, 36, 45, 55, 150]),
                          r_probs=np.ones(13)/13., alpha_0=1, beta_0=1)
 
@@ -78,7 +78,7 @@ for subject in subjects:
 
         bad_trials = data[:, 1] == 1
         bad_trials[0] = True
-        mega_data = np.empty((np.sum(~bad_trials), n_inputs + 1))
+        mega_data = np.empty((np.sum(~bad_trials), n_regressors + 1))
 
         mega_data[:, 0] = np.maximum(data[~bad_trials, 0], 0)
         mega_data[:, 1] = np.abs(np.minimum(data[~bad_trials, 0], 0))
@@ -123,7 +123,10 @@ for subject in subjects:
 
 prev_res = pickle.load(open(save_title, 'rb'))
 
+counter = 0
 for p, m in zip(prev_res, models):
+    print(counter)
+    counter += 1
     for od, nd in zip(p.obs_distns, m.obs_distns):
         assert np.allclose(od.weights, nd.weights)
 prof._prof.print_stats()
diff --git a/dynamic_GLMiHMM_fit.py b/dynamic_GLMiHMM_fit.py
index b05ad929..dbbe447a 100644
--- a/dynamic_GLMiHMM_fit.py
+++ b/dynamic_GLMiHMM_fit.py
@@ -204,14 +204,14 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
     models = []
 
     if params['obs_dur'] == 'glm':
-        n_inputs = len(params['regressors'])
+        n_regressors = len(params['regressors'])
         T = till_session - from_session + (params['fit_type'] != 'prebias')
-        obs_hypparams = {'n_inputs': n_inputs, 'T': T, 'jumplimit': params['jumplimit'], 'prior_mean': params['init_mean'],
-                         'P_0': params['init_var'] * np.eye(n_inputs), 'Q': params['fit_variance'] * np.tile(np.eye(n_inputs), (T, 1, 1))}
+        obs_hypparams = {'n_regressors': n_regressors, 'T': T, 'jumplimit': params['jumplimit'], 'prior_mean': params['init_mean'],
+                         'P_0': params['init_var'] * np.eye(n_regressors), 'Q': params['fit_variance'] * np.tile(np.eye(n_regressors), (T, 1, 1))}
         obs_distns = [distributions.Dynamic_GLM(**obs_hypparams) for state in range(params['n_states'])]
     else:
-        n_inputs = 9 if params['fit_type'] == 'bias' else 11
-        obs_hypparams = {'n_inputs': n_inputs * (1 + (params['conditioned_on'] != 'nothing')), 'n_outputs': 2, 'T': till_session - from_session + (params['fit_type'] != 'prebias'),
+        n_regressors = 9 if params['fit_type'] == 'bias' else 11
+        obs_hypparams = {'n_regressors': n_regressors * (1 + (params['conditioned_on'] != 'nothing')), 'n_outputs': 2, 'T': till_session - from_session + (params['fit_type'] != 'prebias'),
                          'jumplimit': params['jumplimit'], 'sigmasq_states': params['fit_variance']}
         obs_distns = [Dynamic_Input_Categorical(**obs_hypparams) for state in range(params['n_states'])]
 
@@ -280,7 +280,7 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
             mask[0] = False
             if params['fit_type'] == 'zoe_style':
                 mask[90:] = False
-            mega_data = np.empty((np.sum(mask), n_inputs + 1))
+            mega_data = np.empty((np.sum(mask), n_regressors + 1))
 
             for i, reg in enumerate(params['regressors']):
                 # positive numbers are contrast on the right
diff --git a/mcmc_chain_analysis.py b/mcmc_chain_analysis.py
index 2a99a676..c8de1259 100644
--- a/mcmc_chain_analysis.py
+++ b/mcmc_chain_analysis.py
@@ -8,6 +8,8 @@ import pickle
 
 
 def state_size_helper(n=0, mode_specific=False):
+    """Returns a function that returns the # of trials associated to the nth largest state in a sample
+       can be further specified to only look at specific samples, those of a mode"""
     if not mode_specific:
         def nth_largest_state_func(x):
             return np.partition(x.assign_counts, -1 - n, axis=1)[:, -1 - n]
@@ -18,6 +20,8 @@ def state_size_helper(n=0, mode_specific=False):
 
 
 def state_num_helper(t, mode_specific=False):
+    """Returns a function that returns the # of states which have more trials than a percentage threshold t in a sample
+       can be further specified to only look at specific samples, those of a mode"""
     if not mode_specific:
         def state_num_func(x): return ((x.assign_counts / x.n_datapoints) > t).sum(1)
     else:
@@ -35,14 +39,15 @@ def ll_func(x): return x.sample_lls[-x.n_samples:]
 
 
 def r_hat_array_comp(chains):
-    m, n = chains.shape
+    """Computes R^hat on an array of features, following Gelman p. 284f"""
+    m, n = chains.shape  # number of chains, length of chains
     psi_dot_j = np.mean(chains, axis=1)
     psi_dot_dot = np.mean(psi_dot_j)
     B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot) ** 2)
     s_j_squared = np.sum((chains - psi_dot_j[:, None]) ** 2, axis=1) / (n - 1)
     W = np.mean(s_j_squared)
     var_hat_plus = (n - 1) / n * W + B / n
-    if W == 0:
+    if W == 0:  # sometimes a feature has 0 variance
         # print("all the same value")
         return 1, 0
     r_hat = np.sqrt(var_hat_plus / W)
@@ -50,6 +55,7 @@ def r_hat_array_comp(chains):
 
 
 def eval_amortized_r_hat(chains, psi_dot_j, s_j_squared, m, n):
+    """Unused version in which some things were computed ahead of function to save time."""
     psi_dot_dot = np.mean(psi_dot_j, axis=1)
     B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot[:, None]) ** 2, axis=1)
     W = np.mean(s_j_squared, axis=1)
@@ -59,6 +65,7 @@ def eval_amortized_r_hat(chains, psi_dot_j, s_j_squared, m, n):
 
 
 def r_hat_array_comp_mult(chains):
+    """Compute R^hat of multiple features at once."""
     _, m, n = chains.shape
     psi_dot_j = np.mean(chains, axis=2)
     psi_dot_dot = np.mean(psi_dot_j, axis=1)
@@ -71,8 +78,8 @@ def r_hat_array_comp_mult(chains):
 
 
 def rank_inv_normal_transform(chains):
-    # Gelman paper Rank-normalization, folding, and localization: An improved R_hat for assessing convergence of MCMC
-    # ranking with average rank for ties
+    """Gelman paper Rank-normalization, folding, and localization: An improved R_hat for assessing convergence of MCMC
+       ranking with average rank for ties"""
     folded_chains = np.abs(chains - np.median(chains))
     ranked = rankdata(chains).reshape(chains.shape)
     folded_ranked = rankdata(folded_chains).reshape(folded_chains.shape)
@@ -83,6 +90,8 @@ def rank_inv_normal_transform(chains):
 
 
 def eval_r_hat(chains):
+    """Compute entire set of R^hat's for list of feature arrays, and return maximum across features.
+       Computes all R^hat versions, as opposed to eval_simple_r_hat"""
     r_hats = []
     for chain in chains:
         rank_normalised, folded_rank_normalised, _, _ = rank_inv_normal_transform(chain)
@@ -92,6 +101,8 @@ def eval_r_hat(chains):
 
 
 def eval_simple_r_hat(chains):
+    """Compute just simple R^hat's for list of feature arrays, and return maximum across features.
+       Computes only the simple type of R^hat, no folding or rank normalising, making it much faster"""
     r_hats, _ = r_hat_array_comp_mult(chains)
     return max(r_hats)
 
@@ -103,21 +114,21 @@ def comp_multi_r_hat(chains, rank_normalised, folded_rank_normalised):
     return max(lame_r_hat, rank_normalised_r_hat, folded_rank_normalised_r_hat)
 
 
-def sample_statistics(test, mode_indices, subject):
+def sample_statistics(mode_indices, subject, period='prebias'):
     # prints out r_hats and sample sizes for given sample
-    test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb'))
+    test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, period), 'rb'))
     test.r_hat_and_ess(state_size_helper(1), False)
     test.r_hat_and_ess(state_size_helper(1, mode_specific=True), False, mode_indices=mode_indices)
     print()
-    test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb'))
+    test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, period), 'rb'))
     test.r_hat_and_ess(state_size_helper(), False)
     test.r_hat_and_ess(state_size_helper(mode_specific=True), False, mode_indices=mode_indices)
     print()
-    test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb'))
+    test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, period), 'rb'))
     test.r_hat_and_ess(state_num_helper(0.05), False)
     test.r_hat_and_ess(state_num_helper(0.05, mode_specific=True), False, mode_indices=mode_indices)
     print()
-    test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb'))
+    test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, period), 'rb'))
     test.r_hat_and_ess(state_num_helper(0.02), False)
     test.r_hat_and_ess(state_num_helper(0.02, mode_specific=True), False, mode_indices=mode_indices)
     print()
diff --git a/simplex_plot.py b/simplex_plot.py
index bdaad788..dd8ba57d 100644
--- a/simplex_plot.py
+++ b/simplex_plot.py
@@ -37,6 +37,7 @@ def plotSimplex(points, fig=None,
     fig.gca().text(0.43, np.sqrt(3) / 2 + 0.025, vertexlabels[2], size=24)
     # Project and draw the actual points
     projected = projectSimplex(points / points.sum(1)[:, None])
+    print(projected)
     P.scatter(projected[:, 0], projected[:, 1], s=points.sum(1) * 3.5, **kwargs)
 
     # plot center with average size
@@ -90,14 +91,14 @@ if __name__ == '__main__':
     labels = ('[0.1  0.1  0.8]',
               '[0.8  0.1  0.1]',
               '[0.5  0.4  0.1]',
+              '[0.17  0.33  0.5]',
               '[0.33  0.34  0.33]')
     testpoints = np.array([[0.1, 0.1, 0.8],
                            [0.8, 0.1, 0.1],
                            [0.5, 0.4, 0.1],
+                           [0.17, 0.33, 0.5],
                            [0.33, 0.34, 0.33]])
     # Define different colors for each label
     c = range(len(labels))
     # Do scatter plot
-    fig = plotSimplex(testpoints, s=25, c='k')
-
-    P.show()
+    fig = plotSimplex(testpoints, c='k', show=1)
diff --git a/test_codes/pymc_compare/call.py b/test_codes/pymc_compare/call.py
index aeeedd8a..bc521538 100644
--- a/test_codes/pymc_compare/call.py
+++ b/test_codes/pymc_compare/call.py
@@ -1,14 +1,19 @@
-import pymc, bayes_fit              # load the model file
+"""Perform a pymc sampling of the test data."""
+import pymc, bayes_fit  # load the model file
 import numpy as np
 import pickle
 
+# Data params
+T = 14
+n_inputs = 3
+
+# Sampling params
 n_samples = 400000
 
-R = pymc.MCMC(bayes_fit)    #  build the model
-R.sample(n_samples)              # populate and run it
+R = pymc.MCMC(bayes_fit)  # build the model
+R.sample(n_samples)  # populate and run it
 
-T = 14
-n_inputs = 3
+# Extract weights
 weights = np.zeros((T, n_samples, n_inputs))
 for t in range(T):
     try:
@@ -16,4 +21,5 @@ for t in range(T):
     except KeyError:
         weights[t] = R.trace('ws'.format(t))
 
+# Save everything
 pickle.dump(weights, open('pymc_posterior', 'wb'))
diff --git a/test_codes/pymc_compare/dynglm_optimisation_test.py b/test_codes/pymc_compare/dynglm_optimisation_test.py
index 05df6e0c..e17222ea 100644
--- a/test_codes/pymc_compare/dynglm_optimisation_test.py
+++ b/test_codes/pymc_compare/dynglm_optimisation_test.py
@@ -2,7 +2,7 @@
 Need to find out whether loglikelihood is computed correctly.
 
 Or whether a bug here allows states to invade each other more easily.
-We'll test this by maximising the likelihood directly.
+We'll test this by comparing to pymc results.
 """
 import numpy as np
 import pyhsmm.basic.distributions as distributions
diff --git a/test_codes/pymc_compare/gibbs_sample.py b/test_codes/pymc_compare/gibbs_sample.py
index 31ed5919..5b91c1ab 100644
--- a/test_codes/pymc_compare/gibbs_sample.py
+++ b/test_codes/pymc_compare/gibbs_sample.py
@@ -8,7 +8,7 @@ import pyhsmm.basic.distributions as distributions
 from scipy.optimize import minimize
 import pickle
 
-# Data Params
+# Data params
 T = 16
 n_inputs = 3
 step_size = 0.2
@@ -41,5 +41,5 @@ LL_weights = np.zeros((T, n_inputs))
 for t in range(T):
     LL_weights[t] = minimize(lambda w: wrapper(w, t), np.zeros(n_inputs)).x
 
-# save everything
+# Save everything
 pickle.dump((samples, LL_weights), open('gibbs_posterior', 'wb'))
-- 
GitLab