From 6df25f7a42e7ce8e511a43d85fe54868347ab0a9 Mon Sep 17 00:00:00 2001 From: SebastianBruijns <> Date: Thu, 13 Apr 2023 15:38:39 +0200 Subject: [PATCH] code updates and plotting code --- __pycache__/analysis_pmf.cpython-37.pyc | Bin 6551 -> 6865 bytes __pycache__/simplex_plot.cpython-37.pyc | Bin 3159 -> 2990 bytes analysis_pmf.py | 124 +++++++++++++----------- dyn_glm_chain_analysis.py | 42 +++++--- pmf_weight_analysis.py | 91 +++++++++++++++++ simplex_animation.py | 7 +- simplex_plot.py | 19 ++-- 7 files changed, 198 insertions(+), 85 deletions(-) create mode 100644 pmf_weight_analysis.py diff --git a/__pycache__/analysis_pmf.cpython-37.pyc b/__pycache__/analysis_pmf.cpython-37.pyc index 34a4499ac8d78ac0139d1f5675718e1ab52cae6d..e1ab5f64794b712fd3d7753bfc894d24535085b8 100644 GIT binary patch literal 6865 zcmdT}O>7&-72ZFRONtU@MYbqevFZHSVjM}7toScZ<iw8bIF75-4w5dMu9un-wbJr$ zc4bp+S`<*wn+u{Rx9CACf&eXwv}k*1fg%ZdN_+C5K#Msgs1N0(=%MH#m-fAx<&P98 zKS+YQBzAY+n>TOXeBYbdN9oz#-h=|5KfIM8zdo-he?y`5>4nBwxcqPKQ52#O)l)3_ zREwIWtFmoa5z!a5Vnic4F-YXcx@FSHMw~=RjF==&5@69olBAdPkyN2)Q?(MLpX_-< zv3h8d46JC4J<OuYMlbCn50Jri4J=X=z8n3U%6^6H9aO9o*=OwuZH7Xdfly9|@&lnf z9Lj@a#M(>tTl?q$868&WzF~+No*gMQtZ0fgWTmOTJPa|9(4kFL*re$YSemqlrq@+y z?IjQ1Hek&MgtvoRytSmy?TzXi`vIHOrlu;;9))&BwwE6yhsZ<Z;oG8T`2cyOOFD>S z3L1H|)xVM1RF)5c_Atp5bif_^b<qzz(b&dAn~Ip>;qdtgIkKq=&7-IxDY2@tEvsrt zpL<m#E4)3n#am1I+}{3&7#-b$(TG^p<6BnMlK!7oHQv1{dUzfApSZ4BN31c}1zCE8 zj*&@{yRC!nG3#jUSM+f@u&$n0UV95P<JJTnr$@-tpaT2WV59R2Jxa$(c2J>XBTB2a zPg&H4mE}n~L8fJv)(9P-6T=Glc@_NNo@<R0du*tCd|MOx`NX|Rj@^bVZy43fui)Ky zMXCKBVm>Nj{-uaHS9=pQIia}&nkj3VPA?xLGvtW^#JMp;jtf+#H=e+~BlaM!GL=?V z^hU%_&?zEXEJLR^4i%Kp`nYVB&XN<>3GjFb&XAMj<fcj!^aMFYPUD#Y`s3ggProN& zU*dV$ZiCgsVf23z(f{@pJPA7c^h`HfpUBg_Hcx-u#*@;Bu=g90>d#9D1a~9A9qR72 zm`p3ycu4~$JIqGdzwaAnqt(NJ4aEPYGaL2qh0I3TM_&=M(dwHJeYE<6ZOlIP1(}sB z#@J69vFi6iW)BEvW9*$4&s&(q5n3b0?%p@dnpFj`f%w04W=;Q{kXe)c`72`9tbPZg zH>=(@W}p6o%t{vH?Da;xdLv|ZRxlf9cLcLwv4vS2p*7;{XZH=W@hS&wApS3%*@S;5 zWH!Ox{)(7QR9}VY6V<tG%%1ILHmo+E39HRhq5NzppN6;Tlj2P(E3Gr+9FA_NwHY|6 zpQ5Me>2(d?NDcM144&9<rLCvo^q++~Za=(9Q$xxo)^Ga887gwBc9y0I)Ee*(ow3fg zYt}SYtYO`X79c^Dg0<^etX;`@@rLQCHpFT=6=$n16f=?M;7ylgCd>xaWcwW_rPh0n zz?=uj9C@x)d5Su<E$vc=UXbrN>3R2eA!*skcrHRl_ps~pN~5QGeH#zYck@6lStrN~ z)-xggmqYxYB`*S^yHqvgYbA}$Tj%IGG7s}4*+0lj@VnrpVZ~Lgf!Y!s0E}-yts`m{ z>wImVo(IiEQMEvIf-E4m;je~n21MT$7TVTfrP5oCLUtvq|AJK=MYg)Cj-%*#i#)YW z2K!y=3^8`sKJc?^ANu^%8U<p7*glgT&h=UR;k+NO!sI+V>Cx{_dc^QKsH^0!Cy8OV zj2L#yIPWAmM4t)sSg^1>N9Nj^t{RMreGAzuek+9;lVT%b?W|ZAux=(Pl5N+`7p+`v zL*&pMYogX@SH5pRMq=eF>RXb69hYJ^8ma0W)Q!)}oPm3ESQ&yUG>v=u>nkeQe4kz* zsS_|0nX)dG()0ysOH$Grs#IfLwJy;dy#y=Up;{(a!`y?pAoo($laPD4ZMi4q4o*Ve z0dJ}Qbbh7X&NZ>@>*gT0E%T+^9S!b`@FarRF&96c3sS$^oH(YN<s8Ax=+yjD?w%9- zJ9F-we`3@wd+Rd2urAO4PECJp%ju<FRkANP@vkmdc%;B6Epa37tx~QtO0qgrp{FTs z8n*rep8{Gd$_MX;jk~SJd+!!rJ@uRSZ@%-3c4IcQd9T&D+iH9;%Oc=q0e#~KpQtzo zm;YV3y$*l$8OA&OL`o|}Ur|_^XjhdW)Dc5kq?L7zL{<!cZrg|`jHt=+*w2E_2q_0D ziEU`FL!u){=$jaDnQ$opOZ*CpK^y{t!jFE4pFhf=3vXZeus-<qUuHw~hqLvp-T7Q& z<<-1L{c@3J=3ksUo%z<cJDoO*g43+d^`cYbkqUE50go=0S<wlyD%UH;#h)rH2{s?i zOs=fuC!IotO!|4Y>Moc3Nq4DOwDWYyxoP`>6HtHBDLLMn@A@{3oT#jE6P|XkR-rxM z!|#Wiq7JHtHldEHeX0R1)(1Ln+IG=#OSWC#FFc7cz763}Rw;W<;Fe1h6|S*-lYo75 z0B-eR$Mft(m-zus=Zi70Qh%aSPhb}+q1}(|yzBV(qUTmX(=^Zp!dZ`A`NqrRx#L-l zs|#H7ZZ#1PD|YVc#ml+lQ%(GnV7+&KjJbuS;FiOPpKWG92iO?CTc_MSS1tw2@q^}m zP&YAW>V3546e}Lh;JPzYmWkf%+|<dY4&7F-oA>NnZV=FtJvFn<dU|X3iEZ7-vKo63 zoUkFdWwj;(+Kgjc<*`M#fH4V5oA8H9P_n)7;1Ll%*Mo9pu8Cwek&Wg!nrLhUO-xn> z0s-W3i~N1I9(PJ^(eeDuRHr%JX&!5uFt8)f)Kd#fPO0FQ3N7FVyA^iHDN>|k;2Ofx z>cX7#A^ONd*(<Z=KJaP9TcG;vAjr)FI8kqV1k&2XtZ8EC^&ZJREmkHfupFd-;fQMx zo977<{Cdv|S94P{Qs5CGgxlqJvs}&B2Z_75NEs~!F66OfZaVuSk5=4;6_4^gupug5 zIq=*(q&2pnOFs3A7w|}FwFv31l`1^uQzu_$C9eA=pC>R0rp4lT#9_=?;~18`0D0rL zb04v>ZEkpFhj61@frD>&*L}bcc&;t_d<Z@5umQM<9f%9>!&WExnFLt+hM0LDp^I*b z61!4{&E_-A3~q=8Ms6(uBHUQ>+#)w>c#p1mPM&%^S__|R0pR0vy|Cc$giiyzCKWM4 z{R;Ty$pE*G?Kx}Z)c~+@Zo+&8*gZ?-TimFwx&UO!X60KxJ^?46oBnDZ1Ms<ObKR-= zJWe50%77WJx?EqOYjUzM8V$}ab(yOUkBUis-h-jsbbZ)rut!l4ix3gYkgYfdYKu)A zbr>WrkSvJlPGY`qyGo`poRMx}6~KuEZU9s0b+^I|WD%AqqvvXcCj?4z5<x_VD=UFe z*mAg9lR*lMC5#f^*g|;~k3g;?Xu6a4t>~&xZCJ1lbTQoe;&|@5;Xf<5d8Sw<s~&wC z-(>zj;clu))l^MAsb+MuFQN6T{hBcfceQ`f&<yn<bqvm9xNAwU?1S%sn!yojLNz~I z|4t*-sAj^6KLGwT^)NgwxCTIzRMX%&37*qxE}W%h8O8ZE?S2}gBVI{}I0MlqAhH>V z-NZOSkM|59*9S-mymWZ80kXIrl#?1@2+u2sHhf3LcMK38gVp0#>H!%!9&DrFV-TK* coM9jnQ7sMcO|(qIxKXghb{|Mv+mf^Y0?m)68UO$Q literal 6551 zcmd5=-ESP#6`!yDw7u&DCtk<FSsRkrnAnbEk`PFo4-yhcf=eQSMn%(E-|O|*v$L~z z#>sk1B_#7yHPVU)L@OTFQiaqPimIw=B~(x=RkZ2@Pi=LDDroz_{tLXc=iEE{VPiK) z1SvDxJNKS@&bjBD-??+o&c(jIv;x=P{@6#(UQ(2Q<Dhl*LFYXDf_%TC5QV6&V#=pl z(o9{IeZ!22xwx4i8qtYCV(;o^lExY-5+@0gBq@>xjvmrW`bdUki#?mFnI`>Y&zp+b zLwm{IRjuC74pODjM>FIRGO(cmN0!3f=-*U^6|!$YF|%a9xrgd2dqMLfpf5`XH+7mN z%8){n*pl2mun9N@$PnE_4&2kkXdfBwj2inva^EJ-BpQR83UV}Z(Q}9#+*AeSfTVnF zt*lrJ+6S=KqxV^BM63nv16b?OmbGM@^HH3KqqvVo;}#BM(fCpFm^lK!Lv)xt4mKSE zi!kP$5yV4N%)B`op)*GFo2uX$rFry%_R!G{6?!A&@I6CB7;(r79Q4Lx&_5FOA1^AA zt%sxDPB<Oi;?I!{z>$o9Uo(%I<6y%CJxa&P1bO0~4zwrC$?C7@lXQ4PeNK7pJ)j&j zr|2<yl#C20ka-&W=Ym2f=`k`npwRIHN~^bDS<;4-m1#Ofo|M^CC3Kih4Jn}KRnUVT zx8#I}gWd9NN~8z#cGt*|W0<>jy^?uFscV%#y`oe<1e+&?%^wJxXR1F3%8a1A1C-<D zES+6BL8i!b5$tTtkr@${*~U}t*rJ!&oU*Fd_294cI1xSeFFM;e8AW%taZ2{Q^W?aB z8Z@2+JD(=An<`Dy)8qu16Nt^OoC2-5f1iRKj61&F2d)E=^*;gY^-ArPt-I=EH(LJ} zr~kBZ`o}h$l)AxwTsJD7HzKfm46sAmdo3h$ih0b_0Lc!r2K((pgRD{c5@G}PKX7DY z!Fv(1G4|niglw#`1lGqYx3?jC>RTc!QH(QEk5{fn$etF+#@QP!oVOs05?VdZ-g;<| zjaU8vu>t!ZII@Z0jR@HUd;2>=Hc>eb)+Z`c+mL<wTOun_OtN$JWMwQu_KZL_$xMMP zaBM*qCA4~yxepDp$;!JB8?gU@Bby4$2-y_#zawN*l>=aXs*>D>>@(fSMrZ%C(OG{c z8lR2E&%kN=tT;(!U38Y5L+QF&or5>^IeLaZv!UUMR99ci;~N`(qvkx+ROiV&IS(gk zc2L<ARaxUKok!W~c{)m<Vt_Mr&b-jB7)G&Ph^mGQzyYJEa=3t%gRFt->B`H(n@;u6 zst&|T<N};@Y4#Ub|Kg_7KI7!5b<Po3b04`#F0?De0dlD`LdlhraK6oxBSjr5Sg2Fr z`J6oCq~_hv!Zx3@=Q3n;nvE<dVEM>49IkZZK%O^GldI;XDE`-?_&-NpfQar=)sX9+ zMsApw>1A>Q*6Cp%keg7I-5gRJ)f}!a(_x7551`f&HH&$rx<Ic0<*KM!pgKVmFt(z~ z9g{I{)onPp1%)<uRH^h-K7s7&sr&}Knnbj^s*XvN-ZF+NcFP#}cd0Xk?5=a5XV*E@ zdEXiZYz5oC5g)GgP4wZqU$4UCIy=!(?@n}#;q#DfvJ#dt>>gtbyT`chG`U1CMR_bx zSh-3rwkcizZ0P0gY>P6ck9|gE?W~yBux^I`RFmzx`2};P+7LPPjyYAWw=3T_AtSN! z74<F2!j8+Zv3jP0|M*^%IRkBcNEw8klfymz(y9tP3-lVvo`RLgh<V-1(HoMNWF<FL zsm6wCUZ*qkI(W82wM-VG+=I0s_xdVdLGI0L%RM>n#7T^Ih_|HQA79D0b4@Jgx-pp9 zmicnrZ4K^>=uHH+V=jI@7bJbRHBqJ;<r#vR(W&|6xO+{U@2t6V{t2mF_SQ{$Z9~5Q zJO7lw*PTy9y`@aN+{Et&H+ihcDD}8eaMvi;86^`sQ(>kl8V!5@g{z3%it_nKQRmZE z=i`rxuRi_jPwxEwm+j7c#Pe~h^J%N|`8<n(mPOQ!7e0{ypE8u-4QO76&lW@2!e6AE zLiAOI<%o7m2_qUYBu7r!&`5062(E0Ch=W8`%Huo>I}&mnswC0SUWY`-Afawzz%L0u z1;UcL$r50P2tn~@KgH|M(&*y*&wo)Hc>nM75&4Vx+Jx1)ZnN@Q!KFdDMDq(buAIsL z@JF3Vi^W0dgwFMnUF9*KIbO))OJ!EF!wHq^erf3!3hM=)ujZy#*9+5j(I?YEfvq_! zUNG$}mr7QFF57pkAhbgoOxvFAt_MzFf#j6G&Xe%8!gZhafQ|sZ|3GX9R70Cm52_i} zfF8>LjVCRuWILW^)rN&8Arn|&4rPAXwL_=uP5E47g(d<1=3Z>xh6b#REtcIfs|_E{ zmzVOmR^DD@<siuWWq(bI0+$guHCh|G*49W#KA=GWiu0jUqO~3&y6gmD{&@FbR>Y{A za-ue5yRNn5upmT717R)1IGFNlX`DhL^!u@2aO}WZavdKiO#@6e`<#Mn-F3o{de-q% z6B<_+x#r$&;wQ9PhF0y8?^5Z>i3zQViP#i4X-$k^GX)KJlvr|#=u}fj;vPX{gYe)n z<`kDhu7_p+N)yp*V)&cKkVRwrkR{2=fEYL*?TIhvYbo1vO12w-`|a*58wT#q_(U_g zWjt5QE-u?%(ea8crw6(TmS>kJhGOq+1g15GnHaDGpsI;k-P{k2nYvW=LS_eHQ0rS5 z7cShj83`uXLF5^Mrq<WifFW*T?le(Ntw-WOOa7D(K4Q$+A!u%c?XD|=8Ps}Syft%t zP7chN7(|=qck^5=)CP#Nv_u*8LI<*0A}~Ag5|8`N;;Kve9@rMXTMk{P0LhI#n37N3 zcS9cY)=H4{n&<OGK<z@Ad0Y>?fTuCHW<{Rzn9Z2I&e1Gu5j+msI{;&4S=?~THsMCu zhu7cGuDgLD5L{adxFHNlx{eq4_9Er+b=NLXm&dE}X{@_Wi6;mRe8;1_H^k(yTzkE| z7J{Aj9SVsA`(wF$mk**<J0vzw;snN+XRy~vYL){#%#^ef{O1`0`aDW3zYN<qV3>M5 zc6S-#%8e?T5?>D?o&gvrpGIZPfk=85E8hjG2Vrpoo($Fs=vly3i|ckZ;3*3FDFbM@ z>TrFPuJd$2Lo4!EgIAck!d08cg{J}U!Pz@b0Q-%CKEVp>3!!LjsX*9arto3mmf~@n zK~@9e0;TQ@=li#lWRAmY(<!cjkFn4RVL83#_}oB9A;!`ZuKHXrF4{aTfRQVSaCRbY zrV$WJ>QJlFY7sSQwS+XWSYE>y8P^dy-7W-Xd@Z0BL_mk2Gkimc*LGKfpW%7(*;1LT zx%3>KaKUHLCe>awsV4PaHLoW#X{}%F*9>(G-p(00qo&oQcE2%5sr^gF7)p+7y&#!J zIk>TnYe{(P?*)yTdIX*({PqH+SIxnEFRYhSXC!Z%GmaK$+QTHHV{Fi~JXoIw%jUrL zB-#mdZ1doI2BIar*5P!6sG#RK?$t07Y4s*X3hsmA9#`?siCfYkGEyFR<Dg>zo)|g9 a$b)=b%fZ=+oH>v?2)x+Oz-Vh<M*e>`lyK1i diff --git a/__pycache__/simplex_plot.cpython-37.pyc b/__pycache__/simplex_plot.cpython-37.pyc index 8417f61340255c4910e14c7ab0a3f0d48670b60d..d40c9990e4c7cc9a2781a81174008f4cb3e8b339 100644 GIT binary patch delta 1223 zcmY*Z&2Jk;6yKTs@WyL9PGZMNqtJrXgWD3dX{A;O1hfbtC8Uw4I9Ni<WM}Mbvff>0 z$Bi9X3aElx5Y-~_ak7uOfDnH`|AB(U3F%!p_r{qMZ@dW~Z+6~$>-T=|YsUNg$`7sT z{c5#rAU*k_&ext+KlNOWt7me;$<uEcaew;jF1dQ_Dv_y;H153%<S=kG0?EUz(ObB) zN&XQ%Dy!<)7#ous_wrq{S(+c0pAcf?&&{<huS-{rF6EWNJ|(=`FhVkRzoB2EkE$x` zVBrFdR$QM`PBd!G;1;i)(#azCcxg)WkIh#8hv}ZoV)A)jSJnivI`5Pek@>=@O^h*_ zI8$S?#25L}l>YPHxOi8_q{$moo!^)&PYo>(TD-}Zd5f>`)l&;G)>lM+jrJ97k6qMO zwZ@6{fziH{e`x(kAb0G$Coc$20!Y?FfmC^`O~7KqFavT8h-{nS^qD*qQkKvx=aBWj z|4!4{v)(VyHU0S}m`S&LJkJNx&%%BhiJ=}m*edMptMCS<c!Ytxy}e}UhjFkkGB9}> zf~WeD#aYUvh*kR%kX>-PVQ(M>kZo|b!&q#-12j&-%07__%wE92Wt{sFn!p~;3SG#A z^84%{k$#i}OeGTN`;Wle8AQ?c`49F;yXK!c%@eRw8O91q8Tu+A!RiYZLrIEkpQQp^ z9c~zf{X!d|CiRtp%B&qfVxwfB;q0-%VzIRS<Ou8_N;1**z)2H~jty#C<*|@T3?sHD zqO4QakoSBQsz}Idvyg>7NTNh$9e3#C95H>z9i40793C;*%Ya6=I&-N^4n&{?KQrWJ z{;;%lQ%5o8TvDMm(jYF;yy3ir+H2-b^14|kby}t_StJfAQ+em-UrWyjaq_y|+>kb6 zUpUC5{(J1SUD-kj_OXyj26izH8&9Yw<VybAT;tviw09gO!&G+|@*xm>MnKbBxHmgA zy)8F&z!|%(wJXTp$k)qP$(8&;`Ro;)s!f3evk9~>L&^7f68JvY@t~iMWGk<EmtS5& zH(2O1l}3q*!o3X@26zj>DWqJ_cf57GslD>lTR!=pZ!u4B&Lw#Oi~os~GBvw`r3;1b zzAf)yq*<-t`hLv%0t<4Aa|4@6$wr;B{-PA8EAQ#VW+Y;HUjtulX#Q{KMf>?`<@(Zf Qcip9yUBM&Kt41sIZ!tD)XaE2J delta 1459 zcmZWoOK%)S5T5RtogMGn>&JRyoUDl=w8pZCjT8<@D+tOZwq+?1hrW!))4l6)=0#6C z*|ioe65?nengb*USRVrt65_<ke*mr^^(hJ`pYVZGPN^Q-u_J1xdaA4HtLmz*`K9sq zT>0a2xu604`@BqVJ}!Sz%+q{NrwAVF@%IkO|N8P4TD~LEgZ(Pq<96Z&vCsM|aoBS$ z1q3Pr=Yj76C;z3N;zy$)8k>jOZh45Q(bQ;>mWD<qm1*T5hqO!fD%fk8T}l79TcDj3 z#yCgwhj6HQjn*4l7iS1%J41-dP#bFFDy@Ro2+sc%rl^aX+PwB9`n({t-P*9A>H@u? z^ut_3d#$Nv8Z!$#MW;P9w7$b%BczFnFm@*|g?=M4HR#L%9@l9V4u>!Dq^>?Q%3{)4 z+Msi^c~SIuie91fhxoh+<LUFg(8e?LtwWVn8_ynU;F13p@fUFR#xPG8=v8`+whj!? z>!3dYuQ}yK(LJrbxc@_;1+?WyqP5vyYo$fwM?}(}ja#EO?sYTz>!|niiK1sm$8&I7 zciOr%5^tYz6LzHPgkdBc;YFeC`d)mDKnvitjQ;7o2A>JPpKG3UBub92D{mwXr|%^a zElI?WSCwj{6Ekr51IDv7Tn*ess6YQubkZSu)a{a-@OnGKTZYW{+d3(3Z%2LG3*8-- z5S_*zDGKg{N$hYIigr_?En;@P-X3QXtrK(I3)$)?5{EI=^%)mL@43#1SPrGO4>pqP ztJ9Imgb6!vb|Y^4k?V+vOT2NP6gT&LfBjTJa)Zp(I%^7(m6%7n%oX=e{{g8ddjZKO z02hoCBVbNQ@|-0*PRxj<GW2~fAl8x+2O#niR8C4+iMH<yqCEk4=OKd%;JiE0qqJi# ztnmi8$<jl{1?#(!AMr$5eH+>XfuuD!mkJ7M3s3kg{l>h47So^1yYH_6Sevj=30KfG zvXEM*aSOH3D!z(8(BDJLdJWZZ0b9sKbzpw|w;TB*w3Gv2PoXvKdFM<x>2L$~$7(+K z1T33vi`Uc2L~~;bc2akvMB><kW_Tf()Kn2ua4!~cVIJ*E1n+?0_$Kfw@ETuKv9jfD zCC$L{VJaqWpqcdP#IvHx(#8_Gype7eW>Gz5g=?c4Xvt*Yh!_Uk_qJDr=fXrmuI$UF zmA!V!uC6<gx$MIym9>1y>IM?K0l~L7-UxcyZDO8v9o;7NQ&5oQj2)3pV*XsfTOV;X z_E{0qvhC0bm~9i&cea_IketK0Gav=Q5)nsUC=&h-xJ%t<Ay+rWn`*h}$Z1>BxcKhW RlC@xABUgeuR8Iw$@GBIUo?8F_ diff --git a/analysis_pmf.py b/analysis_pmf.py index 6729289c..637f06cd 100644 --- a/analysis_pmf.py +++ b/analysis_pmf.py @@ -28,32 +28,32 @@ if __name__ == "__main__": state_types_interpolation = state_types_interpolation / state_types_interpolation.max() * 100 fs = 18 - plt.plot(np.linspace(0, 1, 150), state_types_interpolation[0], color=type2color[0]) - plt.ylabel("% of type across population", size=fs) - plt.xlabel("Interpolated session time", size=fs) - plt.ylim(0, 100) - sns.despine() - plt.tight_layout() - plt.savefig("type hist 1") - plt.show() - - plt.plot(np.linspace(0, 1, 150), state_types_interpolation[1], color=type2color[1]) - plt.ylabel("% of type across population", size=fs) - plt.xlabel("Interpolated session time", size=fs) - plt.ylim(0, 100) - sns.despine() - plt.tight_layout() - plt.savefig("type hist 2") - plt.show() - - plt.plot(np.linspace(0, 1, 150), state_types_interpolation[2], color=type2color[2]) - plt.ylabel("% of type across population", size=fs) - plt.xlabel("Interpolated session time", size=fs) - plt.ylim(0, 100) - sns.despine() - plt.tight_layout() - plt.savefig("type hist 3") - plt.show() + # plt.plot(np.linspace(0, 1, 150), state_types_interpolation[0], color=type2color[0]) + # plt.ylabel("% of type across population", size=fs) + # plt.xlabel("Interpolated session time", size=fs) + # plt.ylim(0, 100) + # sns.despine() + # plt.tight_layout() + # plt.savefig("type hist 1") + # plt.show() + # + # plt.plot(np.linspace(0, 1, 150), state_types_interpolation[1], color=type2color[1]) + # plt.ylabel("% of type across population", size=fs) + # plt.xlabel("Interpolated session time", size=fs) + # plt.ylim(0, 100) + # sns.despine() + # plt.tight_layout() + # plt.savefig("type hist 2") + # plt.show() + # + # plt.plot(np.linspace(0, 1, 150), state_types_interpolation[2], color=type2color[2]) + # plt.ylabel("% of type across population", size=fs) + # plt.xlabel("Interpolated session time", size=fs) + # plt.ylim(0, 100) + # sns.despine() + # plt.tight_layout() + # plt.savefig("type hist 3") + # plt.show() all_first_pmfs_typeless = pickle.load(open("all_first_pmfs_typeless.p", 'rb')) all_pmfs = pickle.load(open("all_pmfs.p", 'rb')) @@ -161,37 +161,41 @@ if __name__ == "__main__": lw = 4 # Simplex example pmfs - # state_num = 7 - # defined_points, pmf = all_first_pmfs_typeless['NYU-06'][state_num][0], all_first_pmfs_typeless['NYU-06'][state_num][1] - # plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw) - # state_num = 6 - # defined_points, pmf = all_first_pmfs_typeless['CSHL061'][state_num][0], all_first_pmfs_typeless['CSHL061'][state_num][1] - # plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw) - # - # plt.ylim(0, 1) - # plt.xlim(0, 10) - # plt.yticks([]) - # plt.xticks([]) - # sns.despine() - # plt.tight_layout() - # plt.savefig("example type 1") - # plt.show() - # - # state_num = 1 - # defined_points, pmf = all_first_pmfs_typeless['CSHL_018'][state_num][0], all_first_pmfs_typeless['CSHL_018'][state_num][1] - # plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw) - # state_num = 3 - # defined_points, pmf = all_first_pmfs_typeless['ibl_witten_14'][state_num][0], all_first_pmfs_typeless['ibl_witten_14'][state_num][1] - # plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw) - # - # plt.ylim(0, 1) - # plt.xlim(0, 10) - # plt.yticks([]) - # plt.xticks([]) - # sns.despine() - # plt.tight_layout() - # plt.savefig("example type 2") - # plt.show() + state_num = 7 + defined_points, pmf = all_first_pmfs_typeless['NYU-06'][state_num][0], all_first_pmfs_typeless['NYU-06'][state_num][1] + plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw) + state_num = 6 + defined_points, pmf = all_first_pmfs_typeless['CSHL061'][state_num][0], all_first_pmfs_typeless['CSHL061'][state_num][1] + plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw) + + plt.ylim(0, 1) + plt.xlim(0, 10) + plt.ylabel("P(rightwards)", size=32) + plt.xlabel("Contrast", size=32) + plt.yticks([0, 1], size=27) + plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=27) + sns.despine() + plt.tight_layout() + plt.savefig("example type 1") + plt.show() + + state_num = 1 + defined_points, pmf = all_first_pmfs_typeless['CSHL_018'][state_num][0], all_first_pmfs_typeless['CSHL_018'][state_num][1] + plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw) + state_num = 3 + defined_points, pmf = all_first_pmfs_typeless['ibl_witten_14'][state_num][0], all_first_pmfs_typeless['ibl_witten_14'][state_num][1] + plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw) + + plt.ylim(0, 1) + plt.xlim(0, 10) + plt.ylabel("P(rightwards)", size=32) + plt.xlabel("Contrast", size=32) + plt.yticks([0, 1], size=27) + plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=27) + sns.despine() + plt.tight_layout() + plt.savefig("example type 2") + plt.show() state_num = 4 defined_points, pmf = all_first_pmfs_typeless['ibl_witten_17'][state_num][0], all_first_pmfs_typeless['ibl_witten_17'][state_num][1] @@ -199,8 +203,10 @@ if __name__ == "__main__": plt.ylim(0, 1) plt.xlim(0, 10) - plt.yticks([]) - plt.xticks([]) + plt.ylabel("P(rightwards)", size=32) + plt.xlabel("Contrast", size=32) + plt.yticks([0, 1], size=27) + plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=27) sns.despine() plt.tight_layout() plt.savefig("example type 3") diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py index f99904a9..c61d4c7c 100644 --- a/dyn_glm_chain_analysis.py +++ b/dyn_glm_chain_analysis.py @@ -526,7 +526,7 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur # print("fix this by taking the whole array, multiply by n, subtract n, divide by n-1") # input() - label = "State {}".format(state) if np.sum(relevant_trials) > 0.02 * len(test.results[0].models[0].stateseqs[seq_num]) else None + label = "State {}".format(len(state_sets) - test.state_mapping[state]) if np.sum(relevant_trials) > 0.02 * len(test.results[0].models[0].stateseqs[seq_num]) else None # state_c_n_a = c_n_a[relevant_trials - trial_counter] # print(state) @@ -571,7 +571,7 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur plt.xlabel('Trial', size=28) sns.despine() - # plt.xlim(left=250, right=450) + plt.xlim(left=250, right=450) plt.legend(frameon=False, fontsize=22, bbox_to_anchor=(0.8, 0.5)) plt.tight_layout() if save: @@ -854,22 +854,25 @@ def state_set_and_plot(test, mode_prefix, subject, fit_type): def state_pmfs(test, trials, indices): - def func_init(): return {'pmfs': [], 'session_js': []} + def func_init(): return {'pmfs': [], 'session_js': [], 'pmf_weights': []} def first_for(test, results): results['pmf'] = np.zeros(test.results[0].n_contrasts) + results['pmf_weight'] = np.zeros(4) def second_for(m, j, session_trials, trial_counter, results): states, counts = np.unique(m.stateseqs[j][session_trials - trial_counter], return_counts=True) for sub_state, c in zip(states, counts): results['pmf'] += weights_to_pmf(m.obs_distns[sub_state].weights[j]) * c / session_trials.shape[0] + results['pmf_weight'] += m.obs_distns[sub_state].weights[j] * c / session_trials.shape[0] def end_first_for(results, indices, j, **kwargs): results['pmfs'].append(results['pmf'] / len(indices)) + results['pmf_weights'].append(results['pmf_weight'] / len(indices)) results['session_js'].append(j) results = control_flow(test, indices, trials, func_init, first_for, second_for, end_first_for) - return results['session_js'], results['pmfs'] + return results['session_js'], results['pmfs'], results['pmf_weights'] def state_weights(test, trials, indices): @@ -1138,6 +1141,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show ax0.annotate('Bias', (test.results[0].infos['bias_start'] + 1 - 0.5, 0.68), fontsize=22) all_pmfs = [] + all_pmf_weights = [] cmaps = ['Greys', 'Purples', 'Blues', 'Greens', 'Oranges', 'Reds', 'YlOrBr', 'YlOrRd', 'OrRd', 'PuRd', 'RdPu'] np.random.seed(8) np.random.shuffle(cmaps) @@ -1147,7 +1151,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show for state, trials in enumerate(state_sets): if separate_pmf: n_trials = len(trials) - session_js, pmfs = state_pmfs(test, trials, indices) + session_js, pmfs, _ = state_pmfs(test, trials, indices) else: pmfs = np.zeros((len(indices), test.results[0].n_contrasts)) n_trials = len(trials) @@ -1173,9 +1177,10 @@ def state_development(test, state_sets, indices, save=True, save_append='', show if separate_pmf: n_trials = len(trials) - session_js, pmfs = state_pmfs(test, trials, indices) + session_js, pmfs, pmf_weights = state_pmfs(test, trials, indices) else: pmfs = np.zeros((len(indices), test.results[0].n_contrasts)) + pmf_weights = np.zeros((len(indices), test.results[0].obs_distns[0].weights.shape[0])) n_trials = len(trials) counter = 0 for i, m in enumerate([item for sublist in test.results for item in sublist.models]): @@ -1188,6 +1193,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show states, counts = np.unique(state_seq[session_trials - trial_counter], return_counts=True) for sub_state, c in zip(states, counts): pmfs[counter] += weights_to_pmf(m.obs_distns[sub_state].weights[j]) * c / n_trials + pmf_weights[counter] += m.obs_distns[sub_state].weights[j] * c / n_trials trial_counter += len(state_seq) counter += 1 @@ -1218,7 +1224,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show if not test.state_mapping[state] in dont_plot: ax1.fill_between([points[k], points[k+1]], test.state_mapping[state] - 0.5, [test.state_mapping[state] + interpolation[k] - 0.5, test.state_mapping[state] + interpolation[k+1] - 0.5], color=cmap(0.3 + 0.7 * k / n_points)) - ax1.annotate(test.state_mapping[state] + 1, (test.results[0].n_sessions + 0.1, test.state_mapping[state] - 0.15), fontsize=22, annotation_clip=False) + ax1.annotate(len(state_sets) - test.state_mapping[state], (test.results[0].n_sessions + 0.1, test.state_mapping[state] - 0.15), fontsize=22, annotation_clip=False) if test.results[0].name.startswith('GLM_Sim_'): ax1.plot(range(1, 1 + test.results[0].n_sessions), truth['state_map'][test.state_mapping[state]] + truth['state_posterior'][:, state] - 0.5, color='r') @@ -1235,11 +1241,12 @@ def state_development(test, state_sets, indices, save=True, save_append='', show # defined_points = np.zeros(test.results[0].n_contrasts, dtype=bool) # defined_points[[0, 1, -2, -1]] = True if separate_pmf: - for j, pmf in zip(session_js, pmfs): + for j, pmf, pmf_weight in zip(session_js, pmfs, pmf_weights): if not test.state_mapping[state] in dont_plot: ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmf[defined_points] - 0.5 + test.state_mapping[state], color=cmap(0.2 + 0.8 * j / test.results[0].n_sessions)) ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmf[defined_points] - 0.5 + test.state_mapping[state], ls='', ms=7, marker='*', color=cmap(j / test.results[0].n_sessions)) all_pmfs.append((defined_points, pmfs)) + all_pmf_weights += pmf_weights else: temp = np.percentile(pmfs, [2.5, 97.5], axis=0) if not test.state_mapping[state] in dont_plot: @@ -1313,7 +1320,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show ax2.set_title('Psychometric\nfunction', size=16) ax1.set_ylabel('Proportion of trials', size=28, labelpad=-20) ax0.set_ylabel('% correct', size=18) - ax2.set_ylabel('Probability', size=26, labelpad=-20) + ax2.set_ylabel('P(rightwards answer)', size=26, labelpad=-20) ax1.set_xlabel('Session', size=28) ax2.set_xlabel('Contrast', size=26) ax1.set_xlim(left=1, right=test.results[0].n_sessions) @@ -1352,7 +1359,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show else: plt.close() - return states_by_session, all_pmfs, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs)), introductions_by_stage, states_per_type + return states_by_session, all_pmfs, all_pmf_weights, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs)), introductions_by_stage, states_per_type def compare_pmfs(test, state_sets, indices, states2compare, states_by_session, all_pmfs, title=""): """ @@ -1670,7 +1677,6 @@ if __name__ == "__main__": # subjects = ['SWC_021', 'ibl_witten_15', 'ibl_witten_13', 'KS003', 'ibl_witten_19', 'SWC_022', 'CSH_ZAD_017'] # subjects = ['KS014'] - # meh pmfs: KS021 print(subjects) fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0] dur = 'yes' @@ -1709,6 +1715,7 @@ if __name__ == "__main__": regressions = [] regression_diffs = [] all_bias_flips = [] + all_pmf_weights = [] temp_counter = 0 @@ -1744,15 +1751,20 @@ if __name__ == "__main__": # states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 1', show=1, separate_pmf=1, type_coloring=True, dont_plot=list(range(7)), plot_until=2) # states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 2', show=1, separate_pmf=1, type_coloring=True, dont_plot=list(range(6)), plot_until=7) # states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 3', show=1, separate_pmf=1, type_coloring=True, dont_plot=list(range(4)), plot_until=13) - states, pmfs, durs, state_types, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True) - all_state_types.append(state_types) + states, pmfs, pmf_weights, durs, state_types, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True) + all_pmf_weights += pmf_weights continue + + consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb')) + consistencies /= consistencies[0, 0] + contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=False, consistencies=consistencies, CMF=False) + + all_state_types.append(state_types) # state_types_interpolation[0] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[0]) # state_types_interpolation[1] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[1]) # state_types_interpolation[2] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[2]) # temp_counter += 1 - continue # b_flips = bias_flips(states, pmfs, durs) # all_bias_flips.append(b_flips) @@ -1915,6 +1927,8 @@ if __name__ == "__main__": # pickle.dump(regression_diffs, open("regression_diffs.p", 'wb')) # pickle.dump(all_bias_flips, open("all_bias_flips.p", 'wb')) # pickle.dump(all_state_types, open("all_state_types.p", 'wb')) + pickle.dump(all_pmf_weights, open("all_pmf_weights.p", 'wb')) + quit() if True: diff --git a/pmf_weight_analysis.py b/pmf_weight_analysis.py new file mode 100644 index 00000000..00611bcc --- /dev/null +++ b/pmf_weight_analysis.py @@ -0,0 +1,91 @@ +import numpy as np +import matplotlib.pyplot as plt +import pickle +from scipy.stats import gaussian_kde +from analysis_pmf import pmf_type, type2color + +# all pmf weights +apw = np.array(pickle.load(open("all_pmf_weights.p", 'rb'))) + +xy = np.vstack([apw[:, i] for i in range(4)]) +z = gaussian_kde(xy)(xy) + +plt.subplot(1, 3, 1) +plt.scatter(apw[:, 0], apw[:, 1], c=z) +plt.xlabel("Cont right") +plt.ylabel("Cont left") + +plt.subplot(1, 3, 2) +plt.scatter(apw[:, 3], apw[:, 1], c=z) +plt.xlabel("Bias") +plt.ylabel("Cont left") + +plt.subplot(1, 3, 3) +plt.scatter(apw[:, 3], apw[:, 0], c=z) +plt.xlabel("Bias") +plt.ylabel("Cont right") + +plt.show() + + +contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0]) +contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1] + +def weights_to_pmf(weights, with_bias=1): + psi = weights[0] * contrasts_R + weights[1] * contrasts_L + with_bias * weights[-1] + return 1 / (1 + np.exp(psi)) + +colors = [type2color[pmf_type(weights_to_pmf(x))] for x in apw] + +plt.subplot(1, 3, 1) +sc = plt.scatter(apw[:, 0], apw[:, 1], c=colors) +fig, ax = plt.gcf(), plt.gca() +plt.xlabel("Cont right") +plt.ylabel("Cont left") + +annot = ax.annotate("", xy=(0, 0), xytext=(20, 20), textcoords="offset points", + bbox=dict(boxstyle="round", fc="w"), + arrowprops=dict(arrowstyle="->")) +annot.set_visible(False) + +def update_annot(ind): + print(ind) + pos = sc.get_offsets()[ind["ind"][0]] + annot.xy = pos + text = "{}".format(np.round(apw[ind["ind"][0]], 2)) + annot.set_text(text) + +def hover(event): + vis = annot.get_visible() + if event.inaxes == ax: + cont, ind = sc.contains(event) + if cont: + update_annot(ind) + annot.set_visible(True) + fig.canvas.draw_idle() + else: + if vis: + annot.set_visible(False) + fig.canvas.draw_idle() + +fig.canvas.mpl_connect("motion_notify_event", hover) + +plt.subplot(1, 3, 2) +plt.scatter(apw[:, 3], apw[:, 1], c=colors) +plt.xlabel("Bias") +plt.ylabel("Cont left") + +plt.subplot(1, 3, 3) +plt.scatter(apw[:, 3], apw[:, 0], c=colors) +plt.xlabel("Bias") +plt.ylabel("Cont right") + +plt.show() + + +from mpl_toolkits import mplot3d + +fig = plt.figure() +ax = plt.axes(projection='3d') +ax.scatter3D(apw[:, 0], apw[:, 1], apw[:, 3], c=colors) +plt.show() diff --git a/simplex_animation.py b/simplex_animation.py index 2dd78ee1..89dfe945 100644 --- a/simplex_animation.py +++ b/simplex_animation.py @@ -38,6 +38,7 @@ assert (test_count == 1).all() # quit() session_counter = 0 +alph = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'za', 'zb', 'zc', 'zd', 'ze', 'zf', 'zg', 'zh', 'zi', 'zj'] # do as many sessions as it takes while True: @@ -64,9 +65,9 @@ while True: type_proportions[:, i] = sts[:, temp_counter] - plotSimplex(type_proportions.T, x_offset=x_offset, y_offset=y_offset, c=np.arange(len(all_state_types)), show=False, vertexcolors=[type2color[i] for i in range(3)], vertexlabels=['Type 1', 'Type 2', 'Type 3'], save_title=None) - plt.title(session_counter) - plt.show() + plotSimplex(type_proportions.T, x_offset=x_offset, y_offset=y_offset, c=np.arange(len(all_state_types)), show=False, title="Session {}".format(session_counter), + vertexcolors=[type2color[i] for i in range(3)], vertexlabels=['Type 1', 'Type 2', 'Type 3'], save_title="simplex_{}.png".format(alph[session_counter])) + plt.close() if not_ended == 0: break diff --git a/simplex_plot.py b/simplex_plot.py index 62bf3627..635885d0 100644 --- a/simplex_plot.py +++ b/simplex_plot.py @@ -15,7 +15,7 @@ import matplotlib.patches as PA def plotSimplex(points, fig=None, vertexlabels=['1: initial flat PMFs', '2: intermediate unilateral PMFs', '3: final bilateral PMFs'], - show=False, vertexcolors=['k', 'k', 'k'], x_offset=0, y_offset=0, save_title="dur_simplex.png", **kwargs): + save_title="test.png", show=False, vertexcolors=['k', 'k', 'k'], x_offset=0, y_offset=0, **kwargs): """ Plot Nx3 points array on the 3-simplex (with optionally labeled vertices) @@ -32,17 +32,17 @@ def plotSimplex(points, fig=None, fig.gca().xaxis.set_major_locator(MT.NullLocator()) fig.gca().yaxis.set_major_locator(MT.NullLocator()) # Draw vertex labels - fig.gca().annotate(vertexlabels[0], (-0.35, -0.05), size=24, color=vertexcolors[0], annotation_clip=False) - fig.gca().annotate(vertexlabels[1], (0.6, -0.05), size=24, color=vertexcolors[1], annotation_clip=False) - fig.gca().annotate(vertexlabels[2], (0.1, np.sqrt(3) / 2 + 0.025), size=24, color=vertexcolors[2], annotation_clip=False) + # fig.gca().annotate(vertexlabels[0], (-0.35, -0.05), size=24, color=vertexcolors[0], annotation_clip=False) + # fig.gca().annotate(vertexlabels[1], (0.6, -0.05), size=24, color=vertexcolors[1], annotation_clip=False) + # fig.gca().annotate(vertexlabels[2], (0.1, np.sqrt(3) / 2 + 0.025), size=24, color=vertexcolors[2], annotation_clip=False) # Project and draw the actual points projected = projectSimplex(points / points.sum(1)[:, None]) - # print(projected) - P.scatter(projected[:, 0] + x_offset, projected[:, 1] + y_offset, s=35, **kwargs)#s=points.sum(1) * 3.5 + print(projected) + P.scatter(projected[:, 0], projected[:, 1], s=points.sum(1) * 3.5, **kwargs) # plot center with average size projected = projectSimplex(np.mean(points / points.sum(1)[:, None], axis=0).reshape(1, 3)) - P.scatter(projected[:, 0], projected[:, 1], marker='*', color='r', s=50)#np.mean(points.sum(1)) * 3.5) + P.scatter(projected[:, 0], projected[:, 1], marker='*', color='r', s=np.mean(points.sum(1)) * 3.5) # Leave some buffer around the triangle for vertex labels fig.gca().set_xlim(-0.05, 1.05) @@ -51,10 +51,11 @@ def plotSimplex(points, fig=None, P.axis('off') P.tight_layout() - if save_title: - P.savefig(save_title, bbox_inches='tight', dpi=300, transparent=True) + P.savefig("dur_simplex.png", bbox_inches='tight', dpi=300, transparent=True) if show: P.show() + else: + P.close() def projectSimplex(points): -- GitLab