diff --git a/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc b/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc index a82afa9c345c7e3df3fd0494a1e8695fda475388..b4b56a55e7c7e3ff5cceb8da95a87d51c338cbac 100644 Binary files a/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc and b/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc differ diff --git a/__pycache__/mcmc_chain_analysis.cpython-37.pyc b/__pycache__/mcmc_chain_analysis.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdea663ae5d36ec0de32219c100df06016b9244f Binary files /dev/null and b/__pycache__/mcmc_chain_analysis.cpython-37.pyc differ diff --git a/__pycache__/process_many_chains.cpython-37.pyc b/__pycache__/process_many_chains.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44eddf35ca7b5b24cc52c61abb188e42c155fc62 Binary files /dev/null and b/__pycache__/process_many_chains.cpython-37.pyc differ diff --git a/__pycache__/simplex_plot.cpython-37.pyc b/__pycache__/simplex_plot.cpython-37.pyc index 1e8de26b8b6d51af0763e873273c06afd448dab6..1c26491f573cb6355e901c949147f3d921c87a30 100644 Binary files a/__pycache__/simplex_plot.cpython-37.pyc and b/__pycache__/simplex_plot.cpython-37.pyc differ diff --git a/canonical_infos.json b/canonical_infos.json index 6b9451cec02c20319870b92c795d47ba50ee8cbf..c2ca57ff4832606bb4d5af5dad36768fc4162753 100644 --- a/canonical_infos.json +++ b/canonical_infos.json @@ -1 +1 @@ -{"SWC_023": {"seeds": ["302", "312", "304", "300", "315", "311", "308", "305", "303", "309", "306", "313", "307", "314", "301", "310"], "fit_nums": ["994", "913", "681", "816", "972", "790", "142", "230", "696", "537", "975", "773", "918", "677", "742", "745"], "chain_num": 4}, "SWC_021": {"seeds": ["415", "403", "412", "407", "409", "408", "405", "404", "410", "414", "401", "413", "402", "400", "406", "411"], "fit_nums": ["773", "615", "107", "583", "564", "354", "142", "184", "549", "185", "924", "907", "105", "531", "9", "812"], "chain_num": 9}, "ibl_witten_15": {"seeds": ["409", "410", "401", "415", "414", "403", "411", "404", "402", "405", "400", "412", "408", "407", "406", "413"], "fit_nums": ["411", "344", "496", "600", "716", "18", "527", "467", "898", "334", "309", "326", "133", "823", "740", "253"], "chain_num": 9}, "ibl_witten_13": {"seeds": ["302", "312", "313", "306", "315", "307", "311", "314", "309", "301", "308", "300", "304", "310", "303", "305"], "fit_nums": ["897", "765", "433", "641", "967", "599", "984", "259", "853", "385", "887", "619", "434", "964", "483", "891"], "chain_num": 4}, "KS016": {"seeds": ["315", "301", "309", "313", "302", "307", "303", "308", "311", "312", "314", "306", "310", "300", "305", "304"], "fit_nums": ["99", "57", "585", "32", "501", "558", "243", "413", "59", "757", "463", "172", "524", "957", "909", "292"], "chain_num": 4}, "KS003": {"seeds": ["404", "407", "413", "403", "414", "405", "400", "401", "402", "410", "415", "408", "411", "409", "406", "412"], "fit_nums": ["846", "256", "845", "945", "293", "406", "420", "109", "690", "421", "54", "866", "784", "81", "997", "665"], "chain_num": 9}, "ibl_witten_19": {"seeds": ["315", "311", "307", "314", "308", "300", "305", "301", "313", "304", "302", "310", "306", "312", "309", "303"], "fit_nums": ["179", "951", "613", "6", "623", "382", "458", "504", "406", "554", "5", "631", "746", "817", "265", "328"], "chain_num": 4}, "SWC_022": {"seeds": ["411", "403", "414", "409", "407", "412", "410", "413", "415", "404", "405", "400", "402", "401", "408", "406"], "fit_nums": ["408", "884", "62", "962", "744", "854", "635", "70", "320", "952", "8", "67", "231", "381", "536", "962"], "chain_num": 9}, "KS022": {"seeds": ["315", "300", "314", "301", "303", "302", "306", "308", "305", "310", "313", "312", "304", "307", "311", "309"], "fit_nums": ["899", "681", "37", "957", "629", "637", "375", "980", "810", "51", "759", "664", "420", "127", "259", "555"], "chain_num": 4}, "CSH_ZAD_017": {"seeds": ["401", "409", "405", "403", "415", "404", "402", "411", "410", "414", "408", "406", "413", "412", "400", "407"], "fit_nums": ["883", "803", "637", "806", "356", "804", "662", "654", "684", "350", "947", "460", "569", "976", "103", "713"], "chain_num": 9}, "CSH_ZAD_025": {"seeds": ["303", "311", "307", "312", "313", "314", "308", "315", "305", "306", "304", "302", "309", "310", "301", "300"], "fit_nums": ["581", "148", "252", "236", "581", "838", "206", "756", "449", "288", "756", "593", "733", "633", "418", "563"], "chain_num": 4}, "ibl_witten_17": {"seeds": ["406", "415", "408", "413", "402", "405", "409", "400", "414", "401", "412", "407", "404", "410", "403", "411"], "fit_nums": ["827", "797", "496", "6", "444", "823", "384", "873", "634", "27", "811", "142", "207", "322", "756", "275"], "chain_num": 9}, "ibl_witten_18": {"seeds": ["311", "310", "303", "314", "302", "309", "305", "307", "312", "300", "308", "306", "315", "313", "304", "301"], "fit_nums": ["236", "26", "838", "762", "826", "409", "496", "944", "280", "704", "930", "419", "637", "896", "876", "297"], "chain_num": 4}, "CSHL_018": {"seeds": ["302", "310", "306", "300", "314", "307", "309", "313", "311", "308", "304", "301", "312", "303", "305", "315"], "fit_nums": ["843", "817", "920", "900", "226", "36", "472", "676", "933", "453", "116", "263", "269", "897", "568", "438"], "chain_num": 4}, "GLM_Sim_06": {"seeds": ["313", "309", "302", "303", "305", "314", "300", "315", "311", "306", "304", "310", "301", "312", "308", "307"], "fit_nums": ["9", "786", "286", "280", "72", "587", "619", "708", "360", "619", "311", "189", "60", "708", "939", "733"], "chain_num": 2}, "ZM_1897": {"seeds": ["304", "308", "305", "311", "315", "314", "307", "306", "300", "303", "313", "310", "301", "312", "302", "309"], "fit_nums": ["549", "96", "368", "509", "424", "897", "287", "426", "968", "93", "725", "513", "837", "581", "989", "374"], "chain_num": 4}, "CSHL_020": {"seeds": ["305", "309", "313", "302", "314", "310", "300", "307", "315", "306", "312", "304", "311", "301", "303", "308"], "fit_nums": ["222", "306", "243", "229", "584", "471", "894", "238", "986", "660", "494", "657", "896", "459", "100", "283"], "chain_num": 4}, "CSHL054": {"seeds": ["401", "415", "409", "410", "414", "413", "407", "405", "406", "408", "411", "400", "412", "402", "403", "404"], "fit_nums": ["901", "734", "609", "459", "574", "793", "978", "66", "954", "906", "954", "111", "292", "850", "266", "967"], "chain_num": 9}, "CSHL_014": {"seeds": ["305", "311", "309", "300", "313", "310", "307", "306", "304", "312", "308", "302", "314", "303", "301", "315"], "fit_nums": ["371", "550", "166", "24", "705", "385", "870", "884", "831", "546", "404", "722", "287", "564", "613", "783"], "chain_num": 4}, "CSHL062": {"seeds": ["307", "313", "310", "303", "306", "312", "308", "305", "311", "314", "304", "302", "300", "301", "315", "309"], "fit_nums": ["846", "371", "94", "888", "499", "229", "546", "432", "71", "989", "986", "91", "935", "314", "975", "481"], "chain_num": 4}, "CSH_ZAD_001": {"seeds": ["313", "309", "311", "312", "305", "310", "315", "300", "314", "304", "301", "302", "308", "303", "306", "307"], "fit_nums": ["468", "343", "314", "544", "38", "120", "916", "170", "305", "569", "502", "496", "452", "336", "559", "572"], "chain_num": 4}, "NYU-06": {"seeds": ["314", "309", "306", "305", "312", "303", "307", "304", "300", "302", "310", "301", "315", "308", "313", "311"], "fit_nums": ["950", "862", "782", "718", "427", "645", "827", "612", "821", "834", "595", "929", "679", "668", "648", "869"], "chain_num": 4}, "KS019": {"seeds": ["404", "401", "411", "408", "400", "403", "410", "413", "402", "407", "415", "409", "406", "414", "412", "405"], "fit_nums": ["682", "4", "264", "200", "250", "267", "737", "703", "132", "855", "922", "686", "85", "176", "54", "366"], "chain_num": 9}, "CSHL049": {"seeds": ["411", "402", "414", "408", "409", "410", "413", "407", "406", "401", "404", "405", "403", "415", "400", "412"], "fit_nums": ["104", "553", "360", "824", "749", "519", "347", "228", "863", "671", "140", "883", "701", "445", "627", "898"], "chain_num": 9}, "ibl_witten_14": {"seeds": ["310", "311", "304", "306", "300", "302", "314", "313", "303", "308", "301", "309", "305", "315", "312", "307"], "fit_nums": ["563", "120", "85", "712", "277", "871", "183", "661", "505", "598", "210", "89", "310", "638", "564", "998"], "chain_num": 4}, "KS014": {"seeds": ["301", "310", "302", "312", "313", "308", "307", "303", "305", "300", "314", "306", "311", "309", "304", "315"], "fit_nums": ["668", "32", "801", "193", "269", "296", "74", "24", "270", "916", "21", "250", "342", "451", "517", "293"], "chain_num": 4}, "CSHL059": {"seeds": ["306", "309", "300", "304", "314", "303", "315", "311", "313", "305", "301", "307", "302", "312", "310", "308"], "fit_nums": ["821", "963", "481", "999", "986", "45", "551", "605", "701", "201", "629", "261", "972", "407", "165", "9"], "chain_num": 4}, "GLM_Sim_13": {"seeds": ["310", "303", "308", "306", "300", "312", "301", "313", "305", "311", "315", "304", "314", "309", "307", "302"], "fit_nums": ["982", "103", "742", "524", "614", "370", "926", "456", "133", "143", "302", "80", "395", "549", "579", "944"], "chain_num": 2}, "CSHL_007": {"seeds": ["314", "303", "308", "313", "301", "300", "302", "305", "315", "306", "310", "309", "311", "304", "307", "312"], "fit_nums": ["462", "703", "345", "286", "480", "313", "986", "165", "201", "102", "322", "894", "960", "438", "330", "169"], "chain_num": 4}, "CSH_ZAD_011": {"seeds": ["314", "311", "303", "300", "305", "310", "306", "301", "302", "315", "304", "309", "308", "312", "313", "307"], "fit_nums": ["320", "385", "984", "897", "315", "120", "320", "945", "475", "403", "210", "412", "695", "564", "664", "411"], "chain_num": 4}, "KS021": {"seeds": ["309", "312", "304", "310", "303", "311", "314", "302", "305", "301", "306", "300", "308", "315", "313", "307"], "fit_nums": ["874", "943", "925", "587", "55", "136", "549", "528", "349", "211", "401", "84", "225", "545", "153", "382"], "chain_num": 4}, "GLM_Sim_15": {"seeds": ["303", "312", "305", "308", "309", "302", "301", "310", "313", "315", "311", "314", "307", "306", "304", "300"], "fit_nums": ["769", "930", "328", "847", "899", "714", "144", "518", "521", "873", "914", "359", "242", "343", "45", "364"], "chain_num": 2}, "CSHL_015": {"seeds": ["301", "302", "307", "310", "309", "311", "304", "312", "300", "308", "313", "305", "314", "315", "306", "303"], "fit_nums": ["717", "705", "357", "539", "604", "971", "669", "76", "45", "413", "510", "122", "190", "821", "368", "472"], "chain_num": 4}, "ibl_witten_16": {"seeds": ["304", "313", "309", "314", "312", "307", "305", "301", "306", "310", "300", "315", "308", "311", "303", "302"], "fit_nums": ["392", "515", "696", "270", "7", "583", "880", "674", "23", "576", "579", "695", "149", "854", "184", "875"], "chain_num": 4}, "KS015": {"seeds": ["315", "305", "309", "303", "314", "310", "311", "312", "313", "300", "307", "308", "304", "301", "302", "306"], "fit_nums": ["257", "396", "387", "435", "133", "164", "403", "8", "891", "650", "111", "557", "473", "229", "842", "196"], "chain_num": 4}, "GLM_Sim_12": {"seeds": ["304", "312", "306", "303", "310", "302", "300", "305", "308", "313", "307", "311", "315", "301", "314", "309"], "fit_nums": ["971", "550", "255", "195", "952", "486", "841", "535", "559", "37", "654", "213", "864", "506", "732", "550"], "chain_num": 2}, "GLM_Sim_11": {"seeds": ["300", "312", "310", "315", "302", "313", "314", "311", "308", "303", "309", "307", "306", "304", "301", "305"], "fit_nums": ["477", "411", "34", "893", "195", "293", "603", "5", "887", "281", "956", "73", "346", "640", "532", "688"], "chain_num": 2}, "GLM_Sim_10": {"seeds": ["301", "300", "306", "305", "307", "309", "312", "314", "311", "315", "304", "313", "303", "308", "302", "310"], "fit_nums": ["391", "97", "897", "631", "239", "652", "19", "448", "807", "35", "972", "469", "280", "562", "42", "706"], "chain_num": 2}, "CSH_ZAD_026": {"seeds": ["312", "313", "308", "310", "303", "307", "302", "305", "300", "315", "306", "301", "311", "304", "314", "309"], "fit_nums": ["699", "87", "537", "628", "797", "511", "459", "770", "969", "240", "504", "948", "295", "506", "25", "378"], "chain_num": 4}, "KS023": {"seeds": ["304", "313", "306", "309", "300", "314", "302", "310", "303", "315", "307", "308", "301", "311", "305", "312"], "fit_nums": ["698", "845", "319", "734", "908", "507", "45", "499", "175", "108", "419", "443", "116", "779", "159", "231"], "chain_num": 4}, "GLM_Sim_05": {"seeds": ["301", "315", "300", "302", "305", "304", "313", "314", "311", "309", "306", "307", "308", "310", "303", "312"], "fit_nums": ["425", "231", "701", "375", "343", "902", "623", "125", "921", "637", "393", "964", "678", "930", "796", "42"], "chain_num": 2}, "CSHL061": {"seeds": ["305", "315", "304", "303", "309", "310", "302", "300", "314", "306", "311", "313", "301", "308", "307", "312"], "fit_nums": ["396", "397", "594", "911", "308", "453", "686", "552", "103", "209", "128", "892", "345", "925", "777", "396"], "chain_num": 4}, "CSHL051": {"seeds": ["303", "310", "306", "302", "309", "305", "313", "308", "300", "314", "311", "307", "312", "304", "315", "301"], "fit_nums": ["69", "186", "49", "435", "103", "910", "705", "367", "303", "474", "596", "334", "929", "796", "616", "790"], "chain_num": 4}, "GLM_Sim_14": {"seeds": ["310", "311", "309", "313", "314", "300", "302", "304", "305", "306", "307", "312", "303", "301", "315", "308"], "fit_nums": ["616", "872", "419", "106", "940", "986", "599", "704", "218", "808", "244", "825", "448", "397", "552", "316"], "chain_num": 2}, "GLM_Sim_11_trick": {"seeds": ["411", "400", "408", "409", "415", "413", "410", "412", "406", "414", "403", "404", "401", "405", "407", "402"], "fit_nums": ["95", "508", "886", "384", "822", "969", "525", "382", "489", "436", "344", "537", "251", "223", "458", "401"], "chain_num": 2, "ignore": [10, 12, 4, 1, 0, 3, 2, 6]}, "GLM_Sim_16": {"seeds": ["302", "311", "303", "307", "313", "308", "309", "300", "305", "315", "304", "310", "312", "301", "314", "306"], "fit_nums": ["914", "377", "173", "583", "870", "456", "611", "697", "13", "713", "159", "248", "617", "37", "770", "780"], "chain_num": 2}, "ZM_3003": {"seeds": ["300", "304", "307", "312", "305", "310", "311", "314", "303", "308", "313", "301", "315", "309", "306", "302"], "fit_nums": ["603", "620", "657", "735", "357", "390", "119", "33", "62", "617", "209", "810", "688", "21", "744", "426"], "chain_num": 4}, "CSH_ZAD_022": {"seeds": ["305", "310", "311", "315", "303", "312", "314", "313", "307", "302", "300", "304", "301", "308", "306", "309"], "fit_nums": ["143", "946", "596", "203", "576", "403", "900", "65", "478", "325", "282", "513", "460", "42", "161", "970"], "chain_num": 4}, "GLM_Sim_07": {"seeds": ["300", "309", "302", "304", "305", "312", "301", "311", "315", "314", "308", "307", "303", "310", "306", "313"], "fit_nums": ["724", "701", "118", "230", "648", "426", "689", "114", "832", "731", "592", "519", "559", "938", "672", "144"], "chain_num": 1}, "KS017": {"seeds": ["311", "310", "306", "309", "303", "302", "308", "300", "313", "301", "314", "307", "315", "304", "312", "305"], "fit_nums": ["97", "281", "808", "443", "352", "890", "703", "468", "780", "708", "674", "27", "345", "23", "939", "457"], "chain_num": 4}, "GLM_Sim_11_sub": {"seeds": ["410", "414", "413", "404", "409", "415", "406", "408", "402", "411", "400", "405", "403", "407", "412", "401"], "fit_nums": ["830", "577", "701", "468", "929", "374", "954", "749", "937", "488", "873", "416", "612", "792", "461", "488"], "chain_num": 2}} \ No newline at end of file +{"SWC_023": {"seeds": ["302", "312", "304", "300", "315", "311", "308", "305", "303", "309", "306", "313", "307", "314", "301", "310"], "fit_nums": ["994", "913", "681", "816", "972", "790", "142", "230", "696", "537", "975", "773", "918", "677", "742", "745"], "chain_num": 4}, "SWC_021": {"seeds": ["415", "403", "412", "407", "409", "408", "405", "404", "410", "414", "401", "413", "402", "400", "406", "411"], "fit_nums": ["773", "615", "107", "583", "564", "354", "142", "184", "549", "185", "924", "907", "105", "531", "9", "812"], "chain_num": 9}, "ibl_witten_15": {"seeds": ["409", "410", "401", "415", "414", "403", "411", "404", "402", "405", "400", "412", "408", "407", "406", "413"], "fit_nums": ["411", "344", "496", "600", "716", "18", "527", "467", "898", "334", "309", "326", "133", "823", "740", "253"], "chain_num": 9}, "ibl_witten_13": {"seeds": ["302", "312", "313", "306", "315", "307", "311", "314", "309", "301", "308", "300", "304", "310", "303", "305"], "fit_nums": ["897", "765", "433", "641", "967", "599", "984", "259", "853", "385", "887", "619", "434", "964", "483", "891"], "chain_num": 4}, "KS016": {"seeds": ["315", "301", "309", "313", "302", "307", "303", "308", "311", "312", "314", "306", "310", "300", "305", "304"], "fit_nums": ["99", "57", "585", "32", "501", "558", "243", "413", "59", "757", "463", "172", "524", "957", "909", "292"], "chain_num": 4}, "KS003": {"seeds": ["404", "407", "413", "403", "414", "405", "400", "401", "402", "410", "415", "408", "411", "409", "406", "412"], "fit_nums": ["846", "256", "845", "945", "293", "406", "420", "109", "690", "421", "54", "866", "784", "81", "997", "665"], "chain_num": 9}, "ibl_witten_19": {"seeds": ["315", "311", "307", "314", "308", "300", "305", "301", "313", "304", "302", "310", "306", "312", "309", "303"], "fit_nums": ["179", "951", "613", "6", "623", "382", "458", "504", "406", "554", "5", "631", "746", "817", "265", "328"], "chain_num": 4}, "SWC_022": {"seeds": ["411", "403", "414", "409", "407", "412", "410", "413", "415", "404", "405", "400", "402", "401", "408", "406"], "fit_nums": ["408", "884", "62", "962", "744", "854", "635", "70", "320", "952", "8", "67", "231", "381", "536", "962"], "chain_num": 9}, "KS022": {"seeds": ["315", "300", "314", "301", "303", "302", "306", "308", "305", "310", "313", "312", "304", "307", "311", "309"], "fit_nums": ["899", "681", "37", "957", "629", "637", "375", "980", "810", "51", "759", "664", "420", "127", "259", "555"], "chain_num": 4}, "CSH_ZAD_017": {"seeds": ["401", "409", "405", "403", "415", "404", "402", "411", "410", "414", "408", "406", "413", "412", "400", "407"], "fit_nums": ["883", "803", "637", "806", "356", "804", "662", "654", "684", "350", "947", "460", "569", "976", "103", "713"], "chain_num": 9}, "CSH_ZAD_025": {"seeds": ["303", "311", "307", "312", "313", "314", "308", "315", "305", "306", "304", "302", "309", "310", "301", "300"], "fit_nums": ["581", "148", "252", "236", "581", "838", "206", "756", "449", "288", "756", "593", "733", "633", "418", "563"], "chain_num": 4}, "ibl_witten_17": {"seeds": ["406", "415", "408", "413", "402", "405", "409", "400", "414", "401", "412", "407", "404", "410", "403", "411"], "fit_nums": ["827", "797", "496", "6", "444", "823", "384", "873", "634", "27", "811", "142", "207", "322", "756", "275"], "chain_num": 9}, "ibl_witten_18": {"seeds": ["311", "310", "303", "314", "302", "309", "305", "307", "312", "300", "308", "306", "315", "313", "304", "301"], "fit_nums": ["236", "26", "838", "762", "826", "409", "496", "944", "280", "704", "930", "419", "637", "896", "876", "297"], "chain_num": 4}, "CSHL_018": {"seeds": ["302", "310", "306", "300", "314", "307", "309", "313", "311", "308", "304", "301", "312", "303", "305", "315"], "fit_nums": ["843", "817", "920", "900", "226", "36", "472", "676", "933", "453", "116", "263", "269", "897", "568", "438"], "chain_num": 4}, "GLM_Sim_06": {"seeds": ["313", "309", "302", "303", "305", "314", "300", "315", "311", "306", "304", "310", "301", "312", "308", "307"], "fit_nums": ["9", "786", "286", "280", "72", "587", "619", "708", "360", "619", "311", "189", "60", "708", "939", "733"], "chain_num": 2}, "ZM_1897": {"seeds": ["304", "308", "305", "311", "315", "314", "307", "306", "300", "303", "313", "310", "301", "312", "302", "309"], "fit_nums": ["549", "96", "368", "509", "424", "897", "287", "426", "968", "93", "725", "513", "837", "581", "989", "374"], "chain_num": 4}, "CSHL_020": {"seeds": ["305", "309", "313", "302", "314", "310", "300", "307", "315", "306", "312", "304", "311", "301", "303", "308"], "fit_nums": ["222", "306", "243", "229", "584", "471", "894", "238", "986", "660", "494", "657", "896", "459", "100", "283"], "chain_num": 4}, "CSHL054": {"seeds": ["401", "415", "409", "410", "414", "413", "407", "405", "406", "408", "411", "400", "412", "402", "403", "404"], "fit_nums": ["901", "734", "609", "459", "574", "793", "978", "66", "954", "906", "954", "111", "292", "850", "266", "967"], "chain_num": 9}, "CSHL_014": {"seeds": ["305", "311", "309", "300", "313", "310", "307", "306", "304", "312", "308", "302", "314", "303", "301", "315"], "fit_nums": ["371", "550", "166", "24", "705", "385", "870", "884", "831", "546", "404", "722", "287", "564", "613", "783"], "chain_num": 4}, "CSHL062": {"seeds": ["307", "313", "310", "303", "306", "312", "308", "305", "311", "314", "304", "302", "300", "301", "315", "309"], "fit_nums": ["846", "371", "94", "888", "499", "229", "546", "432", "71", "989", "986", "91", "935", "314", "975", "481"], "chain_num": 4}, "CSH_ZAD_001": {"seeds": ["313", "309", "311", "312", "305", "310", "315", "300", "314", "304", "301", "302", "308", "303", "306", "307"], "fit_nums": ["468", "343", "314", "544", "38", "120", "916", "170", "305", "569", "502", "496", "452", "336", "559", "572"], "chain_num": 4}, "NYU-06": {"seeds": ["314", "309", "306", "305", "312", "303", "307", "304", "300", "302", "310", "301", "315", "308", "313", "311"], "fit_nums": ["950", "862", "782", "718", "427", "645", "827", "612", "821", "834", "595", "929", "679", "668", "648", "869"], "chain_num": 4}, "KS019": {"seeds": ["404", "401", "411", "408", "400", "403", "410", "413", "402", "407", "415", "409", "406", "414", "412", "405"], "fit_nums": ["682", "4", "264", "200", "250", "267", "737", "703", "132", "855", "922", "686", "85", "176", "54", "366"], "chain_num": 9}, "CSHL049": {"seeds": ["411", "402", "414", "408", "409", "410", "413", "407", "406", "401", "404", "405", "403", "415", "400", "412"], "fit_nums": ["104", "553", "360", "824", "749", "519", "347", "228", "863", "671", "140", "883", "701", "445", "627", "898"], "chain_num": 9}, "ibl_witten_14": {"seeds": ["310", "311", "304", "306", "300", "302", "314", "313", "303", "308", "301", "309", "305", "315", "312", "307"], "fit_nums": ["563", "120", "85", "712", "277", "871", "183", "661", "505", "598", "210", "89", "310", "638", "564", "998"], "chain_num": 4}, "KS014": {"seeds": ["301", "310", "302", "312", "313", "308", "307", "303", "305", "300", "314", "306", "311", "309", "304", "315"], "fit_nums": ["668", "32", "801", "193", "269", "296", "74", "24", "270", "916", "21", "250", "342", "451", "517", "293"], "chain_num": 4, "ignore": [9, 11, 0, 1, 14, 2, 12, 13]}, "CSHL059": {"seeds": ["306", "309", "300", "304", "314", "303", "315", "311", "313", "305", "301", "307", "302", "312", "310", "308"], "fit_nums": ["821", "963", "481", "999", "986", "45", "551", "605", "701", "201", "629", "261", "972", "407", "165", "9"], "chain_num": 4}, "GLM_Sim_13": {"seeds": ["310", "303", "308", "306", "300", "312", "301", "313", "305", "311", "315", "304", "314", "309", "307", "302"], "fit_nums": ["982", "103", "742", "524", "614", "370", "926", "456", "133", "143", "302", "80", "395", "549", "579", "944"], "chain_num": 2}, "CSHL_007": {"seeds": ["314", "303", "308", "313", "301", "300", "302", "305", "315", "306", "310", "309", "311", "304", "307", "312"], "fit_nums": ["462", "703", "345", "286", "480", "313", "986", "165", "201", "102", "322", "894", "960", "438", "330", "169"], "chain_num": 4}, "CSH_ZAD_011": {"seeds": ["314", "311", "303", "300", "305", "310", "306", "301", "302", "315", "304", "309", "308", "312", "313", "307"], "fit_nums": ["320", "385", "984", "897", "315", "120", "320", "945", "475", "403", "210", "412", "695", "564", "664", "411"], "chain_num": 4}, "KS021": {"seeds": ["309", "312", "304", "310", "303", "311", "314", "302", "305", "301", "306", "300", "308", "315", "313", "307"], "fit_nums": ["874", "943", "925", "587", "55", "136", "549", "528", "349", "211", "401", "84", "225", "545", "153", "382"], "chain_num": 4}, "GLM_Sim_15": {"seeds": ["303", "312", "305", "308", "309", "302", "301", "310", "313", "315", "311", "314", "307", "306", "304", "300"], "fit_nums": ["769", "930", "328", "847", "899", "714", "144", "518", "521", "873", "914", "359", "242", "343", "45", "364"], "chain_num": 2}, "CSHL_015": {"seeds": ["301", "302", "307", "310", "309", "311", "304", "312", "300", "308", "313", "305", "314", "315", "306", "303"], "fit_nums": ["717", "705", "357", "539", "604", "971", "669", "76", "45", "413", "510", "122", "190", "821", "368", "472"], "chain_num": 4}, "ibl_witten_16": {"seeds": ["304", "313", "309", "314", "312", "307", "305", "301", "306", "310", "300", "315", "308", "311", "303", "302"], "fit_nums": ["392", "515", "696", "270", "7", "583", "880", "674", "23", "576", "579", "695", "149", "854", "184", "875"], "chain_num": 4}, "KS015": {"seeds": ["315", "305", "309", "303", "314", "310", "311", "312", "313", "300", "307", "308", "304", "301", "302", "306"], "fit_nums": ["257", "396", "387", "435", "133", "164", "403", "8", "891", "650", "111", "557", "473", "229", "842", "196"], "chain_num": 4}, "GLM_Sim_12": {"seeds": ["304", "312", "306", "303", "310", "302", "300", "305", "308", "313", "307", "311", "315", "301", "314", "309"], "fit_nums": ["971", "550", "255", "195", "952", "486", "841", "535", "559", "37", "654", "213", "864", "506", "732", "550"], "chain_num": 2}, "GLM_Sim_11": {"seeds": ["300", "312", "310", "315", "302", "313", "314", "311", "308", "303", "309", "307", "306", "304", "301", "305"], "fit_nums": ["477", "411", "34", "893", "195", "293", "603", "5", "887", "281", "956", "73", "346", "640", "532", "688"], "chain_num": 2}, "GLM_Sim_10": {"seeds": ["301", "300", "306", "305", "307", "309", "312", "314", "311", "315", "304", "313", "303", "308", "302", "310"], "fit_nums": ["391", "97", "897", "631", "239", "652", "19", "448", "807", "35", "972", "469", "280", "562", "42", "706"], "chain_num": 2}, "CSH_ZAD_026": {"seeds": ["312", "313", "308", "310", "303", "307", "302", "305", "300", "315", "306", "301", "311", "304", "314", "309"], "fit_nums": ["699", "87", "537", "628", "797", "511", "459", "770", "969", "240", "504", "948", "295", "506", "25", "378"], "chain_num": 4}, "KS023": {"seeds": ["304", "313", "306", "309", "300", "314", "302", "310", "303", "315", "307", "308", "301", "311", "305", "312"], "fit_nums": ["698", "845", "319", "734", "908", "507", "45", "499", "175", "108", "419", "443", "116", "779", "159", "231"], "chain_num": 4}, "GLM_Sim_05": {"seeds": ["301", "315", "300", "302", "305", "304", "313", "314", "311", "309", "306", "307", "308", "310", "303", "312"], "fit_nums": ["425", "231", "701", "375", "343", "902", "623", "125", "921", "637", "393", "964", "678", "930", "796", "42"], "chain_num": 2}, "CSHL061": {"seeds": ["305", "315", "304", "303", "309", "310", "302", "300", "314", "306", "311", "313", "301", "308", "307", "312"], "fit_nums": ["396", "397", "594", "911", "308", "453", "686", "552", "103", "209", "128", "892", "345", "925", "777", "396"], "chain_num": 4}, "CSHL051": {"seeds": ["303", "310", "306", "302", "309", "305", "313", "308", "300", "314", "311", "307", "312", "304", "315", "301"], "fit_nums": ["69", "186", "49", "435", "103", "910", "705", "367", "303", "474", "596", "334", "929", "796", "616", "790"], "chain_num": 4}, "GLM_Sim_14": {"seeds": ["310", "311", "309", "313", "314", "300", "302", "304", "305", "306", "307", "312", "303", "301", "315", "308"], "fit_nums": ["616", "872", "419", "106", "940", "986", "599", "704", "218", "808", "244", "825", "448", "397", "552", "316"], "chain_num": 2}, "GLM_Sim_11_trick": {"seeds": ["411", "400", "408", "409", "415", "413", "410", "412", "406", "414", "403", "404", "401", "405", "407", "402"], "fit_nums": ["95", "508", "886", "384", "822", "969", "525", "382", "489", "436", "344", "537", "251", "223", "458", "401"], "chain_num": 2, "ignore": [10, 12, 4, 1, 0, 3, 2, 6]}, "GLM_Sim_16": {"seeds": ["302", "311", "303", "307", "313", "308", "309", "300", "305", "315", "304", "310", "312", "301", "314", "306"], "fit_nums": ["914", "377", "173", "583", "870", "456", "611", "697", "13", "713", "159", "248", "617", "37", "770", "780"], "chain_num": 2}, "ZM_3003": {"seeds": ["300", "304", "307", "312", "305", "310", "311", "314", "303", "308", "313", "301", "315", "309", "306", "302"], "fit_nums": ["603", "620", "657", "735", "357", "390", "119", "33", "62", "617", "209", "810", "688", "21", "744", "426"], "chain_num": 4}, "CSH_ZAD_022": {"seeds": ["305", "310", "311", "315", "303", "312", "314", "313", "307", "302", "300", "304", "301", "308", "306", "309"], "fit_nums": ["143", "946", "596", "203", "576", "403", "900", "65", "478", "325", "282", "513", "460", "42", "161", "970"], "chain_num": 4}, "GLM_Sim_07": {"seeds": ["300", "309", "302", "304", "305", "312", "301", "311", "315", "314", "308", "307", "303", "310", "306", "313"], "fit_nums": ["724", "701", "118", "230", "648", "426", "689", "114", "832", "731", "592", "519", "559", "938", "672", "144"], "chain_num": 1}, "KS017": {"seeds": ["311", "310", "306", "309", "303", "302", "308", "300", "313", "301", "314", "307", "315", "304", "312", "305"], "fit_nums": ["97", "281", "808", "443", "352", "890", "703", "468", "780", "708", "674", "27", "345", "23", "939", "457"], "chain_num": 4}, "GLM_Sim_11_sub": {"seeds": ["410", "414", "413", "404", "409", "415", "406", "408", "402", "411", "400", "405", "403", "407", "412", "401"], "fit_nums": ["830", "577", "701", "468", "929", "374", "954", "749", "937", "488", "873", "416", "612", "792", "461", "488"], "chain_num": 2}} \ No newline at end of file diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py index 3fefa03f85e69e152beae64b51db2d3d37475376..7d57a217bcd38325649ce5ada58a2e37f48fea2d 100644 --- a/dyn_glm_chain_analysis.py +++ b/dyn_glm_chain_analysis.py @@ -1,5 +1,4 @@ import matplotlib -# matplotlib.use('Agg') import os os.environ["OMP_NUM_THREADS"] = "6" # export OMP_NUM_THREADS=4 os.environ["OPENBLAS_NUM_THREADS"] = "6" # export OPENBLAS_NUM_THREADS=4 @@ -12,7 +11,7 @@ import pyhsmm import pickle import seaborn as sns import sys -from scipy.stats import rankdata, norm, zscore +from scipy.stats import zscore from scipy.optimize import minimize from itertools import combinations, product import matplotlib.gridspec as gridspec @@ -20,6 +19,7 @@ from matplotlib.ticker import MaxNLocator import json import time import multiprocessing as mp +from mcmc_chain_analysis import state_size_helper, state_num_helper, gamma_func, alpha_func, ll_func, r_hat_array_comp, rank_inv_normal_transform, eval_r_hat, eval_simple_r_hat colors = np.genfromtxt('colors.csv', delimiter=',') @@ -69,7 +69,7 @@ class MCMC_result_list: def state_num_dist(self): state_nums = np.zeros((self.m, self.n)) for i in range(self.m): - state_nums[i] = state_num_func(self.results[i]) + state_nums[i] = state_num_helper(0.05)(self.results[i]) return state_nums def return_chains(self, func, model_for_loop, rank_norm=True, mode_indices=None): @@ -149,14 +149,11 @@ class MCMC_result_list: counter += 1 step = len(ind) // min_len up_to = min_len * step - # print([j + i * 640 for j in ind[:up_to:step]]) chains[counter] = func(self.results[i], ind[:up_to:step]) - # print(chains) - # print(np.unique(chains.flatten(), return_counts=True)) + self.chains = chains - self.folded_chains = np.abs(self.chains - np.median(chains)) - self.rank_inv_normal_transform() + self.rank_normalised, self.folded_rank_normalised, self.ranked, self.folded_ranked = rank_inv_normal_transform(self.chains) self.lame_r_hat, self.lame_var_hat_plus = r_hat_array_comp(self.chains) self.rank_normalised_r_hat, self.var_hat_plus = r_hat_array_comp(self.rank_normalised) @@ -178,15 +175,6 @@ class MCMC_result_list: print("Effective number of samples is {}".format(self.n_eff)) return chains - def rank_inv_normal_transform(self): - # Gelman paper Rank-normalization, folding, and localization: An improved R_hat for assessing convergence of MCMC - # ranking with average rank for ties - self.ranked = rankdata(self.chains).reshape(self.chains.shape) - self.folded_ranked = rankdata(self.folded_chains).reshape(self.folded_chains.shape) - # inverse normal with fractional offset - self.rank_normalised = norm.ppf((self.ranked - 3/8) / (self.chains.size + 1/4)) - self.folded_rank_normalised = norm.ppf((self.folded_ranked - 3/8) / (self.folded_chains.size + 1/4)) - def rank_histos(self): count_max = 0 for i in range(self.m): @@ -385,42 +373,6 @@ def state_appear_and_dur(m, n_sessions): return state_appear, [(1 + session_dict[s][1] - session_dict[s][0]) / n_sessions for s in observed_states] -def r_hat_array_comp(chains): - m, n = chains.shape - psi_dot_j = np.mean(chains, axis=1) - psi_dot_dot = np.mean(psi_dot_j) - B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot) ** 2) - s_j_squared = np.sum((chains - psi_dot_j[:, None]) ** 2, axis=1) / (n - 1) - W = np.mean(s_j_squared) - var_hat_plus = (n - 1) / n * W + B / n - if W == 0: - # print("all the same value") - return 1, 0 - r_hat = np.sqrt(var_hat_plus / W) - return r_hat, var_hat_plus - - -def eval_amortized_r_hat(chains, psi_dot_j, s_j_squared, l, m, n): - psi_dot_dot = np.mean(psi_dot_j, axis=1) - B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot[:, None]) ** 2, axis=1) - W = np.mean(s_j_squared, axis=1) - var_hat_plus = (n - 1) / n * W + B / n - r_hat = np.sqrt(var_hat_plus / W) - return max(r_hat) - - -def r_hat_array_comp_mult(chains): - l, m, n = chains.shape - psi_dot_j = np.mean(chains, axis=2) - psi_dot_dot = np.mean(psi_dot_j, axis=1) - B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot[:, None]) ** 2, axis=1) - s_j_squared = np.sum((chains - psi_dot_j[:, :, None]) ** 2, axis=2) / (n - 1) - W = np.mean(s_j_squared, axis=1) - var_hat_plus = (n - 1) / n * W + B / n - r_hat = np.sqrt(var_hat_plus / W) - return r_hat, var_hat_plus - - class MCMC_result: def __init__(self, models, infos, data, sessions, fit_variance, save_id, seq_start=0, dur='yes', sample_lls=None): @@ -533,21 +485,6 @@ class MCMC_result: return np.array(glm_weights) -def gamma_func(x): return x.trans_distn.gamma - - -def alpha_func(x): return x.trans_distn.alpha - - -def largest_state_func(x): return x.assign_counts.max(1) - - -def state_num_func(x): return ((x.assign_counts / x.n_datapoints) > 0.05).sum(1) - - -def ll_func(x): return x.sample_lls[-x.n_samples:] - - def find_good_chains(chains, reduce_to=8): delete_n = chains.shape[0] // 2 - reduce_to mins = np.zeros(1 + delete_n) @@ -943,6 +880,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show ax0.plot(1 + 0.25, 0.6 + 0.2, 'ko', ms=18) ax0.plot(1 + 0.25, 0.6 + 0.2, 'wo', ms=16.8) ax0.plot(1 + 0.25, 0.6 + 0.2, 'ko', ms=16.8, alpha=abs(num_to_cont[1])) + if test.results[0].type != 'bias': current, counter = 0, 0 for c in [2, 3, 4, 5]: @@ -1095,12 +1033,30 @@ def state_development(test, state_sets, indices, save=True, save_append='', show durs, state_types = state_type_durs(states_by_session, all_pmfs) dur_counter = 1 + contrast_intro_types = [0, 0, 0, 0] + state, when = np.where(states_by_session > 0.05) + introductions_by_stage = np.zeros(3) + covered_states = [] for i, d in enumerate(durs): ax0.fill_between(range(dur_counter, 1 + dur_counter + d), 0.5, -0.5, color=type_colours[i], zorder=0, alpha=0.3) dur_counter += d - ax2.set_title('Psychometric\nfunction', size=16) + # find out during which state type which contrast was introduced + for j, contrast in enumerate([2, 3, 4, 5]): + if contrast_intro_types[j] != 0: + continue + if test.results[0].infos[contrast] + 1 < dur_counter: + contrast_intro_types[j] = i+1 + + # find out during which stage which state was introduced + for s in range(len(state_sets)): + if np.sum(state == s) == 0 or s in covered_states: + continue + if when[state == s][0] + 1 < dur_counter: + introductions_by_stage[i] += 1 + covered_states.append(s) + ax2.set_title('Psychometric\nfunction', size=16) ax1.set_ylabel('Proportion of trials', size=28, labelpad=-20) ax0.set_ylabel('% correct', size=18) ax2.set_ylabel('Probability', size=26, labelpad=-20) @@ -1142,68 +1098,13 @@ def state_development(test, state_sets, indices, save=True, save_append='', show else: plt.close() - return states_by_session, all_pmfs, durs, state_types - - -def comp_multi_r_hat(chains, rank_normalised, folded_rank_normalised): - lame_r_hat, _ = r_hat_array_comp(chains) - rank_normalised_r_hat, _ = r_hat_array_comp(rank_normalised) - folded_rank_normalised_r_hat, _ = r_hat_array_comp(folded_rank_normalised) - return max(lame_r_hat, rank_normalised_r_hat, folded_rank_normalised_r_hat) - - -def eval_r_hat(chains1, chains2, chains3, chains4): - rank_normalised1, folded_rank_normalised1 = rank_inv_normal_transform(chains1) - rank_normalised2, folded_rank_normalised2 = rank_inv_normal_transform(chains2) - rank_normalised3, folded_rank_normalised3 = rank_inv_normal_transform(chains3) - rank_normalised4, folded_rank_normalised4 = rank_inv_normal_transform(chains4) - - r_hat1 = comp_multi_r_hat(chains1, rank_normalised1, folded_rank_normalised1) - r_hat2 = comp_multi_r_hat(chains2, rank_normalised2, folded_rank_normalised2) - r_hat3 = comp_multi_r_hat(chains3, rank_normalised3, folded_rank_normalised3) - r_hat4 = comp_multi_r_hat(chains4, rank_normalised4, folded_rank_normalised4) - - return max(r_hat1, r_hat2, r_hat3, r_hat4) - - -def eval_simple_r_hat(chains): + return states_by_session, all_pmfs, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs)) - r_hats, _ = r_hat_array_comp_mult(chains) - - return max(r_hats) - - -def find_good_chains_unsplit(chains1, chains2, chains3, chains4, reduce_to=8, simple=False): - delete_n = - reduce_to + chains1.shape[0] - mins = np.zeros(delete_n + 1) - n_chains = chains1.shape[0] - chains = np.stack([chains1, chains2, chains3, chains4]) - - print("Without removals: {}".format(eval_simple_r_hat(chains))) - if simple: - r_hat = eval_simple_r_hat(chains) - else: - r_hat = eval_r_hat(chains1, chains2, chains3, chains4) - mins[0] = r_hat - - for i in range(delete_n): - print() - r_hat_min = 10 - sol = 0 - for x in combinations(range(n_chains), n_chains - 1 - i): - if simple: - r_hat = eval_simple_r_hat(np.delete(chains, x, axis=1)) - else: - r_hat = eval_r_hat(np.delete(chains1, x, axis=0), np.delete(chains2, x, axis=0), np.delete(chains3, x, axis=0), np.delete(chains4, x, axis=0)) - if r_hat < r_hat_min: - sol = x - r_hat_min = min(r_hat, r_hat_min) - print("Minimum is {} (removed {})".format(r_hat_min, i + 1)) - sol = [i for i in range(32) if i not in sol] - print("Removed: {}".format(sol)) - mins[i + 1] = r_hat_min - - return sol, r_hat_min +def smart_divide(a, b): + c = np.zeros_like(a) + d = np.logical_and(a == 0, b == 0) + c[~d] = a[~d] / b[~d] + return c def find_good_chains_unsplit_greedy(chains1, chains2, chains3, chains4, reduce_to=8, simple=False): @@ -1212,8 +1113,8 @@ def find_good_chains_unsplit_greedy(chains1, chains2, chains3, chains4, reduce_t n_chains = chains1.shape[0] chains = np.stack([chains1, chains2, chains3, chains4]) - print("Without removals: {}".format(eval_r_hat(chains1, chains2, chains3, chains4))) - r_hat = eval_r_hat(chains1, chains2, chains3, chains4) + r_hat = eval_r_hat([chains1, chains2, chains3, chains4]) + print("Without removals: {}".format(r_hat)) mins[0] = r_hat to_del = [] @@ -1224,7 +1125,7 @@ def find_good_chains_unsplit_greedy(chains1, chains2, chains3, chains4, reduce_t if x in to_del: continue if not simple: - r_hat = eval_r_hat(np.delete(chains1, to_del + [x], axis=0), np.delete(chains2, to_del + [x], axis=0), np.delete(chains3, to_del + [x], axis=0), np.delete(chains4, to_del + [x], axis=0)) + r_hat = eval_r_hat([np.delete(chains1, to_del + [x], axis=0), np.delete(chains2, to_del + [x], axis=0), np.delete(chains3, to_del + [x], axis=0), np.delete(chains4, to_del + [x], axis=0)]) else: r_hat = eval_simple_r_hat(np.delete(chains, to_del + [x], axis=1)) if r_hat < r_hat_min: @@ -1237,23 +1138,11 @@ def find_good_chains_unsplit_greedy(chains1, chains2, chains3, chains4, reduce_t mins[i + 1] = r_hat_min if simple: - r_hat_local = eval_r_hat(np.delete(chains1, to_del, axis=0), np.delete(chains2, to_del, axis=0), np.delete(chains3, to_del, axis=0), np.delete(chains4, to_del, axis=0)) + r_hat_local = eval_r_hat([np.delete(chains1, to_del, axis=0), np.delete(chains2, to_del, axis=0), np.delete(chains3, to_del, axis=0), np.delete(chains4, to_del, axis=0)]) print("Minimum over everything is {} (removed {})".format(r_hat_local, i + 1)) return to_del, r_hat_min -def rank_inv_normal_transform(chains): - # Gelman paper Rank-normalization, folding, and localization: An improved R_hat for assessing convergence of MCMC - # ranking with average rank for ties - folded_chains = np.abs(chains - np.median(chains)) - ranked = rankdata(chains).reshape(chains.shape) - folded_ranked = rankdata(folded_chains).reshape(folded_chains.shape) - # inverse normal with fractional offset - rank_normalised = norm.ppf((ranked - 3/8) / (chains.size + 1/4)) - folded_rank_normalised = norm.ppf((folded_ranked - 3/8) / (folded_chains.size + 1/4)) - return rank_normalised, folded_rank_normalised - - if __name__ == "__main__": fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0] if fit_type == 'bias': @@ -1268,16 +1157,13 @@ if __name__ == "__main__": # test.r_hat_and_ess(return_ascending, False) # test.r_hat_and_ess(return_ascending_shuffled, False) # quit() - - import pyhsmm.util.profiling as prof - good = [] bad = [] - check_r_hats = False + check_r_hats = True if check_r_hats: subjects = list(loading_info.keys()) - subjects = ['GLM_Sim_11_trick'] + subjects = ['KS014'] for subject in subjects: # if subject.startswith('GLM'): # continue @@ -1322,24 +1208,6 @@ def dist_helper(dist_matrix, state_hists, inds): return dist_matrix -def state_size_helper(n=0, mode_specific=False): - if not mode_specific: - def nth_largest_state_func(x): - return np.partition(x.assign_counts, -1 - n, axis=1)[:, -1 - n] - else: - def nth_largest_state_func(x, ind): - return np.partition(x.assign_counts[ind], -1 - n, axis=1)[:, -1 - n] - return nth_largest_state_func - - -def state_num_helper(t, mode_specific=False): - if not mode_specific: - def state_num_func(x): return ((x.assign_counts / x.n_datapoints) > t).sum(1) - else: - def state_num_func(x, ind): return ((x.assign_counts[ind] / x.n_datapoints) > t).sum(1) - return state_num_func - - def state_glm_func_helper(t, mode_specific=False): if not mode_specific: def temp_state_glm_func(x): @@ -1364,26 +1232,6 @@ def state_glm_func_helper(t, mode_specific=False): return temp_state_glm_func -def sample_statistics(test, mode_indices): - # prints out r_hats and sample sizes for given sample - test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb')) - test.r_hat_and_ess(state_size_helper(1), False) - test.r_hat_and_ess(state_size_helper(1, mode_specific=True), False, mode_indices=mode_indices) - print() - test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb')) - test.r_hat_and_ess(state_size_helper(), False) - test.r_hat_and_ess(state_size_helper(mode_specific=True), False, mode_indices=mode_indices) - print() - test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb')) - test.r_hat_and_ess(state_num_helper(0.05), False) - test.r_hat_and_ess(state_num_helper(0.05, mode_specific=True), False, mode_indices=mode_indices) - print() - test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb')) - test.r_hat_and_ess(state_num_helper(0.02), False) - test.r_hat_and_ess(state_num_helper(0.02, mode_specific=True), False, mode_indices=mode_indices) - print() - - def state_type_durs(states, pmfs): # Takes states and pmfs, first creates an array of when which type is how active, then computes the number of sessions each type lasts. # A type lasts until a more advanced type takes up more than 50% of a session (and cannot return) @@ -1452,6 +1300,22 @@ def pmf_type(pmf): if __name__ == "__main__": + + # visualise pmf types + lapses = [0.1, 0.2, 0.25, 0.33, 0.4, 0.45, 0.5, 0.55, 0.66, 0.9] + test_pmf = np.zeros(4) + for i, lapse_l in enumerate(lapses): + plt.subplot(1, 10, 1+i) + if i != 0: + plt.gca().set_yticklabels([]) + plt.ylim(0, 1) + test_pmf[:2] = lapse_l + for lapse_r in np.linspace(0.02, 0.98, 33): + test_pmf[2:] = lapse_r + plt.plot([0, 1, 9, 10], test_pmf, c=type_colours[pmf_type(test_pmf)]) + plt.show() + quit() + fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0] if fit_type == 'bias': loading_info = json.load(open("canonical_infos_bias.json", 'r')) @@ -1505,6 +1369,8 @@ if __name__ == "__main__": num_sessions = [] state_appear = [] state_dur = [] + contrast_intro_types = [] # list to agglomorate in which state type which contrast is introduced + intros_by_type_sum = np.zeros(3) # array to agglomorate how many states where introduced during which type, normalised by length of phase n_points = 150 state_trajs = np.zeros((3, n_points)) @@ -1530,12 +1396,14 @@ if __name__ == "__main__": mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'rb')) state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb')) - # lapse differential # lapse_sides(test, [s for s in state_sets if len(s) > 40], mode_indices) # training overview - states, pmfs, durs, _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=True) + states, pmfs, durs, _, contrast_intro_type, intros_by_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=True) + intros_by_type_sum += intros_by_type + continue + contrast_intro_types.append(contrast_intro_type) # state_development_single_sample(test, [mode_indices[0]], show=True, separate_pmf=True, save=False) # session overview @@ -1794,12 +1662,6 @@ if __name__ == "__main__": plotSimplex(np.array(abs_state_durs), c='k', show=True) - plt.hist(abs_state_durs, label=['1st type', '2nd type', '3rd type']) - plt.legend() - plt.tight_layout() - plt.savefig("rel state type hists") - plt.show() - if False: ax[0].set_ylim(0, 1) ax[1].set_ylim(0, 1) diff --git a/dyn_glm_chain_analysis_unused_funcs.py b/dyn_glm_chain_analysis_unused_funcs.py index 17daace56bc61e2dac54bc3d559fe7711872a3be..c2e51ba48a8682acd2c2c49850a80e7e2b0072d0 100644 --- a/dyn_glm_chain_analysis_unused_funcs.py +++ b/dyn_glm_chain_analysis_unused_funcs.py @@ -99,6 +99,40 @@ def find_good_chains_unsplit_fast(chains1, chains2, chains3, chains4, reduce_to= return sol, r_hat_min +def find_good_chains_unsplit(chains1, chains2, chains3, chains4, reduce_to=8, simple=False): + delete_n = - reduce_to + chains1.shape[0] + mins = np.zeros(delete_n + 1) + n_chains = chains1.shape[0] + chains = np.stack([chains1, chains2, chains3, chains4]) + + print("Without removals: {}".format(eval_simple_r_hat(chains))) + if simple: + r_hat = eval_simple_r_hat(chains) + else: + r_hat = eval_r_hat(chains1, chains2, chains3, chains4) + mins[0] = r_hat + + for i in range(delete_n): + print() + r_hat_min = 10 + sol = 0 + for x in combinations(range(n_chains), n_chains - 1 - i): + if simple: + r_hat = eval_simple_r_hat(np.delete(chains, x, axis=1)) + else: + r_hat = eval_r_hat(np.delete(chains1, x, axis=0), np.delete(chains2, x, axis=0), np.delete(chains3, x, axis=0), np.delete(chains4, x, axis=0)) + if r_hat < r_hat_min: + sol = x + r_hat_min = min(r_hat, r_hat_min) + print("Minimum is {} (removed {})".format(r_hat_min, i + 1)) + sol = [i for i in range(32) if i not in sol] + print("Removed: {}".format(sol)) + mins[i + 1] = r_hat_min + + return sol, r_hat_min + + + def params_to_pmf(params): return params[2] + (1 - params[2] - params[3]) / (1 + np.exp(- params[0] * (all_conts - params[1]))) @@ -214,7 +248,7 @@ if __name__ == "__main__": plt.savefig("pmf fit scatter") plt.show() - + # New things xy = np.vstack([short_pmfs[:, 0], function_range, short_pmfs[:, 1]]) z = gaussian_kde(xy)(xy) plt.figure(figsize=(24, 24 / 3)) diff --git a/mcmc_chain_analysis.py b/mcmc_chain_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..2a99a67602392e961761ef6ead4b9f40650e0d93 --- /dev/null +++ b/mcmc_chain_analysis.py @@ -0,0 +1,123 @@ +""" + Functions to extract statistics from a set of chains + Functions to compute R^hat from a set of statistic vectors +""" +import numpy as np +from scipy.stats import rankdata, norm +import pickle + + +def state_size_helper(n=0, mode_specific=False): + if not mode_specific: + def nth_largest_state_func(x): + return np.partition(x.assign_counts, -1 - n, axis=1)[:, -1 - n] + else: + def nth_largest_state_func(x, ind): + return np.partition(x.assign_counts[ind], -1 - n, axis=1)[:, -1 - n] + return nth_largest_state_func + + +def state_num_helper(t, mode_specific=False): + if not mode_specific: + def state_num_func(x): return ((x.assign_counts / x.n_datapoints) > t).sum(1) + else: + def state_num_func(x, ind): return ((x.assign_counts[ind] / x.n_datapoints) > t).sum(1) + return state_num_func + + +def gamma_func(x): return x.trans_distn.gamma + + +def alpha_func(x): return x.trans_distn.alpha + + +def ll_func(x): return x.sample_lls[-x.n_samples:] + + +def r_hat_array_comp(chains): + m, n = chains.shape + psi_dot_j = np.mean(chains, axis=1) + psi_dot_dot = np.mean(psi_dot_j) + B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot) ** 2) + s_j_squared = np.sum((chains - psi_dot_j[:, None]) ** 2, axis=1) / (n - 1) + W = np.mean(s_j_squared) + var_hat_plus = (n - 1) / n * W + B / n + if W == 0: + # print("all the same value") + return 1, 0 + r_hat = np.sqrt(var_hat_plus / W) + return r_hat, var_hat_plus + + +def eval_amortized_r_hat(chains, psi_dot_j, s_j_squared, m, n): + psi_dot_dot = np.mean(psi_dot_j, axis=1) + B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot[:, None]) ** 2, axis=1) + W = np.mean(s_j_squared, axis=1) + var_hat_plus = (n - 1) / n * W + B / n + r_hat = np.sqrt(var_hat_plus / W) + return max(r_hat) + + +def r_hat_array_comp_mult(chains): + _, m, n = chains.shape + psi_dot_j = np.mean(chains, axis=2) + psi_dot_dot = np.mean(psi_dot_j, axis=1) + B = n / (m - 1) * np.sum((psi_dot_j - psi_dot_dot[:, None]) ** 2, axis=1) + s_j_squared = np.sum((chains - psi_dot_j[:, :, None]) ** 2, axis=2) / (n - 1) + W = np.mean(s_j_squared, axis=1) + var_hat_plus = (n - 1) / n * W + B / n + r_hat = np.sqrt(var_hat_plus / W) + return r_hat, var_hat_plus + + +def rank_inv_normal_transform(chains): + # Gelman paper Rank-normalization, folding, and localization: An improved R_hat for assessing convergence of MCMC + # ranking with average rank for ties + folded_chains = np.abs(chains - np.median(chains)) + ranked = rankdata(chains).reshape(chains.shape) + folded_ranked = rankdata(folded_chains).reshape(folded_chains.shape) + # inverse normal with fractional offset + rank_normalised = norm.ppf((ranked - 3/8) / (chains.size + 1/4)) + folded_rank_normalised = norm.ppf((folded_ranked - 3/8) / (folded_chains.size + 1/4)) + return rank_normalised, folded_rank_normalised, ranked, folded_ranked + + +def eval_r_hat(chains): + r_hats = [] + for chain in chains: + rank_normalised, folded_rank_normalised, _, _ = rank_inv_normal_transform(chain) + r_hats.append(comp_multi_r_hat(chain, rank_normalised, folded_rank_normalised)) + + return max(r_hats) + + +def eval_simple_r_hat(chains): + r_hats, _ = r_hat_array_comp_mult(chains) + return max(r_hats) + + +def comp_multi_r_hat(chains, rank_normalised, folded_rank_normalised): + lame_r_hat, _ = r_hat_array_comp(chains) + rank_normalised_r_hat, _ = r_hat_array_comp(rank_normalised) + folded_rank_normalised_r_hat, _ = r_hat_array_comp(folded_rank_normalised) + return max(lame_r_hat, rank_normalised_r_hat, folded_rank_normalised_r_hat) + + +def sample_statistics(test, mode_indices, subject): + # prints out r_hats and sample sizes for given sample + test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb')) + test.r_hat_and_ess(state_size_helper(1), False) + test.r_hat_and_ess(state_size_helper(1, mode_specific=True), False, mode_indices=mode_indices) + print() + test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb')) + test.r_hat_and_ess(state_size_helper(), False) + test.r_hat_and_ess(state_size_helper(mode_specific=True), False, mode_indices=mode_indices) + print() + test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb')) + test.r_hat_and_ess(state_num_helper(0.05), False) + test.r_hat_and_ess(state_num_helper(0.05, mode_specific=True), False, mode_indices=mode_indices) + print() + test = pickle.load(open("multi_chain_saves/canonical_result_{}.p".format(subject), 'rb')) + test.r_hat_and_ess(state_num_helper(0.02), False) + test.r_hat_and_ess(state_num_helper(0.02, mode_specific=True), False, mode_indices=mode_indices) + print() diff --git a/process_many_chains.py b/process_many_chains.py index 7aba22fbac5f7b599568df6aea157e10ca1fb152..88840671866206bcdee29ba09c549eb8181bae00 100644 --- a/process_many_chains.py +++ b/process_many_chains.py @@ -1,3 +1,7 @@ +""" + Script for taking a list of subects and extracting statistics from the chains + which can be used to assess which chains have converged to the same regions +""" import numpy as np import pyhsmm import pickle @@ -5,25 +9,7 @@ import json from dyn_glm_chain_analysis import MCMC_result import matplotlib.pyplot as plt import time - -def gamma_func(x): return x.trans_distn.gamma - - -def alpha_func(x): return x.trans_distn.alpha - - -def state_size_helper(n=0): - def nth_largest_state_func(x): - return np.partition(x.assign_counts, -1 - n, axis=1)[:, -1 - n] - return nth_largest_state_func - - -def state_num_helper(t): - def state_num_func(x): return ((x.assign_counts / x.n_datapoints) > t).sum(1) - return state_num_func - - -def ll_func(x): return x.sample_lls[-x.n_samples:] +from mcmc_chain_analysis import state_size_helper, state_num_helper, ll_helper fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0] @@ -45,117 +31,79 @@ func8 = state_size_helper(4) func9 = state_size_helper(5) func10 = state_size_helper(6) func11 = state_size_helper(7) -model_for_loop = False dur = 'yes' -def temp(): - m = 16 - for subject in subjects: - print(subject) - n_runs = -1 - counter = -1 - n = (loading_info[subject]['chain_num'] + 1) * 4000 // 25 - chains1 = np.zeros((m, n)) - chains2 = np.zeros((m, n)) - chains3 = np.zeros((m, n)) - chains4 = np.zeros((m, n)) - for j, (seed, fit_num) in enumerate(zip(loading_info[subject]['seeds'], loading_info[subject]['fit_nums'])): - counter += 1 - print(seed) - info_dict = pickle.load(open("./session_data/{}_info_dict.p".format(subject), "rb")) - samples = [] - mini_counter = 0 - while True: - try: - file = "./dynamic_GLMiHMM_crossvals/{}_fittype_{}_var_{}_{}_{}{}.p".format(subject, fit_type, fit_variance, seed, fit_num, '_{}'.format(mini_counter)) - lala = time.time() - samples += pickle.load(open(file, "rb")) - print("Loading {} took {:.4}".format(mini_counter, time.time() - lala)) - mini_counter += 1 - except Exception: - break - if n_runs == -1: - n_runs = mini_counter - else: - if n_runs != mini_counter: - print("Problem") - print(n_runs, mini_counter) - quit() - save_id = "{}_fittype_{}_var_{}_{}_{}.p".format(subject, fit_type, fit_variance, seed, fit_num).replace('.', '_') - # for file in os.listdir("./dynamic_GLMiHMM_crossvals/infos_new/"): - # if file.startswith("{}_".format(subject)) and file.endswith("_{}_{}_{}_{}.json".format(fit_type, fit_variance, seed, fit_num)): - # print("Taking fit infos from {}".format(file)) - # fit_infos = json.load(open("./dynamic_GLMiHMM_crossvals/infos_new/" + file, 'r')) - # sample_lls = fit_infos['ll'] - - print("loaded seed {}".format(seed)) - - lala = time.time() - result = MCMC_result(samples[::25], - infos=info_dict, data=samples[0].datas, - sessions=fit_type, fit_variance=fit_variance, - dur=dur, save_id=save_id) # , sample_lls=sample_lls) - print("Making result {} took {:.4}".format(counter, time.time() - lala)) - - # if model_for_loop: - # for i in range(n): - # chains[j, i] = func(result.models[i]) - res = func1(result) - chains1[counter] = res - res = func2(result) - chains2[counter] = res - res = func3(result) - chains3[counter] = res - res = func4(result) - chains4[counter] = res - - # res = func5(result) - # chains5[j] = res - # res = func6(result) - # chains6[j] = res - # res = func7(result) - # chains7[j] = res - # res = func8(result) - # chains8[j] = res - # res = func9(result) - # chains9[j] = res - # res = func10(result) - # chains10[j] = res - # res = func11(result) - # chains11[j] = res - - # func2 = state_size_helper() - # func5 = state_size_helper(1) - # func6 = state_size_helper(2) - # func7 = state_size_helper(3) - # func8 = state_size_helper(4) - # func9 = state_size_helper(5) - # func10 = state_size_helper(6) - # func11 = state_size_helper(7) - # plt.plot(chains2.flatten()[::10]) - # plt.plot(chains5.flatten()[::10]) - # plt.plot(chains6.flatten()[::10]) - # plt.plot(chains7.flatten()[::10]) - # plt.plot(chains8.flatten()[::10]) - # plt.plot(chains9.flatten()[::10]) - # plt.plot(chains10.flatten()[::10]) - # plt.plot(chains11.flatten()[::10]) - # plt.axhline(1643, color='r') - # plt.axhline(3438, color='r') - # plt.axhline(2216, color='r') - # plt.axhline(811, color='r') - # plt.axhline(2721, color='r') - # plt.axhline(743, color='r') - # plt.show() - - pickle.dump(chains1, open("multi_chain_saves/{}_state_num_0_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb')) - pickle.dump(chains2, open("multi_chain_saves/{}_state_num_1_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb')) - pickle.dump(chains3, open("multi_chain_saves/{}_largest_state_0_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb')) - pickle.dump(chains4, open("multi_chain_saves/{}_largest_state_1_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb')) - - -temp() -# import pyhsmm.util.profiling as prof -# prof_func = prof._prof(temp) -# prof_func() -# prof._prof.print_stats() +m = 16 +for subject in subjects: + print(subject) + n_runs = -1 + counter = -1 + n = (loading_info[subject]['chain_num'] + 1) * 4000 // 25 + chains1 = np.zeros((m, n)) + chains2 = np.zeros((m, n)) + chains3 = np.zeros((m, n)) + chains4 = np.zeros((m, n)) + for j, (seed, fit_num) in enumerate(zip(loading_info[subject]['seeds'], loading_info[subject]['fit_nums'])): + counter += 1 + print(seed) + info_dict = pickle.load(open("./session_data/{}_info_dict.p".format(subject), "rb")) + samples = [] + mini_counter = 0 + while True: + try: + file = "./dynamic_GLMiHMM_crossvals/{}_fittype_{}_var_{}_{}_{}{}.p".format(subject, fit_type, fit_variance, seed, fit_num, '_{}'.format(mini_counter)) + lala = time.time() + samples += pickle.load(open(file, "rb")) + print("Loading {} took {:.4}".format(mini_counter, time.time() - lala)) + mini_counter += 1 + except Exception: + break + if n_runs == -1: + n_runs = mini_counter + else: + if n_runs != mini_counter: + print("Problem") + print(n_runs, mini_counter) + quit() + save_id = "{}_fittype_{}_var_{}_{}_{}.p".format(subject, fit_type, fit_variance, seed, fit_num).replace('.', '_') + # for file in os.listdir("./dynamic_GLMiHMM_crossvals/infos_new/"): + # if file.startswith("{}_".format(subject)) and file.endswith("_{}_{}_{}_{}.json".format(fit_type, fit_variance, seed, fit_num)): + # print("Taking fit infos from {}".format(file)) + # fit_infos = json.load(open("./dynamic_GLMiHMM_crossvals/infos_new/" + file, 'r')) + # sample_lls = fit_infos['ll'] + + print("loaded seed {}".format(seed)) + + result = MCMC_result(samples[::25], + infos=info_dict, data=samples[0].datas, + sessions=fit_type, fit_variance=fit_variance, + dur=dur, save_id=save_id) # , sample_lls=sample_lls) + print("Making result {} took {:.4}".format(counter, time.time() - lala)) + + res = func1(result) + chains1[counter] = res + res = func2(result) + chains2[counter] = res + res = func3(result) + chains3[counter] = res + res = func4(result) + chains4[counter] = res + # res = func5(result) + # chains5[j] = res + # res = func6(result) + # chains6[j] = res + # res = func7(result) + # chains7[j] = res + # res = func8(result) + # chains8[j] = res + # res = func9(result) + # chains9[j] = res + # res = func10(result) + # chains10[j] = res + # res = func11(result) + # chains11[j] = res + + pickle.dump(chains1, open("multi_chain_saves/{}_state_num_0_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb')) + pickle.dump(chains2, open("multi_chain_saves/{}_state_num_1_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb')) + pickle.dump(chains3, open("multi_chain_saves/{}_largest_state_0_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb')) + pickle.dump(chains4, open("multi_chain_saves/{}_largest_state_1_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb')) diff --git a/simplex_plot.py b/simplex_plot.py index 41db6f5e60a3691585d0e94dacdfb6c319b8ffc7..bdaad7885fa71bbbddcdd58b45447f4acf442ae3 100644 --- a/simplex_plot.py +++ b/simplex_plot.py @@ -4,7 +4,7 @@ Visualize points on the 3-simplex (eg, the parameters of a contained within a 2D triangle. Adapted from David Andrzejewski (david.andrzej@gmail.com) """ -import numpy as NP +import numpy as np import matplotlib.pyplot as P import matplotlib.ticker as MT import matplotlib.lines as L @@ -14,38 +14,42 @@ import matplotlib.patches as PA def plotSimplex(points, fig=None, - vertexlabels=['1', '2', '3'], save_title="test.png", + vertexlabels=['Type 1', 'Type 2', 'Type 3'], save_title="test.png", show=False, **kwargs): """ Plot Nx3 points array on the 3-simplex (with optionally labeled vertices) kwargs will be passed along directly to matplotlib.pyplot.scatter - Returns Figure, caller must .show() """ if fig is None: fig = P.figure(figsize=(9, 9)) # Draw the triangle l1 = L.Line2D([0, 0.5, 1.0, 0], # xcoords - [0, NP.sqrt(3) / 2, 0, 0], # ycoords + [0, np.sqrt(3) / 2, 0, 0], # ycoords color='k') fig.gca().add_line(l1) fig.gca().xaxis.set_major_locator(MT.NullLocator()) fig.gca().yaxis.set_major_locator(MT.NullLocator()) # Draw vertex labels - fig.gca().text(-0.05, -0.05, vertexlabels[0]) - fig.gca().text(1.05, -0.05, vertexlabels[1]) - fig.gca().text(0.5, NP.sqrt(3) / 2 + 0.05, vertexlabels[2]) + fig.gca().text(-0.06, -0.05, vertexlabels[0], size=24) + fig.gca().text(0.95, -0.05, vertexlabels[1], size=24) + fig.gca().text(0.43, np.sqrt(3) / 2 + 0.025, vertexlabels[2], size=24) # Project and draw the actual points projected = projectSimplex(points / points.sum(1)[:, None]) - P.scatter(projected[:, 0], projected[:, 1], s=points.sum(1), **kwargs) + P.scatter(projected[:, 0], projected[:, 1], s=points.sum(1) * 3.5, **kwargs) + + # plot center with average size + projected = projectSimplex(np.mean(points / points.sum(1)[:, None], axis=0).reshape(1, 3)) + P.scatter(projected[:, 0], projected[:, 1], marker='*', color='r', s=np.mean(points.sum(1)) * 3.5) + # Leave some buffer around the triangle for vertex labels fig.gca().set_xlim(-0.05, 1.05) fig.gca().set_ylim(-0.05, 1.05) P.axis('off') - P.savefig("test.png", bbox_inches='tight') + P.savefig("dur_simplex.png", bbox_inches='tight') if show: P.show() else: @@ -59,22 +63,22 @@ def projectSimplex(points): N points are given as N x 3 array """ # Convert points one at a time - tripts = NP.zeros((points.shape[0], 2)) + tripts = np.zeros((points.shape[0], 2)) for idx in range(points.shape[0]): # Init to triangle centroid x = 1.0 / 2 - y = 1.0 / (2 * NP.sqrt(3)) + y = 1.0 / (2 * np.sqrt(3)) # Vector 1 - bisect out of lower left vertex p1 = points[idx, 0] - x = x - (1.0 / NP.sqrt(3)) * p1 * NP.cos(NP.pi / 6) - y = y - (1.0 / NP.sqrt(3)) * p1 * NP.sin(NP.pi / 6) + x = x - (1.0 / np.sqrt(3)) * p1 * np.cos(np.pi / 6) + y = y - (1.0 / np.sqrt(3)) * p1 * np.sin(np.pi / 6) # Vector 2 - bisect out of lower right vertex p2 = points[idx, 1] - x = x + (1.0 / NP.sqrt(3)) * p2 * NP.cos(NP.pi / 6) - y = y - (1.0 / NP.sqrt(3)) * p2 * NP.sin(NP.pi / 6) + x = x + (1.0 / np.sqrt(3)) * p2 * np.cos(np.pi / 6) + y = y - (1.0 / np.sqrt(3)) * p2 * np.sin(np.pi / 6) # Vector 3 - bisect out of top vertex p3 = points[idx, 2] - y = y + (1.0 / NP.sqrt(3) * p3) + y = y + (1.0 / np.sqrt(3) * p3) tripts[idx, :] = (x, y) @@ -87,7 +91,7 @@ if __name__ == '__main__': '[0.8 0.1 0.1]', '[0.5 0.4 0.1]', '[0.33 0.34 0.33]') - testpoints = NP.array([[0.1, 0.1, 0.8], + testpoints = np.array([[0.1, 0.1, 0.8], [0.8, 0.1, 0.1], [0.5, 0.4, 0.1], [0.33, 0.34, 0.33]])