diff --git a/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc b/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc index 4a89c93f30137d03c0b66cc48856b5c8f6b5a111..797f6ea4ca71054cc38cb50347dc59cdea7fef0b 100644 Binary files a/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc and b/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc differ diff --git a/__pycache__/index_mice.cpython-37.pyc b/__pycache__/index_mice.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc0d0f83258d0e8e4559bfb0f57d89d9f097425c Binary files /dev/null and b/__pycache__/index_mice.cpython-37.pyc differ diff --git a/__pycache__/mcmc_chain_analysis.cpython-37.pyc b/__pycache__/mcmc_chain_analysis.cpython-37.pyc index dfcc898f4df0897c89643b35287844e38e54193b..852f09637cc815a928e81c03093be6c9e08945c3 100644 Binary files a/__pycache__/mcmc_chain_analysis.cpython-37.pyc and b/__pycache__/mcmc_chain_analysis.cpython-37.pyc differ diff --git a/behavioral_state_data.py b/behavioral_state_data.py index be9aef832e5d8d480af615b126d5528cad2f8229..11a8cc41aef7d5e3b48b1fd791604b18415cff46 100644 --- a/behavioral_state_data.py +++ b/behavioral_state_data.py @@ -73,7 +73,7 @@ to_introduce = [2, 3, 4, 5] # "ibl_witten_06", "ibl_witten_07", "ibl_witten_12", "ibl_witten_13", "ibl_witten_14", "ibl_witten_15", # "ibl_witten_16", "KS003", "KS005", "KS019", "NYU-01", "NYU-02", "NYU-04", "NYU-06", "ZM_1367", "ZM_1369", # "ZM_1371", "ZM_1372", "ZM_1743", "ZM_1745", "ZM_1746"] # zoe's subjects -subjects = ['ibl_witten_14'] +subjects = ['ZFM-05236'] data_folder = 'session_data' # why does CSHL058 not work? @@ -149,7 +149,8 @@ for subject in subjects: contrast_set = {0, 1, 9, 10} rel_count = -1 - for i, (eid, extra_eids) in enumerate(zip(fixed_eids, additional_eids)): + quit() + for i, (eid, extra_eids, date) in enumerate(zip(fixed_eids, additional_eids, fixed_dates)): try: trials = one.load_object(eid, 'trials') @@ -182,6 +183,8 @@ for subject in subjects: df = pd.concat([df, df2], ignore_index=1) print('new size: {}'.format(len(df))) + pickle.dump(df, open("./sofiya_data/{}_df_{}_{}.p".format(subject, rel_count, date), "wb")) + current_contrasts = set(df['signed_contrast']) diff = current_contrasts.difference(contrast_set) for c in to_introduce: diff --git a/behavioral_state_data_easier.py b/behavioral_state_data_easier.py index e526d7cc485df51c38469af52b0e669aee793b5c..18a9de78e8aea1c241ff6a474e44e58a204d84c7 100644 --- a/behavioral_state_data_easier.py +++ b/behavioral_state_data_easier.py @@ -34,14 +34,12 @@ misses = [] to_introduce = [2, 3, 4, 5] amiss = ['UCLA034', 'UCLA036', 'UCLA037', 'PL015', 'PL016', 'PL017', 'PL024', 'NR_0017', 'NR_0019', 'NR_0020', 'NR_0021', 'NR_0027'] -subjects = ['NYU-21'] +subjects = ['ZFM-04019', 'ZFM-05236'] fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0] if fit_type == 'bias': loading_info = json.load(open("canonical_infos_bias.json", 'r')) - r_hats = json.load(open("canonical_info_r_hats_bias.json", 'r')) elif fit_type == 'prebias': loading_info = json.load(open("canonical_infos.json", 'r')) - r_hats = json.load(open("canonical_info_r_hats.json", 'r')) already_fit = list(loading_info.keys()) remaining_subs = [s for s in subjects if s not in amiss and s not in already_fit] @@ -70,13 +68,14 @@ for subject in subjects: print('_____________________') print(subject) - if subject in already_fit or subject in amiss: - continue + # if subject in already_fit or subject in amiss: + # continue trials = one.load_aggregate('subjects', subject, '_ibl_subjectTrials.table') # Load training status and join to trials table training = one.load_aggregate('subjects', subject, '_ibl_subjectTraining.table') + quit() trials = (trials .set_index('session') .join(training.set_index('session')) diff --git a/canonical_infos.json b/canonical_infos.json index 12a3fa3aae799f418267abc5e98f113db7d5d248..99ae626eb5a4f9444331e2cd47bdbe9b3c2131ce 100644 --- a/canonical_infos.json +++ b/canonical_infos.json @@ -1 +1 @@ -{"SWC_022": {"seeds": ["412", "401", "403", "413", "406", "407", "415", "409", "405", "400", "408", "404", "410", "411", "414", "402"], "fit_nums": ["347", "54", "122", "132", "520", "386", "312", "59", "999", "849", "372", "300", "485", "593", "358", "550"], "chain_num": 19}, "SWC_023": {"seeds": ["302", "312", "304", "300", "315", "311", "308", "305", "303", "309", "306", "313", "307", "314", "301", "310"], "fit_nums": ["994", "913", "681", "816", "972", "790", "142", "230", "696", "537", "975", "773", "918", "677", "742", "745"], "chain_num": 4}, "ZFM-05236": {"seeds": ["404", "409", "401", "412", "406", "411", "410", "402", "408", "405", "415", "403", "414", "400", "413", "407"], "fit_nums": ["106", "111", "333", "253", "395", "76", "186", "192", "221", "957", "989", "612", "632", "304", "50", "493"], "chain_num": 14, "ignore": [3, 0, 8, 2, 6, 15, 14, 10]}, "ibl_witten_15": {"seeds": ["408", "412", "400", "411", "410", "407", "403", "406", "413", "405", "404", "402", "401", "415", "409", "414"], "fit_nums": ["40", "241", "435", "863", "941", "530", "382", "750", "532", "731", "146", "500", "967", "334", "375", "670"], "chain_num": 19}, "ibl_witten_13": {"seeds": ["401", "414", "409", "413", "415", "411", "410", "408", "402", "405", "406", "407", "412", "403", "400", "404"], "fit_nums": ["702", "831", "47", "740", "251", "929", "579", "351", "515", "261", "222", "852", "754", "892", "473", "29"], "chain_num": 19}, "KS016": {"seeds": ["315", "301", "309", "313", "302", "307", "303", "308", "311", "312", "314", "306", "310", "300", "305", "304"], "fit_nums": ["99", "57", "585", "32", "501", "558", "243", "413", "59", "757", "463", "172", "524", "957", "909", "292"], "chain_num": 4}, "ibl_witten_19": {"seeds": ["412", "415", "413", "408", "409", "404", "403", "401", "405", "411", "410", "406", "402", "414", "407", "400"], "fit_nums": ["234", "41", "503", "972", "935", "808", "912", "32", "331", "755", "117", "833", "822", "704", "901", "207"], "chain_num": 19}, "CSH_ZAD_017": {"seeds": ["408", "404", "413", "406", "414", "411", "400", "401", "415", "407", "402", "412", "403", "409", "405", "410"], "fit_nums": ["928", "568", "623", "841", "92", "251", "829", "922", "964", "257", "150", "970", "375", "113", "423", "564"], "chain_num": 19}, "KS022": {"seeds": ["315", "300", "314", "301", "303", "302", "306", "308", "305", "310", "313", "312", "304", "307", "311", "309"], "fit_nums": ["899", "681", "37", "957", "629", "637", "375", "980", "810", "51", "759", "664", "420", "127", "259", "555"], "chain_num": 4}, "CSH_ZAD_025": {"seeds": ["303", "311", "307", "312", "313", "314", "308", "315", "305", "306", "304", "302", "309", "310", "301", "300"], "fit_nums": ["581", "148", "252", "236", "581", "838", "206", "756", "449", "288", "756", "593", "733", "633", "418", "563"], "chain_num": 4}, "ibl_witten_17": {"seeds": ["406", "415", "408", "413", "402", "405", "409", "400", "414", "401", "412", "407", "404", "410", "403", "411"], "fit_nums": ["827", "797", "496", "6", "444", "823", "384", "873", "634", "27", "811", "142", "207", "322", "756", "275"], "chain_num": 9}, "SWC_021": {"seeds": ["404", "413", "406", "412", "403", "401", "410", "409", "400", "414", "415", "402", "405", "408", "411", "407"], "fit_nums": ["840", "978", "224", "38", "335", "500", "83", "509", "441", "9", "135", "890", "358", "460", "844", "30"], "chain_num": 19}, "ibl_witten_18": {"seeds": ["311", "310", "303", "314", "302", "309", "305", "307", "312", "300", "308", "306", "315", "313", "304", "301"], "fit_nums": ["236", "26", "838", "762", "826", "409", "496", "944", "280", "704", "930", "419", "637", "896", "876", "297"], "chain_num": 4}, "CSHL_018": {"seeds": ["302", "310", "306", "300", "314", "307", "309", "313", "311", "308", "304", "301", "312", "303", "305", "315"], "fit_nums": ["843", "817", "920", "900", "226", "36", "472", "676", "933", "453", "116", "263", "269", "897", "568", "438"], "chain_num": 4}, "ZFM-05245": {"seeds": ["400", "413", "414", "409", "411", "403", "405", "412", "406", "410", "407", "415", "401", "404", "402", "408"], "fit_nums": ["512", "765", "704", "17", "539", "449", "584", "987", "138", "932", "869", "313", "253", "540", "37", "634"], "chain_num": 11}, "GLM_Sim_06": {"seeds": ["313", "309", "302", "303", "305", "314", "300", "315", "311", "306", "304", "310", "301", "312", "308", "307"], "fit_nums": ["9", "786", "286", "280", "72", "587", "619", "708", "360", "619", "311", "189", "60", "708", "939", "733"], "chain_num": 2}, "ZM_1897": {"seeds": ["304", "308", "305", "311", "315", "314", "307", "306", "300", "303", "313", "310", "301", "312", "302", "309"], "fit_nums": ["549", "96", "368", "509", "424", "897", "287", "426", "968", "93", "725", "513", "837", "581", "989", "374"], "chain_num": 4}, "CSHL_020": {"seeds": ["305", "309", "313", "302", "314", "310", "300", "307", "315", "306", "312", "304", "311", "301", "303", "308"], "fit_nums": ["222", "306", "243", "229", "584", "471", "894", "238", "986", "660", "494", "657", "896", "459", "100", "283"], "chain_num": 4}, "CSHL054": {"seeds": ["401", "415", "409", "410", "414", "413", "407", "405", "406", "408", "411", "400", "412", "402", "403", "404"], "fit_nums": ["901", "734", "609", "459", "574", "793", "978", "66", "954", "906", "954", "111", "292", "850", "266", "967"], "chain_num": 9}, "CSHL_014": {"seeds": ["305", "311", "309", "300", "313", "310", "307", "306", "304", "312", "308", "302", "314", "303", "301", "315"], "fit_nums": ["371", "550", "166", "24", "705", "385", "870", "884", "831", "546", "404", "722", "287", "564", "613", "783"], "chain_num": 4}, "ZFM-04019": {"seeds": ["413", "404", "408", "403", "415", "406", "414", "410", "402", "405", "411", "400", "401", "412", "409", "407"], "fit_nums": ["493", "302", "590", "232", "121", "938", "270", "999", "95", "175", "576", "795", "728", "244", "32", "177"], "chain_num": 14, "ignore": [11, 12, 6, 3, 5, 2, 15, 1]}, "CSHL062": {"seeds": ["307", "313", "310", "303", "306", "312", "308", "305", "311", "314", "304", "302", "300", "301", "315", "309"], "fit_nums": ["846", "371", "94", "888", "499", "229", "546", "432", "71", "989", "986", "91", "935", "314", "975", "481"], "chain_num": 4}, "CSH_ZAD_001": {"seeds": ["313", "309", "311", "312", "305", "310", "315", "300", "314", "304", "301", "302", "308", "303", "306", "307"], "fit_nums": ["468", "343", "314", "544", "38", "120", "916", "170", "305", "569", "502", "496", "452", "336", "559", "572"], "chain_num": 4}, "KS003": {"seeds": ["405", "401", "414", "415", "410", "404", "409", "413", "412", "408", "411", "407", "402", "406", "403", "400"], "fit_nums": ["858", "464", "710", "285", "665", "857", "990", "438", "233", "177", "43", "509", "780", "254", "523", "695"], "chain_num": 19}, "NYU-06": {"seeds": ["314", "309", "306", "305", "312", "303", "307", "304", "300", "302", "310", "301", "315", "308", "313", "311"], "fit_nums": ["950", "862", "782", "718", "427", "645", "827", "612", "821", "834", "595", "929", "679", "668", "648", "869"], "chain_num": 4}, "KS019": {"seeds": ["404", "401", "411", "408", "400", "403", "410", "413", "402", "407", "415", "409", "406", "414", "412", "405"], "fit_nums": ["682", "4", "264", "200", "250", "267", "737", "703", "132", "855", "922", "686", "85", "176", "54", "366"], "chain_num": 9}, "CSHL049": {"seeds": ["411", "402", "414", "408", "409", "410", "413", "407", "406", "401", "404", "405", "403", "415", "400", "412"], "fit_nums": ["104", "553", "360", "824", "749", "519", "347", "228", "863", "671", "140", "883", "701", "445", "627", "898"], "chain_num": 9}, "ibl_witten_14": {"seeds": ["310", "311", "304", "306", "300", "302", "314", "313", "303", "308", "301", "309", "305", "315", "312", "307"], "fit_nums": ["563", "120", "85", "712", "277", "871", "183", "661", "505", "598", "210", "89", "310", "638", "564", "998"], "chain_num": 4}, "KS014": {"seeds": ["301", "310", "302", "312", "313", "308", "307", "303", "305", "300", "314", "306", "311", "309", "304", "315"], "fit_nums": ["668", "32", "801", "193", "269", "296", "74", "24", "270", "916", "21", "250", "342", "451", "517", "293"], "chain_num": 4}, "CSHL059": {"seeds": ["306", "309", "300", "304", "314", "303", "315", "311", "313", "305", "301", "307", "302", "312", "310", "308"], "fit_nums": ["821", "963", "481", "999", "986", "45", "551", "605", "701", "201", "629", "261", "972", "407", "165", "9"], "chain_num": 4}, "GLM_Sim_13": {"seeds": ["310", "303", "308", "306", "300", "312", "301", "313", "305", "311", "315", "304", "314", "309", "307", "302"], "fit_nums": ["982", "103", "742", "524", "614", "370", "926", "456", "133", "143", "302", "80", "395", "549", "579", "944"], "chain_num": 2}, "CSHL_007": {"seeds": ["314", "303", "308", "313", "301", "300", "302", "305", "315", "306", "310", "309", "311", "304", "307", "312"], "fit_nums": ["462", "703", "345", "286", "480", "313", "986", "165", "201", "102", "322", "894", "960", "438", "330", "169"], "chain_num": 4}, "CSH_ZAD_011": {"seeds": ["314", "311", "303", "300", "305", "310", "306", "301", "302", "315", "304", "309", "308", "312", "313", "307"], "fit_nums": ["320", "385", "984", "897", "315", "120", "320", "945", "475", "403", "210", "412", "695", "564", "664", "411"], "chain_num": 4}, "KS021": {"seeds": ["309", "312", "304", "310", "303", "311", "314", "302", "305", "301", "306", "300", "308", "315", "313", "307"], "fit_nums": ["874", "943", "925", "587", "55", "136", "549", "528", "349", "211", "401", "84", "225", "545", "153", "382"], "chain_num": 4}, "GLM_Sim_15": {"seeds": ["303", "312", "305", "308", "309", "302", "301", "310", "313", "315", "311", "314", "307", "306", "304", "300"], "fit_nums": ["769", "930", "328", "847", "899", "714", "144", "518", "521", "873", "914", "359", "242", "343", "45", "364"], "chain_num": 2}, "CSHL_015": {"seeds": ["301", "302", "307", "310", "309", "311", "304", "312", "300", "308", "313", "305", "314", "315", "306", "303"], "fit_nums": ["717", "705", "357", "539", "604", "971", "669", "76", "45", "413", "510", "122", "190", "821", "368", "472"], "chain_num": 4}, "ibl_witten_16": {"seeds": ["304", "313", "309", "314", "312", "307", "305", "301", "306", "310", "300", "315", "308", "311", "303", "302"], "fit_nums": ["392", "515", "696", "270", "7", "583", "880", "674", "23", "576", "579", "695", "149", "854", "184", "875"], "chain_num": 4}, "KS015": {"seeds": ["315", "305", "309", "303", "314", "310", "311", "312", "313", "300", "307", "308", "304", "301", "302", "306"], "fit_nums": ["257", "396", "387", "435", "133", "164", "403", "8", "891", "650", "111", "557", "473", "229", "842", "196"], "chain_num": 4}, "GLM_Sim_12": {"seeds": ["304", "312", "306", "303", "310", "302", "300", "305", "308", "313", "307", "311", "315", "301", "314", "309"], "fit_nums": ["971", "550", "255", "195", "952", "486", "841", "535", "559", "37", "654", "213", "864", "506", "732", "550"], "chain_num": 2}, "GLM_Sim_11": {"seeds": ["300", "312", "310", "315", "302", "313", "314", "311", "308", "303", "309", "307", "306", "304", "301", "305"], "fit_nums": ["477", "411", "34", "893", "195", "293", "603", "5", "887", "281", "956", "73", "346", "640", "532", "688"], "chain_num": 2}, "GLM_Sim_10": {"seeds": ["301", "300", "306", "305", "307", "309", "312", "314", "311", "315", "304", "313", "303", "308", "302", "310"], "fit_nums": ["391", "97", "897", "631", "239", "652", "19", "448", "807", "35", "972", "469", "280", "562", "42", "706"], "chain_num": 2}, "CSH_ZAD_026": {"seeds": ["312", "313", "308", "310", "303", "307", "302", "305", "300", "315", "306", "301", "311", "304", "314", "309"], "fit_nums": ["699", "87", "537", "628", "797", "511", "459", "770", "969", "240", "504", "948", "295", "506", "25", "378"], "chain_num": 4}, "KS023": {"seeds": ["304", "313", "306", "309", "300", "314", "302", "310", "303", "315", "307", "308", "301", "311", "305", "312"], "fit_nums": ["698", "845", "319", "734", "908", "507", "45", "499", "175", "108", "419", "443", "116", "779", "159", "231"], "chain_num": 4}, "GLM_Sim_05": {"seeds": ["301", "315", "300", "302", "305", "304", "313", "314", "311", "309", "306", "307", "308", "310", "303", "312"], "fit_nums": ["425", "231", "701", "375", "343", "902", "623", "125", "921", "637", "393", "964", "678", "930", "796", "42"], "chain_num": 2}, "CSHL061": {"seeds": ["305", "315", "304", "303", "309", "310", "302", "300", "314", "306", "311", "313", "301", "308", "307", "312"], "fit_nums": ["396", "397", "594", "911", "308", "453", "686", "552", "103", "209", "128", "892", "345", "925", "777", "396"], "chain_num": 4}, "CSHL051": {"seeds": ["303", "310", "306", "302", "309", "305", "313", "308", "300", "314", "311", "307", "312", "304", "315", "301"], "fit_nums": ["69", "186", "49", "435", "103", "910", "705", "367", "303", "474", "596", "334", "929", "796", "616", "790"], "chain_num": 4}, "GLM_Sim_14": {"seeds": ["310", "311", "309", "313", "314", "300", "302", "304", "305", "306", "307", "312", "303", "301", "315", "308"], "fit_nums": ["616", "872", "419", "106", "940", "986", "599", "704", "218", "808", "244", "825", "448", "397", "552", "316"], "chain_num": 2}, "GLM_Sim_11_trick": {"seeds": ["411", "400", "408", "409", "415", "413", "410", "412", "406", "414", "403", "404", "401", "405", "407", "402"], "fit_nums": ["95", "508", "886", "384", "822", "969", "525", "382", "489", "436", "344", "537", "251", "223", "458", "401"], "chain_num": 2}, "GLM_Sim_16": {"seeds": ["302", "311", "303", "307", "313", "308", "309", "300", "305", "315", "304", "310", "312", "301", "314", "306"], "fit_nums": ["914", "377", "173", "583", "870", "456", "611", "697", "13", "713", "159", "248", "617", "37", "770", "780"], "chain_num": 2}, "ZM_3003": {"seeds": ["300", "304", "307", "312", "305", "310", "311", "314", "303", "308", "313", "301", "315", "309", "306", "302"], "fit_nums": ["603", "620", "657", "735", "357", "390", "119", "33", "62", "617", "209", "810", "688", "21", "744", "426"], "chain_num": 4}, "CSH_ZAD_022": {"seeds": ["305", "310", "311", "315", "303", "312", "314", "313", "307", "302", "300", "304", "301", "308", "306", "309"], "fit_nums": ["143", "946", "596", "203", "576", "403", "900", "65", "478", "325", "282", "513", "460", "42", "161", "970"], "chain_num": 4}, "GLM_Sim_07": {"seeds": ["300", "309", "302", "304", "305", "312", "301", "311", "315", "314", "308", "307", "303", "310", "306", "313"], "fit_nums": ["724", "701", "118", "230", "648", "426", "689", "114", "832", "731", "592", "519", "559", "938", "672", "144"], "chain_num": 1}, "KS017": {"seeds": ["311", "310", "306", "309", "303", "302", "308", "300", "313", "301", "314", "307", "315", "304", "312", "305"], "fit_nums": ["97", "281", "808", "443", "352", "890", "703", "468", "780", "708", "674", "27", "345", "23", "939", "457"], "chain_num": 4}, "GLM_Sim_11_sub": {"seeds": ["410", "414", "413", "404", "409", "415", "406", "408", "402", "411", "400", "405", "403", "407", "412", "401"], "fit_nums": ["830", "577", "701", "468", "929", "374", "954", "749", "937", "488", "873", "416", "612", "792", "461", "488"], "chain_num": 2}} \ No newline at end of file +{"NYU-45": {"seeds": ["513", "503", "500", "506", "502", "501", "512", "509", "507", "515", "510", "505", "504", "514", "511", "508"], "fit_nums": ["768", "301", "96", "731", "879", "989", "915", "512", "295", "48", "157", "631", "666", "334", "682", "714"], "chain_num": 14}, "UCLA035": {"seeds": ["512", "509", "500", "510", "501", "503", "506", "513", "505", "504", "507", "502", "514", "508", "511", "515"], "fit_nums": ["715", "834", "996", "656", "242", "883", "870", "959", "483", "94", "864", "588", "390", "173", "967", "871"], "chain_num": 14}, "NYU-30": {"seeds": ["505", "508", "515", "507", "504", "513", "512", "503", "509", "500", "510", "514", "501", "502", "511", "506"], "fit_nums": ["885", "637", "318", "98", "209", "171", "472", "823", "956", "89", "762", "260", "76", "319", "139", "785"], "chain_num": 14}, "CSHL047": {"seeds": ["509", "510", "503", "502", "501", "508", "514", "505", "507", "511", "515", "504", "500", "506", "513", "512"], "fit_nums": ["60", "589", "537", "3", "178", "99", "877", "381", "462", "527", "6", "683", "771", "950", "294", "252"], "chain_num": 14}, "NYU-39": {"seeds": ["515", "502", "513", "508", "514", "503", "510", "506", "509", "504", "511", "500", "512", "501", "507", "505"], "fit_nums": ["722", "12", "207", "378", "698", "928", "15", "180", "650", "334", "388", "528", "608", "593", "988", "479"], "chain_num": 14}, "NYU-37": {"seeds": ["508", "509", "506", "512", "507", "503", "515", "504", "500", "505", "513", "501", "510", "511", "502", "514"], "fit_nums": ["94", "97", "793", "876", "483", "878", "886", "222", "66", "59", "601", "994", "526", "694", "304", "615"], "chain_num": 14}, "KS045": {"seeds": ["501", "514", "500", "502", "503", "505", "510", "515", "508", "507", "509", "512", "511", "506", "513", "504"], "fit_nums": ["731", "667", "181", "609", "489", "555", "995", "19", "738", "1", "267", "653", "750", "332", "218", "170"], "chain_num": 14}, "UCLA006": {"seeds": ["509", "505", "513", "515", "511", "500", "503", "507", "508", "514", "504", "510", "502", "501", "506", "512"], "fit_nums": ["807", "849", "196", "850", "293", "874", "216", "542", "400", "632", "781", "219", "331", "730", "740", "32"], "chain_num": 14}, "UCLA033": {"seeds": ["507", "505", "501", "512", "502", "510", "513", "514", "506", "515", "509", "504", "508", "500", "503", "511"], "fit_nums": ["366", "18", "281", "43", "423", "877", "673", "146", "921", "21", "353", "267", "674", "113", "905", "252"], "chain_num": 14}, "NYU-40": {"seeds": ["513", "500", "503", "507", "515", "504", "510", "508", "505", "514", "502", "512", "501", "509", "511", "506"], "fit_nums": ["656", "826", "657", "634", "861", "347", "334", "227", "747", "834", "460", "191", "489", "458", "24", "346"], "chain_num": 14}, "NYU-46": {"seeds": ["507", "509", "512", "508", "503", "500", "515", "501", "511", "506", "502", "505", "510", "513", "514", "504"], "fit_nums": ["503", "523", "5", "819", "190", "917", "707", "609", "145", "416", "376", "603", "655", "271", "223", "149"], "chain_num": 14}, "KS044": {"seeds": ["513", "503", "507", "504", "502", "515", "501", "511", "510", "500", "512", "508", "509", "506", "514", "505"], "fit_nums": ["367", "656", "73", "877", "115", "627", "610", "772", "558", "581", "398", "267", "353", "779", "393", "473"], "chain_num": 14}, "NYU-48": {"seeds": ["502", "507", "503", "515", "513", "505", "512", "501", "510", "504", "508", "511", "506", "514", "500", "509"], "fit_nums": ["854", "52", "218", "963", "249", "901", "248", "322", "566", "768", "256", "101", "303", "485", "577", "141"], "chain_num": 14}, "UCLA012": {"seeds": ["508", "506", "504", "513", "512", "502", "507", "505", "510", "503", "515", "509", "511", "501", "500", "514"], "fit_nums": ["492", "519", "577", "417", "717", "60", "130", "186", "725", "83", "841", "65", "441", "534", "856", "735"], "chain_num": 14}, "KS084": {"seeds": ["515", "507", "512", "503", "506", "508", "510", "502", "509", "505", "500", "511", "504", "501", "513", "514"], "fit_nums": ["816", "374", "140", "955", "399", "417", "733", "149", "300", "642", "644", "248", "324", "830", "889", "286"], "chain_num": 14}, "CSHL052": {"seeds": ["502", "508", "509", "514", "507", "515", "506", "500", "501", "505", "504", "511", "513", "512", "503", "510"], "fit_nums": ["572", "630", "784", "813", "501", "738", "517", "461", "690", "203", "40", "202", "412", "755", "837", "917"], "chain_num": 14}, "NYU-11": {"seeds": ["502", "510", "501", "513", "512", "500", "503", "515", "506", "514", "511", "505", "509", "504", "508", "507"], "fit_nums": ["650", "771", "390", "185", "523", "901", "387", "597", "57", "624", "12", "833", "433", "58", "276", "248"], "chain_num": 14}, "KS051": {"seeds": ["502", "505", "508", "515", "512", "511", "501", "509", "500", "506", "513", "503", "504", "507", "514", "510"], "fit_nums": ["620", "548", "765", "352", "402", "699", "370", "445", "159", "746", "449", "342", "642", "204", "726", "605"], "chain_num": 14}, "NYU-27": {"seeds": ["507", "513", "501", "505", "511", "504", "500", "514", "502", "509", "503", "515", "512", "510", "508", "506"], "fit_nums": ["485", "520", "641", "480", "454", "913", "526", "705", "138", "151", "962", "24", "21", "743", "119", "699"], "chain_num": 14}, "UCLA011": {"seeds": ["513", "503", "508", "501", "510", "505", "506", "511", "507", "514", "515", "500", "509", "502", "512", "504"], "fit_nums": ["957", "295", "743", "795", "643", "629", "142", "174", "164", "21", "835", "338", "368", "341", "209", "68"], "chain_num": 14}, "NYU-47": {"seeds": ["515", "504", "510", "503", "505", "506", "502", "501", "507", "500", "514", "513", "508", "509", "511", "512"], "fit_nums": ["169", "941", "329", "51", "788", "654", "224", "434", "385", "46", "712", "84", "930", "571", "273", "312"], "chain_num": 14}, "CSHL045": {"seeds": ["514", "502", "510", "507", "512", "501", "511", "506", "508", "509", "503", "513", "500", "504", "515", "505"], "fit_nums": ["862", "97", "888", "470", "620", "765", "874", "421", "104", "909", "924", "874", "158", "992", "25", "40"], "chain_num": 14}, "UCLA017": {"seeds": ["510", "502", "500", "515", "503", "507", "514", "501", "511", "505", "513", "504", "508", "512", "506", "509"], "fit_nums": ["875", "684", "841", "510", "209", "207", "806", "700", "989", "899", "812", "971", "526", "887", "160", "249"], "chain_num": 14}, "CSHL055": {"seeds": ["509", "503", "505", "512", "504", "508", "510", "511", "507", "501", "500", "514", "513", "515", "506", "502"], "fit_nums": ["957", "21", "710", "174", "689", "796", "449", "183", "193", "209", "437", "827", "990", "705", "540", "835"], "chain_num": 14}, "UCLA005": {"seeds": ["507", "503", "512", "515", "505", "500", "504", "513", "509", "506", "501", "511", "514", "502", "510", "508"], "fit_nums": ["636", "845", "712", "60", "733", "789", "990", "230", "335", "337", "307", "404", "297", "608", "428", "108"], "chain_num": 14}, "CSHL060": {"seeds": ["510", "515", "513", "503", "514", "501", "504", "502", "511", "500", "508", "509", "505", "507", "506", "512"], "fit_nums": ["626", "953", "497", "886", "585", "293", "580", "867", "113", "734", "88", "55", "949", "443", "210", "555"], "chain_num": 14}, "UCLA015": {"seeds": ["503", "501", "502", "500"], "fit_nums": ["877", "773", "109", "747"], "chain_num": 14}, "KS055": {"seeds": ["510", "502", "511", "501", "509", "506", "500", "515", "503", "512", "504", "505", "513", "508", "514", "507"], "fit_nums": ["526", "102", "189", "216", "673", "477", "293", "981", "960", "883", "899", "95", "31", "244", "385", "631"], "chain_num": 14}, "UCLA014": {"seeds": ["513", "509", "508", "510", "500", "504", "515", "511", "501", "514", "507", "502", "505", "512", "506", "503"], "fit_nums": ["628", "950", "418", "91", "25", "722", "792", "225", "287", "272", "23", "168", "821", "934", "194", "481"], "chain_num": 14}, "CSHL053": {"seeds": ["505", "513", "501", "508", "502", "515", "510", "507", "506", "509", "514", "511", "500", "503", "504", "512"], "fit_nums": ["56", "674", "979", "0", "221", "577", "25", "679", "612", "185", "464", "751", "648", "715", "344", "348"], "chain_num": 14}, "NYU-12": {"seeds": ["508", "506", "509", "515", "511", "510", "501", "504", "513", "503", "507", "512", "500", "514", "505", "502"], "fit_nums": ["853", "819", "692", "668", "730", "213", "846", "596", "644", "829", "976", "895", "974", "824", "179", "769"], "chain_num": 14}, "KS043": {"seeds": ["505", "504", "507", "500", "502", "503", "508", "511", "512", "509", "515", "513", "510", "514", "501", "506"], "fit_nums": ["997", "179", "26", "741", "476", "502", "597", "477", "511", "181", "233", "330", "299", "939", "542", "113"], "chain_num": 14}, "CSHL058": {"seeds": ["513", "505", "514", "506", "504", "500", "511", "503", "508", "509", "501", "510", "502", "515", "507", "512"], "fit_nums": ["752", "532", "949", "442", "400", "315", "106", "419", "903", "198", "553", "158", "674", "249", "723", "941"], "chain_num": 14}, "KS042": {"seeds": ["503", "502", "506", "511", "501", "505", "512", "509", "513", "508", "507", "500", "510", "504", "515", "514"], "fit_nums": ["241", "895", "503", "880", "283", "267", "944", "204", "921", "514", "392", "241", "28", "905", "334", "894"], "chain_num": 14}} \ No newline at end of file diff --git a/combine_canon_infos.py b/combine_canon_infos.py new file mode 100644 index 0000000000000000000000000000000000000000..a31fe3555793d0a19ceaae4c2aa4f01e9e50f841 --- /dev/null +++ b/combine_canon_infos.py @@ -0,0 +1,33 @@ +""" + Script for combining local canonical_infos.json and the one from the cluster +""" +import json + +dist_info = json.load(open("canonical_infos.json", 'r')) +local_info = json.load(open("canonical_infos_local.json", 'r')) + +cluster_subs = ['KS045', 'KS043', 'KS051', 'DY_008', 'KS084', 'KS052', 'KS046', 'KS096', 'KS086', 'UCLA033', 'UCLA005', 'NYU-21', 'KS055', 'KS091'] + +for key in cluster_subs: + print('ignore' in dist_info[key]) + +quit() + +for key in cluster_subs: + if key not in local_info: + print("Adding all of {} to local info".format(key)) + local_info[key] = dist_info[key] + continue + else: + for sub_key in dist_info[key]: + if sub_key not in local_info[key]: + print("Adding {} into local info for {}".format(key)) + local_info[key][sub_key] = dist_info[key][sub_key] + else: + if local_info[key][sub_key] == dist_info[key][sub_key]: + continue + else: + assert len(dist_info[key][sub_key]) == 16 + for x in dist_info[key][sub_key]: + assert x in local_info[key][sub_key] + local_info[key][sub_key] = dist_info[key][sub_key] diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py index e20781d88ad2813cf9ec63bd5ecb617bef87b722..1b8062228711583bf6024e9b68935547ab36d6ce 100644 --- a/dyn_glm_chain_analysis.py +++ b/dyn_glm_chain_analysis.py @@ -19,7 +19,7 @@ from matplotlib.ticker import MaxNLocator import json import time import multiprocessing as mp -from mcmc_chain_analysis import state_size_helper, state_num_helper, gamma_func, alpha_func, ll_func, r_hat_array_comp, rank_inv_normal_transform, eval_r_hat, eval_simple_r_hat +from mcmc_chain_analysis import state_num_helper, ll_func, r_hat_array_comp, rank_inv_normal_transform import pandas as pd from analysis_pmf import pmf_type, type2color @@ -670,10 +670,12 @@ def control_flow(test, indices, trials, func_init, first_for, second_for, end_fi continue first_for(test, results) + counter = -1 for i, m in enumerate([item for sublist in test.results for item in sublist.models]): if i not in indices: continue - second_for(m, j, session_trials, trial_counter, results) + counter += 1 + second_for(m, j, counter, session_trials, trial_counter, results) end_first_for(results, indices, j, trial_counter=trial_counter, session_trials=session_trials) trial_counter += len(only_for_length) @@ -681,181 +683,6 @@ def control_flow(test, indices, trials, func_init, first_for, second_for, end_fi return results -def create_mode_indices(test, subject, fit_type): - dim = 3 - - try: - xy, z = pickle.load(open("multi_chain_saves/xyz_{}_{}.p".format(subject, fit_type), 'rb')) - except Exception: - print('Doing PCA') - ev, eig, projection_matrix, dimreduc = test.state_pca(subject, pca_type='dists', dim=dim) - xy = np.vstack([dimreduc[i] for i in range(dim)]) - from scipy.stats import gaussian_kde - z = gaussian_kde(xy)(xy) - pickle.dump((xy, z), open("multi_chain_saves/xyz_{}_{}.p".format(subject, fit_type), 'wb')) - - print("Mode indices of " + subject) - - threshold_search(xy, z, test, 'first_', subject, fit_type) - - print("Find another mode?") - if input() not in ['yes', 'y']: - return - - threshold_search(xy, z, test, 'second_', subject, fit_type) - - print("Find another mode?") - if input() not in ['yes', 'y']: - return - - threshold_search(xy, z, test, 'third_', subject, fit_type) - return - -def threshold_search(xy, z, test, mode_prefix, subject, fit_type): - happy = False - conds = [0, None, None, None, None] - x_min, x_max, y_min, y_max = None, None, None, None - while not happy: - print() - print("Pick level") - prob_level = input() - if prob_level == 'cond': - print("x > ?") - resp = input() - if resp not in ['n', 'no']: - x_min = float(resp) - - print("x < ?") - resp = input() - if resp not in ['n', 'no']: - x_max = float(resp) - - print("y > ?") - resp = input() - if resp not in ['n', 'no']: - y_min = float(resp) - - print("y < ?") - resp = input() - if resp not in ['n', 'no']: - y_max = float(resp) - - print("Prob level") - prob_level = float(input()) - conds = [prob_level, x_min, x_max, y_min, y_max] - print("Condtions are {}".format(conds)) - else: - prob_level = float(prob_level) - conds[0] = prob_level - - print("Level is {}".format(prob_level)) - - mode = conditions_fulfilled(z, xy, conds) - print("# of samples: {}".format(mode.sum())) - mode_indices = np.where(mode)[0] - if mode.sum() > 0: - print(xy[0][mode_indices].min(), xy[0][mode_indices].max(), xy[1][mode_indices].min(), xy[1][mode_indices].max()) - print("Happy?") - happy = 'yes' == input() - - print("Subset by factor?") - if input() == 'yes': - print("Factor?") - print(mode_indices.shape) - factor = int(input()) - mode_indices = mode_indices[::factor] - print(mode_indices.shape) - loading_info[subject]['mode prob level'] = prob_level - - pickle.dump(mode_indices, open("multi_chain_saves/{}mode_indices_{}_{}.p".format(mode_prefix, subject, fit_type), 'wb')) - consistencies = test.consistency_rsa(indices=mode_indices) - pickle.dump(consistencies, open("multi_chain_saves/{}mode_consistencies_{}_{}.p".format(mode_prefix, subject, fit_type), 'wb')) - - -def conditions_fulfilled(z, xy, conds): - works = z > conds[0] - if conds[1]: - works = np.logical_and(works, xy[0] > conds[1]) - if conds[2]: - works = np.logical_and(works, xy[0] < conds[2]) - if conds[3]: - works = np.logical_and(works, xy[1] > conds[3]) - if conds[4]: - works = np.logical_and(works, xy[1] < conds[4]) - - return works - - -def state_set_and_plot(test, mode_prefix, subject, fit_type): - mode_indices = pickle.load(open("multi_chain_saves/{}mode_indices_{}_{}.p".format(mode_prefix, subject, fit_type), 'rb')) - consistencies = pickle.load(open("multi_chain_saves/{}mode_consistencies_{}_{}.p".format(mode_prefix, subject, fit_type), 'rb')) - session_bounds = list(np.cumsum([len(s) for s in test.results[0].models[-1].stateseqs])) - - import scipy.cluster.hierarchy as hc - consistencies /= consistencies[0, 0] - linkage = hc.linkage(consistencies[0, 0] - consistencies[np.triu_indices(consistencies.shape[0], k=1)], method='complete') - - # R = hc.dendrogram(linkage, truncate_mode='lastp', p=150, no_labels=True) - plt.savefig("peter figures/{}tree_{}_{}".format(mode_prefix, subject, 'complete')) - plt.close() - - session_bounds = list(np.cumsum([len(s) for s in test.results[0].models[-1].stateseqs])) - - fig, ax = plt.subplots(ncols=5, sharey=True, gridspec_kw={'width_ratios': [10, 1, 1, 1, 1]}, figsize=(13, 8)) - from matplotlib.pyplot import cm - for j, criterion in enumerate([0.95, 0.8, 0.5, 0.2]): - clustering_colors = np.zeros((consistencies.shape[0], 100, 4)) - a = hc.fcluster(linkage, criterion, criterion='distance') - b, c = np.unique(a, return_counts=1) - print(b.shape) - print(np.sort(c)) - - if criterion in [0.95]: - state_sets = [] - for x, y in zip(b, c): - state_sets.append(np.where(a == x)[0]) - print("dumping state set") - pickle.dump(state_sets, open("multi_chain_saves/{}state_sets_{}_{}.p".format(mode_prefix, subject, fit_type), 'wb')) - state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='_{}{}'.format(mode_prefix, criterion), show=True, separate_pmf=True, type_coloring=True) - - # I think this is about finding out where states start and how long they last - # state_id, session_appears = np.where(states) - # for s in np.unique(state_id): - # state_appear_mode.append(session_appears[state_id == s][0] / (test.results[0].n_sessions - 1)) - - # cmap = cm.rainbow(np.linspace(0, 1, 17))#len([x for x, y in zip(b, c) if y > 50]))) - # rank_to_color_place = dict(zip(range(17), [0, 16, 8, 4, 12, 2, 6, 10, 14, 1, 3, 5, 7, 9, 11, 13, 15])) # handcrafted to maximise color distance, I think - i = -1 - b = [x for _, x in sorted(zip(c, b))][::-1] - c = [x for x, _ in sorted(zip(c, b))][::-1] - for x, y in zip(b, c): - if y > 50: - i += 1 - # clustering_colors[a == x] = cmap[rank_to_color_place[i]] - clustering_colors[a == x] = cm.rainbow(np.mean(np.where(a == x)[0]) / a.shape[0]) - - ax[j+1].imshow(clustering_colors, aspect='auto', origin='upper') - for sb in session_bounds: - ax[j+1].axhline(sb, color='k') - ax[j+1].set_xticks([]) - ax[j+1].set_yticks([]) - ax[j+1].set_title("{}%".format(int(criterion * 100)), size=20) - - ax[0].imshow(consistencies, aspect='auto', origin='upper') - for sb in session_bounds: - ax[0].axhline(sb, color='k') - ax[0].set_xticks([]) - ax[0].set_yticks(session_bounds[::-1]) - ax[0].set_yticklabels(session_bounds[::-1], size=18) - ax[0].set_ylim(session_bounds[-1], 0) - ax[0].set_ylabel("Trials", size=28) - plt.yticks(rotation=45) - - plt.tight_layout() - plt.savefig("peter figures/{}clustered_trials_{}_{}".format(mode_prefix, subject, 'criteria comp').replace('.', '_')) - plt.close() - - def state_pmfs(test, trials, indices): def func_init(): return {'pmfs': [], 'session_js': [], 'pmf_weights': []} @@ -863,7 +690,7 @@ def state_pmfs(test, trials, indices): results['pmf'] = np.zeros(test.results[0].n_contrasts) results['pmf_weight'] = np.zeros(4) - def second_for(m, j, session_trials, trial_counter, results): + def second_for(m, j, counter, session_trials, trial_counter, results): states, counts = np.unique(m.stateseqs[j][session_trials - trial_counter], return_counts=True) for sub_state, c in zip(states, counts): results['pmf'] += weights_to_pmf(m.obs_distns[sub_state].weights[j]) * c / session_trials.shape[0] @@ -884,7 +711,7 @@ def state_weights(test, trials, indices): def first_for(test, results): results['weights'] = np.zeros(test.results[0].models[0].obs_distns[0].weights.shape[1]) - def second_for(m, j, session_trials, trial_counter, results): + def second_for(m, j, counter, session_trials, trial_counter, results): states, counts = np.unique(m.stateseqs[j][session_trials - trial_counter], return_counts=True) for sub_state, c in zip(states, counts): results['weights'] += m.obs_distns[sub_state].weights[j] * c / session_trials.shape[0] @@ -907,7 +734,7 @@ def lapse_sides(test, state_sets, indices): def first_for(test, results): results['pmf'] = np.zeros(test.results[0].n_contrasts) - def second_for(m, j, session_trials, trial_counter, results): + def second_for(m, j, counter, session_trials, trial_counter, results): states, counts = np.unique(m.stateseqs[j][session_trials - trial_counter], return_counts=True) for sub_state, c in zip(states, counts): results['pmf'] += weights_to_pmf(m.obs_distns[sub_state].weights[j]) * c / session_trials.shape[0] @@ -1102,6 +929,8 @@ def state_development(test, state_sets, indices, save=True, save_append='', show if test.results[0].name.startswith('GLM_Sim_'): print("./glm sim mice/truth_{}.p".format(test.results[0].name)) truth = pickle.load(open("./glm sim mice/truth_{}.p".format(test.results[0].name), "rb")) + # truth['state_posterior'] = truth['state_posterior'][:, [0, 1, 3, 4, 5, 6, 7]] # 17, mode 2 + # truth['weights'] = [w for i, w in enumerate(truth['weights']) if i != 2] # 17, mode 2 states_by_session = np.zeros((len(state_sets), test.results[0].n_sessions)) trial_counter = 0 @@ -1172,7 +1001,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show counter += 1 pmfs_to_score.append(np.mean(pmfs)) # test.state_mapping = dict(zip(range(len(state_sets)), np.argsort(np.argsort(pmfs_to_score)))) # double argsort for ranks - test.state_mapping = dict(zip(range(len(state_sets)), np.flip(np.argsort((states_by_session != 0).argmax(axis=1))))) + test.state_mapping = dict(zip(np.flip(np.argsort((states_by_session != 0).argmax(axis=1))), range(len(state_sets)))) for state, trials in enumerate(state_sets): cmap = matplotlib.cm.get_cmap(cmaps[state]) if state < len(cmaps) else matplotlib.cm.get_cmap('Greys') @@ -1218,7 +1047,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show test.state_mapping[state] + states_by_session[state] - 0.5, color=state_color) else: - n_points = 150 + n_points = 400 points = np.linspace(1, test.results[0].n_sessions, n_points) interpolation = np.interp(points, np.arange(1, 1 + test.results[0].n_sessions), states_by_session[state]) @@ -1229,7 +1058,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show ax1.annotate(len(state_sets) - test.state_mapping[state], (test.results[0].n_sessions + 0.1, test.state_mapping[state] - 0.15), fontsize=22, annotation_clip=False) if test.results[0].name.startswith('GLM_Sim_'): - ax1.plot(range(1, 1 + test.results[0].n_sessions), truth['state_map'][test.state_mapping[state]] + truth['state_posterior'][:, state] - 0.5, color='r') + ax1.plot(range(1, 1 + test.results[0].n_sessions), state + truth['state_posterior'][:, state] - 0.5, color='r') alpha_level = 0.3 ax2.axvline(0.5, c='grey', alpha=alpha_level, zorder=4) @@ -1259,7 +1088,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show if test.results[0].name.startswith('GLM_Sim_'): sim_pmf = weights_to_pmf(truth['weights'][state]) - ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), sim_pmf[defined_points] - 0.5 + truth['state_map'][test.state_mapping[state]], color='r') + ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), sim_pmf[defined_points] - 0.5 + state, color='r') if not test.state_mapping[state] in dont_plot: ax2.annotate("Type {}".format(1 + pmf_type(pmf[defined_points])), (1.05, test.state_mapping[state] - 0.37), rotation=90, size=13, color=type2color[pmf_type(pmf)], annotation_clip=False) @@ -1354,15 +1183,71 @@ def state_development(test, state_sets, indices, save=True, save_append='', show plt.tight_layout() if save: - # print("saving with {} dpi".format(dpi)) + print("saving with {} dpi".format(dpi)) plt.savefig("dynamic_GLM_figures/meta_state_development_{}_{}{}.png".format(test.results[0].name, separate_pmf, save_append), dpi=dpi) if show: plt.show() else: plt.close() + if subject.startswith("GLM_Sim_"): + plt.figure(figsize=(16, 9)) + for state, trials in enumerate(state_sets): + dur_params = dur_hists(test, trials, indices) + + plt.subplot(4, 4, 1 + 2 * test.state_mapping[state]) + plt.hist(dur_params[:, 0]) + plt.axvline(truth['durs'][state][0], color='red') + + plt.subplot(4, 4, 2 + 2 * test.state_mapping[state]) + plt.hist(dur_params[:, 1]) + plt.axvline(truth['durs'][state][1], color='red') + + plt.tight_layout() + plt.savefig("dur hists") + plt.show() + + from scipy.stats import nbinom + points = np.arange(900) + plt.figure(figsize=(16, 9)) + for state, trials in enumerate(state_sets): + dur_params = dur_hists(test, trials, indices) + + plt.subplot(2, 4, 1 + test.state_mapping[state]) + plt.plot(nbinom.pmf(points, np.mean(dur_params[:, 0]), np.mean(dur_params[:, 1]))) + plt.plot(nbinom.pmf(points, truth['durs'][state][0], truth['durs'][state][1]), color='red') + plt.xlabel("# of trials") + plt.ylabel("P") + + plt.tight_layout() + plt.savefig("dur dists") + plt.show() + return states_by_session, all_pmfs, all_pmf_weights, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs)), introductions_by_stage, states_per_type +def dur_hists(test, trials, indices): + def func_init(): return {'dur_params': np.zeros((len(indices), 2))} + + def first_for(test, results): + pass + + def second_for(m, j, counter, session_trials, trial_counter, results): + states, counts = np.unique(m.stateseqs[j][session_trials - trial_counter], return_counts=True) + for sub_state, c in zip(states, counts): + try: + p = m.dur_distns[sub_state].p_save # this cannot be accessed after deleting data, but not every model's data is deleted + except: + p = m.dur_distns[sub_state].p + results['dur_params'][counter] += np.array([m.dur_distns[sub_state].r * c / session_trials.shape[0], p * c / session_trials.shape[0]]) + + def end_first_for(results, indices, j, **kwargs): + pass + + results = control_flow(test, indices, trials, func_init, first_for, second_for, end_first_for) + results['dur_params'] = results['dur_params'] / len(indices) + return results['dur_params'] + + def compare_pmfs(test, states2compare, states_by_session, all_pmfs, title=""): """ Take a set of states, and plot out their PMFs on all sessions on which they occur. @@ -1645,14 +1530,19 @@ if __name__ == "__main__": fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0] if fit_type == 'bias': loading_info = json.load(open("canonical_infos_bias.json", 'r')) - r_hats = json.load(open("canonical_info_r_hats_bias.json", 'r')) elif fit_type == 'prebias': loading_info = json.load(open("canonical_infos.json", 'r')) - r_hats = json.load(open("canonical_info_r_hats.json", 'r')) no_good_pcas = ['NYU-06', 'SWC_023'] subjects = list(loading_info.keys()) - # subjects = ['SWC_021', 'ibl_witten_15', 'ibl_witten_13', 'KS003', 'ibl_witten_19', 'SWC_022', 'CSH_ZAD_017'] - subjects = ['KS021'] + + new_subs = ['KS044', 'NYU-11', 'DY_016', 'SWC_061', 'ZFM-05245', 'CSH_ZAD_029', 'SWC_021', 'CSHL058', 'DY_014', 'DY_009', 'KS094', 'DY_018', 'KS043', 'UCLA014', 'SWC_038', 'SWC_022', 'UCLA012', 'UCLA011', 'CSHL055', 'ZFM-04019', 'NYU-45', 'ZFM-02370', 'ZFM-02373', 'ZFM-02369', 'NYU-40', 'CSHL060', 'NYU-30', 'CSH_ZAD_019', 'UCLA017', 'KS052', 'ibl_witten_25', 'ZFM-02368', 'CSHL045', 'UCLA005', 'SWC_058', 'CSH_ZAD_024', 'SWC_042', 'DY_008', 'ibl_witten_13', 'SWC_043', 'KS046', 'DY_010', 'CSHL053', 'ZM_1898', 'UCLA033', 'NYU-47', 'DY_011', 'CSHL047', 'SWC_054', 'ibl_witten_19', 'ibl_witten_27', 'KS091', 'KS055', 'CSH_ZAD_017', 'UCLA035', 'SWC_060', 'DY_020', 'ZFM-01577', 'ZM_2240', 'ibl_witten_29', 'KS096', 'SWC_066', 'DY_013', 'ZFM-01592', 'GLM_Sim_17', 'NYU-48', 'UCLA006', 'NYU-39', 'KS051', 'NYU-27', 'NYU-46', 'ZFM-01936', 'ZFM-02372', 'ZFM-01935', 'ibl_witten_26', 'ZFM-05236', 'ZM_2241', 'NYU-37', 'KS086', 'KS084', 'ZFM-01576', 'KS042'] + + miss_count = 0 + for s in new_subs: + if s not in subjects: + print(s) + miss_count += 1 + quit() print(subjects) fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0] @@ -1660,8 +1550,6 @@ if __name__ == "__main__": # fig, ax = plt.subplots(1, 3, sharey=True, figsize=(16, 9)) - thinning = 25 - summary_info = {"thinning": thinning, "contains": [], "seeds": [], "fit_nums": []} pop_state_starts = np.zeros(20) state_appear_dist = np.zeros(10) state_appear_mode = [] @@ -1703,227 +1591,173 @@ if __name__ == "__main__": for subject in subjects: - if subject.startswith('GLM_Sim_') or subject == 'ibl_witten_18': + # NYU-11 is quite weird, super errrativ behaviour, has all contrasts introduced at once, no good session at end + if subject.startswith('GLM_Sim_'): continue print() print(subject) - try: - test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'rb')) + test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'rb')) - # mode_specifier = '' + mode_specifier = 'first' + try: + mode_indices = pickle.load(open("multi_chain_saves/{}_mode_indices_{}_{}.p".format(mode_specifier, subject, fit_type), 'rb')) + state_sets = pickle.load(open("multi_chain_saves/{}_state_sets_{}_{}.p".format(mode_specifier, subject, fit_type), 'rb')) + except: try: - mode_indices = pickle.load(open("multi_chain_saves/first_mode_indices_{}_{}.p".format(subject, fit_type), 'rb')) - state_sets = pickle.load(open("multi_chain_saves/first_state_sets_{}_{}.p".format(subject, fit_type), 'rb')) + mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'rb')) + state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb')) except: - try: - mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'rb')) - state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb')) - except: - continue + continue - # lapse differential - # lapse_sides(test, [s for s in state_sets if len(s) > 40], mode_indices) - - # training overview - # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 0', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(8)), plot_until=-1) - # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 1', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(7)), plot_until=2) - # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 2', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(6)), plot_until=7) - # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 3', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(4)), plot_until=13) - states, pmfs, pmf_weights, durs, state_types, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True) - # state_nums_5.append((states > 0.05).sum(0)) - # state_nums_10.append((states > 0.1).sum(0)) - # continue - a = compare_params(test, 26, [s for s in state_sets if len(s) > 40], mode_indices, [3, 5]) - compare_pmfs(test, [3, 5], states, pmfs, title="{} convergence pmf".format(subject)) - quit() - new = type_2_appearance(states, pmfs) - - if new == 2: - print('____________________________') - print(subject) - print('____________________________') - if new == 1: - new_counter += 1 - if new == 0: - transform_counter += 1 - print(new_counter, transform_counter) - - consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb')) - consistencies /= consistencies[0, 0] - temp = contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True, consistencies=consistencies, CMF=False) - continue + # lapse differential + # lapse_sides(test, [s for s in state_sets if len(s) > 40], mode_indices) - state_dict = write_results(test, [s for s in state_sets if len(s) > 40], mode_indices) - pickle.dump(state_dict, open("state_dict_{}".format(subject), 'wb')) - quit() - # abs_state_durs.append(durs) - # continue - # all_pmf_weights += pmf_weights - # all_state_types.append(state_types) - # - # state_types_interpolation[0] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[0]) - # state_types_interpolation[1] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[1]) - # state_types_interpolation[2] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[2]) - # - # all_first_pmfs_typeless[subject] = [] - # for pmf in pmfs: - # all_first_pmfs_typeless[subject].append((pmf[0], pmf[1][0])) - # all_pmfs.append(pmf) - # - # first_pmfs, changing_pmfs = get_first_pmfs(states, pmfs) - # for pmf in changing_pmfs: - # if type(pmf[0]) == int: - # continue - # all_changing_pmf_names.append(subject) - # all_changing_pmfs.append(pmf) - # - # all_first_pmfs[subject] = first_pmfs - # continue - # - # consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb')) - # consistencies /= consistencies[0, 0] - # contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=False, consistencies=consistencies, CMF=False) - - - # b_flips = bias_flips(states, pmfs, durs) - # all_bias_flips.append(b_flips) - - # regression, diffs = pmf_regressions(states, pmfs, durs) - # regression_diffs += diffs - # regressions.append(regression) - # continue - # compare_pmfs(test, [s for s in state_sets if len(s) > 40], mode_indices, [4, 5], states, pmfs, title="{} convergence pmf".format(subject)) - # compare_weights(test, [s for s in state_sets if len(s) > 40], mode_indices, [4, 5], states, title="{} convergence weights".format(subject)) - # quit() - # all_intros.append(undiv_intros) - # all_intros_div.append(intros_by_type) - # if states_per_type != []: - # all_states_per_type.append(states_per_type) - # - # intros_by_type_sum += intros_by_type - # for pmf in pmfs: - # all_pmfs.append(pmf) - # for p in pmf[1]: - # all_pmf_diffs.append(p[-1] - p[0]) - # all_pmf_asymms.append(np.abs(p[0] + p[-1] - 1)) - # contrast_intro_types.append(contrast_intro_type) - - # num_states.append(np.mean(test.state_num_dist())) - # num_sessions.append(test.results[0].n_sessions) - - # a, b = np.where(states) - # for i in range(states.shape[0]): - # state_appear.append(b[a == i][0] / (test.results[0].n_sessions - 1)) - # state_dur.append(b[a == i].shape[0]) - - # state_development_single_sample(test, [mode_indices[0]], show=True, separate_pmf=True, save=False) - - # session overview - # consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb')) - # consistencies /= consistencies[0, 0] - # c_n_a, rt_data = contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=False, consistencies=consistencies, CMF=True) - # compare_performance(cnas, (1, 1), title="{} contrast {} performance test".format(subject, (1, 1))) - # compare_performance(cnas, (0, 0.848)) - - # duration of different state types (and also percentage of type activities) - # continue - # simplex_durs = np.array(durs).reshape(1, 3) - # print(simplex_durs / np.sum(simplex_durs)) - # from simplex_plot import projectSimplex - # print(projectSimplex(simplex_durs / simplex_durs.sum(1)[:, None])) - - # compute state type proportions and split the pmfs accordingly - # ret, trans, pmf_types = state_cluster_interpolation(states, pmfs) - # state_trans += trans # matrix of how often the different types transition into one another - # for i, r in enumerate(ret): - # state_trajs[i] += np.interp(np.linspace(1, test.results[0].n_sessions, n_points), np.arange(1, 1 + test.results[0].n_sessions), r) # for interpolated population state trajectories - # plot_pmf_types(pmf_types, subject=subject, fit_type=fit_type) - # continue - # plt.plot(ret.T, label=[0, 1, 2]) - # plt.legend() - # plt.show() - - # Analyse single sample - # xy, z = pickle.load(open("multi_chain_saves/xyz_{}.p".format(subject), 'rb')) - # - # single_sample = [np.argmax(z)] - # for position, index in enumerate(np.where(np.logical_and(z > 2.7e-7, xy[0] < -500))[0]): - # print(position, index, index // test.n, index % test.n) - # states, pmfs = state_development_single_sample(test, [index], save_append='_{}_{}'.format('single_sample', position), show=True, separate_pmf=True) - # if position == 9: - # quit() - # quit() - - # all_state_starts = test.state_appearance_posterior(subject) - # pop_state_starts += all_state_starts - # - # a, b, c = test.state_start_and_dur() - # state_appear += a - # state_dur += b - # state_appear_dist += c - - - # all_state_starts = test.state_appearance_posterior(subject) - # plt.plot(all_state_starts) - # plt.savefig("temp") - # plt.close() - - print('Computing sub result') - create_mode_indices(test, subject, fit_type) - state_set_and_plot(test, 'first_', subject, fit_type) - print("second mode?") - if input() in ['y', 'yes']: - state_set_and_plot(test, 'second_', subject, fit_type) - - except FileNotFoundError as e: - print(e) - print('no canoncial result') - continue - print(r_hats[subject]) - if r_hats[subject] >= 1.05: - print("Skipping") - continue - else: - print("Making canonical result") - results = [] - for j, (seed, fit_num) in enumerate(zip(loading_info[subject]['seeds'], loading_info[subject]['fit_nums'])): - if j in loading_info[subject]['ignore']: - continue - summary_info["contains"].append(j) - summary_info["seeds"].append(seed) - summary_info["fit_nums"].append(fit_num) - info_dict = pickle.load(open("./session_data/{}_info_dict.p".format(subject), "rb")) - - samples = [] - mini_counter = 1 # start from one here, first 4000 as burn-in - while True: - try: - file = "./dynamic_GLMiHMM_crossvals/{}_fittype_{}_var_{}_{}_{}{}.p".format(subject, fit_type, fit_variance, seed, fit_num, '_{}'.format(mini_counter)) - lala = time.time() - samples += pickle.load(open(file, "rb")) - print("Loading {} took {:.4}".format(mini_counter, time.time() - lala)) - mini_counter += 1 - except Exception: - break - save_id = "{}_fittype_{}_var_{}_{}_{}.p".format(subject, fit_type, fit_variance, seed, fit_num).replace('.', '_') - - print("Loaded") - - result = MCMC_result(samples[::50], infos=info_dict, - data=samples[0].datas, sessions=fit_type, fit_variance=fit_variance, - dur=dur, save_id=save_id) - results.append(result) - test = MCMC_result_list(results, summary_info) - pickle.dump(test, open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'wb')) - - test.r_hat_and_ess(state_num_helper(0.05), False) - test.r_hat_and_ess(state_num_helper(0.02), False) - test.r_hat_and_ess(state_size_helper(), False) - test.r_hat_and_ess(state_size_helper(1), False) - test.r_hat_and_ess(gamma_func, True) - test.r_hat_and_ess(alpha_func, True) + # training overview + # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 0', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(8)), plot_until=-1) + # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 1', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(7)), plot_until=2) + # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 2', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(6)), plot_until=7) + # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 3', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(4)), plot_until=13) + states, pmfs, pmf_weights, durs, state_types, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True) + quit() + # state_nums_5.append((states > 0.05).sum(0)) + # state_nums_10.append((states > 0.1).sum(0)) + # continue + # a = compare_params(test, 26, [s for s in state_sets if len(s) > 40], mode_indices, [3, 5]) + # compare_pmfs(test, [3, 2, 4], states, pmfs, title="{} convergence pmf".format(subject)) + consistencies = pickle.load(open("multi_chain_saves/first_mode_consistencies_{}_{}.p".format(subject, fit_type), 'rb')) + consistencies /= consistencies[0, 0] + temp = contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True, consistencies=consistencies, CMF=False) + continue + quit() + new = type_2_appearance(states, pmfs) + + if new == 2: + print('____________________________') + print(subject) + print('____________________________') + if new == 1: + new_counter += 1 + if new == 0: + transform_counter += 1 + print(new_counter, transform_counter) + + + state_dict = write_results(test, [s for s in state_sets if len(s) > 40], mode_indices) + pickle.dump(state_dict, open("state_dict_{}".format(subject), 'wb')) + quit() + # abs_state_durs.append(durs) + # continue + # all_pmf_weights += pmf_weights + # all_state_types.append(state_types) + # + # state_types_interpolation[0] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[0]) + # state_types_interpolation[1] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[1]) + # state_types_interpolation[2] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[2]) + # + # all_first_pmfs_typeless[subject] = [] + # for pmf in pmfs: + # all_first_pmfs_typeless[subject].append((pmf[0], pmf[1][0])) + # all_pmfs.append(pmf) + # + # first_pmfs, changing_pmfs = get_first_pmfs(states, pmfs) + # for pmf in changing_pmfs: + # if type(pmf[0]) == int: + # continue + # all_changing_pmf_names.append(subject) + # all_changing_pmfs.append(pmf) + # + # all_first_pmfs[subject] = first_pmfs + # continue + # + # consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb')) + # consistencies /= consistencies[0, 0] + # contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=False, consistencies=consistencies, CMF=False) + + + # b_flips = bias_flips(states, pmfs, durs) + # all_bias_flips.append(b_flips) + + # regression, diffs = pmf_regressions(states, pmfs, durs) + # regression_diffs += diffs + # regressions.append(regression) + # continue + # compare_pmfs(test, [s for s in state_sets if len(s) > 40], mode_indices, [4, 5], states, pmfs, title="{} convergence pmf".format(subject)) + # compare_weights(test, [s for s in state_sets if len(s) > 40], mode_indices, [4, 5], states, title="{} convergence weights".format(subject)) + # quit() + # all_intros.append(undiv_intros) + # all_intros_div.append(intros_by_type) + # if states_per_type != []: + # all_states_per_type.append(states_per_type) + # + # intros_by_type_sum += intros_by_type + # for pmf in pmfs: + # all_pmfs.append(pmf) + # for p in pmf[1]: + # all_pmf_diffs.append(p[-1] - p[0]) + # all_pmf_asymms.append(np.abs(p[0] + p[-1] - 1)) + # contrast_intro_types.append(contrast_intro_type) + + # num_states.append(np.mean(test.state_num_dist())) + # num_sessions.append(test.results[0].n_sessions) + + # a, b = np.where(states) + # for i in range(states.shape[0]): + # state_appear.append(b[a == i][0] / (test.results[0].n_sessions - 1)) + # state_dur.append(b[a == i].shape[0]) + + # state_development_single_sample(test, [mode_indices[0]], show=True, separate_pmf=True, save=False) + + # session overview + # consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb')) + # consistencies /= consistencies[0, 0] + # c_n_a, rt_data = contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=False, consistencies=consistencies, CMF=True) + # compare_performance(cnas, (1, 1), title="{} contrast {} performance test".format(subject, (1, 1))) + # compare_performance(cnas, (0, 0.848)) + + # duration of different state types (and also percentage of type activities) + # continue + # simplex_durs = np.array(durs).reshape(1, 3) + # print(simplex_durs / np.sum(simplex_durs)) + # from simplex_plot import projectSimplex + # print(projectSimplex(simplex_durs / simplex_durs.sum(1)[:, None])) + + # compute state type proportions and split the pmfs accordingly + # ret, trans, pmf_types = state_cluster_interpolation(states, pmfs) + # state_trans += trans # matrix of how often the different types transition into one another + # for i, r in enumerate(ret): + # state_trajs[i] += np.interp(np.linspace(1, test.results[0].n_sessions, n_points), np.arange(1, 1 + test.results[0].n_sessions), r) # for interpolated population state trajectories + # plot_pmf_types(pmf_types, subject=subject, fit_type=fit_type) + # continue + # plt.plot(ret.T, label=[0, 1, 2]) + # plt.legend() + # plt.show() + + # Analyse single sample + # xy, z = pickle.load(open("multi_chain_saves/xyz_{}.p".format(subject), 'rb')) + # + # single_sample = [np.argmax(z)] + # for position, index in enumerate(np.where(np.logical_and(z > 2.7e-7, xy[0] < -500))[0]): + # print(position, index, index // test.n, index % test.n) + # states, pmfs = state_development_single_sample(test, [index], save_append='_{}_{}'.format('single_sample', position), show=True, separate_pmf=True) + # if position == 9: + # quit() + # quit() + + # all_state_starts = test.state_appearance_posterior(subject) + # pop_state_starts += all_state_starts + # + # a, b, c = test.state_start_and_dur() + # state_appear += a + # state_dur += b + # state_appear_dist += c + + + # all_state_starts = test.state_appearance_posterior(subject) + # plt.plot(all_state_starts) + # plt.savefig("temp") + # plt.close() # pickle.dump(all_first_pmfs, open("all_first_pmfs.p", 'wb')) diff --git a/dynamic_GLMiHMM_consistency.py b/dynamic_GLMiHMM_consistency.py index 8bf1924c1a08c92b50d8db433853faa7376093cb..a2aa017456ddf48031bc8af3978a6746102a464d 100644 --- a/dynamic_GLMiHMM_consistency.py +++ b/dynamic_GLMiHMM_consistency.py @@ -87,6 +87,7 @@ for subject in subjects: mega_data[:, 4] = 1 mega_data[:, 5] = data[~bad_trials, 1] - 1 mega_data[:, 5] = (mega_data[:, 5] + 1) / 2 + print(mega_data.sum(0)) posteriormodel.add_data(mega_data) import pyhsmm.util.profiling as prof diff --git a/dynamic_GLMiHMM_fit.py b/dynamic_GLMiHMM_fit.py index bf1d046da4a345335014298a708c541c5c6d4ec3..a16e7376dd52cb9a02256f1782c0ea1ffbd923fb 100644 --- a/dynamic_GLMiHMM_fit.py +++ b/dynamic_GLMiHMM_fit.py @@ -19,15 +19,6 @@ import json import sys -def crp_expec(n, theta): - """ - Return expected number of tables after n customers, given concentration theta. - - From Wikipedia - """ - return theta * (digamma(theta + n) - digamma(theta)) - - def eleven2nine(x): """Map from 11 possible contrasts to 9, for the non-training phases. @@ -87,13 +78,7 @@ subjects = ['ibl_witten_15', 'ibl_witten_17', 'ibl_witten_18', 'ibl_witten_19', # test subjects: subjects = ['KS014'] -# subjects = ['KS021', 'KS016', 'ibl_witten_16', 'SWC_022', 'KS003', 'CSHL054', 'ZM_3003', 'KS015', 'ibl_witten_13', 'CSHL059', 'CSH_ZAD_022', 'CSHL_007', 'CSHL062', 'NYU-06', 'KS014', 'ibl_witten_14', 'SWC_023'] - -# subjects = [['GLM_Sim_15', 'GLM_Sim_14', 'GLM_Sim_13', 'GLM_Sim_11', 'GLM_Sim_10', 'GLM_Sim_09', 'GLM_Sim_12'][2]] -# (0.03, 0.3, 5, 'contR', 'contL', 'prevA', 'bias', 1, 0.1): cv_nums = [15] -# conda activate hdp_pg_env -# python dynamic_GLMiHMM_fit.py cv_nums = [200 + int(sys.argv[1]) % 16] subjects = [subjects[int(sys.argv[1]) // 16]] @@ -267,12 +252,6 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)): print("meh, skipped session") continue - # if j == 15: - # import matplotlib.pyplot as plt - # for i in [0, 2, 3,4,5,6,7,8,10]: - # plt.plot(i, data[data[:, 0] == i, 1].mean(), 'ko') - # plt.show() - if params['obs_dur'] == 'glm': for i in range(data.shape[0]): data[i, 0] = num_to_contrast[data[i, 0]] @@ -291,10 +270,7 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)): elif reg == 'cont': mega_data[:, i] = data[mask, 0] elif reg == 'prevA': - # prev_ans = data[:, 1].copy() new_prev_ans = data[:, 1].copy() - # prev_ans[1:] = prev_ans[:-1] - # prev_ans -= 1 new_prev_ans -= 1 new_prev_ans = np.convolve(np.append(0, new_prev_ans), params['exp_filter'])[:-(params['exp_filter'].shape[0])] mega_data[:, i] = new_prev_ans[mask] @@ -334,9 +310,6 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)): posteriormodel.add_data(mega_data) - # for d in posteriormodel.datas: - # print(d.shape) - # if not os.path.isfile('./{}/data_save_{}.p'.format(data_folder, params['subject'])): pickle.dump(data_save, open('./{}/data_save_{}.p'.format(data_folder, params['subject']), 'wb')) quit() diff --git a/index_mice.py b/index_mice.py index 125425e2f6da0099d9282c43d3bc242ffed6ac75..96d35de5b11c000cdb713ffea432e4c8b74473e5 100644 --- a/index_mice.py +++ b/index_mice.py @@ -15,6 +15,8 @@ for filename in os.listdir("./dynamic_GLMiHMM_crossvals/"): regexp = re.compile(r'((\w|-)+)_fittype_(\w+)_var_0.03_(\d+)_(\d+)_(\d+)') result = regexp.search(filename) subject = result.group(1) + if subject == 'ibl_witten_26': + print('here') fit_type = result.group(3) seed = result.group(4) fit_num = result.group(5) @@ -49,8 +51,8 @@ big = [] non_big = [] sim_subjects = [] for s in prebias_subinfo.keys(): - assert len(prebias_subinfo[s]["seeds"]) in [16, 32] - assert len(prebias_subinfo[s]["fit_nums"]) in [16, 32] + # assert len(prebias_subinfo[s]["seeds"]) in [16, 32], s + " " + str(len(prebias_subinfo[s]["seeds"])) + # assert len(prebias_subinfo[s]["fit_nums"]) in [16, 32] print(s, len(prebias_subinfo[s]["fit_nums"])) if len(prebias_subinfo[s]["fit_nums"]) == 32: big.append(s) @@ -58,8 +60,6 @@ for s in prebias_subinfo.keys(): non_big.append(s) if s.startswith("GLM_Sim_"): sim_subjects.append(s) - else: - non_big.append(s) print(non_big) print() diff --git a/multi_chain_saves/.gitignore b/multi_chain_saves/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..b84a8f82fc35d730fb072a10736910aaca487db3 --- /dev/null +++ b/multi_chain_saves/.gitignore @@ -0,0 +1,23 @@ + +old_ana_code/ +figures/ +dynamic_figures/ +iHMM_fits/ +dynamic_iHMM_fits/ +overview_figures/ +beliefs/ +session_data/ +WAIC/ +*.png +*.p +*.npz +*.pdf +*.csv +*.zip +*.json +peter_fiugres/ +consistency_data/ +dynamic_GLM_figures/ +dynamic_GLMiHMM_fits2/ +glm sim mice/ +dynamic_GLMiHMM_crossvals/ diff --git a/multi_chain_saves/work.py b/multi_chain_saves/work.py deleted file mode 100644 index 25f82f9db00bc52d70145533e1375f2c19b21859..0000000000000000000000000000000000000000 --- a/multi_chain_saves/work.py +++ /dev/null @@ -1,15 +0,0 @@ -import os - -i = 0 -for filename in os.listdir("./"): - if not filename.endswith('.p'): - continue - if 'bias' in filename: - continue - - if not filename.endswith('bias.p'): - i += 1 - print("Rename {} into {}".format(filename, filename[:-2] + '_prebias.p')) - os.rename(filename, filename[:-2] + '_prebias.p') - -print(i) diff --git a/raw_fit_processing.py b/raw_fit_processing_part1.py similarity index 60% rename from raw_fit_processing.py rename to raw_fit_processing_part1.py index 7a6c08b1abe03d177ccaf1970a233df70e69584b..82807f9ccd0a22da1f37d7f1bdd76070fa536d4f 100644 --- a/raw_fit_processing.py +++ b/raw_fit_processing_part1.py @@ -5,7 +5,7 @@ Script for taking a list of subects and extracting statistics from the chains which can be used to assess which chains have converged to the same regions - + This cannot be run in parallel (because the loading_info dict gets changed and dumped) """ import numpy as np import pyhsmm @@ -13,8 +13,9 @@ import pickle import json from dyn_glm_chain_analysis import MCMC_result import time -from mcmc_chain_analysis import state_size_helper, state_num_helper, find_good_chains_unsplit_greedy +from mcmc_chain_analysis import state_size_helper, state_num_helper, find_good_chains_unsplit_greedy, gamma_func, alpha_func import index_mice # executes function for creating dict of available fits +from dyn_glm_chain_analysis import MCMC_result_list fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0] @@ -26,9 +27,13 @@ elif fit_type == 'prebias': loading_info = json.load(open("canonical_infos.json", 'r')) r_hats = json.load(open("canonical_info_r_hats.json", 'r')) -subjects = ['ZFM-04019', 'ZFM-05236'] # list(loading_info.keys()) +# done: 'NYU-45', 'UCLA035', 'NYU-30', 'CSHL047', 'NYU-39', 'NYU-37', +# can't: 'UCLA006' +subjects = ['NYU-40', 'NYU-46', 'KS044', 'NYU-48'] +# 'UCLA012', 'CSHL052', 'NYU-11', 'UCLA011', 'NYU-47', 'CSHL045', 'UCLA017', 'CSHL055', 'UCLA005', 'CSHL060', 'UCLA015', 'UCLA014', 'CSHL053', 'NYU-12', 'CSHL058', 'KS042'] + -thinning = 25 +thinning = 50 fit_variance = 0.03 func1 = state_num_helper(0.2) @@ -37,13 +42,15 @@ func3 = state_size_helper() func4 = state_size_helper(1) dur = 'yes' -m = len(loading_info[subjects]["fit_nums"]) -assert m == 16 for subject in subjects: + results = [] + summary_info = {"thinning": thinning, "contains": [], "seeds": [], "fit_nums": []} + m = len(loading_info[subject]["fit_nums"]) + assert m == 16 print(subject) n_runs = -1 counter = -1 - n = (loading_info[subject]['chain_num'] + 1) * 4000 // thinning + n = (loading_info[subject]['chain_num']) * 4000 // thinning chains1 = np.zeros((m, n)) chains2 = np.zeros((m, n)) chains3 = np.zeros((m, n)) @@ -53,11 +60,27 @@ for subject in subjects: print(seed) info_dict = pickle.load(open("./session_data/{}_info_dict.p".format(subject), "rb")) samples = [] - for mini_counter in range(m): - file = "./dynamic_GLMiHMM_crossvals/{}_fittype_{}_var_{}_{}_{}{}.p".format(subject, fit_type, fit_variance, seed, fit_num, '_{}'.format(mini_counter)) - lala = time.time() - samples += pickle.load(open(file, "rb")) - print("Loading {} took {:.4}".format(mini_counter, time.time() - lala)) + + mini_counter = 1 # start at 1, discard first 4000 as burnin + while True: + try: + file = "./dynamic_GLMiHMM_crossvals/{}_fittype_{}_var_{}_{}_{}{}.p".format(subject, fit_type, fit_variance, seed, fit_num, '_{}'.format(mini_counter)) + lala = time.time() + samples += pickle.load(open(file, "rb")) + print("Loading {} took {:.4}".format(mini_counter, time.time() - lala)) + mini_counter += 1 + except Exception: + break + + if n_runs == -1: + n_runs = mini_counter + else: + if n_runs != mini_counter: + print("Problem") + print(n_runs, mini_counter) + quit() + + save_id = "{}_fittype_{}_var_{}_{}_{}.p".format(subject, fit_type, fit_variance, seed, fit_num).replace('.', '_') print("loaded seed {}".format(seed)) @@ -65,7 +88,8 @@ for subject in subjects: result = MCMC_result(samples[::thinning], infos=info_dict, data=samples[0].datas, sessions=fit_type, fit_variance=fit_variance, - dur=dur, save_id=save_id)) + dur=dur, save_id=save_id) + results.append(result) print("Making result {} took {:.4}".format(counter, time.time() - lala)) res = func1(result) @@ -82,8 +106,6 @@ for subject in subjects: pickle.dump(chains3, open("multi_chain_saves/{}_largest_state_0_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb')) pickle.dump(chains4, open("multi_chain_saves/{}_largest_state_1_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb')) - r_hats = {} - # R^hat tests # test = MCMC_result_list([fake_result(100) for i in range(8)]) # test.r_hat_and_ess(return_ascending, False) @@ -94,10 +116,6 @@ for subject in subjects: continue print() print("Checking R^hat, finding best subset of chains") - chains1 = chains1[:, 160:] - chains2 = chains2[:, 160:] - chains3 = chains3[:, 160:] - chains4 = chains4[:, 160:] # mins = find_good_chains(chains[:, :-1].reshape(32, chains.shape[1] // 2)) sol, final_r_hat = find_good_chains_unsplit_greedy(chains1, chains2, chains3, chains4, reduce_to=chains1.shape[0] // 2) @@ -110,4 +128,27 @@ for subject in subjects: json.dump(r_hats, open("canonical_info_r_hats_bias.json", 'w')) elif fit_type == 'prebias': json.dump(loading_info, open("canonical_infos.json", 'w')) - json.dump(r_hats, open("canonical_info_r_hats.json", 'w')) \ No newline at end of file + json.dump(r_hats, open("canonical_info_r_hats.json", 'w')) + + if r_hats[subject] >= 1.05: + print("Skipping canonical result") + continue + else: + print("Making canonical result") + + # subset data + summary_info['contains'] = [i for i in range(m) if i not in sol] + summary_info['seeds'] = [loading_info[subject]['seeds'][i] for i in summary_info['contains']] + summary_info['fit_nums'] = [loading_info[subject]['fit_nums'] for i in summary_info['contains']] + results = [results[i] for i in summary_info['contains']] + + test = MCMC_result_list(results, summary_info) + pickle.dump(test, open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'wb')) + + test.r_hat_and_ess(state_num_helper(0.2), False) + test.r_hat_and_ess(state_num_helper(0.1), False) + test.r_hat_and_ess(state_num_helper(0.05), False) + test.r_hat_and_ess(state_size_helper(), False) + test.r_hat_and_ess(state_size_helper(1), False) + test.r_hat_and_ess(gamma_func, True) + test.r_hat_and_ess(alpha_func, True) diff --git a/raw_fit_processing_part2.py b/raw_fit_processing_part2.py new file mode 100644 index 0000000000000000000000000000000000000000..4a0835c4357c6dd189b4ae53cf06119cba4e9fc3 --- /dev/null +++ b/raw_fit_processing_part2.py @@ -0,0 +1,238 @@ +import json +import numpy as np +from dyn_glm_chain_analysis import MCMC_result_list +import pickle +import matplotlib.pyplot as plt +import os + + +def create_mode_indices(test, subject, fit_type): + dim = 3 + + try: + xy, z = pickle.load(open("multi_chain_saves/xyz_{}_{}.p".format(subject, fit_type), 'rb')) + except Exception: + return + print('Doing PCA') + ev, eig, projection_matrix, dimreduc = test.state_pca(subject, pca_type='dists', dim=dim) + xy = np.vstack([dimreduc[i] for i in range(dim)]) + from scipy.stats import gaussian_kde + z = gaussian_kde(xy)(xy) + pickle.dump((xy, z), open("multi_chain_saves/xyz_{}_{}.p".format(subject, fit_type), 'wb')) + + print("Mode indices of " + subject) + + threshold_search(xy, z, test, 'first_', subject, fit_type) + + print("Find another mode?") + if input() not in ['yes', 'y']: + return + + threshold_search(xy, z, test, 'second_', subject, fit_type) + + print("Find another mode?") + if input() not in ['yes', 'y']: + return + + threshold_search(xy, z, test, 'third_', subject, fit_type) + return + + +def threshold_search(xy, z, test, mode_prefix, subject, fit_type): + happy = False + conds = [0, None, None, None, None, None, None] + x_min, x_max, y_min, y_max, z_min, z_max = None, None, None, None, None, None + while not happy: + print() + print("Pick level") + prob_level = input() + if prob_level == 'cond': + print("x > ?") + resp = input() + if resp not in ['n', 'no']: + x_min = float(resp) + else: + x_min = None + + print("x < ?") + resp = input() + if resp not in ['n', 'no']: + x_max = float(resp) + else: + x_max = None + + print("y > ?") + resp = input() + if resp not in ['n', 'no']: + y_min = float(resp) + else: + y_min = None + + print("y < ?") + resp = input() + if resp not in ['n', 'no']: + y_max = float(resp) + else: + y_max = None + + print("z > ?") + resp = input() + if resp not in ['n', 'no']: + z_min = float(resp) + else: + z_min = None + + print("z < ?") + resp = input() + if resp not in ['n', 'no']: + z_max = float(resp) + else: + z_max = None + + print("Prob level") + prob_level = float(input()) + conds = [prob_level, x_min, x_max, y_min, y_max, z_min, z_max] + print("Condtions are {}".format(conds)) + else: + try: + prob_level = float(prob_level) + except: + print('mistake') + prob_level = float(input) + conds[0] = prob_level + + print("Level is {}".format(prob_level)) + + mode = conditions_fulfilled(z, xy, conds) + print("# of samples: {}".format(mode.sum())) + mode_indices = np.where(mode)[0] + if mode.sum() > 0: + print(xy[0][mode_indices].min(), xy[0][mode_indices].max(), xy[1][mode_indices].min(), xy[1][mode_indices].max(), xy[2][mode_indices].min(), xy[2][mode_indices].max()) + print("Happy?") + happy = 'yes' == input() + + print("Subset by factor?") + if input() == 'yes': + print("Factor?") + print(mode_indices.shape) + factor = int(input()) + mode_indices = mode_indices[::factor] + print(mode_indices.shape) + if subject not in loading_info: + loading_info[subject] = {} + loading_info[subject]['mode prob level'] = prob_level + + pickle.dump(mode_indices, open("multi_chain_saves/{}mode_indices_{}_{}.p".format(mode_prefix, subject, fit_type), 'wb')) + # consistencies = test.consistency_rsa(indices=mode_indices) # do this on the cluster from now on + # pickle.dump(consistencies, open("multi_chain_saves/{}mode_consistencies_{}_{}.p".format(mode_prefix, subject, fit_type), 'wb', protocol=4)) + + +def conditions_fulfilled(z, xy, conds): + works = z > conds[0] + if conds[1]: + works = np.logical_and(works, xy[0] > conds[1]) + if conds[2]: + works = np.logical_and(works, xy[0] < conds[2]) + if conds[3]: + works = np.logical_and(works, xy[1] > conds[3]) + if conds[4]: + works = np.logical_and(works, xy[1] < conds[4]) + if conds[5]: + works = np.logical_and(works, xy[2] > conds[5]) + if conds[6]: + works = np.logical_and(works, xy[2] < conds[6]) + + return works + + +def state_set_and_plot(test, mode_prefix, subject, fit_type): + mode_indices = pickle.load(open("multi_chain_saves/{}mode_indices_{}_{}.p".format(mode_prefix, subject, fit_type), 'rb')) + consistencies = pickle.load(open("multi_chain_saves/{}mode_consistencies_{}_{}.p".format(mode_prefix, subject, fit_type), 'rb')) + session_bounds = list(np.cumsum([len(s) for s in test.results[0].models[-1].stateseqs])) + + import scipy.cluster.hierarchy as hc + consistencies /= consistencies[0, 0] + linkage = hc.linkage(consistencies[0, 0] - consistencies[np.triu_indices(consistencies.shape[0], k=1)], method='complete') + + # R = hc.dendrogram(linkage, truncate_mode='lastp', p=150, no_labels=True) + # plt.savefig("peter figures/{}tree_{}_{}".format(mode_prefix, subject, 'complete')) + # plt.close() + + session_bounds = list(np.cumsum([len(s) for s in test.results[0].models[-1].stateseqs])) + + plot_criterion = 0.95 + a = hc.fcluster(linkage, plot_criterion, criterion='distance') + b, c = np.unique(a, return_counts=1) + state_sets = [] + for x, y in zip(b, c): + state_sets.append(np.where(a == x)[0]) + print("dumping state set") + pickle.dump(state_sets, open("multi_chain_saves/{}state_sets_{}_{}.p".format(mode_prefix, subject, fit_type), 'wb')) + state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='_{}{}'.format(mode_prefix, plot_criterion), show=True, separate_pmf=True, type_coloring=True) + + fig, ax = plt.subplots(ncols=5, sharey=True, gridspec_kw={'width_ratios': [10, 1, 1, 1, 1]}, figsize=(13, 8)) + from matplotlib.pyplot import cm + for j, criterion in enumerate([0.95, 0.8, 0.5, 0.2]): + clustering_colors = np.zeros((consistencies.shape[0], 100, 4)) + a = hc.fcluster(linkage, criterion, criterion='distance') + b, c = np.unique(a, return_counts=1) + print(b.shape) + print(np.sort(c)) + + cmap = cm.rainbow(np.linspace(0, 1, 17)) + rank_to_color_place = dict(zip(range(17), [0, 16, 8, 4, 12, 2, 6, 10, 14, 1, 3, 5, 7, 9, 11, 13, 15])) # handcrafted to maximise color distance, I think + i = -1 + b = [x for _, x in sorted(zip(c, b))][::-1] + c = [x for x, _ in sorted(zip(c, b))][::-1] + plot_above = 50 + while len([y for y in c if y > plot_above]) > 17: + plot_above += 1 + for x, y in zip(b, c): + if y > plot_above: + i += 1 + clustering_colors[a == x] = cmap[rank_to_color_place[i]] + # clustering_colors[a == x] = cm.rainbow(np.mean(np.where(a == x)[0]) / a.shape[0]) + + ax[j+1].imshow(clustering_colors, aspect='auto', origin='upper') + for sb in session_bounds: + ax[j+1].axhline(sb, color='k') + ax[j+1].set_xticks([]) + ax[j+1].set_yticks([]) + ax[j+1].set_title("{}%".format(int(criterion * 100)), size=20) + + ax[0].imshow(consistencies, aspect='auto', origin='upper') + for sb in session_bounds: + ax[0].axhline(sb, color='k') + ax[0].set_xticks([]) + ax[0].set_yticks(session_bounds[::-1]) + ax[0].set_yticklabels(session_bounds[::-1], size=18) + ax[0].set_ylim(session_bounds[-1], 0) + ax[0].set_ylabel("Trials", size=28) + plt.yticks(rotation=45) + + plt.tight_layout() + plt.savefig("peter figures/{}clustered_trials_{}_{}".format(mode_prefix, subject, 'criteria comp').replace('.', '_')) + plt.close() + + +fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0] +if fit_type == 'bias': + loading_info = json.load(open("canonical_infos_bias.json", 'r')) +elif fit_type == 'prebias': + loading_info = json.load(open("canonical_infos.json", 'r')) +subjects = list(loading_info.keys()) +# error: KS043, KS045, 'NYU-12', ibl_witten_15, NYU-21, CSHL052, KS003 +# done: NYU-46, NYU-39, ibl_witten_19, NYU-48 +subjects = ['DY_013', 'ZFM-01592', 'NYU-39', 'NYU-27', 'NYU-46', 'ZFM-01936', 'ZFM-02372', 'ZFM-01935', 'ibl_witten_26', 'ZM_2241', 'KS084', 'ZFM-01576'] + +for subject in subjects: + test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'rb')) + if os.path.isfile("multi_chain_saves/{}mode_indices_{}_{}.p".format('first_', subject, fit_type)): + print("It has been done") + continue + print('Computing sub result') + create_mode_indices(test, subject, fit_type) + # state_set_and_plot(test, 'first_', subject, fit_type) + # print("second mode?") + # if input() in ['y', 'yes']: + # state_set_and_plot(test, 'second_', subject, fit_type) diff --git a/remove_duplicates.py b/remove_duplicates.py new file mode 100644 index 0000000000000000000000000000000000000000..97f6683b06e9db6c9f9bd9e88da97bb77f1e6b7d --- /dev/null +++ b/remove_duplicates.py @@ -0,0 +1,86 @@ +import os +import re + +memory = {} +name_saves = {} +seed_saves = {} + +do_it = True + +for filename in os.listdir("./dynamic_GLMiHMM_crossvals/"): + if not filename.endswith('.p'): + continue + regexp = re.compile(r'((\w|-)+)_fittype_(\w+)_var_0.03_(\d+)_(\d+)_(\d+)') + result = regexp.search(filename) + subject = result.group(1) + fit_type = result.group(3) + seed = result.group(4) + fit_num = result.group(5) + chain_num = result.group(6) + + if fit_type == 'prebias': + + if subject not in seed_saves: + seed_saves[subject] = [seed] + else: + if seed not in seed_saves[subject]: + seed_saves[subject].append(seed) + + if (subject, seed) not in name_saves: + name_saves[(subject, seed)] = [] + + if fit_num not in name_saves[(subject, seed)]: + name_saves[(subject, seed)].append(fit_num) + + if (subject, seed, fit_num) not in memory: + memory[(subject, seed, fit_num)] = {"chain_num": int(chain_num), "counter": 1} + else: # if this is the first file of that chain, save some info + memory[(subject, seed, fit_num)]["chain_num"] = max(memory[(subject, seed, fit_num)]["chain_num"], int(chain_num)) + memory[(subject, seed, fit_num)]["counter"] += 1 + +total_move = 0 +nyu_11_move = 0 +dicts_removed = 0 +moved = [] +completed = [] +incompleted = [] +for key in name_saves: + subject = key[0] + seed = key[1] + complete = False + save_fit_num = -1 + for fit_num in name_saves[key]: + if memory[(subject, seed, fit_num)]['chain_num'] == 14 and memory[(subject, seed, fit_num)]['counter'] == 15: + save_fit_num = fit_num + complete = True + if len(seed_saves[subject]) == 16: + if subject not in completed: + completed.append(subject) + else: + if subject not in incompleted: + incompleted.append(subject) + if complete and len(name_saves[key]) > 1: + assert save_fit_num != -1 + for fit_num in name_saves[key]: + if fit_num != save_fit_num: + for i in range(15): + if do_it: + if os.path.exists("./dynamic_GLMiHMM_crossvals/{}_fittype_prebias_var_0.03_{}_{}_{}.p".format(subject, seed, fit_num, i)): + os.rename("./dynamic_GLMiHMM_crossvals/{}_fittype_prebias_var_0.03_{}_{}_{}.p".format(subject, seed, fit_num, i), + "./del_test/{}_fittype_prebias_var_0.03_{}_{}_{}.p".format(subject, seed, fit_num, i)) + if subject not in moved: + moved.append(subject) + total_move += 1 + else: + if os.path.exists("./dynamic_GLMiHMM_crossvals/{}_fittype_prebias_var_0.03_{}_{}_{}.p".format(subject, seed, fit_num, i)): + print("I would move ") + print("./dynamic_GLMiHMM_crossvals/{}_fittype_prebias_var_0.03_{}_{}_{}.p".format(subject, seed, fit_num, i)) + print(" to ") + print("./del_test/{}_fittype_prebias_var_0.03_{}_{}_{}.p".format(subject, seed, fit_num, i)) + total_move += 1 + nyu_11_move += subject == "NYU-11" + +print(moved) +print(completed) +print(incompleted) +print("Would move {} in total, and {} of NYU-11".format(total_move, nyu_11_move)) diff --git a/test_codes/bias_shift_test.py b/test_codes/bias_shift_test.py new file mode 100644 index 0000000000000000000000000000000000000000..597c7071b9a5eea45abeb6e7832acea3fae9879a --- /dev/null +++ b/test_codes/bias_shift_test.py @@ -0,0 +1,38 @@ +""" + Studying how much easier it is for small changes in the regressors to affect the pmf around 0 versus further from zero + +""" +import numpy as np +import matplotlib.pyplot as plt + + +weights = list(np.zeros((17, 3))) + +for i, weight in enumerate(weights): + weight[-1] = i * 0.2 - 1.6 + +contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0]) +contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1] +bias = np.ones(11) + +predictors = np.vstack((contrasts_L, contrasts_R, bias)).T + +for weight in weights: + plt.plot(1 / (1 + np.exp(- np.sum(weight * predictors, axis=1)))) + +plt.ylim(0, 1) +plt.show() + + + +weights = list(np.zeros((17, 3))) + +for i, weight in enumerate(weights): + weight[-1] = i * 0.2 - 1.6 + weight[0] = 2 + +for weight in weights: + plt.plot(1 / (1 + np.exp(- np.sum(weight * predictors, axis=1)))) + +plt.ylim(0, 1) +plt.show() \ No newline at end of file