diff --git a/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc b/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc
index 4a89c93f30137d03c0b66cc48856b5c8f6b5a111..797f6ea4ca71054cc38cb50347dc59cdea7fef0b 100644
Binary files a/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc and b/__pycache__/dyn_glm_chain_analysis.cpython-37.pyc differ
diff --git a/__pycache__/index_mice.cpython-37.pyc b/__pycache__/index_mice.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc0d0f83258d0e8e4559bfb0f57d89d9f097425c
Binary files /dev/null and b/__pycache__/index_mice.cpython-37.pyc differ
diff --git a/__pycache__/mcmc_chain_analysis.cpython-37.pyc b/__pycache__/mcmc_chain_analysis.cpython-37.pyc
index dfcc898f4df0897c89643b35287844e38e54193b..852f09637cc815a928e81c03093be6c9e08945c3 100644
Binary files a/__pycache__/mcmc_chain_analysis.cpython-37.pyc and b/__pycache__/mcmc_chain_analysis.cpython-37.pyc differ
diff --git a/behavioral_state_data.py b/behavioral_state_data.py
index be9aef832e5d8d480af615b126d5528cad2f8229..11a8cc41aef7d5e3b48b1fd791604b18415cff46 100644
--- a/behavioral_state_data.py
+++ b/behavioral_state_data.py
@@ -73,7 +73,7 @@ to_introduce = [2, 3, 4, 5]
 #             "ibl_witten_06", "ibl_witten_07", "ibl_witten_12", "ibl_witten_13", "ibl_witten_14", "ibl_witten_15",
 #             "ibl_witten_16", "KS003", "KS005", "KS019", "NYU-01", "NYU-02", "NYU-04", "NYU-06", "ZM_1367", "ZM_1369",
 #             "ZM_1371", "ZM_1372", "ZM_1743", "ZM_1745", "ZM_1746"]  # zoe's subjects
-subjects = ['ibl_witten_14']
+subjects = ['ZFM-05236']
 
 data_folder = 'session_data'
 # why does CSHL058 not work?
@@ -149,7 +149,8 @@ for subject in subjects:
     contrast_set = {0, 1, 9, 10}
 
     rel_count = -1
-    for i, (eid, extra_eids) in enumerate(zip(fixed_eids, additional_eids)):
+    quit()
+    for i, (eid, extra_eids, date) in enumerate(zip(fixed_eids, additional_eids, fixed_dates)):
 
         try:
             trials = one.load_object(eid, 'trials')
@@ -182,6 +183,8 @@ for subject in subjects:
             df = pd.concat([df, df2], ignore_index=1)
             print('new size: {}'.format(len(df)))
 
+        pickle.dump(df, open("./sofiya_data/{}_df_{}_{}.p".format(subject, rel_count, date), "wb"))
+
         current_contrasts = set(df['signed_contrast'])
         diff = current_contrasts.difference(contrast_set)
         for c in to_introduce:
diff --git a/behavioral_state_data_easier.py b/behavioral_state_data_easier.py
index e526d7cc485df51c38469af52b0e669aee793b5c..18a9de78e8aea1c241ff6a474e44e58a204d84c7 100644
--- a/behavioral_state_data_easier.py
+++ b/behavioral_state_data_easier.py
@@ -34,14 +34,12 @@ misses = []
 to_introduce = [2, 3, 4, 5]
 
 amiss = ['UCLA034', 'UCLA036', 'UCLA037', 'PL015', 'PL016', 'PL017', 'PL024', 'NR_0017', 'NR_0019', 'NR_0020', 'NR_0021', 'NR_0027']
-subjects = ['NYU-21']
+subjects = ['ZFM-04019', 'ZFM-05236']
 fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0]
 if fit_type == 'bias':
     loading_info = json.load(open("canonical_infos_bias.json", 'r'))
-    r_hats = json.load(open("canonical_info_r_hats_bias.json", 'r'))
 elif fit_type == 'prebias':
     loading_info = json.load(open("canonical_infos.json", 'r'))
-    r_hats = json.load(open("canonical_info_r_hats.json", 'r'))
 already_fit = list(loading_info.keys())
 
 remaining_subs = [s for s in subjects if s not in amiss and s not in already_fit]
@@ -70,13 +68,14 @@ for subject in subjects:
     print('_____________________')
     print(subject)
 
-    if subject in already_fit or subject in amiss:
-        continue
+    # if subject in already_fit or subject in amiss:
+    #     continue
 
     trials = one.load_aggregate('subjects', subject, '_ibl_subjectTrials.table')
 
     # Load training status and join to trials table
     training = one.load_aggregate('subjects', subject, '_ibl_subjectTraining.table')
+    quit()
     trials = (trials
               .set_index('session')
               .join(training.set_index('session'))
diff --git a/canonical_infos.json b/canonical_infos.json
index 12a3fa3aae799f418267abc5e98f113db7d5d248..99ae626eb5a4f9444331e2cd47bdbe9b3c2131ce 100644
--- a/canonical_infos.json
+++ b/canonical_infos.json
@@ -1 +1 @@
-{"SWC_022": {"seeds": ["412", "401", "403", "413", "406", "407", "415", "409", "405", "400", "408", "404", "410", "411", "414", "402"], "fit_nums": ["347", "54", "122", "132", "520", "386", "312", "59", "999", "849", "372", "300", "485", "593", "358", "550"], "chain_num": 19}, "SWC_023": {"seeds": ["302", "312", "304", "300", "315", "311", "308", "305", "303", "309", "306", "313", "307", "314", "301", "310"], "fit_nums": ["994", "913", "681", "816", "972", "790", "142", "230", "696", "537", "975", "773", "918", "677", "742", "745"], "chain_num": 4}, "ZFM-05236": {"seeds": ["404", "409", "401", "412", "406", "411", "410", "402", "408", "405", "415", "403", "414", "400", "413", "407"], "fit_nums": ["106", "111", "333", "253", "395", "76", "186", "192", "221", "957", "989", "612", "632", "304", "50", "493"], "chain_num": 14, "ignore": [3, 0, 8, 2, 6, 15, 14, 10]}, "ibl_witten_15": {"seeds": ["408", "412", "400", "411", "410", "407", "403", "406", "413", "405", "404", "402", "401", "415", "409", "414"], "fit_nums": ["40", "241", "435", "863", "941", "530", "382", "750", "532", "731", "146", "500", "967", "334", "375", "670"], "chain_num": 19}, "ibl_witten_13": {"seeds": ["401", "414", "409", "413", "415", "411", "410", "408", "402", "405", "406", "407", "412", "403", "400", "404"], "fit_nums": ["702", "831", "47", "740", "251", "929", "579", "351", "515", "261", "222", "852", "754", "892", "473", "29"], "chain_num": 19}, "KS016": {"seeds": ["315", "301", "309", "313", "302", "307", "303", "308", "311", "312", "314", "306", "310", "300", "305", "304"], "fit_nums": ["99", "57", "585", "32", "501", "558", "243", "413", "59", "757", "463", "172", "524", "957", "909", "292"], "chain_num": 4}, "ibl_witten_19": {"seeds": ["412", "415", "413", "408", "409", "404", "403", "401", "405", "411", "410", "406", "402", "414", "407", "400"], "fit_nums": ["234", "41", "503", "972", "935", "808", "912", "32", "331", "755", "117", "833", "822", "704", "901", "207"], "chain_num": 19}, "CSH_ZAD_017": {"seeds": ["408", "404", "413", "406", "414", "411", "400", "401", "415", "407", "402", "412", "403", "409", "405", "410"], "fit_nums": ["928", "568", "623", "841", "92", "251", "829", "922", "964", "257", "150", "970", "375", "113", "423", "564"], "chain_num": 19}, "KS022": {"seeds": ["315", "300", "314", "301", "303", "302", "306", "308", "305", "310", "313", "312", "304", "307", "311", "309"], "fit_nums": ["899", "681", "37", "957", "629", "637", "375", "980", "810", "51", "759", "664", "420", "127", "259", "555"], "chain_num": 4}, "CSH_ZAD_025": {"seeds": ["303", "311", "307", "312", "313", "314", "308", "315", "305", "306", "304", "302", "309", "310", "301", "300"], "fit_nums": ["581", "148", "252", "236", "581", "838", "206", "756", "449", "288", "756", "593", "733", "633", "418", "563"], "chain_num": 4}, "ibl_witten_17": {"seeds": ["406", "415", "408", "413", "402", "405", "409", "400", "414", "401", "412", "407", "404", "410", "403", "411"], "fit_nums": ["827", "797", "496", "6", "444", "823", "384", "873", "634", "27", "811", "142", "207", "322", "756", "275"], "chain_num": 9}, "SWC_021": {"seeds": ["404", "413", "406", "412", "403", "401", "410", "409", "400", "414", "415", "402", "405", "408", "411", "407"], "fit_nums": ["840", "978", "224", "38", "335", "500", "83", "509", "441", "9", "135", "890", "358", "460", "844", "30"], "chain_num": 19}, "ibl_witten_18": {"seeds": ["311", "310", "303", "314", "302", "309", "305", "307", "312", "300", "308", "306", "315", "313", "304", "301"], "fit_nums": ["236", "26", "838", "762", "826", "409", "496", "944", "280", "704", "930", "419", "637", "896", "876", "297"], "chain_num": 4}, "CSHL_018": {"seeds": ["302", "310", "306", "300", "314", "307", "309", "313", "311", "308", "304", "301", "312", "303", "305", "315"], "fit_nums": ["843", "817", "920", "900", "226", "36", "472", "676", "933", "453", "116", "263", "269", "897", "568", "438"], "chain_num": 4}, "ZFM-05245": {"seeds": ["400", "413", "414", "409", "411", "403", "405", "412", "406", "410", "407", "415", "401", "404", "402", "408"], "fit_nums": ["512", "765", "704", "17", "539", "449", "584", "987", "138", "932", "869", "313", "253", "540", "37", "634"], "chain_num": 11}, "GLM_Sim_06": {"seeds": ["313", "309", "302", "303", "305", "314", "300", "315", "311", "306", "304", "310", "301", "312", "308", "307"], "fit_nums": ["9", "786", "286", "280", "72", "587", "619", "708", "360", "619", "311", "189", "60", "708", "939", "733"], "chain_num": 2}, "ZM_1897": {"seeds": ["304", "308", "305", "311", "315", "314", "307", "306", "300", "303", "313", "310", "301", "312", "302", "309"], "fit_nums": ["549", "96", "368", "509", "424", "897", "287", "426", "968", "93", "725", "513", "837", "581", "989", "374"], "chain_num": 4}, "CSHL_020": {"seeds": ["305", "309", "313", "302", "314", "310", "300", "307", "315", "306", "312", "304", "311", "301", "303", "308"], "fit_nums": ["222", "306", "243", "229", "584", "471", "894", "238", "986", "660", "494", "657", "896", "459", "100", "283"], "chain_num": 4}, "CSHL054": {"seeds": ["401", "415", "409", "410", "414", "413", "407", "405", "406", "408", "411", "400", "412", "402", "403", "404"], "fit_nums": ["901", "734", "609", "459", "574", "793", "978", "66", "954", "906", "954", "111", "292", "850", "266", "967"], "chain_num": 9}, "CSHL_014": {"seeds": ["305", "311", "309", "300", "313", "310", "307", "306", "304", "312", "308", "302", "314", "303", "301", "315"], "fit_nums": ["371", "550", "166", "24", "705", "385", "870", "884", "831", "546", "404", "722", "287", "564", "613", "783"], "chain_num": 4}, "ZFM-04019": {"seeds": ["413", "404", "408", "403", "415", "406", "414", "410", "402", "405", "411", "400", "401", "412", "409", "407"], "fit_nums": ["493", "302", "590", "232", "121", "938", "270", "999", "95", "175", "576", "795", "728", "244", "32", "177"], "chain_num": 14, "ignore": [11, 12, 6, 3, 5, 2, 15, 1]}, "CSHL062": {"seeds": ["307", "313", "310", "303", "306", "312", "308", "305", "311", "314", "304", "302", "300", "301", "315", "309"], "fit_nums": ["846", "371", "94", "888", "499", "229", "546", "432", "71", "989", "986", "91", "935", "314", "975", "481"], "chain_num": 4}, "CSH_ZAD_001": {"seeds": ["313", "309", "311", "312", "305", "310", "315", "300", "314", "304", "301", "302", "308", "303", "306", "307"], "fit_nums": ["468", "343", "314", "544", "38", "120", "916", "170", "305", "569", "502", "496", "452", "336", "559", "572"], "chain_num": 4}, "KS003": {"seeds": ["405", "401", "414", "415", "410", "404", "409", "413", "412", "408", "411", "407", "402", "406", "403", "400"], "fit_nums": ["858", "464", "710", "285", "665", "857", "990", "438", "233", "177", "43", "509", "780", "254", "523", "695"], "chain_num": 19}, "NYU-06": {"seeds": ["314", "309", "306", "305", "312", "303", "307", "304", "300", "302", "310", "301", "315", "308", "313", "311"], "fit_nums": ["950", "862", "782", "718", "427", "645", "827", "612", "821", "834", "595", "929", "679", "668", "648", "869"], "chain_num": 4}, "KS019": {"seeds": ["404", "401", "411", "408", "400", "403", "410", "413", "402", "407", "415", "409", "406", "414", "412", "405"], "fit_nums": ["682", "4", "264", "200", "250", "267", "737", "703", "132", "855", "922", "686", "85", "176", "54", "366"], "chain_num": 9}, "CSHL049": {"seeds": ["411", "402", "414", "408", "409", "410", "413", "407", "406", "401", "404", "405", "403", "415", "400", "412"], "fit_nums": ["104", "553", "360", "824", "749", "519", "347", "228", "863", "671", "140", "883", "701", "445", "627", "898"], "chain_num": 9}, "ibl_witten_14": {"seeds": ["310", "311", "304", "306", "300", "302", "314", "313", "303", "308", "301", "309", "305", "315", "312", "307"], "fit_nums": ["563", "120", "85", "712", "277", "871", "183", "661", "505", "598", "210", "89", "310", "638", "564", "998"], "chain_num": 4}, "KS014": {"seeds": ["301", "310", "302", "312", "313", "308", "307", "303", "305", "300", "314", "306", "311", "309", "304", "315"], "fit_nums": ["668", "32", "801", "193", "269", "296", "74", "24", "270", "916", "21", "250", "342", "451", "517", "293"], "chain_num": 4}, "CSHL059": {"seeds": ["306", "309", "300", "304", "314", "303", "315", "311", "313", "305", "301", "307", "302", "312", "310", "308"], "fit_nums": ["821", "963", "481", "999", "986", "45", "551", "605", "701", "201", "629", "261", "972", "407", "165", "9"], "chain_num": 4}, "GLM_Sim_13": {"seeds": ["310", "303", "308", "306", "300", "312", "301", "313", "305", "311", "315", "304", "314", "309", "307", "302"], "fit_nums": ["982", "103", "742", "524", "614", "370", "926", "456", "133", "143", "302", "80", "395", "549", "579", "944"], "chain_num": 2}, "CSHL_007": {"seeds": ["314", "303", "308", "313", "301", "300", "302", "305", "315", "306", "310", "309", "311", "304", "307", "312"], "fit_nums": ["462", "703", "345", "286", "480", "313", "986", "165", "201", "102", "322", "894", "960", "438", "330", "169"], "chain_num": 4}, "CSH_ZAD_011": {"seeds": ["314", "311", "303", "300", "305", "310", "306", "301", "302", "315", "304", "309", "308", "312", "313", "307"], "fit_nums": ["320", "385", "984", "897", "315", "120", "320", "945", "475", "403", "210", "412", "695", "564", "664", "411"], "chain_num": 4}, "KS021": {"seeds": ["309", "312", "304", "310", "303", "311", "314", "302", "305", "301", "306", "300", "308", "315", "313", "307"], "fit_nums": ["874", "943", "925", "587", "55", "136", "549", "528", "349", "211", "401", "84", "225", "545", "153", "382"], "chain_num": 4}, "GLM_Sim_15": {"seeds": ["303", "312", "305", "308", "309", "302", "301", "310", "313", "315", "311", "314", "307", "306", "304", "300"], "fit_nums": ["769", "930", "328", "847", "899", "714", "144", "518", "521", "873", "914", "359", "242", "343", "45", "364"], "chain_num": 2}, "CSHL_015": {"seeds": ["301", "302", "307", "310", "309", "311", "304", "312", "300", "308", "313", "305", "314", "315", "306", "303"], "fit_nums": ["717", "705", "357", "539", "604", "971", "669", "76", "45", "413", "510", "122", "190", "821", "368", "472"], "chain_num": 4}, "ibl_witten_16": {"seeds": ["304", "313", "309", "314", "312", "307", "305", "301", "306", "310", "300", "315", "308", "311", "303", "302"], "fit_nums": ["392", "515", "696", "270", "7", "583", "880", "674", "23", "576", "579", "695", "149", "854", "184", "875"], "chain_num": 4}, "KS015": {"seeds": ["315", "305", "309", "303", "314", "310", "311", "312", "313", "300", "307", "308", "304", "301", "302", "306"], "fit_nums": ["257", "396", "387", "435", "133", "164", "403", "8", "891", "650", "111", "557", "473", "229", "842", "196"], "chain_num": 4}, "GLM_Sim_12": {"seeds": ["304", "312", "306", "303", "310", "302", "300", "305", "308", "313", "307", "311", "315", "301", "314", "309"], "fit_nums": ["971", "550", "255", "195", "952", "486", "841", "535", "559", "37", "654", "213", "864", "506", "732", "550"], "chain_num": 2}, "GLM_Sim_11": {"seeds": ["300", "312", "310", "315", "302", "313", "314", "311", "308", "303", "309", "307", "306", "304", "301", "305"], "fit_nums": ["477", "411", "34", "893", "195", "293", "603", "5", "887", "281", "956", "73", "346", "640", "532", "688"], "chain_num": 2}, "GLM_Sim_10": {"seeds": ["301", "300", "306", "305", "307", "309", "312", "314", "311", "315", "304", "313", "303", "308", "302", "310"], "fit_nums": ["391", "97", "897", "631", "239", "652", "19", "448", "807", "35", "972", "469", "280", "562", "42", "706"], "chain_num": 2}, "CSH_ZAD_026": {"seeds": ["312", "313", "308", "310", "303", "307", "302", "305", "300", "315", "306", "301", "311", "304", "314", "309"], "fit_nums": ["699", "87", "537", "628", "797", "511", "459", "770", "969", "240", "504", "948", "295", "506", "25", "378"], "chain_num": 4}, "KS023": {"seeds": ["304", "313", "306", "309", "300", "314", "302", "310", "303", "315", "307", "308", "301", "311", "305", "312"], "fit_nums": ["698", "845", "319", "734", "908", "507", "45", "499", "175", "108", "419", "443", "116", "779", "159", "231"], "chain_num": 4}, "GLM_Sim_05": {"seeds": ["301", "315", "300", "302", "305", "304", "313", "314", "311", "309", "306", "307", "308", "310", "303", "312"], "fit_nums": ["425", "231", "701", "375", "343", "902", "623", "125", "921", "637", "393", "964", "678", "930", "796", "42"], "chain_num": 2}, "CSHL061": {"seeds": ["305", "315", "304", "303", "309", "310", "302", "300", "314", "306", "311", "313", "301", "308", "307", "312"], "fit_nums": ["396", "397", "594", "911", "308", "453", "686", "552", "103", "209", "128", "892", "345", "925", "777", "396"], "chain_num": 4}, "CSHL051": {"seeds": ["303", "310", "306", "302", "309", "305", "313", "308", "300", "314", "311", "307", "312", "304", "315", "301"], "fit_nums": ["69", "186", "49", "435", "103", "910", "705", "367", "303", "474", "596", "334", "929", "796", "616", "790"], "chain_num": 4}, "GLM_Sim_14": {"seeds": ["310", "311", "309", "313", "314", "300", "302", "304", "305", "306", "307", "312", "303", "301", "315", "308"], "fit_nums": ["616", "872", "419", "106", "940", "986", "599", "704", "218", "808", "244", "825", "448", "397", "552", "316"], "chain_num": 2}, "GLM_Sim_11_trick": {"seeds": ["411", "400", "408", "409", "415", "413", "410", "412", "406", "414", "403", "404", "401", "405", "407", "402"], "fit_nums": ["95", "508", "886", "384", "822", "969", "525", "382", "489", "436", "344", "537", "251", "223", "458", "401"], "chain_num": 2}, "GLM_Sim_16": {"seeds": ["302", "311", "303", "307", "313", "308", "309", "300", "305", "315", "304", "310", "312", "301", "314", "306"], "fit_nums": ["914", "377", "173", "583", "870", "456", "611", "697", "13", "713", "159", "248", "617", "37", "770", "780"], "chain_num": 2}, "ZM_3003": {"seeds": ["300", "304", "307", "312", "305", "310", "311", "314", "303", "308", "313", "301", "315", "309", "306", "302"], "fit_nums": ["603", "620", "657", "735", "357", "390", "119", "33", "62", "617", "209", "810", "688", "21", "744", "426"], "chain_num": 4}, "CSH_ZAD_022": {"seeds": ["305", "310", "311", "315", "303", "312", "314", "313", "307", "302", "300", "304", "301", "308", "306", "309"], "fit_nums": ["143", "946", "596", "203", "576", "403", "900", "65", "478", "325", "282", "513", "460", "42", "161", "970"], "chain_num": 4}, "GLM_Sim_07": {"seeds": ["300", "309", "302", "304", "305", "312", "301", "311", "315", "314", "308", "307", "303", "310", "306", "313"], "fit_nums": ["724", "701", "118", "230", "648", "426", "689", "114", "832", "731", "592", "519", "559", "938", "672", "144"], "chain_num": 1}, "KS017": {"seeds": ["311", "310", "306", "309", "303", "302", "308", "300", "313", "301", "314", "307", "315", "304", "312", "305"], "fit_nums": ["97", "281", "808", "443", "352", "890", "703", "468", "780", "708", "674", "27", "345", "23", "939", "457"], "chain_num": 4}, "GLM_Sim_11_sub": {"seeds": ["410", "414", "413", "404", "409", "415", "406", "408", "402", "411", "400", "405", "403", "407", "412", "401"], "fit_nums": ["830", "577", "701", "468", "929", "374", "954", "749", "937", "488", "873", "416", "612", "792", "461", "488"], "chain_num": 2}}
\ No newline at end of file
+{"NYU-45": {"seeds": ["513", "503", "500", "506", "502", "501", "512", "509", "507", "515", "510", "505", "504", "514", "511", "508"], "fit_nums": ["768", "301", "96", "731", "879", "989", "915", "512", "295", "48", "157", "631", "666", "334", "682", "714"], "chain_num": 14}, "UCLA035": {"seeds": ["512", "509", "500", "510", "501", "503", "506", "513", "505", "504", "507", "502", "514", "508", "511", "515"], "fit_nums": ["715", "834", "996", "656", "242", "883", "870", "959", "483", "94", "864", "588", "390", "173", "967", "871"], "chain_num": 14}, "NYU-30": {"seeds": ["505", "508", "515", "507", "504", "513", "512", "503", "509", "500", "510", "514", "501", "502", "511", "506"], "fit_nums": ["885", "637", "318", "98", "209", "171", "472", "823", "956", "89", "762", "260", "76", "319", "139", "785"], "chain_num": 14}, "CSHL047": {"seeds": ["509", "510", "503", "502", "501", "508", "514", "505", "507", "511", "515", "504", "500", "506", "513", "512"], "fit_nums": ["60", "589", "537", "3", "178", "99", "877", "381", "462", "527", "6", "683", "771", "950", "294", "252"], "chain_num": 14}, "NYU-39": {"seeds": ["515", "502", "513", "508", "514", "503", "510", "506", "509", "504", "511", "500", "512", "501", "507", "505"], "fit_nums": ["722", "12", "207", "378", "698", "928", "15", "180", "650", "334", "388", "528", "608", "593", "988", "479"], "chain_num": 14}, "NYU-37": {"seeds": ["508", "509", "506", "512", "507", "503", "515", "504", "500", "505", "513", "501", "510", "511", "502", "514"], "fit_nums": ["94", "97", "793", "876", "483", "878", "886", "222", "66", "59", "601", "994", "526", "694", "304", "615"], "chain_num": 14}, "KS045": {"seeds": ["501", "514", "500", "502", "503", "505", "510", "515", "508", "507", "509", "512", "511", "506", "513", "504"], "fit_nums": ["731", "667", "181", "609", "489", "555", "995", "19", "738", "1", "267", "653", "750", "332", "218", "170"], "chain_num": 14}, "UCLA006": {"seeds": ["509", "505", "513", "515", "511", "500", "503", "507", "508", "514", "504", "510", "502", "501", "506", "512"], "fit_nums": ["807", "849", "196", "850", "293", "874", "216", "542", "400", "632", "781", "219", "331", "730", "740", "32"], "chain_num": 14}, "UCLA033": {"seeds": ["507", "505", "501", "512", "502", "510", "513", "514", "506", "515", "509", "504", "508", "500", "503", "511"], "fit_nums": ["366", "18", "281", "43", "423", "877", "673", "146", "921", "21", "353", "267", "674", "113", "905", "252"], "chain_num": 14}, "NYU-40": {"seeds": ["513", "500", "503", "507", "515", "504", "510", "508", "505", "514", "502", "512", "501", "509", "511", "506"], "fit_nums": ["656", "826", "657", "634", "861", "347", "334", "227", "747", "834", "460", "191", "489", "458", "24", "346"], "chain_num": 14}, "NYU-46": {"seeds": ["507", "509", "512", "508", "503", "500", "515", "501", "511", "506", "502", "505", "510", "513", "514", "504"], "fit_nums": ["503", "523", "5", "819", "190", "917", "707", "609", "145", "416", "376", "603", "655", "271", "223", "149"], "chain_num": 14}, "KS044": {"seeds": ["513", "503", "507", "504", "502", "515", "501", "511", "510", "500", "512", "508", "509", "506", "514", "505"], "fit_nums": ["367", "656", "73", "877", "115", "627", "610", "772", "558", "581", "398", "267", "353", "779", "393", "473"], "chain_num": 14}, "NYU-48": {"seeds": ["502", "507", "503", "515", "513", "505", "512", "501", "510", "504", "508", "511", "506", "514", "500", "509"], "fit_nums": ["854", "52", "218", "963", "249", "901", "248", "322", "566", "768", "256", "101", "303", "485", "577", "141"], "chain_num": 14}, "UCLA012": {"seeds": ["508", "506", "504", "513", "512", "502", "507", "505", "510", "503", "515", "509", "511", "501", "500", "514"], "fit_nums": ["492", "519", "577", "417", "717", "60", "130", "186", "725", "83", "841", "65", "441", "534", "856", "735"], "chain_num": 14}, "KS084": {"seeds": ["515", "507", "512", "503", "506", "508", "510", "502", "509", "505", "500", "511", "504", "501", "513", "514"], "fit_nums": ["816", "374", "140", "955", "399", "417", "733", "149", "300", "642", "644", "248", "324", "830", "889", "286"], "chain_num": 14}, "CSHL052": {"seeds": ["502", "508", "509", "514", "507", "515", "506", "500", "501", "505", "504", "511", "513", "512", "503", "510"], "fit_nums": ["572", "630", "784", "813", "501", "738", "517", "461", "690", "203", "40", "202", "412", "755", "837", "917"], "chain_num": 14}, "NYU-11": {"seeds": ["502", "510", "501", "513", "512", "500", "503", "515", "506", "514", "511", "505", "509", "504", "508", "507"], "fit_nums": ["650", "771", "390", "185", "523", "901", "387", "597", "57", "624", "12", "833", "433", "58", "276", "248"], "chain_num": 14}, "KS051": {"seeds": ["502", "505", "508", "515", "512", "511", "501", "509", "500", "506", "513", "503", "504", "507", "514", "510"], "fit_nums": ["620", "548", "765", "352", "402", "699", "370", "445", "159", "746", "449", "342", "642", "204", "726", "605"], "chain_num": 14}, "NYU-27": {"seeds": ["507", "513", "501", "505", "511", "504", "500", "514", "502", "509", "503", "515", "512", "510", "508", "506"], "fit_nums": ["485", "520", "641", "480", "454", "913", "526", "705", "138", "151", "962", "24", "21", "743", "119", "699"], "chain_num": 14}, "UCLA011": {"seeds": ["513", "503", "508", "501", "510", "505", "506", "511", "507", "514", "515", "500", "509", "502", "512", "504"], "fit_nums": ["957", "295", "743", "795", "643", "629", "142", "174", "164", "21", "835", "338", "368", "341", "209", "68"], "chain_num": 14}, "NYU-47": {"seeds": ["515", "504", "510", "503", "505", "506", "502", "501", "507", "500", "514", "513", "508", "509", "511", "512"], "fit_nums": ["169", "941", "329", "51", "788", "654", "224", "434", "385", "46", "712", "84", "930", "571", "273", "312"], "chain_num": 14}, "CSHL045": {"seeds": ["514", "502", "510", "507", "512", "501", "511", "506", "508", "509", "503", "513", "500", "504", "515", "505"], "fit_nums": ["862", "97", "888", "470", "620", "765", "874", "421", "104", "909", "924", "874", "158", "992", "25", "40"], "chain_num": 14}, "UCLA017": {"seeds": ["510", "502", "500", "515", "503", "507", "514", "501", "511", "505", "513", "504", "508", "512", "506", "509"], "fit_nums": ["875", "684", "841", "510", "209", "207", "806", "700", "989", "899", "812", "971", "526", "887", "160", "249"], "chain_num": 14}, "CSHL055": {"seeds": ["509", "503", "505", "512", "504", "508", "510", "511", "507", "501", "500", "514", "513", "515", "506", "502"], "fit_nums": ["957", "21", "710", "174", "689", "796", "449", "183", "193", "209", "437", "827", "990", "705", "540", "835"], "chain_num": 14}, "UCLA005": {"seeds": ["507", "503", "512", "515", "505", "500", "504", "513", "509", "506", "501", "511", "514", "502", "510", "508"], "fit_nums": ["636", "845", "712", "60", "733", "789", "990", "230", "335", "337", "307", "404", "297", "608", "428", "108"], "chain_num": 14}, "CSHL060": {"seeds": ["510", "515", "513", "503", "514", "501", "504", "502", "511", "500", "508", "509", "505", "507", "506", "512"], "fit_nums": ["626", "953", "497", "886", "585", "293", "580", "867", "113", "734", "88", "55", "949", "443", "210", "555"], "chain_num": 14}, "UCLA015": {"seeds": ["503", "501", "502", "500"], "fit_nums": ["877", "773", "109", "747"], "chain_num": 14}, "KS055": {"seeds": ["510", "502", "511", "501", "509", "506", "500", "515", "503", "512", "504", "505", "513", "508", "514", "507"], "fit_nums": ["526", "102", "189", "216", "673", "477", "293", "981", "960", "883", "899", "95", "31", "244", "385", "631"], "chain_num": 14}, "UCLA014": {"seeds": ["513", "509", "508", "510", "500", "504", "515", "511", "501", "514", "507", "502", "505", "512", "506", "503"], "fit_nums": ["628", "950", "418", "91", "25", "722", "792", "225", "287", "272", "23", "168", "821", "934", "194", "481"], "chain_num": 14}, "CSHL053": {"seeds": ["505", "513", "501", "508", "502", "515", "510", "507", "506", "509", "514", "511", "500", "503", "504", "512"], "fit_nums": ["56", "674", "979", "0", "221", "577", "25", "679", "612", "185", "464", "751", "648", "715", "344", "348"], "chain_num": 14}, "NYU-12": {"seeds": ["508", "506", "509", "515", "511", "510", "501", "504", "513", "503", "507", "512", "500", "514", "505", "502"], "fit_nums": ["853", "819", "692", "668", "730", "213", "846", "596", "644", "829", "976", "895", "974", "824", "179", "769"], "chain_num": 14}, "KS043": {"seeds": ["505", "504", "507", "500", "502", "503", "508", "511", "512", "509", "515", "513", "510", "514", "501", "506"], "fit_nums": ["997", "179", "26", "741", "476", "502", "597", "477", "511", "181", "233", "330", "299", "939", "542", "113"], "chain_num": 14}, "CSHL058": {"seeds": ["513", "505", "514", "506", "504", "500", "511", "503", "508", "509", "501", "510", "502", "515", "507", "512"], "fit_nums": ["752", "532", "949", "442", "400", "315", "106", "419", "903", "198", "553", "158", "674", "249", "723", "941"], "chain_num": 14}, "KS042": {"seeds": ["503", "502", "506", "511", "501", "505", "512", "509", "513", "508", "507", "500", "510", "504", "515", "514"], "fit_nums": ["241", "895", "503", "880", "283", "267", "944", "204", "921", "514", "392", "241", "28", "905", "334", "894"], "chain_num": 14}}
\ No newline at end of file
diff --git a/combine_canon_infos.py b/combine_canon_infos.py
new file mode 100644
index 0000000000000000000000000000000000000000..a31fe3555793d0a19ceaae4c2aa4f01e9e50f841
--- /dev/null
+++ b/combine_canon_infos.py
@@ -0,0 +1,33 @@
+"""
+    Script for combining local canonical_infos.json and the one from the cluster
+"""
+import json
+
+dist_info = json.load(open("canonical_infos.json", 'r'))
+local_info = json.load(open("canonical_infos_local.json", 'r'))
+
+cluster_subs = ['KS045', 'KS043', 'KS051', 'DY_008', 'KS084', 'KS052', 'KS046', 'KS096', 'KS086', 'UCLA033', 'UCLA005', 'NYU-21', 'KS055', 'KS091']
+
+for key in cluster_subs:
+    print('ignore' in dist_info[key])
+
+quit()
+
+for key in cluster_subs:
+    if key not in local_info:
+        print("Adding all of {} to local info".format(key))
+        local_info[key] = dist_info[key]
+        continue
+    else:
+        for sub_key in dist_info[key]:
+            if sub_key not in local_info[key]:
+                print("Adding {} into local info for {}".format(key))
+                local_info[key][sub_key] = dist_info[key][sub_key]
+            else:
+                if local_info[key][sub_key] == dist_info[key][sub_key]:
+                    continue
+                else:
+                    assert len(dist_info[key][sub_key]) == 16
+                    for x in dist_info[key][sub_key]:
+                        assert x in local_info[key][sub_key]
+                    local_info[key][sub_key] = dist_info[key][sub_key]
diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py
index e20781d88ad2813cf9ec63bd5ecb617bef87b722..1b8062228711583bf6024e9b68935547ab36d6ce 100644
--- a/dyn_glm_chain_analysis.py
+++ b/dyn_glm_chain_analysis.py
@@ -19,7 +19,7 @@ from matplotlib.ticker import MaxNLocator
 import json
 import time
 import multiprocessing as mp
-from mcmc_chain_analysis import state_size_helper, state_num_helper, gamma_func, alpha_func, ll_func, r_hat_array_comp, rank_inv_normal_transform, eval_r_hat, eval_simple_r_hat
+from mcmc_chain_analysis import state_num_helper, ll_func, r_hat_array_comp, rank_inv_normal_transform
 import pandas as pd
 from analysis_pmf import pmf_type, type2color
 
@@ -670,10 +670,12 @@ def control_flow(test, indices, trials, func_init, first_for, second_for, end_fi
             continue
         first_for(test, results)
 
+        counter = -1
         for i, m in enumerate([item for sublist in test.results for item in sublist.models]):
             if i not in indices:
                 continue
-            second_for(m, j, session_trials, trial_counter, results)
+            counter += 1
+            second_for(m, j, counter, session_trials, trial_counter, results)
 
         end_first_for(results, indices, j, trial_counter=trial_counter, session_trials=session_trials)
         trial_counter += len(only_for_length)
@@ -681,181 +683,6 @@ def control_flow(test, indices, trials, func_init, first_for, second_for, end_fi
     return results
 
 
-def create_mode_indices(test, subject, fit_type):
-    dim = 3
-
-    try:
-        xy, z = pickle.load(open("multi_chain_saves/xyz_{}_{}.p".format(subject, fit_type), 'rb'))
-    except Exception:
-        print('Doing PCA')
-        ev, eig, projection_matrix, dimreduc = test.state_pca(subject, pca_type='dists', dim=dim)
-        xy = np.vstack([dimreduc[i] for i in range(dim)])
-        from scipy.stats import gaussian_kde
-        z = gaussian_kde(xy)(xy)
-        pickle.dump((xy, z), open("multi_chain_saves/xyz_{}_{}.p".format(subject, fit_type), 'wb'))
-
-    print("Mode indices of " + subject)
-
-    threshold_search(xy, z, test, 'first_', subject, fit_type)
-
-    print("Find another mode?")
-    if input() not in ['yes', 'y']:
-        return
-
-    threshold_search(xy, z, test, 'second_', subject, fit_type)
-
-    print("Find another mode?")
-    if input() not in ['yes', 'y']:
-        return
-
-    threshold_search(xy, z, test, 'third_', subject, fit_type)
-    return
-
-def threshold_search(xy, z, test, mode_prefix, subject, fit_type):
-    happy = False
-    conds = [0, None, None, None, None]
-    x_min, x_max, y_min, y_max = None, None, None, None
-    while not happy:
-        print()
-        print("Pick level")
-        prob_level = input()
-        if prob_level == 'cond':
-            print("x > ?")
-            resp = input()
-            if resp not in ['n', 'no']:
-                x_min = float(resp)
-
-            print("x < ?")
-            resp = input()
-            if resp not in ['n', 'no']:
-                x_max = float(resp)
-
-            print("y > ?")
-            resp = input()
-            if resp not in ['n', 'no']:
-                y_min = float(resp)
-
-            print("y < ?")
-            resp = input()
-            if resp not in ['n', 'no']:
-                y_max = float(resp)
-
-            print("Prob level")
-            prob_level = float(input())
-            conds = [prob_level, x_min, x_max, y_min, y_max]
-            print("Condtions are {}".format(conds))
-        else:
-            prob_level = float(prob_level)
-            conds[0] = prob_level
-
-        print("Level is {}".format(prob_level))
-
-        mode = conditions_fulfilled(z, xy, conds)
-        print("# of samples: {}".format(mode.sum()))
-        mode_indices = np.where(mode)[0]
-        if mode.sum() > 0:
-            print(xy[0][mode_indices].min(), xy[0][mode_indices].max(), xy[1][mode_indices].min(), xy[1][mode_indices].max())
-            print("Happy?")
-            happy = 'yes' == input()
-
-    print("Subset by factor?")
-    if input() == 'yes':
-        print("Factor?")
-        print(mode_indices.shape)
-        factor = int(input())
-        mode_indices = mode_indices[::factor]
-        print(mode_indices.shape)
-    loading_info[subject]['mode prob level'] = prob_level
-
-    pickle.dump(mode_indices, open("multi_chain_saves/{}mode_indices_{}_{}.p".format(mode_prefix, subject, fit_type), 'wb'))
-    consistencies = test.consistency_rsa(indices=mode_indices)
-    pickle.dump(consistencies, open("multi_chain_saves/{}mode_consistencies_{}_{}.p".format(mode_prefix, subject, fit_type), 'wb'))
-
-
-def conditions_fulfilled(z, xy, conds):
-    works = z > conds[0]
-    if conds[1]:
-        works = np.logical_and(works, xy[0] > conds[1])
-    if conds[2]:
-        works = np.logical_and(works, xy[0] < conds[2])
-    if conds[3]:
-        works = np.logical_and(works, xy[1] > conds[3])
-    if conds[4]:
-        works = np.logical_and(works, xy[1] < conds[4])
-
-    return works
-
-
-def state_set_and_plot(test, mode_prefix, subject, fit_type):
-    mode_indices = pickle.load(open("multi_chain_saves/{}mode_indices_{}_{}.p".format(mode_prefix, subject, fit_type), 'rb'))
-    consistencies = pickle.load(open("multi_chain_saves/{}mode_consistencies_{}_{}.p".format(mode_prefix, subject, fit_type), 'rb'))
-    session_bounds = list(np.cumsum([len(s) for s in test.results[0].models[-1].stateseqs]))
-
-    import scipy.cluster.hierarchy as hc
-    consistencies /= consistencies[0, 0]
-    linkage = hc.linkage(consistencies[0, 0] - consistencies[np.triu_indices(consistencies.shape[0], k=1)], method='complete')
-
-    # R = hc.dendrogram(linkage, truncate_mode='lastp', p=150, no_labels=True)
-    plt.savefig("peter figures/{}tree_{}_{}".format(mode_prefix, subject, 'complete'))
-    plt.close()
-
-    session_bounds = list(np.cumsum([len(s) for s in test.results[0].models[-1].stateseqs]))
-
-    fig, ax = plt.subplots(ncols=5, sharey=True, gridspec_kw={'width_ratios': [10, 1, 1, 1, 1]}, figsize=(13, 8))
-    from matplotlib.pyplot import cm
-    for j, criterion in enumerate([0.95, 0.8, 0.5, 0.2]):
-        clustering_colors = np.zeros((consistencies.shape[0], 100, 4))
-        a = hc.fcluster(linkage, criterion, criterion='distance')
-        b, c = np.unique(a, return_counts=1)
-        print(b.shape)
-        print(np.sort(c))
-
-        if criterion in [0.95]:
-            state_sets = []
-            for x, y in zip(b, c):
-                state_sets.append(np.where(a == x)[0])
-            print("dumping state set")
-            pickle.dump(state_sets, open("multi_chain_saves/{}state_sets_{}_{}.p".format(mode_prefix, subject, fit_type), 'wb'))
-            state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='_{}{}'.format(mode_prefix, criterion), show=True, separate_pmf=True, type_coloring=True)
-
-            # I think this is about finding out where states start and how long they last
-            # state_id, session_appears = np.where(states)
-            # for s in np.unique(state_id):
-            #     state_appear_mode.append(session_appears[state_id == s][0] / (test.results[0].n_sessions - 1))
-
-        # cmap = cm.rainbow(np.linspace(0, 1, 17))#len([x for x, y in zip(b, c) if y > 50])))
-        # rank_to_color_place = dict(zip(range(17), [0, 16, 8, 4, 12, 2, 6, 10, 14, 1, 3, 5, 7, 9, 11, 13, 15]))  # handcrafted to maximise color distance, I think
-        i = -1
-        b = [x for _, x in sorted(zip(c, b))][::-1]
-        c = [x for x, _ in sorted(zip(c, b))][::-1]
-        for x, y in zip(b, c):
-            if y > 50:
-                i += 1
-                # clustering_colors[a == x] = cmap[rank_to_color_place[i]]
-                clustering_colors[a == x] = cm.rainbow(np.mean(np.where(a == x)[0]) / a.shape[0])
-
-        ax[j+1].imshow(clustering_colors, aspect='auto', origin='upper')
-        for sb in session_bounds:
-            ax[j+1].axhline(sb, color='k')
-        ax[j+1].set_xticks([])
-        ax[j+1].set_yticks([])
-        ax[j+1].set_title("{}%".format(int(criterion * 100)), size=20)
-
-    ax[0].imshow(consistencies, aspect='auto', origin='upper')
-    for sb in session_bounds:
-        ax[0].axhline(sb, color='k')
-    ax[0].set_xticks([])
-    ax[0].set_yticks(session_bounds[::-1])
-    ax[0].set_yticklabels(session_bounds[::-1], size=18)
-    ax[0].set_ylim(session_bounds[-1], 0)
-    ax[0].set_ylabel("Trials", size=28)
-    plt.yticks(rotation=45)
-
-    plt.tight_layout()
-    plt.savefig("peter figures/{}clustered_trials_{}_{}".format(mode_prefix, subject, 'criteria comp').replace('.', '_'))
-    plt.close()
-
-
 def state_pmfs(test, trials, indices):
     def func_init(): return {'pmfs': [], 'session_js': [], 'pmf_weights': []}
 
@@ -863,7 +690,7 @@ def state_pmfs(test, trials, indices):
         results['pmf'] = np.zeros(test.results[0].n_contrasts)
         results['pmf_weight'] = np.zeros(4)
 
-    def second_for(m, j, session_trials, trial_counter, results):
+    def second_for(m, j, counter, session_trials, trial_counter, results):
         states, counts = np.unique(m.stateseqs[j][session_trials - trial_counter], return_counts=True)
         for sub_state, c in zip(states, counts):
             results['pmf'] += weights_to_pmf(m.obs_distns[sub_state].weights[j]) * c / session_trials.shape[0]
@@ -884,7 +711,7 @@ def state_weights(test, trials, indices):
     def first_for(test, results):
         results['weights'] = np.zeros(test.results[0].models[0].obs_distns[0].weights.shape[1])
 
-    def second_for(m, j, session_trials, trial_counter, results):
+    def second_for(m, j, counter, session_trials, trial_counter, results):
         states, counts = np.unique(m.stateseqs[j][session_trials - trial_counter], return_counts=True)
         for sub_state, c in zip(states, counts):
             results['weights'] += m.obs_distns[sub_state].weights[j] * c / session_trials.shape[0]
@@ -907,7 +734,7 @@ def lapse_sides(test, state_sets, indices):
     def first_for(test, results):
         results['pmf'] = np.zeros(test.results[0].n_contrasts)
 
-    def second_for(m, j, session_trials, trial_counter, results):
+    def second_for(m, j, counter, session_trials, trial_counter, results):
         states, counts = np.unique(m.stateseqs[j][session_trials - trial_counter], return_counts=True)
         for sub_state, c in zip(states, counts):
             results['pmf'] += weights_to_pmf(m.obs_distns[sub_state].weights[j]) * c / session_trials.shape[0]
@@ -1102,6 +929,8 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
     if test.results[0].name.startswith('GLM_Sim_'):
         print("./glm sim mice/truth_{}.p".format(test.results[0].name))
         truth = pickle.load(open("./glm sim mice/truth_{}.p".format(test.results[0].name), "rb"))
+        # truth['state_posterior'] = truth['state_posterior'][:, [0, 1, 3, 4, 5, 6, 7]]  # 17, mode 2
+        # truth['weights'] = [w for i, w in enumerate(truth['weights']) if i != 2]  # 17, mode 2
 
     states_by_session = np.zeros((len(state_sets), test.results[0].n_sessions))
     trial_counter = 0
@@ -1172,7 +1001,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
                 counter += 1
         pmfs_to_score.append(np.mean(pmfs))
     # test.state_mapping = dict(zip(range(len(state_sets)), np.argsort(np.argsort(pmfs_to_score))))  # double argsort for ranks
-    test.state_mapping = dict(zip(range(len(state_sets)), np.flip(np.argsort((states_by_session != 0).argmax(axis=1)))))
+    test.state_mapping = dict(zip(np.flip(np.argsort((states_by_session != 0).argmax(axis=1))), range(len(state_sets))))
 
     for state, trials in enumerate(state_sets):
         cmap = matplotlib.cm.get_cmap(cmaps[state]) if state < len(cmaps) else matplotlib.cm.get_cmap('Greys')
@@ -1218,7 +1047,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
                                  test.state_mapping[state] + states_by_session[state] - 0.5, color=state_color)
 
         else:
-            n_points = 150
+            n_points = 400
             points = np.linspace(1, test.results[0].n_sessions, n_points)
             interpolation = np.interp(points, np.arange(1, 1 + test.results[0].n_sessions), states_by_session[state])
 
@@ -1229,7 +1058,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
         ax1.annotate(len(state_sets) - test.state_mapping[state], (test.results[0].n_sessions + 0.1, test.state_mapping[state] - 0.15), fontsize=22, annotation_clip=False)
 
         if test.results[0].name.startswith('GLM_Sim_'):
-            ax1.plot(range(1, 1 + test.results[0].n_sessions), truth['state_map'][test.state_mapping[state]] + truth['state_posterior'][:, state] - 0.5, color='r')
+            ax1.plot(range(1, 1 + test.results[0].n_sessions), state + truth['state_posterior'][:, state] - 0.5, color='r')
 
         alpha_level = 0.3
         ax2.axvline(0.5, c='grey', alpha=alpha_level, zorder=4)
@@ -1259,7 +1088,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
 
         if test.results[0].name.startswith('GLM_Sim_'):
             sim_pmf = weights_to_pmf(truth['weights'][state])
-            ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), sim_pmf[defined_points] - 0.5 + truth['state_map'][test.state_mapping[state]], color='r')
+            ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), sim_pmf[defined_points] - 0.5 + state, color='r')
 
         if not test.state_mapping[state] in dont_plot:
             ax2.annotate("Type {}".format(1 + pmf_type(pmf[defined_points])), (1.05, test.state_mapping[state] - 0.37), rotation=90, size=13, color=type2color[pmf_type(pmf)], annotation_clip=False)
@@ -1354,15 +1183,71 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
 
     plt.tight_layout()
     if save:
-        # print("saving with {} dpi".format(dpi))
+        print("saving with {} dpi".format(dpi))
         plt.savefig("dynamic_GLM_figures/meta_state_development_{}_{}{}.png".format(test.results[0].name, separate_pmf, save_append), dpi=dpi)
     if show:
         plt.show()
     else:
         plt.close()
 
+    if subject.startswith("GLM_Sim_"):
+        plt.figure(figsize=(16, 9))
+        for state, trials in enumerate(state_sets):
+            dur_params = dur_hists(test, trials, indices)
+
+            plt.subplot(4, 4, 1 + 2 * test.state_mapping[state])
+            plt.hist(dur_params[:, 0])
+            plt.axvline(truth['durs'][state][0], color='red')
+
+            plt.subplot(4, 4, 2 + 2 * test.state_mapping[state])
+            plt.hist(dur_params[:, 1])
+            plt.axvline(truth['durs'][state][1], color='red')
+
+        plt.tight_layout()
+        plt.savefig("dur hists")
+        plt.show()
+
+        from scipy.stats import nbinom
+        points = np.arange(900)
+        plt.figure(figsize=(16, 9))
+        for state, trials in enumerate(state_sets):
+            dur_params = dur_hists(test, trials, indices)
+
+            plt.subplot(2, 4, 1 + test.state_mapping[state])
+            plt.plot(nbinom.pmf(points, np.mean(dur_params[:, 0]), np.mean(dur_params[:, 1])))
+            plt.plot(nbinom.pmf(points, truth['durs'][state][0], truth['durs'][state][1]), color='red')
+            plt.xlabel("# of trials")
+            plt.ylabel("P")
+
+        plt.tight_layout()
+        plt.savefig("dur dists")
+        plt.show()
+
     return states_by_session, all_pmfs, all_pmf_weights, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs)), introductions_by_stage, states_per_type
 
+def dur_hists(test, trials, indices):
+    def func_init(): return {'dur_params': np.zeros((len(indices), 2))}
+
+    def first_for(test, results):
+        pass
+
+    def second_for(m, j, counter, session_trials, trial_counter, results):
+        states, counts = np.unique(m.stateseqs[j][session_trials - trial_counter], return_counts=True)
+        for sub_state, c in zip(states, counts):
+            try:
+                p = m.dur_distns[sub_state].p_save  # this cannot be accessed after deleting data, but not every model's data is deleted
+            except:
+                p = m.dur_distns[sub_state].p
+            results['dur_params'][counter] += np.array([m.dur_distns[sub_state].r * c / session_trials.shape[0], p * c / session_trials.shape[0]])
+
+    def end_first_for(results, indices, j, **kwargs):
+        pass
+
+    results = control_flow(test, indices, trials, func_init, first_for, second_for, end_first_for)
+    results['dur_params'] = results['dur_params'] / len(indices)
+    return results['dur_params']
+
+
 def compare_pmfs(test, states2compare, states_by_session, all_pmfs, title=""):
     """
        Take a set of states, and plot out their PMFs on all sessions on which they occur.
@@ -1645,14 +1530,19 @@ if __name__ == "__main__":
     fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0]
     if fit_type == 'bias':
         loading_info = json.load(open("canonical_infos_bias.json", 'r'))
-        r_hats = json.load(open("canonical_info_r_hats_bias.json", 'r'))
     elif fit_type == 'prebias':
         loading_info = json.load(open("canonical_infos.json", 'r'))
-        r_hats = json.load(open("canonical_info_r_hats.json", 'r'))
     no_good_pcas = ['NYU-06', 'SWC_023']
     subjects = list(loading_info.keys())
-    # subjects = ['SWC_021', 'ibl_witten_15', 'ibl_witten_13', 'KS003', 'ibl_witten_19', 'SWC_022', 'CSH_ZAD_017']
-    subjects = ['KS021']
+
+    new_subs = ['KS044', 'NYU-11', 'DY_016', 'SWC_061', 'ZFM-05245', 'CSH_ZAD_029', 'SWC_021', 'CSHL058', 'DY_014', 'DY_009', 'KS094', 'DY_018', 'KS043', 'UCLA014', 'SWC_038', 'SWC_022', 'UCLA012', 'UCLA011', 'CSHL055', 'ZFM-04019', 'NYU-45', 'ZFM-02370', 'ZFM-02373', 'ZFM-02369', 'NYU-40', 'CSHL060', 'NYU-30', 'CSH_ZAD_019', 'UCLA017', 'KS052', 'ibl_witten_25', 'ZFM-02368', 'CSHL045', 'UCLA005', 'SWC_058', 'CSH_ZAD_024', 'SWC_042', 'DY_008', 'ibl_witten_13', 'SWC_043', 'KS046', 'DY_010', 'CSHL053', 'ZM_1898', 'UCLA033', 'NYU-47', 'DY_011', 'CSHL047', 'SWC_054', 'ibl_witten_19', 'ibl_witten_27', 'KS091', 'KS055', 'CSH_ZAD_017', 'UCLA035', 'SWC_060', 'DY_020', 'ZFM-01577', 'ZM_2240', 'ibl_witten_29', 'KS096', 'SWC_066', 'DY_013', 'ZFM-01592', 'GLM_Sim_17', 'NYU-48', 'UCLA006', 'NYU-39', 'KS051', 'NYU-27', 'NYU-46', 'ZFM-01936', 'ZFM-02372', 'ZFM-01935', 'ibl_witten_26', 'ZFM-05236', 'ZM_2241', 'NYU-37', 'KS086', 'KS084', 'ZFM-01576', 'KS042']
+
+    miss_count = 0
+    for s in new_subs:
+        if s not in subjects:
+            print(s)
+            miss_count += 1
+    quit()
 
     print(subjects)
     fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0]
@@ -1660,8 +1550,6 @@ if __name__ == "__main__":
 
     # fig, ax = plt.subplots(1, 3, sharey=True, figsize=(16, 9))
 
-    thinning = 25
-    summary_info = {"thinning": thinning, "contains": [], "seeds": [], "fit_nums": []}
     pop_state_starts = np.zeros(20)
     state_appear_dist = np.zeros(10)
     state_appear_mode = []
@@ -1703,227 +1591,173 @@ if __name__ == "__main__":
 
     for subject in subjects:
 
-        if subject.startswith('GLM_Sim_') or subject == 'ibl_witten_18':
+        # NYU-11 is quite weird, super errrativ behaviour, has all contrasts introduced at once, no good session at end
+        if subject.startswith('GLM_Sim_'):
             continue
 
         print()
         print(subject)
 
-        try:
-            test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'rb'))
+        test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'rb'))
 
-            # mode_specifier = ''
+        mode_specifier = 'first'
+        try:
+            mode_indices = pickle.load(open("multi_chain_saves/{}_mode_indices_{}_{}.p".format(mode_specifier, subject, fit_type), 'rb'))
+            state_sets = pickle.load(open("multi_chain_saves/{}_state_sets_{}_{}.p".format(mode_specifier, subject, fit_type), 'rb'))
+        except:
             try:
-                mode_indices = pickle.load(open("multi_chain_saves/first_mode_indices_{}_{}.p".format(subject, fit_type), 'rb'))
-                state_sets = pickle.load(open("multi_chain_saves/first_state_sets_{}_{}.p".format(subject, fit_type), 'rb'))
+                mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'rb'))
+                state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb'))
             except:
-                try:
-                    mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'rb'))
-                    state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb'))
-                except:
-                    continue
+                continue
 
-            # lapse differential
-            # lapse_sides(test, [s for s in state_sets if len(s) > 40], mode_indices)
-
-            # training overview
-            # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 0', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(8)), plot_until=-1)
-            # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 1', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(7)), plot_until=2)
-            # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 2', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(6)), plot_until=7)
-            # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 3', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(4)), plot_until=13)
-            states, pmfs, pmf_weights, durs, state_types, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True)
-            # state_nums_5.append((states > 0.05).sum(0))
-            # state_nums_10.append((states > 0.1).sum(0))
-            # continue
-            a = compare_params(test, 26, [s for s in state_sets if len(s) > 40], mode_indices, [3, 5])
-            compare_pmfs(test, [3, 5], states, pmfs, title="{} convergence pmf".format(subject))
-            quit()
-            new = type_2_appearance(states, pmfs)
-
-            if new == 2:
-                print('____________________________')
-                print(subject)
-                print('____________________________')
-            if new == 1:
-                new_counter += 1
-            if new == 0:
-                transform_counter += 1
-            print(new_counter, transform_counter)
-
-            consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
-            consistencies /= consistencies[0, 0]
-            temp = contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True, consistencies=consistencies, CMF=False)
-            continue
+        # lapse differential
+        # lapse_sides(test, [s for s in state_sets if len(s) > 40], mode_indices)
 
-            state_dict = write_results(test, [s for s in state_sets if len(s) > 40], mode_indices)
-            pickle.dump(state_dict, open("state_dict_{}".format(subject), 'wb'))
-            quit()
-            # abs_state_durs.append(durs)
-            # continue
-            # all_pmf_weights += pmf_weights
-            # all_state_types.append(state_types)
-            #
-            # state_types_interpolation[0] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[0])
-            # state_types_interpolation[1] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[1])
-            # state_types_interpolation[2] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[2])
-            #
-            # all_first_pmfs_typeless[subject] = []
-            # for pmf in pmfs:
-            #     all_first_pmfs_typeless[subject].append((pmf[0], pmf[1][0]))
-            #     all_pmfs.append(pmf)
-            #
-            # first_pmfs, changing_pmfs = get_first_pmfs(states, pmfs)
-            # for pmf in changing_pmfs:
-            #     if type(pmf[0]) == int:
-            #         continue
-            #     all_changing_pmf_names.append(subject)
-            #     all_changing_pmfs.append(pmf)
-            #
-            # all_first_pmfs[subject] = first_pmfs
-            # continue
-            #
-            # consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
-            # consistencies /= consistencies[0, 0]
-            # contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=False, consistencies=consistencies, CMF=False)
-
-
-            # b_flips = bias_flips(states, pmfs, durs)
-            # all_bias_flips.append(b_flips)
-
-            # regression, diffs = pmf_regressions(states, pmfs, durs)
-            # regression_diffs += diffs
-            # regressions.append(regression)
-            # continue
-            # compare_pmfs(test, [s for s in state_sets if len(s) > 40], mode_indices, [4, 5], states, pmfs, title="{} convergence pmf".format(subject))
-            # compare_weights(test, [s for s in state_sets if len(s) > 40], mode_indices, [4, 5], states, title="{} convergence weights".format(subject))
-            # quit()
-            # all_intros.append(undiv_intros)
-            # all_intros_div.append(intros_by_type)
-            # if states_per_type != []:
-            #     all_states_per_type.append(states_per_type)
-            #
-            # intros_by_type_sum += intros_by_type
-            # for pmf in pmfs:
-            #     all_pmfs.append(pmf)
-            #     for p in pmf[1]:
-            #         all_pmf_diffs.append(p[-1] - p[0])
-            #         all_pmf_asymms.append(np.abs(p[0] + p[-1] - 1))
-            # contrast_intro_types.append(contrast_intro_type)
-
-            # num_states.append(np.mean(test.state_num_dist()))
-            # num_sessions.append(test.results[0].n_sessions)
-
-            # a, b = np.where(states)
-            # for i in range(states.shape[0]):
-            #     state_appear.append(b[a == i][0] / (test.results[0].n_sessions - 1))
-            #     state_dur.append(b[a == i].shape[0])
-
-            # state_development_single_sample(test, [mode_indices[0]], show=True, separate_pmf=True, save=False)
-
-            # session overview
-            # consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
-            # consistencies /= consistencies[0, 0]
-            # c_n_a, rt_data = contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=False, consistencies=consistencies, CMF=True)
-            # compare_performance(cnas, (1, 1), title="{} contrast {} performance test".format(subject, (1, 1)))
-            # compare_performance(cnas, (0, 0.848))
-
-            # duration of different state types (and also percentage of type activities)
-            # continue
-            # simplex_durs = np.array(durs).reshape(1, 3)
-            # print(simplex_durs / np.sum(simplex_durs))
-            # from simplex_plot import projectSimplex
-            # print(projectSimplex(simplex_durs / simplex_durs.sum(1)[:, None]))
-
-            # compute state type proportions and split the pmfs accordingly
-            # ret, trans, pmf_types = state_cluster_interpolation(states, pmfs)
-            # state_trans += trans  # matrix of how often the different types transition into one another
-            # for i, r in enumerate(ret):
-            #     state_trajs[i] += np.interp(np.linspace(1, test.results[0].n_sessions, n_points), np.arange(1, 1 + test.results[0].n_sessions), r)  # for interpolated population state trajectories
-            # plot_pmf_types(pmf_types, subject=subject, fit_type=fit_type)
-            # continue
-            # plt.plot(ret.T, label=[0, 1, 2])
-            # plt.legend()
-            # plt.show()
-
-            # Analyse single sample
-            # xy, z = pickle.load(open("multi_chain_saves/xyz_{}.p".format(subject), 'rb'))
-            #
-            # single_sample = [np.argmax(z)]
-            # for position, index in enumerate(np.where(np.logical_and(z > 2.7e-7, xy[0] < -500))[0]):
-            #     print(position, index, index // test.n, index % test.n)
-            #     states, pmfs = state_development_single_sample(test, [index], save_append='_{}_{}'.format('single_sample', position), show=True, separate_pmf=True)
-            #     if position == 9:
-            #         quit()
-            # quit()
-
-            # all_state_starts = test.state_appearance_posterior(subject)
-            # pop_state_starts += all_state_starts
-            #
-            # a, b, c = test.state_start_and_dur()
-            # state_appear += a
-            # state_dur += b
-            # state_appear_dist += c
-
-
-            # all_state_starts = test.state_appearance_posterior(subject)
-            # plt.plot(all_state_starts)
-            # plt.savefig("temp")
-            # plt.close()
-
-            print('Computing sub result')
-            create_mode_indices(test, subject, fit_type)
-            state_set_and_plot(test, 'first_', subject, fit_type)
-            print("second mode?")
-            if input() in ['y', 'yes']:
-                state_set_and_plot(test, 'second_', subject, fit_type)
-
-        except FileNotFoundError as e:
-            print(e)
-            print('no canoncial result')
-            continue
-            print(r_hats[subject])
-            if r_hats[subject] >= 1.05:
-                print("Skipping")
-                continue
-            else:
-                print("Making canonical result")
-            results = []
-            for j, (seed, fit_num) in enumerate(zip(loading_info[subject]['seeds'], loading_info[subject]['fit_nums'])):
-                if j in loading_info[subject]['ignore']:
-                    continue
-                summary_info["contains"].append(j)
-                summary_info["seeds"].append(seed)
-                summary_info["fit_nums"].append(fit_num)
-                info_dict = pickle.load(open("./session_data/{}_info_dict.p".format(subject), "rb"))
-
-                samples = []
-                mini_counter = 1  # start from one here, first 4000 as burn-in
-                while True:
-                    try:
-                        file = "./dynamic_GLMiHMM_crossvals/{}_fittype_{}_var_{}_{}_{}{}.p".format(subject, fit_type, fit_variance, seed, fit_num, '_{}'.format(mini_counter))
-                        lala = time.time()
-                        samples += pickle.load(open(file, "rb"))
-                        print("Loading {} took {:.4}".format(mini_counter, time.time() - lala))
-                        mini_counter += 1
-                    except Exception:
-                        break
-                save_id = "{}_fittype_{}_var_{}_{}_{}.p".format(subject, fit_type, fit_variance, seed, fit_num).replace('.', '_')
-
-                print("Loaded")
-
-                result = MCMC_result(samples[::50], infos=info_dict,
-                                     data=samples[0].datas, sessions=fit_type, fit_variance=fit_variance,
-                                     dur=dur, save_id=save_id)
-                results.append(result)
-            test = MCMC_result_list(results, summary_info)
-            pickle.dump(test, open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'wb'))
-
-            test.r_hat_and_ess(state_num_helper(0.05), False)
-            test.r_hat_and_ess(state_num_helper(0.02), False)
-            test.r_hat_and_ess(state_size_helper(), False)
-            test.r_hat_and_ess(state_size_helper(1), False)
-            test.r_hat_and_ess(gamma_func, True)
-            test.r_hat_and_ess(alpha_func, True)
+        # training overview
+        # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 0', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(8)), plot_until=-1)
+        # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 1', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(7)), plot_until=2)
+        # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 2', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(6)), plot_until=7)
+        # _ = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='step 3', show=1, separate_pmf=1, type_coloring=False, dont_plot=list(range(4)), plot_until=13)
+        states, pmfs, pmf_weights, durs, state_types, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True)
+        quit()
+        # state_nums_5.append((states > 0.05).sum(0))
+        # state_nums_10.append((states > 0.1).sum(0))
+        # continue
+        # a = compare_params(test, 26, [s for s in state_sets if len(s) > 40], mode_indices, [3, 5])
+        # compare_pmfs(test, [3, 2, 4], states, pmfs, title="{} convergence pmf".format(subject))
+        consistencies = pickle.load(open("multi_chain_saves/first_mode_consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
+        consistencies /= consistencies[0, 0]
+        temp = contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True, consistencies=consistencies, CMF=False)
+        continue
+        quit()
+        new = type_2_appearance(states, pmfs)
+
+        if new == 2:
+            print('____________________________')
+            print(subject)
+            print('____________________________')
+        if new == 1:
+            new_counter += 1
+        if new == 0:
+            transform_counter += 1
+        print(new_counter, transform_counter)
+
+
+        state_dict = write_results(test, [s for s in state_sets if len(s) > 40], mode_indices)
+        pickle.dump(state_dict, open("state_dict_{}".format(subject), 'wb'))
+        quit()
+        # abs_state_durs.append(durs)
+        # continue
+        # all_pmf_weights += pmf_weights
+        # all_state_types.append(state_types)
+        #
+        # state_types_interpolation[0] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[0])
+        # state_types_interpolation[1] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[1])
+        # state_types_interpolation[2] += np.interp(np.linspace(1, state_types.shape[1], 150), np.arange(1, 1 + state_types.shape[1]), state_types[2])
+        #
+        # all_first_pmfs_typeless[subject] = []
+        # for pmf in pmfs:
+        #     all_first_pmfs_typeless[subject].append((pmf[0], pmf[1][0]))
+        #     all_pmfs.append(pmf)
+        #
+        # first_pmfs, changing_pmfs = get_first_pmfs(states, pmfs)
+        # for pmf in changing_pmfs:
+        #     if type(pmf[0]) == int:
+        #         continue
+        #     all_changing_pmf_names.append(subject)
+        #     all_changing_pmfs.append(pmf)
+        #
+        # all_first_pmfs[subject] = first_pmfs
+        # continue
+        #
+        # consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
+        # consistencies /= consistencies[0, 0]
+        # contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=False, consistencies=consistencies, CMF=False)
+
+
+        # b_flips = bias_flips(states, pmfs, durs)
+        # all_bias_flips.append(b_flips)
+
+        # regression, diffs = pmf_regressions(states, pmfs, durs)
+        # regression_diffs += diffs
+        # regressions.append(regression)
+        # continue
+        # compare_pmfs(test, [s for s in state_sets if len(s) > 40], mode_indices, [4, 5], states, pmfs, title="{} convergence pmf".format(subject))
+        # compare_weights(test, [s for s in state_sets if len(s) > 40], mode_indices, [4, 5], states, title="{} convergence weights".format(subject))
+        # quit()
+        # all_intros.append(undiv_intros)
+        # all_intros_div.append(intros_by_type)
+        # if states_per_type != []:
+        #     all_states_per_type.append(states_per_type)
+        #
+        # intros_by_type_sum += intros_by_type
+        # for pmf in pmfs:
+        #     all_pmfs.append(pmf)
+        #     for p in pmf[1]:
+        #         all_pmf_diffs.append(p[-1] - p[0])
+        #         all_pmf_asymms.append(np.abs(p[0] + p[-1] - 1))
+        # contrast_intro_types.append(contrast_intro_type)
+
+        # num_states.append(np.mean(test.state_num_dist()))
+        # num_sessions.append(test.results[0].n_sessions)
+
+        # a, b = np.where(states)
+        # for i in range(states.shape[0]):
+        #     state_appear.append(b[a == i][0] / (test.results[0].n_sessions - 1))
+        #     state_dur.append(b[a == i].shape[0])
+
+        # state_development_single_sample(test, [mode_indices[0]], show=True, separate_pmf=True, save=False)
+
+        # session overview
+        # consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
+        # consistencies /= consistencies[0, 0]
+        # c_n_a, rt_data = contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=False, consistencies=consistencies, CMF=True)
+        # compare_performance(cnas, (1, 1), title="{} contrast {} performance test".format(subject, (1, 1)))
+        # compare_performance(cnas, (0, 0.848))
+
+        # duration of different state types (and also percentage of type activities)
+        # continue
+        # simplex_durs = np.array(durs).reshape(1, 3)
+        # print(simplex_durs / np.sum(simplex_durs))
+        # from simplex_plot import projectSimplex
+        # print(projectSimplex(simplex_durs / simplex_durs.sum(1)[:, None]))
+
+        # compute state type proportions and split the pmfs accordingly
+        # ret, trans, pmf_types = state_cluster_interpolation(states, pmfs)
+        # state_trans += trans  # matrix of how often the different types transition into one another
+        # for i, r in enumerate(ret):
+        #     state_trajs[i] += np.interp(np.linspace(1, test.results[0].n_sessions, n_points), np.arange(1, 1 + test.results[0].n_sessions), r)  # for interpolated population state trajectories
+        # plot_pmf_types(pmf_types, subject=subject, fit_type=fit_type)
+        # continue
+        # plt.plot(ret.T, label=[0, 1, 2])
+        # plt.legend()
+        # plt.show()
+
+        # Analyse single sample
+        # xy, z = pickle.load(open("multi_chain_saves/xyz_{}.p".format(subject), 'rb'))
+        #
+        # single_sample = [np.argmax(z)]
+        # for position, index in enumerate(np.where(np.logical_and(z > 2.7e-7, xy[0] < -500))[0]):
+        #     print(position, index, index // test.n, index % test.n)
+        #     states, pmfs = state_development_single_sample(test, [index], save_append='_{}_{}'.format('single_sample', position), show=True, separate_pmf=True)
+        #     if position == 9:
+        #         quit()
+        # quit()
+
+        # all_state_starts = test.state_appearance_posterior(subject)
+        # pop_state_starts += all_state_starts
+        #
+        # a, b, c = test.state_start_and_dur()
+        # state_appear += a
+        # state_dur += b
+        # state_appear_dist += c
+
+
+        # all_state_starts = test.state_appearance_posterior(subject)
+        # plt.plot(all_state_starts)
+        # plt.savefig("temp")
+        # plt.close()
 
 
     # pickle.dump(all_first_pmfs, open("all_first_pmfs.p", 'wb'))
diff --git a/dynamic_GLMiHMM_consistency.py b/dynamic_GLMiHMM_consistency.py
index 8bf1924c1a08c92b50d8db433853faa7376093cb..a2aa017456ddf48031bc8af3978a6746102a464d 100644
--- a/dynamic_GLMiHMM_consistency.py
+++ b/dynamic_GLMiHMM_consistency.py
@@ -87,6 +87,7 @@ for subject in subjects:
         mega_data[:, 4] = 1
         mega_data[:, 5] = data[~bad_trials, 1] - 1
         mega_data[:, 5] = (mega_data[:, 5] + 1) / 2
+        print(mega_data.sum(0))
         posteriormodel.add_data(mega_data)
 
     import pyhsmm.util.profiling as prof
diff --git a/dynamic_GLMiHMM_fit.py b/dynamic_GLMiHMM_fit.py
index bf1d046da4a345335014298a708c541c5c6d4ec3..a16e7376dd52cb9a02256f1782c0ea1ffbd923fb 100644
--- a/dynamic_GLMiHMM_fit.py
+++ b/dynamic_GLMiHMM_fit.py
@@ -19,15 +19,6 @@ import json
 import sys
 
 
-def crp_expec(n, theta):
-    """
-    Return expected number of tables after n customers, given concentration theta.
-
-    From Wikipedia
-    """
-    return theta * (digamma(theta + n) - digamma(theta))
-
-
 def eleven2nine(x):
     """Map from 11 possible contrasts to 9, for the non-training phases.
 
@@ -87,13 +78,7 @@ subjects = ['ibl_witten_15', 'ibl_witten_17', 'ibl_witten_18', 'ibl_witten_19',
 
 # test subjects:
 subjects = ['KS014']
-# subjects = ['KS021', 'KS016', 'ibl_witten_16', 'SWC_022', 'KS003', 'CSHL054', 'ZM_3003', 'KS015', 'ibl_witten_13', 'CSHL059', 'CSH_ZAD_022', 'CSHL_007', 'CSHL062', 'NYU-06', 'KS014', 'ibl_witten_14', 'SWC_023']
-
-# subjects = [['GLM_Sim_15', 'GLM_Sim_14', 'GLM_Sim_13', 'GLM_Sim_11', 'GLM_Sim_10', 'GLM_Sim_09', 'GLM_Sim_12'][2]]
-# (0.03, 0.3, 5, 'contR', 'contL', 'prevA', 'bias', 1, 0.1):
 cv_nums = [15]
-# conda activate hdp_pg_env
-# python dynamic_GLMiHMM_fit.py
 
 cv_nums = [200 + int(sys.argv[1]) % 16]
 subjects = [subjects[int(sys.argv[1]) // 16]]
@@ -267,12 +252,6 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
             print("meh, skipped session")
             continue
 
-        # if j == 15:
-        #     import matplotlib.pyplot as plt
-        #     for i in [0, 2, 3,4,5,6,7,8,10]:
-        #         plt.plot(i, data[data[:, 0] == i, 1].mean(), 'ko')
-        #     plt.show()
-
         if params['obs_dur'] == 'glm':
             for i in range(data.shape[0]):
                 data[i, 0] = num_to_contrast[data[i, 0]]
@@ -291,10 +270,7 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
                 elif reg == 'cont':
                     mega_data[:, i] = data[mask, 0]
                 elif reg == 'prevA':
-                    # prev_ans = data[:, 1].copy()
                     new_prev_ans = data[:, 1].copy()
-                    # prev_ans[1:] = prev_ans[:-1]
-                    # prev_ans -= 1
                     new_prev_ans -= 1
                     new_prev_ans = np.convolve(np.append(0, new_prev_ans), params['exp_filter'])[:-(params['exp_filter'].shape[0])]
                     mega_data[:, i] = new_prev_ans[mask]
@@ -334,9 +310,6 @@ for loop_count_i, (s, cv_num) in enumerate(product(subjects, cv_nums)):
 
         posteriormodel.add_data(mega_data)
 
-    # for d in posteriormodel.datas:
-    #     print(d.shape)
-
     # if not os.path.isfile('./{}/data_save_{}.p'.format(data_folder, params['subject'])):
     pickle.dump(data_save, open('./{}/data_save_{}.p'.format(data_folder, params['subject']), 'wb'))
     quit()
diff --git a/index_mice.py b/index_mice.py
index 125425e2f6da0099d9282c43d3bc242ffed6ac75..96d35de5b11c000cdb713ffea432e4c8b74473e5 100644
--- a/index_mice.py
+++ b/index_mice.py
@@ -15,6 +15,8 @@ for filename in os.listdir("./dynamic_GLMiHMM_crossvals/"):
     regexp = re.compile(r'((\w|-)+)_fittype_(\w+)_var_0.03_(\d+)_(\d+)_(\d+)')
     result = regexp.search(filename)
     subject = result.group(1)
+    if subject == 'ibl_witten_26':
+        print('here')
     fit_type = result.group(3)
     seed = result.group(4)
     fit_num = result.group(5)
@@ -49,8 +51,8 @@ big = []
 non_big = []
 sim_subjects = []
 for s in prebias_subinfo.keys():
-    assert len(prebias_subinfo[s]["seeds"]) in [16, 32]
-    assert len(prebias_subinfo[s]["fit_nums"]) in [16, 32]
+    # assert len(prebias_subinfo[s]["seeds"]) in [16, 32], s + " " + str(len(prebias_subinfo[s]["seeds"]))
+    # assert len(prebias_subinfo[s]["fit_nums"]) in [16, 32]
     print(s, len(prebias_subinfo[s]["fit_nums"]))
     if len(prebias_subinfo[s]["fit_nums"]) == 32:
         big.append(s)
@@ -58,8 +60,6 @@ for s in prebias_subinfo.keys():
         non_big.append(s)
     if s.startswith("GLM_Sim_"):
         sim_subjects.append(s)
-    else:
-        non_big.append(s)
 
 print(non_big)
 print()
diff --git a/multi_chain_saves/.gitignore b/multi_chain_saves/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..b84a8f82fc35d730fb072a10736910aaca487db3
--- /dev/null
+++ b/multi_chain_saves/.gitignore
@@ -0,0 +1,23 @@
+
+old_ana_code/
+figures/
+dynamic_figures/
+iHMM_fits/
+dynamic_iHMM_fits/
+overview_figures/
+beliefs/
+session_data/
+WAIC/
+*.png
+*.p
+*.npz
+*.pdf
+*.csv
+*.zip
+*.json
+peter_fiugres/
+consistency_data/
+dynamic_GLM_figures/
+dynamic_GLMiHMM_fits2/
+glm sim mice/
+dynamic_GLMiHMM_crossvals/
diff --git a/multi_chain_saves/work.py b/multi_chain_saves/work.py
deleted file mode 100644
index 25f82f9db00bc52d70145533e1375f2c19b21859..0000000000000000000000000000000000000000
--- a/multi_chain_saves/work.py
+++ /dev/null
@@ -1,15 +0,0 @@
-import os
-
-i = 0
-for filename in os.listdir("./"):
-    if not filename.endswith('.p'):
-        continue
-    if 'bias' in filename:
-        continue
-
-    if not filename.endswith('bias.p'):
-        i += 1
-        print("Rename {} into {}".format(filename, filename[:-2] + '_prebias.p'))
-        os.rename(filename, filename[:-2] + '_prebias.p')
-
-print(i)
diff --git a/raw_fit_processing.py b/raw_fit_processing_part1.py
similarity index 60%
rename from raw_fit_processing.py
rename to raw_fit_processing_part1.py
index 7a6c08b1abe03d177ccaf1970a233df70e69584b..82807f9ccd0a22da1f37d7f1bdd76070fa536d4f 100644
--- a/raw_fit_processing.py
+++ b/raw_fit_processing_part1.py
@@ -5,7 +5,7 @@
     Script for taking a list of subects and extracting statistics from the chains
     which can be used to assess which chains have converged to the same regions
 
-
+    This cannot be run in parallel (because the loading_info dict gets changed and dumped)
 """
 import numpy as np
 import pyhsmm
@@ -13,8 +13,9 @@ import pickle
 import json
 from dyn_glm_chain_analysis import MCMC_result
 import time
-from mcmc_chain_analysis import state_size_helper, state_num_helper, find_good_chains_unsplit_greedy
+from mcmc_chain_analysis import state_size_helper, state_num_helper, find_good_chains_unsplit_greedy, gamma_func, alpha_func
 import index_mice  # executes function for creating dict of available fits
+from dyn_glm_chain_analysis import MCMC_result_list
 
 
 fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0]
@@ -26,9 +27,13 @@ elif fit_type == 'prebias':
     loading_info = json.load(open("canonical_infos.json", 'r'))
     r_hats = json.load(open("canonical_info_r_hats.json", 'r'))
 
-subjects = ['ZFM-04019', 'ZFM-05236'] # list(loading_info.keys())
+# done: 'NYU-45', 'UCLA035', 'NYU-30', 'CSHL047', 'NYU-39', 'NYU-37',
+# can't: 'UCLA006'
+subjects = ['NYU-40', 'NYU-46', 'KS044', 'NYU-48']
+# 'UCLA012', 'CSHL052', 'NYU-11', 'UCLA011', 'NYU-47', 'CSHL045', 'UCLA017', 'CSHL055', 'UCLA005', 'CSHL060', 'UCLA015', 'UCLA014', 'CSHL053', 'NYU-12', 'CSHL058', 'KS042']
+
 
-thinning = 25
+thinning = 50
 
 fit_variance = 0.03
 func1 = state_num_helper(0.2)
@@ -37,13 +42,15 @@ func3 = state_size_helper()
 func4 = state_size_helper(1)
 dur = 'yes'
 
-m = len(loading_info[subjects]["fit_nums"])
-assert m == 16
 for subject in subjects:
+    results = []
+    summary_info = {"thinning": thinning, "contains": [], "seeds": [], "fit_nums": []}
+    m = len(loading_info[subject]["fit_nums"])
+    assert m == 16
     print(subject)
     n_runs = -1
     counter = -1
-    n = (loading_info[subject]['chain_num'] + 1) * 4000 // thinning
+    n = (loading_info[subject]['chain_num']) * 4000 // thinning
     chains1 = np.zeros((m, n))
     chains2 = np.zeros((m, n))
     chains3 = np.zeros((m, n))
@@ -53,11 +60,27 @@ for subject in subjects:
         print(seed)
         info_dict = pickle.load(open("./session_data/{}_info_dict.p".format(subject), "rb"))
         samples = []
-        for mini_counter in range(m):
-            file = "./dynamic_GLMiHMM_crossvals/{}_fittype_{}_var_{}_{}_{}{}.p".format(subject, fit_type, fit_variance, seed, fit_num, '_{}'.format(mini_counter))
-            lala = time.time()
-            samples += pickle.load(open(file, "rb"))
-            print("Loading {} took {:.4}".format(mini_counter, time.time() - lala))
+
+        mini_counter = 1 # start at 1, discard first 4000 as burnin
+        while True:
+            try:
+                file = "./dynamic_GLMiHMM_crossvals/{}_fittype_{}_var_{}_{}_{}{}.p".format(subject, fit_type, fit_variance, seed, fit_num, '_{}'.format(mini_counter))
+                lala = time.time()
+                samples += pickle.load(open(file, "rb"))
+                print("Loading {} took {:.4}".format(mini_counter, time.time() - lala))
+                mini_counter += 1
+            except Exception:
+                break
+
+        if n_runs == -1:
+            n_runs = mini_counter
+        else:
+            if n_runs != mini_counter:
+                print("Problem")
+                print(n_runs, mini_counter)
+                quit()
+
+
         save_id = "{}_fittype_{}_var_{}_{}_{}.p".format(subject, fit_type, fit_variance, seed, fit_num).replace('.', '_')
 
         print("loaded seed {}".format(seed))
@@ -65,7 +88,8 @@ for subject in subjects:
         result = MCMC_result(samples[::thinning],
                              infos=info_dict, data=samples[0].datas,
                              sessions=fit_type, fit_variance=fit_variance,
-                             dur=dur, save_id=save_id))
+                             dur=dur, save_id=save_id)
+        results.append(result)
         print("Making result {} took {:.4}".format(counter, time.time() - lala))
 
         res = func1(result)
@@ -82,8 +106,6 @@ for subject in subjects:
     pickle.dump(chains3, open("multi_chain_saves/{}_largest_state_0_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb'))
     pickle.dump(chains4, open("multi_chain_saves/{}_largest_state_1_fittype_{}_var_{}_{}_{}_state_num.p".format(subject, fit_type, fit_variance, seed, fit_num), 'wb'))
 
-    r_hats = {}
-
     # R^hat tests
     # test = MCMC_result_list([fake_result(100) for i in range(8)])
     # test.r_hat_and_ess(return_ascending, False)
@@ -94,10 +116,6 @@ for subject in subjects:
         continue
     print()
     print("Checking R^hat, finding best subset of chains")
-    chains1 = chains1[:, 160:]
-    chains2 = chains2[:, 160:]
-    chains3 = chains3[:, 160:]
-    chains4 = chains4[:, 160:]
 
     # mins = find_good_chains(chains[:, :-1].reshape(32, chains.shape[1] // 2))
     sol, final_r_hat = find_good_chains_unsplit_greedy(chains1, chains2, chains3, chains4, reduce_to=chains1.shape[0] // 2)
@@ -110,4 +128,27 @@ for subject in subjects:
         json.dump(r_hats, open("canonical_info_r_hats_bias.json", 'w'))
     elif fit_type == 'prebias':
         json.dump(loading_info, open("canonical_infos.json", 'w'))
-        json.dump(r_hats, open("canonical_info_r_hats.json", 'w'))
\ No newline at end of file
+        json.dump(r_hats, open("canonical_info_r_hats.json", 'w'))
+
+    if r_hats[subject] >= 1.05:
+        print("Skipping canonical result")
+        continue
+    else:
+        print("Making canonical result")
+
+    # subset data
+    summary_info['contains'] = [i for i in range(m) if i not in sol]
+    summary_info['seeds'] = [loading_info[subject]['seeds'][i] for i in summary_info['contains']]
+    summary_info['fit_nums'] = [loading_info[subject]['fit_nums'] for i in summary_info['contains']]
+    results = [results[i] for i in summary_info['contains']]
+
+    test = MCMC_result_list(results, summary_info)
+    pickle.dump(test, open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'wb'))
+
+    test.r_hat_and_ess(state_num_helper(0.2), False)
+    test.r_hat_and_ess(state_num_helper(0.1), False)
+    test.r_hat_and_ess(state_num_helper(0.05), False)
+    test.r_hat_and_ess(state_size_helper(), False)
+    test.r_hat_and_ess(state_size_helper(1), False)
+    test.r_hat_and_ess(gamma_func, True)
+    test.r_hat_and_ess(alpha_func, True)
diff --git a/raw_fit_processing_part2.py b/raw_fit_processing_part2.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a0835c4357c6dd189b4ae53cf06119cba4e9fc3
--- /dev/null
+++ b/raw_fit_processing_part2.py
@@ -0,0 +1,238 @@
+import json
+import numpy as np
+from dyn_glm_chain_analysis import MCMC_result_list
+import pickle
+import matplotlib.pyplot as plt
+import os
+
+
+def create_mode_indices(test, subject, fit_type):
+    dim = 3
+
+    try:
+        xy, z = pickle.load(open("multi_chain_saves/xyz_{}_{}.p".format(subject, fit_type), 'rb'))
+    except Exception:
+        return
+        print('Doing PCA')
+        ev, eig, projection_matrix, dimreduc = test.state_pca(subject, pca_type='dists', dim=dim)
+        xy = np.vstack([dimreduc[i] for i in range(dim)])
+        from scipy.stats import gaussian_kde
+        z = gaussian_kde(xy)(xy)
+        pickle.dump((xy, z), open("multi_chain_saves/xyz_{}_{}.p".format(subject, fit_type), 'wb'))
+
+    print("Mode indices of " + subject)
+
+    threshold_search(xy, z, test, 'first_', subject, fit_type)
+
+    print("Find another mode?")
+    if input() not in ['yes', 'y']:
+        return
+
+    threshold_search(xy, z, test, 'second_', subject, fit_type)
+
+    print("Find another mode?")
+    if input() not in ['yes', 'y']:
+        return
+
+    threshold_search(xy, z, test, 'third_', subject, fit_type)
+    return
+
+
+def threshold_search(xy, z, test, mode_prefix, subject, fit_type):
+    happy = False
+    conds = [0, None, None, None, None, None, None]
+    x_min, x_max, y_min, y_max, z_min, z_max = None, None, None, None, None, None
+    while not happy:
+        print()
+        print("Pick level")
+        prob_level = input()
+        if prob_level == 'cond':
+            print("x > ?")
+            resp = input()
+            if resp not in ['n', 'no']:
+                x_min = float(resp)
+            else:
+                x_min = None
+
+            print("x < ?")
+            resp = input()
+            if resp not in ['n', 'no']:
+                x_max = float(resp)
+            else:
+                x_max = None
+
+            print("y > ?")
+            resp = input()
+            if resp not in ['n', 'no']:
+                y_min = float(resp)
+            else:
+                y_min = None
+
+            print("y < ?")
+            resp = input()
+            if resp not in ['n', 'no']:
+                y_max = float(resp)
+            else:
+                y_max = None
+
+            print("z > ?")
+            resp = input()
+            if resp not in ['n', 'no']:
+                z_min = float(resp)
+            else:
+                z_min = None
+
+            print("z < ?")
+            resp = input()
+            if resp not in ['n', 'no']:
+                z_max = float(resp)
+            else:
+                z_max = None
+
+            print("Prob level")
+            prob_level = float(input())
+            conds = [prob_level, x_min, x_max, y_min, y_max, z_min, z_max]
+            print("Condtions are {}".format(conds))
+        else:
+            try:
+                prob_level = float(prob_level)
+            except:
+                print('mistake')
+                prob_level = float(input)
+            conds[0] = prob_level
+
+        print("Level is {}".format(prob_level))
+
+        mode = conditions_fulfilled(z, xy, conds)
+        print("# of samples: {}".format(mode.sum()))
+        mode_indices = np.where(mode)[0]
+        if mode.sum() > 0:
+            print(xy[0][mode_indices].min(), xy[0][mode_indices].max(), xy[1][mode_indices].min(), xy[1][mode_indices].max(), xy[2][mode_indices].min(), xy[2][mode_indices].max())
+            print("Happy?")
+            happy = 'yes' == input()
+
+    print("Subset by factor?")
+    if input() == 'yes':
+        print("Factor?")
+        print(mode_indices.shape)
+        factor = int(input())
+        mode_indices = mode_indices[::factor]
+        print(mode_indices.shape)
+    if subject not in loading_info:
+        loading_info[subject] = {}
+    loading_info[subject]['mode prob level'] = prob_level
+
+    pickle.dump(mode_indices, open("multi_chain_saves/{}mode_indices_{}_{}.p".format(mode_prefix, subject, fit_type), 'wb'))
+    # consistencies = test.consistency_rsa(indices=mode_indices)  # do this on the cluster from now on
+    # pickle.dump(consistencies, open("multi_chain_saves/{}mode_consistencies_{}_{}.p".format(mode_prefix, subject, fit_type), 'wb', protocol=4))
+
+
+def conditions_fulfilled(z, xy, conds):
+    works = z > conds[0]
+    if conds[1]:
+        works = np.logical_and(works, xy[0] > conds[1])
+    if conds[2]:
+        works = np.logical_and(works, xy[0] < conds[2])
+    if conds[3]:
+        works = np.logical_and(works, xy[1] > conds[3])
+    if conds[4]:
+        works = np.logical_and(works, xy[1] < conds[4])
+    if conds[5]:
+        works = np.logical_and(works, xy[2] > conds[5])
+    if conds[6]:
+        works = np.logical_and(works, xy[2] < conds[6])
+
+    return works
+
+
+def state_set_and_plot(test, mode_prefix, subject, fit_type):
+    mode_indices = pickle.load(open("multi_chain_saves/{}mode_indices_{}_{}.p".format(mode_prefix, subject, fit_type), 'rb'))
+    consistencies = pickle.load(open("multi_chain_saves/{}mode_consistencies_{}_{}.p".format(mode_prefix, subject, fit_type), 'rb'))
+    session_bounds = list(np.cumsum([len(s) for s in test.results[0].models[-1].stateseqs]))
+
+    import scipy.cluster.hierarchy as hc
+    consistencies /= consistencies[0, 0]
+    linkage = hc.linkage(consistencies[0, 0] - consistencies[np.triu_indices(consistencies.shape[0], k=1)], method='complete')
+
+    # R = hc.dendrogram(linkage, truncate_mode='lastp', p=150, no_labels=True)
+    # plt.savefig("peter figures/{}tree_{}_{}".format(mode_prefix, subject, 'complete'))
+    # plt.close()
+
+    session_bounds = list(np.cumsum([len(s) for s in test.results[0].models[-1].stateseqs]))
+
+    plot_criterion = 0.95
+    a = hc.fcluster(linkage, plot_criterion, criterion='distance')
+    b, c = np.unique(a, return_counts=1)
+    state_sets = []
+    for x, y in zip(b, c):
+        state_sets.append(np.where(a == x)[0])
+    print("dumping state set")
+    pickle.dump(state_sets, open("multi_chain_saves/{}state_sets_{}_{}.p".format(mode_prefix, subject, fit_type), 'wb'))
+    state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='_{}{}'.format(mode_prefix, plot_criterion), show=True, separate_pmf=True, type_coloring=True)
+
+    fig, ax = plt.subplots(ncols=5, sharey=True, gridspec_kw={'width_ratios': [10, 1, 1, 1, 1]}, figsize=(13, 8))
+    from matplotlib.pyplot import cm
+    for j, criterion in enumerate([0.95, 0.8, 0.5, 0.2]):
+        clustering_colors = np.zeros((consistencies.shape[0], 100, 4))
+        a = hc.fcluster(linkage, criterion, criterion='distance')
+        b, c = np.unique(a, return_counts=1)
+        print(b.shape)
+        print(np.sort(c))
+
+        cmap = cm.rainbow(np.linspace(0, 1, 17))
+        rank_to_color_place = dict(zip(range(17), [0, 16, 8, 4, 12, 2, 6, 10, 14, 1, 3, 5, 7, 9, 11, 13, 15]))  # handcrafted to maximise color distance, I think
+        i = -1
+        b = [x for _, x in sorted(zip(c, b))][::-1]
+        c = [x for x, _ in sorted(zip(c, b))][::-1]
+        plot_above = 50
+        while len([y for y in c if y > plot_above]) > 17:
+            plot_above += 1
+        for x, y in zip(b, c):
+            if y > plot_above:
+                i += 1
+                clustering_colors[a == x] = cmap[rank_to_color_place[i]]
+                # clustering_colors[a == x] = cm.rainbow(np.mean(np.where(a == x)[0]) / a.shape[0])
+
+        ax[j+1].imshow(clustering_colors, aspect='auto', origin='upper')
+        for sb in session_bounds:
+            ax[j+1].axhline(sb, color='k')
+        ax[j+1].set_xticks([])
+        ax[j+1].set_yticks([])
+        ax[j+1].set_title("{}%".format(int(criterion * 100)), size=20)
+
+    ax[0].imshow(consistencies, aspect='auto', origin='upper')
+    for sb in session_bounds:
+        ax[0].axhline(sb, color='k')
+    ax[0].set_xticks([])
+    ax[0].set_yticks(session_bounds[::-1])
+    ax[0].set_yticklabels(session_bounds[::-1], size=18)
+    ax[0].set_ylim(session_bounds[-1], 0)
+    ax[0].set_ylabel("Trials", size=28)
+    plt.yticks(rotation=45)
+
+    plt.tight_layout()
+    plt.savefig("peter figures/{}clustered_trials_{}_{}".format(mode_prefix, subject, 'criteria comp').replace('.', '_'))
+    plt.close()
+
+
+fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0]
+if fit_type == 'bias':
+    loading_info = json.load(open("canonical_infos_bias.json", 'r'))
+elif fit_type == 'prebias':
+    loading_info = json.load(open("canonical_infos.json", 'r'))
+subjects = list(loading_info.keys())
+# error: KS043, KS045,  'NYU-12', ibl_witten_15, NYU-21, CSHL052, KS003
+# done: NYU-46, NYU-39, ibl_witten_19, NYU-48
+subjects = ['DY_013', 'ZFM-01592', 'NYU-39', 'NYU-27', 'NYU-46', 'ZFM-01936', 'ZFM-02372', 'ZFM-01935', 'ibl_witten_26', 'ZM_2241', 'KS084', 'ZFM-01576']
+
+for subject in subjects:
+    test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'rb'))
+    if os.path.isfile("multi_chain_saves/{}mode_indices_{}_{}.p".format('first_', subject, fit_type)):
+        print("It has been done")
+        continue
+    print('Computing sub result')
+    create_mode_indices(test, subject, fit_type)
+    # state_set_and_plot(test, 'first_', subject, fit_type)
+    # print("second mode?")
+    # if input() in ['y', 'yes']:
+    #     state_set_and_plot(test, 'second_', subject, fit_type)
diff --git a/remove_duplicates.py b/remove_duplicates.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f6683b06e9db6c9f9bd9e88da97bb77f1e6b7d
--- /dev/null
+++ b/remove_duplicates.py
@@ -0,0 +1,86 @@
+import os
+import re
+
+memory = {}
+name_saves = {}
+seed_saves = {}
+
+do_it = True
+
+for filename in os.listdir("./dynamic_GLMiHMM_crossvals/"):
+    if not filename.endswith('.p'):
+        continue
+    regexp = re.compile(r'((\w|-)+)_fittype_(\w+)_var_0.03_(\d+)_(\d+)_(\d+)')
+    result = regexp.search(filename)
+    subject = result.group(1)
+    fit_type = result.group(3)
+    seed = result.group(4)
+    fit_num = result.group(5)
+    chain_num = result.group(6)
+
+    if fit_type == 'prebias':
+
+        if subject not in seed_saves:
+            seed_saves[subject] = [seed]
+        else:
+            if seed not in seed_saves[subject]:
+                seed_saves[subject].append(seed)
+
+        if (subject, seed) not in name_saves:
+            name_saves[(subject, seed)] = []
+
+        if fit_num not in name_saves[(subject, seed)]:
+            name_saves[(subject, seed)].append(fit_num)
+
+        if (subject, seed, fit_num) not in memory:
+            memory[(subject, seed, fit_num)] = {"chain_num": int(chain_num), "counter": 1}
+        else:  # if this is the first file of that chain, save some info
+            memory[(subject, seed, fit_num)]["chain_num"] = max(memory[(subject, seed, fit_num)]["chain_num"], int(chain_num))
+            memory[(subject, seed, fit_num)]["counter"] += 1
+
+total_move = 0
+nyu_11_move = 0
+dicts_removed = 0
+moved = []
+completed = []
+incompleted = []
+for key in name_saves:
+    subject = key[0]
+    seed = key[1]
+    complete = False
+    save_fit_num = -1
+    for fit_num in name_saves[key]:
+        if memory[(subject, seed, fit_num)]['chain_num'] == 14 and memory[(subject, seed, fit_num)]['counter'] == 15:
+            save_fit_num = fit_num
+            complete = True
+    if len(seed_saves[subject]) == 16:
+        if subject not in completed:
+            completed.append(subject)
+    else:
+        if subject not in incompleted:
+            incompleted.append(subject)
+    if complete and len(name_saves[key]) > 1:
+        assert save_fit_num != -1
+        for fit_num in name_saves[key]:
+            if fit_num != save_fit_num:
+                for i in range(15):
+                    if do_it:
+                        if os.path.exists("./dynamic_GLMiHMM_crossvals/{}_fittype_prebias_var_0.03_{}_{}_{}.p".format(subject, seed, fit_num, i)):
+                            os.rename("./dynamic_GLMiHMM_crossvals/{}_fittype_prebias_var_0.03_{}_{}_{}.p".format(subject, seed, fit_num, i),
+                                      "./del_test/{}_fittype_prebias_var_0.03_{}_{}_{}.p".format(subject, seed, fit_num, i))
+                            if subject not in moved:
+                                moved.append(subject)
+                            total_move += 1
+                    else:
+                        if os.path.exists("./dynamic_GLMiHMM_crossvals/{}_fittype_prebias_var_0.03_{}_{}_{}.p".format(subject, seed, fit_num, i)):
+                            print("I would move ")
+                            print("./dynamic_GLMiHMM_crossvals/{}_fittype_prebias_var_0.03_{}_{}_{}.p".format(subject, seed, fit_num, i))
+                            print(" to ")
+                            print("./del_test/{}_fittype_prebias_var_0.03_{}_{}_{}.p".format(subject, seed, fit_num, i))
+                            total_move += 1
+                            nyu_11_move += subject == "NYU-11"
+
+print(moved)
+print(completed)
+print(incompleted)
+print("Would move {} in total, and {} of NYU-11".format(total_move, nyu_11_move))
diff --git a/test_codes/bias_shift_test.py b/test_codes/bias_shift_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..597c7071b9a5eea45abeb6e7832acea3fae9879a
--- /dev/null
+++ b/test_codes/bias_shift_test.py
@@ -0,0 +1,38 @@
+"""
+    Studying how much easier it is for small changes in the regressors to affect the pmf around 0 versus further from zero
+
+"""
+import numpy as np
+import matplotlib.pyplot as plt
+
+
+weights = list(np.zeros((17, 3)))
+
+for i, weight in enumerate(weights):
+    weight[-1] = i * 0.2 - 1.6
+
+contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])
+contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1]
+bias = np.ones(11)
+
+predictors = np.vstack((contrasts_L, contrasts_R, bias)).T
+
+for weight in weights:
+    plt.plot(1 / (1 + np.exp(- np.sum(weight * predictors, axis=1))))
+
+plt.ylim(0, 1)
+plt.show()
+
+
+
+weights = list(np.zeros((17, 3)))
+
+for i, weight in enumerate(weights):
+    weight[-1] = i * 0.2 - 1.6
+    weight[0] = 2
+
+for weight in weights:
+    plt.plot(1 / (1 + np.exp(- np.sum(weight * predictors, axis=1))))
+
+plt.ylim(0, 1)
+plt.show()
\ No newline at end of file