From cb6433077d18f74cb065e890fabe05e95f14670d Mon Sep 17 00:00:00 2001 From: SebastianBruijns <> Date: Thu, 13 Jul 2023 11:56:11 +0200 Subject: [PATCH] code update --- __pycache__/analysis_pmf.cpython-37.pyc | Bin 11418 -> 11748 bytes .../dyn_glm_chain_analysis.cpython-37.pyc | Bin 53918 -> 53764 bytes analysis_pmf.py | 94 +++++++++++++----- analysis_pmf_weights.py | 27 ++++- dyn_glm_chain_analysis.py | 35 ++++--- 5 files changed, 107 insertions(+), 49 deletions(-) diff --git a/__pycache__/analysis_pmf.cpython-37.pyc b/__pycache__/analysis_pmf.cpython-37.pyc index 899a197bee3c2ffcdbb2f048e7d0088d934a594e..f1b5b0e7abdf15bba7ca9ffc05dffe4b341f0893 100644 GIT binary patch delta 1539 zcmaJ>O-y4|6u$ShWuOc#VX%LUOi_N?p)g=@2#SJ}i7_&cB%?tuVCR){T1w07>w9f! z3uDzf>Ozv@b0e#E=`1EDY&3DD#%Q96aoLp%7jE6SoOo`_6vr6b^uBYybIy0p{qE^I ze)IQr*R{SrC!<gH+xzf?dgSu)?~f+F<R2dWHp0I-QTXM@D#b+{xz0A{#6mHGBe>va zv<r$5hBg$TiZN7S!_R6S!-JcUuX0Xy*0Uzp?lqP;`#MCd_*0APJL3Af!}Y}}*91NL zF;wMZ3%7u!6yp38D>2;iv-<NmfuXq4AbRb9-0qh90Xph{Az*$+YnPvS8Mu9{!J&}i zl>RHmVESL#mkDztwzJ4;<@s33Ugcimd8E<+#;(HZqad9t);is5YYd(hbMS(QBdcG) zLD+*!RccWrplEHl0~IJ~3Y5h*ZbO;+^%@$?;t0(8S(1y98ZW+#5sX#&bynk>izMfi zNYaURFonq~hedf_kw(YPE~bG4zpE^=;xy%umoDK6%DB<F2n{t<NryDvpjBODwTd@; ze~nRkhSK-pAfW(*JBK!*<2X*yseDu{i|GcJvbM1ugRrMzvB6OXHxcPlTzDHQhxUgt z^+|ZA{#Lt5lKS0-((%kv1^2_frF#PkPFcTr$K@mF^Hf@8O$xdxgKFMvnHn6IC3ddD zGR9%-utnBF=HXfgvDJqgjLO*Rm*73z_R<}<cDGw-t}1<PmHoybw)Vh89;&2KLF}Pz zXB6DfI}KN0=Fmo9(X#s|E3ukC`_lgok>99c9^CViEG$ae%zuxLC}9Cha%w5EfO}-q zh?2dBOPv8FJ*$c+qlXYfj%-BEEJ6F!!a?gtQ;z#7yPGu>vdSqYEir*BSh7}Ko(Q}T zPaWDwI_jII(4T#V1YD7F)z7qV?@jAsAj9L<$AQPWY26FF&f}j9NppDUl`D7Y0%e=| zhE4uuN!K%dKlrD%z;CheM<>Ta54tO^WG)C9G=tl!uA`PX{!D(pS}OOe>2x%nG|XI7 z-;JBodRQ?W^ja|%XesxQH<&BwXqJ0dZq3ZjwKoa*;$h6b+%|5%u^i@2yIM?|r^G)l z5#UhM`nFnDZax>jpq!@svD7JT{X819{u<5jJ16n67j2V(@u;3jYB^IlWyz?i7SI$u zIo#Ew5-@tHX46Iy*1b?T-qQ(cIw&MpYTDAMo#v{Bp%S2}1kkyn1vQ;Vv)M-g5pAKC zOQx~eduE%>bL87aRcoG)CrOhS=JpUZ>(9{Kgy_nf7*&hnoaEX~YNv);Z1$!ziR6}= zj;b11^Wi7?>dAWeb7kGhUA&t+`Jm#W)unW|cVq_%izvg2o!faop@Y0zaq{y{4<DeM f0YW^=Afcl?p!D;LJj8um`UZ$Ya&}oOkyrl)!RWIM delta 1044 zcmbVLUr!T35Z~P^q!ekJ3Z>U7(D0{!)Rr0|qDGV$eP|RDH0WA^$}Y9N_9)jD3o0dP zOnf0}voTSgTHlO`FMZKO4WGmKV2mHY_yLS_cc{fjcd|S4o1NL&onQ9bo!#5^nVudS zg0cDf4g0WAw|lT!54^>d`WHXWv<MFOdD&Sy!7kDyM~%b0mzCIHl{7KE$JWJCzpwH@ zC9F)B`7$djq~_Lcuo(BVn2gex`pXUOzl82Sr&HX|r({GE<)=IWo^2ZB!33cJ(9=A~ zkc?n;$f!yv<8H7u!_#SaIPYkz7&E|woGTUwY7(%Rf;x4R2E-nToUggpvhxU5I*<66 zer6-a5OXvM)-jxZ0EZ4@@zLv(AWH0cWzZpAcNnLL&(bhhkiTz1rf%UevE}VIa<;vF zMh@tTSvdujueDWf$9m;<3^&c3Vi$OljqaFt)jgAJ=8s2e3^WnELTU`M4Zhrun%F=~ zhS%)9M|;ojSJ4ng4AC^(g!X2r!lldzbFc`Dvhf`Wd|8o@k9FB<p0)4L(wwNteU*m_ zuF2YIo@S0oJvT&Eb)2u7#4rbFAi{$bOBGX<XSm9cLLf536|fzKt^7ZccAZI+NFL{^ zxNdhJXOGy)9SI}-o<OQ+?JJNP|A$l)r$o`aI{3SJ@qdUmaY_*1O`H_Bd`EFgtoyD; zA2$}@B??bBrK;J?zb)QsZ!g1hD74-B;rnS_a3-_a#g(*HC@$u6D}_itM6?e0KiY`? ziMn8SsJM~mh0X5x+<0taAGjjXLMfL^Y8#N7DrvkBg*>e1=}T(&g+AU*i<hITINx%O zU6#a?P(aLv2ZSp;;nKPxiguO?d@))6B&F4&$c2aTxmG>=id?Z^3wB~F?uBW?hlowr i#EM-YL--7F!mbk^F(im}@fc$IAqryc*eaf#nEeHTi4Ud# diff --git a/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc b/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc index 06557f92116c59895d3c014e28181f1e3c4ae7f7..306e8d1ace118fb56561c0e4c743f7ff9528fdc1 100644 GIT binary patch delta 734 zcmX|9O-~b16n*ziTc(XPgdk3l7z_)7#7K03#D*9`B5~nITp&Y9fxKHnI)#3b04}Od z{UjzeZ-LQBrqGI7K%jmkZip_3utE}DuyVmheuE3|OcQ1@_nmj{oHO^l%zRPF78HMX zb+wPtwY0S=zKlNiYcRCl(!-%34Sp7i@+EJ9H}X60geCcpw;g{gbTN<x{knK3LSm56 z4e?$EmCkb|VTc&*Z;Np;k!5<nSQZJ9Bt{o2A|=v<1~4QhqdezLhROxy6?mI)xQFFj zgY~l4-KNBFt~~DU0=ee70Kesdr(2$>`ldV+mh4qqGbOx-(EzgyED081RGS0LMiHI| zV3@T#Hp*2c4*MAP;cJ{Xd5G)F`pVQrT<Kwh?Rp7UD|!>~1I?l)ujbT5J3b_J)bV+6 z<U>69EaP2_23aQOPBh>Z+Y~KRA=@>59cOX9gXy#QltgQe{1KsogB}xFbWm}WMHTwg zSA0TT(cxxrk(4u2BREfk=``3;fT`X}IO2o)2Hq#lhRFCS^M;*QV~&AIw96`2II7e? z!9=s&wfmInPoG_XP=;xGcfnyP1ns@nM>@^0$cqJ&=Q*9;CzDZ)iAC!L&gH=*2LZQO z>*`<o|G2gf<wI{|6XNpbc32MA*2{zKel0L;7~#lRGL;U;W06!iJsn2_Qz<#JlZ4OZ zUprfFY3&B&=5CW`-<6mgOUu)tx|)487@H8226h2?r?O?`XAN|Ks%SnJaA<&2;GtE6 vMp{+y(^KO+&$|^Ls9XadXuKZkl*>T(NjbFF04{lFuhDy{HsAq&h8_D4?r7~Y delta 848 zcmYjP&rcIU6rP#>psiAy29XMpLs5dnK%$9~G-^l_xnUv+y3jO@?^T<23*EK`qXxC; z0W=ooAZjEn78MIcXhm;E4+cDe$6ky#55|8$Vtl(b#+~e&?|t*VU$Z+=FwHHPoW1q+ z4o1(=VhMJ_)6OQozS`~=d_l&<5&m307VZ3n{3v?(OSvOD_b!40Lla!-hgUER0YYb? zD1D}$t|f?pO8$9>!^i|vF2FJ*KqE$h5+or-Xb6WP9Ti!-wpK2fo^bnh4Eb5s9MEK+ zrL%m^(!<SJOJcvQSda2A^1HRS<{qy55E#IxIF}Jh5sT19IqI|@s6sBoIYSkrwf(PO z1h1j^jnm*g&Xj42VZ_)p7tR>W48=_2Oh!=jeaOq7bsghNcnhN*mW$yMFb^Z^^)MZU zMP9Emp`Yj|YfJbC84lJu8Wzg3QW`sfB|i&vD@(Xii_dU;ZPYA~eiH9ebkxXm=f6CY zlxG=lV${dla5clI%o#l`<}xNqyrz_K3d=OY6i!muntW$#HQBXBS+zA&!~a?u>#Ma( zE%=lQ-X?7YoHRs5AE1dDg)_vOb8l)~?R*~tuYPKbWl#79S=OFtqo%HscHUs=dysF> ztz#C0Fb~f%BClVY>V`sY79Zxhe)qu7v5uAUJ$w8;kIP>hK~Zgyp^Zj4x?z>d=E%WA zYC4vR1TS6=MdHC^=n5vg!VxW*3WnkkjD-k{s$uzUQ{!*T1D`)x{Ko`iS~LBd+DRRq z(#-?IB~vvg(veivq2VPBlgY?PTzkYx;MG+gLjY-YR3^F`q`I|-&z5g&xi}XxyWQF+ z%)-HKqKP|rlW61(<e%oWb~$w3So&_fRyTJD9UtH>;iTQo&17A|Cam)7c1vBq+w2@? GyZ->Ov<c4u diff --git a/analysis_pmf.py b/analysis_pmf.py index 8be3eecf..ca0bb8d1 100644 --- a/analysis_pmf.py +++ b/analysis_pmf.py @@ -523,46 +523,88 @@ if __name__ == "__main__": plt.show() # All first PMFs + + # I had an issue where 0 contrasts were undefined for all PMFs, check that that's not the case + at_least_once = False + for key in all_first_pmfs_typeless: + for pmf in all_first_pmfs_typeless[key]: + def_points, _ = pmf + if def_points[5]: + at_least_once = True + assert at_least_once + + type_saves = [[], [], [], [], []] + for key in all_first_pmfs_typeless: + for pmf in all_first_pmfs_typeless[key]: + + defined_points, pmf = pmf + temp_type = pmf_type(pmf) + + if temp_type == 0: + defined_points = np.array([ True, True, False, False, False, False, False, False, False, True, True]) + elif temp_type == 1: + defined_points = np.array([ True, True, True, False, False, False, False, False, True, True, True]) + else: + defined_points[:] = True + + if temp_type == 0: + type_saves[temp_type].append((defined_points, pmf)) + elif temp_type == 1: + if np.abs(pmf[0] + pmf[-1] - 1) <= 0.1: + type_saves[3].append((defined_points, pmf)) + else: + type_saves[1 + int(pmf[0] > 1 - pmf[-1])].append((defined_points, pmf)) + else: + type_saves[4].append((defined_points, pmf)) + tick_size = 14 label_size = 26 all_first_pmfs = pickle.load(open("all_first_pmfs.p", 'rb')) - n_rows, n_cols = 1, 3 + n_rows, n_cols = 1, 5 _, axs = plt.subplots(n_rows, n_cols, figsize=(16, 9)) save_title = "all types" if True else "KS014 types" + if save_title == "KS014 types": all_first_pmfs_typeless = {'KS014': all_first_pmfs_typeless['KS014']} - for key in all_first_pmfs_typeless: - x = all_first_pmfs_typeless[key] - for pmf in x: - - defined_points, pmf = pmf - pmf_min = min(pmf[0], pmf[1]) - pmf_max = max(pmf[-2], pmf[-1]) - defined_points = np.logical_and(np.logical_and(defined_points, ~ (pmf > pmf_max)), ~ (pmf < pmf_min)) - axs[pmf_type(pmf)].plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)]) - # axs[pmf_type(pmf[1])].plot(np.where(pmf[0])[0], pmf[1][pmf[0]], c=type2color[pmf_type(pmf[1])]) - axs[0].set_ylim(0, 1) - axs[1].set_ylim(0, 1) - axs[2].set_ylim(0, 1) - axs[0].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) - axs[1].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) - axs[2].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) - axs[0].set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size) - axs[1].set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size) - axs[2].set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size) - axs[0].spines[['right', 'top']].set_visible(False) - axs[1].spines[['right', 'top']].set_visible(False) - axs[2].spines[['right', 'top']].set_visible(False) - axs[0].set_xlim(0, 10) - axs[1].set_xlim(0, 10) - axs[2].set_xlim(0, 10) + # for key in all_first_pmfs_typeless: + # x = all_first_pmfs_typeless[key] + # for pmf in x: + # defined_points, pmf = pmf + # pmf_min = min(pmf[0], pmf[1]) + # pmf_max = max(pmf[-2], pmf[-1]) + # defined_points = np.logical_and(np.logical_and(defined_points, ~ (pmf > pmf_max)), ~ (pmf < pmf_min)) + # axs[pmf_type(pmf)].plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)]) + for i, type_save in enumerate(type_saves): + type_array = np.empty((len(type_save), 11)) + type_array[:] = np.nan + for j, pmf in enumerate(type_save): + def_points, pmf_values = pmf + type_array[j][def_points] = pmf_values[def_points] + + percentiles = np.nanpercentile(type_array, [2.5, 97.5], axis=0) + if i == 0: + x = [0, 1, 9, 10] + elif i == 4: + x = np.arange(11) + else: + x = [0, 1, 2, 8, 9, 10] + axs[i].plot(x, np.nanmean(type_array, axis=0)) + axs[i].fill_between(x, percentiles[1], percentiles[0], alpha=0.2) + axs[i].annotate("N=".format(len(type_save)), (0.75, 0.1)) + + axs[i].set_ylim(0, 1) + axs[i].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45) + axs[i].set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size) + axs[i].spines[['right', 'top']].set_visible(False) + axs[i].set_xlim(0, 10) axs[0].set_ylabel("P(rightwards)", size=label_size) axs[0].set_xlabel("Contrasts", size=label_size) plt.tight_layout() plt.savefig("./summary_figures/" + save_title) plt.show() + quit() if save_title == "KS014 types": quit() diff --git a/analysis_pmf_weights.py b/analysis_pmf_weights.py index a2b06196..8d2f703b 100644 --- a/analysis_pmf_weights.py +++ b/analysis_pmf_weights.py @@ -33,7 +33,7 @@ def pmf_type_rew(weights): return 2 all_weight_trajectories = pickle.load(open("multi_chain_saves/all_weight_trajectories.p", 'rb')) - +first_and_last_pmf = np.array(pickle.load(open("multi_chain_saves/first_and_last_pmf.p", 'rb'))) # for weight_traj in all_weight_trajectories: # if len(weight_traj) == 1: @@ -82,7 +82,7 @@ for min_dur in [2, 3, 4, 5, 7, 9, 11, 15]: if i == 0: axs[j, i].set_ylabel(ylabels[j]) else: - axs[j, i].set_yticks([]) + axs[j, i].yaxis.set_ticklabels([]) if j == 0: axs[j, i].set_title("Type {}".format(i + 1)) if j == n_weights: @@ -105,14 +105,31 @@ for min_dur in [2, 3, 4, 5, 7, 9, 11, 15]: axs[j, i].set_xticks([]) if i == 0: axs[j, i].set_ylabel(ylabels[j]) + if j < n_weights - 1: + axs[j, i].plot([0], [np.mean(first_and_last_pmf[:, 0, j])], marker='*', c='red') # also plot weights of very first state average + if j == n_weights - 1: + mask = first_and_last_pmf[:, 0, -1] < 0 + axs[j, i].plot([0], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red') # separete biases again + if j == n_weights: + mask = first_and_last_pmf[:, 0, -1] > 0 + axs[j, i].plot([0], [np.mean(first_and_last_pmf[mask, 0, -1])], marker='*', c='red') else: - axs[j, i].set_yticks([]) + axs[j, i].yaxis.set_ticklabels([]) + if i == 2: + if j < n_weights - 1: + axs[j, i].plot([1], [np.mean(first_and_last_pmf[:, 1, j])], marker='*', c='red') # also plot weights of very last state average + if j == n_weights - 1: + mask = first_and_last_pmf[:, 1, -1] < 0 + axs[j, i].plot([1], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red') # separete biases again + if j == n_weights: + mask = first_and_last_pmf[:, 1, -1] > 0 + axs[j, i].plot([1], [np.mean(first_and_last_pmf[mask, 1, -1])], marker='*', c='red') if j == 0: axs[j, i].set_title("Type {}".format(i + 1)) if j == n_weights: axs[j, i].set_xlabel("Lifetime weight change") plt.savefig("./summary_figures/weight_changes/weight changes min dur {}".format(min_dur)) - plt.close() + plt.show() f, axs = plt.subplots(n_weights + 1, 3, figsize=(12, 9)) for i in range(3): @@ -129,7 +146,7 @@ for min_dur in [2, 3, 4, 5, 7, 9, 11, 15]: if i == 0: axs[j, i].set_ylabel(ylabels[j]) else: - axs[j, i].set_yticks([]) + axs[j, i].yaxis.set_ticklabels([]) if j == 0: axs[j, i].set_title("Type {}".format(i + 1)) if j == n_weights: diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py index 37bb5dad..fa7eb589 100644 --- a/dyn_glm_chain_analysis.py +++ b/dyn_glm_chain_analysis.py @@ -1610,7 +1610,8 @@ if __name__ == "__main__": print() print(subject) - continue + print(ultimate_counter) + test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'rb')) @@ -1645,10 +1646,9 @@ if __name__ == "__main__": # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 2', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(6)), plot_until=7) # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 3', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(4)), plot_until=13) states, pmfs, pmf_weights, durs, state_types, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save=False, show=0, separate_pmf=1, type_coloring=True) - + first_and_last_pmf.append((pmf_weights[np.argmax(states[:, 0])][0], pmf_weights[np.argmax(states[:, -1])][-1])) - continue # all_weight_trajectories += pmf_weights abs_state_durs.append(durs) ultimate_counter += 1 @@ -1660,7 +1660,6 @@ if __name__ == "__main__": bias_sessions.append(info_dict['n_sessions'] - info_dict['bias_start']) print(subject, info_dict['n_sessions'], info_dict['bias_start']) - continue state_types_interpolation[0] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[0]) state_types_interpolation[1] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[1]) @@ -1675,17 +1674,17 @@ if __name__ == "__main__": # temp = contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True, consistencies=consistencies, CMF=False) # quit() - new = type_2_appearance(states, pmfs) + # new = type_2_appearance(states, pmfs) - if new == 2: - print('____________________________') - print(subject) - print('____________________________') - if new == 1: - new_counter += 1 - if new == 0: - transform_counter += 1 - print(new_counter, transform_counter) + # if new == 2: + # print('____________________________') + # print(subject) + # print('____________________________') + # if new == 1: + # new_counter += 1 + # if new == 0: + # transform_counter += 1 + # print(new_counter, transform_counter) # state_dict = write_results(test, [s for s in state_sets if len(s) > 40], mode_indices) @@ -1708,7 +1707,7 @@ if __name__ == "__main__": all_first_pmfs[subject] = first_pmfs - quit() + continue # b_flips = bias_flips(states, pmfs, durs) # all_bias_flips.append(b_flips) @@ -1790,7 +1789,7 @@ if __name__ == "__main__": # pickle.dump(all_first_pmfs, open("all_first_pmfs.p", 'wb')) # pickle.dump(all_changing_pmfs, open("changing_pmfs.p", 'wb')) # pickle.dump(all_changing_pmf_names, open("changing_pmf_names.p", 'wb')) - # pickle.dump(all_first_pmfs_typeless, open("all_first_pmfs_typeless.p", 'wb')) + pickle.dump(all_first_pmfs_typeless, open("all_first_pmfs_typeless.p", 'wb')) # pickle.dump(all_pmfs, open("all_pmfs.p", 'wb')) # pickle.dump(all_intros, open("all_intros.p", 'wb')) # pickle.dump(all_intros_div, open("all_intros_div.p", 'wb')) @@ -1812,8 +1811,8 @@ if __name__ == "__main__": # pickle.dump(bias_sessions, open("multi_chain_saves/bias_sessions.p", 'wb')) # pickle.dump(first_and_last_pmf, open("multi_chain_saves/first_and_last_pmf.p", 'wb')) - abs_state_durs = pickle.load(open("multi_chain_saves/abs_state_durs.p", 'rb')) - bias_sessions = pickle.load(open("multi_chain_saves/bias_sessions.p", 'rb')) + # abs_state_durs = pickle.load(open("multi_chain_saves/abs_state_durs.p", 'rb')) + # bias_sessions = pickle.load(open("multi_chain_saves/bias_sessions.p", 'rb')) quit() print("Ultimate count is {}".format(ultimate_counter)) -- GitLab