diff --git a/__pycache__/mcmc_chain_analysis.cpython-34.pyc b/__pycache__/mcmc_chain_analysis.cpython-34.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1f4261481b77817f0a5299f8eaec91dab934679b
Binary files /dev/null and b/__pycache__/mcmc_chain_analysis.cpython-34.pyc differ
diff --git a/__pycache__/pmf_analysis.cpython-37.pyc b/__pycache__/pmf_analysis.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5d73cb225dcfed4bac0c68eecdffceec3453a4d1
Binary files /dev/null and b/__pycache__/pmf_analysis.cpython-37.pyc differ
diff --git a/__pycache__/simplex_plot.cpython-37.pyc b/__pycache__/simplex_plot.cpython-37.pyc
index 8fa16cdde21c4a7d3d6daa3797043fc2e294e842..cc971e6effeb8c5cd96ac67921d02e35f9142f1b 100644
Binary files a/__pycache__/simplex_plot.cpython-37.pyc and b/__pycache__/simplex_plot.cpython-37.pyc differ
diff --git a/all_first_pmfs_typeless b/all_first_pmfs_typeless
new file mode 100644
index 0000000000000000000000000000000000000000..336970346beed422f9657b5d9db49519023836d9
Binary files /dev/null and b/all_first_pmfs_typeless differ
diff --git a/behaviour_overview.py b/behaviour_overview.py
index 48418b9c0dabfdb4e7b7bec914109f255a1fb815..24180ed205a00a9cb7e864c39782d62c92f8d658 100644
--- a/behaviour_overview.py
+++ b/behaviour_overview.py
@@ -2,6 +2,7 @@ import matplotlib.pyplot as plt
 import numpy as np
 import pandas as pd
 from one.api import ONE
+import pickle
 
 
 show = 1
@@ -21,27 +22,39 @@ def progression(data, contrasts, progression_variable='feedback', windowsize=6,
     plt.ylim(0, upper_bound)
     plt.gca().spines['top'].set_visible(False)
     plt.gca().spines['right'].set_visible(False)
-    # if title == "7 rt" or title == "7 fb":
-    #     plt.plot([271, 286], [1, 1], color='r', lw=3)
-    #     plt.plot([429, 439], [1, 1], color='r', lw=3)
-    #     plt.plot([632, 636], [1, 1], color='r', lw=3)
-    if title == "KS014, RT session 12 / 15" or title == "KS014, performance session 12 / 15":
-        plt.plot([599, 603], [1, 1], color='m', lw=6)
-        plt.plot([658, 662], [1, 1], color='m', lw=6)
-        plt.plot([703, 708], [1, 1], color='m', lw=6)
+    # if title == "KS014, RT session 12 / 15" or title == "KS014, performance session 12 / 15":
+    #     plt.plot([599, 603], [1, 1], color='m', lw=6)
+    #     plt.plot([658, 662], [1, 1], color='m', lw=6)
+    #     plt.plot([703, 708], [1, 1], color='m', lw=6)
 
     plt.title(title, size=22)
     plt.legend(fontsize=13)
     plt.ylabel(progression_variable, size=20)
     plt.xlabel('trial', size=20)
-    if title:
-        plt.savefig('./overview_figures/' + title.replace('/', '_') + '.png')
+    # if title:
+    #     plt.savefig('./overview_figures/' + title.replace('/', '_') + '.png')
 
     if show:
         plt.show()
     else:
         plt.close()
 
+    # if progression_variable == 'feedback':
+    #     means = data.groupby('signed_contrast').mean()['response']
+    #     stds = data.groupby('signed_contrast').std()['response']
+    #     plt.errorbar(means.index, means.values, stds.values)
+    #     plt.ylim(-0.2, 1.2)
+    #     plt.title(title, size=22)
+    #     plt.savefig("temp {}".format(title).replace('/', '_'))
+    #     plt.close()
+    # if progression_variable == 'rt':
+    #     means = data.groupby('signed_contrast').mean()['rt']
+    #     stds = data.groupby('signed_contrast').std()['rt']
+    #     plt.errorbar(means.index, means.values, stds.values)
+    #     plt.title(title, size=22)
+    #     plt.savefig("temp {}".format(title).replace('/', '_'))
+    #     plt.show()
+
 
 dataset_types = ['choice', 'contrastLeft', 'contrastRight', \
                  'feedbackType', 'probabilityLeft', 'response_times', \
@@ -49,7 +62,7 @@ dataset_types = ['choice', 'contrastLeft', 'contrastRight', \
 def get_df(eid):
     # dangerously many changes, check again
     try:
-        data_dict = one.load_object(eid, 'trials', attribute=dataset_types) # download_only=True messes with my code, returns path for whatever reason
+        data_dict = one.load_object(eid, 'trials') # download_only=True messes with my code, returns path for whatever reason
     except:
         print('lol')
         return None, None
@@ -82,25 +95,34 @@ exclude_eids = ['a66f1593-dafd-4982-9b66-f9554b6c86b5', 'ee40aece-cffd-4edb-a4b6
 #                      project='ibl_neuropixel_brainwide_01')
 # traj.reverse()
 
-for subject in ['KS014']:
-    eids, sess_info = one.search(subject=subject, date_range=['2015-01-01', '2022-01-01'], details=True)
+subject = 'KS014'
+eids, sess_info = one.search(subject=subject, date_range=['2015-01-01', '2022-01-01'], details=True)
 
-    start_times = [sess['date'] for sess in sess_info]
-    protocols = [sess['task_protocol'] for sess in sess_info]
+start_times = [sess['date'] for sess in sess_info]
+protocols = [sess['task_protocol'] for sess in sess_info]
 
-    eids = [x for _, x in sorted(zip(start_times, eids))]
+eids = [x for _, x in sorted(zip(start_times, eids))]
+protocols = [x for _, x in sorted(zip(start_times, protocols))]
 
 # for t in traj:
 counti = 0
-for i, eid in enumerate(eids):
+for i, (prot, eid) in enumerate(zip(protocols, eids)):
     print(i)
+    if not prot.startswith('_iblrig_tasks_trainingChoiceWorld'):
+        continue
     # eid = t['session']['id']
     df, _ = get_df(eid)
+
     if df is None:
         continue
 
     counti += 1
-    if counti != 12:
-        continue
-    progression(df, df['signed_contrast'].unique(), progression_variable='feedback', upper_bound=2, title="{}, performance session {} / 15".format(subject, counti))
-    progression(df, df['signed_contrast'].unique(), progression_variable='rt', upper_bound=4, title="{}, RT session {} / 15".format(subject, counti))
+
+    rt_data = np.zeros((len(df), 3))
+    rt_data[:, 0] = df['signed_contrast']
+    rt_data[:, 1] = df['rt']
+    rt_data[:, 2] = df['response']
+    pickle.dump(rt_data, open("./session_data/{} rt info {}".format(subject, counti), 'wb'))
+
+    progression(df, df['signed_contrast'].unique(), progression_variable='feedback', upper_bound=2, title="{} PMF {} / 15".format(subject, counti))
+    progression(df, df['signed_contrast'].unique(), progression_variable='rt', upper_bound=4, title="{} CMF {} / 15".format(subject, counti))
diff --git a/canonical_infos.json b/canonical_infos.json
index c2ca57ff4832606bb4d5af5dad36768fc4162753..5f8da99817e18fb2831226f46d3018b12ba8b064 100644
--- a/canonical_infos.json
+++ b/canonical_infos.json
@@ -1 +1 @@
-{"SWC_023": {"seeds": ["302", "312", "304", "300", "315", "311", "308", "305", "303", "309", "306", "313", "307", "314", "301", "310"], "fit_nums": ["994", "913", "681", "816", "972", "790", "142", "230", "696", "537", "975", "773", "918", "677", "742", "745"], "chain_num": 4}, "SWC_021": {"seeds": ["415", "403", "412", "407", "409", "408", "405", "404", "410", "414", "401", "413", "402", "400", "406", "411"], "fit_nums": ["773", "615", "107", "583", "564", "354", "142", "184", "549", "185", "924", "907", "105", "531", "9", "812"], "chain_num": 9}, "ibl_witten_15": {"seeds": ["409", "410", "401", "415", "414", "403", "411", "404", "402", "405", "400", "412", "408", "407", "406", "413"], "fit_nums": ["411", "344", "496", "600", "716", "18", "527", "467", "898", "334", "309", "326", "133", "823", "740", "253"], "chain_num": 9}, "ibl_witten_13": {"seeds": ["302", "312", "313", "306", "315", "307", "311", "314", "309", "301", "308", "300", "304", "310", "303", "305"], "fit_nums": ["897", "765", "433", "641", "967", "599", "984", "259", "853", "385", "887", "619", "434", "964", "483", "891"], "chain_num": 4}, "KS016": {"seeds": ["315", "301", "309", "313", "302", "307", "303", "308", "311", "312", "314", "306", "310", "300", "305", "304"], "fit_nums": ["99", "57", "585", "32", "501", "558", "243", "413", "59", "757", "463", "172", "524", "957", "909", "292"], "chain_num": 4}, "KS003": {"seeds": ["404", "407", "413", "403", "414", "405", "400", "401", "402", "410", "415", "408", "411", "409", "406", "412"], "fit_nums": ["846", "256", "845", "945", "293", "406", "420", "109", "690", "421", "54", "866", "784", "81", "997", "665"], "chain_num": 9}, "ibl_witten_19": {"seeds": ["315", "311", "307", "314", "308", "300", "305", "301", "313", "304", "302", "310", "306", "312", "309", "303"], "fit_nums": ["179", "951", "613", "6", "623", "382", "458", "504", "406", "554", "5", "631", "746", "817", "265", "328"], "chain_num": 4}, "SWC_022": {"seeds": ["411", "403", "414", "409", "407", "412", "410", "413", "415", "404", "405", "400", "402", "401", "408", "406"], "fit_nums": ["408", "884", "62", "962", "744", "854", "635", "70", "320", "952", "8", "67", "231", "381", "536", "962"], "chain_num": 9}, "KS022": {"seeds": ["315", "300", "314", "301", "303", "302", "306", "308", "305", "310", "313", "312", "304", "307", "311", "309"], "fit_nums": ["899", "681", "37", "957", "629", "637", "375", "980", "810", "51", "759", "664", "420", "127", "259", "555"], "chain_num": 4}, "CSH_ZAD_017": {"seeds": ["401", "409", "405", "403", "415", "404", "402", "411", "410", "414", "408", "406", "413", "412", "400", "407"], "fit_nums": ["883", "803", "637", "806", "356", "804", "662", "654", "684", "350", "947", "460", "569", "976", "103", "713"], "chain_num": 9}, "CSH_ZAD_025": {"seeds": ["303", "311", "307", "312", "313", "314", "308", "315", "305", "306", "304", "302", "309", "310", "301", "300"], "fit_nums": ["581", "148", "252", "236", "581", "838", "206", "756", "449", "288", "756", "593", "733", "633", "418", "563"], "chain_num": 4}, "ibl_witten_17": {"seeds": ["406", "415", "408", "413", "402", "405", "409", "400", "414", "401", "412", "407", "404", "410", "403", "411"], "fit_nums": ["827", "797", "496", "6", "444", "823", "384", "873", "634", "27", "811", "142", "207", "322", "756", "275"], "chain_num": 9}, "ibl_witten_18": {"seeds": ["311", "310", "303", "314", "302", "309", "305", "307", "312", "300", "308", "306", "315", "313", "304", "301"], "fit_nums": ["236", "26", "838", "762", "826", "409", "496", "944", "280", "704", "930", "419", "637", "896", "876", "297"], "chain_num": 4}, "CSHL_018": {"seeds": ["302", "310", "306", "300", "314", "307", "309", "313", "311", "308", "304", "301", "312", "303", "305", "315"], "fit_nums": ["843", "817", "920", "900", "226", "36", "472", "676", "933", "453", "116", "263", "269", "897", "568", "438"], "chain_num": 4}, "GLM_Sim_06": {"seeds": ["313", "309", "302", "303", "305", "314", "300", "315", "311", "306", "304", "310", "301", "312", "308", "307"], "fit_nums": ["9", "786", "286", "280", "72", "587", "619", "708", "360", "619", "311", "189", "60", "708", "939", "733"], "chain_num": 2}, "ZM_1897": {"seeds": ["304", "308", "305", "311", "315", "314", "307", "306", "300", "303", "313", "310", "301", "312", "302", "309"], "fit_nums": ["549", "96", "368", "509", "424", "897", "287", "426", "968", "93", "725", "513", "837", "581", "989", "374"], "chain_num": 4}, "CSHL_020": {"seeds": ["305", "309", "313", "302", "314", "310", "300", "307", "315", "306", "312", "304", "311", "301", "303", "308"], "fit_nums": ["222", "306", "243", "229", "584", "471", "894", "238", "986", "660", "494", "657", "896", "459", "100", "283"], "chain_num": 4}, "CSHL054": {"seeds": ["401", "415", "409", "410", "414", "413", "407", "405", "406", "408", "411", "400", "412", "402", "403", "404"], "fit_nums": ["901", "734", "609", "459", "574", "793", "978", "66", "954", "906", "954", "111", "292", "850", "266", "967"], "chain_num": 9}, "CSHL_014": {"seeds": ["305", "311", "309", "300", "313", "310", "307", "306", "304", "312", "308", "302", "314", "303", "301", "315"], "fit_nums": ["371", "550", "166", "24", "705", "385", "870", "884", "831", "546", "404", "722", "287", "564", "613", "783"], "chain_num": 4}, "CSHL062": {"seeds": ["307", "313", "310", "303", "306", "312", "308", "305", "311", "314", "304", "302", "300", "301", "315", "309"], "fit_nums": ["846", "371", "94", "888", "499", "229", "546", "432", "71", "989", "986", "91", "935", "314", "975", "481"], "chain_num": 4}, "CSH_ZAD_001": {"seeds": ["313", "309", "311", "312", "305", "310", "315", "300", "314", "304", "301", "302", "308", "303", "306", "307"], "fit_nums": ["468", "343", "314", "544", "38", "120", "916", "170", "305", "569", "502", "496", "452", "336", "559", "572"], "chain_num": 4}, "NYU-06": {"seeds": ["314", "309", "306", "305", "312", "303", "307", "304", "300", "302", "310", "301", "315", "308", "313", "311"], "fit_nums": ["950", "862", "782", "718", "427", "645", "827", "612", "821", "834", "595", "929", "679", "668", "648", "869"], "chain_num": 4}, "KS019": {"seeds": ["404", "401", "411", "408", "400", "403", "410", "413", "402", "407", "415", "409", "406", "414", "412", "405"], "fit_nums": ["682", "4", "264", "200", "250", "267", "737", "703", "132", "855", "922", "686", "85", "176", "54", "366"], "chain_num": 9}, "CSHL049": {"seeds": ["411", "402", "414", "408", "409", "410", "413", "407", "406", "401", "404", "405", "403", "415", "400", "412"], "fit_nums": ["104", "553", "360", "824", "749", "519", "347", "228", "863", "671", "140", "883", "701", "445", "627", "898"], "chain_num": 9}, "ibl_witten_14": {"seeds": ["310", "311", "304", "306", "300", "302", "314", "313", "303", "308", "301", "309", "305", "315", "312", "307"], "fit_nums": ["563", "120", "85", "712", "277", "871", "183", "661", "505", "598", "210", "89", "310", "638", "564", "998"], "chain_num": 4}, "KS014": {"seeds": ["301", "310", "302", "312", "313", "308", "307", "303", "305", "300", "314", "306", "311", "309", "304", "315"], "fit_nums": ["668", "32", "801", "193", "269", "296", "74", "24", "270", "916", "21", "250", "342", "451", "517", "293"], "chain_num": 4, "ignore": [9, 11, 0, 1, 14, 2, 12, 13]}, "CSHL059": {"seeds": ["306", "309", "300", "304", "314", "303", "315", "311", "313", "305", "301", "307", "302", "312", "310", "308"], "fit_nums": ["821", "963", "481", "999", "986", "45", "551", "605", "701", "201", "629", "261", "972", "407", "165", "9"], "chain_num": 4}, "GLM_Sim_13": {"seeds": ["310", "303", "308", "306", "300", "312", "301", "313", "305", "311", "315", "304", "314", "309", "307", "302"], "fit_nums": ["982", "103", "742", "524", "614", "370", "926", "456", "133", "143", "302", "80", "395", "549", "579", "944"], "chain_num": 2}, "CSHL_007": {"seeds": ["314", "303", "308", "313", "301", "300", "302", "305", "315", "306", "310", "309", "311", "304", "307", "312"], "fit_nums": ["462", "703", "345", "286", "480", "313", "986", "165", "201", "102", "322", "894", "960", "438", "330", "169"], "chain_num": 4}, "CSH_ZAD_011": {"seeds": ["314", "311", "303", "300", "305", "310", "306", "301", "302", "315", "304", "309", "308", "312", "313", "307"], "fit_nums": ["320", "385", "984", "897", "315", "120", "320", "945", "475", "403", "210", "412", "695", "564", "664", "411"], "chain_num": 4}, "KS021": {"seeds": ["309", "312", "304", "310", "303", "311", "314", "302", "305", "301", "306", "300", "308", "315", "313", "307"], "fit_nums": ["874", "943", "925", "587", "55", "136", "549", "528", "349", "211", "401", "84", "225", "545", "153", "382"], "chain_num": 4}, "GLM_Sim_15": {"seeds": ["303", "312", "305", "308", "309", "302", "301", "310", "313", "315", "311", "314", "307", "306", "304", "300"], "fit_nums": ["769", "930", "328", "847", "899", "714", "144", "518", "521", "873", "914", "359", "242", "343", "45", "364"], "chain_num": 2}, "CSHL_015": {"seeds": ["301", "302", "307", "310", "309", "311", "304", "312", "300", "308", "313", "305", "314", "315", "306", "303"], "fit_nums": ["717", "705", "357", "539", "604", "971", "669", "76", "45", "413", "510", "122", "190", "821", "368", "472"], "chain_num": 4}, "ibl_witten_16": {"seeds": ["304", "313", "309", "314", "312", "307", "305", "301", "306", "310", "300", "315", "308", "311", "303", "302"], "fit_nums": ["392", "515", "696", "270", "7", "583", "880", "674", "23", "576", "579", "695", "149", "854", "184", "875"], "chain_num": 4}, "KS015": {"seeds": ["315", "305", "309", "303", "314", "310", "311", "312", "313", "300", "307", "308", "304", "301", "302", "306"], "fit_nums": ["257", "396", "387", "435", "133", "164", "403", "8", "891", "650", "111", "557", "473", "229", "842", "196"], "chain_num": 4}, "GLM_Sim_12": {"seeds": ["304", "312", "306", "303", "310", "302", "300", "305", "308", "313", "307", "311", "315", "301", "314", "309"], "fit_nums": ["971", "550", "255", "195", "952", "486", "841", "535", "559", "37", "654", "213", "864", "506", "732", "550"], "chain_num": 2}, "GLM_Sim_11": {"seeds": ["300", "312", "310", "315", "302", "313", "314", "311", "308", "303", "309", "307", "306", "304", "301", "305"], "fit_nums": ["477", "411", "34", "893", "195", "293", "603", "5", "887", "281", "956", "73", "346", "640", "532", "688"], "chain_num": 2}, "GLM_Sim_10": {"seeds": ["301", "300", "306", "305", "307", "309", "312", "314", "311", "315", "304", "313", "303", "308", "302", "310"], "fit_nums": ["391", "97", "897", "631", "239", "652", "19", "448", "807", "35", "972", "469", "280", "562", "42", "706"], "chain_num": 2}, "CSH_ZAD_026": {"seeds": ["312", "313", "308", "310", "303", "307", "302", "305", "300", "315", "306", "301", "311", "304", "314", "309"], "fit_nums": ["699", "87", "537", "628", "797", "511", "459", "770", "969", "240", "504", "948", "295", "506", "25", "378"], "chain_num": 4}, "KS023": {"seeds": ["304", "313", "306", "309", "300", "314", "302", "310", "303", "315", "307", "308", "301", "311", "305", "312"], "fit_nums": ["698", "845", "319", "734", "908", "507", "45", "499", "175", "108", "419", "443", "116", "779", "159", "231"], "chain_num": 4}, "GLM_Sim_05": {"seeds": ["301", "315", "300", "302", "305", "304", "313", "314", "311", "309", "306", "307", "308", "310", "303", "312"], "fit_nums": ["425", "231", "701", "375", "343", "902", "623", "125", "921", "637", "393", "964", "678", "930", "796", "42"], "chain_num": 2}, "CSHL061": {"seeds": ["305", "315", "304", "303", "309", "310", "302", "300", "314", "306", "311", "313", "301", "308", "307", "312"], "fit_nums": ["396", "397", "594", "911", "308", "453", "686", "552", "103", "209", "128", "892", "345", "925", "777", "396"], "chain_num": 4}, "CSHL051": {"seeds": ["303", "310", "306", "302", "309", "305", "313", "308", "300", "314", "311", "307", "312", "304", "315", "301"], "fit_nums": ["69", "186", "49", "435", "103", "910", "705", "367", "303", "474", "596", "334", "929", "796", "616", "790"], "chain_num": 4}, "GLM_Sim_14": {"seeds": ["310", "311", "309", "313", "314", "300", "302", "304", "305", "306", "307", "312", "303", "301", "315", "308"], "fit_nums": ["616", "872", "419", "106", "940", "986", "599", "704", "218", "808", "244", "825", "448", "397", "552", "316"], "chain_num": 2}, "GLM_Sim_11_trick": {"seeds": ["411", "400", "408", "409", "415", "413", "410", "412", "406", "414", "403", "404", "401", "405", "407", "402"], "fit_nums": ["95", "508", "886", "384", "822", "969", "525", "382", "489", "436", "344", "537", "251", "223", "458", "401"], "chain_num": 2, "ignore": [10, 12, 4, 1, 0, 3, 2, 6]}, "GLM_Sim_16": {"seeds": ["302", "311", "303", "307", "313", "308", "309", "300", "305", "315", "304", "310", "312", "301", "314", "306"], "fit_nums": ["914", "377", "173", "583", "870", "456", "611", "697", "13", "713", "159", "248", "617", "37", "770", "780"], "chain_num": 2}, "ZM_3003": {"seeds": ["300", "304", "307", "312", "305", "310", "311", "314", "303", "308", "313", "301", "315", "309", "306", "302"], "fit_nums": ["603", "620", "657", "735", "357", "390", "119", "33", "62", "617", "209", "810", "688", "21", "744", "426"], "chain_num": 4}, "CSH_ZAD_022": {"seeds": ["305", "310", "311", "315", "303", "312", "314", "313", "307", "302", "300", "304", "301", "308", "306", "309"], "fit_nums": ["143", "946", "596", "203", "576", "403", "900", "65", "478", "325", "282", "513", "460", "42", "161", "970"], "chain_num": 4}, "GLM_Sim_07": {"seeds": ["300", "309", "302", "304", "305", "312", "301", "311", "315", "314", "308", "307", "303", "310", "306", "313"], "fit_nums": ["724", "701", "118", "230", "648", "426", "689", "114", "832", "731", "592", "519", "559", "938", "672", "144"], "chain_num": 1}, "KS017": {"seeds": ["311", "310", "306", "309", "303", "302", "308", "300", "313", "301", "314", "307", "315", "304", "312", "305"], "fit_nums": ["97", "281", "808", "443", "352", "890", "703", "468", "780", "708", "674", "27", "345", "23", "939", "457"], "chain_num": 4}, "GLM_Sim_11_sub": {"seeds": ["410", "414", "413", "404", "409", "415", "406", "408", "402", "411", "400", "405", "403", "407", "412", "401"], "fit_nums": ["830", "577", "701", "468", "929", "374", "954", "749", "937", "488", "873", "416", "612", "792", "461", "488"], "chain_num": 2}}
\ No newline at end of file
+{"SWC_023": {"seeds": ["302", "312", "304", "300", "315", "311", "308", "305", "303", "309", "306", "313", "307", "314", "301", "310"], "fit_nums": ["994", "913", "681", "816", "972", "790", "142", "230", "696", "537", "975", "773", "918", "677", "742", "745"], "chain_num": 4, "ignore": [12, 1, 15, 14, 8, 6, 4, 10]}, "SWC_021": {"seeds": ["415", "403", "412", "407", "409", "408", "405", "404", "410", "414", "401", "413", "402", "400", "406", "411"], "fit_nums": ["773", "615", "107", "583", "564", "354", "142", "184", "549", "185", "924", "907", "105", "531", "9", "812"], "chain_num": 9, "ignore": [14, 12, 0, 10, 9, 4, 5, 1]}, "ibl_witten_15": {"seeds": ["409", "410", "401", "415", "414", "403", "411", "404", "402", "405", "400", "412", "408", "407", "406", "413"], "fit_nums": ["411", "344", "496", "600", "716", "18", "527", "467", "898", "334", "309", "326", "133", "823", "740", "253"], "chain_num": 9, "ignore": [14, 13, 8, 4, 5, 12, 11, 9]}, "ibl_witten_13": {"seeds": ["302", "312", "313", "306", "315", "307", "311", "314", "309", "301", "308", "300", "304", "310", "303", "305"], "fit_nums": ["897", "765", "433", "641", "967", "599", "984", "259", "853", "385", "887", "619", "434", "964", "483", "891"], "chain_num": 4, "ignore": [3, 5, 15, 0, 2, 12, 11, 10]}, "KS016": {"seeds": ["315", "301", "309", "313", "302", "307", "303", "308", "311", "312", "314", "306", "310", "300", "305", "304"], "fit_nums": ["99", "57", "585", "32", "501", "558", "243", "413", "59", "757", "463", "172", "524", "957", "909", "292"], "chain_num": 4, "ignore": [0, 2, 14, 12, 1, 7, 11, 6]}, "KS003": {"seeds": ["404", "407", "413", "403", "414", "405", "400", "401", "402", "410", "415", "408", "411", "409", "406", "412"], "fit_nums": ["846", "256", "845", "945", "293", "406", "420", "109", "690", "421", "54", "866", "784", "81", "997", "665"], "chain_num": 9, "ignore": [8, 15, 0, 13, 7, 12, 11, 1]}, "ibl_witten_19": {"seeds": ["315", "311", "307", "314", "308", "300", "305", "301", "313", "304", "302", "310", "306", "312", "309", "303"], "fit_nums": ["179", "951", "613", "6", "623", "382", "458", "504", "406", "554", "5", "631", "746", "817", "265", "328"], "chain_num": 4, "ignore": [13, 4, 10, 9, 2, 1, 3, 6]}, "SWC_022": {"seeds": ["411", "403", "414", "409", "407", "412", "410", "413", "415", "404", "405", "400", "402", "401", "408", "406"], "fit_nums": ["408", "884", "62", "962", "744", "854", "635", "70", "320", "952", "8", "67", "231", "381", "536", "962"], "chain_num": 9, "ignore": [4, 8, 7, 9, 1, 2, 10, 6]}, "KS022": {"seeds": ["315", "300", "314", "301", "303", "302", "306", "308", "305", "310", "313", "312", "304", "307", "311", "309"], "fit_nums": ["899", "681", "37", "957", "629", "637", "375", "980", "810", "51", "759", "664", "420", "127", "259", "555"], "chain_num": 4, "ignore": [10, 1, 0, 13, 5, 9, 12, 3]}, "CSH_ZAD_017": {"seeds": ["401", "409", "405", "403", "415", "404", "402", "411", "410", "414", "408", "406", "413", "412", "400", "407"], "fit_nums": ["883", "803", "637", "806", "356", "804", "662", "654", "684", "350", "947", "460", "569", "976", "103", "713"], "chain_num": 9, "ignore": [3, 4, 6, 7, 5, 0, 15, 12]}, "CSH_ZAD_025": {"seeds": ["303", "311", "307", "312", "313", "314", "308", "315", "305", "306", "304", "302", "309", "310", "301", "300"], "fit_nums": ["581", "148", "252", "236", "581", "838", "206", "756", "449", "288", "756", "593", "733", "633", "418", "563"], "chain_num": 4, "ignore": [8, 10, 13, 5, 12, 9, 7, 1]}, "ibl_witten_17": {"seeds": ["406", "415", "408", "413", "402", "405", "409", "400", "414", "401", "412", "407", "404", "410", "403", "411"], "fit_nums": ["827", "797", "496", "6", "444", "823", "384", "873", "634", "27", "811", "142", "207", "322", "756", "275"], "chain_num": 9, "ignore": [9, 0, 1, 7, 11, 3, 10, 8]}, "ibl_witten_18": {"seeds": ["311", "310", "303", "314", "302", "309", "305", "307", "312", "300", "308", "306", "315", "313", "304", "301"], "fit_nums": ["236", "26", "838", "762", "826", "409", "496", "944", "280", "704", "930", "419", "637", "896", "876", "297"], "chain_num": 4, "ignore": [11, 0, 4, 2, 12, 13, 8, 3]}, "CSHL_018": {"seeds": ["302", "310", "306", "300", "314", "307", "309", "313", "311", "308", "304", "301", "312", "303", "305", "315"], "fit_nums": ["843", "817", "920", "900", "226", "36", "472", "676", "933", "453", "116", "263", "269", "897", "568", "438"], "chain_num": 4, "ignore": [15, 4, 8, 0, 5, 10, 12, 11]}, "GLM_Sim_06": {"seeds": ["313", "309", "302", "303", "305", "314", "300", "315", "311", "306", "304", "310", "301", "312", "308", "307"], "fit_nums": ["9", "786", "286", "280", "72", "587", "619", "708", "360", "619", "311", "189", "60", "708", "939", "733"], "chain_num": 2, "ignore": [15, 9, 8, 14, 1, 12, 10, 3]}, "ZM_1897": {"seeds": ["304", "308", "305", "311", "315", "314", "307", "306", "300", "303", "313", "310", "301", "312", "302", "309"], "fit_nums": ["549", "96", "368", "509", "424", "897", "287", "426", "968", "93", "725", "513", "837", "581", "989", "374"], "chain_num": 4, "ignore": [0, 14, 5, 8, 7, 11, 13, 10]}, "CSHL_020": {"seeds": ["305", "309", "313", "302", "314", "310", "300", "307", "315", "306", "312", "304", "311", "301", "303", "308"], "fit_nums": ["222", "306", "243", "229", "584", "471", "894", "238", "986", "660", "494", "657", "896", "459", "100", "283"], "chain_num": 4, "ignore": [6, 5, 9, 15, 0, 8, 4, 13]}, "CSHL054": {"seeds": ["401", "415", "409", "410", "414", "413", "407", "405", "406", "408", "411", "400", "412", "402", "403", "404"], "fit_nums": ["901", "734", "609", "459", "574", "793", "978", "66", "954", "906", "954", "111", "292", "850", "266", "967"], "chain_num": 9, "ignore": [5, 12, 7, 10, 11, 2, 6, 4]}, "CSHL_014": {"seeds": ["305", "311", "309", "300", "313", "310", "307", "306", "304", "312", "308", "302", "314", "303", "301", "315"], "fit_nums": ["371", "550", "166", "24", "705", "385", "870", "884", "831", "546", "404", "722", "287", "564", "613", "783"], "chain_num": 4, "ignore": [15, 0, 3, 4, 7, 6, 1, 11]}, "CSHL062": {"seeds": ["307", "313", "310", "303", "306", "312", "308", "305", "311", "314", "304", "302", "300", "301", "315", "309"], "fit_nums": ["846", "371", "94", "888", "499", "229", "546", "432", "71", "989", "986", "91", "935", "314", "975", "481"], "chain_num": 4, "ignore": [14, 6, 3, 11, 15, 13, 4, 12]}, "CSH_ZAD_001": {"seeds": ["313", "309", "311", "312", "305", "310", "315", "300", "314", "304", "301", "302", "308", "303", "306", "307"], "fit_nums": ["468", "343", "314", "544", "38", "120", "916", "170", "305", "569", "502", "496", "452", "336", "559", "572"], "chain_num": 4, "ignore": [12, 8, 5, 1, 9, 3, 13, 15]}, "NYU-06": {"seeds": ["314", "309", "306", "305", "312", "303", "307", "304", "300", "302", "310", "301", "315", "308", "313", "311"], "fit_nums": ["950", "862", "782", "718", "427", "645", "827", "612", "821", "834", "595", "929", "679", "668", "648", "869"], "chain_num": 4, "ignore": [8, 2, 7, 12, 3, 4, 13, 11]}, "KS019": {"seeds": ["404", "401", "411", "408", "400", "403", "410", "413", "402", "407", "415", "409", "406", "414", "412", "405"], "fit_nums": ["682", "4", "264", "200", "250", "267", "737", "703", "132", "855", "922", "686", "85", "176", "54", "366"], "chain_num": 9, "ignore": [12, 14, 1, 2, 4, 7, 10, 15]}, "CSHL049": {"seeds": ["411", "402", "414", "408", "409", "410", "413", "407", "406", "401", "404", "405", "403", "415", "400", "412"], "fit_nums": ["104", "553", "360", "824", "749", "519", "347", "228", "863", "671", "140", "883", "701", "445", "627", "898"], "chain_num": 9, "ignore": [10, 11, 6, 7, 12, 13, 1, 8]}, "ibl_witten_14": {"seeds": ["310", "311", "304", "306", "300", "302", "314", "313", "303", "308", "301", "309", "305", "315", "312", "307"], "fit_nums": ["563", "120", "85", "712", "277", "871", "183", "661", "505", "598", "210", "89", "310", "638", "564", "998"], "chain_num": 4, "ignore": [11, 14, 6, 13, 5, 12, 15, 8]}, "KS014": {"seeds": ["301", "310", "302", "312", "313", "308", "307", "303", "305", "300", "314", "306", "311", "309", "304", "315"], "fit_nums": ["668", "32", "801", "193", "269", "296", "74", "24", "270", "916", "21", "250", "342", "451", "517", "293"], "chain_num": 4, "ignore": [9, 11, 0, 1, 14, 2, 12, 13]}, "CSHL059": {"seeds": ["306", "309", "300", "304", "314", "303", "315", "311", "313", "305", "301", "307", "302", "312", "310", "308"], "fit_nums": ["821", "963", "481", "999", "986", "45", "551", "605", "701", "201", "629", "261", "972", "407", "165", "9"], "chain_num": 4, "ignore": [9, 3, 5, 15, 6, 10, 2, 1]}, "GLM_Sim_13": {"seeds": ["310", "303", "308", "306", "300", "312", "301", "313", "305", "311", "315", "304", "314", "309", "307", "302"], "fit_nums": ["982", "103", "742", "524", "614", "370", "926", "456", "133", "143", "302", "80", "395", "549", "579", "944"], "chain_num": 2, "ignore": [12, 4, 11, 6, 7, 14, 0, 1]}, "CSHL_007": {"seeds": ["314", "303", "308", "313", "301", "300", "302", "305", "315", "306", "310", "309", "311", "304", "307", "312"], "fit_nums": ["462", "703", "345", "286", "480", "313", "986", "165", "201", "102", "322", "894", "960", "438", "330", "169"], "chain_num": 4, "ignore": [3, 12, 4, 5, 2, 0, 13, 1]}, "CSH_ZAD_011": {"seeds": ["314", "311", "303", "300", "305", "310", "306", "301", "302", "315", "304", "309", "308", "312", "313", "307"], "fit_nums": ["320", "385", "984", "897", "315", "120", "320", "945", "475", "403", "210", "412", "695", "564", "664", "411"], "chain_num": 4, "ignore": [0, 2, 14, 11, 7, 10, 13, 9]}, "KS021": {"seeds": ["309", "312", "304", "310", "303", "311", "314", "302", "305", "301", "306", "300", "308", "315", "313", "307"], "fit_nums": ["874", "943", "925", "587", "55", "136", "549", "528", "349", "211", "401", "84", "225", "545", "153", "382"], "chain_num": 4, "ignore": [11, 12, 0, 8, 2, 14, 5, 1]}, "GLM_Sim_15": {"seeds": ["303", "312", "305", "308", "309", "302", "301", "310", "313", "315", "311", "314", "307", "306", "304", "300"], "fit_nums": ["769", "930", "328", "847", "899", "714", "144", "518", "521", "873", "914", "359", "242", "343", "45", "364"], "chain_num": 2, "ignore": [8, 1, 0, 3, 2, 5, 10, 4]}, "CSHL_015": {"seeds": ["301", "302", "307", "310", "309", "311", "304", "312", "300", "308", "313", "305", "314", "315", "306", "303"], "fit_nums": ["717", "705", "357", "539", "604", "971", "669", "76", "45", "413", "510", "122", "190", "821", "368", "472"], "chain_num": 4, "ignore": [7, 6, 10, 2, 15, 13, 1, 3]}, "ibl_witten_16": {"seeds": ["304", "313", "309", "314", "312", "307", "305", "301", "306", "310", "300", "315", "308", "311", "303", "302"], "fit_nums": ["392", "515", "696", "270", "7", "583", "880", "674", "23", "576", "579", "695", "149", "854", "184", "875"], "chain_num": 4, "ignore": [3, 12, 2, 6, 10, 14, 4, 1]}, "KS015": {"seeds": ["315", "305", "309", "303", "314", "310", "311", "312", "313", "300", "307", "308", "304", "301", "302", "306"], "fit_nums": ["257", "396", "387", "435", "133", "164", "403", "8", "891", "650", "111", "557", "473", "229", "842", "196"], "chain_num": 4, "ignore": [7, 8, 0, 10, 2, 3, 12, 9]}, "GLM_Sim_12": {"seeds": ["304", "312", "306", "303", "310", "302", "300", "305", "308", "313", "307", "311", "315", "301", "314", "309"], "fit_nums": ["971", "550", "255", "195", "952", "486", "841", "535", "559", "37", "654", "213", "864", "506", "732", "550"], "chain_num": 2, "ignore": [0, 7, 15, 14, 3, 10, 11, 13]}, "GLM_Sim_11": {"seeds": ["300", "312", "310", "315", "302", "313", "314", "311", "308", "303", "309", "307", "306", "304", "301", "305"], "fit_nums": ["477", "411", "34", "893", "195", "293", "603", "5", "887", "281", "956", "73", "346", "640", "532", "688"], "chain_num": 2}, "GLM_Sim_10": {"seeds": ["301", "300", "306", "305", "307", "309", "312", "314", "311", "315", "304", "313", "303", "308", "302", "310"], "fit_nums": ["391", "97", "897", "631", "239", "652", "19", "448", "807", "35", "972", "469", "280", "562", "42", "706"], "chain_num": 2, "ignore": [1, 9, 15, 3, 13, 12, 7, 11]}, "CSH_ZAD_026": {"seeds": ["312", "313", "308", "310", "303", "307", "302", "305", "300", "315", "306", "301", "311", "304", "314", "309"], "fit_nums": ["699", "87", "537", "628", "797", "511", "459", "770", "969", "240", "504", "948", "295", "506", "25", "378"], "chain_num": 4, "ignore": [12, 13, 4, 11, 8, 3, 15, 0]}, "KS023": {"seeds": ["304", "313", "306", "309", "300", "314", "302", "310", "303", "315", "307", "308", "301", "311", "305", "312"], "fit_nums": ["698", "845", "319", "734", "908", "507", "45", "499", "175", "108", "419", "443", "116", "779", "159", "231"], "chain_num": 4, "ignore": [8, 10, 1, 13, 4, 15, 14, 7]}, "GLM_Sim_05": {"seeds": ["301", "315", "300", "302", "305", "304", "313", "314", "311", "309", "306", "307", "308", "310", "303", "312"], "fit_nums": ["425", "231", "701", "375", "343", "902", "623", "125", "921", "637", "393", "964", "678", "930", "796", "42"], "chain_num": 2, "ignore": [11, 2, 5, 1, 4, 9, 15, 12]}, "CSHL061": {"seeds": ["305", "315", "304", "303", "309", "310", "302", "300", "314", "306", "311", "313", "301", "308", "307", "312"], "fit_nums": ["396", "397", "594", "911", "308", "453", "686", "552", "103", "209", "128", "892", "345", "925", "777", "396"], "chain_num": 4, "ignore": [11, 13, 7, 15, 14, 3, 0, 4]}, "CSHL051": {"seeds": ["303", "310", "306", "302", "309", "305", "313", "308", "300", "314", "311", "307", "312", "304", "315", "301"], "fit_nums": ["69", "186", "49", "435", "103", "910", "705", "367", "303", "474", "596", "334", "929", "796", "616", "790"], "chain_num": 4, "ignore": [15, 12, 8, 13, 0, 2, 4, 5]}, "GLM_Sim_14": {"seeds": ["310", "311", "309", "313", "314", "300", "302", "304", "305", "306", "307", "312", "303", "301", "315", "308"], "fit_nums": ["616", "872", "419", "106", "940", "986", "599", "704", "218", "808", "244", "825", "448", "397", "552", "316"], "chain_num": 2, "ignore": [7, 11, 2, 15, 0, 13, 5, 10]}, "GLM_Sim_11_trick": {"seeds": ["411", "400", "408", "409", "415", "413", "410", "412", "406", "414", "403", "404", "401", "405", "407", "402"], "fit_nums": ["95", "508", "886", "384", "822", "969", "525", "382", "489", "436", "344", "537", "251", "223", "458", "401"], "chain_num": 2, "ignore": [10, 12, 4, 1, 0, 3, 2, 6]}, "GLM_Sim_16": {"seeds": ["302", "311", "303", "307", "313", "308", "309", "300", "305", "315", "304", "310", "312", "301", "314", "306"], "fit_nums": ["914", "377", "173", "583", "870", "456", "611", "697", "13", "713", "159", "248", "617", "37", "770", "780"], "chain_num": 2, "ignore": [4, 10, 5, 0, 13, 8, 6, 7]}, "ZM_3003": {"seeds": ["300", "304", "307", "312", "305", "310", "311", "314", "303", "308", "313", "301", "315", "309", "306", "302"], "fit_nums": ["603", "620", "657", "735", "357", "390", "119", "33", "62", "617", "209", "810", "688", "21", "744", "426"], "chain_num": 4, "ignore": [14, 7, 12, 1, 3, 4, 11, 8]}, "CSH_ZAD_022": {"seeds": ["305", "310", "311", "315", "303", "312", "314", "313", "307", "302", "300", "304", "301", "308", "306", "309"], "fit_nums": ["143", "946", "596", "203", "576", "403", "900", "65", "478", "325", "282", "513", "460", "42", "161", "970"], "chain_num": 4, "ignore": [9, 12, 4, 8, 3, 7, 0, 1]}, "GLM_Sim_07": {"seeds": ["300", "309", "302", "304", "305", "312", "301", "311", "315", "314", "308", "307", "303", "310", "306", "313"], "fit_nums": ["724", "701", "118", "230", "648", "426", "689", "114", "832", "731", "592", "519", "559", "938", "672", "144"], "chain_num": 1}, "KS017": {"seeds": ["311", "310", "306", "309", "303", "302", "308", "300", "313", "301", "314", "307", "315", "304", "312", "305"], "fit_nums": ["97", "281", "808", "443", "352", "890", "703", "468", "780", "708", "674", "27", "345", "23", "939", "457"], "chain_num": 4, "ignore": [0, 13, 8, 1, 12, 5, 10, 9]}, "GLM_Sim_11_sub": {"seeds": ["410", "414", "413", "404", "409", "415", "406", "408", "402", "411", "400", "405", "403", "407", "412", "401"], "fit_nums": ["830", "577", "701", "468", "929", "374", "954", "749", "937", "488", "873", "416", "612", "792", "461", "488"], "chain_num": 2}}
\ No newline at end of file
diff --git a/dyn_glm_chain_analysis.py b/dyn_glm_chain_analysis.py
index 04907a5a632d01e90a49247b007a1febd4f1fe6e..c44c7897447e7df141ad2fd4e516f65052f99d28 100644
--- a/dyn_glm_chain_analysis.py
+++ b/dyn_glm_chain_analysis.py
@@ -11,7 +11,7 @@ import pyhsmm
 import pickle
 import seaborn as sns
 import sys
-from scipy.stats import zscore
+from scipy.stats import zscore, norm
 from scipy.optimize import minimize
 from itertools import combinations, product
 import matplotlib.gridspec as gridspec
@@ -20,6 +20,8 @@ import json
 import time
 import multiprocessing as mp
 from mcmc_chain_analysis import state_size_helper, state_num_helper, gamma_func, alpha_func, ll_func, r_hat_array_comp, rank_inv_normal_transform, eval_r_hat, eval_simple_r_hat
+import pandas as pd
+from pmf_analysis import pmf_type, type2color
 
 
 colors = np.genfromtxt('colors.csv', delimiter=',')
@@ -38,12 +40,20 @@ bias_cont_ticks = (np.arange(9), [-1, -.25, -.125, -.062, 0, .062, .125, .25, 1]
 contrasts_L = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])
 contrasts_R = np.array([1., 0.987, 0.848, 0.555, 0.302, 0, 0, 0, 0, 0, 0])[::-1]
 
-type_colours = ['g', 'b', 'r']
-
 
 def weights_to_pmf(weights, with_bias=1):
-    psi = weights[0] * contrasts_L + weights[1] * contrasts_R + with_bias * weights[-1]
-    return 1 / (1 + np.exp(-psi))
+    psi = weights[0] * contrasts_R + weights[1] * contrasts_L + with_bias * weights[-1]
+    return 1 / (1 + np.exp(psi))  # we somehow got the answers twisted, so we drop the minus here to get the opposite response probability for plotting
+
+performance_points = np.array([-1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0])
+def pmf_to_perf(pmf, def_points):
+    # determine performance of a pmf
+    # we use this to determine regressions in the behaviour of animals
+    # therefore, we exclude 0 as performance on it is 0.5 regardless of PMF, but it might
+    # overall lower performance. The removal of 0.5 later might also be a problem, let's see
+    relevant_points = def_points
+    relevant_points[5] = False
+    return np.mean(np.abs(performance_points[relevant_points] + pmf[relevant_points]))
 
 
 class MCMC_result_list:
@@ -477,9 +487,11 @@ def return_ascending_shuffled():
     return temp
 
 
-def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figure', save_append='', consistencies=None):
+def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figure', save_append='', consistencies=None, CMF=False):
     n = test.results[0].n_sessions
     trial_counter = 0
+    cnas = [] # contrasts aNd actions
+
     for seq_num in range(n):
         # if seq_num + 1 != 12:
         #     trial_counter += len(test.results[0].models[0].stateseqs[seq_num])
@@ -518,9 +530,13 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur
 
             plt.plot(active_trials, color=cmap(0.2 + 0.8 * seq_num / test.results[0].n_sessions), lw=4, label=label, alpha=0.7)
 
+        trial_counter += len(test.results[0].models[0].stateseqs[seq_num])
+
         ms = 6
         noise = np.zeros(len(c_n_a))# np.random.rand(len(c_n_a)) * 0.4 - 0.2
 
+        cnas.append(c_n_a)
+
         mask = c_n_a[:, -1] == 0
         plt.plot(np.where(mask)[0], 0.5 + 0.25 * (noise[mask] - c_n_a[mask, 0] + c_n_a[mask, 1]), 'o', c='b', ms=ms, label='Leftward', alpha=0.6)
 
@@ -557,8 +573,51 @@ def contrasts_plot(test, state_sets, subject, save=False, show=False, dpi='figur
             plt.show()
         else:
             plt.close()
-        trial_counter += len(test.results[0].models[0].stateseqs[seq_num])
 
+        if CMF and not test.results[0].name.startswith('Sim_'):
+            rt_data = pickle.load(open("./session_data/{} rt info {}".format(subject, seq_num + 1), 'rb'))
+            rt_data = rt_data[1:]
+            assert c_n_a.shape[0] == rt_data.shape[0]
+            df = pd.DataFrame(rt_data, columns = ['Signed contrast', 'RT', 'Responses'])
+            means = df.groupby('Signed contrast').mean()['RT']
+            stds = df.groupby('Signed contrast').sem()['RT']
+            plt.errorbar(means.index, means.values, stds.values)
+            plt.title(seq_num + 1, size=22)
+            plt.show()
+
+    return cnas
+
+def pmf_regressions(states_by_session, pmfs, durs):
+    # find out whether pmfs regressed
+    state_perfs = {}
+    state_counter = {}
+    current_best_state = -1
+    counter = 0
+    types = [0, 0, 0]
+
+    for sess in range(states_by_session.shape[1]):
+        states = np.where(states_by_session[:, sess])[0]
+
+        for s in states:
+            if s not in state_counter:
+                state_counter[s] = -1
+            state_counter[s] += 1
+            state_perfs[s] = pmf_to_perf(pmfs[s][1][state_counter[s]], pmfs[s][0])
+            if current_best_state == -1 or state_perfs[current_best_state] < state_perfs[s]:
+                current_best_state = s
+
+        if state_perfs[np.argmax(states_by_session[:, sess])] + 0.05 < state_perfs[current_best_state] and state_counter[np.argmax(states_by_session[:, sess])] > 1:
+            counter += 1
+            if sess < durs[0]:
+                types[0] += 1
+            elif sess < durs[0] + durs[1]:
+                types[1] += 1
+            else:
+                types[2] += 1
+            a, b = state_perfs[np.argmax(states_by_session[:, sess])], state_perfs[current_best_state]
+            print("Regression in session {} during {:.2f}% of session ({:.2f} instead of {:.2f})".format(sess + 1, np.max(states_by_session[:, sess]) * 100, a, b))
+
+    return [counter, states_by_session.shape[1], *types]
 
 def control_flow(test, indices, trials, func_init, first_for, second_for, end_first_for):
     # Generalised control flow for iterating over samples of mode across individually across sessions
@@ -602,6 +661,25 @@ def state_pmfs(test, trials, indices):
     return results['session_js'], results['pmfs']
 
 
+def state_weights(test, trials, indices):
+    def func_init(): return {'weightss': [], 'session_js': []}
+
+    def first_for(test, results):
+        results['weights'] = np.zeros(test.results[0].models[0].obs_distns[0].weights.shape[1])
+
+    def second_for(m, j, session_trials, trial_counter, results):
+        states, counts = np.unique(m.stateseqs[j][session_trials - trial_counter], return_counts=True)
+        for sub_state, c in zip(states, counts):
+            results['weights'] += m.obs_distns[sub_state].weights[j] * c / session_trials.shape[0]
+
+    def end_first_for(results, indices, j, **kwargs):
+        results['weightss'].append(results['weights'] / len(indices))
+        results['session_js'].append(j)
+
+    results = control_flow(test, indices, trials, func_init, first_for, second_for, end_first_for)
+    return results['session_js'], results['weightss']
+
+
 def lapse_sides(test, state_sets, indices):
     """Compute and plot a lapse differential across sessions.
 
@@ -875,7 +953,8 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
                     trial_counter += len(state_seq)
                 counter += 1
         pmfs_to_score.append(np.mean(pmfs))
-    state_mapping = dict(zip(range(len(state_sets)), np.argsort(np.argsort(pmfs_to_score))))  # double argsort for ranks
+    # test.state_mapping = dict(zip(range(len(state_sets)), np.argsort(np.argsort(pmfs_to_score))))  # double argsort for ranks
+    test.state_mapping = dict(zip(range(len(state_sets)), np.flip(np.argsort((states_by_session != 0).argmax(axis=1)))))
 
     for state, trials in enumerate(state_sets):
         cmap = matplotlib.cm.get_cmap(cmaps[state]) if state < len(cmaps) else matplotlib.cm.get_cmap('Greys')
@@ -914,8 +993,8 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
             temp = np.sum(pmfs[:, defined_points]) / (np.sum(defined_points))
             state_color = colors[int(temp * 101 - 1)]
 
-            ax1.fill_between(range(1, 1 + test.results[0].n_sessions), state_mapping[state] - 0.5,
-                             state_mapping[state] + states_by_session[state] - 0.5, color=state_color)
+            ax1.fill_between(range(1, 1 + test.results[0].n_sessions), test.state_mapping[state] - 0.5,
+                             test.state_mapping[state] + states_by_session[state] - 0.5, color=state_color)
 
         else:
             n_points = 150
@@ -924,11 +1003,11 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
 
             for k in range(n_points-1):
                 ax1.fill_between([points[k], points[k+1]],
-                                 state_mapping[state] - 0.5, [state_mapping[state] + interpolation[k] - 0.5, state_mapping[state] + interpolation[k+1] - 0.5], color=cmap(0.3 + 0.7 * k / n_points))
-        ax1.annotate(state_mapping[state] + 1, (test.results[0].n_sessions + 0.1, state_mapping[state] - 0.15), fontsize=22, annotation_clip=False)
+                                 test.state_mapping[state] - 0.5, [test.state_mapping[state] + interpolation[k] - 0.5, test.state_mapping[state] + interpolation[k+1] - 0.5], color=cmap(0.3 + 0.7 * k / n_points))
+        ax1.annotate(test.state_mapping[state] + 1, (test.results[0].n_sessions + 0.1, test.state_mapping[state] - 0.15), fontsize=22, annotation_clip=False)
 
         if test.results[0].name.startswith('GLM_Sim_'):
-            ax1.plot(range(1, 1 + test.results[0].n_sessions), truth['state_map'][state_mapping[state]] + truth['state_posterior'][:, state] - 0.5, color='r')
+            ax1.plot(range(1, 1 + test.results[0].n_sessions), truth['state_map'][test.state_mapping[state]] + truth['state_posterior'][:, state] - 0.5, color='r')
 
         alpha_level = 0.3
         ax2.axvline(0.5, c='grey', alpha=alpha_level, zorder=4)
@@ -943,23 +1022,23 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
         #         defined_points[[0, 1, -2, -1]] = True
         if separate_pmf:
             for j, pmf in zip(session_js, pmfs):
-                ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmf[defined_points] - 0.5 + state_mapping[state], color=cmap(0.2 + 0.8 * j / test.results[0].n_sessions))
-                ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmf[defined_points] - 0.5 + state_mapping[state], ls='', ms=7, marker='*', color=cmap(j / test.results[0].n_sessions))
+                ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmf[defined_points] - 0.5 + test.state_mapping[state], color=cmap(0.2 + 0.8 * j / test.results[0].n_sessions))
+                ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmf[defined_points] - 0.5 + test.state_mapping[state], ls='', ms=7, marker='*', color=cmap(j / test.results[0].n_sessions))
             all_pmfs.append((defined_points, pmfs))
         else:
             temp = np.percentile(pmfs, [2.5, 97.5], axis=0)
-            ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmfs[:, defined_points].mean(axis=0) - 0.5 + state_mapping[state], color=state_color)
-            ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmfs[:, defined_points].mean(axis=0) - 0.5 + state_mapping[state], ls='', ms=7, marker='*', color=state_color)
-            ax2.fill_between(np.where(defined_points)[0] / (len(defined_points)-1), temp[1, defined_points] - 0.5 + state_mapping[state], temp[0, defined_points] - 0.5 + state_mapping[state], alpha=0.2, color=state_color)
+            ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmfs[:, defined_points].mean(axis=0) - 0.5 + test.state_mapping[state], color=state_color)
+            ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), pmfs[:, defined_points].mean(axis=0) - 0.5 + test.state_mapping[state], ls='', ms=7, marker='*', color=state_color)
+            ax2.fill_between(np.where(defined_points)[0] / (len(defined_points)-1), temp[1, defined_points] - 0.5 + test.state_mapping[state], temp[0, defined_points] - 0.5 + test.state_mapping[state], alpha=0.2, color=state_color)
             all_pmfs.append((defined_points, pmfs[:, defined_points].mean(axis=0)))
 
         if test.results[0].name.startswith('GLM_Sim_'):
             sim_pmf = weights_to_pmf(truth['weights'][state])
-            ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), sim_pmf[defined_points] - 0.5 + truth['state_map'][state_mapping[state]], color='r')
+            ax2.plot(np.where(defined_points)[0] / (len(defined_points)-1), sim_pmf[defined_points] - 0.5 + truth['state_map'][test.state_mapping[state]], color='r')
 
-        ax2.axhline(state_mapping[state] + 0.5, c='k')
-        ax2.axhline(state_mapping[state], c='grey', alpha=alpha_level, zorder=4)
-        ax1.axhline(state_mapping[state] + 0.5, c='grey', alpha=alpha_level, zorder=4)
+        ax2.axhline(test.state_mapping[state] + 0.5, c='k')
+        ax2.axhline(test.state_mapping[state], c='grey', alpha=alpha_level, zorder=4)
+        ax1.axhline(test.state_mapping[state] + 0.5, c='grey', alpha=alpha_level, zorder=4)
 
     if not test.results[0].name.startswith('Sim_'):
         perf = np.zeros(test.results[0].n_sessions)
@@ -980,6 +1059,14 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
         ax0.fill_between(range(1, 1 + test.results[0].n_sessions), perf - 0.5, -0.5, color='k')
 
     durs, state_types = state_type_durs(states_by_session, all_pmfs)
+
+    # how many states per session per type
+    states_per_sess = np.sum(states_by_session > 0.05, axis=0)
+    if durs[0] > 0 and durs[1] > 0 and durs[2] > 1:
+        states_per_type = [np.mean(states_per_sess[:durs[0]]), np.mean(states_per_sess[durs[0]:durs[0]+durs[1]]), np.mean(states_per_sess[durs[0]+durs[1]:])]
+    else:
+        states_per_type = []
+    # other statistics
     dur_counter = 1
     contrast_intro_types = [0, 0, 0, 0]
     state, when = np.where(states_by_session > 0.05)
@@ -987,7 +1074,7 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
     covered_states = []
     for i, d in enumerate(durs):
         if type_coloring:
-            ax0.fill_between(range(dur_counter, 1 + dur_counter + d), 0.5, -0.5, color=type_colours[i], zorder=0, alpha=0.3)
+            ax0.fill_between(range(dur_counter, 1 + dur_counter + d), 0.5, -0.5, color=type2color[i], zorder=0, alpha=0.3)
         dur_counter += d
 
         # find out during which state type which contrast was introduced
@@ -1041,13 +1128,86 @@ def state_development(test, state_sets, indices, save=True, save_append='', show
     plt.tight_layout()
     if save:
         print("saving with {} dpi".format(dpi))
-        plt.savefig("dynamic_GLM_figures/meta state development_{}_{}{}.png".format(test.results[0].name, separate_pmf, save_append), dpi=dpi)
+        plt.savefig("dynamic_GLM_figures/meta_state_development_{}_{}{}.png".format(test.results[0].name, separate_pmf, save_append), dpi=dpi)
     if show:
         plt.show()
     else:
         plt.close()
 
-    return states_by_session, all_pmfs, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs)), introductions_by_stage
+    return states_by_session, all_pmfs, durs, state_types, contrast_intro_types, smart_divide(introductions_by_stage, np.array(durs)), introductions_by_stage, states_per_type
+
+def compare_pmfs(test, state_sets, indices, states2compare, states_by_session, all_pmfs, title=""):
+    """
+       Take a set of states, and plot out their PMFs on all sessions on which they occur.
+       See how different they really are.
+
+       Takes states_by_session and all_pmfs as input from state_development
+    """
+    colors = ['blue', 'orange', 'green', 'black', 'red']
+    assert len(states2compare) <= len(colors)
+    # subtract 1 to get internal numbering
+    states2compare = [s - 1 for s in states2compare]
+    # transform desired states into the actual numbering, before ordering by bias
+    states2compare = [key for key in test.state_mapping.keys() if test.state_mapping[key] in states2compare]
+
+    sessions = np.where(states_by_session[states2compare].sum(0))[0]
+
+    for i, state in enumerate(states2compare):
+        counter = 0
+        for j, session in enumerate(sessions):
+            plt.subplot(1, len(sessions), j + 1)
+            if i == 0:
+                plt.title(session)
+            if states_by_session[state, session] > 0:
+                plt.plot(np.where(all_pmfs[state][0])[0], (all_pmfs[state][1][counter])[all_pmfs[state][0]], c=colors[i])
+                counter += 1
+            plt.ylim(0, 1)
+            if j != 0:
+                plt.gca().set_yticks([])
+    # plt.tight_layout()
+    if title != "":
+        plt.savefig(title)
+    plt.show()
+
+
+def compare_weights(test, state_sets, indices, states2compare, states_by_session, title=""):
+    """
+       Take a set of states, and plot out their weights on all sessions on which they occur.
+       See how different they really are.
+       Similar to compare_pmfs
+
+       Takes states_by_session as input from state_development
+    """
+    colors = ['blue', 'orange', 'green', 'black', 'red']
+    assert len(states2compare) <= len(colors)
+    # subtract 1 to get internal numbering
+    states2compare = [s - 1 for s in states2compare]
+    # transform desired states into the actual numbering, before ordering by bias
+    states2compare = [key for key in test.state_mapping.keys() if test.state_mapping[key] in states2compare]
+
+    sessions = np.where(states_by_session[states2compare].sum(0))[0]
+
+    state_counter = -1
+    for state, trials in enumerate(state_sets):
+        if state not in states2compare:
+            continue
+        state_counter += 1
+        _, weights = state_weights(test, trials, indices)
+        counter = 0
+        for j, session in enumerate(sessions):
+            plt.subplot(1, len(sessions), j + 1)
+            if state == 0:
+                plt.title(session)
+            if states_by_session[state, session] > 0:
+                plt.plot(weights[counter], c=colors[state_counter])
+                counter += 1
+            plt.ylim(-6, 6)
+            if j != 0:
+                plt.gca().set_yticks([])
+    # plt.tight_layout()
+    if title != "":
+        plt.savefig(title)
+    plt.show()
 
 
 def smart_divide(a, b):
@@ -1097,11 +1257,13 @@ if __name__ == "__main__":
     fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0]
     if fit_type == 'bias':
         loading_info = json.load(open("canonical_infos_bias.json", 'r'))
+        r_hats = json.load(open("canonical_info_r_hats_bias.json", 'r'))
     elif fit_type == 'prebias':
         loading_info = json.load(open("canonical_infos.json", 'r'))
+        r_hats = json.load(open("canonical_info_r_hats.json", 'r'))
     subjects = list(loading_info.keys())
 
-    r_hats = []
+    r_hats = {}
 
     # R^hat tests
     # test = MCMC_result_list([fake_result(100) for i in range(8)])
@@ -1112,10 +1274,9 @@ if __name__ == "__main__":
     check_r_hats = False
     if check_r_hats:
         subjects = list(loading_info.keys())
-        subjects = ['KS014']
         for subject in subjects:
-            # if subject.startswith('GLM'):
-            #     continue
+            if subject.startswith('GLM_Sim_07') or subject.startswith('GLM_Sim_11'):
+                continue
             print("_________________________________")
             print(subject)
             fit_num = loading_info[subject]['fit_nums'][-1]
@@ -1132,7 +1293,7 @@ if __name__ == "__main__":
 
             # mins = find_good_chains(chains[:, :-1].reshape(32, chains.shape[1] // 2))
             sol, final_r_hat = find_good_chains_unsplit_greedy(chains1, chains2, chains3, chains4, reduce_to=chains1.shape[0] // 2)
-            r_hats.append((subject, final_r_hat))
+            r_hats[subject] = final_r_hat
             loading_info[subject]['ignore'] = sol
 
         print(r_hats)
@@ -1233,171 +1394,88 @@ def plot_pmf_types(pmf_types, subject, fit_type, save=True, show=False):
         plt.close()
 
 
-def pmf_type(pmf):
-    if pmf[-1] - pmf[0] < 0.2:
-        return 0
-    elif pmf[-1] - pmf[0] < 0.6:# and np.abs(pmf[0] + pmf[-1] - 1) > 0.1:
-        return 1
-    else:
-        return 2
-
-
-type2color = {0: 'green', 1: 'blue', 2: 'red'}
+def two_sample_binom_test(s1, s2):
+    p1, p2 = np.mean(s1), np.mean(s2)
+    n1, n2 = s1.size, s2.size
+    p = (n1 * p1 + n2 * p2) / (n1 + n2)
+    if p == 1. or p == 0.:
+        return 0., 0.99
+    z = (p1 - p2) / np.sqrt(p * (1 - p) * (1 / n1 + 1 / n2))
+    p_val = (1 - norm.cdf(np.abs(z))) * 2
+    return z, p_val
+
+
+def compare_performance(cnas, contrasts=(1, 1), title=""):
+    # compare the performance on a certain contrast
+    # cnas is set of data for all sessions (from contrasts_plot)
+    # contrasts is a tuple specifying the contrast of interest (first or second column, contrast strength)
+    plt.figure(figsize=(16*0.75, 9*0.75))
+    for i in range(len(cnas) - 1):
+        if cnas[i][cnas[i][:, contrasts[0]] == contrasts[1], -1].shape[0] == 0 or cnas[i+1][cnas[i+1][:, contrasts[0]] == contrasts[1], -1].shape[0] == 0:
+            continue
+        perf1, perf2 = np.mean(cnas[i][cnas[i][:, contrasts[0]] == contrasts[1], -1]), np.mean(cnas[i+1][cnas[i+1][:, contrasts[0]] == contrasts[1], -1])
+        p = two_sample_binom_test(cnas[i][cnas[i][:, contrasts[0]] == contrasts[1], -1], cnas[i+1][cnas[i+1][:, contrasts[0]] == contrasts[1], -1])[1]
+        factor = (perf1 > perf2) * 2 - 1
+        plt.annotate(xy=(i + 0.5, (perf1 + perf2) / 2 + factor * 0.025), text=("{:.3f}".format(p))[1:])
+        color = 'red' if p < 0.05 else 'blue'
+        plt.plot([i, i+1], [perf1, perf2], color=color)
+    sns.despine()
+    cont_string = "Contrast {} ".format('right' if contrasts[0] else 'left')
+    plt.title(cont_string + str(contrasts[1]))
+    plt.ylim(-0.07, 1.07)
+    plt.tight_layout()
+    if title != "":
+        plt.savefig(title)
+    plt.show()
 
-if False:
 
-    all_changing_pmfs = pickle.load(open("changing_pmfs.p", 'rb'))
-    plt.figure(figsize=(16, 9))
-    for i, pmf in enumerate(all_changing_pmfs):
-        plt.subplot(4, 7, i + 1)
-        for p in pmf[1]:
-            plt.plot(np.where(pmf[0])[0], p[pmf[0]], color=type2color[pmf_type(p)])
-        plt.ylim(0, 1)
+def type_hist(data):
+    highest = int(data.max())
+    if (data % 1 == 0).all():
+        bins = np.arange(highest + 2) - 0.5
+    else:
+        bins = np.histogram(data)[1]
+    hist_max = 0
+    for i in range(data.shape[1]):
+        hist_max = max(hist_max, np.histogram(data[:, i], bins)[0].max())
+
+    plt.subplot(3, 1, 1)
+    assert np.histogram(data[:, 0])[0].sum() == np.histogram(data[:, 0], bins)[0].sum()
+    plt.hist(data[:, 0], alpha=1/3, label="type 1", align='mid', bins=bins)
+    plt.xlim(-0.5, highest + 1)
+    plt.ylim(0, hist_max + 1)
+
+    plt.subplot(3, 1, 2)
+    assert np.histogram(data[:, 1])[0].sum() == np.histogram(data[:, 1], bins)[0].sum()
+    plt.hist(data[:, 1], alpha=1/3, label="type 2", align='mid', bins=bins)
+    plt.xlim(-0.5, highest + 1)
+    plt.ylim(0, hist_max + 1)
+
+    plt.subplot(3, 1, 3)
+    assert np.histogram(data[:, 2])[0].sum() == np.histogram(data[:, 2], bins)[0].sum()
+    plt.hist(data[:, 2], alpha=1/3, label="type 3", align='mid', bins=bins)
+    plt.xlim(-0.5, highest + 1)
+    plt.ylim(0, hist_max + 1)
+
+    # plt.legend()
+    plt.show()
 
-        sns.despine()
-        if i+1 != 22:
-            plt.gca().set_xticks([])
-            plt.gca().set_yticks([])
-        else:
-            plt.xlabel("Contrasts", size=22)
-            plt.ylabel("P(rightwards)", size=22)
-            plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=16)
-            plt.gca().set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=16)
 
-        if i + 1 == 30:
-            break
+if 0:
+    all_intros = pickle.load(open("all_intros.p", 'rb'))
+    all_intros_div = pickle.load(open("all_intros_div.p", 'rb'))
+    all_states_per_type = pickle.load(open("all_states_per_type.p", 'rb'))
 
-    plt.tight_layout()
-    plt.savefig("changing pmfs")
-    plt.show()
+    # There are 5 mice with 0 type 2 intros, but only 3 mice with no type 2 stats.
+    # That is because they come up, but don't explain the necessary 50% to start phase 2.
+    type_hist(all_intros)
+    type_hist(all_intros_div)
+    type_hist(all_states_per_type)
     quit()
 
-    type_2_assyms = []
-    tick_size = 14
-    label_size = 26
-    all_first_pmfs = pickle.load(open("pmfs_temp.p", 'rb'))
-    plt.figure(figsize=(16, 9))
-    plt.subplot(1, 3, 1)
-    counter = [[0, 0], [0, 0]]
-    save_title = "all types" if False else "KS014 types"
-    if save_title == "KS014 types":
-        all_first_pmfs = {'KS014': all_first_pmfs['KS014']}
-
-    for key in all_first_pmfs:
-        x = all_first_pmfs[key]
-        if type(x[0]) == int:
-            continue
-        linestyle = '-' if x[2] == 0 else '--'
-        plt.plot(np.where(x[1])[0], x[0][x[1]], linestyle=linestyle, c='g')
-    plt.ylim(0, 1)
-    plt.gca().set_xticks(np.arange(11), all_conts, size=tick_size)
-    plt.gca().set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size)
-    plt.gca().spines[['right', 'top']].set_visible(False)
-    plt.xlim(0, 10)
-    plt.xticks(rotation=45)
-    plt.gca().set_ylabel("P(rightwards)", size=label_size)
-
-    plt.subplot(1, 3, 2)
-    for key in all_first_pmfs:
-        x = all_first_pmfs[key]
-        if type(x[3]) == int:
-            continue
-        type_2_assyms.append(np.abs(x[3][0] + x[3][-1] - 1))
-        linestyle = '-' if x[5] == 0 else '--'
-        counter[0][0 if x[5] == 0 else 1] += 1
-        if linestyle == '--':
-            continue
-        plt.plot(np.where(x[4])[0], x[3][x[4]], linestyle=linestyle, c='b')
-    plt.gca().set_yticks([])
-    plt.ylim(0, 1)
-    plt.gca().set_xticks(np.arange(11), all_conts, size=tick_size)
-    plt.gca().spines[['right', 'top']].set_visible(False)
-    plt.xticks(rotation=45)
-    plt.xlim(0, 10)
-    plt.gca().set_xlabel("Contrasts", size=label_size)
-
-    plt.subplot(1, 3, 3)
-    for key in all_first_pmfs:
-        x = all_first_pmfs[key]
-        if type(x[6]) == int:
-            continue
-        linestyle = '-' if x[8] == 0 else '--'
-        counter[1][0 if x[8] == 0 else 1] += 1
-        if linestyle == '--':
-            continue
-        plt.plot(np.where(x[7])[0], x[6][x[7]], linestyle=linestyle, c='r')
-    plt.gca().set_yticks([])
-    plt.ylim(0, 1)
-    plt.gca().set_xticks(np.arange(11), all_conts, size=tick_size)
-    plt.gca().spines[['right', 'top']].set_visible(False)
-    plt.xlim(0, 10)
-    plt.xticks(rotation=45)
-
-    print(counter)
-    plt.tight_layout()
-    plt.savefig(save_title)
-    plt.show()
-    if save_title == "KS014 types":
-        quit()
-
-    counter = 0
-    fig, ax = plt.subplots(1, 3, figsize=(16, 9))
-    for key in all_first_pmfs:
-        x = all_first_pmfs[key]
-        if type(x[3]) == int:
-            continue
-        linestyle = '-' if x[5] == 0 else '--'
-        if linestyle == '--':
-            continue
-        if np.abs(x[3][0] + x[3][-1] - 1) <= 0.1:
-            counter += 1
-            use_ax = 2
-        else:
-            use_ax = int(x[3][0] > 1 - x[3][-1])
-
-        ax[use_ax].plot(np.where(x[4])[0], x[3][x[4]], linestyle=linestyle, c='b')
-    ax[0].set_ylim(0, 1)
-    ax[0].set_xlim(0, 10)
-    ax[0].spines[['right', 'top']].set_visible(False)
-    ax[0].set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size)
-    ax[0].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
-    ax[0].set_ylabel("P(rightwards)", size=label_size)
-
-    ax[1].set_ylim(0, 1)
-    ax[1].set_xlim(0, 10)
-    ax[1].set_yticks([])
-    ax[1].spines[['right', 'top']].set_visible(False)
-    ax[1].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
-    ax[1].set_xlabel("Contrasts", size=label_size)
-
-    ax[2].set_ylim(0, 1)
-    ax[2].set_xlim(0, 10)
-    ax[2].set_yticks([])
-    ax[2].spines[['right', 'top']].set_visible(False)
-    ax[2].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
-    print(counter)
-    plt.tight_layout()
-    plt.savefig("differentiate type 2")
-    plt.show()
-    quit()
 
 if __name__ == "__main__":
 
-    # visualise pmf types
-    # lapses = [0.1, 0.2, 0.25, 0.33, 0.4, 0.45, 0.5, 0.55, 0.66, 0.9]
-    # test_pmf = np.zeros(4)
-    # for i, lapse_l in enumerate(lapses):
-    #     plt.subplot(1, 10, 1+i)
-    #     if i != 0:
-    #         plt.gca().set_yticklabels([])
-    #     plt.ylim(0, 1)
-    #     test_pmf[:2] = lapse_l
-    #     for lapse_r in np.linspace(0.02, 0.98, 33):
-    #         test_pmf[2:] = lapse_r
-    #         plt.plot([0, 1, 9, 10], test_pmf, c=type_colours[pmf_type(test_pmf)])
-    # plt.show()
-    # quit()
-
     fit_type = ['prebias', 'bias', 'all', 'prebias_plus', 'zoe_style'][0]
     if fit_type == 'bias':
         loading_info = json.load(open("canonical_infos_bias.json", 'r'))
@@ -1405,8 +1483,11 @@ if __name__ == "__main__":
     elif fit_type == 'prebias':
         loading_info = json.load(open("canonical_infos.json", 'r'))
         r_hats = json.load(open("canonical_info_r_hats.json", 'r'))
-    no_good_pcas = ['NYU-06', 'SWC_023']  # no good rhat: 'ibl_witten_13'
-    subjects = ['KS014']  # list(loading_info.keys())
+    no_good_pcas = ['NYU-06', 'SWC_023']
+    subjects = list(loading_info.keys())
+    # subjects = ['ZM_1897']
+
+    # meh pmfs: KS021
     print(subjects)
     fit_variance = [0.03, 0.002, 0.0005, 'uniform', 0, 0.008][0]
     dur = 'yes'
@@ -1433,11 +1514,16 @@ if __name__ == "__main__":
 
     abs_state_durs = []
     all_first_pmfs = {}
+    all_first_pmfs_typeless = {}
     all_pmf_diffs = []
     all_pmf_asymms = []
     all_pmfs = []
     all_changing_pmfs = []
+    all_changing_pmf_names = []
     all_intros = []
+    all_intros_div = []
+    all_states_per_type = []
+    regressions = []
 
     for subject in subjects:
 
@@ -1445,48 +1531,66 @@ if __name__ == "__main__":
             continue
 
         print(subject)
-        results = []
 
         try:
+            continue
             test = pickle.load(open("multi_chain_saves/canonical_result_{}_{}.p".format(subject, fit_type), 'rb'))
             print('loaded canoncial result')
 
+            mode_specifier = ''
             mode_indices = pickle.load(open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'rb'))
-            quit()
             state_sets = pickle.load(open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'rb'))
+
             # lapse differential
             # lapse_sides(test, [s for s in state_sets if len(s) > 40], mode_indices)
 
             # training overview
-            states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True)
+            states, pmfs, durs, _, contrast_intro_type, intros_by_type, undiv_intros, states_per_type = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, show=0, separate_pmf=1, type_coloring=True)
+            regression = pmf_regressions(states, pmfs, durs)
+            regressions.append(regression)
+            continue
+            # compare_pmfs(test, [s for s in state_sets if len(s) > 40], mode_indices, [4, 5], states, pmfs, title="{} convergence pmf".format(subject))
+            # compare_weights(test, [s for s in state_sets if len(s) > 40], mode_indices, [4, 5], states, title="{} convergence weights".format(subject))
+            # quit()
+            all_first_pmfs_typeless[subject] = []
+            for pmf in pmfs:
+                all_first_pmfs_typeless[subject].append((pmf[0], pmf[1][0]))
             all_intros.append(undiv_intros)
+            all_intros_div.append(intros_by_type)
+            if states_per_type != []:
+                all_states_per_type.append(states_per_type)
+
             intros_by_type_sum += intros_by_type
             first_pmfs, changing_pmfs = get_first_pmfs(states, pmfs)
             for pmf in changing_pmfs:
                 if type(pmf[0]) == int:
                     continue
+                all_changing_pmf_names.append(subject)
                 all_changing_pmfs.append(pmf)
             all_first_pmfs[subject] = first_pmfs
             for pmf in pmfs:
+                all_pmfs.append(pmf)
                 for p in pmf[1]:
                     all_pmf_diffs.append(p[-1] - p[0])
                     all_pmf_asymms.append(np.abs(p[0] + p[-1] - 1))
-                    all_pmfs.append(p)
             contrast_intro_types.append(contrast_intro_type)
+            continue
             # state_development_single_sample(test, [mode_indices[0]], show=True, separate_pmf=True, save=False)
 
             # session overview
-            consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
-            consistencies /= consistencies[0, 0]
-            contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=True, consistencies=consistencies)
+            # consistencies = pickle.load(open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'rb'))
+            # consistencies /= consistencies[0, 0]
+            # c_n_a, rt_data = contrasts_plot(test, [s for s in state_sets if len(s) > 40], dpi=300, subject=subject, save=True, show=False, consistencies=consistencies, CMF=True)
+            # compare_performance(cnas, (1, 1), title="{} contrast {} performance test".format(subject, (1, 1)))
+            # compare_performance(cnas, (0, 0.848))
 
             # duration of different state types (and also percentage of type activities)
-            abs_state_durs.append(durs)
-            simplex_durs = np.array(durs).reshape(1, 3)
-            print(simplex_durs / np.sum(simplex_durs))
-            from simplex_plot import projectSimplex
-            print(projectSimplex(simplex_durs / simplex_durs.sum(1)[:, None]))
-            continue
+            # abs_state_durs.append(durs)
+            # continue
+            # simplex_durs = np.array(durs).reshape(1, 3)
+            # print(simplex_durs / np.sum(simplex_durs))
+            # from simplex_plot import projectSimplex
+            # print(projectSimplex(simplex_durs / simplex_durs.sum(1)[:, None]))
 
             # compute state type proportions and split the pmfs accordingly
             # ret, trans, pmf_types = state_cluster_interpolation(states, pmfs)
@@ -1552,47 +1656,46 @@ if __name__ == "__main__":
             # plt.savefig("temp")
             # plt.close()
 
-            quit()
-
-            dim = 3
-            ev, eig, projection_matrix, dimreduc = test.state_pca(subject, pca_vecs='dists', dim=dim)
-            xy = np.vstack([dimreduc[i] for i in range(dim)])
-            from scipy.stats import gaussian_kde
-            z = gaussian_kde(xy)(xy)
-            pickle.dump((xy, z), open("multi_chain_saves/xyz_{}_{}.p".format(subject, fit_type), 'wb'))
-
-            if 'mode prob level' not in loading_info[subject]:
-                print(subject)
-                xy, z = pickle.load(open("multi_chain_saves/xyz_{}_{}.p".format(subject, fit_type), 'rb'))
-                happy = False
-                while not happy:
-                    print()
-                    print("Pick level")
-                    prob_level = float(input())
-                    print("Level is {}".format(prob_level))
-                    print("# of samples: {}".format((z > prob_level).sum()))
-                    mode_indices = np.where(z > prob_level)[0]
-                    if (z > prob_level).sum() > 0:
-                        print(xy[0][mode_indices].min(), xy[0][mode_indices].max(), xy[1][mode_indices].min(), xy[1][mode_indices].max())
-                        print("Happy?")
-                        happy = 'yes' == input()
-                print("Subset by factor?")
-                if input() == 'yes':
-                    print("Factor?")
-                    print(mode_indices.shape)
-                    factor = int(input())
-                    mode_indices = mode_indices[::factor]
-                    print(mode_indices.shape)
-                loading_info[subject]['mode prob level'] = prob_level
-
-                pickle.dump(mode_indices, open("multi_chain_saves/mode_indices_{}_{}.p".format(subject, fit_type), 'wb'))
-                consistencies = test.consistency_rsa(indices=mode_indices)
-                pickle.dump(consistencies, open("multi_chain_saves/consistencies_{}_{}.p".format(subject, fit_type), 'wb'))
-
-            string_prefix = ''
+            # dim = 3
+            # ev, eig, projection_matrix, dimreduc = test.state_pca(subject, pca_vecs='dists', dim=dim)
+            # xy = np.vstack([dimreduc[i] for i in range(dim)])
+            # from scipy.stats import gaussian_kde
+            # z = gaussian_kde(xy)(xy)
+            # pickle.dump((xy, z), open("multi_chain_saves/xyz_{}_{}.p".format(subject, fit_type), 'wb'))
+            # #
+            # string_prefix = 'second_'
+            #
+            # print(subject)
+            # xy, z = pickle.load(open("multi_chain_saves/xyz_{}_{}.p".format(subject, fit_type), 'rb'))
+            # quit()
+            # happy = False
+            # while not happy:
+            #     print()
+            #     print("Pick level")
+            #     prob_level = float(input())
+            #     print("Level is {}".format(prob_level))
+            #     print("# of samples: {}".format((z > prob_level).sum()))
+            #     mode_indices = np.where(z > prob_level)[0]
+            #     if (z > prob_level).sum() > 0:
+            #         print(xy[0][mode_indices].min(), xy[0][mode_indices].max(), xy[1][mode_indices].min(), xy[1][mode_indices].max())
+            #         print("Happy?")
+            #         happy = 'yes' == input()
+            # print("Subset by factor?")
+            # if input() == 'yes':
+            #     print("Factor?")
+            #     print(mode_indices.shape)
+            #     factor = int(input())
+            #     mode_indices = mode_indices[::factor]
+            #     print(mode_indices.shape)
+            # loading_info[subject]['mode prob level'] = prob_level
+            #
+            # pickle.dump(mode_indices, open("multi_chain_saves/{}mode_indices_{}_{}.p".format(string_prefix, subject, fit_type), 'wb'))
+            # consistencies = test.consistency_rsa(indices=mode_indices)
+            # pickle.dump(consistencies, open("multi_chain_saves/{}consistencies_{}_{}.p".format(string_prefix, subject, fit_type), 'wb'))
 
             mode_indices = pickle.load(open("multi_chain_saves/{}mode_indices_{}_{}.p".format(string_prefix, subject, fit_type), 'rb'))
-            consistencies = pickle.load(open("multi_chain_saves/{}consistencies_{}_{}.p".format(string_prefix, subject, fit_type), 'rb'))
+            consistencies = pickle.load(open("multi_chain_saves/{}mode_consistencies_{}_{}.p".format(string_prefix, subject, fit_type), 'rb'))
+            session_bounds = list(np.cumsum([len(s) for s in test.results[0].models[-1].stateseqs]))
 
             import scipy.cluster.hierarchy as hc
             consistencies /= consistencies[0, 0]
@@ -1620,10 +1723,9 @@ if __name__ == "__main__":
                     for x, y in zip(b, c):
                         state_sets.append(np.where(a == x)[0])
                     print("dumping state set")
-                    pickle.dump(state_sets, open("multi_chain_saves/state_sets_{}_{}.p".format(subject, fit_type), 'wb'))
-                    quit()
+                    pickle.dump(state_sets, open("multi_chain_saves/{}state_sets_{}_{}.p".format(string_prefix, subject, fit_type), 'wb'))
                     states, pmfs = state_development(test, [s for s in state_sets if len(s) > 40], mode_indices, save_append='_{}{}'.format(string_prefix, criterion), show=True, separate_pmf=True)
-                    contrasts_plot(test, [s for s in state_sets if len(s) > 40], subject=subject, save_append='_{}{}'.format(string_prefix, criterion), save=True, show=True)
+                    # contrasts_plot(test, [s for s in state_sets if len(s) > 40], subject=subject, save_append='_{}{}'.format(string_prefix, criterion), save=True, show=True)
 
                     # I think this is about finding out where states start and how long they last
                     # state_id, session_appears = np.where(states)
@@ -1660,18 +1762,14 @@ if __name__ == "__main__":
 
             plt.tight_layout()
             plt.savefig("peter figures/{}clustered_trials_{}_{}".format(string_prefix, subject, 'criteria comp').replace('.', '_'))
-            plt.show()
+            plt.close()
 
         except FileNotFoundError as e:
-            continue
             print(e)
-            r_hat = 1.5
-            for r in r_hats:
-                if r[0] == subject:
-                    r_hat = r[1]
+            continue
             print('no canoncial result')
-            print(r_hat)
-            if r_hat >= 1.05:
+            print(r_hats[subject])
+            if r_hats[subject] >= 1.05:
                 print("Skipping")
                 continue
             else:
@@ -1729,6 +1827,13 @@ if __name__ == "__main__":
 
     # pickle.dump(all_first_pmfs, open("pmfs_temp.p", 'wb'))
     # pickle.dump(all_changing_pmfs, open("changing_pmfs.p", 'wb'))
+    # pickle.dump(all_changing_pmf_names, open("changing_pmf_names.p", 'wb'))
+    # pickle.dump(all_first_pmfs_typeless, open("all_first_pmfs_typeless.p", 'wb'))
+    # pickle.dump(all_pmfs, open("all_pmfs.p", 'wb'))
+    # pickle.dump(all_intros, open("all_intros.p", 'wb'))
+    # pickle.dump(all_intros_div, open("all_intros_div.p", 'wb'))
+    # pickle.dump(all_states_per_type, open("all_states_per_type.p", 'wb'))
+    # pickle.dump(regressions, open("regressions.p", 'wb'))
     #
     # a = [x for x, y in zip(all_pmf_asymms, all_pmf_diffs) if y >= 0.2]
     # b = [y for x, y in zip(all_pmf_asymms, all_pmf_diffs) if y >= 0.2]
@@ -1758,6 +1863,16 @@ if __name__ == "__main__":
 
         plotSimplex(np.array(abs_state_durs), c='k', show=True)
 
+        plt.hist(abs_state_durs.sum(1), color='grey', bins=12)
+        sns.despine()
+        plt.xticks(size=26)
+        plt.yticks(size=26)
+        plt.ylabel("# of mice", size=40)
+        plt.xlabel('# of sessions', size=40)
+        plt.tight_layout()
+        plt.savefig("session_num_hist.png", dpi=300, transparent=True)
+        plt.show()
+
     if False:
         ax[0].set_ylim(0, 1)
         ax[1].set_ylim(0, 1)
diff --git a/index_mice.py b/index_mice.py
index 59af129239564a66b3a9965c4f03454c26f8e81a..ebb58b840677ea1ffc0cdd34584d08c4403b3a75 100644
--- a/index_mice.py
+++ b/index_mice.py
@@ -34,7 +34,7 @@ for filename in os.listdir("./dynamic_GLMiHMM_crossvals/"):
         local_dict = bias_subinfo
     if subject not in local_dict:
         local_dict[subject] = {"seeds": [], "fit_nums": [], "chain_num": 0}
-    if int(chain_num) == 0:
+    if int(chain_num) == 0:  # if this is the first file of that chain, save some info
         local_dict[subject]["seeds"].append(seed)
         local_dict[subject]["fit_nums"].append(fit_num)
     else:
@@ -62,7 +62,6 @@ for s in prebias_subinfo.keys():
         prebias_subinfo[s]["seeds"] = new_seeds
 
 
-
 big = []
 non_big = []
 sim_subjects = []
diff --git a/pmf_analysis.py b/pmf_analysis.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e7a0338c09ff15877860996f52d72999c16cd15
--- /dev/null
+++ b/pmf_analysis.py
@@ -0,0 +1,343 @@
+import pickle
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+type2color = {0: 'green', 1: 'blue', 2: 'red'}
+all_conts = np.array([-1, -0.5, -.25, -.125, -.062, 0, .062, .125, .25, 0.5, 1])
+
+
+def pmf_type(pmf):
+    if pmf[-1] - pmf[0] < 0.2:
+        return 0
+    # elif pmf[-1] - pmf[0] < 0.4:# and np.abs(pmf[0] + pmf[-1] - 1) > 0.1:
+    elif max(pmf[0], 1 - pmf[-1]) > 0.37 or pmf[-1] - pmf[0] < 0.55:
+        if pmf[0] > 0.37 and 1 - pmf[-1] > 0.37:
+            print("___________Troublesome PMF: {}___________".format(pmf))
+            return 0
+        return 1
+    else:
+        return 2
+
+
+if __name__ == "__main__":
+    all_first_pmfs_typeless = pickle.load(open("all_first_pmfs_typeless", 'rb'))
+    all_pmfs = pickle.load(open("all_pmfs.p", 'rb'))
+
+    fewer_states_side = []
+    for key in all_first_pmfs_typeless:
+        animal_biases = np.zeros(2)
+        for defined_points, pmf in all_first_pmfs_typeless[key]:
+            bias = np.mean(pmf[defined_points])
+            if bias > 0.55:
+                animal_biases[0] += 1
+            elif bias < 0.45:
+                animal_biases[1] += 1
+        fewer_states_side.append(np.min(animal_biases / animal_biases.sum()))
+    plt.hist(fewer_states_side)
+    sns.despine()
+    plt.tight_layout()
+    plt.savefig("./meeting_figures/proportion_other_bias")
+    plt.show()
+
+    total_counter = 0
+    bias_counter = 0
+    tendency_counter = 0
+    lapse_counter = 0
+    for pmf in all_pmfs:
+        max_b, min_b = 0, 1
+        max_tendency, min_tendency = 0, 1
+        max_lapse_diff, min_lapse_diff = 0, 1
+        for p in pmf[1]:
+            if pmf[0][5]:  # if this part of the pmf is defined, just take it
+                bias = p[5]
+            deviation = 0
+            while True:  # just take the closest thing
+                deviation += 1
+                if pmf[0][5 - deviation] and pmf[0][5 + deviation]:
+                    bias = (p[5 - deviation] + p[5 + deviation]) / 2
+                    break
+            max_b = max(max_b, bias)
+            min_b = min(min_b, bias)
+            max_tendency = max(max_tendency, np.mean(p[pmf[0]]))
+            min_tendency = min(min_tendency, np.mean(p[pmf[0]]))
+            max_lapse_diff = max(max_lapse_diff, p[0] + p[-1] - 1)
+            min_lapse_diff = min(min_lapse_diff, p[0] + p[-1] - 1)
+        bias_changed = max_b > 0.55 and min_b < 0.45
+        tendency_changed = max_tendency > 0.55 and min_tendency < 0.45
+        lapse_changed = max_lapse_diff > 0.1 and min_lapse_diff < -0.1
+        bias_counter += bias_changed
+        tendency_counter += tendency_changed
+        lapse_counter += lapse_changed
+        if bias_changed or tendency_changed or lapse_changed:
+            total_counter += 1
+            for p in pmf[1]:
+                plt.plot(np.where(pmf[0])[0], p[pmf[0]])
+            plt.title("bias: {}, tendency: {}, lapse: {}".format(bias_changed, tendency_changed, lapse_changed))
+            plt.ylim(0, 1)
+            plt.axvline(5, color='grey')
+            plt.axhline(0.5, color='grey')
+            plt.savefig("./meeting_figures/bias_change_{}".format(total_counter))
+            plt.close()
+    print(total_counter)
+    print(bias_counter)
+    print(tendency_counter)
+    print(lapse_counter)
+
+    pmf_ranges = []
+    for key in all_first_pmfs_typeless:
+        for defined_points, pmf in all_first_pmfs_typeless[key]:
+            if pmf_type(pmf) == 2:
+                pmf_ranges.append(pmf[-1] - pmf[0])
+                # if pmf_ranges[-1] < 0.6:
+                #     plt.plot(pmf)
+                #     plt.title(pmf_ranges[-1])
+                #     plt.ylim(0, 1)
+                #     plt.show()
+    plt.hist(pmf_ranges, bins=40)
+    plt.title("Type 2 ranges")
+    plt.show()
+
+    lapses = []
+    for key in all_first_pmfs_typeless:
+        for defined_points, pmf in all_first_pmfs_typeless[key]:
+            if pmf_type(pmf) != 0:
+                lapses.append(max(pmf[0], 1 - pmf[-1]))
+    plt.hist(lapses, bins=40)
+    plt.title("Higher lapse rate of type != 0")
+    plt.show()
+
+    n_rows, n_cols = 5, 6
+    _, axs = plt.subplots(n_rows, n_cols, figsize=(16, 9))
+    for i, key in enumerate(all_first_pmfs_typeless):
+        if i == n_rows * 2:
+            break
+        axs[(i * 3) // n_cols, (i * 3) % n_cols].set_ylabel(key, size=17)
+        for defined_points, pmf in all_first_pmfs_typeless[key]:
+            axs[(i * 3 + pmf_type(pmf)) // n_cols, (i * 3 + pmf_type(pmf)) % n_cols].plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)])
+    for i, ax in enumerate(axs):
+        for j, a in enumerate(ax):
+            a.spines[['right', 'top']].set_visible(False)
+            a.set_ylim(0, 1)
+            if i != n_rows - 1 or j != 0:
+                a.set_xticks([])
+                a.set_yticks([])
+            else:
+                tick_size = 12
+                a.set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
+                a.set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size)
+    plt.tight_layout()
+    plt.savefig("animals 1")
+    plt.show()
+
+    n_rows, n_cols = 5, 6
+    _, axs = plt.subplots(n_rows, n_cols, figsize=(16, 9))
+    for i, key in enumerate(all_first_pmfs_typeless):
+        if i < n_rows * 2:
+            continue
+        elif i < n_rows * 4:
+            i = i - n_rows * 2
+        else:
+            break
+        if i == n_rows * 2:
+            break
+        axs[(i * 3) // n_cols, (i * 3) % n_cols].set_ylabel(key, size=17)
+        for defined_points, pmf in all_first_pmfs_typeless[key]:
+            axs[(i * 3 + pmf_type(pmf)) // n_cols, (i * 3 + pmf_type(pmf)) % n_cols].plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)])
+    for i, ax in enumerate(axs):
+        for j, a in enumerate(ax):
+            a.spines[['right', 'top']].set_visible(False)
+            a.set_ylim(0, 1)
+            if i != n_rows - 1 or j != 0:
+                a.set_xticks([])
+                a.set_yticks([])
+            else:
+                tick_size = 12
+                a.set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
+                a.set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size)
+    plt.tight_layout()
+    plt.savefig("animals 2")
+    plt.show()
+
+    n_rows, n_cols = 5, 6
+    _, axs = plt.subplots(n_rows, n_cols, figsize=(16, 9))
+    for i, key in enumerate(all_first_pmfs_typeless):
+        if i < n_rows * 4:
+            continue
+        elif i < n_rows * 6:
+            i = i - n_rows * 4
+        else:
+            break
+        if i == n_rows * 4:
+            break
+        axs[(i * 3) // n_cols, (i * 3) % n_cols].set_ylabel(key, size=17)
+        for defined_points, pmf in all_first_pmfs_typeless[key]:
+            axs[(i * 3 + pmf_type(pmf)) // n_cols, (i * 3 + pmf_type(pmf)) % n_cols].plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)])
+    for i, ax in enumerate(axs):
+        for j, a in enumerate(ax):
+            a.spines[['right', 'top']].set_visible(False)
+            a.set_ylim(0, 1)
+            if i != n_rows - 1 or j != 0:
+                a.set_xticks([])
+                a.set_yticks([])
+            else:
+                tick_size = 12
+                a.set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
+                a.set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size)
+    plt.tight_layout()
+    plt.savefig("animals 3")
+    plt.show()
+
+    n_rows, n_cols = 5, 6
+    _, axs = plt.subplots(n_rows, n_cols, figsize=(16, 9))
+    for i, key in enumerate(all_first_pmfs_typeless):
+        if i < n_rows * 6:
+            continue
+        elif i < n_rows * 8:
+            i = i - n_rows * 6
+        else:
+            break
+        if i == n_rows * 6:
+            break
+        axs[(i * 3) // n_cols, (i * 3) % n_cols].set_ylabel(key, size=17)
+        for defined_points, pmf in all_first_pmfs_typeless[key]:
+            axs[(i * 3 + pmf_type(pmf)) // n_cols, (i * 3 + pmf_type(pmf)) % n_cols].plot(np.where(defined_points)[0], pmf[defined_points], c=type2color[pmf_type(pmf)])
+    for i, ax in enumerate(axs):
+        for j, a in enumerate(ax):
+            a.spines[['right', 'top']].set_visible(False)
+            a.set_ylim(0, 1)
+            if i != n_rows - 1 or j != 0:
+                a.set_xticks([])
+                a.set_yticks([])
+            else:
+                tick_size = 12
+                a.set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
+                a.set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size)
+    plt.tight_layout()
+    plt.savefig("animals 4")
+    plt.show()
+
+    # Collection of PMFs which change state type
+    all_changing_pmfs = pickle.load(open("changing_pmfs.p", 'rb'))
+    all_changing_pmf_names = pickle.load(open("changing_pmf_names.p", 'rb'))
+
+    plt.figure(figsize=(16, 9))
+    for i, (pmf, name) in enumerate(zip(all_changing_pmfs, all_changing_pmf_names)):
+        plt.subplot(4, 7, i + 1)
+        plt.title(name)
+        for p in pmf[1]:
+            plt.plot(np.where(pmf[0])[0], p[pmf[0]], color=type2color[pmf_type(p)])
+        plt.ylim(0, 1)
+
+        sns.despine()
+        if i+1 != 22:
+            plt.gca().set_xticks([])
+            plt.gca().set_yticks([])
+        else:
+            plt.xlabel("Contrasts", size=22)
+            plt.ylabel("P(rightwards)", size=22)
+            plt.gca().set_xticks([0, 5, 10], [-1, 0, 1], size=16)
+            plt.gca().set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=16)
+
+        if i + 1 == 30:
+            break
+
+    plt.tight_layout()
+    plt.savefig("changing pmfs")
+    plt.show()
+
+    # All first PMFs
+    tick_size = 14
+    label_size = 26
+    all_first_pmfs = pickle.load(open("pmfs_temp.p", 'rb'))
+    n_rows, n_cols = 1, 3
+    _, axs = plt.subplots(n_rows, n_cols, figsize=(16, 9))
+    counter = [[0, 0], [0, 0]]
+    save_title = "all types" if True else "KS014 types"
+    if save_title == "KS014 types":
+        all_first_pmfs_typeless = {'KS014': all_first_pmfs_typeless['KS014']}
+
+    for key in all_first_pmfs_typeless:
+        x = all_first_pmfs_typeless[key]
+        for pmf in x:
+            axs[pmf_type(pmf[1])].plot(np.where(pmf[0])[0], pmf[1][pmf[0]], c=type2color[pmf_type(pmf[1])])
+    axs[0].set_ylim(0, 1)
+    axs[1].set_ylim(0, 1)
+    axs[2].set_ylim(0, 1)
+    axs[0].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
+    axs[1].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
+    axs[2].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
+    axs[0].set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size)
+    axs[1].set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size)
+    axs[2].set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size)
+    axs[0].spines[['right', 'top']].set_visible(False)
+    axs[1].spines[['right', 'top']].set_visible(False)
+    axs[2].spines[['right', 'top']].set_visible(False)
+    axs[0].set_xlim(0, 10)
+    axs[1].set_xlim(0, 10)
+    axs[2].set_xlim(0, 10)
+    axs[0].set_ylabel("P(rightwards)", size=label_size)
+    axs[0].set_xlabel("Contrasts", size=label_size)
+
+    print(counter)
+    plt.tight_layout()
+    plt.savefig(save_title)
+    plt.show()
+    if save_title == "KS014 types":
+        quit()
+
+    # Type 2 PMFs
+    counter = 0
+    fig, ax = plt.subplots(1, 3, figsize=(16, 9))
+    for key in all_first_pmfs_typeless:
+        for defined_points, pmf in all_first_pmfs_typeless[key]:
+            if pmf_type(pmf) != 1:
+                continue
+            # linestyle = '-' if x[5] == 0 else '--'
+            # if linestyle == '--':
+            #     continue
+            if np.abs(pmf[0] + pmf[-1] - 1) <= 0.1:
+                counter += 1
+                use_ax = 2
+            else:
+                use_ax = int(pmf[0] > 1 - pmf[-1])
+
+            ax[use_ax].plot(np.where(defined_points)[0], pmf[defined_points], c='b')
+    ax[0].set_ylim(0, 1)
+    ax[0].set_xlim(0, 10)
+    ax[0].spines[['right', 'top']].set_visible(False)
+    ax[0].set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1], size=tick_size)
+    ax[0].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
+    ax[0].set_ylabel("P(rightwards)", size=label_size)
+
+    ax[1].set_ylim(0, 1)
+    ax[1].set_xlim(0, 10)
+    ax[1].set_yticks([])
+    ax[1].spines[['right', 'top']].set_visible(False)
+    ax[1].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
+    ax[1].set_xlabel("Contrasts", size=label_size)
+
+    ax[2].set_ylim(0, 1)
+    ax[2].set_xlim(0, 10)
+    ax[2].set_yticks([])
+    ax[2].spines[['right', 'top']].set_visible(False)
+    ax[2].set_xticks(np.arange(11), all_conts, size=tick_size, rotation=45)
+    print(counter)
+    plt.tight_layout()
+    plt.savefig("differentiate type 2")
+    plt.show()
+
+
+    # visualise pmf types
+    # lapses = [0.1, 0.2, 0.25, 0.33, 0.4, 0.45, 0.5, 0.55, 0.66, 0.9]
+    # test_pmf = np.zeros(4)
+    # for i, lapse_l in enumerate(lapses):
+    #     plt.subplot(1, 10, 1+i)
+    #     if i != 0:
+    #         plt.gca().set_yticklabels([])
+    #     plt.ylim(0, 1)
+    #     test_pmf[:2] = lapse_l
+    #     for lapse_r in np.linspace(0.02, 0.98, 33):
+    #         test_pmf[2:] = lapse_r
+    #         plt.plot([0, 1, 9, 10], test_pmf, c=type_colours[pmf_type(test_pmf)])
+    # plt.show()
diff --git a/simplex_plot.py b/simplex_plot.py
index dd8ba57d9f3d356cda8e04524404b4833a627a60..d518ca4025c531ac269c63afabe7eec0417be66f 100644
--- a/simplex_plot.py
+++ b/simplex_plot.py
@@ -50,7 +50,7 @@ def plotSimplex(points, fig=None,
 
     P.axis('off')
 
-    P.savefig("dur_simplex.png", bbox_inches='tight')
+    P.savefig("dur_simplex.png", bbox_inches='tight', dpi=300, transparent=True)
     if show:
         P.show()
     else: