diff --git a/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc b/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc
index a82afa9c345c7e3df3fd0494a1e8695fda475388..b4b56a55e7c7e3ff5cceb8da95a87d51c338cbac 100644
Binary files a/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc and b/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc differ
diff --git a/__pycache__/mcmc_chain_analysis.cpython-37.pyc b/__pycache__/mcmc_chain_analysis.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bdea663ae5d36ec0de32219c100df06016b9244f
Binary files /dev/null and b/__pycache__/mcmc_chain_analysis.cpython-37.pyc differ
diff --git a/__pycache__/process_many_chains.cpython-37.pyc b/__pycache__/process_many_chains.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..44eddf35ca7b5b24cc52c61abb188e42c155fc62
Binary files /dev/null and b/__pycache__/process_many_chains.cpython-37.pyc differ
diff --git a/__pycache__/simplex_plot.cpython-37.pyc b/__pycache__/simplex_plot.cpython-37.pyc
index 1e8de26b8b6d51af0763e873273c06afd448dab6..1c26491f573cb6355e901c949147f3d921c87a30 100644
Binary files a/__pycache__/simplex_plot.cpython-37.pyc and b/__pycache__/simplex_plot.cpython-37.pyc differ
diff --git a/canonical_infos.json b/canonical_infos.json
index 6b9451cec02c20319870b92c795d47ba50ee8cbf..c2ca57ff4832606bb4d5af5dad36768fc4162753 100644
--- a/canonical_infos.json
+++ b/canonical_infos.json
@@ -1 +1 @@
-{"SWC_023": {"seeds": ["302", "312", "304", "300", "315", "311", "308", "305", "303", "309", "306", "313", "307", "314", "301", "310"], "fit_nums": ["994", "913", "681", "816", "972", "790", "142", "230", "696", "537", "975", "773", "918", "677", "742", "745"], "chain_num": 4}, "SWC_021": {"seeds": ["415", "403", "412", "407", "409", "408", "405", "404", "410", "414", "401", "413", "402", "400", "406", "411"], "fit_nums": ["773", "615", "107", "583", "564", "354", "142", "184", "549", "185", "924", "907", "105", "531", "9", "812"], "chain_num": 9}, "ibl_witten_15": {"seeds": ["409", "410", "401", "415", "414", "403", "411", "404", "402", "405", "400", "412", "408", "407", "406", "413"], "fit_nums": ["411", "344", "496", "600", "716", "18", "527", "467", "898", "334", "309", "326", "133", "823", "740", "253"], "chain_num": 9}, "ibl_witten_13": {"seeds": ["302", "312", "313", "306", "315", "307", "311", "314", "309", "301", "308", "300", "304", "310", "303", "305"], "fit_nums": ["897", "765", "433", "641", "967", "599", "984", "259", "853", "385", "887", "619", "434", "964", "483", "891"], "chain_num": 4}, "KS016": {"seeds": ["315", "301", "309", "313", "302", "307", "303", "308", "311", "312", "314", "306", "310", "300", "305", "304"], "fit_nums": ["99", "57", "585", "32", "501", "558", "243", "413", "59", "757", "463", "172", "524", "957", "909", "292"], "chain_num": 4}, "KS003": {"seeds": ["404", "407", "413", "403", "414", "405", "400", "401", "402", "410", "415", "408", "411", "409", "406", "412"], "fit_nums": ["846", "256", "845", "945", "293", "406", "420", "109", "690", "421", "54", "866", "784", "81", "997", "665"], "chain_num": 9}, "ibl_witten_19": {"seeds": ["315", "311", "307", "314", "308", "300", "305", "301", "313", "304", "302", "310", "306", "312", "309", "303"], "fit_nums": ["179", "951", "613", "6", "623", "382", "458", "504", "406", "554", "5", "631", "746", "817", "265", "328"], "chain_num": 4}, "SWC_022": {"seeds": ["411", "403", "414", "409", "407", "412", "410", "413", "415", "404", "405", "400", "402", "401", "408", "406"], "fit_nums": ["408", "884", "62", "962", "744", "854", "635", "70", "320", "952", "8", "67", "231", "381", "536", "962"], "chain_num": 9}, "KS022": {"seeds": ["315", "300", "314", "301", "303", "302", "306", "308", "305", "310", "313", "312", "304", "307", "311", "309"], "fit_nums": ["899", "681", "37", "957", "629", "637", "375", "980", "810", "51", "759", "664", "420", "127", "259", "555"], "chain_num": 4}, "CSH_ZAD_017": {"seeds": ["401", "409", "405", "403", "415", "404", "402", "411", "410", "414", "408", "406", "413", "412", "400", "407"], "fit_nums": ["883", "803", "637", "806", "356", "804", "662", "654", "684", "350", "947", "460", "569", "976", "103", "713"], "chain_num": 9}, "CSH_ZAD_025": {"seeds": ["303", "311", "307", "312", "313", "314", "308", "315", "305", "306", "304", "302", "309", "310", "301", "300"], "fit_nums": ["581", "148", "252", "236", "581", "838", "206", "756", "449", "288", "756", "593", "733", "633", "418", "563"], "chain_num": 4}, "ibl_witten_17": {"seeds": ["406", "415", "408", "413", "402", "405", "409", "400", "414", "401", "412", "407", "404", "410", "403", "411"], "fit_nums": ["827", "797", "496", "6", "444", "823", "384", "873", "634", "27", "811", "142", "207", "322", "756", "275"], "chain_num": 9}, "ibl_witten_18": {"seeds": ["311", "310", "303", "314", "302", "309", "305", "307", "312", "300", "308", "306", "315", "313", "304", "301"], "fit_nums": ["236", "26", "838", "762", "826", "409", "496", "944", "280", "704", "930", "419", "637", "896", "876", "297"], "chain_num": 4}, "CSHL_018": {"seeds": ["302", "310", "306", "300", "314", "307", "309", "313", "311", "308", "304", "301", "312", "303", "305", "315"], "fit_nums": ["843", "817", "920", "900", "226", "36", "472", "676", "933", "453", "116", "263", "269", "897", "568", "438"], "chain_num": 4}, "GLM_Sim_06": {"seeds": ["313", "309", "302", "303", "305", "314", "300", "315", "311", "306", "304", "310", "301", "312", "308", "307"], "fit_nums": ["9", "786", "286", "280", "72", "587", "619", "708", "360", "619", "311", "189", "60", "708", "939", "733"], "chain_num": 2}, "ZM_1897": {"seeds": ["304", "308", "305", "311", "315", "314", "307", "306", "300", "303", "313", "310", "301", "312", "302", "309"], "fit_nums": ["549", "96", "368", "509", "424", "897", "287", "426", "968", "93", "725", "513", "837", "581", "989", "374"], "chain_num": 4}, "CSHL_020": {"seeds": ["305", "309", "313", "302", "314", "310", "300", "307", "315", "306", "312", "304", "311", "301", "303", "308"], "fit_nums": ["222", "306", "243", "229", "584", "471", "894", "238", "986", "660", "494", "657", "896", "459", "100", "283"], "chain_num": 4}, "CSHL054": {"seeds": ["401", "415", "409", "410", "414", "413", "407", "405", "406", "408", "411", "400", "412", "402", "403", "404"], "fit_nums": ["901", "734", "609", "459", "574", "793", "978", "66", "954", "906", "954", "111", "292", "850", "266", "967"], "chain_num": 9}, "CSHL_014": {"seeds": ["305", "311", "309", "300", "313", "310", "307", "306", "304", "312", "308", "302", "314", "303", "301", "315"], "fit_nums": ["371", "550", "166", "24", "705", "385", "870", "884", "831", "546", "404", "722", "287", "564", "613", "783"], "chain_num": 4}, "CSHL062": {"seeds": ["307", "313", "310", "303", "306", "312", "308", "305", "311", "314", "304", "302", "300", "301", "315", "309"], "fit_nums": ["846", "371", "94", "888", "499", "229", "546", "432", "71", "989", "986", "91", "935", "314", "975", "481"], "chain_num": 4}, "CSH_ZAD_001": {"seeds": ["313", "309", "311", "312", "305", "310", "315", "300", "314", "304", "301", "302", "308", "303", "306", "307"], "fit_nums": ["468", "343", "314", "544", "38", "120", "916", "170", "305", "569", "502", "496", "452", "336", "559", "572"], "chain_num": 4}, "NYU-06": {"seeds": ["314", "309", "306", "305", "312", "303", "307", "304", "300", "302", "310", "301", "315", "308", "313", "311"], "fit_nums": ["950", "862", "782", "718", "427", "645", "827", "612", "821", "834", "595", "929", "679", "668", "648", "869"], "chain_num": 4}, "KS019": {"seeds": ["404", "401", "411", "408", "400", "403", "410", "413", "402", "407", "415", "409", "406", "414", "412", "405"], "fit_nums": ["682", "4", "264", "200", "250", "267", "737", "703", "132", "855", "922", "686", "85", "176", "54", "366"], "chain_num": 9}, "CSHL049": {"seeds": ["411", "402", "414", "408", "409", "410", "413", "407", "406", "401", "404", "405", "403", "415", "400", "412"], "fit_nums": ["104", "553", "360", "824", "749", "519", "347", "228", "863", "671", "140", "883", "701", "445", "627", "898"], "chain_num": 9}, "ibl_witten_14": {"seeds": ["310", "311", "304", "306", "300", "302", "314", "313", "303", "308", "301", "309", "305", "315", "312", "307"], "fit_nums": ["563", "120", "85", "712", "277", "871", "183", "661", "505", "598", "210", "89", "310", "638", "564", "998"], "chain_num": 4}, "KS014": {"seeds": ["301", "310", "302", "312", "313", "308", "307", "303", "305", "300", "314", "306", "311", "309", "304", "315"], "fit_nums": ["668", "32", "801", "193", "269", "296", "74", "24", "270", "916", "21", "250", "342", "451", "517", "293"], "chain_num": 4}, "CSHL059": {"seeds": ["306", "309", "300", "304", "314", "303", "315", "311", "313", "305", "301", "307", "302", "312", "310", "308"], "fit_nums": ["821", "963", "481", "999", "986", "45", "551", "605", "701", "201", "629", "261", "972", "407", "165", "9"], "chain_num": 4}, "GLM_Sim_13": {"seeds": ["310", "303", "308", "306", "300", "312", "301", "313", "305", "311", "315", "304", "314", "309", "307", "302"], "fit_nums": ["982", "103", "742", "524", "614", "370", "926", "456", "133", "143", "302", "80", "395", "549", "579", "944"], "chain_num": 2}, "CSHL_007": {"seeds": ["314", "303", "308", "313", "301", "300", "302", "305", "315", "306", "310", "309", "311", "304", "307", "312"], "fit_nums": ["462", "703", "345", "286", "480", "313", "986", "165", "201", "102", "322", "894", "960", "438", "330", "169"], "chain_num": 4}, "CSH_ZAD_011": {"seeds": ["314", "311", "303", "300", "305", "310", "306", "301", "302", "315", "304", "309", "308", "312", "313", "307"], "fit_nums": ["320", "385", "984", "897", "315", "120", "320", "945", "475", "403", "210", "412", "695", "564", "664", "411"], "chain_num": 4}, "KS021": {"seeds": ["309", "312", "304", "310", "303", "311", "314", "302", "305", "301", "306", "300", "308", "315", "313", "307"], "fit_nums": ["874", "943", "925", "587", "55", "136", "549", "528", "349", "211", "401", "84", "225", "545", "153", "382"], "chain_num": 4}, "GLM_Sim_15": {"seeds": ["303", "312", "305", "308", "309", "302", "301", "310", "313", "315", "311", "314", "307", "306", "304", "300"], "fit_nums": ["769", "930", "328", "847", "899", "714", "144", "518", "521", "873", "914", "359", "242", "343", "45", "364"], "chain_num": 2}, "CSHL_015": {"seeds": ["301", "302", "307", "310", "309", "311", "304", "312", "300", "308", "313", "305", "314", "315", "306", "303"], "fit_nums": ["717", "705", "357", "539", "604", "971", "669", "76", "45", "413", "510", "122", "190", "821", "368", "472"], "chain_num": 4}, "ibl_witten_16": {"seeds": ["304", "313", "309", "314", "312", "307", "305", "301", "306", "310", "300", "315", "308", "311", "303", "302"], "fit_nums": ["392", "515", "696", "270", "7", "583", "880", "674", "23", "576", "579", "695", "149", "854", "184", "875"], "chain_num": 4}, "KS015": {"seeds": ["315", "305", "309", "303", "314", "310", "311", "312", "313", "300", "307", "308", "304", "301", "302", "306"], "fit_nums": ["257", "396", "387", "435", "133", "164", "403", "8", "891", "650", "111", "557", "473", "229", "842", "196"], "chain_num": 4}, "GLM_Sim_12": {"seeds": ["304", "312", "306", "303", "310", "302", "300", "305", "308", "313", "307", "311", "315", "301", "314", "309"], "fit_nums": ["971", "550", "255", "195", "952", "486", "841", "535", "559", "37", "654", "213", "864", "506", "732", "550"], "chain_num": 2}, "GLM_Sim_11": {"seeds": ["300", "312", "310", "315", "302", "313", "314", "311", "308", "303", "309", "307", "306", "304", "301", "305"], "fit_nums": ["477", "411", "34", "893", "195", "293", "603", "5", "887", "281", "956", "73", "346", "640", "532", "688"], "chain_num": 2}, "GLM_Sim_10": {"seeds": ["301", "300", "306", "305", "307", "309", "312", "314", "311", "315", "304", "313", "303", "308", "302", "310"], "fit_nums": ["391", "97", "897", "631", "239", "652", "19", "448", "807", "35", "972", "469", "280", "562", "42", "706"], "chain_num": 2}, "CSH_ZAD_026": {"seeds": ["312", "313", "308", "310", "303", "307", "302", "305", "300", "315", "306", "301", "311", "304", "314", "309"], "fit_nums": ["699", "87", "537", "628", "797", "511", "459", "770", "969", "240", "504", "948", "295", "506", "25", "378"], "chain_num": 4}, "KS023": {"seeds": ["304", "313", "306", "309", "300", "314", "302", "310", "303", "315", "307", "308", "301", "311", "305", "312"], "fit_nums": ["698", "845", "319", "734", "908", "507", "45", "499", "175", "108", "419", "443", "116", "779", "159", "231"], "chain_num": 4}, "GLM_Sim_05": {"seeds": ["301", "315", "300", "302", "305", "304", "313", "314", "311", "309", "306", "307", "308", "310", "303", "312"], "fit_nums": ["425", "231", "701", "375", "343", "902", "623", "125", "921", "637", "393", "964", "678", "930", "796", "42"], "chain_num": 2}, "CSHL061": {"seeds": ["305", "315", "304", "303", "309", "310", "302", "300", "314", "306", "311", "313", "301", "308", "307", "312"], "fit_nums": ["396", "397", "594", "911", "308", "453", "686", "552", "103", "209", "128", "892", "345", "925", "777", "396"], "chain_num": 4}, "CSHL051": {"seeds": ["303", "310", "306", "302", "309", "305", "313", "308", "300", "314", "311", "307", "312", "304", "315", "301"], "fit_nums": ["69", "186", "49", "435", "103", "910", "705", "367", "303", "474", "596", "334", "929", "796", "616", "790"], "chain_num": 4}, "GLM_Sim_14": {"seeds": ["310", "311", "309", "313", "314", "300", "302", "304", "305", "306", "307", "312", "303", "301", "315", "308"], "fit_nums": ["616", "872", "419", "106", "940", "986", "599", "704", "218", "808", "244", "825", "448", "397", "552", "316"], "chain_num": 2}, "GLM_Sim_11_trick": {"seeds": ["411", "400", "408", "409", "415", "413", "410", "412", "406", "414", "403", "404", "401", "405", "407", "402"], "fit_nums": ["95", "508", "886", "384", "822", "969", "525", "382", "489", "436", "344", "537", "251", "223", "458", "401"], "chain_num": 2, "ignore": [10, 12, 4, 1, 0, 3, 2, 6]}, "GLM_Sim_16": {"seeds": ["302", "311", "303", "307", "313", "308", "309", "300", "305", "315", "304", "310", "312", "301", "314", "306"], "fit_nums": ["914", "377", "173", "583", "870", "456", "611", "697", "13", "713", "159", "248", "617", "37", "770", "780"], "chain_num": 2}, "ZM_3003": {"seeds": ["300", "304", "307", "312", "305", "310", "311", "314", "303", "308", "313", "301", "315", "309", "306", "302"], "fit_nums": ["603", "620", "657", "735", "357", "390", "119", "33", "62", "617", "209", "810", "688", "21", "744", "426"], "chain_num": 4}, "CSH_ZAD_022": {"seeds": ["305", "310", "311", "315", "303", "312", "314", "313", "307", "302", "300", "304", "301", "308", "306", "309"], "fit_nums": ["143", "946", "596", "203", "576", "403", "900", "65", "478", "325", "282", "513", "460", "42", "161", "970"], "chain_num": 4}, "GLM_Sim_07": {"seeds": ["300", "309", "302", "304", "305", "312", "301", "311", "315", "314", "308", "307", "303", "310", "306", "313"], "fit_nums": ["724", "701", "118", "230", "648", "426", "689", "114", "832", "731", "592", "519", "559", "938", "672", "144"], "chain_num": 1}, "KS017": {"seeds": ["311", "310", "306", "309", "303", "302", "308", "300", "313", "301", "314", "307", "315", "304", "312", "305"], "fit_nums": ["97", "281", "808", "443", "352", "890", "703", "468", "780", "708", "674", "27", "345", "23", "939", "457"], "chain_num": 4}, "GLM_Sim_11_sub": {"seeds": ["410", "414", "413", "404", "409", "415", "406", "408", "402", "411", "400", "405", "403", "407", "412", "401"], "fit_nums": ["830", "577", "701", "468", "929", "374", "954", "749", "937", "488", "873", "416", "612", "792", "461", "488"], "chain_num": 2}}
\ No newline at end of file
+{"SWC_023": {"seeds": ["302", "312", "304", "300", "315", "311", "308", "305", "303", "309", "306", "313", "307", "314", "301", "310"], "fit_nums": ["994", "913", "681", "816", "972", "790", "142", "230", "696", "537", "975", "773", "918", "677", "742", "745"], "chain_num": 4}, "SWC_021": {"seeds": ["415", "403", "412", "407", "409", "408", "405", "404", "410", "414", "401", "413", "402", "400", "406", "411"], "fit_nums": ["773", "615", "107", "583", "564", "354", "142", "184", "549", "185", "924", "907", "105", "531", "9", "812"], "chain_num": 9}, "ibl_witten_15": {"seeds": ["409", "410", "401", "415", "414", "403", "411", "404", "402", "405", "400", "412", "408", "407", "406", "413"], "fit_nums": ["411", "344", "496", "600", "716", "18", "527", "467", "898", "334", "309", "326", "133", "823", "740", "253"], "chain_num": 9}, "ibl_witten_13": {"seeds": ["302", "312", "313", "306", "315", "307", "311", "314", "309", "301", "308", "300", "304", "310", "303", "305"], "fit_nums": ["897", "765", "433", "641", "967", "599", "984", "259", "853", "385", "887", "619", "434", "964", "483", "891"], "chain_num": 4}, "KS016": {"seeds": ["315", "301", "309", "313", "302", "307", "303", "308", "311", "312", "314", "306", "310", "300", "305", "304"], "fit_nums": ["99", "57", "585", "32", "501", "558", "243", "413", "59", "757", "463", "172", "524", "957", "909", "292"], "chain_num": 4}, "KS003": {"seeds": ["404", "407", "413", "403", "414", "405", "400", "401", "402", "410", "415", "408", "411", "409", "406", "412"], "fit_nums": ["846", "256", "845", "945", "293", "406", "420", "109", "690", "421", "54", "866", "784", "81", "997", "665"], "chain_num": 9}, "ibl_witten_19": {"seeds": ["315", "311", "307", "314", "308", "300", "305", "301", "313", "304", "302", "310", "306", "312", "309", "303"], "fit_nums": ["179", "951", "613", "6", "623", "382", "458", "504", "406", "554", "5", "631", "746", "817", "265", "328"], "chain_num": 4}, "SWC_022": {"seeds": ["411", "403", "414", "409", "407", "412", "410", "413", "415", "404", "405", "400", "402", "401", "408", "406"], "fit_nums": ["408", "884", "62", "962", "744", "854", "635", "70", "320", "952", "8", "67", "231", "381", "536", "962"], "chain_num": 9}, "KS022": {"seeds": ["315", "300", "314", "301", "303", "302", "306", "308", "305", "310", "313", "312", "304", "307", "311", "309"], "fit_nums": ["899", "681", "37", "957", "629", "637", "375", "980", "810", "51", "759", "664", "420", "127", "259", "555"], "chain_num": 4}, "CSH_ZAD_017": {"seeds": ["401", "409", "405", "403", "415", "404", "402", "411", "410", "414", "408", "406", "413", "412", "400", "407"], "fit_nums": ["883", "803", "637", "806", "356", "804", "662", "654", "684", "350", "947", "460", "569", "976", "103", "713"], "chain_num": 9}, "CSH_ZAD_025": {"seeds": ["303", "311", "307", "312", "313", "314", "308", "315", "305", "306", "304", "302", "309", "310", "301", "300"], "fit_nums": ["581", "148", "252", "236", "581", "838", "206", "756", "449", "288", "756", "593", "733", "633", "418", "563"], "chain_num": 4}, "ibl_witten_17": {"seeds": ["406", "415", "408", "413", "402", "405", "409", "400", "414", "401", "412", "407", "404", "410", "403", "411"], "fit_nums": ["827", "797", "496", "6", "444", "823", "384", "873", "634", "27", "811", "142", "207", "322", "756", "275"], "chain_num": 9}, "ibl_witten_18": {"seeds": ["311", "310", "303", "314", "302", "309", "305", "307", "312", "300", "308", "306", "315", "313", "304", "301"], "fit_nums": ["236", "26", "838", "762", "826", "409", "496", "944", "280", "704", "930", "419", "637", "896", "876", "297"], "chain_num": 4}, "CSHL_018": {"seeds": ["302", "310", "306", "300", "314", "307", "309", "313", "311", "308", "304", "301", "312", "303", "305", "315"], "fit_nums": ["843", "817", "920", "900", "226", "36", "472", "676", "933", "453", "116", "263", "269", "897", "568", "438"], "chain_num": 4}, "GLM_Sim_06": {"seeds": ["313", "309", "302", "303", "305", "314", "300", "315", "311", "306", "304", "310", "301", "312", "308", "307"], "fit_nums": ["9", "786", "286", "280", "72", "587", "619", "708", "360", "619", "311", "189", "60", "708", "939", "733"], "chain_num": 2}, "ZM_1897": {"seeds": ["304", "308", "305", "311", "315", "314", "307", "306", "300", "303", "313", "310", "301", "312", "302", "309"], "fit_nums": ["549", "96", "368", "509", "424", "897", "287", "426", "968", "93", "725", "513", "837", "581", "989", "374"], "chain_num": 4}, "CSHL_020": {"seeds": ["305", "309", "313", "302", "314", "310", "300", "307", "315", "306", "312", "304", "311", "301", "303", "308"], "fit_nums": ["222", "306", "243", "229", "584", "471", "894", "238", "986", "660", "494", "657", "896", "459", "100", "283"], "chain_num": 4}, "CSHL054": {"seeds": ["401", "415", "409", "410", "414", "413", "407", "405", "406", "408", "411", "400", "412", "402", "403", "404"], "fit_nums": ["901", "734", "609", "459", "574", "793", "978", "66", "954", "906", "954", "111", "292", "850", "266", "967"], "chain_num": 9}, "CSHL_014": {"seeds": ["305", "311", "309", "300", "313", "310", "307", "306", "304", "312", "308", "302", "314", "303", "301", "315"], "fit_nums": ["371", "550", "166", "24", "705", "385", "870", "884", "831", "546", "404", "722", "287", "564", "613", "783"], "chain_num": 4}, "CSHL062": {"seeds": ["307", "313", "310", "303", "306", "312", "308", "305", "311", "314", "304", "302", "300", "301", "315", "309"], "fit_nums": ["846", "371", "94", "888", "499", "229", "546", "432", "71", "989", "986", "91", "935", "314", "975", "481"], "chain_num": 4}, "CSH_ZAD_001": {"seeds": ["313", "309", "311", "312", "305", "310", "315", "300", "314", "304", "301", "302", "308", "303", "306", "307"], "fit_nums": ["468", "343", "314", "544", "38", "120", "916", "170", "305", "569", "502", "496", "452", "336", "559", "572"], "chain_num": 4}, "NYU-06": {"seeds": ["314", "309", "306", "305", "312", "303", "307", "304", "300", "302", "310", "301", "315", "308", "313", "311"], "fit_nums": ["950", "862", "782", "718", "427", "645", "827", "612", "821", "834", "595", "929", "679", "668", "648", "869"], "chain_num": 4}, "KS019": {"seeds": ["404", "401", "411", "408", "400", "403", "410", "413", "402", "407", "415", "409", "406", "414", "412", "405"], "fit_nums": ["682", "4", "264", "200", "250", "267", "737", "703", "132", "855", "922", "686", "85", "176", "54", "366"], "chain_num": 9}, "CSHL049": {"seeds": ["411", "402", "414", "408", "409", "410", "413", "407", "406", "401", "404", "405", "403", "415", "400", "412"], "fit_nums": ["104", "553", "360", "824", "749", "519", "347", "228", "863", "671", "140", "883", "701", "445", "627", "898"], "chain_num": 9}, "ibl_witten_14": {"seeds": ["310", "311", "304", "306", "300", "302", "314", "313", "303", "308", "301", "309", "305", "315", "312", "307"], "fit_nums": ["563", "120", "85", "712", "277", "871", "183", "661", "505", "598", "210", "89", "310", "638", "564", "998"], "chain_num": 4}, "KS014": {"seeds": ["301", "310", "302", "312", "313", "308", "307", "303", "305", "300", "314", "306", "311", "309", "304", "315"], "fit_nums": ["668", "32", "801", "193", "269", "296", "74", "24", "270", "916", "21", "250", "342", "451", "517", "293"], "chain_num": 4, "ignore": [9, 11, 0, 1, 14, 2, 12, 13]}, "CSHL059": {"seeds": ["306", "309", "300", "304", "314", "303", "315", "311", "313", "305", "301", "307", "302", "312", "310", "308"], "fit_nums": ["821", "963", "481", "999", "986", "45", "551", "605", "701", "201", "629", "261", "972", "407", "165", "9"], "chain_num": 4}, "GLM_Sim_13": {"seeds": ["310", "303", "308", "306", "300", "312", "301", "313", "305", "311", "315", "304", "314", "309", "307", "302"], "fit_nums": ["982", "103", "742", "524", "614", "370", "926", "456", "133", "143", "302", "80", "395", "549", "579", "944"], "chain_num": 2}, "CSHL_007": {"seeds": ["314", "303", "308", "313", "301", "300", "302", "305", "315", "306", "310", "309", "311", "304", "307", "312"], "fit_nums": ["462", "703", "345", "286", "480", "313", "986", "165", "201", "102", "322", "894", "960", "438", "330", "169"], "chain_num": 4}, "CSH_ZAD_011": {"seeds": ["314", "311", "303", "300", "305", "310", "306", "301", "302", "315", "304", "309", "308", "312", "313", "307"], "fit_nums": ["320", "385", "984", "897", "315", "120", "320", "945", "475", "403", "210", "412", "695", "564", "664", "411"], "chain_num": 4}, "KS021": {"seeds": ["309", "312", "304", "310", "303", "311", "314", "302", "305", "301", "306", "300", "308", "315", "313", "307"], "fit_nums": ["874", "943", "925", "587", "55", "136", "549", "528", "349", "211", "401", "84", "225", "545", "153", "382"], "chain_num": 4}, "GLM_Sim_15": {"seeds": ["303", "312", "305", "308", "309", "302", "301", "310", "313", "315", "311", "314", "307", "306", "304", "300"], "fit_nums": ["769", "930", "328", "847", "899", "714", "144", "518", "521", "873", "914", "359", "242", "343", "45", "364"], "chain_num": 2}, "CSHL_015": {"seeds": ["301", "302", "307", "310", "309", "311", "304", "312", "300", "308", "313", "305", "314", "315", "306", "303"], "fit_nums": ["717", "705", "357", "539", "604", "971", "669", "76", "45", "413", "510", "122", "190", "821", "368", "472"], "chain_num": 4}, "ibl_witten_16": {"seeds": ["304", "313", "309", "314", "312", "307", "305", "301", "306", "310", "300", "315", "308", "311", "303", "302"], "fit_nums": ["392", "515", "696", "270", "7", "583", "880", "674", "23", "576", "579", "695", "149", "854", "184", "875"], "chain_num": 4}, "KS015": {"seeds": ["315", "305", "309", "303", "314", "310", "311", "312", "313", "300", "307", "308", "304", "301", "302", "306"], "fit_nums": ["257", "396", "387", "435", "133", "164", "403", "8", "891", "650", "111", "557", "473", "229", "842", "196"], "chain_num": 4}, "GLM_Sim_12": {"seeds": ["304", "312", "306", "303", "310", "302", "300", "305", "308", "313", "307", "311", "315", "301", "314", "309"], "fit_nums": ["971", "550", "255", "195", "952", "486", "841", "535", "559", "37", "654", "213", "864", "506", "732", "550"], "chain_num": 2}, "GLM_Sim_11": {"seeds": ["300", "312", "310", "315", "302", "313", "314", "311", "308", "303", "309", "307", "306", "304", "301", "305"], "fit_nums": ["477", "411", "34", "893", "195", "293", "603", "5", "887", "281", "956", "73", "346", "640", "532", "688"], "chain_num": 2}, "GLM_Sim_10": {"seeds": ["301", "300", "306", "305", "307", "309", "312", "314", "311", "315", "304", "313", "303", "308", "302", "310"], "fit_nums": ["391", "97", "897", "631", "239", "652", "19", "448", "807", "35", "972", "469", "280", "562", "42", "706"], "chain_num": 2}, "CSH_ZAD_026": {"seeds": ["312", "313", "308", "310", "303", "307", "302", "305", "300", "315", "306", "301", "311", "304", "314", "309"], "fit_nums": ["699", "87", "537", "628", "797", "511", "459", "770", "969", "240", "504", "948", "295", "506", "25", "378"], "chain_num": 4}, "KS023": {"seeds": ["304", "313", "306", "309", "300", "314", "302", "310", "303", "315", "307", "308", "301", "311", "305", "312"], "fit_nums": ["698", "845", "319", "734", "908", "507", "45", "499", "175", "108", "419", "443", "116", "779", "159", "231"], "chain_num": 4}, "GLM_Sim_05": {"seeds": ["301", "315", "300", "302", "305", "304", "313", "314", "311", "309", "306", "307", "308", "310", "303", "312"], "fit_nums": ["425", "231", "701", "375", "343", "902", "623", "125", "921", "637", "393", "964", "678", "930", "796", "42"], "chain_num": 2}, "CSHL061": {"seeds": ["305", "315", "304", "303", "309", "310", "302", "300", "314", "306", "311", "313", "301", "308", "307", "312"], "fit_nums": ["396", "397", "594", "911", "308", "453", "686", "552", "103", "209", "128", "892", "345", "925", "777", "396"], "chain_num": 4}, "CSHL051": {"seeds": ["303", "310", "306", "302", "309", "305", "313", "308", "300", "314", "311", "307", "312", "304", "315", "301"], "fit_nums": ["69", "186", "49", "435", "103", "910", "705", "367", "303", "474", "596", "334", "929", "796", "616", "790"], "chain_num": 4}, "GLM_Sim_14": {"seeds": ["310", "311", "309", "313", "314", "300", "302", "304", "305", "306", "307", "312", "303", "301", "315", "308"], "fit_nums": ["616", "872", "419", "106", "940", "986", "599", "704", "218", "808", "244", "825", "448", "397", "552", "316"], "chain_num": 2}, "GLM_Sim_11_trick": {"seeds": ["411", "400", "408", "409", "415", "413", "410", "412", "406", "414", "403", "404", "401", "405", "407", "402"], "fit_nums": ["95", "508", "886", "384", "822", "969", "525", "382", "489", "436", "344", "537", "251", "223", "458", "401"], "chain_num": 2, "ignore": [10, 12, 4, 1, 0, 3, 2, 6]}, "GLM_Sim_16": {"seeds": ["302", "311", "303", "307", "313", "308", "309", "300", "305", "315", "304", "310", "312", "301", "314", "306"], "fit_nums": ["914", "377", "173", "583", "870", "456", "611", "697", "13", "713", "159", "248", "617", "37", "770", "780"], "chain_num": 2}, "ZM_3003": {"seeds": ["300", "304", "307", "312", "305", "310", "311", "314", "303", "308", "313", "301", "315", "309", "306", "302"], "fit_nums": ["603", "620", "657", "735", "357", "390", "119", "33", "62", "617", "209", "810", "688", "21", "744", "426"], "chain_num": 4}, "CSH_ZAD_022": {"seeds": ["305", "310", "311", "315", "303", "312", "314", "313", "307", "302", "300", "304", "301", "308", "306", "309"], "fit_nums": ["143", "946", "596", "203", "576", "403", "900", "65", "478", "325", "282", "513", "460", "42", "161", "970"], "chain_num": 4}, "GLM_Sim_07": {"seeds": ["300", "309", "302", "304", "305", "312", "301", "311", "315", "314", "308", "307", "303", "310", "306", "313"], "fit_nums": ["724", "701", "118", "230", "648", "426", "689", "114", "832", "731", "592", "519", "559", "938", "672", "144"], "chain_num": 1}, "KS017": {"seeds": ["311", "310", "306", "309", "303", "302", "308", "300", "313", "301", "314", "307", "315", "304", "312", "305"], "fit_nums": ["97", "281", "808", "443", "352", "890", "703", "468", "780", "708", "674", "27", "345", "23", "939", "457"], "chain_num": 4}, "GLM_Sim_11_sub": {"seeds": ["410", "414", "413", "404", "409", "415", "406", "408", "402", "411", "400", "405", "403", "407", "412", "401"], "fit_nums": ["830", "577", "701", "468", "929", "374", "954", "749", "937", "488", "873", "416", "612", "792", "461", "488"], "chain_num": 2}}
\ No newline at end of file
diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py
index 3fefa03f85e69e152beae64b51db2d3d37475376..7d57a217bcd38325649ce5ada58a2e37f48fea2d 100644
--- a/dyn_glm_chain_analysis.py
+++ b/dyn_glm_chain_analysis.py
@@ -1,5 +1,4 @@
 import matplotlib
-# matplotlib.use('Agg')
 import os
 os.environ["OMP_NUM_THREADS"] = "6" # export OMP_NUM_THREADS=4
 os.environ["OPENBLAS_NUM_THREADS"] = "6" # export OPENBLAS_NUM_THREADS=4
@@ -12,7 +11,7 @@ import pyhsmm
 import pickle
 import seaborn as sns
 import sys
-from scipy.stats import rankdata, norm, zscore
+from scipy.stats import zscore
 from scipy.optimize import minimize
 from itertools import combinations, product
 import matplotlib.gridspec as gridspec
@@ -20,6 +19,7 @@ from matplotlib.ticker import MaxNLocator
 import json
 import time
 import multiprocessing as mp
+from mcmc_chain_analysis import state_size_helper, state_num_helper, gamma_func, alpha_func, ll_func, r_hat_array_comp, rank_inv_normal_transform, eval_r_hat, eval_simple_r_hat
 
 
 colors = np.genfromtxt('colors.csv', delimiter=',')
@@ -69,7 +69,7 @@ class MCMC_result_list:
     def state_num_dist(self):
         state_nums = np.zeros((self.m, self.n))
         for i in range(self.m):
-            state_nums[i] = state_num_func(self.results[i])
+            state_nums[i] = state_num_helper(0.05)(self.results[i])
         return state_nums
 
     def return_chains(self, func, model_for_loop, rank_norm=True, mode_indices=None):
@@ -149,14 +149,11 @@ class MCMC_result_list:
                 counter += 1
                 step = len(ind) // min_len
                 up_to = min_len * step
-                # print([j + i * 640 for j in ind[:up_to:step]])
                 chains[counter] = func(self.results[i], ind[:up_to:step])
-            # print(chains)
-        # print(np.unique(chains.flatten(), return_counts=True))
+
         self.chains = chains
-        self.folded_chains = np.abs(self.chains - np.median(chains))
 
-        self.rank_inv_normal_transform()
+        self.rank_normalised, self.folded_rank_normalised, self.ranked, self.folded_ranked = rank_inv_normal_transform(self.chains)
 
         self.lame_r_hat, self.lame_var_hat_plus = r_hat_array_comp(self.chains)
         self.rank_normalised_r_hat, self.var_hat_plus = r_hat_array_comp(self.rank_normalised)
@@ -178,15 +175,6 @@ class MCMC_result_list:
         print("Effective number of samples is {}".format(self.n_eff))
         return chains
 
-    def rank_inv_normal_transform(self):
-        # Gelman paper Rank-normalization, folding, and localization: An improved R_hat for assessing convergence of MCMC
-        # ranking with average rank for ties
-        self.ranked = rankdata(self.chains).reshape(self.chains.shape)
-        self.folded_ranked = rankdata(self.folded_chains).reshape(self.folded_chains.shape)
-        # inverse normal with fractional offset
-        self.rank_normalised = norm.ppf((self.ranked - 3/8) / (self.chains.size + 1/4))
-        self.folded_rank_normalised = norm.ppf((self.folded_ranked - 3/8) / (self.folded_chains.size + 1/4))
-
     def rank_histos(self):
         count_max = 0
         for i in range(self.m):
@@ -385,42 +373,6 @@ def state_appear_and_dur(m, n_sessions):
     return state_appear, [(1 + session_dict[s][1] - session_dict[s][0]) / n_sessions for s in observed_states]
 
 
-def r_hat_array_comp(chains):
-    m, n = chains.shape
-    psi_dot_j = np.mean(chains, axis=1)
-    psi_dot_dot = np.mean(psi_dot_j)
-    B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot) ** 2)
-    s_j_squared = np.sum((chains - psi_dot_j[:, None]) ** 2, axis=1) / (n - 1)
-    W = np.mean(s_j_squared)
-    var_hat_plus = (n - 1) / n * W + B / n
-    if W == 0:
-        # print("all the same value")
-        return 1, 0
-    r_hat = np.sqrt(var_hat_plus / W)
-    return r_hat, var_hat_plus
-
-
-def eval_amortized_r_hat(chains, psi_dot_j, s_j_squared, l, m, n):
-    psi_dot_dot = np.mean(psi_dot_j, axis=1)
-    B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot[:, None]) ** 2, axis=1)
-    W = np.mean(s_j_squared, axis=1)
-    var_hat_plus = (n - 1) / n * W + B / n
-    r_hat = np.sqrt(var_hat_plus / W)
-    return max(r_hat)
-
-
-def r_hat_array_comp_mult(chains):
-    l, m, n = chains.shape
-    psi_dot_j = np.mean(chains, axis=2)
-    psi_dot_dot = np.mean(psi_dot_j, axis=1)
-    B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot[:, None]) ** 2, axis=1)
-    s_j_squared = np.sum((chains - psi_dot_j[:, :, None]) ** 2, axis=2) / (n - 1)
-    W = np.mean(s_j_squared, axis=1)
-    var_hat_plus = (n - 1) / n * W + B / n
-    r_hat = np.sqrt(var_hat_plus / W)
-    return r_hat, var_hat_plus
-
-
 class MCMC_result:
     def __init__(self, models, infos, data, sessions, fit_variance, save_id, seq_start=0, dur='yes', sample_lls=None):
 
@@ -533,21 +485,6 @@ class MCMC_result:
         return np.array(glm_weights)
 
 
-def gamma_func(x): return x.trans_distn.gamma
-
-
-def alpha_func(x): return x.trans_distn.alpha
-
-
-def largest_state_func(x): return x.assign_counts.max(1)
-
-
-def state_num_func(x): return ((x.assign_counts / x.n_datapoints) > 0.05).sum(1)
-
-
-def ll_func(x): return x.sample_lls[-x.n_samples:]
-
-
 def find_good_chains(chains, reduce_to=8):
     delete_n = chains.shape[0] // 2 - reduce_to
     mins = np.zeros(1 + delete_n)
@@ -943,6 +880,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
     ax0.plot(1 + 0.25, 0.6 + 0.2, 'ko', ms=18)
     ax0.plot(1 + 0.25, 0.6 + 0.2, 'wo', ms=16.8)
     ax0.plot(1 + 0.25, 0.6 + 0.2, 'ko', ms=16.8, alpha=abs(num_to_cont[1]))
+
     if test.results[0].type != 'bias':
         current, counter = 0, 0
         for c in [2, 3, 4, 5]:
@@ -1095,12 +1033,30 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
 
     durs, state_types = state_type_durs(states_by_session, all_pmfs)
     dur_counter = 1
+    contrast_intro_types = [0, 0, 0, 0]
+    state, when = np.where(states_by_session > 0.05)
+    introductions_by_stage = np.zeros(3)
+    covered_states = []
     for i, d in enumerate(durs):
         ax0.fill_between(range(dur_counter, 1 + dur_counter + d), 0.5, -0.5, color=type_colours[i], zorder=0, alpha=0.3)
         dur_counter += d
 
-    ax2.set_title('Psychometric\nfunction', size=16)
+        # find out during which state type which contrast was introduced
+        for j, contrast in enumerate([2, 3, 4, 5]):
+            if contrast_intro_types[j] != 0:
+                continue
+            if test.results[0].infos[contrast] + 1 < dur_counter:
+                contrast_intro_types[j] = i+1
+
+        # find out during which stage which state was introduced
+        for s in range(len(state_sets)):
+            if np.sum(state == s) == 0 or s in covered_states:
+                continue
+            if when[state == s][0] + 1 < dur_counter:
+                introductions_by_stage[i] += 1
+                covered_states.append(s)
 
+    ax2.set_title('Psychometric\nfunction', size=16)
     ax1.set_ylabel('Proportion of trials', size=28, labelpad=-20)
     ax0.set_ylabel('% correct', size=18)
     ax2.set_ylabel('Probability', size=26, labelpad=-20)
@@ -1142,68 +1098,13 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
     else:
         plt.close()
 
-    return states_by_session, all_pmfs, durs, state_types
-
-
-def comp_multi_r_hat(chains, rank_normalised, folded_rank_normalised):
-    lame_r_hat, _ = r_hat_array_comp(chains)
-    rank_normalised_r_hat, _ = r_hat_array_comp(rank_normalised)
-    folded_rank_normalised_r_hat, _ = r_hat_array_comp(folded_rank_normalised)
-    return max(lame_r_hat, rank_normalised_r_hat, folded_rank_normalised_r_hat)
-
-
-def eval_r_hat(chains1, chains2, chains3, chains4):
-    rank_normalised1, folded_rank_normalised1 = rank_inv_normal_transform(chains1)
-    rank_normalised2, folded_rank_normalised2 = rank_inv_normal_transform(chains2)
-    rank_normalised3, folded_rank_normalised3 = rank_inv_normal_transform(chains3)
-    rank_normalised4, folded_rank_normalised4 = rank_inv_normal_transform(chains4)
-
-    r_hat1 = comp_multi_r_hat(chains1, rank_normalised1, folded_rank_normalised1)
-    r_hat2 = comp_multi_r_hat(chains2, rank_normalised2, folded_rank_normalised2)
-    r_hat3 = comp_multi_r_hat(chains3, rank_normalised3, folded_rank_normalised3)
-    r_hat4 = comp_multi_r_hat(chains4, rank_normalised4, folded_rank_normalised4)
-
-    return max(r_hat1, r_hat2, r_hat3, r_hat4)
-
-
-def eval_simple_r_hat(chains):
+    return states_by_session, all_pmfs, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs))
 
-    r_hats, _ = r_hat_array_comp_mult(chains)
-
-    return max(r_hats)
-
-
-def find_good_chains_unsplit(chains1, chains2, chains3, chains4, reduce_to=8, simple=False):
-    delete_n = - reduce_to + chains1.shape[0]
-    mins = np.zeros(delete_n + 1)
-    n_chains = chains1.shape[0]
-    chains = np.stack([chains1, chains2, chains3, chains4])
-
-    print("Without removals: {}".format(eval_simple_r_hat(chains)))
-    if simple:
-        r_hat = eval_simple_r_hat(chains)
-    else:
-        r_hat = eval_r_hat(chains1, chains2, chains3, chains4)
-    mins[0] = r_hat
-
-    for i in range(delete_n):
-        print()
-        r_hat_min = 10
-        sol = 0
-        for x in combinations(range(n_chains), n_chains - 1 - i):
-            if simple:
-                r_hat = eval_simple_r_hat(np.delete(chains, x, axis=1))
-            else:
-                r_hat = eval_r_hat(np.delete(chains1, x, axis=0), np.delete(chains2, x, axis=0), np.delete(chains3, x, axis=0), np.delete(chains4, x, axis=0))
-            if r_hat < r_hat_min:
-                sol = x
-            r_hat_min = min(r_hat, r_hat_min)
-        print("Minimum is {} (removed {})".format(r_hat_min, i + 1))
-        sol = [i for i in range(32) if i not in sol]
-        print("Removed: {}".format(sol))
-        mins[i + 1] = r_hat_min
-
-    return sol, r_hat_min
+def smart_divide(a, b):
+    c = np.zeros_like(a)
+    d = np.logical_and(a == 0, b == 0)
+    c[~d] = a[~d] / b[~d]
+    return c
 
 
 def find_good_chains_unsplit_greedy(chains1, chains2, chains3, chains4, reduce_to=8, simple=False):
@@ -1212,8 +1113,8 @@ def find_good_chains_unsplit_greedy(chains1, chains2, chains3, chains4, reduce_t
     n_chains = chains1.shape[0]
     chains = np.stack([chains1, chains2, chains3, chains4])
 
-    print("Without removals: {}".format(eval_r_hat(chains1, chains2, chains3, chains4)))
-    r_hat = eval_r_hat(chains1, chains2, chains3, chains4)
+    r_hat = eval_r_hat([chains1, chains2, chains3, chains4])
+    print("Without removals: {}".format(r_hat))
     mins[0] = r_hat
 
     to_del = []
@@ -1224,7 +1125,7 @@ def find_good_chains_unsplit_greedy(chains1, chains2, chains3, chains4, reduce_t
             if x in to_del:
                 continue
             if not simple:
-                r_hat = eval_r_hat(np.delete(chains1, to_del + [x], axis=0), np.delete(chains2, to_del + [x], axis=0), np.delete(chains3, to_del + [x], axis=0), np.delete(chains4, to_del + [x], axis=0))
+                r_hat = eval_r_hat([np.delete(chains1, to_del + [x], axis=0), np.delete(chains2, to_del + [x], axis=0), np.delete(chains3, to_del + [x], axis=0), np.delete(chains4, to_del + [x], axis=0)])
             else:
                 r_hat = eval_simple_r_hat(np.delete(chains, to_del + [x], axis=1))
             if r_hat < r_hat_min:
@@ -1237,23 +1138,11 @@ def find_good_chains_unsplit_greedy(chains1, chains2, chains3, chains4, reduce_t
         mins[i + 1] = r_hat_min
 
     if simple:
-        r_hat_local = eval_r_hat(np.delete(chains1, to_del, axis=0), np.delete(chains2, to_del, axis=0), np.delete(chains3, to_del, axis=0), np.delete(chains4, to_del, axis=0))
+        r_hat_local = eval_r_hat([np.delete(chains1, to_del, axis=0), np.delete(chains2, to_del, axis=0), np.delete(chains3, to_del, axis=0), np.delete(chains4, to_del, axis=0)])
         print("Minimum over everything is {} (removed {})".format(r_hat_local, i + 1))
     return to_del, r_hat_min
 
 
-def rank_inv_normal_transform(chains):
-    # Gelman paper Rank-normalization, folding, and localization: An improved R_hat for assessing convergence of MCMC
-    # ranking with average rank for ties
-    folded_chains = np.abs(chains - np.median(chains))
-    ranked = rankdata(chains).reshape(chains.shape)
-    folded_ranked = rankdata(folded_chains).reshape(folded_chains.shape)
-    # inverse normal with fractional offset
-    rank_normalised = norm.ppf((ranked - 3/8) / (chains.size + 1/4))
-    folded_rank_normalised = norm.ppf((folded_ranked - 3/8) / (folded_chains.size + 1/4))
-    return rank_normalised, folded_rank_normalised
-
-
 if __name__ == "__main__":
     fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0]
     if fit_type == 'bias':
@@ -1268,16 +1157,13 @@ if __name__ == "__main__":
     # test.r_hat_and_ess(return_ascending, False)
     # test.r_hat_and_ess(return_ascending_shuffled, False)
     # quit()
-
-    import pyhsmm.util.profiling as prof
-
     good = []
     bad = []
 
-    check_r_hats = False
+    check_r_hats = True
     if check_r_hats:
         subjects = list(loading_info.keys())
-        subjects = ['GLM_Sim_11_trick']
+        subjects = ['KS014']
         for subject in subjects:
             # if subject.startswith('GLM'):
             #     continue
@@ -1322,24 +1208,6 @@ def dist_helper(dist_matrix, state_hists, inds):
     return dist_matrix
 
 
-def state_size_helper(n=0, mode_specific=False):
-    if not mode_specific:
-        def nth_largest_state_func(x):
-            return np.partition(x.assign_counts, -1 - n, axis=1)[:, -1 - n]
-    else:
-        def nth_largest_state_func(x, ind):
-            return np.partition(x.assign_counts[ind], -1 - n, axis=1)[:, -1 - n]
-    return nth_largest_state_func
-
-
-def state_num_helper(t, mode_specific=False):
-    if not mode_specific:
-        def state_num_func(x): return ((x.assign_counts / x.n_datapoints) > t).sum(1)
-    else:
-        def state_num_func(x, ind): return ((x.assign_counts[ind] / x.n_datapoints) > t).sum(1)
-    return state_num_func
-
-
 def state_glm_func_helper(t, mode_specific=False):
     if not mode_specific:
         def temp_state_glm_func(x):
@@ -1364,26 +1232,6 @@ def state_glm_func_helper(t, mode_specific=False):
     return temp_state_glm_func
 
 
-def sample_statistics(test, mode_indices):
-    # prints out r_hats and sample sizes for given sample
-    test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb'))
-    test.r_hat_and_ess(state_size_helper(1), False)
-    test.r_hat_and_ess(state_size_helper(1, mode_specific=True), False, mode_indices=mode_indices)
-    print()
-    test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb'))
-    test.r_hat_and_ess(state_size_helper(), False)
-    test.r_hat_and_ess(state_size_helper(mode_specific=True), False, mode_indices=mode_indices)
-    print()
-    test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb'))
-    test.r_hat_and_ess(state_num_helper(0.05), False)
-    test.r_hat_and_ess(state_num_helper(0.05, mode_specific=True), False, mode_indices=mode_indices)
-    print()
-    test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb'))
-    test.r_hat_and_ess(state_num_helper(0.02), False)
-    test.r_hat_and_ess(state_num_helper(0.02, mode_specific=True), False, mode_indices=mode_indices)
-    print()
-
-
 def state_type_durs(states, pmfs):
     # Takes states and pmfs, first creates an array of when which type is how active, then computes the number of sessions each type lasts.
     # A type lasts until a more advanced type takes up more than 50% of a session (and cannot return)
@@ -1452,6 +1300,22 @@ def pmf_type(pmf):
 
 
 if __name__ == "__main__":
+
+    # visualise pmf types
+    lapses = [0.1, 0.2, 0.25, 0.33, 0.4, 0.45, 0.5, 0.55, 0.66, 0.9]
+    test_pmf = np.zeros(4)
+    for i, lapse_l in enumerate(lapses):
+        plt.subplot(1, 10, 1+i)
+        if i != 0:
+            plt.gca().set_yticklabels([])
+        plt.ylim(0, 1)
+        test_pmf[:2] = lapse_l
+        for lapse_r in np.linspace(0.02, 0.98, 33):
+            test_pmf[2:] = lapse_r
+            plt.plot([0, 1, 9, 10], test_pmf, c=type_colours[pmf_type(test_pmf)])
+    plt.show()
+    quit()
+
     fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0]
     if fit_type == 'bias':
         loading_info = json.load(open("canonical_infos_bias.json", 'r'))
@@ -1505,6 +1369,8 @@ if __name__ == "__main__":
     num_sessions = []
     state_appear = []
     state_dur = []
+    contrast_intro_types = []  # list to agglomorate in which state type which contrast is introduced
+    intros_by_type_sum = np.zeros(3)  # array to agglomorate how many states where introduced during which type, normalised by length of phase
 
     n_points = 150
     state_trajs = np.zeros((3, n_points))
@@ -1530,12 +1396,14 @@ if __name__ == "__main__":
 
             mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'rb'))
             state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb'))
-
             # lapse differential
             # lapse_sides(test, [s for s in state_sets if len(s) > 40], mode_indices)
 
             # training overview
-            states, pmfs, durs, _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=True)
+            states, pmfs, durs, _, contrast_intro_type, intros_by_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=True)
+            intros_by_type_sum += intros_by_type
+            continue
+            contrast_intro_types.append(contrast_intro_type)
             # state_development_single_sample(test, [mode_indices[0]], show=True, separate_pmf=True, save=False)
 
             # session overview
@@ -1794,12 +1662,6 @@ if __name__ == "__main__":
 
         plotSimplex(np.array(abs_state_durs), c='k', show=True)
 
-        plt.hist(abs_state_durs, label=['1st type', '2nd type', '3rd type'])
-        plt.legend()
-        plt.tight_layout()
-        plt.savefig("rel state type hists")
-        plt.show()
-
     if False:
         ax[0].set_ylim(0, 1)
         ax[1].set_ylim(0, 1)
diff --git a/dyn_glm_chain_analysis_unused_funcs.py b/dyn_glm_chain_analysis_unused_funcs.py
index 17daace56bc61e2dac54bc3d559fe7711872a3be..c2e51ba48a8682acd2c2c49850a80e7e2b0072d0 100644
--- a/dyn_glm_chain_analysis_unused_funcs.py
+++ b/dyn_glm_chain_analysis_unused_funcs.py
@@ -99,6 +99,40 @@ def find_good_chains_unsplit_fast(chains1, chains2, chains3, chains4, reduce_to=
     return sol, r_hat_min
 
 
+def find_good_chains_unsplit(chains1, chains2, chains3, chains4, reduce_to=8, simple=False):
+    delete_n = - reduce_to + chains1.shape[0]
+    mins = np.zeros(delete_n + 1)
+    n_chains = chains1.shape[0]
+    chains = np.stack([chains1, chains2, chains3, chains4])
+
+    print("Without removals: {}".format(eval_simple_r_hat(chains)))
+    if simple:
+        r_hat = eval_simple_r_hat(chains)
+    else:
+        r_hat = eval_r_hat(chains1, chains2, chains3, chains4)
+    mins[0] = r_hat
+
+    for i in range(delete_n):
+        print()
+        r_hat_min = 10
+        sol = 0
+        for x in combinations(range(n_chains), n_chains - 1 - i):
+            if simple:
+                r_hat = eval_simple_r_hat(np.delete(chains, x, axis=1))
+            else:
+                r_hat = eval_r_hat(np.delete(chains1, x, axis=0), np.delete(chains2, x, axis=0), np.delete(chains3, x, axis=0), np.delete(chains4, x, axis=0))
+            if r_hat < r_hat_min:
+                sol = x
+            r_hat_min = min(r_hat, r_hat_min)
+        print("Minimum is {} (removed {})".format(r_hat_min, i + 1))
+        sol = [i for i in range(32) if i not in sol]
+        print("Removed: {}".format(sol))
+        mins[i + 1] = r_hat_min
+
+    return sol, r_hat_min
+
+
+
 def params_to_pmf(params):
     return params[2] + (1 - params[2] - params[3]) / (1 + np.exp(- params[0] * (all_conts - params[1])))
 
@@ -214,7 +248,7 @@ if __name__ == "__main__":
         plt.savefig("pmf fit scatter")
         plt.show()
 
-
+        # New things
         xy = np.vstack([short_pmfs[:, 0], function_range, short_pmfs[:, 1]])
         z = gaussian_kde(xy)(xy)
         plt.figure(figsize=(24, 24 / 3))
diff --git a/mcmc_chain_analysis.py b/mcmc_chain_analysis.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a99a67602392e961761ef6ead4b9f40650e0d93
--- /dev/null
+++ b/mcmc_chain_analysis.py
@@ -0,0 +1,123 @@
+"""
+    Functions to extract statistics from a set of chains
+    Functions to compute R^hat from a set of statistic vectors
+"""
+import numpy as np
+from scipy.stats import rankdata, norm
+import pickle
+
+
+def state_size_helper(n=0, mode_specific=False):
+    if not mode_specific:
+        def nth_largest_state_func(x):
+            return np.partition(x.assign_counts, -1 - n, axis=1)[:, -1 - n]
+    else:
+        def nth_largest_state_func(x, ind):
+            return np.partition(x.assign_counts[ind], -1 - n, axis=1)[:, -1 - n]
+    return nth_largest_state_func
+
+
+def state_num_helper(t, mode_specific=False):
+    if not mode_specific:
+        def state_num_func(x): return ((x.assign_counts / x.n_datapoints) > t).sum(1)
+    else:
+        def state_num_func(x, ind): return ((x.assign_counts[ind] / x.n_datapoints) > t).sum(1)
+    return state_num_func
+
+
+def gamma_func(x): return x.trans_distn.gamma
+
+
+def alpha_func(x): return x.trans_distn.alpha
+
+
+def ll_func(x): return x.sample_lls[-x.n_samples:]
+
+
+def r_hat_array_comp(chains):
+    m, n = chains.shape
+    psi_dot_j = np.mean(chains, axis=1)
+    psi_dot_dot = np.mean(psi_dot_j)
+    B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot) ** 2)
+    s_j_squared = np.sum((chains - psi_dot_j[:, None]) ** 2, axis=1) / (n - 1)
+    W = np.mean(s_j_squared)
+    var_hat_plus = (n - 1) / n * W + B / n
+    if W == 0:
+        # print("all the same value")
+        return 1, 0
+    r_hat = np.sqrt(var_hat_plus / W)
+    return r_hat, var_hat_plus
+
+
+def eval_amortized_r_hat(chains, psi_dot_j, s_j_squared, m, n):
+    psi_dot_dot = np.mean(psi_dot_j, axis=1)
+    B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot[:, None]) ** 2, axis=1)
+    W = np.mean(s_j_squared, axis=1)
+    var_hat_plus = (n - 1) / n * W + B / n
+    r_hat = np.sqrt(var_hat_plus / W)
+    return max(r_hat)
+
+
+def r_hat_array_comp_mult(chains):
+    _, m, n = chains.shape
+    psi_dot_j = np.mean(chains, axis=2)
+    psi_dot_dot = np.mean(psi_dot_j, axis=1)
+    B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot[:, None]) ** 2, axis=1)
+    s_j_squared = np.sum((chains - psi_dot_j[:, :, None]) ** 2, axis=2) / (n - 1)
+    W = np.mean(s_j_squared, axis=1)
+    var_hat_plus = (n - 1) / n * W + B / n
+    r_hat = np.sqrt(var_hat_plus / W)
+    return r_hat, var_hat_plus
+
+
+def rank_inv_normal_transform(chains):
+    # Gelman paper Rank-normalization, folding, and localization: An improved R_hat for assessing convergence of MCMC
+    # ranking with average rank for ties
+    folded_chains = np.abs(chains - np.median(chains))
+    ranked = rankdata(chains).reshape(chains.shape)
+    folded_ranked = rankdata(folded_chains).reshape(folded_chains.shape)
+    # inverse normal with fractional offset
+    rank_normalised = norm.ppf((ranked - 3/8) / (chains.size + 1/4))
+    folded_rank_normalised = norm.ppf((folded_ranked - 3/8) / (folded_chains.size + 1/4))
+    return rank_normalised, folded_rank_normalised, ranked, folded_ranked
+
+
+def eval_r_hat(chains):
+    r_hats = []
+    for chain in chains:
+        rank_normalised, folded_rank_normalised, _, _ = rank_inv_normal_transform(chain)
+        r_hats.append(comp_multi_r_hat(chain, rank_normalised, folded_rank_normalised))
+
+    return max(r_hats)
+
+
+def eval_simple_r_hat(chains):
+    r_hats, _ = r_hat_array_comp_mult(chains)
+    return max(r_hats)
+
+
+def comp_multi_r_hat(chains, rank_normalised, folded_rank_normalised):
+    lame_r_hat, _ = r_hat_array_comp(chains)
+    rank_normalised_r_hat, _ = r_hat_array_comp(rank_normalised)
+    folded_rank_normalised_r_hat, _ = r_hat_array_comp(folded_rank_normalised)
+    return max(lame_r_hat, rank_normalised_r_hat, folded_rank_normalised_r_hat)
+
+
+def sample_statistics(test, mode_indices, subject):
+    # prints out r_hats and sample sizes for given sample
+    test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb'))
+    test.r_hat_and_ess(state_size_helper(1), False)
+    test.r_hat_and_ess(state_size_helper(1, mode_specific=True), False, mode_indices=mode_indices)
+    print()
+    test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb'))
+    test.r_hat_and_ess(state_size_helper(), False)
+    test.r_hat_and_ess(state_size_helper(mode_specific=True), False, mode_indices=mode_indices)
+    print()
+    test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb'))
+    test.r_hat_and_ess(state_num_helper(0.05), False)
+    test.r_hat_and_ess(state_num_helper(0.05, mode_specific=True), False, mode_indices=mode_indices)
+    print()
+    test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb'))
+    test.r_hat_and_ess(state_num_helper(0.02), False)
+    test.r_hat_and_ess(state_num_helper(0.02, mode_specific=True), False, mode_indices=mode_indices)
+    print()
diff --git a/process_many_chains.py b/process_many_chains.py
index 7aba22fbac5f7b599568df6aea157e10ca1fb152..88840671866206bcdee29ba09c549eb8181bae00 100644
--- a/process_many_chains.py
+++ b/process_many_chains.py
@@ -1,3 +1,7 @@
+"""
+    Script for taking a list of subects and extracting statistics from the chains
+    which can be used to assess which chains have converged to the same regions
+"""
 import numpy as np
 import pyhsmm
 import pickle
@@ -5,25 +9,7 @@ import json
 from dyn_glm_chain_analysis import MCMC_result
 import matplotlib.pyplot as plt
 import time
-
-def gamma_func(x): return x.trans_distn.gamma
-
-
-def alpha_func(x): return x.trans_distn.alpha
-
-
-def state_size_helper(n=0):
-    def nth_largest_state_func(x):
-        return np.partition(x.assign_counts, -1 - n, axis=1)[:, -1 - n]
-    return nth_largest_state_func
-
-
-def state_num_helper(t):
-    def state_num_func(x): return ((x.assign_counts / x.n_datapoints) > t).sum(1)
-    return state_num_func
-
-
-def ll_func(x): return x.sample_lls[-x.n_samples:]
+from mcmc_chain_analysis import state_size_helper, state_num_helper, ll_helper
 
 
 fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0]
@@ -45,117 +31,79 @@ func8 = state_size_helper(4)
 func9 = state_size_helper(5)
 func10 = state_size_helper(6)
 func11 = state_size_helper(7)
-model_for_loop = False
 dur = 'yes'
 
-def temp():
-    m = 16
-    for subject in subjects:
-        print(subject)
-        n_runs = -1
-        counter = -1
-        n = (loading_info[subject]['chain_num'] + 1) * 4000 // 25
-        chains1 = np.zeros((m, n))
-        chains2 = np.zeros((m, n))
-        chains3 = np.zeros((m, n))
-        chains4 = np.zeros((m, n))
-        for j, (seed, fit_num) in enumerate(zip(loading_info[subject]['seeds'], loading_info[subject]['fit_nums'])):
-            counter += 1
-            print(seed)
-            info_dict = pickle.load(open("./session_data/{}_info_dict.p".format(subject), "rb"))
-            samples = []
-            mini_counter = 0
-            while True:
-                try:
-                    file = "./dynamic_GLMiHMM_crossvals/{}_fittype_{}_var_{}_{}_{}{}.p".format(subject, fit_type, fit_variance, seed, fit_num, '_{}'.format(mini_counter))
-                    lala = time.time()
-                    samples += pickle.load(open(file, "rb"))
-                    print("Loading {} took {:.4}".format(mini_counter, time.time() - lala))
-                    mini_counter += 1
-                except Exception:
-                    break
-            if n_runs == -1:
-                n_runs = mini_counter
-            else:
-                if n_runs != mini_counter:
-                    print("Problem")
-                    print(n_runs, mini_counter)
-                    quit()
-            save_id = "{}_fittype_{}_var_{}_{}_{}.p".format(subject, fit_type, fit_variance, seed, fit_num).replace('.', '_')
-            # for file in os.listdir("./dynamic_GLMiHMM_crossvals/infos_new/"):
-            #     if file.startswith("{}_".format(subject)) and file.endswith("_{}_{}_{}_{}.json".format(fit_type, fit_variance, seed, fit_num)):
-            #         print("Taking fit infos from {}".format(file))
-            #         fit_infos = json.load(open("./dynamic_GLMiHMM_crossvals/infos_new/" + file, 'r'))
-            #         sample_lls = fit_infos['ll']
-
-            print("loaded seed {}".format(seed))
-
-            lala = time.time()
-            result = MCMC_result(samples[::25],
-                                 infos=info_dict, data=samples[0].datas,
-                                 sessions=fit_type, fit_variance=fit_variance,
-                                 dur=dur, save_id=save_id) # , sample_lls=sample_lls)
-            print("Making result {} took {:.4}".format(counter, time.time() - lala))
-
-            # if model_for_loop:
-            #     for i in range(n):
-            #         chains[j, i] = func(result.models[i])
-            res = func1(result)
-            chains1[counter] = res
-            res = func2(result)
-            chains2[counter] = res
-            res = func3(result)
-            chains3[counter] = res
-            res = func4(result)
-            chains4[counter] = res
-
-            # res = func5(result)
-            # chains5[j] = res
-            # res = func6(result)
-            # chains6[j] = res
-            # res = func7(result)
-            # chains7[j] = res
-            # res = func8(result)
-            # chains8[j] = res
-            # res = func9(result)
-            # chains9[j] = res
-            # res = func10(result)
-            # chains10[j] = res
-            # res = func11(result)
-            # chains11[j] = res
-
-        # func2 = state_size_helper()
-        # func5 = state_size_helper(1)
-        # func6 = state_size_helper(2)
-        # func7 = state_size_helper(3)
-        # func8 = state_size_helper(4)
-        # func9 = state_size_helper(5)
-        # func10 = state_size_helper(6)
-        # func11 = state_size_helper(7)
-        # plt.plot(chains2.flatten()[::10])
-        # plt.plot(chains5.flatten()[::10])
-        # plt.plot(chains6.flatten()[::10])
-        # plt.plot(chains7.flatten()[::10])
-        # plt.plot(chains8.flatten()[::10])
-        # plt.plot(chains9.flatten()[::10])
-        # plt.plot(chains10.flatten()[::10])
-        # plt.plot(chains11.flatten()[::10])
-        # plt.axhline(1643, color='r')
-        # plt.axhline(3438, color='r')
-        # plt.axhline(2216, color='r')
-        # plt.axhline(811, color='r')
-        # plt.axhline(2721, color='r')
-        # plt.axhline(743, color='r')
-        # plt.show()
-
-        pickle.dump(chains1, open("multi_chain_saves/{}_state_num_0_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb'))
-        pickle.dump(chains2, open("multi_chain_saves/{}_state_num_1_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb'))
-        pickle.dump(chains3, open("multi_chain_saves/{}_largest_state_0_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb'))
-        pickle.dump(chains4, open("multi_chain_saves/{}_largest_state_1_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb'))
-
-
-temp()
-# import pyhsmm.util.profiling as prof
-# prof_func = prof._prof(temp)
-# prof_func()
-# prof._prof.print_stats()
+m = 16
+for subject in subjects:
+    print(subject)
+    n_runs = -1
+    counter = -1
+    n = (loading_info[subject]['chain_num'] + 1) * 4000 // 25
+    chains1 = np.zeros((m, n))
+    chains2 = np.zeros((m, n))
+    chains3 = np.zeros((m, n))
+    chains4 = np.zeros((m, n))
+    for j, (seed, fit_num) in enumerate(zip(loading_info[subject]['seeds'], loading_info[subject]['fit_nums'])):
+        counter += 1
+        print(seed)
+        info_dict = pickle.load(open("./session_data/{}_info_dict.p".format(subject), "rb"))
+        samples = []
+        mini_counter = 0
+        while True:
+            try:
+                file = "./dynamic_GLMiHMM_crossvals/{}_fittype_{}_var_{}_{}_{}{}.p".format(subject, fit_type, fit_variance, seed, fit_num, '_{}'.format(mini_counter))
+                lala = time.time()
+                samples += pickle.load(open(file, "rb"))
+                print("Loading {} took {:.4}".format(mini_counter, time.time() - lala))
+                mini_counter += 1
+            except Exception:
+                break
+        if n_runs == -1:
+            n_runs = mini_counter
+        else:
+            if n_runs != mini_counter:
+                print("Problem")
+                print(n_runs, mini_counter)
+                quit()
+        save_id = "{}_fittype_{}_var_{}_{}_{}.p".format(subject, fit_type, fit_variance, seed, fit_num).replace('.', '_')
+        # for file in os.listdir("./dynamic_GLMiHMM_crossvals/infos_new/"):
+        #     if file.startswith("{}_".format(subject)) and file.endswith("_{}_{}_{}_{}.json".format(fit_type, fit_variance, seed, fit_num)):
+        #         print("Taking fit infos from {}".format(file))
+        #         fit_infos = json.load(open("./dynamic_GLMiHMM_crossvals/infos_new/" + file, 'r'))
+        #         sample_lls = fit_infos['ll']
+
+        print("loaded seed {}".format(seed))
+
+        result = MCMC_result(samples[::25],
+                             infos=info_dict, data=samples[0].datas,
+                             sessions=fit_type, fit_variance=fit_variance,
+                             dur=dur, save_id=save_id) # , sample_lls=sample_lls)
+        print("Making result {} took {:.4}".format(counter, time.time() - lala))
+
+        res = func1(result)
+        chains1[counter] = res
+        res = func2(result)
+        chains2[counter] = res
+        res = func3(result)
+        chains3[counter] = res
+        res = func4(result)
+        chains4[counter] = res
+        # res = func5(result)
+        # chains5[j] = res
+        # res = func6(result)
+        # chains6[j] = res
+        # res = func7(result)
+        # chains7[j] = res
+        # res = func8(result)
+        # chains8[j] = res
+        # res = func9(result)
+        # chains9[j] = res
+        # res = func10(result)
+        # chains10[j] = res
+        # res = func11(result)
+        # chains11[j] = res
+
+    pickle.dump(chains1, open("multi_chain_saves/{}_state_num_0_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb'))
+    pickle.dump(chains2, open("multi_chain_saves/{}_state_num_1_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb'))
+    pickle.dump(chains3, open("multi_chain_saves/{}_largest_state_0_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb'))
+    pickle.dump(chains4, open("multi_chain_saves/{}_largest_state_1_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb'))
diff --git a/simplex_plot.py b/simplex_plot.py
index 41db6f5e60a3691585d0e94dacdfb6c319b8ffc7..bdaad7885fa71bbbddcdd58b45447f4acf442ae3 100644
--- a/simplex_plot.py
+++ b/simplex_plot.py
@@ -4,7 +4,7 @@ Visualize points on the 3-simplex (eg, the parameters of a
 contained within a 2D triangle.
 Adapted from David Andrzejewski (david.andrzej@gmail.com)
 """
-import numpy as NP
+import numpy as np
 import matplotlib.pyplot as P
 import matplotlib.ticker as MT
 import matplotlib.lines as L
@@ -14,38 +14,42 @@ import matplotlib.patches as PA
 
 
 def plotSimplex(points, fig=None,
-                vertexlabels=['1', '2', '3'], save_title="test.png",
+                vertexlabels=['Type 1', 'Type 2', 'Type 3'], save_title="test.png",
                 show=False, **kwargs):
     """
     Plot Nx3 points array on the 3-simplex
     (with optionally labeled vertices)
 
     kwargs will be passed along directly to matplotlib.pyplot.scatter
-    Returns Figure, caller must .show()
     """
     if fig is None:
         fig = P.figure(figsize=(9, 9))
     # Draw the triangle
     l1 = L.Line2D([0, 0.5, 1.0, 0], # xcoords
-                  [0, NP.sqrt(3) / 2, 0, 0], # ycoords
+                  [0, np.sqrt(3) / 2, 0, 0], # ycoords
                   color='k')
     fig.gca().add_line(l1)
     fig.gca().xaxis.set_major_locator(MT.NullLocator())
     fig.gca().yaxis.set_major_locator(MT.NullLocator())
     # Draw vertex labels
-    fig.gca().text(-0.05, -0.05, vertexlabels[0])
-    fig.gca().text(1.05, -0.05, vertexlabels[1])
-    fig.gca().text(0.5, NP.sqrt(3) / 2 + 0.05, vertexlabels[2])
+    fig.gca().text(-0.06, -0.05, vertexlabels[0], size=24)
+    fig.gca().text(0.95, -0.05, vertexlabels[1], size=24)
+    fig.gca().text(0.43, np.sqrt(3) / 2 + 0.025, vertexlabels[2], size=24)
     # Project and draw the actual points
     projected = projectSimplex(points / points.sum(1)[:, None])
-    P.scatter(projected[:, 0], projected[:, 1], s=points.sum(1), **kwargs)
+    P.scatter(projected[:, 0], projected[:, 1], s=points.sum(1) * 3.5, **kwargs)
+
+    # plot center with average size
+    projected = projectSimplex(np.mean(points / points.sum(1)[:, None], axis=0).reshape(1, 3))
+    P.scatter(projected[:, 0], projected[:, 1], marker='*', color='r', s=np.mean(points.sum(1)) * 3.5)
+
     # Leave some buffer around the triangle for vertex labels
     fig.gca().set_xlim(-0.05, 1.05)
     fig.gca().set_ylim(-0.05, 1.05)
 
     P.axis('off')
 
-    P.savefig("test.png", bbox_inches='tight')
+    P.savefig("dur_simplex.png", bbox_inches='tight')
     if show:
         P.show()
     else:
@@ -59,22 +63,22 @@ def projectSimplex(points):
     N points are given as N x 3 array
     """
     # Convert points one at a time
-    tripts = NP.zeros((points.shape[0], 2))
+    tripts = np.zeros((points.shape[0], 2))
     for idx in range(points.shape[0]):
         # Init to triangle centroid
         x = 1.0 / 2
-        y = 1.0 / (2 * NP.sqrt(3))
+        y = 1.0 / (2 * np.sqrt(3))
         # Vector 1 - bisect out of lower left vertex
         p1 = points[idx, 0]
-        x = x - (1.0 / NP.sqrt(3)) * p1 * NP.cos(NP.pi / 6)
-        y = y - (1.0 / NP.sqrt(3)) * p1 * NP.sin(NP.pi / 6)
+        x = x - (1.0 / np.sqrt(3)) * p1 * np.cos(np.pi / 6)
+        y = y - (1.0 / np.sqrt(3)) * p1 * np.sin(np.pi / 6)
         # Vector 2 - bisect out of lower right vertex
         p2 = points[idx, 1]
-        x = x + (1.0 / NP.sqrt(3)) * p2 * NP.cos(NP.pi / 6)
-        y = y - (1.0 / NP.sqrt(3)) * p2 * NP.sin(NP.pi / 6)
+        x = x + (1.0 / np.sqrt(3)) * p2 * np.cos(np.pi / 6)
+        y = y - (1.0 / np.sqrt(3)) * p2 * np.sin(np.pi / 6)
         # Vector 3 - bisect out of top vertex
         p3 = points[idx, 2]
-        y = y + (1.0 / NP.sqrt(3) * p3)
+        y = y + (1.0 / np.sqrt(3) * p3)
 
         tripts[idx, :] = (x, y)
 
@@ -87,7 +91,7 @@ if __name__ == '__main__':
               '[0.8  0.1  0.1]',
               '[0.5  0.4  0.1]',
               '[0.33  0.34  0.33]')
-    testpoints = NP.array([[0.1, 0.1, 0.8],
+    testpoints = np.array([[0.1, 0.1, 0.8],
                            [0.8, 0.1, 0.1],
                            [0.5, 0.4, 0.1],
                            [0.33, 0.34, 0.33]])