diff --git a/__pycache__/analysis_pmf.cpython-37.pyc b/__pycache__/analysis_pmf.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b4caf72889d24bf4325fbc9922b9fddd747dec8 Binary files /dev/null and b/__pycache__/analysis_pmf.cpython-37.pyc differ diff --git a/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc b/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc index b4b56a55e7c7e3ff5cceb8da95a87d51c338cbac..740060e1aee1c905d0d6e98e7fb190b6806d36f6 100644 Binary files a/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc and b/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc differ diff --git a/analysis_pmf.py b/analysis_pmf.py index b07ba8a8e44f3f518bd19fc3d8e0b2a8aa2073e4..9a47f86eafebdd743f530024e2b605ce65a9b736 100644 --- a/analysis_pmf.py +++ b/analysis_pmf.py @@ -24,6 +24,17 @@ def pmf_type(pmf): if __name__ == "__main__": all_first_pmfs_typeless = pickle.load(open("all_first_pmfs_typeless.p", 'rb')) all_pmfs = pickle.load(open("all_pmfs.p", 'rb')) + all_bias_flips = pickle.load(open("all_bias_flips.p", 'rb')) + + # bias flips + plt.hist(all_bias_flips, bins=np.arange(0, max(all_bias_flips) + 1), color='grey', align='left') + plt.ylabel("# of mice") + plt.xlabel("Bias flips") + sns.despine() + plt.tight_layout() + plt.savefig("./meeting_figures/bias_flips.png") + plt.show() + quit() fewer_states_side = [] for key in all_first_pmfs_typeless: diff --git a/analysis_regression.py b/analysis_regression.py index fe165ae1aded90bb19b0826ec79f7735284d2608..575f234c91dbf83c6f0082d3ac1aa05e743d765e 100644 --- a/analysis_regression.py +++ b/analysis_regression.py @@ -9,6 +9,7 @@ fontsize = 22 if __name__ == "__main__": regressions = np.array(pickle.load(open("regressions.p", 'rb'))) + regression_diffs = np.array(pickle.load(open("regression_diffs.p", 'rb'))) assert (regressions[:, 0] == np.sum(regressions[:, 2:], 1)).all() @@ -50,3 +51,13 @@ if __name__ == "__main__": plt.tight_layout() plt.savefig("# of mice with regressions per type") plt.show() + + # histogram of regression diffs + plt.hist(regression_diffs, color='grey', bins=20) + plt.ylabel("# regressions", size=fontsize) + plt.xlabel("Reward rate diff", size=fontsize) + + sns.despine() + plt.tight_layout() + plt.savefig("Regression diffs") + plt.show() diff --git a/analysis_state_intros.py b/analysis_state_intros.py index c00985af68fa992ce0fba0a42df16274370b7fa9..91b3250a7fee932274c8510534aa0e591912e9b7 100644 --- a/analysis_state_intros.py +++ b/analysis_state_intros.py @@ -6,8 +6,9 @@ fontsize = 16 def type_hist(data, title=''): highest = int(data.max()) + lowest = int(data.min()) if (data % 1 == 0).all(): - bins = np.arange(highest + 2) - 0.5 + bins = np.arange(lowest, highest + 2) - 0.5 else: bins = np.histogram(data)[1] hist_max = 0 @@ -15,31 +16,35 @@ def type_hist(data, title=''): hist_max = max(hist_max, np.histogram(data[:, i], bins)[0].max()) plt.subplot(3, 1, 1) - - if title != '': - plt.title(title, size=fontsize + 4) - assert np.histogram(data[:, 0])[0].sum() == np.histogram(data[:, 0], bins)[0].sum() plt.hist(data[:, 0], alpha=1/3, label="type 1", align='mid', bins=bins, color='grey') - plt.xlim(-0.5, highest + 1) + plt.xlim(lowest - 0.5, highest + 1) plt.ylim(0, hist_max + 1) + ax2 = plt.gca().twinx() plt.ylabel("Type 1", size=fontsize) + ax2.set_yticks([]) plt.subplot(3, 1, 2) assert np.histogram(data[:, 1])[0].sum() == np.histogram(data[:, 1], bins)[0].sum() plt.hist(data[:, 1], alpha=1/3, label="type 2", align='mid', bins=bins, color='grey') - plt.xlim(-0.5, highest + 1) + plt.xlim(lowest - 0.5, highest + 1) plt.ylim(0, hist_max + 1) + ax2 = plt.gca().twinx() plt.ylabel("Type 2", size=fontsize) + ax2.set_yticks([]) plt.subplot(3, 1, 3) assert np.histogram(data[:, 2])[0].sum() == np.histogram(data[:, 2], bins)[0].sum() plt.hist(data[:, 2], alpha=1/3, label="type 3", align='mid', bins=bins, color='grey') - plt.xlim(-0.5, highest + 1) + plt.xlim(lowest - 0.5, highest + 1) plt.ylim(0, hist_max + 1) + plt.xlabel(title, size=fontsize) + plt.ylabel("# of mice", size=fontsize) + ax2 = plt.gca().twinx() plt.ylabel("Type 3", size=fontsize) + ax2.set_yticks([]) - # plt.legend() + plt.savefig(title) plt.show() diff --git a/behaviour_overview.py b/behaviour_overview.py index 24180ed205a00a9cb7e864c39782d62c92f8d658..1a4899e65f47119af82ac806fd7babd966679c33 100644 --- a/behaviour_overview.py +++ b/behaviour_overview.py @@ -5,7 +5,7 @@ from one.api import ONE import pickle -show = 1 +show = 0 def progression(data, contrasts, progression_variable='feedback', windowsize=6, upper_bound=None, title=None): # looks somewhat irregular, red dots are not in middle of bump they cause, this is because distribution of specific contrasts is not uniform @@ -47,13 +47,13 @@ def progression(data, contrasts, progression_variable='feedback', windowsize=6, # plt.title(title, size=22) # plt.savefig("temp {}".format(title).replace('/', '_')) # plt.close() - # if progression_variable == 'rt': - # means = data.groupby('signed_contrast').mean()['rt'] - # stds = data.groupby('signed_contrast').std()['rt'] - # plt.errorbar(means.index, means.values, stds.values) - # plt.title(title, size=22) - # plt.savefig("temp {}".format(title).replace('/', '_')) - # plt.show() + if progression_variable == 'rt': + means = data.groupby('signed_contrast').mean()['rt'] + stds = data.groupby('signed_contrast').sem()['rt'] + plt.errorbar(means.index, means.values, stds.values) + plt.title(title, size=22) + plt.savefig("temp {}".format(title).replace('/', '_')) + plt.close() dataset_types = ['choice', 'contrastLeft', 'contrastRight', \ diff --git a/canonical_infos.json b/canonical_infos.json index 5f8da99817e18fb2831226f46d3018b12ba8b064..c0868bc5f9a444da61e2e4c5010c9c759000ff3c 100644 --- a/canonical_infos.json +++ b/canonical_infos.json @@ -1 +1 @@ -{"SWC_023": {"seeds": ["302", "312", "304", "300", "315", "311", "308", "305", "303", "309", "306", "313", "307", "314", "301", "310"], "fit_nums": ["994", "913", "681", "816", "972", "790", "142", "230", "696", "537", "975", "773", "918", "677", "742", "745"], "chain_num": 4, "ignore": [12, 1, 15, 14, 8, 6, 4, 10]}, "SWC_021": {"seeds": ["415", "403", "412", "407", "409", "408", "405", "404", "410", "414", "401", "413", "402", "400", "406", "411"], "fit_nums": ["773", "615", "107", "583", "564", "354", "142", "184", "549", "185", "924", "907", "105", "531", "9", "812"], "chain_num": 9, "ignore": [14, 12, 0, 10, 9, 4, 5, 1]}, "ibl_witten_15": {"seeds": ["409", "410", "401", "415", "414", "403", "411", "404", "402", "405", "400", "412", "408", "407", "406", "413"], "fit_nums": ["411", "344", "496", "600", "716", "18", "527", "467", "898", "334", "309", "326", "133", "823", "740", "253"], "chain_num": 9, "ignore": [14, 13, 8, 4, 5, 12, 11, 9]}, "ibl_witten_13": {"seeds": ["302", "312", "313", "306", "315", "307", "311", "314", "309", "301", "308", "300", "304", "310", "303", "305"], "fit_nums": ["897", "765", "433", "641", "967", "599", "984", "259", "853", "385", "887", "619", "434", "964", "483", "891"], "chain_num": 4, "ignore": [3, 5, 15, 0, 2, 12, 11, 10]}, "KS016": {"seeds": ["315", "301", "309", "313", "302", "307", "303", "308", "311", "312", "314", "306", "310", "300", "305", "304"], "fit_nums": ["99", "57", "585", "32", "501", "558", "243", "413", "59", "757", "463", "172", "524", "957", "909", "292"], "chain_num": 4, "ignore": [0, 2, 14, 12, 1, 7, 11, 6]}, "KS003": {"seeds": ["404", "407", "413", "403", "414", "405", "400", "401", "402", "410", "415", "408", "411", "409", "406", "412"], "fit_nums": ["846", "256", "845", "945", "293", "406", "420", "109", "690", "421", "54", "866", "784", "81", "997", "665"], "chain_num": 9, "ignore": [8, 15, 0, 13, 7, 12, 11, 1]}, "ibl_witten_19": {"seeds": ["315", "311", "307", "314", "308", "300", "305", "301", "313", "304", "302", "310", "306", "312", "309", "303"], "fit_nums": ["179", "951", "613", "6", "623", "382", "458", "504", "406", "554", "5", "631", "746", "817", "265", "328"], "chain_num": 4, "ignore": [13, 4, 10, 9, 2, 1, 3, 6]}, "SWC_022": {"seeds": ["411", "403", "414", "409", "407", "412", "410", "413", "415", "404", "405", "400", "402", "401", "408", "406"], "fit_nums": ["408", "884", "62", "962", "744", "854", "635", "70", "320", "952", "8", "67", "231", "381", "536", "962"], "chain_num": 9, "ignore": [4, 8, 7, 9, 1, 2, 10, 6]}, "KS022": {"seeds": ["315", "300", "314", "301", "303", "302", "306", "308", "305", "310", "313", "312", "304", "307", "311", "309"], "fit_nums": ["899", "681", "37", "957", "629", "637", "375", "980", "810", "51", "759", "664", "420", "127", "259", "555"], "chain_num": 4, "ignore": [10, 1, 0, 13, 5, 9, 12, 3]}, "CSH_ZAD_017": {"seeds": ["401", "409", "405", "403", "415", "404", "402", "411", "410", "414", "408", "406", "413", "412", "400", "407"], "fit_nums": ["883", "803", "637", "806", "356", "804", "662", "654", "684", "350", "947", "460", "569", "976", "103", "713"], "chain_num": 9, "ignore": [3, 4, 6, 7, 5, 0, 15, 12]}, "CSH_ZAD_025": {"seeds": ["303", "311", "307", "312", "313", "314", "308", "315", "305", "306", "304", "302", "309", "310", "301", "300"], "fit_nums": ["581", "148", "252", "236", "581", "838", "206", "756", "449", "288", "756", "593", "733", "633", "418", "563"], "chain_num": 4, "ignore": [8, 10, 13, 5, 12, 9, 7, 1]}, "ibl_witten_17": {"seeds": ["406", "415", "408", "413", "402", "405", "409", "400", "414", "401", "412", "407", "404", "410", "403", "411"], "fit_nums": ["827", "797", "496", "6", "444", "823", "384", "873", "634", "27", "811", "142", "207", "322", "756", "275"], "chain_num": 9, "ignore": [9, 0, 1, 7, 11, 3, 10, 8]}, "ibl_witten_18": {"seeds": ["311", "310", "303", "314", "302", "309", "305", "307", "312", "300", "308", "306", "315", "313", "304", "301"], "fit_nums": ["236", "26", "838", "762", "826", "409", "496", "944", "280", "704", "930", "419", "637", "896", "876", "297"], "chain_num": 4, "ignore": [11, 0, 4, 2, 12, 13, 8, 3]}, "CSHL_018": {"seeds": ["302", "310", "306", "300", "314", "307", "309", "313", "311", "308", "304", "301", "312", "303", "305", "315"], "fit_nums": ["843", "817", "920", "900", "226", "36", "472", "676", "933", "453", "116", "263", "269", "897", "568", "438"], "chain_num": 4, "ignore": [15, 4, 8, 0, 5, 10, 12, 11]}, "GLM_Sim_06": {"seeds": ["313", "309", "302", "303", "305", "314", "300", "315", "311", "306", "304", "310", "301", "312", "308", "307"], "fit_nums": ["9", "786", "286", "280", "72", "587", "619", "708", "360", "619", "311", "189", "60", "708", "939", "733"], "chain_num": 2, "ignore": [15, 9, 8, 14, 1, 12, 10, 3]}, "ZM_1897": {"seeds": ["304", "308", "305", "311", "315", "314", "307", "306", "300", "303", "313", "310", "301", "312", "302", "309"], "fit_nums": ["549", "96", "368", "509", "424", "897", "287", "426", "968", "93", "725", "513", "837", "581", "989", "374"], "chain_num": 4, "ignore": [0, 14, 5, 8, 7, 11, 13, 10]}, "CSHL_020": {"seeds": ["305", "309", "313", "302", "314", "310", "300", "307", "315", "306", "312", "304", "311", "301", "303", "308"], "fit_nums": ["222", "306", "243", "229", "584", "471", "894", "238", "986", "660", "494", "657", "896", "459", "100", "283"], "chain_num": 4, "ignore": [6, 5, 9, 15, 0, 8, 4, 13]}, "CSHL054": {"seeds": ["401", "415", "409", "410", "414", "413", "407", "405", "406", "408", "411", "400", "412", "402", "403", "404"], "fit_nums": ["901", "734", "609", "459", "574", "793", "978", "66", "954", "906", "954", "111", "292", "850", "266", "967"], "chain_num": 9, "ignore": [5, 12, 7, 10, 11, 2, 6, 4]}, "CSHL_014": {"seeds": ["305", "311", "309", "300", "313", "310", "307", "306", "304", "312", "308", "302", "314", "303", "301", "315"], "fit_nums": ["371", "550", "166", "24", "705", "385", "870", "884", "831", "546", "404", "722", "287", "564", "613", "783"], "chain_num": 4, "ignore": [15, 0, 3, 4, 7, 6, 1, 11]}, "CSHL062": {"seeds": ["307", "313", "310", "303", "306", "312", "308", "305", "311", "314", "304", "302", "300", "301", "315", "309"], "fit_nums": ["846", "371", "94", "888", "499", "229", "546", "432", "71", "989", "986", "91", "935", "314", "975", "481"], "chain_num": 4, "ignore": [14, 6, 3, 11, 15, 13, 4, 12]}, "CSH_ZAD_001": {"seeds": ["313", "309", "311", "312", "305", "310", "315", "300", "314", "304", "301", "302", "308", "303", "306", "307"], "fit_nums": ["468", "343", "314", "544", "38", "120", "916", "170", "305", "569", "502", "496", "452", "336", "559", "572"], "chain_num": 4, "ignore": [12, 8, 5, 1, 9, 3, 13, 15]}, "NYU-06": {"seeds": ["314", "309", "306", "305", "312", "303", "307", "304", "300", "302", "310", "301", "315", "308", "313", "311"], "fit_nums": ["950", "862", "782", "718", "427", "645", "827", "612", "821", "834", "595", "929", "679", "668", "648", "869"], "chain_num": 4, "ignore": [8, 2, 7, 12, 3, 4, 13, 11]}, "KS019": {"seeds": ["404", "401", "411", "408", "400", "403", "410", "413", "402", "407", "415", "409", "406", "414", "412", "405"], "fit_nums": ["682", "4", "264", "200", "250", "267", "737", "703", "132", "855", "922", "686", "85", "176", "54", "366"], "chain_num": 9, "ignore": [12, 14, 1, 2, 4, 7, 10, 15]}, "CSHL049": {"seeds": ["411", "402", "414", "408", "409", "410", "413", "407", "406", "401", "404", "405", "403", "415", "400", "412"], "fit_nums": ["104", "553", "360", "824", "749", "519", "347", "228", "863", "671", "140", "883", "701", "445", "627", "898"], "chain_num": 9, "ignore": [10, 11, 6, 7, 12, 13, 1, 8]}, "ibl_witten_14": {"seeds": ["310", "311", "304", "306", "300", "302", "314", "313", "303", "308", "301", "309", "305", "315", "312", "307"], "fit_nums": ["563", "120", "85", "712", "277", "871", "183", "661", "505", "598", "210", "89", "310", "638", "564", "998"], "chain_num": 4, "ignore": [11, 14, 6, 13, 5, 12, 15, 8]}, "KS014": {"seeds": ["301", "310", "302", "312", "313", "308", "307", "303", "305", "300", "314", "306", "311", "309", "304", "315"], "fit_nums": ["668", "32", "801", "193", "269", "296", "74", "24", "270", "916", "21", "250", "342", "451", "517", "293"], "chain_num": 4, "ignore": [9, 11, 0, 1, 14, 2, 12, 13]}, "CSHL059": {"seeds": ["306", "309", "300", "304", "314", "303", "315", "311", "313", "305", "301", "307", "302", "312", "310", "308"], "fit_nums": ["821", "963", "481", "999", "986", "45", "551", "605", "701", "201", "629", "261", "972", "407", "165", "9"], "chain_num": 4, "ignore": [9, 3, 5, 15, 6, 10, 2, 1]}, "GLM_Sim_13": {"seeds": ["310", "303", "308", "306", "300", "312", "301", "313", "305", "311", "315", "304", "314", "309", "307", "302"], "fit_nums": ["982", "103", "742", "524", "614", "370", "926", "456", "133", "143", "302", "80", "395", "549", "579", "944"], "chain_num": 2, "ignore": [12, 4, 11, 6, 7, 14, 0, 1]}, "CSHL_007": {"seeds": ["314", "303", "308", "313", "301", "300", "302", "305", "315", "306", "310", "309", "311", "304", "307", "312"], "fit_nums": ["462", "703", "345", "286", "480", "313", "986", "165", "201", "102", "322", "894", "960", "438", "330", "169"], "chain_num": 4, "ignore": [3, 12, 4, 5, 2, 0, 13, 1]}, "CSH_ZAD_011": {"seeds": ["314", "311", "303", "300", "305", "310", "306", "301", "302", "315", "304", "309", "308", "312", "313", "307"], "fit_nums": ["320", "385", "984", "897", "315", "120", "320", "945", "475", "403", "210", "412", "695", "564", "664", "411"], "chain_num": 4, "ignore": [0, 2, 14, 11, 7, 10, 13, 9]}, "KS021": {"seeds": ["309", "312", "304", "310", "303", "311", "314", "302", "305", "301", "306", "300", "308", "315", "313", "307"], "fit_nums": ["874", "943", "925", "587", "55", "136", "549", "528", "349", "211", "401", "84", "225", "545", "153", "382"], "chain_num": 4, "ignore": [11, 12, 0, 8, 2, 14, 5, 1]}, "GLM_Sim_15": {"seeds": ["303", "312", "305", "308", "309", "302", "301", "310", "313", "315", "311", "314", "307", "306", "304", "300"], "fit_nums": ["769", "930", "328", "847", "899", "714", "144", "518", "521", "873", "914", "359", "242", "343", "45", "364"], "chain_num": 2, "ignore": [8, 1, 0, 3, 2, 5, 10, 4]}, "CSHL_015": {"seeds": ["301", "302", "307", "310", "309", "311", "304", "312", "300", "308", "313", "305", "314", "315", "306", "303"], "fit_nums": ["717", "705", "357", "539", "604", "971", "669", "76", "45", "413", "510", "122", "190", "821", "368", "472"], "chain_num": 4, "ignore": [7, 6, 10, 2, 15, 13, 1, 3]}, "ibl_witten_16": {"seeds": ["304", "313", "309", "314", "312", "307", "305", "301", "306", "310", "300", "315", "308", "311", "303", "302"], "fit_nums": ["392", "515", "696", "270", "7", "583", "880", "674", "23", "576", "579", "695", "149", "854", "184", "875"], "chain_num": 4, "ignore": [3, 12, 2, 6, 10, 14, 4, 1]}, "KS015": {"seeds": ["315", "305", "309", "303", "314", "310", "311", "312", "313", "300", "307", "308", "304", "301", "302", "306"], "fit_nums": ["257", "396", "387", "435", "133", "164", "403", "8", "891", "650", "111", "557", "473", "229", "842", "196"], "chain_num": 4, "ignore": [7, 8, 0, 10, 2, 3, 12, 9]}, "GLM_Sim_12": {"seeds": ["304", "312", "306", "303", "310", "302", "300", "305", "308", "313", "307", "311", "315", "301", "314", "309"], "fit_nums": ["971", "550", "255", "195", "952", "486", "841", "535", "559", "37", "654", "213", "864", "506", "732", "550"], "chain_num": 2, "ignore": [0, 7, 15, 14, 3, 10, 11, 13]}, "GLM_Sim_11": {"seeds": ["300", "312", "310", "315", "302", "313", "314", "311", "308", "303", "309", "307", "306", "304", "301", "305"], "fit_nums": ["477", "411", "34", "893", "195", "293", "603", "5", "887", "281", "956", "73", "346", "640", "532", "688"], "chain_num": 2}, "GLM_Sim_10": {"seeds": ["301", "300", "306", "305", "307", "309", "312", "314", "311", "315", "304", "313", "303", "308", "302", "310"], "fit_nums": ["391", "97", "897", "631", "239", "652", "19", "448", "807", "35", "972", "469", "280", "562", "42", "706"], "chain_num": 2, "ignore": [1, 9, 15, 3, 13, 12, 7, 11]}, "CSH_ZAD_026": {"seeds": ["312", "313", "308", "310", "303", "307", "302", "305", "300", "315", "306", "301", "311", "304", "314", "309"], "fit_nums": ["699", "87", "537", "628", "797", "511", "459", "770", "969", "240", "504", "948", "295", "506", "25", "378"], "chain_num": 4, "ignore": [12, 13, 4, 11, 8, 3, 15, 0]}, "KS023": {"seeds": ["304", "313", "306", "309", "300", "314", "302", "310", "303", "315", "307", "308", "301", "311", "305", "312"], "fit_nums": ["698", "845", "319", "734", "908", "507", "45", "499", "175", "108", "419", "443", "116", "779", "159", "231"], "chain_num": 4, "ignore": [8, 10, 1, 13, 4, 15, 14, 7]}, "GLM_Sim_05": {"seeds": ["301", "315", "300", "302", "305", "304", "313", "314", "311", "309", "306", "307", "308", "310", "303", "312"], "fit_nums": ["425", "231", "701", "375", "343", "902", "623", "125", "921", "637", "393", "964", "678", "930", "796", "42"], "chain_num": 2, "ignore": [11, 2, 5, 1, 4, 9, 15, 12]}, "CSHL061": {"seeds": ["305", "315", "304", "303", "309", "310", "302", "300", "314", "306", "311", "313", "301", "308", "307", "312"], "fit_nums": ["396", "397", "594", "911", "308", "453", "686", "552", "103", "209", "128", "892", "345", "925", "777", "396"], "chain_num": 4, "ignore": [11, 13, 7, 15, 14, 3, 0, 4]}, "CSHL051": {"seeds": ["303", "310", "306", "302", "309", "305", "313", "308", "300", "314", "311", "307", "312", "304", "315", "301"], "fit_nums": ["69", "186", "49", "435", "103", "910", "705", "367", "303", "474", "596", "334", "929", "796", "616", "790"], "chain_num": 4, "ignore": [15, 12, 8, 13, 0, 2, 4, 5]}, "GLM_Sim_14": {"seeds": ["310", "311", "309", "313", "314", "300", "302", "304", "305", "306", "307", "312", "303", "301", "315", "308"], "fit_nums": ["616", "872", "419", "106", "940", "986", "599", "704", "218", "808", "244", "825", "448", "397", "552", "316"], "chain_num": 2, "ignore": [7, 11, 2, 15, 0, 13, 5, 10]}, "GLM_Sim_11_trick": {"seeds": ["411", "400", "408", "409", "415", "413", "410", "412", "406", "414", "403", "404", "401", "405", "407", "402"], "fit_nums": ["95", "508", "886", "384", "822", "969", "525", "382", "489", "436", "344", "537", "251", "223", "458", "401"], "chain_num": 2, "ignore": [10, 12, 4, 1, 0, 3, 2, 6]}, "GLM_Sim_16": {"seeds": ["302", "311", "303", "307", "313", "308", "309", "300", "305", "315", "304", "310", "312", "301", "314", "306"], "fit_nums": ["914", "377", "173", "583", "870", "456", "611", "697", "13", "713", "159", "248", "617", "37", "770", "780"], "chain_num": 2, "ignore": [4, 10, 5, 0, 13, 8, 6, 7]}, "ZM_3003": {"seeds": ["300", "304", "307", "312", "305", "310", "311", "314", "303", "308", "313", "301", "315", "309", "306", "302"], "fit_nums": ["603", "620", "657", "735", "357", "390", "119", "33", "62", "617", "209", "810", "688", "21", "744", "426"], "chain_num": 4, "ignore": [14, 7, 12, 1, 3, 4, 11, 8]}, "CSH_ZAD_022": {"seeds": ["305", "310", "311", "315", "303", "312", "314", "313", "307", "302", "300", "304", "301", "308", "306", "309"], "fit_nums": ["143", "946", "596", "203", "576", "403", "900", "65", "478", "325", "282", "513", "460", "42", "161", "970"], "chain_num": 4, "ignore": [9, 12, 4, 8, 3, 7, 0, 1]}, "GLM_Sim_07": {"seeds": ["300", "309", "302", "304", "305", "312", "301", "311", "315", "314", "308", "307", "303", "310", "306", "313"], "fit_nums": ["724", "701", "118", "230", "648", "426", "689", "114", "832", "731", "592", "519", "559", "938", "672", "144"], "chain_num": 1}, "KS017": {"seeds": ["311", "310", "306", "309", "303", "302", "308", "300", "313", "301", "314", "307", "315", "304", "312", "305"], "fit_nums": ["97", "281", "808", "443", "352", "890", "703", "468", "780", "708", "674", "27", "345", "23", "939", "457"], "chain_num": 4, "ignore": [0, 13, 8, 1, 12, 5, 10, 9]}, "GLM_Sim_11_sub": {"seeds": ["410", "414", "413", "404", "409", "415", "406", "408", "402", "411", "400", "405", "403", "407", "412", "401"], "fit_nums": ["830", "577", "701", "468", "929", "374", "954", "749", "937", "488", "873", "416", "612", "792", "461", "488"], "chain_num": 2}} \ No newline at end of file +{"SWC_022": {"seeds": ["412", "401", "403", "413", "406", "407", "415", "409", "405", "400", "408", "404", "410", "411", "414", "402"], "fit_nums": ["347", "54", "122", "132", "520", "386", "312", "59", "999", "849", "372", "300", "485", "593", "358", "550"], "chain_num": 19, "ignore": [5, 6, 12, 3, 1, 11, 10]}, "SWC_023": {"seeds": ["302", "312", "304", "300", "315", "311", "308", "305", "303", "309", "306", "313", "307", "314", "301", "310"], "fit_nums": ["994", "913", "681", "816", "972", "790", "142", "230", "696", "537", "975", "773", "918", "677", "742", "745"], "chain_num": 4, "ignore": [12, 1, 15, 14, 8, 6, 4, 10]}, "ibl_witten_15": {"seeds": ["408", "412", "400", "411", "410", "407", "403", "406", "413", "405", "404", "402", "401", "415", "409", "414"], "fit_nums": ["40", "241", "435", "863", "941", "530", "382", "750", "532", "731", "146", "500", "967", "334", "375", "670"], "chain_num": 19, "ignore": [6, 7, 9, 8, 5, 13, 12, 11]}, "ibl_witten_13": {"seeds": ["401", "414", "409", "413", "415", "411", "410", "408", "402", "405", "406", "407", "412", "403", "400", "404"], "fit_nums": ["702", "831", "47", "740", "251", "929", "579", "351", "515", "261", "222", "852", "754", "892", "473", "29"], "chain_num": 19, "ignore": [11, 4, 15, 6, 8, 0]}, "KS016": {"seeds": ["315", "301", "309", "313", "302", "307", "303", "308", "311", "312", "314", "306", "310", "300", "305", "304"], "fit_nums": ["99", "57", "585", "32", "501", "558", "243", "413", "59", "757", "463", "172", "524", "957", "909", "292"], "chain_num": 4, "ignore": [0, 2, 14, 12, 1, 7, 11, 6]}, "ibl_witten_19": {"seeds": ["412", "415", "413", "408", "409", "404", "403", "401", "405", "411", "410", "406", "402", "414", "407", "400"], "fit_nums": ["234", "41", "503", "972", "935", "808", "912", "32", "331", "755", "117", "833", "822", "704", "901", "207"], "chain_num": 19, "ignore": [6, 1, 15, 10, 12]}, "CSH_ZAD_017": {"seeds": ["408", "404", "413", "406", "414", "411", "400", "401", "415", "407", "402", "412", "403", "409", "405", "410"], "fit_nums": ["928", "568", "623", "841", "92", "251", "829", "922", "964", "257", "150", "970", "375", "113", "423", "564"], "chain_num": 19, "ignore": [8, 10, 12, 5, 9, 7, 4]}, "KS022": {"seeds": ["315", "300", "314", "301", "303", "302", "306", "308", "305", "310", "313", "312", "304", "307", "311", "309"], "fit_nums": ["899", "681", "37", "957", "629", "637", "375", "980", "810", "51", "759", "664", "420", "127", "259", "555"], "chain_num": 4, "ignore": [10, 1, 0, 13, 5, 9, 12, 3]}, "CSH_ZAD_025": {"seeds": ["303", "311", "307", "312", "313", "314", "308", "315", "305", "306", "304", "302", "309", "310", "301", "300"], "fit_nums": ["581", "148", "252", "236", "581", "838", "206", "756", "449", "288", "756", "593", "733", "633", "418", "563"], "chain_num": 4, "ignore": [8, 10, 13, 5, 12, 9, 7, 1]}, "ibl_witten_17": {"seeds": ["406", "415", "408", "413", "402", "405", "409", "400", "414", "401", "412", "407", "404", "410", "403", "411"], "fit_nums": ["827", "797", "496", "6", "444", "823", "384", "873", "634", "27", "811", "142", "207", "322", "756", "275"], "chain_num": 9, "ignore": [9, 0, 1, 7, 11, 3, 10, 8]}, "SWC_021": {"seeds": ["404", "413", "406", "412", "403", "401", "410", "409", "400", "414", "415", "402", "405", "408", "411", "407"], "fit_nums": ["840", "978", "224", "38", "335", "500", "83", "509", "441", "9", "135", "890", "358", "460", "844", "30"], "chain_num": 19, "ignore": [2, 13, 1, 9, 0, 4, 6, 5]}, "ibl_witten_18": {"seeds": ["311", "310", "303", "314", "302", "309", "305", "307", "312", "300", "308", "306", "315", "313", "304", "301"], "fit_nums": ["236", "26", "838", "762", "826", "409", "496", "944", "280", "704", "930", "419", "637", "896", "876", "297"], "chain_num": 4, "ignore": [11, 0, 4, 2, 12, 13, 8, 3]}, "CSHL_018": {"seeds": ["302", "310", "306", "300", "314", "307", "309", "313", "311", "308", "304", "301", "312", "303", "305", "315"], "fit_nums": ["843", "817", "920", "900", "226", "36", "472", "676", "933", "453", "116", "263", "269", "897", "568", "438"], "chain_num": 4, "ignore": [15, 4, 8, 0, 5, 10, 12, 11]}, "GLM_Sim_06": {"seeds": ["313", "309", "302", "303", "305", "314", "300", "315", "311", "306", "304", "310", "301", "312", "308", "307"], "fit_nums": ["9", "786", "286", "280", "72", "587", "619", "708", "360", "619", "311", "189", "60", "708", "939", "733"], "chain_num": 2, "ignore": [15, 9, 8, 14, 1, 12, 10, 3]}, "ZM_1897": {"seeds": ["304", "308", "305", "311", "315", "314", "307", "306", "300", "303", "313", "310", "301", "312", "302", "309"], "fit_nums": ["549", "96", "368", "509", "424", "897", "287", "426", "968", "93", "725", "513", "837", "581", "989", "374"], "chain_num": 4, "ignore": [0, 14, 5, 8, 7, 11, 13, 10]}, "CSHL_020": {"seeds": ["305", "309", "313", "302", "314", "310", "300", "307", "315", "306", "312", "304", "311", "301", "303", "308"], "fit_nums": ["222", "306", "243", "229", "584", "471", "894", "238", "986", "660", "494", "657", "896", "459", "100", "283"], "chain_num": 4, "ignore": [6, 5, 9, 15, 0, 8, 4, 13]}, "CSHL054": {"seeds": ["401", "415", "409", "410", "414", "413", "407", "405", "406", "408", "411", "400", "412", "402", "403", "404"], "fit_nums": ["901", "734", "609", "459", "574", "793", "978", "66", "954", "906", "954", "111", "292", "850", "266", "967"], "chain_num": 9, "ignore": [5, 12, 7, 10, 11, 2, 6, 4]}, "CSHL_014": {"seeds": ["305", "311", "309", "300", "313", "310", "307", "306", "304", "312", "308", "302", "314", "303", "301", "315"], "fit_nums": ["371", "550", "166", "24", "705", "385", "870", "884", "831", "546", "404", "722", "287", "564", "613", "783"], "chain_num": 4, "ignore": [15, 0, 3, 4, 7, 6, 1, 11]}, "CSHL062": {"seeds": ["307", "313", "310", "303", "306", "312", "308", "305", "311", "314", "304", "302", "300", "301", "315", "309"], "fit_nums": ["846", "371", "94", "888", "499", "229", "546", "432", "71", "989", "986", "91", "935", "314", "975", "481"], "chain_num": 4, "ignore": [14, 6, 3, 11, 15, 13, 4, 12]}, "CSH_ZAD_001": {"seeds": ["313", "309", "311", "312", "305", "310", "315", "300", "314", "304", "301", "302", "308", "303", "306", "307"], "fit_nums": ["468", "343", "314", "544", "38", "120", "916", "170", "305", "569", "502", "496", "452", "336", "559", "572"], "chain_num": 4, "ignore": [12, 8, 5, 1, 9, 3, 13, 15]}, "KS003": {"seeds": ["405", "401", "414", "415", "410", "404", "409", "413", "412", "408", "411", "407", "402", "406", "403", "400"], "fit_nums": ["858", "464", "710", "285", "665", "857", "990", "438", "233", "177", "43", "509", "780", "254", "523", "695"], "chain_num": 19, "ignore": [1, 8, 9, 2, 4, 15, 0, 7]}, "NYU-06": {"seeds": ["314", "309", "306", "305", "312", "303", "307", "304", "300", "302", "310", "301", "315", "308", "313", "311"], "fit_nums": ["950", "862", "782", "718", "427", "645", "827", "612", "821", "834", "595", "929", "679", "668", "648", "869"], "chain_num": 4, "ignore": [8, 2, 7, 12, 3, 4, 13, 11]}, "KS019": {"seeds": ["404", "401", "411", "408", "400", "403", "410", "413", "402", "407", "415", "409", "406", "414", "412", "405"], "fit_nums": ["682", "4", "264", "200", "250", "267", "737", "703", "132", "855", "922", "686", "85", "176", "54", "366"], "chain_num": 9, "ignore": [12, 14, 1, 2, 4, 7, 10, 15]}, "CSHL049": {"seeds": ["411", "402", "414", "408", "409", "410", "413", "407", "406", "401", "404", "405", "403", "415", "400", "412"], "fit_nums": ["104", "553", "360", "824", "749", "519", "347", "228", "863", "671", "140", "883", "701", "445", "627", "898"], "chain_num": 9, "ignore": [10, 11, 6, 7, 12, 13, 1, 8]}, "ibl_witten_14": {"seeds": ["310", "311", "304", "306", "300", "302", "314", "313", "303", "308", "301", "309", "305", "315", "312", "307"], "fit_nums": ["563", "120", "85", "712", "277", "871", "183", "661", "505", "598", "210", "89", "310", "638", "564", "998"], "chain_num": 4, "ignore": [11, 14, 6, 13, 5, 12, 15, 8]}, "KS014": {"seeds": ["301", "310", "302", "312", "313", "308", "307", "303", "305", "300", "314", "306", "311", "309", "304", "315"], "fit_nums": ["668", "32", "801", "193", "269", "296", "74", "24", "270", "916", "21", "250", "342", "451", "517", "293"], "chain_num": 4, "ignore": [9, 11, 0, 1, 14, 2, 12, 13]}, "CSHL059": {"seeds": ["306", "309", "300", "304", "314", "303", "315", "311", "313", "305", "301", "307", "302", "312", "310", "308"], "fit_nums": ["821", "963", "481", "999", "986", "45", "551", "605", "701", "201", "629", "261", "972", "407", "165", "9"], "chain_num": 4, "ignore": [9, 3, 5, 15, 6, 10, 2, 1]}, "GLM_Sim_13": {"seeds": ["310", "303", "308", "306", "300", "312", "301", "313", "305", "311", "315", "304", "314", "309", "307", "302"], "fit_nums": ["982", "103", "742", "524", "614", "370", "926", "456", "133", "143", "302", "80", "395", "549", "579", "944"], "chain_num": 2, "ignore": [12, 4, 11, 6, 7, 14, 0, 1]}, "CSHL_007": {"seeds": ["314", "303", "308", "313", "301", "300", "302", "305", "315", "306", "310", "309", "311", "304", "307", "312"], "fit_nums": ["462", "703", "345", "286", "480", "313", "986", "165", "201", "102", "322", "894", "960", "438", "330", "169"], "chain_num": 4, "ignore": [3, 12, 4, 5, 2, 0, 13, 1]}, "CSH_ZAD_011": {"seeds": ["314", "311", "303", "300", "305", "310", "306", "301", "302", "315", "304", "309", "308", "312", "313", "307"], "fit_nums": ["320", "385", "984", "897", "315", "120", "320", "945", "475", "403", "210", "412", "695", "564", "664", "411"], "chain_num": 4, "ignore": [0, 2, 14, 11, 7, 10, 13, 9]}, "KS021": {"seeds": ["309", "312", "304", "310", "303", "311", "314", "302", "305", "301", "306", "300", "308", "315", "313", "307"], "fit_nums": ["874", "943", "925", "587", "55", "136", "549", "528", "349", "211", "401", "84", "225", "545", "153", "382"], "chain_num": 4, "ignore": [11, 12, 0, 8, 2, 14, 5, 1]}, "GLM_Sim_15": {"seeds": ["303", "312", "305", "308", "309", "302", "301", "310", "313", "315", "311", "314", "307", "306", "304", "300"], "fit_nums": ["769", "930", "328", "847", "899", "714", "144", "518", "521", "873", "914", "359", "242", "343", "45", "364"], "chain_num": 2, "ignore": [8, 1, 0, 3, 2, 5, 10, 4]}, "CSHL_015": {"seeds": ["301", "302", "307", "310", "309", "311", "304", "312", "300", "308", "313", "305", "314", "315", "306", "303"], "fit_nums": ["717", "705", "357", "539", "604", "971", "669", "76", "45", "413", "510", "122", "190", "821", "368", "472"], "chain_num": 4, "ignore": [7, 6, 10, 2, 15, 13, 1, 3]}, "ibl_witten_16": {"seeds": ["304", "313", "309", "314", "312", "307", "305", "301", "306", "310", "300", "315", "308", "311", "303", "302"], "fit_nums": ["392", "515", "696", "270", "7", "583", "880", "674", "23", "576", "579", "695", "149", "854", "184", "875"], "chain_num": 4, "ignore": [3, 12, 2, 6, 10, 14, 4, 1]}, "KS015": {"seeds": ["315", "305", "309", "303", "314", "310", "311", "312", "313", "300", "307", "308", "304", "301", "302", "306"], "fit_nums": ["257", "396", "387", "435", "133", "164", "403", "8", "891", "650", "111", "557", "473", "229", "842", "196"], "chain_num": 4, "ignore": [7, 8, 0, 10, 2, 3, 12, 9]}, "GLM_Sim_12": {"seeds": ["304", "312", "306", "303", "310", "302", "300", "305", "308", "313", "307", "311", "315", "301", "314", "309"], "fit_nums": ["971", "550", "255", "195", "952", "486", "841", "535", "559", "37", "654", "213", "864", "506", "732", "550"], "chain_num": 2, "ignore": [0, 7, 15, 14, 3, 10, 11, 13]}, "GLM_Sim_11": {"seeds": ["300", "312", "310", "315", "302", "313", "314", "311", "308", "303", "309", "307", "306", "304", "301", "305"], "fit_nums": ["477", "411", "34", "893", "195", "293", "603", "5", "887", "281", "956", "73", "346", "640", "532", "688"], "chain_num": 2}, "GLM_Sim_10": {"seeds": ["301", "300", "306", "305", "307", "309", "312", "314", "311", "315", "304", "313", "303", "308", "302", "310"], "fit_nums": ["391", "97", "897", "631", "239", "652", "19", "448", "807", "35", "972", "469", "280", "562", "42", "706"], "chain_num": 2, "ignore": [1, 9, 15, 3, 13, 12, 7, 11]}, "CSH_ZAD_026": {"seeds": ["312", "313", "308", "310", "303", "307", "302", "305", "300", "315", "306", "301", "311", "304", "314", "309"], "fit_nums": ["699", "87", "537", "628", "797", "511", "459", "770", "969", "240", "504", "948", "295", "506", "25", "378"], "chain_num": 4, "ignore": [12, 13, 4, 11, 8, 3, 15, 0]}, "KS023": {"seeds": ["304", "313", "306", "309", "300", "314", "302", "310", "303", "315", "307", "308", "301", "311", "305", "312"], "fit_nums": ["698", "845", "319", "734", "908", "507", "45", "499", "175", "108", "419", "443", "116", "779", "159", "231"], "chain_num": 4, "ignore": [8, 10, 1, 13, 4, 15, 14, 7]}, "GLM_Sim_05": {"seeds": ["301", "315", "300", "302", "305", "304", "313", "314", "311", "309", "306", "307", "308", "310", "303", "312"], "fit_nums": ["425", "231", "701", "375", "343", "902", "623", "125", "921", "637", "393", "964", "678", "930", "796", "42"], "chain_num": 2, "ignore": [11, 2, 5, 1, 4, 9, 15, 12]}, "CSHL061": {"seeds": ["305", "315", "304", "303", "309", "310", "302", "300", "314", "306", "311", "313", "301", "308", "307", "312"], "fit_nums": ["396", "397", "594", "911", "308", "453", "686", "552", "103", "209", "128", "892", "345", "925", "777", "396"], "chain_num": 4, "ignore": [11, 13, 7, 15, 14, 3, 0, 4]}, "CSHL051": {"seeds": ["303", "310", "306", "302", "309", "305", "313", "308", "300", "314", "311", "307", "312", "304", "315", "301"], "fit_nums": ["69", "186", "49", "435", "103", "910", "705", "367", "303", "474", "596", "334", "929", "796", "616", "790"], "chain_num": 4, "ignore": [15, 12, 8, 13, 0, 2, 4, 5]}, "GLM_Sim_14": {"seeds": ["310", "311", "309", "313", "314", "300", "302", "304", "305", "306", "307", "312", "303", "301", "315", "308"], "fit_nums": ["616", "872", "419", "106", "940", "986", "599", "704", "218", "808", "244", "825", "448", "397", "552", "316"], "chain_num": 2, "ignore": [7, 11, 2, 15, 0, 13, 5, 10]}, "GLM_Sim_11_trick": {"seeds": ["411", "400", "408", "409", "415", "413", "410", "412", "406", "414", "403", "404", "401", "405", "407", "402"], "fit_nums": ["95", "508", "886", "384", "822", "969", "525", "382", "489", "436", "344", "537", "251", "223", "458", "401"], "chain_num": 2}, "GLM_Sim_16": {"seeds": ["302", "311", "303", "307", "313", "308", "309", "300", "305", "315", "304", "310", "312", "301", "314", "306"], "fit_nums": ["914", "377", "173", "583", "870", "456", "611", "697", "13", "713", "159", "248", "617", "37", "770", "780"], "chain_num": 2, "ignore": [4, 10, 5, 0, 13, 8, 6, 7]}, "ZM_3003": {"seeds": ["300", "304", "307", "312", "305", "310", "311", "314", "303", "308", "313", "301", "315", "309", "306", "302"], "fit_nums": ["603", "620", "657", "735", "357", "390", "119", "33", "62", "617", "209", "810", "688", "21", "744", "426"], "chain_num": 4, "ignore": [14, 7, 12, 1, 3, 4, 11, 8]}, "CSH_ZAD_022": {"seeds": ["305", "310", "311", "315", "303", "312", "314", "313", "307", "302", "300", "304", "301", "308", "306", "309"], "fit_nums": ["143", "946", "596", "203", "576", "403", "900", "65", "478", "325", "282", "513", "460", "42", "161", "970"], "chain_num": 4, "ignore": [9, 12, 4, 8, 3, 7, 0, 1]}, "GLM_Sim_07": {"seeds": ["300", "309", "302", "304", "305", "312", "301", "311", "315", "314", "308", "307", "303", "310", "306", "313"], "fit_nums": ["724", "701", "118", "230", "648", "426", "689", "114", "832", "731", "592", "519", "559", "938", "672", "144"], "chain_num": 1}, "KS017": {"seeds": ["311", "310", "306", "309", "303", "302", "308", "300", "313", "301", "314", "307", "315", "304", "312", "305"], "fit_nums": ["97", "281", "808", "443", "352", "890", "703", "468", "780", "708", "674", "27", "345", "23", "939", "457"], "chain_num": 4, "ignore": [0, 13, 8, 1, 12, 5, 10, 9]}, "GLM_Sim_11_sub": {"seeds": ["410", "414", "413", "404", "409", "415", "406", "408", "402", "411", "400", "405", "403", "407", "412", "401"], "fit_nums": ["830", "577", "701", "468", "929", "374", "954", "749", "937", "488", "873", "416", "612", "792", "461", "488"], "chain_num": 2}} \ No newline at end of file diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py index 4924143b4f44716c903d7e798a60125d432c262e..410e2cdc3fb938dc9bea4ed69cb5b5b578531d86 100644 --- a/dyn_glm_chain_analysis.py +++ b/dyn_glm_chain_analysis.py @@ -21,7 +21,7 @@ import time import multiprocessing as mp from mcmc_chain_analysis import state_size_helper, state_num_helper, gamma_func, alpha_func, ll_func, r_hat_array_comp, rank_inv_normal_transform, eval_r_hat, eval_simple_r_hat import pandas as pd -from pmf_analysis import pmf_type, type2color +from analysis_pmf import pmf_type, type2color colors = np.genfromtxt('colors.csv', delimiter=',') @@ -587,6 +587,33 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur return cnas +def bias_flips(states_by_session, pmfs, durs): + state_counter = {} + prev_bias = 0 + bias_flips = 0 + + for sess in range(states_by_session.shape[1]): + states = np.where(states_by_session[:, sess])[0] + + for s in states: + if s not in state_counter: + state_counter[s] = -1 + state_counter[s] += 1 + + if s == np.argmax(states_by_session[:, sess]): + bias = np.mean(pmfs[s][1][state_counter[s]][pmfs[s][0]]) + if bias < 0.45: + if prev_bias == 1: + print("sess {}".format(sess)) + bias_flips += 1 + prev_bias = -1 + elif bias > 0.55: + if prev_bias == -1: + print("sess {}".format(sess)) + bias_flips += 1 + prev_bias = 1 + return bias_flips + def pmf_regressions(states_by_session, pmfs, durs): # find out whether pmfs regressed state_perfs = {} @@ -594,6 +621,7 @@ def pmf_regressions(states_by_session, pmfs, durs): current_best_state = -1 counter = 0 types = [0, 0, 0] + diffs = [] for sess in range(states_by_session.shape[1]): states = np.where(states_by_session[:, sess])[0] @@ -616,8 +644,9 @@ def pmf_regressions(states_by_session, pmfs, durs): types[2] += 1 a, b = state_perfs[np.argmax(states_by_session[:, sess])], state_perfs[current_best_state] print("Regression in session {} during {:.2f}% of session ({:.2f} instead of {:.2f})".format(sess + 1, np.max(states_by_session[:, sess]) * 100, a, b)) + diffs.append(b - a) - return [counter, states_by_session.shape[1], *types] + return [counter, states_by_session.shape[1], *types], diffs def control_flow(test, indices, trials, func_init, first_for, second_for, end_first_for): # Generalised control flow for iterating over samples of mode across individually across sessions @@ -1227,6 +1256,9 @@ def find_good_chains_unsplit_greedy(chains1, chains2, chains3, chains4, reduce_t print("Without removals: {}".format(r_hat)) mins[0] = r_hat + r_hats = [] + solutions = [] + to_del = [] for i in range(delete_n): r_hat_min = 50 @@ -1247,10 +1279,16 @@ def find_good_chains_unsplit_greedy(chains1, chains2, chains3, chains4, reduce_t print("Removed: {}".format(to_del)) mins[i + 1] = r_hat_min + r_hats.append(r_hat_min) + solutions.append(to_del.copy()) + if simple: r_hat_local = eval_r_hat([np.delete(chains1, to_del, axis=0), np.delete(chains2, to_del, axis=0), np.delete(chains3, to_del, axis=0), np.delete(chains4, to_del, axis=0)]) print("Minimum over everything is {} (removed {})".format(r_hat_local, i + 1)) - return to_del, r_hat_min + + best = np.argmin(r_hats) + + return solutions[best], r_hats[best] if __name__ == "__main__": @@ -1261,7 +1299,7 @@ if __name__ == "__main__": elif fit_type == 'prebias': loading_info = json.load(open("canonical_infos.json", 'r')) r_hats = json.load(open("canonical_info_r_hats.json", 'r')) - subjects = list(loading_info.keys()) + subjects = ['SWC_021', 'ibl_witten_15', 'ibl_witten_13', 'KS003', 'ibl_witten_19', 'SWC_022', 'CSH_ZAD_017'] # list(loading_info.keys()) r_hats = {} @@ -1273,7 +1311,6 @@ if __name__ == "__main__": check_r_hats = False if check_r_hats: - subjects = list(loading_info.keys()) for subject in subjects: if subject.startswith('GLM_Sim_07') or subject.startswith('GLM_Sim_11'): continue @@ -1440,7 +1477,7 @@ if __name__ == "__main__": r_hats = json.load(open("canonical_info_r_hats.json", 'r')) no_good_pcas = ['NYU-06', 'SWC_023'] subjects = list(loading_info.keys()) - subjects = ['KS014'] + # subjects = ['SWC_021', 'ibl_witten_15', 'ibl_witten_13', 'KS003', 'ibl_witten_19', 'SWC_022', 'CSH_ZAD_017'] # meh pmfs: KS021 print(subjects) @@ -1479,12 +1516,17 @@ if __name__ == "__main__": all_intros_div = [] all_states_per_type = [] regressions = [] + regression_diffs = [] + all_bias_flips = [] for subject in subjects: if subject.startswith('GLM_Sim_') or subject == 'ibl_witten_18': continue + if subject in ['SWC_021', 'ibl_witten_15', 'ibl_witten_13', 'KS003', 'ibl_witten_19', 'SWC_022', 'CSH_ZAD_017']: + continue + print(subject) try: @@ -1500,9 +1542,14 @@ if __name__ == "__main__": # training overview states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True) - regression = pmf_regressions(states, pmfs, durs) + + # b_flips = bias_flips(states, pmfs, durs) + # all_bias_flips.append(b_flips) + + regression, diffs = pmf_regressions(states, pmfs, durs) + regression_diffs += diffs regressions.append(regression) - quit() + continue # compare_pmfs(test, [s for s in state_sets if len(s) > 40], mode_indices, [4, 5], states, pmfs, title="{} convergence pmf".format(subject)) # compare_weights(test, [s for s in state_sets if len(s) > 40], mode_indices, [4, 5], states, title="{} convergence weights".format(subject)) # quit() @@ -1720,7 +1767,6 @@ if __name__ == "__main__": except FileNotFoundError as e: print(e) - continue print('no canoncial result') print(r_hats[subject]) if r_hats[subject] >= 1.05: @@ -1749,17 +1795,12 @@ if __name__ == "__main__": except Exception: break save_id = "{}_fittype_{}_var_{}_{}_{}.p".format(subject, fit_type, fit_variance, seed, fit_num).replace('.', '_') - # for file in os.listdir("./dynamic_GLMiHMM_crossvals/infos_new/"): - # if file.startswith("{}_".format(subject)) and file.endswith("_{}_{}_{}_{}.json".format(fit_type, fit_variance, seed, fit_num)): - # print("Taking fit infos from {}".format(file)) - # fit_infos = json.load(open("./dynamic_GLMiHMM_crossvals/infos_new/" + file, 'r')) - # sample_lls = fit_infos['ll'] print("Loaded") result = MCMC_result(samples[::25], infos=info_dict, data=samples[0].datas, sessions=fit_type, fit_variance=fit_variance, - dur=dur, save_id=save_id) # , sample_lls=sample_lls) + dur=dur, save_id=save_id) results.append(result) test = MCMC_result_list(results, summary_info) pickle.dump(test, open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'wb')) @@ -1789,13 +1830,14 @@ if __name__ == "__main__": # pickle.dump(all_intros_div, open("all_intros_div.p", 'wb')) # pickle.dump(all_states_per_type, open("all_states_per_type.p", 'wb')) # pickle.dump(regressions, open("regressions.p", 'wb')) + # pickle.dump(regression_diffs, open("regression_diffs.p", 'wb')) + # pickle.dump(all_bias_flips, open("all_bias_flips.p", 'wb')) # # a = [x for x, y in zip(all_pmf_asymms, all_pmf_diffs) if y >= 0.2] # b = [y for x, y in zip(all_pmf_asymms, all_pmf_diffs) if y >= 0.2] # plt.hist2d(a, b, bins=40) # plt.show() - if True: abs_state_durs = np.array(abs_state_durs) # pickle.dump(abs_state_durs, open("multi_chain_saves/abs_state_durs.p", 'wb')) diff --git a/index_mice.py b/index_mice.py index ebb58b840677ea1ffc0cdd34584d08c4403b3a75..e821de0e2e86ea42e5859cc88e8c545179cea4e8 100644 --- a/index_mice.py +++ b/index_mice.py @@ -42,7 +42,7 @@ for filename in os.listdir("./dynamic_GLMiHMM_crossvals/"): for s in prebias_subinfo.keys(): - if len(prebias_subinfo[s]["fit_nums"]) == 48: + if len(prebias_subinfo[s]["fit_nums"]) == 48 or len(prebias_subinfo[s]["fit_nums"]) == 32: new_fit_nums = [] new_seeds = [] for fit_num, seed in zip(prebias_subinfo[s]["fit_nums"], prebias_subinfo[s]["seeds"]): @@ -51,23 +51,23 @@ for s in prebias_subinfo.keys(): new_seeds.append(seed) prebias_subinfo[s]["fit_nums"] = new_fit_nums prebias_subinfo[s]["seeds"] = new_seeds - if s == 'ibl_witten_13': - new_fit_nums = [] - new_seeds = [] - for fit_num, seed in zip(prebias_subinfo[s]["fit_nums"], prebias_subinfo[s]["seeds"]): - if int(seed) < 316: - new_fit_nums.append(fit_num) - new_seeds.append(seed) - prebias_subinfo[s]["fit_nums"] = new_fit_nums - prebias_subinfo[s]["seeds"] = new_seeds + # if s == 'ibl_witten_13': + # new_fit_nums = [] + # new_seeds = [] + # for fit_num, seed in zip(prebias_subinfo[s]["fit_nums"], prebias_subinfo[s]["seeds"]): + # if int(seed) < 316: + # new_fit_nums.append(fit_num) + # new_seeds.append(seed) + # prebias_subinfo[s]["fit_nums"] = new_fit_nums + # prebias_subinfo[s]["seeds"] = new_seeds big = [] non_big = [] sim_subjects = [] for s in prebias_subinfo.keys(): - # assert len(sub_info[s]["seeds"]) in [16, 32] - # assert len(sub_info[s]["fit_nums"]) in [16, 32] + assert len(prebias_subinfo[s]["seeds"]) in [16, 32] + assert len(prebias_subinfo[s]["fit_nums"]) in [16, 32] print(s, len(prebias_subinfo[s]["fit_nums"])) if len(prebias_subinfo[s]["fit_nums"]) == 32: big.append(s) diff --git a/process_many_chains.py b/process_many_chains.py index 88840671866206bcdee29ba09c549eb8181bae00..cf8061356316a47aae18dedf91b2d0b4b2b27e9c 100644 --- a/process_many_chains.py +++ b/process_many_chains.py @@ -9,7 +9,7 @@ import json from dyn_glm_chain_analysis import MCMC_result import matplotlib.pyplot as plt import time -from mcmc_chain_analysis import state_size_helper, state_num_helper, ll_helper +from mcmc_chain_analysis import state_size_helper, state_num_helper fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0] @@ -17,7 +17,7 @@ if fit_type == 'bias': loading_info = json.load(open("canonical_infos_bias.json", 'r')) elif fit_type == 'prebias': loading_info = json.load(open("canonical_infos.json", 'r')) -subjects = ['GLM_Sim_11_trick'] # list(loading_info.keys()) +subjects = ['SWC_021', 'ibl_witten_15', 'ibl_witten_13', 'KS003', 'ibl_witten_19', 'SWC_022', 'CSH_ZAD_017'] # list(loading_info.keys()) fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0] func1 = state_num_helper(0.2)