From 6df25f7a42e7ce8e511a43d85fe54868347ab0a9 Mon Sep 17 00:00:00 2001
From: SebastianBruijns <>
Date: Thu, 13 Apr 2023 15:38:39 +0200
Subject: [PATCH] code updates and plotting code

---
 __pycache__/analysis_pmf.cpython-37.pyc | Bin 6551 -> 6865 bytes
 __pycache__/simplex_plot.cpython-37.pyc | Bin 3159 -> 2990 bytes
 analysis_pmf.py                         | 124 +++++++++++++-----------
 dyn_glm_chain_analysis.py               |  42 +++++---
 pmf_weight_analysis.py                  |  91 +++++++++++++++++
 simplex_animation.py                    |   7 +-
 simplex_plot.py                         |  19 ++--
 7 files changed, 198 insertions(+), 85 deletions(-)
 create mode 100644 pmf_weight_analysis.py

diff --git a/__pycache__/analysis_pmf.cpython-37.pyc b/__pycache__/analysis_pmf.cpython-37.pyc
index 34a4499ac8d78ac0139d1f5675718e1ab52cae6d..e1ab5f64794b712fd3d7753bfc894d24535085b8 100644
GIT binary patch
literal 6865
zcmdT}O>7&-72ZFRONtU@MYbqevFZHSVjM}7toScZ<iw8bIF75-4w5dMu9un-wbJr$
zc4bp+S`<*wn+u{Rx9CACf&eXwv}k*1fg%ZdN_+C5K#Msgs1N0(=%MH#m-fAx<&P98
zKS+YQBzAY+n>TOXeBYbdN9oz#-h=|5KfIM8zdo-he?y`5>4nBwxcqPKQ52#O)l)3_
zREwIWtFmoa5z!a5Vnic4F-YXcx@FSHMw~=RjF==&5@69olBAdPkyN2)Q?(MLpX_-<
zv3h8d46JC4J<OuYMlbCn50Jri4J=X=z8n3U%6^6H9aO9o*=OwuZH7Xdfly9|@&lnf
z9Lj@a#M(>tTl?q$868&WzF~+No*gMQtZ0fgWTmOTJPa|9(4kFL*re$YSemqlrq@+y
z?IjQ1Hek&MgtvoRytSmy?TzXi`vIHOrlu;;9))&BwwE6yhsZ<Z;oG8T`2cyOOFD>S
z3L1H|)xVM1RF)5c_Atp5bif_^b<qzz(b&dAn~Ip>;qdtgIkKq=&7-IxDY2@tEvsrt
zpL<m#E4)3n#am1I+}{3&7#-b$(TG^p<6BnMlK!7oHQv1{dUzfApSZ4BN31c}1zCE8
zj*&@{yRC!nG3#jUSM+f@u&$n0UV95P<JJTnr$@-tpaT2WV59R2Jxa$(c2J>XBTB2a
zPg&H4mE}n~L8fJv)(9P-6T=Glc@_NNo@<R0du*tCd|MOx`NX|Rj@^bVZy43fui)Ky
zMXCKBVm>Nj{-uaHS9=pQIia}&nkj3VPA?xLGvtW^#JMp;jtf+#H=e+~BlaM!GL=?V
z^hU%_&?zEXEJLR^4i%Kp`nYVB&XN<>3GjFb&XAMj<fcj!^aMFYPUD#Y`s3ggProN&
zU*dV$ZiCgsVf23z(f{@pJPA7c^h`HfpUBg_Hcx-u#*@;Bu=g90>d#9D1a~9A9qR72
zm`p3ycu4~$JIqGdzwaAnqt(NJ4aEPYGaL2qh0I3TM_&=M(dwHJeYE<6ZOlIP1(}sB
z#@J69vFi6iW)BEvW9*$4&s&(q5n3b0?%p@dnpFj`f%w04W=;Q{kXe)c`72`9tbPZg
zH>=(@W}p6o%t{vH?Da;xdLv|ZRxlf9cLcLwv4vS2p*7;{XZH=W@hS&wApS3%*@S;5
zWH!Ox{)(7QR9}VY6V<tG%%1ILHmo+E39HRhq5NzppN6;Tlj2P(E3Gr+9FA_NwHY|6
zpQ5Me>2(d?NDcM144&9<rLCvo^q++~Za=(9Q$xxo)^Ga887gwBc9y0I)Ee*(ow3fg
zYt}SYtYO`X79c^Dg0<^etX;`@@rLQCHpFT=6=$n16f=?M;7ylgCd>xaWcwW_rPh0n
zz?=uj9C@x)d5Su<E$vc=UXbrN>3R2eA!*skcrHRl_ps~pN~5QGeH#zYck@6lStrN~
z)-xggmqYxYB`*S^yHqvgYbA}$Tj%IGG7s}4*+0lj@VnrpVZ~Lgf!Y!s0E}-yts`m{
z>wImVo(IiEQMEvIf-E4m;je~n21MT$7TVTfrP5oCLUtvq|AJK=MYg)Cj-%*#i#)YW
z2K!y=3^8`sKJc?^ANu^%8U<p7*glgT&h=UR;k+NO!sI+V>Cx{_dc^QKsH^0!Cy8OV
zj2L#yIPWAmM4t)sSg^1>N9Nj^t{RMreGAzuek+9;lVT%b?W|ZAux=(Pl5N+`7p+`v
zL*&pMYogX@SH5pRMq=eF>RXb69hYJ^8ma0W)Q!)}oPm3ESQ&yUG>v=u>nkeQe4kz*
zsS_|0nX)dG()0ysOH$Grs#IfLwJy;dy#y=Up;{(a!`y?pAoo($laPD4ZMi4q4o*Ve
z0dJ}Qbbh7X&NZ>@>*gT0E%T+^9S!b`@FarRF&96c3sS$^oH(YN<s8Ax=+yjD?w%9-
zJ9F-we`3@wd+Rd2urAO4PECJp%ju<FRkANP@vkmdc%;B6Epa37tx~QtO0qgrp{FTs
z8n*rep8{Gd$_MX;jk~SJd+!!rJ@uRSZ@%-3c4IcQd9T&D+iH9;%Oc=q0e#~KpQtzo
zm;YV3y$*l$8OA&OL`o|}Ur|_^XjhdW)Dc5kq?L7zL{<!cZrg|`jHt=+*w2E_2q_0D
ziEU`FL!u){=$jaDnQ$opOZ*CpK^y{t!jFE4pFhf=3vXZeus-<qUuHw~hqLvp-T7Q&
z<<-1L{c@3J=3ksUo%z<cJDoO*g43+d^`cYbkqUE50go=0S<wlyD%UH;#h)rH2{s?i
zOs=fuC!IotO!|4Y>Moc3Nq4DOwDWYyxoP`>6HtHBDLLMn@A@{3oT#jE6P|XkR-rxM
z!|#Wiq7JHtHldEHeX0R1)(1Ln+IG=#OSWC#FFc7cz763}Rw;W<;Fe1h6|S*-lYo75
z0B-eR$Mft(m-zus=Zi70Qh%aSPhb}+q1}(|yzBV(qUTmX(=^Zp!dZ`A`NqrRx#L-l
zs|#H7ZZ#1PD|YVc#ml+lQ%(GnV7+&KjJbuS;FiOPpKWG92iO?CTc_MSS1tw2@q^}m
zP&YAW>V3546e}Lh;JPzYmWkf%+|<dY4&7F-oA>NnZV=FtJvFn<dU|X3iEZ7-vKo63
zoUkFdWwj;(+Kgjc<*`M#fH4V5oA8H9P_n)7;1Ll%*Mo9pu8Cwek&Wg!nrLhUO-xn>
z0s-W3i~N1I9(PJ^(eeDuRHr%JX&!5uFt8)f)Kd#fPO0FQ3N7FVyA^iHDN>|k;2Ofx
z>cX7#A^ONd*(<Z=KJaP9TcG;vAjr)FI8kqV1k&2XtZ8EC^&ZJREmkHfupFd-;fQMx
zo977<{Cdv|S94P{Qs5CGgxlqJvs}&B2Z_75NEs~!F66OfZaVuSk5=4;6_4^gupug5
zIq=*(q&2pnOFs3A7w|}FwFv31l`1^uQzu_$C9eA=pC>R0rp4lT#9_=?;~18`0D0rL
zb04v>ZEkpFhj61@frD>&*L}bcc&;t_d<Z@5umQM<9f%9>!&WExnFLt+hM0LDp^I*b
z61!4{&E_-A3~q=8Ms6(uBHUQ>+#)w>c#p1mPM&%^S__|R0pR0vy|Cc$giiyzCKWM4
z{R;Ty$pE*G?Kx}Z)c~+@Zo+&8*gZ?-TimFwx&UO!X60KxJ^?46oBnDZ1Ms<ObKR-=
zJWe50%77WJx?EqOYjUzM8V$}ab(yOUkBUis-h-jsbbZ)rut!l4ix3gYkgYfdYKu)A
zbr>WrkSvJlPGY`qyGo`poRMx}6~KuEZU9s0b+^I|WD%AqqvvXcCj?4z5<x_VD=UFe
z*mAg9lR*lMC5#f^*g|;~k3g;?Xu6a4t>~&xZCJ1lbTQoe;&|@5;Xf<5d8Sw<s~&wC
z-(>zj;clu))l^MAsb+MuFQN6T{hBcfceQ`f&<yn<bqvm9xNAwU?1S%sn!yojLNz~I
z|4t*-sAj^6KLGwT^)NgwxCTIzRMX%&37*qxE}W%h8O8ZE?S2}gBVI{}I0MlqAhH>V
z-NZOSkM|59*9S-mymWZ80kXIrl#?1@2+u2sHhf3LcMK38gVp0#>H!%!9&DrFV-TK*
coM9jnQ7sMcO|(qIxKXghb{|Mv+mf^Y0?m)68UO$Q

literal 6551
zcmd5=-ESP#6`!yDw7u&DCtk<FSsRkrnAnbEk`PFo4-yhcf=eQSMn%(E-|O|*v$L~z
z#>sk1B_#7yHPVU)L@OTFQiaqPimIw=B~(x=RkZ2@Pi=LDDroz_{tLXc=iEE{VPiK)
z1SvDxJNKS@&bjBD-??+o&c(jIv;x=P{@6#(UQ(2Q<Dhl*LFYXDf_%TC5QV6&V#=pl
z(o9{IeZ!22xwx4i8qtYCV(;o^lExY-5+@0gBq@>xjvmrW`bdUki#?mFnI`>Y&zp+b
zLwm{IRjuC74pODjM>FIRGO(cmN0!3f=-*U^6|!$YF|%a9xrgd2dqMLfpf5`XH+7mN
z%8){n*pl2mun9N@$PnE_4&2kkXdfBwj2inva^EJ-BpQR83UV}Z(Q}9#+*AeSfTVnF
zt*lrJ+6S=KqxV^BM63nv16b?OmbGM@^HH3KqqvVo;}#BM(fCpFm^lK!Lv)xt4mKSE
zi!kP$5yV4N%)B`op)*GFo2uX$rFry%_R!G{6?!A&@I6CB7;(r79Q4Lx&_5FOA1^AA
zt%sxDPB<Oi;?I!{z>$o9Uo(%I<6y%CJxa&P1bO0~4zwrC$?C7@lXQ4PeNK7pJ)j&j
zr|2<yl#C20ka-&W=Ym2f=`k`npwRIHN~^bDS<;4-m1#Ofo|M^CC3Kih4Jn}KRnUVT
zx8#I}gWd9NN~8z#cGt*|W0<>jy^?uFscV%#y`oe<1e+&?%^wJxXR1F3%8a1A1C-<D
zES+6BL8i!b5$tTtkr@${*~U}t*rJ!&oU*Fd_294cI1xSeFFM;e8AW%taZ2{Q^W?aB
z8Z@2+JD(=An<`Dy)8qu16Nt^OoC2-5f1iRKj61&F2d)E=^*;gY^-ArPt-I=EH(LJ}
zr~kBZ`o}h$l)AxwTsJD7HzKfm46sAmdo3h$ih0b_0Lc!r2K((pgRD{c5@G}PKX7DY
z!Fv(1G4|niglw#`1lGqYx3?jC>RTc!QH(QEk5{fn$etF+#@QP!oVOs05?VdZ-g;<|
zjaU8vu>t!ZII@Z0jR@HUd;2>=Hc>eb)+Z`c+mL<wTOun_OtN$JWMwQu_KZL_$xMMP
zaBM*qCA4~yxepDp$;!JB8?gU@Bby4$2-y_#zawN*l>=aXs*>D>>@(fSMrZ%C(OG{c
z8lR2E&%kN=tT;(!U38Y5L+QF&or5>^IeLaZv!UUMR99ci;~N`(qvkx+ROiV&IS(gk
zc2L<ARaxUKok!W~c{)m<Vt_Mr&b-jB7)G&Ph^mGQzyYJEa=3t%gRFt->B`H(n@;u6
zst&|T<N};@Y4#Ub|Kg_7KI7!5b<Po3b04`#F0?De0dlD`LdlhraK6oxBSjr5Sg2Fr
z`J6oCq~_hv!Zx3@=Q3n;nvE<dVEM>49IkZZK%O^GldI;XDE`-?_&-NpfQar=)sX9+
zMsApw>1A>Q*6Cp%keg7I-5gRJ)f}!a(_x7551`f&HH&$rx<Ic0<*KM!pgKVmFt(z~
z9g{I{)onPp1%)<uRH^h-K7s7&sr&}Knnbj^s*XvN-ZF+NcFP#}cd0Xk?5=a5XV*E@
zdEXiZYz5oC5g)GgP4wZqU$4UCIy=!(?@n}#;q#DfvJ#dt>>gtbyT`chG`U1CMR_bx
zSh-3rwkcizZ0P0gY>P6ck9|gE?W~yBux^I`RFmzx`2};P+7LPPjyYAWw=3T_AtSN!
z74<F2!j8+Zv3jP0|M*^%IRkBcNEw8klfymz(y9tP3-lVvo`RLgh<V-1(HoMNWF<FL
zsm6wCUZ*qkI(W82wM-VG+=I0s_xdVdLGI0L%RM>n#7T^Ih_|HQA79D0b4@Jgx-pp9
zmicnrZ4K^>=uHH+V=jI@7bJbRHBqJ;<r#vR(W&|6xO+{U@2t6V{t2mF_SQ{$Z9~5Q
zJO7lw*PTy9y`@aN+{Et&H+ihcDD}8eaMvi;86^`sQ(>kl8V!5@g{z3%it_nKQRmZE
z=i`rxuRi_jPwxEwm+j7c#Pe~h^J%N|`8<n(mPOQ!7e0{ypE8u-4QO76&lW@2!e6AE
zLiAOI<%o7m2_qUYBu7r!&`5062(E0Ch=W8`%Huo>I}&mnswC0SUWY`-Afawzz%L0u
z1;UcL$r50P2tn~@KgH|M(&*y*&wo)Hc>nM75&4Vx+Jx1)ZnN@Q!KFdDMDq(buAIsL
z@JF3Vi^W0dgwFMnUF9*KIbO))OJ!EF!wHq^erf3!3hM=)ujZy#*9+5j(I?YEfvq_!
zUNG$}mr7QFF57pkAhbgoOxvFAt_MzFf#j6G&Xe%8!gZhafQ|sZ|3GX9R70Cm52_i}
zfF8>LjVCRuWILW^)rN&8Arn|&4rPAXwL_=uP5E47g(d<1=3Z>xh6b#REtcIfs|_E{
zmzVOmR^DD@<siuWWq(bI0+$guHCh|G*49W#KA=GWiu0jUqO~3&y6gmD{&@FbR>Y{A
za-ue5yRNn5upmT717R)1IGFNlX`DhL^!u@2aO}WZavdKiO#@6e`<#Mn-F3o{de-q%
z6B<_+x#r$&;wQ9PhF0y8?^5Z>i3zQViP#i4X-$k^GX)KJlvr|#=u}fj;vPX{gYe)n
z<`kDhu7_p+N)yp*V)&cKkVRwrkR{2=fEYL*?TIhvYbo1vO12w-`|a*58wT#q_(U_g
zWjt5QE-u?%(ea8crw6(TmS>kJhGOq+1g15GnHaDGpsI;k-P{k2nYvW=LS_eHQ0rS5
z7cShj83`uXLF5^Mrq<WifFW*T?le(Ntw-WOOa7D(K4Q$+A!u%c?XD|=8Ps}Syft%t
zP7chN7(|=qck^5=)CP#Nv_u*8LI<*0A}~Ag5|8`N;;Kve9@rMXTMk{P0LhI#n37N3
zcS9cY)=H4{n&<OGK<z@Ad0Y>?fTuCHW<{Rzn9Z2I&e1Gu5j+msI{;&4S=?~THsMCu
zhu7cGuDgLD5L{adxFHNlx{eq4_9Er+b=NLXm&dE}X{@_Wi6;mRe8;1_H^k(yTzkE|
z7J{Aj9SVsA`(wF$mk**<J0vzw;snN+XRy~vYL){#%#^ef{O1`0`aDW3zYN<qV3>M5
zc6S-#%8e?T5?>D?o&gvrpGIZPfk=85E8hjG2Vrpoo($Fs=vly3i|ckZ;3*3FDFbM@
z>TrFPuJd$2Lo4!EgIAck!d08cg{J}U!Pz@b0Q-%CKEVp>3!!LjsX*9arto3mmf~@n
zK~@9e0;TQ@=li#lWRAmY(<!cjkFn4RVL83#_}oB9A;!`ZuKHXrF4{aTfRQVSaCRbY
zrV$WJ>QJlFY7sSQwS+XWSYE>y8P^dy-7W-Xd@Z0BL_mk2Gkimc*LGKfpW%7(*;1LT
zx%3>KaKUHLCe>awsV4PaHLoW#X{}%F*9>(G-p(00qo&oQcE2%5sr^gF7)p+7y&#!J
zIk>TnYe{(P?*)yTdIX*({PqH+SIxnEFRYhSXC!Z%GmaK$+QTHHV{Fi~JXoIw%jUrL
zB-#mdZ1doI2BIar*5P!6sG#RK?$t07Y4s*X3hsmA9#`?siCfYkGEyFR<Dg>zo)|g9
a$b)=b%fZ=+oH>v?2)x+Oz-Vh<M*e>`lyK1i

diff --git a/__pycache__/simplex_plot.cpython-37.pyc b/__pycache__/simplex_plot.cpython-37.pyc
index 8417f61340255c4910e14c7ab0a3f0d48670b60d..d40c9990e4c7cc9a2781a81174008f4cb3e8b339 100644
GIT binary patch
delta 1223
zcmY*Z&2Jk;6yKTs@WyL9PGZMNqtJrXgWD3dX{A;O1hfbtC8Uw4I9Ni<WM}Mbvff>0
z$Bi9X3aElx5Y-~_ak7uOfDnH`|AB(U3F%!p_r{qMZ@dW~Z+6~$>-T=|YsUNg$`7sT
z{c5#rAU*k_&ext+KlNOWt7me;$<uEcaew;jF1dQ_Dv_y;H153%<S=kG0?EUz(ObB)
zN&XQ%Dy!<)7#ous_wrq{S(+c0pAcf?&&{<huS-{rF6EWNJ|(=`FhVkRzoB2EkE$x`
zVBrFdR$QM`PBd!G;1;i)(#azCcxg)WkIh#8hv}ZoV)A)jSJnivI`5Pek@>=@O^h*_
zI8$S?#25L}l>YPHxOi8_q{$moo!^)&PYo>(TD-}Zd5f>`)l&;G)>lM+jrJ97k6qMO
zwZ@6{fziH{e`x(kAb0G$Coc$20!Y?FfmC^`O~7KqFavT8h-{nS^qD*qQkKvx=aBWj
z|4!4{v)(VyHU0S}m`S&LJkJNx&%%BhiJ=}m*edMptMCS<c!Ytxy}e}UhjFkkGB9}>
zf~WeD#aYUvh*kR%kX>-PVQ(M>kZo|b!&q#-12j&-%07__%wE92Wt{sFn!p~;3SG#A
z^84%{k$#i}OeGTN`;Wle8AQ?c`49F;yXK!c%@eRw8O91q8Tu+A!RiYZLrIEkpQQp^
z9c~zf{X!d|CiRtp%B&qfVxwfB;q0-%VzIRS<Ou8_N;1**z)2H~jty#C<*|@T3?sHD
zqO4QakoSBQsz}Idvyg>7NTNh$9e3#C95H>z9i40793C;*%Ya6=I&-N^4n&{?KQrWJ
z{;;%lQ%5o8TvDMm(jYF;yy3ir+H2-b^14|kby}t_StJfAQ+em-UrWyjaq_y|+>kb6
zUpUC5{(J1SUD-kj_OXyj26izH8&9Yw<VybAT;tviw09gO!&G+|@*xm>MnKbBxHmgA
zy)8F&z!|%(wJXTp$k)qP$(8&;`Ro;)s!f3evk9~>L&^7f68JvY@t~iMWGk<EmtS5&
zH(2O1l}3q*!o3X@26zj>DWqJ_cf57GslD>lTR!=pZ!u4B&Lw#Oi~os~GBvw`r3;1b
zzAf)yq*<-t`hLv%0t<4Aa|4@6$wr;B{-PA8EAQ#VW+Y;HUjtulX#Q{KMf>?`<@(Zf
Qcip9yUBM&Kt41sIZ!tD)XaE2J

delta 1459
zcmZWoOK%)S5T5RtogMGn>&JRyoUDl=w8pZCjT8<@D+tOZwq+?1hrW!))4l6)=0#6C
z*|ioe65?nengb*USRVrt65_<ke*mr^^(hJ`pYVZGPN^Q-u_J1xdaA4HtLmz*`K9sq
zT>0a2xu604`@BqVJ}!Sz%+q{NrwAVF@%IkO|N8P4TD~LEgZ(Pq<96Z&vCsM|aoBS$
z1q3Pr=Yj76C;z3N;zy$)8k>jOZh45Q(bQ;>mWD<qm1*T5hqO!fD%fk8T}l79TcDj3
z#yCgwhj6HQjn*4l7iS1%J41-dP#bFFDy@Ro2+sc%rl^aX+PwB9`n({t-P*9A>H@u?
z^ut_3d#$Nv8Z!$#MW;P9w7$b%BczFnFm@*|g?=M4HR#L%9@l9V4u>!Dq^>?Q%3{)4
z+Msi^c~SIuie91fhxoh+<LUFg(8e?LtwWVn8_ynU;F13p@fUFR#xPG8=v8`+whj!?
z>!3dYuQ}yK(LJrbxc@_;1+?WyqP5vyYo$fwM?}(}ja#EO?sYTz>!|niiK1sm$8&I7
zciOr%5^tYz6LzHPgkdBc;YFeC`d)mDKnvitjQ;7o2A>JPpKG3UBub92D{mwXr|%^a
zElI?WSCwj{6Ekr51IDv7Tn*ess6YQubkZSu)a{a-@OnGKTZYW{+d3(3Z%2LG3*8--
z5S_*zDGKg{N$hYIigr_?En;@P-X3QXtrK(I3)$)?5{EI=^%)mL@43#1SPrGO4>pqP
ztJ9Imgb6!vb|Y^4k?V+vOT2NP6gT&LfBjTJa)Zp(I%^7(m6%7n%oX=e{{g8ddjZKO
z02hoCBVbNQ@|-0*PRxj<GW2~fAl8x+2O#niR8C4+iMH<yqCEk4=OKd%;JiE0qqJi#
ztnmi8$<jl{1?#(!AMr$5eH+>XfuuD!mkJ7M3s3kg{l>h47So^1yYH_6Sevj=30KfG
zvXEM*aSOH3D!z(8(BDJLdJWZZ0b9sKbzpw|w;TB*w3Gv2PoXvKdFM<x>2L$~$7(+K
z1T33vi`Uc2L~~;bc2akvMB><kW_Tf()Kn2ua4!~cVIJ*E1n+?0_$Kfw@ETuKv9jfD
zCC$L{VJaqWpqcdP#IvHx(#8_Gype7eW>Gz5g=?c4Xvt*Yh!_Uk_qJDr=fXrmuI$UF
zmA!V!uC6<gx$MIym9>1y>IM?K0l~L7-UxcyZDO8v9o;7NQ&5oQj2)3pV*XsfTOV;X
z_E{0qvhC0bm~9i&cea_IketK0Gav=Q5)nsUC=&h-xJ%t<Ay+rWn`*h}$Z1>BxcKhW
RlC@xABUgeuR8Iw$@GBIUo?8F_

diff --git a/analysis_pmf.py b/analysis_pmf.py
index 6729289c..637f06cd 100644
--- a/analysis_pmf.py
+++ b/analysis_pmf.py
@@ -28,32 +28,32 @@ if __name__ == "__main__":
     state_types_interpolation = state_types_interpolation / state_types_interpolation.max() * 100
 
     fs = 18
-    plt.plot(np.linspace(0, 1, 150), state_types_interpolation[0], color=type2color[0])
-    plt.ylabel("% of type across population", size=fs)
-    plt.xlabel("Interpolated session time", size=fs)
-    plt.ylim(0, 100)
-    sns.despine()
-    plt.tight_layout()
-    plt.savefig("type hist 1")
-    plt.show()
-
-    plt.plot(np.linspace(0, 1, 150), state_types_interpolation[1], color=type2color[1])
-    plt.ylabel("% of type across population", size=fs)
-    plt.xlabel("Interpolated session time", size=fs)
-    plt.ylim(0, 100)
-    sns.despine()
-    plt.tight_layout()
-    plt.savefig("type hist 2")
-    plt.show()
-
-    plt.plot(np.linspace(0, 1, 150), state_types_interpolation[2], color=type2color[2])
-    plt.ylabel("% of type across population", size=fs)
-    plt.xlabel("Interpolated session time", size=fs)
-    plt.ylim(0, 100)
-    sns.despine()
-    plt.tight_layout()
-    plt.savefig("type hist 3")
-    plt.show()
+    # plt.plot(np.linspace(0, 1, 150), state_types_interpolation[0], color=type2color[0])
+    # plt.ylabel("% of type across population", size=fs)
+    # plt.xlabel("Interpolated session time", size=fs)
+    # plt.ylim(0, 100)
+    # sns.despine()
+    # plt.tight_layout()
+    # plt.savefig("type hist 1")
+    # plt.show()
+    #
+    # plt.plot(np.linspace(0, 1, 150), state_types_interpolation[1], color=type2color[1])
+    # plt.ylabel("% of type across population", size=fs)
+    # plt.xlabel("Interpolated session time", size=fs)
+    # plt.ylim(0, 100)
+    # sns.despine()
+    # plt.tight_layout()
+    # plt.savefig("type hist 2")
+    # plt.show()
+    #
+    # plt.plot(np.linspace(0, 1, 150), state_types_interpolation[2], color=type2color[2])
+    # plt.ylabel("% of type across population", size=fs)
+    # plt.xlabel("Interpolated session time", size=fs)
+    # plt.ylim(0, 100)
+    # sns.despine()
+    # plt.tight_layout()
+    # plt.savefig("type hist 3")
+    # plt.show()
 
     all_first_pmfs_typeless = pickle.load(open("all_first_pmfs_typeless.p", 'rb'))
     all_pmfs = pickle.load(open("all_pmfs.p", 'rb'))
@@ -161,37 +161,41 @@ if __name__ == "__main__":
 
     lw = 4
     # Simplex example pmfs
-    # state_num = 7
-    # defined_points, pmf = all_first_pmfs_typeless['NYU-06'][state_num][0], all_first_pmfs_typeless['NYU-06'][state_num][1]
-    # plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
-    # state_num = 6
-    # defined_points, pmf = all_first_pmfs_typeless['CSHL061'][state_num][0], all_first_pmfs_typeless['CSHL061'][state_num][1]
-    # plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
-    #
-    # plt.ylim(0, 1)
-    # plt.xlim(0, 10)
-    # plt.yticks([])
-    # plt.xticks([])
-    # sns.despine()
-    # plt.tight_layout()
-    # plt.savefig("example type 1")
-    # plt.show()
-    #
-    # state_num = 1
-    # defined_points, pmf = all_first_pmfs_typeless['CSHL_018'][state_num][0], all_first_pmfs_typeless['CSHL_018'][state_num][1]
-    # plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
-    # state_num = 3
-    # defined_points, pmf = all_first_pmfs_typeless['ibl_witten_14'][state_num][0], all_first_pmfs_typeless['ibl_witten_14'][state_num][1]
-    # plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
-    #
-    # plt.ylim(0, 1)
-    # plt.xlim(0, 10)
-    # plt.yticks([])
-    # plt.xticks([])
-    # sns.despine()
-    # plt.tight_layout()
-    # plt.savefig("example type 2")
-    # plt.show()
+    state_num = 7
+    defined_points, pmf = all_first_pmfs_typeless['NYU-06'][state_num][0], all_first_pmfs_typeless['NYU-06'][state_num][1]
+    plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
+    state_num = 6
+    defined_points, pmf = all_first_pmfs_typeless['CSHL061'][state_num][0], all_first_pmfs_typeless['CSHL061'][state_num][1]
+    plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
+
+    plt.ylim(0, 1)
+    plt.xlim(0, 10)
+    plt.ylabel("P(rightwards)", size=32)
+    plt.xlabel("Contrast", size=32)
+    plt.yticks([0, 1], size=27)
+    plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=27)
+    sns.despine()
+    plt.tight_layout()
+    plt.savefig("example type 1")
+    plt.show()
+
+    state_num = 1
+    defined_points, pmf = all_first_pmfs_typeless['CSHL_018'][state_num][0], all_first_pmfs_typeless['CSHL_018'][state_num][1]
+    plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
+    state_num = 3
+    defined_points, pmf = all_first_pmfs_typeless['ibl_witten_14'][state_num][0], all_first_pmfs_typeless['ibl_witten_14'][state_num][1]
+    plt.plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)], lw=lw)
+
+    plt.ylim(0, 1)
+    plt.xlim(0, 10)
+    plt.ylabel("P(rightwards)", size=32)
+    plt.xlabel("Contrast", size=32)
+    plt.yticks([0, 1], size=27)
+    plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=27)
+    sns.despine()
+    plt.tight_layout()
+    plt.savefig("example type 2")
+    plt.show()
 
     state_num = 4
     defined_points, pmf = all_first_pmfs_typeless['ibl_witten_17'][state_num][0], all_first_pmfs_typeless['ibl_witten_17'][state_num][1]
@@ -199,8 +203,10 @@ if __name__ == "__main__":
 
     plt.ylim(0, 1)
     plt.xlim(0, 10)
-    plt.yticks([])
-    plt.xticks([])
+    plt.ylabel("P(rightwards)", size=32)
+    plt.xlabel("Contrast", size=32)
+    plt.yticks([0, 1], size=27)
+    plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=27)
     sns.despine()
     plt.tight_layout()
     plt.savefig("example type 3")
diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py
index f99904a9..c61d4c7c 100644
--- a/dyn_glm_chain_analysis.py
+++ b/dyn_glm_chain_analysis.py
@@ -526,7 +526,7 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur
                 # print("fix this by taking the whole array, multiply by n, subtract n, divide by n-1")
                 # input()
 
-            label = "State {}".format(state) if np.sum(relevant_trials) > 0.02 * len(test.results[0].models[0].stateseqs[seq_num]) else None
+            label = "State {}".format(len(state_sets) - test.state_mapping[state]) if np.sum(relevant_trials) > 0.02 * len(test.results[0].models[0].stateseqs[seq_num]) else None
 
             # state_c_n_a = c_n_a[relevant_trials - trial_counter]
             # print(state)
@@ -571,7 +571,7 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur
         plt.xlabel('Trial', size=28)
         sns.despine()
 
-        # plt.xlim(left=250, right=450)
+        plt.xlim(left=250, right=450)
         plt.legend(frameon=False, fontsize=22, bbox_to_anchor=(0.8, 0.5))
         plt.tight_layout()
         if save:
@@ -854,22 +854,25 @@ def state_set_and_plot(test, mode_prefix, subject, fit_type):
 
 
 def state_pmfs(test, trials, indices):
-    def func_init(): return {'pmfs': [], 'session_js': []}
+    def func_init(): return {'pmfs': [], 'session_js': [], 'pmf_weights': []}
 
     def first_for(test, results):
         results['pmf'] = np.zeros(test.results[0].n_contrasts)
+        results['pmf_weight'] = np.zeros(4)
 
     def second_for(m, j, session_trials, trial_counter, results):
         states, counts = np.unique(m.stateseqs[j][session_trials - trial_counter], return_counts=True)
         for sub_state, c in zip(states, counts):
             results['pmf'] += weights_to_pmf(m.obs_distns[sub_state].weights[j]) * c / session_trials.shape[0]
+            results['pmf_weight'] += m.obs_distns[sub_state].weights[j] * c / session_trials.shape[0]
 
     def end_first_for(results, indices, j, **kwargs):
         results['pmfs'].append(results['pmf'] / len(indices))
+        results['pmf_weights'].append(results['pmf_weight'] / len(indices))
         results['session_js'].append(j)
 
     results = control_flow(test, indices, trials, func_init, first_for, second_for, end_first_for)
-    return results['session_js'], results['pmfs']
+    return results['session_js'], results['pmfs'], results['pmf_weights']
 
 
 def state_weights(test, trials, indices):
@@ -1138,6 +1141,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
         ax0.annotate('Bias', (test.results[0].infos['bias_start'] + 1 - 0.5, 0.68), fontsize=22)
 
     all_pmfs = []
+    all_pmf_weights = []
     cmaps = ['Greys', 'Purples', 'Blues', 'Greens', 'Oranges', 'Reds', 'YlOrBr', 'YlOrRd', 'OrRd', 'PuRd', 'RdPu']
     np.random.seed(8)
     np.random.shuffle(cmaps)
@@ -1147,7 +1151,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
     for state, trials in enumerate(state_sets):
         if separate_pmf:
             n_trials = len(trials)
-            session_js, pmfs = state_pmfs(test, trials, indices)
+            session_js, pmfs, _ = state_pmfs(test, trials, indices)
         else:
             pmfs = np.zeros((len(indices), test.results[0].n_contrasts))
             n_trials = len(trials)
@@ -1173,9 +1177,10 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
 
         if separate_pmf:
             n_trials = len(trials)
-            session_js, pmfs = state_pmfs(test, trials, indices)
+            session_js, pmfs, pmf_weights = state_pmfs(test, trials, indices)
         else:
             pmfs = np.zeros((len(indices), test.results[0].n_contrasts))
+            pmf_weights = np.zeros((len(indices), test.results[0].obs_distns[0].weights.shape[0]))
             n_trials = len(trials)
             counter = 0
             for i, m in enumerate([item for sublist in test.results for item in sublist.models]):
@@ -1188,6 +1193,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
                         states, counts = np.unique(state_seq[session_trials - trial_counter], return_counts=True)
                         for sub_state, c in zip(states, counts):
                             pmfs[counter] += weights_to_pmf(m.obs_distns[sub_state].weights[j]) * c / n_trials
+                            pmf_weights[counter] += m.obs_distns[sub_state].weights[j] * c / n_trials
                     trial_counter += len(state_seq)
                 counter += 1
 
@@ -1218,7 +1224,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
                 if not test.state_mapping[state] in dont_plot:
                     ax1.fill_between([points[k], points[k+1]],
                                      test.state_mapping[state] - 0.5, [test.state_mapping[state] + interpolation[k] - 0.5, test.state_mapping[state] + interpolation[k+1] - 0.5], color=cmap(0.3 + 0.7 * k / n_points))
-        ax1.annotate(test.state_mapping[state] + 1, (test.results[0].n_sessions + 0.1, test.state_mapping[state] - 0.15), fontsize=22, annotation_clip=False)
+        ax1.annotate(len(state_sets) - test.state_mapping[state], (test.results[0].n_sessions + 0.1, test.state_mapping[state] - 0.15), fontsize=22, annotation_clip=False)
 
         if test.results[0].name.startswith('GLM_Sim_'):
             ax1.plot(range(1, 1 + test.results[0].n_sessions), truth['state_map'][test.state_mapping[state]] + truth['state_posterior'][:, state] - 0.5, color='r')
@@ -1235,11 +1241,12 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
         #         defined_points = np.zeros(test.results[0].n_contrasts, dtype=bool)
         #         defined_points[[0, 1, -2, -1]] = True
         if separate_pmf:
-            for j, pmf in zip(session_js, pmfs):
+            for j, pmf, pmf_weight in zip(session_js, pmfs, pmf_weights):
                 if not test.state_mapping[state] in dont_plot:
                     ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmf[defined_points] - 0.5 + test.state_mapping[state], color=cmap(0.2 + 0.8 * j / test.results[0].n_sessions))
                     ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmf[defined_points] - 0.5 + test.state_mapping[state], ls='', ms=7, marker='*', color=cmap(j / test.results[0].n_sessions))
             all_pmfs.append((defined_points, pmfs))
+            all_pmf_weights += pmf_weights
         else:
             temp = np.percentile(pmfs, [2.5, 97.5], axis=0)
             if not test.state_mapping[state] in dont_plot:
@@ -1313,7 +1320,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
     ax2.set_title('Psychometric\nfunction', size=16)
     ax1.set_ylabel('Proportion of trials', size=28, labelpad=-20)
     ax0.set_ylabel('% correct', size=18)
-    ax2.set_ylabel('Probability', size=26, labelpad=-20)
+    ax2.set_ylabel('P(rightwards answer)', size=26, labelpad=-20)
     ax1.set_xlabel('Session', size=28)
     ax2.set_xlabel('Contrast', size=26)
     ax1.set_xlim(left=1, right=test.results[0].n_sessions)
@@ -1352,7 +1359,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
     else:
         plt.close()
 
-    return states_by_session, all_pmfs, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs)), introductions_by_stage, states_per_type
+    return states_by_session, all_pmfs, all_pmf_weights, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs)), introductions_by_stage, states_per_type
 
 def compare_pmfs(test, state_sets, indices, states2compare, states_by_session, all_pmfs, title=""):
     """
@@ -1670,7 +1677,6 @@ if __name__ == "__main__":
     # subjects = ['SWC_021', 'ibl_witten_15', 'ibl_witten_13', 'KS003', 'ibl_witten_19', 'SWC_022', 'CSH_ZAD_017']
     # subjects = ['KS014']
 
-    # meh pmfs: KS021
     print(subjects)
     fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0]
     dur = 'yes'
@@ -1709,6 +1715,7 @@ if __name__ == "__main__":
     regressions = []
     regression_diffs = []
     all_bias_flips = []
+    all_pmf_weights = []
 
     temp_counter = 0
 
@@ -1744,15 +1751,20 @@ if __name__ == "__main__":
             # states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 1', show=1, separate_pmf=1, type_coloring=True, dont_plot=list(range(7)), plot_until=2)
             # states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 2', show=1, separate_pmf=1, type_coloring=True, dont_plot=list(range(6)), plot_until=7)
             # states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 3', show=1, separate_pmf=1, type_coloring=True, dont_plot=list(range(4)), plot_until=13)
-            states, pmfs, durs, state_types, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True)
-            all_state_types.append(state_types)
+            states, pmfs, pmf_weights, durs, state_types, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True)
+            all_pmf_weights += pmf_weights
             continue
+
+            consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
+            consistencies /= consistencies[0, 0]
+            contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=False, consistencies=consistencies, CMF=False)
+
+            all_state_types.append(state_types)
             # state_types_interpolation[0] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[0])
             # state_types_interpolation[1] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[1])
             # state_types_interpolation[2] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[2])
             # temp_counter += 1
 
-            continue
 
             # b_flips = bias_flips(states, pmfs, durs)
             # all_bias_flips.append(b_flips)
@@ -1915,6 +1927,8 @@ if __name__ == "__main__":
     # pickle.dump(regression_diffs, open("regression_diffs.p", 'wb'))
     # pickle.dump(all_bias_flips, open("all_bias_flips.p", 'wb'))
     # pickle.dump(all_state_types, open("all_state_types.p", 'wb'))
+    pickle.dump(all_pmf_weights, open("all_pmf_weights.p", 'wb'))
+    quit()
 
 
     if True:
diff --git a/pmf_weight_analysis.py b/pmf_weight_analysis.py
new file mode 100644
index 00000000..00611bcc
--- /dev/null
+++ b/pmf_weight_analysis.py
@@ -0,0 +1,91 @@
+import numpy as np
+import matplotlib.pyplot as plt
+import pickle
+from scipy.stats import gaussian_kde
+from analysis_pmf import pmf_type, type2color
+
+# all pmf weights
+apw = np.array(pickle.load(open("all_pmf_weights.p", 'rb')))
+
+xy = np.vstack([apw[:, i] for i in range(4)])
+z = gaussian_kde(xy)(xy)
+
+plt.subplot(1, 3, 1)
+plt.scatter(apw[:, 0], apw[:, 1], c=z)
+plt.xlabel("Cont right")
+plt.ylabel("Cont left")
+
+plt.subplot(1, 3, 2)
+plt.scatter(apw[:, 3], apw[:, 1], c=z)
+plt.xlabel("Bias")
+plt.ylabel("Cont left")
+
+plt.subplot(1, 3, 3)
+plt.scatter(apw[:, 3], apw[:, 0], c=z)
+plt.xlabel("Bias")
+plt.ylabel("Cont right")
+
+plt.show()
+
+
+contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])
+contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1]
+
+def weights_to_pmf(weights, with_bias=1):
+    psi = weights[0] * contrasts_R + weights[1] * contrasts_L + with_bias * weights[-1]
+    return 1 / (1 + np.exp(psi))
+
+colors = [type2color[pmf_type(weights_to_pmf(x))] for x in apw]
+
+plt.subplot(1, 3, 1)
+sc = plt.scatter(apw[:, 0], apw[:, 1], c=colors)
+fig, ax = plt.gcf(), plt.gca()
+plt.xlabel("Cont right")
+plt.ylabel("Cont left")
+
+annot = ax.annotate("", xy=(0, 0), xytext=(20, 20), textcoords="offset points",
+                    bbox=dict(boxstyle="round", fc="w"),
+                    arrowprops=dict(arrowstyle="->"))
+annot.set_visible(False)
+
+def update_annot(ind):
+    print(ind)
+    pos = sc.get_offsets()[ind["ind"][0]]
+    annot.xy = pos
+    text = "{}".format(np.round(apw[ind["ind"][0]], 2))
+    annot.set_text(text)
+
+def hover(event):
+    vis = annot.get_visible()
+    if event.inaxes == ax:
+        cont, ind = sc.contains(event)
+        if cont:
+            update_annot(ind)
+            annot.set_visible(True)
+            fig.canvas.draw_idle()
+        else:
+            if vis:
+                annot.set_visible(False)
+                fig.canvas.draw_idle()
+
+fig.canvas.mpl_connect("motion_notify_event", hover)
+
+plt.subplot(1, 3, 2)
+plt.scatter(apw[:, 3], apw[:, 1], c=colors)
+plt.xlabel("Bias")
+plt.ylabel("Cont left")
+
+plt.subplot(1, 3, 3)
+plt.scatter(apw[:, 3], apw[:, 0], c=colors)
+plt.xlabel("Bias")
+plt.ylabel("Cont right")
+
+plt.show()
+
+
+from mpl_toolkits import mplot3d
+
+fig = plt.figure()
+ax = plt.axes(projection='3d')
+ax.scatter3D(apw[:, 0], apw[:, 1], apw[:, 3], c=colors)
+plt.show()
diff --git a/simplex_animation.py b/simplex_animation.py
index 2dd78ee1..89dfe945 100644
--- a/simplex_animation.py
+++ b/simplex_animation.py
@@ -38,6 +38,7 @@ assert (test_count == 1).all()
 # quit()
 
 session_counter = 0
+alph = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'za', 'zb', 'zc', 'zd', 'ze', 'zf', 'zg', 'zh', 'zi', 'zj']
 # do as many sessions as it takes
 while True:
 
@@ -64,9 +65,9 @@ while True:
 
         type_proportions[:, i] = sts[:, temp_counter]
 
-    plotSimplex(type_proportions.T, x_offset=x_offset, y_offset=y_offset, c=np.arange(len(all_state_types)), show=False, vertexcolors=[type2color[i] for i in range(3)], vertexlabels=['Type 1', 'Type 2', 'Type 3'], save_title=None)
-    plt.title(session_counter)
-    plt.show()
+    plotSimplex(type_proportions.T, x_offset=x_offset, y_offset=y_offset, c=np.arange(len(all_state_types)), show=False, title="Session {}".format(session_counter),
+                vertexcolors=[type2color[i] for i in range(3)], vertexlabels=['Type 1', 'Type 2', 'Type 3'], save_title="simplex_{}.png".format(alph[session_counter]))
+    plt.close()
 
     if not_ended == 0:
         break
diff --git a/simplex_plot.py b/simplex_plot.py
index 62bf3627..635885d0 100644
--- a/simplex_plot.py
+++ b/simplex_plot.py
@@ -15,7 +15,7 @@ import matplotlib.patches as PA
 
 def plotSimplex(points, fig=None,
                 vertexlabels=['1: initial flat PMFs', '2: intermediate unilateral PMFs', '3: final bilateral PMFs'],
-                show=False, vertexcolors=['k', 'k', 'k'], x_offset=0, y_offset=0, save_title="dur_simplex.png", **kwargs):
+                save_title="test.png", show=False, vertexcolors=['k', 'k', 'k'], x_offset=0, y_offset=0, **kwargs):
     """
     Plot Nx3 points array on the 3-simplex
     (with optionally labeled vertices)
@@ -32,17 +32,17 @@ def plotSimplex(points, fig=None,
     fig.gca().xaxis.set_major_locator(MT.NullLocator())
     fig.gca().yaxis.set_major_locator(MT.NullLocator())
     # Draw vertex labels
-    fig.gca().annotate(vertexlabels[0], (-0.35, -0.05), size=24, color=vertexcolors[0], annotation_clip=False)
-    fig.gca().annotate(vertexlabels[1], (0.6, -0.05), size=24, color=vertexcolors[1], annotation_clip=False)
-    fig.gca().annotate(vertexlabels[2], (0.1, np.sqrt(3) / 2 + 0.025), size=24, color=vertexcolors[2], annotation_clip=False)
+    # fig.gca().annotate(vertexlabels[0], (-0.35, -0.05), size=24, color=vertexcolors[0], annotation_clip=False)
+    # fig.gca().annotate(vertexlabels[1], (0.6, -0.05), size=24, color=vertexcolors[1], annotation_clip=False)
+    # fig.gca().annotate(vertexlabels[2], (0.1, np.sqrt(3) / 2 + 0.025), size=24, color=vertexcolors[2], annotation_clip=False)
     # Project and draw the actual points
     projected = projectSimplex(points / points.sum(1)[:, None])
-    # print(projected)
-    P.scatter(projected[:, 0] + x_offset, projected[:, 1] + y_offset, s=35, **kwargs)#s=points.sum(1) * 3.5
+    print(projected)
+    P.scatter(projected[:, 0], projected[:, 1], s=points.sum(1) * 3.5, **kwargs)
 
     # plot center with average size
     projected = projectSimplex(np.mean(points / points.sum(1)[:, None], axis=0).reshape(1, 3))
-    P.scatter(projected[:, 0], projected[:, 1], marker='*', color='r', s=50)#np.mean(points.sum(1)) * 3.5)
+    P.scatter(projected[:, 0], projected[:, 1], marker='*', color='r', s=np.mean(points.sum(1)) * 3.5)
 
     # Leave some buffer around the triangle for vertex labels
     fig.gca().set_xlim(-0.05, 1.05)
@@ -51,10 +51,11 @@ def plotSimplex(points, fig=None,
     P.axis('off')
 
     P.tight_layout()
-    if save_title:
-        P.savefig(save_title, bbox_inches='tight', dpi=300, transparent=True)
+    P.savefig("dur_simplex.png", bbox_inches='tight', dpi=300, transparent=True)
     if show:
         P.show()
+    else:
+        P.close()
 
 
 def projectSimplex(points):
-- 
GitLab